Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import AutoModelForPreTraining, AutoTokenizer
|
3 |
+
|
4 |
+
|
5 |
+
repo_id = "studio-ousia/luxe"
|
6 |
+
revision = "ja-v0.1"
|
7 |
+
|
8 |
+
model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
10 |
+
|
11 |
+
num_normal_entities = len(tokenizer.entity_vocab) - model.config.num_category_entities
|
12 |
+
num_category_entities = model.config.num_category_entities
|
13 |
+
|
14 |
+
id2normal_entity = {
|
15 |
+
entity_id: entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id < num_normal_entities
|
16 |
+
}
|
17 |
+
|
18 |
+
id2category_entity = {
|
19 |
+
entity_id - num_normal_entities: entity
|
20 |
+
for entity, entity_id in tokenizer.entity_vocab.items()
|
21 |
+
if entity_id >= num_normal_entities
|
22 |
+
}
|
23 |
+
|
24 |
+
entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
|
25 |
+
normal_entity_embeddings = entity_embeddings[:num_normal_entities]
|
26 |
+
category_entity_embeddings = entity_embeddings[num_normal_entities:]
|
27 |
+
|
28 |
+
|
29 |
+
def get_texts_from_file(file_path):
|
30 |
+
texts = []
|
31 |
+
with open(file_path) as f:
|
32 |
+
for line in f:
|
33 |
+
line = line.strip()
|
34 |
+
if line:
|
35 |
+
texts.append(line)
|
36 |
+
|
37 |
+
return texts
|
38 |
+
|
39 |
+
|
40 |
+
def get_topk_entities_from_texts(texts: list[str], k: int = 5) -> tuple[list[list[str]], list[list[str]]]:
|
41 |
+
topk_normal_entities = []
|
42 |
+
topk_category_entities = []
|
43 |
+
|
44 |
+
for text in texts:
|
45 |
+
tokenized_examples = tokenizer(text, return_tensors="pt")
|
46 |
+
model_outputs = model(**tokenized_examples)
|
47 |
+
|
48 |
+
_, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
|
49 |
+
topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
|
50 |
+
|
51 |
+
_, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
|
52 |
+
topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
|
53 |
+
|
54 |
+
return topk_normal_entities, topk_category_entities
|
55 |
+
|
56 |
+
|
57 |
+
def get_selected_entity(evt: gr.SelectData):
|
58 |
+
return evt.value[0]
|
59 |
+
|
60 |
+
|
61 |
+
def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
|
62 |
+
query_entity_id = tokenizer.entity_vocab[query_entity]
|
63 |
+
|
64 |
+
if query_entity_id < num_normal_entities:
|
65 |
+
topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T
|
66 |
+
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
|
67 |
+
topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
|
68 |
+
else:
|
69 |
+
query_entity_id -= num_normal_entities
|
70 |
+
topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
|
71 |
+
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
|
72 |
+
topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
|
73 |
+
|
74 |
+
return topk_entities
|
75 |
+
|
76 |
+
|
77 |
+
with gr.Blocks() as demo:
|
78 |
+
gr.Markdown("## テキスト(直接入力またはファイルアップロード)")
|
79 |
+
|
80 |
+
texts = gr.State([])
|
81 |
+
topk_normal_entities = gr.State([])
|
82 |
+
topk_category_entities = gr.State([])
|
83 |
+
selected_entity = gr.State()
|
84 |
+
similar_entities = gr.State([])
|
85 |
+
|
86 |
+
text_input = gr.Textbox(label="Input Text")
|
87 |
+
texts_file = gr.File(label="Input texts")
|
88 |
+
|
89 |
+
text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts)
|
90 |
+
texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
|
91 |
+
texts.change(fn=get_topk_entities_from_texts, inputs=texts, outputs=[topk_normal_entities, topk_category_entities])
|
92 |
+
|
93 |
+
gr.Markdown("---")
|
94 |
+
gr.Markdown("## 出力エンティティ")
|
95 |
+
|
96 |
+
@gr.render(inputs=[texts, topk_normal_entities, topk_category_entities])
|
97 |
+
def render_topk_entities(texts, topk_normal_entities, topk_category_entities):
|
98 |
+
for text, normal_entities, category_entities in zip(texts, topk_normal_entities, topk_category_entities):
|
99 |
+
gr.Textbox(text, label="Text")
|
100 |
+
entities = gr.Dataset(
|
101 |
+
label="Entities",
|
102 |
+
components=["text"],
|
103 |
+
samples=[[entity] for entity in normal_entities + category_entities],
|
104 |
+
)
|
105 |
+
entities.select(fn=get_selected_entity, outputs=selected_entity)
|
106 |
+
|
107 |
+
gr.Markdown("---")
|
108 |
+
gr.Markdown("## 選択されたエンティティの類似エンティティ")
|
109 |
+
|
110 |
+
selected_entity.change(fn=get_similar_entities, inputs=selected_entity, outputs=similar_entities)
|
111 |
+
|
112 |
+
@gr.render(inputs=[selected_entity, similar_entities])
|
113 |
+
def render_similar_entities(selected_entity, similar_entities):
|
114 |
+
gr.Textbox(selected_entity, label="Selected Entity")
|
115 |
+
gr.Dataset(label="Similar Entities", components=["text"], samples=[[entity] for entity in similar_entities])
|
116 |
+
|
117 |
+
|
118 |
+
demo.launch()
|