NLL_Interface / interface.py
bytedancerneat's picture
Update interface.py
b263d05 verified
raw
history blame contribute delete
5.97 kB
import pandas as pd
import json
import re
from json import loads, JSONDecodeError
import sys
import os
import ast
from util.vector_base import EmbeddingFunction, get_or_create_vector_base
from doubao_service import DouBaoService
from PROMPT_TEMPLATE import prompt_template
from util.Embeddings import TextEmb3LargeEmbedding
from langchain_core.documents import Document
from FlagEmbedding import FlagReranker
from retriever import retriever
import time
# from bm25s import BM25, tokenize
import contextlib
import io
import gradio as gr
import time
client = DouBaoService("DouBao128Pro")
embeddingmodel = TextEmb3LargeEmbedding(max_qpm=58)
embedding = EmbeddingFunction(embeddingmodel)
safeguard_vector_store = get_or_create_vector_base('safeguard_database', embedding)
# reranker_model = FlagReranker(
# 'C://Users//Admin//Desktop//PDPO//NLL_LLM//model//bge-reranker-v2-m3',
# use_fp16=True,
# devices=["cpu"],
# )
OPTIONS = ['AI Governance',
'Data Accuracy',
'Data Minimization & Purpose Limitation',
'Data Retention',
'Data Security',
'Data Sharing',
'Individual Rights',
'Privacy by Design',
'Transparency']
def format_model_output(raw_output):
"""
处理模型输出:
- 将 \n 转换为实际换行
- 提取 ```json ``` 中的内容并格式化为可折叠的 JSON
"""
formatted = raw_output.replace('\\n', '\n')
def replace_json(match):
json_str = match.group(1).strip()
try:
json_obj = loads(json_str)
return f"```json\n{json.dumps(json_obj, indent=2, ensure_ascii=False)}\n```"
except JSONDecodeError:
return match.group(0)
formatted = re.sub(r'```json\n?(.*?)\n?```', replace_json, formatted, flags=re.DOTALL)
return ast.literal_eval(formatted)
def model_predict(input_text, if_split_po, topk, selected_items):
"""
selected_items: 用户选择的项目(可能是["All"]或具体PO)
"""
requirement = input_text
requirement = requirement.replace("\t", "").replace("\n", "").replace("\r", "")
if "All" in selected_items:
PO = OPTIONS
else:
PO = selected_items
if topk:
topk = int(topk)
else:
topk = 10
final_result = retriever(
requirement,
PO,
safeguard_vector_store,
reranker_model=None,
using_reranker=False,
using_BM25=False,
using_chroma=True,
k=topk,
if_split_po=if_split_po
)
mapping_safeguards = {}
for safeguard in final_result:
if safeguard[3] not in mapping_safeguards:
mapping_safeguards[safeguard[3]] = []
mapping_safeguards[safeguard[3]].append(
{
"Score": safeguard[0],
"Safeguard Number": safeguard[1],
"Safeguard Description": safeguard[2]
}
)
prompt = prompt_template(requirement, mapping_safeguards)
response = client.chat_complete(messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
])
# return {"requirement": requirement, "safeguards": mapping_safeguards}
print("requirement:", requirement)
print("mapping safeguards:", mapping_safeguards)
print("response:", response)
return {"requirement": requirement, "safeguards": format_model_output(response)}
with gr.Blocks(title="New Law Landing") as demo:
gr.Markdown("## 🏙️ New Law Landing")
requirement = gr.Textbox(label="Input Requirements", placeholder="Example: Data Minimization Consent for incompatible purposes")
details = gr.Textbox(label="Input Details", placeholder="Example: Require consent for...")
# 修改为 Number 输入组件
topk = gr.Number(
label="Top K safeguards",
value=10,
precision=0,
minimum=1,
interactive=True
)
with gr.Row():
with gr.Column(scale=1):
if_split_po = gr.Checkbox(
label="If Split Privacy Objective",
value=True,
info="Recall K Safeguards for each Privacy Objective"
)
with gr.Column(scale=1):
all_checkbox = gr.Checkbox(
label="ALL Privacy Objective",
value=True,
info="No specific Privacy Objective is specified"
)
with gr.Column(scale=4):
PO_checklist = gr.CheckboxGroup(
label="Choose Privacy Objective",
choices=OPTIONS,
value=[],
interactive=True
)
submit_btn = gr.Button("Submit", variant="primary")
result_output = gr.JSON(label="Related safeguards", open=True)
def sync_checkboxes(selected_items, all_selected):
if len(selected_items) > 0:
return False
return all_selected
PO_checklist.change(
fn=sync_checkboxes,
inputs=[PO_checklist, all_checkbox],
outputs=all_checkbox
)
def sync_all(selected_all, current_selection):
if selected_all:
return []
return current_selection
all_checkbox.change(
fn=sync_all,
inputs=[all_checkbox, PO_checklist],
outputs=PO_checklist
)
def process_inputs(requirement, details, topk, if_split_po, all_selected, PO_selected):
input_text = requirement + ": " + details
if all_selected:
return model_predict(input_text, if_split_po, int(topk), ["All"])
else:
return model_predict(input_text, if_split_po, int(topk), PO_selected)
submit_btn.click(
fn=process_inputs,
inputs=[requirement, details, topk, if_split_po, all_checkbox, PO_checklist],
outputs=[result_output]
)
if __name__ == "__main__":
demo.launch(share=True)