Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Use ja-v0.2 model, ignore categories of some patterns
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from pathlib import Path
|
2 |
|
3 |
import gradio as gr
|
@@ -7,7 +8,18 @@ from transformers import AutoModelForPreTraining, AutoTokenizer
|
|
7 |
|
8 |
|
9 |
repo_id = "studio-ousia/luxe"
|
10 |
-
revision = "ja-v0.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
13 |
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
@@ -24,6 +36,11 @@ id2category_entity = {
|
|
24 |
for entity, entity_id in tokenizer.entity_vocab.items()
|
25 |
if entity_id >= num_normal_entities
|
26 |
}
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
|
29 |
normal_entity_embeddings = entity_embeddings[:num_normal_entities]
|
@@ -94,6 +111,8 @@ def get_topk_entities_from_texts(
|
|
94 |
tokenized_examples = tokenizer(text, entity_spans=noun_spans, return_tensors="pt")
|
95 |
model_outputs = model(**tokenized_examples)
|
96 |
|
|
|
|
|
97 |
_, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
|
98 |
topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
|
99 |
|
@@ -120,6 +139,9 @@ def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
|
|
120 |
else:
|
121 |
query_entity_id -= num_normal_entities
|
122 |
topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
|
|
|
|
|
|
|
123 |
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
|
124 |
topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
|
125 |
|
|
|
1 |
+
import re
|
2 |
from pathlib import Path
|
3 |
|
4 |
import gradio as gr
|
|
|
8 |
|
9 |
|
10 |
repo_id = "studio-ousia/luxe"
|
11 |
+
revision = "ja-v0.2"
|
12 |
+
|
13 |
+
ignore_category_patterns = [
|
14 |
+
r"\d+年",
|
15 |
+
r"楽曲 [ぁ-ん]",
|
16 |
+
r"漫画作品 [ぁ-ん]",
|
17 |
+
r"アニメ作品 [ぁ-ん]",
|
18 |
+
r"アニメ作品 [ぁ-ん]",
|
19 |
+
r"の一覧",
|
20 |
+
r"各国の",
|
21 |
+
r"各年の",
|
22 |
+
]
|
23 |
|
24 |
model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
25 |
tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
|
|
|
36 |
for entity, entity_id in tokenizer.entity_vocab.items()
|
37 |
if entity_id >= num_normal_entities
|
38 |
}
|
39 |
+
ignore_category_entity_ids = [
|
40 |
+
entity_id - num_normal_entities
|
41 |
+
for entity, entity_id in tokenizer.entity_vocab.items()
|
42 |
+
if entity_id >= num_normal_entities and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
|
43 |
+
]
|
44 |
|
45 |
entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
|
46 |
normal_entity_embeddings = entity_embeddings[:num_normal_entities]
|
|
|
111 |
tokenized_examples = tokenizer(text, entity_spans=noun_spans, return_tensors="pt")
|
112 |
model_outputs = model(**tokenized_examples)
|
113 |
|
114 |
+
model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
|
115 |
+
|
116 |
_, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
|
117 |
topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
|
118 |
|
|
|
139 |
else:
|
140 |
query_entity_id -= num_normal_entities
|
141 |
topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
|
142 |
+
|
143 |
+
topk_entity_scores[ignore_category_entity_ids] = float("-inf")
|
144 |
+
|
145 |
topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
|
146 |
topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
|
147 |
|