File size: 9,181 Bytes
fe90810
5a8d953
c5a1374
 
 
 
 
bff0819
 
8358a80
fe90810
 
92afd9b
84da2ad
 
 
 
92afd9b
 
cdcf84b
babfadf
 
5fed423
007d52b
 
48bade5
007d52b
fe90810
6548b20
fe90810
8fe07b4
fe90810
8c66bd2
 
fe90810
 
 
7097d7e
fe90810
93d1253
 
 
 
 
 
5c3646a
93d1253
 
 
 
 
 
 
 
97267cc
93d1253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dca599
93d1253
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3d2577
 
 
93d1253
 
 
 
4ff1be1
93d1253
 
 
 
 
 
 
 
 
 
 
 
 
 
fe90810
 
 
 
 
 
97267cc
fe90810
 
 
 
 
 
 
 
97267cc
fe90810
 
 
4e6dc7b
fe90810
70c765f
fe90810
 
 
70c765f
4f93d3e
70c765f
eb4d8ca
fe90810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ff1be1
fe90810
 
 
 
 
 
 
 
 
 
 
 
 
b8893a2
 
 
 
 
596d760
61e96dd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
---
license: cc-by-nc-4.0
language:
- zh
- en
base_model:
- meta-llama/Llama-3.2-3B-Instruct
tags:
- Text-to-Speech
pipeline_tag: text-to-speech
---

[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2502.04128)

**Update (2025-02-13):** Add [Llasa finetune instruction](https://github.com/zhenye234/LLaSA_training/tree/main/finetune).


**Update (2025-02-07):** Our paper has been released!


LLaSA: Scaling Train-Time and Inference-Time Compute for LLaMA-based Speech Synthesis 


- **Train from Scratch**: If you want to train the model from scratch, use the [LLaSA Training Repository](https://github.com/zhenye234/LLaSA_training).

- **Scale for Test-Time Computation**: If you want to experiment with scaling for test-time computation, use the [LLaSA Testing Repository](https://github.com/zhenye234/LLaSA_inference).

## Model Information
Our model, Llasa, is a text-to-speech (TTS) system that extends the text-based LLaMA (1B,3B, and 8B) language model by incorporating speech tokens from the XCodec2 codebook,
 which contains 65,536 tokens. We trained Llasa on a dataset comprising 250,000 hours of Chinese-English speech data.
 The model is capable of generating speech **either solely from input text or by utilizing a given speech prompt.**  

 The method is seamlessly compatible with the Llama framework, making training TTS similar as training LLM (convert audios into single-codebook tokens and simply view it as a special language). It opens the possiblity of existing method for compression, acceleration and finetuning for LLM to be applied. 



## How to use
Install [XCodec2](https://huggingface.co/HKUSTAudio/xcodec2).  

**1. Speech synthesis solely from input text**
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import soundfile as sf

llasa_3b ='HKUSTAudio/Llasa-3B'

tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
model = AutoModelForCausalLM.from_pretrained(llasa_3b)
model.eval() 
model.to('cuda')

from xcodec2.modeling_xcodec2 import XCodec2Model
 
model_path = "HKUSTAudio/xcodec2"  
 
Codec_model = XCodec2Model.from_pretrained(model_path)
Codec_model.eval().cuda()   

input_text = 'Dealing with family secrets is never easy. Yet, sometimes, omission is a form of protection, intending to safeguard some from the harsh truths. One day, I hope you understand the reasons behind my actions. Until then, Anna, please, bear with me.'
# input_text = '突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"'
def ids_to_speech_tokens(speech_ids):
 
    speech_tokens_str = []
    for speech_id in speech_ids:
        speech_tokens_str.append(f"<|s_{speech_id}|>")
    return speech_tokens_str

def extract_speech_ids(speech_tokens_str):
 
    speech_ids = []
    for token_str in speech_tokens_str:
        if token_str.startswith('<|s_') and token_str.endswith('|>'):
            num_str = token_str[4:-2]

            num = int(num_str)
            speech_ids.append(num)
        else:
            print(f"Unexpected token: {token_str}")
    return speech_ids

#TTS start!
with torch.no_grad():
 
    formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"

    # Tokenize the text
    chat = [
        {"role": "user", "content": "Convert the text to speech:" + formatted_text},
        {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>"}
    ]

    input_ids = tokenizer.apply_chat_template(
        chat, 
        tokenize=True, 
        return_tensors='pt', 
        continue_final_message=True
    )
    input_ids = input_ids.to('cuda')
    speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')

    # Generate the speech autoregressively
    outputs = model.generate(
        input_ids,
        max_length=2048,  # We trained our model with a max length of 2048
        eos_token_id= speech_end_id ,
        do_sample=True,    
        top_p=1,           #  Adjusts the diversity of generated content
        temperature=0.8,   #  Controls randomness in output
    )
    # Extract the speech tokens
    generated_ids = outputs[0][input_ids.shape[1]:-1]

    speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)   

    # Convert  token <|s_23456|> to int 23456 
    speech_tokens = extract_speech_ids(speech_tokens)

    speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)

    # Decode the speech tokens to speech waveform
    gen_wav = Codec_model.decode_code(speech_tokens) 
 

sf.write("gen.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
```

**2. Speech synthesis utilizing a given speech prompt**

```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import soundfile as sf

llasa_3b ='HKUSTAudio/Llasa-3B'

tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
model = AutoModelForCausalLM.from_pretrained(llasa_3b)
model.eval() 
model.to('cuda')

from xcodec2.modeling_xcodec2 import XCodec2Model
 
model_path = "HKUSTAudio/xcodec2"  
 
Codec_model = XCodec2Model.from_pretrained(model_path)
Codec_model.eval().cuda()   
# only 16khz speech support!
prompt_wav, sr = sf.read("太乙真人.wav")   # you can find wav in Files
#prompt_wav, sr = sf.read("Anna.wav") # English prompt
prompt_wav = torch.from_numpy(prompt_wav).float().unsqueeze(0)  

prompt_text ="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"
#promt_text = "A chance to leave him alone, but... No. She just wanted to see him again. Anna, you don't know how it feels to lose a sister. Anna, I'm sorry, but your father asked me not to tell you anything."
target_text = '突然,身边一阵笑声。我看着他们,意气风发地挺直了胸膛,甩了甩那稍显肉感的双臂,轻笑道:"我身上的肉,是为了掩饰我爆棚的魅力,否则,岂不吓坏了你们呢?"'
#target_text = "Dealing with family secrets is never easy. Yet, sometimes, omission is a form of protection, intending to safeguard some from the harsh truths. One day, I hope you understand the reasons behind my actions. Until then, Anna, please, bear with me."
input_text = prompt_text   + target_text

def ids_to_speech_tokens(speech_ids):
 
    speech_tokens_str = []
    for speech_id in speech_ids:
        speech_tokens_str.append(f"<|s_{speech_id}|>")
    return speech_tokens_str

def extract_speech_ids(speech_tokens_str):
 
    speech_ids = []
    for token_str in speech_tokens_str:
        if token_str.startswith('<|s_') and token_str.endswith('|>'):
            num_str = token_str[4:-2]

            num = int(num_str)
            speech_ids.append(num)
        else:
            print(f"Unexpected token: {token_str}")
    return speech_ids

#TTS start!
with torch.no_grad():
    # Encode the prompt wav
    vq_code_prompt = Codec_model.encode_code(input_waveform=prompt_wav)
    print("Prompt Vq Code Shape:", vq_code_prompt.shape )   

    vq_code_prompt = vq_code_prompt[0,0,:]
    # Convert int 12345 to token <|s_12345|>
    speech_ids_prefix = ids_to_speech_tokens(vq_code_prompt)

    formatted_text = f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"

    # Tokenize the text and the speech prefix
    chat = [
        {"role": "user", "content": "Convert the text to speech:" + formatted_text},
        {"role": "assistant", "content": "<|SPEECH_GENERATION_START|>" + ''.join(speech_ids_prefix)}
    ]

    input_ids = tokenizer.apply_chat_template(
        chat, 
        tokenize=True, 
        return_tensors='pt', 
        continue_final_message=True
    )
    input_ids = input_ids.to('cuda')
    speech_end_id = tokenizer.convert_tokens_to_ids('<|SPEECH_GENERATION_END|>')

    # Generate the speech autoregressively
    outputs = model.generate(
        input_ids,
        max_length=2048,  # We trained our model with a max length of 2048
        eos_token_id= speech_end_id ,
        do_sample=True,
        top_p=1,           
        temperature=0.8,
    )
    # Extract the speech tokens
    generated_ids = outputs[0][input_ids.shape[1]-len(speech_ids_prefix):-1]

    speech_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)   

    # Convert  token <|s_23456|> to int 23456 
    speech_tokens = extract_speech_ids(speech_tokens)

    speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)

    # Decode the speech tokens to speech waveform
    gen_wav = Codec_model.decode_code(speech_tokens) 

    # if only need the generated part
    # gen_wav = gen_wav[:,:,prompt_wav.shape[1]:]

sf.write("gen.wav", gen_wav[0, 0, :].cpu().numpy(), 16000)
```


## Disclaimer

This model is licensed under the CC BY-NC 4.0 License, which prohibits free commercial use because of ethics and privacy concerns; detected violations will result in legal consequences.

This codebase is strictly prohibited from being used for any illegal purposes in any country or region. Please refer to your local laws about DMCA and other related laws.