bytedancerneat commited on
Commit
929938f
·
verified ·
1 Parent(s): aff53fd

Upload folder using huggingface_hub

Browse files
Files changed (46) hide show
  1. .gitattributes +4 -0
  2. .gradio/certificate.pem +31 -0
  3. PROMPT_TEMPLATE.py +29 -0
  4. README.md +3 -9
  5. __pycache__/PROMPT_TEMPLATE.cpython-311.pyc +0 -0
  6. __pycache__/doubao_service.cpython-311.pyc +0 -0
  7. __pycache__/retriever.cpython-311.pyc +0 -0
  8. conf/config.ini +22 -0
  9. conf/logs.ini +28 -0
  10. doubao_service.py +166 -0
  11. interface.py +186 -0
  12. retriever.py +121 -0
  13. store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/data_level0.bin +3 -0
  14. store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/header.bin +3 -0
  15. store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/length.bin +3 -0
  16. store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/link_lists.bin +0 -0
  17. store/requirement_full_database/chroma.sqlite3 +3 -0
  18. store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/data_level0.bin +3 -0
  19. store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/header.bin +3 -0
  20. store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/length.bin +3 -0
  21. store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/link_lists.bin +0 -0
  22. store/requirement_v1_database/chroma.sqlite3 +3 -0
  23. store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/data_level0.bin +3 -0
  24. store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/header.bin +3 -0
  25. store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/length.bin +3 -0
  26. store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/link_lists.bin +0 -0
  27. store/requirement_v2_database/chroma.sqlite3 +3 -0
  28. store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/data_level0.bin +3 -0
  29. store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/header.bin +3 -0
  30. store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/length.bin +3 -0
  31. store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/link_lists.bin +0 -0
  32. store/safeguard_database/chroma.sqlite3 +3 -0
  33. test.ipynb +0 -0
  34. util/Embeddings.py +195 -0
  35. util/__init__.py +3 -0
  36. util/__pycache__/Embeddings.cpython-311.pyc +0 -0
  37. util/__pycache__/__init__.cpython-311.pyc +0 -0
  38. util/__pycache__/__init__.cpython-39.pyc +0 -0
  39. util/__pycache__/config_util.cpython-311.pyc +0 -0
  40. util/__pycache__/logger_util.cpython-311.pyc +0 -0
  41. util/__pycache__/logger_util.cpython-39.pyc +0 -0
  42. util/__pycache__/vector_base.cpython-311.pyc +0 -0
  43. util/__pycache__/vector_base.cpython-39.pyc +0 -0
  44. util/config_util.py +21 -0
  45. util/logger_util.py +23 -0
  46. util/vector_base.py +79 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ store/requirement_full_database/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
37
+ store/requirement_v1_database/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
38
+ store/requirement_v2_database/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
39
+ store/safeguard_database/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
PROMPT_TEMPLATE.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ former_string = """# Role
2
+ ## You are an expert in the field of law, and you are good at explaining why law requirements are related to their matching safeguards.
3
+
4
+ # Task
5
+ You need to analyze **requirement** and **privacy objective dict** provided by the user, each key in the **privacy objective dict** is a specific **privacy objective** and has corresponding **safeguards list**, you need to explain why each **safeguard** is related to the **requirement**.
6
+
7
+ # Output format
8
+ For each **safeguard** in the **safeguards list**, explain its association with the requirement in the following format:
9
+ {
10
+ "privacy objective":
11
+ [
12
+ {
13
+ "safeguard number": "xxx",
14
+ "safeguard description": "xxx",
15
+ "analysis": "xxx"
16
+ },
17
+ ...
18
+ ]
19
+ }
20
+ Please return your answers in JSON format."""
21
+ input_format = """
22
+ # Input
23
+ Requirement:
24
+ {requirement}
25
+ Safeguards list:
26
+ {safeguards}
27
+ """
28
+ def prompt_template(requirement, safeguards):
29
+ return former_string + input_format.format(requirement=requirement, safeguards=safeguards)
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: NLL Interface
3
- emoji: 🏆
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 5.23.3
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: NLL_Interface
3
+ app_file: interface.py
 
 
4
  sdk: gradio
5
+ sdk_version: 5.23.2
 
 
6
  ---
 
 
__pycache__/PROMPT_TEMPLATE.cpython-311.pyc ADDED
Binary file (1.39 kB). View file
 
__pycache__/doubao_service.cpython-311.pyc ADDED
Binary file (7.68 kB). View file
 
__pycache__/retriever.cpython-311.pyc ADDED
Binary file (6.95 kB). View file
 
conf/config.ini ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [DouBao128ProModelInfo]
2
+ ACCESS_KEY = AKLTYjI0OWNiMGVmZGEwNDNhYjk3YzJhNDdlYTI1NTA5M2M
3
+ SECRET_KEY = TldReVlUaGlNVE0wT0dVeE5ESTFOV0l3T1RKa1lXSm1aak0zTXpJeE5qVQ==
4
+ BASE_URL = https://ark.cn-beijing.volces.com/api/v3
5
+ API_KEY = 0c654012-8989-455f-8a5d-032fc067fbc8
6
+ ENDPOINT_ID = ep-20241223113321-g47rr
7
+ CYCLE_TIMES = 3
8
+ MERGER_RETRY_TIMES = 4
9
+ MAX_RETRY_TIMES = 4
10
+ MAX_THREAD_NUM = 2
11
+
12
+
13
+ [DouBaoPreviewModelInfo]
14
+ ACCESS_KEY = AKLTYjI0OWNiMGVmZGEwNDNhYjk3YzJhNDdlYTI1NTA5M2M
15
+ SECRET_KEY = TldReVlUaGlNVE0wT0dVeE5ESTFOV0l3T1RKa1lXSm1aak0zTXpJeE5qVQ==
16
+ BASE_URL = https://ark.cn-beijing.volces.com/api/v3
17
+ API_KEY = 0c654012-8989-455f-8a5d-032fc067fbc8
18
+ ENDPOINT_ID = ep-20240923111539-mbwqc
19
+ CYCLE_TIMES = 3
20
+ MERGER_RETRY_TIMES = 4
21
+ MAX_RETRY_TIMES = 4
22
+ MAX_THREAD_NUM = 2
conf/logs.ini ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys = root,Robot
3
+
4
+ [handlers]
5
+ keys = consoleHandler
6
+
7
+ [formatters]
8
+ keys = simpleFormatter
9
+
10
+ [logger_root]
11
+ level = INFO
12
+ handlers= consoleHandler
13
+
14
+ [logger_Robot]
15
+ level= INFO
16
+ handlers = consoleHandler
17
+ qualname = Robot
18
+ propagate=0
19
+
20
+ [handler_consoleHandler]
21
+ class = StreamHandler
22
+ level = INFO
23
+ formatter = simpleFormatter
24
+
25
+
26
+ [formatter_simpleFormatter]
27
+ format = %(asctime)s %(levelname)s %(filename)s %(lineno)d %(message)s
28
+ datefmt = %Y-%m-%d %H:%M:%S
doubao_service.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import requests
3
+ import json
4
+ from volcenginesdkarkruntime import Ark
5
+ from util.config_util import read_config as config
6
+ from util import logger
7
+ import volcenginesdkcore
8
+ import volcenginesdkark
9
+ from volcenginesdkcore.rest import ApiException
10
+ from util.logger_util import log_decorate
11
+
12
+
13
+ class DouBaoService:
14
+
15
+ def __init__(self, model_name):
16
+ self.conf = config()[f"{model_name}ModelInfo"]
17
+ self.client = self.init_client()
18
+ self._complete_args = {}
19
+
20
+
21
+
22
+ def init_client(self):
23
+ base_url = self.conf["BASE_URL"]
24
+ ak = self.conf["ACCESS_KEY"]
25
+ sk = self.conf["SECRET_KEY"]
26
+ # api_key = self.conf["API_KEY"]
27
+ client = Ark(ak=ak, sk=sk, base_url=base_url)
28
+ # client = Ark(ak=api_key, base_url=base_url)
29
+ return client
30
+
31
+ def get_api_key(self):
32
+ configuration = volcenginesdkcore.Configuration()
33
+ configuration.ak = self.conf["ACCESS_KEY"]
34
+ configuration.sk = self.conf["SECRET_KEY"]
35
+ configuration.region = "cn-beijing"
36
+ endpoint_id = self.conf["ENDPOINT_ID"]
37
+
38
+ volcenginesdkcore.Configuration.set_default(configuration)
39
+
40
+ # use global default configuration
41
+ api_instance = volcenginesdkark.ARKApi()
42
+ get_api_key_request = volcenginesdkark.GetApiKeyRequest(
43
+ duration_seconds=30 * 24 * 3600,
44
+ resource_type="endpoint",
45
+ resource_ids=[
46
+ endpoint_id
47
+ ],
48
+ )
49
+
50
+ try:
51
+ resp = api_instance.get_api_key(get_api_key_request)
52
+ return resp.api_key
53
+ except ApiException as e:
54
+ logger.error(f"Exception when calling api: {e}")
55
+
56
+ def set_complete_args(self, temperature=None, top_p=None, max_token=None):
57
+ if temperature is not None:
58
+ self._complete_args["temperature"] = temperature
59
+ if top_p is not None:
60
+ self._complete_args["top_p"] = top_p
61
+ if max_token is not None:
62
+ self._complete_args["max_tokens"] = max_token
63
+
64
+ def form_user_role(self, content):
65
+ return {"role": "user", "content": content}
66
+
67
+ def form_sys_role(self, content):
68
+ return {"role": "system", "content": content}
69
+
70
+ def form_assistant_role(self, content):
71
+ return {"role": "assistant", "content": content}
72
+
73
+ @property
74
+ def complete_args(self):
75
+ return {"temperature": 0.01, "top_p": 0.7}
76
+
77
+ @log_decorate
78
+ def chat_complete(self, messages):
79
+
80
+ endpoint_id = self.conf["ENDPOINT_ID"]
81
+ completion = self.client.chat.completions.create(
82
+ model=endpoint_id,
83
+ messages=messages,
84
+ **self.complete_args
85
+ )
86
+ logger.info(f"complete doubao task, id: {completion.id}")
87
+ return completion.choices[0].message.content
88
+
89
+ def prd_to_keypoint(self, prd_content):
90
+
91
+ role_desc = {"role": "system", "content": PRD2KP_SYS}
92
+
93
+ messages = [
94
+ role_desc,
95
+ {"role": "user", "content": prd_content}
96
+ ]
97
+ return self.chat_complete(messages)
98
+
99
+ def prd_to_cases(self, prd_content, case_language="Chinese"):
100
+
101
+ role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]}
102
+
103
+ messages = [
104
+ role_desc,
105
+ {"role": "user", "content": prd_content}
106
+ ]
107
+ return self.chat_complete(messages)
108
+
109
+ def keypoint_to_case(self, key_points):
110
+
111
+ role_desc = {"role": "system", "content": KP2CASE_SYS}
112
+
113
+ messages = [
114
+ role_desc,
115
+ {"role": "user", "content": key_points}
116
+ ]
117
+ return self.chat_complete(messages)
118
+
119
+ def case_merge_together(self, case_suits):
120
+
121
+ role_desc = {"role": "system", "content": CASE_AGG_SYS}
122
+
123
+ content_case_suits = ""
124
+ for i, case_suit in enumerate(case_suits):
125
+ case_suit_expr = json.dumps(case_suit, indent=4, ensure_ascii=False)
126
+ content_case_suits += f"来自初级测试工程师{i + 1}的测试用例:\n```json\n{case_suit_expr}\n```\n"
127
+ messages = [
128
+ role_desc,
129
+ {"role": "user", "content": content_case_suits}
130
+ ]
131
+ completion = self.chat_complete(messages)
132
+ return completion
133
+
134
+ def cycle_more_case(self, prd_content, case_language="Chinese"):
135
+
136
+ role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]}
137
+
138
+ messages = [
139
+ role_desc,
140
+ {"role": "user", "content": PRD_CASE_1[case_language] + prd_content + "\n" + PRD_CASE_2[case_language]}
141
+ ]
142
+
143
+ result = []
144
+
145
+ for sys in MORE_CASE_PROMPT[case_language]:
146
+ if sys:
147
+ messages.append({"role": "user", "content": sys})
148
+ reply = self.chat_complete(messages)
149
+ result.append(reply)
150
+ messages.append({"role": "assistant", "content": reply})
151
+ time.sleep(10)
152
+ return result
153
+
154
+
155
+ if __name__ == "__main__":
156
+ cli = DouBaoService("DouBao128Pro")
157
+ # print(cli.get_api_key())
158
+ # prd_content = requests.get("https://tosv.byted.org/obj/music-qa-bucket/xmind-test/de3ebc67410c43603034e21bfefa76a0.md").text
159
+ # aa = cli.cycle_more_case(prd_content, "English")
160
+ # print(aa)
161
+
162
+ print(cli.chat_complete(messages=[
163
+ {"role": "system", "content": "You are a helpful assistant."},
164
+ {"role": "user", "content": "Introduce LLM shortly."},
165
+ ]))
166
+
interface.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import re
4
+ from json import loads, JSONDecodeError
5
+ import sys
6
+ import os
7
+ import ast
8
+ from util.vector_base import EmbeddingFunction, get_or_create_vector_base
9
+ from doubao_service import DouBaoService
10
+ from PROMPT_TEMPLATE import prompt_template
11
+ from util.Embeddings import TextEmb3LargeEmbedding
12
+ from langchain_core.documents import Document
13
+ from FlagEmbedding import FlagReranker
14
+ from retriever import retriever
15
+ import time
16
+ from bm25s import BM25, tokenize
17
+ import contextlib
18
+ import io
19
+
20
+ import gradio as gr
21
+ import time
22
+
23
+ client = DouBaoService("DouBao128Pro")
24
+ embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
25
+ embedding = EmbeddingFunction(embeddingmodel)
26
+ safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding)
27
+
28
+ # reranker_model = FlagReranker(
29
+ # 'C://Users//Admin//Desktop//PDPO//NLL_LLM//model//bge-reranker-v2-m3',
30
+ # use_fp16=True,
31
+ # devices=["cpu"],
32
+ # )
33
+
34
+ OPTIONS = ['AI Governance',
35
+ 'Data Accuracy',
36
+ 'Data Minimization & Purpose Limitation',
37
+ 'Data Retention',
38
+ 'Data Security',
39
+ 'Data Sharing',
40
+ 'Individual Rights',
41
+ 'Privacy by Design',
42
+ 'Transparency']
43
+
44
+
45
+ def format_model_output(raw_output):
46
+ """
47
+ 处理模型输出:
48
+ - 将 \n 转换为实际换行
49
+ - 提取 ```json ``` 中的内容并格式化为可折叠的 JSON
50
+ """
51
+ formatted = raw_output.replace('\\n', '\n')
52
+ def replace_json(match):
53
+ json_str = match.group(1).strip()
54
+ try:
55
+ json_obj = loads(json_str)
56
+ return f"```json\n{json.dumps(json_obj, indent=2, ensure_ascii=False)}\n```"
57
+ except JSONDecodeError:
58
+ return match.group(0)
59
+
60
+ formatted = re.sub(r'```json\n?(.*?)\n?```', replace_json, formatted, flags=re.DOTALL)
61
+ return ast.literal_eval(formatted)
62
+
63
+ def model_predict(input_text, if_split_po, topk, selected_items):
64
+ """
65
+ selected_items: 用户选择的项目(可能是["All"]或具体PO)
66
+ """
67
+ requirement = input_text
68
+ requirement = requirement.replace("\t", "").replace("\n", "").replace("\r", "")
69
+ if "All" in selected_items:
70
+ PO = OPTIONS
71
+ else:
72
+ PO = selected_items
73
+ if topk:
74
+ topk = int(topk)
75
+ else:
76
+ topk = 10
77
+ final_result = retriever(
78
+ requirement,
79
+ PO,
80
+ safeguard_vector_store,
81
+ reranker_model=None,
82
+ using_reranker=False,
83
+ using_BM25=False,
84
+ using_chroma=True,
85
+ k=topk,
86
+ if_split_po=if_split_po
87
+ )
88
+ mapping_safeguards = {}
89
+ for safeguard in final_result:
90
+ if safeguard[3] not in mapping_safeguards:
91
+ mapping_safeguards[safeguard[3]] = []
92
+ mapping_safeguards[safeguard[3]].append(
93
+ {
94
+ "Score": safeguard[0],
95
+ "Safeguard Number": safeguard[1],
96
+ "Safeguard Description": safeguard[2]
97
+ }
98
+ )
99
+ prompt = prompt_template(requirement, mapping_safeguards)
100
+ response = client.chat_complete(messages=[
101
+ {"role": "system", "content": "You are a helpful assistant."},
102
+ {"role": "user", "content": prompt},
103
+ ])
104
+ # return {"requirement": requirement, "safeguards": mapping_safeguards}
105
+ print("requirement:", requirement)
106
+ print("mapping safeguards:", mapping_safeguards)
107
+ print("response:", response)
108
+ return {"requirement": requirement, "safeguards": format_model_output(response)}
109
+
110
+ with gr.Blocks(title="New Law Landing") as demo:
111
+ gr.Markdown("## 🏙️ New Law Landing")
112
+
113
+ requirement = gr.Textbox(label="Input Requirements", placeholder="Example: Data Minimization Consent for incompatible purposes")
114
+ details = gr.Textbox(label="Input Details", placeholder="Example: Require consent for...")
115
+
116
+ # 修改为 Number 输入组件
117
+ topk = gr.Number(
118
+ label="Top K safeguards",
119
+ value=10,
120
+ precision=0,
121
+ minimum=1,
122
+ interactive=True
123
+ )
124
+
125
+ with gr.Row():
126
+ with gr.Column(scale=1):
127
+ if_split_po = gr.Checkbox(
128
+ label="If Split Privacy Objective",
129
+ value=True,
130
+ info="Recall K Safeguards for each Privacy Objective"
131
+ )
132
+ with gr.Column(scale=1):
133
+ all_checkbox = gr.Checkbox(
134
+ label="ALL Privacy Objective",
135
+ value=True,
136
+ info="No specific Privacy Objective is specified"
137
+ )
138
+ with gr.Column(scale=4):
139
+ PO_checklist = gr.CheckboxGroup(
140
+ label="Choose Privacy Objective",
141
+ choices=OPTIONS,
142
+ value=[],
143
+ interactive=True
144
+ )
145
+
146
+ submit_btn = gr.Button("Submit", variant="primary")
147
+ result_output = gr.JSON(label="Related safeguards", open=True)
148
+
149
+
150
+ def sync_checkboxes(selected_items, all_selected):
151
+ if len(selected_items) > 0:
152
+ return False
153
+ return all_selected
154
+
155
+ PO_checklist.change(
156
+ fn=sync_checkboxes,
157
+ inputs=[PO_checklist, all_checkbox],
158
+ outputs=all_checkbox
159
+ )
160
+
161
+ def sync_all(selected_all, current_selection):
162
+ if selected_all:
163
+ return []
164
+ return current_selection
165
+
166
+ all_checkbox.change(
167
+ fn=sync_all,
168
+ inputs=[all_checkbox, PO_checklist],
169
+ outputs=PO_checklist
170
+ )
171
+
172
+ def process_inputs(requirement, details, topk, if_split_po, all_selected, PO_selected):
173
+ input_text = requirement + ": " + details
174
+ if all_selected:
175
+ return model_predict(input_text, if_split_po, int(topk), ["All"])
176
+ else:
177
+ return model_predict(input_text, if_split_po, int(topk), PO_selected)
178
+
179
+ submit_btn.click(
180
+ fn=process_inputs,
181
+ inputs=[requirement, details, topk, if_split_po, all_checkbox, PO_checklist],
182
+ outputs=[result_output]
183
+ )
184
+
185
+ if __name__ == "__main__":
186
+ demo.launch(share=True)
retriever.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import json
3
+ import sys
4
+ import os
5
+ from collections import defaultdict
6
+ from util.vector_base import EmbeddingFunction, get_or_create_vector_base
7
+ from util.Embeddings import TextEmb3LargeEmbedding
8
+ from langchain_core.documents import Document
9
+ from FlagEmbedding import FlagReranker
10
+ import time
11
+ from bm25s import BM25, tokenize
12
+ import contextlib
13
+ import io
14
+ from tqdm import tqdm
15
+
16
+ def rrf(rankings, k = 60):
17
+ res = 0
18
+ for r in rankings:
19
+ res += 1 / (r + k)
20
+ return res
21
+
22
+ def retriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=False, using_BM25=False, using_chroma=True, k=20, if_split_po=True):
23
+ final_result = []
24
+ if not if_split_po:
25
+ final_result = multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
26
+ else:
27
+ for po in PO:
28
+ po_result = multiretriever(requirement, [po], safeguard_vector_store, reranker_model, using_reranker=using_reranker, using_BM25=using_BM25, using_chroma=using_chroma, k=k)
29
+ for safeguard in po_result:
30
+ final_result.append(safeguard)
31
+ return final_result
32
+
33
+ def multiretriever(requirement, PO, safeguard_vector_store, reranker_model, using_reranker=True, using_BM25=False, using_chroma=True, k=20):
34
+ """
35
+ requirements_dict: [
36
+ requirement: {
37
+ "PO": [],
38
+ "safeguard": []
39
+ }
40
+ ]
41
+ """
42
+ candidate_safeguards = []
43
+ po_list = [po.lower().rstrip() for po in PO if po]
44
+ if "young users" in po_list and len(po_list) == 1:
45
+ return []
46
+ candidate_safeguards = safeguard_vector_store.get(where={"po": {"$in": po_list}})
47
+ safeguard_dict, safeguard_content = {}, []
48
+ for id, content, metadata in zip(candidate_safeguards['ids'], candidate_safeguards['documents'], candidate_safeguards['metadatas']):
49
+ safeguard_dict[content] = {
50
+ "metadata": metadata,
51
+ "rank": [],
52
+ "rrf_score": 0
53
+ }
54
+ safeguard_content.append(content)
55
+
56
+ # Reranker
57
+ if using_reranker:
58
+ content_pairs, reranking_rank, reranking_results = [], [], []
59
+ for safeguard in safeguard_content:
60
+ content_pairs.append([requirement, safeguard])
61
+ safeguard_rerank_scores = reranker_model.compute_score(content_pairs)
62
+ for content_pair, score in zip(content_pairs, safeguard_rerank_scores):
63
+ reranking_rank.append((content_pair[1], score))
64
+ reranking_results = sorted(reranking_rank, key=lambda x: x[1], reverse=True)
65
+ for safeguard, score in reranking_results:
66
+ safeguard_dict[safeguard]['rank'].append(reranking_results.index((safeguard, score)) + 1)
67
+
68
+ # BM25
69
+ if using_BM25:
70
+ with contextlib.redirect_stdout(io.StringIO()):
71
+ bm25_retriever = BM25(corpus=safeguard_content)
72
+ bm25_retriever.index(tokenize(safeguard_content))
73
+ bm25_results, scores = bm25_retriever.retrieve(tokenize(requirement), k = len(safeguard_content))
74
+ bm25_retrieval_rank = 1
75
+ for safeguard in bm25_results[0]:
76
+ safeguard_dict[safeguard]['rank'].append(bm25_retrieval_rank)
77
+ bm25_retrieval_rank += 1
78
+
79
+ # chroma retrieval
80
+ if using_chroma:
81
+ retrieved_safeguards = safeguard_vector_store.similarity_search_with_score(query=requirement, k=len(candidate_safeguards['ids']), filter={"po": {"$in": po_list}})
82
+ retrieval_rank = 1
83
+ for safeguard in retrieved_safeguards:
84
+ safeguard_dict[safeguard[0].page_content]['rank'].append(retrieval_rank)
85
+ retrieval_rank += 1
86
+
87
+ final_result = []
88
+ for safeguard in safeguard_content:
89
+ safeguard_dict[safeguard]['rrf_score'] = rrf(safeguard_dict[safeguard]['rank'])
90
+ final_result.append((safeguard_dict[safeguard]['rrf_score'], safeguard_dict[safeguard]['metadata']['safeguard_number'], safeguard, safeguard_dict[safeguard]['metadata']['po']))
91
+ final_result.sort(key=lambda x: x[0], reverse=True)
92
+
93
+ # top k
94
+ topk_final_result = final_result[:k]
95
+
96
+ return topk_final_result
97
+
98
+ if __name__=="__main__":
99
+ embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
100
+ embedding = EmbeddingFunction(embeddingmodel)
101
+ safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding)
102
+ reranker_model = FlagReranker(
103
+ '/root/PTR-LLM/tasks/pcf/model/bge-reranker-v2-m3',
104
+ use_fp16=True,
105
+ devices=["cpu"],
106
+ )
107
+ requirement = """
108
+ Data Minimization Consent for incompatible purposes: Require consent for additional use of personal information not reasonably necessary to or incompatible with original purpose disclosure.
109
+ """
110
+ PO = ["Data Minimization & Purpose Limitation", "Transparency"]
111
+ final_result = retriever(
112
+ requirement,
113
+ PO,
114
+ safeguard_vector_store,
115
+ reranker_model,
116
+ using_reranker=True,
117
+ using_BM25=False,
118
+ using_chroma=True,
119
+ k=10
120
+ )
121
+ print(final_result)
store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b8d4b3825a7c7a773e22fa3eeef0e7d15a695f5c4183aeff5beb07741a68679
3
+ size 12428000
store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8a3ec48846fc6fdfaef19f5ed2508f0bf3da4a3c93b0f6b3dd21f0a22ec1026
3
+ size 100
store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc19b1997119425765295aeab72d76faa6927d4f83985d328c26f20468d6cc76
3
+ size 4000
store/requirement_full_database/8879b034-d26b-4dd9-bdc6-9a0751f8eeeb/link_lists.bin ADDED
File without changes
store/requirement_full_database/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:debbda97fa444d1beed59205da5310caa07fc5e1fea98ee5d217bb1cd86b3312
3
+ size 2031616
store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b8d4b3825a7c7a773e22fa3eeef0e7d15a695f5c4183aeff5beb07741a68679
3
+ size 12428000
store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8a3ec48846fc6fdfaef19f5ed2508f0bf3da4a3c93b0f6b3dd21f0a22ec1026
3
+ size 100
store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0c0e05e944e59611aa54c4b0b708f835a8a0daf4baec11208e3f60773b22d89
3
+ size 4000
store/requirement_v1_database/6db99751-9b95-42b7-ae30-46ba43f95c27/link_lists.bin ADDED
File without changes
store/requirement_v1_database/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f177f30d3354a5cf2c14378bfe1b73755684d9c5fc38046f9e4339a1180af0a2
3
+ size 1212416
store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b8d4b3825a7c7a773e22fa3eeef0e7d15a695f5c4183aeff5beb07741a68679
3
+ size 12428000
store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8a3ec48846fc6fdfaef19f5ed2508f0bf3da4a3c93b0f6b3dd21f0a22ec1026
3
+ size 100
store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a911716b3f8450b156db9e04a3c81548395cabac0c846dd1e1eef832991c120
3
+ size 4000
store/requirement_v2_database/c1b4f057-aa88-49ff-a2ac-08fc9d60804c/link_lists.bin ADDED
File without changes
store/requirement_v2_database/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6d2f15577b5bcd6c8ce845d894cc17196e71c71c3909eec73f9bfedee400c90
3
+ size 991232
store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b8d4b3825a7c7a773e22fa3eeef0e7d15a695f5c4183aeff5beb07741a68679
3
+ size 12428000
store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8a3ec48846fc6fdfaef19f5ed2508f0bf3da4a3c93b0f6b3dd21f0a22ec1026
3
+ size 100
store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c710099c23f51c095ec65b43539cd7534bb745aebf55fef0a9c06229121caca
3
+ size 4000
store/safeguard_database/1ae9d702-e220-41de-95e3-e603f3a12409/link_lists.bin ADDED
File without changes
store/safeguard_database/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0730981819c475f93d173a347f891b5e28e02cce24a63a578ca43c98a13d9c10
3
+ size 5095424
test.ipynb ADDED
File without changes
util/Embeddings.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from copy import copy
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+ import numpy as np
5
+ import time
6
+ from functools import wraps
7
+
8
+ os.environ['CURL_CA_BUNDLE'] = ''
9
+ from dotenv import load_dotenv, find_dotenv
10
+ _ = load_dotenv(find_dotenv())
11
+
12
+
13
+ class BaseEmbeddings:
14
+ """
15
+ Base class for embeddings
16
+ """
17
+ def __init__(self, path: str, is_api: bool) -> None:
18
+ self.path = path
19
+ self.is_api = is_api
20
+
21
+ def get_embedding(self, text: str, model: str) -> List[float]:
22
+ raise NotImplementedError
23
+
24
+ @classmethod
25
+ def cosine_similarity(cls, vector1: List[float], vector2: List[float]) -> float:
26
+ """
27
+ calculate cosine similarity between two vectors
28
+ """
29
+ dot_product = np.dot(vector1, vector2)
30
+ magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
31
+ if not magnitude:
32
+ return 0
33
+ return dot_product / magnitude
34
+
35
+
36
+ class OpenAIEmbedding(BaseEmbeddings):
37
+ """
38
+ class for OpenAI embeddings
39
+ """
40
+ def __init__(self, path: str = '', is_api: bool = True) -> None:
41
+ super().__init__(path, is_api)
42
+ if self.is_api:
43
+ from openai import OpenAI
44
+ self.client = OpenAI()
45
+ self.client.api_key = os.getenv("OPENAI_API_KEY")
46
+ self.client.base_url = os.getenv("OPENAI_BASE_URL")
47
+
48
+ def get_embedding(self, text: str, model: str = "text-embedding-3-large") -> List[float]:
49
+ if self.is_api:
50
+ text = text.replace("\n", " ")
51
+ return self.client.embeddings.create(input=[text], model=model).data[0].embedding
52
+ else:
53
+ raise NotImplementedError
54
+
55
+ class JinaEmbedding(BaseEmbeddings):
56
+ """
57
+ class for Jina embeddings
58
+ """
59
+ def __init__(self, path: str = 'jinaai/jina-embeddings-v2-base-zh', is_api: bool = False) -> None:
60
+ super().__init__(path, is_api)
61
+ self._model = self.load_model()
62
+
63
+ def get_embedding(self, text: str) -> List[float]:
64
+ return self._model.encode([text])[0].tolist()
65
+
66
+ def load_model(self):
67
+ import torch
68
+ from transformers import AutoModel
69
+ if torch.cuda.is_available():
70
+ device = torch.device("cuda")
71
+ else:
72
+ device = torch.device("cpu")
73
+ model = AutoModel.from_pretrained(self.path, trust_remote_code=True).to(device)
74
+ return model
75
+
76
+ class ZhipuEmbedding(BaseEmbeddings):
77
+ """
78
+ class for Zhipu embeddings
79
+ """
80
+ def __init__(self, path: str = '', is_api: bool = True) -> None:
81
+ super().__init__(path, is_api)
82
+ if self.is_api:
83
+ from zhipuai import ZhipuAI
84
+ self.client = ZhipuAI(api_key=os.getenv("ZHIPUAI_API_KEY"))
85
+
86
+ def get_embedding(self, text: str) -> List[float]:
87
+ response = self.client.embeddings.create(
88
+ model="embedding-2",
89
+ input=text,
90
+ )
91
+ return response.data[0].embedding
92
+
93
+
94
+ class DashscopeEmbedding(BaseEmbeddings):
95
+ """
96
+ class for Dashscope embeddings
97
+ """
98
+ def __init__(self, path: str = '', is_api: bool = True) -> None:
99
+ super().__init__(path, is_api)
100
+ if self.is_api:
101
+ import dashscope
102
+ dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
103
+ self.client = dashscope.TextEmbedding
104
+
105
+ def get_embedding(self, text: str, model: str='text-embedding-v1') -> List[float]:
106
+ response = self.client.call(
107
+ model=model,
108
+ input=text
109
+ )
110
+ return response.output['embeddings'][0]['embedding']
111
+
112
+
113
+ class BgeEmbedding(BaseEmbeddings):
114
+ """
115
+ class for BGE embeddings
116
+ """
117
+
118
+ def __init__(self, path: str = 'BAAI/bge-en-icl', is_api: bool = False) -> None:
119
+ super().__init__(path, is_api)
120
+ self._model, self._tokenizer = self.load_model(path)
121
+
122
+ def get_embedding(self, text: str) -> List[float]:
123
+ import torch
124
+ encoded_input = self._tokenizer([text], padding=True, truncation=True, return_tensors='pt')
125
+ encoded_input = {k: v.to(self._model.device) for k, v in encoded_input.items()}
126
+ with torch.no_grad():
127
+ model_output = self._model(**encoded_input)
128
+ sentence_embeddings = model_output[0][:, 0]
129
+ sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
130
+ return sentence_embeddings[0].tolist()
131
+
132
+ def load_model(self, path: str):
133
+ import torch
134
+ from transformers import AutoModel, AutoTokenizer
135
+ if torch.cuda.is_available():
136
+ device = torch.device("cuda")
137
+ else:
138
+ device = torch.device("cpu")
139
+ tokenizer = AutoTokenizer.from_pretrained(path)
140
+ model = AutoModel.from_pretrained(path).to(device)
141
+ model.eval()
142
+ return model, tokenizer
143
+
144
+
145
+ def rate_limiter():
146
+ def rate_limiter_decorator(func):
147
+ @wraps(func)
148
+ def wrapper(self, *args, **kwargs):
149
+ max_calls_per_minute = self.max_qpm
150
+ interval = 60 / max_calls_per_minute
151
+ current_time = time.time()
152
+
153
+ # Check if there's a record of the last call, if not set it to 0
154
+ if not hasattr(self, '_last_called'):
155
+ self._last_called = 0
156
+ elapsed_time = current_time - self._last_called
157
+ if elapsed_time < interval:
158
+ time_to_wait = interval - elapsed_time
159
+ if self.silent is False:
160
+ print(f"## Rate limit reached. Waiting for {time_to_wait:.2f} seconds.")
161
+ time.sleep(time_to_wait)
162
+ result = func(self, *args, **kwargs)
163
+ self._last_called = time.time()
164
+ return result
165
+ return wrapper
166
+
167
+ return rate_limiter_decorator
168
+
169
+
170
+ class TextEmb3LargeEmbedding(BaseEmbeddings):
171
+ """
172
+ class for text-embedding-3-large embeddings
173
+ """
174
+ def __init__(self, max_qpm, is_silent=False):
175
+ from langchain_openai import AzureOpenAIEmbeddings
176
+
177
+ ## https://gpt.bytedance.net/gpt_openapi/
178
+ base_url = "https://search-va.byteintl.net/gpt/openapi/online/v2/crawl"
179
+ api_version = "2024-03-01-preview"
180
+ ak = "5dXdIKxZc8JWVVgvX0DN92HWIYb9NfEb_GPT_AK"
181
+ model_name = "text-embedding-3-large"
182
+ api_type = "azure"
183
+ self.llm = AzureOpenAIEmbeddings(
184
+ azure_endpoint=base_url,
185
+ openai_api_version=api_version,
186
+ deployment=model_name,
187
+ openai_api_key=ak,
188
+ openai_api_type=api_type,
189
+ )
190
+ self.max_qpm = max_qpm
191
+ self.silent = is_silent
192
+
193
+ @rate_limiter()
194
+ def get_embedding(self, text: str):
195
+ return self.llm.embed_query(text)
util/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .logger_util import get_logger
2
+
3
+ logger = get_logger()
util/__pycache__/Embeddings.cpython-311.pyc ADDED
Binary file (13.7 kB). View file
 
util/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (258 Bytes). View file
 
util/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (205 Bytes). View file
 
util/__pycache__/config_util.cpython-311.pyc ADDED
Binary file (1.35 kB). View file
 
util/__pycache__/logger_util.cpython-311.pyc ADDED
Binary file (1.85 kB). View file
 
util/__pycache__/logger_util.cpython-39.pyc ADDED
Binary file (1.01 kB). View file
 
util/__pycache__/vector_base.cpython-311.pyc ADDED
Binary file (5.76 kB). View file
 
util/__pycache__/vector_base.cpython-39.pyc ADDED
Binary file (3.39 kB). View file
 
util/config_util.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ import bytedenv
5
+ import configparser
6
+
7
+ ROOT_PATH = os.path.split(os.path.split(os.path.abspath(__file__))[0])[0]
8
+
9
+
10
+ def read_config():
11
+ config_file = ROOT_PATH + "\conf\config.ini"
12
+ config_ini = configparser.ConfigParser()
13
+ config_ini.read(config_file)
14
+ model_name = "DouBao128Pro"
15
+ return config_ini
16
+
17
+ def read_json(filepath):
18
+ with open(filepath, 'r') as f:
19
+ result = json.load(f)
20
+ return result
21
+
util/logger_util.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import logging.config
4
+ import traceback
5
+ from functools import wraps
6
+
7
+
8
+ def get_logger():
9
+ root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
10
+ logging.config.fileConfig(os.path.join(root_path, "conf", "logs.ini"))
11
+ logger = logging.getLogger("Robot")
12
+ return logger
13
+
14
+
15
+ def log_decorate(func):
16
+ @wraps(func)
17
+ def log(*args, **kwargs):
18
+ logger = get_logger()
19
+ try:
20
+ return func(*args, **kwargs)
21
+ except Exception as e:
22
+ logger.error(f"{func.__name__} is error, logId: {e.args}, errMsg is: {traceback.format_exc()}")
23
+ return log
util/vector_base.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from langchain_chroma import Chroma
3
+ from langchain_core.documents import Document
4
+ sys.path.append('C://Users//Admin//Desktop//PDPO//NLL_LLM//util')
5
+ from Embeddings import TextEmb3LargeEmbedding
6
+ from pathlib import Path
7
+ import time
8
+
9
+ class EmbeddingFunction():
10
+ def __init__(self, embeddingmodel):
11
+ self.embeddingmodel = embeddingmodel
12
+ def embed_query(self, query):
13
+ return list(self.embeddingmodel.get_embedding(query))
14
+ def embed_documents(self, documents):
15
+ return [self.embeddingmodel.get_embedding(document) for document in documents]
16
+
17
+ def get_or_create_vector_base(collection_name: str, embedding, documents=None) -> Chroma:
18
+ """
19
+ 判断vector store是否已经构建好,如果没有构建好,则先初始化vector store。不使用embed_documents
20
+ 方法批量初始化vector store而是for循环逐个加入,同时使用sleep,以此避免调用openai的接口达到最大
21
+ 上限而导致初始化失败。
22
+ """
23
+ persist_directory = "C://Users//Admin//Desktop//PDPO//NLL_LLM//store//" +collection_name
24
+ persist_path = Path(persist_directory)
25
+ if not persist_path.exists and not documents:
26
+ raise ValueError("vector store does not exist and documents is empty")
27
+ elif persist_path.exists():
28
+ print("vector store already exists")
29
+ vector_store = Chroma(
30
+ collection_name=collection_name,
31
+ embedding_function=embedding,
32
+ persist_directory=persist_directory
33
+ )
34
+ else:
35
+ print("start creating vector store")
36
+ vector_store = Chroma(
37
+ collection_name=collection_name,
38
+ embedding_function=embedding,
39
+ persist_directory=persist_directory
40
+ )
41
+ for document in documents:
42
+ vector_store.add_documents(documents=[document])
43
+ time.sleep(1)
44
+ return vector_store
45
+
46
+ if __name__=="__main__":
47
+ import pandas as pd
48
+ requirements_data = pd.read_csv("/root/PTR-LLM/tasks/pcf/reference/NLL_DATA_NEW_Test.csv")
49
+ requirements_dict_v2 = {}
50
+ for index, row in requirements_data.iterrows():
51
+ requirement = row['Requirement'].split("- ")[1]
52
+ requirement = requirement + ": " + row['Details']
53
+ requirement = requirement.replace('\n', ' ').replace('\r', ' ').replace('\t', ' ')
54
+ if requirement not in requirements_dict_v2:
55
+ requirements_dict_v2[requirement] = {
56
+ 'PO': set(),
57
+ 'safeguard': set()
58
+ }
59
+ requirements_dict_v2[requirement]['PO'].add(row['PCF-Privacy Objective'].lower().rstrip() if isinstance(row['PCF-Privacy Objective'], str) else None)
60
+ requirements_dict_v2[requirement]['safeguard'].add(row['Safeguard'].lower().rstrip())
61
+ index = 0
62
+ documents = []
63
+ for key, value in requirements_dict_v2.items():
64
+ page_content = key
65
+ metadata = {
66
+ "index": index,
67
+ "version":2,
68
+ "PO": str([po for po in value['PO'] if po]),
69
+ "safeguard":str([safeguard for safeguard in value['safeguard']])
70
+ }
71
+ index += 1
72
+ document=Document(
73
+ page_content=page_content,
74
+ metadata=metadata
75
+ )
76
+ documents.append(document)
77
+ embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
78
+ embedding = EmbeddingFunction(embeddingmodel)
79
+ requirement_v2_vector_store = get_or_create_vector_base('requirement_v2_database', embedding, documents)