update-github
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +3 -4
- app.py +2 -3
- src/__pycache__/pipeline.cpython-311.pyc +0 -0
- src/__pycache__/pipeline.cpython-39.pyc +0 -0
- src/config.yaml +3 -2
- src/generate_memory.py +0 -181
- src/models/__init__.py +1 -1
- src/models/__pycache__/__init__.cpython-311.pyc +0 -0
- src/models/__pycache__/__init__.cpython-37.pyc +0 -0
- src/models/__pycache__/__init__.cpython-39.pyc +0 -0
- src/models/__pycache__/llm_def.cpython-311.pyc +0 -0
- src/models/__pycache__/llm_def.cpython-37.pyc +0 -0
- src/models/__pycache__/llm_def.cpython-39.pyc +0 -0
- src/models/__pycache__/prompt_example.cpython-311.pyc +0 -0
- src/models/__pycache__/prompt_example.cpython-39.pyc +0 -0
- src/models/__pycache__/prompt_template.cpython-311.pyc +0 -0
- src/models/__pycache__/prompt_template.cpython-39.pyc +0 -0
- src/models/llm_def.py +80 -13
- src/models/prompt_example.py +7 -7
- src/models/prompt_template.py +22 -1
- src/models/vllm_serve.py +34 -0
- src/modules/__pycache__/__init__.cpython-311.pyc +0 -0
- src/modules/__pycache__/__init__.cpython-39.pyc +0 -0
- src/modules/__pycache__/extraction_agent.cpython-311.pyc +0 -0
- src/modules/__pycache__/extraction_agent.cpython-39.pyc +0 -0
- src/modules/__pycache__/reflection_agent.cpython-311.pyc +0 -0
- src/modules/__pycache__/reflection_agent.cpython-39.pyc +0 -0
- src/modules/__pycache__/schema_agent.cpython-311.pyc +0 -0
- src/modules/__pycache__/schema_agent.cpython-39.pyc +0 -0
- src/modules/extraction_agent.py +28 -7
- src/modules/knowledge_base/__pycache__/case_repository.cpython-311.pyc +0 -0
- src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc +0 -0
- src/modules/knowledge_base/__pycache__/schema_repository.cpython-311.pyc +0 -0
- src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc +0 -0
- src/modules/knowledge_base/case_repository.py +135 -336
- src/modules/knowledge_base/schema_repository.py +1 -1
- src/modules/schema_agent.py +0 -3
- src/pipeline.py +43 -23
- src/run.py +21 -67
- src/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- src/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- src/utils/__pycache__/data_def.cpython-311.pyc +0 -0
- src/utils/__pycache__/data_def.cpython-39.pyc +0 -0
- src/utils/__pycache__/process.cpython-311.pyc +0 -0
- src/utils/__pycache__/process.cpython-39.pyc +0 -0
- src/utils/data_def.py +0 -1
- src/utils/process.py +58 -1
- src/{main.py → webui.py} +1 -5
- src/webui/__init__.py +0 -1
- src/webui/__pycache__/__init__.cpython-39.pyc +0 -0
README.md
CHANGED
@@ -1,13 +1,12 @@
|
|
1 |
---
|
2 |
title: OneKE
|
3 |
emoji: 👌🏻
|
4 |
-
colorFrom:
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.8.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
-
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
title: OneKE
|
3 |
emoji: 👌🏻
|
4 |
+
colorFrom: blue
|
5 |
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.8.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
+
short_description: Schema-Guided LLM Agent-based Knowledge Extraction System
|
12 |
+
---
|
|
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
-
import nltk
|
2 |
import subprocess
|
3 |
-
|
4 |
nltk.download('punkt')
|
5 |
nltk.download('punkt_tab')
|
6 |
|
7 |
-
subprocess.run(["python", "src/
|
|
|
|
|
1 |
import subprocess
|
2 |
+
import nltk
|
3 |
nltk.download('punkt')
|
4 |
nltk.download('punkt_tab')
|
5 |
|
6 |
+
subprocess.run(["python", "src/webui.py"])
|
src/__pycache__/pipeline.cpython-311.pyc
DELETED
Binary file (5.34 kB)
|
|
src/__pycache__/pipeline.cpython-39.pyc
CHANGED
Binary files a/src/__pycache__/pipeline.cpython-39.pyc and b/src/__pycache__/pipeline.cpython-39.pyc differ
|
|
src/config.yaml
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
agent:
|
2 |
default_schema: The final extraction result should be formatted as a JSON object.
|
3 |
default_ner: Extract the Named Entities in the given text.
|
@@ -15,5 +18,3 @@ agent:
|
|
15 |
customized:
|
16 |
schema_agent: get_retrieved_schema
|
17 |
extraction_agent: extract_information_direct
|
18 |
-
|
19 |
-
|
|
|
1 |
+
model:
|
2 |
+
embedding_model: all-MiniLM-L6-v2
|
3 |
+
|
4 |
agent:
|
5 |
default_schema: The final extraction result should be formatted as a JSON object.
|
6 |
default_ner: Extract the Named Entities in the given text.
|
|
|
18 |
customized:
|
19 |
schema_agent: get_retrieved_schema
|
20 |
extraction_agent: extract_information_direct
|
|
|
|
src/generate_memory.py
DELETED
@@ -1,181 +0,0 @@
|
|
1 |
-
from typing import Literal
|
2 |
-
from models import *
|
3 |
-
from utils import *
|
4 |
-
from modules import *
|
5 |
-
|
6 |
-
|
7 |
-
class Pipeline:
|
8 |
-
def __init__(self, llm: BaseEngine):
|
9 |
-
self.llm = llm
|
10 |
-
self.case_repo = CaseRepositoryHandler(llm = llm)
|
11 |
-
self.schema_agent = SchemaAgent(llm = llm)
|
12 |
-
self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
|
13 |
-
self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
|
14 |
-
|
15 |
-
def __init_method(self, data: DataPoint, process_method):
|
16 |
-
default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
|
17 |
-
if "schema_agent" not in process_method:
|
18 |
-
process_method["schema_agent"] = "get_default_schema"
|
19 |
-
if data.task != "Base":
|
20 |
-
process_method["schema_agent"] = "get_retrieved_schema"
|
21 |
-
if "extraction_agent" not in process_method:
|
22 |
-
process_method["extraction_agent"] = "extract_information_direct"
|
23 |
-
sorted_process_method = {key: process_method[key] for key in default_order if key in process_method}
|
24 |
-
return sorted_process_method
|
25 |
-
|
26 |
-
def __init_data(self, data: DataPoint):
|
27 |
-
if data.task == "NER":
|
28 |
-
data.instruction = config['agent']['default_ner']
|
29 |
-
data.output_schema = "EntityList"
|
30 |
-
elif data.task == "RE":
|
31 |
-
data.instruction = config['agent']['default_re']
|
32 |
-
data.output_schema = "RelationList"
|
33 |
-
elif data.task == "EE":
|
34 |
-
data.instruction = config['agent']['default_ee']
|
35 |
-
data.output_schema = "EventList"
|
36 |
-
return data
|
37 |
-
|
38 |
-
# main entry
|
39 |
-
def get_extract_result(self,
|
40 |
-
task: TaskType,
|
41 |
-
instruction: str = "",
|
42 |
-
text: str = "",
|
43 |
-
output_schema: str = "",
|
44 |
-
constraint: str = "",
|
45 |
-
use_file: bool = False,
|
46 |
-
truth: str = "",
|
47 |
-
mode: str = "quick",
|
48 |
-
update_case: bool = False
|
49 |
-
):
|
50 |
-
|
51 |
-
data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, truth=truth)
|
52 |
-
data = self.__init_data(data)
|
53 |
-
data.instruction = "In the tranquil seaside town, the summer evening cast a golden glow over everything. The townsfolk gathered at the café by the pier, enjoying the sea breeze while eagerly anticipating the annual Ocean Festival's opening ceremony. \nFirst to arrive was Mayor William, dressed in a deep blue suit, holding a roll of his speech. He smiled and greeted the residents, who held deep respect for their community-minded mayor. Beside him trotted Max, his loyal golden retriever, wagging his tail excitedly at every familiar face he saw. \nFollowing closely was Emily, the town’s high school teacher, accompanied by a group of students ready to perform a musical piece they'd rehearsed. One of the girls carried Polly, a vibrant green parrot, on her shoulder. Polly occasionally chimed in with cheerful squawks, adding to the lively atmosphere. \nNot far away, Captain Jack, with his trusty pipe in hand, chatted with old friends about this year's catch. His fleet was the town’s economic backbone, and his seasoned face and towering presence were complemented by the presence of Whiskers, his orange tabby cat, who loved lounging on the dock, attentively watching the gentle waves. \nInside the café, Kate was bustling about, serving guests. As the owner, with her fiery red curls and vivacious spirit, she was the heart of the place. Her friend Susan, an artist living in a tiny cottage nearby, was helping her prepare refreshing beverages. Slinky, Susan's mischievous ferret, darted playfully between the tables, much to the delight of the children present. \nLeaning on the café's railing, a young boy named Tommy watched the sea with wide, gleaming eyes, filled with dreams of the future. By his side sat Daisy, a spirited little dachshund, barking excitedly at the seagulls flying overhead. Tommy's mother, Lucy, stood beside him, smiling softly as she held a seashell he had just found on the beach. \nAmong the crowd, a group of unnamed tourists snapped photos, capturing memories of the charming festival. Street vendors called out, selling their wares—handmade jewelry and sweet confections—as the scent of grilled seafood wafted through the air. \nSuddenly, a burst of laughter erupted—it was James and his band making their grand entrance. Accompanying them was Benny, a friendly border collie who \"performed\" with the band, delighting the crowd with his antics. Set to play a big concert after the opening ceremony, James, the town's star musician, had won the hearts of locals with his soulful tunes. \nAs dusk settled, lights were strung across the streets, casting a magical glow over the town. Mayor William took the stage to deliver his speech, with Max sitting proudly by his side. The festival atmosphere reached its vibrant peak, and in this small town, each person—and animal—carried their own dreams and stories, yet at this moment, they were united by the shared celebration."
|
54 |
-
data.chunk_text_list.append("In the tranquil seaside town, the summer evening cast a golden glow over everything. The townsfolk gathered at the café by the pier, enjoying the sea breeze while eagerly anticipating the annual Ocean Festival's opening ceremony. \nFirst to arrive was Mayor William, dressed in a deep blue suit, holding a roll of his speech. He smiled and greeted the residents, who held deep respect for their community-minded mayor. Beside him trotted Max, his loyal golden retriever, wagging his tail excitedly at every familiar face he saw. \nFollowing closely was Emily, the town’s high school teacher, accompanied by a group of students ready to perform a musical piece they'd rehearsed. One of the girls carried Polly, a vibrant green parrot, on her shoulder. Polly occasionally chimed in with cheerful squawks, adding to the lively atmosphere. \nNot far away, Captain Jack, with his trusty pipe in hand, chatted with old friends about this year's catch. His fleet was the town’s economic backbone, and his seasoned face and towering presence were complemented by the presence of Whiskers, his orange tabby cat, who loved lounging on the dock, attentively watching the gentle waves. \nInside the café, Kate was bustling about, serving guests. As the owner, with her fiery red curls and vivacious spirit, she was the heart of the place. Her friend Susan, an artist living in a tiny cottage nearby, was helping her prepare refreshing beverages. Slinky, Susan's mischievous ferret, darted playfully between the tables, much to the delight of the children present. \nLeaning on the café's railing, a young boy named Tommy watched the sea with wide, gleaming eyes, filled with dreams of the future. By his side sat Daisy, a spirited little dachshund, barking excitedly at the seagulls flying overhead. Tommy's mother, Lucy, stood beside him, smiling softly as she held a seashell he had just found on the beach. \nAmong the crowd, a group of unnamed tourists snapped photos, capturing memories of the charming festival. Street vendors called out, selling their wares—handmade jewelry and sweet confections—as the scent of grilled seafood wafted through the air. \nSuddenly, a burst of laughter erupted—it was James and his band making their grand entrance. Accompanying them was Benny, a friendly border collie who \"performed\" with the band, delighting the crowd with his antics. Set to play a big concert after the opening ceremony, James, the town's star musician, had won the hearts of locals with his soulful tunes. \nAs dusk settled, lights were strung across the streets, casting a magical glow over the town. Mayor William took the stage to deliver his speech, with Max sitting proudly by his side. The festival atmosphere reached its vibrant peak, and in this small town, each person—and animal—carried their own dreams and stories, yet at this moment, they were united by the shared celebration.")
|
55 |
-
data.distilled_text = "This text is from the field of Slice of Life and represents the genre of Novel."
|
56 |
-
data.pred = {
|
57 |
-
"characters": [
|
58 |
-
{
|
59 |
-
"name": "Mayor William",
|
60 |
-
"role": "Mayor"
|
61 |
-
},
|
62 |
-
{
|
63 |
-
"name": "Max",
|
64 |
-
"role": "Golden Retriever, Mayor William's dog"
|
65 |
-
},
|
66 |
-
{
|
67 |
-
"name": "Emily",
|
68 |
-
"role": "High school teacher"
|
69 |
-
},
|
70 |
-
{
|
71 |
-
"name": "Polly",
|
72 |
-
"role": "Parrot, accompanying a student"
|
73 |
-
},
|
74 |
-
{
|
75 |
-
"name": "Captain Jack",
|
76 |
-
"role": "Captain"
|
77 |
-
},
|
78 |
-
{
|
79 |
-
"name": "Whiskers",
|
80 |
-
"role": "Orange tabby cat, Captain Jack's pet"
|
81 |
-
},
|
82 |
-
{
|
83 |
-
"name": "Kate",
|
84 |
-
"role": "Café owner"
|
85 |
-
},
|
86 |
-
{
|
87 |
-
"name": "Susan",
|
88 |
-
"role": "Artist, Kate's friend"
|
89 |
-
},
|
90 |
-
{
|
91 |
-
"name": "Slinky",
|
92 |
-
"role": "Ferret, Susan's pet"
|
93 |
-
},
|
94 |
-
{
|
95 |
-
"name": "Tommy",
|
96 |
-
"role": "Young boy"
|
97 |
-
},
|
98 |
-
{
|
99 |
-
"name": "Daisy",
|
100 |
-
"role": "Dachshund, Tommy's pet"
|
101 |
-
},
|
102 |
-
{
|
103 |
-
"name": "Lucy",
|
104 |
-
"role": "Tommy's mother"
|
105 |
-
},
|
106 |
-
{
|
107 |
-
"name": "James",
|
108 |
-
"role": "Musician, band leader"
|
109 |
-
},
|
110 |
-
{
|
111 |
-
"name": "Benny",
|
112 |
-
"role": "Border Collie, accompanying James and his band"
|
113 |
-
},
|
114 |
-
{
|
115 |
-
"name": "Unnamed Tourists",
|
116 |
-
"role": "Visitors at the festival"
|
117 |
-
},
|
118 |
-
{
|
119 |
-
"name": "Street Vendors",
|
120 |
-
"role": "Sellers at the festival"
|
121 |
-
}
|
122 |
-
]
|
123 |
-
}
|
124 |
-
|
125 |
-
data.truth = {
|
126 |
-
"characters": [
|
127 |
-
{
|
128 |
-
"name": "Mayor William",
|
129 |
-
"role": "The friendly and respected mayor of the seaside town."
|
130 |
-
},
|
131 |
-
{
|
132 |
-
"name": "Emily",
|
133 |
-
"role": "A high school teacher guiding students in a festival performance."
|
134 |
-
},
|
135 |
-
{
|
136 |
-
"name": "Captain Jack",
|
137 |
-
"role": "A seasoned sailor whose fleet supports the town."
|
138 |
-
},
|
139 |
-
{
|
140 |
-
"name": "Kate",
|
141 |
-
"role": "The welcoming owner of the local café."
|
142 |
-
},
|
143 |
-
{
|
144 |
-
"name": "Susan",
|
145 |
-
"role": "An artist known for her ocean-themed paintings."
|
146 |
-
},
|
147 |
-
{
|
148 |
-
"name": "Tommy",
|
149 |
-
"role": "A young boy with dreams of the sea."
|
150 |
-
},
|
151 |
-
{
|
152 |
-
"name": "Lucy",
|
153 |
-
"role": "Tommy's caring and supportive mother."
|
154 |
-
},
|
155 |
-
{
|
156 |
-
"name": "James",
|
157 |
-
"role": "A charismatic musician and band leader."
|
158 |
-
}
|
159 |
-
]
|
160 |
-
}
|
161 |
-
|
162 |
-
|
163 |
-
# Case Update
|
164 |
-
if update_case:
|
165 |
-
if (data.truth == ""):
|
166 |
-
truth = input("Please enter the correct answer you prefer, or press Enter to accept the current answer: ")
|
167 |
-
if truth.strip() == "":
|
168 |
-
data.truth = data.pred
|
169 |
-
else:
|
170 |
-
data.truth = extract_json_dict(truth)
|
171 |
-
self.case_repo.update_case(data)
|
172 |
-
|
173 |
-
# return result
|
174 |
-
result = data.pred
|
175 |
-
trajectory = data.get_result_trajectory()
|
176 |
-
|
177 |
-
return result, trajectory, "a", "b"
|
178 |
-
|
179 |
-
model = DeepSeek(model_name_or_path="deepseek-chat", api_key="")
|
180 |
-
pipeline = Pipeline(model)
|
181 |
-
result, trajectory, *_ = pipeline.get_extract_result(update_case=True, task="Base")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/models/__init__.py
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
from .llm_def import
|
2 |
from .prompt_example import *
|
3 |
from .prompt_template import *
|
|
|
1 |
+
from .llm_def import *
|
2 |
from .prompt_example import *
|
3 |
from .prompt_template import *
|
src/models/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (434 Bytes)
|
|
src/models/__pycache__/__init__.cpython-37.pyc
DELETED
Binary file (315 Bytes)
|
|
src/models/__pycache__/__init__.cpython-39.pyc
CHANGED
Binary files a/src/models/__pycache__/__init__.cpython-39.pyc and b/src/models/__pycache__/__init__.cpython-39.pyc differ
|
|
src/models/__pycache__/llm_def.cpython-311.pyc
DELETED
Binary file (11.8 kB)
|
|
src/models/__pycache__/llm_def.cpython-37.pyc
DELETED
Binary file (7.14 kB)
|
|
src/models/__pycache__/llm_def.cpython-39.pyc
CHANGED
Binary files a/src/models/__pycache__/llm_def.cpython-39.pyc and b/src/models/__pycache__/llm_def.cpython-39.pyc differ
|
|
src/models/__pycache__/prompt_example.cpython-311.pyc
DELETED
Binary file (5.67 kB)
|
|
src/models/__pycache__/prompt_example.cpython-39.pyc
CHANGED
Binary files a/src/models/__pycache__/prompt_example.cpython-39.pyc and b/src/models/__pycache__/prompt_example.cpython-39.pyc differ
|
|
src/models/__pycache__/prompt_template.cpython-311.pyc
DELETED
Binary file (5.42 kB)
|
|
src/models/__pycache__/prompt_template.cpython-39.pyc
CHANGED
Binary files a/src/models/__pycache__/prompt_template.cpython-39.pyc and b/src/models/__pycache__/prompt_template.cpython-39.pyc differ
|
|
src/models/llm_def.py
CHANGED
@@ -6,7 +6,7 @@ Supports:
|
|
6 |
"""
|
7 |
|
8 |
from transformers import pipeline
|
9 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoTokenizer
|
10 |
import torch
|
11 |
import openai
|
12 |
import os
|
@@ -21,7 +21,8 @@ class BaseEngine:
|
|
21 |
self.temperature = 0.2
|
22 |
self.top_p = 0.9
|
23 |
self.max_tokens = 1024
|
24 |
-
|
|
|
25 |
def get_chat_response(self, prompt):
|
26 |
raise NotImplementedError
|
27 |
|
@@ -29,7 +30,7 @@ class BaseEngine:
|
|
29 |
self.temperature = temperature
|
30 |
self.top_p = top_p
|
31 |
self.max_tokens = max_tokens
|
32 |
-
|
33 |
class LLaMA(BaseEngine):
|
34 |
def __init__(self, model_name_or_path: str):
|
35 |
super().__init__(model_name_or_path)
|
@@ -60,7 +61,7 @@ class LLaMA(BaseEngine):
|
|
60 |
top_p=self.top_p,
|
61 |
)
|
62 |
return outputs[0]["generated_text"][-1]['content'].strip()
|
63 |
-
|
64 |
class Qwen(BaseEngine):
|
65 |
def __init__(self, model_name_or_path: str):
|
66 |
super().__init__(model_name_or_path)
|
@@ -71,7 +72,7 @@ class Qwen(BaseEngine):
|
|
71 |
torch_dtype="auto",
|
72 |
device_map="auto"
|
73 |
)
|
74 |
-
|
75 |
def get_chat_response(self, prompt):
|
76 |
messages = [
|
77 |
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
@@ -82,7 +83,7 @@ class Qwen(BaseEngine):
|
|
82 |
tokenize=False,
|
83 |
add_generation_prompt=True
|
84 |
)
|
85 |
-
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.
|
86 |
generated_ids = self.model.generate(
|
87 |
**model_inputs,
|
88 |
temperature=self.temperature,
|
@@ -93,7 +94,7 @@ class Qwen(BaseEngine):
|
|
93 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
94 |
]
|
95 |
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
96 |
-
|
97 |
return response
|
98 |
|
99 |
class MiniCPM(BaseEngine):
|
@@ -113,7 +114,7 @@ class MiniCPM(BaseEngine):
|
|
113 |
{"role": "system", "content": "You are a helpful assistant."},
|
114 |
{"role": "user", "content": prompt}
|
115 |
]
|
116 |
-
model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.
|
117 |
model_outputs = self.model.generate(
|
118 |
model_inputs,
|
119 |
temperature=self.temperature,
|
@@ -124,7 +125,7 @@ class MiniCPM(BaseEngine):
|
|
124 |
model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
|
125 |
]
|
126 |
response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
|
127 |
-
|
128 |
return response
|
129 |
|
130 |
class ChatGLM(BaseEngine):
|
@@ -145,7 +146,7 @@ class ChatGLM(BaseEngine):
|
|
145 |
{"role": "system", "content": "You are a helpful assistant."},
|
146 |
{"role": "user", "content": prompt}
|
147 |
]
|
148 |
-
model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True, tokenize=True).to(self.
|
149 |
model_outputs = self.model.generate(
|
150 |
**model_inputs,
|
151 |
temperature=self.temperature,
|
@@ -154,9 +155,45 @@ class ChatGLM(BaseEngine):
|
|
154 |
)
|
155 |
model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
|
156 |
response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
|
157 |
-
|
158 |
return response
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
class ChatGPT(BaseEngine):
|
161 |
def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
|
162 |
self.name = "ChatGPT"
|
@@ -170,7 +207,7 @@ class ChatGPT(BaseEngine):
|
|
170 |
else:
|
171 |
self.api_key = os.environ["OPENAI_API_KEY"]
|
172 |
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
173 |
-
|
174 |
def get_chat_response(self, input):
|
175 |
response = self.client.chat.completions.create(
|
176 |
model=self.model,
|
@@ -197,7 +234,7 @@ class DeepSeek(BaseEngine):
|
|
197 |
else:
|
198 |
self.api_key = os.environ["DEEPSEEK_API_KEY"]
|
199 |
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
200 |
-
|
201 |
def get_chat_response(self, input):
|
202 |
response = self.client.chat.completions.create(
|
203 |
model=self.model,
|
@@ -210,3 +247,33 @@ class DeepSeek(BaseEngine):
|
|
210 |
stop=None
|
211 |
)
|
212 |
return response.choices[0].message.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"""
|
7 |
|
8 |
from transformers import pipeline
|
9 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, GenerationConfig
|
10 |
import torch
|
11 |
import openai
|
12 |
import os
|
|
|
21 |
self.temperature = 0.2
|
22 |
self.top_p = 0.9
|
23 |
self.max_tokens = 1024
|
24 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
25 |
+
|
26 |
def get_chat_response(self, prompt):
|
27 |
raise NotImplementedError
|
28 |
|
|
|
30 |
self.temperature = temperature
|
31 |
self.top_p = top_p
|
32 |
self.max_tokens = max_tokens
|
33 |
+
|
34 |
class LLaMA(BaseEngine):
|
35 |
def __init__(self, model_name_or_path: str):
|
36 |
super().__init__(model_name_or_path)
|
|
|
61 |
top_p=self.top_p,
|
62 |
)
|
63 |
return outputs[0]["generated_text"][-1]['content'].strip()
|
64 |
+
|
65 |
class Qwen(BaseEngine):
|
66 |
def __init__(self, model_name_or_path: str):
|
67 |
super().__init__(model_name_or_path)
|
|
|
72 |
torch_dtype="auto",
|
73 |
device_map="auto"
|
74 |
)
|
75 |
+
|
76 |
def get_chat_response(self, prompt):
|
77 |
messages = [
|
78 |
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
|
|
|
83 |
tokenize=False,
|
84 |
add_generation_prompt=True
|
85 |
)
|
86 |
+
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
|
87 |
generated_ids = self.model.generate(
|
88 |
**model_inputs,
|
89 |
temperature=self.temperature,
|
|
|
94 |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
95 |
]
|
96 |
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
97 |
+
|
98 |
return response
|
99 |
|
100 |
class MiniCPM(BaseEngine):
|
|
|
114 |
{"role": "system", "content": "You are a helpful assistant."},
|
115 |
{"role": "user", "content": prompt}
|
116 |
]
|
117 |
+
model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.device)
|
118 |
model_outputs = self.model.generate(
|
119 |
model_inputs,
|
120 |
temperature=self.temperature,
|
|
|
125 |
model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
|
126 |
]
|
127 |
response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
|
128 |
+
|
129 |
return response
|
130 |
|
131 |
class ChatGLM(BaseEngine):
|
|
|
146 |
{"role": "system", "content": "You are a helpful assistant."},
|
147 |
{"role": "user", "content": prompt}
|
148 |
]
|
149 |
+
model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True, tokenize=True).to(self.device)
|
150 |
model_outputs = self.model.generate(
|
151 |
**model_inputs,
|
152 |
temperature=self.temperature,
|
|
|
155 |
)
|
156 |
model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
|
157 |
response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
|
158 |
+
|
159 |
return response
|
160 |
|
161 |
+
class OneKE(BaseEngine):
|
162 |
+
def __init__(self, model_name_or_path: str):
|
163 |
+
super().__init__(model_name_or_path)
|
164 |
+
self.name = "OneKE"
|
165 |
+
self.model_id = model_name_or_path
|
166 |
+
config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
|
167 |
+
quantization_config=BitsAndBytesConfig(
|
168 |
+
load_in_4bit=True,
|
169 |
+
llm_int8_threshold=6.0,
|
170 |
+
llm_int8_has_fp16_weight=False,
|
171 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
172 |
+
bnb_4bit_use_double_quant=True,
|
173 |
+
bnb_4bit_quant_type="nf4",
|
174 |
+
)
|
175 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
176 |
+
self.model_id,
|
177 |
+
config=config,
|
178 |
+
device_map="auto",
|
179 |
+
quantization_config=quantization_config,
|
180 |
+
torch_dtype=torch.bfloat16,
|
181 |
+
trust_remote_code=True,
|
182 |
+
)
|
183 |
+
|
184 |
+
def get_chat_response(self, prompt):
|
185 |
+
system_prompt = '<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n'
|
186 |
+
sintruct = '[INST] ' + system_prompt + prompt + '[/INST]'
|
187 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
|
188 |
+
input_ids = self.tokenizer.encode(sintruct, return_tensors="pt").to(self.device)
|
189 |
+
input_length = input_ids.size(1)
|
190 |
+
generation_output = self.model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_length=1024, max_new_tokens=512, return_dict_in_generate=True,pad_token_id=self.tokenizer.pad_token_id,eos_token_id=self.tokenizer.eos_token_id))
|
191 |
+
generation_output = generation_output.sequences[0]
|
192 |
+
generation_output = generation_output[input_length:]
|
193 |
+
response = self.tokenizer.decode(generation_output, skip_special_tokens=True)
|
194 |
+
|
195 |
+
return response
|
196 |
+
|
197 |
class ChatGPT(BaseEngine):
|
198 |
def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
|
199 |
self.name = "ChatGPT"
|
|
|
207 |
else:
|
208 |
self.api_key = os.environ["OPENAI_API_KEY"]
|
209 |
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
210 |
+
|
211 |
def get_chat_response(self, input):
|
212 |
response = self.client.chat.completions.create(
|
213 |
model=self.model,
|
|
|
234 |
else:
|
235 |
self.api_key = os.environ["DEEPSEEK_API_KEY"]
|
236 |
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
237 |
+
|
238 |
def get_chat_response(self, input):
|
239 |
response = self.client.chat.completions.create(
|
240 |
model=self.model,
|
|
|
247 |
stop=None
|
248 |
)
|
249 |
return response.choices[0].message.content
|
250 |
+
|
251 |
+
class LocalServer(BaseEngine):
|
252 |
+
def __init__(self, model_name_or_path: str, base_url="http://localhost:8000/v1"):
|
253 |
+
self.name = model_name_or_path.split('/')[-1]
|
254 |
+
self.model = model_name_or_path
|
255 |
+
self.base_url = base_url
|
256 |
+
self.temperature = 0.2
|
257 |
+
self.top_p = 0.9
|
258 |
+
self.max_tokens = 1024
|
259 |
+
self.api_key = "EMPTY_API_KEY"
|
260 |
+
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
261 |
+
|
262 |
+
def get_chat_response(self, input):
|
263 |
+
try:
|
264 |
+
response = self.client.chat.completions.create(
|
265 |
+
model=self.model,
|
266 |
+
messages=[
|
267 |
+
{"role": "user", "content": input},
|
268 |
+
],
|
269 |
+
stream=False,
|
270 |
+
temperature=self.temperature,
|
271 |
+
max_tokens=self.max_tokens,
|
272 |
+
stop=None
|
273 |
+
)
|
274 |
+
return response.choices[0].message.content
|
275 |
+
except ConnectionError:
|
276 |
+
print("Error: Unable to connect to the server. Please check if the vllm service is running and the port is 8080.")
|
277 |
+
except Exception as e:
|
278 |
+
print(f"Error: {e}")
|
279 |
+
|
src/models/prompt_example.py
CHANGED
@@ -95,13 +95,13 @@ class Event(BaseModel):
|
|
95 |
process: Optional[str] = Field(description="Details of the event process")
|
96 |
result: Optional[str] = Field(default=None, description="Result or outcome of the event")
|
97 |
|
98 |
-
class
|
99 |
-
title: str = Field(description="The title or headline of the news
|
100 |
-
summary: str = Field(description="A brief summary of the news
|
101 |
-
publication_date: Optional[str] = Field(description="The publication date of the
|
102 |
-
keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the
|
103 |
-
events: List[Event] = Field(description="Events covered in the
|
104 |
-
quotes: Optional[
|
105 |
viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
|
106 |
```
|
107 |
|
|
|
95 |
process: Optional[str] = Field(description="Details of the event process")
|
96 |
result: Optional[str] = Field(default=None, description="Result or outcome of the event")
|
97 |
|
98 |
+
class NewsReport(BaseModel):
|
99 |
+
title: str = Field(description="The title or headline of the news report")
|
100 |
+
summary: str = Field(description="A brief summary of the news report")
|
101 |
+
publication_date: Optional[str] = Field(description="The publication date of the report")
|
102 |
+
keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
|
103 |
+
events: List[Event] = Field(description="Events covered in the news report")
|
104 |
+
quotes: Optional[dict] = Field(default=None, description="Quotes related to the news, with keys as the citation sources and values as the quoted content. ")
|
105 |
viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
|
106 |
```
|
107 |
|
src/models/prompt_template.py
CHANGED
@@ -76,6 +76,25 @@ extract_instruction = PromptTemplate(
|
|
76 |
template=EXTRACT_INSTRUCTION,
|
77 |
)
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
SUMMARIZE_INSTRUCTION = """
|
80 |
**Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
|
81 |
{examples}
|
@@ -84,7 +103,7 @@ SUMMARIZE_INSTRUCTION = """
|
|
84 |
**Result List**: {answer_list}
|
85 |
|
86 |
**Output Schema**: {schema}
|
87 |
-
Now summarize all the information from the Result List.
|
88 |
"""
|
89 |
summarize_instruction = PromptTemplate(
|
90 |
input_variables=["instruction", "examples", "answer_list", "schema"],
|
@@ -92,6 +111,8 @@ summarize_instruction = PromptTemplate(
|
|
92 |
)
|
93 |
|
94 |
|
|
|
|
|
95 |
# ==================================================================== #
|
96 |
# REFLECION AGENT #
|
97 |
# ==================================================================== #
|
|
|
76 |
template=EXTRACT_INSTRUCTION,
|
77 |
)
|
78 |
|
79 |
+
instruction_mapper = {
|
80 |
+
'NER': "You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string.",
|
81 |
+
'RE': "You are an expert in relationship extraction. Please extract relationship triples that match the schema definition from the input. Return an empty list for relationships that do not exist. Please respond in the format of a JSON string.",
|
82 |
+
'EE': "You are an expert in event extraction. Please extract events from the input that conform to the schema definition. Return an empty list for events that do not exist, and return NAN for arguments that do not exist. If an argument has multiple values, please return a list. Respond in the format of a JSON string.",
|
83 |
+
}
|
84 |
+
|
85 |
+
EXTRACT_INSTRUCTION_JSON = """
|
86 |
+
{{
|
87 |
+
"instruction": {instruction},
|
88 |
+
"schema": {constraint},
|
89 |
+
"input": {input},
|
90 |
+
}}
|
91 |
+
"""
|
92 |
+
|
93 |
+
extract_instruction_json = PromptTemplate(
|
94 |
+
input_variables=["instruction", "constraint", "input"],
|
95 |
+
template=EXTRACT_INSTRUCTION_JSON,
|
96 |
+
)
|
97 |
+
|
98 |
SUMMARIZE_INSTRUCTION = """
|
99 |
**Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
|
100 |
{examples}
|
|
|
103 |
**Result List**: {answer_list}
|
104 |
|
105 |
**Output Schema**: {schema}
|
106 |
+
Now summarize all the information from the Result List. Filter or merge the redundant information.
|
107 |
"""
|
108 |
summarize_instruction = PromptTemplate(
|
109 |
input_variables=["instruction", "examples", "answer_list", "schema"],
|
|
|
111 |
)
|
112 |
|
113 |
|
114 |
+
|
115 |
+
|
116 |
# ==================================================================== #
|
117 |
# REFLECION AGENT #
|
118 |
# ==================================================================== #
|
src/models/vllm_serve.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import warnings
|
3 |
+
import subprocess
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
7 |
+
from utils import *
|
8 |
+
|
9 |
+
def main():
|
10 |
+
# Create command-line argument parser
|
11 |
+
parser = argparse.ArgumentParser(description='Run the extraction model.')
|
12 |
+
parser.add_argument('--config', type=str, required=True,
|
13 |
+
help='Path to the YAML configuration file.')
|
14 |
+
parser.add_argument('--tensor-parallel-size', type=int, default=2,
|
15 |
+
help='Tensor parallel size for the VLLM server.')
|
16 |
+
parser.add_argument('--max-model-len', type=int, default=32768,
|
17 |
+
help='Maximum model length for the VLLM server.')
|
18 |
+
|
19 |
+
# Parse command-line arguments
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
# Load configuration
|
23 |
+
config = load_extraction_config(args.config)
|
24 |
+
# Model config
|
25 |
+
model_config = config['model']
|
26 |
+
if model_config['vllm_serve'] == False:
|
27 |
+
warnings.warn("VLLM-deployed model will not be used for extraction. To enable VLLM, set vllm_serve to true in the configuration file.")
|
28 |
+
model_name_or_path = model_config['model_name_or_path']
|
29 |
+
command = f"vllm serve {model_name_or_path} --tensor-parallel-size {args.tensor_parallel_size} --max-model-len {args.max_model_len} --enforce-eager --port 8000"
|
30 |
+
subprocess.run(command, shell=True)
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
main()
|
34 |
+
|
src/modules/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (459 Bytes)
|
|
src/modules/__pycache__/__init__.cpython-39.pyc
CHANGED
Binary files a/src/modules/__pycache__/__init__.cpython-39.pyc and b/src/modules/__pycache__/__init__.cpython-39.pyc differ
|
|
src/modules/__pycache__/extraction_agent.cpython-311.pyc
DELETED
Binary file (6.66 kB)
|
|
src/modules/__pycache__/extraction_agent.cpython-39.pyc
CHANGED
Binary files a/src/modules/__pycache__/extraction_agent.cpython-39.pyc and b/src/modules/__pycache__/extraction_agent.cpython-39.pyc differ
|
|
src/modules/__pycache__/reflection_agent.cpython-311.pyc
DELETED
Binary file (6.98 kB)
|
|
src/modules/__pycache__/reflection_agent.cpython-39.pyc
CHANGED
Binary files a/src/modules/__pycache__/reflection_agent.cpython-39.pyc and b/src/modules/__pycache__/reflection_agent.cpython-39.pyc differ
|
|
src/modules/__pycache__/schema_agent.cpython-311.pyc
DELETED
Binary file (10.7 kB)
|
|
src/modules/__pycache__/schema_agent.cpython-39.pyc
CHANGED
Binary files a/src/modules/__pycache__/schema_agent.cpython-39.pyc and b/src/modules/__pycache__/schema_agent.cpython-39.pyc differ
|
|
src/modules/extraction_agent.py
CHANGED
@@ -11,9 +11,13 @@ class InformationExtractor:
|
|
11 |
prompt = extract_instruction.format(instruction=instruction, examples=examples, text=text, additional_info=additional_info, schema=schema)
|
12 |
response = self.llm.get_chat_response(prompt)
|
13 |
response = extract_json_dict(response)
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
17 |
return response
|
18 |
|
19 |
def summarize_answer(self, instruction="", answer_list="", schema="", additional_info=""):
|
@@ -34,26 +38,43 @@ class ExtractionAgent:
|
|
34 |
return data
|
35 |
if data.task == "NER":
|
36 |
constraint = json.dumps(data.constraint)
|
37 |
-
if "**Entity Type Constraint**" in constraint:
|
38 |
return data
|
39 |
data.constraint = f"\n**Entity Type Constraint**: The type of entities must be chosen from the following list.\n{constraint}\n"
|
40 |
elif data.task == "RE":
|
41 |
constraint = json.dumps(data.constraint)
|
42 |
-
if "**Relation Type Constraint**" in constraint:
|
43 |
return data
|
44 |
data.constraint = f"\n**Relation Type Constraint**: The type of relations must be chosen from the following list.\n{constraint}\n"
|
45 |
elif data.task == "EE":
|
46 |
constraint = json.dumps(data.constraint)
|
47 |
if "**Event Extraction Constraint**" in constraint:
|
48 |
return data
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
return data
|
51 |
|
52 |
def extract_information_direct(self, data: DataPoint):
|
53 |
data = self.__get_constraint(data)
|
54 |
result_list = []
|
55 |
for chunk_text in data.chunk_text_list:
|
56 |
-
|
|
|
|
|
|
|
57 |
result_list.append(extract_direct_result)
|
58 |
function_name = current_function_name()
|
59 |
data.set_result_list(result_list)
|
|
|
11 |
prompt = extract_instruction.format(instruction=instruction, examples=examples, text=text, additional_info=additional_info, schema=schema)
|
12 |
response = self.llm.get_chat_response(prompt)
|
13 |
response = extract_json_dict(response)
|
14 |
+
return response
|
15 |
+
|
16 |
+
def extract_information_compatible(self, task="", text="", constraint=""):
|
17 |
+
instruction = instruction_mapper.get(task)
|
18 |
+
prompt = extract_instruction_json.format(instruction=instruction, constraint=constraint, input=text)
|
19 |
+
response = self.llm.get_chat_response(prompt)
|
20 |
+
response = extract_json_dict(response)
|
21 |
return response
|
22 |
|
23 |
def summarize_answer(self, instruction="", answer_list="", schema="", additional_info=""):
|
|
|
38 |
return data
|
39 |
if data.task == "NER":
|
40 |
constraint = json.dumps(data.constraint)
|
41 |
+
if "**Entity Type Constraint**" in constraint or self.llm.name == "OneKE":
|
42 |
return data
|
43 |
data.constraint = f"\n**Entity Type Constraint**: The type of entities must be chosen from the following list.\n{constraint}\n"
|
44 |
elif data.task == "RE":
|
45 |
constraint = json.dumps(data.constraint)
|
46 |
+
if "**Relation Type Constraint**" in constraint or self.llm.name == "OneKE":
|
47 |
return data
|
48 |
data.constraint = f"\n**Relation Type Constraint**: The type of relations must be chosen from the following list.\n{constraint}\n"
|
49 |
elif data.task == "EE":
|
50 |
constraint = json.dumps(data.constraint)
|
51 |
if "**Event Extraction Constraint**" in constraint:
|
52 |
return data
|
53 |
+
if self.llm.name != "OneKE":
|
54 |
+
data.constraint = f"\n**Event Extraction Constraint**: The event type must be selected from the following dictionary keys, and its event arguments should be chosen from its corresponding dictionary values. \n{constraint}\n"
|
55 |
+
else:
|
56 |
+
try:
|
57 |
+
result = [
|
58 |
+
{
|
59 |
+
"event_type": key,
|
60 |
+
"trigger": True,
|
61 |
+
"arguments": value
|
62 |
+
}
|
63 |
+
for key, value in data.constraint.items()
|
64 |
+
]
|
65 |
+
data.constraint = json.dumps(result)
|
66 |
+
except:
|
67 |
+
print("Invalid Constraint: Event Extraction constraint must be a dictionary with event types as keys and lists of arguments as values.", data.constraint)
|
68 |
return data
|
69 |
|
70 |
def extract_information_direct(self, data: DataPoint):
|
71 |
data = self.__get_constraint(data)
|
72 |
result_list = []
|
73 |
for chunk_text in data.chunk_text_list:
|
74 |
+
if self.llm.name != "OneKE":
|
75 |
+
extract_direct_result = self.module.extract_information(instruction=data.instruction, text=chunk_text, schema=data.output_schema, examples="", additional_info=data.constraint)
|
76 |
+
else:
|
77 |
+
extract_direct_result = self.module.extract_information_compatible(task=data.task, text=chunk_text, constraint=data.constraint)
|
78 |
result_list.append(extract_direct_result)
|
79 |
function_name = current_function_name()
|
80 |
data.set_result_list(result_list)
|
src/modules/knowledge_base/__pycache__/case_repository.cpython-311.pyc
DELETED
Binary file (4.64 kB)
|
|
src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc
CHANGED
Binary files a/src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc and b/src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc differ
|
|
src/modules/knowledge_base/__pycache__/schema_repository.cpython-311.pyc
DELETED
Binary file (9.25 kB)
|
|
src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc
CHANGED
Binary files a/src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc and b/src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc differ
|
|
src/modules/knowledge_base/case_repository.py
CHANGED
@@ -1,192 +1,3 @@
|
|
1 |
-
# import json
|
2 |
-
# import os
|
3 |
-
# import torch
|
4 |
-
# import numpy as np
|
5 |
-
# from utils import *
|
6 |
-
# from sentence_transformers import SentenceTransformer
|
7 |
-
# from rapidfuzz import process
|
8 |
-
# from models import *
|
9 |
-
# import copy
|
10 |
-
|
11 |
-
# import warnings
|
12 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
-
# warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
|
14 |
-
|
15 |
-
# class CaseRepository:
|
16 |
-
# def __init__(self):
|
17 |
-
# self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
|
18 |
-
# self.embedder.to(device)
|
19 |
-
# self.corpus = self.load_corpus()
|
20 |
-
# self.embedded_corpus = self.embed_corpus()
|
21 |
-
|
22 |
-
# def load_corpus(self):
|
23 |
-
# with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
|
24 |
-
# corpus = json.load(file)
|
25 |
-
# return corpus
|
26 |
-
|
27 |
-
# def update_corpus(self):
|
28 |
-
# try:
|
29 |
-
# with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
|
30 |
-
# json.dump(self.corpus, file, indent=2)
|
31 |
-
# except Exception as e:
|
32 |
-
# print(f"Error when updating corpus: {e}")
|
33 |
-
|
34 |
-
# def embed_corpus(self):
|
35 |
-
# embedded_corpus = {}
|
36 |
-
# for key, content in self.corpus.items():
|
37 |
-
# good_index = [item['index']['embed_index'] for item in content['good']]
|
38 |
-
# encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
|
39 |
-
# bad_index = [item['index']['embed_index'] for item in content['bad']]
|
40 |
-
# encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
|
41 |
-
# embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
|
42 |
-
# return embedded_corpus
|
43 |
-
|
44 |
-
# def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
|
45 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
46 |
-
# # Embedding similarity match
|
47 |
-
# encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
|
48 |
-
# embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
|
49 |
-
# embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
|
50 |
-
|
51 |
-
# # String similarity match
|
52 |
-
# str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
|
53 |
-
# str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
|
54 |
-
# scores_dict = {match[0]: match[1] for match in str_similarity_results}
|
55 |
-
# scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
|
56 |
-
# str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
|
57 |
-
|
58 |
-
# # Normalize scores
|
59 |
-
# embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
|
60 |
-
# str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
|
61 |
-
# if embedding_score_range > 0:
|
62 |
-
# embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
|
63 |
-
# else:
|
64 |
-
# embed_norm_scores = embedding_similarity_scores
|
65 |
-
# if str_score_range > 0:
|
66 |
-
# str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
|
67 |
-
# else:
|
68 |
-
# str_norm_scores = str_similarity_scores / 100
|
69 |
-
|
70 |
-
# # Combine the scores with weights
|
71 |
-
# combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
|
72 |
-
# original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
|
73 |
-
|
74 |
-
# scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
|
75 |
-
# original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
|
76 |
-
# return scores, indices, original_scores, original_indices
|
77 |
-
|
78 |
-
# def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
|
79 |
-
# _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
|
80 |
-
# top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
|
81 |
-
# return top_matches
|
82 |
-
|
83 |
-
# def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
|
84 |
-
# self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
|
85 |
-
# self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
|
86 |
-
# print(f"Case updated for {task} task.")
|
87 |
-
|
88 |
-
# class CaseRepositoryHandler:
|
89 |
-
# def __init__(self, llm: BaseEngine):
|
90 |
-
# self.repository = CaseRepository()
|
91 |
-
# self.llm = llm
|
92 |
-
|
93 |
-
# def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
|
94 |
-
# prompt = good_case_analysis_instruction.format(
|
95 |
-
# instruction=instruction, text=text, result=result, additional_info=additional_info
|
96 |
-
# )
|
97 |
-
# for _ in range(3):
|
98 |
-
# response = self.llm.get_chat_response(prompt)
|
99 |
-
# response = extract_json_dict(response)
|
100 |
-
# if not isinstance(response, dict):
|
101 |
-
# return response
|
102 |
-
# return None
|
103 |
-
|
104 |
-
# def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
|
105 |
-
# prompt = bad_case_reflection_instruction.format(
|
106 |
-
# instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
|
107 |
-
# )
|
108 |
-
# for _ in range(3):
|
109 |
-
# response = self.llm.get_chat_response(prompt)
|
110 |
-
# response = extract_json_dict(response)
|
111 |
-
# if not isinstance(response, dict):
|
112 |
-
# return response
|
113 |
-
# return None
|
114 |
-
|
115 |
-
# def __get_index(self, data: DataPoint, case_type: str):
|
116 |
-
# # set embed_index
|
117 |
-
# embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
118 |
-
|
119 |
-
# # set str_index
|
120 |
-
# if data.task == "Base":
|
121 |
-
# str_index = f"**Task**: {data.instruction}"
|
122 |
-
# else:
|
123 |
-
# str_index = f"{data.constraint}"
|
124 |
-
|
125 |
-
# if case_type == "bad":
|
126 |
-
# str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
|
127 |
-
|
128 |
-
# return embed_index, str_index
|
129 |
-
|
130 |
-
# def query_good_case(self, data: DataPoint):
|
131 |
-
# embed_index, str_index = self.__get_index(data, "good")
|
132 |
-
# return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
|
133 |
-
|
134 |
-
# def query_bad_case(self, data: DataPoint):
|
135 |
-
# embed_index, str_index = self.__get_index(data, "bad")
|
136 |
-
# return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
|
137 |
-
|
138 |
-
# def update_good_case(self, data: DataPoint):
|
139 |
-
# if data.truth == "" :
|
140 |
-
# print("No truth value provided.")
|
141 |
-
# return
|
142 |
-
# embed_index, str_index = self.__get_index(data, "good")
|
143 |
-
# _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
|
144 |
-
# original_scores = original_scores.tolist()
|
145 |
-
# if original_scores[0] >= 0.9:
|
146 |
-
# print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
|
147 |
-
# return
|
148 |
-
# good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
|
149 |
-
# wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
|
150 |
-
# wrapped_instruction = f"**Task**: {data.instruction}"
|
151 |
-
# wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
152 |
-
# wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
153 |
-
# if data.task == "Base":
|
154 |
-
# content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
155 |
-
# else:
|
156 |
-
# content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
157 |
-
# self.repository.update_case(data.task, embed_index, str_index, content, "good")
|
158 |
-
|
159 |
-
# def update_bad_case(self, data: DataPoint):
|
160 |
-
# if data.truth == "" :
|
161 |
-
# print("No truth value provided.")
|
162 |
-
# return
|
163 |
-
# if normalize_obj(data.pred) == normalize_obj(data.truth):
|
164 |
-
# return
|
165 |
-
# embed_index, str_index = self.__get_index(data, "bad")
|
166 |
-
# _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
|
167 |
-
# original_scores = original_scores.tolist()
|
168 |
-
# if original_scores[0] >= 0.9:
|
169 |
-
# print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
|
170 |
-
# return
|
171 |
-
# bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
|
172 |
-
# wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
|
173 |
-
# wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
|
174 |
-
# wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
175 |
-
# wrapped_instruction = f"**Task**: {data.instruction}"
|
176 |
-
# wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
177 |
-
# if data.task == "Base":
|
178 |
-
# content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
179 |
-
# else:
|
180 |
-
# content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
181 |
-
# self.repository.update_case(data.task, embed_index, str_index, content, "bad")
|
182 |
-
|
183 |
-
# def update_case(self, data: DataPoint):
|
184 |
-
# self.update_good_case(data)
|
185 |
-
# self.update_bad_case(data)
|
186 |
-
# self.repository.update_corpus()
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
import json
|
191 |
import os
|
192 |
import torch
|
@@ -199,87 +10,84 @@ import copy
|
|
199 |
|
200 |
import warnings
|
201 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
202 |
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
|
203 |
|
204 |
class CaseRepository:
|
205 |
def __init__(self):
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
|
|
|
|
211 |
|
212 |
def load_corpus(self):
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
pass
|
217 |
|
218 |
def update_corpus(self):
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
pass
|
225 |
|
226 |
def embed_corpus(self):
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
pass
|
236 |
|
237 |
def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
|
238 |
-
|
239 |
-
#
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
#
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
|
251 |
-
#
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
#
|
264 |
-
|
265 |
-
|
266 |
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
pass
|
271 |
|
272 |
def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
pass
|
277 |
|
278 |
def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
pass
|
283 |
|
284 |
class CaseRepositoryHandler:
|
285 |
def __init__(self, llm: BaseEngine):
|
@@ -287,105 +95,96 @@ class CaseRepositoryHandler:
|
|
287 |
self.llm = llm
|
288 |
|
289 |
def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
pass
|
300 |
|
301 |
def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
pass
|
312 |
|
313 |
def __get_index(self, data: DataPoint, case_type: str):
|
314 |
# set embed_index
|
315 |
-
|
316 |
|
317 |
-
#
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
|
323 |
-
|
324 |
-
|
325 |
|
326 |
-
|
327 |
-
pass
|
328 |
|
329 |
def query_good_case(self, data: DataPoint):
|
330 |
-
|
331 |
-
|
332 |
-
pass
|
333 |
|
334 |
def query_bad_case(self, data: DataPoint):
|
335 |
-
|
336 |
-
|
337 |
-
pass
|
338 |
|
339 |
def update_good_case(self, data: DataPoint):
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
pass
|
360 |
|
361 |
def update_bad_case(self, data: DataPoint):
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
pass
|
385 |
|
386 |
def update_case(self, data: DataPoint):
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
pass
|
391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
import os
|
3 |
import torch
|
|
|
10 |
|
11 |
import warnings
|
12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
docker_model_path = "/app/model/all-MiniLM-L6-v2"
|
14 |
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
|
15 |
|
16 |
class CaseRepository:
|
17 |
def __init__(self):
|
18 |
+
try:
|
19 |
+
self.embedder = SentenceTransformer(docker_model_path)
|
20 |
+
except:
|
21 |
+
self.embedder = SentenceTransformer(config['model']['embedding_model'])
|
22 |
+
self.embedder.to(device)
|
23 |
+
self.corpus = self.load_corpus()
|
24 |
+
self.embedded_corpus = self.embed_corpus()
|
25 |
|
26 |
def load_corpus(self):
|
27 |
+
with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
|
28 |
+
corpus = json.load(file)
|
29 |
+
return corpus
|
|
|
30 |
|
31 |
def update_corpus(self):
|
32 |
+
try:
|
33 |
+
with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
|
34 |
+
json.dump(self.corpus, file, indent=2)
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Error when updating corpus: {e}")
|
|
|
37 |
|
38 |
def embed_corpus(self):
|
39 |
+
embedded_corpus = {}
|
40 |
+
for key, content in self.corpus.items():
|
41 |
+
good_index = [item['index']['embed_index'] for item in content['good']]
|
42 |
+
encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
|
43 |
+
bad_index = [item['index']['embed_index'] for item in content['bad']]
|
44 |
+
encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
|
45 |
+
embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
|
46 |
+
return embedded_corpus
|
|
|
47 |
|
48 |
def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
|
49 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
50 |
+
# Embedding similarity match
|
51 |
+
encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
|
52 |
+
embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
|
53 |
+
embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
|
54 |
+
|
55 |
+
# String similarity match
|
56 |
+
str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
|
57 |
+
str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
|
58 |
+
scores_dict = {match[0]: match[1] for match in str_similarity_results}
|
59 |
+
scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
|
60 |
+
str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
|
61 |
|
62 |
+
# Normalize scores
|
63 |
+
embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
|
64 |
+
str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
|
65 |
+
if embedding_score_range > 0:
|
66 |
+
embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
|
67 |
+
else:
|
68 |
+
embed_norm_scores = embedding_similarity_scores
|
69 |
+
if str_score_range > 0:
|
70 |
+
str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
|
71 |
+
else:
|
72 |
+
str_norm_scores = str_similarity_scores / 100
|
73 |
+
|
74 |
+
# Combine the scores with weights
|
75 |
+
combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
|
76 |
+
original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
|
77 |
|
78 |
+
scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
|
79 |
+
original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
|
80 |
+
return scores, indices, original_scores, original_indices
|
|
|
81 |
|
82 |
def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
|
83 |
+
_, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
|
84 |
+
top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
|
85 |
+
return top_matches
|
|
|
86 |
|
87 |
def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
|
88 |
+
self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
|
89 |
+
self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
|
90 |
+
print(f"A {case_type} case updated for {task} task.")
|
|
|
91 |
|
92 |
class CaseRepositoryHandler:
|
93 |
def __init__(self, llm: BaseEngine):
|
|
|
95 |
self.llm = llm
|
96 |
|
97 |
def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
|
98 |
+
prompt = good_case_analysis_instruction.format(
|
99 |
+
instruction=instruction, text=text, result=result, additional_info=additional_info
|
100 |
+
)
|
101 |
+
for _ in range(3):
|
102 |
+
response = self.llm.get_chat_response(prompt)
|
103 |
+
response = extract_json_dict(response)
|
104 |
+
if not isinstance(response, dict):
|
105 |
+
return response
|
106 |
+
return None
|
|
|
107 |
|
108 |
def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
|
109 |
+
prompt = bad_case_reflection_instruction.format(
|
110 |
+
instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
|
111 |
+
)
|
112 |
+
for _ in range(3):
|
113 |
+
response = self.llm.get_chat_response(prompt)
|
114 |
+
response = extract_json_dict(response)
|
115 |
+
if not isinstance(response, dict):
|
116 |
+
return response
|
117 |
+
return None
|
|
|
118 |
|
119 |
def __get_index(self, data: DataPoint, case_type: str):
|
120 |
# set embed_index
|
121 |
+
embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
122 |
|
123 |
+
# set str_index
|
124 |
+
if data.task == "Base":
|
125 |
+
str_index = f"**Task**: {data.instruction}"
|
126 |
+
else:
|
127 |
+
str_index = f"{data.constraint}"
|
128 |
|
129 |
+
if case_type == "bad":
|
130 |
+
str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
|
131 |
|
132 |
+
return embed_index, str_index
|
|
|
133 |
|
134 |
def query_good_case(self, data: DataPoint):
|
135 |
+
embed_index, str_index = self.__get_index(data, "good")
|
136 |
+
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
|
|
|
137 |
|
138 |
def query_bad_case(self, data: DataPoint):
|
139 |
+
embed_index, str_index = self.__get_index(data, "bad")
|
140 |
+
return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
|
|
|
141 |
|
142 |
def update_good_case(self, data: DataPoint):
|
143 |
+
if data.truth == "" :
|
144 |
+
print("No truth value provided.")
|
145 |
+
return
|
146 |
+
embed_index, str_index = self.__get_index(data, "good")
|
147 |
+
_, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
|
148 |
+
original_scores = original_scores.tolist()
|
149 |
+
if original_scores[0] >= 0.9:
|
150 |
+
print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
|
151 |
+
return
|
152 |
+
good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
|
153 |
+
wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
|
154 |
+
wrapped_instruction = f"**Task**: {data.instruction}"
|
155 |
+
wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
156 |
+
wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
157 |
+
if data.task == "Base":
|
158 |
+
content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
159 |
+
else:
|
160 |
+
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
|
161 |
+
self.repository.update_case(data.task, embed_index, str_index, content, "good")
|
|
|
162 |
|
163 |
def update_bad_case(self, data: DataPoint):
|
164 |
+
if data.truth == "" :
|
165 |
+
print("No truth value provided.")
|
166 |
+
return
|
167 |
+
if normalize_obj(data.pred) == normalize_obj(data.truth):
|
168 |
+
return
|
169 |
+
embed_index, str_index = self.__get_index(data, "bad")
|
170 |
+
_, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
|
171 |
+
original_scores = original_scores.tolist()
|
172 |
+
if original_scores[0] >= 0.9:
|
173 |
+
print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
|
174 |
+
return
|
175 |
+
bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
|
176 |
+
wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
|
177 |
+
wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
|
178 |
+
wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
|
179 |
+
wrapped_instruction = f"**Task**: {data.instruction}"
|
180 |
+
wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
|
181 |
+
if data.task == "Base":
|
182 |
+
content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
183 |
+
else:
|
184 |
+
content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
|
185 |
+
self.repository.update_case(data.task, embed_index, str_index, content, "bad")
|
|
|
186 |
|
187 |
def update_case(self, data: DataPoint):
|
188 |
+
self.update_good_case(data)
|
189 |
+
self.update_bad_case(data)
|
190 |
+
self.repository.update_corpus()
|
|
|
|
src/modules/knowledge_base/schema_repository.py
CHANGED
@@ -85,7 +85,7 @@ class NewsReport(BaseModel):
|
|
85 |
publication_date: Optional[str] = Field(description="The publication date of the report")
|
86 |
keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
|
87 |
events: List[Event] = Field(description="Events covered in the news report")
|
88 |
-
quotes: Optional[
|
89 |
viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
|
90 |
|
91 |
# --------- You can customize new extraction schemas below -------- #
|
|
|
85 |
publication_date: Optional[str] = Field(description="The publication date of the report")
|
86 |
keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
|
87 |
events: List[Event] = Field(description="Events covered in the news report")
|
88 |
+
quotes: Optional[dict] = Field(default=None, description="Quotes related to the news, with keys as the citation sources and values as the quoted content. ")
|
89 |
viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
|
90 |
|
91 |
# --------- You can customize new extraction schemas below -------- #
|
src/modules/schema_agent.py
CHANGED
@@ -48,9 +48,6 @@ class SchemaAnalyzer:
|
|
48 |
def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str):
|
49 |
prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
|
50 |
response = self.llm.get_chat_response(prompt)
|
51 |
-
print(f"schema prompt: {prompt}")
|
52 |
-
print("========================================")
|
53 |
-
print(f"schema response: {response}")
|
54 |
code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL)
|
55 |
if code_blocks:
|
56 |
try:
|
|
|
48 |
def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str):
|
49 |
prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
|
50 |
response = self.llm.get_chat_response(prompt)
|
|
|
|
|
|
|
51 |
code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL)
|
52 |
if code_blocks:
|
53 |
try:
|
src/pipeline.py
CHANGED
@@ -3,6 +3,7 @@ from models import *
|
|
3 |
from utils import *
|
4 |
from modules import *
|
5 |
|
|
|
6 |
class Pipeline:
|
7 |
def __init__(self, llm: BaseEngine):
|
8 |
self.llm = llm
|
@@ -11,17 +12,26 @@ class Pipeline:
|
|
11 |
self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
|
12 |
self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
|
13 |
|
14 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
|
16 |
-
if "schema_agent" not in
|
17 |
-
|
18 |
-
if data.task == "Base":
|
19 |
-
process_method["schema_agent"] = "get_deduced_schema"
|
20 |
if data.task != "Base":
|
21 |
-
|
22 |
-
if "extraction_agent" not in
|
23 |
-
|
24 |
-
sorted_process_method = {key:
|
25 |
return sorted_process_method
|
26 |
|
27 |
def __init_data(self, data: DataPoint):
|
@@ -36,8 +46,6 @@ class Pipeline:
|
|
36 |
data.output_schema = "EventList"
|
37 |
return data
|
38 |
|
39 |
-
|
40 |
-
|
41 |
# main entry
|
42 |
def get_extract_result(self,
|
43 |
task: TaskType,
|
@@ -49,23 +57,29 @@ class Pipeline:
|
|
49 |
file_path: str = "",
|
50 |
truth: str = "",
|
51 |
mode: str = "quick",
|
52 |
-
update_case: bool = False
|
53 |
-
|
|
|
54 |
print(f" task: {task},\n instruction: {instruction},\n text: {text},\n output_schema: {output_schema},\n constraint: {constraint},\n use_file: {use_file},\n file_path: {file_path},\n truth: {truth},\n mode: {mode},\n update_case: {update_case}")
|
|
|
|
|
|
|
|
|
|
|
55 |
data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, file_path=file_path, truth=truth)
|
56 |
data = self.__init_data(data)
|
57 |
if mode in config['agent']['mode'].keys():
|
58 |
-
process_method = config['agent']['mode'][mode]
|
59 |
else:
|
60 |
process_method = mode
|
61 |
-
print(f"data=================: {data.task}")
|
62 |
-
print(f"process_method=================: {process_method}")
|
63 |
sorted_process_method = self.__init_method(data, process_method)
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
67 |
# Information Extract
|
68 |
-
print(f"sorted_process_method=================: {sorted_process_method}")
|
69 |
for agent_name, method_name in sorted_process_method.items():
|
70 |
agent = getattr(self, agent_name, None)
|
71 |
if not agent:
|
@@ -74,17 +88,23 @@ class Pipeline:
|
|
74 |
if not method:
|
75 |
raise AttributeError(f"Method '{method_name}' not found in {agent_name}.")
|
76 |
data = method(data)
|
77 |
-
if not print_schema and data.print_schema:
|
78 |
print("Schema: \n", data.print_schema)
|
79 |
frontend_schema = data.print_schema
|
80 |
print_schema = True
|
81 |
data = self.extraction_agent.summarize_answer(data)
|
|
|
|
|
|
|
|
|
82 |
print("Extraction Result: \n", json.dumps(data.pred, indent=2))
|
83 |
-
|
|
|
|
|
84 |
# Case Update
|
85 |
if update_case:
|
86 |
if (data.truth == ""):
|
87 |
-
truth = input("Please enter the correct answer you prefer, or press Enter to accept the current answer: ")
|
88 |
if truth.strip() == "":
|
89 |
data.truth = data.pred
|
90 |
else:
|
|
|
3 |
from utils import *
|
4 |
from modules import *
|
5 |
|
6 |
+
|
7 |
class Pipeline:
|
8 |
def __init__(self, llm: BaseEngine):
|
9 |
self.llm = llm
|
|
|
12 |
self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
|
13 |
self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
|
14 |
|
15 |
+
def __check_consistancy(self, llm, task, mode, update_case):
|
16 |
+
if llm.name == "OneKE":
|
17 |
+
if task == "Base":
|
18 |
+
raise ValueError("The finetuned OneKE only supports quick extraction mode for NER, RE and EE Task.")
|
19 |
+
else:
|
20 |
+
mode = "quick"
|
21 |
+
update_case = False
|
22 |
+
print("The fine-tuned OneKE defaults to quick extraction mode without case update.")
|
23 |
+
return mode, update_case
|
24 |
+
return mode, update_case
|
25 |
+
|
26 |
+
def __init_method(self, data: DataPoint, process_method2):
|
27 |
default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
|
28 |
+
if "schema_agent" not in process_method2:
|
29 |
+
process_method2["schema_agent"] = "get_default_schema"
|
|
|
|
|
30 |
if data.task != "Base":
|
31 |
+
process_method2["schema_agent"] = "get_retrieved_schema"
|
32 |
+
if "extraction_agent" not in process_method2:
|
33 |
+
process_method2["extraction_agent"] = "extract_information_direct"
|
34 |
+
sorted_process_method = {key: process_method2[key] for key in default_order if key in process_method2}
|
35 |
return sorted_process_method
|
36 |
|
37 |
def __init_data(self, data: DataPoint):
|
|
|
46 |
data.output_schema = "EventList"
|
47 |
return data
|
48 |
|
|
|
|
|
49 |
# main entry
|
50 |
def get_extract_result(self,
|
51 |
task: TaskType,
|
|
|
57 |
file_path: str = "",
|
58 |
truth: str = "",
|
59 |
mode: str = "quick",
|
60 |
+
update_case: bool = False,
|
61 |
+
show_trajectory: bool = False
|
62 |
+
):
|
63 |
print(f" task: {task},\n instruction: {instruction},\n text: {text},\n output_schema: {output_schema},\n constraint: {constraint},\n use_file: {use_file},\n file_path: {file_path},\n truth: {truth},\n mode: {mode},\n update_case: {update_case}")
|
64 |
+
|
65 |
+
# Check Consistancy
|
66 |
+
mode, update_case = self.__check_consistancy(self.llm, task, mode, update_case)
|
67 |
+
|
68 |
+
# Load Data
|
69 |
data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, file_path=file_path, truth=truth)
|
70 |
data = self.__init_data(data)
|
71 |
if mode in config['agent']['mode'].keys():
|
72 |
+
process_method = config['agent']['mode'][mode].copy()
|
73 |
else:
|
74 |
process_method = mode
|
|
|
|
|
75 |
sorted_process_method = self.__init_method(data, process_method)
|
76 |
+
print("Process Method: ", sorted_process_method)
|
77 |
+
|
78 |
+
print_schema = False #
|
79 |
+
frontend_schema = "" #
|
80 |
+
frontend_res = "" #
|
81 |
+
|
82 |
# Information Extract
|
|
|
83 |
for agent_name, method_name in sorted_process_method.items():
|
84 |
agent = getattr(self, agent_name, None)
|
85 |
if not agent:
|
|
|
88 |
if not method:
|
89 |
raise AttributeError(f"Method '{method_name}' not found in {agent_name}.")
|
90 |
data = method(data)
|
91 |
+
if not print_schema and data.print_schema: #
|
92 |
print("Schema: \n", data.print_schema)
|
93 |
frontend_schema = data.print_schema
|
94 |
print_schema = True
|
95 |
data = self.extraction_agent.summarize_answer(data)
|
96 |
+
|
97 |
+
# show result
|
98 |
+
if show_trajectory:
|
99 |
+
print("Extraction Trajectory: \n", json.dumps(data.get_result_trajectory(), indent=2))
|
100 |
print("Extraction Result: \n", json.dumps(data.pred, indent=2))
|
101 |
+
|
102 |
+
frontend_res = data.pred #
|
103 |
+
|
104 |
# Case Update
|
105 |
if update_case:
|
106 |
if (data.truth == ""):
|
107 |
+
truth = input("Please enter the correct answer you prefer, or just press Enter to accept the current answer: ")
|
108 |
if truth.strip() == "":
|
109 |
data.truth = data.pred
|
110 |
else:
|
src/run.py
CHANGED
@@ -8,81 +8,35 @@ from models import *
|
|
8 |
from utils import *
|
9 |
from modules import *
|
10 |
|
11 |
-
def load_extraction_config(yaml_path):
|
12 |
-
# 从文件路径读取 YAML 内容
|
13 |
-
if not os.path.exists(yaml_path):
|
14 |
-
print(f"Error: The config file '{yaml_path}' does not exist.")
|
15 |
-
return {}
|
16 |
-
|
17 |
-
with open(yaml_path, 'r') as file:
|
18 |
-
config = yaml.safe_load(file)
|
19 |
-
|
20 |
-
# 提取'extraction'配置的字典
|
21 |
-
model_config = config.get('model', {})
|
22 |
-
extraction_config = config.get('extraction', {})
|
23 |
-
# model config
|
24 |
-
model_name_or_path = model_config.get('model_name_or_path', "")
|
25 |
-
model_category = model_config.get('category', "")
|
26 |
-
api_key = model_config.get('api_key', "")
|
27 |
-
base_url = model_config.get('base_url', "")
|
28 |
-
|
29 |
-
# extraction config
|
30 |
-
task = extraction_config.get('task', "")
|
31 |
-
instruction = extraction_config.get('instruction', "")
|
32 |
-
text = extraction_config.get('text', "")
|
33 |
-
output_schema = extraction_config.get('output_schema', "")
|
34 |
-
constraint = extraction_config.get('constraint', "")
|
35 |
-
truth = extraction_config.get('truth', "")
|
36 |
-
use_file = extraction_config.get('use_file', False)
|
37 |
-
mode = extraction_config.get('mode', "quick")
|
38 |
-
update_case = extraction_config.get('update_case', False)
|
39 |
-
|
40 |
-
# 返回一个包含这些变量的字典
|
41 |
-
return {
|
42 |
-
"model": {
|
43 |
-
"model_name_or_path": model_name_or_path,
|
44 |
-
"category": model_category,
|
45 |
-
"api_key": api_key,
|
46 |
-
"base_url": base_url
|
47 |
-
},
|
48 |
-
"extraction": {
|
49 |
-
"task": task,
|
50 |
-
"instruction": instruction,
|
51 |
-
"text": text,
|
52 |
-
"output_schema": output_schema,
|
53 |
-
"constraint": constraint,
|
54 |
-
"truth": truth,
|
55 |
-
"use_file": use_file,
|
56 |
-
"mode": mode,
|
57 |
-
"update_case": update_case
|
58 |
-
}
|
59 |
-
}
|
60 |
-
|
61 |
-
|
62 |
def main():
|
63 |
-
#
|
64 |
-
parser = argparse.ArgumentParser(description='Run the extraction
|
65 |
-
parser.add_argument('--config', type=str, required=True,
|
66 |
help='Path to the YAML configuration file.')
|
67 |
-
|
68 |
-
#
|
69 |
args = parser.parse_args()
|
70 |
|
71 |
-
#
|
72 |
config = load_extraction_config(args.config)
|
|
|
73 |
model_config = config['model']
|
74 |
-
|
75 |
-
|
76 |
-
if clazz is None:
|
77 |
-
print(f"Error: The model category '{model_config['category']}' is not supported.")
|
78 |
-
return
|
79 |
-
if model_config['api_key'] == "":
|
80 |
-
model = clazz(model_config['model_name_or_path'])
|
81 |
else:
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
pipeline = Pipeline(model)
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
|
87 |
if __name__ == "__main__":
|
88 |
main()
|
|
|
8 |
from utils import *
|
9 |
from modules import *
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def main():
|
12 |
+
# Create command-line argument parser
|
13 |
+
parser = argparse.ArgumentParser(description='Run the extraction framefork.')
|
14 |
+
parser.add_argument('--config', type=str, required=True,
|
15 |
help='Path to the YAML configuration file.')
|
16 |
+
|
17 |
+
# Parse command-line arguments
|
18 |
args = parser.parse_args()
|
19 |
|
20 |
+
# Load configuration
|
21 |
config = load_extraction_config(args.config)
|
22 |
+
# Model config
|
23 |
model_config = config['model']
|
24 |
+
if model_config['vllm_serve'] == True:
|
25 |
+
model = LocalServer(model_config['model_name_or_path'])
|
|
|
|
|
|
|
|
|
|
|
26 |
else:
|
27 |
+
clazz = getattr(models, model_config['category'], None)
|
28 |
+
if clazz is None:
|
29 |
+
print(f"Error: The model category '{model_config['category']}' is not supported.")
|
30 |
+
return
|
31 |
+
if model_config['api_key'] == "":
|
32 |
+
model = clazz(model_config['model_name_or_path'])
|
33 |
+
else:
|
34 |
+
model = clazz(model_config['model_name_or_path'], model_config['api_key'], model_config['base_url'])
|
35 |
pipeline = Pipeline(model)
|
36 |
+
# Extraction config
|
37 |
+
extraction_config = config['extraction']
|
38 |
+
result, trajectory = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'])
|
39 |
+
return
|
40 |
|
41 |
if __name__ == "__main__":
|
42 |
main()
|
src/utils/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (274 Bytes)
|
|
src/utils/__pycache__/__init__.cpython-39.pyc
CHANGED
Binary files a/src/utils/__pycache__/__init__.cpython-39.pyc and b/src/utils/__pycache__/__init__.cpython-39.pyc differ
|
|
src/utils/__pycache__/data_def.cpython-311.pyc
DELETED
Binary file (3.07 kB)
|
|
src/utils/__pycache__/data_def.cpython-39.pyc
CHANGED
Binary files a/src/utils/__pycache__/data_def.cpython-39.pyc and b/src/utils/__pycache__/data_def.cpython-39.pyc differ
|
|
src/utils/__pycache__/process.cpython-311.pyc
DELETED
Binary file (10.7 kB)
|
|
src/utils/__pycache__/process.cpython-39.pyc
CHANGED
Binary files a/src/utils/__pycache__/process.cpython-39.pyc and b/src/utils/__pycache__/process.cpython-39.pyc differ
|
|
src/utils/data_def.py
CHANGED
@@ -3,7 +3,6 @@ from models import *
|
|
3 |
from .process import *
|
4 |
# predefined processing logic for routine extraction tasks
|
5 |
TaskType = Literal["NER", "RE", "EE", "Base"]
|
6 |
-
ModelType = Literal["gpt-3.5-turbo", "gpt-4o"]
|
7 |
|
8 |
class DataPoint:
|
9 |
def __init__(self,
|
|
|
3 |
from .process import *
|
4 |
# predefined processing logic for routine extraction tasks
|
5 |
TaskType = Literal["NER", "RE", "EE", "Base"]
|
|
|
6 |
|
7 |
class DataPoint:
|
8 |
def __init__(self,
|
src/utils/process.py
CHANGED
@@ -17,7 +17,65 @@ import inspect
|
|
17 |
import ast
|
18 |
with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
|
19 |
config = yaml.safe_load(file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# Split the string text into chunks
|
22 |
def chunk_str(text):
|
23 |
sentences = sent_tokenize(text)
|
@@ -165,7 +223,6 @@ def normalize_obj(value):
|
|
165 |
if isinstance(value, dict):
|
166 |
return frozenset((k, normalize_obj(v)) for k, v in value.items())
|
167 |
elif isinstance(value, (list, set, tuple)):
|
168 |
-
# 将 Counter 转换为元组以便于被哈希
|
169 |
return tuple(Counter(map(normalize_obj, value)).items())
|
170 |
elif isinstance(value, str):
|
171 |
return format_string(value)
|
|
|
17 |
import ast
|
18 |
with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
|
19 |
config = yaml.safe_load(file)
|
20 |
+
|
21 |
+
# Load configuration
|
22 |
+
def load_extraction_config(yaml_path):
|
23 |
+
# Read YAML content from the file path
|
24 |
+
if not os.path.exists(yaml_path):
|
25 |
+
print(f"Error: The config file '{yaml_path}' does not exist.")
|
26 |
+
return {}
|
27 |
+
|
28 |
+
with open(yaml_path, 'r') as file:
|
29 |
+
config = yaml.safe_load(file)
|
30 |
+
|
31 |
+
# Extract the 'extraction' configuration dictionary
|
32 |
+
model_config = config.get('model', {})
|
33 |
+
extraction_config = config.get('extraction', {})
|
34 |
+
|
35 |
+
# Model config
|
36 |
+
model_name_or_path = model_config.get('model_name_or_path', "")
|
37 |
+
model_category = model_config.get('category', "")
|
38 |
+
api_key = model_config.get('api_key', "")
|
39 |
+
base_url = model_config.get('base_url', "")
|
40 |
+
vllm_serve = model_config.get('vllm_serve', False)
|
41 |
+
|
42 |
+
# Extraction config
|
43 |
+
task = extraction_config.get('task', "")
|
44 |
+
instruction = extraction_config.get('instruction', "")
|
45 |
+
text = extraction_config.get('text', "")
|
46 |
+
output_schema = extraction_config.get('output_schema', "")
|
47 |
+
constraint = extraction_config.get('constraint', "")
|
48 |
+
truth = extraction_config.get('truth', "")
|
49 |
+
use_file = extraction_config.get('use_file', False)
|
50 |
+
file_path = extraction_config.get('file_path', "")
|
51 |
+
mode = extraction_config.get('mode', "quick")
|
52 |
+
update_case = extraction_config.get('update_case', False)
|
53 |
+
show_trajectory = extraction_config.get('show_trajectory', False)
|
54 |
|
55 |
+
# Return a dictionary containing these variables
|
56 |
+
return {
|
57 |
+
"model": {
|
58 |
+
"model_name_or_path": model_name_or_path,
|
59 |
+
"category": model_category,
|
60 |
+
"api_key": api_key,
|
61 |
+
"base_url": base_url,
|
62 |
+
"vllm_serve": vllm_serve
|
63 |
+
},
|
64 |
+
"extraction": {
|
65 |
+
"task": task,
|
66 |
+
"instruction": instruction,
|
67 |
+
"text": text,
|
68 |
+
"output_schema": output_schema,
|
69 |
+
"constraint": constraint,
|
70 |
+
"truth": truth,
|
71 |
+
"use_file": use_file,
|
72 |
+
"file_path": file_path,
|
73 |
+
"mode": mode,
|
74 |
+
"update_case": update_case,
|
75 |
+
"show_trajectory": show_trajectory
|
76 |
+
}
|
77 |
+
}
|
78 |
+
|
79 |
# Split the string text into chunks
|
80 |
def chunk_str(text):
|
81 |
sentences = sent_tokenize(text)
|
|
|
223 |
if isinstance(value, dict):
|
224 |
return frozenset((k, normalize_obj(v)) for k, v in value.items())
|
225 |
elif isinstance(value, (list, set, tuple)):
|
|
|
226 |
return tuple(Counter(map(normalize_obj, value)).items())
|
227 |
elif isinstance(value, str):
|
228 |
return format_string(value)
|
src/{main.py → webui.py}
RENAMED
@@ -147,6 +147,7 @@ def create_interface():
|
|
147 |
use_file=use_file,
|
148 |
file_path=file_path,
|
149 |
text=text,
|
|
|
150 |
)
|
151 |
|
152 |
ger_frontend_schema = str(ger_frontend_schema)
|
@@ -159,8 +160,6 @@ def create_interface():
|
|
159 |
|
160 |
def clear_all():
|
161 |
return (
|
162 |
-
gr.update(value=""), # model
|
163 |
-
gr.update(value=""), # API Key
|
164 |
gr.update(value=""), # task
|
165 |
gr.update(value="", visible=False), # instruction
|
166 |
gr.update(value="", visible=False), # constraint
|
@@ -223,9 +222,6 @@ def create_interface():
|
|
223 |
clear_button.click(
|
224 |
fn=clear_all,
|
225 |
outputs=[
|
226 |
-
model_gr,
|
227 |
-
api_key_gr,
|
228 |
-
base_url_gr,
|
229 |
task_gr,
|
230 |
instruction_gr,
|
231 |
constraint_gr,
|
|
|
147 |
use_file=use_file,
|
148 |
file_path=file_path,
|
149 |
text=text,
|
150 |
+
show_trajectory=False,
|
151 |
)
|
152 |
|
153 |
ger_frontend_schema = str(ger_frontend_schema)
|
|
|
160 |
|
161 |
def clear_all():
|
162 |
return (
|
|
|
|
|
163 |
gr.update(value=""), # task
|
164 |
gr.update(value="", visible=False), # instruction
|
165 |
gr.update(value="", visible=False), # constraint
|
|
|
222 |
clear_button.click(
|
223 |
fn=clear_all,
|
224 |
outputs=[
|
|
|
|
|
|
|
225 |
task_gr,
|
226 |
instruction_gr,
|
227 |
constraint_gr,
|
src/webui/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .interface import InterFace
|
|
|
|
src/webui/__pycache__/__init__.cpython-39.pyc
DELETED
Binary file (197 Bytes)
|
|