singletongue commited on
Commit
c5df237
·
verified ·
1 Parent(s): 647335c

Use ja-v0.2 model, ignore categories of some patterns

Browse files
Files changed (1) hide show
  1. app.py +23 -1
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from pathlib import Path
2
 
3
  import gradio as gr
@@ -7,7 +8,18 @@ from transformers import AutoModelForPreTraining, AutoTokenizer
7
 
8
 
9
  repo_id = "studio-ousia/luxe"
10
- revision = "ja-v0.1"
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
13
  tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
@@ -24,6 +36,11 @@ id2category_entity = {
24
  for entity, entity_id in tokenizer.entity_vocab.items()
25
  if entity_id >= num_normal_entities
26
  }
 
 
 
 
 
27
 
28
  entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
29
  normal_entity_embeddings = entity_embeddings[:num_normal_entities]
@@ -94,6 +111,8 @@ def get_topk_entities_from_texts(
94
  tokenized_examples = tokenizer(text, entity_spans=noun_spans, return_tensors="pt")
95
  model_outputs = model(**tokenized_examples)
96
 
 
 
97
  _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
98
  topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
99
 
@@ -120,6 +139,9 @@ def get_similar_entities(query_entity: str, k: int = 10) -> list[str]:
120
  else:
121
  query_entity_id -= num_normal_entities
122
  topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
 
 
 
123
  topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
124
  topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
125
 
 
1
+ import re
2
  from pathlib import Path
3
 
4
  import gradio as gr
 
8
 
9
 
10
  repo_id = "studio-ousia/luxe"
11
+ revision = "ja-v0.2"
12
+
13
+ ignore_category_patterns = [
14
+ r"\d+年",
15
+ r"楽曲 [ぁ-ん]",
16
+ r"漫画作品 [ぁ-ん]",
17
+ r"アニメ作品 [ぁ-ん]",
18
+ r"アニメ作品 [ぁ-ん]",
19
+ r"の一覧",
20
+ r"各国の",
21
+ r"各年の",
22
+ ]
23
 
24
  model = AutoModelForPreTraining.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
25
  tokenizer = AutoTokenizer.from_pretrained(repo_id, revision=revision, trust_remote_code=True)
 
36
  for entity, entity_id in tokenizer.entity_vocab.items()
37
  if entity_id >= num_normal_entities
38
  }
39
+ ignore_category_entity_ids = [
40
+ entity_id - num_normal_entities
41
+ for entity, entity_id in tokenizer.entity_vocab.items()
42
+ if entity_id >= num_normal_entities and any(re.search(pattern, entity) for pattern in ignore_category_patterns)
43
+ ]
44
 
45
  entity_embeddings = model.luke.entity_embeddings.entity_embeddings.weight
46
  normal_entity_embeddings = entity_embeddings[:num_normal_entities]
 
111
  tokenized_examples = tokenizer(text, entity_spans=noun_spans, return_tensors="pt")
112
  model_outputs = model(**tokenized_examples)
113
 
114
+ model_outputs.topic_category_logits[:, ignore_category_entity_ids] = float("-inf")
115
+
116
  _, topk_normal_entity_ids = model_outputs.topic_entity_logits[0].topk(k)
117
  topk_normal_entities.append([id2normal_entity[id_] for id_ in topk_normal_entity_ids.tolist()])
118
 
 
139
  else:
140
  query_entity_id -= num_normal_entities
141
  topk_entity_scores = category_entity_embeddings[query_entity_id] @ category_entity_embeddings.T
142
+
143
+ topk_entity_scores[ignore_category_entity_ids] = float("-inf")
144
+
145
  topk_entity_ids = topk_entity_scores.topk(k + 1).indices[1:]
146
  topk_entities = [id2category_entity[entity_id] for entity_id in topk_entity_ids.tolist()]
147