singletongue commited on
Commit
aaaa32a
·
verified ·
1 Parent(s): c5df237

Use ja-v0.3 model, introduce entity span sensitivity

Browse files
Files changed (1) hide show
  1. app.py +111 -26
app.py CHANGED
@@ -2,13 +2,14 @@ import re
2
  from pathlib import Path
3
 
4
  import gradio as gr
 
5
  import unidic_lite
6
  from fugashi import GenericTagger
7
  from transformers import AutoModelForPreTraining, AutoTokenizer
8
 
9
 
10
  repo_id = "studio-ousia/luxe"
11
- revision = "ja-v0.2"
12
 
13
  ignore_category_patterns = [
14
  r"\d+年",
@@ -98,17 +99,63 @@ def get_noun_spans_from_text(text: str) -> list[tuple[int, int]]:
98
  return noun_spans
99
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def get_topk_entities_from_texts(
102
- texts: list[str], k: int = 5
103
- ) -> tuple[list[list[str]], list[list[str]], list[list[list[str]]]]:
 
104
  topk_normal_entities: list[list[str]] = []
105
  topk_category_entities: list[list[str]] = []
106
  topk_span_entities: list[list[list[str]]] = []
107
 
108
  for text in texts:
109
- noun_spans = get_noun_spans_from_text(text)
 
 
 
 
110
 
111
- tokenized_examples = tokenizer(text, entity_spans=noun_spans, return_tensors="pt")
112
  model_outputs = model(**tokenized_examples)
113
 
114
  model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
@@ -119,10 +166,15 @@ def get_topk_entities_from_texts(
119
  _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
120
  topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
121
 
122
- _, topk_span_entity_ids = model_outputs.entity_logits[0, :, :500000].topk(k)
123
- topk_span_entities.append([[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()])
 
 
 
 
 
124
 
125
- return topk_normal_entities, topk_category_entities, topk_span_entities
126
 
127
 
128
  def get_selected_entity(evt: gr.SelectData):
@@ -152,6 +204,9 @@ with gr.Blocks() as demo:
152
  gr.Markdown("## テキスト(直接入力またはファイルアップロード)")
153
 
154
  texts = gr.State([])
 
 
 
155
  topk_normal_entities = gr.State([])
156
  topk_category_entities = gr.State([])
157
  topk_span_entities = gr.State([])
@@ -159,31 +214,60 @@ with gr.Blocks() as demo:
159
  similar_entities = gr.State([])
160
 
161
  text_input = gr.Textbox(label="Input Text")
162
- texts_file = gr.File(label="Input Texts")
163
-
164
  text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts)
 
165
  texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
 
 
 
 
 
 
 
 
 
166
  texts.change(
167
  fn=get_topk_entities_from_texts,
168
- inputs=texts,
169
- outputs=[topk_normal_entities, topk_category_entities, topk_span_entities],
 
 
 
 
 
170
  )
 
 
 
 
 
 
 
171
 
172
  gr.Markdown("---")
173
  gr.Markdown("## 出力エンティティ")
174
 
175
- @gr.render(inputs=[texts, topk_normal_entities, topk_category_entities, topk_span_entities])
176
- def render_topk_entities(texts, topk_normal_entities, topk_category_entities, topk_span_entities):
177
- for text, normal_entities, category_entities, span_entities in zip(
178
- texts, topk_normal_entities, topk_category_entities, topk_span_entities
 
 
179
  ):
 
 
 
 
 
 
 
 
 
 
 
 
180
  gr.HighlightedText(
181
- value=[(word, pos if pos == "名詞" else None) for word, pos, _ in mecab_tokenizer(text)],
182
- color_map={"名詞": "green"},
183
- show_legend=True,
184
- combine_adjacent=True,
185
- adjacent_separator=" ",
186
- label="Text",
187
  )
188
 
189
  # gr.Textbox(text, label="Text")
@@ -194,11 +278,12 @@ with gr.Blocks() as demo:
194
  label="Topic Categories", components=["text"], samples=[[entity] for entity in category_entities]
195
  ).select(fn=get_selected_entity, outputs=selected_entity)
196
 
197
- noun_spans = get_noun_spans_from_text(text)
198
- nouns = [text[start:end] for start, end in noun_spans]
199
- for noun, entities in zip(nouns, span_entities):
200
  gr.Dataset(
201
- label=f"Span Entities for {noun}", components=["text"], samples=[[entity] for entity in entities]
 
 
202
  ).select(fn=get_selected_entity, outputs=selected_entity)
203
 
204
  gr.Markdown("---")
 
2
  from pathlib import Path
3
 
4
  import gradio as gr
5
+ import torch
6
  import unidic_lite
7
  from fugashi import GenericTagger
8
  from transformers import AutoModelForPreTraining, AutoTokenizer
9
 
10
 
11
  repo_id = "studio-ousia/luxe"
12
+ revision = "ja-v0.3"
13
 
14
  ignore_category_patterns = [
15
  r"\d+年",
 
99
  return noun_spans
100
 
101
 
102
+ def get_token_spans(text: str) -> list[tuple[int, int]]:
103
+ token_spans = []
104
+ end = 0
105
+ for token in tokenizer.tokenize(text):
106
+ token = token.removeprefix("##")
107
+ start = text.index(token, end)
108
+ end = start + len(token)
109
+ token_spans.append((start, end))
110
+
111
+ return [(0, 0)] + token_spans + [(end, end)] # count for "[CLS]" and "[SEP]"
112
+
113
+
114
+ def get_predicted_entity_spans(
115
+ ner_logits: torch.Tensor, token_spans: list[tuple[int, int]], entity_span_sensitivity: float = 1.0
116
+ ) -> list[tuple[int, int]]:
117
+ length = ner_logits.size(-1)
118
+ assert ner_logits.size() == (length, length) # not batched
119
+
120
+ ner_probs = torch.sigmoid(ner_logits).triu()
121
+ probs_sorted, sort_idxs = ner_probs.flatten().sort(descending=True)
122
+
123
+ predicted_entity_spans = []
124
+ for p, i in zip(probs_sorted, sort_idxs.tolist()):
125
+ if p < 10.0 ** (-1.0 * entity_span_sensitivity):
126
+ break
127
+
128
+ start_idx = i // length
129
+ end_idx = i % length
130
+
131
+ start = token_spans[start_idx][0]
132
+ end = token_spans[end_idx][1]
133
+
134
+ for ex_start, ex_end in predicted_entity_spans:
135
+ if not (start < end <= ex_start or ex_end <= start < end):
136
+ break
137
+ else:
138
+ predicted_entity_spans.append((start, end))
139
+
140
+ return sorted(predicted_entity_spans)
141
+
142
+
143
  def get_topk_entities_from_texts(
144
+ texts: list[str], k: int = 5, entity_span_sensitivity: float = 1.0
145
+ ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
146
+ batch_entity_spans: list[list[tuple[int, int]]] = []
147
  topk_normal_entities: list[list[str]] = []
148
  topk_category_entities: list[list[str]] = []
149
  topk_span_entities: list[list[list[str]]] = []
150
 
151
  for text in texts:
152
+ tokenized_examples = tokenizer(text, return_tensors="pt")
153
+ model_outputs = model(**tokenized_examples)
154
+ token_spans = get_token_spans(text)
155
+ entity_spans = get_predicted_entity_spans(model_outputs.ner_logits[0], token_spans, entity_span_sensitivity)
156
+ batch_entity_spans.append(entity_spans)
157
 
158
+ tokenized_examples = tokenizer(text, entity_spans=entity_spans or None, return_tensors="pt")
159
  model_outputs = model(**tokenized_examples)
160
 
161
  model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
 
166
  _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
167
  topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
168
 
169
+ if model_outputs.entity_logits is not None:
170
+ _, topk_span_entity_ids = model_outputs.entity_logits[0, :, :500000].topk(k)
171
+ topk_span_entities.append(
172
+ [[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]
173
+ )
174
+ else:
175
+ topk_span_entities.append([])
176
 
177
+ return batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
178
 
179
 
180
  def get_selected_entity(evt: gr.SelectData):
 
204
  gr.Markdown("## テキスト(直接入力またはファイルアップロード)")
205
 
206
  texts = gr.State([])
207
+ topk = gr.State(5)
208
+ entity_span_sensitivity = gr.State(1.0)
209
+ batch_entity_spans = gr.State([])
210
  topk_normal_entities = gr.State([])
211
  topk_category_entities = gr.State([])
212
  topk_span_entities = gr.State([])
 
214
  similar_entities = gr.State([])
215
 
216
  text_input = gr.Textbox(label="Input Text")
 
 
217
  text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts)
218
+ texts_file = gr.File(label="Input Texts")
219
  texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
220
+ topk_input = gr.Number(5, label="Top K", interactive=True)
221
+ topk_input.change(fn=lambda val: val, inputs=topk_input, outputs=topk)
222
+ entity_span_sensitivity_input = gr.Slider(
223
+ minimum=0.1, maximum=5.0, value=1.0, step=0.1, label="Entity Span Sensitivity", interactive=True
224
+ )
225
+ entity_span_sensitivity_input.change(
226
+ fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
227
+ )
228
+
229
  texts.change(
230
  fn=get_topk_entities_from_texts,
231
+ inputs=[texts, topk, entity_span_sensitivity],
232
+ outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
233
+ )
234
+ topk.change(
235
+ fn=get_topk_entities_from_texts,
236
+ inputs=[texts, topk, entity_span_sensitivity],
237
+ outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
238
  )
239
+ entity_span_sensitivity.change(
240
+ fn=get_topk_entities_from_texts,
241
+ inputs=[texts, topk, entity_span_sensitivity],
242
+ outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
243
+ )
244
+
245
+ topk_input.change(inputs=topk_input, outputs=topk)
246
 
247
  gr.Markdown("---")
248
  gr.Markdown("## 出力エンティティ")
249
 
250
+ @gr.render(inputs=[texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities])
251
+ def render_topk_entities(
252
+ texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
253
+ ):
254
+ for text, entity_spans, normal_entities, category_entities, span_entities in zip(
255
+ texts, batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities
256
  ):
257
+ highlighted_text_value = []
258
+ cur = 0
259
+ for start, end in entity_spans:
260
+ if cur < start:
261
+ highlighted_text_value.append((text[cur:start], None))
262
+
263
+ highlighted_text_value.append((text[start:end], "Entity"))
264
+ cur = end
265
+
266
+ if cur < len(text):
267
+ highlighted_text_value.append((text[cur:], None))
268
+
269
  gr.HighlightedText(
270
+ value=highlighted_text_value, color_map={"Entity": "green"}, combine_adjacent=False, label="Text"
 
 
 
 
 
271
  )
272
 
273
  # gr.Textbox(text, label="Text")
 
278
  label="Topic Categories", components=["text"], samples=[[entity] for entity in category_entities]
279
  ).select(fn=get_selected_entity, outputs=selected_entity)
280
 
281
+ span_texts = [text[start:end] for start, end in entity_spans]
282
+ for span_text, entities in zip(span_texts, span_entities):
 
283
  gr.Dataset(
284
+ label=f"Span Entities for {span_text}",
285
+ components=["text"],
286
+ samples=[[entity] for entity in entities],
287
  ).select(fn=get_selected_entity, outputs=selected_entity)
288
 
289
  gr.Markdown("---")