Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Use ja-v0.3.2 model, add submit button
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ MAX_TEXT_FILE_LINES = 100
|
|
21 |
MAX_ENTITY_FILE_LINES = 1000
|
22 |
|
23 |
repo_id = "studio-ousia/luxe"
|
24 |
-
revision = "ja-v0.3.
|
25 |
|
26 |
nayose_repo_id = "studio-ousia/luxe-nayose-bm25"
|
27 |
|
@@ -218,7 +218,7 @@ def get_topk_entities_from_texts(
|
|
218 |
else:
|
219 |
topk_span_entities.append([])
|
220 |
|
221 |
-
return batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
|
222 |
|
223 |
|
224 |
def get_selected_entity(evt: gr.SelectData):
|
@@ -391,7 +391,8 @@ with gr.Blocks() as demo:
|
|
391 |
# cf. https://www.gradio.app/docs/gradio/state#param-state-value
|
392 |
models = gr.State((model, tokenizer, bm25_tokenizer, bm25_retriever))
|
393 |
|
394 |
-
|
|
|
395 |
|
396 |
entity_replaced_counts = gr.State(0)
|
397 |
|
@@ -414,12 +415,37 @@ with gr.Blocks() as demo:
|
|
414 |
with gr.Tab(label="直接入力"):
|
415 |
text_input = gr.Textbox(label=f"入力テキスト(最大{MAX_TEXT_LENGTH}文字)", max_length=MAX_TEXT_LENGTH)
|
416 |
with gr.Tab(label="ファイルアップロード"):
|
417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
|
419 |
with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
|
420 |
-
|
421 |
-
|
|
|
|
|
422 |
)
|
|
|
|
|
|
|
|
|
|
|
423 |
new_entity_text_pairs_input = gr.Dataframe(
|
424 |
# value=sample_new_entity_text_pairs,
|
425 |
headers=["entity", "text"],
|
@@ -429,6 +455,7 @@ with gr.Blocks() as demo:
|
|
429 |
interactive=True,
|
430 |
)
|
431 |
replace_entity_button = gr.Button(value="エンティティ語彙を置き換える")
|
|
|
432 |
|
433 |
new_entity_text_pairs_file.change(
|
434 |
fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
|
@@ -439,59 +466,29 @@ with gr.Blocks() as demo:
|
|
439 |
outputs=entity_replaced_counts,
|
440 |
)
|
441 |
|
442 |
-
|
443 |
-
|
444 |
-
entity_span_sensitivity_input = gr.Slider(
|
445 |
-
minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True
|
446 |
-
)
|
447 |
-
nayose_coef_input = gr.Slider(
|
448 |
-
minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
|
449 |
-
)
|
450 |
-
|
451 |
-
text_input.change(fn=lambda text: [normalize_text(text)], inputs=text_input, outputs=texts)
|
452 |
-
texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
|
453 |
-
topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk)
|
454 |
-
entity_span_sensitivity_input.change(
|
455 |
-
fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
|
456 |
-
)
|
457 |
-
nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef)
|
458 |
-
|
459 |
-
texts.change(
|
460 |
-
fn=get_topk_entities_from_texts,
|
461 |
-
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
462 |
-
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
463 |
-
)
|
464 |
-
topk.change(
|
465 |
-
fn=get_topk_entities_from_texts,
|
466 |
-
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
467 |
-
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
468 |
-
)
|
469 |
-
entity_span_sensitivity.change(
|
470 |
fn=get_topk_entities_from_texts,
|
471 |
-
inputs=[models,
|
472 |
-
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
473 |
)
|
474 |
-
|
475 |
fn=get_topk_entities_from_texts,
|
476 |
-
inputs=[models,
|
477 |
-
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
478 |
-
)
|
479 |
-
|
480 |
-
entity_replaced_counts.change(
|
481 |
-
fn=get_topk_entities_from_texts,
|
482 |
-
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
483 |
-
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
484 |
)
|
485 |
|
486 |
gr.Markdown("---")
|
487 |
gr.Markdown("## 出力エンティティ")
|
488 |
|
489 |
-
@gr.render(
|
|
|
|
|
490 |
def render_topk_entities(
|
491 |
-
|
492 |
):
|
493 |
for text, entity_spans, normal_entities, category_entities, span_entities in zip(
|
494 |
-
|
495 |
):
|
496 |
highlighted_text_value = []
|
497 |
cur = 0
|
|
|
21 |
MAX_ENTITY_FILE_LINES = 1000
|
22 |
|
23 |
repo_id = "studio-ousia/luxe"
|
24 |
+
revision = "ja-v0.3.2"
|
25 |
|
26 |
nayose_repo_id = "studio-ousia/luxe-nayose-bm25"
|
27 |
|
|
|
218 |
else:
|
219 |
topk_span_entities.append([])
|
220 |
|
221 |
+
return texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
|
222 |
|
223 |
|
224 |
def get_selected_entity(evt: gr.SelectData):
|
|
|
391 |
# cf. https://www.gradio.app/docs/gradio/state#param-state-value
|
392 |
models = gr.State((model, tokenizer, bm25_tokenizer, bm25_retriever))
|
393 |
|
394 |
+
input_texts = gr.State([])
|
395 |
+
output_texts = gr.State([])
|
396 |
|
397 |
entity_replaced_counts = gr.State(0)
|
398 |
|
|
|
415 |
with gr.Tab(label="直接入力"):
|
416 |
text_input = gr.Textbox(label=f"入力テキスト(最大{MAX_TEXT_LENGTH}文字)", max_length=MAX_TEXT_LENGTH)
|
417 |
with gr.Tab(label="ファイルアップロード"):
|
418 |
+
gr.Markdown(f"1行1事例のテキストファイル(最大{MAX_TEXT_FILE_LINES}行)をアップロードできます。")
|
419 |
+
texts_file = gr.File(label="入力テキストファイル")
|
420 |
+
|
421 |
+
with gr.Accordion(label="ハイパーパラメータ", open=False):
|
422 |
+
topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
|
423 |
+
entity_span_sensitivity_input = gr.Slider(
|
424 |
+
minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True
|
425 |
+
)
|
426 |
+
nayose_coef_input = gr.Slider(
|
427 |
+
minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
|
428 |
+
)
|
429 |
+
|
430 |
+
text_input.change(fn=lambda text: [normalize_text(text)], inputs=text_input, outputs=input_texts)
|
431 |
+
texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=input_texts)
|
432 |
+
topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk)
|
433 |
+
entity_span_sensitivity_input.change(
|
434 |
+
fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
|
435 |
+
)
|
436 |
+
nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef)
|
437 |
|
438 |
with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
|
439 |
+
gr.Markdown(
|
440 |
+
"""LUXEのモデルのエンティティの語彙を任意のエンティティ集合に置き換えます。
|
441 |
+
エンティティと共に与えられるエンティティの説明文から、エンティティの埋め込みが計算されます。""",
|
442 |
+
line_breaks=True,
|
443 |
)
|
444 |
+
gr.Markdown(
|
445 |
+
f"「エンティティ」と「エンティティの説明文」の2列からなるCSVファイル(最大{MAX_ENTITY_FILE_LINES}行)をアップロードできます。"
|
446 |
+
)
|
447 |
+
new_entity_text_pairs_file = gr.File(label="エンティティと説明文のCSVファイル", height="128px")
|
448 |
+
gr.Markdown("CSVファイルから読み込まれた項目が以下の表に表示されます。表の内容を直接編集することも可能です。")
|
449 |
new_entity_text_pairs_input = gr.Dataframe(
|
450 |
# value=sample_new_entity_text_pairs,
|
451 |
headers=["entity", "text"],
|
|
|
455 |
interactive=True,
|
456 |
)
|
457 |
replace_entity_button = gr.Button(value="エンティティ語彙を置き換える")
|
458 |
+
gr.Markdown("LUXEのモデルのエンティティ語彙は、デモページの再読み込み時にリセットされます。")
|
459 |
|
460 |
new_entity_text_pairs_file.change(
|
461 |
fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
|
|
|
466 |
outputs=entity_replaced_counts,
|
467 |
)
|
468 |
|
469 |
+
submit_button = gr.Button(value="予測実行", variant="huggingface")
|
470 |
+
submit_button.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
fn=get_topk_entities_from_texts,
|
472 |
+
inputs=[models, input_texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
473 |
+
outputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
474 |
)
|
475 |
+
text_input.submit(
|
476 |
fn=get_topk_entities_from_texts,
|
477 |
+
inputs=[models, input_texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
478 |
+
outputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
)
|
480 |
|
481 |
gr.Markdown("---")
|
482 |
gr.Markdown("## 出力エンティティ")
|
483 |
|
484 |
+
@gr.render(
|
485 |
+
inputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities]
|
486 |
+
)
|
487 |
def render_topk_entities(
|
488 |
+
output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
|
489 |
):
|
490 |
for text, entity_spans, normal_entities, category_entities, span_entities in zip(
|
491 |
+
output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
|
492 |
):
|
493 |
highlighted_text_value = []
|
494 |
cur = 0
|