Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Set maximum values for input text length and number of lines in input files
Browse files
app.py
CHANGED
@@ -16,6 +16,9 @@ from transformers import AutoModelForPreTraining, AutoTokenizer
|
|
16 |
|
17 |
ALIAS_SEP = "|"
|
18 |
ENTITY_SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[MASK]", "[MASK2]"]
|
|
|
|
|
|
|
19 |
|
20 |
repo_id = "studio-ousia/luxe"
|
21 |
revision = "ja-v0.3.1"
|
@@ -68,10 +71,14 @@ def get_texts_from_file(file_path: str | None):
|
|
68 |
try:
|
69 |
with open(file_path, newline="") as f:
|
70 |
reader = csv.DictReader(f, fieldnames=["text"])
|
71 |
-
for row in reader:
|
|
|
|
|
|
|
|
|
72 |
text = normalize_text(row["text"]).strip()
|
73 |
if text != "":
|
74 |
-
texts.append(text)
|
75 |
except Exception as e:
|
76 |
gr.Warning("ファイルを正しく読み込めませんでした。")
|
77 |
print(e)
|
@@ -144,7 +151,7 @@ def get_topk_entities_from_texts(
|
|
144 |
k: int = 5,
|
145 |
entity_span_sensitivity: float = 1.0,
|
146 |
nayose_coef: float = 1.0,
|
147 |
-
|
148 |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
|
149 |
model, tokenizer, bm25_tokenizer, bm25_retriever = models
|
150 |
|
@@ -196,7 +203,7 @@ def get_topk_entities_from_texts(
|
|
196 |
if model_outputs.entity_logits is not None:
|
197 |
span_entity_logits = model_outputs.entity_logits[0, :, :500000]
|
198 |
|
199 |
-
if nayose_coef > 0.0 and
|
200 |
nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans]
|
201 |
nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries)
|
202 |
nayose_scores = torch.vstack(
|
@@ -265,7 +272,11 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
|
|
265 |
try:
|
266 |
with open(file_path, newline="") as f:
|
267 |
reader = csv.DictReader(f, fieldnames=["entity", "text"])
|
268 |
-
for row in reader:
|
|
|
|
|
|
|
|
|
269 |
entity = normalize_text(row["entity"]).strip()
|
270 |
text = normalize_text(row["text"]).strip()
|
271 |
if entity != "" and text != "":
|
@@ -281,6 +292,7 @@ def get_new_entity_text_pairs_from_file(file_path: str | None) -> list[list[str]
|
|
281 |
def replace_entities(
|
282 |
models,
|
283 |
new_entity_text_pairs: list[tuple[str, str]],
|
|
|
284 |
new_num_category_entities: int = 0,
|
285 |
new_entity_counts: list[int] | None = None,
|
286 |
new_padding_idx: int = 0,
|
@@ -314,7 +326,7 @@ def replace_entities(
|
|
314 |
|
315 |
for entity, text in new_entity_text_pairs:
|
316 |
entity_id = tokenizer.entity_vocab[entity]
|
317 |
-
tokenized_inputs = tokenizer(text, return_tensors="pt")
|
318 |
model_outputs = model(**tokenized_inputs)
|
319 |
entity_embeddings = model.entity_predictions.transform(model_outputs.last_hidden_state[:, 0])
|
320 |
new_entity_embeddings_dict[entity_id].append(entity_embeddings[0])
|
@@ -363,7 +375,7 @@ def replace_entities(
|
|
363 |
|
364 |
gr.Info("モデルとトークナイザのエンティティの置き換えが完了しました", duration=5)
|
365 |
|
366 |
-
return
|
367 |
|
368 |
|
369 |
with gr.Blocks() as demo:
|
@@ -381,7 +393,7 @@ with gr.Blocks() as demo:
|
|
381 |
|
382 |
texts = gr.State([])
|
383 |
|
384 |
-
|
385 |
|
386 |
topk = gr.State(5)
|
387 |
entity_span_sensitivity = gr.State(1.0)
|
@@ -400,12 +412,14 @@ with gr.Blocks() as demo:
|
|
400 |
gr.Markdown("## 入力テキスト")
|
401 |
|
402 |
with gr.Tab(label="直接入力"):
|
403 |
-
text_input = gr.Textbox(label="
|
404 |
with gr.Tab(label="ファイルアップロード"):
|
405 |
-
texts_file = gr.File(label="
|
406 |
|
407 |
with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
|
408 |
-
new_entity_text_pairs_file = gr.File(
|
|
|
|
|
409 |
new_entity_text_pairs_input = gr.Dataframe(
|
410 |
# value=sample_new_entity_text_pairs,
|
411 |
headers=["entity", "text"],
|
@@ -420,7 +434,9 @@ with gr.Blocks() as demo:
|
|
420 |
fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
|
421 |
)
|
422 |
replace_entity_button.click(
|
423 |
-
fn=replace_entities,
|
|
|
|
|
424 |
)
|
425 |
|
426 |
with gr.Accordion(label="ハイパーパラメータ", open=False):
|
@@ -442,28 +458,28 @@ with gr.Blocks() as demo:
|
|
442 |
|
443 |
texts.change(
|
444 |
fn=get_topk_entities_from_texts,
|
445 |
-
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef,
|
446 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
447 |
)
|
448 |
topk.change(
|
449 |
fn=get_topk_entities_from_texts,
|
450 |
-
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef,
|
451 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
452 |
)
|
453 |
entity_span_sensitivity.change(
|
454 |
fn=get_topk_entities_from_texts,
|
455 |
-
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef,
|
456 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
457 |
)
|
458 |
nayose_coef.change(
|
459 |
fn=get_topk_entities_from_texts,
|
460 |
-
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef,
|
461 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
462 |
)
|
463 |
|
464 |
-
|
465 |
fn=get_topk_entities_from_texts,
|
466 |
-
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef,
|
467 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
468 |
)
|
469 |
|
|
|
16 |
|
17 |
ALIAS_SEP = "|"
|
18 |
ENTITY_SPECIAL_TOKENS = ["[PAD]", "[UNK]", "[MASK]", "[MASK2]"]
|
19 |
+
MAX_TEXT_LENGTH = 800
|
20 |
+
MAX_TEXT_FILE_LINES = 100
|
21 |
+
MAX_ENTITY_FILE_LINES = 1000
|
22 |
|
23 |
repo_id = "studio-ousia/luxe"
|
24 |
revision = "ja-v0.3.1"
|
|
|
71 |
try:
|
72 |
with open(file_path, newline="") as f:
|
73 |
reader = csv.DictReader(f, fieldnames=["text"])
|
74 |
+
for i, row in enumerate(reader):
|
75 |
+
if i >= MAX_TEXT_FILE_LINES:
|
76 |
+
gr.Info(f"{MAX_TEXT_FILE_LINES}行目までのデータを読み込みました。")
|
77 |
+
break
|
78 |
+
|
79 |
text = normalize_text(row["text"]).strip()
|
80 |
if text != "":
|
81 |
+
texts.append(text[:MAX_TEXT_LENGTH])
|
82 |
except Exception as e:
|
83 |
gr.Warning("ファイルを正しく読み込めませんでした。")
|
84 |
print(e)
|
|
|
151 |
k: int = 5,
|
152 |
entity_span_sensitivity: float = 1.0,
|
153 |
nayose_coef: float = 1.0,
|
154 |
+
entity_replaced_counts: bool = False,
|
155 |
) -> tuple[list[list[tuple[int, int]]], list[list[str]], list[list[str]], list[list[list[str]]]]:
|
156 |
model, tokenizer, bm25_tokenizer, bm25_retriever = models
|
157 |
|
|
|
203 |
if model_outputs.entity_logits is not None:
|
204 |
span_entity_logits = model_outputs.entity_logits[0, :, :500000]
|
205 |
|
206 |
+
if nayose_coef > 0.0 and entity_replaced_counts == 0:
|
207 |
nayose_queries = ["ja:" + text[start:end] for start, end in entity_spans]
|
208 |
nayose_query_tokens = bm25_tokenizer.tokenize(nayose_queries)
|
209 |
nayose_scores = torch.vstack(
|
|
|
272 |
try:
|
273 |
with open(file_path, newline="") as f:
|
274 |
reader = csv.DictReader(f, fieldnames=["entity", "text"])
|
275 |
+
for i, row in enumerate(reader):
|
276 |
+
if i >= MAX_ENTITY_FILE_LINES:
|
277 |
+
gr.Info(f"{MAX_ENTITY_FILE_LINES}行目までのデータを読み込みました。")
|
278 |
+
break
|
279 |
+
|
280 |
entity = normalize_text(row["entity"]).strip()
|
281 |
text = normalize_text(row["text"]).strip()
|
282 |
if entity != "" and text != "":
|
|
|
292 |
def replace_entities(
|
293 |
models,
|
294 |
new_entity_text_pairs: list[tuple[str, str]],
|
295 |
+
entity_replaced_counts: int,
|
296 |
new_num_category_entities: int = 0,
|
297 |
new_entity_counts: list[int] | None = None,
|
298 |
new_padding_idx: int = 0,
|
|
|
326 |
|
327 |
for entity, text in new_entity_text_pairs:
|
328 |
entity_id = tokenizer.entity_vocab[entity]
|
329 |
+
tokenized_inputs = tokenizer(text[:MAX_TEXT_LENGTH], return_tensors="pt")
|
330 |
model_outputs = model(**tokenized_inputs)
|
331 |
entity_embeddings = model.entity_predictions.transform(model_outputs.last_hidden_state[:, 0])
|
332 |
new_entity_embeddings_dict[entity_id].append(entity_embeddings[0])
|
|
|
375 |
|
376 |
gr.Info("モデルとトークナイザのエンティティの置き換えが完了しました", duration=5)
|
377 |
|
378 |
+
return entity_replaced_counts + 1
|
379 |
|
380 |
|
381 |
with gr.Blocks() as demo:
|
|
|
393 |
|
394 |
texts = gr.State([])
|
395 |
|
396 |
+
entity_replaced_counts = gr.State(0)
|
397 |
|
398 |
topk = gr.State(5)
|
399 |
entity_span_sensitivity = gr.State(1.0)
|
|
|
412 |
gr.Markdown("## 入力テキスト")
|
413 |
|
414 |
with gr.Tab(label="直接入力"):
|
415 |
+
text_input = gr.Textbox(label=f"入力テキスト(最大{MAX_TEXT_LENGTH}文字)", max_length=MAX_TEXT_LENGTH)
|
416 |
with gr.Tab(label="ファイルアップロード"):
|
417 |
+
texts_file = gr.File(label=f"入力テキストファイル(最大{MAX_TEXT_FILE_LINES}行)")
|
418 |
|
419 |
with gr.Accordion(label="LUXEのエンティティ語彙を置き換える", open=False):
|
420 |
+
new_entity_text_pairs_file = gr.File(
|
421 |
+
label=f"エンティティと説明文のCSVファイル(最大{MAX_ENTITY_FILE_LINES}行)"
|
422 |
+
)
|
423 |
new_entity_text_pairs_input = gr.Dataframe(
|
424 |
# value=sample_new_entity_text_pairs,
|
425 |
headers=["entity", "text"],
|
|
|
434 |
fn=get_new_entity_text_pairs_from_file, inputs=new_entity_text_pairs_file, outputs=new_entity_text_pairs_input
|
435 |
)
|
436 |
replace_entity_button.click(
|
437 |
+
fn=replace_entities,
|
438 |
+
inputs=[models, new_entity_text_pairs_input, entity_replaced_counts],
|
439 |
+
outputs=entity_replaced_counts,
|
440 |
)
|
441 |
|
442 |
with gr.Accordion(label="ハイパーパラメータ", open=False):
|
|
|
458 |
|
459 |
texts.change(
|
460 |
fn=get_topk_entities_from_texts,
|
461 |
+
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
462 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
463 |
)
|
464 |
topk.change(
|
465 |
fn=get_topk_entities_from_texts,
|
466 |
+
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
467 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
468 |
)
|
469 |
entity_span_sensitivity.change(
|
470 |
fn=get_topk_entities_from_texts,
|
471 |
+
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
472 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
473 |
)
|
474 |
nayose_coef.change(
|
475 |
fn=get_topk_entities_from_texts,
|
476 |
+
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
477 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
478 |
)
|
479 |
|
480 |
+
entity_replaced_counts.change(
|
481 |
fn=get_topk_entities_from_texts,
|
482 |
+
inputs=[models, texts, topk, entity_span_sensitivity, nayose_coef, entity_replaced_counts],
|
483 |
outputs=[batch_entity_spans, topk_normal_entities, topk_category_entities, topk_span_entities],
|
484 |
)
|
485 |
|