Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Use ja-v0.3 model, introduce entity span sensitivity
Browse files
app.py
CHANGED
@@ -2,13 +2,14 @@ import re
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
import gradio as gr
|
|
|
5 |
import unidic_lite
|
6 |
from fugashi import GenericTagger
|
7 |
from transformers import AutoModelForPreTraining, AutoTokenizer
|
8 |
|
9 |
|
10 |
repo_id = "studio-ousia/luxe"
|
11 |
-
revision = "ja-v0.
|
12 |
|
13 |
ignore_category_patterns = [
|
14 |
r"\d+年",
|
@@ -98,17 +99,63 @@ def get_noun_spans_from_text(text: str) -> list[tuple[int, int]]:
|
|
98 |
return noun_spans
|
99 |
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
def get_topk_entities_from_texts(
|
102 |
-
texts: list[str], k: int = 5
|
103 |
-
) -> tuple[list[list[str]], list[list[str]], list[list[list[str]]]]:
|
|
|
104 |
topk_normal_entities: list[list[str]] = []
|
105 |
topk_category_entities: list[list[str]] = []
|
106 |
topk_span_entities: list[list[list[str]]] = []
|
107 |
|
108 |
for text in texts:
|
109 |
-
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
tokenized_examples = tokenizer(text, entity_spans=
|
112 |
model_outputs = model(**tokenized_examples)
|
113 |
|
114 |
model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
|
@@ -119,10 +166,15 @@ def get_topk_entities_from_texts(
|
|
119 |
_, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
|
120 |
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
|
121 |
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
-
return topk_normal_entities, topk_category_entities, topk_span_entities
|
126 |
|
127 |
|
128 |
def get_selected_entity(evt: gr.SelectData):
|
@@ -152,6 +204,9 @@ with gr.Blocks() as demo:
|
|
152 |
gr.Markdown("## テキスト(直接入力またはファイルアップロード)")
|
153 |
|
154 |
texts = gr.State([])
|
|
|
|
|
|
|
155 |
topk_normal_entities = gr.State([])
|
156 |
topk_category_entities = gr.State([])
|
157 |
topk_span_entities = gr.State([])
|
@@ -159,31 +214,60 @@ with gr.Blocks() as demo:
|
|
159 |
similar_entities = gr.State([])
|
160 |
|
161 |
text_input = gr.Textbox(label="Input Text")
|
162 |
-
texts_file = gr.File(label="Input Texts")
|
163 |
-
|
164 |
text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts)
|
|
|
165 |
texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
texts.change(
|
167 |
fn=get_topk_entities_from_texts,
|
168 |
-
inputs=texts,
|
169 |
-
outputs=[topk_normal_entities, topk_category_entities, topk_span_entities],
|
|
|
|
|
|
|
|
|
|
|
170 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
gr.Markdown("---")
|
173 |
gr.Markdown("## 出力エンティティ")
|
174 |
|
175 |
-
@gr.render(inputs=[texts, topk_normal_entities, topk_category_entities, topk_span_entities])
|
176 |
-
def render_topk_entities(
|
177 |
-
|
178 |
-
|
|
|
|
|
179 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
gr.HighlightedText(
|
181 |
-
value=
|
182 |
-
color_map={"名詞": "green"},
|
183 |
-
show_legend=True,
|
184 |
-
combine_adjacent=True,
|
185 |
-
adjacent_separator=" ",
|
186 |
-
label="Text",
|
187 |
)
|
188 |
|
189 |
# gr.Textbox(text, label="Text")
|
@@ -194,11 +278,12 @@ with gr.Blocks() as demo:
|
|
194 |
label="Topic Categories", components=["text"], samples=[[entity] for entity in category_entities]
|
195 |
).select(fn=get_selected_entity, outputs=selected_entity)
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
for noun, entities in zip(nouns, span_entities):
|
200 |
gr.Dataset(
|
201 |
-
label=f"Span Entities for {
|
|
|
|
|
202 |
).select(fn=get_selected_entity, outputs=selected_entity)
|
203 |
|
204 |
gr.Markdown("---")
|
|
|
2 |
from pathlib import Path
|
3 |
|
4 |
import gradio as gr
|
5 |
+
import torch
|
6 |
import unidic_lite
|
7 |
from fugashi import GenericTagger
|
8 |
from transformers import AutoModelForPreTraining, AutoTokenizer
|
9 |
|
10 |
|
11 |
repo_id = "studio-ousia/luxe"
|
12 |
+
revision = "ja-v0.3"
|
13 |
|
14 |
ignore_category_patterns = [
|
15 |
r"\d+年",
|
|
|
99 |
return noun_spans
|
100 |
|
101 |
|
102 |
+
def get_token_spans(text: str) -> list[tuple[int, int]]:
|
103 |
+
token_spans = []
|
104 |
+
end = 0
|
105 |
+
for token in tokenizer.tokenize(text):
|
106 |
+
token = token.removeprefix("##")
|
107 |
+
start = text.index(token, end)
|
108 |
+
end = start + len(token)
|
109 |
+
token_spans.append((start, end))
|
110 |
+
|
111 |
+
return [(0, 0)] + token_spans + [(end, end)] # count for "[CLS]" and "[SEP]"
|
112 |
+
|
113 |
+
|
114 |
+
def get_predicted_entity_spans(
|
115 |
+
ner_logits: torch.Tensor, token_spans: list[tuple[int, int]], entity_span_sensitivity: float = 1.0
|
116 |
+
) -> list[tuple[int, int]]:
|
117 |
+
length = ner_logits.size(-1)
|
118 |
+
assert ner_logits.size() == (length, length) # not batched
|
119 |
+
|
120 |
+
ner_probs = torch.sigmoid(ner_logits).triu()
|
121 |
+
probs_sorted, sort_idxs = ner_probs.flatten().sort(descending=True)
|
122 |
+
|
123 |
+
predicted_entity_spans = []
|
124 |
+
for p, i in zip(probs_sorted, sort_idxs.tolist()):
|
125 |
+
if p < 10.0 ** (-1.0 * entity_span_sensitivity):
|
126 |
+
break
|
127 |
+
|
128 |
+
start_idx = i // length
|
129 |
+
end_idx = i % length
|
130 |
+
|
131 |
+
start = token_spans[start_idx][0]
|
132 |
+
end = token_spans[end_idx][1]
|
133 |
+
|
134 |
+
for ex_start, ex_end in predicted_entity_spans:
|
135 |
+
if not (start < end <= ex_start or ex_end <= start < end):
|
136 |
+
break
|
137 |
+
else:
|
138 |
+
predicted_entity_spans.append((start, end))
|
139 |
+
|
140 |
+
return sorted(predicted_entity_spans)
|
141 |
+
|
142 |
+
|
143 |
def get_topk_entities_from_texts(
|
144 |
+
texts: list[str], k: int = 5, entity_span_sensitivity: float = 1.0
|
145 |
+
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
|
146 |
+
batch_entity_spans: list[list[tuple[int, int]]] = []
|
147 |
topk_normal_entities: list[list[str]] = []
|
148 |
topk_category_entities: list[list[str]] = []
|
149 |
topk_span_entities: list[list[list[str]]] = []
|
150 |
|
151 |
for text in texts:
|
152 |
+
tokenized_examples = tokenizer(text, return_tensors="pt")
|
153 |
+
model_outputs = model(**tokenized_examples)
|
154 |
+
token_spans = get_token_spans(text)
|
155 |
+
entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity)
|
156 |
+
batch_entity_spans.append(entity_spans)
|
157 |
|
158 |
+
tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, return_tensors="pt")
|
159 |
model_outputs = model(**tokenized_examples)
|
160 |
|
161 |
model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
|
|
|
166 |
_, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
|
167 |
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
|
168 |
|
169 |
+
if model_outputs.entity_logits is not None:
|
170 |
+
_, topk_span_entity_ids = model_outputs.entity_logits[0, :, :500000].topk(k)
|
171 |
+
topk_span_entities.append(
|
172 |
+
[[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]
|
173 |
+
)
|
174 |
+
else:
|
175 |
+
topk_span_entities.append([])
|
176 |
|
177 |
+
return batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
|
178 |
|
179 |
|
180 |
def get_selected_entity(evt: gr.SelectData):
|
|
|
204 |
gr.Markdown("## テキスト(直接入力またはファイルアップロード)")
|
205 |
|
206 |
texts = gr.State([])
|
207 |
+
topk = gr.State(5)
|
208 |
+
entity_span_sensitivity = gr.State(1.0)
|
209 |
+
batch_entity_spans = gr.State([])
|
210 |
topk_normal_entities = gr.State([])
|
211 |
topk_category_entities = gr.State([])
|
212 |
topk_span_entities = gr.State([])
|
|
|
214 |
similar_entities = gr.State([])
|
215 |
|
216 |
text_input = gr.Textbox(label="Input Text")
|
|
|
|
|
217 |
text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts)
|
218 |
+
texts_file = gr.File(label="Input Texts")
|
219 |
texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
|
220 |
+
topk_input = gr.Number(5, label="Top K", interactive=True)
|
221 |
+
topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk)
|
222 |
+
entity_span_sensitivity_input = gr.Slider(
|
223 |
+
minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="Entity Span Sensitivity", interactive=True
|
224 |
+
)
|
225 |
+
entity_span_sensitivity_input.change(
|
226 |
+
fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
|
227 |
+
)
|
228 |
+
|
229 |
texts.change(
|
230 |
fn=get_topk_entities_from_texts,
|
231 |
+
inputs=[texts, topk, entity_span_sensitivity],
|
232 |
+
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
233 |
+
)
|
234 |
+
topk.change(
|
235 |
+
fn=get_topk_entities_from_texts,
|
236 |
+
inputs=[texts, topk, entity_span_sensitivity],
|
237 |
+
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
238 |
)
|
239 |
+
entity_span_sensitivity.change(
|
240 |
+
fn=get_topk_entities_from_texts,
|
241 |
+
inputs=[texts, topk, entity_span_sensitivity],
|
242 |
+
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
243 |
+
)
|
244 |
+
|
245 |
+
topk_input.change(inputs=topk_input, outputs=topk)
|
246 |
|
247 |
gr.Markdown("---")
|
248 |
gr.Markdown("## 出力エンティティ")
|
249 |
|
250 |
+
@gr.render(inputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities])
|
251 |
+
def render_topk_entities(
|
252 |
+
texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
|
253 |
+
):
|
254 |
+
for text, entity_spans, normal_entities, category_entities, span_entities in zip(
|
255 |
+
texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
|
256 |
):
|
257 |
+
highlighted_text_value = []
|
258 |
+
cur = 0
|
259 |
+
for start, end in entity_spans:
|
260 |
+
if cur < start:
|
261 |
+
highlighted_text_value.append((text[cur:start], None))
|
262 |
+
|
263 |
+
highlighted_text_value.append((text[start:end], "Entity"))
|
264 |
+
cur = end
|
265 |
+
|
266 |
+
if cur < len(text):
|
267 |
+
highlighted_text_value.append((text[cur:], None))
|
268 |
+
|
269 |
gr.HighlightedText(
|
270 |
+
value=highlighted_text_value, color_map={"Entity": "green"}, combine_adjacent=False, label="Text"
|
|
|
|
|
|
|
|
|
|
|
271 |
)
|
272 |
|
273 |
# gr.Textbox(text, label="Text")
|
|
|
278 |
label="Topic Categories", components=["text"], samples=[[entity] for entity in category_entities]
|
279 |
).select(fn=get_selected_entity, outputs=selected_entity)
|
280 |
|
281 |
+
span_texts = [text[start:end] for start, end in entity_spans]
|
282 |
+
for span_text, entities in zip(span_texts, span_entities):
|
|
|
283 |
gr.Dataset(
|
284 |
+
label=f"Span Entities for {span_text}",
|
285 |
+
components=["text"],
|
286 |
+
samples=[[entity] for entity in entities],
|
287 |
).select(fn=get_selected_entity, outputs=selected_entity)
|
288 |
|
289 |
gr.Markdown("---")
|