singletongue commited on
Commit
d6ab44d
·
verified ·
1 Parent(s): 643182d

Implement replacing of model/tokenizer entities

Browse files
Files changed (1) hide show
  1. app.py +234 -65
app.py CHANGED
@@ -1,9 +1,13 @@
 
1
  import re
2
  import unicodedata
 
3
  from pathlib import Path
4
 
5
  import gradio as gr
6
  import torch
 
 
7
  import unidic_lite
8
  from bm25s.hf import BM25HF, TokenizerHF
9
  from fugashi import GenericTagger
@@ -11,6 +15,7 @@ from transformers import AutoModelForPreTraining, AutoTokenizer
11
 
12
 
13
  ALIAS_SEP = "|"
 
14
 
15
  repo_id = "studio-ousia/luxe"
16
  revision = "ja-v0.3.1"
@@ -31,28 +36,6 @@ ignore_category_patterns = [
31
  model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
32
  tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
33
 
34
- num_normal_entities = len(tokenizer.entity_vocab) - model.config.num_category_entities
35
- num_category_entities = model.config.num_category_entities
36
-
37
- id2normal_entity = {
38
- entity_id: entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id < num_normal_entities
39
- }
40
-
41
- id2category_entity = {
42
- entity_id - num_normal_entities: entity
43
- for entity, entity_id in tokenizer.entity_vocab.items()
44
- if entity_id >= num_normal_entities
45
- }
46
- ignore_category_entity_ids = [
47
- entity_id - num_normal_entities
48
- for entity, entity_id in tokenizer.entity_vocab.items()
49
- if entity_id >= num_normal_entities and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
50
- ]
51
-
52
- entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
53
- normal_entity_embeddings = entity_embeddings[:num_normal_entities]
54
- category_entity_embeddings = entity_embeddings[num_normal_entities:]
55
-
56
 
57
  class MecabTokenizer:
58
  def __init__(self):
@@ -87,13 +70,20 @@ bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
87
  bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
88
 
89
 
90
- def get_texts_from_file(file_path):
91
  texts = []
92
- with open(file_path) as f:
93
- for line in f:
94
- line = line.strip()
95
- if line:
96
- texts.append(normalize_text(line))
 
 
 
 
 
 
 
97
 
98
  return texts
99
 
@@ -136,33 +126,55 @@ def get_predicted_entity_spans(
136
  probs_sorted, sort_idxs = ner_probs.flatten().sort(descending=True)
137
 
138
  predicted_entity_spans = []
139
- for p, i in zip(probs_sorted, sort_idxs.tolist()):
140
- if p < 10.0 ** (-1.0 * entity_span_sensitivity):
141
- break
 
142
 
143
- start_idx = i // length
144
- end_idx = i % length
145
 
146
- start = token_spans[start_idx][0]
147
- end = token_spans[end_idx][1]
148
 
149
- for ex_start, ex_end in predicted_entity_spans:
150
- if not (start < end <= ex_start or ex_end <= start < end):
151
- break
152
- else:
153
- predicted_entity_spans.append((start, end))
154
 
155
  return sorted(predicted_entity_spans)
156
 
157
 
158
  def get_topk_entities_from_texts(
159
- texts: list[str], k: int = 5, entity_span_sensitivity: float = 1.0, nayose_coef: float = 1.0
 
 
 
 
160
  ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
161
  batch_entity_spans: list[list[tuple[int, int]]] = []
162
  topk_normal_entities: list[list[str]] = []
163
  topk_category_entities: list[list[str]] = []
164
  topk_span_entities: list[list[list[str]]] = []
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  for text in texts:
167
  tokenized_examples = tokenizer(text, return_tensors="pt")
168
  model_outputs = model(**tokenized_examples)
@@ -173,18 +185,23 @@ def get_topk_entities_from_texts(
173
  tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, return_tensors="pt")
174
  model_outputs = model(**tokenized_examples)
175
 
176
- model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
177
-
178
- _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
179
- topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
 
180
 
181
- _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
182
- topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
 
 
 
 
183
 
184
  if model_outputs.entity_logits is not None:
185
  span_entity_logits = model_outputs.entity_logits[0, :, :500000]
186
 
187
- if nayose_coef > 0.0:
188
  nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans]
189
  nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries)
190
  nayose_scores = torch.vstack(
@@ -209,12 +226,32 @@ def get_selected_entity(evt: gr.SelectData):
209
  def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
210
  query_entity_id = tokenizer.entity_vocab[query_entity]
211
 
212
- if query_entity_id < num_normal_entities:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T
214
  topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
215
  topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
216
  else:
217
- query_entity_id -= num_normal_entities
218
  topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
219
 
220
  topk_entity_scores[ignore_category_entity_ids] = float("-inf")
@@ -225,31 +262,157 @@ def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
225
  return topk_entities
226
 
227
 
228
- with gr.Blocks() as demo:
229
- gr.Markdown("# 📝 LUXE Demo")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
- gr.Markdown("## 入力テキスト")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  texts = gr.State([])
 
 
 
234
  topk = gr.State(5)
235
  entity_span_sensitivity = gr.State(1.0)
236
  nayose_coef = gr.State(1.0)
 
237
  batch_entity_spans = gr.State([])
238
  topk_normal_entities = gr.State([])
239
  topk_category_entities = gr.State([])
240
  topk_span_entities = gr.State([])
 
241
  selected_entity = gr.State()
242
  similar_entities = gr.State([])
243
 
 
 
 
 
244
  with gr.Tab(label="直接入力"):
245
  text_input = gr.Textbox(label="入力テキスト")
246
  with gr.Tab(label="ファイルアップロード"):
247
  texts_file = gr.File(label="入力テキストファイル")
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  with gr.Accordion(label="ハイパーパラメータ", open=False):
250
  topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
251
  entity_span_sensitivity_input = gr.Slider(
252
- minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True
253
  )
254
  nayose_coef_input = gr.Slider(
255
  minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
@@ -265,22 +428,22 @@ with gr.Blocks() as demo:
265
 
266
  texts.change(
267
  fn=get_topk_entities_from_texts,
268
- inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
269
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
270
  )
271
  topk.change(
272
  fn=get_topk_entities_from_texts,
273
- inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
274
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
275
  )
276
  entity_span_sensitivity.change(
277
  fn=get_topk_entities_from_texts,
278
- inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
279
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
280
  )
281
  nayose_coef.change(
282
  fn=get_topk_entities_from_texts,
283
- inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
284
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
285
  )
286
  topk_input.change(inputs=topk_input, outputs=topk)
@@ -312,17 +475,23 @@ with gr.Blocks() as demo:
312
  )
313
 
314
  # gr.Textbox(text, label="Text")
315
- gr.Dataset(
316
- label="Topic Entities", components=["text"], samples=[[entity] for entity in normal_entities]
317
- ).select(fn=get_selected_entity, outputs=selected_entity)
318
- gr.Dataset(
319
- label="Topic Categories", components=["text"], samples=[[entity] for entity in category_entities]
320
- ).select(fn=get_selected_entity, outputs=selected_entity)
 
 
 
 
 
 
321
 
322
  span_texts = [text[start:end] for start, end in entity_spans]
323
  for span_text, entities in zip(span_texts, span_entities):
324
  gr.Dataset(
325
- label=f"Span Entities for {span_text}",
326
  components=["text"],
327
  samples=[[entity] for entity in entities],
328
  ).select(fn=get_selected_entity, outputs=selected_entity)
 
1
+ import csv
2
  import re
3
  import unicodedata
4
+ from collections import defaultdict
5
  from pathlib import Path
6
 
7
  import gradio as gr
8
  import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
  import unidic_lite
12
  from bm25s.hf import BM25HF, TokenizerHF
13
  from fugashi import GenericTagger
 
15
 
16
 
17
  ALIAS_SEP = "|"
18
+ ENTITY_SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[MASK]", "[MASK2]"]
19
 
20
  repo_id = "studio-ousia/luxe"
21
  revision = "ja-v0.3.1"
 
36
  model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
37
  tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  class MecabTokenizer:
41
  def __init__(self):
 
70
  bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
71
 
72
 
73
+ def get_texts_from_file(file_path: str | None):
74
  texts = []
75
+ if file_path is not None:
76
+ try:
77
+ with open(file_path, newline="") as f:
78
+ reader = csv.DictReader(f, fieldnames=["text"])
79
+ for row in reader:
80
+ text = normalize_text(row["text"]).strip()
81
+ if text != "":
82
+ texts.append(text)
83
+ except Exception as e:
84
+ gr.Warning("ファイルを正しく読み込めませんでした。")
85
+ print(e)
86
+ texts = []
87
 
88
  return texts
89
 
 
126
  probs_sorted, sort_idxs = ner_probs.flatten().sort(descending=True)
127
 
128
  predicted_entity_spans = []
129
+ if entity_span_sensitivity > 0.0:
130
+ for p, i in zip(probs_sorted, sort_idxs.tolist()):
131
+ if p < 10.0 ** (-1.0 * entity_span_sensitivity):
132
+ break
133
 
134
+ start_idx = i // length
135
+ end_idx = i % length
136
 
137
+ start = token_spans[start_idx][0]
138
+ end = token_spans[end_idx][1]
139
 
140
+ for ex_start, ex_end in predicted_entity_spans:
141
+ if not (start < end <= ex_start or ex_end <= start < end):
142
+ break
143
+ else:
144
+ predicted_entity_spans.append((start, end))
145
 
146
  return sorted(predicted_entity_spans)
147
 
148
 
149
  def get_topk_entities_from_texts(
150
+ texts: list[str],
151
+ k: int = 5,
152
+ entity_span_sensitivity: float = 1.0,
153
+ nayose_coef: float = 1.0,
154
+ entities_are_replaced: bool = False,
155
  ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
156
  batch_entity_spans: list[list[tuple[int, int]]] = []
157
  topk_normal_entities: list[list[str]] = []
158
  topk_category_entities: list[list[str]] = []
159
  topk_span_entities: list[list[list[str]]] = []
160
 
161
+ id2normal_entity = {
162
+ entity_id: entity
163
+ for entity, entity_id in tokenizer.entity_vocab.items()
164
+ if entity_id < model.config.num_normal_entities
165
+ }
166
+ id2category_entity = {
167
+ entity_id - model.config.num_normal_entities: entity
168
+ for entity, entity_id in tokenizer.entity_vocab.items()
169
+ if entity_id >= model.config.num_normal_entities
170
+ }
171
+ ignore_category_entity_ids = [
172
+ entity_id - model.config.num_normal_entities
173
+ for entity, entity_id in tokenizer.entity_vocab.items()
174
+ if entity_id >= model.config.num_normal_entities
175
+ and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
176
+ ]
177
+
178
  for text in texts:
179
  tokenized_examples = tokenizer(text, return_tensors="pt")
180
  model_outputs = model(**tokenized_examples)
 
185
  tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, return_tensors="pt")
186
  model_outputs = model(**tokenized_examples)
187
 
188
+ if model_outputs.topic_entity_logits is not None:
189
+ _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
190
+ topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
191
+ else:
192
+ topk_normal_entities.append([])
193
 
194
+ if model_outputs.topic_category_logits is not None:
195
+ model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
196
+ _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
197
+ topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
198
+ else:
199
+ topk_category_entities.append([])
200
 
201
  if model_outputs.entity_logits is not None:
202
  span_entity_logits = model_outputs.entity_logits[0, :, :500000]
203
 
204
+ if nayose_coef > 0.0 and not entities_are_replaced:
205
  nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans]
206
  nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries)
207
  nayose_scores = torch.vstack(
 
226
  def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
227
  query_entity_id = tokenizer.entity_vocab[query_entity]
228
 
229
+ id2normal_entity = {
230
+ entity_id: entity
231
+ for entity, entity_id in tokenizer.entity_vocab.items()
232
+ if entity_id < model.config.num_normal_entities
233
+ }
234
+ id2category_entity = {
235
+ entity_id - model.config.num_normal_entities: entity
236
+ for entity, entity_id in tokenizer.entity_vocab.items()
237
+ if entity_id >= model.config.num_normal_entities
238
+ }
239
+ ignore_category_entity_ids = [
240
+ entity_id - model.config.num_normal_entities
241
+ for entity, entity_id in tokenizer.entity_vocab.items()
242
+ if entity_id >= model.config.num_normal_entities
243
+ and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
244
+ ]
245
+ entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
246
+ normal_entity_embeddings = entity_embeddings[: model.config.num_normal_entities]
247
+ category_entity_embeddings = entity_embeddings[model.config.num_normal_entities :]
248
+
249
+ if query_entity_id < model.config.num_normal_entities:
250
  topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T
251
  topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
252
  topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
253
  else:
254
+ query_entity_id -= model.config.num_normal_entities
255
  topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
256
 
257
  topk_entity_scores[ignore_category_entity_ids] = float("-inf")
 
262
  return topk_entities
263
 
264
 
265
+ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]]:
266
+ new_entity_text_pairs = []
267
+ if file_path is not None:
268
+ try:
269
+ with open(file_path, newline="") as f:
270
+ reader = csv.DictReader(f, fieldnames=["entity", "text"])
271
+ for row in reader:
272
+ entity = normalize_text(row["entity"]).strip()
273
+ text = normalize_text(row["text"]).strip()
274
+ if entity != "" and text != "":
275
+ new_entity_text_pairs.append([entity, text])
276
+ except Exception as e:
277
+ gr.Warning("ファイルを正しく読み込めませんでした。")
278
+ print(e)
279
+ new_entity_text_pairs = []
280
+
281
+ return new_entity_text_pairs
282
+
283
+
284
+ def replace_entities(
285
+ new_entity_text_pairs: list[tuple[str, str]],
286
+ new_num_category_entities: int = 0,
287
+ new_entity_counts: list[int] | None = None,
288
+ new_padding_idx: int = 0,
289
+ ) -> True:
290
+ gr.Info("トークナイザのエンティティの語彙を置き換えています...", duration=5)
291
+ new_entity_tokens = ENTITY_SPECIAL_TOKENS + [entity for entity, _ in new_entity_text_pairs]
292
+
293
+ new_entity_vocab = {}
294
+ for entity in new_entity_tokens:
295
+ if entity not in new_entity_vocab:
296
+ new_entity_vocab[entity] = len(new_entity_vocab)
297
+
298
+ new_entity_vocab = {entity: entity_id for entity_id, entity in enumerate(new_entity_tokens)}
299
+
300
+ tokenizer.entity_vocab = new_entity_vocab
301
+ tokenizer.entity_pad_token_id = tokenizer.entity_vocab["[PAD]"]
302
+ tokenizer.entity_unk_token_id = tokenizer.entity_vocab["[UNK]"]
303
+ tokenizer.entity_mask_token_id = tokenizer.entity_vocab["[MASK]"]
304
+ tokenizer.entity_mask2_token_id = tokenizer.entity_vocab["[MASK2]"]
305
+
306
+ gr.Info("モデルのエンティティの埋め込みを置き換えています...", duration=5)
307
+ new_entity_embeddings_dict = defaultdict(list)
308
+
309
+ for entity_special_token in ENTITY_SPECIAL_TOKENS:
310
+ entity_special_token_id = tokenizer.entity_vocab[entity_special_token]
311
+ new_entity_embeddings_dict[entity_special_token_id].append(
312
+ model.luke.entity_embeddings.entity_embeddings.weight.data[entity_special_token_id]
313
+ )
314
 
315
+ for entity, text in new_entity_text_pairs:
316
+ entity_id = tokenizer.entity_vocab[entity]
317
+ tokenized_inputs = tokenizer(text, return_tensors="pt")
318
+ model_outputs = model(**tokenized_inputs)
319
+ entity_embeddings = model.entity_predictions.transform(model_outputs.last_hidden_state[:, 0])
320
+ new_entity_embeddings_dict[entity_id].append(entity_embeddings[0])
321
+
322
+ assert len(new_entity_embeddings_dict) == len(tokenizer.entity_vocab)
323
+
324
+ new_entity_embeddings = torch.vstack(
325
+ [
326
+ sum(new_entity_embeddings_dict[i]) / len(new_entity_embeddings_dict[i])
327
+ for i in range(len(new_entity_embeddings_dict))
328
+ ]
329
+ )
330
+ new_entity_vocab_size, new_entity_emb_size = new_entity_embeddings.size()
331
+ assert new_entity_vocab_size == len(tokenizer.entity_vocab)
332
+
333
+ new_num_normal_entities = new_entity_vocab_size - new_num_category_entities
334
+
335
+ if new_entity_counts is not None and any(count < 1 for count in new_entity_counts):
336
+ raise ValueError("All items in new_entity_counts must be greater than zero")
337
+
338
+ if model.config.normalize_entity_embeddings:
339
+ new_entity_embeddings = F.normalize(new_entity_embeddings)
340
+
341
+ new_entity_embeddings_module = nn.Embedding(
342
+ new_entity_vocab_size,
343
+ new_entity_emb_size,
344
+ padding_idx=new_padding_idx,
345
+ device=model.luke.entity_embeddings.entity_embeddings.weight.device,
346
+ dtype=model.luke.entity_embeddings.entity_embeddings.weight.dtype,
347
+ )
348
+ new_entity_embeddings_module.weight.data = new_entity_embeddings.data
349
+ model.luke.entity_embeddings.entity_embeddings = new_entity_embeddings_module
350
+
351
+ new_entity_decoder_module = nn.Linear(new_entity_emb_size, new_entity_vocab_size, bias=False)
352
+ model.entity_predictions.decoder = new_entity_decoder_module
353
+ model.entity_predictions.bias = nn.Parameter(torch.zeros(new_entity_vocab_size))
354
+ model.tie_weights()
355
 
356
+ if hasattr(model, "entity_log_probs"):
357
+ del model.entity_log_probs
358
+
359
+ model.config.entity_vocab_size = new_entity_vocab_size
360
+ model.config.num_normal_entities = new_num_normal_entities
361
+ model.config.num_category_entities = new_num_category_entities
362
+ model.config.entity_counts = new_entity_counts
363
+
364
+ gr.Info("モデルとトークナイザのエンティティの置き換えが完了しました", duration=5)
365
+
366
+ return True
367
+
368
+
369
+ with gr.Blocks() as demo:
370
  texts = gr.State([])
371
+
372
+ entities_are_replaced = gr.State(False)
373
+
374
  topk = gr.State(5)
375
  entity_span_sensitivity = gr.State(1.0)
376
  nayose_coef = gr.State(1.0)
377
+
378
  batch_entity_spans = gr.State([])
379
  topk_normal_entities = gr.State([])
380
  topk_category_entities = gr.State([])
381
  topk_span_entities = gr.State([])
382
+
383
  selected_entity = gr.State()
384
  similar_entities = gr.State([])
385
 
386
+ gr.Markdown("# 📝 LUXE Demo")
387
+
388
+ gr.Markdown("## 入力テキスト")
389
+
390
  with gr.Tab(label="直接入力"):
391
  text_input = gr.Textbox(label="入力テキスト")
392
  with gr.Tab(label="ファイルアップロード"):
393
  texts_file = gr.File(label="入力テキストファイル")
394
 
395
+ with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
396
+ new_entity_text_pairs_file = gr.File(label="エンティティと説明文のCSVファイル")
397
+ new_entity_text_pairs_input = gr.Dataframe(
398
+ # value=sample_new_entity_text_pairs,
399
+ headers=["entity", "text"],
400
+ col_count=(2, "fixed"),
401
+ type="array",
402
+ label="エンティティと��明文",
403
+ interactive=True,
404
+ )
405
+ replace_entity_button = gr.Button(value="エンティティ語彙を置き換える")
406
+
407
+ new_entity_text_pairs_file.change(
408
+ fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
409
+ )
410
+ replace_entity_button.click(fn=replace_entities, inputs=new_entity_text_pairs_input, outputs=entities_are_replaced)
411
+
412
  with gr.Accordion(label="ハイパーパラメータ", open=False):
413
  topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
414
  entity_span_sensitivity_input = gr.Slider(
415
+ minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True
416
  )
417
  nayose_coef_input = gr.Slider(
418
  minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
 
428
 
429
  texts.change(
430
  fn=get_topk_entities_from_texts,
431
+ inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
432
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
433
  )
434
  topk.change(
435
  fn=get_topk_entities_from_texts,
436
+ inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
437
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
438
  )
439
  entity_span_sensitivity.change(
440
  fn=get_topk_entities_from_texts,
441
+ inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
442
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
443
  )
444
  nayose_coef.change(
445
  fn=get_topk_entities_from_texts,
446
+ inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
447
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
448
  )
449
  topk_input.change(inputs=topk_input, outputs=topk)
 
475
  )
476
 
477
  # gr.Textbox(text, label="Text")
478
+ if normal_entities:
479
+ gr.Dataset(
480
+ label="テキスト全体に関連するエンティティ",
481
+ components=["text"],
482
+ samples=[[entity] for entity in normal_entities],
483
+ ).select(fn=get_selected_entity, outputs=selected_entity)
484
+ if category_entities:
485
+ gr.Dataset(
486
+ label="テキスト全体に関連するカテゴリ",
487
+ components=["text"],
488
+ samples=[[entity] for entity in category_entities],
489
+ ).select(fn=get_selected_entity, outputs=selected_entity)
490
 
491
  span_texts = [text[start:end] for start, end in entity_spans]
492
  for span_text, entities in zip(span_texts, span_entities):
493
  gr.Dataset(
494
+ label=f"{span_text}」に対応するエンティティ",
495
  components=["text"],
496
  samples=[[entity] for entity in entities],
497
  ).select(fn=get_selected_entity, outputs=selected_entity)