singletongue commited on
Commit
647335c
verified
1 Parent(s): ac150d7

Add span-based entity linking

Browse files
Files changed (1) hide show
  1. app.py +91 -16
app.py CHANGED
@@ -1,4 +1,8 @@
 
 
1
  import gradio as gr
 
 
2
  from transformers import AutoModelForPreTraining, AutoTokenizer
3
 
4
 
@@ -26,6 +30,30 @@ normal_entity_embeddings = entity_embeddings[:num_normal_entities]
26
  category_entity_embeddings = entity_embeddings[num_normal_entities:]
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_texts_from_file(file_path):
30
  texts = []
31
  with open(file_path) as f:
@@ -37,12 +65,33 @@ def get_texts_from_file(file_path):
37
  return texts
38
 
39
 
40
- def get_topk_entities_from_texts(texts: list[str], k: int = 5) -> tuple[list[list[str]], list[list[str]]]:
41
- topk_normal_entities = []
42
- topk_category_entities = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  for text in texts:
45
- tokenized_examples = tokenizer(text, return_tensors="pt")
 
 
46
  model_outputs = model(**tokenized_examples)
47
 
48
  _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
@@ -51,7 +100,10 @@ def get_topk_entities_from_texts(texts: list[str], k: int = 5) -> tuple[list[lis
51
  _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
52
  topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
53
 
54
- return topk_normal_entities, topk_category_entities
 
 
 
55
 
56
 
57
  def get_selected_entity(evt: gr.SelectData):
@@ -80,29 +132,52 @@ with gr.Blocks() as demo:
80
  texts = gr.State([])
81
  topk_normal_entities = gr.State([])
82
  topk_category_entities = gr.State([])
 
83
  selected_entity = gr.State()
84
  similar_entities = gr.State([])
85
 
86
  text_input = gr.Textbox(label="Input Text")
87
- texts_file = gr.File(label="Input texts")
88
 
89
  text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts)
90
  texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
91
- texts.change(fn=get_topk_entities_from_texts, inputs=texts, outputs=[topk_normal_entities, topk_category_entities])
 
 
 
 
92
 
93
  gr.Markdown("---")
94
  gr.Markdown("## 鍑哄姏銈ㄣ兂銉嗐偅銉嗐偅")
95
 
96
- @gr.render(inputs=[texts, topk_normal_entities, topk_category_entities])
97
- def render_topk_entities(texts, topk_normal_entities, topk_category_entities):
98
- for text, normal_entities, category_entities in zip(texts, topk_normal_entities, topk_category_entities):
99
- gr.Textbox(text, label="Text")
100
- entities = gr.Dataset(
101
- label="Entities",
102
- components=["text"],
103
- samples=[[entity] for entity in normal_entities + category_entities],
 
 
 
 
104
  )
105
- entities.select(fn=get_selected_entity, outputs=selected_entity)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  gr.Markdown("---")
108
  gr.Markdown("## 閬告姙銇曘倢銇熴偍銉炽儐銈c儐銈c伄椤炰技銈ㄣ兂銉嗐偅銉嗐偅")
 
1
+ from pathlib import Path
2
+
3
  import gradio as gr
4
+ import unidic_lite
5
+ from fugashi import GenericTagger
6
  from transformers import AutoModelForPreTraining, AutoTokenizer
7
 
8
 
 
30
  category_entity_embeddings = entity_embeddings[num_normal_entities:]
31
 
32
 
33
+ class MecabTokenizer:
34
+ def __init__(self):
35
+ unidic_dir = unidic_lite.DICDIR
36
+ mecabrc_file = Path(unidic_dir, "mecabrc")
37
+ mecab_option = f"-d {unidic_dir} -r {mecabrc_file}"
38
+ self.tagger = GenericTagger(mecab_option)
39
+
40
+ def __call__(self, text: str) -> list[tuple[str, str, tuple[int, int]]]:
41
+ outputs = []
42
+
43
+ end = 0
44
+ for node in self.tagger(text):
45
+ word = node.surface.strip()
46
+ pos = node.feature[0]
47
+ start = text.index(word, end)
48
+ end = start + len(word)
49
+ outputs.append((word, pos, (start, end)))
50
+
51
+ return outputs
52
+
53
+
54
+ mecab_tokenizer = MecabTokenizer()
55
+
56
+
57
  def get_texts_from_file(file_path):
58
  texts = []
59
  with open(file_path) as f:
 
65
  return texts
66
 
67
 
68
+ def get_noun_spans_from_text(text: str) -> list[tuple[int, int]]:
69
+ last_pos = None
70
+ noun_spans = []
71
+
72
+ for word, pos, (start, end) in mecab_tokenizer(text):
73
+ if pos == "鍚嶈":
74
+ if len(noun_spans) > 0 and last_pos == "鍚嶈":
75
+ noun_spans[-1] = (noun_spans[-1][0], end)
76
+ else:
77
+ noun_spans.append((start, end))
78
+
79
+ last_pos = pos
80
+
81
+ return noun_spans
82
+
83
+
84
+ def get_topk_entities_from_texts(
85
+ texts: list[str], k: int = 5
86
+ ) -> tuple[list[list[str]], list[list[str]], list[list[list[str]]]]:
87
+ topk_normal_entities: list[list[str]] = []
88
+ topk_category_entities: list[list[str]] = []
89
+ topk_span_entities: list[list[list[str]]] = []
90
 
91
  for text in texts:
92
+ noun_spans = get_noun_spans_from_text(text)
93
+
94
+ tokenized_examples = tokenizer(text, entity_spans=noun_spans, return_tensors="pt")
95
  model_outputs = model(**tokenized_examples)
96
 
97
  _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
 
100
  _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
101
  topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
102
 
103
+ _, topk_span_entity_ids = model_outputs.entity_logits[0, :, :500000].topk(k)
104
+ topk_span_entities.append([[id2normal_entity[id_] for id_ in ids] for ids in topk_span_entity_ids.tolist()])
105
+
106
+ return topk_normal_entities, topk_category_entities, topk_span_entities
107
 
108
 
109
  def get_selected_entity(evt: gr.SelectData):
 
132
  texts = gr.State([])
133
  topk_normal_entities = gr.State([])
134
  topk_category_entities = gr.State([])
135
+ topk_span_entities = gr.State([])
136
  selected_entity = gr.State()
137
  similar_entities = gr.State([])
138
 
139
  text_input = gr.Textbox(label="Input Text")
140
+ texts_file = gr.File(label="Input Texts")
141
 
142
  text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts)
143
  texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
144
+ texts.change(
145
+ fn=get_topk_entities_from_texts,
146
+ inputs=texts,
147
+ outputs=[topk_normal_entities, topk_category_entities, topk_span_entities],
148
+ )
149
 
150
  gr.Markdown("---")
151
  gr.Markdown("## 鍑哄姏銈ㄣ兂銉嗐偅銉嗐偅")
152
 
153
+ @gr.render(inputs=[texts, topk_normal_entities, topk_category_entities, topk_span_entities])
154
+ def render_topk_entities(texts, topk_normal_entities, topk_category_entities, topk_span_entities):
155
+ for text, normal_entities, category_entities, span_entities in zip(
156
+ texts, topk_normal_entities, topk_category_entities, topk_span_entities
157
+ ):
158
+ gr.HighlightedText(
159
+ value=[(word, pos if pos == "鍚嶈" else None) for word, pos, _ in mecab_tokenizer(text)],
160
+ color_map={"鍚嶈": "green"},
161
+ show_legend=True,
162
+ combine_adjacent=True,
163
+ adjacent_separator=" ",
164
+ label="Text",
165
  )
166
+
167
+ # gr.Textbox(text, label="Text")
168
+ gr.Dataset(
169
+ label="Topic Entities", components=["text"], samples=[[entity] for entity in normal_entities]
170
+ ).select(fn=get_selected_entity, outputs=selected_entity)
171
+ gr.Dataset(
172
+ label="Topic Categories", components=["text"], samples=[[entity] for entity in category_entities]
173
+ ).select(fn=get_selected_entity, outputs=selected_entity)
174
+
175
+ noun_spans = get_noun_spans_from_text(text)
176
+ nouns = [text[start:end] for start, end in noun_spans]
177
+ for noun, entities in zip(nouns, span_entities):
178
+ gr.Dataset(
179
+ label=f"Span Entities for {noun}", components=["text"], samples=[[entity] for entity in entities]
180
+ ).select(fn=get_selected_entity, outputs=selected_entity)
181
 
182
  gr.Markdown("---")
183
  gr.Markdown("## 閬告姙銇曘倢銇熴偍銉炽儐銈c儐銈c伄椤炰技銈ㄣ兂銉嗐偅銉嗐偅")