|
from typing import Callable |
|
import numpy as np |
|
import onnxruntime as ort |
|
import os |
|
from rknnlite.api import RKNNLite |
|
import json |
|
import os |
|
import time |
|
|
|
class HParams: |
|
def __init__(self, **kwargs): |
|
for k, v in kwargs.items(): |
|
if type(v) == dict: |
|
v = HParams(**v) |
|
self[k] = v |
|
|
|
def keys(self): |
|
return self.__dict__.keys() |
|
|
|
def items(self): |
|
return self.__dict__.items() |
|
|
|
def values(self): |
|
return self.__dict__.values() |
|
|
|
def __len__(self): |
|
return len(self.__dict__) |
|
|
|
def __getitem__(self, key): |
|
return getattr(self, key) |
|
|
|
def __setitem__(self, key, value): |
|
return setattr(self, key, value) |
|
|
|
def __contains__(self, key): |
|
return key in self.__dict__ |
|
|
|
def __repr__(self): |
|
return self.__dict__.__repr__() |
|
|
|
@staticmethod |
|
def load_from_file(file_path:str): |
|
if not os.path.exists(file_path): |
|
raise FileNotFoundError(f"Can not found the configuration file \"{file_path}\"") |
|
with open(file_path, "r", encoding="utf-8") as f: |
|
hps = json.load(f) |
|
return HParams(**hps) |
|
|
|
class BaseClassForOnnxInfer(): |
|
@staticmethod |
|
def create_onnx_infer(infer_factor:Callable, onnx_model_path:str, providers:list, session_options:ort.SessionOptions = None, onnx_params:dict = None): |
|
if not os.path.exists(onnx_model_path): |
|
raise FileNotFoundError(f"Can not found the onnx model file \"{onnx_model_path}\"") |
|
session = ort.InferenceSession(onnx_model_path, sess_options=BaseClassForOnnxInfer.adjust_onnx_session_options(session_options), providers=providers, **(onnx_params or {})) |
|
fn = infer_factor(session) |
|
fn.__session = session |
|
return fn |
|
|
|
@staticmethod |
|
def get_def_onnx_session_options(): |
|
session_options = ort.SessionOptions() |
|
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
return session_options |
|
|
|
@staticmethod |
|
def adjust_onnx_session_options(session_options:ort.SessionOptions = None): |
|
return session_options or BaseClassForOnnxInfer.get_def_onnx_session_options() |
|
|
|
class OpenVoiceToneClone_ONNXRKNN(BaseClassForOnnxInfer): |
|
|
|
PreferredProviders = ['CPUExecutionProvider'] |
|
|
|
def __init__(self, model_path, execution_provider:str = None, verbose:bool = False, onnx_session_options:ort.SessionOptions = None, onnx_params:dict = None, target_length:int = 1024): |
|
''' |
|
Create the instance of the tone cloner |
|
|
|
Args: |
|
model_path (str): The path of the folder which contains the model |
|
execution_provider (str): The provider that onnxruntime used. Such as CPUExecutionProvider, CUDAExecutionProvider, etc. Or you can use CPU, CUDA as short one. If it is None, the constructor will choose a best one automaticlly |
|
verbose (bool): Set True to show more detail informations when working |
|
onnx_session_options (onnxruntime.SessionOptions): The custom options for onnx session |
|
onnx_params (dict): Other parameters you want to pass to the onnxruntime.InferenceSession constructor |
|
target_length (int): The target length for padding/truncating spectrogram, defaults to 1024 |
|
|
|
Returns: |
|
OpenVoiceToneClone_ONNX: The instance of the tone cloner |
|
''' |
|
self.__verbose = verbose |
|
self.__target_length = target_length |
|
|
|
if verbose: |
|
print("Loading the configuration...") |
|
config_path = os.path.join(model_path, "configuration.json") |
|
self.__hparams = HParams.load_from_file(config_path) |
|
|
|
execution_provider = f"{execution_provider}ExecutionProvider" if (execution_provider is not None) and (not execution_provider.endswith("ExecutionProvider")) else execution_provider |
|
available_providers = ort.get_available_providers() |
|
|
|
self.__execution_providers = ['CPUExecutionProvider'] |
|
if verbose: |
|
print("Creating onnx session for tone color extractor...") |
|
def se_infer_factor(session): |
|
return lambda **kwargs: session.run(None, kwargs)[0] |
|
self.__se_infer = self.create_onnx_infer(se_infer_factor, os.path.join(model_path, "tone_color_extract_model.onnx"), self.__execution_providers, onnx_session_options, onnx_params) |
|
|
|
if verbose: |
|
print("Creating RKNNLite session for tone clone ...") |
|
|
|
self.__tc_rknn = RKNNLite(verbose=verbose) |
|
|
|
ret = self.__tc_rknn.load_rknn(os.path.join(model_path, "tone_clone_model.rknn")) |
|
if ret != 0: |
|
raise RuntimeError("Failed to load RKNN model") |
|
|
|
ret = self.__tc_rknn.init_runtime() |
|
if ret != 0: |
|
raise RuntimeError("Failed to init RKNN runtime") |
|
|
|
def __del__(self): |
|
"""释放RKNN资源""" |
|
if hasattr(self, '_OpenVoiceToneClone_ONNXRKNN__tc_rknn'): |
|
self.__tc_rknn.release() |
|
|
|
hann_window = {} |
|
|
|
def __spectrogram_numpy(self, y, n_fft, sampling_rate, hop_size, win_size, onesided=True): |
|
if self.__verbose: |
|
if np.min(y) < -1.1: |
|
print("min value is ", np.min(y)) |
|
if np.max(y) > 1.1: |
|
print("max value is ", np.max(y)) |
|
|
|
|
|
y = np.pad( |
|
y, |
|
int((n_fft - hop_size) / 2), |
|
mode="reflect", |
|
) |
|
|
|
|
|
win_key = f"{str(y.dtype)}-{win_size}" |
|
if True or win_key not in hann_window: |
|
OpenVoiceToneClone_ONNXRKNN.hann_window[win_key] = np.hanning(win_size + 1)[:-1].astype(y.dtype) |
|
window = OpenVoiceToneClone_ONNXRKNN.hann_window[win_key] |
|
|
|
|
|
y_len = y.shape[0] |
|
win_len = window.shape[0] |
|
count = int((y_len - win_len) // hop_size) + 1 |
|
spec = np.empty((count, int(win_len / 2) + 1 if onesided else (int(win_len / 2) + 1) * 2, 2)) |
|
start = 0 |
|
end = start + win_len |
|
idx = 0 |
|
while end <= y_len: |
|
segment = y[start:end] |
|
frame = segment * window |
|
step_result = np.fft.rfft(frame) if onesided else np.fft.fft(frame) |
|
spec[idx] = np.column_stack((step_result.real, step_result.imag)) |
|
start = start + hop_size |
|
end = start + win_len |
|
idx += 1 |
|
|
|
|
|
spec = np.sqrt(np.sum(np.square(spec), axis=-1) + 1e-6) |
|
|
|
return np.array([spec], dtype=np.float32) |
|
|
|
def extract_tone_color(self, audio:np.array): |
|
''' |
|
Extract the tone color from an audio |
|
|
|
Args: |
|
audio (numpy.array): The data of the audio |
|
|
|
Returns: |
|
numpy.array: The tone color vector |
|
''' |
|
hps = self.__hparams |
|
y = self.to_mono(audio.astype(np.float32)) |
|
spec = self.__spectrogram_numpy(y, hps.data.filter_length, |
|
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, |
|
) |
|
|
|
if self.__verbose: |
|
print("spec shape", spec.shape) |
|
return self.__se_infer(input=spec).reshape(1,256,1) |
|
|
|
def mix_tone_color(self, colors:list): |
|
''' |
|
Mix multi tone colors to a single one |
|
|
|
Args: |
|
color (list[numpy.array]): The list of the tone colors you want to mix. Each element should be the result of extract_tone_color. |
|
|
|
Returns: |
|
numpy.array: The tone color vector |
|
''' |
|
return np.stack(colors).mean(axis=0) |
|
|
|
def tone_clone(self, audio:np.array, target_tone_color:np.array, tau=0.3): |
|
''' |
|
Clone the tone |
|
|
|
Args: |
|
audio (numpy.array): The data of the audio that will be changed the tone |
|
target_tone_color (numpy.array): The tone color that you want to clone. It should be the result of the extract_tone_color or mix_tone_color. |
|
tau (float): |
|
|
|
Returns: |
|
numpy.array: The dest audio |
|
''' |
|
assert (target_tone_color.shape == (1,256,1)), "The target tone color must be an array with shape (1,256,1)" |
|
hps = self.__hparams |
|
src = self.to_mono(audio.astype(np.float32)) |
|
src = self.__spectrogram_numpy(src, hps.data.filter_length, |
|
hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, |
|
) |
|
src_tone = self.__se_infer(input=src).reshape(1,256,1) |
|
|
|
src = np.transpose(src, (0, 2, 1)) |
|
|
|
original_length = src.shape[2] |
|
|
|
|
|
if original_length > self.__target_length: |
|
if self.__verbose: |
|
print(f"Input length {original_length} exceeds target length {self.__target_length}, truncating...") |
|
src = src[:, :, :self.__target_length] |
|
elif original_length < self.__target_length: |
|
if self.__verbose: |
|
print(f"Input length {original_length} is less than target length {self.__target_length}, padding...") |
|
pad_width = ((0, 0), (0, 0), (0, self.__target_length - original_length)) |
|
src = np.pad(src, pad_width, mode='constant', constant_values=0) |
|
|
|
src_length = np.array([self.__target_length], dtype=np.int64) |
|
|
|
if self.__verbose: |
|
print("src shape", src.shape) |
|
print("src_length shape", src_length.shape) |
|
print("src_tone shape", src_tone.shape) |
|
print("target_tone_color shape", target_tone_color.shape) |
|
print("tau", tau) |
|
|
|
|
|
inputs = [ |
|
src, |
|
src_length, |
|
src_tone, |
|
target_tone_color, |
|
np.array([tau], dtype=np.float32) |
|
] |
|
|
|
|
|
outputs = self.__tc_rknn.inference(inputs=inputs) |
|
res = outputs[0][0, 0] |
|
|
|
generated_multiplier = 262144 / 1024 |
|
|
|
if original_length < self.__target_length: |
|
res = res[:int(original_length * generated_multiplier)] |
|
|
|
if self.__verbose: |
|
print("res shape", res.shape) |
|
return res |
|
|
|
def to_mono(self, audio:np.array): |
|
''' |
|
Change the audio to be a mono audio |
|
|
|
Args: |
|
audio (numpy.array): The source audio |
|
|
|
Returns: |
|
numpy.array: The mono audio data |
|
''' |
|
return np.mean(audio, axis=1) if len(audio.shape) > 1 else audio |
|
|
|
def resample(self, audio:np.array, original_rate:int): |
|
''' |
|
Resample the audio to match the model. It is used for changing the sample rate of the audio. |
|
|
|
Args: |
|
audio (numpy.array): The source audio you want to resample. |
|
original_rate (int): The original sample rate of the source audio |
|
|
|
Returns: |
|
numpy.array: The dest data of the audio after resample |
|
''' |
|
audio = self.to_mono(audio) |
|
target_rate = self.__hparams.data.sampling_rate |
|
duration = audio.shape[0] / original_rate |
|
target_length = int(duration * target_rate) |
|
time_original = np.linspace(0, duration, num=audio.shape[0]) |
|
time_target = np.linspace(0, duration, num=target_length) |
|
resampled_data = np.interp(time_target, time_original, audio) |
|
return resampled_data |
|
|
|
@property |
|
def sample_rate(self): |
|
''' |
|
The sample rate of the tone cloning result |
|
''' |
|
return self.__hparams.data.sampling_rate |
|
|
|
|
|
tc = OpenVoiceToneClone_ONNXRKNN(".",verbose=True) |
|
import soundfile |
|
|
|
tgt = soundfile.read("target.wav", dtype='float32') |
|
tgt = tc.resample(tgt[0], tgt[1]) |
|
|
|
|
|
start_time = time.time() |
|
tgt_tone_color = tc.extract_tone_color(tgt) |
|
extract_time = time.time() - start_time |
|
print(f"提取音色特征耗时: {extract_time:.2f}秒") |
|
|
|
src = soundfile.read("src2.wav", dtype='float32') |
|
src = tc.resample(src[0], src[1]) |
|
|
|
|
|
start_time = time.time() |
|
result = tc.tone_clone(src, tgt_tone_color) |
|
clone_time = time.time() - start_time |
|
print(f"克隆音色耗时: {clone_time:.2f}秒") |
|
|
|
soundfile.write("result.wav", result, tc.sample_rate) |
|
|