|
import torch |
|
import torch.nn as nn |
|
from openvoice.api import ToneColorConverter |
|
from openvoice.models import SynthesizerTrn |
|
import os |
|
|
|
os.chdir(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
class ToneColorExtractWrapper(nn.Module): |
|
def __init__(self, model): |
|
super().__init__() |
|
self.model = model |
|
|
|
def forward(self, audio): |
|
|
|
|
|
audio = audio.contiguous() |
|
|
|
g = self.model.ref_enc(audio) |
|
|
|
|
|
return g |
|
|
|
class ToneCloneWrapper(nn.Module): |
|
def __init__(self, model): |
|
super().__init__() |
|
self.model = model |
|
|
|
def forward(self, audio, audio_lengths, src_tone, dest_tone, tau): |
|
|
|
audio = audio.contiguous() |
|
src_tone = src_tone.contiguous() |
|
dest_tone = dest_tone.contiguous() |
|
|
|
|
|
o_hat, _, _ = self.model.voice_conversion( |
|
audio, |
|
audio_lengths, |
|
sid_src=src_tone, |
|
sid_tgt=dest_tone, |
|
tau=tau[0] |
|
) |
|
return o_hat |
|
|
|
def export_models(ckpt_path, output_dir, target_audio_lens, source_audio_lens): |
|
""" |
|
导出音色提取和克隆模型为ONNX格式 |
|
|
|
Args: |
|
ckpt_path: 模型检查点路径 |
|
output_dir: 输出目录 |
|
target_audio_lens: 目标音频长度列表 |
|
source_audio_lens: 源音频长度列表 |
|
""" |
|
|
|
|
|
device = "cpu" |
|
converter = ToneColorConverter(f'{ckpt_path}/config.json', device=device) |
|
converter.load_ckpt(f'{ckpt_path}/checkpoint.pth') |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
extract_wrapper = ToneColorExtractWrapper(converter.model) |
|
extract_wrapper.eval() |
|
|
|
for source_len in source_audio_lens: |
|
dummy_input = torch.randn(1, source_len, 513).contiguous() |
|
output_path = f"{output_dir}/tone_color_extract_model.onnx" |
|
|
|
torch.onnx.export( |
|
extract_wrapper, |
|
dummy_input, |
|
output_path, |
|
input_names=['input'], |
|
output_names=['tone_embedding'], |
|
dynamic_axes={ |
|
'input': {1: 'source_audio_len'}, |
|
}, |
|
opset_version=11, |
|
do_constant_folding=True, |
|
verbose=True |
|
) |
|
print(f"Exported tone extract model to {output_path}") |
|
|
|
|
|
clone_wrapper = ToneCloneWrapper(converter.model) |
|
clone_wrapper.eval() |
|
|
|
for target_len in target_audio_lens: |
|
dummy_inputs = ( |
|
torch.randn(1, 513, target_len).contiguous(), |
|
torch.LongTensor([target_len]), |
|
torch.randn(1, 256, 1).contiguous(), |
|
torch.randn(1, 256, 1).contiguous(), |
|
torch.FloatTensor([0.3]) |
|
) |
|
|
|
output_path = f"{output_dir}/tone_clone_model.onnx" |
|
|
|
torch.onnx.export( |
|
clone_wrapper, |
|
dummy_inputs, |
|
output_path, |
|
input_names=['audio', 'audio_length', 'src_tone', 'dest_tone', 'tau'], |
|
output_names=['converted_audio'], |
|
dynamic_axes={ |
|
'audio': {2: 'target_audio_len'}, |
|
}, |
|
opset_version=17, |
|
do_constant_folding=True, |
|
verbose=True |
|
) |
|
print(f"Exported tone clone model to {output_path}") |
|
|
|
if __name__ == "__main__": |
|
|
|
TARGET_AUDIO_LENS = [1024] |
|
SOURCE_AUDIO_LENS = [1024] |
|
|
|
export_models( |
|
ckpt_path="checkpoints_v2/converter", |
|
output_dir="onnx_models", |
|
target_audio_lens=TARGET_AUDIO_LENS, |
|
source_audio_lens=SOURCE_AUDIO_LENS |
|
) |