lijialudew commited on
Commit
6333913
·
1 Parent(s): 4989262

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +341 -0
README.md CHANGED
@@ -41,6 +41,347 @@ We develop fine-tuning recipe using SpeechBrain toolkit available at
41
  <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
42
  If you wish to use fairseq framework, the following code snippet can be used to load our pretrained model
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  # Evaluation
46
 
 
41
  <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
42
  If you wish to use fairseq framework, the following code snippet can be used to load our pretrained model
43
 
44
+ '''
45
+ """This lobe enables the integration of fairseq pretrained wav2vec models.
46
+
47
+ Reference: https://arxiv.org/abs/2006.11477
48
+ Reference: https://arxiv.org/abs/1904.05862
49
+ FairSeq >= 1.0.0 needs to be installed: https://fairseq.readthedocs.io/en/latest/
50
+
51
+ Authors
52
+ * Titouan Parcollet 2021
53
+ * Salima Mdhaffar 2021
54
+ """
55
+
56
+ import torch
57
+ import torch.nn.functional as F
58
+ from torch import nn
59
+ from speechbrain.utils.data_utils import download_file
60
+ import pdb
61
+
62
+ # We check if fairseq is installed.
63
+ try:
64
+ import fairseq
65
+ except ImportError:
66
+ MSG = "Please install Fairseq to use pretrained wav2vec\n"
67
+ MSG += "E.G. run: pip install fairseq"
68
+ raise ImportError(MSG)
69
+
70
+
71
+ class FairseqWav2Vec2(nn.Module):
72
+ """This lobe enables the integration of fairseq pretrained wav2vec2.0 models.
73
+
74
+ Source paper: https://arxiv.org/abs/2006.11477
75
+ FairSeq >= 1.0.0 needs to be installed:
76
+ https://fairseq.readthedocs.io/en/latest/
77
+
78
+ The model can be used as a fixed features extractor or can be finetuned. It
79
+ will download automatically the model if a url is given (e.g FairSeq
80
+ repository from GitHub).
81
+
82
+ Arguments
83
+ ---------
84
+ pretrained_path : str
85
+ Path of the pretrained wav2vec2 model. It can be a url or a local path.
86
+ save_path : str
87
+ Path and filename of the downloaded model.
88
+ input_norm : bool (default: None)
89
+ If True, a layer_norm (affine) will be applied to the input waveform.
90
+ By default, it is extracted from the checkpoint of the downloaded model
91
+ in order to match the pretraining conditions. However, if this information
92
+ is not given in the checkpoint, it has to be given manually.
93
+ output_norm : bool (default: True)
94
+ If True, a layer_norm (affine) will be applied to the output obtained
95
+ from the wav2vec model.
96
+ freeze : bool (default: True)
97
+ If True, the model is frozen. If False, the model will be trained
98
+ alongside with the rest of the pipeline.
99
+ pretrain : bool (default: True)
100
+ If True, the model is pretrained with the specified source.
101
+ If False, the randomly-initialized model is instantiated.
102
+ dropout : float (default: None)
103
+ If different from None (0.0 to 1.0), it will override the given fairseq
104
+ dropout rates. This is useful if the wav2vec2 model has been trained
105
+ without dropout and one wants to reactivate it for downstream task
106
+ fine-tuning (better performance observed).
107
+
108
+ Example
109
+ -------
110
+ >>> inputs = torch.rand([10, 600])
111
+ >>> model_url = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt"
112
+ >>> save_path = "models_checkpoints/wav2vec2.pt"
113
+ >>> model = FairseqWav2Vec2(model_url, save_path)
114
+ >>> outputs = model(inputs)
115
+ >>> outputs.shape
116
+ torch.Size([10, 100, 768])
117
+ """
118
+
119
+ def __init__(
120
+ self,
121
+ pretrained_path,
122
+ save_path,
123
+ input_norm=None,
124
+ output_norm=True,
125
+ freeze=True,
126
+ pretrain=True,
127
+ dropout=None,
128
+ encoder_dropout = None,
129
+ output_all_hiddens=False,
130
+ tgt_layer=None,
131
+ include_CNN_layer=True,
132
+ ):
133
+ super().__init__()
134
+
135
+ # Download the pretrained wav2vec2 model. It can be local or online.
136
+ download_file(pretrained_path, save_path)
137
+
138
+ # During pretraining dropout might be set to 0. However, we might want
139
+ # to apply dropout when fine-tuning on a downstream task. Hence we need
140
+ # to modify the fairseq cfg to activate dropout (if requested).
141
+ overrides={}
142
+ if encoder_dropout is not None:
143
+ overrides = {
144
+ "model": {
145
+ "encoder_layerdrop": encoder_dropout,
146
+ }
147
+ }
148
+ if not freeze:
149
+ if dropout is not None and encoder_dropout is not None:
150
+ overrides = {
151
+ "model": {
152
+ "dropout": dropout,
153
+ "encoder_layerdrop": encoder_dropout,
154
+ "dropout_input": dropout,
155
+ "attention_dropout": dropout,
156
+ }
157
+ }
158
+ elif dropout is not None:
159
+ overrides = {
160
+ "model": {
161
+ "dropout": dropout,
162
+ "dropout_input": dropout,
163
+ "attention_dropout": dropout,
164
+ }
165
+ }
166
+ (
167
+ model,
168
+ cfg,
169
+ task,
170
+ ) = fairseq.checkpoint_utils.load_model_ensemble_and_task(
171
+ [save_path], arg_overrides=overrides
172
+ )
173
+
174
+ # wav2vec pretrained models may need the input waveform to be normalized
175
+ # Hence, we check if the model has be trained with or without it.
176
+ # If the information isn't contained in the checkpoint IT HAS TO BE GIVEN
177
+ # BY THE USER.
178
+ if input_norm is None:
179
+ if hasattr(cfg["task"], "normalize"):
180
+ self.normalize = cfg["task"].normalize
181
+ elif hasattr(cfg, "normalize"):
182
+ self.normalize = cfg.normalize
183
+ else:
184
+ self.normalize = False
185
+ else:
186
+ self.normalize = input_norm
187
+
188
+ model = model[0]
189
+ self.model = model
190
+ self.freeze = freeze
191
+ self.output_norm = output_norm
192
+
193
+ if self.freeze:
194
+ self.model.eval()
195
+ # Freeze parameters
196
+ for param in model.parameters():
197
+ param.requires_grad = False
198
+ else:
199
+ self.model.train()
200
+ for param in model.parameters():
201
+ param.requires_grad = True
202
+
203
+ # Randomly initialized layers if pretrain is False
204
+ if not (pretrain):
205
+ self.reset_layer(self.model)
206
+
207
+ # Following the fairseq implementation of downstream training,
208
+ # we remove some modules that are unnecessary.
209
+ self.remove_pretraining_modules()
210
+ self.output_all_hiddens = output_all_hiddens
211
+ self.tgt_layer = tgt_layer
212
+ self.include_CNN_layer = include_CNN_layer
213
+
214
+ def forward(self, wav):
215
+ """Takes an input waveform and return its corresponding wav2vec encoding.
216
+
217
+ Arguments
218
+ ---------
219
+ wav : torch.Tensor (signal)
220
+ A batch of audio signals to transform to features.
221
+ """
222
+
223
+ # If we freeze, we simply remove all grads and features from the graph.
224
+ if self.freeze:
225
+ with torch.no_grad():
226
+ return self.extract_features(wav).detach()
227
+
228
+ return self.extract_features(wav)
229
+
230
+ def extract_features(self, wav):
231
+ """Extracts the wav2vect embeddings"""
232
+ # We normalize the input signal if needed.
233
+ if self.normalize:
234
+ wav = F.layer_norm(wav, wav.shape)
235
+
236
+ # Extract wav2vec output
237
+ if self.tgt_layer=="CNN": #initial embeddings from conv
238
+ out = self.model.extract_features(wav, padding_mask=None, mask=False)
239
+ out = self.model.post_extract_proj(out['features'])
240
+ elif isinstance(self.tgt_layer, int):
241
+ out = self.model.extract_features(wav, padding_mask=None, mask=False, layer=self.tgt_layer)['x']
242
+ else: #
243
+ out = self.model.extract_features(wav, padding_mask=None, mask=False, layer=self.tgt_layer)
244
+ if self.output_all_hiddens or isinstance(self.tgt_layer, list):
245
+ out = self.aggregate_features(out, include_CNN_layer=self.include_CNN_layer) # 13, B, T, D
246
+ if isinstance(self.tgt_layer, list):
247
+ out = out[self.tgt_layer]
248
+ else:
249
+ out = out['x']
250
+
251
+ # We normalize the output if required
252
+ if self.output_norm:
253
+ out = F.layer_norm(out, out.shape)
254
+
255
+ return out
256
+
257
+ def aggregate_features(self, out, include_CNN_layer=True):
258
+ features = []
259
+ if include_CNN_layer:
260
+ features = [self.model.post_extract_proj(out['features'])]
261
+ self.model.layerdrop = 0
262
+ for i in range(len(out['layer_results'])):
263
+ curr_feature = out['layer_results'][i][0].transpose(0,1)
264
+ features.append(curr_feature)
265
+ features = torch.stack(features)
266
+ return features
267
+
268
+
269
+ def reset_layer(self, model):
270
+ """Reinitializes the parameters of the network"""
271
+ if hasattr(model, "reset_parameters"):
272
+ model.reset_parameters()
273
+ for child_layer in model.children():
274
+ if model != child_layer:
275
+ self.reset_layer(child_layer)
276
+
277
+ def remove_pretraining_modules(self):
278
+ """ Remove uneeded modules. Inspired by the same fairseq function."""
279
+
280
+ self.model.quantizer = None
281
+ self.model.project_q = None
282
+ self.model.target_glu = None
283
+ self.model.final_proj = None
284
+
285
+
286
+ class FairseqWav2Vec1(nn.Module):
287
+ """This lobes enables the integration of fairseq pretrained wav2vec1.0 models.
288
+
289
+ Arguments
290
+ ---------
291
+ pretrained_path : str
292
+ Path of the pretrained wav2vec1 model. It can be a url or a local path.
293
+ save_path : str
294
+ Path and filename of the downloaded model.
295
+ output_norm : bool (default: True)
296
+ If True, a layer_norm (affine) will be applied to the output obtained
297
+ from the wav2vec model.
298
+ freeze : bool (default: True)
299
+ If True, the model is frozen. If False, the model will be trained
300
+ alongside with the rest of the pipeline.
301
+ pretrain : bool (default: True)
302
+ If True, the model is pretrained with the specified source.
303
+ If False, the randomly-initialized model is instantiated.
304
+
305
+ Example
306
+ -------
307
+ >>> inputs = torch.rand([10, 600])
308
+ >>> model_url = ""
309
+ >>> save_path = "models_checkpoints/wav2vec.pt"
310
+ >>> model = FairseqWav2Vec1(model_url, save_path)
311
+ >>> outputs = model(inputs)
312
+ >>> outputs.shape
313
+ torch.Size([10, 100, 512])
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ pretrained_path,
319
+ save_path,
320
+ output_norm=True,
321
+ freeze=True,
322
+ pretrain=True,
323
+ ):
324
+ super().__init__()
325
+ self.freeze = freeze
326
+ self.output_norm = output_norm
327
+
328
+ # Download the pretrained wav2vec1 model. It can be local or online.
329
+ download_file(pretrained_path, save_path)
330
+
331
+ (
332
+ model,
333
+ cfg,
334
+ task,
335
+ ) = fairseq.checkpoint_utils.load_model_ensemble_and_task(
336
+ [pretrained_path]
337
+ )
338
+
339
+ self.model = model
340
+ self.model = self.model[0]
341
+ if self.freeze:
342
+ model.eval()
343
+
344
+ # Randomly initialized layers if pretrain is False
345
+ if not (pretrain):
346
+ self.reset_layer(self.model)
347
+
348
+ def forward(self, wav):
349
+ """Takes an input waveform and return its corresponding wav2vec encoding.
350
+
351
+ Arguments
352
+ ---------
353
+ wav : torch.Tensor (signal)
354
+ A batch of audio signals to transform to features.
355
+ """
356
+
357
+ # If we freeze, we simply remove all grads and features from the graph.
358
+ if self.freeze:
359
+ with torch.no_grad():
360
+ return self.extract_features(wav).detach()
361
+
362
+ return self.extract_features(wav)
363
+
364
+ def extract_features(self, wav):
365
+ """Extracts the wav2vect embeddings"""
366
+
367
+ out = self.model.feature_extractor(wav)
368
+ out = self.model.feature_aggregator(out).squeeze(0)
369
+ out = out.transpose(2, 1)
370
+
371
+ # We normalize the output if required
372
+ if self.output_norm:
373
+ out = F.layer_norm(out, out.shape)
374
+
375
+ return out
376
+
377
+ def reset_layer(self, model):
378
+ """Reinitializes the parameters of the network"""
379
+ if hasattr(model, "reset_parameters"):
380
+ model.reset_parameters()
381
+ for child_layer in model.children():
382
+ if model != child_layer:
383
+ self.reset_layer(child_layer)
384
+ '''
385
 
386
  # Evaluation
387