|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
from fastapi.responses import JSONResponse |
|
import subprocess |
|
import tempfile |
|
import os |
|
import shutil |
|
from pydantic import BaseModel |
|
import sys |
|
import numpy as np |
|
import soundfile as sf |
|
from typing import Optional, List |
|
import librosa |
|
|
|
app = FastAPI(title="ViSQOL 音频质量 API") |
|
|
|
|
|
|
|
VISQOL_DIR = "./build/visqol" |
|
VISQOL_LIB_PATH = os.path.join(VISQOL_DIR, "visqol_lib_py.so") |
|
PB2_DIR = os.path.join(VISQOL_DIR, "pb2") |
|
MODEL_DIR = os.path.join(VISQOL_DIR, "model") |
|
SPEECH_MODEL_PATH = os.path.join(MODEL_DIR, "libsvm_nu_svr_model.txt") |
|
AUDIO_MODEL_PATH = os.path.join(MODEL_DIR, "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite") |
|
|
|
|
|
|
|
required_files = [VISQOL_LIB_PATH, SPEECH_MODEL_PATH, AUDIO_MODEL_PATH] |
|
if not all(os.path.exists(f) for f in required_files): |
|
missing = [f for f in required_files if not os.path.exists(f)] |
|
raise FileNotFoundError(f"ViSQOL 必需文件未找到: {', '.join(missing)}") |
|
if not os.path.exists(PB2_DIR) or not os.path.isdir(PB2_DIR): |
|
raise FileNotFoundError(f"ViSQOL pb2 目录未找到: {PB2_DIR}") |
|
|
|
|
|
try: |
|
|
|
sys.path.insert(0, os.path.abspath(PB2_DIR)) |
|
sys.path.insert(0, os.path.abspath(VISQOL_DIR)) |
|
|
|
|
|
|
|
import visqol_lib_py |
|
import similarity_result_pb2 |
|
import visqol_config_pb2 |
|
print("ViSQOL 库和 pb2 文件导入成功。") |
|
except ImportError as e: |
|
print(f"错误:无法导入 ViSQOL 库或 pb2 文件。") |
|
print(f"Python 搜索路径: {sys.path}") |
|
print(f"错误详情: {e}") |
|
|
|
|
|
visqol_lib_py = None |
|
|
|
|
|
class VisqolResponse(BaseModel): |
|
reference_filename: str |
|
degraded_filename: str |
|
mode: str |
|
moslqo: float |
|
vnsim: Optional[float] = None |
|
fvnsim: Optional[List[float]] = None |
|
status: str |
|
error_message: Optional[str] = None |
|
|
|
|
|
def convert_and_resample_audio(input_path, output_path, target_sr): |
|
"""Converts audio to WAV format and resamples using ffmpeg.""" |
|
cmd = [ |
|
'ffmpeg', |
|
'-y', |
|
'-i', input_path, |
|
'-ar', str(target_sr), |
|
'-ac', '1', |
|
output_path |
|
] |
|
print(f"Running ffmpeg: {' '.join(cmd)}") |
|
try: |
|
result = subprocess.run(cmd, check=True, capture_output=True, text=True, encoding='utf-8') |
|
print("ffmpeg conversion successful.") |
|
|
|
return True |
|
except FileNotFoundError: |
|
print("错误: ffmpeg 未找到,无法转换音频。请确保已在 Docker 环境中安装 ffmpeg。") |
|
return False |
|
except subprocess.CalledProcessError as e: |
|
print(f"错误: ffmpeg 执行失败 (返回码 {e.returncode})。") |
|
print(f"ffmpeg stderr: {e.stderr}") |
|
return False |
|
except Exception as e: |
|
print(f"转换音频时发生未知错误: {e}") |
|
return False |
|
|
|
@app.post("/evaluate/", response_model=VisqolResponse) |
|
async def evaluate_audio( |
|
reference: UploadFile = File(..., description="参考音频文件"), |
|
degraded: UploadFile = File(..., description="待评估音频文件"), |
|
mode: str = "audio" |
|
): |
|
""" |
|
使用 ViSQOL 评估两个音频文件之间的感知相似度。 |
|
返回预测的平均意见得分 (MOS-LQO)。 |
|
""" |
|
if visqol_lib_py is None: |
|
raise HTTPException(status_code=500, detail="ViSQOL 库未成功加载。") |
|
|
|
if mode not in ["audio", "speech"]: |
|
raise HTTPException(status_code=400, detail="模式参数 'mode' 必须是 'audio' 或 'speech'") |
|
|
|
temp_dir = tempfile.mkdtemp() |
|
|
|
ref_temp_orig = os.path.join(temp_dir, f"ref_{reference.filename}") |
|
deg_temp_orig = os.path.join(temp_dir, f"deg_{degraded.filename}") |
|
|
|
ref_path_wav = os.path.join(temp_dir, "reference.wav") |
|
deg_path_wav = os.path.join(temp_dir, "degraded.wav") |
|
|
|
mos = -1.0 |
|
vnsim_val = None |
|
fvnsim_val = None |
|
status_msg = "处理失败" |
|
error_msg = None |
|
|
|
try: |
|
|
|
ref_content = await reference.read() |
|
with open(ref_temp_orig, "wb") as f: f.write(ref_content) |
|
deg_content = await degraded.read() |
|
with open(deg_temp_orig, "wb") as f: f.write(deg_content) |
|
await reference.close() |
|
await degraded.close() |
|
|
|
|
|
target_sr = 48000 if mode == 'audio' else 16000 |
|
print(f"目标采样率: {target_sr} Hz for mode '{mode}'") |
|
|
|
conv_ref_ok = convert_and_resample_audio(ref_temp_orig, ref_path_wav, target_sr) |
|
conv_deg_ok = convert_and_resample_audio(deg_temp_orig, deg_path_wav, target_sr) |
|
|
|
if not (conv_ref_ok and conv_deg_ok): |
|
raise HTTPException(status_code=500, detail="使用 ffmpeg 转换或重采样音频文件失败。") |
|
|
|
|
|
try: |
|
ref_info = sf.info(ref_path_wav) |
|
deg_info = sf.info(deg_path_wav) |
|
if ref_info.samplerate != target_sr or deg_info.samplerate != target_sr: |
|
print(f"警告:ffmpeg 转换后的采样率 ({ref_info.samplerate}/{deg_info.samplerate}) 与目标 ({target_sr}) 不符,可能影响 ViSQOL 结果。") |
|
except Exception as audio_e: |
|
|
|
raise HTTPException(status_code=400, detail=f"无法读取转换后的 WAV 文件: {audio_e}") |
|
|
|
|
|
try: |
|
print(f"从 WAV 加载音频数据: {ref_path_wav}, {deg_path_wav}") |
|
|
|
ref_data, sr_ref = sf.read(ref_path_wav, dtype='float64') |
|
deg_data, sr_deg = sf.read(deg_path_wav, dtype='float64') |
|
|
|
if sr_ref != target_sr or sr_deg != target_sr: |
|
print(f"警告:读取的 WAV 文件采样率 ({sr_ref}/{sr_deg}) 与目标 ({target_sr}) 不符。") |
|
|
|
print("音频数据加载成功。") |
|
except Exception as read_e: |
|
raise HTTPException(status_code=500, detail=f"读取转换后的 WAV 文件时出错: {read_e}") |
|
|
|
|
|
config = visqol_config_pb2.VisqolConfig() |
|
config.audio.sample_rate = target_sr |
|
|
|
|
|
if mode == "speech": |
|
config.options.use_speech_scoring = True |
|
|
|
model_file_to_use = AUDIO_MODEL_PATH |
|
else: |
|
config.options.use_speech_scoring = False |
|
|
|
model_file_to_use = SPEECH_MODEL_PATH |
|
|
|
config.options.svr_model_path = os.path.abspath(model_file_to_use) |
|
print(f"使用模型: {model_file_to_use} for mode '{mode}'") |
|
|
|
|
|
api = visqol_lib_py.VisqolApi() |
|
api.Create(config) |
|
|
|
similarity_result_msg = api.Measure(ref_data, deg_data) |
|
|
|
|
|
if similarity_result_msg and hasattr(similarity_result_msg, 'moslqo'): |
|
mos = similarity_result_msg.moslqo |
|
status_msg = "处理成功" |
|
print(f"ViSQOL 评估完成: MOS-LQO = {mos}") |
|
|
|
if hasattr(similarity_result_msg, 'vnsim'): |
|
vnsim_val = similarity_result_msg.vnsim |
|
print(f"VNSIM = {vnsim_val}") |
|
else: |
|
print("ViSQOL 结果中未找到 vnsim 字段。") |
|
|
|
if hasattr(similarity_result_msg, 'fvnsim') and similarity_result_msg.fvnsim: |
|
fvnsim_val = list(similarity_result_msg.fvnsim) |
|
print(f"FVNSIM (第一个元素): {fvnsim_val[0] if fvnsim_val else 'N/A'}") |
|
else: |
|
print("ViSQOL 结果中未找到 fvnsim 字段或为空。") |
|
else: |
|
error_msg = "ViSQOL 未返回有效的 MOS-LQO 结果。" |
|
print(f"错误: {error_msg}") |
|
|
|
except ImportError as e: |
|
status_msg = "导入错误" |
|
error_msg = f"无法导入 ViSQOL 库或依赖: {e}" |
|
print(f"错误: {error_msg}") |
|
except FileNotFoundError as e: |
|
status_msg = "文件未找到错误" |
|
error_msg = f"必需文件丢失: {e}" |
|
print(f"错误: {error_msg}") |
|
except HTTPException as e: |
|
status_msg = "请求错误" |
|
error_msg = str(e.detail) |
|
print(f"错误: {error_msg}") |
|
except Exception as e: |
|
status_msg = "运行时错误" |
|
error_msg = f"处理过程中发生错误: {type(e).__name__} - {e}" |
|
print(f"错误: {error_msg}") |
|
|
|
|
|
|
|
finally: |
|
if os.path.exists(temp_dir): |
|
shutil.rmtree(temp_dir) |
|
|
|
return VisqolResponse( |
|
reference_filename=reference.filename, |
|
degraded_filename=degraded.filename, |
|
mode=mode, |
|
moslqo=mos, |
|
vnsim=vnsim_val, |
|
fvnsim=fvnsim_val, |
|
status=status_msg, |
|
error_message=error_msg |
|
) |
|
|
|
@app.get("/", include_in_schema=False) |
|
async def root(): |
|
|
|
return {"message": "欢迎使用 ViSQOL 音频质量评估 API。请使用 POST 方法访问 /evaluate/ 端点。"} |
|
|
|
|
|
@app.get("/healthz", status_code=200) |
|
async def health_check(): |
|
"""Hugging Face Spaces health check endpoint.""" |
|
|
|
if visqol_lib_py is None: |
|
return {"status": "error", "detail": "ViSQOL library not loaded"} |
|
return {"status": "ok"} |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
print("运行本地测试服务器: http://127.0.0.1:8000") |
|
|
|
uvicorn.run(app, host="127.0.0.1", port=8000) |