File size: 3,390 Bytes
02e65c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e113534
02e65c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# coding=utf-8
# Copyright 2024 The EMOVA team and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" EMOVASpeechTokenizer model """

import math
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel

try:
    from emova_speech_tokenizer.speech_utils import get_S2U_ckpt_config_path, load_config, VQCTCFinetuneModel, s2u_extract_unit_demo
    from emova_speech_tokenizer.speech_utils import get_U2S_config_checkpoint_file, load_U2S_config, SynthesizerTrn, synthesis
except:
    raise ImportError('Dependencies of emova speech tokenizer are not installed properly. Check https://github.com/emova-ollm/EMOVA_speech_tokenizer#installation for detailed instructions.')

from .configuration_emova_speech_tokenizer import EMOVASpeechTokenizerConfig

class EMOVASpeechTokenizer(PreTrainedModel):
    config_class = EMOVASpeechTokenizerConfig
    base_model_prefix = "emova_speech_tokenizer"

    def __init__(self, config: EMOVASpeechTokenizerConfig):
        super().__init__(config)
        self.config = config
        
        # s2u encoder configs
        _, S2U_config_path = get_S2U_ckpt_config_path(config.s2u_unit_type)
        s2u_cfg = load_config(config=S2U_config_path)
        s2u_cfg.model.pretrain_chkpt_path = None
        
        # u2s decoder configs
        U2S_config_file, _ = get_U2S_config_checkpoint_file(config.u2s_unit_type)
        u2s_cfg = load_U2S_config(U2S_config_file)
        
        # construct models
        self.s2u_config = s2u_cfg.model
        self.u2s_config = u2s_cfg
        self.encoder = VQCTCFinetuneModel(s2u_cfg.model, trainer=None)
        self.decoder = SynthesizerTrn(
            u2s_cfg.num_symbols,
            u2s_cfg.data.filter_length // 2 + 1,
            u2s_cfg.train.segment_size // u2s_cfg.data.hop_length,
            n_speakers=u2s_cfg.data.n_speakers,
            **u2s_cfg.model
        )
        self.style_embedding = nn.Embedding(config.u2s_num_styles, config.u2s_dim_styles)

    @property
    def device(self):
        return next(self.encoder.parameters()).device

    @property
    def dtype(self):
        return next(self.encoder.parameters()).dtype

    def encode(self, wav_file):
        speech_unit = s2u_extract_unit_demo(self.encoder, wav_file, model_name='SPIRAL-FSQ-CTC', reduced=True)
        return speech_unit

    def decode(self, speech_unit, condition=None, output_wav_file='output.wav'):
        content_unit = speech_unit.replace('<|speech_', '').replace('|>', ' ').strip()
        style_centroid_embedding = self.style_embedding(torch.LongTensor([self.config.u2s_style2idx[condition]]).to(self.device)).unsqueeze(-1) if condition else None
        audio = synthesis(content_unit, style_centroid_embedding, self.u2s_config, self.decoder, output_wav_file)
        return audio