Spaces:
Running
Running
import gradio as gr | |
import sys | |
import os | |
from pathlib import Path | |
from typing import List | |
# 添加项目根目录到 Python 路径 | |
sys.path.append(str(Path(__file__).parent)) | |
from src.api.search_api import BochaSearch | |
from src.core.document_processor import DocumentProcessor | |
from src.core.ranking import RankingSystem | |
from src.core.plan_generator import PlanGenerator | |
from src.core.embeddings import EmbeddingModel | |
from src.core.reranker import Reranker | |
from src.api.llm_api import DeepseekInterface, LLMInterface, OpenAIInterface | |
from src.utils.helpers import load_config | |
import logging | |
# 设置日志 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class TravelRAGSystem: | |
def __init__(self): | |
self.config = load_config("config/config.yaml") | |
self.llm_instances = {} # 存储不同provider的LLM实例 | |
# 固定使用 standard 检索方法 | |
self.retrieval_method = "standard" | |
self.init_llm_instances() | |
self.init_components() | |
def init_components(self): | |
# 获取默认提供商的配置 | |
default_provider = self.config['llm_settings']['default_provider'] | |
provider_config = next( | |
(p for p in self.config['llm_settings']['providers'] | |
if p['name'] == default_provider), | |
None | |
) | |
if not provider_config: | |
raise ValueError(f"未找到默认提供商 {default_provider} 的配置") | |
# 初始化LLM实例 | |
self.llm = self.init_llm(provider_config['name'], provider_config['model']) | |
self.search_engine = BochaSearch( | |
api_key=self.config['bocha_api_key'], | |
base_url=self.config['bocha_base_url'] | |
) | |
self.doc_processor = DocumentProcessor(self.llm) | |
# 初始化嵌入模型 - 使用 Hugging Face 模型 ID | |
try: | |
self.embedding_model = EmbeddingModel( | |
model_name="BAAI/bge-m3" | |
) | |
logger.info("成功加载嵌入模型") | |
except Exception as e: | |
logger.error(f"加载嵌入模型失败: {str(e)}") | |
raise | |
# 初始化重排序器 - 使用 Hugging Face 模型 ID | |
try: | |
self.reranker = Reranker( | |
model_path="BAAI/bge-reranker-large" | |
) | |
logger.info("成功加载重排序模型") | |
except Exception as e: | |
logger.error(f"加载重排序模型失败: {str(e)}") | |
raise | |
self.ranking_system = RankingSystem(self.embedding_model, self.reranker) | |
self.plan_generator = PlanGenerator(self.llm) | |
def init_llm(self, provider: str, model: str): | |
if provider == "openai": | |
return OpenAIInterface( | |
api_key=self.config['openai_api_key'], | |
model=model | |
) | |
elif provider == "deepseek": | |
return DeepseekInterface( | |
api_key=self.config['deepseek_api_key'], | |
base_url=next( | |
p['base_url'] for p in self.config['llm_settings']['providers'] | |
if p['name'] == 'deepseek' | |
), | |
model=model | |
) | |
else: | |
raise ValueError(f"不支持的LLM提供商: {provider}") | |
def init_llm_instances(self): | |
"""初始化所有启用的LLM实例""" | |
for provider in self.config['llm_settings']['providers']: | |
if provider.get('enabled', False): | |
try: | |
if provider['name'] == "openai": | |
self.llm_instances['openai'] = OpenAIInterface( | |
api_key=self.config['openai_api_key'], | |
model=provider['model'] | |
) | |
else: | |
self.llm_instances['deepseek'] = DeepseekInterface( | |
api_key=self.config['deepseek_api_key'], | |
base_url=provider['base_url'], | |
model=provider['model'] | |
) | |
logging.info(f"成功初始化 {provider['name']} LLM") | |
except Exception as e: | |
logging.error(f"初始化 {provider['name']} LLM 失败: {str(e)}") | |
def get_llm(self, provider_name: str = None) -> LLMInterface: | |
"""获取指定的LLM实例""" | |
if not provider_name: | |
provider_name = self.config['llm_settings']['default_provider'] | |
if provider_name not in self.llm_instances: | |
raise ValueError(f"未找到或未启用的LLM提供商: {provider_name}") | |
return self.llm_instances[provider_name] | |
def process_query( | |
self, | |
query: str, | |
days: int, | |
llm_provider: str, | |
llm_model: str, | |
enable_images: bool = True, | |
retrieval_method: str = None | |
) -> tuple: | |
try: | |
# 如果指定了新的检索方法,则切换 | |
if retrieval_method and retrieval_method != self.retrieval_method: | |
self.set_retrieval_method(retrieval_method) | |
# 确保LLM提供商存在 | |
if llm_provider not in self.llm_instances: | |
raise ValueError(f"LLM提供商 {llm_provider} 未启用或不可用,将使用默认提供商") | |
current_llm = self.llm_instances[llm_provider] | |
self.doc_processor = DocumentProcessor(current_llm) | |
self.plan_generator = PlanGenerator(current_llm) | |
# 确保查询包含天数 | |
if days > 0: | |
query = f"{query} {days} days" | |
# 执行搜索 | |
logger.info(f"执行搜索: {query}") | |
search_results = self.search_engine.search(query) | |
logger.info(f"搜索结果: {search_results}") | |
# 处理文档 | |
passages = self.doc_processor.process_documents(search_results) | |
logger.info(f"处理后的文档: {passages}") | |
# 使用当前检索器进行检索和排序 | |
if hasattr(self, 'retriever'): | |
final_ranked = self.retriever.retrieve(query, passages) | |
else: | |
# 使用默认的排序系统 | |
initial_ranked = self.ranking_system.initial_ranking(query, passages) | |
final_ranked = self.ranking_system.rerank(query, initial_ranked) | |
# 生成计划 | |
final_plan = self.plan_generator.generate_plan(query, final_ranked) | |
logger.info(f"生成的计划: {final_plan}") | |
# 修改准备参考来源的部分 | |
# 创建表格的表头 | |
table_header = "| Reference URL | Relevance Score | Retrieval Score | Rerank Score |\n| --- | --- | --- | --- |" | |
# 准备表格行 | |
table_rows = [] | |
for doc in final_ranked: | |
# 如果标题为空,使用URL作为标题 | |
title = doc.get('title', '').strip() | |
if not title: | |
from urllib.parse import urlparse | |
domain = urlparse(doc['url']).netloc | |
title = domain | |
# 创建表格行 | |
row = ( | |
f"| [{title}]({doc['url']}) | " | |
f"{doc.get('final_score', 0):.3f} | " | |
f"{doc.get('retrieval_score', 0):.3f} | " | |
f"{doc.get('rerank_score', 0):.3f} |" | |
) | |
table_rows.append(row) | |
# 组合表格 | |
sources = table_header + "\n" + "\n".join(table_rows) | |
logger.info(f"参考来源: {sources}") | |
# 修改图片展示部分 | |
image_html = "" | |
if enable_images: | |
try: | |
# 增加搜索数量,因为要过滤 | |
images = self.search_engine.search_images(query, count=8) | |
valid_images = [] | |
if images: | |
# 过滤图片 | |
for img in images: | |
img_url = img.get('url', '') | |
if img_url and self.verify_image_url(img_url): | |
valid_images.append(img_url) | |
if len(valid_images) >= 3: # 只需要3张有效图片 | |
break | |
if valid_images: # 如果有有效图片 | |
image_html = """ | |
<div style="display: flex; flex-direction: column; gap: 15px;"> | |
""" | |
for img_url in valid_images: | |
image_html += f""" | |
<div style="border-radius: 8px; overflow: hidden; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.3);"> | |
<div style="position: relative; padding-top: 66.67%;"> | |
<img src="{img_url}" | |
alt="旅行相关图片" | |
style="position: absolute; top: 0; left: 0; width: 100%; | |
height: 100%; object-fit: cover; transition: transform 0.3s;" | |
onerror="this.style.display='none'"> | |
</div> | |
</div> | |
""" | |
image_html += "</div>" | |
except Exception as e: | |
logger.warning(f"获取图片时出现错误: {str(e)}") | |
# 美化计划内容的展示 | |
plan_content = final_plan['plan'] | |
# 替换标记符号 | |
replacements = { | |
'###': '', # 移除三个# | |
'##': '', # 移除两个# | |
'# ': '', # 移除单个# | |
'**': '', # 移除所有** | |
} | |
for old, new in replacements.items(): | |
plan_content = plan_content.replace(old, new) | |
# 处理标题和段落 | |
paragraphs = plan_content.split('\n') | |
formatted_paragraphs = [] | |
for p in paragraphs: | |
p = p.strip() | |
if not p: | |
continue | |
if "Tour Overview" in p: | |
# 主标题样式 | |
formatted_paragraphs.append( | |
f'<h2 style="color: #f3f4f6; margin: 10px 0 12px 0; font-size: 1.15em; ' | |
f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.3; ' | |
f'border-bottom: 1px solid rgba(99, 102, 241, 0.3); padding-bottom: 8px; ' | |
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">' | |
f'📍 {p}</h2>' | |
) | |
elif ":" in p and any(x in p for x in ["Date", "Destination", "Key Attractions"]): | |
# 关键信息样式,加粗值部分 | |
key, value = p.split(":", 1) | |
formatted_paragraphs.append( | |
f'<div style="display: flex; align-items: start; margin: 4px 0; ' | |
f'padding-left: 4px;">' | |
f'<span style="color: #818cf8; font-weight: 500; min-width: 70px; ' | |
f'font-size: 0.92em;">{key}:</span>' | |
f'<span style="flex: 1; line-height: 1.4; color: #e2e8f0; margin-left: 8px; ' | |
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; ' | |
f'font-size: 0.92em; font-weight: 600; letter-spacing: 0.005em;">{value.strip()}</span>' | |
f'</div>' | |
) | |
elif "Daily Itinerary" in p: | |
# 主标题样式 | |
formatted_paragraphs.append( | |
f'<h2 style="color: #f3f4f6; margin: 20px 0 12px 0; font-size: 1.15em; ' | |
f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.3; ' | |
f'border-bottom: 1px solid rgba(99, 102, 241, 0.3); padding-bottom: 8px; ' | |
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">' | |
f'🕒️ {p}</h2>' | |
) | |
elif " - " in p: # 时间段标题 | |
# 子标题样式 | |
formatted_paragraphs.append( | |
f'<h3 style="color: #e2e8f0; margin: 14px 0 6px 0; font-size: 1.05em; ' | |
f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.4; padding-bottom: 4px; ' | |
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">' | |
f'🕒 {p}</h3>' | |
) | |
elif p.startswith("Location") or p.startswith("Activity") or p.startswith("Transportation") or p.startswith("Specific Guidance"): | |
# 信息样式 | |
key, value = p.split(":", 1) | |
icon = { | |
"Location": "📍", | |
"Activity": "🎯", | |
"Transportation": "🚇", | |
"Specific Guidance": "🗺️" | |
}.get(key, "•") | |
formatted_paragraphs.append( | |
f'<div style="display: flex; align-items: start; margin: 4px 0; padding-left: 4px;">' | |
f'<span style="color: #818cf8; margin-right: 8px;">{icon}</span>' | |
f'<span style="flex: 1; line-height: 1.4; color: #e2e8f0; ' | |
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; ' | |
f'font-size: 0.92em; font-weight: 400; letter-spacing: 0.005em;">{value.strip()}</span>' | |
f'</div>' | |
) | |
else: | |
# 普通段落样式 | |
formatted_paragraphs.append( | |
f'<p style="margin: 8px 0; line-height: 1.5; color: #e2e8f0; ' | |
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; ' | |
f'font-size: 0.95em; font-weight: 400; letter-spacing: 0.005em;">{p}</p>' | |
) | |
plan_content = '\n'.join(formatted_paragraphs) | |
# 将所有内容包装在一个暗色主题的容器中 | |
final_output = f""" | |
<div style="max-width: 100%; padding: 24px; background: rgba(17, 24, 39, 0.7); | |
border-radius: 16px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2);"> | |
<div style="margin-top: 20px;"> | |
{plan_content} | |
</div> | |
</div> | |
""" | |
return final_output, sources, image_html | |
except Exception as e: | |
logger.error(f"Error processing query: {str(e)}") | |
return f"Sorry, an error occurred while processing your request: {str(e)}", "", "" | |
def verify_image_url(self, url: str) -> bool: | |
"""验证图片URL是否可访问且符合要求""" | |
try: | |
import requests | |
from PIL import Image | |
import io | |
import numpy as np | |
from PIL import ImageDraw, ImageFont | |
# 获取图片 | |
response = requests.get(url, timeout=3) | |
if response.status_code != 200: | |
return False | |
# 检查内容类型 | |
content_type = response.headers.get('content-type', '') | |
if 'image' not in content_type.lower(): | |
return False | |
# 读取图片 | |
img = Image.open(io.BytesIO(response.content)) | |
# 1. 检查图片尺寸 | |
width, height = img.size | |
if width < 300 or height < 300: # 过滤掉太小的图片 | |
return False | |
# 2. 检查宽高比 | |
aspect_ratio = width / height | |
if aspect_ratio < 0.5 or aspect_ratio > 2.0: # 过滤掉比例不合适的图片 | |
return False | |
# 3. 转换为numpy数组进行分析 | |
img_array = np.array(img) | |
# 4. 检查图片是否过于单调(可能是纯文字图) | |
if len(img_array.shape) == 3: # 确保是彩色图片 | |
std = np.std(img_array) | |
if std < 30: # 标准差太小说明图片太单调 | |
return False | |
# 5. 检测文字区域(简单实现) | |
# 转换为灰度图 | |
if img.mode != 'L': | |
img_gray = img.convert('L') | |
else: | |
img_gray = img | |
# 计算边缘密度 | |
from PIL import ImageFilter | |
edges = img_gray.filter(ImageFilter.FIND_EDGES) | |
edge_density = np.mean(np.array(edges)) | |
# 如果边缘密度太高,可能包含大量文字 | |
if edge_density > 30: | |
return False | |
# 6. 检查图片是否过于饱和(可能是广告图) | |
if len(img_array.shape) == 3: | |
hsv = img.convert('HSV') | |
saturation = np.array(hsv)[:,:,1] | |
if np.mean(saturation) > 200: # 饱和度过高 | |
return False | |
return True | |
except Exception as e: | |
logger.warning(f"图片验证失败: {str(e)}") | |
return False | |
def _format_images_html(self, images: List[str]) -> str: | |
"""格式化图片HTML展示""" | |
if not images: | |
return "" | |
# 使用flex布局来展示图片 | |
html = """ | |
<div style="display: flex; flex-wrap: wrap; gap: 10px; justify-content: center; margin-top: 20px;"> | |
""" | |
for img_url in images: | |
# 添加图片容器和加载失败处理 | |
html += f""" | |
<div style="flex: 0 0 calc(50% - 10px); max-width: 300px; min-width: 200px;"> | |
<img | |
src="{img_url}" | |
style="width: 100%; height: 200px; object-fit: cover; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);" | |
onerror="this.onerror=null; this.src='https://via.placeholder.com/300x200?text=Image+Not+Available';" | |
/> | |
</div> | |
""" | |
html += "</div>" | |
# 添加调试日志 | |
logger.info(f"生成的图片HTML: {html[:200]}...") # 只打印前200个字符 | |
return html | |
def set_retrieval_method(self, method: str): | |
"""切换检索方法""" | |
if method not in ["standard"]: | |
raise ValueError(f"不支持的检索方法: {method}") | |
self.retrieval_method = method | |
# 根据方法初始化对应的检索器 | |
if method == "standard": | |
self.retriever = self.ranking_system | |
def create_interface(): | |
system = TravelRAGSystem() | |
# 获取已启用的提供商列表 | |
enabled_providers = [ | |
provider['name'] | |
for provider in system.config['llm_settings']['providers'] | |
if provider['enabled'] | |
] | |
# 创建提供商和模型的映射 | |
provider_models = { | |
provider['name']: provider['models'] | |
for provider in system.config['llm_settings']['providers'] | |
if provider['enabled'] | |
} | |
# 创建界面并设置自定义CSS | |
css = """ | |
.gradio-container { | |
font-family: "PingFang SC", "Microsoft YaHei", sans-serif; | |
} | |
/* 针对所有英文文本 */ | |
[class*="message-"] { | |
font-family: 'Times New Roman', serif !important; | |
} | |
/* 确保英文和数字使用 Times New Roman */ | |
.gradio-container *:not(:lang(zh)) { | |
font-family: 'Times New Roman', serif !important; | |
} | |
@keyframes spin { | |
0% { transform: rotate(0deg); } | |
100% { transform: rotate(360deg); } | |
} | |
/* 隐藏数字输入框的上下箭头 */ | |
input[type="number"]::-webkit-inner-spin-button, | |
input[type="number"]::-webkit-outer-spin-button { | |
-webkit-appearance: none; | |
margin: 0; | |
} | |
input[type="number"] { | |
-moz-appearance: textfield; | |
} | |
/* 隐藏默认的 processing 信息和箭头 */ | |
.progress-text, .meta-text-center, .progress-container { | |
display: none !important; | |
} | |
/* 修改加载动画样式 */ | |
.loading { | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
gap: 8px; | |
font-size: 1.2em; | |
color: rgb(192, 192, 255); | |
} | |
.loading::before { | |
content: '🌍'; | |
display: inline-block; | |
animation: spin 2s linear infinite; | |
filter: brightness(1.5); /* 让地球图标更亮 */ | |
} | |
/* 调整 Gradio 默认加载动画的位置 */ | |
.progress-text { | |
display: block !important; | |
order: 3; | |
margin-top: 8px; | |
opacity: 0.7; | |
} | |
.meta-text-center { | |
display: block !important; | |
} | |
/* 确保加载容器使用 flex 布局 */ | |
.loading-container { | |
display: flex; | |
flex-direction: column; | |
align-items: center; | |
} | |
/* 隐藏滑块右侧的上下箭头 */ | |
.num-input-plus, .num-input-minus { | |
display: none !important; | |
} | |
/* 隐藏所有滚动箭头 */ | |
.scroll-hide, | |
.output-markdown, | |
.output-text, | |
.markdown-text, | |
.prose, | |
.gr-box, | |
.gr-panel { | |
-ms-overflow-style: none !important; | |
scrollbar-width: none !important; | |
overflow-y: hidden !important; | |
overflow: hidden !important; | |
} | |
.scroll-hide::-webkit-scrollbar, | |
.output-markdown::-webkit-scrollbar, | |
.output-text::-webkit-scrollbar, | |
.markdown-text::-webkit-scrollbar, | |
.prose::-webkit-scrollbar, | |
.gr-box::-webkit-scrollbar, | |
.gr-panel::-webkit-scrollbar { | |
display: none !important; | |
width: 0 !important; | |
height: 0 !important; | |
} | |
/* 修改加载动画容器样式 */ | |
.loading-container { | |
overflow: hidden !important; | |
min-height: 60px; | |
} | |
/* 隐藏 Gradio 默认的滚动控件 */ | |
.wrap.svelte-byatnx, | |
.contain.svelte-byatnx, | |
[class*='svelte'], | |
.gradio-container { | |
overflow: hidden !important; | |
overflow-y: hidden !important; | |
} | |
/* 禁用所有可能的滚动控件 */ | |
::-webkit-scrollbar { | |
display: none !important; | |
width: 0 !important; | |
height: 0 !important; | |
} | |
/* 移除 Group 组件的默认背景 */ | |
.custom-group { | |
border: none !important; | |
background: none !important; | |
box-shadow: none !important; | |
} | |
.custom-group > div { | |
border: none !important; | |
background: none !important; | |
box-shadow: none !important; | |
} | |
/* 添加图片容器样式 */ | |
.images-container { | |
margin-top: 20px; | |
padding: 10px; | |
background: #fff; | |
border-radius: 8px; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
} | |
.images-container img { | |
transition: transform 0.3s ease; | |
} | |
.images-container img:hover { | |
transform: scale(1.05); | |
} | |
/* 确保图片容器可见 */ | |
#component-13 { | |
min-height: 200px; | |
overflow: visible !important; | |
} | |
""" | |
# 修改 JavaScript 加载状态文本 | |
js = """ | |
function showLoading() { | |
document.getElementById('loading_status').innerHTML = '<p class="loading">Generating your personalized travel plan...</p>'; | |
return ['', '']; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as interface: | |
gr.Markdown(""" | |
# 🌟 Tourism Planning Assistant 🌟 | |
Welcome to the Smart Travel Planning Assistant! Simply input your travel requirements, and we'll generate a personalized travel plan for you. | |
### Instructions | |
1. Describe your travel needs in the input box (e.g., 'One-day trip to Hong Kong Disneyland') | |
2. Select the number of days for your plan | |
3. Click the "Generate Plan" button | |
""") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
llm_provider = gr.Dropdown( | |
choices=enabled_providers, | |
value=system.config['llm_settings']['default_provider'], | |
label="Select LLM Provider" | |
) | |
llm_model = gr.Dropdown( | |
choices=provider_models[system.config['llm_settings']['default_provider']], | |
label="Select Model" | |
) | |
# 添加更新模型选择的函数 | |
def update_model_choices(provider): | |
return gr.Dropdown(choices=provider_models[provider]) | |
# 设置提供商改变时的回调 | |
llm_provider.change( | |
fn=update_model_choices, | |
inputs=[llm_provider], | |
outputs=[llm_model] | |
) | |
query_input = gr.Textbox( | |
label="Travel Requirements", | |
placeholder="Please enter your travel requirements, e.g.: One-day trip to Hong Kong Disneyland", | |
lines=2 | |
) | |
days_input = gr.Slider( | |
minimum=1, | |
maximum=7, | |
value=1, | |
step=1, | |
label="Number of Days" | |
) | |
# 添加显示图片的复选框 | |
show_images = gr.Checkbox( | |
label="Search Related Images", | |
value=True, | |
info="Whether to search and display related reference images" | |
) | |
# 移除 memorag 和 graphrag 选项,只保留 standard | |
retrieval_method = gr.Radio( | |
choices=["standard"], | |
value="standard", | |
label="Retrieval Method", | |
info="Choose different retrieval strategies", | |
visible=False # 由于只有一个选项,可以直接隐藏 | |
) | |
submit_btn = gr.Button("Generate Plan", variant="primary") | |
loading_status = gr.Markdown("", elem_id="loading_status", show_label=False) | |
# 添加图片展示区域到左侧列 | |
images_container = gr.HTML( | |
value="", # 确保初始值为空字符串 | |
visible=True, | |
label="Related Images" | |
) | |
# 当复选框状态改变时更新图片区域的显示状态 | |
show_images.change( | |
fn=lambda x: "" if not x else "<div></div>", # 当禁用图片时返回空字符串 | |
inputs=[show_images], | |
outputs=[images_container] | |
) | |
with gr.Column(scale=6): | |
with gr.Tabs(): | |
with gr.TabItem("Travel Plan"): | |
plan_output = gr.HTML(label="Generated Travel Plan", show_label=False) | |
with gr.TabItem("References and Evaluation"): | |
sources_output = gr.Markdown(label="References and Evaluation", show_label=False) | |
# 修改示例为英文 | |
gr.Examples( | |
examples=[ | |
["One-day trip to Hong Kong Disneyland", 1], | |
["Family trip to Hong Kong Ocean Park", 1], | |
["Hong Kong Shopping and Food Tour", 2], | |
["Hong Kong Cultural Experience Tour", 3] | |
], | |
inputs=[query_input, days_input], | |
label="Example Queries" | |
) | |
def show_loading(): | |
loading_html = "<div class='loading-container'><p class='loading'>Generating your personalized travel plan...</p></div>" | |
return loading_html, loading_html, "", "" | |
def process_with_images(query, days, llm_provider, llm_model, enable_images, retrieval_method): | |
plan_html, sources_md, images_html = system.process_query( | |
query, days, llm_provider, llm_model, | |
enable_images, retrieval_method | |
) | |
# 添加调试日志 | |
logger.info(f"图片HTML长度: {len(images_html) if images_html else 0}") | |
return plan_html, sources_md, images_html | |
# 设置提交按钮事件 | |
submit_btn.click( | |
fn=show_loading, | |
inputs=None, | |
outputs=[loading_status, plan_output, sources_output, images_container] | |
).then( | |
fn=process_with_images, | |
inputs=[ | |
query_input, | |
days_input, | |
llm_provider, | |
llm_model, | |
show_images, | |
retrieval_method | |
], | |
outputs=[plan_output, sources_output, images_container] # 确保顺序正确 | |
).then( | |
fn=lambda: "", | |
inputs=None, | |
outputs=[loading_status] | |
) | |
# 修改页脚为英文 | |
gr.Markdown(""" | |
### 📝 Notes | |
- Plan generation may take some time, please be patient | |
- Queries should include specific locations and activity preferences | |
- All plans are AI-generated, please adjust according to actual circumstances | |
Powered by RAG for Tourism system © 2024 | |
""") | |
return interface | |
if __name__ == "__main__": | |
demo = create_interface() | |
# 使用 Hugging Face Spaces 环境变量 | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, # Hugging Face Spaces 已经提供了公开访问 | |
debug=False, | |
ssr_mode=False | |
) |