Spaces:
Sleeping
Sleeping
import time | |
import requests | |
import json | |
from volcenginesdkarkruntime import Ark | |
from util.config_util import read_config as config | |
from util.config_util import load_json | |
from util import logger | |
import volcenginesdkcore | |
import volcenginesdkark | |
from volcenginesdkcore.rest import ApiException | |
from util.logger_util import log_decorate | |
class DouBaoService: | |
def __init__(self, model_name): | |
config = load_json('./conf/config.json') | |
self.conf = config[f"{model_name}ModelInfo"] | |
self.client = self.init_client() | |
self._complete_args = {} | |
def init_client(self): | |
base_url = self.conf["BASE_URL"] | |
ak = self.conf["ACCESS_KEY"] | |
sk = self.conf["SECRET_KEY"] | |
# api_key = self.conf["API_KEY"] | |
client = Ark(ak=ak, sk=sk, base_url=base_url) | |
# client = Ark(ak=api_key, base_url=base_url) | |
return client | |
def get_api_key(self): | |
configuration = volcenginesdkcore.Configuration() | |
configuration.ak = self.conf["ACCESS_KEY"] | |
configuration.sk = self.conf["SECRET_KEY"] | |
configuration.region = "cn-beijing" | |
endpoint_id = self.conf["ENDPOINT_ID"] | |
volcenginesdkcore.Configuration.set_default(configuration) | |
# use global default configuration | |
api_instance = volcenginesdkark.ARKApi() | |
get_api_key_request = volcenginesdkark.GetApiKeyRequest( | |
duration_seconds=30 * 24 * 3600, | |
resource_type="endpoint", | |
resource_ids=[ | |
endpoint_id | |
], | |
) | |
try: | |
resp = api_instance.get_api_key(get_api_key_request) | |
return resp.api_key | |
except ApiException as e: | |
logger.error(f"Exception when calling api: {e}") | |
def set_complete_args(self, temperature=None, top_p=None, max_token=None): | |
if temperature is not None: | |
self._complete_args["temperature"] = temperature | |
if top_p is not None: | |
self._complete_args["top_p"] = top_p | |
if max_token is not None: | |
self._complete_args["max_tokens"] = max_token | |
def form_user_role(self, content): | |
return {"role": "user", "content": content} | |
def form_sys_role(self, content): | |
return {"role": "system", "content": content} | |
def form_assistant_role(self, content): | |
return {"role": "assistant", "content": content} | |
def complete_args(self): | |
return {"temperature": 0.01, "top_p": 0.7} | |
def chat_complete(self, messages): | |
endpoint_id = self.conf["ENDPOINT_ID"] | |
completion = self.client.chat.completions.create( | |
model=endpoint_id, | |
messages=messages, | |
**self.complete_args | |
) | |
logger.info(f"complete doubao task, id: {completion.id}") | |
return completion.choices[0].message.content | |
def prd_to_keypoint(self, prd_content): | |
role_desc = {"role": "system", "content": PRD2KP_SYS} | |
messages = [ | |
role_desc, | |
{"role": "user", "content": prd_content} | |
] | |
return self.chat_complete(messages) | |
def prd_to_cases(self, prd_content, case_language="Chinese"): | |
role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]} | |
messages = [ | |
role_desc, | |
{"role": "user", "content": prd_content} | |
] | |
return self.chat_complete(messages) | |
def keypoint_to_case(self, key_points): | |
role_desc = {"role": "system", "content": KP2CASE_SYS} | |
messages = [ | |
role_desc, | |
{"role": "user", "content": key_points} | |
] | |
return self.chat_complete(messages) | |
def case_merge_together(self, case_suits): | |
role_desc = {"role": "system", "content": CASE_AGG_SYS} | |
content_case_suits = "" | |
for i, case_suit in enumerate(case_suits): | |
case_suit_expr = json.dumps(case_suit, indent=4, ensure_ascii=False) | |
content_case_suits += f"来自初级测试工程师{i + 1}的测试用例:\n```json\n{case_suit_expr}\n```\n" | |
messages = [ | |
role_desc, | |
{"role": "user", "content": content_case_suits} | |
] | |
completion = self.chat_complete(messages) | |
return completion | |
def cycle_more_case(self, prd_content, case_language="Chinese"): | |
role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]} | |
messages = [ | |
role_desc, | |
{"role": "user", "content": PRD_CASE_1[case_language] + prd_content + "\n" + PRD_CASE_2[case_language]} | |
] | |
result = [] | |
for sys in MORE_CASE_PROMPT[case_language]: | |
if sys: | |
messages.append({"role": "user", "content": sys}) | |
reply = self.chat_complete(messages) | |
result.append(reply) | |
messages.append({"role": "assistant", "content": reply}) | |
time.sleep(10) | |
return result | |
if __name__ == "__main__": | |
cli = DouBaoService("DouBao128Pro") | |
# print(cli.get_api_key()) | |
# prd_content = requests.get("https://tosv.byted.org/obj/music-qa-bucket/xmind-test/de3ebc67410c43603034e21bfefa76a0.md").text | |
# aa = cli.cycle_more_case(prd_content, "English") | |
# print(aa) | |
print(cli.chat_complete(messages=[ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": "Introduce LLM shortly."}, | |
])) | |