OpenVoice-RKNN2 / export_onnx.py
happyme531's picture
Upload 10 files
cc403c3 verified
import torch
import torch.nn as nn
from openvoice.api import ToneColorConverter
from openvoice.models import SynthesizerTrn
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
class ToneColorExtractWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, audio):
# audio: [1, source_audio_len, 513]
# 将mel谱图转置为模型需要的格式 [1, 513, source_audio_len]
audio = audio.contiguous()
# 提取声纹
g = self.model.ref_enc(audio)
# 扩展最后一维
# g = g.unsqueeze(-1) # [1, 256, 1]
return g
class ToneCloneWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, audio, audio_lengths, src_tone, dest_tone, tau):
# 确保张量连续
audio = audio.contiguous()
src_tone = src_tone.contiguous()
dest_tone = dest_tone.contiguous()
# 语音转换
o_hat, _, _ = self.model.voice_conversion(
audio,
audio_lengths,
sid_src=src_tone,
sid_tgt=dest_tone,
tau=tau[0]
)
return o_hat
def export_models(ckpt_path, output_dir, target_audio_lens, source_audio_lens):
"""
导出音色提取和克隆模型为ONNX格式
Args:
ckpt_path: 模型检查点路径
output_dir: 输出目录
target_audio_lens: 目标音频长度列表
source_audio_lens: 源音频长度列表
"""
# 加载模型
device = "cpu"
converter = ToneColorConverter(f'{ckpt_path}/config.json', device=device)
converter.load_ckpt(f'{ckpt_path}/checkpoint.pth')
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
# 导出音色提取模型
extract_wrapper = ToneColorExtractWrapper(converter.model)
extract_wrapper.eval()
for source_len in source_audio_lens:
dummy_input = torch.randn(1, source_len, 513).contiguous()
output_path = f"{output_dir}/tone_color_extract_model.onnx"
torch.onnx.export(
extract_wrapper,
dummy_input,
output_path,
input_names=['input'],
output_names=['tone_embedding'],
dynamic_axes={
'input': {1: 'source_audio_len'},
},
opset_version=11,
do_constant_folding=True,
verbose=True
)
print(f"Exported tone extract model to {output_path}")
# 导出音色克隆模型
clone_wrapper = ToneCloneWrapper(converter.model)
clone_wrapper.eval()
for target_len in target_audio_lens:
dummy_inputs = (
torch.randn(1, 513, target_len).contiguous(), # audio
torch.LongTensor([target_len]), # audio_lengths
torch.randn(1, 256, 1).contiguous(), # src_tone
torch.randn(1, 256, 1).contiguous(), # dest_tone
torch.FloatTensor([0.3]) # tau
)
output_path = f"{output_dir}/tone_clone_model.onnx"
torch.onnx.export(
clone_wrapper,
dummy_inputs,
output_path,
input_names=['audio', 'audio_length', 'src_tone', 'dest_tone', 'tau'],
output_names=['converted_audio'],
dynamic_axes={
'audio': {2: 'target_audio_len'},
},
opset_version=17,
do_constant_folding=True,
verbose=True
)
print(f"Exported tone clone model to {output_path}")
if __name__ == "__main__":
# 示例用法
TARGET_AUDIO_LENS = [1024] # 根据需要设置目标长度
SOURCE_AUDIO_LENS = [1024] # 根据需要设置源长度
export_models(
ckpt_path="checkpoints_v2/converter",
output_dir="onnx_models",
target_audio_lens=TARGET_AUDIO_LENS,
source_audio_lens=SOURCE_AUDIO_LENS
)