Spaces:
Runtime error
Runtime error
Константин
Refactor FastAPI application to enhance audio synthesis functionality and improve logging. Introduced model loading with error handling, updated endpoint to '/synthesize', and implemented thread-safe audio generation. Modified Dockerfile to streamline dependency installation and updated requirements.txt to include additional packages.
5f36d04
import os | |
import logging | |
from io import BytesIO | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel | |
import torch | |
import torchaudio | |
from huggingface_hub import hf_hub_download | |
# Set watermark key to avoid errors in model's watermarking (use public GH key) | |
os.environ["WATERMARK_KEY"] = os.environ.get("WATERMARK_KEY", "212 211 146 56 201") | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("csm_app") | |
# Initialize FastAPI | |
app = FastAPI() | |
# Request model for input payload | |
class SynthesisRequest(BaseModel): | |
text: str | |
# Load model at startup | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
try: | |
logger.info("Downloading CSM-1B model from Hugging Face...") | |
model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt") | |
# Import model loader from CSM repository (requires generator.py from the model's codebase) | |
from generator import load_csm_1b | |
generator = load_csm_1b(model_path, device) | |
logger.info(f"CSM-1B model loaded on {device}.") | |
except Exception as e: | |
logger.error("Failed to load the CSM-1B model", exc_info=True) | |
raise RuntimeError("Model loading failed") from e | |
# Optional lock for thread-safe generation (ensure one generation at a time) | |
from threading import Lock | |
_generate_lock = Lock() | |
def synthesize(request: SynthesisRequest): | |
"""Synthesize speech from text and return WAV audio.""" | |
text = request.text | |
if not text or not text.strip(): | |
logger.error("Received empty text input") | |
raise HTTPException(status_code=400, detail="Text input is empty.") | |
logger.info(f"Received synthesis request (text length={len(text)} chars).") | |
try: | |
with _generate_lock: | |
# Generate audio tensor from text | |
audio = generator.generate( | |
text=text, | |
speaker=0, | |
context=[], | |
max_audio_length_ms=10000 | |
) | |
# Move to CPU (if on GPU) and prepare WAV bytes | |
audio = audio.cpu() | |
sample_rate = getattr(generator, "sample_rate", 44100) # model sample rate (default 44100 Hz) | |
wav_bytes = BytesIO() | |
torchaudio.save(wav_bytes, audio.unsqueeze(0), sample_rate, format="wav") | |
wav_bytes.seek(0) | |
logger.info("Audio generated successfully, returning WAV file.") | |
except Exception as e: | |
logger.error("Error during audio generation", exc_info=True) | |
raise HTTPException(status_code=500, detail="Internal server error during synthesis.") | |
# Stream the WAV audio back to the client | |
return StreamingResponse(wav_bytes, media_type="audio/wav") | |