singletongue commited on
Commit
913e5a4
·
verified ·
1 Parent(s): 3f97903

Use ja-v0.3.2 model, add submit button

Browse files
Files changed (1) hide show
  1. app.py +45 -48
app.py CHANGED
@@ -21,7 +21,7 @@ MAX_TEXT_FILE_LINES = 100
21
  MAX_ENTITY_FILE_LINES = 1000
22
 
23
  repo_id = "studio-ousia/luxe"
24
- revision = "ja-v0.3.1"
25
 
26
  nayose_repo_id = "studio-ousia/luxe-nayose-bm25"
27
 
@@ -218,7 +218,7 @@ def get_topk_entities_from_texts(
218
  else:
219
  topk_span_entities.append([])
220
 
221
- return batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
222
 
223
 
224
  def get_selected_entity(evt: gr.SelectData):
@@ -391,7 +391,8 @@ with gr.Blocks() as demo:
391
  # cf. https://www.gradio.app/docs/gradio/state#param-state-value
392
  models = gr.State((model, tokenizer, bm25_tokenizer, bm25_retriever))
393
 
394
- texts = gr.State([])
 
395
 
396
  entity_replaced_counts = gr.State(0)
397
 
@@ -414,12 +415,37 @@ with gr.Blocks() as demo:
414
  with gr.Tab(label="直接入力"):
415
  text_input = gr.Textbox(label=f"入力テキスト(最大{MAX_TEXT_LENGTH}文字)", max_length=MAX_TEXT_LENGTH)
416
  with gr.Tab(label="ファイルアップロード"):
417
- texts_file = gr.File(label=f"入力テキストファイル(最大{MAX_TEXT_FILE_LINES}行)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
  with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
420
- new_entity_text_pairs_file = gr.File(
421
- label=f"エンティティと説明文のCSVファイル(最大{MAX_ENTITY_FILE_LINES}行)"
 
 
422
  )
 
 
 
 
 
423
  new_entity_text_pairs_input = gr.Dataframe(
424
  # value=sample_new_entity_text_pairs,
425
  headers=["entity", "text"],
@@ -429,6 +455,7 @@ with gr.Blocks() as demo:
429
  interactive=True,
430
  )
431
  replace_entity_button = gr.Button(value="エンティティ語彙を置き換える")
 
432
 
433
  new_entity_text_pairs_file.change(
434
  fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
@@ -439,59 +466,29 @@ with gr.Blocks() as demo:
439
  outputs=entity_replaced_counts,
440
  )
441
 
442
- with gr.Accordion(label="ハイパーパラメータ", open=False):
443
- topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
444
- entity_span_sensitivity_input = gr.Slider(
445
- minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True
446
- )
447
- nayose_coef_input = gr.Slider(
448
- minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
449
- )
450
-
451
- text_input.change(fn=lambda text: [normalize_text(text)], inputs=text_input, outputs=texts)
452
- texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
453
- topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk)
454
- entity_span_sensitivity_input.change(
455
- fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
456
- )
457
- nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef)
458
-
459
- texts.change(
460
- fn=get_topk_entities_from_texts,
461
- inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
462
- outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
463
- )
464
- topk.change(
465
- fn=get_topk_entities_from_texts,
466
- inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
467
- outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
468
- )
469
- entity_span_sensitivity.change(
470
  fn=get_topk_entities_from_texts,
471
- inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
472
- outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
473
  )
474
- nayose_coef.change(
475
  fn=get_topk_entities_from_texts,
476
- inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
477
- outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
478
- )
479
-
480
- entity_replaced_counts.change(
481
- fn=get_topk_entities_from_texts,
482
- inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
483
- outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
484
  )
485
 
486
  gr.Markdown("---")
487
  gr.Markdown("## 出力エンティティ")
488
 
489
- @gr.render(inputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities])
 
 
490
  def render_topk_entities(
491
- texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
492
  ):
493
  for text, entity_spans, normal_entities, category_entities, span_entities in zip(
494
- texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
495
  ):
496
  highlighted_text_value = []
497
  cur = 0
 
21
  MAX_ENTITY_FILE_LINES = 1000
22
 
23
  repo_id = "studio-ousia/luxe"
24
+ revision = "ja-v0.3.2"
25
 
26
  nayose_repo_id = "studio-ousia/luxe-nayose-bm25"
27
 
 
218
  else:
219
  topk_span_entities.append([])
220
 
221
+ return texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
222
 
223
 
224
  def get_selected_entity(evt: gr.SelectData):
 
391
  # cf. https://www.gradio.app/docs/gradio/state#param-state-value
392
  models = gr.State((model, tokenizer, bm25_tokenizer, bm25_retriever))
393
 
394
+ input_texts = gr.State([])
395
+ output_texts = gr.State([])
396
 
397
  entity_replaced_counts = gr.State(0)
398
 
 
415
  with gr.Tab(label="直接入力"):
416
  text_input = gr.Textbox(label=f"入力テキスト(最大{MAX_TEXT_LENGTH}文字)", max_length=MAX_TEXT_LENGTH)
417
  with gr.Tab(label="ファイルアップロード"):
418
+ gr.Markdown(f"1行1事例のテキストファイル(最大{MAX_TEXT_FILE_LINES}行)をアップロードできます。")
419
+ texts_file = gr.File(label="入力テキストファイル")
420
+
421
+ with gr.Accordion(label="ハイパーパラメータ", open=False):
422
+ topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
423
+ entity_span_sensitivity_input = gr.Slider(
424
+ minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True
425
+ )
426
+ nayose_coef_input = gr.Slider(
427
+ minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
428
+ )
429
+
430
+ text_input.change(fn=lambda text: [normalize_text(text)], inputs=text_input, outputs=input_texts)
431
+ texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=input_texts)
432
+ topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk)
433
+ entity_span_sensitivity_input.change(
434
+ fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
435
+ )
436
+ nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef)
437
 
438
  with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
439
+ gr.Markdown(
440
+ """LUXEのモデルのエンティティの語彙を任意のエンティティ集合に置き換えます。
441
+ エンティティと共に与えられるエンティティの説明文から、エンティティの埋め込みが計算されます。""",
442
+ line_breaks=True,
443
  )
444
+ gr.Markdown(
445
+ f"「エンティティ」と「エンティティの説明文」の2列からなるCSVファイル(最大{MAX_ENTITY_FILE_LINES}行)をアップロードできます。"
446
+ )
447
+ new_entity_text_pairs_file = gr.File(label="エンティティと説明文のCSVファイル", height="128px")
448
+ gr.Markdown("CSVファイルから読み込まれた項目が以下の表に表示されます。表の内容を直接編集することも可能です。")
449
  new_entity_text_pairs_input = gr.Dataframe(
450
  # value=sample_new_entity_text_pairs,
451
  headers=["entity", "text"],
 
455
  interactive=True,
456
  )
457
  replace_entity_button = gr.Button(value="エンティティ語彙を置き換える")
458
+ gr.Markdown("LUXEのモデルのエンティティ語彙は、デモページの再読み込み時にリセットされます。")
459
 
460
  new_entity_text_pairs_file.change(
461
  fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
 
466
  outputs=entity_replaced_counts,
467
  )
468
 
469
+ submit_button = gr.Button(value="予測実行", variant="huggingface")
470
+ submit_button.click(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  fn=get_topk_entities_from_texts,
472
+ inputs=[models, input_texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
473
+ outputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
474
  )
475
+ text_input.submit(
476
  fn=get_topk_entities_from_texts,
477
+ inputs=[models, input_texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
478
+ outputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
 
 
 
 
 
 
479
  )
480
 
481
  gr.Markdown("---")
482
  gr.Markdown("## 出力エンティティ")
483
 
484
+ @gr.render(
485
+ inputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities]
486
+ )
487
  def render_topk_entities(
488
+ output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
489
  ):
490
  for text, entity_spans, normal_entities, category_entities, span_entities in zip(
491
+ output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
492
  ):
493
  highlighted_text_value = []
494
  cur = 0