Toursim-Test / app.py
zhuhai111's picture
Update app.py
5ef0bd6 verified
raw
history blame contribute delete
31.1 kB
import gradio as gr
import sys
import os
from pathlib import Path
from typing import List
# 添加项目根目录到 Python 路径
sys.path.append(str(Path(__file__).parent))
from src.api.search_api import BochaSearch
from src.core.document_processor import DocumentProcessor
from src.core.ranking import RankingSystem
from src.core.plan_generator import PlanGenerator
from src.core.embeddings import EmbeddingModel
from src.core.reranker import Reranker
from src.api.llm_api import DeepseekInterface, LLMInterface, OpenAIInterface
from src.utils.helpers import load_config
import logging
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TravelRAGSystem:
def __init__(self):
self.config = load_config("config/config.yaml")
self.llm_instances = {} # 存储不同provider的LLM实例
# 固定使用 standard 检索方法
self.retrieval_method = "standard"
self.init_llm_instances()
self.init_components()
def init_components(self):
# 获取默认提供商的配置
default_provider = self.config['llm_settings']['default_provider']
provider_config = next(
(p for p in self.config['llm_settings']['providers']
if p['name'] == default_provider),
None
)
if not provider_config:
raise ValueError(f"未找到默认提供商 {default_provider} 的配置")
# 初始化LLM实例
self.llm = self.init_llm(provider_config['name'], provider_config['model'])
self.search_engine = BochaSearch(
api_key=self.config['bocha_api_key'],
base_url=self.config['bocha_base_url']
)
self.doc_processor = DocumentProcessor(self.llm)
# 初始化嵌入模型 - 使用 Hugging Face 模型 ID
try:
self.embedding_model = EmbeddingModel(
model_name="BAAI/bge-m3"
)
logger.info("成功加载嵌入模型")
except Exception as e:
logger.error(f"加载嵌入模型失败: {str(e)}")
raise
# 初始化重排序器 - 使用 Hugging Face 模型 ID
try:
self.reranker = Reranker(
model_path="BAAI/bge-reranker-large"
)
logger.info("成功加载重排序模型")
except Exception as e:
logger.error(f"加载重排序模型失败: {str(e)}")
raise
self.ranking_system = RankingSystem(self.embedding_model, self.reranker)
self.plan_generator = PlanGenerator(self.llm)
def init_llm(self, provider: str, model: str):
if provider == "openai":
return OpenAIInterface(
api_key=self.config['openai_api_key'],
model=model
)
elif provider == "deepseek":
return DeepseekInterface(
api_key=self.config['deepseek_api_key'],
base_url=next(
p['base_url'] for p in self.config['llm_settings']['providers']
if p['name'] == 'deepseek'
),
model=model
)
else:
raise ValueError(f"不支持的LLM提供商: {provider}")
def init_llm_instances(self):
"""初始化所有启用的LLM实例"""
for provider in self.config['llm_settings']['providers']:
if provider.get('enabled', False):
try:
if provider['name'] == "openai":
self.llm_instances['openai'] = OpenAIInterface(
api_key=self.config['openai_api_key'],
model=provider['model']
)
else:
self.llm_instances['deepseek'] = DeepseekInterface(
api_key=self.config['deepseek_api_key'],
base_url=provider['base_url'],
model=provider['model']
)
logging.info(f"成功初始化 {provider['name']} LLM")
except Exception as e:
logging.error(f"初始化 {provider['name']} LLM 失败: {str(e)}")
def get_llm(self, provider_name: str = None) -> LLMInterface:
"""获取指定的LLM实例"""
if not provider_name:
provider_name = self.config['llm_settings']['default_provider']
if provider_name not in self.llm_instances:
raise ValueError(f"未找到或未启用的LLM提供商: {provider_name}")
return self.llm_instances[provider_name]
def process_query(
self,
query: str,
days: int,
llm_provider: str,
llm_model: str,
enable_images: bool = True,
retrieval_method: str = None
) -> tuple:
try:
# 如果指定了新的检索方法,则切换
if retrieval_method and retrieval_method != self.retrieval_method:
self.set_retrieval_method(retrieval_method)
# 确保LLM提供商存在
if llm_provider not in self.llm_instances:
raise ValueError(f"LLM提供商 {llm_provider} 未启用或不可用,将使用默认提供商")
current_llm = self.llm_instances[llm_provider]
self.doc_processor = DocumentProcessor(current_llm)
self.plan_generator = PlanGenerator(current_llm)
# 确保查询包含天数
if days > 0:
query = f"{query} {days} days"
# 执行搜索
logger.info(f"执行搜索: {query}")
search_results = self.search_engine.search(query)
logger.info(f"搜索结果: {search_results}")
# 处理文档
passages = self.doc_processor.process_documents(search_results)
logger.info(f"处理后的文档: {passages}")
# 使用当前检索器进行检索和排序
if hasattr(self, 'retriever'):
final_ranked = self.retriever.retrieve(query, passages)
else:
# 使用默认的排序系统
initial_ranked = self.ranking_system.initial_ranking(query, passages)
final_ranked = self.ranking_system.rerank(query, initial_ranked)
# 生成计划
final_plan = self.plan_generator.generate_plan(query, final_ranked)
logger.info(f"生成的计划: {final_plan}")
# 修改准备参考来源的部分
# 创建表格的表头
table_header = "| Reference URL | Relevance Score | Retrieval Score | Rerank Score |\n| --- | --- | --- | --- |"
# 准备表格行
table_rows = []
for doc in final_ranked:
# 如果标题为空,使用URL作为标题
title = doc.get('title', '').strip()
if not title:
from urllib.parse import urlparse
domain = urlparse(doc['url']).netloc
title = domain
# 创建表格行
row = (
f"| [{title}]({doc['url']}) | "
f"{doc.get('final_score', 0):.3f} | "
f"{doc.get('retrieval_score', 0):.3f} | "
f"{doc.get('rerank_score', 0):.3f} |"
)
table_rows.append(row)
# 组合表格
sources = table_header + "\n" + "\n".join(table_rows)
logger.info(f"参考来源: {sources}")
# 修改图片展示部分
image_html = ""
if enable_images:
try:
# 增加搜索数量,因为要过滤
images = self.search_engine.search_images(query, count=8)
valid_images = []
if images:
# 过滤图片
for img in images:
img_url = img.get('url', '')
if img_url and self.verify_image_url(img_url):
valid_images.append(img_url)
if len(valid_images) >= 3: # 只需要3张有效图片
break
if valid_images: # 如果有有效图片
image_html = """
<div style="display: flex; flex-direction: column; gap: 15px;">
"""
for img_url in valid_images:
image_html += f"""
<div style="border-radius: 8px; overflow: hidden;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.3);">
<div style="position: relative; padding-top: 66.67%;">
<img src="{img_url}"
alt="旅行相关图片"
style="position: absolute; top: 0; left: 0; width: 100%;
height: 100%; object-fit: cover; transition: transform 0.3s;"
onerror="this.style.display='none'">
</div>
</div>
"""
image_html += "</div>"
except Exception as e:
logger.warning(f"获取图片时出现错误: {str(e)}")
# 美化计划内容的展示
plan_content = final_plan['plan']
# 替换标记符号
replacements = {
'###': '', # 移除三个#
'##': '', # 移除两个#
'# ': '', # 移除单个#
'**': '', # 移除所有**
}
for old, new in replacements.items():
plan_content = plan_content.replace(old, new)
# 处理标题和段落
paragraphs = plan_content.split('\n')
formatted_paragraphs = []
for p in paragraphs:
p = p.strip()
if not p:
continue
if "Tour Overview" in p:
# 主标题样式
formatted_paragraphs.append(
f'<h2 style="color: #f3f4f6; margin: 10px 0 12px 0; font-size: 1.15em; '
f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.3; '
f'border-bottom: 1px solid rgba(99, 102, 241, 0.3); padding-bottom: 8px; '
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">'
f'📍 {p}</h2>'
)
elif ":" in p and any(x in p for x in ["Date", "Destination", "Key Attractions"]):
# 关键信息样式,加粗值部分
key, value = p.split(":", 1)
formatted_paragraphs.append(
f'<div style="display: flex; align-items: start; margin: 4px 0; '
f'padding-left: 4px;">'
f'<span style="color: #818cf8; font-weight: 500; min-width: 70px; '
f'font-size: 0.92em;">{key}:</span>'
f'<span style="flex: 1; line-height: 1.4; color: #e2e8f0; margin-left: 8px; '
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; '
f'font-size: 0.92em; font-weight: 600; letter-spacing: 0.005em;">{value.strip()}</span>'
f'</div>'
)
elif "Daily Itinerary" in p:
# 主标题样式
formatted_paragraphs.append(
f'<h2 style="color: #f3f4f6; margin: 20px 0 12px 0; font-size: 1.15em; '
f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.3; '
f'border-bottom: 1px solid rgba(99, 102, 241, 0.3); padding-bottom: 8px; '
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">'
f'🕒️ {p}</h2>'
)
elif " - " in p: # 时间段标题
# 子标题样式
formatted_paragraphs.append(
f'<h3 style="color: #e2e8f0; margin: 14px 0 6px 0; font-size: 1.05em; '
f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.4; padding-bottom: 4px; '
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">'
f'🕒 {p}</h3>'
)
elif p.startswith("Location") or p.startswith("Activity") or p.startswith("Transportation") or p.startswith("Specific Guidance"):
# 信息样式
key, value = p.split(":", 1)
icon = {
"Location": "📍",
"Activity": "🎯",
"Transportation": "🚇",
"Specific Guidance": "🗺️"
}.get(key, "•")
formatted_paragraphs.append(
f'<div style="display: flex; align-items: start; margin: 4px 0; padding-left: 4px;">'
f'<span style="color: #818cf8; margin-right: 8px;">{icon}</span>'
f'<span style="flex: 1; line-height: 1.4; color: #e2e8f0; '
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; '
f'font-size: 0.92em; font-weight: 400; letter-spacing: 0.005em;">{value.strip()}</span>'
f'</div>'
)
else:
# 普通段落样式
formatted_paragraphs.append(
f'<p style="margin: 8px 0; line-height: 1.5; color: #e2e8f0; '
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; '
f'font-size: 0.95em; font-weight: 400; letter-spacing: 0.005em;">{p}</p>'
)
plan_content = '\n'.join(formatted_paragraphs)
# 将所有内容包装在一个暗色主题的容器中
final_output = f"""
<div style="max-width: 100%; padding: 24px; background: rgba(17, 24, 39, 0.7);
border-radius: 16px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2);">
<div style="margin-top: 20px;">
{plan_content}
</div>
</div>
"""
return final_output, sources, image_html
except Exception as e:
logger.error(f"Error processing query: {str(e)}")
return f"Sorry, an error occurred while processing your request: {str(e)}", "", ""
def verify_image_url(self, url: str) -> bool:
"""验证图片URL是否可访问且符合要求"""
try:
import requests
from PIL import Image
import io
import numpy as np
from PIL import ImageDraw, ImageFont
# 获取图片
response = requests.get(url, timeout=3)
if response.status_code != 200:
return False
# 检查内容类型
content_type = response.headers.get('content-type', '')
if 'image' not in content_type.lower():
return False
# 读取图片
img = Image.open(io.BytesIO(response.content))
# 1. 检查图片尺寸
width, height = img.size
if width < 300 or height < 300: # 过滤掉太小的图片
return False
# 2. 检查宽高比
aspect_ratio = width / height
if aspect_ratio < 0.5 or aspect_ratio > 2.0: # 过滤掉比例不合适的图片
return False
# 3. 转换为numpy数组进行分析
img_array = np.array(img)
# 4. 检查图片是否过于单调(可能是纯文字图)
if len(img_array.shape) == 3: # 确保是彩色图片
std = np.std(img_array)
if std < 30: # 标准差太小说明图片太单调
return False
# 5. 检测文字区域(简单实现)
# 转换为灰度图
if img.mode != 'L':
img_gray = img.convert('L')
else:
img_gray = img
# 计算边缘密度
from PIL import ImageFilter
edges = img_gray.filter(ImageFilter.FIND_EDGES)
edge_density = np.mean(np.array(edges))
# 如果边缘密度太高,可能包含大量文字
if edge_density > 30:
return False
# 6. 检查图片是否过于饱和(可能是广告图)
if len(img_array.shape) == 3:
hsv = img.convert('HSV')
saturation = np.array(hsv)[:,:,1]
if np.mean(saturation) > 200: # 饱和度过高
return False
return True
except Exception as e:
logger.warning(f"图片验证失败: {str(e)}")
return False
def _format_images_html(self, images: List[str]) -> str:
"""格式化图片HTML展示"""
if not images:
return ""
# 使用flex布局来展示图片
html = """
<div style="display: flex; flex-wrap: wrap; gap: 10px; justify-content: center; margin-top: 20px;">
"""
for img_url in images:
# 添加图片容器和加载失败处理
html += f"""
<div style="flex: 0 0 calc(50% - 10px); max-width: 300px; min-width: 200px;">
<img
src="{img_url}"
style="width: 100%; height: 200px; object-fit: cover; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);"
onerror="this.onerror=null; this.src='https://via.placeholder.com/300x200?text=Image+Not+Available';"
/>
</div>
"""
html += "</div>"
# 添加调试日志
logger.info(f"生成的图片HTML: {html[:200]}...") # 只打印前200个字符
return html
def set_retrieval_method(self, method: str):
"""切换检索方法"""
if method not in ["standard"]:
raise ValueError(f"不支持的检索方法: {method}")
self.retrieval_method = method
# 根据方法初始化对应的检索器
if method == "standard":
self.retriever = self.ranking_system
def create_interface():
system = TravelRAGSystem()
# 获取已启用的提供商列表
enabled_providers = [
provider['name']
for provider in system.config['llm_settings']['providers']
if provider['enabled']
]
# 创建提供商和模型的映射
provider_models = {
provider['name']: provider['models']
for provider in system.config['llm_settings']['providers']
if provider['enabled']
}
# 创建界面并设置自定义CSS
css = """
.gradio-container {
font-family: "PingFang SC", "Microsoft YaHei", sans-serif;
}
/* 针对所有英文文本 */
[class*="message-"] {
font-family: 'Times New Roman', serif !important;
}
/* 确保英文和数字使用 Times New Roman */
.gradio-container *:not(:lang(zh)) {
font-family: 'Times New Roman', serif !important;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
/* 隐藏数字输入框的上下箭头 */
input[type="number"]::-webkit-inner-spin-button,
input[type="number"]::-webkit-outer-spin-button {
-webkit-appearance: none;
margin: 0;
}
input[type="number"] {
-moz-appearance: textfield;
}
/* 隐藏默认的 processing 信息和箭头 */
.progress-text, .meta-text-center, .progress-container {
display: none !important;
}
/* 修改加载动画样式 */
.loading {
display: flex;
align-items: center;
justify-content: center;
gap: 8px;
font-size: 1.2em;
color: rgb(192, 192, 255);
}
.loading::before {
content: '🌍';
display: inline-block;
animation: spin 2s linear infinite;
filter: brightness(1.5); /* 让地球图标更亮 */
}
/* 调整 Gradio 默认加载动画的位置 */
.progress-text {
display: block !important;
order: 3;
margin-top: 8px;
opacity: 0.7;
}
.meta-text-center {
display: block !important;
}
/* 确保加载容器使用 flex 布局 */
.loading-container {
display: flex;
flex-direction: column;
align-items: center;
}
/* 隐藏滑块右侧的上下箭头 */
.num-input-plus, .num-input-minus {
display: none !important;
}
/* 隐藏所有滚动箭头 */
.scroll-hide,
.output-markdown,
.output-text,
.markdown-text,
.prose,
.gr-box,
.gr-panel {
-ms-overflow-style: none !important;
scrollbar-width: none !important;
overflow-y: hidden !important;
overflow: hidden !important;
}
.scroll-hide::-webkit-scrollbar,
.output-markdown::-webkit-scrollbar,
.output-text::-webkit-scrollbar,
.markdown-text::-webkit-scrollbar,
.prose::-webkit-scrollbar,
.gr-box::-webkit-scrollbar,
.gr-panel::-webkit-scrollbar {
display: none !important;
width: 0 !important;
height: 0 !important;
}
/* 修改加载动画容器样式 */
.loading-container {
overflow: hidden !important;
min-height: 60px;
}
/* 隐藏 Gradio 默认的滚动控件 */
.wrap.svelte-byatnx,
.contain.svelte-byatnx,
[class*='svelte'],
.gradio-container {
overflow: hidden !important;
overflow-y: hidden !important;
}
/* 禁用所有可能的滚动控件 */
::-webkit-scrollbar {
display: none !important;
width: 0 !important;
height: 0 !important;
}
/* 移除 Group 组件的默认背景 */
.custom-group {
border: none !important;
background: none !important;
box-shadow: none !important;
}
.custom-group > div {
border: none !important;
background: none !important;
box-shadow: none !important;
}
/* 添加图片容器样式 */
.images-container {
margin-top: 20px;
padding: 10px;
background: #fff;
border-radius: 8px;
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
}
.images-container img {
transition: transform 0.3s ease;
}
.images-container img:hover {
transform: scale(1.05);
}
/* 确保图片容器可见 */
#component-13 {
min-height: 200px;
overflow: visible !important;
}
"""
# 修改 JavaScript 加载状态文本
js = """
function showLoading() {
document.getElementById('loading_status').innerHTML = '<p class="loading">Generating your personalized travel plan...</p>';
return ['', ''];
}
"""
with gr.Blocks(theme=gr.themes.Soft(), css=css) as interface:
gr.Markdown("""
# 🌟 Tourism Planning Assistant 🌟
Welcome to the Smart Travel Planning Assistant! Simply input your travel requirements, and we'll generate a personalized travel plan for you.
### Instructions
1. Describe your travel needs in the input box (e.g., 'One-day trip to Hong Kong Disneyland')
2. Select the number of days for your plan
3. Click the "Generate Plan" button
""")
with gr.Row():
with gr.Column(scale=4):
llm_provider = gr.Dropdown(
choices=enabled_providers,
value=system.config['llm_settings']['default_provider'],
label="Select LLM Provider"
)
llm_model = gr.Dropdown(
choices=provider_models[system.config['llm_settings']['default_provider']],
label="Select Model"
)
# 添加更新模型选择的函数
def update_model_choices(provider):
return gr.Dropdown(choices=provider_models[provider])
# 设置提供商改变时的回调
llm_provider.change(
fn=update_model_choices,
inputs=[llm_provider],
outputs=[llm_model]
)
query_input = gr.Textbox(
label="Travel Requirements",
placeholder="Please enter your travel requirements, e.g.: One-day trip to Hong Kong Disneyland",
lines=2
)
days_input = gr.Slider(
minimum=1,
maximum=7,
value=1,
step=1,
label="Number of Days"
)
# 添加显示图片的复选框
show_images = gr.Checkbox(
label="Search Related Images",
value=True,
info="Whether to search and display related reference images"
)
# 移除 memorag 和 graphrag 选项,只保留 standard
retrieval_method = gr.Radio(
choices=["standard"],
value="standard",
label="Retrieval Method",
info="Choose different retrieval strategies",
visible=False # 由于只有一个选项,可以直接隐藏
)
submit_btn = gr.Button("Generate Plan", variant="primary")
loading_status = gr.Markdown("", elem_id="loading_status", show_label=False)
# 添加图片展示区域到左侧列
images_container = gr.HTML(
value="", # 确保初始值为空字符串
visible=True,
label="Related Images"
)
# 当复选框状态改变时更新图片区域的显示状态
show_images.change(
fn=lambda x: "" if not x else "<div></div>", # 当禁用图片时返回空字符串
inputs=[show_images],
outputs=[images_container]
)
with gr.Column(scale=6):
with gr.Tabs():
with gr.TabItem("Travel Plan"):
plan_output = gr.HTML(label="Generated Travel Plan", show_label=False)
with gr.TabItem("References and Evaluation"):
sources_output = gr.Markdown(label="References and Evaluation", show_label=False)
# 修改示例为英文
gr.Examples(
examples=[
["One-day trip to Hong Kong Disneyland", 1],
["Family trip to Hong Kong Ocean Park", 1],
["Hong Kong Shopping and Food Tour", 2],
["Hong Kong Cultural Experience Tour", 3]
],
inputs=[query_input, days_input],
label="Example Queries"
)
def show_loading():
loading_html = "<div class='loading-container'><p class='loading'>Generating your personalized travel plan...</p></div>"
return loading_html, loading_html, "", ""
def process_with_images(query, days, llm_provider, llm_model, enable_images, retrieval_method):
plan_html, sources_md, images_html = system.process_query(
query, days, llm_provider, llm_model,
enable_images, retrieval_method
)
# 添加调试日志
logger.info(f"图片HTML长度: {len(images_html) if images_html else 0}")
return plan_html, sources_md, images_html
# 设置提交按钮事件
submit_btn.click(
fn=show_loading,
inputs=None,
outputs=[loading_status, plan_output, sources_output, images_container]
).then(
fn=process_with_images,
inputs=[
query_input,
days_input,
llm_provider,
llm_model,
show_images,
retrieval_method
],
outputs=[plan_output, sources_output, images_container] # 确保顺序正确
).then(
fn=lambda: "",
inputs=None,
outputs=[loading_status]
)
# 修改页脚为英文
gr.Markdown("""
### 📝 Notes
- Plan generation may take some time, please be patient
- Queries should include specific locations and activity preferences
- All plans are AI-generated, please adjust according to actual circumstances
Powered by RAG for Tourism system © 2024
""")
return interface
if __name__ == "__main__":
demo = create_interface()
# 使用 Hugging Face Spaces 环境变量
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False, # Hugging Face Spaces 已经提供了公开访问
debug=False,
ssr_mode=False
)