visqol-7 / app.py
clash-linux's picture
Upload 9 files
6230c1a verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
import subprocess
import tempfile
import os
import shutil
from pydantic import BaseModel
import sys
import numpy as np # ViSQOL 可能需要 numpy
import soundfile as sf # 用于读取音频
from typing import Optional, List # 导入 List
import librosa # Need librosa for resampling during conversion if soundfile fails
app = FastAPI(title="ViSQOL 音频质量 API")
# --- 配置 ViSQOL 路径 ---
# 相对于 app.py 的路径
VISQOL_DIR = "./build/visqol"
VISQOL_LIB_PATH = os.path.join(VISQOL_DIR, "visqol_lib_py.so")
PB2_DIR = os.path.join(VISQOL_DIR, "pb2") # pb2 文件所在的目录
MODEL_DIR = os.path.join(VISQOL_DIR, "model")
SPEECH_MODEL_PATH = os.path.join(MODEL_DIR, "libsvm_nu_svr_model.txt")
AUDIO_MODEL_PATH = os.path.join(MODEL_DIR, "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite")
# --- 路径配置结束 ---
# 检查文件是否存在
required_files = [VISQOL_LIB_PATH, SPEECH_MODEL_PATH, AUDIO_MODEL_PATH]
if not all(os.path.exists(f) for f in required_files):
missing = [f for f in required_files if not os.path.exists(f)]
raise FileNotFoundError(f"ViSQOL 必需文件未找到: {', '.join(missing)}")
if not os.path.exists(PB2_DIR) or not os.path.isdir(PB2_DIR):
raise FileNotFoundError(f"ViSQOL pb2 目录未找到: {PB2_DIR}")
# 动态导入 ViSQOL 库和 pb2 文件
try:
# 将 pb2 目录和 visqol 目录添加到 Python 路径
sys.path.insert(0, os.path.abspath(PB2_DIR))
sys.path.insert(0, os.path.abspath(VISQOL_DIR))
# 加载 .so 文件需要确保 Python 能找到它,或者它在 LD_LIBRARY_PATH 中
# 通常放在 sys.path 中对于纯 Python 导入是足够的,但 .so 可能不同
# 在 Dockerfile 中我们会处理库路径
import visqol_lib_py
import similarity_result_pb2
import visqol_config_pb2
print("ViSQOL 库和 pb2 文件导入成功。")
except ImportError as e:
print(f"错误:无法导入 ViSQOL 库或 pb2 文件。")
print(f"Python 搜索路径: {sys.path}")
print(f"错误详情: {e}")
# 在 Hugging Face 环境中,启动失败会显示日志,所以这里不直接 raise
# raise ImportError(f"无法导入 ViSQOL 库或 pb2 文件: {e}")
visqol_lib_py = None # 标记为不可用
# 定义 API 响应模型
class VisqolResponse(BaseModel):
reference_filename: str
degraded_filename: str
mode: str
moslqo: float
vnsim: Optional[float] = None # 添加 vnsim 字段,设为可选
fvnsim: Optional[List[float]] = None # 添加 fvnsim 字段,设为可选
status: str
error_message: Optional[str] = None
# Function to convert and resample audio using ffmpeg
def convert_and_resample_audio(input_path, output_path, target_sr):
"""Converts audio to WAV format and resamples using ffmpeg."""
cmd = [
'ffmpeg',
'-y', # Overwrite output file if it exists
'-i', input_path,
'-ar', str(target_sr), # Set target sample rate
'-ac', '1', # Force mono channel (ViSQOL often expects mono)
output_path
]
print(f"Running ffmpeg: {' '.join(cmd)}")
try:
result = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8')
print("ffmpeg conversion successful.")
# print(f"ffmpeg stderr: {result.stderr}") # Optional debug
return True
except FileNotFoundError:
print("错误: ffmpeg 未找到,无法转换音频。请确保已在 Docker 环境中安装 ffmpeg。")
return False
except subprocess.CalledProcessError as e:
print(f"错误: ffmpeg 执行失败 (返回码 {e.returncode})。")
print(f"ffmpeg stderr: {e.stderr}")
return False
except Exception as e:
print(f"转换音频时发生未知错误: {e}")
return False
@app.post("/evaluate/", response_model=VisqolResponse)
async def evaluate_audio(
reference: UploadFile = File(..., description="参考音频文件"),
degraded: UploadFile = File(..., description="待评估音频文件"),
mode: str = "audio" # 'audio' 或 'speech'
):
"""
使用 ViSQOL 评估两个音频文件之间的感知相似度。
返回预测的平均意见得分 (MOS-LQO)。
"""
if visqol_lib_py is None:
raise HTTPException(status_code=500, detail="ViSQOL 库未成功加载。")
if mode not in ["audio", "speech"]:
raise HTTPException(status_code=400, detail="模式参数 'mode' 必须是 'audio' 或 'speech'")
temp_dir = tempfile.mkdtemp()
# Save with original extension first to help ffmpeg identify format
ref_temp_orig = os.path.join(temp_dir, f"ref_{reference.filename}")
deg_temp_orig = os.path.join(temp_dir, f"deg_{degraded.filename}")
# Define final WAV paths
ref_path_wav = os.path.join(temp_dir, "reference.wav")
deg_path_wav = os.path.join(temp_dir, "degraded.wav")
mos = -1.0
vnsim_val = None # 初始化 vnsim
fvnsim_val = None # 初始化 fvnsim
status_msg = "处理失败"
error_msg = None
try:
# 1. 保存原始上传文件
ref_content = await reference.read()
with open(ref_temp_orig, "wb") as f: f.write(ref_content)
deg_content = await degraded.read()
with open(deg_temp_orig, "wb") as f: f.write(deg_content)
await reference.close()
await degraded.close()
# 2. 确定目标采样率并转换/重采样文件
target_sr = 48000 if mode == 'audio' else 16000
print(f"目标采样率: {target_sr} Hz for mode '{mode}'")
conv_ref_ok = convert_and_resample_audio(ref_temp_orig, ref_path_wav, target_sr)
conv_deg_ok = convert_and_resample_audio(deg_temp_orig, deg_path_wav, target_sr)
if not (conv_ref_ok and conv_deg_ok):
raise HTTPException(status_code=500, detail="使用 ffmpeg 转换或重采样音频文件失败。")
# 3. 验证转换后的 WAV 文件 (可选)
try:
ref_info = sf.info(ref_path_wav)
deg_info = sf.info(deg_path_wav)
if ref_info.samplerate != target_sr or deg_info.samplerate != target_sr:
print(f"警告:ffmpeg 转换后的采样率 ({ref_info.samplerate}/{deg_info.samplerate}) 与目标 ({target_sr}) 不符,可能影响 ViSQOL 结果。")
except Exception as audio_e:
# 如果 sf.info 失败,可能是 ffmpeg 转换有问题
raise HTTPException(status_code=400, detail=f"无法读取转换后的 WAV 文件: {audio_e}")
# 4. 加载转换/重采样后的音频数据
try:
print(f"从 WAV 加载音频数据: {ref_path_wav}, {deg_path_wav}")
# 确保读取为 float64 类型 (对应 C++ double)
ref_data, sr_ref = sf.read(ref_path_wav, dtype='float64')
deg_data, sr_deg = sf.read(deg_path_wav, dtype='float64')
# 确认采样率是否符合预期 (理论上 ffmpeg 已经处理)
if sr_ref != target_sr or sr_deg != target_sr:
print(f"警告:读取的 WAV 文件采样率 ({sr_ref}/{sr_deg}) 与目标 ({target_sr}) 不符。")
# 可以选择在这里停止或继续
print("音频数据加载成功。")
except Exception as read_e:
raise HTTPException(status_code=500, detail=f"读取转换后的 WAV 文件时出错: {read_e}")
# 5. 初始化 ViSQOL 配置 (修正模型选择逻辑)
config = visqol_config_pb2.VisqolConfig()
config.audio.sample_rate = target_sr # 使用目标采样率
# 修正模型选择:根据官方示例调整
if mode == "speech":
config.options.use_speech_scoring = True
# Speech mode uses the TFLite model according to official example
model_file_to_use = AUDIO_MODEL_PATH # .tflite model
else: # audio mode
config.options.use_speech_scoring = False
# Audio mode uses the SVR model according to official example
model_file_to_use = SPEECH_MODEL_PATH # .txt model (libsvm)
config.options.svr_model_path = os.path.abspath(model_file_to_use)
print(f"使用模型: {model_file_to_use} for mode '{mode}'")
# 6. 创建 API 实例并运行评估 (传递数据而不是路径)
api = visqol_lib_py.VisqolApi()
api.Create(config) # 传递对象
# 传递加载的 NumPy 数组
similarity_result_msg = api.Measure(ref_data, deg_data) # <--- 修改此处
# 7. 处理结果 (逻辑保持不变,增加提取 vnsim 和 fvnsim)
if similarity_result_msg and hasattr(similarity_result_msg, 'moslqo'):
mos = similarity_result_msg.moslqo
status_msg = "处理成功"
print(f"ViSQOL 评估完成: MOS-LQO = {mos}")
# 尝试提取 vnsim
if hasattr(similarity_result_msg, 'vnsim'):
vnsim_val = similarity_result_msg.vnsim
print(f"VNSIM = {vnsim_val}")
else:
print("ViSQOL 结果中未找到 vnsim 字段。")
# 尝试提取 fvnsim (需要转换为 Python 列表)
if hasattr(similarity_result_msg, 'fvnsim') and similarity_result_msg.fvnsim:
fvnsim_val = list(similarity_result_msg.fvnsim) # 转换为列表
print(f"FVNSIM (第一个元素): {fvnsim_val[0] if fvnsim_val else 'N/A'}") # 打印部分信息
else:
print("ViSQOL 结果中未找到 fvnsim 字段或为空。")
else:
error_msg = "ViSQOL 未返回有效的 MOS-LQO 结果。"
print(f"错误: {error_msg}")
except ImportError as e:
status_msg = "导入错误"
error_msg = f"无法导入 ViSQOL 库或依赖: {e}"
print(f"错误: {error_msg}")
except FileNotFoundError as e:
status_msg = "文件未找到错误"
error_msg = f"必需文件丢失: {e}"
print(f"错误: {error_msg}")
except HTTPException as e: # 捕获我们自己抛出的 HTTP 异常
status_msg = "请求错误"
error_msg = str(e.detail)
print(f"错误: {error_msg}")
except Exception as e:
status_msg = "运行时错误"
error_msg = f"处理过程中发生错误: {type(e).__name__} - {e}"
print(f"错误: {error_msg}")
# 可以在这里添加更详细的堆栈跟踪日志,如果需要
# import traceback
# print(traceback.format_exc())
finally:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
return VisqolResponse(
reference_filename=reference.filename,
degraded_filename=degraded.filename,
mode=mode,
moslqo=mos,
vnsim=vnsim_val, # 添加 vnsim 到响应
fvnsim=fvnsim_val, # 添加 fvnsim 到响应
status=status_msg,
error_message=error_msg
)
@app.get("/", include_in_schema=False)
async def root():
# 提供一个简单的根路径信息
return {"message": "欢迎使用 ViSQOL 音频质量评估 API。请使用 POST 方法访问 /evaluate/ 端点。"}
# 添加健康检查端点
@app.get("/healthz", status_code=200)
async def health_check():
"""Hugging Face Spaces health check endpoint."""
# 如果 ViSQOL 库加载失败,也在这里反映出来
if visqol_lib_py is None:
return {"status": "error", "detail": "ViSQOL library not loaded"}
return {"status": "ok"}
# 如果直接运行脚本,用于本地测试 (可选)
if __name__ == "__main__":
import uvicorn
print("运行本地测试服务器: http://127.0.0.1:8000")
# 注意:本地运行可能需要正确设置 LD_LIBRARY_PATH 或将 .so 文件放在系统可查找的路径
uvicorn.run(app, host="127.0.0.1", port=8000)