singletongue commited on
Commit
c1e0e01
·
verified ·
1 Parent(s): d6ab44d

Reset model and tokenizer when the demo is reloaded

Browse files
Files changed (1) hide show
  1. app.py +37 -18
app.py CHANGED
@@ -33,9 +33,6 @@ ignore_category_patterns = [
33
  r"各年の",
34
  ]
35
 
36
- model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
37
- tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
38
-
39
 
40
  class MecabTokenizer:
41
  def __init__(self):
@@ -65,11 +62,6 @@ def normalize_text(text: str) -> str:
65
  return unicodedata.normalize("NFKC", text)
66
 
67
 
68
- bm25_tokenizer = TokenizerHF(lower=True, splitter=tokenizer.tokenize, stopwords=None, stemmer=None)
69
- bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
70
- bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
71
-
72
-
73
  def get_texts_from_file(file_path: str | None):
74
  texts = []
75
  if file_path is not None:
@@ -104,7 +96,7 @@ def get_noun_spans_from_text(text: str) -> list[tuple[int, int]]:
104
  return noun_spans
105
 
106
 
107
- def get_token_spans(text: str) -> list[tuple[int, int]]:
108
  token_spans = []
109
  end = 0
110
  for token in tokenizer.tokenize(text):
@@ -147,12 +139,15 @@ def get_predicted_entity_spans(
147
 
148
 
149
  def get_topk_entities_from_texts(
 
150
  texts: list[str],
151
  k: int = 5,
152
  entity_span_sensitivity: float = 1.0,
153
  nayose_coef: float = 1.0,
154
  entities_are_replaced: bool = False,
155
  ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
 
 
156
  batch_entity_spans: list[list[tuple[int, int]]] = []
157
  topk_normal_entities: list[list[str]] = []
158
  topk_category_entities: list[list[str]] = []
@@ -178,7 +173,7 @@ def get_topk_entities_from_texts(
178
  for text in texts:
179
  tokenized_examples = tokenizer(text, return_tensors="pt")
180
  model_outputs = model(**tokenized_examples)
181
- token_spans = get_token_spans(text)
182
  entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity)
183
  batch_entity_spans.append(entity_spans)
184
 
@@ -223,7 +218,9 @@ def get_selected_entity(evt: gr.SelectData):
223
  return evt.value[0]
224
 
225
 
226
- def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
 
 
227
  query_entity_id = tokenizer.entity_vocab[query_entity]
228
 
229
  id2normal_entity = {
@@ -282,11 +279,14 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
282
 
283
 
284
  def replace_entities(
 
285
  new_entity_text_pairs: list[tuple[str, str]],
286
  new_num_category_entities: int = 0,
287
  new_entity_counts: list[int] | None = None,
288
  new_padding_idx: int = 0,
289
  ) -> True:
 
 
290
  gr.Info("トークナイザのエンティティの語彙を置き換えています...", duration=5)
291
  new_entity_tokens = ENTITY_SPECIAL_TOKENS + [entity for entity, _ in new_entity_text_pairs]
292
 
@@ -367,6 +367,18 @@ def replace_entities(
367
 
368
 
369
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
370
  texts = gr.State([])
371
 
372
  entities_are_replaced = gr.State(False)
@@ -407,7 +419,9 @@ with gr.Blocks() as demo:
407
  new_entity_text_pairs_file.change(
408
  fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
409
  )
410
- replace_entity_button.click(fn=replace_entities, inputs=new_entity_text_pairs_input, outputs=entities_are_replaced)
 
 
411
 
412
  with gr.Accordion(label="ハイパーパラメータ", open=False):
413
  topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
@@ -428,25 +442,30 @@ with gr.Blocks() as demo:
428
 
429
  texts.change(
430
  fn=get_topk_entities_from_texts,
431
- inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
432
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
433
  )
434
  topk.change(
435
  fn=get_topk_entities_from_texts,
436
- inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
437
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
438
  )
439
  entity_span_sensitivity.change(
440
  fn=get_topk_entities_from_texts,
441
- inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
442
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
443
  )
444
  nayose_coef.change(
445
  fn=get_topk_entities_from_texts,
446
- inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
 
 
 
 
 
 
447
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
448
  )
449
- topk_input.change(inputs=topk_input, outputs=topk)
450
 
451
  gr.Markdown("---")
452
  gr.Markdown("## 出力エンティティ")
@@ -499,7 +518,7 @@ with gr.Blocks() as demo:
499
  # gr.Markdown("---")
500
  # gr.Markdown("## 選択されたエンティティの類似エンティティ")
501
 
502
- # selected_entity.change(fn=get_similar_entities, inputs=selected_entity, outputs=similar_entities)
503
 
504
  # @gr.render(inputs=[selected_entity, similar_entities])
505
  # def render_similar_entities(selected_entity, similar_entities):
 
33
  r"各年の",
34
  ]
35
 
 
 
 
36
 
37
  class MecabTokenizer:
38
  def __init__(self):
 
62
  return unicodedata.normalize("NFKC", text)
63
 
64
 
 
 
 
 
 
65
  def get_texts_from_file(file_path: str | None):
66
  texts = []
67
  if file_path is not None:
 
96
  return noun_spans
97
 
98
 
99
+ def get_token_spans(tokenizer, text: str) -> list[tuple[int, int]]:
100
  token_spans = []
101
  end = 0
102
  for token in tokenizer.tokenize(text):
 
139
 
140
 
141
  def get_topk_entities_from_texts(
142
+ models,
143
  texts: list[str],
144
  k: int = 5,
145
  entity_span_sensitivity: float = 1.0,
146
  nayose_coef: float = 1.0,
147
  entities_are_replaced: bool = False,
148
  ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
149
+ model, tokenizer, bm25_tokenizer, bm25_retriever = models
150
+
151
  batch_entity_spans: list[list[tuple[int, int]]] = []
152
  topk_normal_entities: list[list[str]] = []
153
  topk_category_entities: list[list[str]] = []
 
173
  for text in texts:
174
  tokenized_examples = tokenizer(text, return_tensors="pt")
175
  model_outputs = model(**tokenized_examples)
176
+ token_spans = get_token_spans(tokenizer, text)
177
  entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity)
178
  batch_entity_spans.append(entity_spans)
179
 
 
218
  return evt.value[0]
219
 
220
 
221
+ def get_similar_entities(models, query_entity: str, k: int = 10) -> list[str]:
222
+ model, tokenizer, _, _ = models
223
+
224
  query_entity_id = tokenizer.entity_vocab[query_entity]
225
 
226
  id2normal_entity = {
 
279
 
280
 
281
  def replace_entities(
282
+ models,
283
  new_entity_text_pairs: list[tuple[str, str]],
284
  new_num_category_entities: int = 0,
285
  new_entity_counts: list[int] | None = None,
286
  new_padding_idx: int = 0,
287
  ) -> True:
288
+ model, tokenizer, bm25_tokenizer, bm25_retriever = models
289
+
290
  gr.Info("トークナイザのエンティティの語彙を置き換えています...", duration=5)
291
  new_entity_tokens = ENTITY_SPECIAL_TOKENS + [entity for entity, _ in new_entity_text_pairs]
292
 
 
367
 
368
 
369
  with gr.Blocks() as demo:
370
+ model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
371
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
372
+ bm25_tokenizer = TokenizerHF(lower=True, splitter=tokenizer.tokenize, stopwords=None, stemmer=None)
373
+ bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
374
+ bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
375
+
376
+ # Hint: gr.State に callable を渡すと、それが state の初期値を設定するための関数とみなされて
377
+ # __call__ が引数なしで実行されてしまうため、gr.State の引数に model や tokenizer を単体で渡すとエラーになってしまう。
378
+ # ここでは、モデル一式のタプル(callable でない)を渡すことで、そのようなエラーを回避している。
379
+ # cf. https://www.gradio.app/docs/gradio/state#param-state-value
380
+ models = gr.State((model, tokenizer, bm25_tokenizer, bm25_retriever))
381
+
382
  texts = gr.State([])
383
 
384
  entities_are_replaced = gr.State(False)
 
419
  new_entity_text_pairs_file.change(
420
  fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
421
  )
422
+ replace_entity_button.click(
423
+ fn=replace_entities, inputs=[models, new_entity_text_pairs_input], outputs=entities_are_replaced
424
+ )
425
 
426
  with gr.Accordion(label="ハイパーパラメータ", open=False):
427
  topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
 
442
 
443
  texts.change(
444
  fn=get_topk_entities_from_texts,
445
+ inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
446
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
447
  )
448
  topk.change(
449
  fn=get_topk_entities_from_texts,
450
+ inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
451
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
452
  )
453
  entity_span_sensitivity.change(
454
  fn=get_topk_entities_from_texts,
455
+ inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
456
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
457
  )
458
  nayose_coef.change(
459
  fn=get_topk_entities_from_texts,
460
+ inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
461
+ outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
462
+ )
463
+
464
+ entities_are_replaced.change(
465
+ fn=get_topk_entities_from_texts,
466
+ inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
467
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
468
  )
 
469
 
470
  gr.Markdown("---")
471
  gr.Markdown("## 出力エンティティ")
 
518
  # gr.Markdown("---")
519
  # gr.Markdown("## 選択されたエンティティの類似エンティティ")
520
 
521
+ # selected_entity.change(fn=get_similar_entities, inputs=[models, selected_entity], outputs=similar_entities)
522
 
523
  # @gr.render(inputs=[selected_entity, similar_entities])
524
  # def render_similar_entities(selected_entity, similar_entities):