singletongue commited on
Commit
3f97903
·
verified ·
1 Parent(s): c1e0e01

Set maximum values for input text length and number of lines in input files

Browse files
Files changed (1) hide show
  1. app.py +34 -18
app.py CHANGED
@@ -16,6 +16,9 @@ from transformers import AutoModelForPreTraining, AutoTokenizer
16
 
17
  ALIAS_SEP = "|"
18
  ENTITY_SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[MASK]", "[MASK2]"]
 
 
 
19
 
20
  repo_id = "studio-ousia/luxe"
21
  revision = "ja-v0.3.1"
@@ -68,10 +71,14 @@ def get_texts_from_file(file_path: str | None):
68
  try:
69
  with open(file_path, newline="") as f:
70
  reader = csv.DictReader(f, fieldnames=["text"])
71
- for row in reader:
 
 
 
 
72
  text = normalize_text(row["text"]).strip()
73
  if text != "":
74
- texts.append(text)
75
  except Exception as e:
76
  gr.Warning("ファイルを正しく読み込めませんでした。")
77
  print(e)
@@ -144,7 +151,7 @@ def get_topk_entities_from_texts(
144
  k: int = 5,
145
  entity_span_sensitivity: float = 1.0,
146
  nayose_coef: float = 1.0,
147
- entities_are_replaced: bool = False,
148
  ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
149
  model, tokenizer, bm25_tokenizer, bm25_retriever = models
150
 
@@ -196,7 +203,7 @@ def get_topk_entities_from_texts(
196
  if model_outputs.entity_logits is not None:
197
  span_entity_logits = model_outputs.entity_logits[0, :, :500000]
198
 
199
- if nayose_coef > 0.0 and not entities_are_replaced:
200
  nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans]
201
  nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries)
202
  nayose_scores = torch.vstack(
@@ -265,7 +272,11 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
265
  try:
266
  with open(file_path, newline="") as f:
267
  reader = csv.DictReader(f, fieldnames=["entity", "text"])
268
- for row in reader:
 
 
 
 
269
  entity = normalize_text(row["entity"]).strip()
270
  text = normalize_text(row["text"]).strip()
271
  if entity != "" and text != "":
@@ -281,6 +292,7 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
281
  def replace_entities(
282
  models,
283
  new_entity_text_pairs: list[tuple[str, str]],
 
284
  new_num_category_entities: int = 0,
285
  new_entity_counts: list[int] | None = None,
286
  new_padding_idx: int = 0,
@@ -314,7 +326,7 @@ def replace_entities(
314
 
315
  for entity, text in new_entity_text_pairs:
316
  entity_id = tokenizer.entity_vocab[entity]
317
- tokenized_inputs = tokenizer(text, return_tensors="pt")
318
  model_outputs = model(**tokenized_inputs)
319
  entity_embeddings = model.entity_predictions.transform(model_outputs.last_hidden_state[:, 0])
320
  new_entity_embeddings_dict[entity_id].append(entity_embeddings[0])
@@ -363,7 +375,7 @@ def replace_entities(
363
 
364
  gr.Info("モデルとトークナイザのエンティティの置き換えが完了しました", duration=5)
365
 
366
- return True
367
 
368
 
369
  with gr.Blocks() as demo:
@@ -381,7 +393,7 @@ with gr.Blocks() as demo:
381
 
382
  texts = gr.State([])
383
 
384
- entities_are_replaced = gr.State(False)
385
 
386
  topk = gr.State(5)
387
  entity_span_sensitivity = gr.State(1.0)
@@ -400,12 +412,14 @@ with gr.Blocks() as demo:
400
  gr.Markdown("## 入力テキスト")
401
 
402
  with gr.Tab(label="直接入力"):
403
- text_input = gr.Textbox(label="入力テキスト")
404
  with gr.Tab(label="ファイルアップロード"):
405
- texts_file = gr.File(label="入力テキストファイル")
406
 
407
  with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
408
- new_entity_text_pairs_file = gr.File(label="エンティティと説明文のCSVファイル")
 
 
409
  new_entity_text_pairs_input = gr.Dataframe(
410
  # value=sample_new_entity_text_pairs,
411
  headers=["entity", "text"],
@@ -420,7 +434,9 @@ with gr.Blocks() as demo:
420
  fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
421
  )
422
  replace_entity_button.click(
423
- fn=replace_entities, inputs=[models, new_entity_text_pairs_input], outputs=entities_are_replaced
 
 
424
  )
425
 
426
  with gr.Accordion(label="ハイパーパラメータ", open=False):
@@ -442,28 +458,28 @@ with gr.Blocks() as demo:
442
 
443
  texts.change(
444
  fn=get_topk_entities_from_texts,
445
- inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
446
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
447
  )
448
  topk.change(
449
  fn=get_topk_entities_from_texts,
450
- inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
451
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
452
  )
453
  entity_span_sensitivity.change(
454
  fn=get_topk_entities_from_texts,
455
- inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
456
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
457
  )
458
  nayose_coef.change(
459
  fn=get_topk_entities_from_texts,
460
- inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
461
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
462
  )
463
 
464
- entities_are_replaced.change(
465
  fn=get_topk_entities_from_texts,
466
- inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
467
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
468
  )
469
 
 
16
 
17
  ALIAS_SEP = "|"
18
  ENTITY_SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[MASK]", "[MASK2]"]
19
+ MAX_TEXT_LENGTH = 800
20
+ MAX_TEXT_FILE_LINES = 100
21
+ MAX_ENTITY_FILE_LINES = 1000
22
 
23
  repo_id = "studio-ousia/luxe"
24
  revision = "ja-v0.3.1"
 
71
  try:
72
  with open(file_path, newline="") as f:
73
  reader = csv.DictReader(f, fieldnames=["text"])
74
+ for i, row in enumerate(reader):
75
+ if i >= MAX_TEXT_FILE_LINES:
76
+ gr.Info(f"{MAX_TEXT_FILE_LINES}行目までのデータを読み込みました。")
77
+ break
78
+
79
  text = normalize_text(row["text"]).strip()
80
  if text != "":
81
+ texts.append(text[:MAX_TEXT_LENGTH])
82
  except Exception as e:
83
  gr.Warning("ファイルを正しく読み込めませんでした。")
84
  print(e)
 
151
  k: int = 5,
152
  entity_span_sensitivity: float = 1.0,
153
  nayose_coef: float = 1.0,
154
+ entity_replaced_counts: bool = False,
155
  ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
156
  model, tokenizer, bm25_tokenizer, bm25_retriever = models
157
 
 
203
  if model_outputs.entity_logits is not None:
204
  span_entity_logits = model_outputs.entity_logits[0, :, :500000]
205
 
206
+ if nayose_coef > 0.0 and entity_replaced_counts == 0:
207
  nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans]
208
  nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries)
209
  nayose_scores = torch.vstack(
 
272
  try:
273
  with open(file_path, newline="") as f:
274
  reader = csv.DictReader(f, fieldnames=["entity", "text"])
275
+ for i, row in enumerate(reader):
276
+ if i >= MAX_ENTITY_FILE_LINES:
277
+ gr.Info(f"{MAX_ENTITY_FILE_LINES}行目までのデータを読み込みました。")
278
+ break
279
+
280
  entity = normalize_text(row["entity"]).strip()
281
  text = normalize_text(row["text"]).strip()
282
  if entity != "" and text != "":
 
292
  def replace_entities(
293
  models,
294
  new_entity_text_pairs: list[tuple[str, str]],
295
+ entity_replaced_counts: int,
296
  new_num_category_entities: int = 0,
297
  new_entity_counts: list[int] | None = None,
298
  new_padding_idx: int = 0,
 
326
 
327
  for entity, text in new_entity_text_pairs:
328
  entity_id = tokenizer.entity_vocab[entity]
329
+ tokenized_inputs = tokenizer(text[:MAX_TEXT_LENGTH], return_tensors="pt")
330
  model_outputs = model(**tokenized_inputs)
331
  entity_embeddings = model.entity_predictions.transform(model_outputs.last_hidden_state[:, 0])
332
  new_entity_embeddings_dict[entity_id].append(entity_embeddings[0])
 
375
 
376
  gr.Info("モデルとトークナイザのエンティティの置き換えが完了しました", duration=5)
377
 
378
+ return entity_replaced_counts + 1
379
 
380
 
381
  with gr.Blocks() as demo:
 
393
 
394
  texts = gr.State([])
395
 
396
+ entity_replaced_counts = gr.State(0)
397
 
398
  topk = gr.State(5)
399
  entity_span_sensitivity = gr.State(1.0)
 
412
  gr.Markdown("## 入力テキスト")
413
 
414
  with gr.Tab(label="直接入力"):
415
+ text_input = gr.Textbox(label=f"入力テキスト(最大{MAX_TEXT_LENGTH}文字)", max_length=MAX_TEXT_LENGTH)
416
  with gr.Tab(label="ファイルアップロード"):
417
+ texts_file = gr.File(label=f"入力テキストファイル(最大{MAX_TEXT_FILE_LINES}行)")
418
 
419
  with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
420
+ new_entity_text_pairs_file = gr.File(
421
+ label=f"エンティティと説明文のCSVファイル(最大{MAX_ENTITY_FILE_LINES}行)"
422
+ )
423
  new_entity_text_pairs_input = gr.Dataframe(
424
  # value=sample_new_entity_text_pairs,
425
  headers=["entity", "text"],
 
434
  fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
435
  )
436
  replace_entity_button.click(
437
+ fn=replace_entities,
438
+ inputs=[models, new_entity_text_pairs_input, entity_replaced_counts],
439
+ outputs=entity_replaced_counts,
440
  )
441
 
442
  with gr.Accordion(label="ハイパーパラメータ", open=False):
 
458
 
459
  texts.change(
460
  fn=get_topk_entities_from_texts,
461
+ inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
462
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
463
  )
464
  topk.change(
465
  fn=get_topk_entities_from_texts,
466
+ inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
467
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
468
  )
469
  entity_span_sensitivity.change(
470
  fn=get_topk_entities_from_texts,
471
+ inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
472
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
473
  )
474
  nayose_coef.change(
475
  fn=get_topk_entities_from_texts,
476
+ inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
477
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
478
  )
479
 
480
+ entity_replaced_counts.change(
481
  fn=get_topk_entities_from_texts,
482
+ inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
483
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
484
  )
485