Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Implement replacing of model/tokenizer entities
Browse files
app.py
CHANGED
@@ -1,9 +1,13 @@
|
|
|
|
1 |
import re
|
2 |
import unicodedata
|
|
|
3 |
from pathlib import Path
|
4 |
|
5 |
import gradio as gr
|
6 |
import torch
|
|
|
|
|
7 |
import unidic_lite
|
8 |
from bm25s.hf import BM25HF, TokenizerHF
|
9 |
from fugashi import GenericTagger
|
@@ -11,6 +15,7 @@ from transformers import AutoModelForPreTraining, AutoTokenizer
|
|
11 |
|
12 |
|
13 |
ALIAS_SEP = "|"
|
|
|
14 |
|
15 |
repo_id = "studio-ousia/luxe"
|
16 |
revision = "ja-v0.3.1"
|
@@ -31,28 +36,6 @@ ignore_category_patterns = [
|
|
31 |
model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
32 |
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
33 |
|
34 |
-
num_normal_entities = len(tokenizer.entity_vocab) - model.config.num_category_entities
|
35 |
-
num_category_entities = model.config.num_category_entities
|
36 |
-
|
37 |
-
id2normal_entity = {
|
38 |
-
entity_id: entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id < num_normal_entities
|
39 |
-
}
|
40 |
-
|
41 |
-
id2category_entity = {
|
42 |
-
entity_id - num_normal_entities: entity
|
43 |
-
for entity, entity_id in tokenizer.entity_vocab.items()
|
44 |
-
if entity_id >= num_normal_entities
|
45 |
-
}
|
46 |
-
ignore_category_entity_ids = [
|
47 |
-
entity_id - num_normal_entities
|
48 |
-
for entity, entity_id in tokenizer.entity_vocab.items()
|
49 |
-
if entity_id >= num_normal_entities and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
|
50 |
-
]
|
51 |
-
|
52 |
-
entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
|
53 |
-
normal_entity_embeddings = entity_embeddings[:num_normal_entities]
|
54 |
-
category_entity_embeddings = entity_embeddings[num_normal_entities:]
|
55 |
-
|
56 |
|
57 |
class MecabTokenizer:
|
58 |
def __init__(self):
|
@@ -87,13 +70,20 @@ bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
|
|
87 |
bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
|
88 |
|
89 |
|
90 |
-
def get_texts_from_file(file_path):
|
91 |
texts = []
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
return texts
|
99 |
|
@@ -136,33 +126,55 @@ def get_predicted_entity_spans(
|
|
136 |
probs_sorted, sort_idxs = ner_probs.flatten().sort(descending=True)
|
137 |
|
138 |
predicted_entity_spans = []
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
142 |
|
143 |
-
|
144 |
-
|
145 |
|
146 |
-
|
147 |
-
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
|
155 |
return sorted(predicted_entity_spans)
|
156 |
|
157 |
|
158 |
def get_topk_entities_from_texts(
|
159 |
-
texts: list[str],
|
|
|
|
|
|
|
|
|
160 |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
|
161 |
batch_entity_spans: list[list[tuple[int, int]]] = []
|
162 |
topk_normal_entities: list[list[str]] = []
|
163 |
topk_category_entities: list[list[str]] = []
|
164 |
topk_span_entities: list[list[list[str]]] = []
|
165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
for text in texts:
|
167 |
tokenized_examples = tokenizer(text, return_tensors="pt")
|
168 |
model_outputs = model(**tokenized_examples)
|
@@ -173,18 +185,23 @@ def get_topk_entities_from_texts(
|
|
173 |
tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, return_tensors="pt")
|
174 |
model_outputs = model(**tokenized_examples)
|
175 |
|
176 |
-
model_outputs.
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
180 |
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
183 |
|
184 |
if model_outputs.entity_logits is not None:
|
185 |
span_entity_logits = model_outputs.entity_logits[0, :, :500000]
|
186 |
|
187 |
-
if nayose_coef > 0.0:
|
188 |
nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans]
|
189 |
nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries)
|
190 |
nayose_scores = torch.vstack(
|
@@ -209,12 +226,32 @@ def get_selected_entity(evt: gr.SelectData):
|
|
209 |
def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
|
210 |
query_entity_id = tokenizer.entity_vocab[query_entity]
|
211 |
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T
|
214 |
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
|
215 |
topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
|
216 |
else:
|
217 |
-
query_entity_id -= num_normal_entities
|
218 |
topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
|
219 |
|
220 |
topk_entity_scores[ignore_category_entity_ids] = float("-inf")
|
@@ -225,31 +262,157 @@ def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
|
|
225 |
return topk_entities
|
226 |
|
227 |
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
texts = gr.State([])
|
|
|
|
|
|
|
234 |
topk = gr.State(5)
|
235 |
entity_span_sensitivity = gr.State(1.0)
|
236 |
nayose_coef = gr.State(1.0)
|
|
|
237 |
batch_entity_spans = gr.State([])
|
238 |
topk_normal_entities = gr.State([])
|
239 |
topk_category_entities = gr.State([])
|
240 |
topk_span_entities = gr.State([])
|
|
|
241 |
selected_entity = gr.State()
|
242 |
similar_entities = gr.State([])
|
243 |
|
|
|
|
|
|
|
|
|
244 |
with gr.Tab(label="直接入力"):
|
245 |
text_input = gr.Textbox(label="入力テキスト")
|
246 |
with gr.Tab(label="ファイルアップロード"):
|
247 |
texts_file = gr.File(label="入力テキストファイル")
|
248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
with gr.Accordion(label="ハイパーパラメータ", open=False):
|
250 |
topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
|
251 |
entity_span_sensitivity_input = gr.Slider(
|
252 |
-
minimum=0.
|
253 |
)
|
254 |
nayose_coef_input = gr.Slider(
|
255 |
minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
|
@@ -265,22 +428,22 @@ with gr.Blocks() as demo:
|
|
265 |
|
266 |
texts.change(
|
267 |
fn=get_topk_entities_from_texts,
|
268 |
-
inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
|
269 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
270 |
)
|
271 |
topk.change(
|
272 |
fn=get_topk_entities_from_texts,
|
273 |
-
inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
|
274 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
275 |
)
|
276 |
entity_span_sensitivity.change(
|
277 |
fn=get_topk_entities_from_texts,
|
278 |
-
inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
|
279 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
280 |
)
|
281 |
nayose_coef.change(
|
282 |
fn=get_topk_entities_from_texts,
|
283 |
-
inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
|
284 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
285 |
)
|
286 |
topk_input.change(inputs=topk_input, outputs=topk)
|
@@ -312,17 +475,23 @@ with gr.Blocks() as demo:
|
|
312 |
)
|
313 |
|
314 |
# gr.Textbox(text, label="Text")
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
321 |
|
322 |
span_texts = [text[start:end] for start, end in entity_spans]
|
323 |
for span_text, entities in zip(span_texts, span_entities):
|
324 |
gr.Dataset(
|
325 |
-
label=f"
|
326 |
components=["text"],
|
327 |
samples=[[entity] for entity in entities],
|
328 |
).select(fn=get_selected_entity, outputs=selected_entity)
|
|
|
1 |
+
import csv
|
2 |
import re
|
3 |
import unicodedata
|
4 |
+
from collections import defaultdict
|
5 |
from pathlib import Path
|
6 |
|
7 |
import gradio as gr
|
8 |
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
import unidic_lite
|
12 |
from bm25s.hf import BM25HF, TokenizerHF
|
13 |
from fugashi import GenericTagger
|
|
|
15 |
|
16 |
|
17 |
ALIAS_SEP = "|"
|
18 |
+
ENTITY_SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[MASK]", "[MASK2]"]
|
19 |
|
20 |
repo_id = "studio-ousia/luxe"
|
21 |
revision = "ja-v0.3.1"
|
|
|
36 |
model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
37 |
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
class MecabTokenizer:
|
41 |
def __init__(self):
|
|
|
70 |
bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
|
71 |
|
72 |
|
73 |
+
def get_texts_from_file(file_path: str | None):
|
74 |
texts = []
|
75 |
+
if file_path is not None:
|
76 |
+
try:
|
77 |
+
with open(file_path, newline="") as f:
|
78 |
+
reader = csv.DictReader(f, fieldnames=["text"])
|
79 |
+
for row in reader:
|
80 |
+
text = normalize_text(row["text"]).strip()
|
81 |
+
if text != "":
|
82 |
+
texts.append(text)
|
83 |
+
except Exception as e:
|
84 |
+
gr.Warning("ファイルを正しく読み込めませんでした。")
|
85 |
+
print(e)
|
86 |
+
texts = []
|
87 |
|
88 |
return texts
|
89 |
|
|
|
126 |
probs_sorted, sort_idxs = ner_probs.flatten().sort(descending=True)
|
127 |
|
128 |
predicted_entity_spans = []
|
129 |
+
if entity_span_sensitivity > 0.0:
|
130 |
+
for p, i in zip(probs_sorted, sort_idxs.tolist()):
|
131 |
+
if p < 10.0 ** (-1.0 * entity_span_sensitivity):
|
132 |
+
break
|
133 |
|
134 |
+
start_idx = i // length
|
135 |
+
end_idx = i % length
|
136 |
|
137 |
+
start = token_spans[start_idx][0]
|
138 |
+
end = token_spans[end_idx][1]
|
139 |
|
140 |
+
for ex_start, ex_end in predicted_entity_spans:
|
141 |
+
if not (start < end <= ex_start or ex_end <= start < end):
|
142 |
+
break
|
143 |
+
else:
|
144 |
+
predicted_entity_spans.append((start, end))
|
145 |
|
146 |
return sorted(predicted_entity_spans)
|
147 |
|
148 |
|
149 |
def get_topk_entities_from_texts(
|
150 |
+
texts: list[str],
|
151 |
+
k: int = 5,
|
152 |
+
entity_span_sensitivity: float = 1.0,
|
153 |
+
nayose_coef: float = 1.0,
|
154 |
+
entities_are_replaced: bool = False,
|
155 |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
|
156 |
batch_entity_spans: list[list[tuple[int, int]]] = []
|
157 |
topk_normal_entities: list[list[str]] = []
|
158 |
topk_category_entities: list[list[str]] = []
|
159 |
topk_span_entities: list[list[list[str]]] = []
|
160 |
|
161 |
+
id2normal_entity = {
|
162 |
+
entity_id: entity
|
163 |
+
for entity, entity_id in tokenizer.entity_vocab.items()
|
164 |
+
if entity_id < model.config.num_normal_entities
|
165 |
+
}
|
166 |
+
id2category_entity = {
|
167 |
+
entity_id - model.config.num_normal_entities: entity
|
168 |
+
for entity, entity_id in tokenizer.entity_vocab.items()
|
169 |
+
if entity_id >= model.config.num_normal_entities
|
170 |
+
}
|
171 |
+
ignore_category_entity_ids = [
|
172 |
+
entity_id - model.config.num_normal_entities
|
173 |
+
for entity, entity_id in tokenizer.entity_vocab.items()
|
174 |
+
if entity_id >= model.config.num_normal_entities
|
175 |
+
and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
|
176 |
+
]
|
177 |
+
|
178 |
for text in texts:
|
179 |
tokenized_examples = tokenizer(text, return_tensors="pt")
|
180 |
model_outputs = model(**tokenized_examples)
|
|
|
185 |
tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, return_tensors="pt")
|
186 |
model_outputs = model(**tokenized_examples)
|
187 |
|
188 |
+
if model_outputs.topic_entity_logits is not None:
|
189 |
+
_, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
|
190 |
+
topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
|
191 |
+
else:
|
192 |
+
topk_normal_entities.append([])
|
193 |
|
194 |
+
if model_outputs.topic_category_logits is not None:
|
195 |
+
model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
|
196 |
+
_, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
|
197 |
+
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
|
198 |
+
else:
|
199 |
+
topk_category_entities.append([])
|
200 |
|
201 |
if model_outputs.entity_logits is not None:
|
202 |
span_entity_logits = model_outputs.entity_logits[0, :, :500000]
|
203 |
|
204 |
+
if nayose_coef > 0.0 and not entities_are_replaced:
|
205 |
nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans]
|
206 |
nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries)
|
207 |
nayose_scores = torch.vstack(
|
|
|
226 |
def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
|
227 |
query_entity_id = tokenizer.entity_vocab[query_entity]
|
228 |
|
229 |
+
id2normal_entity = {
|
230 |
+
entity_id: entity
|
231 |
+
for entity, entity_id in tokenizer.entity_vocab.items()
|
232 |
+
if entity_id < model.config.num_normal_entities
|
233 |
+
}
|
234 |
+
id2category_entity = {
|
235 |
+
entity_id - model.config.num_normal_entities: entity
|
236 |
+
for entity, entity_id in tokenizer.entity_vocab.items()
|
237 |
+
if entity_id >= model.config.num_normal_entities
|
238 |
+
}
|
239 |
+
ignore_category_entity_ids = [
|
240 |
+
entity_id - model.config.num_normal_entities
|
241 |
+
for entity, entity_id in tokenizer.entity_vocab.items()
|
242 |
+
if entity_id >= model.config.num_normal_entities
|
243 |
+
and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
|
244 |
+
]
|
245 |
+
entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
|
246 |
+
normal_entity_embeddings = entity_embeddings[: model.config.num_normal_entities]
|
247 |
+
category_entity_embeddings = entity_embeddings[model.config.num_normal_entities :]
|
248 |
+
|
249 |
+
if query_entity_id < model.config.num_normal_entities:
|
250 |
topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T
|
251 |
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
|
252 |
topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
|
253 |
else:
|
254 |
+
query_entity_id -= model.config.num_normal_entities
|
255 |
topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
|
256 |
|
257 |
topk_entity_scores[ignore_category_entity_ids] = float("-inf")
|
|
|
262 |
return topk_entities
|
263 |
|
264 |
|
265 |
+
def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]]:
|
266 |
+
new_entity_text_pairs = []
|
267 |
+
if file_path is not None:
|
268 |
+
try:
|
269 |
+
with open(file_path, newline="") as f:
|
270 |
+
reader = csv.DictReader(f, fieldnames=["entity", "text"])
|
271 |
+
for row in reader:
|
272 |
+
entity = normalize_text(row["entity"]).strip()
|
273 |
+
text = normalize_text(row["text"]).strip()
|
274 |
+
if entity != "" and text != "":
|
275 |
+
new_entity_text_pairs.append([entity, text])
|
276 |
+
except Exception as e:
|
277 |
+
gr.Warning("ファイルを正しく読み込めませんでした。")
|
278 |
+
print(e)
|
279 |
+
new_entity_text_pairs = []
|
280 |
+
|
281 |
+
return new_entity_text_pairs
|
282 |
+
|
283 |
+
|
284 |
+
def replace_entities(
|
285 |
+
new_entity_text_pairs: list[tuple[str, str]],
|
286 |
+
new_num_category_entities: int = 0,
|
287 |
+
new_entity_counts: list[int] | None = None,
|
288 |
+
new_padding_idx: int = 0,
|
289 |
+
) -> True:
|
290 |
+
gr.Info("トークナイザのエンティティの語彙を置き換えています...", duration=5)
|
291 |
+
new_entity_tokens = ENTITY_SPECIAL_TOKENS + [entity for entity, _ in new_entity_text_pairs]
|
292 |
+
|
293 |
+
new_entity_vocab = {}
|
294 |
+
for entity in new_entity_tokens:
|
295 |
+
if entity not in new_entity_vocab:
|
296 |
+
new_entity_vocab[entity] = len(new_entity_vocab)
|
297 |
+
|
298 |
+
new_entity_vocab = {entity: entity_id for entity_id, entity in enumerate(new_entity_tokens)}
|
299 |
+
|
300 |
+
tokenizer.entity_vocab = new_entity_vocab
|
301 |
+
tokenizer.entity_pad_token_id = tokenizer.entity_vocab["[PAD]"]
|
302 |
+
tokenizer.entity_unk_token_id = tokenizer.entity_vocab["[UNK]"]
|
303 |
+
tokenizer.entity_mask_token_id = tokenizer.entity_vocab["[MASK]"]
|
304 |
+
tokenizer.entity_mask2_token_id = tokenizer.entity_vocab["[MASK2]"]
|
305 |
+
|
306 |
+
gr.Info("モデルのエンティティの埋め込みを置き換えています...", duration=5)
|
307 |
+
new_entity_embeddings_dict = defaultdict(list)
|
308 |
+
|
309 |
+
for entity_special_token in ENTITY_SPECIAL_TOKENS:
|
310 |
+
entity_special_token_id = tokenizer.entity_vocab[entity_special_token]
|
311 |
+
new_entity_embeddings_dict[entity_special_token_id].append(
|
312 |
+
model.luke.entity_embeddings.entity_embeddings.weight.data[entity_special_token_id]
|
313 |
+
)
|
314 |
|
315 |
+
for entity, text in new_entity_text_pairs:
|
316 |
+
entity_id = tokenizer.entity_vocab[entity]
|
317 |
+
tokenized_inputs = tokenizer(text, return_tensors="pt")
|
318 |
+
model_outputs = model(**tokenized_inputs)
|
319 |
+
entity_embeddings = model.entity_predictions.transform(model_outputs.last_hidden_state[:, 0])
|
320 |
+
new_entity_embeddings_dict[entity_id].append(entity_embeddings[0])
|
321 |
+
|
322 |
+
assert len(new_entity_embeddings_dict) == len(tokenizer.entity_vocab)
|
323 |
+
|
324 |
+
new_entity_embeddings = torch.vstack(
|
325 |
+
[
|
326 |
+
sum(new_entity_embeddings_dict[i]) / len(new_entity_embeddings_dict[i])
|
327 |
+
for i in range(len(new_entity_embeddings_dict))
|
328 |
+
]
|
329 |
+
)
|
330 |
+
new_entity_vocab_size, new_entity_emb_size = new_entity_embeddings.size()
|
331 |
+
assert new_entity_vocab_size == len(tokenizer.entity_vocab)
|
332 |
+
|
333 |
+
new_num_normal_entities = new_entity_vocab_size - new_num_category_entities
|
334 |
+
|
335 |
+
if new_entity_counts is not None and any(count < 1 for count in new_entity_counts):
|
336 |
+
raise ValueError("All items in new_entity_counts must be greater than zero")
|
337 |
+
|
338 |
+
if model.config.normalize_entity_embeddings:
|
339 |
+
new_entity_embeddings = F.normalize(new_entity_embeddings)
|
340 |
+
|
341 |
+
new_entity_embeddings_module = nn.Embedding(
|
342 |
+
new_entity_vocab_size,
|
343 |
+
new_entity_emb_size,
|
344 |
+
padding_idx=new_padding_idx,
|
345 |
+
device=model.luke.entity_embeddings.entity_embeddings.weight.device,
|
346 |
+
dtype=model.luke.entity_embeddings.entity_embeddings.weight.dtype,
|
347 |
+
)
|
348 |
+
new_entity_embeddings_module.weight.data = new_entity_embeddings.data
|
349 |
+
model.luke.entity_embeddings.entity_embeddings = new_entity_embeddings_module
|
350 |
+
|
351 |
+
new_entity_decoder_module = nn.Linear(new_entity_emb_size, new_entity_vocab_size, bias=False)
|
352 |
+
model.entity_predictions.decoder = new_entity_decoder_module
|
353 |
+
model.entity_predictions.bias = nn.Parameter(torch.zeros(new_entity_vocab_size))
|
354 |
+
model.tie_weights()
|
355 |
|
356 |
+
if hasattr(model, "entity_log_probs"):
|
357 |
+
del model.entity_log_probs
|
358 |
+
|
359 |
+
model.config.entity_vocab_size = new_entity_vocab_size
|
360 |
+
model.config.num_normal_entities = new_num_normal_entities
|
361 |
+
model.config.num_category_entities = new_num_category_entities
|
362 |
+
model.config.entity_counts = new_entity_counts
|
363 |
+
|
364 |
+
gr.Info("モデルとトークナイザのエンティティの置き換えが完了しました", duration=5)
|
365 |
+
|
366 |
+
return True
|
367 |
+
|
368 |
+
|
369 |
+
with gr.Blocks() as demo:
|
370 |
texts = gr.State([])
|
371 |
+
|
372 |
+
entities_are_replaced = gr.State(False)
|
373 |
+
|
374 |
topk = gr.State(5)
|
375 |
entity_span_sensitivity = gr.State(1.0)
|
376 |
nayose_coef = gr.State(1.0)
|
377 |
+
|
378 |
batch_entity_spans = gr.State([])
|
379 |
topk_normal_entities = gr.State([])
|
380 |
topk_category_entities = gr.State([])
|
381 |
topk_span_entities = gr.State([])
|
382 |
+
|
383 |
selected_entity = gr.State()
|
384 |
similar_entities = gr.State([])
|
385 |
|
386 |
+
gr.Markdown("# 📝 LUXE Demo")
|
387 |
+
|
388 |
+
gr.Markdown("## 入力テキスト")
|
389 |
+
|
390 |
with gr.Tab(label="直接入力"):
|
391 |
text_input = gr.Textbox(label="入力テキスト")
|
392 |
with gr.Tab(label="ファイルアップロード"):
|
393 |
texts_file = gr.File(label="入力テキストファイル")
|
394 |
|
395 |
+
with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
|
396 |
+
new_entity_text_pairs_file = gr.File(label="エンティティと説明文のCSVファイル")
|
397 |
+
new_entity_text_pairs_input = gr.Dataframe(
|
398 |
+
# value=sample_new_entity_text_pairs,
|
399 |
+
headers=["entity", "text"],
|
400 |
+
col_count=(2, "fixed"),
|
401 |
+
type="array",
|
402 |
+
label="エンティティと��明文",
|
403 |
+
interactive=True,
|
404 |
+
)
|
405 |
+
replace_entity_button = gr.Button(value="エンティティ語彙を置き換える")
|
406 |
+
|
407 |
+
new_entity_text_pairs_file.change(
|
408 |
+
fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
|
409 |
+
)
|
410 |
+
replace_entity_button.click(fn=replace_entities, inputs=new_entity_text_pairs_input, outputs=entities_are_replaced)
|
411 |
+
|
412 |
with gr.Accordion(label="ハイパーパラメータ", open=False):
|
413 |
topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
|
414 |
entity_span_sensitivity_input = gr.Slider(
|
415 |
+
minimum=0.0, maximum=5.0, value=1.0, step=0.1, label="エンティティ検出の積極度", interactive=True
|
416 |
)
|
417 |
nayose_coef_input = gr.Slider(
|
418 |
minimum=0.0, maximum=2.0, value=1.0, step=0.1, label="文字列一致の優先度", interactive=True
|
|
|
428 |
|
429 |
texts.change(
|
430 |
fn=get_topk_entities_from_texts,
|
431 |
+
inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
432 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
433 |
)
|
434 |
topk.change(
|
435 |
fn=get_topk_entities_from_texts,
|
436 |
+
inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
437 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
438 |
)
|
439 |
entity_span_sensitivity.change(
|
440 |
fn=get_topk_entities_from_texts,
|
441 |
+
inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
442 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
443 |
)
|
444 |
nayose_coef.change(
|
445 |
fn=get_topk_entities_from_texts,
|
446 |
+
inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
447 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
448 |
)
|
449 |
topk_input.change(inputs=topk_input, outputs=topk)
|
|
|
475 |
)
|
476 |
|
477 |
# gr.Textbox(text, label="Text")
|
478 |
+
if normal_entities:
|
479 |
+
gr.Dataset(
|
480 |
+
label="テキスト全体に関連するエンティティ",
|
481 |
+
components=["text"],
|
482 |
+
samples=[[entity] for entity in normal_entities],
|
483 |
+
).select(fn=get_selected_entity, outputs=selected_entity)
|
484 |
+
if category_entities:
|
485 |
+
gr.Dataset(
|
486 |
+
label="テキスト全体に関連するカテゴリ",
|
487 |
+
components=["text"],
|
488 |
+
samples=[[entity] for entity in category_entities],
|
489 |
+
).select(fn=get_selected_entity, outputs=selected_entity)
|
490 |
|
491 |
span_texts = [text[start:end] for start, end in entity_spans]
|
492 |
for span_text, entities in zip(span_texts, span_entities):
|
493 |
gr.Dataset(
|
494 |
+
label=f"「{span_text}」に対応するエンティティ",
|
495 |
components=["text"],
|
496 |
samples=[[entity] for entity in entities],
|
497 |
).select(fn=get_selected_entity, outputs=selected_entity)
|