File size: 3,390 Bytes
02e65c1 e113534 02e65c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# coding=utf-8
# Copyright 2024 The EMOVA team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" EMOVASpeechTokenizer model """
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
try:
from emova_speech_tokenizer.speech_utils import get_S2U_ckpt_config_path, load_config, VQCTCFinetuneModel, s2u_extract_unit_demo
from emova_speech_tokenizer.speech_utils import get_U2S_config_checkpoint_file, load_U2S_config, SynthesizerTrn, synthesis
except:
raise ImportError('Dependencies of emova speech tokenizer are not installed properly. Check https://github.com/emova-ollm/EMOVA_speech_tokenizer#installation for detailed instructions.')
from .configuration_emova_speech_tokenizer import EMOVASpeechTokenizerConfig
class EMOVASpeechTokenizer(PreTrainedModel):
config_class = EMOVASpeechTokenizerConfig
base_model_prefix = "emova_speech_tokenizer"
def __init__(self, config: EMOVASpeechTokenizerConfig):
super().__init__(config)
self.config = config
# s2u encoder configs
_, S2U_config_path = get_S2U_ckpt_config_path(config.s2u_unit_type)
s2u_cfg = load_config(config=S2U_config_path)
s2u_cfg.model.pretrain_chkpt_path = None
# u2s decoder configs
U2S_config_file, _ = get_U2S_config_checkpoint_file(config.u2s_unit_type)
u2s_cfg = load_U2S_config(U2S_config_file)
# construct models
self.s2u_config = s2u_cfg.model
self.u2s_config = u2s_cfg
self.encoder = VQCTCFinetuneModel(s2u_cfg.model, trainer=None)
self.decoder = SynthesizerTrn(
u2s_cfg.num_symbols,
u2s_cfg.data.filter_length // 2 + 1,
u2s_cfg.train.segment_size // u2s_cfg.data.hop_length,
n_speakers=u2s_cfg.data.n_speakers,
**u2s_cfg.model
)
self.style_embedding = nn.Embedding(config.u2s_num_styles, config.u2s_dim_styles)
@property
def device(self):
return next(self.encoder.parameters()).device
@property
def dtype(self):
return next(self.encoder.parameters()).dtype
def encode(self, wav_file):
speech_unit = s2u_extract_unit_demo(self.encoder, wav_file, model_name='SPIRAL-FSQ-CTC', reduced=True)
return speech_unit
def decode(self, speech_unit, condition=None, output_wav_file='output.wav'):
content_unit = speech_unit.replace('<|speech_', '').replace('|>', ' ').strip()
style_centroid_embedding = self.style_embedding(torch.LongTensor([self.config.u2s_style2idx[condition]]).to(self.device)).unsqueeze(-1) if condition else None
audio = synthesis(content_unit, style_centroid_embedding, self.u2s_config, self.decoder, output_wav_file)
return audio |