singletongue commited on
Commit
dde7d2a
·
verified ·
1 Parent(s): a4d01d8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForPreTraining, AutoTokenizer
3
+
4
+
5
+ repo_id = "studio-ousia/luxe"
6
+ revision = "ja-v0.1"
7
+
8
+ model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
9
+ tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
10
+
11
+ num_normal_entities = len(tokenizer.entity_vocab) - model.config.num_category_entities
12
+ num_category_entities = model.config.num_category_entities
13
+
14
+ id2normal_entity = {
15
+ entity_id: entity for entity, entity_id in tokenizer.entity_vocab.items() if entity_id < num_normal_entities
16
+ }
17
+
18
+ id2category_entity = {
19
+ entity_id - num_normal_entities: entity
20
+ for entity, entity_id in tokenizer.entity_vocab.items()
21
+ if entity_id >= num_normal_entities
22
+ }
23
+
24
+ entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
25
+ normal_entity_embeddings = entity_embeddings[:num_normal_entities]
26
+ category_entity_embeddings = entity_embeddings[num_normal_entities:]
27
+
28
+
29
+ def get_texts_from_file(file_path):
30
+ texts = []
31
+ with open(file_path) as f:
32
+ for line in f:
33
+ line = line.strip()
34
+ if line:
35
+ texts.append(line)
36
+
37
+ return texts
38
+
39
+
40
+ def get_topk_entities_from_texts(texts: list[str], k: int = 5) -> tuple[list[list[str]], list[list[str]]]:
41
+ topk_normal_entities = []
42
+ topk_category_entities = []
43
+
44
+ for text in texts:
45
+ tokenized_examples = tokenizer(text, return_tensors="pt")
46
+ model_outputs = model(**tokenized_examples)
47
+
48
+ _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
49
+ topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
50
+
51
+ _, topk_category_entity_ids = model_outputs.topic_category_logits[0].topk(k)
52
+ topk_category_entities.append([id2category_entity[id_] for id_ in topk_category_entity_ids.tolist()])
53
+
54
+ return topk_normal_entities, topk_category_entities
55
+
56
+
57
+ def get_selected_entity(evt: gr.SelectData):
58
+ return evt.value[0]
59
+
60
+
61
+ def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
62
+ query_entity_id = tokenizer.entity_vocab[query_entity]
63
+
64
+ if query_entity_id < num_normal_entities:
65
+ topk_entity_scores = normal_entity_embeddings[query_entity_id] @ normal_entity_embeddings.T
66
+ topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
67
+ topk_entities = [id2normal_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
68
+ else:
69
+ query_entity_id -= num_normal_entities
70
+ topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
71
+ topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
72
+ topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
73
+
74
+ return topk_entities
75
+
76
+
77
+ with gr.Blocks() as demo:
78
+ gr.Markdown("## テキスト(直接入力またはファイルアップロード)")
79
+
80
+ texts = gr.State([])
81
+ topk_normal_entities = gr.State([])
82
+ topk_category_entities = gr.State([])
83
+ selected_entity = gr.State()
84
+ similar_entities = gr.State([])
85
+
86
+ text_input = gr.Textbox(label="Input Text")
87
+ texts_file = gr.File(label="Input texts")
88
+
89
+ text_input.change(fn=lambda text: [text], inputs=text_input, outputs=texts)
90
+ texts_file.change(fn=get_texts_from_file, inputs=texts_file, outputs=texts)
91
+ texts.change(fn=get_topk_entities_from_texts, inputs=texts, outputs=[topk_normal_entities, topk_category_entities])
92
+
93
+ gr.Markdown("---")
94
+ gr.Markdown("## 出力エンティティ")
95
+
96
+ @gr.render(inputs=[texts, topk_normal_entities, topk_category_entities])
97
+ def render_topk_entities(texts, topk_normal_entities, topk_category_entities):
98
+ for text, normal_entities, category_entities in zip(texts, topk_normal_entities, topk_category_entities):
99
+ gr.Textbox(text, label="Text")
100
+ entities = gr.Dataset(
101
+ label="Entities",
102
+ components=["text"],
103
+ samples=[[entity] for entity in normal_entities + category_entities],
104
+ )
105
+ entities.select(fn=get_selected_entity, outputs=selected_entity)
106
+
107
+ gr.Markdown("---")
108
+ gr.Markdown("## 選択されたエンティティの類似エンティティ")
109
+
110
+ selected_entity.change(fn=get_similar_entities, inputs=selected_entity, outputs=similar_entities)
111
+
112
+ @gr.render(inputs=[selected_entity, similar_entities])
113
+ def render_similar_entities(selected_entity, similar_entities):
114
+ gr.Textbox(selected_entity, label="Selected Entity")
115
+ gr.Dataset(label="Similar Entities", components=["text"], samples=[[entity] for entity in similar_entities])
116
+
117
+
118
+ demo.launch()