farzadab commited on
Commit
779bcda
·
verified ·
1 Parent(s): 6eadb3e

Upload 4 files

Browse files
Files changed (3) hide show
  1. processor_config.json +1 -1
  2. ultravox_model.py +48 -46
  3. ultravox_processing.py +172 -91
processor_config.json CHANGED
@@ -5,7 +5,7 @@
5
  "auto_map": {
6
  "AutoProcessor": "ultravox_processing.UltravoxProcessor"
7
  },
8
- "encoder_ds_factor": 320,
9
  "processor_class": "UltravoxProcessor",
10
  "stack_factor": 8
11
  }
 
5
  "auto_map": {
6
  "AutoProcessor": "ultravox_processing.UltravoxProcessor"
7
  },
8
+ "encoder_ds_factor": 2,
9
  "processor_class": "UltravoxProcessor",
10
  "stack_factor": 8
11
  }
ultravox_model.py CHANGED
@@ -1,6 +1,6 @@
1
  import logging
2
  import re
3
- from typing import Any, Dict, Optional, Set, Tuple, Union
4
 
5
  import peft
6
  import torch
@@ -10,6 +10,7 @@ import transformers
10
  import transformers.activations
11
  import transformers.modeling_outputs
12
  import transformers.models
 
13
  from transformers.models.whisper import modeling_whisper as whisper
14
 
15
  # We must use relative import in this directory to allow uploading to HF Hub
@@ -19,7 +20,7 @@ from .ultravox_config import LossFunction
19
  from .ultravox_config import UltravoxConfig
20
 
21
 
22
- class UltravoxModel(transformers.LlamaPreTrainedModel):
23
  """
24
  The Ultravox model which consists of an audio encoder and a language model.
25
 
@@ -57,10 +58,8 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
57
 
58
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
59
  # FSDP throws an error if some of the layer types are not found in the model.
60
- # This would be something like ["LlamaDecoderLayer", "WhisperEncoderLayer"]
61
- self._no_split_modules = (self.language_model._no_split_modules or []) + (
62
- self.audio_tower._no_split_modules or []
63
- )
64
 
65
  self.loss_config = LossConfig()
66
  self.post_init()
@@ -147,6 +146,24 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
147
  )
148
  return {"loss": kl_loss}
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def forward(
151
  self,
152
  input_ids: torch.Tensor,
@@ -188,23 +205,22 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
188
  # B x T -> B x T x D
189
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
190
 
191
- if audio_values is not None:
192
  assert (
193
  audio_token_start_idx is not None
194
  and audio_token_len is not None
 
195
  and audio_batch_size is not None
196
- ), "audio_token_start_idx and audio_token_len and audio_batch_size must be provided if audio_values are provided."
197
  assert (
198
  len(audio_token_start_idx)
199
  == len(audio_token_len)
200
- == len(audio_batch_size)
201
- ), "audio_token_start_idx and audio_token_len and audio_batch_size must have the same batch size."
202
- assert (
203
- audio_lens is not None
204
- ), "audio_lens must be provided if audio_values are provided"
205
- assert len(audio_lens) == len(
206
- audio_values
207
- ), "audio_lens must have the same batch size as audio_values."
208
 
209
  # B x A/3200 x (D=max-audio-length-in-batch)
210
  audio_tower_output = self.audio_tower.forward(
@@ -215,24 +231,11 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
215
  audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
216
 
217
  # combine audio and text embeddings
218
- # audio_embeds is (B_a X T X D)
219
- # inputs_embeds is (B_i X T X D)
220
- # B_a >= B_i because B_a includes all audio chunks.
221
- # B_i == audio_token_start_idx.shape[0] == audio_token_len.shape[0] == audio_batch_size.shape[0]
222
- audio_ind = 0
223
- for i, (start, length, batch_size) in enumerate(
224
- zip(audio_token_start_idx, audio_token_len, audio_batch_size)
225
- ):
226
- # audio_embeds is [B1 x T1 x D_hidden, B2 x T2 x D_hidden, ...]
227
- # audio.shape (T1 + T2 + ..., D_hidden)
228
- audio = torch.cat(
229
- [audio_embeds[k] for k in range(audio_ind, audio_ind + batch_size)],
230
- dim=0,
231
- )
232
- length = min(length, audio.shape[1])
233
- inputs_embeds[i, start : start + length] = audio[:length]
234
-
235
- audio_ind += batch_size
236
 
237
  lm_output = self.language_model.forward(
238
  inputs_embeds=inputs_embeds,
@@ -424,13 +427,17 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
424
  if state_dict is None:
425
  state_dict = super().state_dict()
426
 
427
- named_params = dict(self.named_parameters())
 
 
 
 
 
428
 
429
  state_dict = {
430
  k: v
431
  for k, v in state_dict.items()
432
- if k in self.keep_params
433
- or (k in named_params and named_params[k].requires_grad)
434
  }
435
 
436
  return state_dict
@@ -476,7 +483,7 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
476
 
477
  # TODO: refactor common parts to a shared module
478
  def is_cache_empty(
479
- past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]]
480
  ) -> bool:
481
  """
482
  Check if the cache is empty.
@@ -512,12 +519,8 @@ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
512
 
513
  class StackAudioFrames(nn.Module):
514
  """
515
- Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`.
516
-
517
- The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames.
518
- NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor,
519
- we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings.
520
- In most cases this extra padding will get removed in the model's forward function so it has no effect.
521
  """
522
 
523
  def __init__(self, stack_factor: int = 8):
@@ -527,7 +530,7 @@ class StackAudioFrames(nn.Module):
527
  def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
528
  B, T, C = audio_embeds.shape
529
  T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
530
- audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor))
531
  B, T, C = audio_embeds.shape
532
  audio_embeds = audio_embeds.view(
533
  B, T // self.stack_factor, C * self.stack_factor
@@ -700,7 +703,6 @@ class ModifiedWhisperEncoder(
700
  attention_mask = self.get_extended_attention_mask(
701
  attention_mask,
702
  None,
703
- device=hidden_states.device,
704
  dtype=hidden_states.dtype,
705
  )
706
 
 
1
  import logging
2
  import re
3
+ from typing import Any, Dict, Generator, Optional, Set, Tuple, Union
4
 
5
  import peft
6
  import torch
 
10
  import transformers.activations
11
  import transformers.modeling_outputs
12
  import transformers.models
13
+ from transformers.generation.utils import GenerationMixin
14
  from transformers.models.whisper import modeling_whisper as whisper
15
 
16
  # We must use relative import in this directory to allow uploading to HF Hub
 
20
  from .ultravox_config import UltravoxConfig
21
 
22
 
23
+ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
24
  """
25
  The Ultravox model which consists of an audio encoder and a language model.
26
 
 
58
 
59
  # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
60
  # FSDP throws an error if some of the layer types are not found in the model.
61
+ # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
62
+ self._no_split_modules = self.language_model._no_split_modules
 
 
63
 
64
  self.loss_config = LossConfig()
65
  self.post_init()
 
146
  )
147
  return {"loss": kl_loss}
148
 
149
+ def _audio_iter(
150
+ self, audio_batch_size: torch.Tensor
151
+ ) -> Generator[Tuple[int, int], None, None]:
152
+ """
153
+ Iterate over the audio batch size and yield the batch index and audio index of each audio item.
154
+
155
+ Args:
156
+ audio_batch_size: A tensor of shape (B,) where B is the batch size.
157
+
158
+ Returns:
159
+ A generator that yields a tuple of (start index, length) for each audio item.
160
+ """
161
+ audio_index = 0
162
+ for i_b, batch_count in enumerate(audio_batch_size):
163
+ for _ in range(batch_count):
164
+ yield i_b, audio_index
165
+ audio_index += 1
166
+
167
  def forward(
168
  self,
169
  input_ids: torch.Tensor,
 
205
  # B x T -> B x T x D
206
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
207
 
208
+ if audio_values is not None and len(audio_values) > 0:
209
  assert (
210
  audio_token_start_idx is not None
211
  and audio_token_len is not None
212
+ and audio_lens is not None
213
  and audio_batch_size is not None
214
+ ), "audio_token_start_idx/audio_token_len/audio_lens must be provided if audio_values are provided."
215
  assert (
216
  len(audio_token_start_idx)
217
  == len(audio_token_len)
218
+ == len(audio_lens)
219
+ == len(audio_values)
220
+ ), "audio_token_start_idx/audio_token_len/audio_lens/audio_values must have the same batch size."
221
+ assert len(audio_batch_size) == len(
222
+ inputs_embeds
223
+ ), "audio_batch_size and inputs_embeds must have the same batch size."
 
 
224
 
225
  # B x A/3200 x (D=max-audio-length-in-batch)
226
  audio_tower_output = self.audio_tower.forward(
 
231
  audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
232
 
233
  # combine audio and text embeddings
234
+ for i_b, i_a in self._audio_iter(audio_batch_size):
235
+ start_idx = audio_token_start_idx[i_a]
236
+ token_len = audio_token_len[i_a]
237
+ item_embedding = audio_embeds[i_a][:token_len]
238
+ inputs_embeds[i_b][start_idx : start_idx + token_len] = item_embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  lm_output = self.language_model.forward(
241
  inputs_embeds=inputs_embeds,
 
427
  if state_dict is None:
428
  state_dict = super().state_dict()
429
 
430
+ trainable_params = {k for k, v in self.named_parameters() if v.requires_grad}
431
+ # normalize the keys to match the original model
432
+ # Example: audio_tower.base_model.model.layers.0._fsdp_wrapped_module.self_attn.k_proj.lora_B.default.weight
433
+ trainable_params = {
434
+ k.replace("_fsdp_wrapped_module.", "") for k in trainable_params
435
+ }
436
 
437
  state_dict = {
438
  k: v
439
  for k, v in state_dict.items()
440
+ if k in self.keep_params or k in trainable_params
 
441
  }
442
 
443
  return state_dict
 
483
 
484
  # TODO: refactor common parts to a shared module
485
  def is_cache_empty(
486
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]],
487
  ) -> bool:
488
  """
489
  Check if the cache is empty.
 
519
 
520
  class StackAudioFrames(nn.Module):
521
  """
522
+ Stack the audio embedding frames to reduce the sequence length by a factor
523
+ of `stack_factor`.
 
 
 
 
524
  """
525
 
526
  def __init__(self, stack_factor: int = 8):
 
530
  def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
531
  B, T, C = audio_embeds.shape
532
  T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
533
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
534
  B, T, C = audio_embeds.shape
535
  audio_embeds = audio_embeds.view(
536
  B, T // self.stack_factor, C * self.stack_factor
 
703
  attention_mask = self.get_extended_attention_mask(
704
  attention_mask,
705
  None,
 
706
  dtype=hidden_states.dtype,
707
  )
708
 
ultravox_processing.py CHANGED
@@ -1,5 +1,5 @@
1
  import dataclasses
2
- from typing import Any, Dict, Optional, Union
3
 
4
  import numpy as np
5
  import torch
@@ -15,8 +15,13 @@ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
15
  include_alt_fields: bool = False
16
 
17
  def __call__(self, features, *args, **kwargs):
18
- audio_values = [f.pop("audio_values", None) for f in features]
19
- audio_lens = [f.pop("audio_lens", None) for f in features]
 
 
 
 
 
20
  if self.include_alt_fields:
21
  # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
22
  alt_features = [
@@ -35,10 +40,14 @@ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
35
  batch["alt_attention_mask"] = alt_batch["attention_mask"]
36
  batch["alt_labels"] = alt_batch["labels"]
37
 
 
 
 
 
38
  # Pad the last dimension of all audio_values to the same length, with 0s on the right.
39
- if audio_values and audio_values[0] is not None:
40
  max_len = max([x.shape[-1] for x in audio_values])
41
- batch["audio_values"] = torch.cat(
42
  [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
43
  )
44
  if self.tokenizer.padding_side == "left":
@@ -46,11 +55,12 @@ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
46
  [f["input_ids"].shape[-1] for f in features]
47
  )
48
  displacement = batch["input_ids"].shape[-1] - input_ids_lens
 
 
 
49
  batch["audio_token_start_idx"] += displacement.to(
50
  batch["audio_token_start_idx"].device
51
  )
52
- # batch["audio_lens"].shape = (B,)
53
- batch["audio_lens"] = torch.cat(audio_lens)
54
  return batch
55
 
56
 
@@ -64,11 +74,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
64
  """
65
 
66
  attributes = ["audio_processor", "tokenizer"]
67
- audio_processor_class = (
68
- "Wav2Vec2Processor",
69
- "SeamlessM4TFeatureExtractor",
70
- "WhisperProcessor",
71
- )
72
  tokenizer_class = (
73
  "PreTrainedTokenizer",
74
  "PreTrainedTokenizerFast",
@@ -82,7 +88,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
82
  audio_processor=None,
83
  tokenizer=None,
84
  audio_padding: str = "longest",
85
- encoder_ds_factor: int = 320,
86
  stack_factor: int = 8,
87
  audio_placeholder: str = "<|audio|>",
88
  # Defaults to whisper encoder context size
@@ -93,8 +99,8 @@ class UltravoxProcessor(transformers.ProcessorMixin):
93
  audio_processor: The audio processor for the audio encoder.
94
  tokenizer: The tokenizer for the language model.
95
  audio_padding: The padding strategy for the audio encoder.
96
- encoder_ds_factor: The downsample factor of the audio encoder.
97
  stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
 
98
  audio_placeholder: The placeholder for the audio in the text.
99
  audio_context_size: The maximum number of frames that the audio encoder can handle.
100
  """
@@ -102,11 +108,12 @@ class UltravoxProcessor(transformers.ProcessorMixin):
102
  self.encoder_ds_factor = encoder_ds_factor
103
  self.stack_factor = stack_factor
104
  self.audio_placeholder = audio_placeholder
105
- self.audio_token_replacement = tokenizer.eos_token
106
  self.audio_context_size = audio_context_size
107
  assert (
108
- self.audio_token_replacement is not None
109
  ), "The tokenizer has no EOS token. Cannot recover."
 
 
110
  if tokenizer.pad_token_id is None:
111
  tokenizer.pad_token_id = tokenizer.eos_token_id
112
 
@@ -120,7 +127,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
120
  audio_processor = transformers.AutoProcessor.from_pretrained(
121
  config.audio_model_id
122
  or config.audio_config._name_or_path
123
- or "facebook/wav2vec2-base-960h"
124
  )
125
 
126
  tokenizer = transformers.AutoTokenizer.from_pretrained(
@@ -135,65 +142,100 @@ class UltravoxProcessor(transformers.ProcessorMixin):
135
  stack_factor=config.stack_factor,
136
  )
137
 
138
- def _chunk_and_pad_audio(self, audio_values: torch.Tensor) -> Dict[str, Any]:
 
 
 
 
 
139
  """
140
- Processes the audio tensor by chunking it according to the audio_context_size,
141
  padding the last chunk if needed, and returns a dictionary with updated audio data.
142
 
143
  Args:
144
  audio_values (torch.Tensor): A tensor of audio values (e.g., in B, D, T format).
 
145
 
146
  Returns:
147
  Dict[str, Any]: Dictionary with the following keys:
148
  - "audio_values": The concatenated audio tensor after chunking and padding.
149
- - "audio_lens": List of lengths (as torch.Tensor) for each chunk.
150
- - "audio_batch_size": A list with one integer representing the number of chunks.
 
 
151
  """
152
- result: Dict[str, Any] = {}
153
- if self.audio_context_size and audio_values.shape[-1] > self.audio_context_size:
154
- audio_chunks = list(
155
- torch.split(audio_values, self.audio_context_size, dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  )
157
- valid_lengths = [chunk.shape[-1] for chunk in audio_chunks]
158
- result = {
159
- "audio_lens": [torch.as_tensor(length) for length in valid_lengths]
160
- }
161
- # Pad the last chunk to the full context length if needed.
162
- last_chunk = audio_chunks[-1]
163
- pad_size = self.audio_context_size - last_chunk.shape[-1]
164
- if pad_size > 0:
165
- audio_chunks[-1] = F.pad(last_chunk, (0, pad_size))
166
- else:
167
- audio_chunks = [audio_values]
168
- result = {"audio_lens": [torch.as_tensor(audio_values.shape[-1])]}
169
- result["audio_values"] = torch.cat(audio_chunks)
170
- result["audio_batch_size"] = [result["audio_values"].shape[0]]
171
- return result
172
 
173
  def __call__(
174
  self,
175
  text: Optional[str] = None,
176
  audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
 
 
 
 
 
177
  sampling_rate: Optional[int] = None,
178
  return_tensors: Optional[
179
  Union[str, transformers.TensorType]
180
  ] = transformers.TensorType.PYTORCH,
 
181
  **kwargs,
182
  ) -> transformers.BatchFeature:
183
  """
184
  Main method to prepare for the model one text sequence and audio. This method forwards the `text`
185
  and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
186
  the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
187
- audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
188
  of the above two methods for more information.
189
 
190
  Args:
191
  text (`str`, `List[str]`):
192
  The sequence to be encoded. Sequence can be a string or (pretokenized string).
193
  audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
194
- The audio to be prepared. Audio can be NumPy array or PyTorch tensor. In case of a
195
- NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the
196
- sample length of the audio.
197
  sampling_rate (`int`, *optional*, defaults to 16000):
198
  Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
199
  you are doing.
@@ -217,66 +259,105 @@ class UltravoxProcessor(transformers.ProcessorMixin):
217
  Returned when `audio` is not `None`.
218
  - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
219
  """
220
- # TODO: Add support for multiple audio and text inputs.
221
- data: Dict[str, Any] = {}
222
- audio_embed_frames = 0
223
- if audio is not None and len(audio) > 0:
224
- audio_len = audio.shape[-1]
225
- # It's guaranteed that the number of frames is less than or equal to this amount.
226
- # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
227
- # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
228
- nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
229
- audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
230
- data["audio_token_len"] = [audio_embed_frames]
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  # Main audio processing. The processor is model-specific.
233
- x = self.audio_processor(
234
- audio,
235
  sampling_rate=sampling_rate,
236
  padding="longest",
 
 
237
  return_attention_mask=True,
238
  **kwargs,
239
  )
240
 
241
- if "input_features" in x:
242
- audio_values = x.input_features
243
- else:
244
- audio_values = x.input_values
 
 
 
 
 
245
 
246
- audio_values = torch.tensor(audio_values)
247
- chunk_and_pad_results = self._chunk_and_pad_audio(audio_values)
248
- data["audio_values"] = chunk_and_pad_results["audio_values"]
249
- data["audio_lens"] = chunk_and_pad_results["audio_lens"]
250
- data["audio_batch_size"] = chunk_and_pad_results["audio_batch_size"]
251
 
252
  if text is not None:
253
- assert isinstance(
254
- text, str
255
- ), "Text must be a string. Batch mode not supported yet."
256
- if self.audio_placeholder in text:
257
- if "audio_token_len" not in data:
258
- raise ValueError(
259
- f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
260
- )
261
-
262
- start_idx = len(
263
- self.tokenizer.encode(
264
- text[: text.index(self.audio_placeholder)],
265
- add_special_tokens=False,
266
- )
267
- )
268
- data["audio_token_start_idx"] = [start_idx]
269
-
270
- # Replace the audio placeholder with the audio token.
271
- # e.g. "Transcribe\n<|audio|>" -> "Transcribe\n</s></s></s></s></s></s></s></s>"
272
- # where the number of </s> is the number of audio frames.
273
- text = text.replace(
274
- self.audio_placeholder,
275
- self.audio_token_replacement * audio_embed_frames,
276
- )
277
 
278
  # Special tokens like BOS should already have been added by the caller.
279
- data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
282
 
 
1
  import dataclasses
2
+ from typing import Any, Dict, List, Optional, Union
3
 
4
  import numpy as np
5
  import torch
 
15
  include_alt_fields: bool = False
16
 
17
  def __call__(self, features, *args, **kwargs):
18
+ audio_values = [x for f in features for x in f.pop("audio_values", [])]
19
+ audio_lens = [x for f in features for x in f.pop("audio_lens", [])]
20
+ audio_token_len = [x for f in features for x in f.pop("audio_token_len", [])]
21
+ audio_token_start_idx = [
22
+ x for f in features for x in f.pop("audio_token_start_idx", [])
23
+ ]
24
+
25
  if self.include_alt_fields:
26
  # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
27
  alt_features = [
 
40
  batch["alt_attention_mask"] = alt_batch["attention_mask"]
41
  batch["alt_labels"] = alt_batch["labels"]
42
 
43
+ batch["audio_token_start_idx"] = torch.stack(audio_token_start_idx)
44
+ batch["audio_lens"] = torch.stack(audio_lens)
45
+ batch["audio_token_len"] = torch.stack(audio_token_len)
46
+
47
  # Pad the last dimension of all audio_values to the same length, with 0s on the right.
48
+ if audio_values:
49
  max_len = max([x.shape[-1] for x in audio_values])
50
+ batch["audio_values"] = torch.stack(
51
  [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
52
  )
53
  if self.tokenizer.padding_side == "left":
 
55
  [f["input_ids"].shape[-1] for f in features]
56
  )
57
  displacement = batch["input_ids"].shape[-1] - input_ids_lens
58
+ displacement = displacement.repeat_interleave(
59
+ batch["audio_batch_size"].squeeze(-1)
60
+ )
61
  batch["audio_token_start_idx"] += displacement.to(
62
  batch["audio_token_start_idx"].device
63
  )
 
 
64
  return batch
65
 
66
 
 
74
  """
75
 
76
  attributes = ["audio_processor", "tokenizer"]
77
+ audio_processor_class = ("WhisperProcessor",)
 
 
 
 
78
  tokenizer_class = (
79
  "PreTrainedTokenizer",
80
  "PreTrainedTokenizerFast",
 
88
  audio_processor=None,
89
  tokenizer=None,
90
  audio_padding: str = "longest",
91
+ encoder_ds_factor: int = 2,
92
  stack_factor: int = 8,
93
  audio_placeholder: str = "<|audio|>",
94
  # Defaults to whisper encoder context size
 
99
  audio_processor: The audio processor for the audio encoder.
100
  tokenizer: The tokenizer for the language model.
101
  audio_padding: The padding strategy for the audio encoder.
 
102
  stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
103
+ encoder_ds_factor: The downsampling factor of the audio encoder.
104
  audio_placeholder: The placeholder for the audio in the text.
105
  audio_context_size: The maximum number of frames that the audio encoder can handle.
106
  """
 
108
  self.encoder_ds_factor = encoder_ds_factor
109
  self.stack_factor = stack_factor
110
  self.audio_placeholder = audio_placeholder
 
111
  self.audio_context_size = audio_context_size
112
  assert (
113
+ tokenizer.eos_token is not None
114
  ), "The tokenizer has no EOS token. Cannot recover."
115
+ self.vocab = tokenizer.get_vocab()
116
+ self.audio_token_replacement = tokenizer.eos_token
117
  if tokenizer.pad_token_id is None:
118
  tokenizer.pad_token_id = tokenizer.eos_token_id
119
 
 
127
  audio_processor = transformers.AutoProcessor.from_pretrained(
128
  config.audio_model_id
129
  or config.audio_config._name_or_path
130
+ or "openai/whisper-tiny"
131
  )
132
 
133
  tokenizer = transformers.AutoTokenizer.from_pretrained(
 
142
  stack_factor=config.stack_factor,
143
  )
144
 
145
+ def _chunk_and_pad_audio(
146
+ self,
147
+ audio_values: torch.Tensor,
148
+ audio_lens: torch.Tensor,
149
+ include_audio_num_chunks: bool = False,
150
+ ) -> Dict[str, Any]:
151
  """
152
+ Processes the audio batch by chunking any items in the batch according to the audio_context_size,
153
  padding the last chunk if needed, and returns a dictionary with updated audio data.
154
 
155
  Args:
156
  audio_values (torch.Tensor): A tensor of audio values (e.g., in B, D, T format).
157
+ audio_lens (torch.Tensor): A tensor of audio lengths.
158
 
159
  Returns:
160
  Dict[str, Any]: Dictionary with the following keys:
161
  - "audio_values": The concatenated audio tensor after chunking and padding.
162
+ - "audio_lens": Tensor of lengths for each chunk.
163
+ - "audio_is_continuation": Tensor of booleans indicating if the chunk is a continuation of the previous chunk.
164
+ - "audio_batch_size": A Tensor with one integer representing the number of chunks.
165
+
166
  """
167
+ chunked_audio_values: List[torch.Tensor] = []
168
+ chunked_audio_lens: List[int] = []
169
+ is_continuation_list: List[bool] = []
170
+ num_chunks: List[int] = []
171
+ context_size = self.audio_context_size or audio_values.shape[-1]
172
+
173
+ for i in range(audio_values.shape[0]): # iterate over the batch
174
+ num_chunks.append(int(np.ceil(audio_lens[i] / context_size)))
175
+ for offset in range(0, audio_lens[i], context_size):
176
+ is_continuation = offset > 0
177
+ chunk = audio_values[i, :, offset : offset + context_size]
178
+ if is_continuation and chunk.shape[-1] < context_size:
179
+ # N.B. We only need to pad continuation chunks. If none of the samples require chunking, the
180
+ # batch might not (need to) be padded all the way to the audio_context_size, in which case
181
+ # we've already included the padding above. On the other hand, if we have any continuation
182
+ # chunks we know that the batch needs to be padded to audio_context_size because that's what
183
+ # we're slicing to.
184
+ chunk = F.pad(chunk, (0, context_size - chunk.shape[-1]))
185
+ chunked_audio_values.append(chunk)
186
+ chunked_audio_lens.append(
187
+ min(int(audio_lens[i].item()) - offset, context_size)
188
+ )
189
+ is_continuation_list.append(is_continuation)
190
+
191
+ data = {
192
+ "audio_values": torch.stack(chunked_audio_values, dim=0),
193
+ "audio_lens": torch.tensor(
194
+ chunked_audio_lens, dtype=torch.int64, device=audio_values.device
195
+ ),
196
+ "audio_is_continuation": torch.tensor(
197
+ is_continuation_list, dtype=torch.bool, device=audio_values.device
198
+ ),
199
+ "audio_batch_size": torch.tensor(
200
+ [len(chunked_audio_values)], device=audio_values.device
201
+ ),
202
+ }
203
+ if include_audio_num_chunks:
204
+ data["audio_num_chunks"] = torch.tensor(
205
+ num_chunks, dtype=torch.int64, device=audio_values.device
206
  )
207
+ return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
  def __call__(
210
  self,
211
  text: Optional[str] = None,
212
  audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
213
+ audios: Optional[
214
+ Union[
215
+ List[Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor]
216
+ ]
217
+ ] = None,
218
  sampling_rate: Optional[int] = None,
219
  return_tensors: Optional[
220
  Union[str, transformers.TensorType]
221
  ] = transformers.TensorType.PYTORCH,
222
+ include_audio_num_chunks: bool = False,
223
  **kwargs,
224
  ) -> transformers.BatchFeature:
225
  """
226
  Main method to prepare for the model one text sequence and audio. This method forwards the `text`
227
  and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
228
  the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
229
+ audio processor's [`~WhisperProcessor.__call__`] if `audio` is not `None`. Please refer to the docstring
230
  of the above two methods for more information.
231
 
232
  Args:
233
  text (`str`, `List[str]`):
234
  The sequence to be encoded. Sequence can be a string or (pretokenized string).
235
  audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
236
+ The audio to be prepared. Audio can be a single-channel (1-dimensional) NumPy array or PyTorch tensor.
237
+ audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
238
+ A list or two dimensional array of audio to be prepared.
239
  sampling_rate (`int`, *optional*, defaults to 16000):
240
  Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
241
  you are doing.
 
259
  Returned when `audio` is not `None`.
260
  - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
261
  """
262
+ # TODO: Add support for multiple text inputs.
263
+ if audio is not None and audios is not None:
264
+ raise ValueError("Only one of `audio` or `audios` should be provided.")
265
+ elif audio is not None:
266
+ audios = audio if isinstance(audio, list) or audio.ndim == 2 else [audio]
267
+ elif audios is None:
268
+ audios = []
269
+
270
+ data = {}
271
+ audio_is_continuation = []
272
+ if len(audios) > 0:
273
+ audios = [x.numpy() if isinstance(x, torch.Tensor) else x for x in audios]
274
+
275
+ # Pad out each audio to at least 2 hops (the minimum required by the processor).
276
+ hop_length = self.audio_processor.feature_extractor.hop_length
277
+ audios = [
278
+ (
279
+ np.pad(x, (0, 2 * hop_length - len(x)), mode="constant")
280
+ if len(x) < 2 * hop_length
281
+ else x
282
+ )
283
+ for x in audios
284
+ ]
285
 
286
  # Main audio processing. The processor is model-specific.
287
+ x: transformers.BatchFeature = self.audio_processor(
288
+ audios,
289
  sampling_rate=sampling_rate,
290
  padding="longest",
291
+ pad_to_multiple_of=hop_length, # The attention mask effectively gets padded to the hop length, so pad the audio to be consistent.
292
+ truncation=False,
293
  return_attention_mask=True,
294
  **kwargs,
295
  )
296
 
297
+ data.update(
298
+ self._chunk_and_pad_audio(
299
+ audio_values=torch.as_tensor(
300
+ x.input_features if "input_features" in x else x.input_values
301
+ ),
302
+ audio_lens=torch.as_tensor(x.attention_mask).sum(-1),
303
+ include_audio_num_chunks=include_audio_num_chunks,
304
+ )
305
+ )
306
 
307
+ audio_is_continuation = data.pop("audio_is_continuation")
308
+ data["audio_token_len"] = torch.ceil(
309
+ data["audio_lens"] / (self.encoder_ds_factor * self.stack_factor)
310
+ ).to(dtype=torch.int)
 
311
 
312
  if text is not None:
313
+ if not isinstance(text, str):
314
+ raise ValueError("Text must be a string. Batch mode not supported yet.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  # Special tokens like BOS should already have been added by the caller.
317
+ tokenized_parts = self.tokenizer(
318
+ text.split(
319
+ "<|audio|>" # The placeholder isn't part of the vocabulary, so split the text around it.
320
+ ),
321
+ add_special_tokens=False,
322
+ **kwargs,
323
+ )
324
+
325
+ audio_token_start_idx = []
326
+ placeholder_index = -1
327
+ split_input_ids = tokenized_parts["input_ids"]
328
+ input_ids: List[int] = []
329
+
330
+ audio_token_replacement_token_id = self.vocab[self.audio_token_replacement]
331
+
332
+ for i, token_len in enumerate(data.get("audio_token_len", [])):
333
+ if not audio_is_continuation[i]:
334
+ placeholder_index += 1
335
+ if placeholder_index >= len(split_input_ids):
336
+ raise ValueError(
337
+ f"Text contains too few audio placeholders. (Expected {len(audios)} placeholders)"
338
+ )
339
+
340
+ input_ids.extend(split_input_ids[placeholder_index])
341
+
342
+ audio_token_start_idx.append(len(input_ids))
343
+
344
+ input_ids.extend([audio_token_replacement_token_id] * token_len)
345
+
346
+ # Include any tokens after the last audio.
347
+ placeholder_index += 1
348
+ if placeholder_index != len(split_input_ids) - 1:
349
+ raise ValueError(
350
+ f"Text contains too many audio placeholders. (Expected {len(audios)} placeholders)"
351
+ )
352
+ input_ids.extend(split_input_ids[placeholder_index])
353
+
354
+ if "audio_token_len" in data:
355
+ data["audio_token_start_idx"] = torch.as_tensor(audio_token_start_idx)
356
+
357
+ data["input_ids"] = [input_ids]
358
+ data["attention_mask"] = [[1] * len(input_ids)]
359
+
360
+ # Ensure that there are no audio placeholders after the last audio.
361
 
362
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
363