Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException, BackgroundTasks | |
from fastapi.responses import StreamingResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import os | |
import uuid | |
import torch | |
import torchaudio | |
import base64 | |
from io import BytesIO | |
from transformers import AutoModelForCausalLM | |
import sys | |
import subprocess | |
from datetime import datetime, timedelta | |
app = FastAPI(title="Nigerian TTS API") | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # In production, set this to your Next.js domain | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Initialize necessary directories | |
os.makedirs("audio_files", exist_ok=True) | |
os.makedirs("models", exist_ok=True) | |
# Check if YarnGPT is installed, if not install it | |
try: | |
import yarngpt | |
from yarngpt.audiotokenizer import AudioTokenizerV2 | |
except ImportError: | |
print("Installing YarnGPT and dependencies...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "git+https://github.com/saheedniyi02/yarngpt.git"]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "outetts", "uroman", "transformers", "torchaudio"]) | |
from yarngpt.audiotokenizer import AudioTokenizerV2 | |
# Model configuration | |
tokenizer_path = "saheedniyi/YarnGPT2" | |
# Check if model files exist, if not download them | |
wav_tokenizer_config_path = "./models/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml" | |
wav_tokenizer_model_path = "./models/wavtokenizer_large_speech_320_24k.ckpt" | |
if not os.path.exists(wav_tokenizer_config_path): | |
print("Downloading model config file...") | |
subprocess.check_call([ | |
"wget", "-O", wav_tokenizer_config_path, | |
"https://huggingface.co/novateur/WavTokenizer-medium-speech-75token/resolve/main/wavtokenizer_mediumdata_frame75_3s_nq1_code4096_dim512_kmeans200_attn.yaml" | |
]) | |
if not os.path.exists(wav_tokenizer_model_path): | |
print("Downloading model checkpoint file...") | |
subprocess.check_call([ | |
"wget", "-O", wav_tokenizer_model_path, | |
"https://drive.google.com/uc?id=1-ASeEkrn4HY49yZWHTASgfGFNXdVnLTt&export=download" | |
]) | |
print("Loading YarnGPT model and tokenizer...") | |
audio_tokenizer = AudioTokenizerV2( | |
tokenizer_path, wav_tokenizer_model_path, wav_tokenizer_config_path | |
) | |
model = AutoModelForCausalLM.from_pretrained(tokenizer_path, torch_dtype="auto").to(audio_tokenizer.device) | |
print("Model loaded successfully!") | |
# Available voices and languages | |
AVAILABLE_VOICES = { | |
"female": ["zainab", "idera", "regina", "chinenye", "joke", "remi"], | |
"male": ["jude", "tayo", "umar", "osagie", "onye", "emma"] | |
} | |
AVAILABLE_LANGUAGES = ["english", "yoruba", "igbo", "hausa"] | |
# Input validation model | |
class TTSRequest(BaseModel): | |
text: str | |
language: str = "english" | |
voice: str = "idera" | |
# Output model with base64-encoded audio | |
class TTSResponse(BaseModel): | |
audio_base64: str # Base64-encoded audio data | |
audio_url: str # Keep for backward compatibility | |
text: str | |
voice: str | |
language: str | |
async def root(): | |
"""API health check and info""" | |
return { | |
"status": "ok", | |
"message": "Nigerian TTS API is running", | |
"available_languages": AVAILABLE_LANGUAGES, | |
"available_voices": AVAILABLE_VOICES | |
} | |
async def text_to_speech(request: TTSRequest, background_tasks: BackgroundTasks): | |
"""Convert text to Nigerian-accented speech""" | |
# Validate inputs | |
if request.language not in AVAILABLE_LANGUAGES: | |
raise HTTPException(status_code=400, detail=f"Language must be one of {AVAILABLE_LANGUAGES}") | |
all_voices = AVAILABLE_VOICES["female"] + AVAILABLE_VOICES["male"] | |
if request.voice not in all_voices: | |
raise HTTPException(status_code=400, detail=f"Voice must be one of {all_voices}") | |
# Generate unique filename | |
audio_id = str(uuid.uuid4()) | |
output_path = f"audio_files/{audio_id}.wav" | |
try: | |
# Create prompt and generate audio | |
prompt = audio_tokenizer.create_prompt(request.text, lang=request.language, speaker_name=request.voice) | |
input_ids = audio_tokenizer.tokenize_prompt(prompt) | |
output = model.generate( | |
input_ids=input_ids, | |
temperature=0.1, | |
repetition_penalty=1.1, | |
max_length=4000, | |
) | |
codes = audio_tokenizer.get_codes(output) | |
audio = audio_tokenizer.get_audio(codes) | |
# Save audio file | |
torchaudio.save(output_path, audio, sample_rate=24000) | |
# Read the file and encode as base64 | |
with open(output_path, "rb") as audio_file: | |
audio_bytes = audio_file.read() | |
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') | |
# Clean up old files after a while | |
background_tasks.add_task(cleanup_old_files) | |
return TTSResponse( | |
audio_base64=audio_base64, | |
audio_url=f"/audio/{audio_id}.wav", # Keep for compatibility | |
text=request.text, | |
voice=request.voice, | |
language=request.language | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}") | |
# File serving endpoint for direct audio access | |
async def get_audio(filename: str): | |
file_path = f"audio_files/{filename}" | |
if not os.path.exists(file_path): | |
raise HTTPException(status_code=404, detail="Audio file not found") | |
def iterfile(): | |
with open(file_path, "rb") as audio_file: | |
yield from audio_file | |
return StreamingResponse(iterfile(), media_type="audio/wav") | |
# Endpoint to stream audio directly from base64 (useful for debugging) | |
async def stream_audio(request: TTSRequest): | |
"""Stream audio directly without saving to disk""" | |
try: | |
# Create prompt and generate audio | |
prompt = audio_tokenizer.create_prompt(request.text, lang=request.language, speaker_name=request.voice) | |
input_ids = audio_tokenizer.tokenize_prompt(prompt) | |
output = model.generate( | |
input_ids=input_ids, | |
temperature=0.1, | |
repetition_penalty=1.1, | |
max_length=4000, | |
) | |
codes = audio_tokenizer.get_codes(output) | |
audio = audio_tokenizer.get_audio(codes) | |
# Create BytesIO object | |
buffer = BytesIO() | |
torchaudio.save(buffer, audio, sample_rate=24000, format="wav") | |
buffer.seek(0) | |
return StreamingResponse(buffer, media_type="audio/wav") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error generating audio: {str(e)}") | |
# Cleanup function to remove old files | |
def cleanup_old_files(): | |
"""Delete audio files older than 6 hours to manage disk space""" | |
try: | |
now = datetime.now() | |
audio_dir = "audio_files" | |
for filename in os.listdir(audio_dir): | |
if not filename.endswith(".wav"): | |
continue | |
file_path = os.path.join(audio_dir, filename) | |
file_mod_time = datetime.fromtimestamp(os.path.getmtime(file_path)) | |
# Delete files older than 6 hours | |
if now - file_mod_time > timedelta(hours=6): | |
os.remove(file_path) | |
print(f"Deleted old audio file: {filename}") | |
except Exception as e: | |
print(f"Error cleaning up old files: {e}") | |
# For running locally with uvicorn | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 8000)) | |
uvicorn.run(app, host="0.0.0.0", port=port) |