Text-To-Speech / app.py
Bmo411's picture
Update app.py
d0fe171 verified
raw
history blame contribute delete
7.24 kB
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow import keras
import torch
from huggingface_hub import hf_hub_download
from speechbrain.inference.TTS import Tacotron2
import os
# Cargar modelo Tacotron2
tacotron2 = Tacotron2.from_hparams(
source="speechbrain/tts-tacotron2-ljspeech",
savedir="tmpdir_tts",
run_opts={"device": "cpu"}
)
# Diccionario para almacenar los modelos cargados
loaded_models = {}
# Modelos disponibles - define aqu铆 las 茅pocas que quieres incluir
available_models = {
"脡poca 100": "generator_epoch_100.keras",
"脡poca 300": "generator_epoch_300.keras",
"脡poca 400": "generator_epoch_400.keras",
"脡poca 1000": "generator_epoch_1000.keras",
"脡poca 4200": "generator_epoch_4200.keras",
"脡poca 4700": "generator_epoch_4700.keras",
"脡poca 7700": "generator_epoch_7700.keras",
}
# Funci贸n para cargar un modelo espec铆fico
def load_generator_model(model_name):
if model_name in loaded_models:
return loaded_models[model_name]
try:
model_path = hf_hub_download(
repo_id="Bmo411/WGAN",
filename=model_name
)
model = keras.models.load_model(model_path, compile=False)
loaded_models[model_name] = model
print(f"Modelo {model_name} cargado correctamente")
return model
except Exception as e:
print(f"Error al cargar el modelo {model_name}: {e}")
# Si falla la carga, intentamos usar el modelo de la 茅poca 1000 como fallback
try:
fallback_model = "generator_epoch_1000.keras"
model_path = hf_hub_download(
repo_id="Bmo411/WGAN",
filename=fallback_model
)
model = keras.models.load_model(model_path, compile=False)
loaded_models[model_name] = model # Guardamos con el nombre original para evitar recargar
print(f"Usando modelo fallback {fallback_model}")
return model
except:
print("Error cr铆tico al cargar modelos. No hay modelos disponibles.")
return None
# Funci贸n para convertir texto a audio
def text_to_audio(text, model_epoch):
# Crear un array vac铆o por defecto en caso de error
default_audio = np.zeros(8000, dtype=np.float32)
sample_rate = 8000 # Ajusta seg煤n la configuraci贸n de tu modelo
if not text or not text.strip():
return (sample_rate, default_audio)
try:
# Obtener el nombre del archivo del modelo seleccionado
model_filename = available_models[model_epoch]
# Cargar el modelo generador correspondiente
generator = load_generator_model(model_filename)
if generator is None:
print("No se pudo cargar el generador")
return (sample_rate, default_audio)
# Convertir texto a mel-spectrograma con Tacotron2
mel_output, _, _ = tacotron2.encode_text(text)
mel = mel_output.detach().cpu().numpy().astype(np.float32)
# Imprimir forma original del mel para debugging
print(f"Forma original del mel: {mel.shape}")
# Reorganizar el mel para que coincida con la forma esperada (batch, 80, frames, 1)
# Si mel tiene forma (80, frames) - lo m谩s probable
if len(mel.shape) == 2:
mel_input = np.expand_dims(mel, axis=0) # (1, 80, frames)
mel_input = np.expand_dims(mel_input, axis=-1) # (1, 80, frames, 1)
# Si viene con otra forma, intentamos adaptarla
elif len(mel.shape) == 3 and mel.shape[0] == 1:
# Si es (1, 80, frames) o (1, frames, 80)
if mel.shape[1] == 80:
mel_input = np.expand_dims(mel, axis=-1) # (1, 80, frames, 1)
else:
mel_input = np.expand_dims(np.transpose(mel, (0, 2, 1)), axis=-1) # (1, 80, frames, 1)
else:
# Intento final de reorganizaci贸n
mel_input = np.expand_dims(np.expand_dims(mel, axis=0), axis=-1)
print(f"Forma del mel preparado: {mel_input.shape}")
# Generar audio
generated_audio = generator(mel_input, training=False)
# Procesar el audio generado
generated_audio = tf.squeeze(generated_audio).numpy()
# Asegurarse de que hay valores no cero antes de normalizar
if np.max(np.abs(generated_audio)) > 0:
generated_audio = generated_audio / np.max(np.abs(generated_audio))
# Convertir a float32 para gradio
generated_audio = generated_audio.astype(np.float32)
print(f"Forma del audio generado: {generated_audio.shape}")
current_length = len(generated_audio)
if current_length > 8000:
# Recortar si es m谩s largo de 2 segundos
print(f"Recortando audio de {current_length} a {8000} muestras")
final_audio = generated_audio[:8000]
else:
# Rellenar con ceros si es m谩s corto de 2 segundos
print(f"Rellenando audio de {current_length} a {8000} muestras")
final_audio = np.zeros(8000, dtype=np.float32)
final_audio[:current_length] = generated_audio
return (sample_rate, final_audio)
except Exception as e:
print(f"Error en la generaci贸n de audio: {e}")
# Si hay error, imprimir un traceback completo para mejor diagn贸stico
import traceback
traceback.print_exc()
return (sample_rate, default_audio)
# Crear interfaz en Gradio
with gr.Blocks(title="Demo de TTS con Tacotron2 + Generador") as interface:
gr.Markdown("# Demo de TTS con Tacotron2 + Generador")
gr.Markdown("Convierte texto en audio usando Tacotron2 + modelo Generator entrenado en diferentes 茅pocas.")
with gr.Row():
with gr.Column(scale=3):
text_input = gr.Textbox(lines=2, placeholder="Escribe nine-", label="Texto a convertir")
with gr.Column(scale=1):
model_selection = gr.Dropdown(
choices=list(available_models.keys()),
value="脡poca 1000",
label="Selecciona la 茅poca del modelo"
)
generate_btn = gr.Button("Generar Audio", variant="primary")
audio_output = gr.Audio(label="Audio generado")
# Configurar ejemplos
examples = gr.Examples(
examples=[
["nine", "脡poca 100"],
["nine", "脡poca 400"],
["nine", "脡poca 4700"]
],
inputs=[text_input, model_selection],
outputs=audio_output
)
# Conectar bot贸n a la funci贸n
generate_btn.click(fn=text_to_audio, inputs=[text_input, model_selection], outputs=audio_output)
# Tambi茅n permitir enviar con Enter desde el cuadro de texto
text_input.submit(fn=text_to_audio, inputs=[text_input, model_selection], outputs=audio_output)
# Lanzar aplicaci贸n
if __name__ == "__main__":
# Precargamos el modelo de la 茅poca 1000 para tenerlo disponible inmediatamente
load_generator_model(available_models["脡poca 1000"])
# Lanzamos la interfaz
interface.launch(debug=True)