farzadab commited on
Commit
870be04
·
verified ·
1 Parent(s): 54bfb27

Upload 5 files

Browse files
Files changed (1) hide show
  1. ultravox_config.py +19 -7
ultravox_config.py CHANGED
@@ -32,6 +32,8 @@ class LossFunction(str, Enum):
32
  class LossConfig:
33
  loss_function: LossFunction = LossFunction.CrossEntropy
34
  kl_temperature: float = 2.0
 
 
35
 
36
  @property
37
  def requires_alt_fields(self):
@@ -47,7 +49,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
47
  documentation from [`PretrainedConfig`] for more information.
48
 
49
  Args:
50
- audio_config (`Wav2Vec2Config`, *optional*):
51
  Custom audio config or dict
52
  text_config (`Union[AutoConfig, dict]`, *optional*):
53
  The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
@@ -65,15 +67,17 @@ class UltravoxConfig(transformers.PretrainedConfig):
65
  The LoRA configuration for finetuning the text model.
66
  audio_model_lora_config (`LoraConfigSimplified`, *optional*):
67
  The LoRA configuration for finetuning the audio model.
 
 
68
 
69
 
70
  Example:
71
 
72
  ```python
73
- >>> from transformers import UltravoxForConditionalGeneration, Wav2Vec2Config, UltravoxConfig, LlamaConfig
74
 
75
  >>> # Initializing an audio encoder config
76
- >>> audio_config = Wav2Vec2Config()
77
 
78
  >>> # Initializing a Llama config
79
  >>> text_config = LlamaConfig()
@@ -82,13 +86,13 @@ class UltravoxConfig(transformers.PretrainedConfig):
82
  >>> configuration = UltravoxConfig(audio_config, text_config)
83
 
84
  >>> # Initializing a completely untrained model from the configuration
85
- >>> model = UltravoxForConditionalGeneration(configuration)
86
 
87
  >>> # Accessing the model configuration
88
  >>> configuration = model.config
89
 
90
  >>> # Initialize a model from pretrained checkpoints and random projector weights
91
- >>> config = UltravoxConfig(audio_model_id="facebook/wav2vec2-base-960h", text_model_id="meta-llama/Llama-2-7b-chat-hf")
92
  ```"""
93
 
94
  model_type = "ultravox"
@@ -105,8 +109,10 @@ class UltravoxConfig(transformers.PretrainedConfig):
105
  stack_factor: int = 8,
106
  norm_init: float = 0.4,
107
  projector_act: str = "swiglu",
 
108
  text_model_lora_config: Optional[LoraConfigSimplified] = None,
109
  audio_model_lora_config: Optional[LoraConfigSimplified] = None,
 
110
  **kwargs,
111
  ):
112
  self.ignore_index = ignore_index
@@ -118,7 +124,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
118
  self.stack_factor = stack_factor
119
  self.norm_init = norm_init
120
  self.projector_act = projector_act
121
-
122
  if text_model_id is not None:
123
  self.text_config: transformers.LlamaConfig = (
124
  transformers.AutoConfig.from_pretrained(text_model_id)
@@ -136,7 +142,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
136
  else:
137
  audio_config = audio_config or {}
138
  self.audio_config = transformers.CONFIG_MAPPING[
139
- audio_config.get("model_type", "wav2vec2")
140
  ](**audio_config)
141
 
142
  self.text_model_lora_config = (
@@ -149,6 +155,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
149
  if isinstance(audio_model_lora_config, dict)
150
  else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
151
  )
 
152
 
153
  self.vocab_size = self.text_config.vocab_size
154
 
@@ -162,7 +169,12 @@ class UltravoxConfig(transformers.PretrainedConfig):
162
  # remove text_config and audio_config if text_model_id and audio_model_id are present
163
  if self.text_model_id is not None:
164
  diff_dict.pop("text_config", None)
 
 
 
165
  if self.audio_model_id is not None:
166
  diff_dict.pop("audio_config", None)
 
 
167
 
168
  return diff_dict
 
32
  class LossConfig:
33
  loss_function: LossFunction = LossFunction.CrossEntropy
34
  kl_temperature: float = 2.0
35
+ # Number of tokens to ignore from the beginning of the sequence. Only used in LSM
36
+ initial_tokens_to_ignore: int = 0
37
 
38
  @property
39
  def requires_alt_fields(self):
 
49
  documentation from [`PretrainedConfig`] for more information.
50
 
51
  Args:
52
+ audio_config (`WhisperConfig`, *optional*):
53
  Custom audio config or dict
54
  text_config (`Union[AutoConfig, dict]`, *optional*):
55
  The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
 
67
  The LoRA configuration for finetuning the text model.
68
  audio_model_lora_config (`LoraConfigSimplified`, *optional*):
69
  The LoRA configuration for finetuning the audio model.
70
+ audio_latency_block_size (`int`, *optional*, defaults to `None`):
71
+ The latency block size for simulating audio streaming.
72
 
73
 
74
  Example:
75
 
76
  ```python
77
+ >>> from transformers import UltravoxModel, WhisperConfig, UltravoxConfig, LlamaConfig
78
 
79
  >>> # Initializing an audio encoder config
80
+ >>> audio_config = WhisperConfig()
81
 
82
  >>> # Initializing a Llama config
83
  >>> text_config = LlamaConfig()
 
86
  >>> configuration = UltravoxConfig(audio_config, text_config)
87
 
88
  >>> # Initializing a completely untrained model from the configuration
89
+ >>> model = UltravoxModel(configuration)
90
 
91
  >>> # Accessing the model configuration
92
  >>> configuration = model.config
93
 
94
  >>> # Initialize a model from pretrained checkpoints and random projector weights
95
+ >>> config = UltravoxConfig(audio_model_id="openai/whisper-tiny", text_model_id="meta-llama/Llama-2-7b-chat-hf")
96
  ```"""
97
 
98
  model_type = "ultravox"
 
109
  stack_factor: int = 8,
110
  norm_init: float = 0.4,
111
  projector_act: str = "swiglu",
112
+ projector_ln_mid: bool = False, # defaults to False for compatibility with v0.4.1 and below
113
  text_model_lora_config: Optional[LoraConfigSimplified] = None,
114
  audio_model_lora_config: Optional[LoraConfigSimplified] = None,
115
+ audio_latency_block_size: Optional[int] = None,
116
  **kwargs,
117
  ):
118
  self.ignore_index = ignore_index
 
124
  self.stack_factor = stack_factor
125
  self.norm_init = norm_init
126
  self.projector_act = projector_act
127
+ self.projector_ln_mid = projector_ln_mid
128
  if text_model_id is not None:
129
  self.text_config: transformers.LlamaConfig = (
130
  transformers.AutoConfig.from_pretrained(text_model_id)
 
142
  else:
143
  audio_config = audio_config or {}
144
  self.audio_config = transformers.CONFIG_MAPPING[
145
+ audio_config.get("model_type", "whisper")
146
  ](**audio_config)
147
 
148
  self.text_model_lora_config = (
 
155
  if isinstance(audio_model_lora_config, dict)
156
  else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
157
  )
158
+ self.audio_latency_block_size = audio_latency_block_size
159
 
160
  self.vocab_size = self.text_config.vocab_size
161
 
 
169
  # remove text_config and audio_config if text_model_id and audio_model_id are present
170
  if self.text_model_id is not None:
171
  diff_dict.pop("text_config", None)
172
+ elif "text_config" in diff_dict:
173
+ diff_dict["text_config"].pop("_attn_implementation_autoset", None)
174
+
175
  if self.audio_model_id is not None:
176
  diff_dict.pop("audio_config", None)
177
+ elif "audio_config" in diff_dict:
178
+ diff_dict["audio_config"].pop("_attn_implementation_autoset", None)
179
 
180
  return diff_dict