File size: 5,877 Bytes
31fdbaa efe8448 31fdbaa efe8448 31fdbaa efe8448 31fdbaa efe8448 31fdbaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import os
import io
import base64
import torch
import numpy as np
from transformers import BarkModel, BarkProcessor
from typing import Dict, List, Any
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the handler for Bark text-to-speech model.
Args:
path (str, optional): Path to the model directory. Defaults to "".
"""
self.path = path
self.model = None
self.processor = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.initialized = False
def setup(self, **kwargs):
"""
Load the model and processor.
Args:
**kwargs: Additional arguments.
"""
# Load model from the local directory
self.model = BarkModel.from_pretrained(self.path)
self.model.to(self.device)
# Load processor
self.processor = BarkProcessor.from_pretrained(self.path)
self.initialized = True
print(f"Bark model loaded on {self.device}")
def preprocess(self, request: Dict) -> Dict:
"""
Process the input request before inference.
Args:
request (Dict): The request data containing text to convert to speech.
Returns:
Dict: Processed inputs for the model.
"""
if not self.initialized:
self.setup()
inputs = {}
# Get text from the request
if "inputs" in request:
if isinstance(request["inputs"], str):
# Single text input
inputs["text"] = request["inputs"]
elif isinstance(request["inputs"], list):
# List of text inputs
inputs["text"] = request["inputs"][0] # Take the first text
# Get optional parameters
params = request.get("parameters", {})
# Speaker ID/voice preset
if "speaker_id" in params:
inputs["speaker_id"] = params["speaker_id"]
elif "voice_preset" in params:
inputs["voice_preset"] = params["voice_preset"]
# Other generation parameters
if "temperature" in params:
inputs["temperature"] = params.get("temperature", 0.7)
return inputs
def inference(self, inputs: Dict) -> Dict:
"""
Run model inference on the processed inputs.
Args:
inputs (Dict): Processed inputs for the model.
Returns:
Dict: Model outputs.
"""
text = inputs.get("text", "")
if not text:
return {"error": "No text provided for speech generation"}
# Extract optional parameters
speaker_id = inputs.get("speaker_id", None)
voice_preset = inputs.get("voice_preset", None)
temperature = inputs.get("temperature", 0.7)
# Prepare inputs for the model
input_ids = self.processor(text).to(self.device)
# Generate speech
with torch.no_grad():
if speaker_id:
# Use speaker_id if provided
speech_output = self.model.generate(
input_ids=input_ids,
speaker_id=speaker_id,
temperature=temperature
)
elif voice_preset:
# Use voice_preset if provided
speech_output = self.model.generate(
input_ids=input_ids,
voice_preset=voice_preset,
temperature=temperature
)
else:
# Use default settings
speech_output = self.model.generate(
input_ids=input_ids,
temperature=temperature
)
# Convert to numpy array
audio_array = speech_output.cpu().numpy().squeeze()
return {"audio_array": audio_array, "sample_rate": self.model.generation_config.sample_rate}
def postprocess(self, inference_output: Dict) -> Dict:
"""
Process the model outputs after inference.
Args:
inference_output (Dict): Model outputs.
Returns:
Dict: Processed outputs ready for the response.
"""
if "error" in inference_output:
return {"error": inference_output["error"]}
audio_array = inference_output.get("audio_array")
sample_rate = inference_output.get("sample_rate", 24000)
# Convert audio array to WAV format
try:
import scipy.io.wavfile as wav
audio_buffer = io.BytesIO()
wav.write(audio_buffer, sample_rate, audio_array)
audio_buffer.seek(0)
audio_data = audio_buffer.read()
# Encode audio data to base64
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
return {
"audio": audio_base64,
"sample_rate": sample_rate,
"format": "wav"
}
except Exception as e:
return {"error": f"Error converting audio: {str(e)}"}
def __call__(self, data: Dict) -> Dict:
"""
Main entry point for the handler.
Args:
data (Dict): Request data.
Returns:
Dict: Response data.
"""
# Ensure the model is initialized
if not self.initialized:
self.setup()
# Process the request
try:
inputs = self.preprocess(data)
outputs = self.inference(inputs)
response = self.postprocess(outputs)
return response
except Exception as e:
return {"error": f"Error processing request: {str(e)}"}
|