Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Reset model and tokenizer when the demo is reloaded
Browse files
app.py
CHANGED
@@ -33,9 +33,6 @@ ignore_category_patterns = [
|
|
33 |
r"各年の",
|
34 |
]
|
35 |
|
36 |
-
model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
37 |
-
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
38 |
-
|
39 |
|
40 |
class MecabTokenizer:
|
41 |
def __init__(self):
|
@@ -65,11 +62,6 @@ def normalize_text(text: str) -> str:
|
|
65 |
return unicodedata.normalize("NFKC", text)
|
66 |
|
67 |
|
68 |
-
bm25_tokenizer = TokenizerHF(lower=True, splitter=tokenizer.tokenize, stopwords=None, stemmer=None)
|
69 |
-
bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
|
70 |
-
bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
|
71 |
-
|
72 |
-
|
73 |
def get_texts_from_file(file_path: str | None):
|
74 |
texts = []
|
75 |
if file_path is not None:
|
@@ -104,7 +96,7 @@ def get_noun_spans_from_text(text: str) -> list[tuple[int, int]]:
|
|
104 |
return noun_spans
|
105 |
|
106 |
|
107 |
-
def get_token_spans(text: str) -> list[tuple[int, int]]:
|
108 |
token_spans = []
|
109 |
end = 0
|
110 |
for token in tokenizer.tokenize(text):
|
@@ -147,12 +139,15 @@ def get_predicted_entity_spans(
|
|
147 |
|
148 |
|
149 |
def get_topk_entities_from_texts(
|
|
|
150 |
texts: list[str],
|
151 |
k: int = 5,
|
152 |
entity_span_sensitivity: float = 1.0,
|
153 |
nayose_coef: float = 1.0,
|
154 |
entities_are_replaced: bool = False,
|
155 |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
|
|
|
|
|
156 |
batch_entity_spans: list[list[tuple[int, int]]] = []
|
157 |
topk_normal_entities: list[list[str]] = []
|
158 |
topk_category_entities: list[list[str]] = []
|
@@ -178,7 +173,7 @@ def get_topk_entities_from_texts(
|
|
178 |
for text in texts:
|
179 |
tokenized_examples = tokenizer(text, return_tensors="pt")
|
180 |
model_outputs = model(**tokenized_examples)
|
181 |
-
token_spans = get_token_spans(text)
|
182 |
entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity)
|
183 |
batch_entity_spans.append(entity_spans)
|
184 |
|
@@ -223,7 +218,9 @@ def get_selected_entity(evt: gr.SelectData):
|
|
223 |
return evt.value[0]
|
224 |
|
225 |
|
226 |
-
def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
|
|
|
|
|
227 |
query_entity_id = tokenizer.entity_vocab[query_entity]
|
228 |
|
229 |
id2normal_entity = {
|
@@ -282,11 +279,14 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
|
|
282 |
|
283 |
|
284 |
def replace_entities(
|
|
|
285 |
new_entity_text_pairs: list[tuple[str, str]],
|
286 |
new_num_category_entities: int = 0,
|
287 |
new_entity_counts: list[int] | None = None,
|
288 |
new_padding_idx: int = 0,
|
289 |
) -> True:
|
|
|
|
|
290 |
gr.Info("トークナイザのエンティティの語彙を置き換えています...", duration=5)
|
291 |
new_entity_tokens = ENTITY_SPECIAL_TOKENS + [entity for entity, _ in new_entity_text_pairs]
|
292 |
|
@@ -367,6 +367,18 @@ def replace_entities(
|
|
367 |
|
368 |
|
369 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
texts = gr.State([])
|
371 |
|
372 |
entities_are_replaced = gr.State(False)
|
@@ -407,7 +419,9 @@ with gr.Blocks() as demo:
|
|
407 |
new_entity_text_pairs_file.change(
|
408 |
fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
|
409 |
)
|
410 |
-
replace_entity_button.click(
|
|
|
|
|
411 |
|
412 |
with gr.Accordion(label="ハイパーパラメータ", open=False):
|
413 |
topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
|
@@ -428,25 +442,30 @@ with gr.Blocks() as demo:
|
|
428 |
|
429 |
texts.change(
|
430 |
fn=get_topk_entities_from_texts,
|
431 |
-
inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
432 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
433 |
)
|
434 |
topk.change(
|
435 |
fn=get_topk_entities_from_texts,
|
436 |
-
inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
437 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
438 |
)
|
439 |
entity_span_sensitivity.change(
|
440 |
fn=get_topk_entities_from_texts,
|
441 |
-
inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
442 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
443 |
)
|
444 |
nayose_coef.change(
|
445 |
fn=get_topk_entities_from_texts,
|
446 |
-
inputs=[texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
448 |
)
|
449 |
-
topk_input.change(inputs=topk_input, outputs=topk)
|
450 |
|
451 |
gr.Markdown("---")
|
452 |
gr.Markdown("## 出力エンティティ")
|
@@ -499,7 +518,7 @@ with gr.Blocks() as demo:
|
|
499 |
# gr.Markdown("---")
|
500 |
# gr.Markdown("## 選択されたエンティティの類似エンティティ")
|
501 |
|
502 |
-
# selected_entity.change(fn=get_similar_entities, inputs=selected_entity, outputs=similar_entities)
|
503 |
|
504 |
# @gr.render(inputs=[selected_entity, similar_entities])
|
505 |
# def render_similar_entities(selected_entity, similar_entities):
|
|
|
33 |
r"各年の",
|
34 |
]
|
35 |
|
|
|
|
|
|
|
36 |
|
37 |
class MecabTokenizer:
|
38 |
def __init__(self):
|
|
|
62 |
return unicodedata.normalize("NFKC", text)
|
63 |
|
64 |
|
|
|
|
|
|
|
|
|
|
|
65 |
def get_texts_from_file(file_path: str | None):
|
66 |
texts = []
|
67 |
if file_path is not None:
|
|
|
96 |
return noun_spans
|
97 |
|
98 |
|
99 |
+
def get_token_spans(tokenizer, text: str) -> list[tuple[int, int]]:
|
100 |
token_spans = []
|
101 |
end = 0
|
102 |
for token in tokenizer.tokenize(text):
|
|
|
139 |
|
140 |
|
141 |
def get_topk_entities_from_texts(
|
142 |
+
models,
|
143 |
texts: list[str],
|
144 |
k: int = 5,
|
145 |
entity_span_sensitivity: float = 1.0,
|
146 |
nayose_coef: float = 1.0,
|
147 |
entities_are_replaced: bool = False,
|
148 |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
|
149 |
+
model, tokenizer, bm25_tokenizer, bm25_retriever = models
|
150 |
+
|
151 |
batch_entity_spans: list[list[tuple[int, int]]] = []
|
152 |
topk_normal_entities: list[list[str]] = []
|
153 |
topk_category_entities: list[list[str]] = []
|
|
|
173 |
for text in texts:
|
174 |
tokenized_examples = tokenizer(text, return_tensors="pt")
|
175 |
model_outputs = model(**tokenized_examples)
|
176 |
+
token_spans = get_token_spans(tokenizer, text)
|
177 |
entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity)
|
178 |
batch_entity_spans.append(entity_spans)
|
179 |
|
|
|
218 |
return evt.value[0]
|
219 |
|
220 |
|
221 |
+
def get_similar_entities(models, query_entity: str, k: int = 10) -> list[str]:
|
222 |
+
model, tokenizer, _, _ = models
|
223 |
+
|
224 |
query_entity_id = tokenizer.entity_vocab[query_entity]
|
225 |
|
226 |
id2normal_entity = {
|
|
|
279 |
|
280 |
|
281 |
def replace_entities(
|
282 |
+
models,
|
283 |
new_entity_text_pairs: list[tuple[str, str]],
|
284 |
new_num_category_entities: int = 0,
|
285 |
new_entity_counts: list[int] | None = None,
|
286 |
new_padding_idx: int = 0,
|
287 |
) -> True:
|
288 |
+
model, tokenizer, bm25_tokenizer, bm25_retriever = models
|
289 |
+
|
290 |
gr.Info("トークナイザのエンティティの語彙を置き換えています...", duration=5)
|
291 |
new_entity_tokens = ENTITY_SPECIAL_TOKENS + [entity for entity, _ in new_entity_text_pairs]
|
292 |
|
|
|
367 |
|
368 |
|
369 |
with gr.Blocks() as demo:
|
370 |
+
model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
371 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
372 |
+
bm25_tokenizer = TokenizerHF(lower=True, splitter=tokenizer.tokenize, stopwords=None, stemmer=None)
|
373 |
+
bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
|
374 |
+
bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
|
375 |
+
|
376 |
+
# Hint: gr.State に callable を渡すと、それが state の初期値を設定するための関数とみなされて
|
377 |
+
# __call__ が引数なしで実行されてしまうため、gr.State の引数に model や tokenizer を単体で渡すとエラーになってしまう。
|
378 |
+
# ここでは、モデル一式のタプル(callable でない)を渡すことで、そのようなエラーを回避している。
|
379 |
+
# cf. https://www.gradio.app/docs/gradio/state#param-state-value
|
380 |
+
models = gr.State((model, tokenizer, bm25_tokenizer, bm25_retriever))
|
381 |
+
|
382 |
texts = gr.State([])
|
383 |
|
384 |
entities_are_replaced = gr.State(False)
|
|
|
419 |
new_entity_text_pairs_file.change(
|
420 |
fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
|
421 |
)
|
422 |
+
replace_entity_button.click(
|
423 |
+
fn=replace_entities, inputs=[models, new_entity_text_pairs_input], outputs=entities_are_replaced
|
424 |
+
)
|
425 |
|
426 |
with gr.Accordion(label="ハイパーパラメータ", open=False):
|
427 |
topk_input = gr.Number(5, label="エンティティ件数", interactive=True)
|
|
|
442 |
|
443 |
texts.change(
|
444 |
fn=get_topk_entities_from_texts,
|
445 |
+
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
446 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
447 |
)
|
448 |
topk.change(
|
449 |
fn=get_topk_entities_from_texts,
|
450 |
+
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
451 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
452 |
)
|
453 |
entity_span_sensitivity.change(
|
454 |
fn=get_topk_entities_from_texts,
|
455 |
+
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
456 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
457 |
)
|
458 |
nayose_coef.change(
|
459 |
fn=get_topk_entities_from_texts,
|
460 |
+
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
461 |
+
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
462 |
+
)
|
463 |
+
|
464 |
+
entities_are_replaced.change(
|
465 |
+
fn=get_topk_entities_from_texts,
|
466 |
+
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entities_are_replaced],
|
467 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
468 |
)
|
|
|
469 |
|
470 |
gr.Markdown("---")
|
471 |
gr.Markdown("## 出力エンティティ")
|
|
|
518 |
# gr.Markdown("---")
|
519 |
# gr.Markdown("## 選択されたエンティティの類似エンティティ")
|
520 |
|
521 |
+
# selected_entity.change(fn=get_similar_entities, inputs=[models, selected_entity], outputs=similar_entities)
|
522 |
|
523 |
# @gr.render(inputs=[selected_entity, similar_entities])
|
524 |
# def render_similar_entities(selected_entity, similar_entities):
|