Video-to-Audio-and-Piano / src /inference_v2p.py
lshzhm's picture
Upload 141 files
1991049 verified
import sys
if len(sys.argv) >= 6:
ckpt = sys.argv[1]
drop_prompt = bool(int(sys.argv[2]))
test_scp = sys.argv[3]
start = int(sys.argv[4])
end = int(sys.argv[5])
step = 1
out_dir = sys.argv[6]
print("inference", ckpt, drop_prompt, test_scp, start, end, out_dir)
else:
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more/98500.pt"
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more/190000.pt"
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more/315000.pt"
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more_more/60000.pt"
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more_more_piano5/4_2_8000.pt"
ckpt = "./ckpts/piano5_4_2_8000.pt"
#ckpt = "/ailab-train/speech/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more_more_piano6/dpo_100.pt"
drop_prompt = False
####test_scp = "/ailab-train/speech/zhanghaomin/scps/VGGSound/test.scp"
#test_scp = "/ailab-train/speech/zhanghaomin/scps/instruments/test.scp"
#test_scp = "/ailab-train/speech/zhanghaomin/scps/instruments/piano_2h/test.scp"
test_scp = "./tests/piano_2h_test.scp"
#test_scp = "/ailab-train/speech/zhanghaomin/scps/instruments/piano_20h/v2a_giant_piano2/test.scp"
start = 0
end = 2
step = 1
####out_dir = "./outputs_vgg/"
out_dir = "./outputs_piano/"
#out_dir = "./outputs2t_20h_dpo/"
import torch
from e2_tts_pytorch.e2_tts_crossatt3 import E2TTS, DurationPredictor
from e2_tts_pytorch.e2_tts_crossatt3 import MelSpec, EncodecWrapper
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from e2_tts_pytorch.trainer_multigpus_alldatas3 import HFDataset, Text2AudioDataset
from einops import einsum, rearrange, repeat, reduce, pack, unpack
import torchaudio
from datetime import datetime
import json
import numpy as np
import os
from moviepy.editor import VideoFileClip, AudioFileClip
import traceback
audiocond_drop_prob = 1.1
#audiocond_drop_prob = 0.3
#cond_proj_in_bias = True
#cond_drop_prob = 1.1
cond_drop_prob = -0.1
prompt_drop_prob = -0.1
#prompt_drop_prob = 1.1
video_text = True
def main():
#duration_predictor = DurationPredictor(
# transformer = dict(
# dim = 512,
# depth = 6,
# )
#)
duration_predictor = None
e2tts = E2TTS(
duration_predictor = duration_predictor,
transformer = dict(
#depth = 12,
#dim = 512,
#heads = 8,
#dim_head = 64,
depth = 12,
dim = 1024,
dim_text = 1280,
heads = 16,
dim_head = 64,
if_text_modules = (cond_drop_prob < 1.0),
if_cross_attn = (prompt_drop_prob < 1.0),
if_audio_conv = True,
if_text_conv = True,
),
#tokenizer = 'char_utf8',
tokenizer = 'phoneme_zh',
audiocond_drop_prob = audiocond_drop_prob,
cond_drop_prob = cond_drop_prob,
prompt_drop_prob = prompt_drop_prob,
frac_lengths_mask = (0.7, 1.0),
#audiocond_snr = None,
#audiocond_snr = (5.0, 10.0),
if_cond_proj_in = (audiocond_drop_prob < 1.0),
#cond_proj_in_bias = cond_proj_in_bias,
if_embed_text = (cond_drop_prob < 1.0) and (not video_text),
if_text_encoder2 = (prompt_drop_prob < 1.0),
if_clip_encoder = video_text,
video_encoder = "clip_vit",
pretrained_vocos_path = 'facebook/encodec_24khz',
num_channels = 128,
sampling_rate = 24000,
)
e2tts = e2tts.to("cuda")
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec/3000.pt", map_location="cpu")
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more/500.pt", map_location="cpu")
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more/98500.pt", map_location="cpu")
#checkpoint = torch.load("/ckptstorage/zhanghaomin/e2/e2_tts_experiment_v2a_encodec_more_more_more/190000.pt", map_location="cpu")
checkpoint = torch.load(ckpt, map_location="cpu")
#for key in list(checkpoint['model_state_dict'].keys()):
# if key.startswith('mel_spec.'):
# del checkpoint['model_state_dict'][key]
# if key.startswith('transformer.text_registers'):
# del checkpoint['model_state_dict'][key]
e2tts.load_state_dict(checkpoint['model_state_dict'], strict=False)
e2tts.vocos = EncodecWrapper("facebook/encodec_24khz")
for param in e2tts.vocos.parameters():
param.requires_grad = False
e2tts.vocos.eval()
e2tts.vocos.to("cuda")
#dataset = HFDataset(load_dataset("parquet", data_files={"test": "/ckptstorage/zhanghaomin/tts/GLOBE/data/test-*.parquet"})["test"])
#sample = dataset[1]
#mel_spec_raw = sample["mel_spec"].unsqueeze(0)
#mel_spec = rearrange(mel_spec_raw, 'b d n -> b n d')
#print(mel_spec.shape, sample["text"])
#out_dir = "/user-fs/zhanghaomin/v2a_generated/v2a_190000_tests/"
#out_dir = "/user-fs/zhanghaomin/v2a_generated/tv2a_98500_clips/"
if not os.path.exists(out_dir):
os.makedirs(out_dir)
#bs = list(range(10)) + [14,16]
bs = None
SCORE_THRESHOLD_TRAIN = '{"/zhanghaomin/datas/audiocaps": -9999.0, "/radiostorage/WavCaps": -9999.0, "/radiostorage/AudioGroup": 9999.0, "/ckptstorage/zhanghaomin/audioset": -9999.0, "/ckptstorage/zhanghaomin/BBCSoundEffects": 9999.0, "/ckptstorage/zhanghaomin/CLAP_freesound": 9999.0, "/zhanghaomin/datas/musiccap": -9999.0, "/ckptstorage/zhanghaomin/TangoPromptBank": -9999.0, "audioset": "af-audioset", "/ckptstorage/zhanghaomin/audiosetsl": 9999.0, "/ckptstorage/zhanghaomin/giantsoundeffects": -9999.0}' # /root/datasets/ /radiostorage/
SCORE_THRESHOLD_TRAIN = json.loads(SCORE_THRESHOLD_TRAIN)
for key in SCORE_THRESHOLD_TRAIN:
if key == "audioset":
continue
if SCORE_THRESHOLD_TRAIN[key] <= -9000.0:
SCORE_THRESHOLD_TRAIN[key] = -np.inf
print("SCORE_THRESHOLD_TRAIN", SCORE_THRESHOLD_TRAIN)
stft = EncodecWrapper("facebook/encodec_24khz")
eval_dataset = Text2AudioDataset(None, "val_instruments", None, None, None, -1, -1, stft, 0, True, SCORE_THRESHOLD_TRAIN, "/zhanghaomin/codes2/audiocaption/msclapcap_v1.list", -1.0, 1, 1, [drop_prompt], None, 0, vgg_test=[test_scp, start, end, step], video_encoder="clip_vit")
####eval_dataset = Text2AudioDataset(None, "val_vggsound", None, None, None, -1, -1, stft, 0, True, SCORE_THRESHOLD_TRAIN, "/zhanghaomin/codes2/audiocaption/msclapcap_v1.list", -1.0, 1, 1, [drop_prompt], None, 0, vgg_test=[test_scp, start, end, step], video_encoder="clip_vit")
eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=1, collate_fn=eval_dataset.collate_fn, num_workers=1, drop_last=False, pin_memory=True)
i = 0
for b, batch in enumerate(eval_dataloader):
if (bs is not None) and (b not in bs):
continue
#text, mel_spec, _, mel_lengths = batch
text, mel_spec, video_paths, mel_lengths, video_drop_prompt, audio_drop_prompt, frames, midis = batch
print(mel_spec.shape, mel_lengths, text, video_paths, video_drop_prompt, audio_drop_prompt, frames.shape if frames is not None and not isinstance(frames, float) else frames, midis.shape if midis is not None else midis, midis.sum() if midis is not None else midis)
text = text[i:i+1]
mel_spec = mel_spec[i:i+1, 0:mel_lengths[i], :]
mel_lengths = mel_lengths[i:i+1]
video_paths = video_paths[i:i+1]
video_path = out_dir + video_paths[0].replace("/", "__")
audio_path = video_path.replace(".mp4", ".wav")
name = video_paths[0].rsplit("/", 1)[1].rsplit(".", 1)[0]
num = 1
l = mel_lengths[0]
#cond = mel_spec.repeat(num, 1, 1)
cond = torch.randn(num, l, e2tts.num_channels)
duration = torch.tensor([l]*num, dtype=torch.int32)
lens = torch.tensor([l]*num, dtype=torch.int32)
print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "start")
#e2tts.sample(text=[""]*num, duration=duration.to("cuda"), lens=lens.to("cuda"), cond=cond.to("cuda"), save_to_filename="test.wav", steps=16, cfg_strength=3.0, remove_parallel_component=False, sway_sampling=True)
e2tts.sample(text=None, duration=duration.to("cuda"), lens=lens.to("cuda"), cond=cond.to("cuda"), save_to_filename=audio_path, steps=64, prompt=text*num, video_drop_prompt=video_drop_prompt, audio_drop_prompt=audio_drop_prompt, cfg_strength=2.0, remove_parallel_component=False, sway_sampling=True, video_paths=video_paths, frames=(frames if frames is None or isinstance(frames, float) else frames.to("cuda")), midis=(midis if midis is None else midis.to("cuda")))
print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "sample")
#one_audio = e2tts.vocos.decode(mel_spec_raw.to("cuda"))
#one_audio = e2tts.vocos.decode(cond.transpose(-1,-2).to("cuda"))
#print(datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3], "vocoder")
#torchaudio.save("ref.wav", one_audio.detach().cpu(), sample_rate = e2tts.sampling_rate)
try:
os.system("cp \"" + video_paths[0] + "\" \"" + video_path + "\"")
video = VideoFileClip(video_path)
audio = AudioFileClip(audio_path)
print("duration", video.duration, audio.duration)
if video.duration >= audio.duration:
video = video.subclip(0, audio.duration)
else:
audio = audio.subclip(0, video.duration)
final_video = video.set_audio(audio)
final_video.write_videofile(video_path.replace(".mp4", ".v2a.mp4"), codec="libx264", audio_codec="aac")
print("\"" + video_path.replace(".mp4", ".v2a.mp4") + "\"")
except Exception as e:
print("Exception write_videofile:", video_path.replace(".mp4", ".v2a.mp4"))
traceback.print_exc()
if False:
if not os.path.exists(out_dir+"groundtruth/"):
os.makedirs(out_dir+"groundtruth/")
if not os.path.exists(out_dir+"generated/"):
os.makedirs(out_dir+"generated/")
duration_gt = video.duration
duration_gr = final_video.duration
duration = min(duration_gt, duration_gr)
audio_gt = video.audio.subclip(0, duration)
audio_gr = final_video.audio.subclip(0, duration)
audio_gt.write_audiofile(out_dir+"groundtruth/"+name+".wav", fps=24000)
audio_gr.write_audiofile(out_dir+"generated/"+name+".wav", fps=24000)
if __name__ == "__main__":
main()