Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitattributes +4 -0
- .gradio/certificate.pem +31 -0
- PROMPT_TEMPLATE.py +29 -0
- README.md +3 -9
- __pycache__/PROMPT_TEMPLATE.cpython-311.pyc +0 -0
- __pycache__/doubao_service.cpython-311.pyc +0 -0
- __pycache__/retriever.cpython-311.pyc +0 -0
- conf/config.ini +22 -0
- conf/logs.ini +28 -0
- doubao_service.py +166 -0
- interface.py +186 -0
- retriever.py +121 -0
- store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/data_level0.bin +3 -0
- store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/header.bin +3 -0
- store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/length.bin +3 -0
- store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/link_lists.bin +0 -0
- store/requirement_full_database/chroma.sqlite3 +3 -0
- store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/data_level0.bin +3 -0
- store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/header.bin +3 -0
- store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/length.bin +3 -0
- store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/link_lists.bin +0 -0
- store/requirement_v1_database/chroma.sqlite3 +3 -0
- store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/data_level0.bin +3 -0
- store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/header.bin +3 -0
- store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/length.bin +3 -0
- store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/link_lists.bin +0 -0
- store/requirement_v2_database/chroma.sqlite3 +3 -0
- store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/data_level0.bin +3 -0
- store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/header.bin +3 -0
- store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/length.bin +3 -0
- store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/link_lists.bin +0 -0
- store/safeguard_database/chroma.sqlite3 +3 -0
- test.ipynb +0 -0
- util/Embeddings.py +195 -0
- util/__init__.py +3 -0
- util/__pycache__/Embeddings.cpython-311.pyc +0 -0
- util/__pycache__/__init__.cpython-311.pyc +0 -0
- util/__pycache__/__init__.cpython-39.pyc +0 -0
- util/__pycache__/config_util.cpython-311.pyc +0 -0
- util/__pycache__/logger_util.cpython-311.pyc +0 -0
- util/__pycache__/logger_util.cpython-39.pyc +0 -0
- util/__pycache__/vector_base.cpython-311.pyc +0 -0
- util/__pycache__/vector_base.cpython-39.pyc +0 -0
- util/config_util.py +21 -0
- util/logger_util.py +23 -0
- util/vector_base.py +79 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
store/requirement_full_database/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
store/requirement_v1_database/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
store/requirement_v2_database/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
store/safeguard_database/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
PROMPT_TEMPLATE.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
former_string = """# Role
|
2 |
+
## You are an expert in the field of law, and you are good at explaining why law requirements are related to their matching safeguards.
|
3 |
+
|
4 |
+
# Task
|
5 |
+
You need to analyze **requirement** and **privacy objective dict** provided by the user, each key in the **privacy objective dict** is a specific **privacy objective** and has corresponding **safeguards list**, you need to explain why each **safeguard** is related to the **requirement**.
|
6 |
+
|
7 |
+
# Output format
|
8 |
+
For each **safeguard** in the **safeguards list**, explain its association with the requirement in the following format:
|
9 |
+
{
|
10 |
+
"privacy objective":
|
11 |
+
[
|
12 |
+
{
|
13 |
+
"safeguard number": "xxx",
|
14 |
+
"safeguard description": "xxx",
|
15 |
+
"analysis": "xxx"
|
16 |
+
},
|
17 |
+
...
|
18 |
+
]
|
19 |
+
}
|
20 |
+
Please return your answers in JSON format."""
|
21 |
+
input_format = """
|
22 |
+
# Input
|
23 |
+
Requirement:
|
24 |
+
{requirement}
|
25 |
+
Safeguards list:
|
26 |
+
{safeguards}
|
27 |
+
"""
|
28 |
+
def prompt_template(requirement, safeguards):
|
29 |
+
return former_string + input_format.format(requirement=requirement, safeguards=safeguards)
|
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: blue
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.23.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: NLL_Interface
|
3 |
+
app_file: interface.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 5.23.2
|
|
|
|
|
6 |
---
|
|
|
|
__pycache__/PROMPT_TEMPLATE.cpython-311.pyc
ADDED
Binary file (1.39 kB). View file
|
|
__pycache__/doubao_service.cpython-311.pyc
ADDED
Binary file (7.68 kB). View file
|
|
__pycache__/retriever.cpython-311.pyc
ADDED
Binary file (6.95 kB). View file
|
|
conf/config.ini
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[DouBao128ProModelInfo]
|
2 |
+
ACCESS_KEY = AKLTYjI0OWNiMGVmZGEwNDNhYjk3YzJhNDdlYTI1NTA5M2M
|
3 |
+
SECRET_KEY = TldReVlUaGlNVE0wT0dVeE5ESTFOV0l3T1RKa1lXSm1aak0zTXpJeE5qVQ==
|
4 |
+
BASE_URL = https://ark.cn-beijing.volces.com/api/v3
|
5 |
+
API_KEY = 0c654012-8989-455f-8a5d-032fc067fbc8
|
6 |
+
ENDPOINT_ID = ep-20241223113321-g47rr
|
7 |
+
CYCLE_TIMES = 3
|
8 |
+
MERGER_RETRY_TIMES = 4
|
9 |
+
MAX_RETRY_TIMES = 4
|
10 |
+
MAX_THREAD_NUM = 2
|
11 |
+
|
12 |
+
|
13 |
+
[DouBaoPreviewModelInfo]
|
14 |
+
ACCESS_KEY = AKLTYjI0OWNiMGVmZGEwNDNhYjk3YzJhNDdlYTI1NTA5M2M
|
15 |
+
SECRET_KEY = TldReVlUaGlNVE0wT0dVeE5ESTFOV0l3T1RKa1lXSm1aak0zTXpJeE5qVQ==
|
16 |
+
BASE_URL = https://ark.cn-beijing.volces.com/api/v3
|
17 |
+
API_KEY = 0c654012-8989-455f-8a5d-032fc067fbc8
|
18 |
+
ENDPOINT_ID = ep-20240923111539-mbwqc
|
19 |
+
CYCLE_TIMES = 3
|
20 |
+
MERGER_RETRY_TIMES = 4
|
21 |
+
MAX_RETRY_TIMES = 4
|
22 |
+
MAX_THREAD_NUM = 2
|
conf/logs.ini
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[loggers]
|
2 |
+
keys = root,Robot
|
3 |
+
|
4 |
+
[handlers]
|
5 |
+
keys = consoleHandler
|
6 |
+
|
7 |
+
[formatters]
|
8 |
+
keys = simpleFormatter
|
9 |
+
|
10 |
+
[logger_root]
|
11 |
+
level = INFO
|
12 |
+
handlers= consoleHandler
|
13 |
+
|
14 |
+
[logger_Robot]
|
15 |
+
level= INFO
|
16 |
+
handlers = consoleHandler
|
17 |
+
qualname = Robot
|
18 |
+
propagate=0
|
19 |
+
|
20 |
+
[handler_consoleHandler]
|
21 |
+
class = StreamHandler
|
22 |
+
level = INFO
|
23 |
+
formatter = simpleFormatter
|
24 |
+
|
25 |
+
|
26 |
+
[formatter_simpleFormatter]
|
27 |
+
format = %(asctime)s %(levelname)s %(filename)s %(lineno)d %(message)s
|
28 |
+
datefmt = %Y-%m-%d %H:%M:%S
|
doubao_service.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import requests
|
3 |
+
import json
|
4 |
+
from volcenginesdkarkruntime import Ark
|
5 |
+
from util.config_util import read_config as config
|
6 |
+
from util import logger
|
7 |
+
import volcenginesdkcore
|
8 |
+
import volcenginesdkark
|
9 |
+
from volcenginesdkcore.rest import ApiException
|
10 |
+
from util.logger_util import log_decorate
|
11 |
+
|
12 |
+
|
13 |
+
class DouBaoService:
|
14 |
+
|
15 |
+
def __init__(self, model_name):
|
16 |
+
self.conf = config()[f"{model_name}ModelInfo"]
|
17 |
+
self.client = self.init_client()
|
18 |
+
self._complete_args = {}
|
19 |
+
|
20 |
+
|
21 |
+
|
22 |
+
def init_client(self):
|
23 |
+
base_url = self.conf["BASE_URL"]
|
24 |
+
ak = self.conf["ACCESS_KEY"]
|
25 |
+
sk = self.conf["SECRET_KEY"]
|
26 |
+
# api_key = self.conf["API_KEY"]
|
27 |
+
client = Ark(ak=ak, sk=sk, base_url=base_url)
|
28 |
+
# client = Ark(ak=api_key, base_url=base_url)
|
29 |
+
return client
|
30 |
+
|
31 |
+
def get_api_key(self):
|
32 |
+
configuration = volcenginesdkcore.Configuration()
|
33 |
+
configuration.ak = self.conf["ACCESS_KEY"]
|
34 |
+
configuration.sk = self.conf["SECRET_KEY"]
|
35 |
+
configuration.region = "cn-beijing"
|
36 |
+
endpoint_id = self.conf["ENDPOINT_ID"]
|
37 |
+
|
38 |
+
volcenginesdkcore.Configuration.set_default(configuration)
|
39 |
+
|
40 |
+
# use global default configuration
|
41 |
+
api_instance = volcenginesdkark.ARKApi()
|
42 |
+
get_api_key_request = volcenginesdkark.GetApiKeyRequest(
|
43 |
+
duration_seconds=30 * 24 * 3600,
|
44 |
+
resource_type="endpoint",
|
45 |
+
resource_ids=[
|
46 |
+
endpoint_id
|
47 |
+
],
|
48 |
+
)
|
49 |
+
|
50 |
+
try:
|
51 |
+
resp = api_instance.get_api_key(get_api_key_request)
|
52 |
+
return resp.api_key
|
53 |
+
except ApiException as e:
|
54 |
+
logger.error(f"Exception when calling api: {e}")
|
55 |
+
|
56 |
+
def set_complete_args(self, temperature=None, top_p=None, max_token=None):
|
57 |
+
if temperature is not None:
|
58 |
+
self._complete_args["temperature"] = temperature
|
59 |
+
if top_p is not None:
|
60 |
+
self._complete_args["top_p"] = top_p
|
61 |
+
if max_token is not None:
|
62 |
+
self._complete_args["max_tokens"] = max_token
|
63 |
+
|
64 |
+
def form_user_role(self, content):
|
65 |
+
return {"role": "user", "content": content}
|
66 |
+
|
67 |
+
def form_sys_role(self, content):
|
68 |
+
return {"role": "system", "content": content}
|
69 |
+
|
70 |
+
def form_assistant_role(self, content):
|
71 |
+
return {"role": "assistant", "content": content}
|
72 |
+
|
73 |
+
@property
|
74 |
+
def complete_args(self):
|
75 |
+
return {"temperature": 0.01, "top_p": 0.7}
|
76 |
+
|
77 |
+
@log_decorate
|
78 |
+
def chat_complete(self, messages):
|
79 |
+
|
80 |
+
endpoint_id = self.conf["ENDPOINT_ID"]
|
81 |
+
completion = self.client.chat.completions.create(
|
82 |
+
model=endpoint_id,
|
83 |
+
messages=messages,
|
84 |
+
**self.complete_args
|
85 |
+
)
|
86 |
+
logger.info(f"complete doubao task, id: {completion.id}")
|
87 |
+
return completion.choices[0].message.content
|
88 |
+
|
89 |
+
def prd_to_keypoint(self, prd_content):
|
90 |
+
|
91 |
+
role_desc = {"role": "system", "content": PRD2KP_SYS}
|
92 |
+
|
93 |
+
messages = [
|
94 |
+
role_desc,
|
95 |
+
{"role": "user", "content": prd_content}
|
96 |
+
]
|
97 |
+
return self.chat_complete(messages)
|
98 |
+
|
99 |
+
def prd_to_cases(self, prd_content, case_language="Chinese"):
|
100 |
+
|
101 |
+
role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]}
|
102 |
+
|
103 |
+
messages = [
|
104 |
+
role_desc,
|
105 |
+
{"role": "user", "content": prd_content}
|
106 |
+
]
|
107 |
+
return self.chat_complete(messages)
|
108 |
+
|
109 |
+
def keypoint_to_case(self, key_points):
|
110 |
+
|
111 |
+
role_desc = {"role": "system", "content": KP2CASE_SYS}
|
112 |
+
|
113 |
+
messages = [
|
114 |
+
role_desc,
|
115 |
+
{"role": "user", "content": key_points}
|
116 |
+
]
|
117 |
+
return self.chat_complete(messages)
|
118 |
+
|
119 |
+
def case_merge_together(self, case_suits):
|
120 |
+
|
121 |
+
role_desc = {"role": "system", "content": CASE_AGG_SYS}
|
122 |
+
|
123 |
+
content_case_suits = ""
|
124 |
+
for i, case_suit in enumerate(case_suits):
|
125 |
+
case_suit_expr = json.dumps(case_suit, indent=4, ensure_ascii=False)
|
126 |
+
content_case_suits += f"来自初级测试工程师{i + 1}的测试用例:\n```json\n{case_suit_expr}\n```\n"
|
127 |
+
messages = [
|
128 |
+
role_desc,
|
129 |
+
{"role": "user", "content": content_case_suits}
|
130 |
+
]
|
131 |
+
completion = self.chat_complete(messages)
|
132 |
+
return completion
|
133 |
+
|
134 |
+
def cycle_more_case(self, prd_content, case_language="Chinese"):
|
135 |
+
|
136 |
+
role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]}
|
137 |
+
|
138 |
+
messages = [
|
139 |
+
role_desc,
|
140 |
+
{"role": "user", "content": PRD_CASE_1[case_language] + prd_content + "\n" + PRD_CASE_2[case_language]}
|
141 |
+
]
|
142 |
+
|
143 |
+
result = []
|
144 |
+
|
145 |
+
for sys in MORE_CASE_PROMPT[case_language]:
|
146 |
+
if sys:
|
147 |
+
messages.append({"role": "user", "content": sys})
|
148 |
+
reply = self.chat_complete(messages)
|
149 |
+
result.append(reply)
|
150 |
+
messages.append({"role": "assistant", "content": reply})
|
151 |
+
time.sleep(10)
|
152 |
+
return result
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
cli = DouBaoService("DouBao128Pro")
|
157 |
+
# print(cli.get_api_key())
|
158 |
+
# prd_content = requests.get("https://tosv.byted.org/obj/music-qa-bucket/xmind-test/de3ebc67410c43603034e21bfefa76a0.md").text
|
159 |
+
# aa = cli.cycle_more_case(prd_content, "English")
|
160 |
+
# print(aa)
|
161 |
+
|
162 |
+
print(cli.chat_complete(messages=[
|
163 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
164 |
+
{"role": "user", "content": "Introduce LLM shortly."},
|
165 |
+
]))
|
166 |
+
|
interface.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from json import loads, JSONDecodeError
|
5 |
+
import sys
|
6 |
+
import os
|
7 |
+
import ast
|
8 |
+
from util.vector_base import EmbeddingFunction, get_or_create_vector_base
|
9 |
+
from doubao_service import DouBaoService
|
10 |
+
from PROMPT_TEMPLATE import prompt_template
|
11 |
+
from util.Embeddings import TextEmb3LargeEmbedding
|
12 |
+
from langchain_core.documents import Document
|
13 |
+
from FlagEmbedding import FlagReranker
|
14 |
+
from retriever import retriever
|
15 |
+
import time
|
16 |
+
from bm25s import BM25, tokenize
|
17 |
+
import contextlib
|
18 |
+
import io
|
19 |
+
|
20 |
+
import gradio as gr
|
21 |
+
import time
|
22 |
+
|
23 |
+
client = DouBaoService("DouBao128Pro")
|
24 |
+
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
|
25 |
+
embedding = EmbeddingFunction(embeddingmodel)
|
26 |
+
safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding)
|
27 |
+
|
28 |
+
# reranker_model = FlagReranker(
|
29 |
+
# 'C://Users//Admin//Desktop//PDPO//NLL_LLM//model//bge-reranker-v2-m3',
|
30 |
+
# use_fp16=True,
|
31 |
+
# devices=["cpu"],
|
32 |
+
# )
|
33 |
+
|
34 |
+
OPTIONS = ['AI Governance',
|
35 |
+
'Data Accuracy',
|
36 |
+
'Data Minimization & Purpose Limitation',
|
37 |
+
'Data Retention',
|
38 |
+
'Data Security',
|
39 |
+
'Data Sharing',
|
40 |
+
'Individual Rights',
|
41 |
+
'Privacy by Design',
|
42 |
+
'Transparency']
|
43 |
+
|
44 |
+
|
45 |
+
def format_model_output(raw_output):
|
46 |
+
"""
|
47 |
+
处理模型输出:
|
48 |
+
- 将 \n 转换为实际换行
|
49 |
+
- 提取 ```json ``` 中的内容并格式化为可折叠的 JSON
|
50 |
+
"""
|
51 |
+
formatted = raw_output.replace('\\n', '\n')
|
52 |
+
def replace_json(match):
|
53 |
+
json_str = match.group(1).strip()
|
54 |
+
try:
|
55 |
+
json_obj = loads(json_str)
|
56 |
+
return f"```json\n{json.dumps(json_obj, indent=2, ensure_ascii=False)}\n```"
|
57 |
+
except JSONDecodeError:
|
58 |
+
return match.group(0)
|
59 |
+
|
60 |
+
formatted = re.sub(r'```json\n?(.*?)\n?```', replace_json, formatted, flags=re.DOTALL)
|
61 |
+
return ast.literal_eval(formatted)
|
62 |
+
|
63 |
+
def model_predict(input_text, if_split_po, topk, selected_items):
|
64 |
+
"""
|
65 |
+
selected_items: 用户选择的项目(可能是["All"]或具体PO)
|
66 |
+
"""
|
67 |
+
requirement = input_text
|
68 |
+
requirement = requirement.replace("\t", "").replace("\n", "").replace("\r", "")
|
69 |
+
if "All" in selected_items:
|
70 |
+
PO = OPTIONS
|
71 |
+
else:
|
72 |
+
PO = selected_items
|
73 |
+
if topk:
|
74 |
+
topk = int(topk)
|
75 |
+
else:
|
76 |
+
topk = 10
|
77 |
+
final_result = retriever(
|
78 |
+
requirement,
|
79 |
+
PO,
|
80 |
+
safeguard_vector_store,
|
81 |
+
reranker_model=None,
|
82 |
+
using_reranker=False,
|
83 |
+
using_BM25=False,
|
84 |
+
using_chroma=True,
|
85 |
+
k=topk,
|
86 |
+
if_split_po=if_split_po
|
87 |
+
)
|
88 |
+
mapping_safeguards = {}
|
89 |
+
for safeguard in final_result:
|
90 |
+
if safeguard[3] not in mapping_safeguards:
|
91 |
+
mapping_safeguards[safeguard[3]] = []
|
92 |
+
mapping_safeguards[safeguard[3]].append(
|
93 |
+
{
|
94 |
+
"Score": safeguard[0],
|
95 |
+
"Safeguard Number": safeguard[1],
|
96 |
+
"Safeguard Description": safeguard[2]
|
97 |
+
}
|
98 |
+
)
|
99 |
+
prompt = prompt_template(requirement, mapping_safeguards)
|
100 |
+
response = client.chat_complete(messages=[
|
101 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
102 |
+
{"role": "user", "content": prompt},
|
103 |
+
])
|
104 |
+
# return {"requirement": requirement, "safeguards": mapping_safeguards}
|
105 |
+
print("requirement:", requirement)
|
106 |
+
print("mapping safeguards:", mapping_safeguards)
|
107 |
+
print("response:", response)
|
108 |
+
return {"requirement": requirement, "safeguards": format_model_output(response)}
|
109 |
+
|
110 |
+
with gr.Blocks(title="New Law Landing") as demo:
|
111 |
+
gr.Markdown("## 🏙️ New Law Landing")
|
112 |
+
|
113 |
+
requirement = gr.Textbox(label="Input Requirements", placeholder="Example: Data Minimization Consent for incompatible purposes")
|
114 |
+
details = gr.Textbox(label="Input Details", placeholder="Example: Require consent for...")
|
115 |
+
|
116 |
+
# 修改为 Number 输入组件
|
117 |
+
topk = gr.Number(
|
118 |
+
label="Top K safeguards",
|
119 |
+
value=10,
|
120 |
+
precision=0,
|
121 |
+
minimum=1,
|
122 |
+
interactive=True
|
123 |
+
)
|
124 |
+
|
125 |
+
with gr.Row():
|
126 |
+
with gr.Column(scale=1):
|
127 |
+
if_split_po = gr.Checkbox(
|
128 |
+
label="If Split Privacy Objective",
|
129 |
+
value=True,
|
130 |
+
info="Recall K Safeguards for each Privacy Objective"
|
131 |
+
)
|
132 |
+
with gr.Column(scale=1):
|
133 |
+
all_checkbox = gr.Checkbox(
|
134 |
+
label="ALL Privacy Objective",
|
135 |
+
value=True,
|
136 |
+
info="No specific Privacy Objective is specified"
|
137 |
+
)
|
138 |
+
with gr.Column(scale=4):
|
139 |
+
PO_checklist = gr.CheckboxGroup(
|
140 |
+
label="Choose Privacy Objective",
|
141 |
+
choices=OPTIONS,
|
142 |
+
value=[],
|
143 |
+
interactive=True
|
144 |
+
)
|
145 |
+
|
146 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
147 |
+
result_output = gr.JSON(label="Related safeguards", open=True)
|
148 |
+
|
149 |
+
|
150 |
+
def sync_checkboxes(selected_items, all_selected):
|
151 |
+
if len(selected_items) > 0:
|
152 |
+
return False
|
153 |
+
return all_selected
|
154 |
+
|
155 |
+
PO_checklist.change(
|
156 |
+
fn=sync_checkboxes,
|
157 |
+
inputs=[PO_checklist, all_checkbox],
|
158 |
+
outputs=all_checkbox
|
159 |
+
)
|
160 |
+
|
161 |
+
def sync_all(selected_all, current_selection):
|
162 |
+
if selected_all:
|
163 |
+
return []
|
164 |
+
return current_selection
|
165 |
+
|
166 |
+
all_checkbox.change(
|
167 |
+
fn=sync_all,
|
168 |
+
inputs=[all_checkbox, PO_checklist],
|
169 |
+
outputs=PO_checklist
|
170 |
+
)
|
171 |
+
|
172 |
+
def process_inputs(requirement, details, topk, if_split_po, all_selected, PO_selected):
|
173 |
+
input_text = requirement + ": " + details
|
174 |
+
if all_selected:
|
175 |
+
return model_predict(input_text, if_split_po, int(topk), ["All"])
|
176 |
+
else:
|
177 |
+
return model_predict(input_text, if_split_po, int(topk), PO_selected)
|
178 |
+
|
179 |
+
submit_btn.click(
|
180 |
+
fn=process_inputs,
|
181 |
+
inputs=[requirement, details, topk, if_split_po, all_checkbox, PO_checklist],
|
182 |
+
outputs=[result_output]
|
183 |
+
)
|
184 |
+
|
185 |
+
if __name__ == "__main__":
|
186 |
+
demo.launch(share=True)
|
retriever.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
from collections import defaultdict
|
6 |
+
from util.vector_base import EmbeddingFunction, get_or_create_vector_base
|
7 |
+
from util.Embeddings import TextEmb3LargeEmbedding
|
8 |
+
from langchain_core.documents import Document
|
9 |
+
from FlagEmbedding import FlagReranker
|
10 |
+
import time
|
11 |
+
from bm25s import BM25, tokenize
|
12 |
+
import contextlib
|
13 |
+
import io
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
def rrf(rankings, k = 60):
|
17 |
+
res = 0
|
18 |
+
for r in rankings:
|
19 |
+
res += 1 / (r + k)
|
20 |
+
return res
|
21 |
+
|
22 |
+
def retriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=False, using_BM25=False, using_chroma=True, k=20, if_split_po=True):
|
23 |
+
final_result = []
|
24 |
+
if not if_split_po:
|
25 |
+
final_result = multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
|
26 |
+
else:
|
27 |
+
for po in PO:
|
28 |
+
po_result = multiretriever(requirement, [po], safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
|
29 |
+
for safeguard in po_result:
|
30 |
+
final_result.append(safeguard)
|
31 |
+
return final_result
|
32 |
+
|
33 |
+
def multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=True, using_BM25=False, using_chroma=True, k=20):
|
34 |
+
"""
|
35 |
+
requirements_dict: [
|
36 |
+
requirement: {
|
37 |
+
"PO": [],
|
38 |
+
"safeguard": []
|
39 |
+
}
|
40 |
+
]
|
41 |
+
"""
|
42 |
+
candidate_safeguards = []
|
43 |
+
po_list = [po.lower().rstrip() for po in PO if po]
|
44 |
+
if "young users" in po_list and len(po_list) == 1:
|
45 |
+
return []
|
46 |
+
candidate_safeguards = safeguard_vector_store.get(where={"po": {"$in": po_list}})
|
47 |
+
safeguard_dict, safeguard_content = {}, []
|
48 |
+
for id, content, metadata in zip(candidate_safeguards['ids'], candidate_safeguards['documents'], candidate_safeguards['metadatas']):
|
49 |
+
safeguard_dict[content] = {
|
50 |
+
"metadata": metadata,
|
51 |
+
"rank": [],
|
52 |
+
"rrf_score": 0
|
53 |
+
}
|
54 |
+
safeguard_content.append(content)
|
55 |
+
|
56 |
+
# Reranker
|
57 |
+
if using_reranker:
|
58 |
+
content_pairs, reranking_rank, reranking_results = [], [], []
|
59 |
+
for safeguard in safeguard_content:
|
60 |
+
content_pairs.append([requirement, safeguard])
|
61 |
+
safeguard_rerank_scores = reranker_model.compute_score(content_pairs)
|
62 |
+
for content_pair, score in zip(content_pairs, safeguard_rerank_scores):
|
63 |
+
reranking_rank.append((content_pair[1], score))
|
64 |
+
reranking_results = sorted(reranking_rank, key=lambda x: x[1], reverse=True)
|
65 |
+
for safeguard, score in reranking_results:
|
66 |
+
safeguard_dict[safeguard]['rank'].append(reranking_results.index((safeguard, score)) + 1)
|
67 |
+
|
68 |
+
# BM25
|
69 |
+
if using_BM25:
|
70 |
+
with contextlib.redirect_stdout(io.StringIO()):
|
71 |
+
bm25_retriever = BM25(corpus=safeguard_content)
|
72 |
+
bm25_retriever.index(tokenize(safeguard_content))
|
73 |
+
bm25_results, scores = bm25_retriever.retrieve(tokenize(requirement), k = len(safeguard_content))
|
74 |
+
bm25_retrieval_rank = 1
|
75 |
+
for safeguard in bm25_results[0]:
|
76 |
+
safeguard_dict[safeguard]['rank'].append(bm25_retrieval_rank)
|
77 |
+
bm25_retrieval_rank += 1
|
78 |
+
|
79 |
+
# chroma retrieval
|
80 |
+
if using_chroma:
|
81 |
+
retrieved_safeguards = safeguard_vector_store.similarity_search_with_score(query=requirement, k=len(candidate_safeguards['ids']), filter={"po": {"$in": po_list}})
|
82 |
+
retrieval_rank = 1
|
83 |
+
for safeguard in retrieved_safeguards:
|
84 |
+
safeguard_dict[safeguard[0].page_content]['rank'].append(retrieval_rank)
|
85 |
+
retrieval_rank += 1
|
86 |
+
|
87 |
+
final_result = []
|
88 |
+
for safeguard in safeguard_content:
|
89 |
+
safeguard_dict[safeguard]['rrf_score'] = rrf(safeguard_dict[safeguard]['rank'])
|
90 |
+
final_result.append((safeguard_dict[safeguard]['rrf_score'], safeguard_dict[safeguard]['metadata']['safeguard_number'], safeguard, safeguard_dict[safeguard]['metadata']['po']))
|
91 |
+
final_result.sort(key=lambda x: x[0], reverse=True)
|
92 |
+
|
93 |
+
# top k
|
94 |
+
topk_final_result = final_result[:k]
|
95 |
+
|
96 |
+
return topk_final_result
|
97 |
+
|
98 |
+
if __name__=="__main__":
|
99 |
+
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
|
100 |
+
embedding = EmbeddingFunction(embeddingmodel)
|
101 |
+
safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding)
|
102 |
+
reranker_model = FlagReranker(
|
103 |
+
'/root/PTR-LLM/tasks/pcf/model/bge-reranker-v2-m3',
|
104 |
+
use_fp16=True,
|
105 |
+
devices=["cpu"],
|
106 |
+
)
|
107 |
+
requirement = """
|
108 |
+
Data Minimization Consent for incompatible purposes: Require consent for additional use of personal information not reasonably necessary to or incompatible with original purpose disclosure.
|
109 |
+
"""
|
110 |
+
PO = ["Data Minimization & Purpose Limitation", "Transparency"]
|
111 |
+
final_result = retriever(
|
112 |
+
requirement,
|
113 |
+
PO,
|
114 |
+
safeguard_vector_store,
|
115 |
+
reranker_model,
|
116 |
+
using_reranker=True,
|
117 |
+
using_BM25=False,
|
118 |
+
using_chroma=True,
|
119 |
+
k=10
|
120 |
+
)
|
121 |
+
print(final_result)
|
store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/data_level0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b8d4b3825a7c7a773e22fa3eeef0e7d15a695f5c4183aeff5beb07741a68679
|
3 |
+
size 12428000
|
store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/header.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8a3ec48846fc6fdfaef19f5ed2508f0bf3da4a3c93b0f6b3dd21f0a22ec1026
|
3 |
+
size 100
|
store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/length.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fc19b1997119425765295aeab72d76faa6927d4f83985d328c26f20468d6cc76
|
3 |
+
size 4000
|
store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/link_lists.bin
ADDED
File without changes
|
store/requirement_full_database/chroma.sqlite3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:debbda97fa444d1beed59205da5310caa07fc5e1fea98ee5d217bb1cd86b3312
|
3 |
+
size 2031616
|
store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/data_level0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b8d4b3825a7c7a773e22fa3eeef0e7d15a695f5c4183aeff5beb07741a68679
|
3 |
+
size 12428000
|
store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/header.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8a3ec48846fc6fdfaef19f5ed2508f0bf3da4a3c93b0f6b3dd21f0a22ec1026
|
3 |
+
size 100
|
store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/length.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e0c0e05e944e59611aa54c4b0b708f835a8a0daf4baec11208e3f60773b22d89
|
3 |
+
size 4000
|
store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/link_lists.bin
ADDED
File without changes
|
store/requirement_v1_database/chroma.sqlite3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f177f30d3354a5cf2c14378bfe1b73755684d9c5fc38046f9e4339a1180af0a2
|
3 |
+
size 1212416
|
store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/data_level0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b8d4b3825a7c7a773e22fa3eeef0e7d15a695f5c4183aeff5beb07741a68679
|
3 |
+
size 12428000
|
store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/header.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8a3ec48846fc6fdfaef19f5ed2508f0bf3da4a3c93b0f6b3dd21f0a22ec1026
|
3 |
+
size 100
|
store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/length.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a911716b3f8450b156db9e04a3c81548395cabac0c846dd1e1eef832991c120
|
3 |
+
size 4000
|
store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/link_lists.bin
ADDED
File without changes
|
store/requirement_v2_database/chroma.sqlite3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6d2f15577b5bcd6c8ce845d894cc17196e71c71c3909eec73f9bfedee400c90
|
3 |
+
size 991232
|
store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/data_level0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b8d4b3825a7c7a773e22fa3eeef0e7d15a695f5c4183aeff5beb07741a68679
|
3 |
+
size 12428000
|
store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/header.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8a3ec48846fc6fdfaef19f5ed2508f0bf3da4a3c93b0f6b3dd21f0a22ec1026
|
3 |
+
size 100
|
store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/length.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c710099c23f51c095ec65b43539cd7534bb745aebf55fef0a9c06229121caca
|
3 |
+
size 4000
|
store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/link_lists.bin
ADDED
File without changes
|
store/safeguard_database/chroma.sqlite3
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0730981819c475f93d173a347f891b5e28e02cce24a63a578ca43c98a13d9c10
|
3 |
+
size 5095424
|
test.ipynb
ADDED
File without changes
|
util/Embeddings.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from copy import copy
|
3 |
+
from typing import Dict, List, Optional, Tuple, Union
|
4 |
+
import numpy as np
|
5 |
+
import time
|
6 |
+
from functools import wraps
|
7 |
+
|
8 |
+
os.environ['CURL_CA_BUNDLE'] = ''
|
9 |
+
from dotenv import load_dotenv, find_dotenv
|
10 |
+
_ = load_dotenv(find_dotenv())
|
11 |
+
|
12 |
+
|
13 |
+
class BaseEmbeddings:
|
14 |
+
"""
|
15 |
+
Base class for embeddings
|
16 |
+
"""
|
17 |
+
def __init__(self, path: str, is_api: bool) -> None:
|
18 |
+
self.path = path
|
19 |
+
self.is_api = is_api
|
20 |
+
|
21 |
+
def get_embedding(self, text: str, model: str) -> List[float]:
|
22 |
+
raise NotImplementedError
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
|
26 |
+
"""
|
27 |
+
calculate cosine similarity between two vectors
|
28 |
+
"""
|
29 |
+
dot_product = np.dot(vector1, vector2)
|
30 |
+
magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
|
31 |
+
if not magnitude:
|
32 |
+
return 0
|
33 |
+
return dot_product / magnitude
|
34 |
+
|
35 |
+
|
36 |
+
class OpenAIEmbedding(BaseEmbeddings):
|
37 |
+
"""
|
38 |
+
class for OpenAI embeddings
|
39 |
+
"""
|
40 |
+
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
41 |
+
super().__init__(path, is_api)
|
42 |
+
if self.is_api:
|
43 |
+
from openai import OpenAI
|
44 |
+
self.client = OpenAI()
|
45 |
+
self.client.api_key = os.getenv("OPENAI_API_KEY")
|
46 |
+
self.client.base_url = os.getenv("OPENAI_BASE_URL")
|
47 |
+
|
48 |
+
def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
|
49 |
+
if self.is_api:
|
50 |
+
text = text.replace("\n", " ")
|
51 |
+
return self.client.embeddings.create(input=[text], model=model).data[0].embedding
|
52 |
+
else:
|
53 |
+
raise NotImplementedError
|
54 |
+
|
55 |
+
class JinaEmbedding(BaseEmbeddings):
|
56 |
+
"""
|
57 |
+
class for Jina embeddings
|
58 |
+
"""
|
59 |
+
def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
|
60 |
+
super().__init__(path, is_api)
|
61 |
+
self._model = self.load_model()
|
62 |
+
|
63 |
+
def get_embedding(self, text: str) -> List[float]:
|
64 |
+
return self._model.encode([text])[0].tolist()
|
65 |
+
|
66 |
+
def load_model(self):
|
67 |
+
import torch
|
68 |
+
from transformers import AutoModel
|
69 |
+
if torch.cuda.is_available():
|
70 |
+
device = torch.device("cuda")
|
71 |
+
else:
|
72 |
+
device = torch.device("cpu")
|
73 |
+
model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
|
74 |
+
return model
|
75 |
+
|
76 |
+
class ZhipuEmbedding(BaseEmbeddings):
|
77 |
+
"""
|
78 |
+
class for Zhipu embeddings
|
79 |
+
"""
|
80 |
+
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
81 |
+
super().__init__(path, is_api)
|
82 |
+
if self.is_api:
|
83 |
+
from zhipuai import ZhipuAI
|
84 |
+
self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
|
85 |
+
|
86 |
+
def get_embedding(self, text: str) -> List[float]:
|
87 |
+
response = self.client.embeddings.create(
|
88 |
+
model="embedding-2",
|
89 |
+
input=text,
|
90 |
+
)
|
91 |
+
return response.data[0].embedding
|
92 |
+
|
93 |
+
|
94 |
+
class DashscopeEmbedding(BaseEmbeddings):
|
95 |
+
"""
|
96 |
+
class for Dashscope embeddings
|
97 |
+
"""
|
98 |
+
def __init__(self, path: str = '', is_api: bool = True) -> None:
|
99 |
+
super().__init__(path, is_api)
|
100 |
+
if self.is_api:
|
101 |
+
import dashscope
|
102 |
+
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
103 |
+
self.client = dashscope.TextEmbedding
|
104 |
+
|
105 |
+
def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]:
|
106 |
+
response = self.client.call(
|
107 |
+
model=model,
|
108 |
+
input=text
|
109 |
+
)
|
110 |
+
return response.output['embeddings'][0]['embedding']
|
111 |
+
|
112 |
+
|
113 |
+
class BgeEmbedding(BaseEmbeddings):
|
114 |
+
"""
|
115 |
+
class for BGE embeddings
|
116 |
+
"""
|
117 |
+
|
118 |
+
def __init__(self, path: str = 'BAAI/bge-en-icl', is_api: bool = False) -> None:
|
119 |
+
super().__init__(path, is_api)
|
120 |
+
self._model, self._tokenizer = self.load_model(path)
|
121 |
+
|
122 |
+
def get_embedding(self, text: str) -> List[float]:
|
123 |
+
import torch
|
124 |
+
encoded_input = self._tokenizer([text], padding=True, truncation=True, return_tensors='pt')
|
125 |
+
encoded_input = {k: v.to(self._model.device) for k, v in encoded_input.items()}
|
126 |
+
with torch.no_grad():
|
127 |
+
model_output = self._model(**encoded_input)
|
128 |
+
sentence_embeddings = model_output[0][:, 0]
|
129 |
+
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
|
130 |
+
return sentence_embeddings[0].tolist()
|
131 |
+
|
132 |
+
def load_model(self, path: str):
|
133 |
+
import torch
|
134 |
+
from transformers import AutoModel, AutoTokenizer
|
135 |
+
if torch.cuda.is_available():
|
136 |
+
device = torch.device("cuda")
|
137 |
+
else:
|
138 |
+
device = torch.device("cpu")
|
139 |
+
tokenizer = AutoTokenizer.from_pretrained(path)
|
140 |
+
model = AutoModel.from_pretrained(path).to(device)
|
141 |
+
model.eval()
|
142 |
+
return model, tokenizer
|
143 |
+
|
144 |
+
|
145 |
+
def rate_limiter():
|
146 |
+
def rate_limiter_decorator(func):
|
147 |
+
@wraps(func)
|
148 |
+
def wrapper(self, *args, **kwargs):
|
149 |
+
max_calls_per_minute = self.max_qpm
|
150 |
+
interval = 60 / max_calls_per_minute
|
151 |
+
current_time = time.time()
|
152 |
+
|
153 |
+
# Check if there's a record of the last call, if not set it to 0
|
154 |
+
if not hasattr(self, '_last_called'):
|
155 |
+
self._last_called = 0
|
156 |
+
elapsed_time = current_time - self._last_called
|
157 |
+
if elapsed_time < interval:
|
158 |
+
time_to_wait = interval - elapsed_time
|
159 |
+
if self.silent is False:
|
160 |
+
print(f"## Rate limit reached. Waiting for {time_to_wait:.2f} seconds.")
|
161 |
+
time.sleep(time_to_wait)
|
162 |
+
result = func(self, *args, **kwargs)
|
163 |
+
self._last_called = time.time()
|
164 |
+
return result
|
165 |
+
return wrapper
|
166 |
+
|
167 |
+
return rate_limiter_decorator
|
168 |
+
|
169 |
+
|
170 |
+
class TextEmb3LargeEmbedding(BaseEmbeddings):
|
171 |
+
"""
|
172 |
+
class for text-embedding-3-large embeddings
|
173 |
+
"""
|
174 |
+
def __init__(self, max_qpm, is_silent=False):
|
175 |
+
from langchain_openai import AzureOpenAIEmbeddings
|
176 |
+
|
177 |
+
## https://gpt.bytedance.net/gpt_openapi/
|
178 |
+
base_url = "https://search-va.byteintl.net/gpt/openapi/online/v2/crawl"
|
179 |
+
api_version = "2024-03-01-preview"
|
180 |
+
ak = "5dXdIKxZc8JWVVgvX0DN92HWIYb9NfEb_GPT_AK"
|
181 |
+
model_name = "text-embedding-3-large"
|
182 |
+
api_type = "azure"
|
183 |
+
self.llm = AzureOpenAIEmbeddings(
|
184 |
+
azure_endpoint=base_url,
|
185 |
+
openai_api_version=api_version,
|
186 |
+
deployment=model_name,
|
187 |
+
openai_api_key=ak,
|
188 |
+
openai_api_type=api_type,
|
189 |
+
)
|
190 |
+
self.max_qpm = max_qpm
|
191 |
+
self.silent = is_silent
|
192 |
+
|
193 |
+
@rate_limiter()
|
194 |
+
def get_embedding(self, text: str):
|
195 |
+
return self.llm.embed_query(text)
|
util/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .logger_util import get_logger
|
2 |
+
|
3 |
+
logger = get_logger()
|
util/__pycache__/Embeddings.cpython-311.pyc
ADDED
Binary file (13.7 kB). View file
|
|
util/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (258 Bytes). View file
|
|
util/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (205 Bytes). View file
|
|
util/__pycache__/config_util.cpython-311.pyc
ADDED
Binary file (1.35 kB). View file
|
|
util/__pycache__/logger_util.cpython-311.pyc
ADDED
Binary file (1.85 kB). View file
|
|
util/__pycache__/logger_util.cpython-39.pyc
ADDED
Binary file (1.01 kB). View file
|
|
util/__pycache__/vector_base.cpython-311.pyc
ADDED
Binary file (5.76 kB). View file
|
|
util/__pycache__/vector_base.cpython-39.pyc
ADDED
Binary file (3.39 kB). View file
|
|
util/config_util.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import bytedenv
|
5 |
+
import configparser
|
6 |
+
|
7 |
+
ROOT_PATH = os.path.split(os.path.split(os.path.abspath(__file__))[0])[0]
|
8 |
+
|
9 |
+
|
10 |
+
def read_config():
|
11 |
+
config_file = ROOT_PATH + "\conf\config.ini"
|
12 |
+
config_ini = configparser.ConfigParser()
|
13 |
+
config_ini.read(config_file)
|
14 |
+
model_name = "DouBao128Pro"
|
15 |
+
return config_ini
|
16 |
+
|
17 |
+
def read_json(filepath):
|
18 |
+
with open(filepath, 'r') as f:
|
19 |
+
result = json.load(f)
|
20 |
+
return result
|
21 |
+
|
util/logger_util.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import logging.config
|
4 |
+
import traceback
|
5 |
+
from functools import wraps
|
6 |
+
|
7 |
+
|
8 |
+
def get_logger():
|
9 |
+
root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
10 |
+
logging.config.fileConfig(os.path.join(root_path, "conf", "logs.ini"))
|
11 |
+
logger = logging.getLogger("Robot")
|
12 |
+
return logger
|
13 |
+
|
14 |
+
|
15 |
+
def log_decorate(func):
|
16 |
+
@wraps(func)
|
17 |
+
def log(*args, **kwargs):
|
18 |
+
logger = get_logger()
|
19 |
+
try:
|
20 |
+
return func(*args, **kwargs)
|
21 |
+
except Exception as e:
|
22 |
+
logger.error(f"{func.__name__} is error, logId: {e.args}, errMsg is: {traceback.format_exc()}")
|
23 |
+
return log
|
util/vector_base.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from langchain_chroma import Chroma
|
3 |
+
from langchain_core.documents import Document
|
4 |
+
sys.path.append('C://Users//Admin//Desktop//PDPO//NLL_LLM//util')
|
5 |
+
from Embeddings import TextEmb3LargeEmbedding
|
6 |
+
from pathlib import Path
|
7 |
+
import time
|
8 |
+
|
9 |
+
class EmbeddingFunction():
|
10 |
+
def __init__(self, embeddingmodel):
|
11 |
+
self.embeddingmodel = embeddingmodel
|
12 |
+
def embed_query(self, query):
|
13 |
+
return list(self.embeddingmodel.get_embedding(query))
|
14 |
+
def embed_documents(self, documents):
|
15 |
+
return [self.embeddingmodel.get_embedding(document) for document in documents]
|
16 |
+
|
17 |
+
def get_or_create_vector_base(collection_name: str, embedding, documents=None) -> Chroma:
|
18 |
+
"""
|
19 |
+
判断vector store是否已经构建好,如果没有构建好,则先初始化vector store。不使用embed_documents
|
20 |
+
方法批量初始化vector store而是for循环逐个加入,同时使用sleep,以此避免调用openai的接口达到最大
|
21 |
+
上限而导致初始化失败。
|
22 |
+
"""
|
23 |
+
persist_directory = "C://Users//Admin//Desktop//PDPO//NLL_LLM//store//" +collection_name
|
24 |
+
persist_path = Path(persist_directory)
|
25 |
+
if not persist_path.exists and not documents:
|
26 |
+
raise ValueError("vector store does not exist and documents is empty")
|
27 |
+
elif persist_path.exists():
|
28 |
+
print("vector store already exists")
|
29 |
+
vector_store = Chroma(
|
30 |
+
collection_name=collection_name,
|
31 |
+
embedding_function=embedding,
|
32 |
+
persist_directory=persist_directory
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
print("start creating vector store")
|
36 |
+
vector_store = Chroma(
|
37 |
+
collection_name=collection_name,
|
38 |
+
embedding_function=embedding,
|
39 |
+
persist_directory=persist_directory
|
40 |
+
)
|
41 |
+
for document in documents:
|
42 |
+
vector_store.add_documents(documents=[document])
|
43 |
+
time.sleep(1)
|
44 |
+
return vector_store
|
45 |
+
|
46 |
+
if __name__=="__main__":
|
47 |
+
import pandas as pd
|
48 |
+
requirements_data = pd.read_csv("/root/PTR-LLM/tasks/pcf/reference/NLL_DATA_NEW_Test.csv")
|
49 |
+
requirements_dict_v2 = {}
|
50 |
+
for index, row in requirements_data.iterrows():
|
51 |
+
requirement = row['Requirement'].split("- ")[1]
|
52 |
+
requirement = requirement + ": " + row['Details']
|
53 |
+
requirement = requirement.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
|
54 |
+
if requirement not in requirements_dict_v2:
|
55 |
+
requirements_dict_v2[requirement] = {
|
56 |
+
'PO': set(),
|
57 |
+
'safeguard': set()
|
58 |
+
}
|
59 |
+
requirements_dict_v2[requirement]['PO'].add(row['PCF-Privacy Objective'].lower().rstrip() if isinstance(row['PCF-Privacy Objective'], str) else None)
|
60 |
+
requirements_dict_v2[requirement]['safeguard'].add(row['Safeguard'].lower().rstrip())
|
61 |
+
index = 0
|
62 |
+
documents = []
|
63 |
+
for key, value in requirements_dict_v2.items():
|
64 |
+
page_content = key
|
65 |
+
metadata = {
|
66 |
+
"index": index,
|
67 |
+
"version":2,
|
68 |
+
"PO": str([po for po in value['PO'] if po]),
|
69 |
+
"safeguard":str([safeguard for safeguard in value['safeguard']])
|
70 |
+
}
|
71 |
+
index += 1
|
72 |
+
document=Document(
|
73 |
+
page_content=page_content,
|
74 |
+
metadata=metadata
|
75 |
+
)
|
76 |
+
documents.append(document)
|
77 |
+
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
|
78 |
+
embedding = EmbeddingFunction(embeddingmodel)
|
79 |
+
requirement_v2_vector_store = get_or_create_vector_base('requirement_v2_database', embedding, documents)
|