ShawnRu commited on
Commit
4754e33
·
1 Parent(s): 132649c

update-github

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +3 -4
  2. app.py +2 -3
  3. src/__pycache__/pipeline.cpython-311.pyc +0 -0
  4. src/__pycache__/pipeline.cpython-39.pyc +0 -0
  5. src/config.yaml +3 -2
  6. src/generate_memory.py +0 -181
  7. src/models/__init__.py +1 -1
  8. src/models/__pycache__/__init__.cpython-311.pyc +0 -0
  9. src/models/__pycache__/__init__.cpython-37.pyc +0 -0
  10. src/models/__pycache__/__init__.cpython-39.pyc +0 -0
  11. src/models/__pycache__/llm_def.cpython-311.pyc +0 -0
  12. src/models/__pycache__/llm_def.cpython-37.pyc +0 -0
  13. src/models/__pycache__/llm_def.cpython-39.pyc +0 -0
  14. src/models/__pycache__/prompt_example.cpython-311.pyc +0 -0
  15. src/models/__pycache__/prompt_example.cpython-39.pyc +0 -0
  16. src/models/__pycache__/prompt_template.cpython-311.pyc +0 -0
  17. src/models/__pycache__/prompt_template.cpython-39.pyc +0 -0
  18. src/models/llm_def.py +80 -13
  19. src/models/prompt_example.py +7 -7
  20. src/models/prompt_template.py +22 -1
  21. src/models/vllm_serve.py +34 -0
  22. src/modules/__pycache__/__init__.cpython-311.pyc +0 -0
  23. src/modules/__pycache__/__init__.cpython-39.pyc +0 -0
  24. src/modules/__pycache__/extraction_agent.cpython-311.pyc +0 -0
  25. src/modules/__pycache__/extraction_agent.cpython-39.pyc +0 -0
  26. src/modules/__pycache__/reflection_agent.cpython-311.pyc +0 -0
  27. src/modules/__pycache__/reflection_agent.cpython-39.pyc +0 -0
  28. src/modules/__pycache__/schema_agent.cpython-311.pyc +0 -0
  29. src/modules/__pycache__/schema_agent.cpython-39.pyc +0 -0
  30. src/modules/extraction_agent.py +28 -7
  31. src/modules/knowledge_base/__pycache__/case_repository.cpython-311.pyc +0 -0
  32. src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc +0 -0
  33. src/modules/knowledge_base/__pycache__/schema_repository.cpython-311.pyc +0 -0
  34. src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc +0 -0
  35. src/modules/knowledge_base/case_repository.py +135 -336
  36. src/modules/knowledge_base/schema_repository.py +1 -1
  37. src/modules/schema_agent.py +0 -3
  38. src/pipeline.py +43 -23
  39. src/run.py +21 -67
  40. src/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  41. src/utils/__pycache__/__init__.cpython-39.pyc +0 -0
  42. src/utils/__pycache__/data_def.cpython-311.pyc +0 -0
  43. src/utils/__pycache__/data_def.cpython-39.pyc +0 -0
  44. src/utils/__pycache__/process.cpython-311.pyc +0 -0
  45. src/utils/__pycache__/process.cpython-39.pyc +0 -0
  46. src/utils/data_def.py +0 -1
  47. src/utils/process.py +58 -1
  48. src/{main.py → webui.py} +1 -5
  49. src/webui/__init__.py +0 -1
  50. src/webui/__pycache__/__init__.cpython-39.pyc +0 -0
README.md CHANGED
@@ -1,13 +1,12 @@
1
  ---
2
  title: OneKE
3
  emoji: 👌🏻
4
- colorFrom: indigo
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.8.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: OneKE
3
  emoji: 👌🏻
4
+ colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 5.8.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Schema-Guided LLM Agent-based Knowledge Extraction System
12
+ ---
 
app.py CHANGED
@@ -1,7 +1,6 @@
1
- import nltk
2
  import subprocess
3
-
4
  nltk.download('punkt')
5
  nltk.download('punkt_tab')
6
 
7
- subprocess.run(["python", "src/main.py"])
 
 
1
  import subprocess
2
+ import nltk
3
  nltk.download('punkt')
4
  nltk.download('punkt_tab')
5
 
6
+ subprocess.run(["python", "src/webui.py"])
src/__pycache__/pipeline.cpython-311.pyc DELETED
Binary file (5.34 kB)
 
src/__pycache__/pipeline.cpython-39.pyc CHANGED
Binary files a/src/__pycache__/pipeline.cpython-39.pyc and b/src/__pycache__/pipeline.cpython-39.pyc differ
 
src/config.yaml CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  agent:
2
  default_schema: The final extraction result should be formatted as a JSON object.
3
  default_ner: Extract the Named Entities in the given text.
@@ -15,5 +18,3 @@ agent:
15
  customized:
16
  schema_agent: get_retrieved_schema
17
  extraction_agent: extract_information_direct
18
-
19
-
 
1
+ model:
2
+ embedding_model: all-MiniLM-L6-v2
3
+
4
  agent:
5
  default_schema: The final extraction result should be formatted as a JSON object.
6
  default_ner: Extract the Named Entities in the given text.
 
18
  customized:
19
  schema_agent: get_retrieved_schema
20
  extraction_agent: extract_information_direct
 
 
src/generate_memory.py DELETED
@@ -1,181 +0,0 @@
1
- from typing import Literal
2
- from models import *
3
- from utils import *
4
- from modules import *
5
-
6
-
7
- class Pipeline:
8
- def __init__(self, llm: BaseEngine):
9
- self.llm = llm
10
- self.case_repo = CaseRepositoryHandler(llm = llm)
11
- self.schema_agent = SchemaAgent(llm = llm)
12
- self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
13
- self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
14
-
15
- def __init_method(self, data: DataPoint, process_method):
16
- default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
17
- if "schema_agent" not in process_method:
18
- process_method["schema_agent"] = "get_default_schema"
19
- if data.task != "Base":
20
- process_method["schema_agent"] = "get_retrieved_schema"
21
- if "extraction_agent" not in process_method:
22
- process_method["extraction_agent"] = "extract_information_direct"
23
- sorted_process_method = {key: process_method[key] for key in default_order if key in process_method}
24
- return sorted_process_method
25
-
26
- def __init_data(self, data: DataPoint):
27
- if data.task == "NER":
28
- data.instruction = config['agent']['default_ner']
29
- data.output_schema = "EntityList"
30
- elif data.task == "RE":
31
- data.instruction = config['agent']['default_re']
32
- data.output_schema = "RelationList"
33
- elif data.task == "EE":
34
- data.instruction = config['agent']['default_ee']
35
- data.output_schema = "EventList"
36
- return data
37
-
38
- # main entry
39
- def get_extract_result(self,
40
- task: TaskType,
41
- instruction: str = "",
42
- text: str = "",
43
- output_schema: str = "",
44
- constraint: str = "",
45
- use_file: bool = False,
46
- truth: str = "",
47
- mode: str = "quick",
48
- update_case: bool = False
49
- ):
50
-
51
- data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, truth=truth)
52
- data = self.__init_data(data)
53
- data.instruction = "In the tranquil seaside town, the summer evening cast a golden glow over everything. The townsfolk gathered at the café by the pier, enjoying the sea breeze while eagerly anticipating the annual Ocean Festival's opening ceremony. \nFirst to arrive was Mayor William, dressed in a deep blue suit, holding a roll of his speech. He smiled and greeted the residents, who held deep respect for their community-minded mayor. Beside him trotted Max, his loyal golden retriever, wagging his tail excitedly at every familiar face he saw. \nFollowing closely was Emily, the town’s high school teacher, accompanied by a group of students ready to perform a musical piece they'd rehearsed. One of the girls carried Polly, a vibrant green parrot, on her shoulder. Polly occasionally chimed in with cheerful squawks, adding to the lively atmosphere. \nNot far away, Captain Jack, with his trusty pipe in hand, chatted with old friends about this year's catch. His fleet was the town’s economic backbone, and his seasoned face and towering presence were complemented by the presence of Whiskers, his orange tabby cat, who loved lounging on the dock, attentively watching the gentle waves. \nInside the café, Kate was bustling about, serving guests. As the owner, with her fiery red curls and vivacious spirit, she was the heart of the place. Her friend Susan, an artist living in a tiny cottage nearby, was helping her prepare refreshing beverages. Slinky, Susan's mischievous ferret, darted playfully between the tables, much to the delight of the children present. \nLeaning on the café's railing, a young boy named Tommy watched the sea with wide, gleaming eyes, filled with dreams of the future. By his side sat Daisy, a spirited little dachshund, barking excitedly at the seagulls flying overhead. Tommy's mother, Lucy, stood beside him, smiling softly as she held a seashell he had just found on the beach. \nAmong the crowd, a group of unnamed tourists snapped photos, capturing memories of the charming festival. Street vendors called out, selling their wares—handmade jewelry and sweet confections—as the scent of grilled seafood wafted through the air. \nSuddenly, a burst of laughter erupted—it was James and his band making their grand entrance. Accompanying them was Benny, a friendly border collie who \"performed\" with the band, delighting the crowd with his antics. Set to play a big concert after the opening ceremony, James, the town's star musician, had won the hearts of locals with his soulful tunes. \nAs dusk settled, lights were strung across the streets, casting a magical glow over the town. Mayor William took the stage to deliver his speech, with Max sitting proudly by his side. The festival atmosphere reached its vibrant peak, and in this small town, each person—and animal—carried their own dreams and stories, yet at this moment, they were united by the shared celebration."
54
- data.chunk_text_list.append("In the tranquil seaside town, the summer evening cast a golden glow over everything. The townsfolk gathered at the café by the pier, enjoying the sea breeze while eagerly anticipating the annual Ocean Festival's opening ceremony. \nFirst to arrive was Mayor William, dressed in a deep blue suit, holding a roll of his speech. He smiled and greeted the residents, who held deep respect for their community-minded mayor. Beside him trotted Max, his loyal golden retriever, wagging his tail excitedly at every familiar face he saw. \nFollowing closely was Emily, the town’s high school teacher, accompanied by a group of students ready to perform a musical piece they'd rehearsed. One of the girls carried Polly, a vibrant green parrot, on her shoulder. Polly occasionally chimed in with cheerful squawks, adding to the lively atmosphere. \nNot far away, Captain Jack, with his trusty pipe in hand, chatted with old friends about this year's catch. His fleet was the town’s economic backbone, and his seasoned face and towering presence were complemented by the presence of Whiskers, his orange tabby cat, who loved lounging on the dock, attentively watching the gentle waves. \nInside the café, Kate was bustling about, serving guests. As the owner, with her fiery red curls and vivacious spirit, she was the heart of the place. Her friend Susan, an artist living in a tiny cottage nearby, was helping her prepare refreshing beverages. Slinky, Susan's mischievous ferret, darted playfully between the tables, much to the delight of the children present. \nLeaning on the café's railing, a young boy named Tommy watched the sea with wide, gleaming eyes, filled with dreams of the future. By his side sat Daisy, a spirited little dachshund, barking excitedly at the seagulls flying overhead. Tommy's mother, Lucy, stood beside him, smiling softly as she held a seashell he had just found on the beach. \nAmong the crowd, a group of unnamed tourists snapped photos, capturing memories of the charming festival. Street vendors called out, selling their wares—handmade jewelry and sweet confections—as the scent of grilled seafood wafted through the air. \nSuddenly, a burst of laughter erupted—it was James and his band making their grand entrance. Accompanying them was Benny, a friendly border collie who \"performed\" with the band, delighting the crowd with his antics. Set to play a big concert after the opening ceremony, James, the town's star musician, had won the hearts of locals with his soulful tunes. \nAs dusk settled, lights were strung across the streets, casting a magical glow over the town. Mayor William took the stage to deliver his speech, with Max sitting proudly by his side. The festival atmosphere reached its vibrant peak, and in this small town, each person—and animal—carried their own dreams and stories, yet at this moment, they were united by the shared celebration.")
55
- data.distilled_text = "This text is from the field of Slice of Life and represents the genre of Novel."
56
- data.pred = {
57
- "characters": [
58
- {
59
- "name": "Mayor William",
60
- "role": "Mayor"
61
- },
62
- {
63
- "name": "Max",
64
- "role": "Golden Retriever, Mayor William's dog"
65
- },
66
- {
67
- "name": "Emily",
68
- "role": "High school teacher"
69
- },
70
- {
71
- "name": "Polly",
72
- "role": "Parrot, accompanying a student"
73
- },
74
- {
75
- "name": "Captain Jack",
76
- "role": "Captain"
77
- },
78
- {
79
- "name": "Whiskers",
80
- "role": "Orange tabby cat, Captain Jack's pet"
81
- },
82
- {
83
- "name": "Kate",
84
- "role": "Café owner"
85
- },
86
- {
87
- "name": "Susan",
88
- "role": "Artist, Kate's friend"
89
- },
90
- {
91
- "name": "Slinky",
92
- "role": "Ferret, Susan's pet"
93
- },
94
- {
95
- "name": "Tommy",
96
- "role": "Young boy"
97
- },
98
- {
99
- "name": "Daisy",
100
- "role": "Dachshund, Tommy's pet"
101
- },
102
- {
103
- "name": "Lucy",
104
- "role": "Tommy's mother"
105
- },
106
- {
107
- "name": "James",
108
- "role": "Musician, band leader"
109
- },
110
- {
111
- "name": "Benny",
112
- "role": "Border Collie, accompanying James and his band"
113
- },
114
- {
115
- "name": "Unnamed Tourists",
116
- "role": "Visitors at the festival"
117
- },
118
- {
119
- "name": "Street Vendors",
120
- "role": "Sellers at the festival"
121
- }
122
- ]
123
- }
124
-
125
- data.truth = {
126
- "characters": [
127
- {
128
- "name": "Mayor William",
129
- "role": "The friendly and respected mayor of the seaside town."
130
- },
131
- {
132
- "name": "Emily",
133
- "role": "A high school teacher guiding students in a festival performance."
134
- },
135
- {
136
- "name": "Captain Jack",
137
- "role": "A seasoned sailor whose fleet supports the town."
138
- },
139
- {
140
- "name": "Kate",
141
- "role": "The welcoming owner of the local café."
142
- },
143
- {
144
- "name": "Susan",
145
- "role": "An artist known for her ocean-themed paintings."
146
- },
147
- {
148
- "name": "Tommy",
149
- "role": "A young boy with dreams of the sea."
150
- },
151
- {
152
- "name": "Lucy",
153
- "role": "Tommy's caring and supportive mother."
154
- },
155
- {
156
- "name": "James",
157
- "role": "A charismatic musician and band leader."
158
- }
159
- ]
160
- }
161
-
162
-
163
- # Case Update
164
- if update_case:
165
- if (data.truth == ""):
166
- truth = input("Please enter the correct answer you prefer, or press Enter to accept the current answer: ")
167
- if truth.strip() == "":
168
- data.truth = data.pred
169
- else:
170
- data.truth = extract_json_dict(truth)
171
- self.case_repo.update_case(data)
172
-
173
- # return result
174
- result = data.pred
175
- trajectory = data.get_result_trajectory()
176
-
177
- return result, trajectory, "a", "b"
178
-
179
- model = DeepSeek(model_name_or_path="deepseek-chat", api_key="")
180
- pipeline = Pipeline(model)
181
- result, trajectory, *_ = pipeline.get_extract_result(update_case=True, task="Base")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/models/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from .llm_def import BaseEngine, LLaMA, Qwen, MiniCPM, ChatGLM, ChatGPT, DeepSeek
2
  from .prompt_example import *
3
  from .prompt_template import *
 
1
+ from .llm_def import *
2
  from .prompt_example import *
3
  from .prompt_template import *
src/models/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (434 Bytes)
 
src/models/__pycache__/__init__.cpython-37.pyc DELETED
Binary file (315 Bytes)
 
src/models/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/src/models/__pycache__/__init__.cpython-39.pyc and b/src/models/__pycache__/__init__.cpython-39.pyc differ
 
src/models/__pycache__/llm_def.cpython-311.pyc DELETED
Binary file (11.8 kB)
 
src/models/__pycache__/llm_def.cpython-37.pyc DELETED
Binary file (7.14 kB)
 
src/models/__pycache__/llm_def.cpython-39.pyc CHANGED
Binary files a/src/models/__pycache__/llm_def.cpython-39.pyc and b/src/models/__pycache__/llm_def.cpython-39.pyc differ
 
src/models/__pycache__/prompt_example.cpython-311.pyc DELETED
Binary file (5.67 kB)
 
src/models/__pycache__/prompt_example.cpython-39.pyc CHANGED
Binary files a/src/models/__pycache__/prompt_example.cpython-39.pyc and b/src/models/__pycache__/prompt_example.cpython-39.pyc differ
 
src/models/__pycache__/prompt_template.cpython-311.pyc DELETED
Binary file (5.42 kB)
 
src/models/__pycache__/prompt_template.cpython-39.pyc CHANGED
Binary files a/src/models/__pycache__/prompt_template.cpython-39.pyc and b/src/models/__pycache__/prompt_template.cpython-39.pyc differ
 
src/models/llm_def.py CHANGED
@@ -6,7 +6,7 @@ Supports:
6
  """
7
 
8
  from transformers import pipeline
9
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoTokenizer
10
  import torch
11
  import openai
12
  import os
@@ -21,7 +21,8 @@ class BaseEngine:
21
  self.temperature = 0.2
22
  self.top_p = 0.9
23
  self.max_tokens = 1024
24
-
 
25
  def get_chat_response(self, prompt):
26
  raise NotImplementedError
27
 
@@ -29,7 +30,7 @@ class BaseEngine:
29
  self.temperature = temperature
30
  self.top_p = top_p
31
  self.max_tokens = max_tokens
32
-
33
  class LLaMA(BaseEngine):
34
  def __init__(self, model_name_or_path: str):
35
  super().__init__(model_name_or_path)
@@ -60,7 +61,7 @@ class LLaMA(BaseEngine):
60
  top_p=self.top_p,
61
  )
62
  return outputs[0]["generated_text"][-1]['content'].strip()
63
-
64
  class Qwen(BaseEngine):
65
  def __init__(self, model_name_or_path: str):
66
  super().__init__(model_name_or_path)
@@ -71,7 +72,7 @@ class Qwen(BaseEngine):
71
  torch_dtype="auto",
72
  device_map="auto"
73
  )
74
-
75
  def get_chat_response(self, prompt):
76
  messages = [
77
  {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
@@ -82,7 +83,7 @@ class Qwen(BaseEngine):
82
  tokenize=False,
83
  add_generation_prompt=True
84
  )
85
- model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
86
  generated_ids = self.model.generate(
87
  **model_inputs,
88
  temperature=self.temperature,
@@ -93,7 +94,7 @@ class Qwen(BaseEngine):
93
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
94
  ]
95
  response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
96
-
97
  return response
98
 
99
  class MiniCPM(BaseEngine):
@@ -113,7 +114,7 @@ class MiniCPM(BaseEngine):
113
  {"role": "system", "content": "You are a helpful assistant."},
114
  {"role": "user", "content": prompt}
115
  ]
116
- model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.model.device)
117
  model_outputs = self.model.generate(
118
  model_inputs,
119
  temperature=self.temperature,
@@ -124,7 +125,7 @@ class MiniCPM(BaseEngine):
124
  model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
125
  ]
126
  response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
127
-
128
  return response
129
 
130
  class ChatGLM(BaseEngine):
@@ -145,7 +146,7 @@ class ChatGLM(BaseEngine):
145
  {"role": "system", "content": "You are a helpful assistant."},
146
  {"role": "user", "content": prompt}
147
  ]
148
- model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True, tokenize=True).to(self.model.device)
149
  model_outputs = self.model.generate(
150
  **model_inputs,
151
  temperature=self.temperature,
@@ -154,9 +155,45 @@ class ChatGLM(BaseEngine):
154
  )
155
  model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
156
  response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
157
-
158
  return response
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  class ChatGPT(BaseEngine):
161
  def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
162
  self.name = "ChatGPT"
@@ -170,7 +207,7 @@ class ChatGPT(BaseEngine):
170
  else:
171
  self.api_key = os.environ["OPENAI_API_KEY"]
172
  self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
173
-
174
  def get_chat_response(self, input):
175
  response = self.client.chat.completions.create(
176
  model=self.model,
@@ -197,7 +234,7 @@ class DeepSeek(BaseEngine):
197
  else:
198
  self.api_key = os.environ["DEEPSEEK_API_KEY"]
199
  self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
200
-
201
  def get_chat_response(self, input):
202
  response = self.client.chat.completions.create(
203
  model=self.model,
@@ -210,3 +247,33 @@ class DeepSeek(BaseEngine):
210
  stop=None
211
  )
212
  return response.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
 
8
  from transformers import pipeline
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, GenerationConfig
10
  import torch
11
  import openai
12
  import os
 
21
  self.temperature = 0.2
22
  self.top_p = 0.9
23
  self.max_tokens = 1024
24
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+
26
  def get_chat_response(self, prompt):
27
  raise NotImplementedError
28
 
 
30
  self.temperature = temperature
31
  self.top_p = top_p
32
  self.max_tokens = max_tokens
33
+
34
  class LLaMA(BaseEngine):
35
  def __init__(self, model_name_or_path: str):
36
  super().__init__(model_name_or_path)
 
61
  top_p=self.top_p,
62
  )
63
  return outputs[0]["generated_text"][-1]['content'].strip()
64
+
65
  class Qwen(BaseEngine):
66
  def __init__(self, model_name_or_path: str):
67
  super().__init__(model_name_or_path)
 
72
  torch_dtype="auto",
73
  device_map="auto"
74
  )
75
+
76
  def get_chat_response(self, prompt):
77
  messages = [
78
  {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
 
83
  tokenize=False,
84
  add_generation_prompt=True
85
  )
86
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
87
  generated_ids = self.model.generate(
88
  **model_inputs,
89
  temperature=self.temperature,
 
94
  output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
95
  ]
96
  response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
97
+
98
  return response
99
 
100
  class MiniCPM(BaseEngine):
 
114
  {"role": "system", "content": "You are a helpful assistant."},
115
  {"role": "user", "content": prompt}
116
  ]
117
+ model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(self.device)
118
  model_outputs = self.model.generate(
119
  model_inputs,
120
  temperature=self.temperature,
 
125
  model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
126
  ]
127
  response = self.tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0].strip()
128
+
129
  return response
130
 
131
  class ChatGLM(BaseEngine):
 
146
  {"role": "system", "content": "You are a helpful assistant."},
147
  {"role": "user", "content": prompt}
148
  ]
149
+ model_inputs = self.tokenizer.apply_chat_template(messages, return_tensors="pt", return_dict=True, add_generation_prompt=True, tokenize=True).to(self.device)
150
  model_outputs = self.model.generate(
151
  **model_inputs,
152
  temperature=self.temperature,
 
155
  )
156
  model_outputs = model_outputs[:, model_inputs['input_ids'].shape[1]:]
157
  response = self.tokenizer.batch_decode(model_outputs, skip_special_tokens=True)[0].strip()
158
+
159
  return response
160
 
161
+ class OneKE(BaseEngine):
162
+ def __init__(self, model_name_or_path: str):
163
+ super().__init__(model_name_or_path)
164
+ self.name = "OneKE"
165
+ self.model_id = model_name_or_path
166
+ config = AutoConfig.from_pretrained(self.model_id, trust_remote_code=True)
167
+ quantization_config=BitsAndBytesConfig(
168
+ load_in_4bit=True,
169
+ llm_int8_threshold=6.0,
170
+ llm_int8_has_fp16_weight=False,
171
+ bnb_4bit_compute_dtype=torch.bfloat16,
172
+ bnb_4bit_use_double_quant=True,
173
+ bnb_4bit_quant_type="nf4",
174
+ )
175
+ self.model = AutoModelForCausalLM.from_pretrained(
176
+ self.model_id,
177
+ config=config,
178
+ device_map="auto",
179
+ quantization_config=quantization_config,
180
+ torch_dtype=torch.bfloat16,
181
+ trust_remote_code=True,
182
+ )
183
+
184
+ def get_chat_response(self, prompt):
185
+ system_prompt = '<<SYS>>\nYou are a helpful assistant. 你是一个乐于助人的助手。\n<</SYS>>\n\n'
186
+ sintruct = '[INST] ' + system_prompt + prompt + '[/INST]'
187
+ input_ids = self.tokenizer.encode(prompt, return_tensors='pt')
188
+ input_ids = self.tokenizer.encode(sintruct, return_tensors="pt").to(self.device)
189
+ input_length = input_ids.size(1)
190
+ generation_output = self.model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_length=1024, max_new_tokens=512, return_dict_in_generate=True,pad_token_id=self.tokenizer.pad_token_id,eos_token_id=self.tokenizer.eos_token_id))
191
+ generation_output = generation_output.sequences[0]
192
+ generation_output = generation_output[input_length:]
193
+ response = self.tokenizer.decode(generation_output, skip_special_tokens=True)
194
+
195
+ return response
196
+
197
  class ChatGPT(BaseEngine):
198
  def __init__(self, model_name_or_path: str, api_key: str, base_url=openai.base_url):
199
  self.name = "ChatGPT"
 
207
  else:
208
  self.api_key = os.environ["OPENAI_API_KEY"]
209
  self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
210
+
211
  def get_chat_response(self, input):
212
  response = self.client.chat.completions.create(
213
  model=self.model,
 
234
  else:
235
  self.api_key = os.environ["DEEPSEEK_API_KEY"]
236
  self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
237
+
238
  def get_chat_response(self, input):
239
  response = self.client.chat.completions.create(
240
  model=self.model,
 
247
  stop=None
248
  )
249
  return response.choices[0].message.content
250
+
251
+ class LocalServer(BaseEngine):
252
+ def __init__(self, model_name_or_path: str, base_url="http://localhost:8000/v1"):
253
+ self.name = model_name_or_path.split('/')[-1]
254
+ self.model = model_name_or_path
255
+ self.base_url = base_url
256
+ self.temperature = 0.2
257
+ self.top_p = 0.9
258
+ self.max_tokens = 1024
259
+ self.api_key = "EMPTY_API_KEY"
260
+ self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
261
+
262
+ def get_chat_response(self, input):
263
+ try:
264
+ response = self.client.chat.completions.create(
265
+ model=self.model,
266
+ messages=[
267
+ {"role": "user", "content": input},
268
+ ],
269
+ stream=False,
270
+ temperature=self.temperature,
271
+ max_tokens=self.max_tokens,
272
+ stop=None
273
+ )
274
+ return response.choices[0].message.content
275
+ except ConnectionError:
276
+ print("Error: Unable to connect to the server. Please check if the vllm service is running and the port is 8080.")
277
+ except Exception as e:
278
+ print(f"Error: {e}")
279
+
src/models/prompt_example.py CHANGED
@@ -95,13 +95,13 @@ class Event(BaseModel):
95
  process: Optional[str] = Field(description="Details of the event process")
96
  result: Optional[str] = Field(default=None, description="Result or outcome of the event")
97
 
98
- class ExtractionTarget(BaseModel):
99
- title: str = Field(description="The title or headline of the news article")
100
- summary: str = Field(description="A brief summary of the news article")
101
- publication_date: Optional[str] = Field(description="The publication date of the article")
102
- keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the article")
103
- events: List[Event] = Field(description="Events covered in the article")
104
- quotes: Optional[List[str]] = Field(default=None, description="Quotes related to the news, if any")
105
  viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
106
  ```
107
 
 
95
  process: Optional[str] = Field(description="Details of the event process")
96
  result: Optional[str] = Field(default=None, description="Result or outcome of the event")
97
 
98
+ class NewsReport(BaseModel):
99
+ title: str = Field(description="The title or headline of the news report")
100
+ summary: str = Field(description="A brief summary of the news report")
101
+ publication_date: Optional[str] = Field(description="The publication date of the report")
102
+ keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
103
+ events: List[Event] = Field(description="Events covered in the news report")
104
+ quotes: Optional[dict] = Field(default=None, description="Quotes related to the news, with keys as the citation sources and values as the quoted content. ")
105
  viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
106
  ```
107
 
src/models/prompt_template.py CHANGED
@@ -76,6 +76,25 @@ extract_instruction = PromptTemplate(
76
  template=EXTRACT_INSTRUCTION,
77
  )
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  SUMMARIZE_INSTRUCTION = """
80
  **Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
81
  {examples}
@@ -84,7 +103,7 @@ SUMMARIZE_INSTRUCTION = """
84
  **Result List**: {answer_list}
85
 
86
  **Output Schema**: {schema}
87
- Now summarize all the information from the Result List.
88
  """
89
  summarize_instruction = PromptTemplate(
90
  input_variables=["instruction", "examples", "answer_list", "schema"],
@@ -92,6 +111,8 @@ summarize_instruction = PromptTemplate(
92
  )
93
 
94
 
 
 
95
  # ==================================================================== #
96
  # REFLECION AGENT #
97
  # ==================================================================== #
 
76
  template=EXTRACT_INSTRUCTION,
77
  )
78
 
79
+ instruction_mapper = {
80
+ 'NER': "You are an expert in named entity recognition. Please extract entities that match the schema definition from the input. Return an empty list if the entity type does not exist. Please respond in the format of a JSON string.",
81
+ 'RE': "You are an expert in relationship extraction. Please extract relationship triples that match the schema definition from the input. Return an empty list for relationships that do not exist. Please respond in the format of a JSON string.",
82
+ 'EE': "You are an expert in event extraction. Please extract events from the input that conform to the schema definition. Return an empty list for events that do not exist, and return NAN for arguments that do not exist. If an argument has multiple values, please return a list. Respond in the format of a JSON string.",
83
+ }
84
+
85
+ EXTRACT_INSTRUCTION_JSON = """
86
+ {{
87
+ "instruction": {instruction},
88
+ "schema": {constraint},
89
+ "input": {input},
90
+ }}
91
+ """
92
+
93
+ extract_instruction_json = PromptTemplate(
94
+ input_variables=["instruction", "constraint", "input"],
95
+ template=EXTRACT_INSTRUCTION_JSON,
96
+ )
97
+
98
  SUMMARIZE_INSTRUCTION = """
99
  **Instruction**: Below is a list of results obtained after segmenting and extracting information from a long article. Please consolidate all the answers to generate a final response.
100
  {examples}
 
103
  **Result List**: {answer_list}
104
 
105
  **Output Schema**: {schema}
106
+ Now summarize all the information from the Result List. Filter or merge the redundant information.
107
  """
108
  summarize_instruction = PromptTemplate(
109
  input_variables=["instruction", "examples", "answer_list", "schema"],
 
111
  )
112
 
113
 
114
+
115
+
116
  # ==================================================================== #
117
  # REFLECION AGENT #
118
  # ==================================================================== #
src/models/vllm_serve.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import warnings
3
+ import subprocess
4
+ import sys
5
+ import os
6
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
+ from utils import *
8
+
9
+ def main():
10
+ # Create command-line argument parser
11
+ parser = argparse.ArgumentParser(description='Run the extraction model.')
12
+ parser.add_argument('--config', type=str, required=True,
13
+ help='Path to the YAML configuration file.')
14
+ parser.add_argument('--tensor-parallel-size', type=int, default=2,
15
+ help='Tensor parallel size for the VLLM server.')
16
+ parser.add_argument('--max-model-len', type=int, default=32768,
17
+ help='Maximum model length for the VLLM server.')
18
+
19
+ # Parse command-line arguments
20
+ args = parser.parse_args()
21
+
22
+ # Load configuration
23
+ config = load_extraction_config(args.config)
24
+ # Model config
25
+ model_config = config['model']
26
+ if model_config['vllm_serve'] == False:
27
+ warnings.warn("VLLM-deployed model will not be used for extraction. To enable VLLM, set vllm_serve to true in the configuration file.")
28
+ model_name_or_path = model_config['model_name_or_path']
29
+ command = f"vllm serve {model_name_or_path} --tensor-parallel-size {args.tensor_parallel_size} --max-model-len {args.max_model_len} --enforce-eager --port 8000"
30
+ subprocess.run(command, shell=True)
31
+
32
+ if __name__ == "__main__":
33
+ main()
34
+
src/modules/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (459 Bytes)
 
src/modules/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/src/modules/__pycache__/__init__.cpython-39.pyc and b/src/modules/__pycache__/__init__.cpython-39.pyc differ
 
src/modules/__pycache__/extraction_agent.cpython-311.pyc DELETED
Binary file (6.66 kB)
 
src/modules/__pycache__/extraction_agent.cpython-39.pyc CHANGED
Binary files a/src/modules/__pycache__/extraction_agent.cpython-39.pyc and b/src/modules/__pycache__/extraction_agent.cpython-39.pyc differ
 
src/modules/__pycache__/reflection_agent.cpython-311.pyc DELETED
Binary file (6.98 kB)
 
src/modules/__pycache__/reflection_agent.cpython-39.pyc CHANGED
Binary files a/src/modules/__pycache__/reflection_agent.cpython-39.pyc and b/src/modules/__pycache__/reflection_agent.cpython-39.pyc differ
 
src/modules/__pycache__/schema_agent.cpython-311.pyc DELETED
Binary file (10.7 kB)
 
src/modules/__pycache__/schema_agent.cpython-39.pyc CHANGED
Binary files a/src/modules/__pycache__/schema_agent.cpython-39.pyc and b/src/modules/__pycache__/schema_agent.cpython-39.pyc differ
 
src/modules/extraction_agent.py CHANGED
@@ -11,9 +11,13 @@ class InformationExtractor:
11
  prompt = extract_instruction.format(instruction=instruction, examples=examples, text=text, additional_info=additional_info, schema=schema)
12
  response = self.llm.get_chat_response(prompt)
13
  response = extract_json_dict(response)
14
- print(f"prompt: {prompt}")
15
- print("========================================")
16
- print(f"response: {response}")
 
 
 
 
17
  return response
18
 
19
  def summarize_answer(self, instruction="", answer_list="", schema="", additional_info=""):
@@ -34,26 +38,43 @@ class ExtractionAgent:
34
  return data
35
  if data.task == "NER":
36
  constraint = json.dumps(data.constraint)
37
- if "**Entity Type Constraint**" in constraint:
38
  return data
39
  data.constraint = f"\n**Entity Type Constraint**: The type of entities must be chosen from the following list.\n{constraint}\n"
40
  elif data.task == "RE":
41
  constraint = json.dumps(data.constraint)
42
- if "**Relation Type Constraint**" in constraint:
43
  return data
44
  data.constraint = f"\n**Relation Type Constraint**: The type of relations must be chosen from the following list.\n{constraint}\n"
45
  elif data.task == "EE":
46
  constraint = json.dumps(data.constraint)
47
  if "**Event Extraction Constraint**" in constraint:
48
  return data
49
- data.constraint = f"\n**Event Extraction Constraint**: The event type must be selected from the following dictionary keys, and its event arguments should be chosen from its corresponding dictionary values. \n{constraint}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  return data
51
 
52
  def extract_information_direct(self, data: DataPoint):
53
  data = self.__get_constraint(data)
54
  result_list = []
55
  for chunk_text in data.chunk_text_list:
56
- extract_direct_result = self.module.extract_information(instruction=data.instruction, text=chunk_text, schema=data.output_schema, examples="", additional_info=data.constraint)
 
 
 
57
  result_list.append(extract_direct_result)
58
  function_name = current_function_name()
59
  data.set_result_list(result_list)
 
11
  prompt = extract_instruction.format(instruction=instruction, examples=examples, text=text, additional_info=additional_info, schema=schema)
12
  response = self.llm.get_chat_response(prompt)
13
  response = extract_json_dict(response)
14
+ return response
15
+
16
+ def extract_information_compatible(self, task="", text="", constraint=""):
17
+ instruction = instruction_mapper.get(task)
18
+ prompt = extract_instruction_json.format(instruction=instruction, constraint=constraint, input=text)
19
+ response = self.llm.get_chat_response(prompt)
20
+ response = extract_json_dict(response)
21
  return response
22
 
23
  def summarize_answer(self, instruction="", answer_list="", schema="", additional_info=""):
 
38
  return data
39
  if data.task == "NER":
40
  constraint = json.dumps(data.constraint)
41
+ if "**Entity Type Constraint**" in constraint or self.llm.name == "OneKE":
42
  return data
43
  data.constraint = f"\n**Entity Type Constraint**: The type of entities must be chosen from the following list.\n{constraint}\n"
44
  elif data.task == "RE":
45
  constraint = json.dumps(data.constraint)
46
+ if "**Relation Type Constraint**" in constraint or self.llm.name == "OneKE":
47
  return data
48
  data.constraint = f"\n**Relation Type Constraint**: The type of relations must be chosen from the following list.\n{constraint}\n"
49
  elif data.task == "EE":
50
  constraint = json.dumps(data.constraint)
51
  if "**Event Extraction Constraint**" in constraint:
52
  return data
53
+ if self.llm.name != "OneKE":
54
+ data.constraint = f"\n**Event Extraction Constraint**: The event type must be selected from the following dictionary keys, and its event arguments should be chosen from its corresponding dictionary values. \n{constraint}\n"
55
+ else:
56
+ try:
57
+ result = [
58
+ {
59
+ "event_type": key,
60
+ "trigger": True,
61
+ "arguments": value
62
+ }
63
+ for key, value in data.constraint.items()
64
+ ]
65
+ data.constraint = json.dumps(result)
66
+ except:
67
+ print("Invalid Constraint: Event Extraction constraint must be a dictionary with event types as keys and lists of arguments as values.", data.constraint)
68
  return data
69
 
70
  def extract_information_direct(self, data: DataPoint):
71
  data = self.__get_constraint(data)
72
  result_list = []
73
  for chunk_text in data.chunk_text_list:
74
+ if self.llm.name != "OneKE":
75
+ extract_direct_result = self.module.extract_information(instruction=data.instruction, text=chunk_text, schema=data.output_schema, examples="", additional_info=data.constraint)
76
+ else:
77
+ extract_direct_result = self.module.extract_information_compatible(task=data.task, text=chunk_text, constraint=data.constraint)
78
  result_list.append(extract_direct_result)
79
  function_name = current_function_name()
80
  data.set_result_list(result_list)
src/modules/knowledge_base/__pycache__/case_repository.cpython-311.pyc DELETED
Binary file (4.64 kB)
 
src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc CHANGED
Binary files a/src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc and b/src/modules/knowledge_base/__pycache__/case_repository.cpython-39.pyc differ
 
src/modules/knowledge_base/__pycache__/schema_repository.cpython-311.pyc DELETED
Binary file (9.25 kB)
 
src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc CHANGED
Binary files a/src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc and b/src/modules/knowledge_base/__pycache__/schema_repository.cpython-39.pyc differ
 
src/modules/knowledge_base/case_repository.py CHANGED
@@ -1,192 +1,3 @@
1
- # import json
2
- # import os
3
- # import torch
4
- # import numpy as np
5
- # from utils import *
6
- # from sentence_transformers import SentenceTransformer
7
- # from rapidfuzz import process
8
- # from models import *
9
- # import copy
10
-
11
- # import warnings
12
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- # warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
14
-
15
- # class CaseRepository:
16
- # def __init__(self):
17
- # self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
18
- # self.embedder.to(device)
19
- # self.corpus = self.load_corpus()
20
- # self.embedded_corpus = self.embed_corpus()
21
-
22
- # def load_corpus(self):
23
- # with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
24
- # corpus = json.load(file)
25
- # return corpus
26
-
27
- # def update_corpus(self):
28
- # try:
29
- # with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
30
- # json.dump(self.corpus, file, indent=2)
31
- # except Exception as e:
32
- # print(f"Error when updating corpus: {e}")
33
-
34
- # def embed_corpus(self):
35
- # embedded_corpus = {}
36
- # for key, content in self.corpus.items():
37
- # good_index = [item['index']['embed_index'] for item in content['good']]
38
- # encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
39
- # bad_index = [item['index']['embed_index'] for item in content['bad']]
40
- # encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
41
- # embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
42
- # return embedded_corpus
43
-
44
- # def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
45
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
- # # Embedding similarity match
47
- # encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
48
- # embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
49
- # embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
50
-
51
- # # String similarity match
52
- # str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
53
- # str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
54
- # scores_dict = {match[0]: match[1] for match in str_similarity_results}
55
- # scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
56
- # str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
57
-
58
- # # Normalize scores
59
- # embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
60
- # str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
61
- # if embedding_score_range > 0:
62
- # embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
63
- # else:
64
- # embed_norm_scores = embedding_similarity_scores
65
- # if str_score_range > 0:
66
- # str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
67
- # else:
68
- # str_norm_scores = str_similarity_scores / 100
69
-
70
- # # Combine the scores with weights
71
- # combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
72
- # original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
73
-
74
- # scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
75
- # original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
76
- # return scores, indices, original_scores, original_indices
77
-
78
- # def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
79
- # _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
80
- # top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
81
- # return top_matches
82
-
83
- # def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
84
- # self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
85
- # self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
86
- # print(f"Case updated for {task} task.")
87
-
88
- # class CaseRepositoryHandler:
89
- # def __init__(self, llm: BaseEngine):
90
- # self.repository = CaseRepository()
91
- # self.llm = llm
92
-
93
- # def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
94
- # prompt = good_case_analysis_instruction.format(
95
- # instruction=instruction, text=text, result=result, additional_info=additional_info
96
- # )
97
- # for _ in range(3):
98
- # response = self.llm.get_chat_response(prompt)
99
- # response = extract_json_dict(response)
100
- # if not isinstance(response, dict):
101
- # return response
102
- # return None
103
-
104
- # def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
105
- # prompt = bad_case_reflection_instruction.format(
106
- # instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
107
- # )
108
- # for _ in range(3):
109
- # response = self.llm.get_chat_response(prompt)
110
- # response = extract_json_dict(response)
111
- # if not isinstance(response, dict):
112
- # return response
113
- # return None
114
-
115
- # def __get_index(self, data: DataPoint, case_type: str):
116
- # # set embed_index
117
- # embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
118
-
119
- # # set str_index
120
- # if data.task == "Base":
121
- # str_index = f"**Task**: {data.instruction}"
122
- # else:
123
- # str_index = f"{data.constraint}"
124
-
125
- # if case_type == "bad":
126
- # str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
127
-
128
- # return embed_index, str_index
129
-
130
- # def query_good_case(self, data: DataPoint):
131
- # embed_index, str_index = self.__get_index(data, "good")
132
- # return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
133
-
134
- # def query_bad_case(self, data: DataPoint):
135
- # embed_index, str_index = self.__get_index(data, "bad")
136
- # return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
137
-
138
- # def update_good_case(self, data: DataPoint):
139
- # if data.truth == "" :
140
- # print("No truth value provided.")
141
- # return
142
- # embed_index, str_index = self.__get_index(data, "good")
143
- # _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
144
- # original_scores = original_scores.tolist()
145
- # if original_scores[0] >= 0.9:
146
- # print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
147
- # return
148
- # good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
149
- # wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
150
- # wrapped_instruction = f"**Task**: {data.instruction}"
151
- # wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
152
- # wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
153
- # if data.task == "Base":
154
- # content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
155
- # else:
156
- # content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
157
- # self.repository.update_case(data.task, embed_index, str_index, content, "good")
158
-
159
- # def update_bad_case(self, data: DataPoint):
160
- # if data.truth == "" :
161
- # print("No truth value provided.")
162
- # return
163
- # if normalize_obj(data.pred) == normalize_obj(data.truth):
164
- # return
165
- # embed_index, str_index = self.__get_index(data, "bad")
166
- # _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
167
- # original_scores = original_scores.tolist()
168
- # if original_scores[0] >= 0.9:
169
- # print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
170
- # return
171
- # bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
172
- # wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
173
- # wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
174
- # wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
175
- # wrapped_instruction = f"**Task**: {data.instruction}"
176
- # wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
177
- # if data.task == "Base":
178
- # content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
179
- # else:
180
- # content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
181
- # self.repository.update_case(data.task, embed_index, str_index, content, "bad")
182
-
183
- # def update_case(self, data: DataPoint):
184
- # self.update_good_case(data)
185
- # self.update_bad_case(data)
186
- # self.repository.update_corpus()
187
-
188
-
189
-
190
  import json
191
  import os
192
  import torch
@@ -199,87 +10,84 @@ import copy
199
 
200
  import warnings
201
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
202
  warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
203
 
204
  class CaseRepository:
205
  def __init__(self):
206
- # self.embedder = SentenceTransformer("all-MiniLM-L6-v2")
207
- # self.embedder.to(device)
208
- # self.corpus = self.load_corpus()
209
- # self.embedded_corpus = self.embed_corpus()
210
- pass
 
 
211
 
212
  def load_corpus(self):
213
- # with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
214
- # corpus = json.load(file)
215
- # return corpus
216
- pass
217
 
218
  def update_corpus(self):
219
- # try:
220
- # with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
221
- # json.dump(self.corpus, file, indent=2)
222
- # except Exception as e:
223
- # print(f"Error when updating corpus: {e}")
224
- pass
225
 
226
  def embed_corpus(self):
227
- # embedded_corpus = {}
228
- # for key, content in self.corpus.items():
229
- # good_index = [item['index']['embed_index'] for item in content['good']]
230
- # encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
231
- # bad_index = [item['index']['embed_index'] for item in content['bad']]
232
- # encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
233
- # embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
234
- # return embedded_corpus
235
- pass
236
 
237
  def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
238
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
239
- # # Embedding similarity match
240
- # encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
241
- # embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
242
- # embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
243
-
244
- # # String similarity match
245
- # str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
246
- # str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
247
- # scores_dict = {match[0]: match[1] for match in str_similarity_results}
248
- # scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
249
- # str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
250
 
251
- # # Normalize scores
252
- # embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
253
- # str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
254
- # if embedding_score_range > 0:
255
- # embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
256
- # else:
257
- # embed_norm_scores = embedding_similarity_scores
258
- # if str_score_range > 0:
259
- # str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
260
- # else:
261
- # str_norm_scores = str_similarity_scores / 100
262
-
263
- # # Combine the scores with weights
264
- # combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
265
- # original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
266
 
267
- # scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
268
- # original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
269
- # return scores, indices, original_scores, original_indices
270
- pass
271
 
272
  def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
273
- # _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
274
- # top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
275
- # return top_matches
276
- pass
277
 
278
  def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
279
- # self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
280
- # self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
281
- # print(f"Case updated for {task} task.")
282
- pass
283
 
284
  class CaseRepositoryHandler:
285
  def __init__(self, llm: BaseEngine):
@@ -287,105 +95,96 @@ class CaseRepositoryHandler:
287
  self.llm = llm
288
 
289
  def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
290
- # prompt = good_case_analysis_instruction.format(
291
- # instruction=instruction, text=text, result=result, additional_info=additional_info
292
- # )
293
- # for _ in range(3):
294
- # response = self.llm.get_chat_response(prompt)
295
- # response = extract_json_dict(response)
296
- # if not isinstance(response, dict):
297
- # return response
298
- # return None
299
- pass
300
 
301
  def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
302
- # prompt = bad_case_reflection_instruction.format(
303
- # instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
304
- # )
305
- # for _ in range(3):
306
- # response = self.llm.get_chat_response(prompt)
307
- # response = extract_json_dict(response)
308
- # if not isinstance(response, dict):
309
- # return response
310
- # return None
311
- pass
312
 
313
  def __get_index(self, data: DataPoint, case_type: str):
314
  # set embed_index
315
- # embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
316
 
317
- # # set str_index
318
- # if data.task == "Base":
319
- # str_index = f"**Task**: {data.instruction}"
320
- # else:
321
- # str_index = f"{data.constraint}"
322
 
323
- # if case_type == "bad":
324
- # str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
325
 
326
- # return embed_index, str_index
327
- pass
328
 
329
  def query_good_case(self, data: DataPoint):
330
- # embed_index, str_index = self.__get_index(data, "good")
331
- # return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
332
- pass
333
 
334
  def query_bad_case(self, data: DataPoint):
335
- # embed_index, str_index = self.__get_index(data, "bad")
336
- # return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
337
- pass
338
 
339
  def update_good_case(self, data: DataPoint):
340
- # if data.truth == "" :
341
- # print("No truth value provided.")
342
- # return
343
- # embed_index, str_index = self.__get_index(data, "good")
344
- # _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
345
- # original_scores = original_scores.tolist()
346
- # if original_scores[0] >= 0.9:
347
- # print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
348
- # return
349
- # good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
350
- # wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
351
- # wrapped_instruction = f"**Task**: {data.instruction}"
352
- # wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
353
- # wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
354
- # if data.task == "Base":
355
- # content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
356
- # else:
357
- # content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
358
- # self.repository.update_case(data.task, embed_index, str_index, content, "good")
359
- pass
360
 
361
  def update_bad_case(self, data: DataPoint):
362
- # if data.truth == "" :
363
- # print("No truth value provided.")
364
- # return
365
- # if normalize_obj(data.pred) == normalize_obj(data.truth):
366
- # return
367
- # embed_index, str_index = self.__get_index(data, "bad")
368
- # _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
369
- # original_scores = original_scores.tolist()
370
- # if original_scores[0] >= 0.9:
371
- # print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
372
- # return
373
- # bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
374
- # wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
375
- # wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
376
- # wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
377
- # wrapped_instruction = f"**Task**: {data.instruction}"
378
- # wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
379
- # if data.task == "Base":
380
- # content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
381
- # else:
382
- # content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
383
- # self.repository.update_case(data.task, embed_index, str_index, content, "bad")
384
- pass
385
 
386
  def update_case(self, data: DataPoint):
387
- # self.update_good_case(data)
388
- # self.update_bad_case(data)
389
- # self.repository.update_corpus()
390
- pass
391
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import os
3
  import torch
 
10
 
11
  import warnings
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ docker_model_path = "/app/model/all-MiniLM-L6-v2"
14
  warnings.filterwarnings("ignore", category=FutureWarning, message=r".*clean_up_tokenization_spaces*")
15
 
16
  class CaseRepository:
17
  def __init__(self):
18
+ try:
19
+ self.embedder = SentenceTransformer(docker_model_path)
20
+ except:
21
+ self.embedder = SentenceTransformer(config['model']['embedding_model'])
22
+ self.embedder.to(device)
23
+ self.corpus = self.load_corpus()
24
+ self.embedded_corpus = self.embed_corpus()
25
 
26
  def load_corpus(self):
27
+ with open(os.path.join(os.path.dirname(__file__), "case_repository.json")) as file:
28
+ corpus = json.load(file)
29
+ return corpus
 
30
 
31
  def update_corpus(self):
32
+ try:
33
+ with open(os.path.join(os.path.dirname(__file__), "case_repository.json"), "w") as file:
34
+ json.dump(self.corpus, file, indent=2)
35
+ except Exception as e:
36
+ print(f"Error when updating corpus: {e}")
 
37
 
38
  def embed_corpus(self):
39
+ embedded_corpus = {}
40
+ for key, content in self.corpus.items():
41
+ good_index = [item['index']['embed_index'] for item in content['good']]
42
+ encoded_good_index = self.embedder.encode(good_index, convert_to_tensor=True).to(device)
43
+ bad_index = [item['index']['embed_index'] for item in content['bad']]
44
+ encoded_bad_index = self.embedder.encode(bad_index, convert_to_tensor=True).to(device)
45
+ embedded_corpus[key] = {"good": encoded_good_index, "bad": encoded_bad_index}
46
+ return embedded_corpus
 
47
 
48
  def get_similarity_scores(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2):
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ # Embedding similarity match
51
+ encoded_embed_query = self.embedder.encode(embed_index, convert_to_tensor=True).to(device)
52
+ embedding_similarity_matrix = self.embedder.similarity(encoded_embed_query, self.embedded_corpus[task][case_type])
53
+ embedding_similarity_scores = embedding_similarity_matrix[0].to(device)
54
+
55
+ # String similarity match
56
+ str_match_corpus = [item['index']['str_index'] for item in self.corpus[task][case_type]]
57
+ str_similarity_results = process.extract(str_index, str_match_corpus, limit=len(str_match_corpus))
58
+ scores_dict = {match[0]: match[1] for match in str_similarity_results}
59
+ scores_in_order = [scores_dict[candidate] for candidate in str_match_corpus]
60
+ str_similarity_scores = torch.tensor(scores_in_order, dtype=torch.float32).to(device)
61
 
62
+ # Normalize scores
63
+ embedding_score_range = embedding_similarity_scores.max() - embedding_similarity_scores.min()
64
+ str_score_range = str_similarity_scores.max() - str_similarity_scores.min()
65
+ if embedding_score_range > 0:
66
+ embed_norm_scores = (embedding_similarity_scores - embedding_similarity_scores.min()) / embedding_score_range
67
+ else:
68
+ embed_norm_scores = embedding_similarity_scores
69
+ if str_score_range > 0:
70
+ str_norm_scores = (str_similarity_scores - str_similarity_scores.min()) / str_score_range
71
+ else:
72
+ str_norm_scores = str_similarity_scores / 100
73
+
74
+ # Combine the scores with weights
75
+ combined_scores = 0.5 * embed_norm_scores + 0.5 * str_norm_scores
76
+ original_combined_scores = 0.5 * embedding_similarity_scores + 0.5 * str_similarity_scores / 100
77
 
78
+ scores, indices = torch.topk(combined_scores, k=min(top_k, combined_scores.size(0)))
79
+ original_scores, original_indices = torch.topk(original_combined_scores, k=min(top_k, original_combined_scores.size(0)))
80
+ return scores, indices, original_scores, original_indices
 
81
 
82
  def query_case(self, task: TaskType, embed_index="", str_index="", case_type="", top_k=2) -> list:
83
+ _, indices, _, _ = self.get_similarity_scores(task, embed_index, str_index, case_type, top_k)
84
+ top_matches = [self.corpus[task][case_type][idx]["content"] for idx in indices]
85
+ return top_matches
 
86
 
87
  def update_case(self, task: TaskType, embed_index="", str_index="", content="" ,case_type=""):
88
+ self.corpus[task][case_type].append({"index": {"embed_index": embed_index, "str_index": str_index}, "content": content})
89
+ self.embedded_corpus[task][case_type] = torch.cat([self.embedded_corpus[task][case_type], self.embedder.encode([embed_index], convert_to_tensor=True).to(device)], dim=0)
90
+ print(f"A {case_type} case updated for {task} task.")
 
91
 
92
  class CaseRepositoryHandler:
93
  def __init__(self, llm: BaseEngine):
 
95
  self.llm = llm
96
 
97
  def __get_good_case_analysis(self, instruction="", text="", result="", additional_info=""):
98
+ prompt = good_case_analysis_instruction.format(
99
+ instruction=instruction, text=text, result=result, additional_info=additional_info
100
+ )
101
+ for _ in range(3):
102
+ response = self.llm.get_chat_response(prompt)
103
+ response = extract_json_dict(response)
104
+ if not isinstance(response, dict):
105
+ return response
106
+ return None
 
107
 
108
  def __get_bad_case_reflection(self, instruction="", text="", original_answer="", correct_answer="", additional_info=""):
109
+ prompt = bad_case_reflection_instruction.format(
110
+ instruction=instruction, text=text, original_answer=original_answer, correct_answer=correct_answer, additional_info=additional_info
111
+ )
112
+ for _ in range(3):
113
+ response = self.llm.get_chat_response(prompt)
114
+ response = extract_json_dict(response)
115
+ if not isinstance(response, dict):
116
+ return response
117
+ return None
 
118
 
119
  def __get_index(self, data: DataPoint, case_type: str):
120
  # set embed_index
121
+ embed_index = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
122
 
123
+ # set str_index
124
+ if data.task == "Base":
125
+ str_index = f"**Task**: {data.instruction}"
126
+ else:
127
+ str_index = f"{data.constraint}"
128
 
129
+ if case_type == "bad":
130
+ str_index += f"\n\n**Original Result**: {json.dumps(data.pred)}"
131
 
132
+ return embed_index, str_index
 
133
 
134
  def query_good_case(self, data: DataPoint):
135
+ embed_index, str_index = self.__get_index(data, "good")
136
+ return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="good")
 
137
 
138
  def query_bad_case(self, data: DataPoint):
139
+ embed_index, str_index = self.__get_index(data, "bad")
140
+ return self.repository.query_case(task=data.task, embed_index=embed_index, str_index=str_index, case_type="bad")
 
141
 
142
  def update_good_case(self, data: DataPoint):
143
+ if data.truth == "" :
144
+ print("No truth value provided.")
145
+ return
146
+ embed_index, str_index = self.__get_index(data, "good")
147
+ _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "good", 1)
148
+ original_scores = original_scores.tolist()
149
+ if original_scores[0] >= 0.9:
150
+ print("The similar good case is already in the corpus. Similarity Score: ", original_scores[0])
151
+ return
152
+ good_case_alaysis = self.__get_good_case_analysis(instruction=data.instruction, text=data.distilled_text, result=data.truth, additional_info=data.constraint)
153
+ wrapped_good_case_analysis = f"**Analysis**: {good_case_alaysis}"
154
+ wrapped_instruction = f"**Task**: {data.instruction}"
155
+ wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
156
+ wrapped_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
157
+ if data.task == "Base":
158
+ content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
159
+ else:
160
+ content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapped_good_case_analysis}\n\n{wrapped_answer}"
161
+ self.repository.update_case(data.task, embed_index, str_index, content, "good")
 
162
 
163
  def update_bad_case(self, data: DataPoint):
164
+ if data.truth == "" :
165
+ print("No truth value provided.")
166
+ return
167
+ if normalize_obj(data.pred) == normalize_obj(data.truth):
168
+ return
169
+ embed_index, str_index = self.__get_index(data, "bad")
170
+ _, _, original_scores, _ = self.repository.get_similarity_scores(data.task, embed_index, str_index, "bad", 1)
171
+ original_scores = original_scores.tolist()
172
+ if original_scores[0] >= 0.9:
173
+ print("The similar bad case is already in the corpus. Similarity Score: ", original_scores[0])
174
+ return
175
+ bad_case_reflection = self.__get_bad_case_reflection(instruction=data.instruction, text=data.distilled_text, original_answer=data.pred, correct_answer=data.truth, additional_info=data.constraint)
176
+ wrapped_bad_case_reflection = f"**Reflection**: {bad_case_reflection}"
177
+ wrapper_original_answer = f"**Original Answer**: {json.dumps(data.pred)}"
178
+ wrapper_correct_answer = f"**Correct Answer**: {json.dumps(data.truth)}"
179
+ wrapped_instruction = f"**Task**: {data.instruction}"
180
+ wrapped_text = f"**Text**: {data.distilled_text}\n{data.chunk_text_list[0]}"
181
+ if data.task == "Base":
182
+ content = f"{wrapped_instruction}\n\n{wrapped_text}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
183
+ else:
184
+ content = f"{wrapped_text}\n\n{data.constraint}\n\n{wrapper_original_answer}\n\n{wrapped_bad_case_reflection}\n\n{wrapper_correct_answer}"
185
+ self.repository.update_case(data.task, embed_index, str_index, content, "bad")
 
186
 
187
  def update_case(self, data: DataPoint):
188
+ self.update_good_case(data)
189
+ self.update_bad_case(data)
190
+ self.repository.update_corpus()
 
 
src/modules/knowledge_base/schema_repository.py CHANGED
@@ -85,7 +85,7 @@ class NewsReport(BaseModel):
85
  publication_date: Optional[str] = Field(description="The publication date of the report")
86
  keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
87
  events: List[Event] = Field(description="Events covered in the news report")
88
- quotes: Optional[List[str]] = Field(default=None, description="Quotes related to the news, if any")
89
  viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
90
 
91
  # --------- You can customize new extraction schemas below -------- #
 
85
  publication_date: Optional[str] = Field(description="The publication date of the report")
86
  keywords: Optional[List[str]] = Field(description="List of keywords or topics covered in the news report")
87
  events: List[Event] = Field(description="Events covered in the news report")
88
+ quotes: Optional[dict] = Field(default=None, description="Quotes related to the news, with keys as the citation sources and values as the quoted content. ")
89
  viewpoints: Optional[List[str]] = Field(default=None, description="Different viewpoints regarding the news")
90
 
91
  # --------- You can customize new extraction schemas below -------- #
src/modules/schema_agent.py CHANGED
@@ -48,9 +48,6 @@ class SchemaAnalyzer:
48
  def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str):
49
  prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
50
  response = self.llm.get_chat_response(prompt)
51
- print(f"schema prompt: {prompt}")
52
- print("========================================")
53
- print(f"schema response: {response}")
54
  code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL)
55
  if code_blocks:
56
  try:
 
48
  def get_deduced_schema_code(self, instruction: str, text: str, distilled_text: str):
49
  prompt = deduced_schema_code_instruction.format(examples=example_wrapper(code_schema_examples), instruction=instruction, distilled_text=distilled_text, text=text)
50
  response = self.llm.get_chat_response(prompt)
 
 
 
51
  code_blocks = re.findall(r'```[^\n]*\n(.*?)\n```', response, re.DOTALL)
52
  if code_blocks:
53
  try:
src/pipeline.py CHANGED
@@ -3,6 +3,7 @@ from models import *
3
  from utils import *
4
  from modules import *
5
 
 
6
  class Pipeline:
7
  def __init__(self, llm: BaseEngine):
8
  self.llm = llm
@@ -11,17 +12,26 @@ class Pipeline:
11
  self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
12
  self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
13
 
14
- def __init_method(self, data: DataPoint, process_method):
 
 
 
 
 
 
 
 
 
 
 
15
  default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
16
- if "schema_agent" not in process_method:
17
- process_method["schema_agent"] = "get_default_schema"
18
- if data.task == "Base":
19
- process_method["schema_agent"] = "get_deduced_schema"
20
  if data.task != "Base":
21
- process_method["schema_agent"] = "get_retrieved_schema"
22
- if "extraction_agent" not in process_method:
23
- process_method["extraction_agent"] = "extract_information_direct"
24
- sorted_process_method = {key: process_method[key] for key in default_order if key in process_method}
25
  return sorted_process_method
26
 
27
  def __init_data(self, data: DataPoint):
@@ -36,8 +46,6 @@ class Pipeline:
36
  data.output_schema = "EventList"
37
  return data
38
 
39
-
40
-
41
  # main entry
42
  def get_extract_result(self,
43
  task: TaskType,
@@ -49,23 +57,29 @@ class Pipeline:
49
  file_path: str = "",
50
  truth: str = "",
51
  mode: str = "quick",
52
- update_case: bool = False
53
- ):
 
54
  print(f" task: {task},\n instruction: {instruction},\n text: {text},\n output_schema: {output_schema},\n constraint: {constraint},\n use_file: {use_file},\n file_path: {file_path},\n truth: {truth},\n mode: {mode},\n update_case: {update_case}")
 
 
 
 
 
55
  data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, file_path=file_path, truth=truth)
56
  data = self.__init_data(data)
57
  if mode in config['agent']['mode'].keys():
58
- process_method = config['agent']['mode'][mode]
59
  else:
60
  process_method = mode
61
- print(f"data=================: {data.task}")
62
- print(f"process_method=================: {process_method}")
63
  sorted_process_method = self.__init_method(data, process_method)
64
- print_schema = False
65
- frontend_schema = ""
66
- frontend_res = ""
 
 
 
67
  # Information Extract
68
- print(f"sorted_process_method=================: {sorted_process_method}")
69
  for agent_name, method_name in sorted_process_method.items():
70
  agent = getattr(self, agent_name, None)
71
  if not agent:
@@ -74,17 +88,23 @@ class Pipeline:
74
  if not method:
75
  raise AttributeError(f"Method '{method_name}' not found in {agent_name}.")
76
  data = method(data)
77
- if not print_schema and data.print_schema:
78
  print("Schema: \n", data.print_schema)
79
  frontend_schema = data.print_schema
80
  print_schema = True
81
  data = self.extraction_agent.summarize_answer(data)
 
 
 
 
82
  print("Extraction Result: \n", json.dumps(data.pred, indent=2))
83
- frontend_res = data.pred
 
 
84
  # Case Update
85
  if update_case:
86
  if (data.truth == ""):
87
- truth = input("Please enter the correct answer you prefer, or press Enter to accept the current answer: ")
88
  if truth.strip() == "":
89
  data.truth = data.pred
90
  else:
 
3
  from utils import *
4
  from modules import *
5
 
6
+
7
  class Pipeline:
8
  def __init__(self, llm: BaseEngine):
9
  self.llm = llm
 
12
  self.extraction_agent = ExtractionAgent(llm = llm, case_repo = self.case_repo)
13
  self.reflection_agent = ReflectionAgent(llm = llm, case_repo = self.case_repo)
14
 
15
+ def __check_consistancy(self, llm, task, mode, update_case):
16
+ if llm.name == "OneKE":
17
+ if task == "Base":
18
+ raise ValueError("The finetuned OneKE only supports quick extraction mode for NER, RE and EE Task.")
19
+ else:
20
+ mode = "quick"
21
+ update_case = False
22
+ print("The fine-tuned OneKE defaults to quick extraction mode without case update.")
23
+ return mode, update_case
24
+ return mode, update_case
25
+
26
+ def __init_method(self, data: DataPoint, process_method2):
27
  default_order = ["schema_agent", "extraction_agent", "reflection_agent"]
28
+ if "schema_agent" not in process_method2:
29
+ process_method2["schema_agent"] = "get_default_schema"
 
 
30
  if data.task != "Base":
31
+ process_method2["schema_agent"] = "get_retrieved_schema"
32
+ if "extraction_agent" not in process_method2:
33
+ process_method2["extraction_agent"] = "extract_information_direct"
34
+ sorted_process_method = {key: process_method2[key] for key in default_order if key in process_method2}
35
  return sorted_process_method
36
 
37
  def __init_data(self, data: DataPoint):
 
46
  data.output_schema = "EventList"
47
  return data
48
 
 
 
49
  # main entry
50
  def get_extract_result(self,
51
  task: TaskType,
 
57
  file_path: str = "",
58
  truth: str = "",
59
  mode: str = "quick",
60
+ update_case: bool = False,
61
+ show_trajectory: bool = False
62
+ ):
63
  print(f" task: {task},\n instruction: {instruction},\n text: {text},\n output_schema: {output_schema},\n constraint: {constraint},\n use_file: {use_file},\n file_path: {file_path},\n truth: {truth},\n mode: {mode},\n update_case: {update_case}")
64
+
65
+ # Check Consistancy
66
+ mode, update_case = self.__check_consistancy(self.llm, task, mode, update_case)
67
+
68
+ # Load Data
69
  data = DataPoint(task=task, instruction=instruction, text=text, output_schema=output_schema, constraint=constraint, use_file=use_file, file_path=file_path, truth=truth)
70
  data = self.__init_data(data)
71
  if mode in config['agent']['mode'].keys():
72
+ process_method = config['agent']['mode'][mode].copy()
73
  else:
74
  process_method = mode
 
 
75
  sorted_process_method = self.__init_method(data, process_method)
76
+ print("Process Method: ", sorted_process_method)
77
+
78
+ print_schema = False #
79
+ frontend_schema = "" #
80
+ frontend_res = "" #
81
+
82
  # Information Extract
 
83
  for agent_name, method_name in sorted_process_method.items():
84
  agent = getattr(self, agent_name, None)
85
  if not agent:
 
88
  if not method:
89
  raise AttributeError(f"Method '{method_name}' not found in {agent_name}.")
90
  data = method(data)
91
+ if not print_schema and data.print_schema: #
92
  print("Schema: \n", data.print_schema)
93
  frontend_schema = data.print_schema
94
  print_schema = True
95
  data = self.extraction_agent.summarize_answer(data)
96
+
97
+ # show result
98
+ if show_trajectory:
99
+ print("Extraction Trajectory: \n", json.dumps(data.get_result_trajectory(), indent=2))
100
  print("Extraction Result: \n", json.dumps(data.pred, indent=2))
101
+
102
+ frontend_res = data.pred #
103
+
104
  # Case Update
105
  if update_case:
106
  if (data.truth == ""):
107
+ truth = input("Please enter the correct answer you prefer, or just press Enter to accept the current answer: ")
108
  if truth.strip() == "":
109
  data.truth = data.pred
110
  else:
src/run.py CHANGED
@@ -8,81 +8,35 @@ from models import *
8
  from utils import *
9
  from modules import *
10
 
11
- def load_extraction_config(yaml_path):
12
- # 从文件路径读取 YAML 内容
13
- if not os.path.exists(yaml_path):
14
- print(f"Error: The config file '{yaml_path}' does not exist.")
15
- return {}
16
-
17
- with open(yaml_path, 'r') as file:
18
- config = yaml.safe_load(file)
19
-
20
- # 提取'extraction'配置的字典
21
- model_config = config.get('model', {})
22
- extraction_config = config.get('extraction', {})
23
- # model config
24
- model_name_or_path = model_config.get('model_name_or_path', "")
25
- model_category = model_config.get('category', "")
26
- api_key = model_config.get('api_key', "")
27
- base_url = model_config.get('base_url', "")
28
-
29
- # extraction config
30
- task = extraction_config.get('task', "")
31
- instruction = extraction_config.get('instruction', "")
32
- text = extraction_config.get('text', "")
33
- output_schema = extraction_config.get('output_schema', "")
34
- constraint = extraction_config.get('constraint', "")
35
- truth = extraction_config.get('truth', "")
36
- use_file = extraction_config.get('use_file', False)
37
- mode = extraction_config.get('mode', "quick")
38
- update_case = extraction_config.get('update_case', False)
39
-
40
- # 返回一个包含这些变量的字典
41
- return {
42
- "model": {
43
- "model_name_or_path": model_name_or_path,
44
- "category": model_category,
45
- "api_key": api_key,
46
- "base_url": base_url
47
- },
48
- "extraction": {
49
- "task": task,
50
- "instruction": instruction,
51
- "text": text,
52
- "output_schema": output_schema,
53
- "constraint": constraint,
54
- "truth": truth,
55
- "use_file": use_file,
56
- "mode": mode,
57
- "update_case": update_case
58
- }
59
- }
60
-
61
-
62
  def main():
63
- # 创建命令行参数解析器
64
- parser = argparse.ArgumentParser(description='Run the extraction model.')
65
- parser.add_argument('--config', type=str, required=True,
66
  help='Path to the YAML configuration file.')
67
-
68
- # 解析命令行参数
69
  args = parser.parse_args()
70
 
71
- # 加载配置
72
  config = load_extraction_config(args.config)
 
73
  model_config = config['model']
74
- extraction_config = config['extraction']
75
- clazz = getattr(models, model_config['category'], None)
76
- if clazz is None:
77
- print(f"Error: The model category '{model_config['category']}' is not supported.")
78
- return
79
- if model_config['api_key'] == "":
80
- model = clazz(model_config['model_name_or_path'])
81
  else:
82
- model = clazz(model_config['model_name_or_path'], model_config['api_key'], model_config['base_url'])
 
 
 
 
 
 
 
83
  pipeline = Pipeline(model)
84
- result, trajectory, *_ = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'])
85
- return
 
 
86
 
87
  if __name__ == "__main__":
88
  main()
 
8
  from utils import *
9
  from modules import *
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def main():
12
+ # Create command-line argument parser
13
+ parser = argparse.ArgumentParser(description='Run the extraction framefork.')
14
+ parser.add_argument('--config', type=str, required=True,
15
  help='Path to the YAML configuration file.')
16
+
17
+ # Parse command-line arguments
18
  args = parser.parse_args()
19
 
20
+ # Load configuration
21
  config = load_extraction_config(args.config)
22
+ # Model config
23
  model_config = config['model']
24
+ if model_config['vllm_serve'] == True:
25
+ model = LocalServer(model_config['model_name_or_path'])
 
 
 
 
 
26
  else:
27
+ clazz = getattr(models, model_config['category'], None)
28
+ if clazz is None:
29
+ print(f"Error: The model category '{model_config['category']}' is not supported.")
30
+ return
31
+ if model_config['api_key'] == "":
32
+ model = clazz(model_config['model_name_or_path'])
33
+ else:
34
+ model = clazz(model_config['model_name_or_path'], model_config['api_key'], model_config['base_url'])
35
  pipeline = Pipeline(model)
36
+ # Extraction config
37
+ extraction_config = config['extraction']
38
+ result, trajectory = pipeline.get_extract_result(task=extraction_config['task'], instruction=extraction_config['instruction'], text=extraction_config['text'], output_schema=extraction_config['output_schema'], constraint=extraction_config['constraint'], use_file=extraction_config['use_file'], file_path=extraction_config['file_path'], truth=extraction_config['truth'], mode=extraction_config['mode'], update_case=extraction_config['update_case'], show_trajectory=extraction_config['show_trajectory'])
39
+ return
40
 
41
  if __name__ == "__main__":
42
  main()
src/utils/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (274 Bytes)
 
src/utils/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/src/utils/__pycache__/__init__.cpython-39.pyc and b/src/utils/__pycache__/__init__.cpython-39.pyc differ
 
src/utils/__pycache__/data_def.cpython-311.pyc DELETED
Binary file (3.07 kB)
 
src/utils/__pycache__/data_def.cpython-39.pyc CHANGED
Binary files a/src/utils/__pycache__/data_def.cpython-39.pyc and b/src/utils/__pycache__/data_def.cpython-39.pyc differ
 
src/utils/__pycache__/process.cpython-311.pyc DELETED
Binary file (10.7 kB)
 
src/utils/__pycache__/process.cpython-39.pyc CHANGED
Binary files a/src/utils/__pycache__/process.cpython-39.pyc and b/src/utils/__pycache__/process.cpython-39.pyc differ
 
src/utils/data_def.py CHANGED
@@ -3,7 +3,6 @@ from models import *
3
  from .process import *
4
  # predefined processing logic for routine extraction tasks
5
  TaskType = Literal["NER", "RE", "EE", "Base"]
6
- ModelType = Literal["gpt-3.5-turbo", "gpt-4o"]
7
 
8
  class DataPoint:
9
  def __init__(self,
 
3
  from .process import *
4
  # predefined processing logic for routine extraction tasks
5
  TaskType = Literal["NER", "RE", "EE", "Base"]
 
6
 
7
  class DataPoint:
8
  def __init__(self,
src/utils/process.py CHANGED
@@ -17,7 +17,65 @@ import inspect
17
  import ast
18
  with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
19
  config = yaml.safe_load(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Split the string text into chunks
22
  def chunk_str(text):
23
  sentences = sent_tokenize(text)
@@ -165,7 +223,6 @@ def normalize_obj(value):
165
  if isinstance(value, dict):
166
  return frozenset((k, normalize_obj(v)) for k, v in value.items())
167
  elif isinstance(value, (list, set, tuple)):
168
- # 将 Counter 转换为元组以便于被哈希
169
  return tuple(Counter(map(normalize_obj, value)).items())
170
  elif isinstance(value, str):
171
  return format_string(value)
 
17
  import ast
18
  with open(os.path.join(os.path.dirname(__file__), "..", "config.yaml")) as file:
19
  config = yaml.safe_load(file)
20
+
21
+ # Load configuration
22
+ def load_extraction_config(yaml_path):
23
+ # Read YAML content from the file path
24
+ if not os.path.exists(yaml_path):
25
+ print(f"Error: The config file '{yaml_path}' does not exist.")
26
+ return {}
27
+
28
+ with open(yaml_path, 'r') as file:
29
+ config = yaml.safe_load(file)
30
+
31
+ # Extract the 'extraction' configuration dictionary
32
+ model_config = config.get('model', {})
33
+ extraction_config = config.get('extraction', {})
34
+
35
+ # Model config
36
+ model_name_or_path = model_config.get('model_name_or_path', "")
37
+ model_category = model_config.get('category', "")
38
+ api_key = model_config.get('api_key', "")
39
+ base_url = model_config.get('base_url', "")
40
+ vllm_serve = model_config.get('vllm_serve', False)
41
+
42
+ # Extraction config
43
+ task = extraction_config.get('task', "")
44
+ instruction = extraction_config.get('instruction', "")
45
+ text = extraction_config.get('text', "")
46
+ output_schema = extraction_config.get('output_schema', "")
47
+ constraint = extraction_config.get('constraint', "")
48
+ truth = extraction_config.get('truth', "")
49
+ use_file = extraction_config.get('use_file', False)
50
+ file_path = extraction_config.get('file_path', "")
51
+ mode = extraction_config.get('mode', "quick")
52
+ update_case = extraction_config.get('update_case', False)
53
+ show_trajectory = extraction_config.get('show_trajectory', False)
54
 
55
+ # Return a dictionary containing these variables
56
+ return {
57
+ "model": {
58
+ "model_name_or_path": model_name_or_path,
59
+ "category": model_category,
60
+ "api_key": api_key,
61
+ "base_url": base_url,
62
+ "vllm_serve": vllm_serve
63
+ },
64
+ "extraction": {
65
+ "task": task,
66
+ "instruction": instruction,
67
+ "text": text,
68
+ "output_schema": output_schema,
69
+ "constraint": constraint,
70
+ "truth": truth,
71
+ "use_file": use_file,
72
+ "file_path": file_path,
73
+ "mode": mode,
74
+ "update_case": update_case,
75
+ "show_trajectory": show_trajectory
76
+ }
77
+ }
78
+
79
  # Split the string text into chunks
80
  def chunk_str(text):
81
  sentences = sent_tokenize(text)
 
223
  if isinstance(value, dict):
224
  return frozenset((k, normalize_obj(v)) for k, v in value.items())
225
  elif isinstance(value, (list, set, tuple)):
 
226
  return tuple(Counter(map(normalize_obj, value)).items())
227
  elif isinstance(value, str):
228
  return format_string(value)
src/{main.py → webui.py} RENAMED
@@ -147,6 +147,7 @@ def create_interface():
147
  use_file=use_file,
148
  file_path=file_path,
149
  text=text,
 
150
  )
151
 
152
  ger_frontend_schema = str(ger_frontend_schema)
@@ -159,8 +160,6 @@ def create_interface():
159
 
160
  def clear_all():
161
  return (
162
- gr.update(value=""), # model
163
- gr.update(value=""), # API Key
164
  gr.update(value=""), # task
165
  gr.update(value="", visible=False), # instruction
166
  gr.update(value="", visible=False), # constraint
@@ -223,9 +222,6 @@ def create_interface():
223
  clear_button.click(
224
  fn=clear_all,
225
  outputs=[
226
- model_gr,
227
- api_key_gr,
228
- base_url_gr,
229
  task_gr,
230
  instruction_gr,
231
  constraint_gr,
 
147
  use_file=use_file,
148
  file_path=file_path,
149
  text=text,
150
+ show_trajectory=False,
151
  )
152
 
153
  ger_frontend_schema = str(ger_frontend_schema)
 
160
 
161
  def clear_all():
162
  return (
 
 
163
  gr.update(value=""), # task
164
  gr.update(value="", visible=False), # instruction
165
  gr.update(value="", visible=False), # constraint
 
222
  clear_button.click(
223
  fn=clear_all,
224
  outputs=[
 
 
 
225
  task_gr,
226
  instruction_gr,
227
  constraint_gr,
src/webui/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .interface import InterFace
 
 
src/webui/__pycache__/__init__.cpython-39.pyc DELETED
Binary file (197 Bytes)