import os import logging from io import BytesIO from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel import torch import torchaudio from huggingface_hub import hf_hub_download # Set watermark key to avoid errors in model's watermarking (use public GH key) os.environ["WATERMARK_KEY"] = os.environ.get("WATERMARK_KEY", "212 211 146 56 201") # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("csm_app") # Initialize FastAPI app = FastAPI() # Request model for input payload class SynthesisRequest(BaseModel): text: str # Load model at startup device = "cuda" if torch.cuda.is_available() else "cpu" try: logger.info("Downloading CSM-1B model from Hugging Face...") model_path = hf_hub_download(repo_id="sesame/csm-1b", filename="ckpt.pt") # Import model loader from CSM repository (requires generator.py from the model's codebase) from generator import load_csm_1b generator = load_csm_1b(model_path, device) logger.info(f"CSM-1B model loaded on {device}.") except Exception as e: logger.error("Failed to load the CSM-1B model", exc_info=True) raise RuntimeError("Model loading failed") from e # Optional lock for thread-safe generation (ensure one generation at a time) from threading import Lock _generate_lock = Lock() @app.post("/synthesize") def synthesize(request: SynthesisRequest): """Synthesize speech from text and return WAV audio.""" text = request.text if not text or not text.strip(): logger.error("Received empty text input") raise HTTPException(status_code=400, detail="Text input is empty.") logger.info(f"Received synthesis request (text length={len(text)} chars).") try: with _generate_lock: # Generate audio tensor from text audio = generator.generate( text=text, speaker=0, context=[], max_audio_length_ms=10000 ) # Move to CPU (if on GPU) and prepare WAV bytes audio = audio.cpu() sample_rate = getattr(generator, "sample_rate", 44100) # model sample rate (default 44100 Hz) wav_bytes = BytesIO() torchaudio.save(wav_bytes, audio.unsqueeze(0), sample_rate, format="wav") wav_bytes.seek(0) logger.info("Audio generated successfully, returning WAV file.") except Exception as e: logger.error("Error during audio generation", exc_info=True) raise HTTPException(status_code=500, detail="Internal server error during synthesis.") # Stream the WAV audio back to the client return StreamingResponse(wav_bytes, media_type="audio/wav")