|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" EMOVASpeechTokenizer model """ |
|
|
|
import math |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
try: |
|
from emova_speech_tokenizer.speech_utils import get_S2U_ckpt_config_path, load_config, VQCTCFinetuneModel, s2u_extract_unit_demo |
|
from emova_speech_tokenizer.speech_utils import get_U2S_config_checkpoint_file, load_U2S_config, SynthesizerTrn, synthesis |
|
except: |
|
raise ImportError('Dependencies of emova speech tokenizer are not installed properly. Check https://github.com/emova-ollm/EMOVA_speech_tokenizer#installation for detailed instructions.') |
|
|
|
from .configuration_emova_speech_tokenizer import EMOVASpeechTokenizerConfig |
|
|
|
class EMOVASpeechTokenizer(PreTrainedModel): |
|
config_class = EMOVASpeechTokenizerConfig |
|
base_model_prefix = "emova_speech_tokenizer" |
|
|
|
def __init__(self, config: EMOVASpeechTokenizerConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
_, S2U_config_path = get_S2U_ckpt_config_path(config.s2u_unit_type) |
|
s2u_cfg = load_config(config=S2U_config_path) |
|
s2u_cfg.model.pretrain_chkpt_path = None |
|
|
|
|
|
U2S_config_file, _ = get_U2S_config_checkpoint_file(config.u2s_unit_type) |
|
u2s_cfg = load_U2S_config(U2S_config_file) |
|
|
|
|
|
self.s2u_config = s2u_cfg.model |
|
self.u2s_config = u2s_cfg |
|
self.encoder = VQCTCFinetuneModel(s2u_cfg.model, trainer=None) |
|
self.decoder = SynthesizerTrn( |
|
u2s_cfg.num_symbols, |
|
u2s_cfg.data.filter_length // 2 + 1, |
|
u2s_cfg.train.segment_size // u2s_cfg.data.hop_length, |
|
n_speakers=u2s_cfg.data.n_speakers, |
|
**u2s_cfg.model |
|
) |
|
self.style_embedding = nn.Embedding(config.u2s_num_styles, config.u2s_dim_styles) |
|
|
|
@property |
|
def device(self): |
|
return next(self.encoder.parameters()).device |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.encoder.parameters()).dtype |
|
|
|
def encode(self, wav_file): |
|
speech_unit = s2u_extract_unit_demo(self.encoder, wav_file, model_name='SPIRAL-FSQ-CTC', reduced=True) |
|
return speech_unit |
|
|
|
def decode(self, speech_unit, condition=None, output_wav_file='output.wav'): |
|
content_unit = speech_unit.replace('<|speech_', '').replace('|>', ' ').strip() |
|
style_centroid_embedding = self.style_embedding(torch.LongTensor([self.config.u2s_style2idx[condition]]).to(self.device)).unsqueeze(-1) if condition else None |
|
audio = synthesis(content_unit, style_centroid_embedding, self.u2s_config, self.decoder, output_wav_file) |
|
return audio |