import spaces import gradio as gr import torch import tempfile import os from vae_wrapper import VaeWrapper, encode_video_chunk from landmarks_extractor import LandmarksExtractor import decord from utils import ( get_raw_audio, save_audio_video, calculate_splits, instantiate_from_config, create_pipeline_inputs, ) from transformers import HubertModel from einops import rearrange import numpy as np from WavLM import WavLM_wrapper from omegaconf import OmegaConf from inference_functions import ( sample_keyframes, sample_interpolation, ) from wordle_game import WordleGame import torch.cuda.amp as amp # Import amp for mixed precision from huggingface_hub import snapshot_download # Define the repository ID repo_id = "toninio19/keysync" # Download the entire repository repo_path = snapshot_download(repo_id=repo_id) print(f"Repository downloaded to: {repo_path}") # Set default tensor type to float16 for faster computation if torch.cuda.is_available(): # torch.set_default_tensor_type(torch.cuda.FloatTensor) # Enable TF32 precision for better performance on Ampere+ GPUs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Cache for video and audio processing cache = { "video": { "path": None, "embedding": None, "frames": None, "landmarks": None, }, "audio": { "path": None, "raw_audio": None, "hubert_embedding": None, "wavlm_embedding": None, }, } # Create mixed precision scaler scaler = amp.GradScaler() def load_model( config: str, device: str = "cuda", ckpt: str = None, ): """ Load a model from configuration. Args: config: Path to model configuration file device: Device to load the model on num_frames: Number of frames to process input_key: Input key for the model ckpt: Optional checkpoint path Returns: Tuple of (model, filter, batch size) """ config = OmegaConf.load(config) config["model"]["params"]["input_key"] = "latents" if ckpt is not None: config.model.params.ckpt_path = ckpt with torch.device(device): model = instantiate_from_config(config.model).to(device).eval() # Convert model to half precision if torch.cuda.is_available(): model = model.half() model.first_stage_model = model.first_stage_model.float() print("Converted model to FP16 precision") # Compile model for faster inference if torch.cuda.is_available(): try: model = torch.compile(model) print(f"Successfully compiled model with torch.compile()") except Exception as e: print(f"Warning: Failed to compile model: {e}") return model # Default media paths DEFAULT_VIDEO_PATH = os.path.join( os.path.dirname(__file__), "assets", "sample_video.mp4" ) DEFAULT_AUDIO_PATH = os.path.join( os.path.dirname(__file__), "assets", "sample_audio.wav" ) # @spaces.GPU(duration=60) # def load_all_models(): # global \ # keyframe_model, \ # interpolation_model, \ # vae_model, \ # hubert_model, \ # wavlm_model, \ # landmarks_extractor # vae_model = VaeWrapper("video") # vae_model = vae_model.half() # Convert to half precision # try: # vae_model = torch.compile(vae_model) # print("Successfully compiled vae_model in FP16") # except Exception as e: # print(f"Warning: Failed to compile vae_model: {e}") # hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda() # hubert_model = hubert_model.half() # Convert to half precision # try: # hubert_model = torch.compile(hubert_model) # print("Successfully compiled hubert_model in FP16") # except Exception as e: # print(f"Warning: Failed to compile hubert_model: {e}") # wavlm_model = WavLM_wrapper( # model_size="Base+", # feed_as_frames=False, # merge_type="None", # model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"), # ).cuda() # wavlm_model = wavlm_model.half() # Convert to half precision # try: # wavlm_model = torch.compile(wavlm_model) # print("Successfully compiled wavlm_model in FP16") # except Exception as e: # print(f"Warning: Failed to compile wavlm_model: {e}") # landmarks_extractor = LandmarksExtractor() # keyframe_model = load_model( # config="keyframe.yaml", # ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"), # ) # interpolation_model = load_model( # config="interpolation.yaml", # ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"), # ) # keyframe_model.en_and_decode_n_samples_a_time = 2 # interpolation_model.en_and_decode_n_samples_a_time = 2 # return ( # keyframe_model, # interpolation_model, # vae_model, # hubert_model, # wavlm_model, # landmarks_extractor, # ) # ( # keyframe_model, # interpolation_model, # vae_model, # hubert_model, # wavlm_model, # landmarks_extractor, # ) = load_all_models() keyframe_model = None interpolation_model = None vae_model = None hubert_model = None wavlm_model = None landmarks_extractor = None @spaces.GPU(duration=60) @torch.no_grad() def compute_video_embedding(video_reader, min_len, vae_model): """Compute embeddings from video""" total_frames = min_len encoded = [] video_frames = [] chunk_size = 16 resolution = 512 # # Create a progress bar for Gradio progress = gr.Progress() # Calculate total chunks for progress tracking total_chunks = (total_frames + chunk_size - 1) // chunk_size for i, start_idx in enumerate(range(0, total_frames, chunk_size)): # Update progress bar progress(i / total_chunks, desc="Processing video chunks") end_idx = min(start_idx + chunk_size, total_frames) video_chunk = video_reader.get_batch(range(start_idx, end_idx)) # Interpolate video chunk to the target resolution video_chunk = rearrange(video_chunk, "f h w c -> f c h w") video_chunk = torch.nn.functional.interpolate( video_chunk, size=(resolution, resolution), mode="bilinear", align_corners=False, ) video_chunk = rearrange(video_chunk, "f c h w -> f h w c") video_frames.append(video_chunk) # Convert chunk to FP16 if using CUDA if torch.cuda.is_available(): video_chunk = video_chunk.half() # Always use autocast for FP16 computation with amp.autocast(enabled=True): encoded.append(encode_video_chunk(vae_model, video_chunk, resolution)) encoded = torch.cat(encoded, dim=0) video_frames = torch.cat(video_frames, dim=0) video_frames = rearrange(video_frames, "f h w c -> f c h w") torch.cuda.empty_cache() return encoded, video_frames @spaces.GPU(duration=60) @torch.no_grad() def compute_hubert_embedding(raw_audio, hubert_model): """Compute embeddings from audio""" print(f"Computing audio embedding from {raw_audio.shape}") audio = ( (raw_audio - raw_audio.mean()) / torch.sqrt(raw_audio.var() + 1e-7) ).unsqueeze(0) chunks = 16000 * 20 # Create a progress bar for Gradio progress = gr.Progress() # Get audio embeddings audio_embeddings = [] splits = list(calculate_splits(audio, chunks)) total_splits = len(splits) for i, chunk in enumerate(splits): # Update progress bar progress(i / total_splits, desc="Processing audio chunks") # Convert audio chunk to half precision if torch.cuda.is_available(): chunk_cuda = chunk.cuda().half() else: chunk_cuda = chunk.cuda() # Always use autocast for FP16 computation with amp.autocast(enabled=True): hidden_states = hubert_model(chunk_cuda)[0] audio_embeddings.append(hidden_states) audio_embeddings = torch.cat(audio_embeddings, dim=1) # audio_embeddings = self.model.wav2vec2(rearrange(audio_frames, "f s -> () (f s)"))[0] if audio_embeddings.shape[1] % 2 != 0: audio_embeddings = torch.cat( [audio_embeddings, torch.zeros_like(audio_embeddings[:, :1])], dim=1 ) audio_embeddings = rearrange(audio_embeddings, "() (f d) c -> f d c", d=2) torch.cuda.empty_cache() return audio_embeddings @spaces.GPU(duration=60) @torch.no_grad() def compute_wavlm_embedding(raw_audio, wavlm_model): """Compute embeddings from audio""" audio = rearrange(raw_audio, "(f s) -> f s", s=640) if audio.shape[0] % 2 != 0: audio = torch.cat([audio, torch.zeros(1, 640)], dim=0) chunks = 500 # Create a progress bar for Gradio progress = gr.Progress() # Get audio embeddings audio_embeddings = [] splits = list(calculate_splits(audio, chunks)) total_splits = len(splits) for i, chunk in enumerate(splits): # Update progress bar progress(i / total_splits, desc="Processing audio chunks") # Convert chunk to half precision if torch.cuda.is_available(): chunk_cuda = chunk.unsqueeze(0).cuda().half() else: chunk_cuda = chunk.unsqueeze(0).cuda() # Always use autocast for FP16 computation with amp.autocast(enabled=True): wavlm_hidden_states = wavlm_model(chunk_cuda).squeeze(0) audio_embeddings.append(wavlm_hidden_states) audio_embeddings = torch.cat(audio_embeddings, dim=0) torch.cuda.empty_cache() return audio_embeddings @torch.no_grad() def extract_video_landmarks(video_frames, landmarks_extractor): """Extract landmarks from video frames""" # Create a progress bar for Gradio progress = gr.Progress() landmarks = [] batch_size = 10 for i in range(0, len(video_frames), batch_size): # Update progress bar progress(i / len(video_frames), desc="Extracting facial landmarks") batch = video_frames[i : i + batch_size].cpu().float() batch_landmarks = landmarks_extractor.extract_landmarks(batch) landmarks.extend(batch_landmarks) torch.cuda.empty_cache() # Convert landmarks to a list of numpy arrays with consistent shape processed_landmarks = [] expected_shape = (68, 2) # Common shape for facial landmarks # Process each landmark to ensure consistent shape last_valid_landmark = None for i, lm in enumerate(landmarks): if lm is not None and isinstance(lm, np.ndarray) and lm.shape == expected_shape: processed_landmarks.append(lm) last_valid_landmark = lm else: # Print information about inconsistent landmarks if lm is None: print(f"Warning: Landmark at index {i} is None") elif not isinstance(lm, np.ndarray): print( f"Warning: Landmark at index {i} is not a numpy array, type: {type(lm)}" ) elif lm.shape != expected_shape: print( f"Warning: Landmark at index {i} has shape {lm.shape}, expected {expected_shape}" ) # Replace invalid landmarks with the closest valid landmark if available if last_valid_landmark is not None: processed_landmarks.append(last_valid_landmark.copy()) else: # If no valid landmark has been seen yet, look ahead for a valid one found_future_valid = False for future_lm in landmarks[i + 1 :]: if ( future_lm is not None and isinstance(future_lm, np.ndarray) and future_lm.shape == expected_shape ): processed_landmarks.append(future_lm.copy()) found_future_valid = True break # If no valid landmark found in the future, use zeros if not found_future_valid: processed_landmarks.append(np.zeros(expected_shape)) return np.array(processed_landmarks) @spaces.GPU(duration=180) @torch.no_grad() def sample( audio_list, gt_keyframes, masks_keyframes, to_remove, test_keyframes_list, num_frames, device, emb, force_uc_zero_embeddings, n_batch_keyframes, n_batch, test_interpolation_list, audio_interpolation_list, masks_interpolation, gt_interpolation, model_keyframes, model, ): # Create a progress bar for Gradio progress = gr.Progress() condition = torch.zeros(1, 3, 512, 512).to(device) if torch.cuda.is_available(): condition = condition.half() audio_list = rearrange(audio_list, "(b t) c d -> b t c d", t=num_frames) gt_keyframes = rearrange(gt_keyframes, "(b t) c h w -> b t c h w", t=num_frames) # Rearrange masks_keyframes and save locally masks_keyframes = rearrange( masks_keyframes, "(b t) c h w -> b t c h w", t=num_frames ) # Convert to_remove into chunks of num_frames to_remove_chunks = [ to_remove[i : i + num_frames] for i in range(0, len(to_remove), num_frames) ] test_keyframes_list = [ test_keyframes_list[i : i + num_frames] for i in range(0, len(test_keyframes_list), num_frames) ] audio_cond = audio_list if emb is not None: embbedings = emb.unsqueeze(0).to(device) if torch.cuda.is_available(): embbedings = embbedings.half() else: embbedings = None # One batch of keframes is approximately 7 seconds chunk_size = 2 complete_video = [] start_idx = 0 last_frame_z = None last_frame_x = None last_keyframe_idx = None last_to_remove = None total_chunks = (len(audio_cond) + chunk_size - 1) // chunk_size for chunk_idx, chunk_start in enumerate(range(0, len(audio_cond), chunk_size)): # Update progress bar progress(chunk_idx / total_chunks, desc="Generating video") # Clear GPU cache between chunks torch.cuda.empty_cache() chunk_end = min(chunk_start + chunk_size, len(audio_cond)) chunk_audio_cond = audio_cond[chunk_start:chunk_end].cuda() if torch.cuda.is_available(): chunk_audio_cond = chunk_audio_cond.half() chunk_gt_keyframes = gt_keyframes[chunk_start:chunk_end].cuda() chunk_masks = masks_keyframes[chunk_start:chunk_end].cuda() if torch.cuda.is_available(): chunk_gt_keyframes = chunk_gt_keyframes.half() chunk_masks = chunk_masks.half() test_keyframes_list_unwrapped = [ elem for sublist in test_keyframes_list[chunk_start:chunk_end] for elem in sublist ] to_remove_chunks_unwrapped = [ elem for sublist in to_remove_chunks[chunk_start:chunk_end] for elem in sublist ] if last_keyframe_idx is not None: test_keyframes_list_unwrapped = [ last_keyframe_idx ] + test_keyframes_list_unwrapped to_remove_chunks_unwrapped = [last_to_remove] + to_remove_chunks_unwrapped last_keyframe_idx = test_keyframes_list_unwrapped[-1] last_to_remove = to_remove_chunks_unwrapped[-1] # Find the first non-None keyframe in the chunk first_keyframe = next( (kf for kf in test_keyframes_list_unwrapped if kf is not None), None ) # Find the last non-None keyframe in the chunk last_keyframe = next( (kf for kf in reversed(test_keyframes_list_unwrapped) if kf is not None), None, ) start_idx = next( ( idx for idx, comb in enumerate(test_interpolation_list) if comb[0] == first_keyframe ), None, ) end_idx = next( ( idx for idx, comb in enumerate(reversed(test_interpolation_list)) if comb[1] == last_keyframe ), None, ) if start_idx is not None and end_idx is not None: end_idx = ( len(test_interpolation_list) - 1 - end_idx ) # Adjust for reversed enumeration end_idx += 1 if start_idx is None: break if end_idx < start_idx: end_idx = len(audio_interpolation_list) audio_interpolation_list_chunk = audio_interpolation_list[start_idx:end_idx] chunk_masks_interpolation = masks_interpolation[start_idx:end_idx] gt_interpolation_chunks = gt_interpolation[start_idx:end_idx] if torch.cuda.is_available(): audio_interpolation_list_chunk = [ chunk.half() for chunk in audio_interpolation_list_chunk ] chunk_masks_interpolation = [ chunk.half() for chunk in chunk_masks_interpolation ] gt_interpolation_chunks = [ chunk.half() for chunk in gt_interpolation_chunks ] progress(chunk_idx / total_chunks, desc="Generating keyframes") # Always use autocast for FP16 computation with amp.autocast(enabled=True): samples_z = sample_keyframes( model_keyframes, chunk_audio_cond, chunk_gt_keyframes, chunk_masks, condition.cuda(), num_frames, 24, 0.0, device, embbedings.cuda() if embbedings is not None else None, force_uc_zero_embeddings, n_batch_keyframes, 0, 1.0, None, gt_as_cond=False, ) if last_frame_x is not None: # samples_x = torch.cat([last_frame_x.unsqueeze(0), samples_x], axis=0) samples_z = torch.cat([last_frame_z.unsqueeze(0), samples_z], axis=0) # last_frame_x = samples_x[-1] last_frame_z = samples_z[-1] progress(chunk_idx / total_chunks, desc="Interpolating frames") # Always use autocast for FP16 computation with amp.autocast(enabled=True): vid = sample_interpolation( model, samples_z, # samples_x, audio_interpolation_list_chunk, gt_interpolation_chunks, chunk_masks_interpolation, condition.cuda(), num_frames, device, 1, 24, 0.0, force_uc_zero_embeddings, n_batch, chunk_size, 1.0, None, cut_audio=False, to_remove=to_remove_chunks_unwrapped, ) if chunk_start == 0: complete_video = vid else: complete_video = np.concatenate([complete_video[:-1], vid], axis=0) return complete_video @spaces.GPU(duration=180) @torch.no_grad() def process_video(video_input, audio_input): """Main processing function to generate synchronized video""" # Display a message to the user about the processing time gr.Info("Processing video. This may take a while...", duration=10) gr.Info( "If you're tired of waiting, try playing the Wordle game in the other tab!", duration=10, ) max_num_seconds = 6 global \ vae_model, \ hubert_model, \ wavlm_model, \ landmarks_extractor, \ keyframe_model, \ interpolation_model if vae_model is None: vae_model = VaeWrapper("video") vae_model = vae_model.half() # Convert to half precision try: vae_model = torch.compile(vae_model) print("Successfully compiled vae_model in FP16") except Exception as e: print(f"Warning: Failed to compile vae_model: {e}") if hubert_model is None: hubert_model = HubertModel.from_pretrained("facebook/hubert-base-ls960").cuda() hubert_model = hubert_model.half() # Convert to half precision try: hubert_model = torch.compile(hubert_model) print("Successfully compiled hubert_model in FP16") except Exception as e: print(f"Warning: Failed to compile hubert_model: {e}") if wavlm_model is None: wavlm_model = WavLM_wrapper( model_size="Base+", feed_as_frames=False, merge_type="None", model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"), ).cuda() wavlm_model = wavlm_model.half() # Convert to half precision try: wavlm_model = torch.compile(wavlm_model) print("Successfully compiled wavlm_model in FP16") except Exception as e: print(f"Warning: Failed to compile wavlm_model: {e}") if landmarks_extractor is None: landmarks_extractor = LandmarksExtractor() if keyframe_model is None: keyframe_model = load_model( config="keyframe.yaml", ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"), ) if interpolation_model is None: interpolation_model = load_model( config="interpolation.yaml", ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"), ) keyframe_model.en_and_decode_n_samples_a_time = 2 interpolation_model.en_and_decode_n_samples_a_time = 2 # Use default media if none provided if video_input is None: video_input = DEFAULT_VIDEO_PATH print(f"Using default video: {DEFAULT_VIDEO_PATH}") if audio_input is None: audio_input = DEFAULT_AUDIO_PATH print(f"Using default audio: {DEFAULT_AUDIO_PATH}") # try: # Calculate hashes for cache keys video_path_hash = video_input audio_path_hash = audio_input # Check if we need to recompute video embeddings video_cache_hit = cache["video"]["path"] == video_path_hash audio_cache_hit = cache["audio"]["path"] == audio_path_hash if video_cache_hit and audio_cache_hit: print("Using cached video and audio computations") # Make copies of cached data to avoid modifying cache video_embedding = cache["video"]["embedding"].clone() video_frames = cache["video"]["frames"].clone() video_landmarks = cache["video"]["landmarks"].copy() raw_audio = cache["audio"]["raw_audio"].clone() raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") hubert_embedding = cache["audio"]["hubert_embedding"].clone() wavlm_embedding = cache["audio"]["wavlm_embedding"].clone() # Ensure all data is truncated to the same length if needed min_len = min( len(video_frames), len(raw_audio), len(hubert_embedding), len(wavlm_embedding), ) video_frames = video_frames[:min_len] video_embedding = video_embedding[:min_len] video_landmarks = video_landmarks[:min_len] raw_audio = raw_audio[:min_len] hubert_embedding = hubert_embedding[:min_len] wavlm_embedding = wavlm_embedding[:min_len] raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") else: # Process video if needed if not video_cache_hit: print("Computing video embeddings and landmarks") video_reader = decord.VideoReader(video_input) decord.bridge.set_bridge("torch") if not audio_cache_hit: # Need to process audio to determine min_len raw_audio = get_raw_audio(audio_input, 16000) if len(raw_audio) == 0 or len(video_reader) == 0: raise ValueError("Empty audio or video input") min_len = min(len(raw_audio), len(video_reader)) # Store full audio in cache cache["audio"]["path"] = audio_path_hash cache["audio"]["raw_audio"] = raw_audio.clone() # Create truncated copy for processing raw_audio = raw_audio[:min_len] raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") else: # Use cached audio - make a copy if cache["audio"]["raw_audio"] is None: raise ValueError("Cached audio is None") raw_audio = cache["audio"]["raw_audio"].clone() if len(raw_audio) == 0 or len(video_reader) == 0: raise ValueError("Empty cached audio or video input") min_len = min(len(raw_audio), len(video_reader)) # Create truncated copy for processing raw_audio = raw_audio[:min_len] raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") # Compute video embeddings and landmarks - store full version in cache video_embedding, video_frames = compute_video_embedding( video_reader, len(video_reader), vae_model ) video_landmarks = extract_video_landmarks(video_frames, landmarks_extractor) # Update video cache with full versions cache["video"]["path"] = video_path_hash cache["video"]["embedding"] = video_embedding cache["video"]["frames"] = video_frames cache["video"]["landmarks"] = video_landmarks # Create truncated copies for processing video_embedding = video_embedding[:min_len] video_frames = video_frames[:min_len] video_landmarks = video_landmarks[:min_len] else: # Use cached video data - make copies print("Using cached video computations") if ( cache["video"]["embedding"] is None or cache["video"]["frames"] is None or cache["video"]["landmarks"] is None ): raise ValueError("One or more video cache entries are None") if not audio_cache_hit: # New audio with cached video raw_audio = get_raw_audio(audio_input, 16000) if len(raw_audio) == 0: raise ValueError("Empty audio input") # Store full audio in cache cache["audio"]["path"] = audio_path_hash cache["audio"]["raw_audio"] = raw_audio.clone() # Make copies of video data video_embedding = cache["video"]["embedding"].clone() video_frames = cache["video"]["frames"].clone() video_landmarks = cache["video"]["landmarks"].copy() # Determine truncation length and create truncated copies min_len = min(len(raw_audio), len(video_frames)) raw_audio = raw_audio[:min_len] raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") video_frames = video_frames[:min_len] video_embedding = video_embedding[:min_len] video_landmarks = video_landmarks[:min_len] else: # Both video and audio are cached - should not reach here # as it's handled in the first if statement pass # Process audio if needed if not audio_cache_hit: print("Computing audio embeddings") # Compute audio embeddings with the truncated audio hubert_embedding = compute_hubert_embedding(raw_audio_reshape, hubert_model) wavlm_embedding = compute_wavlm_embedding(raw_audio_reshape, wavlm_model) # Update audio cache with full embeddings # Note: raw_audio was already cached above cache["audio"]["hubert_embedding"] = hubert_embedding.clone() cache["audio"]["wavlm_embedding"] = wavlm_embedding.clone() else: # Use cached audio data - make copies if ( cache["audio"]["hubert_embedding"] is None or cache["audio"]["wavlm_embedding"] is None ): raise ValueError("One or more audio embedding cache entries are None") hubert_embedding = cache["audio"]["hubert_embedding"].clone() wavlm_embedding = cache["audio"]["wavlm_embedding"].clone() # Make sure embeddings match the truncated video length if needed if "min_len" in locals() and ( min_len < len(hubert_embedding) or min_len < len(wavlm_embedding) ): hubert_embedding = hubert_embedding[:min_len] wavlm_embedding = wavlm_embedding[:min_len] # Apply max_num_seconds limit if specified if max_num_seconds > 0: # Convert seconds to frames (assuming 25 fps) max_frames = int(max_num_seconds * 25) # Truncate all data to max_frames video_embedding = video_embedding[:max_frames] video_frames = video_frames[:max_frames] video_landmarks = video_landmarks[:max_frames] hubert_embedding = hubert_embedding[:max_frames] wavlm_embedding = wavlm_embedding[:max_frames] raw_audio = raw_audio[:max_frames] raw_audio_reshape = rearrange(raw_audio, "f s -> (f s)") # Validate shapes before proceeding assert video_embedding.shape[0] == hubert_embedding.shape[0], ( f"Video embedding length ({video_embedding.shape[0]}) doesn't match Hubert embedding length ({hubert_embedding.shape[0]})" ) assert video_embedding.shape[0] == wavlm_embedding.shape[0], ( f"Video embedding length ({video_embedding.shape[0]}) doesn't match WavLM embedding length ({wavlm_embedding.shape[0]})" ) assert video_embedding.shape[0] == video_landmarks.shape[0], ( f"Video embedding length ({video_embedding.shape[0]}) doesn't match landmarks length ({video_landmarks.shape[0]})" ) print(f"Hubert embedding shape: {hubert_embedding.shape}") print(f"WavLM embedding shape: {wavlm_embedding.shape}") print(f"Video embedding shape: {video_embedding.shape}") print(f"Video landmarks shape: {video_landmarks.shape}") # Create pipeline inputs for models ( interpolation_chunks, keyframe_chunks, audio_interpolation_chunks, audio_keyframe_chunks, emb_cond, masks_keyframe_chunks, masks_interpolation_chunks, to_remove, audio_interpolation_idx, audio_keyframe_idx, ) = create_pipeline_inputs( hubert_embedding, wavlm_embedding, 14, video_embedding, video_landmarks, overlap=1, add_zero_flag=True, mask_arms=None, nose_index=28, ) complete_video = sample( audio_keyframe_chunks, keyframe_chunks, masks_keyframe_chunks, to_remove, audio_keyframe_idx, 14, "cuda", emb_cond, [], 3, 3, audio_interpolation_idx, audio_interpolation_chunks, masks_interpolation_chunks, interpolation_chunks, keyframe_model, interpolation_model, ) complete_audio = rearrange(raw_audio[: complete_video.shape[0]], "f s -> () (f s)") # 4. Convert frames to video and combine with audio with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_video: output_path = temp_video.name print("Saving video to", output_path) save_audio_video(complete_video, audio=complete_audio, save_path=output_path) torch.cuda.empty_cache() return output_path # except Exception as e: # raise e # print(f"Error processing video: {str(e)}") # return None # def get_max_duration(video_input, audio_input): # """Get the maximum duration in seconds for the slider""" # try: # # Default to 60 seconds if files don't exist # if video_input is None or not os.path.exists(video_input): # video_input = DEFAULT_VIDEO_PATH # if audio_input is None or not os.path.exists(audio_input): # audio_input = DEFAULT_AUDIO_PATH # # Get video duration # video_reader = decord.VideoReader(video_input) # video_duration = len(video_reader) / video_reader.get_avg_fps() # # Get audio duration # raw_audio = get_raw_audio(audio_input, 16000) # audio_duration = len(raw_audio) / 25 # Assuming 25 fps # # Return the minimum of the two durations # return min(video_duration, audio_duration) # except Exception as e: # print(f"Error getting max duration: {str(e)}") # return 60 # Default to 60 seconds def new_game_click(state): """Handle the 'New Game' button click.""" message = state.new_game() feedback_history = state.get_feedback_history() return state, feedback_history, message def submit_guess_click(guess, state): """Handle the 'Submit Guess' button click.""" message = state.submit_guess(guess) feedback_history = state.get_feedback_history() return state, feedback_history, message # Create Gradio interface with gr.Blocks( title="KeySync: A Robust Approach for Leakage-free Lip Synchronization in High Resolution" ) as demo: gr.Markdown( "# KeySync: A Robust Approach for Leakage-free Lip Synchronization in High Resolution" ) gr.Markdown( "Upload a video and audio to create a synchronized video with the same visuals but synchronized to the new audio." ) with gr.Tabs(): with gr.TabItem("Video Synchronization"): with gr.Row(): with gr.Column(): video_input = gr.Video( label="Input Video", value=DEFAULT_VIDEO_PATH if os.path.exists(DEFAULT_VIDEO_PATH) else None, width=512, height=512, ) audio_input = gr.Audio( label="Input Audio", type="filepath", value=DEFAULT_AUDIO_PATH if os.path.exists(DEFAULT_AUDIO_PATH) else None, ) # max_duration = gr.State(value=60) # Default max duration # max_seconds_slider = gr.Slider( # minimum=0, # maximum=60, # Will be updated dynamically # value=0, # step=1, # label="Max Duration (seconds, 0 = full length)", # info="Limit the processing duration (0 means use full length)", # ) process_button = gr.Button("Generate Synchronized Video") with gr.Column("Output Video"): video_output = gr.Video(label="Output Video", width=512, height=512) # # Update slider max value when inputs change # def update_slider_max(video, audio): # max_dur = get_max_duration(video, audio) # return {"maximum": max_dur, "__type__": "update"} # video_input.change( # update_slider_max, [video_input, audio_input], [max_seconds_slider] # ) # audio_input.change( # update_slider_max, [video_input, audio_input], [max_seconds_slider] # ) # Show Wordle message when processing starts and hide when complete process_button.click( fn=process_video, inputs=[video_input, audio_input], outputs=video_output, ) with gr.TabItem("Wordle Game"): state = gr.State(WordleGame()) # Persist the WordleGame instance guess_input = gr.Textbox(label="Your guess (5 letters)", max_length=5) submit_btn = gr.Button("Submit Guess") new_game_btn = gr.Button("New Game") feedback_display = gr.HTML(label="Guesses") message_display = gr.Textbox( label="Message", interactive=False, value="Click 'New Game' to start." ) # Connect the 'New Game' button new_game_btn.click( fn=new_game_click, inputs=[state], outputs=[state, feedback_display, message_display], ) # Connect the 'Submit Guess' button submit_btn.click( fn=submit_guess_click, inputs=[guess_input, state], outputs=[state, feedback_display, message_display], ) gr.Markdown("## How it works") gr.Markdown(""" 1. The system extracts embeddings and landmarks from the input video 2. Audio embeddings are computed from the input audio 3. A keyframe model generates key visual frames 4. An interpolation model creates a smooth video between keyframes 5. The final video is rendered with the new audio """) gr.Markdown(""" ## Limitations Due to GPU restrictions on Hugging Face Spaces, the demo is limited to processing videos of maximum 6 seconds in length. For longer videos or better performance, we recommend using the inference scripts provided in this repository (https://github.com/antonibigata/keysync) to run KeySync locally on your own hardware. """) if __name__ == "__main__": import spaces demo.launch()