csm-1b-tts-demo / app.py
Refactor FastAPI application to enhance audio synthesis functionality and improve logging. Introduced model loading with error handling, updated endpoint to '/synthesize', and implemented thread-safe audio generation. Modified Dockerfile to streamline dependency installation and updated requirements.txt to include additional packages.
import os
import logging
from io import BytesIO
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import torch
import torchaudio
from huggingface_hub import hf_hub_download
# Set watermark key to avoid errors in model's watermarking (use public GH key)
os.environ["WATERMARK_KEY"] = os.environ.get("WATERMARK_KEY", "212 211 146 56 201")
# Configure logging
logger = logging.getLogger("csm_app")
# Initialize FastAPI
app = FastAPI()
# Request model for input payload
class SynthesisRequest(BaseModel):
text: str
# Load model at startup
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info("Downloading CSM-1B model from Hugging Face...")
model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt")
# Import model loader from CSM repository (requires generator.py from the model's codebase)
from generator import load_csm_1b
generator = load_csm_1b(model_path, device)
logger.info(f"CSM-1B model loaded on {device}.")
except Exception as e:
logger.error("Failed to load the CSM-1B model", exc_info=True)
raise RuntimeError("Model loading failed") from e
# Optional lock for thread-safe generation (ensure one generation at a time)
from threading import Lock
_generate_lock = Lock()
def synthesize(request: SynthesisRequest):
"""Synthesize speech from text and return WAV audio."""
text = request.text
if not text or not text.strip():
logger.error("Received empty text input")
raise HTTPException(status_code=400, detail="Text input is empty.")
logger.info(f"Received synthesis request (text length={len(text)} chars).")
with _generate_lock:
# Generate audio tensor from text
audio = generator.generate(
# Move to CPU (if on GPU) and prepare WAV bytes
audio = audio.cpu()
sample_rate = getattr(generator, "sample_rate", 44100) # model sample rate (default 44100 Hz)
wav_bytes = BytesIO()
torchaudio.save(wav_bytes, audio.unsqueeze(0), sample_rate, format="wav")
logger.info("Audio generated successfully, returning WAV file.")
except Exception as e:
logger.error("Error during audio generation", exc_info=True)
raise HTTPException(status_code=500, detail="Internal server error during synthesis.")
# Stream the WAV audio back to the client
return StreamingResponse(wav_bytes, media_type="audio/wav")