singletongue commited on
Commit
fd70e43
·
verified ·
1 Parent(s): 913e5a4

Change MAX_TEXT_FILE_LINES to 10, clean up entity names, modify some UI components

Browse files
Files changed (1) hide show
  1. app.py +182 -213
app.py CHANGED
@@ -2,22 +2,21 @@ import csv
2
  import re
3
  import unicodedata
4
  from collections import defaultdict
5
- from pathlib import Path
6
 
7
  import gradio as gr
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- import unidic_lite
12
  from bm25s.hf import BM25HF, TokenizerHF
13
- from fugashi import GenericTagger
14
  from transformers import AutoModelForPreTraining, AutoTokenizer
15
 
16
 
17
  ALIAS_SEP = "|"
 
18
  ENTITY_SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[MASK]", "[MASK2]"]
19
  MAX_TEXT_LENGTH = 800
20
- MAX_TEXT_FILE_LINES = 100
21
  MAX_ENTITY_FILE_LINES = 1000
22
 
23
  repo_id = "studio-ousia/luxe"
@@ -37,32 +36,21 @@ ignore_category_patterns = [
37
  ]
38
 
39
 
40
- class MecabTokenizer:
41
- def __init__(self):
42
- unidic_dir = unidic_lite.DICDIR
43
- mecabrc_file = Path(unidic_dir, "mecabrc")
44
- mecab_option = f"-d {unidic_dir} -r {mecabrc_file}"
45
- self.tagger = GenericTagger(mecab_option)
 
46
 
47
- def __call__(self, text: str) -> list[tuple[str, str, tuple[int, int]]]:
48
- outputs = []
49
 
50
- end = 0
51
- for node in self.tagger(text):
52
- word = node.surface.strip()
53
- pos = node.feature[0]
54
- start = text.index(word, end)
55
- end = start + len(word)
56
- outputs.append((word, pos, (start, end)))
57
-
58
- return outputs
59
-
60
-
61
- mecab_tokenizer = MecabTokenizer()
62
 
63
 
64
  def normalize_text(text: str) -> str:
65
- return unicodedata.normalize("NFKC", text)
66
 
67
 
68
  def get_texts_from_file(file_path: str | None):
@@ -73,36 +61,20 @@ def get_texts_from_file(file_path: str | None):
73
  reader = csv.DictReader(f, fieldnames=["text"])
74
  for i, row in enumerate(reader):
75
  if i >= MAX_TEXT_FILE_LINES:
76
- gr.Info(f"{MAX_TEXT_FILE_LINES}行目までのデータを読み込みました。")
77
  break
78
 
79
- text = normalize_text(row["text"]).strip()
80
- if text != "":
81
  texts.append(text[:MAX_TEXT_LENGTH])
82
  except Exception as e:
83
- gr.Warning("ファイルを正しく読み込めませんでした。")
84
  print(e)
85
  texts = []
86
 
87
  return texts
88
 
89
 
90
- def get_noun_spans_from_text(text: str) -> list[tuple[int, int]]:
91
- last_pos = None
92
- noun_spans = []
93
-
94
- for word, pos, (start, end) in mecab_tokenizer(text):
95
- if pos == "名詞":
96
- if len(noun_spans) > 0 and last_pos == "名詞":
97
- noun_spans[-1] = (noun_spans[-1][0], end)
98
- else:
99
- noun_spans.append((start, end))
100
-
101
- last_pos = pos
102
-
103
- return noun_spans
104
-
105
-
106
  def get_token_spans(tokenizer, text: str) -> list[tuple[int, int]]:
107
  token_spans = []
108
  end = 0
@@ -147,12 +119,17 @@ def get_predicted_entity_spans(
147
 
148
  def get_topk_entities_from_texts(
149
  models,
150
- texts: list[str],
151
  k: int = 5,
152
  entity_span_sensitivity: float = 1.0,
153
  nayose_coef: float = 1.0,
154
  entity_replaced_counts: bool = False,
155
  ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
 
 
 
 
 
156
  model, tokenizer, bm25_tokenizer, bm25_retriever = models
157
 
158
  batch_entity_spans: list[list[tuple[int, int]]] = []
@@ -177,7 +154,12 @@ def get_topk_entities_from_texts(
177
  and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
178
  ]
179
 
 
 
 
180
  for text in texts:
 
 
181
  tokenized_examples = tokenizer(text, return_tensors="pt")
182
  model_outputs = model(**tokenized_examples)
183
  token_spans = get_token_spans(tokenizer, text)
@@ -188,14 +170,14 @@ def get_topk_entities_from_texts(
188
  model_outputs = model(**tokenized_examples)
189
 
190
  if model_outputs.topic_entity_logits is not None:
191
- _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
192
  topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
193
  else:
194
  topk_normal_entities.append([])
195
 
196
  if model_outputs.topic_category_logits is not None:
197
  model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
198
- _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
199
  topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
200
  else:
201
  topk_category_entities.append([])
@@ -211,7 +193,7 @@ def get_topk_entities_from_texts(
211
  )
212
  span_entity_logits += nayose_coef * nayose_scores
213
 
214
- _, topk_span_entity_ids = span_entity_logits.topk(k)
215
  topk_span_entities.append(
216
  [[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]
217
  )
@@ -221,51 +203,6 @@ def get_topk_entities_from_texts(
221
  return texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
222
 
223
 
224
- def get_selected_entity(evt: gr.SelectData):
225
- return evt.value[0]
226
-
227
-
228
- def get_similar_entities(models, query_entity: str, k: int = 10) -> list[str]:
229
- model, tokenizer, _, _ = models
230
-
231
- query_entity_id = tokenizer.entity_vocab[query_entity]
232
-
233
- id2normal_entity = {
234
- entity_id: entity
235
- for entity, entity_id in tokenizer.entity_vocab.items()
236
- if entity_id < model.config.num_normal_entities
237
- }
238
- id2category_entity = {
239
- entity_id - model.config.num_normal_entities: entity
240
- for entity, entity_id in tokenizer.entity_vocab.items()
241
- if entity_id >= model.config.num_normal_entities
242
- }
243
- ignore_category_entity_ids = [
244
- entity_id - model.config.num_normal_entities
245
- for entity, entity_id in tokenizer.entity_vocab.items()
246
- if entity_id >= model.config.num_normal_entities
247
- and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
248
- ]
249
- entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
250
- normal_entity_embeddings = entity_embeddings[: model.config.num_normal_entities]
251
- category_entity_embeddings = entity_embeddings[model.config.num_normal_entities :]
252
-
253
- if query_entity_id < model.config.num_normal_entities:
254
- topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T
255
- topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
256
- topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
257
- else:
258
- query_entity_id -= model.config.num_normal_entities
259
- topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
260
-
261
- topk_entity_scores[ignore_category_entity_ids] = float("-inf")
262
-
263
- topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
264
- topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
265
-
266
- return topk_entities
267
-
268
-
269
  def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]]:
270
  new_entity_text_pairs = []
271
  if file_path is not None:
@@ -274,7 +211,7 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
274
  reader = csv.DictReader(f, fieldnames=["entity", "text"])
275
  for i, row in enumerate(reader):
276
  if i >= MAX_ENTITY_FILE_LINES:
277
- gr.Info(f"{MAX_ENTITY_FILE_LINES}行目までのデータを読み込みました。")
278
  break
279
 
280
  entity = normalize_text(row["entity"]).strip()
@@ -282,7 +219,7 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
282
  if entity != "" and text != "":
283
  new_entity_text_pairs.append([entity, text])
284
  except Exception as e:
285
- gr.Warning("ファイルを正しく読み込めませんでした。")
286
  print(e)
287
  new_entity_text_pairs = []
288
 
@@ -290,90 +227,109 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
290
 
291
 
292
  def replace_entities(
293
- models,
294
- new_entity_text_pairs: list[tuple[str, str]],
295
- entity_replaced_counts: int,
296
- new_num_category_entities: int = 0,
297
- new_entity_counts: list[int] | None = None,
298
- new_padding_idx: int = 0,
299
- ) -> True:
300
- model, tokenizer, bm25_tokenizer, bm25_retriever = models
301
-
302
- gr.Info("トークナイザのエンティティの語彙を置き換えています...", duration=5)
303
- new_entity_tokens = ENTITY_SPECIAL_TOKENS + [entity for entity, _ in new_entity_text_pairs]
304
-
305
- new_entity_vocab = {}
306
- for entity in new_entity_tokens:
307
- if entity not in new_entity_vocab:
308
- new_entity_vocab[entity] = len(new_entity_vocab)
309
 
310
- new_entity_vocab = {entity: entity_id for entity_id, entity in enumerate(new_entity_tokens)}
311
 
312
- tokenizer.entity_vocab = new_entity_vocab
313
- tokenizer.entity_pad_token_id = tokenizer.entity_vocab["[PAD]"]
314
- tokenizer.entity_unk_token_id = tokenizer.entity_vocab["[UNK]"]
315
- tokenizer.entity_mask_token_id = tokenizer.entity_vocab["[MASK]"]
316
- tokenizer.entity_mask2_token_id = tokenizer.entity_vocab["[MASK2]"]
317
-
318
- gr.Info("モデルのエンティティの埋め込みを置き換えています...", duration=5)
319
- new_entity_embeddings_dict = defaultdict(list)
320
 
321
- for entity_special_token in ENTITY_SPECIAL_TOKENS:
322
- entity_special_token_id = tokenizer.entity_vocab[entity_special_token]
323
- new_entity_embeddings_dict[entity_special_token_id].append(
324
- model.luke.entity_embeddings.entity_embeddings.weight.data[entity_special_token_id]
325
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  for entity, text in new_entity_text_pairs:
328
- entity_id = tokenizer.entity_vocab[entity]
329
  tokenized_inputs = tokenizer(text[:MAX_TEXT_LENGTH], return_tensors="pt")
330
  model_outputs = model(**tokenized_inputs)
331
- entity_embeddings = model.entity_predictions.transform(model_outputs.last_hidden_state[:, 0])
332
- new_entity_embeddings_dict[entity_id].append(entity_embeddings[0])
 
 
 
 
 
333
 
334
- assert len(new_entity_embeddings_dict) == len(tokenizer.entity_vocab)
 
335
 
336
- new_entity_embeddings = torch.vstack(
337
- [
338
- sum(new_entity_embeddings_dict[i]) / len(new_entity_embeddings_dict[i])
339
- for i in range(len(new_entity_embeddings_dict))
340
- ]
341
- )
342
- new_entity_vocab_size, new_entity_emb_size = new_entity_embeddings.size()
343
- assert new_entity_vocab_size == len(tokenizer.entity_vocab)
 
 
344
 
345
- new_num_normal_entities = new_entity_vocab_size - new_num_category_entities
 
 
 
 
346
 
347
- if new_entity_counts is not None and any(count < 1 for count in new_entity_counts):
348
- raise ValueError("All items in new_entity_counts must be greater than zero")
349
 
350
  if model.config.normalize_entity_embeddings:
351
- new_entity_embeddings = F.normalize(new_entity_embeddings)
352
 
353
- new_entity_embeddings_module = nn.Embedding(
354
- new_entity_vocab_size,
355
- new_entity_emb_size,
356
- padding_idx=new_padding_idx,
 
 
357
  device=model.luke.entity_embeddings.entity_embeddings.weight.device,
358
  dtype=model.luke.entity_embeddings.entity_embeddings.weight.dtype,
359
  )
360
- new_entity_embeddings_module.weight.data = new_entity_embeddings.data
361
- model.luke.entity_embeddings.entity_embeddings = new_entity_embeddings_module
362
 
363
- new_entity_decoder_module = nn.Linear(new_entity_emb_size, new_entity_vocab_size, bias=False)
364
- model.entity_predictions.decoder = new_entity_decoder_module
365
- model.entity_predictions.bias = nn.Parameter(torch.zeros(new_entity_vocab_size))
366
  model.tie_weights()
367
 
368
- if hasattr(model, "entity_log_probs"):
369
- del model.entity_log_probs
 
 
 
 
 
 
 
 
 
 
370
 
371
- model.config.entity_vocab_size = new_entity_vocab_size
372
- model.config.num_normal_entities = new_num_normal_entities
373
- model.config.num_category_entities = new_num_category_entities
374
- model.config.entity_counts = new_entity_counts
 
375
 
376
- gr.Info("モデルとトークナイザのエンティティの置き換えが完了しました", duration=5)
377
 
378
  return entity_replaced_counts + 1
379
 
@@ -385,14 +341,15 @@ with gr.Blocks() as demo:
385
  bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
386
  bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
387
 
 
 
388
  # Hint: gr.State に callable を渡すと、それが state の初期値を設定するための関数とみなされて
389
  # __call__ が引数なしで実行されてしまうため、gr.State の引数に model や tokenizer を単体で渡すとエラーになってしまう。
390
  # ここでは、モデル一式のタプル(callable でない)を渡すことで、そのようなエラーを回避している。
391
  # cf. https://www.gradio.app/docs/gradio/state#param-state-value
392
  models = gr.State((model, tokenizer, bm25_tokenizer, bm25_retriever))
393
 
394
- input_texts = gr.State([])
395
- output_texts = gr.State([])
396
 
397
  entity_replaced_counts = gr.State(0)
398
 
@@ -400,26 +357,59 @@ with gr.Blocks() as demo:
400
  entity_span_sensitivity = gr.State(1.0)
401
  nayose_coef = gr.State(1.0)
402
 
 
403
  batch_entity_spans = gr.State([])
404
  topk_normal_entities = gr.State([])
405
  topk_category_entities = gr.State([])
406
  topk_span_entities = gr.State([])
407
 
408
- selected_entity = gr.State()
409
- similar_entities = gr.State([])
 
 
 
410
 
411
- gr.Markdown("# 📝 LUXE Demo")
 
 
 
412
 
413
  gr.Markdown("## 入力テキスト")
414
 
415
  with gr.Tab(label="直接入力"):
416
  text_input = gr.Textbox(label=f"入力テキスト(最大{MAX_TEXT_LENGTH}文字)", max_length=MAX_TEXT_LENGTH)
 
417
  with gr.Tab(label="ファイルアップロード"):
418
- gr.Markdown(f"1行1事例のテキストファイル(最大{MAX_TEXT_FILE_LINES}行)をアップロードできます。")
 
 
 
 
419
  texts_file = gr.File(label="入力テキストファイル")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
 
421
  with gr.Accordion(label="ハイパーパラメータ", open=False):
422
- topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
423
  entity_span_sensitivity_input = gr.Slider(
424
  minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True
425
  )
@@ -427,25 +417,23 @@ with gr.Blocks() as demo:
427
  minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
428
  )
429
 
430
- text_input.change(fn=lambda text: [normalize_text(text)], inputs=text_input, outputs=input_texts)
431
- texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=input_texts)
432
  topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk)
433
  entity_span_sensitivity_input.change(
434
  fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
435
  )
436
  nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef)
437
 
438
- with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
439
  gr.Markdown(
440
- """LUXEのモデルのエンティティの語彙を任意のエンティティ集合に置き換えます。
441
- エンティティと共に与えられるエンティティの説明文から、エンティティの埋め込みが計算されます。""",
442
  line_breaks=True,
443
  )
444
  gr.Markdown(
445
- f"「エンティティ」と「エンティティの説明文」の2列からなるCSVファイル(最大{MAX_ENTITY_FILE_LINES}行)をアップロードできます。"
446
  )
447
- new_entity_text_pairs_file = gr.File(label="エンティティと説明文のCSVファイル", height="128px")
448
- gr.Markdown("CSVファイルから読み込まれた項目が以下の表に表示されます。表の内容を直接編集することも可能���す。")
449
  new_entity_text_pairs_input = gr.Dataframe(
450
  # value=sample_new_entity_text_pairs,
451
  headers=["entity", "text"],
@@ -454,41 +442,28 @@ with gr.Blocks() as demo:
454
  label="エンティティと説明文",
455
  interactive=True,
456
  )
 
457
  replace_entity_button = gr.Button(value="エンティティ語彙を置き換える")
458
- gr.Markdown("LUXEのモデルのエンティティ語彙は、デモページの再読み込み時にリセットされます。")
459
 
460
  new_entity_text_pairs_file.change(
461
  fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
462
  )
463
  replace_entity_button.click(
464
  fn=replace_entities,
465
- inputs=[models, new_entity_text_pairs_input, entity_replaced_counts],
466
  outputs=entity_replaced_counts,
467
  )
468
 
469
- submit_button = gr.Button(value="予測実行", variant="huggingface")
470
- submit_button.click(
471
- fn=get_topk_entities_from_texts,
472
- inputs=[models, input_texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
473
- outputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
474
- )
475
- text_input.submit(
476
- fn=get_topk_entities_from_texts,
477
- inputs=[models, input_texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
478
- outputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
479
- )
480
-
481
  gr.Markdown("---")
482
- gr.Markdown("## 出力エンティティ")
483
 
484
- @gr.render(
485
- inputs=[output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities]
486
- )
487
  def render_topk_entities(
488
- output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
489
  ):
490
  for text, entity_spans, normal_entities, category_entities, span_entities in zip(
491
- output_texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
492
  ):
493
  highlighted_text_value = []
494
  cur = 0
@@ -503,7 +478,10 @@ with gr.Blocks() as demo:
503
  highlighted_text_value.append((text[cur:], None))
504
 
505
  gr.HighlightedText(
506
- value=highlighted_text_value, color_map={"Entity": "green"}, combine_adjacent=False, label="Text"
 
 
 
507
  )
508
 
509
  # gr.Textbox(text, label="Text")
@@ -512,31 +490,22 @@ with gr.Blocks() as demo:
512
  label="テキスト全体に関連するエンティティ",
513
  components=["text"],
514
  samples=[[entity] for entity in normal_entities],
515
- ).select(fn=get_selected_entity, outputs=selected_entity)
516
  if category_entities:
517
  gr.Dataset(
518
  label="テキスト全体に関連するカテゴリ",
519
  components=["text"],
520
  samples=[[entity] for entity in category_entities],
521
- ).select(fn=get_selected_entity, outputs=selected_entity)
522
-
523
- span_texts = [text[start:end] for start, end in entity_spans]
524
- for span_text, entities in zip(span_texts, span_entities):
525
- gr.Dataset(
526
- label=f"「{span_text}」に対応するエンティティ",
527
- components=["text"],
528
- samples=[[entity] for entity in entities],
529
- ).select(fn=get_selected_entity, outputs=selected_entity)
530
-
531
- # gr.Markdown("---")
532
- # gr.Markdown("## 選択されたエンティティの類似エンティティ")
533
-
534
- # selected_entity.change(fn=get_similar_entities, inputs=[models, selected_entity], outputs=similar_entities)
535
 
536
- # @gr.render(inputs=[selected_entity, similar_entities])
537
- # def render_similar_entities(selected_entity, similar_entities):
538
- # gr.Textbox(selected_entity, label="Selected Entity")
539
- # gr.Dataset(label="Similar Entities", components=["text"], samples=[[entity] for entity in similar_entities])
 
 
 
 
540
 
541
 
542
  demo.launch()
 
2
  import re
3
  import unicodedata
4
  from collections import defaultdict
5
+ from itertools import chain
6
 
7
  import gradio as gr
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
 
11
  from bm25s.hf import BM25HF, TokenizerHF
 
12
  from transformers import AutoModelForPreTraining, AutoTokenizer
13
 
14
 
15
  ALIAS_SEP = "|"
16
+ CATEGORY_ENTITY_PREFIX = "Category:"
17
  ENTITY_SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[MASK]", "[MASK2]"]
18
  MAX_TEXT_LENGTH = 800
19
+ MAX_TEXT_FILE_LINES = 10
20
  MAX_ENTITY_FILE_LINES = 1000
21
 
22
  repo_id = "studio-ousia/luxe"
 
36
  ]
37
 
38
 
39
+ def clean_default_entity_vocab(tokenizer):
40
+ entity_vocab = {}
41
+ for entity, entity_id in tokenizer.entity_vocab.items():
42
+ if entity.startswith("ja:"):
43
+ entity = entity.removeprefix("ja:")
44
+ elif entity.startswith("Category:ja:"):
45
+ entity = "Category:" + entity.removeprefix("Category:ja:")
46
 
47
+ entity_vocab[entity] = entity_id
 
48
 
49
+ tokenizer.entity_vocab = entity_vocab
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  def normalize_text(text: str) -> str:
53
+ return unicodedata.normalize("NFKC", text).strip()
54
 
55
 
56
  def get_texts_from_file(file_path: str | None):
 
61
  reader = csv.DictReader(f, fieldnames=["text"])
62
  for i, row in enumerate(reader):
63
  if i >= MAX_TEXT_FILE_LINES:
64
+ gr.Info(f"{MAX_TEXT_FILE_LINES}行目までのデータを読み込みました。", duration=5)
65
  break
66
 
67
+ text = row["text"]
68
+ if text.strip() != "":
69
  texts.append(text[:MAX_TEXT_LENGTH])
70
  except Exception as e:
71
+ gr.Warning("ファイルを正しく読み込めませんでした。", duration=5)
72
  print(e)
73
  texts = []
74
 
75
  return texts
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def get_token_spans(tokenizer, text: str) -> list[tuple[int, int]]:
79
  token_spans = []
80
  end = 0
 
119
 
120
  def get_topk_entities_from_texts(
121
  models,
122
+ texts: str | list[str],
123
  k: int = 5,
124
  entity_span_sensitivity: float = 1.0,
125
  nayose_coef: float = 1.0,
126
  entity_replaced_counts: bool = False,
127
  ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
128
+ gr.Info("LUXEによる予測を実行しています。", duration=5)
129
+
130
+ if isinstance(texts, str):
131
+ texts = [texts]
132
+
133
  model, tokenizer, bm25_tokenizer, bm25_retriever = models
134
 
135
  batch_entity_spans: list[list[tuple[int, int]]] = []
 
154
  and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
155
  ]
156
 
157
+ entity_k = min(k, len(id2normal_entity))
158
+ category_k = min(k, len(id2category_entity))
159
+
160
  for text in texts:
161
+ text = normalize_text(text).strip()
162
+
163
  tokenized_examples = tokenizer(text, return_tensors="pt")
164
  model_outputs = model(**tokenized_examples)
165
  token_spans = get_token_spans(tokenizer, text)
 
170
  model_outputs = model(**tokenized_examples)
171
 
172
  if model_outputs.topic_entity_logits is not None:
173
+ _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(entity_k)
174
  topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
175
  else:
176
  topk_normal_entities.append([])
177
 
178
  if model_outputs.topic_category_logits is not None:
179
  model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
180
+ _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(category_k)
181
  topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
182
  else:
183
  topk_category_entities.append([])
 
193
  )
194
  span_entity_logits += nayose_coef * nayose_scores
195
 
196
+ _, topk_span_entity_ids = span_entity_logits.topk(entity_k)
197
  topk_span_entities.append(
198
  [[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]
199
  )
 
203
  return texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
204
 
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]]:
207
  new_entity_text_pairs = []
208
  if file_path is not None:
 
211
  reader = csv.DictReader(f, fieldnames=["entity", "text"])
212
  for i, row in enumerate(reader):
213
  if i >= MAX_ENTITY_FILE_LINES:
214
+ gr.Info(f"{MAX_ENTITY_FILE_LINES}行目までのデータを読み込みました。", duration=5)
215
  break
216
 
217
  entity = normalize_text(row["entity"]).strip()
 
219
  if entity != "" and text != "":
220
  new_entity_text_pairs.append([entity, text])
221
  except Exception as e:
222
+ gr.Warning("ファイルを正しく読み込めませんでした。", duration=5)
223
  print(e)
224
  new_entity_text_pairs = []
225
 
 
227
 
228
 
229
  def replace_entities(
230
+ models, new_entity_text_pairs: list[tuple[str, str]], entity_replaced_counts: int, preserve_default_entities: bool
231
+ ) -> int:
232
+ if len(new_entity_text_pairs) == 0:
233
+ return entity_replaced_counts
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ gr.Info("LUXEのモデルとトークナイザのエンティティ語彙を更新しています。完了までお待ちください。", duration=5)
236
 
237
+ model, tokenizer, bm25_tokenizer, bm25_retriever = models
 
 
 
 
 
 
 
238
 
239
+ normal_entity_embeddings = defaultdict(list) # entity -> list of embeddings
240
+ category_entity_embeddings = defaultdict(list) # entity -> list of embeddings
241
+ normal_entity_counts = {} # entity -> count (int)
242
+ category_entity_counts = {} # entity -> count (int)
243
+
244
+ for entity, entity_id in sorted(tokenizer.entity_vocab.items(), key=lambda x: x[1]):
245
+ if entity in ENTITY_SPECIAL_TOKENS or preserve_default_entities:
246
+ entity_embedding = model.luke.entity_embeddings.entity_embeddings.weight.data[entity_id]
247
+ if entity.startswith(CATEGORY_ENTITY_PREFIX):
248
+ category_entity_embeddings[entity].append(entity_embedding)
249
+ if model.config.entity_counts is not None:
250
+ category_entity_counts[entity] = model.config.entity_counts[entity_id]
251
+ else:
252
+ category_entity_counts[entity] = 1
253
+ else:
254
+ normal_entity_embeddings[entity].append(entity_embedding)
255
+ if model.config.entity_counts is not None:
256
+ normal_entity_counts[entity] = model.config.entity_counts[entity_id]
257
+ else:
258
+ normal_entity_counts[entity] = 1
259
 
260
  for entity, text in new_entity_text_pairs:
 
261
  tokenized_inputs = tokenizer(text[:MAX_TEXT_LENGTH], return_tensors="pt")
262
  model_outputs = model(**tokenized_inputs)
263
+ entity_embedding = model.entity_predictions.transform(model_outputs.last_hidden_state[:, 0])[0]
264
+ if entity.startswith(CATEGORY_ENTITY_PREFIX):
265
+ category_entity_embeddings[entity].append(entity_embedding)
266
+ category_entity_counts.setdefault(entity, 1)
267
+ else:
268
+ normal_entity_embeddings[entity].append(entity_embedding)
269
+ normal_entity_counts.setdefault(entity, 1)
270
 
271
+ num_normal_entities = len(normal_entity_embeddings)
272
+ num_category_entities = len(category_entity_embeddings)
273
 
274
+ entity_embeddings = {
275
+ entity: sum(embeddings) / len(embeddings)
276
+ for entity, embeddings in chain(normal_entity_embeddings.items(), category_entity_embeddings.items())
277
+ }
278
+ entity_vocab = {entity: entity_id for entity_id, entity in enumerate(entity_embeddings.keys())}
279
+
280
+ entity_counts = [
281
+ category_entity_counts[entity] if entity.startswith(CATEGORY_ENTITY_PREFIX) else normal_entity_counts[entity]
282
+ for entity in entity_vocab.keys()
283
+ ]
284
 
285
+ tokenizer.entity_vocab = entity_vocab
286
+ tokenizer.entity_pad_token_id = entity_vocab["[PAD]"]
287
+ tokenizer.entity_unk_token_id = entity_vocab["[UNK]"]
288
+ tokenizer.entity_mask_token_id = entity_vocab["[MASK]"]
289
+ tokenizer.entity_mask2_token_id = entity_vocab["[MASK2]"]
290
 
291
+ entity_embeddings_tensor = torch.vstack(list(entity_embeddings.values()))
 
292
 
293
  if model.config.normalize_entity_embeddings:
294
+ entity_embeddings_tensor = F.normalize(entity_embeddings_tensor)
295
 
296
+ entity_vocab_size, entity_emb_size = entity_embeddings_tensor.size()
297
+
298
+ entity_embeddings_module = nn.Embedding(
299
+ entity_vocab_size,
300
+ entity_emb_size,
301
+ padding_idx=tokenizer.entity_pad_token_id,
302
  device=model.luke.entity_embeddings.entity_embeddings.weight.device,
303
  dtype=model.luke.entity_embeddings.entity_embeddings.weight.dtype,
304
  )
305
+ entity_embeddings_module.weight.data = entity_embeddings_tensor.data
306
+ model.luke.entity_embeddings.entity_embeddings = entity_embeddings_module
307
 
308
+ entity_decoder_module = nn.Linear(entity_emb_size, entity_vocab_size, bias=False)
309
+ model.entity_predictions.decoder = entity_decoder_module
310
+ model.entity_predictions.bias = nn.Parameter(torch.zeros(entity_vocab_size))
311
  model.tie_weights()
312
 
313
+ if model.config.entity_counts is not None:
314
+ total_normal_entity_count = sum(entity_counts[:num_normal_entities])
315
+ total_category_entity_count = sum(entity_counts[num_normal_entities:])
316
+
317
+ entity_counts_tensor = torch.tensor(entity_counts, dtype=model.dtype, device=model.device)
318
+ total_entity_counts = torch.tensor(
319
+ [total_normal_entity_count] * num_normal_entities + [total_category_entity_count] * num_category_entities,
320
+ dtype=model.dtype,
321
+ device=model.device,
322
+ )
323
+ entity_log_probs = torch.log(entity_counts_tensor / total_entity_counts)
324
+ model.entity_log_probs = entity_log_probs
325
 
326
+ model.config.entity_vocab_size = entity_vocab_size
327
+ model.config.num_normal_entities = num_normal_entities
328
+ model.config.num_category_entities = num_category_entities
329
+ if model.config.entity_counts is not None:
330
+ model.config.entity_counts = entity_counts
331
 
332
+ gr.Info("LUXEのモデルとトークナイザのエンティティ語彙の更新が完了しました。", duration=5)
333
 
334
  return entity_replaced_counts + 1
335
 
 
341
  bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
342
  bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
343
 
344
+ clean_default_entity_vocab(tokenizer)
345
+
346
  # Hint: gr.State に callable を渡すと、それが state の初期値を設定するための関数とみなされて
347
  # __call__ が引数なしで実行されてしまうため、gr.State の引数に model や tokenizer を単体で渡すとエラーになってしまう。
348
  # ここでは、モデル一式のタプル(callable でない)を渡すことで、そのようなエラーを回避している。
349
  # cf. https://www.gradio.app/docs/gradio/state#param-state-value
350
  models = gr.State((model, tokenizer, bm25_tokenizer, bm25_retriever))
351
 
352
+ texts_input = gr.State([])
 
353
 
354
  entity_replaced_counts = gr.State(0)
355
 
 
357
  entity_span_sensitivity = gr.State(1.0)
358
  nayose_coef = gr.State(1.0)
359
 
360
+ texts = gr.State([])
361
  batch_entity_spans = gr.State([])
362
  topk_normal_entities = gr.State([])
363
  topk_category_entities = gr.State([])
364
  topk_span_entities = gr.State([])
365
 
366
+ gr.Markdown("# 📝 LUXE Demo (β版)")
367
+
368
+ gr.Markdown(
369
+ """Studio Ousia で開発中の次世代知識強化言語モデル **LUXE** の動作デモです。
370
+ 入力されたテキストに対して、テキスト中に出現するエンティティ(事物)と、テキスト全体の主題となるエンティティおよびカテゴリを予測します。
371
 
372
+ デフォルトのLUXEは、エンティティおよびカテゴリとして、それぞれ日本語 Wikipedia における被リンク数上位50万件および10万件の項目を使用しています。
373
+ 予測対象のエンティティを任意のものに置き換えて推論を行うことも可能です(下記「LUXE のエンティティ語彙を置き換える」を参照してください)。""",
374
+ line_breaks=True,
375
+ )
376
 
377
  gr.Markdown("## 入力テキスト")
378
 
379
  with gr.Tab(label="直接入力"):
380
  text_input = gr.Textbox(label=f"入力テキスト(最大{MAX_TEXT_LENGTH}文字)", max_length=MAX_TEXT_LENGTH)
381
+ text_submit_button = gr.Button(value="予測実行", variant="huggingface")
382
  with gr.Tab(label="ファイルアップロード"):
383
+ gr.Markdown(
384
+ f"""1行1事例のテキストファイル(最大{MAX_TEXT_FILE_LINES}行)をアップロードできます。
385
+ アップロードされたテキストのそれぞれに対して推論が実行されます。""",
386
+ line_breaks=True,
387
+ )
388
  texts_file = gr.File(label="入力テキストファイル")
389
+ texts_submit_button = gr.Button(value="予測実行", variant="huggingface")
390
+
391
+ text_input.submit(
392
+ fn=get_topk_entities_from_texts,
393
+ inputs=[models, text_input, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
394
+ outputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
395
+ )
396
+ text_submit_button.click(
397
+ fn=get_topk_entities_from_texts,
398
+ inputs=[models, text_input, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
399
+ outputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
400
+ )
401
+
402
+ texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts_input)
403
+ texts_submit_button.click(
404
+ fn=get_topk_entities_from_texts,
405
+ inputs=[models, texts_input, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
406
+ outputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
407
+ )
408
+
409
+ gr.Markdown("---")
410
 
411
  with gr.Accordion(label="ハイパーパラメータ", open=False):
412
+ topk_input = gr.Number(5, label="予測するエンティティの件数 (Top K)", interactive=True)
413
  entity_span_sensitivity_input = gr.Slider(
414
  minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True
415
  )
 
417
  minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
418
  )
419
 
 
 
420
  topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk)
421
  entity_span_sensitivity_input.change(
422
  fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
423
  )
424
  nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef)
425
 
426
+ with gr.Accordion(label="LUXE のエンティティ語彙を置き換える", open=False):
427
  gr.Markdown(
428
+ """LUXE のモデルとトークナイザのエンティティ語彙を任意のエンティティ集合に置き換えます。
429
+ エンティティとともに与えられるエンティティの説明文から、エンティティの埋め込みが計算され、LUXE の推論に利用されます。""",
430
  line_breaks=True,
431
  )
432
  gr.Markdown(
433
+ f"「エンティティ」と「エンティティの説明文」の2列からなる CSV ファイル(最大{MAX_ENTITY_FILE_LINES}行)をアップロードできます。"
434
  )
435
+ new_entity_text_pairs_file = gr.File(label="エンティティと説明文の CSV ファイル", height="128px")
436
+ gr.Markdown("CSV ファイルから読み込まれた項目が以下の表に表示されます。表の内容を直接編集することも可能です。")
437
  new_entity_text_pairs_input = gr.Dataframe(
438
  # value=sample_new_entity_text_pairs,
439
  headers=["entity", "text"],
 
442
  label="エンティティと説明文",
443
  interactive=True,
444
  )
445
+ preserve_default_entities_checkbox = gr.Checkbox(label="既存のエンティティを保持する", value=True)
446
  replace_entity_button = gr.Button(value="エンティティ語彙を置き換える")
447
+ gr.Markdown("LUXE のモデルのエンティティ語彙は、デモページの再読み込み時にリセットされます。")
448
 
449
  new_entity_text_pairs_file.change(
450
  fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
451
  )
452
  replace_entity_button.click(
453
  fn=replace_entities,
454
+ inputs=[models, new_entity_text_pairs_input, entity_replaced_counts, preserve_default_entities_checkbox],
455
  outputs=entity_replaced_counts,
456
  )
457
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  gr.Markdown("---")
459
+ gr.Markdown("## 予測されたエンティティとカテゴリ")
460
 
461
+ @gr.render(inputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities])
 
 
462
  def render_topk_entities(
463
+ texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
464
  ):
465
  for text, entity_spans, normal_entities, category_entities, span_entities in zip(
466
+ texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
467
  ):
468
  highlighted_text_value = []
469
  cur = 0
 
478
  highlighted_text_value.append((text[cur:], None))
479
 
480
  gr.HighlightedText(
481
+ value=highlighted_text_value,
482
+ color_map={"Entity": "green"},
483
+ combine_adjacent=False,
484
+ label="予測されたエンティティのスパン",
485
  )
486
 
487
  # gr.Textbox(text, label="Text")
 
490
  label="テキスト全体に関連するエンティティ",
491
  components=["text"],
492
  samples=[[entity] for entity in normal_entities],
493
+ )
494
  if category_entities:
495
  gr.Dataset(
496
  label="テキスト全体に関連するカテゴリ",
497
  components=["text"],
498
  samples=[[entity] for entity in category_entities],
499
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
500
 
501
+ with gr.Accordion(label="テキスト中のスパンに対応するエンティティ", open=len(texts) == 1):
502
+ span_texts = [text[start:end] for start, end in entity_spans]
503
+ for span_text, entities in zip(span_texts, span_entities):
504
+ gr.Dataset(
505
+ label=f"「{span_text}」に対応するエンティティ",
506
+ components=["text"],
507
+ samples=[[entity] for entity in entities],
508
+ )
509
 
510
 
511
  demo.launch()