Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Support NAYOSE for span entities using BM25 of entity name sub-tokens
Browse files
app.py
CHANGED
@@ -5,13 +5,18 @@ from pathlib import Path
|
|
5 |
import gradio as gr
|
6 |
import torch
|
7 |
import unidic_lite
|
|
|
8 |
from fugashi import GenericTagger
|
9 |
from transformers import AutoModelForPreTraining, AutoTokenizer
|
10 |
|
11 |
|
|
|
|
|
12 |
repo_id = "studio-ousia/luxe"
|
13 |
revision = "ja-v0.3"
|
14 |
|
|
|
|
|
15 |
ignore_category_patterns = [
|
16 |
r"\d+年",
|
17 |
r"楽曲 [ぁ-ん]",
|
@@ -77,6 +82,11 @@ def normalize_text(text: str) -> str:
|
|
77 |
return unicodedata.normalize("NFKC", text)
|
78 |
|
79 |
|
|
|
|
|
|
|
|
|
|
|
80 |
def get_texts_from_file(file_path):
|
81 |
texts = []
|
82 |
with open(file_path) as f:
|
@@ -146,7 +156,7 @@ def get_predicted_entity_spans(
|
|
146 |
|
147 |
|
148 |
def get_topk_entities_from_texts(
|
149 |
-
texts: list[str], k: int = 5, entity_span_sensitivity: float = 1.0
|
150 |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
|
151 |
batch_entity_spans: list[list[tuple[int, int]]] = []
|
152 |
topk_normal_entities: list[list[str]] = []
|
@@ -172,7 +182,17 @@ def get_topk_entities_from_texts(
|
|
172 |
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
|
173 |
|
174 |
if model_outputs.entity_logits is not None:
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
topk_span_entities.append(
|
177 |
[[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]
|
178 |
)
|
@@ -211,6 +231,7 @@ with gr.Blocks() as demo:
|
|
211 |
texts = gr.State([])
|
212 |
topk = gr.State(5)
|
213 |
entity_span_sensitivity = gr.State(1.0)
|
|
|
214 |
batch_entity_spans = gr.State([])
|
215 |
topk_normal_entities = gr.State([])
|
216 |
topk_category_entities = gr.State([])
|
@@ -230,20 +251,29 @@ with gr.Blocks() as demo:
|
|
230 |
entity_span_sensitivity_input.change(
|
231 |
fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
|
232 |
)
|
|
|
|
|
|
|
|
|
233 |
|
234 |
texts.change(
|
235 |
fn=get_topk_entities_from_texts,
|
236 |
-
inputs=[texts, topk, entity_span_sensitivity],
|
237 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
238 |
)
|
239 |
topk.change(
|
240 |
fn=get_topk_entities_from_texts,
|
241 |
-
inputs=[texts, topk, entity_span_sensitivity],
|
242 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
243 |
)
|
244 |
entity_span_sensitivity.change(
|
245 |
fn=get_topk_entities_from_texts,
|
246 |
-
inputs=[texts, topk, entity_span_sensitivity],
|
|
|
|
|
|
|
|
|
|
|
247 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
248 |
)
|
249 |
|
|
|
5 |
import gradio as gr
|
6 |
import torch
|
7 |
import unidic_lite
|
8 |
+
from bm25s.hf import BM25HF, TokenizerHF
|
9 |
from fugashi import GenericTagger
|
10 |
from transformers import AutoModelForPreTraining, AutoTokenizer
|
11 |
|
12 |
|
13 |
+
ALIAS_SEP = "|"
|
14 |
+
|
15 |
repo_id = "studio-ousia/luxe"
|
16 |
revision = "ja-v0.3"
|
17 |
|
18 |
+
nayose_repo_id = "studio-ousia/luxe-nayose-bm25"
|
19 |
+
|
20 |
ignore_category_patterns = [
|
21 |
r"\d+年",
|
22 |
r"楽曲 [ぁ-ん]",
|
|
|
82 |
return unicodedata.normalize("NFKC", text)
|
83 |
|
84 |
|
85 |
+
bm25_tokenizer = TokenizerHF(lower=True, splitter=tokenizer.tokenize, stopwords=None, stemmer=None)
|
86 |
+
bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
|
87 |
+
bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
|
88 |
+
|
89 |
+
|
90 |
def get_texts_from_file(file_path):
|
91 |
texts = []
|
92 |
with open(file_path) as f:
|
|
|
156 |
|
157 |
|
158 |
def get_topk_entities_from_texts(
|
159 |
+
texts: list[str], k: int = 5, entity_span_sensitivity: float = 1.0, nayose_coef: float = 0.0
|
160 |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
|
161 |
batch_entity_spans: list[list[tuple[int, int]]] = []
|
162 |
topk_normal_entities: list[list[str]] = []
|
|
|
182 |
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
|
183 |
|
184 |
if model_outputs.entity_logits is not None:
|
185 |
+
span_entity_logits = model_outputs.entity_logits[0, :, :500000]
|
186 |
+
|
187 |
+
if nayose_coef > 0.0:
|
188 |
+
nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans]
|
189 |
+
nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries)
|
190 |
+
nayose_scores = torch.vstack(
|
191 |
+
[torch.from_numpy(bm25_retriever.get_scores(tokens)) for tokens in nayose_query_tokens]
|
192 |
+
)
|
193 |
+
span_entity_logits += nayose_coef * nayose_scores
|
194 |
+
|
195 |
+
_, topk_span_entity_ids = span_entity_logits.topk(k)
|
196 |
topk_span_entities.append(
|
197 |
[[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]
|
198 |
)
|
|
|
231 |
texts = gr.State([])
|
232 |
topk = gr.State(5)
|
233 |
entity_span_sensitivity = gr.State(1.0)
|
234 |
+
nayose_coef = gr.State(0.0)
|
235 |
batch_entity_spans = gr.State([])
|
236 |
topk_normal_entities = gr.State([])
|
237 |
topk_category_entities = gr.State([])
|
|
|
251 |
entity_span_sensitivity_input.change(
|
252 |
fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
|
253 |
)
|
254 |
+
nayose_coef_input = gr.Slider(
|
255 |
+
minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Nayose Coefficient", interactive=True
|
256 |
+
)
|
257 |
+
nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef)
|
258 |
|
259 |
texts.change(
|
260 |
fn=get_topk_entities_from_texts,
|
261 |
+
inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
|
262 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
263 |
)
|
264 |
topk.change(
|
265 |
fn=get_topk_entities_from_texts,
|
266 |
+
inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
|
267 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
268 |
)
|
269 |
entity_span_sensitivity.change(
|
270 |
fn=get_topk_entities_from_texts,
|
271 |
+
inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
|
272 |
+
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
273 |
+
)
|
274 |
+
nayose_coef.change(
|
275 |
+
fn=get_topk_entities_from_texts,
|
276 |
+
inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
|
277 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
278 |
)
|
279 |
|