File size: 5,877 Bytes
31fdbaa
 
efe8448
31fdbaa
 
 
 
efe8448
31fdbaa
 
 
 
 
 
 
 
 
 
 
 
efe8448
31fdbaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efe8448
31fdbaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import io
import base64
import torch
import numpy as np
from transformers import BarkModel, BarkProcessor
from typing import Dict, List, Any

class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the handler for Bark text-to-speech model.
        Args:
            path (str, optional): Path to the model directory. Defaults to "".
        """
        self.path = path
        self.model = None
        self.processor = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.initialized = False

    def setup(self, **kwargs):
        """
        Load the model and processor.
        Args:
            **kwargs: Additional arguments.
        """
        # Load model from the local directory
        self.model = BarkModel.from_pretrained(self.path)
        self.model.to(self.device)
        
        # Load processor
        self.processor = BarkProcessor.from_pretrained(self.path)
        
        self.initialized = True
        print(f"Bark model loaded on {self.device}")

    def preprocess(self, request: Dict) -> Dict:
        """
        Process the input request before inference.
        Args:
            request (Dict): The request data containing text to convert to speech.
        Returns:
            Dict: Processed inputs for the model.
        """
        if not self.initialized:
            self.setup()
            
        inputs = {}
        
        # Get text from the request
        if "inputs" in request:
            if isinstance(request["inputs"], str):
                # Single text input
                inputs["text"] = request["inputs"]
            elif isinstance(request["inputs"], list):
                # List of text inputs
                inputs["text"] = request["inputs"][0]  # Take the first text
        
        # Get optional parameters
        params = request.get("parameters", {})
        
        # Speaker ID/voice preset
        if "speaker_id" in params:
            inputs["speaker_id"] = params["speaker_id"]
        elif "voice_preset" in params:
            inputs["voice_preset"] = params["voice_preset"]
            
        # Other generation parameters
        if "temperature" in params:
            inputs["temperature"] = params.get("temperature", 0.7)
        
        return inputs

    def inference(self, inputs: Dict) -> Dict:
        """
        Run model inference on the processed inputs.
        Args:
            inputs (Dict): Processed inputs for the model.
        Returns:
            Dict: Model outputs.
        """
        text = inputs.get("text", "")
        if not text:
            return {"error": "No text provided for speech generation"}
        
        # Extract optional parameters
        speaker_id = inputs.get("speaker_id", None)
        voice_preset = inputs.get("voice_preset", None)
        temperature = inputs.get("temperature", 0.7)
        
        # Prepare inputs for the model
        input_ids = self.processor(text).to(self.device)
        
        # Generate speech
        with torch.no_grad():
            if speaker_id:
                # Use speaker_id if provided
                speech_output = self.model.generate(
                    input_ids=input_ids,
                    speaker_id=speaker_id,
                    temperature=temperature
                )
            elif voice_preset:
                # Use voice_preset if provided
                speech_output = self.model.generate(
                    input_ids=input_ids,
                    voice_preset=voice_preset,
                    temperature=temperature
                )
            else:
                # Use default settings
                speech_output = self.model.generate(
                    input_ids=input_ids,
                    temperature=temperature
                )
        
        # Convert to numpy array
        audio_array = speech_output.cpu().numpy().squeeze()
        
        return {"audio_array": audio_array, "sample_rate": self.model.generation_config.sample_rate}

    def postprocess(self, inference_output: Dict) -> Dict:
        """
        Process the model outputs after inference.
        Args:
            inference_output (Dict): Model outputs.
        Returns:
            Dict: Processed outputs ready for the response.
        """
        if "error" in inference_output:
            return {"error": inference_output["error"]}
        
        audio_array = inference_output.get("audio_array")
        sample_rate = inference_output.get("sample_rate", 24000)
        
        # Convert audio array to WAV format
        try:
            import scipy.io.wavfile as wav
            audio_buffer = io.BytesIO()
            wav.write(audio_buffer, sample_rate, audio_array)
            audio_buffer.seek(0)
            audio_data = audio_buffer.read()
            
            # Encode audio data to base64
            audio_base64 = base64.b64encode(audio_data).decode("utf-8")
            
            return {
                "audio": audio_base64,
                "sample_rate": sample_rate,
                "format": "wav"
            }
        except Exception as e:
            return {"error": f"Error converting audio: {str(e)}"}

    def __call__(self, data: Dict) -> Dict:
        """
        Main entry point for the handler.
        Args:
            data (Dict): Request data.
        Returns:
            Dict: Response data.
        """
        # Ensure the model is initialized
        if not self.initialized:
            self.setup()
            
        # Process the request
        try:
            inputs = self.preprocess(data)
            outputs = self.inference(inputs)
            response = self.postprocess(outputs)
            return response
        except Exception as e:
            return {"error": f"Error processing request: {str(e)}"}