Antoni Bigata commited on
Commit
13f5b7a
·
1 Parent(s): 892dd0f

add checkpoint download

Browse files
Files changed (3) hide show
  1. app.py +15 -5
  2. interpolation.yaml +152 -0
  3. keyframe.yaml +154 -0
app.py CHANGED
@@ -25,6 +25,16 @@ from inference_functions import (
25
  from wordle_game import WordleGame
26
  import torch.cuda.amp as amp # Import amp for mixed precision
27
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Set default tensor type to float16 for faster computation
30
  if torch.cuda.is_available():
@@ -136,7 +146,7 @@ def load_all_models():
136
  model_size="Base+",
137
  feed_as_frames=False,
138
  merge_type="None",
139
- model_path="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/WavLM-Base+.pt",
140
  ).cuda()
141
 
142
  wavlm_model = wavlm_model.half() # Convert to half precision
@@ -148,12 +158,12 @@ def load_all_models():
148
 
149
  landmarks_extractor = LandmarksExtractor()
150
  keyframe_model = load_model(
151
- config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/keyframe.yaml",
152
- ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/keyframe_dub.pt",
153
  )
154
  interpolation_model = load_model(
155
- config="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/scripts/sampling/configs/interpolation.yaml",
156
- ckpt="/vol/paramonos2/projects/antoni/code/Personal/code_prep/keysync/pretrained_models/checkpoints/interpolation_dub.pt",
157
  )
158
  keyframe_model.en_and_decode_n_samples_a_time = 2
159
  interpolation_model.en_and_decode_n_samples_a_time = 2
 
25
  from wordle_game import WordleGame
26
  import torch.cuda.amp as amp # Import amp for mixed precision
27
 
28
+ from huggingface_hub import snapshot_download
29
+
30
+ # Define the repository ID
31
+ repo_id = "toninio19/keysync"
32
+
33
+ # Download the entire repository
34
+ repo_path = snapshot_download(repo_id=repo_id)
35
+
36
+ print(f"Repository downloaded to: {repo_path}")
37
+
38
 
39
  # Set default tensor type to float16 for faster computation
40
  if torch.cuda.is_available():
 
146
  model_size="Base+",
147
  feed_as_frames=False,
148
  merge_type="None",
149
+ model_path=os.path.join(repo_path, "checkpoints/WavLM-Base+.pt"),
150
  ).cuda()
151
 
152
  wavlm_model = wavlm_model.half() # Convert to half precision
 
158
 
159
  landmarks_extractor = LandmarksExtractor()
160
  keyframe_model = load_model(
161
+ config="keyframe.yaml",
162
+ ckpt=os.path.join(repo_path, "checkpoints/keyframe_dub.pt"),
163
  )
164
  interpolation_model = load_model(
165
+ config="interpolation.yaml",
166
+ ckpt=os.path.join(repo_path, "checkpoints/interpolation_dub.pt"),
167
  )
168
  keyframe_model.en_and_decode_n_samples_a_time = 2
169
  interpolation_model.en_and_decode_n_samples_a_time = 2
interpolation.yaml ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ scale_factor: 0.18215
5
+ disable_first_stage_autocast: True
6
+ ckpt_path:
7
+
8
+ denoiser_config:
9
+ target: sgm.modules.diffusionmodules.denoiser.DenoiserDub
10
+ params:
11
+ scaling_config:
12
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
13
+
14
+ network_wrapper:
15
+ target: sgm.modules.diffusionmodules.wrappers.InterpolationWrapper
16
+ params:
17
+ im_size: [512, 512] # USER: adapt this to your dataset
18
+ n_channels: 4
19
+ starting_mask_method: zeros
20
+ add_mask: True
21
+
22
+ network_config:
23
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
24
+ params:
25
+ adm_in_channels: 0
26
+ num_classes: sequential
27
+ use_checkpoint: True
28
+ in_channels: 9
29
+ out_channels: 4
30
+ model_channels: 320
31
+ attention_resolutions: [4, 2, 1]
32
+ num_res_blocks: 2
33
+ channel_mult: [1, 2, 4, 4]
34
+ num_head_channels: 64
35
+ use_linear_in_transformer: True
36
+ transformer_depth: 1
37
+ context_dim: 1024
38
+ spatial_transformer_attn_type: softmax-xformers
39
+ extra_ff_mix_layer: True
40
+ use_spatial_context: True
41
+ merge_strategy: learned_with_images
42
+ video_kernel_size: [3, 1, 1]
43
+ fine_tuning_method: null
44
+ audio_cond_method: both_keyframes
45
+ additional_audio_frames: 0
46
+ audio_dim: 1024
47
+ unfreeze_blocks: ["input"]
48
+
49
+ conditioner_config:
50
+ target: sgm.modules.GeneralConditioner
51
+ params:
52
+ emb_models:
53
+
54
+ - input_key: cond_frames
55
+ is_trainable: False
56
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
57
+ params:
58
+ disable_encoder_autocast: True
59
+ n_cond_frames: 2
60
+ n_copies: 1
61
+ is_ae: True
62
+ encoder_config:
63
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
64
+ params:
65
+ embed_dim: 4
66
+ monitor: val/rec_loss
67
+ ddconfig:
68
+ attn_type: vanilla-xformers
69
+ double_z: True
70
+ z_channels: 4
71
+ resolution: 256
72
+ in_channels: 3
73
+ out_ch: 3
74
+ ch: 128
75
+ ch_mult: [1, 2, 4, 4]
76
+ num_res_blocks: 2
77
+ attn_resolutions: []
78
+ dropout: 0.0
79
+ lossconfig:
80
+ target: torch.nn.Identity
81
+
82
+ - input_key: gt # allows to use the ground truth as a condition
83
+ is_trainable: False
84
+ target: sgm.modules.encoders.modules.IdentityEncoder
85
+ params:
86
+ cond_type: gt
87
+
88
+ - input_key: audio_emb
89
+ is_trainable: True
90
+ target: sgm.modules.encoders.modules.WhisperAudioEmbedder
91
+ params:
92
+ merge_method: mean
93
+ linear_dim: 1024
94
+ cond_type: crossattn
95
+ audio_dim: 768
96
+
97
+ - input_key: masks
98
+ is_trainable: False
99
+ target: sgm.modules.encoders.modules.IdentityEncoder
100
+ params:
101
+ cond_type: masks
102
+
103
+ first_stage_config:
104
+ target: sgm.models.autoencoder.AutoencodingEngine
105
+ params:
106
+ loss_config:
107
+ target: torch.nn.Identity
108
+ regularizer_config:
109
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
110
+ encoder_config:
111
+ target: sgm.modules.diffusionmodules.model.Encoder
112
+ params:
113
+ attn_type: vanilla
114
+ double_z: True
115
+ z_channels: 4
116
+ resolution: 256
117
+ in_channels: 3
118
+ out_ch: 3
119
+ ch: 128
120
+ ch_mult: [1, 2, 4, 4]
121
+ num_res_blocks: 2
122
+ attn_resolutions: []
123
+ dropout: 0.0
124
+ decoder_config:
125
+ target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
126
+ params:
127
+ attn_type: vanilla
128
+ double_z: True
129
+ z_channels: 4
130
+ resolution: 256
131
+ in_channels: 3
132
+ out_ch: 3
133
+ ch: 128
134
+ ch_mult: [1, 2, 4, 4]
135
+ num_res_blocks: 2
136
+ attn_resolutions: []
137
+ dropout: 0.0
138
+ video_kernel_size: [3, 1, 1]
139
+
140
+ sampler_config:
141
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
142
+ params:
143
+ num_steps: 10
144
+ discretization_config:
145
+ target: sgm.modules.diffusionmodules.discretizer.AYSDiscretization
146
+
147
+ guider_config:
148
+ target: sgm.modules.diffusionmodules.guiders.AudioRefMultiCondGuider
149
+ params:
150
+ audio_ratio: 5.0
151
+ ref_ratio: 2.0
152
+
keyframe.yaml ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: sgm.models.diffusion.DiffusionEngine
3
+ params:
4
+ input_key: latents
5
+ scale_factor: 0.18215
6
+ disable_first_stage_autocast: True
7
+ ckpt_path:
8
+
9
+ denoiser_config:
10
+ target: sgm.modules.diffusionmodules.denoiser.DenoiserDub
11
+ params:
12
+ scaling_config:
13
+ target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
14
+
15
+ network_wrapper:
16
+ target: sgm.modules.diffusionmodules.wrappers.DubbingWrapper
17
+ params:
18
+ mask_input: True
19
+
20
+ network_config:
21
+ target: sgm.modules.diffusionmodules.video_model.VideoUNet
22
+ params:
23
+ adm_in_channels: 0
24
+ num_classes: sequential
25
+ use_checkpoint: True
26
+ in_channels: 8
27
+ out_channels: 4
28
+ model_channels: 320
29
+ attention_resolutions: [4, 2, 1]
30
+ num_res_blocks: 2
31
+ channel_mult: [1, 2, 4, 4]
32
+ num_head_channels: 64
33
+ use_linear_in_transformer: True
34
+ transformer_depth: 1
35
+ context_dim: 1024
36
+ spatial_transformer_attn_type: softmax-xformers
37
+ extra_ff_mix_layer: True
38
+ use_spatial_context: True
39
+ merge_strategy: learned_with_images
40
+ video_kernel_size: [3, 1, 1]
41
+ fine_tuning_method: null
42
+ audio_cond_method: both_keyframes
43
+ additional_audio_frames: 0
44
+ audio_dim: 1024
45
+ unfreeze_blocks: [] # Because we changed the input block
46
+
47
+
48
+ conditioner_config:
49
+ target: sgm.modules.GeneralConditioner
50
+ params:
51
+ emb_models:
52
+
53
+ - input_key: cond_frames
54
+ is_trainable: False
55
+ ucg_rate: 0.1
56
+ target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
57
+ params:
58
+ disable_encoder_autocast: True
59
+ n_cond_frames: 1
60
+ n_copies: 1
61
+ is_ae: True
62
+ load_encoder: False
63
+ encoder_config:
64
+ target: sgm.models.autoencoder.AutoencoderKLModeOnly
65
+ params:
66
+ embed_dim: 4
67
+ monitor: val/rec_loss
68
+ ddconfig:
69
+ attn_type: vanilla-xformers
70
+ double_z: True
71
+ z_channels: 4
72
+ resolution: 256
73
+ in_channels: 3
74
+ out_ch: 3
75
+ ch: 128
76
+ ch_mult: [1, 2, 4, 4]
77
+ num_res_blocks: 2
78
+ attn_resolutions: []
79
+ dropout: 0.0
80
+ lossconfig:
81
+ target: torch.nn.Identity
82
+
83
+ - input_key: gt # allows to use the ground truth as a condition
84
+ is_trainable: False
85
+ target: sgm.modules.encoders.modules.IdentityEncoder
86
+ params:
87
+ cond_type: gt
88
+
89
+ - input_key: audio_emb
90
+ is_trainable: True
91
+ ucg_rate: 0.2
92
+ target: sgm.modules.encoders.modules.WhisperAudioEmbedder
93
+ params:
94
+ merge_method: mean
95
+ linear_dim: 1024
96
+ cond_type: crossattn
97
+ audio_dim: 768
98
+
99
+
100
+ - input_key: masks
101
+ is_trainable: False
102
+ target: sgm.modules.encoders.modules.IdentityEncoder
103
+ params:
104
+ cond_type: masks
105
+
106
+ first_stage_config:
107
+ target: sgm.models.autoencoder.AutoencodingEngine
108
+ params:
109
+ loss_config:
110
+ target: torch.nn.Identity
111
+ regularizer_config:
112
+ target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
113
+ encoder_config:
114
+ target: sgm.modules.diffusionmodules.model.Encoder
115
+ params:
116
+ attn_type: vanilla
117
+ double_z: True
118
+ z_channels: 4
119
+ resolution: 256
120
+ in_channels: 3
121
+ out_ch: 3
122
+ ch: 128
123
+ ch_mult: [1, 2, 4, 4]
124
+ num_res_blocks: 2
125
+ attn_resolutions: []
126
+ dropout: 0.0
127
+ decoder_config:
128
+ target: sgm.modules.autoencoding.temporal_ae.VideoDecoder
129
+ params:
130
+ attn_type: vanilla
131
+ double_z: True
132
+ z_channels: 4
133
+ resolution: 256
134
+ in_channels: 3
135
+ out_ch: 3
136
+ ch: 128
137
+ ch_mult: [1, 2, 4, 4]
138
+ num_res_blocks: 2
139
+ attn_resolutions: []
140
+ dropout: 0.0
141
+ video_kernel_size: [3, 1, 1]
142
+
143
+ sampler_config:
144
+ target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
145
+ params:
146
+ num_steps: 10
147
+ discretization_config:
148
+ target: sgm.modules.diffusionmodules.discretizer.AYSDiscretization
149
+
150
+ guider_config:
151
+ target: sgm.modules.diffusionmodules.guiders.AudioRefMultiCondGuider
152
+ params:
153
+ audio_ratio: 5.0
154
+ ref_ratio: 2.0