singletongue commited on
Commit
6234321
·
verified ·
1 Parent(s): a050369

Support NAYOSE for span entities using BM25 of entity name sub-tokens

Browse files
Files changed (1) hide show
  1. app.py +35 -5
app.py CHANGED
@@ -5,13 +5,18 @@ from pathlib import Path
5
  import gradio as gr
6
  import torch
7
  import unidic_lite
 
8
  from fugashi import GenericTagger
9
  from transformers import AutoModelForPreTraining, AutoTokenizer
10
 
11
 
 
 
12
  repo_id = "studio-ousia/luxe"
13
  revision = "ja-v0.3"
14
 
 
 
15
  ignore_category_patterns = [
16
  r"\d+年",
17
  r"楽曲 [ぁ-ん]",
@@ -77,6 +82,11 @@ def normalize_text(text: str) -> str:
77
  return unicodedata.normalize("NFKC", text)
78
 
79
 
 
 
 
 
 
80
  def get_texts_from_file(file_path):
81
  texts = []
82
  with open(file_path) as f:
@@ -146,7 +156,7 @@ def get_predicted_entity_spans(
146
 
147
 
148
  def get_topk_entities_from_texts(
149
- texts: list[str], k: int = 5, entity_span_sensitivity: float = 1.0
150
  ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
151
  batch_entity_spans: list[list[tuple[int, int]]] = []
152
  topk_normal_entities: list[list[str]] = []
@@ -172,7 +182,17 @@ def get_topk_entities_from_texts(
172
  topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
173
 
174
  if model_outputs.entity_logits is not None:
175
- _, topk_span_entity_ids = model_outputs.entity_logits[0, :, :500000].topk(k)
 
 
 
 
 
 
 
 
 
 
176
  topk_span_entities.append(
177
  [[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]
178
  )
@@ -211,6 +231,7 @@ with gr.Blocks() as demo:
211
  texts = gr.State([])
212
  topk = gr.State(5)
213
  entity_span_sensitivity = gr.State(1.0)
 
214
  batch_entity_spans = gr.State([])
215
  topk_normal_entities = gr.State([])
216
  topk_category_entities = gr.State([])
@@ -230,20 +251,29 @@ with gr.Blocks() as demo:
230
  entity_span_sensitivity_input.change(
231
  fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
232
  )
 
 
 
 
233
 
234
  texts.change(
235
  fn=get_topk_entities_from_texts,
236
- inputs=[texts, topk, entity_span_sensitivity],
237
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
238
  )
239
  topk.change(
240
  fn=get_topk_entities_from_texts,
241
- inputs=[texts, topk, entity_span_sensitivity],
242
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
243
  )
244
  entity_span_sensitivity.change(
245
  fn=get_topk_entities_from_texts,
246
- inputs=[texts, topk, entity_span_sensitivity],
 
 
 
 
 
247
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
248
  )
249
 
 
5
  import gradio as gr
6
  import torch
7
  import unidic_lite
8
+ from bm25s.hf import BM25HF, TokenizerHF
9
  from fugashi import GenericTagger
10
  from transformers import AutoModelForPreTraining, AutoTokenizer
11
 
12
 
13
+ ALIAS_SEP = "|"
14
+
15
  repo_id = "studio-ousia/luxe"
16
  revision = "ja-v0.3"
17
 
18
+ nayose_repo_id = "studio-ousia/luxe-nayose-bm25"
19
+
20
  ignore_category_patterns = [
21
  r"\d+年",
22
  r"楽曲 [ぁ-ん]",
 
82
  return unicodedata.normalize("NFKC", text)
83
 
84
 
85
+ bm25_tokenizer = TokenizerHF(lower=True, splitter=tokenizer.tokenize, stopwords=None, stemmer=None)
86
+ bm25_tokenizer.load_vocab_from_hub("studio-ousia/luxe-nayose-bm25")
87
+ bm25_retriever = BM25HF.load_from_hub("studio-ousia/luxe-nayose-bm25")
88
+
89
+
90
  def get_texts_from_file(file_path):
91
  texts = []
92
  with open(file_path) as f:
 
156
 
157
 
158
  def get_topk_entities_from_texts(
159
+ texts: list[str], k: int = 5, entity_span_sensitivity: float = 1.0, nayose_coef: float = 0.0
160
  ) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
161
  batch_entity_spans: list[list[tuple[int, int]]] = []
162
  topk_normal_entities: list[list[str]] = []
 
182
  topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
183
 
184
  if model_outputs.entity_logits is not None:
185
+ span_entity_logits = model_outputs.entity_logits[0, :, :500000]
186
+
187
+ if nayose_coef > 0.0:
188
+ nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans]
189
+ nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries)
190
+ nayose_scores = torch.vstack(
191
+ [torch.from_numpy(bm25_retriever.get_scores(tokens)) for tokens in nayose_query_tokens]
192
+ )
193
+ span_entity_logits += nayose_coef * nayose_scores
194
+
195
+ _, topk_span_entity_ids = span_entity_logits.topk(k)
196
  topk_span_entities.append(
197
  [[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()]
198
  )
 
231
  texts = gr.State([])
232
  topk = gr.State(5)
233
  entity_span_sensitivity = gr.State(1.0)
234
+ nayose_coef = gr.State(0.0)
235
  batch_entity_spans = gr.State([])
236
  topk_normal_entities = gr.State([])
237
  topk_category_entities = gr.State([])
 
251
  entity_span_sensitivity_input.change(
252
  fn=lambda val: val, inputs=entity_span_sensitivity_input, outputs=entity_span_sensitivity
253
  )
254
+ nayose_coef_input = gr.Slider(
255
+ minimum=0.0, maximum=2.0, value=0.0, step=0.1, label="Nayose Coefficient", interactive=True
256
+ )
257
+ nayose_coef_input.change(fn=lambda val: val, inputs=nayose_coef_input, outputs=nayose_coef)
258
 
259
  texts.change(
260
  fn=get_topk_entities_from_texts,
261
+ inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
262
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
263
  )
264
  topk.change(
265
  fn=get_topk_entities_from_texts,
266
+ inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
267
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
268
  )
269
  entity_span_sensitivity.change(
270
  fn=get_topk_entities_from_texts,
271
+ inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
272
+ outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
273
+ )
274
+ nayose_coef.change(
275
+ fn=get_topk_entities_from_texts,
276
+ inputs=[texts, topk, entity_span_sensitivity, nayose_coef],
277
  outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
278
  )
279