keysync-demo / utils.py
Antoni Bigata
first commit
b5ce381
raw
history blame contribute delete
11.4 kB
import torchvision
from einops import rearrange
import numpy as np
import math
import torchaudio
import torch
import importlib
from data_utils import create_masks_from_landmarks_box
import torch.nn.functional as F
def save_audio_video(
video,
audio=None,
frame_rate=25,
sample_rate=16000,
save_path="temp.mp4",
):
"""Save audio and video to a single file.
video: (t, c, h, w)
audio: (channels t)
"""
save_path = str(save_path)
if isinstance(video, torch.Tensor):
video = video.cpu().numpy()
video_tensor = rearrange(video, "t c h w -> t h w c").astype(np.uint8)
print("video_tensor shape", video_tensor.shape)
print("audio shape", audio.shape)
if audio is not None:
# Assuming audio is a tensor of shape (channels, samples)
audio_tensor = audio
torchvision.io.write_video(
save_path,
video_tensor,
fps=frame_rate,
audio_array=audio_tensor,
audio_fps=sample_rate,
video_codec="h264", # Specify a codec to address the error
audio_codec="aac",
)
else:
torchvision.io.write_video(
save_path,
video_tensor,
fps=frame_rate,
video_codec="h264", # Specify a codec to address the error
audio_codec="aac",
)
return save_path
def trim_pad_audio(audio, sr, max_len_sec=None, max_len_raw=None):
len_file = audio.shape[-1]
if max_len_sec or max_len_raw:
max_len = max_len_raw if max_len_raw is not None else int(max_len_sec * sr)
if len_file < int(max_len):
# dummy = np.zeros((1, int(max_len_sec * sr) - len_file))
# extened_wav = np.concatenate((audio_data, dummy[0]))
extened_wav = torch.nn.functional.pad(
audio, (0, int(max_len) - len_file), "constant"
)
else:
extened_wav = audio[:, : int(max_len)]
else:
extened_wav = audio
return extened_wav
def get_raw_audio(audio_path, audio_rate, fps=25):
audio, sr = torchaudio.load(audio_path, channels_first=True)
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=audio_rate)[0]
samples_per_frame = math.ceil(audio_rate / fps)
n_frames = audio.shape[-1] / samples_per_frame
if not n_frames.is_integer():
audio = trim_pad_audio(
audio, audio_rate, max_len_raw=math.ceil(n_frames) * samples_per_frame
)
audio = rearrange(audio, "(f s) -> f s", s=samples_per_frame)
return audio
def calculate_splits(tensor, min_last_size):
# Check the total number of elements in the tensor
total_size = tensor.size(1) # size along the second dimension
# If total size is less than the minimum size for the last split, return the tensor as a single split
if total_size <= min_last_size:
return [tensor]
# Calculate number of splits and size of each split
num_splits = (total_size - min_last_size) // min_last_size + 1
base_size = (total_size - min_last_size) // num_splits
# Create split sizes list
split_sizes = [base_size] * (num_splits - 1)
split_sizes.append(
total_size - sum(split_sizes)
) # Ensure the last split has at least min_last_size
# Adjust sizes to ensure they sum exactly to total_size
sum_sizes = sum(split_sizes)
while sum_sizes != total_size:
for i in range(num_splits):
if sum_sizes < total_size:
split_sizes[i] += 1
sum_sizes += 1
if sum_sizes >= total_size:
break
# Split the tensor
splits = torch.split(tensor, split_sizes, dim=1)
return splits
def make_into_multiple_of(x, multiple, dim=0):
"""Make the torch tensor into a multiple of the given number."""
if x.shape[dim] % multiple != 0:
x = torch.cat(
[
x,
torch.zeros(
*x.shape[:dim],
multiple - (x.shape[dim] % multiple),
*x.shape[dim + 1 :],
).to(x.device),
],
dim=dim,
)
return x
def default(value, default_value):
return default_value if value is None else value
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False, invalidate_cache=True):
module, cls = string.rsplit(".", 1)
if invalidate_cache:
importlib.invalidate_caches()
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def load_landmarks(
landmarks: np.ndarray,
original_size,
target_size=(64, 64),
nose_index=28,
):
"""
Load and process facial landmarks to create masks.
Args:
landmarks: Facial landmarks array
original_size: Original size of the video frames
index: Index for non-dub mode
target_size: Target size for the output mask
is_dub: Whether this is for dubbing mode
what_mask: Type of mask to create ("full", "box", "heart", "mouth")
nose_index: Index of the nose landmark
Returns:
Processed landmarks mask
"""
expand_box = 0.0
if len(landmarks.shape) == 2:
landmarks = landmarks[None, ...]
mask = create_masks_from_landmarks_box(
landmarks,
(original_size[0], original_size[1]),
box_expand=expand_box,
nose_index=nose_index,
)
mask = F.interpolate(mask.unsqueeze(1).float(), size=target_size, mode="nearest")
return mask
def create_pipeline_inputs(
audio: torch.Tensor,
audio_interpolation: torch.Tensor,
num_frames: int,
video_emb: torch.Tensor,
landmarks: np.ndarray,
overlap: int = 1,
add_zero_flag: bool = False,
mask_arms: bool = None,
nose_index: int = 28,
):
"""
Create inputs for the keyframe generation and interpolation pipeline.
Args:
video: Input video tensor
audio: Audio embeddings for keyframe generation
audio_interpolation: Audio embeddings for interpolation
num_frames: Number of frames per segment
video_emb: Optional video embeddings
landmarks: Facial landmarks for mask generation
overlap: Number of frames to overlap between segments
add_zero_flag: Whether to add zero flag every num_frames
what_mask: Type of mask to generate ("box" or other options)
mask_arms: Optional mask for arms region
nose_index: Index of the nose landmark point
Returns:
Tuple containing all necessary inputs for the pipeline
"""
audio_interpolation_chunks = []
audio_image_preds = []
gt_chunks = []
gt_keyframes_chunks = []
# Adjustment for overlap to ensure segments are created properly
step = num_frames - overlap
# Ensure there's at least one step forward on each iteration
if step < 1:
step = 1
audio_image_preds_idx = []
audio_interp_preds_idx = []
masks_chunks = []
masks_interpolation_chunks = []
for i in range(0, audio.shape[0] - num_frames + 1, step):
try:
audio[i + num_frames - 1]
except IndexError:
break # Last chunk is smaller than num_frames
segment_end = i + num_frames
gt_chunks.append(video_emb[i:segment_end])
masks = load_landmarks(
landmarks[i:segment_end],
(512, 512),
target_size=(64, 64),
nose_index=nose_index,
)
if mask_arms is not None:
masks = np.logical_and(
masks, np.logical_not(mask_arms[i:segment_end, None, ...])
)
masks_interpolation_chunks.append(masks)
if i not in audio_image_preds_idx:
audio_image_preds.append(audio[i])
masks_chunks.append(masks[0])
gt_keyframes_chunks.append(video_emb[i])
audio_image_preds_idx.append(i)
if segment_end - 1 not in audio_image_preds_idx:
audio_image_preds_idx.append(segment_end - 1)
audio_image_preds.append(audio[segment_end - 1])
masks_chunks.append(masks[-1])
gt_keyframes_chunks.append(video_emb[segment_end - 1])
audio_interpolation_chunks.append(audio_interpolation[i:segment_end])
audio_interp_preds_idx.append([i, segment_end - 1])
# If the flag is on, add element 0 every 14 audio elements
if add_zero_flag:
first_element = audio_image_preds[0]
len_audio_image_preds = (
len(audio_image_preds) + (len(audio_image_preds) + 1) % num_frames
)
for i in range(0, len_audio_image_preds, num_frames):
audio_image_preds.insert(i, first_element)
audio_image_preds_idx.insert(i, None)
masks_chunks.insert(i, masks_chunks[0])
gt_keyframes_chunks.insert(i, gt_keyframes_chunks[0])
to_remove = [idx is None for idx in audio_image_preds_idx]
audio_image_preds_idx_clone = [idx for idx in audio_image_preds_idx]
if add_zero_flag:
# Remove the added elements from the list
audio_image_preds_idx = [
sample for i, sample in zip(to_remove, audio_image_preds_idx) if not i
]
interpolation_cond_list = []
for i in range(0, len(audio_image_preds_idx) - 1, overlap if overlap > 0 else 2):
interpolation_cond_list.append(
[audio_image_preds_idx[i], audio_image_preds_idx[i + 1]]
)
# Since we generate num_frames at a time, we need to ensure that the last chunk is of size num_frames
# Calculate the number of frames needed to make audio_image_preds a multiple of num_frames
frames_needed = (num_frames - (len(audio_image_preds) % num_frames)) % num_frames
# Extend from the start of audio_image_preds
audio_image_preds = audio_image_preds + [audio_image_preds[-1]] * frames_needed
masks_chunks = masks_chunks + [masks_chunks[-1]] * frames_needed
gt_keyframes_chunks = (
gt_keyframes_chunks + [gt_keyframes_chunks[-1]] * frames_needed
)
to_remove = to_remove + [True] * frames_needed
audio_image_preds_idx_clone = (
audio_image_preds_idx_clone + [audio_image_preds_idx_clone[-1]] * frames_needed
)
print(
f"Added {frames_needed} frames from the start to make audio_image_preds a multiple of {num_frames}"
)
# random_cond_idx = np.random.randint(0, len(video_emb))
random_cond_idx = 0
assert len(to_remove) == len(audio_image_preds), (
"to_remove and audio_image_preds must have the same length"
)
return (
gt_chunks,
gt_keyframes_chunks,
audio_interpolation_chunks,
audio_image_preds,
video_emb[random_cond_idx],
masks_chunks,
masks_interpolation_chunks,
to_remove,
audio_interp_preds_idx,
audio_image_preds_idx_clone,
)