likunchang commited on
Commit
7e078c9
·
1 Parent(s): 100feec
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb
2
+ __pycache__
3
+ .vscode
4
+ notebooks
5
+ results
6
+ *.ipynb_checkpoints
7
+ eval_results
8
+ tests
9
+ .DS_Store
10
+ gradio.sh
11
+ debug*
app.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ import random
7
+ import subprocess
8
+ subprocess.run(
9
+ "pip install flash-attn --no-build-isolation",
10
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
11
+ shell=True,
12
+ )
13
+
14
+ from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
+ from PIL import Image
16
+
17
+ from data.data_utils import add_special_tokens, pil_img2rgb
18
+ from data.transforms import ImageTransform
19
+ from inferencer import InterleaveInferencer
20
+ from modeling.autoencoder import load_ae
21
+ from modeling.bagel import (
22
+ BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
23
+ SiglipVisionConfig, SiglipVisionModel
24
+ )
25
+ from modeling.qwen2 import Qwen2Tokenizer
26
+
27
+ from huggingface_hub import snapshot_download
28
+
29
+ model_path = "/model"
30
+ repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
31
+ cache_dir = model_path + "/cache"
32
+
33
+ snapshot_download(cache_dir=cache_dir,
34
+ local_dir=model_path,
35
+ repo_id=repo_id,
36
+ local_dir_use_symlinks=False,
37
+ resume_download=True,
38
+ allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
39
+ )
40
+
41
+ # Model Initialization
42
+
43
+ llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
44
+ llm_config.qk_norm = True
45
+ llm_config.tie_word_embeddings = False
46
+ llm_config.layer_module = "Qwen2MoTDecoderLayer"
47
+
48
+ vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
49
+ vit_config.rope = False
50
+ vit_config.num_hidden_layers -= 1
51
+
52
+ vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
53
+
54
+ config = BagelConfig(
55
+ visual_gen=True,
56
+ visual_und=True,
57
+ llm_config=llm_config,
58
+ vit_config=vit_config,
59
+ vae_config=vae_config,
60
+ vit_max_num_patch_per_side=70,
61
+ connector_act='gelu_pytorch_tanh',
62
+ latent_patch_size=2,
63
+ max_latent_size=64,
64
+ )
65
+
66
+ with init_empty_weights():
67
+ language_model = Qwen2ForCausalLM(llm_config)
68
+ vit_model = SiglipVisionModel(vit_config)
69
+ model = Bagel(language_model, vit_model, config)
70
+ model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
71
+
72
+ tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
73
+ tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
74
+
75
+ vae_transform = ImageTransform(1024, 512, 16)
76
+ vit_transform = ImageTransform(980, 224, 14)
77
+
78
+ # Model Loading and Multi GPU Infernece Preparing
79
+ device_map = infer_auto_device_map(
80
+ model,
81
+ max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
82
+ no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
83
+ )
84
+
85
+ same_device_modules = [
86
+ 'language_model.model.embed_tokens',
87
+ 'time_embedder',
88
+ 'latent_pos_embed',
89
+ 'vae2llm',
90
+ 'llm2vae',
91
+ 'connector',
92
+ 'vit_pos_embed'
93
+ ]
94
+
95
+ if torch.cuda.device_count() == 1:
96
+ first_device = device_map.get(same_device_modules[0], "cuda:0")
97
+ for k in same_device_modules:
98
+ if k in device_map:
99
+ device_map[k] = first_device
100
+ else:
101
+ device_map[k] = "cuda:0"
102
+ else:
103
+ first_device = device_map.get(same_device_modules[0])
104
+ for k in same_device_modules:
105
+ if k in device_map:
106
+ device_map[k] = first_device
107
+
108
+ model = load_checkpoint_and_dispatch(
109
+ model,
110
+ checkpoint=os.path.join(model_path, "ema.safetensors"),
111
+ device_map=device_map,
112
+ offload_buffers=True,
113
+ offload_folder="offload",
114
+ dtype=torch.bfloat16,
115
+ force_hooks=True,
116
+ ).eval()
117
+
118
+
119
+ # Inferencer Preparing
120
+ inferencer = InterleaveInferencer(
121
+ model=model,
122
+ vae_model=vae_model,
123
+ tokenizer=tokenizer,
124
+ vae_transform=vae_transform,
125
+ vit_transform=vit_transform,
126
+ new_token_ids=new_token_ids,
127
+ )
128
+
129
+ def set_seed(seed):
130
+ """Set random seeds for reproducibility"""
131
+ if seed > 0:
132
+ random.seed(seed)
133
+ np.random.seed(seed)
134
+ torch.manual_seed(seed)
135
+ if torch.cuda.is_available():
136
+ torch.cuda.manual_seed(seed)
137
+ torch.cuda.manual_seed_all(seed)
138
+ torch.backends.cudnn.deterministic = True
139
+ torch.backends.cudnn.benchmark = False
140
+ return seed
141
+
142
+ # Text to Image function with thinking option and hyperparameters
143
+ @spaces.GPU(duration=90)
144
+ def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
145
+ timestep_shift=3.0, num_timesteps=50,
146
+ cfg_renorm_min=1.0, cfg_renorm_type="global",
147
+ max_think_token_n=1024, do_sample=False, text_temperature=0.3,
148
+ seed=0, image_ratio="1:1"):
149
+ # Set seed for reproducibility
150
+ set_seed(seed)
151
+
152
+ if image_ratio == "1:1":
153
+ image_shapes = (1024, 1024)
154
+ elif image_ratio == "4:3":
155
+ image_shapes = (768, 1024)
156
+ elif image_ratio == "3:4":
157
+ image_shapes = (1024, 768)
158
+ elif image_ratio == "16:9":
159
+ image_shapes = (576, 1024)
160
+ elif image_ratio == "9:16":
161
+ image_shapes = (1024, 576)
162
+
163
+ # Set hyperparameters
164
+ inference_hyper = dict(
165
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
166
+ do_sample=do_sample if show_thinking else False,
167
+ text_temperature=text_temperature if show_thinking else 0.3,
168
+ cfg_text_scale=cfg_text_scale,
169
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
170
+ timestep_shift=timestep_shift,
171
+ num_timesteps=num_timesteps,
172
+ cfg_renorm_min=cfg_renorm_min,
173
+ cfg_renorm_type=cfg_renorm_type,
174
+ image_shapes=image_shapes,
175
+ )
176
+
177
+ # Call inferencer with or without think parameter based on user choice
178
+ result = inferencer(text=prompt, think=show_thinking, **inference_hyper)
179
+ return result["image"], result.get("text", None)
180
+
181
+
182
+ # Image Understanding function with thinking option and hyperparameters
183
+ @spaces.GPU(duration=90)
184
+ def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
185
+ do_sample=False, text_temperature=0.3, max_new_tokens=512):
186
+ if image is None:
187
+ return "Please upload an image."
188
+
189
+ if isinstance(image, np.ndarray):
190
+ image = Image.fromarray(image)
191
+
192
+ image = pil_img2rgb(image)
193
+
194
+ # Set hyperparameters
195
+ inference_hyper = dict(
196
+ do_sample=do_sample,
197
+ text_temperature=text_temperature,
198
+ max_think_token_n=max_new_tokens, # Set max_length
199
+ )
200
+
201
+ # Use show_thinking parameter to control thinking process
202
+ result = inferencer(image=image, text=prompt, think=show_thinking,
203
+ understanding_output=True, **inference_hyper)
204
+ return result["text"]
205
+
206
+
207
+ # Image Editing function with thinking option and hyperparameters
208
+ @spaces.GPU(duration=90)
209
+ def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
210
+ cfg_img_scale=2.0, cfg_interval=0.0,
211
+ timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
212
+ cfg_renorm_type="text_channel", max_think_token_n=1024,
213
+ do_sample=False, text_temperature=0.3, seed=0):
214
+ # Set seed for reproducibility
215
+ set_seed(seed)
216
+
217
+ if image is None:
218
+ return "Please upload an image.", ""
219
+
220
+ if isinstance(image, np.ndarray):
221
+ image = Image.fromarray(image)
222
+
223
+ image = pil_img2rgb(image)
224
+
225
+ # Set hyperparameters
226
+ inference_hyper = dict(
227
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
228
+ do_sample=do_sample if show_thinking else False,
229
+ text_temperature=text_temperature if show_thinking else 0.3,
230
+ cfg_text_scale=cfg_text_scale,
231
+ cfg_img_scale=cfg_img_scale,
232
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
233
+ timestep_shift=timestep_shift,
234
+ num_timesteps=num_timesteps,
235
+ cfg_renorm_min=cfg_renorm_min,
236
+ cfg_renorm_type=cfg_renorm_type,
237
+ )
238
+
239
+ # Include thinking parameter based on user choice
240
+ result = inferencer(image=image, text=prompt, think=show_thinking, **inference_hyper)
241
+ return result["image"], result.get("text", "")
242
+
243
+
244
+ # Helper function to load example images
245
+ def load_example_image(image_path):
246
+ try:
247
+ return Image.open(image_path)
248
+ except Exception as e:
249
+ print(f"Error loading example image: {e}")
250
+ return None
251
+
252
+
253
+ # Gradio UI
254
+ with gr.Blocks() as demo:
255
+ gr.Markdown("""
256
+ <div>
257
+ <img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/>
258
+ </div>
259
+ """)
260
+
261
+ with gr.Tab("📝 Text to Image"):
262
+ txt_input = gr.Textbox(
263
+ label="Prompt",
264
+ value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
265
+ )
266
+
267
+ with gr.Row():
268
+ show_thinking = gr.Checkbox(label="Thinking", value=False)
269
+
270
+ # Add hyperparameter controls in an accordion
271
+ with gr.Accordion("Inference Hyperparameters", open=False):
272
+ # 参数一排两个布局
273
+ with gr.Group():
274
+ with gr.Row():
275
+ seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1,
276
+ label="Seed", info="0 for random seed, positive for reproducible results")
277
+ image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"],
278
+ value="1:1", label="Image Ratio",
279
+ info="The longer size is fixed to 1024")
280
+
281
+ with gr.Row():
282
+ cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
283
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)")
284
+ cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1,
285
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
286
+
287
+ with gr.Row():
288
+ cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
289
+ value="global", label="CFG Renorm Type",
290
+ info="If the genrated image is blurry, use 'global'")
291
+ cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
292
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
293
+
294
+ with gr.Row():
295
+ num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
296
+ label="Timesteps", info="Total denoising steps")
297
+ timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True,
298
+ label="Timestep Shift", info="Higher values for layout, lower for details")
299
+
300
+ # Thinking parameters in a single row
301
+ thinking_params = gr.Group(visible=False)
302
+ with thinking_params:
303
+ with gr.Row():
304
+ do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
305
+ max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
306
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
307
+ text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
308
+ label="Temperature", info="Controls randomness in text generation")
309
+
310
+ thinking_output = gr.Textbox(label="Thinking Process", visible=False)
311
+ img_output = gr.Image(label="Generated Image")
312
+ gen_btn = gr.Button("Generate", variant="primary")
313
+
314
+ # Dynamically show/hide thinking process box and parameters
315
+ def update_thinking_visibility(show):
316
+ return gr.update(visible=show), gr.update(visible=show)
317
+
318
+ show_thinking.change(
319
+ fn=update_thinking_visibility,
320
+ inputs=[show_thinking],
321
+ outputs=[thinking_output, thinking_params]
322
+ )
323
+
324
+ # Process function based on thinking option and hyperparameters
325
+ def process_text_to_image(prompt, show_thinking, cfg_text_scale,
326
+ cfg_interval, timestep_shift,
327
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
328
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio):
329
+ image, thinking = text_to_image(
330
+ prompt, show_thinking, cfg_text_scale, cfg_interval,
331
+ timestep_shift, num_timesteps,
332
+ cfg_renorm_min, cfg_renorm_type,
333
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio
334
+ )
335
+ return image, thinking if thinking else ""
336
+
337
+ gr.on(
338
+ triggers=[gen_btn.click, txt_input.submit],
339
+ fn=process_text_to_image,
340
+ inputs=[
341
+ txt_input, show_thinking, cfg_text_scale,
342
+ cfg_interval, timestep_shift,
343
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
344
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio
345
+ ],
346
+ outputs=[img_output, thinking_output]
347
+ )
348
+
349
+ with gr.Tab("🖌️ Image Edit"):
350
+ with gr.Row():
351
+ with gr.Column(scale=1):
352
+ edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'))
353
+ edit_prompt = gr.Textbox(
354
+ label="Prompt",
355
+ value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes."
356
+ )
357
+
358
+ with gr.Column(scale=1):
359
+ edit_image_output = gr.Image(label="Result")
360
+ edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
361
+
362
+ with gr.Row():
363
+ edit_show_thinking = gr.Checkbox(label="Thinking", value=False)
364
+
365
+ # Add hyperparameter controls in an accordion
366
+ with gr.Accordion("Inference Hyperparameters", open=False):
367
+ with gr.Group():
368
+ with gr.Row():
369
+ edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True,
370
+ label="Seed", info="0 for random seed, positive for reproducible results")
371
+ edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
372
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt")
373
+
374
+ with gr.Row():
375
+ edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True,
376
+ label="CFG Image Scale", info="Controls how much the model preserves input image details")
377
+ edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
378
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
379
+
380
+ with gr.Row():
381
+ edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
382
+ value="text_channel", label="CFG Renorm Type",
383
+ info="If the genrated image is blurry, use 'global")
384
+ edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
385
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
386
+
387
+ with gr.Row():
388
+ edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
389
+ label="Timesteps", info="Total denoising steps")
390
+ edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True,
391
+ label="Timestep Shift", info="Higher values for layout, lower for details")
392
+
393
+
394
+ # Thinking parameters in a single row
395
+ edit_thinking_params = gr.Group(visible=False)
396
+ with edit_thinking_params:
397
+ with gr.Row():
398
+ edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
399
+ edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
400
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
401
+ edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
402
+ label="Temperature", info="Controls randomness in text generation")
403
+
404
+ edit_btn = gr.Button("Submit", variant="primary")
405
+
406
+ # Dynamically show/hide thinking process box for editing
407
+ def update_edit_thinking_visibility(show):
408
+ return gr.update(visible=show), gr.update(visible=show)
409
+
410
+ edit_show_thinking.change(
411
+ fn=update_edit_thinking_visibility,
412
+ inputs=[edit_show_thinking],
413
+ outputs=[edit_thinking_output, edit_thinking_params]
414
+ )
415
+
416
+ # Process editing with thinking option and hyperparameters
417
+ def process_edit_image(image, prompt, show_thinking, cfg_text_scale,
418
+ cfg_img_scale, cfg_interval,
419
+ timestep_shift, num_timesteps, cfg_renorm_min,
420
+ cfg_renorm_type, max_think_token_n, do_sample,
421
+ text_temperature, seed):
422
+ edited_image, thinking = edit_image(
423
+ image, prompt, show_thinking, cfg_text_scale, cfg_img_scale,
424
+ cfg_interval, timestep_shift,
425
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
426
+ max_think_token_n, do_sample, text_temperature, seed
427
+ )
428
+
429
+ return edited_image, thinking if thinking else ""
430
+
431
+ gr.on(
432
+ triggers=[edit_btn.click, edit_prompt.submit],
433
+ fn=process_edit_image,
434
+ inputs=[
435
+ edit_image_input, edit_prompt, edit_show_thinking,
436
+ edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval,
437
+ edit_timestep_shift, edit_num_timesteps,
438
+ edit_cfg_renorm_min, edit_cfg_renorm_type,
439
+ edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed
440
+ ],
441
+ outputs=[edit_image_output, edit_thinking_output]
442
+ )
443
+
444
+ with gr.Tab("🖼️ Image Understanding"):
445
+ with gr.Row():
446
+ with gr.Column(scale=1):
447
+ img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'))
448
+ understand_prompt = gr.Textbox(
449
+ label="Prompt",
450
+ value="Can someone explain what's funny about this meme??"
451
+ )
452
+
453
+ with gr.Column(scale=1):
454
+ txt_output = gr.Textbox(label="Result", lines=20)
455
+
456
+ with gr.Row():
457
+ understand_show_thinking = gr.Checkbox(label="Thinking", value=False)
458
+
459
+ # Add hyperparameter controls in an accordion
460
+ with gr.Accordion("Inference Hyperparameters", open=False):
461
+ with gr.Row():
462
+ understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
463
+ understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True,
464
+ label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)")
465
+ understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True,
466
+ label="Max New Tokens", info="Maximum length of generated text, including potential thinking")
467
+
468
+ img_understand_btn = gr.Button("Submit", variant="primary")
469
+
470
+ # Process understanding with thinking option and hyperparameters
471
+ def process_understanding(image, prompt, show_thinking, do_sample,
472
+ text_temperature, max_new_tokens):
473
+ result = image_understanding(
474
+ image, prompt, show_thinking, do_sample,
475
+ text_temperature, max_new_tokens
476
+ )
477
+ return result
478
+
479
+ gr.on(
480
+ triggers=[img_understand_btn.click, understand_prompt.submit],
481
+ fn=process_understanding,
482
+ inputs=[
483
+ img_input, understand_prompt, understand_show_thinking,
484
+ understand_do_sample, understand_text_temperature, understand_max_new_tokens
485
+ ],
486
+ outputs=txt_output
487
+ )
488
+
489
+ gr.Markdown("""
490
+ <div style="display: flex; justify-content: flex-start; flex-wrap: wrap; gap: 10px;">
491
+ <a href="https://bagel-ai.org/">
492
+ <img
493
+ src="https://img.shields.io/badge/BAGEL-Website-0A66C2?logo=safari&logoColor=white"
494
+ alt="BAGEL Website"
495
+ />
496
+ </a>
497
+ <a href="https://arxiv.org/abs/2505.14683">
498
+ <img
499
+ src="https://img.shields.io/badge/BAGEL-Paper-red?logo=arxiv&logoColor=red"
500
+ alt="BAGEL Paper on arXiv"
501
+ />
502
+ </a>
503
+ <a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
504
+ <img
505
+ src="https://img.shields.io/badge/BAGEL-Hugging%20Face-orange?logo=huggingface&logoColor=yellow"
506
+ alt="BAGEL on Hugging Face"
507
+ />
508
+ </a>
509
+ <a href="https://demo.bagel-ai.org/">
510
+ <img
511
+ src="https://img.shields.io/badge/BAGEL-Demo-blue?logo=googleplay&logoColor=blue"
512
+ alt="BAGEL Demo"
513
+ />
514
+ </a>
515
+ <a href="https://discord.gg/Z836xxzy">
516
+ <img
517
+ src="https://img.shields.io/badge/BAGEL-Discord-5865F2?logo=discord&logoColor=purple"
518
+ alt="BAGEL Discord"
519
+ />
520
+ </a>
521
+ <a href="mailto:[email protected]">
522
+ <img
523
+ src="https://img.shields.io/badge/BAGEL-Email-D14836?logo=gmail&logoColor=red"
524
+ alt="BAGEL Email"
525
+ />
526
+ </a>
527
+ </div>
528
+ """)
529
+
530
+ demo.launch()
data/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
data/data_utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ import math
6
+ import random
7
+ from PIL import Image
8
+
9
+ import torch
10
+ from torch.nn.attention.flex_attention import or_masks, and_masks
11
+
12
+
13
+ def create_sparse_mask(document_lens, split_lens, attn_modes, device):
14
+ def causal_mask(b, h, q_idx, kv_idx):
15
+ return q_idx >= kv_idx
16
+
17
+ def full_and_noise_mask(b, h, q_idx, kv_idx):
18
+ return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (full_and_noise_seq_id[q_idx] >= 0)
19
+
20
+ def remove_noise_mask(b, h, q_idx, kv_idx):
21
+ return (~((noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])))
22
+
23
+ def sample_mask(b, h, q_idx, kv_idx):
24
+ return document_id[q_idx] == document_id[kv_idx]
25
+
26
+ full_and_noise_tmp = []
27
+ noise_tmp = []
28
+
29
+ for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
30
+ value = i if model in ['full', 'noise'] else -1
31
+ full_and_noise_tmp.extend([value] * length)
32
+ value_noise = i if model == 'noise' else -1
33
+ noise_tmp.extend([value_noise] * length)
34
+
35
+ full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
36
+ noise_seq_id = torch.Tensor(noise_tmp).to(device)
37
+
38
+ document_id = torch.cat([torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]).to(device)
39
+
40
+ return and_masks(or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask)
41
+
42
+
43
+ def patchify(image, patch_size):
44
+ p = patch_size
45
+ c, h, w = image.shape
46
+ assert h % p == 0 and w % p == 0
47
+ image = image.reshape(c, h // p, p, w // p, p)
48
+ image = torch.einsum("chpwq->hwpqc", image)
49
+ image = image.reshape(-1, p**2 * c)
50
+ return image
51
+
52
+
53
+ def get_flattened_position_ids_extrapolate(img_h, img_w, patch_size, max_num_patches_per_side):
54
+ num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
55
+ coords_h = torch.arange(0, num_patches_h)
56
+ coords_w = torch.arange(0, num_patches_w)
57
+ pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
58
+ return pos_ids
59
+
60
+
61
+ def get_flattened_position_ids_interpolate(img_h, img_w, patch_size, max_num_patches_per_side):
62
+ num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
63
+ boundaries = torch.arange(1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side)
64
+ fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
65
+ fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
66
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
67
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
68
+ pos_ids = (bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w).flatten()
69
+ return pos_ids
70
+
71
+
72
+ def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
73
+ """
74
+ nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
75
+ a sample, where each sample contains multiple splits with different attn modes.
76
+ nested_attn_modes: whether to use full attn in each split.
77
+ """
78
+ sample_len = sum(split_lens)
79
+ attention_mask = torch.zeros((sample_len, sample_len), dtype=torch.bool, device=device)
80
+
81
+ csum = 0
82
+ for s, attn_mode in zip(split_lens, attn_modes):
83
+ assert attn_mode in ['causal', 'full', 'noise']
84
+ if attn_mode == "causal":
85
+ attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s), device=device).tril()
86
+ attention_mask[csum:csum + s, :csum] = 1
87
+ else:
88
+ attention_mask[csum:csum + s, csum:csum + s] = torch.ones((s, s))
89
+ attention_mask[csum:csum + s, :csum] = 1
90
+ csum += s
91
+
92
+ csum = 0
93
+ for s, attn_mode in zip(split_lens, attn_modes):
94
+ if attn_mode == "noise":
95
+ attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
96
+ attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
97
+ csum += s
98
+
99
+ attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
100
+ ~attention_mask, float("-inf")
101
+ )
102
+
103
+ return attention_mask
104
+
105
+
106
+ def split_integer_exp_decay(S, ng_sample_decay=1.0):
107
+ if ng_sample_decay == 1.0:
108
+ N = random.randint(1, S)
109
+ else:
110
+ base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
111
+ p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
112
+ N = random.choices(list(range(1, S + 1)), p, k=1)[0]
113
+ cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
114
+ result = [cumsum[i+1] - cumsum[i] for i in range(len(cumsum) - 1)]
115
+ return result, cumsum
116
+
117
+
118
+ def pil_img2rgb(image):
119
+ if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
120
+ image = image.convert("RGBA")
121
+ white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
122
+ white.paste(image, mask=image.split()[3])
123
+ image = white
124
+ else:
125
+ image = image.convert("RGB")
126
+
127
+ return image
128
+
129
+
130
+ def add_special_tokens(tokenizer):
131
+ all_special_tokens = []
132
+ for k, v in tokenizer.special_tokens_map.items():
133
+ if isinstance(v, str):
134
+ all_special_tokens.append(v)
135
+ elif isinstance(v, list):
136
+ all_special_tokens += v
137
+
138
+ new_tokens = []
139
+
140
+ if '<|im_start|>' not in all_special_tokens:
141
+ new_tokens.append('<|im_start|>')
142
+
143
+ if '<|im_end|>' not in all_special_tokens:
144
+ new_tokens.append('<|im_end|>')
145
+
146
+ if '<|vision_start|>' not in all_special_tokens:
147
+ new_tokens.append('<|vision_start|>')
148
+
149
+ if '<|vision_end|>' not in all_special_tokens:
150
+ new_tokens.append('<|vision_end|>')
151
+
152
+ num_new_tokens = tokenizer.add_tokens(new_tokens)
153
+ bos_token_id = tokenizer.convert_tokens_to_ids('<|im_start|>')
154
+ eos_token_id = tokenizer.convert_tokens_to_ids('<|im_end|>')
155
+ start_of_image = tokenizer.convert_tokens_to_ids('<|vision_start|>')
156
+ end_of_image = tokenizer.convert_tokens_to_ids('<|vision_end|>')
157
+
158
+ new_token_ids = dict(
159
+ bos_token_id=bos_token_id,
160
+ eos_token_id=eos_token_id,
161
+ start_of_image=start_of_image,
162
+ end_of_image=end_of_image,
163
+ )
164
+
165
+ return tokenizer, new_token_ids, num_new_tokens
166
+
167
+
168
+ def len2weight(x, loss_reduction='square'):
169
+ if x == 0:
170
+ return x
171
+ if loss_reduction == 'token':
172
+ return 1
173
+ if loss_reduction == 'sample':
174
+ return 1 / x
175
+ if loss_reduction == 'square':
176
+ return 1 / (x ** 0.5)
177
+ raise NotImplementedError(loss_reduction)
data/transforms.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import random
5
+ from PIL import Image
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+ from torchvision import transforms
11
+ from torchvision.transforms import functional as F
12
+ from torchvision.transforms import InterpolationMode
13
+
14
+
15
+ class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
16
+ """Resize the input image so that its longest side and shortest side are within a specified range,
17
+ ensuring that both sides are divisible by a specified stride.
18
+
19
+ Args:
20
+ max_size (int): Maximum size for the longest edge of the image.
21
+ min_size (int): Minimum size for the shortest edge of the image.
22
+ stride (int): Value by which the height and width of the image must be divisible.
23
+ max_pixels (int): Maximum pixels for the full image.
24
+ interpolation (InterpolationMode): Desired interpolation enum defined by
25
+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
26
+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
27
+ ``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
28
+ The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
29
+ antialias (bool, optional): Whether to apply antialiasing (default is True).
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ max_size: int,
35
+ min_size: int,
36
+ stride: int,
37
+ max_pixels: int,
38
+ interpolation=InterpolationMode.BICUBIC,
39
+ antialias=True
40
+ ):
41
+ super().__init__()
42
+ self.max_size = max_size
43
+ self.min_size = min_size
44
+ self.stride = stride
45
+ self.max_pixels = max_pixels
46
+ self.interpolation = interpolation
47
+ self.antialias = antialias
48
+
49
+ def _make_divisible(self, value, stride):
50
+ """Ensure the value is divisible by the stride."""
51
+ return max(stride, int(round(value / stride) * stride))
52
+
53
+ def _apply_scale(self, width, height, scale):
54
+ new_width = round(width * scale)
55
+ new_height = round(height * scale)
56
+ new_width = self._make_divisible(new_width, self.stride)
57
+ new_height = self._make_divisible(new_height, self.stride)
58
+ return new_width, new_height
59
+
60
+ def forward(self, img, img_num=1):
61
+ """
62
+ Args:
63
+ img (PIL Image): Image to be resized.
64
+ img_num (int): Number of images, used to change max_tokens.
65
+ Returns:
66
+ PIL Image or Tensor: Rescaled image with divisible dimensions.
67
+ """
68
+ if isinstance(img, torch.Tensor):
69
+ height, width = img.shape[-2:]
70
+ else:
71
+ width, height = img.size
72
+
73
+ scale = min(self.max_size / max(width, height), 1.0)
74
+ scale = max(scale, self.min_size / min(width, height))
75
+ new_width, new_height = self._apply_scale(width, height, scale)
76
+
77
+ # Ensure the number of pixels does not exceed max_pixels
78
+ if new_width * new_height > self.max_pixels / img_num:
79
+ scale = self.max_pixels / img_num / (new_width * new_height)
80
+ new_width, new_height = self._apply_scale(new_width, new_height, scale)
81
+
82
+ # Ensure longest edge does not exceed max_size
83
+ if max(new_width, new_height) > self.max_size:
84
+ scale = self.max_size / max(new_width, new_height)
85
+ new_width, new_height = self._apply_scale(new_width, new_height, scale)
86
+
87
+ return F.resize(img, (new_height, new_width), self.interpolation, antialias=self.antialias)
88
+
89
+
90
+ class ImageTransform:
91
+ def __init__(
92
+ self,
93
+ max_image_size,
94
+ min_image_size,
95
+ image_stride,
96
+ max_pixels=14*14*9*1024,
97
+ image_mean=[0.5, 0.5, 0.5],
98
+ image_std=[0.5, 0.5, 0.5]
99
+ ):
100
+ self.stride = image_stride
101
+
102
+ self.resize_transform = MaxLongEdgeMinShortEdgeResize(
103
+ max_size=max_image_size,
104
+ min_size=min_image_size,
105
+ stride=image_stride,
106
+ max_pixels=max_pixels,
107
+ )
108
+ self.to_tensor_transform = transforms.ToTensor()
109
+ self.normalize_transform = transforms.Normalize(mean=image_mean, std=image_std, inplace=True)
110
+
111
+ def __call__(self, img, img_num=1):
112
+ img = self.resize_transform(img, img_num=img_num)
113
+ img = self.to_tensor_transform(img)
114
+ img = self.normalize_transform(img)
115
+ return img
116
+
117
+
118
+ def decolorization(image):
119
+ gray_image = image.convert('L')
120
+ return Image.merge(image.mode, [gray_image] * 3) if image.mode in ('RGB', 'L') else gray_image
121
+
122
+
123
+ def downscale(image, scale_factor):
124
+ new_width = int(round(image.width * scale_factor))
125
+ new_height = int(round(image.height * scale_factor))
126
+ new_width = max(1, new_width)
127
+ new_height = max(1, new_height)
128
+ return image.resize((new_width, new_height), resample=Image.BICUBIC)
129
+
130
+
131
+ def crop(image, crop_factors):
132
+ target_h, target_w = crop_factors
133
+ img_w, img_h = image.size
134
+
135
+ if target_h > img_h or target_w > img_w:
136
+ raise ValueError("Crop size exceeds image dimensions")
137
+
138
+ x = random.randint(0, img_w - target_w)
139
+ y = random.randint(0, img_h - target_h)
140
+
141
+ return image.crop((x, y, x + target_w, y + target_h)), [[x, y], [x + target_w, y + target_h]]
142
+
143
+
144
+ def motion_blur_opencv(image, kernel_size=15, angle=0):
145
+ # 线性核
146
+ kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
147
+ kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
148
+
149
+ # 旋转核
150
+ center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
151
+ M = cv2.getRotationMatrix2D(center, angle, 1)
152
+ rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
153
+
154
+ # 归一化核
155
+ rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
156
+
157
+ img = np.array(image)
158
+ if img.ndim == 2:
159
+ blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
160
+ else:
161
+ # 对于彩色图像,各通道独立卷积
162
+ blurred = np.zeros_like(img)
163
+ for c in range(img.shape[2]):
164
+ blurred[..., c] = cv2.filter2D(img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
165
+
166
+ return Image.fromarray(blurred.astype(np.uint8))
167
+
168
+
169
+ def shuffle_patch(image, num_splits, gap_size=2):
170
+ """将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
171
+ h_splits, w_splits = num_splits
172
+ img_w, img_h = image.size
173
+
174
+ base_patch_h = img_h // h_splits
175
+ patch_heights = [base_patch_h] * (h_splits - 1)
176
+ patch_heights.append(img_h - sum(patch_heights))
177
+
178
+ base_patch_w = img_w // w_splits
179
+ patch_widths = [base_patch_w] * (w_splits - 1)
180
+ patch_widths.append(img_w - sum(patch_widths))
181
+
182
+ patches = []
183
+ current_y = 0
184
+ for i in range(h_splits):
185
+ current_x = 0
186
+ patch_h = patch_heights[i]
187
+ for j in range(w_splits):
188
+ patch_w = patch_widths[j]
189
+ patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
190
+ patches.append(patch)
191
+ current_x += patch_w
192
+ current_y += patch_h
193
+
194
+ random.shuffle(patches)
195
+
196
+ total_width = sum(patch_widths) + (w_splits - 1) * gap_size
197
+ total_height = sum(patch_heights) + (h_splits - 1) * gap_size
198
+ new_image = Image.new(image.mode, (total_width, total_height), color=(255, 255, 255))
199
+
200
+ current_y = 0 # 当前行的起始 Y 坐标
201
+ patch_idx = 0 # 当前处理的块索引
202
+ for i in range(h_splits):
203
+ current_x = 0 # 当前列的起始 X 坐标
204
+ patch_h = patch_heights[i] # 当前行块的高度
205
+ for j in range(w_splits):
206
+ # 取出打乱后的块
207
+ patch = patches[patch_idx]
208
+ patch_w = patch_widths[j] # 当前列块的宽度
209
+ # 粘贴块(左上角坐标为 (current_x, current_y))
210
+ new_image.paste(patch, (current_x, current_y))
211
+ # 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
212
+ current_x += patch_w + gap_size
213
+ patch_idx += 1
214
+ # 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
215
+ current_y += patch_h + gap_size
216
+
217
+ return new_image
218
+
219
+
220
+ def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
221
+ """
222
+ 图像分割后随机空白部分patch,用于inpainting任务
223
+
224
+ 参数:
225
+ image: PIL.Image 输入图像(RGB模式)
226
+ h_splits: int 行分割数(垂直方向分割块数)
227
+ w_splits: int 列分割数(水平方向分割块数)
228
+ blank_ratio: float 空白patch的比例(0~1)
229
+ blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
230
+
231
+ 返回:
232
+ PIL.Image 处理后拼接的图像
233
+ """
234
+ h_splits, w_splits = num_splits
235
+ img_w, img_h = image.size
236
+
237
+ base_patch_h = img_h // h_splits
238
+ patch_heights = [base_patch_h] * (h_splits - 1)
239
+ patch_heights.append(img_h - sum(patch_heights))
240
+
241
+ base_patch_w = img_w // w_splits
242
+ patch_widths = [base_patch_w] * (w_splits - 1)
243
+ patch_widths.append(img_w - sum(patch_widths))
244
+
245
+ patches = []
246
+ current_y = 0
247
+ for i in range(h_splits):
248
+ current_x = 0
249
+ patch_h = patch_heights[i]
250
+ for j in range(w_splits):
251
+ patch_w = patch_widths[j]
252
+ patch = image.crop((current_x, current_y, current_x + patch_w, current_y + patch_h))
253
+ patches.append(patch)
254
+ current_x += patch_w
255
+ current_y += patch_h
256
+
257
+ total_patches = h_splits * w_splits
258
+ num_blank = int(total_patches * blank_ratio)
259
+ num_blank = max(0, min(num_blank, total_patches))
260
+ blank_indices = random.sample(range(total_patches), num_blank)
261
+
262
+ processed_patches = []
263
+ for idx, patch in enumerate(patches):
264
+ if idx in blank_indices:
265
+ blank_patch = Image.new("RGB", patch.size, color=blank_color)
266
+ processed_patches.append(blank_patch)
267
+ else:
268
+ processed_patches.append(patch)
269
+
270
+ # 创建结果图像(尺寸与原图一致)
271
+ result_image = Image.new("RGB", (img_w, img_h))
272
+ current_y = 0
273
+ patch_idx = 0
274
+ for i in range(h_splits):
275
+ current_x = 0
276
+ patch_h = patch_heights[i]
277
+ for j in range(w_splits):
278
+ # 取出处理后的patch
279
+ patch = processed_patches[patch_idx]
280
+ patch_w = patch_widths[j]
281
+ # 粘贴到原位置
282
+ result_image.paste(patch, (current_x, current_y))
283
+ current_x += patch_w
284
+ patch_idx += 1
285
+ current_y += patch_h
286
+
287
+ return result_image
inferencer.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from copy import deepcopy
5
+ from typing import List, Dict, Tuple, Optional, Union, Any
6
+ import matplotlib.pyplot as plt
7
+
8
+ from PIL import Image
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.nn.attention.flex_attention import create_block_mask
13
+ from transformers.configuration_utils import PretrainedConfig
14
+ from transformers.modeling_utils import PreTrainedModel
15
+
16
+ from data.data_utils import pil_img2rgb
17
+ from modeling.bagel.qwen2_navit import NaiveCache
18
+
19
+
20
+
21
+ VLM_THINK_SYSTEM_PROMPT = '''You should first think about the reasoning process in the mind and then provide the user with the answer.
22
+ The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here'''
23
+
24
+ GEN_THINK_SYSTEM_PROMPT = '''You should first think about the planning process in the mind and then generate the image.
25
+ The planning process is enclosed within <think> </think> tags, i.e. <think> planning process here </think> image here'''
26
+
27
+
28
+ class InterleaveInferencer:
29
+ def __init__(self, model, vae_model, tokenizer, vae_transform, vit_transform, new_token_ids):
30
+ self.model = model
31
+ self.vae_model = vae_model
32
+ self.tokenizer = tokenizer
33
+ self.vae_transform = vae_transform
34
+ self.vit_transform = vit_transform
35
+ self.new_token_ids = new_token_ids
36
+
37
+ def init_gen_context(self):
38
+ gen_context = {
39
+ 'kv_lens': [0],
40
+ 'ropes': [0],
41
+ 'past_key_values': NaiveCache(self.model.config.llm_config.num_hidden_layers),
42
+ }
43
+ return gen_context
44
+
45
+ @torch.no_grad()
46
+ def update_context_text(self, text, gen_context):
47
+ # used for interleave data, currently only support 1 data inference,
48
+
49
+ past_key_values = gen_context['past_key_values']
50
+ kv_lens = gen_context['kv_lens']
51
+ ropes = gen_context['ropes']
52
+ generation_input, kv_lens, ropes = self.model.prepare_prompts(
53
+ curr_kvlens=kv_lens,
54
+ curr_rope=ropes,
55
+ prompts=[text],
56
+ tokenizer=self.tokenizer,
57
+ new_token_ids=self.new_token_ids,
58
+ )
59
+
60
+ past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
61
+ gen_context['kv_lens'] = kv_lens
62
+ gen_context['ropes'] = ropes
63
+ gen_context['past_key_values'] = past_key_values
64
+
65
+ return gen_context
66
+
67
+ @torch.no_grad()
68
+ def update_context_image(self, image, gen_context, vae=True, vit=True):
69
+ # used for interleave data, currently only support 1 data inference,
70
+
71
+ assert vae or vit
72
+ past_key_values = gen_context['past_key_values']
73
+ kv_lens = gen_context['kv_lens']
74
+ ropes = gen_context['ropes']
75
+
76
+ if vae:
77
+ ## update vae
78
+ generation_input, kv_lens, ropes = self.model.prepare_vae_images(
79
+ curr_kvlens=kv_lens,
80
+ curr_rope=ropes,
81
+ images=[image],
82
+ transforms=self.vae_transform,
83
+ new_token_ids=self.new_token_ids,
84
+ )
85
+ past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
86
+
87
+ if vit:
88
+ ## update vit
89
+ generation_input, kv_lens, ropes = self.model.prepare_vit_images(
90
+ curr_kvlens=kv_lens,
91
+ curr_rope=ropes,
92
+ images=[image],
93
+ transforms=self.vit_transform,
94
+ new_token_ids=self.new_token_ids,
95
+ )
96
+ past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
97
+
98
+ gen_context['kv_lens'] = kv_lens
99
+ gen_context['ropes'] = ropes
100
+ gen_context['past_key_values'] = past_key_values
101
+
102
+ return gen_context
103
+
104
+ @torch.no_grad()
105
+ def gen_image(
106
+ self,
107
+ image_shape,
108
+ gen_context,
109
+ cfg_text_scale=4.0,
110
+ cfg_img_scale=1.5,
111
+
112
+ cfg_text_precontext=None,
113
+ cfg_img_precontext=None,
114
+ cfg_interval=(0.4, 1.0),
115
+ cfg_renorm_min=0.0,
116
+ cfg_renorm_type="global",
117
+
118
+ num_timesteps=50,
119
+ timestep_shift=3.0
120
+ ):
121
+ # print(cfg_renorm_type)
122
+ past_key_values = gen_context['past_key_values']
123
+ kv_lens = gen_context['kv_lens']
124
+ ropes = gen_context['ropes']
125
+ generation_input = self.model.prepare_vae_latent(
126
+ curr_kvlens=kv_lens,
127
+ curr_rope=ropes,
128
+ image_sizes=[image_shape],
129
+ new_token_ids=self.new_token_ids,
130
+ )
131
+
132
+ # text cfg
133
+ cfg_text_past_key_values = cfg_text_precontext['past_key_values']
134
+ kv_lens_cfg = cfg_text_precontext['kv_lens']
135
+ ropes_cfg = cfg_text_precontext['ropes']
136
+ generation_input_cfg_text = self.model.prepare_vae_latent_cfg(
137
+ curr_kvlens=kv_lens_cfg,
138
+ curr_rope=ropes_cfg,
139
+ image_sizes=[image_shape],
140
+ )
141
+
142
+ # img cfg
143
+ cfg_img_past_key_values = cfg_img_precontext['past_key_values']
144
+ kv_lens_cfg = cfg_img_precontext['kv_lens']
145
+ ropes_cfg = cfg_img_precontext['ropes']
146
+ generation_input_cfg_img = self.model.prepare_vae_latent_cfg(
147
+ curr_kvlens=kv_lens_cfg,
148
+ curr_rope=ropes_cfg,
149
+ image_sizes=[image_shape],
150
+ )
151
+
152
+ unpacked_latent = self.model.generate_image(
153
+ past_key_values=past_key_values,
154
+ cfg_text_past_key_values=cfg_text_past_key_values,
155
+ cfg_img_past_key_values=cfg_img_past_key_values,
156
+ num_timesteps=num_timesteps,
157
+ cfg_text_scale=cfg_text_scale,
158
+ cfg_img_scale=cfg_img_scale,
159
+ cfg_interval=cfg_interval,
160
+ cfg_renorm_min=cfg_renorm_min,
161
+ cfg_renorm_type=cfg_renorm_type,
162
+ timestep_shift=timestep_shift,
163
+ **generation_input,
164
+ cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
165
+ cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
166
+ cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
167
+ cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
168
+ cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
169
+ cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
170
+ cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
171
+ cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
172
+ )
173
+
174
+ image = self.decode_image(unpacked_latent[0], image_shape)
175
+ return image
176
+
177
+
178
+ def decode_image(self, latent, image_shape):
179
+ H, W = image_shape
180
+ h, w = H // self.model.latent_downsample, W // self.model.latent_downsample
181
+
182
+ latent = latent.reshape(1, h, w, self.model.latent_patch_size, self.model.latent_patch_size, self.model.latent_channel)
183
+ latent = torch.einsum("nhwpqc->nchpwq", latent)
184
+ latent = latent.reshape(1, self.model.latent_channel, h * self.model.latent_patch_size, w * self.model.latent_patch_size)
185
+ image = self.vae_model.decode(latent)
186
+ image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255
187
+ image = Image.fromarray((image).to(torch.uint8).cpu().numpy())
188
+
189
+ return image
190
+
191
+ @torch.no_grad()
192
+ def gen_text(self, gen_context, max_length: int = 500, do_sample: bool = True, temperature: float = 1.0):
193
+ gen_context = deepcopy(gen_context)
194
+ past_key_values = gen_context['past_key_values']
195
+ kv_lens = gen_context['kv_lens']
196
+ ropes = gen_context['ropes']
197
+
198
+ generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
199
+ unpacked_latent = self.model.generate_text(
200
+ past_key_values=past_key_values,
201
+ max_length=max_length,
202
+ do_sample=do_sample,
203
+ temperature=temperature,
204
+ end_token_id=self.new_token_ids['eos_token_id'],
205
+ **generation_input,
206
+ )
207
+ output = self.tokenizer.decode(unpacked_latent[:,0])
208
+ output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
209
+ return output
210
+
211
+ @torch.no_grad()
212
+ def interleave_inference(
213
+ self,
214
+ input_lists: List[Union[str, Image.Image]],
215
+ think=False,
216
+ understanding_output=False,
217
+
218
+ max_think_token_n=1000,
219
+ do_sample=False,
220
+ text_temperature=0.3,
221
+ cfg_text_scale=3.0,
222
+ cfg_img_scale=1.5,
223
+ cfg_interval=[0.4, 1.0],
224
+ timestep_shift=3.0,
225
+ num_timesteps=50,
226
+ cfg_renorm_min=0.0,
227
+ cfg_renorm_type="global",
228
+ image_shapes=(1024, 1024),
229
+ ) -> List[Union[str, Image.Image]]:
230
+
231
+ output_list = []
232
+ gen_context = self.init_gen_context()
233
+ cfg_text_context = deepcopy(gen_context)
234
+ cfg_img_context = deepcopy(gen_context)
235
+
236
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
237
+ if think:
238
+ if understanding_output:
239
+ system_prompt = VLM_THINK_SYSTEM_PROMPT
240
+ else:
241
+ system_prompt = GEN_THINK_SYSTEM_PROMPT
242
+ gen_context = self.update_context_text(system_prompt, gen_context)
243
+ cfg_img_context = self.update_context_text(system_prompt, cfg_img_context)
244
+
245
+ for input_term in input_lists:
246
+ if isinstance(input_term, str):
247
+ cfg_text_context = deepcopy(gen_context)
248
+ gen_context = self.update_context_text(input_term, gen_context)
249
+ cfg_img_context = self.update_context_text(input_term, cfg_img_context)
250
+
251
+ elif isinstance(input_term, Image.Image):
252
+ input_term = self.vae_transform.resize_transform(pil_img2rgb(input_term))
253
+ gen_context = self.update_context_image(input_term, gen_context, vae=not understanding_output)
254
+
255
+ image_shapes = input_term.size[::-1]
256
+ cfg_text_context = deepcopy(gen_context)
257
+
258
+ else:
259
+ raise ValueError(f"Unsupported input type: {type(input_term)}")
260
+
261
+ if understanding_output:
262
+ gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n)
263
+ output_list.append(gen_text)
264
+
265
+ else:
266
+ if think:
267
+ gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n)
268
+ gen_context = self.update_context_text(gen_text, gen_context)
269
+ output_list.append(gen_text)
270
+
271
+ img = self.gen_image(
272
+ image_shapes,
273
+ gen_context,
274
+ cfg_text_precontext=cfg_text_context,
275
+ cfg_img_precontext=cfg_img_context,
276
+
277
+ cfg_text_scale=cfg_text_scale,
278
+ cfg_img_scale=cfg_img_scale,
279
+ cfg_interval=cfg_interval,
280
+ timestep_shift=timestep_shift,
281
+ num_timesteps=num_timesteps,
282
+ cfg_renorm_min=cfg_renorm_min,
283
+ cfg_renorm_type=cfg_renorm_type,
284
+ )
285
+
286
+ output_list.append(img)
287
+
288
+ return output_list
289
+
290
+ def __call__(
291
+ self,
292
+ image: Optional[Image.Image] = None,
293
+ text: Optional[str] = None,
294
+ **kargs
295
+ ) -> Dict[str, Any]:
296
+ output_dict = {'image': None, 'text': None}
297
+
298
+ if image is None and text is None:
299
+ print('Please provide at least one input: either an image or text.')
300
+ return output_dict
301
+
302
+ input_list = []
303
+ if image is not None:
304
+ input_list.append(image)
305
+ if text is not None:
306
+ input_list.append(text)
307
+
308
+ output_list = self.interleave_inference(input_list, **kargs)
309
+
310
+ for i in output_list:
311
+ if isinstance(i, Image.Image):
312
+ output_dict['image'] = i
313
+ elif isinstance(i, str):
314
+ output_dict['text'] = i
315
+ return output_dict
modeling/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from . import bagel, qwen2, siglip, autoencoder
modeling/autoencoder.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Black Forest Labs.
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under Apache-2.0, with the full license text
8
+ # available at https://github.com/black-forest-labs/flux/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ from einops import rearrange
16
+ from torch import Tensor, nn
17
+ from huggingface_hub import hf_hub_download
18
+ from safetensors.torch import load_file as load_sft
19
+
20
+
21
+ @dataclass
22
+ class AutoEncoderParams:
23
+ resolution: int
24
+ in_channels: int
25
+ downsample: int
26
+ ch: int
27
+ out_ch: int
28
+ ch_mult: list[int]
29
+ num_res_blocks: int
30
+ z_channels: int
31
+ scale_factor: float
32
+ shift_factor: float
33
+
34
+
35
+ def swish(x: Tensor) -> Tensor:
36
+ return x * torch.sigmoid(x)
37
+
38
+
39
+ class AttnBlock(nn.Module):
40
+ def __init__(self, in_channels: int):
41
+ super().__init__()
42
+ self.in_channels = in_channels
43
+
44
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
45
+
46
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
47
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
48
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
49
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
50
+
51
+ def attention(self, h_: Tensor) -> Tensor:
52
+ h_ = self.norm(h_)
53
+ q = self.q(h_)
54
+ k = self.k(h_)
55
+ v = self.v(h_)
56
+
57
+ b, c, h, w = q.shape
58
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
59
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
60
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
61
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
62
+
63
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
64
+
65
+ def forward(self, x: Tensor) -> Tensor:
66
+ return x + self.proj_out(self.attention(x))
67
+
68
+
69
+ class ResnetBlock(nn.Module):
70
+ def __init__(self, in_channels: int, out_channels: int):
71
+ super().__init__()
72
+ self.in_channels = in_channels
73
+ out_channels = in_channels if out_channels is None else out_channels
74
+ self.out_channels = out_channels
75
+
76
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
77
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
78
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
79
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
80
+ if self.in_channels != self.out_channels:
81
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
82
+
83
+ def forward(self, x):
84
+ h = x
85
+ h = self.norm1(h)
86
+ h = swish(h)
87
+ h = self.conv1(h)
88
+
89
+ h = self.norm2(h)
90
+ h = swish(h)
91
+ h = self.conv2(h)
92
+
93
+ if self.in_channels != self.out_channels:
94
+ x = self.nin_shortcut(x)
95
+
96
+ return x + h
97
+
98
+
99
+ class Downsample(nn.Module):
100
+ def __init__(self, in_channels: int):
101
+ super().__init__()
102
+ # no asymmetric padding in torch conv, must do it ourselves
103
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
104
+
105
+ def forward(self, x: Tensor):
106
+ pad = (0, 1, 0, 1)
107
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
108
+ x = self.conv(x)
109
+ return x
110
+
111
+
112
+ class Upsample(nn.Module):
113
+ def __init__(self, in_channels: int):
114
+ super().__init__()
115
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
116
+
117
+ def forward(self, x: Tensor):
118
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
119
+ x = self.conv(x)
120
+ return x
121
+
122
+
123
+ class Encoder(nn.Module):
124
+ def __init__(
125
+ self,
126
+ resolution: int,
127
+ in_channels: int,
128
+ ch: int,
129
+ ch_mult: list[int],
130
+ num_res_blocks: int,
131
+ z_channels: int,
132
+ ):
133
+ super().__init__()
134
+ self.ch = ch
135
+ self.num_resolutions = len(ch_mult)
136
+ self.num_res_blocks = num_res_blocks
137
+ self.resolution = resolution
138
+ self.in_channels = in_channels
139
+ # downsampling
140
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
141
+
142
+ curr_res = resolution
143
+ in_ch_mult = (1,) + tuple(ch_mult)
144
+ self.in_ch_mult = in_ch_mult
145
+ self.down = nn.ModuleList()
146
+ block_in = self.ch
147
+ for i_level in range(self.num_resolutions):
148
+ block = nn.ModuleList()
149
+ attn = nn.ModuleList()
150
+ block_in = ch * in_ch_mult[i_level]
151
+ block_out = ch * ch_mult[i_level]
152
+ for _ in range(self.num_res_blocks):
153
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
154
+ block_in = block_out
155
+ down = nn.Module()
156
+ down.block = block
157
+ down.attn = attn
158
+ if i_level != self.num_resolutions - 1:
159
+ down.downsample = Downsample(block_in)
160
+ curr_res = curr_res // 2
161
+ self.down.append(down)
162
+
163
+ # middle
164
+ self.mid = nn.Module()
165
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
166
+ self.mid.attn_1 = AttnBlock(block_in)
167
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
168
+
169
+ # end
170
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
171
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
172
+
173
+ def forward(self, x: Tensor) -> Tensor:
174
+ # downsampling
175
+ hs = [self.conv_in(x)]
176
+ for i_level in range(self.num_resolutions):
177
+ for i_block in range(self.num_res_blocks):
178
+ h = self.down[i_level].block[i_block](hs[-1])
179
+ if len(self.down[i_level].attn) > 0:
180
+ h = self.down[i_level].attn[i_block](h)
181
+ hs.append(h)
182
+ if i_level != self.num_resolutions - 1:
183
+ hs.append(self.down[i_level].downsample(hs[-1]))
184
+
185
+ # middle
186
+ h = hs[-1]
187
+ h = self.mid.block_1(h)
188
+ h = self.mid.attn_1(h)
189
+ h = self.mid.block_2(h)
190
+ # end
191
+ h = self.norm_out(h)
192
+ h = swish(h)
193
+ h = self.conv_out(h)
194
+ return h
195
+
196
+
197
+ class Decoder(nn.Module):
198
+ def __init__(
199
+ self,
200
+ ch: int,
201
+ out_ch: int,
202
+ ch_mult: list[int],
203
+ num_res_blocks: int,
204
+ in_channels: int,
205
+ resolution: int,
206
+ z_channels: int,
207
+ ):
208
+ super().__init__()
209
+ self.ch = ch
210
+ self.num_resolutions = len(ch_mult)
211
+ self.num_res_blocks = num_res_blocks
212
+ self.resolution = resolution
213
+ self.in_channels = in_channels
214
+ self.ffactor = 2 ** (self.num_resolutions - 1)
215
+
216
+ # compute in_ch_mult, block_in and curr_res at lowest res
217
+ block_in = ch * ch_mult[self.num_resolutions - 1]
218
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
219
+ self.z_shape = (1, z_channels, curr_res, curr_res)
220
+
221
+ # z to block_in
222
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
223
+
224
+ # middle
225
+ self.mid = nn.Module()
226
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
227
+ self.mid.attn_1 = AttnBlock(block_in)
228
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
229
+
230
+ # upsampling
231
+ self.up = nn.ModuleList()
232
+ for i_level in reversed(range(self.num_resolutions)):
233
+ block = nn.ModuleList()
234
+ attn = nn.ModuleList()
235
+ block_out = ch * ch_mult[i_level]
236
+ for _ in range(self.num_res_blocks + 1):
237
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
238
+ block_in = block_out
239
+ up = nn.Module()
240
+ up.block = block
241
+ up.attn = attn
242
+ if i_level != 0:
243
+ up.upsample = Upsample(block_in)
244
+ curr_res = curr_res * 2
245
+ self.up.insert(0, up) # prepend to get consistent order
246
+
247
+ # end
248
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
249
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
250
+
251
+ def forward(self, z: Tensor) -> Tensor:
252
+ # z to block_in
253
+ h = self.conv_in(z)
254
+
255
+ # middle
256
+ h = self.mid.block_1(h)
257
+ h = self.mid.attn_1(h)
258
+ h = self.mid.block_2(h)
259
+
260
+ # upsampling
261
+ for i_level in reversed(range(self.num_resolutions)):
262
+ for i_block in range(self.num_res_blocks + 1):
263
+ h = self.up[i_level].block[i_block](h)
264
+ if len(self.up[i_level].attn) > 0:
265
+ h = self.up[i_level].attn[i_block](h)
266
+ if i_level != 0:
267
+ h = self.up[i_level].upsample(h)
268
+
269
+ # end
270
+ h = self.norm_out(h)
271
+ h = swish(h)
272
+ h = self.conv_out(h)
273
+ return h
274
+
275
+
276
+ class DiagonalGaussian(nn.Module):
277
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
278
+ super().__init__()
279
+ self.sample = sample
280
+ self.chunk_dim = chunk_dim
281
+
282
+ def forward(self, z: Tensor) -> Tensor:
283
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
284
+ if self.sample:
285
+ std = torch.exp(0.5 * logvar)
286
+ return mean + std * torch.randn_like(mean)
287
+ else:
288
+ return mean
289
+
290
+
291
+ class AutoEncoder(nn.Module):
292
+ def __init__(self, params: AutoEncoderParams):
293
+ super().__init__()
294
+ self.encoder = Encoder(
295
+ resolution=params.resolution,
296
+ in_channels=params.in_channels,
297
+ ch=params.ch,
298
+ ch_mult=params.ch_mult,
299
+ num_res_blocks=params.num_res_blocks,
300
+ z_channels=params.z_channels,
301
+ )
302
+ self.decoder = Decoder(
303
+ resolution=params.resolution,
304
+ in_channels=params.in_channels,
305
+ ch=params.ch,
306
+ out_ch=params.out_ch,
307
+ ch_mult=params.ch_mult,
308
+ num_res_blocks=params.num_res_blocks,
309
+ z_channels=params.z_channels,
310
+ )
311
+ self.reg = DiagonalGaussian()
312
+
313
+ self.scale_factor = params.scale_factor
314
+ self.shift_factor = params.shift_factor
315
+
316
+ def encode(self, x: Tensor) -> Tensor:
317
+ z = self.reg(self.encoder(x))
318
+ z = self.scale_factor * (z - self.shift_factor)
319
+ return z
320
+
321
+ def decode(self, z: Tensor) -> Tensor:
322
+ z = z / self.scale_factor + self.shift_factor
323
+ return self.decoder(z)
324
+
325
+ def forward(self, x: Tensor) -> Tensor:
326
+ return self.decode(self.encode(x))
327
+
328
+
329
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
330
+ if len(missing) > 0 and len(unexpected) > 0:
331
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
332
+ print("\n" + "-" * 79 + "\n")
333
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
334
+ elif len(missing) > 0:
335
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
336
+ elif len(unexpected) > 0:
337
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
338
+
339
+
340
+ def load_ae(local_path: str) -> AutoEncoder:
341
+ ae_params = AutoEncoderParams(
342
+ resolution=256,
343
+ in_channels=3,
344
+ downsample=8,
345
+ ch=128,
346
+ out_ch=3,
347
+ ch_mult=[1, 2, 4, 4],
348
+ num_res_blocks=2,
349
+ z_channels=16,
350
+ scale_factor=0.3611,
351
+ shift_factor=0.1159,
352
+ )
353
+
354
+ # Loading the autoencoder
355
+ ae = AutoEncoder(ae_params)
356
+
357
+ if local_path is not None:
358
+ sd = load_sft(local_path)
359
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
360
+ print_load_warning(missing, unexpected)
361
+ return ae, ae_params
modeling/bagel/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ from .bagel import BagelConfig, Bagel
6
+ from .qwen2_navit import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
7
+ from .siglip_navit import SiglipVisionConfig, SiglipVisionModel
8
+
9
+
10
+ __all__ = [
11
+ 'BagelConfig',
12
+ 'Bagel',
13
+ 'Qwen2Config',
14
+ 'Qwen2Model',
15
+ 'Qwen2ForCausalLM',
16
+ 'SiglipVisionConfig',
17
+ 'SiglipVisionModel',
18
+ ]
modeling/bagel/bagel.py ADDED
@@ -0,0 +1,1026 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import copy
5
+ from typing import List, Tuple, Optional
6
+ import matplotlib.pyplot as plt
7
+
8
+ from PIL import Image
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.nn.attention.flex_attention import create_block_mask
13
+ from transformers.configuration_utils import PretrainedConfig
14
+ from transformers.modeling_utils import PreTrainedModel
15
+
16
+ from data.data_utils import (
17
+ create_sparse_mask,
18
+ get_flattened_position_ids_extrapolate,
19
+ get_flattened_position_ids_interpolate,
20
+ patchify,
21
+ )
22
+ from .qwen2_navit import NaiveCache
23
+ from .modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding
24
+
25
+
26
+ class BagelConfig(PretrainedConfig):
27
+ def __init__(
28
+ self,
29
+ visual_gen=True,
30
+ visual_und=True,
31
+ llm_config=None,
32
+ vit_config=None,
33
+ vae_config=None,
34
+ latent_patch_size=2,
35
+ max_latent_size=32,
36
+ vit_max_num_patch_per_side=70,
37
+ connector_act="gelu_pytorch_tanh",
38
+ interpolate_pos=False,
39
+ timestep_shift=1.0,
40
+ **kwargs
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.visual_gen = visual_gen
44
+ self.visual_und = visual_und
45
+ self.llm_config = llm_config
46
+ self.vit_config = vit_config
47
+ self.vae_config = vae_config
48
+ self.latent_patch_size = latent_patch_size
49
+ self.max_latent_size = max_latent_size
50
+ self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
51
+ self.connector_act = connector_act
52
+ self.interpolate_pos = interpolate_pos
53
+ self.timestep_shift = timestep_shift
54
+
55
+
56
+ class Bagel(PreTrainedModel):
57
+ config_class = BagelConfig
58
+ base_model_prefix = 'bagel'
59
+
60
+ def __init__(self, language_model, vit_model, config: BagelConfig):
61
+ super().__init__(config)
62
+ self.language_model = language_model
63
+ self.hidden_size = config.llm_config.hidden_size
64
+ self.use_moe = "Mo" in config.llm_config.layer_module
65
+ self.num_heads = config.llm_config.num_attention_heads
66
+
67
+ if config.visual_gen:
68
+ self.latent_patch_size = config.latent_patch_size
69
+ self.timestep_shift = config.timestep_shift
70
+ self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
71
+ self.max_latent_size = config.max_latent_size
72
+ self.latent_channel = config.vae_config.z_channels
73
+ self.patch_latent_dim = self.latent_patch_size ** 2 * self.latent_channel
74
+ self.time_embedder = TimestepEmbedder(self.hidden_size)
75
+ self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
76
+ self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
77
+ self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
78
+
79
+ if config.visual_und:
80
+ self.vit_model = vit_model
81
+ self.vit_patch_size = config.vit_config.patch_size
82
+ self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
83
+ self.vit_hidden_size = config.vit_config.hidden_size
84
+ self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act)
85
+ self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size)
86
+
87
+ if config.interpolate_pos:
88
+ self.get_flattened_position_ids = get_flattened_position_ids_interpolate
89
+ else:
90
+ self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
91
+
92
+ self.config = config
93
+ self._init_weights()
94
+
95
+ def _init_weights(self):
96
+ if self.config.visual_gen:
97
+ nn.init.constant_(self.llm2vae.weight, 0)
98
+ nn.init.constant_(self.llm2vae.bias, 0)
99
+
100
+ def forward(
101
+ self,
102
+ sequence_length: int,
103
+ packed_text_ids: torch.LongTensor,
104
+ packed_text_indexes: torch.LongTensor,
105
+ sample_lens: List[int],
106
+ packed_position_ids: torch.LongTensor,
107
+ nested_attention_masks: List[torch.Tensor] = None,
108
+ split_lens: List[int] = None,
109
+ attn_modes: List[str] = None,
110
+ # for visual understanding
111
+ ce_loss_indexes: Optional[torch.BoolTensor] = None,
112
+ packed_label_ids: Optional[torch.LongTensor] = None,
113
+ packed_vit_tokens: Optional[torch.Tensor] = None,
114
+ packed_vit_token_indexes: Optional[torch.LongTensor] = None,
115
+ packed_vit_position_ids: Optional[torch.LongTensor] = None,
116
+ vit_token_seqlens: Optional[torch.IntTensor] = None,
117
+ # for visual generation
118
+ padded_latent: Optional[torch.Tensor] = None,
119
+ patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None,
120
+ packed_latent_position_ids: Optional[torch.LongTensor] = None,
121
+ packed_vae_token_indexes: Optional[torch.LongTensor] = None,
122
+ packed_timesteps: Optional[torch.LongTensor] = None,
123
+ mse_loss_indexes: Optional[torch.BoolTensor] = None,
124
+ ) -> torch.Tensor:
125
+ """
126
+ Args:
127
+ sequence_length: length of sequence.
128
+ packed_text_ids: 1-D int tensor, packed text token ids.
129
+ packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
130
+ sample_lens: A list of N ints, length of each sample in packed_sequence.
131
+ nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
132
+ -inf means ignore.
133
+ packed_position_ids: packed 1-D positions, an image has only one global position shared
134
+ by all latent tokens.
135
+
136
+ packed_vit_tokens: packed patchified image tokens for vit model.
137
+ packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
138
+ packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
139
+ vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
140
+ packed_label_ids: 1-D int tensor, packed label token ids.
141
+ ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
142
+
143
+ padded_latent: padded latent from VAE encoder.
144
+ patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
145
+ packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
146
+ packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
147
+ packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
148
+ mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
149
+ """
150
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
151
+ packed_sequence = packed_text_embedding.new_zeros(size=(sequence_length, self.hidden_size))
152
+ packed_sequence[packed_text_indexes] = packed_text_embedding
153
+
154
+ if nested_attention_masks is None:
155
+ sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, packed_text_embedding.device)
156
+ seqlen = sum(sample_lens)
157
+ block_mask = create_block_mask(
158
+ sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen,
159
+ device=packed_text_embedding.device, BLOCK_SIZE=128, _compile=True
160
+ )
161
+ attention_mask = block_mask
162
+ else:
163
+ attention_mask = nested_attention_masks
164
+
165
+ if self.config.visual_und:
166
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
167
+ cu_seqlens = cu_seqlens.to(torch.int32)
168
+ max_seqlen = torch.max(vit_token_seqlens).item()
169
+ packed_vit_token_embed = self.vit_model(
170
+ packed_pixel_values=packed_vit_tokens,
171
+ packed_flattened_position_ids=packed_vit_position_ids,
172
+ cu_seqlens=cu_seqlens,
173
+ max_seqlen=max_seqlen,
174
+ )
175
+ packed_vit_token_embed = self.connector(packed_vit_token_embed)
176
+ vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids)
177
+ packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb
178
+ packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
179
+
180
+ if self.config.visual_gen:
181
+ p = self.latent_patch_size
182
+ packed_latent = []
183
+ for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
184
+ latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
185
+ latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
186
+ packed_latent.append(latent)
187
+ packed_latent_clean = torch.cat(packed_latent, dim=0)
188
+
189
+ noise = torch.randn_like(packed_latent_clean)
190
+ packed_timesteps = torch.sigmoid(packed_timesteps)
191
+ packed_timesteps = self.timestep_shift * packed_timesteps / (1 + (self.timestep_shift - 1) * packed_timesteps)
192
+ packed_latent = (1 - packed_timesteps[:, None]) * packed_latent_clean + packed_timesteps[:, None] * noise
193
+ packed_timestep_embeds = self.time_embedder(packed_timesteps)
194
+ latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids)
195
+ packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + latent_token_pos_emb
196
+ packed_sequence[packed_vae_token_indexes] = packed_latent
197
+
198
+ extra_inputs = {}
199
+ if self.use_moe:
200
+ packed_und_token_indexes = packed_text_indexes
201
+ if packed_vit_token_indexes is not None:
202
+ packed_und_token_indexes=torch.cat([packed_text_indexes, packed_vit_token_indexes], dim=0)
203
+ extra_inputs.update(
204
+ packed_und_token_indexes=packed_und_token_indexes,
205
+ packed_gen_token_indexes=packed_vae_token_indexes,
206
+ )
207
+
208
+ last_hidden_state = self.language_model(
209
+ packed_sequence=packed_sequence,
210
+ sample_lens=sample_lens,
211
+ attention_mask=attention_mask,
212
+ packed_position_ids=packed_position_ids,
213
+ **extra_inputs,
214
+ )
215
+
216
+ mse = None
217
+ if self.config.visual_gen:
218
+ packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes])
219
+ target = noise - packed_latent_clean # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
220
+ has_mse = packed_timesteps > 0
221
+ mse = (packed_mse_preds - target[has_mse]) ** 2
222
+
223
+ ce = None
224
+ if ce_loss_indexes is not None:
225
+ packed_ce_preds = self.language_model.lm_head(last_hidden_state[ce_loss_indexes])
226
+ ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none")
227
+
228
+ return dict(mse=mse, ce=ce)
229
+
230
+
231
+ def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids):
232
+ packed_text_ids = list()
233
+ packed_text_position_ids = list()
234
+ text_token_lens = list()
235
+ packed_text_indexes = list()
236
+ packed_key_value_indexes = list()
237
+
238
+ curr = 0
239
+ newlens, new_rope = list(), list()
240
+ for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope):
241
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
242
+ curr += curr_kvlen
243
+
244
+ text_ids = tokenizer.encode(prompt)
245
+ text_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']]
246
+ text_token_lens.append(len(text_ids))
247
+ packed_text_ids.extend(text_ids)
248
+ packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids)))
249
+ packed_text_indexes.extend(range(curr, curr + len(text_ids)))
250
+ newlens.append(curr_kvlen + len(text_ids))
251
+ new_rope.append(curr_position_id + len(text_ids))
252
+ curr += len(text_ids)
253
+
254
+ generation_input = {
255
+ "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int),
256
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
257
+ "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long),
258
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
259
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
260
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
261
+ }
262
+
263
+ return generation_input, newlens, new_rope
264
+
265
+ @torch.no_grad
266
+ def forward_cache_update_text(
267
+ self,
268
+ past_key_values: NaiveCache,
269
+ packed_text_ids: torch.IntTensor,
270
+ packed_text_position_ids: torch.LongTensor,
271
+ text_token_lens: torch.LongTensor,
272
+ packed_text_indexes: torch.LongTensor,
273
+ packed_key_value_indexes: torch.LongTensor,
274
+ key_values_lens: torch.IntTensor,
275
+ ):
276
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
277
+
278
+ extra_inputs = {}
279
+ if self.use_moe:
280
+ extra_inputs = {"mode": "und"}
281
+
282
+ output = self.language_model.forward_inference(
283
+ packed_query_sequence=packed_text_embedding,
284
+ query_lens=text_token_lens,
285
+ packed_query_position_ids=packed_text_position_ids,
286
+ packed_query_indexes=packed_text_indexes,
287
+ past_key_values=past_key_values,
288
+ packed_key_value_indexes=packed_key_value_indexes,
289
+ key_values_lens=key_values_lens,
290
+ update_past_key_values=True,
291
+ is_causal=True,
292
+ **extra_inputs,
293
+ )
294
+ past_key_values = output.past_key_values
295
+
296
+ return past_key_values
297
+
298
+ def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids):
299
+ packed_vit_token_indexes = list()
300
+ vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list()
301
+ packed_text_ids, packed_text_indexes = list(), list()
302
+ packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
303
+ packed_key_value_indexes = list()
304
+
305
+ _curr = curr = 0
306
+ newlens, new_rope = list(), list()
307
+ for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
308
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
309
+ curr += curr_kvlen
310
+
311
+ packed_text_ids.append(new_token_ids['start_of_image'])
312
+ packed_text_indexes.append(_curr)
313
+ packed_indexes.append(curr)
314
+ curr += 1
315
+ _curr += 1
316
+
317
+ image_tensor = transforms(image)
318
+ vit_position_ids = self.get_flattened_position_ids(
319
+ image_tensor.size(1), image_tensor.size(2),
320
+ self.vit_patch_size,
321
+ max_num_patches_per_side=self.vit_max_num_patch_per_side
322
+ )
323
+ vit_tokens = patchify(image_tensor, self.vit_patch_size)
324
+ packed_vit_tokens.append(vit_tokens)
325
+ num_img_tokens = vit_tokens.shape[0]
326
+ packed_vit_position_ids.append(vit_position_ids)
327
+ vit_token_seqlens.append(num_img_tokens)
328
+ packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
329
+ packed_indexes.extend(range(curr, curr + num_img_tokens))
330
+ curr += num_img_tokens
331
+ _curr += num_img_tokens
332
+
333
+ packed_text_ids.append(new_token_ids['end_of_image'])
334
+ packed_text_indexes.append(_curr)
335
+ packed_indexes.append(curr)
336
+ curr += 1
337
+ _curr += 1
338
+
339
+ packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
340
+ packed_seqlens.append(num_img_tokens + 2)
341
+ newlens.append(curr_kvlen + num_img_tokens + 2)
342
+ new_rope.append(curr_position_id + 1)
343
+
344
+ generation_input = {
345
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
346
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
347
+ "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
348
+ "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
349
+ "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
350
+ "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long),
351
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
352
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
353
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
354
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
355
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
356
+ }
357
+
358
+ return generation_input, newlens, new_rope
359
+
360
+ @torch.no_grad
361
+ def forward_cache_update_vit(
362
+ self,
363
+ past_key_values: NaiveCache,
364
+ packed_text_ids: torch.LongTensor,
365
+ packed_text_indexes: torch.LongTensor,
366
+ packed_vit_tokens: torch.Tensor,
367
+ packed_vit_token_indexes: torch.LongTensor,
368
+ packed_vit_position_ids: torch.LongTensor,
369
+ vit_token_seqlens: torch.IntTensor,
370
+ packed_position_ids: torch.LongTensor,
371
+ packed_seqlens: torch.IntTensor,
372
+ packed_indexes: torch.LongTensor,
373
+ packed_key_value_indexes: torch.LongTensor,
374
+ key_values_lens: torch.IntTensor,
375
+ ):
376
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
377
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
378
+ packed_sequence[packed_text_indexes] = packed_text_embedding
379
+
380
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
381
+ cu_seqlens = cu_seqlens.to(torch.int32)
382
+ max_seqlen = torch.max(vit_token_seqlens).item()
383
+ packed_vit_token_embed = self.vit_model(
384
+ packed_pixel_values=packed_vit_tokens,
385
+ packed_flattened_position_ids=packed_vit_position_ids,
386
+ cu_seqlens=cu_seqlens,
387
+ max_seqlen=max_seqlen,
388
+ )
389
+ packed_vit_token_embed = self.connector(packed_vit_token_embed)
390
+ pos_emb = self.vit_pos_embed(packed_vit_position_ids)
391
+ packed_vit_token_embed = packed_vit_token_embed + pos_emb
392
+ packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
393
+
394
+ extra_inputs = {}
395
+ if self.use_moe:
396
+ extra_inputs = {"mode": "und"}
397
+
398
+ output = self.language_model.forward_inference(
399
+ packed_query_sequence=packed_sequence,
400
+ query_lens=packed_seqlens,
401
+ packed_query_position_ids=packed_position_ids,
402
+ packed_query_indexes=packed_indexes,
403
+ past_key_values=past_key_values,
404
+ packed_key_value_indexes=packed_key_value_indexes,
405
+ key_values_lens=key_values_lens,
406
+ update_past_key_values=True,
407
+ is_causal=False,
408
+ **extra_inputs,
409
+ )
410
+ past_key_values = output.past_key_values
411
+
412
+ return past_key_values
413
+
414
+ def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0):
415
+ patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
416
+ packed_vae_token_indexes = list()
417
+ packed_text_ids, packed_text_indexes = list(), list()
418
+ packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
419
+ packed_key_value_indexes = list()
420
+
421
+ _curr = curr = 0
422
+ vae_image_tensors = list()
423
+ newlens, new_rope = list(), list()
424
+ for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
425
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
426
+ curr += curr_kvlen
427
+
428
+ packed_text_ids.append(new_token_ids['start_of_image'])
429
+ packed_text_indexes.append(_curr)
430
+ packed_indexes.append(curr)
431
+ curr += 1
432
+ _curr += 1
433
+
434
+ image_tensor = transforms(image)
435
+ vae_image_tensors.append(image_tensor)
436
+ vae_posiiton_ids = self.get_flattened_position_ids(
437
+ image_tensor.size(1), image_tensor.size(2),
438
+ self.latent_downsample,
439
+ max_num_patches_per_side=self.max_latent_size
440
+ )
441
+ packed_vae_position_ids.append(vae_posiiton_ids)
442
+ H, W = image_tensor.shape[1:]
443
+ h = H // self.latent_downsample
444
+ w = W // self.latent_downsample
445
+ patchified_vae_latent_shapes.append((h, w))
446
+
447
+ num_img_tokens = w * h
448
+ packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
449
+ packed_indexes.extend(range(curr, curr + num_img_tokens))
450
+ curr += num_img_tokens
451
+ _curr += num_img_tokens
452
+
453
+ packed_text_ids.append(new_token_ids['end_of_image'])
454
+ packed_text_indexes.append(_curr)
455
+ packed_indexes.append(curr)
456
+ curr += 1
457
+ _curr += 1
458
+
459
+ packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
460
+ packed_seqlens.append(num_img_tokens + 2)
461
+ newlens.append(curr_kvlen + num_img_tokens + 2)
462
+ new_rope.append(curr_position_id + 1)
463
+
464
+ image_sizes = [item.shape for item in vae_image_tensors]
465
+ max_image_size = [max(item) for item in list(zip(*image_sizes))]
466
+ padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
467
+ for i, image_tensor in enumerate(vae_image_tensors):
468
+ padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
469
+
470
+ generation_input = {
471
+ "padded_images": padded_images,
472
+ "patchified_vae_latent_shapes": patchified_vae_latent_shapes,
473
+ "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
474
+ "packed_timesteps": torch.tensor([timestep]),
475
+ "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
476
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
477
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
478
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
479
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
480
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
481
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
482
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
483
+ }
484
+
485
+ return generation_input, newlens, new_rope
486
+
487
+ @torch.no_grad
488
+ def forward_cache_update_vae(
489
+ self,
490
+ vae_model,
491
+ past_key_values: NaiveCache,
492
+ padded_images: torch.Tensor,
493
+ patchified_vae_latent_shapes: List,
494
+ packed_vae_position_ids: torch.LongTensor,
495
+ packed_timesteps: torch.Tensor,
496
+ packed_vae_token_indexes: torch.LongTensor,
497
+ packed_text_ids: torch.LongTensor,
498
+ packed_text_indexes: torch.LongTensor,
499
+ packed_position_ids: torch.LongTensor,
500
+ packed_seqlens: torch.IntTensor,
501
+ packed_indexes: torch.LongTensor,
502
+ key_values_lens: torch.IntTensor,
503
+ packed_key_value_indexes: torch.Tensor,
504
+ ):
505
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
506
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
507
+ packed_sequence[packed_text_indexes] = packed_text_embedding
508
+
509
+ padded_latent = vae_model.encode(padded_images)
510
+
511
+ p = self.latent_patch_size
512
+ packed_latent = list()
513
+ for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
514
+ latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
515
+ latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
516
+ packed_latent.append(latent)
517
+ packed_latent = torch.cat(packed_latent, dim=0)
518
+ packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
519
+ packed_timestep_embeds = self.time_embedder(packed_timesteps)
520
+ packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
521
+ packed_sequence[packed_vae_token_indexes] = packed_latent
522
+
523
+ extra_inputs = {}
524
+ if self.use_moe:
525
+ extra_inputs = {
526
+ "mode": "gen",
527
+ "packed_vae_token_indexes": packed_vae_token_indexes,
528
+ "packed_text_indexes": packed_text_indexes
529
+ }
530
+
531
+ output = self.language_model.forward_inference(
532
+ packed_query_sequence=packed_sequence,
533
+ query_lens=packed_seqlens,
534
+ packed_query_position_ids=packed_position_ids,
535
+ packed_query_indexes=packed_indexes,
536
+ past_key_values=past_key_values,
537
+ key_values_lens=key_values_lens,
538
+ packed_key_value_indexes=packed_key_value_indexes,
539
+ update_past_key_values=True,
540
+ is_causal=False,
541
+ **extra_inputs,
542
+ )
543
+ past_key_values = output.past_key_values
544
+
545
+ return past_key_values
546
+
547
+ def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids):
548
+ packed_text_ids, packed_text_indexes = list(), list()
549
+ packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list()
550
+ packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list()
551
+ packed_key_value_indexes = list()
552
+
553
+ query_curr = curr = 0
554
+ for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
555
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
556
+ curr += curr_kvlen
557
+
558
+ packed_text_ids.append(new_token_ids['start_of_image'])
559
+ packed_text_indexes.append(query_curr)
560
+ packed_indexes.append(curr)
561
+ curr += 1
562
+ query_curr += 1
563
+
564
+ vae_posiiton_ids = self.get_flattened_position_ids(
565
+ H, W,
566
+ self.latent_downsample,
567
+ max_num_patches_per_side=self.max_latent_size
568
+ )
569
+ packed_vae_position_ids.append(vae_posiiton_ids)
570
+
571
+ h, w = H // self.latent_downsample, W // self.latent_downsample
572
+ num_image_tokens = h * w
573
+ packed_init_noises.append(
574
+ torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size ** 2)
575
+ )
576
+ packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens))
577
+ packed_indexes.extend(range(curr, curr + num_image_tokens))
578
+ curr += num_image_tokens
579
+ query_curr += num_image_tokens
580
+
581
+ packed_text_ids.append(new_token_ids['end_of_image'])
582
+ packed_text_indexes.append(query_curr)
583
+ packed_indexes.append(curr)
584
+ curr += 1
585
+ query_curr += 1
586
+
587
+ packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
588
+ packed_seqlens.append(num_image_tokens + 2)
589
+
590
+ generation_input = {
591
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
592
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
593
+ "packed_init_noises": torch.cat(packed_init_noises, dim=0),
594
+ "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
595
+ "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
596
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
597
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
598
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
599
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
600
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
601
+ }
602
+
603
+ return generation_input
604
+
605
+ def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
606
+ packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list()
607
+
608
+ query_curr = curr = 0
609
+ for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
610
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
611
+ curr += curr_kvlen
612
+
613
+ packed_indexes.append(curr)
614
+ curr += 1
615
+ query_curr += 1
616
+
617
+ h, w = H // self.latent_downsample, W // self.latent_downsample
618
+ num_image_tokens = h * w
619
+ packed_indexes.extend(range(curr, curr + num_image_tokens))
620
+ curr += num_image_tokens
621
+ query_curr += num_image_tokens
622
+
623
+ packed_indexes.append(curr)
624
+ curr += 1
625
+ query_curr += 1
626
+
627
+ packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
628
+
629
+ generation_input = {
630
+ "cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
631
+ "cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
632
+ "cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long),
633
+ "cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
634
+ }
635
+
636
+ return generation_input
637
+
638
+ @torch.no_grad
639
+ def generate_image(
640
+ self,
641
+ packed_text_ids: torch.LongTensor,
642
+ packed_text_indexes: torch.LongTensor,
643
+ packed_init_noises: torch.Tensor,
644
+ packed_vae_position_ids: torch.LongTensor,
645
+ packed_vae_token_indexes: torch.LongTensor,
646
+ packed_seqlens: torch.IntTensor,
647
+ packed_position_ids: torch.LongTensor,
648
+ packed_indexes: torch.LongTensor,
649
+ past_key_values: NaiveCache,
650
+ key_values_lens: torch.IntTensor,
651
+ packed_key_value_indexes: torch.LongTensor,
652
+ num_timesteps: int = 24,
653
+ timestep_shift: float = 1.0,
654
+ cfg_renorm_min: float = 0.0,
655
+ cfg_renorm_type: str = "global",
656
+ cfg_interval: Optional[Tuple[float, float]] = [0, 1],
657
+ # cfg_text
658
+ cfg_text_scale: float = 1.0,
659
+ cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
660
+ cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
661
+ cfg_text_past_key_values: Optional[NaiveCache] = None,
662
+ cfg_text_key_values_lens: Optional[torch.IntTensor] = None,
663
+ cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
664
+ # cfg_img
665
+ cfg_img_scale: float = 1.0,
666
+ cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
667
+ cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
668
+ cfg_img_past_key_values: Optional[NaiveCache] = None,
669
+ cfg_img_key_values_lens: Optional[torch.IntTensor] = None,
670
+ cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
671
+ cfg_type: str = "parallel",
672
+ ):
673
+ x_t = packed_init_noises
674
+
675
+ timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device)
676
+ timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps)
677
+ dts = timesteps[:-1] - timesteps[1:]
678
+ timesteps = timesteps[:-1]
679
+
680
+ for i, t in enumerate(timesteps):
681
+
682
+ timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
683
+ if t > cfg_interval[0] and t <= cfg_interval[1]:
684
+ cfg_text_scale_ = cfg_text_scale
685
+ cfg_img_scale_ = cfg_img_scale
686
+ else:
687
+ cfg_text_scale_ = 1.0
688
+ cfg_img_scale_ = 1.0
689
+ v_t = self._forward_flow(
690
+ x_t=x_t,
691
+ timestep=timestep,
692
+ packed_vae_token_indexes=packed_vae_token_indexes,
693
+ packed_vae_position_ids=packed_vae_position_ids,
694
+ packed_text_ids=packed_text_ids,
695
+ packed_text_indexes=packed_text_indexes,
696
+ packed_position_ids=packed_position_ids,
697
+ packed_indexes=packed_indexes,
698
+ packed_seqlens=packed_seqlens,
699
+ key_values_lens=key_values_lens,
700
+ past_key_values=past_key_values,
701
+ packed_key_value_indexes=packed_key_value_indexes,
702
+ cfg_renorm_min=cfg_renorm_min,
703
+ cfg_renorm_type=cfg_renorm_type,
704
+ # cfg_text
705
+ cfg_text_scale=cfg_text_scale_,
706
+ cfg_text_packed_position_ids=cfg_text_packed_position_ids,
707
+ cfg_text_packed_query_indexes=cfg_text_packed_query_indexes,
708
+ cfg_text_key_values_lens=cfg_text_key_values_lens,
709
+ cfg_text_past_key_values=cfg_text_past_key_values,
710
+ cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes,
711
+ # cfg_img
712
+ cfg_img_scale=cfg_img_scale_,
713
+ cfg_img_packed_position_ids=cfg_img_packed_position_ids,
714
+ cfg_img_packed_query_indexes=cfg_img_packed_query_indexes,
715
+ cfg_img_key_values_lens=cfg_img_key_values_lens,
716
+ cfg_img_past_key_values=cfg_img_past_key_values,
717
+ cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
718
+ cfg_type=cfg_type,
719
+ )
720
+
721
+ x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise
722
+
723
+ unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
724
+ return unpacked_latent
725
+
726
+ @torch.no_grad
727
+ def _forward_flow(
728
+ self,
729
+ x_t: torch.Tensor,
730
+ timestep: torch.LongTensor,
731
+ packed_vae_token_indexes: torch.LongTensor,
732
+ packed_vae_position_ids: torch.LongTensor,
733
+ packed_text_ids: torch.LongTensor,
734
+ packed_text_indexes: torch.LongTensor,
735
+ packed_indexes: torch.LongTensor,
736
+ packed_position_ids: torch.LongTensor,
737
+ packed_seqlens: torch.IntTensor,
738
+ key_values_lens: torch.IntTensor,
739
+ past_key_values: NaiveCache,
740
+ packed_key_value_indexes: torch.LongTensor,
741
+ cfg_renorm_min: float = 0.0,
742
+ cfg_renorm_type: str = "global",
743
+ # cfg_text
744
+ cfg_text_scale: float = 1.0,
745
+ cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
746
+ cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
747
+ cfg_text_key_values_lens: Optional[torch.Tensor] = None,
748
+ cfg_text_past_key_values: Optional[NaiveCache] = None,
749
+ cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
750
+ # cfg_img
751
+ cfg_img_scale: float = 1.0,
752
+ cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
753
+ cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
754
+ cfg_img_key_values_lens: Optional[torch.Tensor] = None,
755
+ cfg_img_past_key_values: Optional[NaiveCache] = None,
756
+ cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
757
+ cfg_type: str = "parallel",
758
+ ):
759
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
760
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
761
+ packed_sequence[packed_text_indexes] = packed_text_embedding
762
+
763
+ assert timestep.unique().shape[0] == 1
764
+ packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
765
+ packed_timestep_embeds = self.time_embedder(timestep)
766
+ x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
767
+ packed_sequence[packed_vae_token_indexes] = x_t
768
+
769
+ extra_inputs = {}
770
+ if self.use_moe:
771
+ extra_inputs = {
772
+ "mode": "gen",
773
+ "packed_vae_token_indexes": packed_vae_token_indexes,
774
+ "packed_text_indexes": packed_text_indexes
775
+ }
776
+
777
+ output = self.language_model.forward_inference(
778
+ packed_query_sequence=packed_sequence,
779
+ query_lens=packed_seqlens,
780
+ packed_query_position_ids=packed_position_ids,
781
+ packed_query_indexes=packed_indexes,
782
+ past_key_values=past_key_values,
783
+ key_values_lens=key_values_lens,
784
+ packed_key_value_indexes=packed_key_value_indexes,
785
+ update_past_key_values=False,
786
+ is_causal=False,
787
+ **extra_inputs,
788
+ )
789
+ v_t = self.llm2vae(output.packed_query_sequence)
790
+ v_t = v_t[packed_vae_token_indexes]
791
+
792
+ if cfg_text_scale > 1.0:
793
+ cfg_text_output = self.language_model.forward_inference(
794
+ packed_query_sequence=packed_sequence,
795
+ query_lens=packed_seqlens,
796
+ packed_query_position_ids=cfg_text_packed_position_ids,
797
+ packed_query_indexes=cfg_text_packed_query_indexes,
798
+ past_key_values=cfg_text_past_key_values,
799
+ key_values_lens=cfg_text_key_values_lens,
800
+ packed_key_value_indexes=cfg_text_packed_key_value_indexes,
801
+ update_past_key_values=False,
802
+ is_causal=False,
803
+ **extra_inputs,
804
+ )
805
+ cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence)
806
+ cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes]
807
+
808
+ if cfg_img_scale > 1.0:
809
+ cfg_img_output = self.language_model.forward_inference(
810
+ packed_query_sequence=packed_sequence,
811
+ query_lens=packed_seqlens,
812
+ packed_query_position_ids=cfg_img_packed_position_ids,
813
+ packed_query_indexes=cfg_img_packed_query_indexes,
814
+ past_key_values=cfg_img_past_key_values,
815
+ key_values_lens=cfg_img_key_values_lens,
816
+ packed_key_value_indexes=cfg_img_packed_key_value_indexes,
817
+ update_past_key_values=False,
818
+ is_causal=False,
819
+ **extra_inputs,
820
+ )
821
+ cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence)
822
+ cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes]
823
+
824
+ if cfg_text_scale > 1.0:
825
+ if cfg_renorm_type == "text_channel":
826
+ v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
827
+ norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
828
+ norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
829
+ scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
830
+ v_t_text = v_t_text_ * scale
831
+ if cfg_img_scale > 1.0:
832
+ v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
833
+ else:
834
+ v_t = v_t_text
835
+ else:
836
+ v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
837
+
838
+ if cfg_img_scale > 1.0:
839
+ v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
840
+ else:
841
+ v_t_ = v_t_text_
842
+
843
+ # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
844
+ if cfg_renorm_type == "global":
845
+ norm_v_t = torch.norm(v_t)
846
+ norm_v_t_ = torch.norm(v_t_)
847
+ elif cfg_renorm_type == "channel":
848
+ norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
849
+ norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
850
+ else:
851
+ raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted")
852
+ scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
853
+ v_t = v_t_ * scale
854
+ else:
855
+ # No CFG
856
+ pass
857
+
858
+ return v_t
859
+
860
+ def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
861
+ packed_start_tokens, packed_key_value_indexes = list(), list()
862
+ packed_query_position_ids = list()
863
+
864
+ curr = 0
865
+ for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
866
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
867
+ packed_start_tokens.append(new_token_ids['bos_token_id'])
868
+ packed_query_position_ids.append(curr_position_id)
869
+ curr += curr_kvlen
870
+
871
+ generation_input = {
872
+ "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
873
+ "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
874
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
875
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
876
+ }
877
+
878
+ return generation_input
879
+
880
+ @torch.no_grad
881
+ def generate_text(
882
+ self,
883
+ past_key_values: NaiveCache,
884
+ packed_key_value_indexes: torch.LongTensor,
885
+ key_values_lens: torch.IntTensor,
886
+ packed_start_tokens: torch.LongTensor,
887
+ packed_query_position_ids: torch.LongTensor,
888
+ max_length: int,
889
+ do_sample: bool = False,
890
+ temperature: float = 1.0,
891
+ end_token_id: int = None,
892
+ ):
893
+ step = 0
894
+ generated_sequence = []
895
+ curr_tokens = packed_start_tokens
896
+ while step < max_length:
897
+ generated_sequence.append(curr_tokens)
898
+ packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
899
+ query_lens = torch.ones_like(curr_tokens)
900
+ packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
901
+ 0, len(key_values_lens),
902
+ device=key_values_lens.device,
903
+ dtype=key_values_lens.dtype
904
+ )
905
+
906
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
907
+ for i in range(len(uppacked)):
908
+ uppacked[i] += i
909
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
910
+
911
+ extra_inputs = {}
912
+ if self.use_moe:
913
+ extra_inputs = {"mode": "und"}
914
+
915
+ output = self.language_model.forward_inference(
916
+ packed_query_sequence=packed_text_embedding,
917
+ query_lens=query_lens,
918
+ packed_query_position_ids=packed_query_position_ids,
919
+ packed_query_indexes=packed_query_indexes,
920
+ past_key_values=past_key_values,
921
+ key_values_lens=key_values_lens,
922
+ packed_key_value_indexes=packed_key_value_indexes,
923
+ update_past_key_values=True,
924
+ is_causal=True,
925
+ **extra_inputs,
926
+ )
927
+ past_key_values = output.past_key_values
928
+ packed_query_sequence = output.packed_query_sequence
929
+ pred_logits = self.language_model.lm_head(packed_query_sequence)
930
+
931
+ if do_sample:
932
+ probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
933
+ curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
934
+ else:
935
+ curr_tokens = torch.argmax(pred_logits, dim=-1)
936
+
937
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
938
+ for i in range(len(uppacked)):
939
+ uppacked[i] = torch.cat(
940
+ [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
941
+ )
942
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
943
+ key_values_lens = key_values_lens + 1
944
+ packed_query_position_ids = packed_query_position_ids + 1
945
+ step += 1
946
+
947
+ if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1
948
+ break
949
+
950
+ output_device = generated_sequence[0].device
951
+ return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)
952
+
953
+ # for evaluation
954
+ @torch.no_grad()
955
+ def chat(
956
+ self,
957
+ tokenizer,
958
+ new_token_ids,
959
+ image_transform,
960
+ images,
961
+ prompt,
962
+ max_length: int,
963
+ do_sample: bool = False,
964
+ temperature: float = 1.0,
965
+ ):
966
+ device = next(self.parameters()).device
967
+
968
+ if isinstance(new_token_ids, dict):
969
+ for k, v in new_token_ids.items():
970
+ if torch.is_tensor(v):
971
+ new_token_ids[k] = v.to(device)
972
+ elif torch.is_tensor(new_token_ids):
973
+ new_token_ids = new_token_ids.to(device)
974
+
975
+ # prefill
976
+ past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers)
977
+ newlens = [0]
978
+ new_rope = [0]
979
+
980
+ # add images
981
+ for image in images:
982
+ generation_input, newlens, new_rope = self.prepare_vit_images(
983
+ curr_kvlens=newlens,
984
+ curr_rope=new_rope,
985
+ images=[image],
986
+ transforms=image_transform,
987
+ new_token_ids=new_token_ids,
988
+ )
989
+ for k, v in generation_input.items():
990
+ if torch.is_tensor(v):
991
+ generation_input[k] = v.to(device)
992
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
993
+ past_key_values = self.forward_cache_update_vit(past_key_values, **generation_input)
994
+
995
+ # add text
996
+ generation_input, newlens, new_rope = self.prepare_prompts(
997
+ curr_kvlens=newlens,
998
+ curr_rope=new_rope,
999
+ prompts=[prompt],
1000
+ tokenizer=tokenizer,
1001
+ new_token_ids=new_token_ids,
1002
+ )
1003
+ for k, v in generation_input.items():
1004
+ if torch.is_tensor(v):
1005
+ generation_input[k] = v.to(device)
1006
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1007
+ past_key_values = self.forward_cache_update_text(past_key_values, **generation_input)
1008
+
1009
+ # decode
1010
+ generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids)
1011
+ for k, v in generation_input.items():
1012
+ if torch.is_tensor(v):
1013
+ generation_input[k] = v.to(device)
1014
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1015
+ unpacked_latent = self.generate_text(
1016
+ past_key_values=past_key_values,
1017
+ max_length=max_length,
1018
+ do_sample=do_sample,
1019
+ temperature=temperature,
1020
+ end_token_id=new_token_ids['eos_token_id'],
1021
+ **generation_input,
1022
+ )
1023
+ output = tokenizer.decode(unpacked_latent[:,0])
1024
+ output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
1025
+
1026
+ return output
modeling/bagel/modeling_utils.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 Facebook, Inc. and its affiliates.
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: CC BY-NC 4.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under CC BY-NC 4.0, with the full license text
8
+ # available at https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import math
13
+
14
+ import numpy as np
15
+ import torch
16
+ from torch import nn
17
+ from transformers.activations import ACT2FN
18
+
19
+ # --------------------------------------------------------
20
+ # 2D sine-cosine position embedding
21
+ # References:
22
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
23
+ # --------------------------------------------------------
24
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
25
+ grid_h = np.arange(grid_size, dtype=np.float32)
26
+ grid_w = np.arange(grid_size, dtype=np.float32)
27
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
28
+ grid = np.stack(grid, axis=0)
29
+
30
+ grid = grid.reshape([2, 1, grid_size, grid_size])
31
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
32
+ if cls_token and extra_tokens > 0:
33
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
34
+ return pos_embed
35
+
36
+
37
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
38
+ assert embed_dim % 2 == 0
39
+
40
+ # use half of dimensions to encode grid_h
41
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
42
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
43
+
44
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
45
+ return emb
46
+
47
+
48
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
49
+ """
50
+ embed_dim: output dimension for each position
51
+ pos: a list of positions to be encoded: size (M,)
52
+ out: (M, D)
53
+ """
54
+ assert embed_dim % 2 == 0
55
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
56
+ omega /= embed_dim / 2.
57
+ omega = 1. / 10000**omega # (D/2,)
58
+
59
+ pos = pos.reshape(-1) # (M,)
60
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
61
+
62
+ emb_sin = np.sin(out) # (M, D/2)
63
+ emb_cos = np.cos(out) # (M, D/2)
64
+
65
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
66
+ return emb
67
+
68
+
69
+ # --------------------------------------------------------
70
+ # TimestepEmbedder
71
+ # Reference:
72
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
73
+ # --------------------------------------------------------
74
+ class TimestepEmbedder(nn.Module):
75
+ """
76
+ Embeds scalar timesteps into vector representations.
77
+ """
78
+ def __init__(self, hidden_size, frequency_embedding_size=256):
79
+ super().__init__()
80
+ self.mlp = nn.Sequential(
81
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
82
+ nn.SiLU(),
83
+ nn.Linear(hidden_size, hidden_size, bias=True),
84
+ )
85
+ self.frequency_embedding_size = frequency_embedding_size
86
+
87
+ @staticmethod
88
+ def timestep_embedding(t, dim, max_period=10000):
89
+ """
90
+ Create sinusoidal timestep embeddings.
91
+ :param t: a 1-D Tensor of N indices, one per batch element.
92
+ These may be fractional.
93
+ :param dim: the dimension of the output.
94
+ :param max_period: controls the minimum frequency of the embeddings.
95
+ :return: an (N, D) Tensor of positional embeddings.
96
+ """
97
+ half = dim // 2
98
+ freqs = torch.exp(
99
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
100
+ ).to(device=t.device)
101
+ args = t[:, None].float() * freqs[None]
102
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
103
+ if dim % 2:
104
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
105
+ return embedding
106
+
107
+ def forward(self, t):
108
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
109
+ t_emb = self.mlp(t_freq)
110
+ return t_emb
111
+
112
+
113
+ class MLPconnector(nn.Module):
114
+ def __init__(self, in_dim: int, out_dim: int, hidden_act: str):
115
+ super().__init__()
116
+ self.activation_fn = ACT2FN[hidden_act]
117
+ self.fc1 = nn.Linear(in_dim, out_dim)
118
+ self.fc2 = nn.Linear(out_dim, out_dim)
119
+
120
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
121
+ hidden_states = self.fc1(hidden_states)
122
+ hidden_states = self.activation_fn(hidden_states)
123
+ hidden_states = self.fc2(hidden_states)
124
+ return hidden_states
125
+
126
+
127
+ class PositionEmbedding(nn.Module):
128
+ def __init__(self, max_num_patch_per_side, hidden_size):
129
+ super().__init__()
130
+ self.max_num_patch_per_side = max_num_patch_per_side
131
+ self.hidden_size = hidden_size
132
+ self.pos_embed = nn.Parameter(
133
+ torch.zeros(max_num_patch_per_side ** 2, hidden_size),
134
+ requires_grad=False
135
+ )
136
+ self._init_weights()
137
+
138
+ def _init_weights(self):
139
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
140
+ pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side)
141
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
142
+
143
+ def forward(self, position_ids):
144
+ return self.pos_embed[position_ids]
modeling/bagel/qwen2_navit.py ADDED
@@ -0,0 +1,1157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under Apache-2.0, with the full license text
8
+ # available at https://github.com/huggingface/transformers/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+
13
+ from dataclasses import dataclass
14
+ from functools import partial
15
+ from typing import List, Optional, Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn.attention import SDPBackend, sdpa_kernel
20
+ from torch.nn.attention.flex_attention import flex_attention
21
+ from torch.nn.functional import scaled_dot_product_attention
22
+ from transformers.utils import ModelOutput
23
+
24
+ from flash_attn import flash_attn_varlen_func
25
+ from modeling.qwen2.modeling_qwen2 import (
26
+ Qwen2Attention,
27
+ Qwen2MLP,
28
+ Qwen2PreTrainedModel,
29
+ Qwen2RMSNorm,
30
+ Qwen2RotaryEmbedding,
31
+ apply_rotary_pos_emb,
32
+ )
33
+
34
+ from modeling.qwen2.configuration_qwen2 import Qwen2Config as _Qwen2Config
35
+
36
+
37
+ torch._dynamo.config.cache_size_limit = 512
38
+ torch._dynamo.config.accumulated_cache_size_limit = 4096
39
+ # flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune'
40
+ flex_attention = torch.compile(flex_attention)
41
+
42
+
43
+ class Qwen2Config(_Qwen2Config):
44
+ r"""
45
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
46
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
47
+ with the defaults will yield a similar configuration to that of
48
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
49
+
50
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
51
+ documentation from [`PretrainedConfig`] for more information.
52
+
53
+ Args:
54
+ vocab_size (`int`, *optional*, defaults to 151936):
55
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
56
+ `inputs_ids` passed when calling [`Qwen2Model`]
57
+ hidden_size (`int`, *optional*, defaults to 4096):
58
+ Dimension of the hidden representations.
59
+ intermediate_size (`int`, *optional*, defaults to 22016):
60
+ Dimension of the MLP representations.
61
+ num_hidden_layers (`int`, *optional*, defaults to 32):
62
+ Number of hidden layers in the Transformer encoder.
63
+ num_attention_heads (`int`, *optional*, defaults to 32):
64
+ Number of attention heads for each attention layer in the Transformer encoder.
65
+ num_key_value_heads (`int`, *optional*, defaults to 32):
66
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
67
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
68
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
69
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
70
+ by meanpooling all the original heads within that group. For more details checkout [this
71
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
72
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
73
+ The non-linear activation function (function or string) in the decoder.
74
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
75
+ The maximum sequence length that this model might ever be used with.
76
+ initializer_range (`float`, *optional*, defaults to 0.02):
77
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
78
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
79
+ The epsilon used by the rms normalization layers.
80
+ use_cache (`bool`, *optional*, defaults to `True`):
81
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
82
+ relevant if `config.is_decoder=True`.
83
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
84
+ Whether the model's input and output word embeddings should be tied.
85
+ rope_theta (`float`, *optional*, defaults to 10000.0):
86
+ The base period of the RoPE embeddings.
87
+ rope_scaling (`Dict`, *optional*):
88
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
89
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
90
+ accordingly.
91
+ Expected contents:
92
+ `rope_type` (`str`):
93
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
94
+ 'llama3'], with 'default' being the original RoPE implementation.
95
+ `factor` (`float`, *optional*):
96
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
97
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
98
+ original maximum pre-trained length.
99
+ `original_max_position_embeddings` (`int`, *optional*):
100
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
101
+ pretraining.
102
+ `attention_factor` (`float`, *optional*):
103
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
104
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
105
+ `factor` field to infer the suggested value.
106
+ `beta_fast` (`float`, *optional*):
107
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
108
+ ramp function. If unspecified, it defaults to 32.
109
+ `beta_slow` (`float`, *optional*):
110
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
111
+ ramp function. If unspecified, it defaults to 1.
112
+ `short_factor` (`List[float]`, *optional*):
113
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
114
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
115
+ size divided by the number of attention heads divided by 2
116
+ `long_factor` (`List[float]`, *optional*):
117
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
118
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
119
+ size divided by the number of attention heads divided by 2
120
+ `low_freq_factor` (`float`, *optional*):
121
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
122
+ `high_freq_factor` (`float`, *optional*):
123
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
124
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
125
+ Whether to use sliding window attention.
126
+ sliding_window (`int`, *optional*, defaults to 4096):
127
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
128
+ max_window_layers (`int`, *optional*, defaults to 28):
129
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
130
+ attention_dropout (`float`, *optional*, defaults to 0.0):
131
+ The dropout ratio for the attention probabilities.
132
+
133
+ ```python
134
+ >>> from transformers import Qwen2Model, Qwen2Config
135
+
136
+ >>> # Initializing a Qwen2 style configuration
137
+ >>> configuration = Qwen2Config()
138
+
139
+ >>> # Initializing a model from the Qwen2-7B style configuration
140
+ >>> model = Qwen2Model(configuration)
141
+
142
+ >>> # Accessing the model configuration
143
+ >>> configuration = model.config
144
+ ```"""
145
+
146
+ model_type = "qwen2"
147
+ keys_to_ignore_at_inference = ["past_key_values"]
148
+
149
+ def __init__(
150
+ self,
151
+ vocab_size=151936,
152
+ hidden_size=4096,
153
+ intermediate_size=22016,
154
+ num_hidden_layers=32,
155
+ num_attention_heads=32,
156
+ num_key_value_heads=32,
157
+ hidden_act="silu",
158
+ max_position_embeddings=32768,
159
+ initializer_range=0.02,
160
+ rms_norm_eps=1e-6,
161
+ use_cache=True,
162
+ tie_word_embeddings=False,
163
+ rope_theta=10000.0,
164
+ rope_scaling=None,
165
+ use_sliding_window=False,
166
+ sliding_window=4096,
167
+ max_window_layers=28,
168
+ attention_dropout=0.0,
169
+ is_causal=True,
170
+ _attn_implementation="flash_attention_2",
171
+ qk_norm=True,
172
+ layer_module="Qwen2DecoderLayer",
173
+ freeze_und=False,
174
+ **kwargs,
175
+ ):
176
+ super().__init__(
177
+ vocab_size=vocab_size,
178
+ hidden_size=hidden_size,
179
+ intermediate_size=intermediate_size,
180
+ num_hidden_layers=num_hidden_layers,
181
+ num_attention_heads=num_attention_heads,
182
+ num_key_value_heads=num_key_value_heads,
183
+ hidden_act=hidden_act,
184
+ max_position_embeddings=max_position_embeddings,
185
+ initializer_range=initializer_range,
186
+ rms_norm_eps=rms_norm_eps,
187
+ use_cache=use_cache,
188
+ tie_word_embeddings=tie_word_embeddings,
189
+ rope_theta=rope_theta,
190
+ rope_scaling=rope_scaling,
191
+ use_sliding_window=use_sliding_window,
192
+ sliding_window=sliding_window,
193
+ max_window_layers=max_window_layers,
194
+ attention_dropout=attention_dropout,
195
+ is_causal=is_causal,
196
+ _attn_implementation=_attn_implementation,
197
+ **kwargs,
198
+ )
199
+ self.qk_norm = qk_norm
200
+ self.layer_module = layer_module
201
+ self.freeze_und = freeze_und
202
+
203
+
204
+ class NaiveCache:
205
+ def __init__(self, num_layers):
206
+ self.key_cache = {k: None for k in range(num_layers)}
207
+ self.value_cache = {k: None for k in range(num_layers)}
208
+
209
+ @property
210
+ def num_layers(self):
211
+ return len(self.key_cache)
212
+
213
+ @property
214
+ def seq_lens(self):
215
+ if self.key_cache[0] is not None:
216
+ return self.key_cache[0].shape[0]
217
+ else:
218
+ return 0
219
+
220
+
221
+ @dataclass
222
+ class BaseNavitOutputWithPast(ModelOutput):
223
+ packed_query_sequence: torch.FloatTensor = None
224
+ past_key_values: Optional[NaiveCache] = None
225
+
226
+
227
+ def pad_sequence(tensor, pad_size):
228
+ H, L, D = tensor.shape
229
+ pad_tensor = tensor.new_zeros((H, pad_size, D))
230
+ return torch.cat([tensor, pad_tensor], dim=1)
231
+
232
+
233
+ class PackedAttention(Qwen2Attention):
234
+ def __init__(self, config, layer_idx: Optional[int] = None):
235
+ super().__init__(config, layer_idx)
236
+ if self.config.qk_norm:
237
+ self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
238
+ self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
239
+ else:
240
+ self.q_norm = nn.Identity()
241
+ self.k_norm = nn.Identity()
242
+
243
+ def forward(self, *args, **kwargs):
244
+ if self.training:
245
+ return self.forward_train(*args, **kwargs)
246
+ else:
247
+ return self.forward_inference(*args, **kwargs)
248
+
249
+ def forward_train(
250
+ self,
251
+ packed_sequence: torch.Tensor,
252
+ sample_lens: List[int],
253
+ attention_mask: List[torch.Tensor],
254
+ packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
255
+ ):
256
+ packed_query_states = self.q_proj(packed_sequence).view(-1, self.num_heads, self.head_dim)
257
+ packed_key_states = self.k_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)
258
+ packed_value_states = self.v_proj(packed_sequence).view(-1, self.num_key_value_heads, self.head_dim)
259
+
260
+ packed_query_states = self.q_norm(packed_query_states)
261
+ packed_key_states = self.k_norm(packed_key_states)
262
+
263
+ packed_cos, packed_sin = packed_position_embeddings
264
+ packed_query_states, packed_key_states = apply_rotary_pos_emb(
265
+ packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
266
+ )
267
+
268
+ if isinstance(attention_mask, List):
269
+ packed_key_states = packed_key_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
270
+ packed_key_states = packed_key_states.reshape(-1, self.num_heads, self.head_dim)
271
+ packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
272
+ packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)
273
+
274
+ unpacked_query_states = packed_query_states.transpose(0, 1).split(sample_lens, dim=1)
275
+ unpacked_key_states = packed_key_states.transpose(0, 1).split(sample_lens, dim=1)
276
+ unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
277
+ upacked_attn_output = []
278
+ for query_states, key_states, value_states, attention_mask_per_sample in zip(
279
+ unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
280
+ ):
281
+ with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
282
+ attn_output = scaled_dot_product_attention(
283
+ query_states.to(torch.bfloat16).unsqueeze(0),
284
+ key_states.to(torch.bfloat16).unsqueeze(0),
285
+ value_states.to(torch.bfloat16).unsqueeze(0),
286
+ attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
287
+ )
288
+ upacked_attn_output.append(attn_output.squeeze(0))
289
+ packed_attn_output = torch.cat(upacked_attn_output, dim=1)
290
+ else:
291
+ pad_size = sum(sample_lens) - packed_query_states.shape[0]
292
+ packed_query_states = pad_sequence(packed_query_states.permute(1, 0, 2), pad_size)
293
+ packed_key_states = pad_sequence(packed_key_states.permute(1, 0, 2), pad_size)
294
+ packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
295
+ packed_attn_output = flex_attention(
296
+ packed_query_states.unsqueeze(0),
297
+ packed_key_states.unsqueeze(0),
298
+ packed_value_states.unsqueeze(0),
299
+ enable_gqa=True,
300
+ block_mask=attention_mask,
301
+ )
302
+ end_index = packed_attn_output.shape[2] - pad_size
303
+ packed_attn_output = packed_attn_output[0, :, :end_index, :]
304
+
305
+ packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.hidden_size)
306
+ packed_attn_output = self.o_proj(packed_attn_output)
307
+
308
+ return packed_attn_output
309
+
310
+ def forward_inference(
311
+ self,
312
+ packed_query_sequence: torch.Tensor,
313
+ query_lens: torch.Tensor,
314
+ packed_query_position_embeddings: torch.Tensor,
315
+ packed_query_indexes: torch.Tensor,
316
+ past_key_values: Optional[NaiveCache] = None,
317
+ key_values_lens: Optional[torch.Tensor] = None,
318
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
319
+ update_past_key_values=True,
320
+ is_causal=True,
321
+ ):
322
+ packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
323
+ packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
324
+ packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
325
+
326
+ packed_query_states = self.q_norm(packed_query_states)
327
+ packed_key_states = self.k_norm(packed_key_states)
328
+
329
+ packed_cos, packed_sin = packed_query_position_embeddings
330
+ packed_query_states, packed_key_states = apply_rotary_pos_emb(
331
+ packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
332
+ )
333
+
334
+ packed_query_states = packed_query_states.to(torch.bfloat16)
335
+ packed_key_states = packed_key_states.to(torch.bfloat16)
336
+ packed_value_states = packed_value_states.to(torch.bfloat16)
337
+
338
+ if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:
339
+ past_key_states = past_key_values.key_cache[self.layer_idx]
340
+ past_value_states = past_key_values.value_cache[self.layer_idx]
341
+
342
+ seqlens = sum(query_lens) + sum(key_values_lens)
343
+ merged_key_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
344
+ merged_value_states = past_key_states.new_zeros((seqlens, self.num_key_value_heads, self.head_dim))
345
+ merged_key_states[packed_query_indexes] = packed_key_states
346
+ merged_key_states[packed_key_value_indexes] = past_key_states
347
+ merged_value_states[packed_query_indexes] = packed_value_states
348
+ merged_value_states[packed_key_value_indexes] = past_value_states
349
+ key_values_lens = key_values_lens + query_lens
350
+ else:
351
+ merged_key_states = packed_key_states
352
+ merged_value_states = packed_value_states
353
+ key_values_lens = query_lens
354
+
355
+ cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
356
+ cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
357
+
358
+ packed_attn_output = flash_attn_varlen_func(
359
+ q=packed_query_states,
360
+ k=merged_key_states,
361
+ v=merged_value_states,
362
+ cu_seqlens_q=cu_seqlens_q.to(torch.int32),
363
+ cu_seqlens_k=cu_seqlens_k.to(torch.int32),
364
+ max_seqlen_q=max(query_lens).item(),
365
+ max_seqlen_k=max(key_values_lens).item(),
366
+ causal=is_causal,
367
+ )
368
+ packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
369
+ packed_attn_output = self.o_proj(packed_attn_output)
370
+
371
+ if update_past_key_values:
372
+ past_key_values.key_cache[self.layer_idx] = merged_key_states
373
+ past_key_values.value_cache[self.layer_idx] = merged_value_states
374
+
375
+ return packed_attn_output, past_key_values
376
+
377
+
378
+ class PackedAttentionMoT(Qwen2Attention):
379
+ def __init__(self, config, layer_idx: Optional[int] = None):
380
+ super().__init__(config, layer_idx)
381
+ if self.config.qk_norm:
382
+ self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
383
+ self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
384
+ self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
385
+ self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
386
+ else:
387
+ self.q_norm = nn.Identity()
388
+ self.k_norm = nn.Identity()
389
+ self.q_norm_moe_gen = nn.Identity()
390
+ self.k_norm_moe_gen = nn.Identity()
391
+
392
+ self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
393
+ self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
394
+ self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
395
+ self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
396
+
397
+ def forward(self, *args, **kwargs):
398
+ if self.training:
399
+ return self.forward_train(*args, **kwargs)
400
+ else:
401
+ return self.forward_inference(*args, **kwargs)
402
+
403
+ def forward_train(
404
+ self,
405
+ packed_sequence: torch.Tensor,
406
+ sample_lens: List[int],
407
+ attention_mask,
408
+ packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
409
+ packed_und_token_indexes: torch.LongTensor,
410
+ packed_gen_token_indexes: torch.LongTensor,
411
+ ):
412
+ packed_query_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_heads * self.head_dim))
413
+ packed_key_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))
414
+ packed_value_states = packed_sequence.new_zeros((packed_sequence.shape[0], self.num_key_value_heads * self.head_dim))
415
+
416
+ packed_sequence_und = packed_sequence[packed_und_token_indexes]
417
+ packed_sequence_gen = packed_sequence[packed_gen_token_indexes]
418
+
419
+ packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und)
420
+ packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(packed_sequence_gen)
421
+
422
+ packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und)
423
+ packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(packed_sequence_gen)
424
+
425
+ packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und)
426
+ packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(packed_sequence_gen)
427
+
428
+ packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
429
+ packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
430
+ packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)
431
+ if self.config.freeze_und:
432
+ packed_value_states[packed_und_token_indexes] = packed_value_states[packed_und_token_indexes].detach()
433
+
434
+ packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape)
435
+ packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape)
436
+
437
+ packed_query_states_[packed_und_token_indexes] = self.q_norm(packed_query_states[packed_und_token_indexes])
438
+ if self.config.freeze_und:
439
+ packed_query_states_[packed_und_token_indexes] = packed_query_states_[packed_und_token_indexes].detach()
440
+ packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_gen_token_indexes])
441
+
442
+ packed_key_states_[packed_und_token_indexes] = self.k_norm(packed_key_states[packed_und_token_indexes])
443
+ if self.config.freeze_und:
444
+ packed_key_states_[packed_und_token_indexes] = packed_key_states_[packed_und_token_indexes].detach()
445
+ packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_gen_token_indexes])
446
+
447
+ packed_cos, packed_sin = packed_position_embeddings
448
+ packed_query_states_, packed_key_states_ = apply_rotary_pos_emb(
449
+ packed_query_states_, packed_key_states_, packed_cos, packed_sin, unsqueeze_dim=1
450
+ )
451
+
452
+ if isinstance(attention_mask, List):
453
+ packed_key_states_ = packed_key_states_[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
454
+ packed_key_states_ = packed_key_states_.reshape(-1, self.num_heads, self.head_dim)
455
+ packed_value_states = packed_value_states[:, :, None, :].repeat(1, 1, self.num_key_value_groups, 1)
456
+ packed_value_states = packed_value_states.reshape(-1, self.num_heads, self.head_dim)
457
+
458
+ unpacked_query_states = packed_query_states_.transpose(0, 1).split(sample_lens, dim=1)
459
+ unpacked_key_states = packed_key_states_.transpose(0, 1).split(sample_lens, dim=1)
460
+ unpacked_value_states = packed_value_states.transpose(0, 1).split(sample_lens, dim=1)
461
+ upacked_attn_output = []
462
+ for query_states, key_states, value_states, attention_mask_per_sample in zip(
463
+ unpacked_query_states, unpacked_key_states, unpacked_value_states, attention_mask
464
+ ):
465
+ with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
466
+ attn_output = scaled_dot_product_attention(
467
+ query_states.to(torch.bfloat16).unsqueeze(0),
468
+ key_states.to(torch.bfloat16).unsqueeze(0),
469
+ value_states.to(torch.bfloat16).unsqueeze(0),
470
+ attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
471
+ )
472
+ upacked_attn_output.append(attn_output.squeeze(0))
473
+ packed_attn_output = torch.cat(upacked_attn_output, dim=1)
474
+ else:
475
+ pad_size = sum(sample_lens) - packed_query_states.shape[0]
476
+ packed_query_states_ = pad_sequence(packed_query_states_.permute(1, 0, 2), pad_size)
477
+ packed_key_states_ = pad_sequence(packed_key_states_.permute(1, 0, 2), pad_size)
478
+ packed_value_states = pad_sequence(packed_value_states.permute(1, 0, 2), pad_size)
479
+ packed_attn_output = flex_attention(
480
+ packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim
481
+ packed_key_states_.unsqueeze(0),
482
+ packed_value_states.unsqueeze(0),
483
+ enable_gqa=True,
484
+ block_mask=attention_mask,
485
+ )
486
+ end_index = packed_attn_output.shape[2] - pad_size
487
+ packed_attn_output = packed_attn_output[0, :, :end_index, :]
488
+
489
+ packed_attn_output = packed_attn_output.transpose(0, 1).reshape(-1, self.num_heads * self.head_dim)
490
+ packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape)
491
+ packed_attn_output_[packed_und_token_indexes] = self.o_proj(packed_attn_output[packed_und_token_indexes])
492
+ packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_gen_token_indexes])
493
+
494
+ return packed_attn_output_
495
+
496
+ def forward_inference(
497
+ self,
498
+ packed_query_sequence: torch.Tensor,
499
+ query_lens: torch.Tensor,
500
+ packed_query_position_embeddings: torch.Tensor,
501
+ packed_query_indexes: torch.Tensor,
502
+ past_key_values: Optional[NaiveCache] = None,
503
+ key_values_lens: Optional[torch.Tensor] = None,
504
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
505
+ update_past_key_values=True,
506
+ is_causal=True,
507
+ mode="und",
508
+ packed_vae_token_indexes=None,
509
+ packed_text_indexes=None,
510
+ ):
511
+ if mode == 'und':
512
+ packed_query_states = self.q_proj(packed_query_sequence).view(-1, self.num_heads, self.head_dim)
513
+ packed_key_states = self.k_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
514
+ packed_value_states = self.v_proj(packed_query_sequence).view(-1, self.num_key_value_heads, self.head_dim)
515
+ packed_query_states = self.q_norm(packed_query_states)
516
+ packed_key_states = self.k_norm(packed_key_states)
517
+ elif mode == 'gen':
518
+ packed_query_sequence = packed_query_sequence.to(torch.bfloat16)
519
+ packed_query_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_heads * self.head_dim))
520
+ packed_key_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))
521
+ packed_value_states = packed_query_sequence.new_zeros((packed_query_sequence.shape[0], self.num_key_value_heads * self.head_dim))
522
+
523
+ packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
524
+ packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
525
+
526
+ packed_query_states[packed_text_indexes] = self.q_proj(packed_text_query_sequence)
527
+ packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(packed_vae_query_sequence)
528
+
529
+ packed_key_states[packed_text_indexes] = self.k_proj(packed_text_query_sequence)
530
+ packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(packed_vae_query_sequence)
531
+
532
+ packed_value_states[packed_text_indexes] = self.v_proj(packed_text_query_sequence)
533
+ packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(packed_vae_query_sequence)
534
+
535
+ packed_query_states = packed_query_states.view(-1, self.num_heads, self.head_dim)
536
+ packed_key_states = packed_key_states.view(-1, self.num_key_value_heads, self.head_dim)
537
+ packed_value_states = packed_value_states.view(-1, self.num_key_value_heads, self.head_dim)
538
+
539
+ packed_query_states = packed_query_states.to(torch.float32)
540
+ packed_query_states[packed_text_indexes] = self.q_norm(packed_query_states[packed_text_indexes])
541
+ packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(packed_query_states[packed_vae_token_indexes])
542
+
543
+ packed_key_states = packed_key_states.to(torch.float32)
544
+ packed_key_states[packed_text_indexes] = self.k_norm(packed_key_states[packed_text_indexes])
545
+ packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(packed_key_states[packed_vae_token_indexes])
546
+
547
+ packed_cos, packed_sin = packed_query_position_embeddings
548
+ packed_query_states, packed_key_states = apply_rotary_pos_emb(
549
+ packed_query_states, packed_key_states, packed_cos, packed_sin, unsqueeze_dim=1
550
+ )
551
+
552
+ packed_query_states = packed_query_states.to(torch.bfloat16)
553
+ packed_key_states = packed_key_states.to(torch.bfloat16)
554
+ packed_value_states = packed_value_states.to(torch.bfloat16)
555
+
556
+ if past_key_values is not None and past_key_values.key_cache[self.layer_idx] is not None:
557
+ past_key_states = past_key_values.key_cache[self.layer_idx]
558
+ past_value_states = past_key_values.value_cache[self.layer_idx]
559
+
560
+ seqlens = sum(query_lens) + sum(key_values_lens)
561
+ merged_key_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])
562
+ merged_value_states = past_key_states.new_zeros(size=[seqlens, self.num_key_value_heads, self.head_dim])
563
+ merged_key_states[packed_query_indexes] = packed_key_states
564
+ merged_key_states[packed_key_value_indexes] = past_key_states
565
+ merged_value_states[packed_query_indexes] = packed_value_states
566
+ merged_value_states[packed_key_value_indexes] = past_value_states
567
+ key_values_lens = key_values_lens + query_lens
568
+ else:
569
+ merged_key_states = packed_key_states
570
+ merged_value_states = packed_value_states
571
+ key_values_lens = query_lens
572
+
573
+ cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
574
+ cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
575
+
576
+ packed_attn_output = flash_attn_varlen_func(
577
+ q=packed_query_states,
578
+ k=merged_key_states,
579
+ v=merged_value_states,
580
+ cu_seqlens_q=cu_seqlens_q.to(torch.int32),
581
+ cu_seqlens_k=cu_seqlens_k.to(torch.int32),
582
+ max_seqlen_q=max(query_lens).item(),
583
+ max_seqlen_k=max(key_values_lens).item(),
584
+ causal=is_causal,
585
+ )
586
+ packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
587
+ if mode == 'und':
588
+ packed_attn_output = self.o_proj(packed_attn_output)
589
+ elif mode == 'gen':
590
+ packed_attn_output[packed_text_indexes] = self.o_proj(packed_attn_output[packed_text_indexes])
591
+ packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(packed_attn_output[packed_vae_token_indexes])
592
+
593
+ if update_past_key_values:
594
+ past_key_values.key_cache[self.layer_idx] = merged_key_states
595
+ past_key_values.value_cache[self.layer_idx] = merged_value_states
596
+
597
+ return packed_attn_output, past_key_values
598
+
599
+
600
+ class Qwen2DecoderLayer(nn.Module):
601
+ def __init__(self, config, layer_idx: Optional[int] = None):
602
+ super().__init__()
603
+ self.hidden_size = config.hidden_size
604
+
605
+ self.self_attn = PackedAttention(config, layer_idx)
606
+
607
+ self.mlp = Qwen2MLP(config)
608
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
609
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
610
+
611
+ def forward(self, *args, **kwargs):
612
+ if self.training:
613
+ return self.forward_train(*args, **kwargs)
614
+ else:
615
+ return self.forward_inference(*args, **kwargs)
616
+
617
+ def forward_train(
618
+ self,
619
+ packed_sequence: torch.Tensor,
620
+ sample_lens: List[int],
621
+ attention_mask,
622
+ packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
623
+ ) -> torch.Tensor:
624
+
625
+ residual = packed_sequence
626
+ packed_sequence = self.input_layernorm(packed_sequence)
627
+
628
+ # Self Attention
629
+ packed_sequence = self.self_attn(
630
+ packed_sequence=packed_sequence,
631
+ sample_lens=sample_lens,
632
+ attention_mask=attention_mask,
633
+ packed_position_embeddings=packed_position_embeddings,
634
+ )
635
+ packed_sequence = residual + packed_sequence
636
+
637
+ # Fully Connected
638
+ residual = packed_sequence
639
+ packed_sequence = self.post_attention_layernorm(packed_sequence)
640
+ packed_sequence = self.mlp(packed_sequence)
641
+ packed_sequence = residual + packed_sequence
642
+
643
+ return packed_sequence
644
+
645
+ def forward_inference(
646
+ self,
647
+ packed_query_sequence: torch.Tensor,
648
+ query_lens: torch.Tensor,
649
+ packed_query_position_embeddings: torch.Tensor,
650
+ packed_query_indexes: torch.Tensor,
651
+ past_key_values: Optional[NaiveCache] = None,
652
+ key_values_lens: Optional[torch.Tensor] = None,
653
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
654
+ update_past_key_values=True,
655
+ is_causal=True,
656
+ ) -> BaseNavitOutputWithPast:
657
+
658
+ residual = packed_query_sequence
659
+ packed_query_sequence = self.input_layernorm(packed_query_sequence)
660
+
661
+ # Self Attention
662
+ packed_query_sequence, past_key_values = self.self_attn(
663
+ packed_query_sequence=packed_query_sequence,
664
+ query_lens=query_lens,
665
+ packed_query_position_embeddings=packed_query_position_embeddings,
666
+ packed_query_indexes=packed_query_indexes,
667
+ past_key_values=past_key_values,
668
+ key_values_lens=key_values_lens,
669
+ packed_key_value_indexes=packed_key_value_indexes,
670
+ update_past_key_values=update_past_key_values,
671
+ is_causal=is_causal,
672
+ )
673
+ packed_query_sequence = residual + packed_query_sequence
674
+
675
+ # Fully Connected
676
+ residual = packed_query_sequence
677
+ packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
678
+ packed_query_sequence = self.mlp(packed_query_sequence)
679
+ packed_query_sequence = residual + packed_query_sequence
680
+
681
+ return packed_query_sequence, past_key_values
682
+
683
+
684
+ class Qwen2MoTDecoderLayer(nn.Module):
685
+ def __init__(
686
+ self,
687
+ config,
688
+ layer_idx: Optional[int] = None,
689
+ attn_module: Optional[Qwen2Attention] = PackedAttentionMoT,
690
+ ):
691
+ super().__init__()
692
+ self.hidden_size = config.hidden_size
693
+ self.freeze_und = config.freeze_und
694
+
695
+ self.self_attn = attn_module(config, layer_idx)
696
+
697
+ self.mlp = Qwen2MLP(config)
698
+ self.mlp_moe_gen = Qwen2MLP(config)
699
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
700
+ self.input_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
701
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
702
+ self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
703
+
704
+ def forward(self, *args, **kwargs):
705
+ if self.training:
706
+ return self.forward_train(*args, **kwargs)
707
+ else:
708
+ return self.forward_inference(*args, **kwargs)
709
+
710
+ def forward_train(
711
+ self,
712
+ packed_sequence: torch.Tensor,
713
+ sample_lens: List[int],
714
+ attention_mask,
715
+ packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
716
+ packed_und_token_indexes: torch.LongTensor,
717
+ packed_gen_token_indexes: torch.LongTensor,
718
+ ) -> torch.Tensor:
719
+
720
+ residual = packed_sequence
721
+ packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
722
+ packed_sequence_[packed_und_token_indexes] = self.input_layernorm(packed_sequence[packed_und_token_indexes])
723
+ packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])
724
+
725
+ # Self Attention
726
+ packed_sequence_ = self.self_attn(
727
+ packed_sequence=packed_sequence_,
728
+ sample_lens=sample_lens,
729
+ attention_mask=attention_mask,
730
+ packed_position_embeddings=packed_position_embeddings,
731
+ packed_und_token_indexes=packed_und_token_indexes,
732
+ packed_gen_token_indexes=packed_gen_token_indexes,
733
+ )
734
+ if self.freeze_und:
735
+ packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
736
+ packed_sequence = residual + packed_sequence_
737
+
738
+ # Fully Connected
739
+ residual = packed_sequence
740
+ packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
741
+ packed_sequence_[packed_und_token_indexes] = self.mlp(
742
+ self.post_attention_layernorm(packed_sequence[packed_und_token_indexes])
743
+ )
744
+ if self.freeze_und:
745
+ packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
746
+
747
+ packed_sequence_[packed_gen_token_indexes] = self.mlp_moe_gen(
748
+ self.post_attention_layernorm_moe_gen(packed_sequence[packed_gen_token_indexes])
749
+ )
750
+ packed_sequence = residual + packed_sequence_
751
+
752
+ return packed_sequence
753
+
754
+ def forward_inference(
755
+ self,
756
+ packed_query_sequence: torch.Tensor,
757
+ query_lens: torch.Tensor,
758
+ packed_query_position_embeddings: torch.Tensor,
759
+ packed_query_indexes: torch.Tensor,
760
+ past_key_values: Optional[NaiveCache] = None,
761
+ key_values_lens: Optional[torch.Tensor] = None,
762
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
763
+ update_past_key_values=True,
764
+ is_causal=True,
765
+ mode="und",
766
+ packed_vae_token_indexes=None,
767
+ packed_text_indexes=None,
768
+ ) -> BaseNavitOutputWithPast:
769
+
770
+ residual = packed_query_sequence
771
+ if mode == "und":
772
+ packed_query_sequence = self.input_layernorm(packed_query_sequence)
773
+ elif mode == "gen":
774
+ packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
775
+ packed_query_sequence_[packed_text_indexes] = self.input_layernorm(packed_query_sequence[packed_text_indexes])
776
+ packed_query_sequence_[packed_vae_token_indexes] = self.input_layernorm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
777
+ packed_query_sequence = packed_query_sequence_
778
+
779
+ # Self Attention
780
+ packed_query_sequence, past_key_values = self.self_attn(
781
+ packed_query_sequence=packed_query_sequence,
782
+ query_lens=query_lens,
783
+ packed_query_position_embeddings=packed_query_position_embeddings,
784
+ packed_query_indexes=packed_query_indexes,
785
+ past_key_values=past_key_values,
786
+ key_values_lens=key_values_lens,
787
+ packed_key_value_indexes=packed_key_value_indexes,
788
+ update_past_key_values=update_past_key_values,
789
+ is_causal=is_causal,
790
+ mode=mode,
791
+ packed_vae_token_indexes=packed_vae_token_indexes,
792
+ packed_text_indexes=packed_text_indexes,
793
+ )
794
+ packed_query_sequence = residual + packed_query_sequence
795
+
796
+ # Fully Connected
797
+ residual = packed_query_sequence
798
+ if mode == "und":
799
+ packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
800
+ packed_query_sequence = self.mlp(packed_query_sequence)
801
+ elif mode == "gen":
802
+ packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
803
+ packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
804
+ packed_text_query_sequence = self.post_attention_layernorm(packed_text_query_sequence).to(torch.bfloat16)
805
+ packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(packed_vae_query_sequence).to(torch.bfloat16)
806
+
807
+ packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)
808
+ packed_query_sequence_[packed_text_indexes] = self.mlp(packed_text_query_sequence)
809
+ packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_vae_query_sequence)
810
+ packed_query_sequence = packed_query_sequence_
811
+
812
+ packed_query_sequence = residual + packed_query_sequence
813
+ return packed_query_sequence, past_key_values
814
+
815
+
816
+ class Qwen2MoEDecoderLayer(nn.Module):
817
+ def __init__(self, config, layer_idx: Optional[int] = None):
818
+ super().__init__()
819
+ self.hidden_size = config.hidden_size
820
+
821
+ self.self_attn = PackedAttention(config, layer_idx)
822
+
823
+ self.mlp = Qwen2MLP(config)
824
+ self.mlp_moe_gen = Qwen2MLP(config)
825
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
826
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
827
+
828
+ def forward(self, *args, **kwargs):
829
+ if self.training:
830
+ return self.forward_train(*args, **kwargs)
831
+ else:
832
+ return self.forward_inference(*args, **kwargs)
833
+
834
+ def forward_train(
835
+ self,
836
+ packed_sequence: torch.Tensor,
837
+ sample_lens: List[int],
838
+ attention_mask,
839
+ packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
840
+ packed_und_token_indexes: torch.LongTensor,
841
+ packed_gen_token_indexes: torch.LongTensor,
842
+ ) -> torch.Tensor:
843
+
844
+ residual = packed_sequence
845
+ packed_sequence = self.input_layernorm(packed_sequence)
846
+
847
+ # Self Attention
848
+ packed_sequence = self.self_attn(
849
+ packed_sequence=packed_sequence,
850
+ sample_lens=sample_lens,
851
+ attention_mask=attention_mask,
852
+ packed_position_embeddings=packed_position_embeddings,
853
+ )
854
+ packed_sequence = residual + packed_sequence
855
+
856
+ # Fully Connected
857
+ residual = packed_sequence
858
+ packed_sequence = self.post_attention_layernorm(packed_sequence)
859
+
860
+ packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape)
861
+ packed_sequence_und = self.mlp(packed_sequence[packed_und_token_indexes])
862
+ packed_sequence_gen = self.mlp_moe_gen(packed_sequence[packed_gen_token_indexes])
863
+ packed_sequence_new[packed_und_token_indexes] = packed_sequence_und
864
+ packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen
865
+
866
+ packed_sequence = residual + packed_sequence_new
867
+
868
+ return packed_sequence
869
+
870
+ def forward_inference(
871
+ self,
872
+ packed_query_sequence: torch.Tensor,
873
+ query_lens: torch.Tensor,
874
+ packed_query_position_embeddings: torch.Tensor,
875
+ packed_query_indexes: torch.Tensor,
876
+ past_key_values: Optional[NaiveCache] = None,
877
+ key_values_lens: Optional[torch.Tensor] = None,
878
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
879
+ update_past_key_values=True,
880
+ is_causal=True,
881
+ mode="und",
882
+ packed_vae_token_indexes=None,
883
+ packed_text_indexes=None,
884
+ ) -> BaseNavitOutputWithPast:
885
+
886
+ residual = packed_query_sequence
887
+ packed_query_sequence = self.input_layernorm(packed_query_sequence)
888
+
889
+ # Self Attention
890
+ packed_query_sequence, past_key_values = self.self_attn(
891
+ packed_query_sequence=packed_query_sequence,
892
+ query_lens=query_lens,
893
+ packed_query_position_embeddings=packed_query_position_embeddings,
894
+ packed_query_indexes=packed_query_indexes,
895
+ past_key_values=past_key_values,
896
+ key_values_lens=key_values_lens,
897
+ packed_key_value_indexes=packed_key_value_indexes,
898
+ update_past_key_values=update_past_key_values,
899
+ is_causal=is_causal,
900
+ )
901
+ packed_query_sequence = residual + packed_query_sequence
902
+
903
+ # Fully Connected
904
+ residual = packed_query_sequence
905
+ packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
906
+ if mode == "und":
907
+ packed_query_sequence = self.mlp(packed_query_sequence)
908
+ elif mode == "gen":
909
+ packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(torch.bfloat16)
910
+ packed_query_sequence_[packed_text_indexes] = self.mlp(packed_query_sequence[packed_text_indexes])
911
+ packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(packed_query_sequence[packed_vae_token_indexes])
912
+ packed_query_sequence = packed_query_sequence_
913
+ packed_query_sequence = residual + packed_query_sequence
914
+
915
+ return packed_query_sequence, past_key_values
916
+
917
+
918
+ Decoder_layer_dict = {
919
+ "Qwen2DecoderLayer": Qwen2DecoderLayer,
920
+ "Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer,
921
+ "Qwen2MoTDecoderLayer": partial(Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT),
922
+ }
923
+
924
+
925
+ class Qwen2Model(Qwen2PreTrainedModel):
926
+ def __init__(self, config):
927
+ super().__init__(config)
928
+ self.padding_idx = config.pad_token_id
929
+ self.vocab_size = config.vocab_size
930
+ self.use_moe = 'Mo' in config.layer_module
931
+
932
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
933
+ layer_module = Decoder_layer_dict[config.layer_module]
934
+ self.layers = nn.ModuleList(
935
+ [layer_module(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
936
+ )
937
+
938
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
939
+ if self.use_moe:
940
+ self.norm_moe_gen = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
941
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
942
+
943
+ # Initialize weights and apply final processing
944
+ self.post_init()
945
+
946
+ def forward(self, *args, **kwargs):
947
+ if self.training:
948
+ return self.forward_train(*args, **kwargs)
949
+ else:
950
+ return self.forward_inference(*args, **kwargs)
951
+
952
+ def forward_train(
953
+ self,
954
+ packed_sequence: torch.Tensor,
955
+ sample_lens: List[int],
956
+ attention_mask,
957
+ packed_position_ids: torch.Tensor,
958
+ packed_und_token_indexes: Optional[torch.LongTensor] = None,
959
+ packed_gen_token_indexes: Optional[torch.LongTensor] = None,
960
+ ) -> torch.Tensor:
961
+
962
+ if self.config.freeze_und:
963
+ packed_sequence[packed_und_token_indexes] = packed_sequence[packed_und_token_indexes].detach()
964
+
965
+ # create position embeddings to be shared across the decoder layers
966
+ cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0))
967
+ cos = cos.squeeze(0)
968
+ sin = sin.squeeze(0)
969
+ packed_position_embeddings = (cos, sin)
970
+
971
+ extra_inputs = {}
972
+ if self.use_moe:
973
+ assert packed_und_token_indexes is not None
974
+ if packed_gen_token_indexes is None:
975
+ packed_gen_token_indexes = packed_und_token_indexes.new_ones(size=[0])
976
+ extra_inputs.update(
977
+ packed_und_token_indexes=packed_und_token_indexes,
978
+ packed_gen_token_indexes=packed_gen_token_indexes,
979
+ )
980
+
981
+ for decoder_layer in self.layers:
982
+ packed_sequence = decoder_layer(
983
+ packed_sequence=packed_sequence,
984
+ sample_lens=sample_lens,
985
+ attention_mask=attention_mask,
986
+ packed_position_embeddings=packed_position_embeddings,
987
+ **extra_inputs
988
+ )
989
+
990
+ if self.use_moe:
991
+ packed_sequence_ = torch.zeros_like(packed_sequence)
992
+ packed_sequence_[packed_und_token_indexes] = self.norm(packed_sequence[packed_und_token_indexes])
993
+ if self.config.freeze_und:
994
+ packed_sequence_[packed_und_token_indexes] = packed_sequence_[packed_und_token_indexes].detach()
995
+ packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(packed_sequence[packed_gen_token_indexes])
996
+ return packed_sequence_
997
+ else:
998
+ return self.norm(packed_sequence)
999
+
1000
+ def forward_inference(
1001
+ self,
1002
+ packed_query_sequence: torch.Tensor,
1003
+ query_lens: torch.Tensor,
1004
+ packed_query_position_ids: torch.Tensor,
1005
+ packed_query_indexes: torch.Tensor,
1006
+ past_key_values: Optional[NaiveCache] = None,
1007
+ key_values_lens: Optional[torch.Tensor] = None,
1008
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
1009
+ update_past_key_values=True,
1010
+ is_causal=True,
1011
+ mode="und",
1012
+ packed_vae_token_indexes=None,
1013
+ packed_text_indexes=None,
1014
+ ) -> BaseNavitOutputWithPast:
1015
+
1016
+ # create position embeddings to be shared across the decoder layers
1017
+ cos, sin = self.rotary_emb(packed_query_sequence, packed_query_position_ids.unsqueeze(0))
1018
+ cos = cos.squeeze(0)
1019
+ sin = sin.squeeze(0)
1020
+ packed_query_position_embeddings = (cos, sin)
1021
+
1022
+ extra_inputs = {}
1023
+ if self.use_moe:
1024
+ extra_inputs.update(mode=mode)
1025
+ if mode == 'gen':
1026
+ assert packed_vae_token_indexes is not None
1027
+ assert packed_text_indexes is not None
1028
+ extra_inputs.update(
1029
+ packed_vae_token_indexes=packed_vae_token_indexes,
1030
+ packed_text_indexes=packed_text_indexes,
1031
+ )
1032
+
1033
+ for decoder_layer in self.layers:
1034
+ packed_query_sequence, past_key_values = decoder_layer(
1035
+ packed_query_sequence=packed_query_sequence,
1036
+ query_lens=query_lens,
1037
+ packed_query_position_embeddings=packed_query_position_embeddings,
1038
+ packed_query_indexes=packed_query_indexes,
1039
+ past_key_values=past_key_values,
1040
+ key_values_lens=key_values_lens,
1041
+ packed_key_value_indexes=packed_key_value_indexes,
1042
+ update_past_key_values=update_past_key_values,
1043
+ is_causal=is_causal,
1044
+ **extra_inputs,
1045
+ )
1046
+
1047
+ if self.use_moe:
1048
+ if mode == "und":
1049
+ packed_query_sequence = self.norm(packed_query_sequence)
1050
+ elif mode == "gen":
1051
+ packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
1052
+ packed_query_sequence_[packed_text_indexes] = self.norm(packed_query_sequence[packed_text_indexes])
1053
+ packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(packed_query_sequence[packed_vae_token_indexes])
1054
+ packed_query_sequence = packed_query_sequence_
1055
+ else:
1056
+ packed_query_sequence = self.norm(packed_query_sequence)
1057
+
1058
+ return BaseNavitOutputWithPast(
1059
+ packed_query_sequence=packed_query_sequence,
1060
+ past_key_values=past_key_values,
1061
+ )
1062
+
1063
+
1064
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1065
+ _tied_weights_keys = ["lm_head.weight"]
1066
+
1067
+ def __init__(self, config):
1068
+ super().__init__(config)
1069
+ self.model = Qwen2Model(config)
1070
+ self.vocab_size = config.vocab_size
1071
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1072
+
1073
+ # Initialize weights and apply final processing
1074
+ self.post_init()
1075
+
1076
+ def init_moe(self):
1077
+ for name, param in self.named_parameters():
1078
+ if "moe_gen" in name:
1079
+ original_name = name.replace("_moe_gen", "")
1080
+ param.data.copy_(self.state_dict()[original_name].data)
1081
+
1082
+ def get_input_embeddings(self):
1083
+ return self.model.embed_tokens
1084
+
1085
+ def set_input_embeddings(self, value):
1086
+ self.model.embed_tokens = value
1087
+
1088
+ def get_output_embeddings(self):
1089
+ return self.lm_head
1090
+
1091
+ def set_output_embeddings(self, new_embeddings):
1092
+ self.lm_head = new_embeddings
1093
+
1094
+ def set_decoder(self, decoder):
1095
+ self.model = decoder
1096
+
1097
+ def get_decoder(self):
1098
+ return self.model
1099
+
1100
+ def forward(self, *args, **kwargs):
1101
+ if self.training:
1102
+ return self.forward_train(*args, **kwargs)
1103
+ else:
1104
+ return self.forward_inference(*args, **kwargs)
1105
+
1106
+ def forward_train(
1107
+ self,
1108
+ packed_sequence: torch.Tensor,
1109
+ sample_lens: List[int],
1110
+ attention_mask,
1111
+ packed_position_ids: torch.Tensor,
1112
+ packed_und_token_indexes: Optional[torch.LongTensor] = None,
1113
+ packed_gen_token_indexes: Optional[torch.LongTensor] = None,
1114
+ ) -> torch.Tensor:
1115
+
1116
+ outputs = self.model(
1117
+ packed_sequence=packed_sequence,
1118
+ sample_lens=sample_lens,
1119
+ packed_position_ids=packed_position_ids,
1120
+ attention_mask=attention_mask,
1121
+ packed_und_token_indexes=packed_und_token_indexes,
1122
+ packed_gen_token_indexes=packed_gen_token_indexes,
1123
+ )
1124
+ return outputs
1125
+
1126
+ def forward_inference(
1127
+ self,
1128
+ packed_query_sequence: torch.Tensor,
1129
+ query_lens: torch.Tensor,
1130
+ packed_query_position_ids: torch.Tensor,
1131
+ packed_query_indexes: torch.Tensor,
1132
+ past_key_values: Optional[NaiveCache] = None,
1133
+ key_values_lens: Optional[torch.Tensor] = None,
1134
+ packed_key_value_indexes: Optional[torch.Tensor] = None,
1135
+ update_past_key_values=True,
1136
+ is_causal=True,
1137
+ mode="und",
1138
+ packed_vae_token_indexes=None,
1139
+ packed_text_indexes=None,
1140
+ ) -> BaseNavitOutputWithPast:
1141
+
1142
+ outputs = self.model(
1143
+ packed_query_sequence=packed_query_sequence,
1144
+ query_lens=query_lens,
1145
+ packed_query_position_ids=packed_query_position_ids,
1146
+ packed_query_indexes=packed_query_indexes,
1147
+ past_key_values=past_key_values,
1148
+ key_values_lens=key_values_lens,
1149
+ packed_key_value_indexes=packed_key_value_indexes,
1150
+ update_past_key_values=update_past_key_values,
1151
+ is_causal=is_causal,
1152
+ mode=mode,
1153
+ packed_vae_token_indexes=packed_vae_token_indexes,
1154
+ packed_text_indexes=packed_text_indexes,
1155
+ )
1156
+
1157
+ return outputs
modeling/bagel/siglip_navit.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 The HuggingFace Inc. team.
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under Apache-2.0, with the full license text
8
+ # available at https://github.com/huggingface/transformers/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from transformers.activations import ACT2FN
16
+ from modeling.siglip.configuration_siglip import SiglipVisionConfig as _SiglipVisionConfig
17
+ from modeling.siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel
18
+ from flash_attn import flash_attn_varlen_func
19
+
20
+
21
+ class SiglipVisionConfig(_SiglipVisionConfig):
22
+ r"""
23
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
24
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
25
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
26
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
27
+
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+ Args:
32
+ hidden_size (`int`, *optional*, defaults to 768):
33
+ Dimensionality of the encoder layers and the pooler layer.
34
+ intermediate_size (`int`, *optional*, defaults to 3072):
35
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
36
+ num_hidden_layers (`int`, *optional*, defaults to 12):
37
+ Number of hidden layers in the Transformer encoder.
38
+ num_attention_heads (`int`, *optional*, defaults to 12):
39
+ Number of attention heads for each attention layer in the Transformer encoder.
40
+ num_channels (`int`, *optional*, defaults to 3):
41
+ Number of channels in the input images.
42
+ image_size (`int`, *optional*, defaults to 224):
43
+ The size (resolution) of each image.
44
+ patch_size (`int`, *optional*, defaults to 16):
45
+ The size (resolution) of each patch.
46
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
47
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
49
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
50
+ The epsilon used by the layer normalization layers.
51
+ attention_dropout (`float`, *optional*, defaults to 0.0):
52
+ The dropout ratio for the attention probabilities.
53
+
54
+ Example:
55
+
56
+ ```python
57
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
58
+
59
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
60
+ >>> configuration = SiglipVisionConfig()
61
+
62
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
63
+ >>> model = SiglipVisionModel(configuration)
64
+
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```"""
68
+
69
+ model_type = "siglip_vision_model"
70
+
71
+ def __init__(
72
+ self,
73
+ hidden_size=768,
74
+ intermediate_size=3072,
75
+ num_hidden_layers=12,
76
+ num_attention_heads=12,
77
+ num_channels=3,
78
+ image_size=224,
79
+ patch_size=16,
80
+ hidden_act="gelu_pytorch_tanh",
81
+ layer_norm_eps=1e-6,
82
+ attention_dropout=0.0,
83
+ rope=True,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(
87
+ hidden_size=hidden_size,
88
+ intermediate_size=intermediate_size,
89
+ num_hidden_layers=num_hidden_layers,
90
+ num_attention_heads=num_attention_heads,
91
+ num_channels=num_channels,
92
+ image_size=image_size,
93
+ patch_size=patch_size,
94
+ hidden_act=hidden_act,
95
+ layer_norm_eps=layer_norm_eps,
96
+ attention_dropout=attention_dropout,
97
+ **kwargs)
98
+
99
+ self.rope = rope
100
+
101
+
102
+ class RotaryEmbedding2D(torch.nn.Module):
103
+ def __init__(self, dim, max_h, max_w, base=10000):
104
+ super().__init__()
105
+ freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
106
+ inv_freq = 1.0 / (base ** freq)
107
+
108
+ grid_h = torch.arange(0, max_h)
109
+ grid_h = grid_h.to(inv_freq.dtype)
110
+ grid_h = grid_h[:, None].repeat(1, max_w)
111
+
112
+ grid_w = torch.arange(0, max_w)
113
+ grid_w = grid_w.to(inv_freq.dtype)
114
+ grid_w = grid_w[None, :].repeat(max_h, 1)
115
+
116
+ cos_h, sin_h = self._forward_one_side(grid_h, inv_freq)
117
+ cos_w, sin_w = self._forward_one_side(grid_w, inv_freq)
118
+
119
+ self.register_buffer("cos_h", cos_h)
120
+ self.register_buffer("sin_h", sin_h)
121
+ self.register_buffer("cos_w", cos_w)
122
+ self.register_buffer("sin_w", sin_w)
123
+
124
+ def _forward_one_side(self, grid, inv_freq):
125
+ freqs = grid[..., None] * inv_freq[None, None, :]
126
+ emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1)
127
+ return emb.cos(), emb.sin()
128
+
129
+
130
+ def rotate_half(x):
131
+ x1 = x[..., : x.shape[-1] // 2]
132
+ x2 = x[..., x.shape[-1] // 2 :]
133
+ return torch.cat((-x2, x1), dim=-1)
134
+
135
+
136
+ def apply_rotary_pos_emb(q, k, cos, sin):
137
+ # unsqueeze due to the head dimension
138
+ cos = cos.unsqueeze(1)
139
+ sin = sin.unsqueeze(1)
140
+ q_embed = (q * cos) + (rotate_half(q) * sin)
141
+ k_embed = (k * cos) + (rotate_half(k) * sin)
142
+ return q_embed, k_embed
143
+
144
+
145
+ class SiglipVisionEmbeddings(nn.Module):
146
+ def __init__(self, config: SiglipVisionConfig):
147
+ super().__init__()
148
+ self.config = config
149
+ self.embed_dim = config.hidden_size
150
+ self.image_size = config.image_size
151
+ self.patch_size = config.patch_size
152
+
153
+ self.patch_embedding = nn.Conv2d(
154
+ in_channels=config.num_channels,
155
+ out_channels=self.embed_dim,
156
+ kernel_size=self.patch_size,
157
+ stride=self.patch_size,
158
+ padding="valid",
159
+ )
160
+
161
+ self.num_patches_per_side = self.image_size // self.patch_size
162
+ self.num_patches = self.num_patches_per_side**2
163
+ self.num_positions = self.num_patches
164
+ if not config.rope:
165
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
166
+
167
+ def convert_conv2d_to_linear(self, config, meta=False):
168
+ if meta:
169
+ linear_patch_embedding = nn.Linear(
170
+ config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True, device='meta'
171
+ )
172
+ else:
173
+ linear_patch_embedding = nn.Linear(
174
+ config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True
175
+ )
176
+ W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape(
177
+ self.embed_dim, config.num_channels * self.patch_size ** 2
178
+ )
179
+ linear_patch_embedding.weight.data = W
180
+ linear_patch_embedding.bias.data = self.patch_embedding.bias.data
181
+ del self.patch_embedding
182
+ self.patch_embedding = linear_patch_embedding
183
+
184
+ def forward(
185
+ self,
186
+ packed_pixel_values: torch.FloatTensor,
187
+ packed_flattened_position_ids: torch.LongTensor
188
+ ) -> torch.Tensor:
189
+
190
+ patch_embeds = self.patch_embedding(packed_pixel_values)
191
+ if not self.config.rope:
192
+ embeddings = patch_embeds + self.position_embedding(packed_flattened_position_ids)
193
+ else:
194
+ embeddings = patch_embeds
195
+ return embeddings
196
+
197
+
198
+ class SiglipFlashAttention2(SiglipAttention):
199
+ def __init__(self, *args, **kwargs):
200
+ super().__init__(*args, **kwargs)
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states: torch.Tensor,
205
+ cu_seqlens: torch.IntTensor,
206
+ max_seqlen: int,
207
+ cos_h: torch.Tensor = None,
208
+ sin_h: torch.Tensor = None,
209
+ cos_w: torch.Tensor = None,
210
+ sin_w: torch.Tensor = None,
211
+ **kwargs,
212
+ ) -> torch.Tensor:
213
+
214
+ total_q_len, _ = hidden_states.size()
215
+
216
+ query_states = self.q_proj(hidden_states)
217
+ key_states = self.k_proj(hidden_states)
218
+ value_states = self.v_proj(hidden_states)
219
+
220
+ query_states = query_states.view(total_q_len, self.num_heads, self.head_dim)
221
+ key_states = key_states.view(total_q_len, self.num_heads, self.head_dim)
222
+ value_states = value_states.view(total_q_len, self.num_heads, self.head_dim)
223
+
224
+ if self.config.rope:
225
+ qh, qw = query_states[:, :, :self.head_dim // 2], query_states[:, :, self.head_dim // 2:]
226
+ kh, kw = key_states[:, :, :self.head_dim // 2], key_states[:, :, self.head_dim // 2:]
227
+ qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h)
228
+ qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w)
229
+ query_states = torch.cat([qh, qw], dim=-1)
230
+ key_states = torch.cat([kh, kw], dim=-1)
231
+
232
+ attn_output = flash_attn_varlen_func(
233
+ query_states.to(torch.bfloat16),
234
+ key_states.to(torch.bfloat16),
235
+ value_states.to(torch.bfloat16),
236
+ cu_seqlens_q=cu_seqlens,
237
+ cu_seqlens_k=cu_seqlens,
238
+ max_seqlen_q=max_seqlen,
239
+ max_seqlen_k=max_seqlen,
240
+ causal=False,
241
+ )
242
+
243
+ attn_output = self.out_proj(attn_output.reshape(total_q_len, -1))
244
+ return attn_output
245
+
246
+
247
+ class SiglipMLP(nn.Module):
248
+ def __init__(self, config):
249
+ super().__init__()
250
+ self.config = config
251
+ self.activation_fn = ACT2FN[config.hidden_act]
252
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
253
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
254
+
255
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
256
+ hidden_states = self.fc1(hidden_states)
257
+ hidden_states = self.activation_fn(hidden_states)
258
+ hidden_states = self.fc2(hidden_states)
259
+ return hidden_states
260
+
261
+
262
+ class SiglipEncoderLayer(nn.Module):
263
+ def __init__(self, config: SiglipVisionConfig):
264
+ super().__init__()
265
+ self.embed_dim = config.hidden_size
266
+ self.self_attn = SiglipFlashAttention2(config)
267
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
268
+ self.mlp = SiglipMLP(config)
269
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
270
+
271
+ def forward(
272
+ self,
273
+ hidden_states: torch.Tensor,
274
+ cu_seqlens: torch.IntTensor,
275
+ max_seqlen: int,
276
+ cos_h: torch.Tensor = None,
277
+ sin_h: torch.Tensor = None,
278
+ cos_w: torch.Tensor = None,
279
+ sin_w: torch.Tensor = None
280
+ ) -> torch.Tensor:
281
+ residual = hidden_states
282
+
283
+ hidden_states = self.layer_norm1(hidden_states)
284
+ hidden_states = self.self_attn(
285
+ hidden_states=hidden_states,
286
+ cu_seqlens=cu_seqlens,
287
+ max_seqlen=max_seqlen,
288
+ cos_h=cos_h,
289
+ sin_h=sin_h,
290
+ cos_w=cos_w,
291
+ sin_w=sin_w
292
+ )
293
+ hidden_states = residual + hidden_states
294
+
295
+ residual = hidden_states
296
+ hidden_states = self.layer_norm2(hidden_states)
297
+ hidden_states = self.mlp(hidden_states)
298
+ hidden_states = residual + hidden_states
299
+
300
+ return hidden_states
301
+
302
+
303
+ class SiglipEncoder(nn.Module):
304
+ def __init__(self, config: SiglipVisionConfig):
305
+ super().__init__()
306
+ self.config = config
307
+ self.layers = nn.ModuleList(
308
+ [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
309
+ )
310
+
311
+ def forward(
312
+ self,
313
+ inputs_embeds: torch.Tensor,
314
+ cu_seqlens: torch.IntTensor,
315
+ max_seqlen: int,
316
+ cos_h: torch.Tensor = None,
317
+ sin_h: torch.Tensor = None,
318
+ cos_w: torch.Tensor = None,
319
+ sin_w: torch.Tensor = None,
320
+ ) -> torch.Tensor:
321
+
322
+ hidden_states = inputs_embeds
323
+ for encoder_layer in self.layers:
324
+ hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen,
325
+ cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w)
326
+
327
+ return hidden_states
328
+
329
+
330
+ class SiglipVisionTransformer(nn.Module):
331
+ def __init__(self, config: SiglipVisionConfig):
332
+ super().__init__()
333
+ self.config = config
334
+ embed_dim = config.hidden_size
335
+
336
+ self.embeddings = SiglipVisionEmbeddings(config)
337
+ if config.rope:
338
+ max_size = config.image_size // config.patch_size
339
+ dim_head = config.hidden_size // config.num_attention_heads
340
+ self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size)
341
+
342
+ self.encoder = SiglipEncoder(config)
343
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
344
+
345
+ def forward(
346
+ self,
347
+ packed_pixel_values: torch.Tensor,
348
+ packed_flattened_position_ids: torch.LongTensor,
349
+ cu_seqlens: torch.IntTensor,
350
+ max_seqlen: int,
351
+ ) -> torch.Tensor:
352
+ hidden_states = self.embeddings(
353
+ packed_pixel_values=packed_pixel_values,
354
+ packed_flattened_position_ids=packed_flattened_position_ids
355
+ )
356
+
357
+ extra_inputs = {}
358
+ if self.config.rope:
359
+ extra_inputs.update(
360
+ cos_h = self.rope.cos_h[packed_flattened_position_ids],
361
+ sin_h = self.rope.sin_h[packed_flattened_position_ids],
362
+ cos_w = self.rope.cos_w[packed_flattened_position_ids],
363
+ sin_w = self.rope.sin_w[packed_flattened_position_ids]
364
+ )
365
+
366
+ last_hidden_state = self.encoder(
367
+ inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
368
+ **extra_inputs
369
+ )
370
+ last_hidden_state = self.post_layernorm(last_hidden_state)
371
+ return last_hidden_state
372
+
373
+
374
+ class SiglipVisionModel(SiglipPreTrainedModel):
375
+ config_class = SiglipVisionConfig
376
+ main_input_name = "packed_pixel_values"
377
+
378
+ def __init__(self, config: SiglipVisionConfig):
379
+ super().__init__(config)
380
+
381
+ self.vision_model = SiglipVisionTransformer(config)
382
+
383
+ # Initialize weights and apply final processing
384
+ self.post_init()
385
+
386
+ def get_input_embeddings(self) -> nn.Module:
387
+ return self.vision_model.embeddings.patch_embedding
388
+
389
+ def forward(
390
+ self,
391
+ packed_pixel_values: torch.Tensor,
392
+ packed_flattened_position_ids: torch.LongTensor,
393
+ cu_seqlens: torch.IntTensor,
394
+ max_seqlen: int,
395
+ ) -> torch.Tensor:
396
+
397
+ return self.vision_model(
398
+ packed_pixel_values=packed_pixel_values,
399
+ packed_flattened_position_ids=packed_flattened_position_ids,
400
+ cu_seqlens=cu_seqlens,
401
+ max_seqlen=max_seqlen,
402
+ )
modeling/qwen2/__init__.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ from transformers.utils import (
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_tokenizers_available,
10
+ is_torch_available,
11
+ )
12
+
13
+
14
+ _import_structure = {
15
+ "configuration_qwen2": ["Qwen2Config"],
16
+ "tokenization_qwen2": ["Qwen2Tokenizer"],
17
+ }
18
+
19
+ try:
20
+ if not is_tokenizers_available():
21
+ raise OptionalDependencyNotAvailable()
22
+ except OptionalDependencyNotAvailable:
23
+ pass
24
+ else:
25
+ _import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
26
+
27
+ try:
28
+ if not is_torch_available():
29
+ raise OptionalDependencyNotAvailable()
30
+ except OptionalDependencyNotAvailable:
31
+ pass
32
+ else:
33
+ _import_structure["modeling_qwen2"] = [
34
+ "Qwen2ForCausalLM",
35
+ "Qwen2Model",
36
+ "Qwen2PreTrainedModel",
37
+ ]
38
+
39
+
40
+ if TYPE_CHECKING:
41
+ from .configuration_qwen2 import Qwen2Config
42
+ from .tokenization_qwen2 import Qwen2Tokenizer
43
+
44
+ try:
45
+ if not is_tokenizers_available():
46
+ raise OptionalDependencyNotAvailable()
47
+ except OptionalDependencyNotAvailable:
48
+ pass
49
+ else:
50
+ from .tokenization_qwen2_fast import Qwen2TokenizerFast
51
+
52
+ try:
53
+ if not is_torch_available():
54
+ raise OptionalDependencyNotAvailable()
55
+ except OptionalDependencyNotAvailable:
56
+ pass
57
+ else:
58
+ from .modeling_qwen2 import (
59
+ Qwen2ForCausalLM,
60
+ Qwen2Model,
61
+ Qwen2PreTrainedModel,
62
+ )
63
+
64
+
65
+ else:
66
+ import sys
67
+
68
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
modeling/qwen2/configuration_qwen2.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Qwen2 model configuration"""
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.modeling_rope_utils import rope_config_validation
8
+ from transformers.utils import logging
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class Qwen2Config(PretrainedConfig):
15
+ r"""
16
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
17
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
18
+ with the defaults will yield a similar configuration to that of
19
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
20
+
21
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
+ documentation from [`PretrainedConfig`] for more information.
23
+
24
+
25
+ Args:
26
+ vocab_size (`int`, *optional*, defaults to 151936):
27
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
28
+ `inputs_ids` passed when calling [`Qwen2Model`]
29
+ hidden_size (`int`, *optional*, defaults to 4096):
30
+ Dimension of the hidden representations.
31
+ intermediate_size (`int`, *optional*, defaults to 22016):
32
+ Dimension of the MLP representations.
33
+ num_hidden_layers (`int`, *optional*, defaults to 32):
34
+ Number of hidden layers in the Transformer encoder.
35
+ num_attention_heads (`int`, *optional*, defaults to 32):
36
+ Number of attention heads for each attention layer in the Transformer encoder.
37
+ num_key_value_heads (`int`, *optional*, defaults to 32):
38
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
39
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
40
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
41
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
42
+ by meanpooling all the original heads within that group. For more details checkout [this
43
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
45
+ The non-linear activation function (function or string) in the decoder.
46
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
47
+ The maximum sequence length that this model might ever be used with.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
51
+ The epsilon used by the rms normalization layers.
52
+ use_cache (`bool`, *optional*, defaults to `True`):
53
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
54
+ relevant if `config.is_decoder=True`.
55
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
56
+ Whether the model's input and output word embeddings should be tied.
57
+ rope_theta (`float`, *optional*, defaults to 10000.0):
58
+ The base period of the RoPE embeddings.
59
+ rope_scaling (`Dict`, *optional*):
60
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
61
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
62
+ accordingly.
63
+ Expected contents:
64
+ `rope_type` (`str`):
65
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
66
+ 'llama3'], with 'default' being the original RoPE implementation.
67
+ `factor` (`float`, *optional*):
68
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
69
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
70
+ original maximum pre-trained length.
71
+ `original_max_position_embeddings` (`int`, *optional*):
72
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
73
+ pretraining.
74
+ `attention_factor` (`float`, *optional*):
75
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
76
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
77
+ `factor` field to infer the suggested value.
78
+ `beta_fast` (`float`, *optional*):
79
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
80
+ ramp function. If unspecified, it defaults to 32.
81
+ `beta_slow` (`float`, *optional*):
82
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
83
+ ramp function. If unspecified, it defaults to 1.
84
+ `short_factor` (`List[float]`, *optional*):
85
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
86
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
87
+ size divided by the number of attention heads divided by 2
88
+ `long_factor` (`List[float]`, *optional*):
89
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
90
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
91
+ size divided by the number of attention heads divided by 2
92
+ `low_freq_factor` (`float`, *optional*):
93
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
94
+ `high_freq_factor` (`float`, *optional*):
95
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
96
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
97
+ Whether to use sliding window attention.
98
+ sliding_window (`int`, *optional*, defaults to 4096):
99
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
100
+ max_window_layers (`int`, *optional*, defaults to 28):
101
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
102
+ attention_dropout (`float`, *optional*, defaults to 0.0):
103
+ The dropout ratio for the attention probabilities.
104
+
105
+ ```python
106
+ >>> from transformers import Qwen2Model, Qwen2Config
107
+
108
+ >>> # Initializing a Qwen2 style configuration
109
+ >>> configuration = Qwen2Config()
110
+
111
+ >>> # Initializing a model from the Qwen2-7B style configuration
112
+ >>> model = Qwen2Model(configuration)
113
+
114
+ >>> # Accessing the model configuration
115
+ >>> configuration = model.config
116
+ ```"""
117
+
118
+ model_type = "qwen2"
119
+ keys_to_ignore_at_inference = ["past_key_values"]
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_size=151936,
124
+ hidden_size=4096,
125
+ intermediate_size=22016,
126
+ num_hidden_layers=32,
127
+ num_attention_heads=32,
128
+ num_key_value_heads=32,
129
+ hidden_act="silu",
130
+ max_position_embeddings=32768,
131
+ initializer_range=0.02,
132
+ rms_norm_eps=1e-6,
133
+ use_cache=True,
134
+ tie_word_embeddings=False,
135
+ rope_theta=10000.0,
136
+ rope_scaling=None,
137
+ use_sliding_window=False,
138
+ sliding_window=4096,
139
+ max_window_layers=28,
140
+ attention_dropout=0.0,
141
+ is_causal=True,
142
+ _attn_implementation="flash_attention_2",
143
+ **kwargs,
144
+ ):
145
+ self.vocab_size = vocab_size
146
+ self.max_position_embeddings = max_position_embeddings
147
+ self.hidden_size = hidden_size
148
+ self.intermediate_size = intermediate_size
149
+ self.num_hidden_layers = num_hidden_layers
150
+ self.num_attention_heads = num_attention_heads
151
+ self.use_sliding_window = use_sliding_window
152
+ self.sliding_window = sliding_window if use_sliding_window else None
153
+ self.max_window_layers = max_window_layers
154
+
155
+ # for backward compatibility
156
+ if num_key_value_heads is None:
157
+ num_key_value_heads = num_attention_heads
158
+
159
+ self.num_key_value_heads = num_key_value_heads
160
+ self.hidden_act = hidden_act
161
+ self.initializer_range = initializer_range
162
+ self.rms_norm_eps = rms_norm_eps
163
+ self.use_cache = use_cache
164
+ self.rope_theta = rope_theta
165
+ self.rope_scaling = rope_scaling
166
+ self.attention_dropout = attention_dropout
167
+ self.is_causal = is_causal
168
+ self._attn_implementation = _attn_implementation
169
+
170
+ # Validate the correctness of rotary position embeddings parameters
171
+ # BC: if there is a 'type' field, move it to 'rope_type'.
172
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
173
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
174
+ rope_config_validation(self)
175
+
176
+ super().__init__(
177
+ tie_word_embeddings=tie_word_embeddings,
178
+ **kwargs,
179
+ )
modeling/qwen2/modeling_qwen2.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """PyTorch Qwen2 model."""
5
+
6
+ import math
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+
13
+ from transformers.activations import ACT2FN
14
+ from transformers.cache_utils import Cache, DynamicCache
15
+ from transformers.generation import GenerationMixin
16
+ from transformers.modeling_outputs import (
17
+ BaseModelOutputWithPast,
18
+ CausalLMOutputWithPast,
19
+ )
20
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
21
+ from transformers.modeling_utils import PreTrainedModel
22
+ from transformers.utils import (
23
+ add_start_docstrings,
24
+ add_start_docstrings_to_model_forward,
25
+ is_flash_attn_2_available,
26
+ is_flash_attn_greater_or_equal_2_10,
27
+ logging,
28
+ replace_return_docstrings,
29
+ )
30
+ from .configuration_qwen2 import Qwen2Config
31
+
32
+
33
+ if is_flash_attn_2_available():
34
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B"
41
+ _CONFIG_FOR_DOC = "Qwen2Config"
42
+
43
+
44
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
45
+ class Qwen2RMSNorm(nn.Module):
46
+ def __init__(self, hidden_size, eps=1e-6):
47
+ """
48
+ Qwen2RMSNorm is equivalent to T5LayerNorm
49
+ """
50
+ super().__init__()
51
+ self.weight = nn.Parameter(torch.ones(hidden_size))
52
+ self.variance_epsilon = eps
53
+
54
+ def forward(self, hidden_states):
55
+ input_dtype = hidden_states.dtype
56
+ hidden_states = hidden_states.to(torch.float32)
57
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
58
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
59
+ return self.weight * hidden_states.to(input_dtype)
60
+
61
+ def extra_repr(self):
62
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
63
+
64
+
65
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
66
+ class Qwen2RotaryEmbedding(nn.Module):
67
+ def __init__(
68
+ self,
69
+ dim=None,
70
+ max_position_embeddings=2048,
71
+ base=10000,
72
+ device=None,
73
+ scaling_factor=1.0,
74
+ rope_type="default",
75
+ config: Optional[Qwen2Config] = None,
76
+ ):
77
+ super().__init__()
78
+ # TODO (joao): remove the `if` below, only used for BC
79
+ self.rope_kwargs = {}
80
+ if config is None:
81
+ logger.warning_once(
82
+ "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
83
+ "`config` argument. All other arguments will be removed in v4.46"
84
+ )
85
+ self.rope_kwargs = {
86
+ "rope_type": rope_type,
87
+ "factor": scaling_factor,
88
+ "dim": dim,
89
+ "base": base,
90
+ "max_position_embeddings": max_position_embeddings,
91
+ }
92
+ self.rope_type = rope_type
93
+ self.max_seq_len_cached = max_position_embeddings
94
+ self.original_max_seq_len = max_position_embeddings
95
+ else:
96
+ # BC: "rope_type" was originally "type"
97
+ if config.rope_scaling is not None:
98
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
99
+ else:
100
+ self.rope_type = "default"
101
+ self.max_seq_len_cached = config.max_position_embeddings
102
+ self.original_max_seq_len = config.max_position_embeddings
103
+
104
+ self.config = config
105
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
106
+
107
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
108
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
109
+ self.original_inv_freq = self.inv_freq
110
+
111
+ def _dynamic_frequency_update(self, position_ids, device):
112
+ """
113
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
114
+ 1 - growing beyond the cached sequence length (allow scaling)
115
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
116
+ """
117
+ seq_len = torch.max(position_ids) + 1
118
+ if seq_len > self.max_seq_len_cached: # growth
119
+ inv_freq, self.attention_scaling = self.rope_init_fn(
120
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
121
+ )
122
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
123
+ self.max_seq_len_cached = seq_len
124
+
125
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
126
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
127
+ self.max_seq_len_cached = self.original_max_seq_len
128
+
129
+ @torch.no_grad()
130
+ def forward(self, x, position_ids):
131
+ if "dynamic" in self.rope_type:
132
+ self._dynamic_frequency_update(position_ids, device=x.device)
133
+
134
+ # Core RoPE block
135
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
136
+ position_ids_expanded = position_ids[:, None, :].float()
137
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
138
+ device_type = x.device.type
139
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
140
+ with torch.autocast(device_type=device_type, enabled=False):
141
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
142
+ emb = torch.cat((freqs, freqs), dim=-1)
143
+ cos = emb.cos()
144
+ sin = emb.sin()
145
+
146
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
147
+ cos = cos * self.attention_scaling
148
+ sin = sin * self.attention_scaling
149
+
150
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
151
+
152
+
153
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
154
+ def rotate_half(x):
155
+ """Rotates half the hidden dims of the input."""
156
+ x1 = x[..., : x.shape[-1] // 2]
157
+ x2 = x[..., x.shape[-1] // 2 :]
158
+ return torch.cat((-x2, x1), dim=-1)
159
+
160
+
161
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
162
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
163
+ """Applies Rotary Position Embedding to the query and key tensors.
164
+
165
+ Args:
166
+ q (`torch.Tensor`): The query tensor.
167
+ k (`torch.Tensor`): The key tensor.
168
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
169
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
170
+ position_ids (`torch.Tensor`, *optional*):
171
+ Deprecated and unused.
172
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
173
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
174
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
175
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
176
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
177
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
178
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
179
+ Returns:
180
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
181
+ """
182
+ cos = cos.unsqueeze(unsqueeze_dim)
183
+ sin = sin.unsqueeze(unsqueeze_dim)
184
+ q_embed = (q * cos) + (rotate_half(q) * sin)
185
+ k_embed = (k * cos) + (rotate_half(k) * sin)
186
+ return q_embed, k_embed
187
+
188
+
189
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
190
+ class Qwen2MLP(nn.Module):
191
+ def __init__(self, config):
192
+ super().__init__()
193
+ self.hidden_size = config.hidden_size
194
+ self.intermediate_size = config.intermediate_size
195
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
196
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
197
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
198
+ self.act_fn = ACT2FN[config.hidden_act]
199
+
200
+ def forward(self, hidden_state):
201
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
202
+
203
+
204
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
205
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
206
+ """
207
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
208
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
209
+ """
210
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
211
+ if n_rep == 1:
212
+ return hidden_states
213
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
214
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
215
+
216
+
217
+ class Qwen2Attention(nn.Module):
218
+ """
219
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
220
+ and "Generating Long Sequences with Sparse Transformers".
221
+ """
222
+
223
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
224
+ super().__init__()
225
+ self.config = config
226
+ self.layer_idx = layer_idx
227
+ if layer_idx is None:
228
+ logger.warning_once(
229
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
230
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
231
+ "when creating this class."
232
+ )
233
+
234
+ self.hidden_size = config.hidden_size
235
+ self.num_heads = config.num_attention_heads
236
+ self.head_dim = self.hidden_size // self.num_heads
237
+ self.num_key_value_heads = config.num_key_value_heads
238
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
239
+ self.max_position_embeddings = config.max_position_embeddings
240
+ self.rope_theta = config.rope_theta
241
+ self.is_causal = config.is_causal
242
+ self.attention_dropout = config.attention_dropout
243
+
244
+ if (self.head_dim * self.num_heads) != self.hidden_size:
245
+ raise ValueError(
246
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
247
+ f" and `num_heads`: {self.num_heads})."
248
+ )
249
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
250
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
251
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
252
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
253
+
254
+ def forward(
255
+ self,
256
+ hidden_states: torch.Tensor,
257
+ attention_mask: Optional[torch.Tensor] = None,
258
+ position_ids: Optional[torch.LongTensor] = None,
259
+ past_key_value: Optional[Cache] = None,
260
+ output_attentions: bool = False,
261
+ use_cache: bool = False,
262
+ cache_position: Optional[torch.LongTensor] = None,
263
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
264
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
265
+ bsz, q_len, _ = hidden_states.size()
266
+
267
+ query_states = self.q_proj(hidden_states)
268
+ key_states = self.k_proj(hidden_states)
269
+ value_states = self.v_proj(hidden_states)
270
+
271
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
272
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
273
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
274
+
275
+ if position_embeddings is None:
276
+ logger.warning_once(
277
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
278
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
279
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
280
+ "removed and `position_embeddings` will be mandatory."
281
+ )
282
+ cos, sin = self.rotary_emb(value_states, position_ids)
283
+ else:
284
+ cos, sin = position_embeddings
285
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
286
+
287
+ if past_key_value is not None:
288
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
289
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
290
+
291
+ # repeat k/v heads if n_kv_heads < n_heads
292
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
293
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
294
+
295
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
296
+ if attention_mask is not None: # no matter the length, we just slice it
297
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
298
+ attn_weights = attn_weights + causal_mask
299
+
300
+ # upcast attention to fp32
301
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
302
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
303
+ attn_output = torch.matmul(attn_weights, value_states)
304
+
305
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
306
+ raise ValueError(
307
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
308
+ f" {attn_output.size()}"
309
+ )
310
+
311
+ attn_output = attn_output.transpose(1, 2).contiguous()
312
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
313
+
314
+ attn_output = self.o_proj(attn_output)
315
+
316
+ if not output_attentions:
317
+ attn_weights = None
318
+
319
+ return attn_output, attn_weights, past_key_value
320
+
321
+
322
+ class Qwen2FlashAttention2(Qwen2Attention):
323
+ """
324
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
325
+ as the weights of the module stays untouched. The only required change would be on the forward pass
326
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
327
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
328
+ config.max_window_layers layers.
329
+ """
330
+
331
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
332
+ def __init__(self, *args, **kwargs):
333
+ super().__init__(*args, **kwargs)
334
+
335
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
336
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
337
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
338
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
339
+
340
+ def forward(
341
+ self,
342
+ hidden_states: torch.Tensor,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ position_ids: Optional[torch.LongTensor] = None,
345
+ past_key_value: Optional[Cache] = None,
346
+ output_attentions: bool = False,
347
+ use_cache: bool = False,
348
+ cache_position: Optional[torch.LongTensor] = None,
349
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
350
+ ):
351
+ bsz, q_len, _ = hidden_states.size()
352
+
353
+ query_states = self.q_proj(hidden_states)
354
+ key_states = self.k_proj(hidden_states)
355
+ value_states = self.v_proj(hidden_states)
356
+
357
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
358
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
359
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
360
+
361
+ if position_embeddings is None:
362
+ logger.warning_once(
363
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
364
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
365
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
366
+ "removed and `position_embeddings` will be mandatory."
367
+ )
368
+ cos, sin = self.rotary_emb(value_states, position_ids)
369
+ else:
370
+ cos, sin = position_embeddings
371
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
372
+
373
+ if past_key_value is not None:
374
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
375
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
376
+
377
+ # repeat k/v heads if n_kv_heads < n_heads
378
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
379
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
380
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
381
+
382
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
383
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
384
+ # cast them back in float16 just to be sure everything works as expected.
385
+ input_dtype = query_states.dtype
386
+ if input_dtype == torch.float32:
387
+ if torch.is_autocast_enabled():
388
+ target_dtype = torch.get_autocast_gpu_dtype()
389
+ # Handle the case where the model is quantized
390
+ elif hasattr(self.config, "_pre_quantization_dtype"):
391
+ target_dtype = self.config._pre_quantization_dtype
392
+ else:
393
+ target_dtype = self.q_proj.weight.dtype
394
+
395
+ logger.warning_once(
396
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
397
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
398
+ f" {target_dtype}."
399
+ )
400
+
401
+ query_states = query_states.to(target_dtype)
402
+ key_states = key_states.to(target_dtype)
403
+ value_states = value_states.to(target_dtype)
404
+
405
+ # Reashape to the expected shape for Flash Attention
406
+ query_states = query_states.transpose(1, 2)
407
+ key_states = key_states.transpose(1, 2)
408
+ value_states = value_states.transpose(1, 2)
409
+
410
+ if (
411
+ self.config.use_sliding_window
412
+ and getattr(self.config, "sliding_window", None) is not None
413
+ and self.layer_idx >= self.config.max_window_layers
414
+ ):
415
+ sliding_window = self.config.sliding_window
416
+ else:
417
+ sliding_window = None
418
+
419
+ attn_output = _flash_attention_forward(
420
+ query_states,
421
+ key_states,
422
+ value_states,
423
+ attention_mask,
424
+ q_len,
425
+ position_ids=position_ids,
426
+ dropout=dropout_rate,
427
+ sliding_window=sliding_window,
428
+ is_causal=self.is_causal,
429
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
430
+ )
431
+
432
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
433
+ attn_output = self.o_proj(attn_output)
434
+
435
+ if not output_attentions:
436
+ attn_weights = None
437
+
438
+ return attn_output, attn_weights, past_key_value
439
+
440
+
441
+ QWEN2_ATTENTION_CLASSES = {
442
+ "eager": Qwen2Attention,
443
+ "flash_attention_2": Qwen2FlashAttention2,
444
+ }
445
+
446
+
447
+ class Qwen2DecoderLayer(nn.Module):
448
+ def __init__(self, config: Qwen2Config, layer_idx: int):
449
+ super().__init__()
450
+ self.hidden_size = config.hidden_size
451
+
452
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
453
+ logger.warning_once(
454
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
455
+ "unexpected results may be encountered."
456
+ )
457
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
458
+
459
+ self.mlp = Qwen2MLP(config)
460
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
461
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
462
+
463
+ def forward(
464
+ self,
465
+ hidden_states: torch.Tensor,
466
+ attention_mask: Optional[torch.Tensor] = None,
467
+ position_ids: Optional[torch.LongTensor] = None,
468
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
469
+ output_attentions: Optional[bool] = False,
470
+ use_cache: Optional[bool] = False,
471
+ cache_position: Optional[torch.LongTensor] = None,
472
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
473
+ **kwargs,
474
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
475
+ """
476
+ Args:
477
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
478
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
479
+ `(batch, sequence_length)` where padding elements are indicated by 0.
480
+ output_attentions (`bool`, *optional*):
481
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
482
+ returned tensors for more detail.
483
+ use_cache (`bool`, *optional*):
484
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
485
+ (see `past_key_values`).
486
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
487
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
488
+ Indices depicting the position of the input sequence tokens in the sequence.
489
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
490
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
491
+ with `head_dim` being the embedding dimension of each attention head.
492
+ kwargs (`dict`, *optional*):
493
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
494
+ into the model
495
+ """
496
+
497
+ residual = hidden_states
498
+
499
+ hidden_states = self.input_layernorm(hidden_states)
500
+
501
+ # Self Attention
502
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
503
+ hidden_states=hidden_states,
504
+ attention_mask=attention_mask,
505
+ position_ids=position_ids,
506
+ past_key_value=past_key_value,
507
+ output_attentions=output_attentions,
508
+ use_cache=use_cache,
509
+ cache_position=cache_position,
510
+ position_embeddings=position_embeddings,
511
+ )
512
+ hidden_states = residual + hidden_states
513
+
514
+ # Fully Connected
515
+ residual = hidden_states
516
+ hidden_states = self.post_attention_layernorm(hidden_states)
517
+ hidden_states = self.mlp(hidden_states)
518
+ hidden_states = residual + hidden_states
519
+
520
+ outputs = (hidden_states,)
521
+
522
+ if output_attentions:
523
+ outputs += (self_attn_weights,)
524
+
525
+ if use_cache:
526
+ outputs += (present_key_value,)
527
+
528
+ return outputs
529
+
530
+
531
+ QWEN2_START_DOCSTRING = r"""
532
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
533
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
534
+ etc.)
535
+
536
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
537
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
538
+ and behavior.
539
+
540
+ Parameters:
541
+ config ([`Qwen2Config`]):
542
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
543
+ load the weights associated with the model, only the configuration. Check out the
544
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
545
+ """
546
+
547
+
548
+ @add_start_docstrings(
549
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
550
+ QWEN2_START_DOCSTRING,
551
+ )
552
+ class Qwen2PreTrainedModel(PreTrainedModel):
553
+ config_class = Qwen2Config
554
+ base_model_prefix = "model"
555
+ supports_gradient_checkpointing = True
556
+ _no_split_modules = ["Qwen2DecoderLayer"]
557
+ _skip_keys_device_placement = "past_key_values"
558
+ _supports_flash_attn_2 = True
559
+ _supports_cache_class = True
560
+ _supports_quantized_cache = True
561
+ _supports_static_cache = True
562
+
563
+ def _init_weights(self, module):
564
+ std = self.config.initializer_range
565
+ if isinstance(module, nn.Linear):
566
+ module.weight.data.normal_(mean=0.0, std=std)
567
+ if module.bias is not None:
568
+ module.bias.data.zero_()
569
+ elif isinstance(module, nn.Embedding):
570
+ module.weight.data.normal_(mean=0.0, std=std)
571
+ if module.padding_idx is not None:
572
+ module.weight.data[module.padding_idx].zero_()
573
+
574
+
575
+ QWEN2_INPUTS_DOCSTRING = r"""
576
+ Args:
577
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
578
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
579
+ it.
580
+
581
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
582
+ [`PreTrainedTokenizer.__call__`] for details.
583
+
584
+ [What are input IDs?](../glossary#input-ids)
585
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
586
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
587
+
588
+ - 1 for tokens that are **not masked**,
589
+ - 0 for tokens that are **masked**.
590
+
591
+ [What are attention masks?](../glossary#attention-mask)
592
+
593
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
594
+ [`PreTrainedTokenizer.__call__`] for details.
595
+
596
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
597
+ `past_key_values`).
598
+
599
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
600
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
601
+ information on the default strategy.
602
+
603
+ - 1 indicates the head is **not masked**,
604
+ - 0 indicates the head is **masked**.
605
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
606
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
607
+ config.n_positions - 1]`.
608
+
609
+ [What are position IDs?](../glossary#position-ids)
610
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
611
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
612
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
613
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
614
+
615
+ Two formats are allowed:
616
+ - a [`~cache_utils.Cache`] instance, see our
617
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
618
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
619
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
620
+ cache format.
621
+
622
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
623
+ legacy cache format will be returned.
624
+
625
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
626
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
627
+ of shape `(batch_size, sequence_length)`.
628
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
629
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
630
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
631
+ model's internal embedding lookup matrix.
632
+ use_cache (`bool`, *optional*):
633
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
634
+ `past_key_values`).
635
+ output_attentions (`bool`, *optional*):
636
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
637
+ tensors for more detail.
638
+ output_hidden_states (`bool`, *optional*):
639
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
640
+ more detail.
641
+ return_dict (`bool`, *optional*):
642
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
643
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
644
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
645
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
646
+ the complete sequence length.
647
+ """
648
+
649
+
650
+ @add_start_docstrings(
651
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
652
+ QWEN2_START_DOCSTRING,
653
+ )
654
+ class Qwen2Model(Qwen2PreTrainedModel):
655
+ """
656
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
657
+
658
+ Args:
659
+ config: Qwen2Config
660
+ """
661
+
662
+ def __init__(self, config: Qwen2Config):
663
+ super().__init__(config)
664
+ self.padding_idx = config.pad_token_id
665
+ self.vocab_size = config.vocab_size
666
+
667
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
668
+ self.layers = nn.ModuleList(
669
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
670
+ )
671
+ self._attn_implementation = config._attn_implementation
672
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
673
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
674
+
675
+ self.gradient_checkpointing = False
676
+ # Initialize weights and apply final processing
677
+ self.post_init()
678
+
679
+ def get_input_embeddings(self):
680
+ return self.embed_tokens
681
+
682
+ def set_input_embeddings(self, value):
683
+ self.embed_tokens = value
684
+
685
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
686
+ def forward(
687
+ self,
688
+ input_ids: torch.LongTensor = None,
689
+ attention_mask: Optional[torch.Tensor] = None,
690
+ position_ids: Optional[torch.LongTensor] = None,
691
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
692
+ inputs_embeds: Optional[torch.FloatTensor] = None,
693
+ use_cache: Optional[bool] = None,
694
+ output_attentions: Optional[bool] = None,
695
+ output_hidden_states: Optional[bool] = None,
696
+ return_dict: Optional[bool] = None,
697
+ cache_position: Optional[torch.LongTensor] = None,
698
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
699
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
700
+ output_hidden_states = (
701
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
702
+ )
703
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
704
+
705
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
706
+
707
+ if (input_ids is None) ^ (inputs_embeds is not None):
708
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
709
+
710
+ if self.gradient_checkpointing and self.training:
711
+ if use_cache:
712
+ logger.warning_once(
713
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
714
+ )
715
+ use_cache = False
716
+
717
+ # kept for BC (non `Cache` `past_key_values` inputs)
718
+ return_legacy_cache = False
719
+ if use_cache and not isinstance(past_key_values, Cache):
720
+ return_legacy_cache = True
721
+ if past_key_values is None:
722
+ past_key_values = DynamicCache()
723
+ else:
724
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
725
+ logger.warning_once(
726
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
727
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
728
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
729
+ )
730
+
731
+ if inputs_embeds is None:
732
+ inputs_embeds = self.embed_tokens(input_ids)
733
+
734
+ if cache_position is None:
735
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
736
+ cache_position = torch.arange(
737
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
738
+ )
739
+ if position_ids is None:
740
+ position_ids = cache_position.unsqueeze(0)
741
+
742
+ if attention_mask is not None and 0.0 in attention_mask:
743
+ causal_mask = attention_mask
744
+ else:
745
+ causal_mask = None
746
+
747
+ hidden_states = inputs_embeds
748
+ # create position embeddings to be shared across the decoder layers
749
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
750
+
751
+ # decoder layers
752
+ all_hidden_states = () if output_hidden_states else None
753
+ all_self_attns = () if output_attentions else None
754
+ next_decoder_cache = None
755
+
756
+ for decoder_layer in self.layers:
757
+ if output_hidden_states:
758
+ all_hidden_states += (hidden_states,)
759
+
760
+ if self.gradient_checkpointing and self.training:
761
+ layer_outputs = self._gradient_checkpointing_func(
762
+ decoder_layer.__call__,
763
+ hidden_states,
764
+ causal_mask,
765
+ position_ids,
766
+ past_key_values,
767
+ output_attentions,
768
+ use_cache,
769
+ cache_position,
770
+ position_embeddings,
771
+ )
772
+ else:
773
+ layer_outputs = decoder_layer(
774
+ hidden_states,
775
+ attention_mask=causal_mask,
776
+ position_ids=position_ids,
777
+ past_key_value=past_key_values,
778
+ output_attentions=output_attentions,
779
+ use_cache=use_cache,
780
+ cache_position=cache_position,
781
+ position_embeddings=position_embeddings,
782
+ )
783
+
784
+ hidden_states = layer_outputs[0]
785
+
786
+ if use_cache:
787
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
788
+
789
+ if output_attentions:
790
+ all_self_attns += (layer_outputs[1],)
791
+
792
+ hidden_states = self.norm(hidden_states)
793
+
794
+ # add hidden states from the last decoder layer
795
+ if output_hidden_states:
796
+ all_hidden_states += (hidden_states,)
797
+
798
+ next_cache = next_decoder_cache if use_cache else None
799
+ if return_legacy_cache:
800
+ next_cache = next_cache.to_legacy_cache()
801
+
802
+ if not return_dict:
803
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
804
+ return BaseModelOutputWithPast(
805
+ last_hidden_state=hidden_states,
806
+ past_key_values=next_cache,
807
+ hidden_states=all_hidden_states,
808
+ attentions=all_self_attns,
809
+ )
810
+
811
+
812
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
813
+ _tied_weights_keys = ["lm_head.weight"]
814
+
815
+ def __init__(self, config):
816
+ super().__init__(config)
817
+ self.model = Qwen2Model(config)
818
+ self.vocab_size = config.vocab_size
819
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
820
+
821
+ # Initialize weights and apply final processing
822
+ self.post_init()
823
+
824
+ def get_input_embeddings(self):
825
+ return self.model.embed_tokens
826
+
827
+ def set_input_embeddings(self, value):
828
+ self.model.embed_tokens = value
829
+
830
+ def get_output_embeddings(self):
831
+ return self.lm_head
832
+
833
+ def set_output_embeddings(self, new_embeddings):
834
+ self.lm_head = new_embeddings
835
+
836
+ def set_decoder(self, decoder):
837
+ self.model = decoder
838
+
839
+ def get_decoder(self):
840
+ return self.model
841
+
842
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
843
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
844
+ def forward(
845
+ self,
846
+ input_ids: torch.LongTensor = None,
847
+ attention_mask: Optional[torch.Tensor] = None,
848
+ position_ids: Optional[torch.LongTensor] = None,
849
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
850
+ inputs_embeds: Optional[torch.FloatTensor] = None,
851
+ labels: Optional[torch.LongTensor] = None,
852
+ use_cache: Optional[bool] = None,
853
+ output_attentions: Optional[bool] = None,
854
+ output_hidden_states: Optional[bool] = None,
855
+ return_dict: Optional[bool] = None,
856
+ cache_position: Optional[torch.LongTensor] = None,
857
+ num_logits_to_keep: int = 0,
858
+ **loss_kwargs,
859
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
860
+ r"""
861
+ Args:
862
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
863
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
864
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
865
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
866
+
867
+ num_logits_to_keep (`int`, *optional*):
868
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
869
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
870
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
871
+
872
+ Returns:
873
+
874
+ Example:
875
+
876
+ ```python
877
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
878
+
879
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
880
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
881
+
882
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
883
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
884
+
885
+ >>> # Generate
886
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
887
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
888
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
889
+ ```"""
890
+
891
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
892
+ output_hidden_states = (
893
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
894
+ )
895
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
896
+
897
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
898
+ outputs = self.model(
899
+ input_ids=input_ids,
900
+ attention_mask=attention_mask,
901
+ position_ids=position_ids,
902
+ past_key_values=past_key_values,
903
+ inputs_embeds=inputs_embeds,
904
+ use_cache=use_cache,
905
+ output_attentions=output_attentions,
906
+ output_hidden_states=output_hidden_states,
907
+ return_dict=return_dict,
908
+ cache_position=cache_position,
909
+ )
910
+
911
+ hidden_states = outputs[0]
912
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
913
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
914
+
915
+ loss = None
916
+ if labels is not None:
917
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
918
+
919
+ if not return_dict:
920
+ output = (logits,) + outputs[1:]
921
+ return (loss,) + output if loss is not None else output
922
+
923
+ return CausalLMOutputWithPast(
924
+ loss=loss,
925
+ logits=logits,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ )
modeling/qwen2/tokenization_qwen2.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tokenization classes for Qwen2."""
5
+
6
+ import json
7
+ import os
8
+ import unicodedata
9
+ from functools import lru_cache
10
+ from typing import Optional, Tuple
11
+
12
+ import regex as re
13
+
14
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
15
+ from transformers.utils import logging
16
+
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+ VOCAB_FILES_NAMES = {
21
+ "vocab_file": "vocab.json",
22
+ "merges_file": "merges.txt",
23
+ }
24
+
25
+
26
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
27
+
28
+ PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
29
+
30
+
31
+ @lru_cache()
32
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
33
+ def bytes_to_unicode():
34
+ """
35
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
36
+ characters the bpe code barfs on.
37
+
38
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
39
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
40
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
41
+ tables between utf-8 bytes and unicode strings.
42
+ """
43
+ bs = (
44
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
45
+ )
46
+ cs = bs[:]
47
+ n = 0
48
+ for b in range(2**8):
49
+ if b not in bs:
50
+ bs.append(b)
51
+ cs.append(2**8 + n)
52
+ n += 1
53
+ cs = [chr(n) for n in cs]
54
+ return dict(zip(bs, cs))
55
+
56
+
57
+ # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
58
+ def get_pairs(word):
59
+ """
60
+ Return set of symbol pairs in a word.
61
+
62
+ Word is represented as tuple of symbols (symbols being variable-length strings).
63
+ """
64
+ pairs = set()
65
+ prev_char = word[0]
66
+ for char in word[1:]:
67
+ pairs.add((prev_char, char))
68
+ prev_char = char
69
+ return pairs
70
+
71
+
72
+ class Qwen2Tokenizer(PreTrainedTokenizer):
73
+ """
74
+ Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
75
+
76
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
77
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
78
+
79
+ ```python
80
+ >>> from transformers import Qwen2Tokenizer
81
+
82
+ >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
83
+ >>> tokenizer("Hello world")["input_ids"]
84
+ [9707, 1879]
85
+
86
+ >>> tokenizer(" Hello world")["input_ids"]
87
+ [21927, 1879]
88
+ ```
89
+ This is expected.
90
+
91
+ You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
92
+
93
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
94
+ this superclass for more information regarding those methods.
95
+
96
+ Args:
97
+ vocab_file (`str`):
98
+ Path to the vocabulary file.
99
+ merges_file (`str`):
100
+ Path to the merges file.
101
+ errors (`str`, *optional*, defaults to `"replace"`):
102
+ Paradigm to follow when decoding bytes to UTF-8. See
103
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
104
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
105
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
106
+ token instead.
107
+ bos_token (`str`, *optional*):
108
+ The beginning of sequence token. Not applicable for this tokenizer.
109
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
110
+ The end of sequence token.
111
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
112
+ The token used for padding, for example when batching sequences of different lengths.
113
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
114
+ Whether or not the model should cleanup the spaces that were added when splitting the input text during the
115
+ tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
116
+ split_special_tokens (`bool`, *optional*, defaults to `False`):
117
+ Whether or not the special tokens should be split during the tokenization process. The default behavior is
118
+ to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
119
+ ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
120
+ '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
121
+ """
122
+
123
+ vocab_files_names = VOCAB_FILES_NAMES
124
+ model_input_names = ["input_ids", "attention_mask"]
125
+
126
+ def __init__(
127
+ self,
128
+ vocab_file,
129
+ merges_file,
130
+ errors="replace",
131
+ unk_token="<|endoftext|>",
132
+ bos_token=None,
133
+ eos_token="<|endoftext|>",
134
+ pad_token="<|endoftext|>",
135
+ clean_up_tokenization_spaces=False,
136
+ split_special_tokens=False,
137
+ **kwargs,
138
+ ):
139
+ # Qwen vocab does not contain control tokens; added tokens need to be special
140
+ bos_token = (
141
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
142
+ if isinstance(bos_token, str)
143
+ else bos_token
144
+ )
145
+ eos_token = (
146
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
147
+ if isinstance(eos_token, str)
148
+ else eos_token
149
+ )
150
+ unk_token = (
151
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
152
+ if isinstance(unk_token, str)
153
+ else unk_token
154
+ )
155
+ pad_token = (
156
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
157
+ if isinstance(pad_token, str)
158
+ else pad_token
159
+ )
160
+
161
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
162
+ self.encoder = json.load(vocab_handle)
163
+ self.decoder = {v: k for k, v in self.encoder.items()}
164
+ self.errors = errors # how to handle errors in decoding
165
+ self.byte_encoder = bytes_to_unicode()
166
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
167
+ bpe_merges = []
168
+ with open(merges_file, encoding="utf-8") as merges_handle:
169
+ for i, line in enumerate(merges_handle):
170
+ line = line.strip()
171
+ if (i == 0 and line.startswith("#version:")) or not line:
172
+ continue
173
+ bpe_merges.append(tuple(line.split()))
174
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
175
+ # NOTE: the cache can grow without bound and will get really large for long running processes
176
+ # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
177
+ # not a memory leak but appears as one.
178
+ # GPT2Tokenizer has the same problem, so let's be consistent.
179
+ self.cache = {}
180
+
181
+ self.pat = re.compile(PRETOKENIZE_REGEX)
182
+
183
+ if kwargs.get("add_prefix_space", False):
184
+ logger.warning_once(
185
+ f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
186
+ )
187
+
188
+ super().__init__(
189
+ errors=errors,
190
+ bos_token=bos_token,
191
+ eos_token=eos_token,
192
+ pad_token=pad_token,
193
+ unk_token=unk_token,
194
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
195
+ split_special_tokens=split_special_tokens,
196
+ **kwargs,
197
+ )
198
+
199
+ @property
200
+ def vocab_size(self) -> int:
201
+ return len(self.encoder)
202
+
203
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
204
+ def get_vocab(self):
205
+ return dict(self.encoder, **self.added_tokens_encoder)
206
+
207
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
208
+ def bpe(self, token):
209
+ if token in self.cache:
210
+ return self.cache[token]
211
+ word = tuple(token)
212
+ pairs = get_pairs(word)
213
+
214
+ if not pairs:
215
+ return token
216
+
217
+ while True:
218
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
219
+ if bigram not in self.bpe_ranks:
220
+ break
221
+ first, second = bigram
222
+ new_word = []
223
+ i = 0
224
+ while i < len(word):
225
+ try:
226
+ j = word.index(first, i)
227
+ except ValueError:
228
+ new_word.extend(word[i:])
229
+ break
230
+ else:
231
+ new_word.extend(word[i:j])
232
+ i = j
233
+
234
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
235
+ new_word.append(first + second)
236
+ i += 2
237
+ else:
238
+ new_word.append(word[i])
239
+ i += 1
240
+ new_word = tuple(new_word)
241
+ word = new_word
242
+ if len(word) == 1:
243
+ break
244
+ else:
245
+ pairs = get_pairs(word)
246
+ word = " ".join(word)
247
+ self.cache[token] = word
248
+ return word
249
+
250
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
251
+ def _tokenize(self, text):
252
+ """Tokenize a string."""
253
+ bpe_tokens = []
254
+ for token in re.findall(self.pat, text):
255
+ token = "".join(
256
+ self.byte_encoder[b] for b in token.encode("utf-8")
257
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
258
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
259
+ return bpe_tokens
260
+
261
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
262
+ def _convert_token_to_id(self, token):
263
+ """Converts a token (str) in an id using the vocab."""
264
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
265
+
266
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
267
+ def _convert_id_to_token(self, index):
268
+ """Converts an index (integer) in a token (str) using the vocab."""
269
+ return self.decoder.get(index)
270
+
271
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
272
+ def convert_tokens_to_string(self, tokens):
273
+ """Converts a sequence of tokens (string) in a single string."""
274
+ text = "".join(tokens)
275
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
276
+ return text
277
+
278
+ def decode(
279
+ self,
280
+ token_ids,
281
+ skip_special_tokens: bool = False,
282
+ clean_up_tokenization_spaces: Optional[bool] = False,
283
+ spaces_between_special_tokens: bool = False,
284
+ **kwargs,
285
+ ) -> str:
286
+ # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
287
+ # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
288
+ return super().decode(
289
+ token_ids,
290
+ skip_special_tokens=skip_special_tokens,
291
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
292
+ spaces_between_special_tokens=spaces_between_special_tokens,
293
+ **kwargs,
294
+ )
295
+
296
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
297
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
298
+ if not os.path.isdir(save_directory):
299
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
300
+ return
301
+ vocab_file = os.path.join(
302
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
303
+ )
304
+ merge_file = os.path.join(
305
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
306
+ )
307
+
308
+ with open(vocab_file, "w", encoding="utf-8") as f:
309
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
310
+
311
+ index = 0
312
+ with open(merge_file, "w", encoding="utf-8") as writer:
313
+ writer.write("#version: 0.2\n")
314
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
315
+ if index != token_index:
316
+ logger.warning(
317
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
318
+ " Please check that the tokenizer is not corrupted!"
319
+ )
320
+ index = token_index
321
+ writer.write(" ".join(bpe_tokens) + "\n")
322
+ index += 1
323
+
324
+ return vocab_file, merge_file
325
+
326
+ def prepare_for_tokenization(self, text, **kwargs):
327
+ text = unicodedata.normalize("NFC", text)
328
+ return (text, kwargs)
modeling/qwen2/tokenization_qwen2_fast.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tokenization classes for Qwen2."""
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ from transformers.tokenization_utils import AddedToken
9
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
10
+ from transformers.utils import logging
11
+ from .tokenization_qwen2 import Qwen2Tokenizer
12
+
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+ VOCAB_FILES_NAMES = {
17
+ "vocab_file": "vocab.json",
18
+ "merges_file": "merges.txt",
19
+ "tokenizer_file": "tokenizer.json",
20
+ }
21
+
22
+
23
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
24
+
25
+
26
+ class Qwen2TokenizerFast(PreTrainedTokenizerFast):
27
+ """
28
+ Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
29
+ Byte-Pair-Encoding.
30
+
31
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
32
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
33
+
34
+ ```python
35
+ >>> from transformers import Qwen2TokenizerFast
36
+
37
+ >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
38
+ >>> tokenizer("Hello world")["input_ids"]
39
+ [9707, 1879]
40
+
41
+ >>> tokenizer(" Hello world")["input_ids"]
42
+ [21927, 1879]
43
+ ```
44
+ This is expected.
45
+
46
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
47
+ refer to this superclass for more information regarding those methods.
48
+
49
+ Args:
50
+ vocab_file (`str`, *optional*):
51
+ Path to the vocabulary file.
52
+ merges_file (`str`, *optional*):
53
+ Path to the merges file.
54
+ tokenizer_file (`str`, *optional*):
55
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
56
+ contains everything needed to load the tokenizer.
57
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
58
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
59
+ token instead. Not applicable to this tokenizer.
60
+ bos_token (`str`, *optional*):
61
+ The beginning of sequence token. Not applicable for this tokenizer.
62
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
63
+ The end of sequence token.
64
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
65
+ The token used for padding, for example when batching sequences of different lengths.
66
+ """
67
+
68
+ vocab_files_names = VOCAB_FILES_NAMES
69
+ model_input_names = ["input_ids", "attention_mask"]
70
+ slow_tokenizer_class = Qwen2Tokenizer
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_file=None,
75
+ merges_file=None,
76
+ tokenizer_file=None,
77
+ unk_token="<|endoftext|>",
78
+ bos_token=None,
79
+ eos_token="<|endoftext|>",
80
+ pad_token="<|endoftext|>",
81
+ **kwargs,
82
+ ):
83
+ # We need to at least pass vocab_file and merges_file to base class
84
+ # in case a slow tokenizer needs to be initialized; other can be
85
+ # configured through files.
86
+ # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
87
+
88
+ bos_token = (
89
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
90
+ if isinstance(bos_token, str)
91
+ else bos_token
92
+ )
93
+ eos_token = (
94
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
95
+ if isinstance(eos_token, str)
96
+ else eos_token
97
+ )
98
+ unk_token = (
99
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
100
+ if isinstance(unk_token, str)
101
+ else unk_token
102
+ )
103
+ pad_token = (
104
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
105
+ if isinstance(pad_token, str)
106
+ else pad_token
107
+ )
108
+
109
+ super().__init__(
110
+ vocab_file=vocab_file,
111
+ merges_file=merges_file,
112
+ tokenizer_file=tokenizer_file,
113
+ unk_token=unk_token,
114
+ bos_token=bos_token,
115
+ eos_token=eos_token,
116
+ pad_token=pad_token,
117
+ **kwargs,
118
+ )
119
+
120
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
121
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
122
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
123
+ return tuple(files)
modeling/siglip/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ from transformers.utils import (
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_sentencepiece_available,
10
+ is_torch_available,
11
+ is_vision_available,
12
+ )
13
+
14
+
15
+ _import_structure = {
16
+ "configuration_siglip": [
17
+ "SiglipConfig",
18
+ "SiglipTextConfig",
19
+ "SiglipVisionConfig",
20
+ ],
21
+ "processing_siglip": ["SiglipProcessor"],
22
+ }
23
+
24
+ try:
25
+ if not is_sentencepiece_available():
26
+ raise OptionalDependencyNotAvailable()
27
+ except OptionalDependencyNotAvailable:
28
+ pass
29
+ else:
30
+ _import_structure["tokenization_siglip"] = ["SiglipTokenizer"]
31
+
32
+
33
+ try:
34
+ if not is_vision_available():
35
+ raise OptionalDependencyNotAvailable()
36
+ except OptionalDependencyNotAvailable:
37
+ pass
38
+ else:
39
+ _import_structure["image_processing_siglip"] = ["SiglipImageProcessor"]
40
+
41
+ try:
42
+ if not is_torch_available():
43
+ raise OptionalDependencyNotAvailable()
44
+ except OptionalDependencyNotAvailable:
45
+ pass
46
+ else:
47
+ _import_structure["modeling_siglip"] = [
48
+ "SiglipModel",
49
+ "SiglipPreTrainedModel",
50
+ "SiglipTextModel",
51
+ "SiglipVisionModel",
52
+ "SiglipForImageClassification",
53
+ ]
54
+
55
+
56
+ if TYPE_CHECKING:
57
+ from .configuration_siglip import (
58
+ SiglipConfig,
59
+ SiglipTextConfig,
60
+ SiglipVisionConfig,
61
+ )
62
+ from .processing_siglip import SiglipProcessor
63
+
64
+ try:
65
+ if not is_sentencepiece_available():
66
+ raise OptionalDependencyNotAvailable()
67
+ except OptionalDependencyNotAvailable:
68
+ pass
69
+ else:
70
+ from .tokenization_siglip import SiglipTokenizer
71
+
72
+ try:
73
+ if not is_vision_available():
74
+ raise OptionalDependencyNotAvailable()
75
+ except OptionalDependencyNotAvailable:
76
+ pass
77
+ else:
78
+ from .image_processing_siglip import SiglipImageProcessor
79
+
80
+ try:
81
+ if not is_torch_available():
82
+ raise OptionalDependencyNotAvailable()
83
+ except OptionalDependencyNotAvailable:
84
+ pass
85
+ else:
86
+ from .modeling_siglip import (
87
+ SiglipForImageClassification,
88
+ SiglipModel,
89
+ SiglipPreTrainedModel,
90
+ SiglipTextModel,
91
+ SiglipVisionModel,
92
+ )
93
+
94
+
95
+ else:
96
+ import sys
97
+
98
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
modeling/siglip/configuration_siglip.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Siglip model configuration"""
5
+
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import logging
11
+
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ class SiglipTextConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
19
+ Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
20
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
21
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
22
+
23
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
24
+ documentation from [`PretrainedConfig`] for more information.
25
+
26
+ Args:
27
+ vocab_size (`int`, *optional*, defaults to 32000):
28
+ Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
29
+ the `inputs_ids` passed when calling [`SiglipModel`].
30
+ hidden_size (`int`, *optional*, defaults to 768):
31
+ Dimensionality of the encoder layers and the pooler layer.
32
+ intermediate_size (`int`, *optional*, defaults to 3072):
33
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
34
+ num_hidden_layers (`int`, *optional*, defaults to 12):
35
+ Number of hidden layers in the Transformer encoder.
36
+ num_attention_heads (`int`, *optional*, defaults to 12):
37
+ Number of attention heads for each attention layer in the Transformer encoder.
38
+ max_position_embeddings (`int`, *optional*, defaults to 64):
39
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
40
+ just in case (e.g., 512 or 1024 or 2048).
41
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
42
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
43
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
44
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
45
+ The epsilon used by the layer normalization layers.
46
+ attention_dropout (`float`, *optional*, defaults to 0.0):
47
+ The dropout ratio for the attention probabilities.
48
+ pad_token_id (`int`, *optional*, defaults to 1):
49
+ The id of the padding token in the vocabulary.
50
+ bos_token_id (`int`, *optional*, defaults to 49406):
51
+ The id of the beginning-of-sequence token in the vocabulary.
52
+ eos_token_id (`int`, *optional*, defaults to 49407):
53
+ The id of the end-of-sequence token in the vocabulary.
54
+
55
+ Example:
56
+
57
+ ```python
58
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
59
+
60
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
61
+ >>> configuration = SiglipTextConfig()
62
+
63
+ >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
64
+ >>> model = SiglipTextModel(configuration)
65
+
66
+ >>> # Accessing the model configuration
67
+ >>> configuration = model.config
68
+ ```"""
69
+
70
+ model_type = "siglip_text_model"
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_size=32000,
75
+ hidden_size=768,
76
+ intermediate_size=3072,
77
+ num_hidden_layers=12,
78
+ num_attention_heads=12,
79
+ max_position_embeddings=64,
80
+ hidden_act="gelu_pytorch_tanh",
81
+ layer_norm_eps=1e-6,
82
+ attention_dropout=0.0,
83
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
84
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
85
+ pad_token_id=1,
86
+ bos_token_id=49406,
87
+ eos_token_id=49407,
88
+ **kwargs,
89
+ ):
90
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
91
+
92
+ self.vocab_size = vocab_size
93
+ self.hidden_size = hidden_size
94
+ self.intermediate_size = intermediate_size
95
+ self.num_hidden_layers = num_hidden_layers
96
+ self.num_attention_heads = num_attention_heads
97
+ self.max_position_embeddings = max_position_embeddings
98
+ self.layer_norm_eps = layer_norm_eps
99
+ self.hidden_act = hidden_act
100
+ self.attention_dropout = attention_dropout
101
+
102
+ @classmethod
103
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
104
+ cls._set_token_in_kwargs(kwargs)
105
+
106
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
107
+
108
+ # get the text config dict if we are loading from SiglipConfig
109
+ if config_dict.get("model_type") == "siglip":
110
+ config_dict = config_dict["text_config"]
111
+
112
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
113
+ logger.warning(
114
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
115
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
116
+ )
117
+
118
+ return cls.from_dict(config_dict, **kwargs)
119
+
120
+
121
+ class SiglipVisionConfig(PretrainedConfig):
122
+ r"""
123
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
124
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
125
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
126
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
127
+
128
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
129
+ documentation from [`PretrainedConfig`] for more information.
130
+
131
+ Args:
132
+ hidden_size (`int`, *optional*, defaults to 768):
133
+ Dimensionality of the encoder layers and the pooler layer.
134
+ intermediate_size (`int`, *optional*, defaults to 3072):
135
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
136
+ num_hidden_layers (`int`, *optional*, defaults to 12):
137
+ Number of hidden layers in the Transformer encoder.
138
+ num_attention_heads (`int`, *optional*, defaults to 12):
139
+ Number of attention heads for each attention layer in the Transformer encoder.
140
+ num_channels (`int`, *optional*, defaults to 3):
141
+ Number of channels in the input images.
142
+ image_size (`int`, *optional*, defaults to 224):
143
+ The size (resolution) of each image.
144
+ patch_size (`int`, *optional*, defaults to 16):
145
+ The size (resolution) of each patch.
146
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
147
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
148
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
149
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
150
+ The epsilon used by the layer normalization layers.
151
+ attention_dropout (`float`, *optional*, defaults to 0.0):
152
+ The dropout ratio for the attention probabilities.
153
+
154
+ Example:
155
+
156
+ ```python
157
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
158
+
159
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
160
+ >>> configuration = SiglipVisionConfig()
161
+
162
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
163
+ >>> model = SiglipVisionModel(configuration)
164
+
165
+ >>> # Accessing the model configuration
166
+ >>> configuration = model.config
167
+ ```"""
168
+
169
+ model_type = "siglip_vision_model"
170
+
171
+ def __init__(
172
+ self,
173
+ hidden_size=768,
174
+ intermediate_size=3072,
175
+ num_hidden_layers=12,
176
+ num_attention_heads=12,
177
+ num_channels=3,
178
+ image_size=224,
179
+ patch_size=16,
180
+ hidden_act="gelu_pytorch_tanh",
181
+ layer_norm_eps=1e-6,
182
+ attention_dropout=0.0,
183
+ **kwargs,
184
+ ):
185
+ super().__init__(**kwargs)
186
+
187
+ self.hidden_size = hidden_size
188
+ self.intermediate_size = intermediate_size
189
+ self.num_hidden_layers = num_hidden_layers
190
+ self.num_attention_heads = num_attention_heads
191
+ self.num_channels = num_channels
192
+ self.patch_size = patch_size
193
+ self.image_size = image_size
194
+ self.attention_dropout = attention_dropout
195
+ self.layer_norm_eps = layer_norm_eps
196
+ self.hidden_act = hidden_act
197
+
198
+ @classmethod
199
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
200
+ cls._set_token_in_kwargs(kwargs)
201
+
202
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
203
+
204
+ # get the vision config dict if we are loading from SiglipConfig
205
+ if config_dict.get("model_type") == "siglip":
206
+ config_dict = config_dict["vision_config"]
207
+
208
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
209
+ logger.warning(
210
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
211
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
212
+ )
213
+
214
+ return cls.from_dict(config_dict, **kwargs)
215
+
216
+
217
+ class SiglipConfig(PretrainedConfig):
218
+ r"""
219
+ [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
220
+ instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
221
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
222
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
223
+
224
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
225
+ documentation from [`PretrainedConfig`] for more information.
226
+
227
+ Args:
228
+ text_config (`dict`, *optional*):
229
+ Dictionary of configuration options used to initialize [`SiglipTextConfig`].
230
+ vision_config (`dict`, *optional*):
231
+ Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
232
+ kwargs (*optional*):
233
+ Dictionary of keyword arguments.
234
+
235
+ Example:
236
+
237
+ ```python
238
+ >>> from transformers import SiglipConfig, SiglipModel
239
+
240
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
241
+ >>> configuration = SiglipConfig()
242
+
243
+ >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
244
+ >>> model = SiglipModel(configuration)
245
+
246
+ >>> # Accessing the model configuration
247
+ >>> configuration = model.config
248
+
249
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
250
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
251
+
252
+ >>> # Initializing a SiglipText and SiglipVision configuration
253
+ >>> config_text = SiglipTextConfig()
254
+ >>> config_vision = SiglipVisionConfig()
255
+
256
+ >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
257
+ ```"""
258
+
259
+ model_type = "siglip"
260
+
261
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
262
+ super().__init__(**kwargs)
263
+
264
+ if text_config is None:
265
+ text_config = {}
266
+ logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
267
+
268
+ if vision_config is None:
269
+ vision_config = {}
270
+ logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
271
+
272
+ self.text_config = SiglipTextConfig(**text_config)
273
+ self.vision_config = SiglipVisionConfig(**vision_config)
274
+
275
+ self.initializer_factor = 1.0
276
+
277
+ @classmethod
278
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
279
+ r"""
280
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
281
+ model configuration.
282
+
283
+ Returns:
284
+ [`SiglipConfig`]: An instance of a configuration object
285
+ """
286
+
287
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
modeling/siglip/convert_siglip_to_hf.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Convert SigLIP checkpoints from the original repository.
5
+
6
+ URL: https://github.com/google-research/big_vision/tree/main
7
+ """
8
+
9
+ import argparse
10
+ import collections
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import requests
15
+ import torch
16
+ from huggingface_hub import hf_hub_download
17
+ from numpy import load
18
+ from PIL import Image
19
+
20
+ from transformers import SiglipConfig, SiglipImageProcessor, SiglipModel, SiglipProcessor, SiglipTokenizer
21
+ from transformers.utils import logging
22
+
23
+
24
+ logging.set_verbosity_info()
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ model_name_to_checkpoint = {
29
+ # base checkpoints
30
+ "siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz",
31
+ "siglip-base-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_256_60500360.npz",
32
+ "siglip-base-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_384_68578854.npz",
33
+ "siglip-base-patch16-512": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_512_68580893.npz",
34
+ # large checkpoints
35
+ "siglip-large-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_256_60552751.npz",
36
+ "siglip-large-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_384_63634585.npz",
37
+ # multilingual checkpoint
38
+ "siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz",
39
+ # so400m checkpoints
40
+ "siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz",
41
+ }
42
+
43
+ model_name_to_image_size = {
44
+ "siglip-base-patch16-224": 224,
45
+ "siglip-base-patch16-256": 256,
46
+ "siglip-base-patch16-384": 384,
47
+ "siglip-base-patch16-512": 512,
48
+ "siglip-large-patch16-256": 256,
49
+ "siglip-large-patch16-384": 384,
50
+ "siglip-base-patch16-256-i18n": 256,
51
+ "siglip-so400m-patch14-384": 384,
52
+ }
53
+
54
+
55
+ def get_siglip_config(model_name):
56
+ config = SiglipConfig()
57
+
58
+ vocab_size = 250000 if "i18n" in model_name else 32000
59
+ image_size = model_name_to_image_size[model_name]
60
+ patch_size = 16 if "patch16" in model_name else 14
61
+
62
+ # size of the architecture
63
+ config.vision_config.image_size = image_size
64
+ config.vision_config.patch_size = patch_size
65
+ config.text_config.vocab_size = vocab_size
66
+
67
+ if "base" in model_name:
68
+ pass
69
+ elif "large" in model_name:
70
+ config.text_config.hidden_size = 1024
71
+ config.text_config.intermediate_size = 4096
72
+ config.text_config.num_hidden_layers = 24
73
+ config.text_config.num_attention_heads = 16
74
+ config.vision_config.hidden_size = 1024
75
+ config.vision_config.intermediate_size = 4096
76
+ config.vision_config.num_hidden_layers = 24
77
+ config.vision_config.num_attention_heads = 16
78
+ elif "so400m" in model_name:
79
+ config.text_config.hidden_size = 1152
80
+ config.text_config.intermediate_size = 4304
81
+ config.text_config.num_hidden_layers = 27
82
+ config.text_config.num_attention_heads = 16
83
+ config.vision_config.hidden_size = 1152
84
+ config.vision_config.intermediate_size = 4304
85
+ config.vision_config.num_hidden_layers = 27
86
+ config.vision_config.num_attention_heads = 16
87
+ else:
88
+ raise ValueError("Model not supported")
89
+
90
+ return config
91
+
92
+
93
+ def create_rename_keys(config):
94
+ rename_keys = []
95
+ # fmt: off
96
+
97
+ # vision encoder
98
+
99
+ rename_keys.append(("params/img/embedding/kernel", "vision_model.embeddings.patch_embedding.weight"))
100
+ rename_keys.append(("params/img/embedding/bias", "vision_model.embeddings.patch_embedding.bias"))
101
+ rename_keys.append(("params/img/pos_embedding", "vision_model.embeddings.position_embedding.weight"))
102
+
103
+ for i in range(config.vision_config.num_hidden_layers):
104
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
105
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
106
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
107
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
108
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
109
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
110
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
111
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
112
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight"))
113
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias"))
114
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight"))
115
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias"))
116
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight"))
117
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias"))
118
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight"))
119
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias"))
120
+
121
+ rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.post_layernorm.weight"))
122
+ rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.post_layernorm.bias"))
123
+
124
+ rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.head.probe"))
125
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.head.layernorm.weight"))
126
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.head.layernorm.bias"))
127
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.head.mlp.fc1.weight"))
128
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.head.mlp.fc1.bias"))
129
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.head.mlp.fc2.weight"))
130
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.head.mlp.fc2.bias"))
131
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.head.attention.out_proj.weight"))
132
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.head.attention.out_proj.bias"))
133
+
134
+ # text encoder
135
+
136
+ rename_keys.append(("params/txt/Embed_0/embedding", "text_model.embeddings.token_embedding.weight"))
137
+ rename_keys.append(("params/txt/pos_embedding", "text_model.embeddings.position_embedding.weight"))
138
+
139
+ for i in range(config.text_config.num_hidden_layers):
140
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.encoder.layers.{i}.layer_norm1.weight"))
141
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.encoder.layers.{i}.layer_norm1.bias"))
142
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.encoder.layers.{i}.layer_norm2.weight"))
143
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.encoder.layers.{i}.layer_norm2.bias"))
144
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.encoder.layers.{i}.mlp.fc1.weight"))
145
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.encoder.layers.{i}.mlp.fc1.bias"))
146
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.encoder.layers.{i}.mlp.fc2.weight"))
147
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.encoder.layers.{i}.mlp.fc2.bias"))
148
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.encoder.layers.{i}.self_attn.k_proj.weight"))
149
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.encoder.layers.{i}.self_attn.k_proj.bias"))
150
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.encoder.layers.{i}.self_attn.v_proj.weight"))
151
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.encoder.layers.{i}.self_attn.v_proj.bias"))
152
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.encoder.layers.{i}.self_attn.q_proj.weight"))
153
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.encoder.layers.{i}.self_attn.q_proj.bias"))
154
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.encoder.layers.{i}.self_attn.out_proj.weight"))
155
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.encoder.layers.{i}.self_attn.out_proj.bias"))
156
+
157
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.final_layer_norm.weight"))
158
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.final_layer_norm.bias"))
159
+ rename_keys.append(("params/txt/head/kernel", "text_model.head.weight"))
160
+ rename_keys.append(("params/txt/head/bias", "text_model.head.bias"))
161
+
162
+ # learned temperature and bias
163
+ rename_keys.append(("params/t", "logit_scale"))
164
+ rename_keys.append(("params/b", "logit_bias"))
165
+
166
+ # fmt: on
167
+ return rename_keys
168
+
169
+
170
+ def rename_key(dct, old, new, config):
171
+ val = dct.pop(old)
172
+
173
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new:
174
+ val = val.reshape(-1, config.vision_config.hidden_size)
175
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new:
176
+ val = val.reshape(-1, config.text_config.hidden_size)
177
+
178
+ if "patch_embedding.weight" in new:
179
+ val = val.transpose(3, 2, 0, 1)
180
+ elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new:
181
+ val = val.T
182
+
183
+ if "position_embedding" in new and "vision" in new:
184
+ val = val.reshape(-1, config.vision_config.hidden_size)
185
+ if "position_embedding" in new and "text" in new:
186
+ val = val.reshape(-1, config.text_config.hidden_size)
187
+
188
+ if new.endswith("bias"):
189
+ val = val.reshape(-1)
190
+
191
+ dct[new] = torch.from_numpy(val)
192
+
193
+
194
+ def read_in_q_k_v_head(state_dict, config):
195
+ # read in individual input projection layers
196
+ key_proj_weight = (
197
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel")
198
+ .reshape(-1, config.vision_config.hidden_size)
199
+ .T
200
+ )
201
+ key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1)
202
+ value_proj_weight = (
203
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel")
204
+ .reshape(-1, config.vision_config.hidden_size)
205
+ .T
206
+ )
207
+ value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1)
208
+ query_proj_weight = (
209
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel")
210
+ .reshape(-1, config.vision_config.hidden_size)
211
+ .T
212
+ )
213
+ query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1)
214
+
215
+ # next, add them to the state dict as a single matrix + vector
216
+ state_dict["vision_model.head.attention.in_proj_weight"] = torch.from_numpy(
217
+ np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0)
218
+ )
219
+ state_dict["vision_model.head.attention.in_proj_bias"] = torch.from_numpy(
220
+ np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0)
221
+ )
222
+
223
+
224
+ # We will verify our results on an image of cute cats
225
+ def prepare_img():
226
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
227
+ image = Image.open(requests.get(url, stream=True).raw)
228
+ return image
229
+
230
+
231
+ def flatten_nested_dict(params, parent_key="", sep="/"):
232
+ items = []
233
+
234
+ for k, v in params.items():
235
+ new_key = parent_key + sep + k if parent_key else k
236
+
237
+ if isinstance(v, collections.abc.MutableMapping):
238
+ items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
239
+ else:
240
+ items.append((new_key, v))
241
+ return dict(items)
242
+
243
+
244
+ @torch.no_grad()
245
+ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False):
246
+ """
247
+ Copy/paste/tweak model's weights to our SigLIP structure.
248
+ """
249
+
250
+ # define default SigLIP configuration
251
+ config = get_siglip_config(model_name)
252
+
253
+ # get checkpoint
254
+ checkpoint = model_name_to_checkpoint[model_name]
255
+
256
+ # get vocab file
257
+ if "i18n" in model_name:
258
+ vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model"
259
+ else:
260
+ vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model"
261
+
262
+ # load original state dict
263
+ data = load(checkpoint)
264
+ state_dict = flatten_nested_dict(data)
265
+
266
+ # remove and rename some keys
267
+ rename_keys = create_rename_keys(config)
268
+ for src, dest in rename_keys:
269
+ rename_key(state_dict, src, dest, config)
270
+
271
+ # qkv matrices of attention pooling head need special treatment
272
+ read_in_q_k_v_head(state_dict, config)
273
+
274
+ # load HuggingFace model
275
+ model = SiglipModel(config).eval()
276
+ model.load_state_dict(state_dict)
277
+
278
+ # create processor
279
+ # important: make tokenizer not return attention_mask since original one doesn't require it
280
+ image_size = config.vision_config.image_size
281
+ size = {"height": image_size, "width": image_size}
282
+ image_processor = SiglipImageProcessor(size=size)
283
+ tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"])
284
+ processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer)
285
+
286
+ # verify on dummy images and texts
287
+ url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
288
+ image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB")
289
+ url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
290
+ image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB")
291
+ texts = ["an apple", "a picture of an apple"]
292
+
293
+ inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt", padding="max_length")
294
+
295
+ # verify input_ids against original ones
296
+ if image_size == 224:
297
+ filename = "siglip_pixel_values.pt"
298
+ elif image_size == 256:
299
+ filename = "siglip_pixel_values_256.pt"
300
+ elif image_size == 384:
301
+ filename = "siglip_pixel_values_384.pt"
302
+ elif image_size == 512:
303
+ filename = "siglip_pixel_values_512.pt"
304
+ else:
305
+ raise ValueError("Image size not supported")
306
+
307
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename=filename, repo_type="dataset")
308
+ original_pixel_values = torch.load(filepath)
309
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset")
310
+ original_input_ids = torch.load(filepath)
311
+
312
+ if "i18n" not in model_name:
313
+ assert inputs.input_ids.tolist() == original_input_ids.tolist()
314
+
315
+ print("Mean of original pixel values:", original_pixel_values.mean())
316
+ print("Mean of new pixel values:", inputs.pixel_values.mean())
317
+
318
+ # note: we're testing with original pixel values here since we don't have exact pixel values
319
+ with torch.no_grad():
320
+ outputs = model(input_ids=inputs.input_ids, pixel_values=original_pixel_values)
321
+
322
+ # with torch.no_grad():
323
+ # outputs = model(input_ids=inputs.input_ids, pixel_values=inputs.pixel_values)
324
+
325
+ print(outputs.logits_per_image[:3, :3])
326
+
327
+ probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities
328
+ print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
329
+ print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'")
330
+
331
+ if verify_logits:
332
+ if model_name == "siglip-base-patch16-224":
333
+ expected_slice = torch.tensor(
334
+ [[-2.9621, -2.1672], [-0.2713, 0.2910]],
335
+ )
336
+ elif model_name == "siglip-base-patch16-256":
337
+ expected_slice = torch.tensor(
338
+ [[-3.1146, -1.9894], [-0.7312, 0.6387]],
339
+ )
340
+ elif model_name == "siglip-base-patch16-384":
341
+ expected_slice = torch.tensor(
342
+ [[-2.8098, -2.1891], [-0.4242, 0.4102]],
343
+ )
344
+ elif model_name == "siglip-base-patch16-512":
345
+ expected_slice = torch.tensor(
346
+ [[-2.7899, -2.2668], [-0.4295, -0.0735]],
347
+ )
348
+ elif model_name == "siglip-large-patch16-256":
349
+ expected_slice = torch.tensor(
350
+ [[-1.5827, -0.5801], [-0.9153, 0.1363]],
351
+ )
352
+ elif model_name == "siglip-large-patch16-384":
353
+ expected_slice = torch.tensor(
354
+ [[-2.1523, -0.2899], [-0.2959, 0.7884]],
355
+ )
356
+ elif model_name == "siglip-so400m-patch14-384":
357
+ expected_slice = torch.tensor([[-1.2441, -0.6649], [-0.7060, 0.7374]])
358
+ elif model_name == "siglip-base-patch16-256-i18n":
359
+ expected_slice = torch.tensor(
360
+ [[-0.9064, 0.1073], [-0.0299, 0.5304]],
361
+ )
362
+
363
+ assert torch.allclose(outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4)
364
+ print("Looks ok!")
365
+
366
+ if pytorch_dump_folder_path is not None:
367
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
368
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
369
+ model.save_pretrained(pytorch_dump_folder_path)
370
+ print(f"Saving processor to {pytorch_dump_folder_path}")
371
+ processor.save_pretrained(pytorch_dump_folder_path)
372
+
373
+ if push_to_hub:
374
+ model.push_to_hub(f"nielsr/{model_name}")
375
+ processor.push_to_hub(f"nielsr/{model_name}")
376
+
377
+
378
+ if __name__ == "__main__":
379
+ parser = argparse.ArgumentParser()
380
+ # Required parameters
381
+ parser.add_argument(
382
+ "--model_name",
383
+ default="siglip-base-patch16-224",
384
+ type=str,
385
+ choices=model_name_to_checkpoint.keys(),
386
+ help="Name of the model you'd like to convert.",
387
+ )
388
+ parser.add_argument(
389
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
390
+ )
391
+ parser.add_argument(
392
+ "--verify_logits",
393
+ action="store_false",
394
+ help="Whether to verify logits against the original implementation.",
395
+ )
396
+ parser.add_argument(
397
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
398
+ )
399
+
400
+ args = parser.parse_args()
401
+ convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub)
modeling/siglip/image_processing_siglip.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Image processor class for SigLIP."""
5
+
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
9
+ from transformers.image_transforms import (
10
+ convert_to_rgb,
11
+ resize,
12
+ to_channel_dimension_format,
13
+ )
14
+ from transformers.image_utils import (
15
+ IMAGENET_STANDARD_MEAN,
16
+ IMAGENET_STANDARD_STD,
17
+ ChannelDimension,
18
+ ImageInput,
19
+ PILImageResampling,
20
+ infer_channel_dimension_format,
21
+ is_scaled_image,
22
+ make_list_of_images,
23
+ to_numpy_array,
24
+ valid_images,
25
+ validate_preprocess_arguments,
26
+ )
27
+ from transformers.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ if is_vision_available():
34
+ import PIL
35
+
36
+
37
+ class SiglipImageProcessor(BaseImageProcessor):
38
+ r"""
39
+ Constructs a SigLIP image processor.
40
+
41
+ Args:
42
+ do_resize (`bool`, *optional*, defaults to `True`):
43
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
44
+ `do_resize` in the `preprocess` method.
45
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
46
+ Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
47
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
48
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
49
+ do_rescale (`bool`, *optional*, defaults to `True`):
50
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
51
+ the `preprocess` method.
52
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
53
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
54
+ method.
55
+ do_normalize (`bool`, *optional*, defaults to `True`):
56
+ Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
57
+ `do_normalize` in the `preprocess` method.
58
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
59
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
60
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
61
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
62
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
63
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
64
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
65
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
66
+ Whether to convert the image to RGB.
67
+ """
68
+
69
+ model_input_names = ["pixel_values"]
70
+
71
+ def __init__(
72
+ self,
73
+ do_resize: bool = True,
74
+ size: Dict[str, int] = None,
75
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
76
+ do_rescale: bool = True,
77
+ rescale_factor: Union[int, float] = 1 / 255,
78
+ do_normalize: bool = True,
79
+ image_mean: Optional[Union[float, List[float]]] = None,
80
+ image_std: Optional[Union[float, List[float]]] = None,
81
+ do_convert_rgb: bool = None,
82
+ **kwargs,
83
+ ) -> None:
84
+ super().__init__(**kwargs)
85
+ size = size if size is not None else {"height": 224, "width": 224}
86
+ image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
87
+ image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
88
+
89
+ self.do_resize = do_resize
90
+ self.size = size
91
+ self.resample = resample
92
+ self.do_rescale = do_rescale
93
+ self.rescale_factor = rescale_factor
94
+ self.do_normalize = do_normalize
95
+ self.image_mean = image_mean
96
+ self.image_std = image_std
97
+ self.do_convert_rgb = do_convert_rgb
98
+
99
+ @filter_out_non_signature_kwargs()
100
+ def preprocess(
101
+ self,
102
+ images: ImageInput,
103
+ do_resize: bool = None,
104
+ size: Dict[str, int] = None,
105
+ resample: PILImageResampling = None,
106
+ do_rescale: bool = None,
107
+ rescale_factor: float = None,
108
+ do_normalize: bool = None,
109
+ image_mean: Optional[Union[float, List[float]]] = None,
110
+ image_std: Optional[Union[float, List[float]]] = None,
111
+ return_tensors: Optional[Union[str, TensorType]] = None,
112
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
113
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
114
+ do_convert_rgb: bool = None,
115
+ ) -> PIL.Image.Image:
116
+ """
117
+ Preprocess an image or batch of images.
118
+
119
+ Args:
120
+ images (`ImageInput`):
121
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
122
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
123
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
124
+ Whether to resize the image.
125
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
126
+ Size of the image after resizing.
127
+ resample (`int`, *optional*, defaults to `self.resample`):
128
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
129
+ has an effect if `do_resize` is set to `True`.
130
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
131
+ Whether to rescale the image.
132
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
133
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
134
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
135
+ Whether to normalize the image.
136
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
137
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
138
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
139
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
140
+ `True`.
141
+ return_tensors (`str` or `TensorType`, *optional*):
142
+ The type of tensors to return. Can be one of:
143
+ - Unset: Return a list of `np.ndarray`.
144
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
145
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
146
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
147
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
148
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
149
+ The channel dimension format for the output image. Can be one of:
150
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
151
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
152
+ - Unset: Use the channel dimension format of the input image.
153
+ input_data_format (`ChannelDimension` or `str`, *optional*):
154
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
155
+ from the input image. Can be one of:
156
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
157
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
158
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
159
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
160
+ Whether to convert the image to RGB.
161
+ """
162
+ do_resize = do_resize if do_resize is not None else self.do_resize
163
+ size = size if size is not None else self.size
164
+ size = get_size_dict(size, param_name="size", default_to_square=False)
165
+ resample = resample if resample is not None else self.resample
166
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
167
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
168
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
169
+ image_mean = image_mean if image_mean is not None else self.image_mean
170
+ image_std = image_std if image_std is not None else self.image_std
171
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
172
+
173
+ images = make_list_of_images(images)
174
+
175
+ if not valid_images(images):
176
+ raise ValueError(
177
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
178
+ "torch.Tensor, tf.Tensor or jax.ndarray."
179
+ )
180
+ validate_preprocess_arguments(
181
+ do_rescale=do_rescale,
182
+ rescale_factor=rescale_factor,
183
+ do_normalize=do_normalize,
184
+ image_mean=image_mean,
185
+ image_std=image_std,
186
+ do_resize=do_resize,
187
+ size=size,
188
+ resample=resample,
189
+ )
190
+ # All transformations expect numpy arrays.
191
+ images = [to_numpy_array(image) for image in images]
192
+
193
+ if do_convert_rgb:
194
+ images = [convert_to_rgb(image) for image in images]
195
+
196
+ if is_scaled_image(images[0]) and do_rescale:
197
+ logger.warning_once(
198
+ "It looks like you are trying to rescale already rescaled images. If the input"
199
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
200
+ )
201
+
202
+ if input_data_format is None:
203
+ # We assume that all images have the same channel dimension format.
204
+ input_data_format = infer_channel_dimension_format(images[0])
205
+
206
+ if do_resize:
207
+ height, width = size["height"], size["width"]
208
+ images = [
209
+ resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format)
210
+ for image in images
211
+ ]
212
+
213
+ if do_rescale:
214
+ images = [
215
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
216
+ for image in images
217
+ ]
218
+
219
+ if do_normalize:
220
+ images = [
221
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
222
+ for image in images
223
+ ]
224
+
225
+ images = [
226
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
227
+ ]
228
+
229
+ data = {"pixel_values": images}
230
+ return BatchFeature(data=data, tensor_type=return_tensors)
modeling/siglip/modeling_siglip.py ADDED
@@ -0,0 +1,1557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """PyTorch Siglip model."""
5
+
6
+ import math
7
+ import warnings
8
+ from dataclasses import dataclass
9
+ from typing import Any, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.utils.checkpoint
14
+ from torch import nn
15
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
16
+ from torch.nn.init import _calculate_fan_in_and_fan_out
17
+
18
+ from transformers.activations import ACT2FN
19
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
20
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
21
+ from transformers.modeling_utils import PreTrainedModel
22
+ from transformers.utils import (
23
+ ModelOutput,
24
+ add_start_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ is_flash_attn_2_available,
27
+ is_flash_attn_greater_or_equal_2_10,
28
+ logging,
29
+ replace_return_docstrings,
30
+ torch_int,
31
+ )
32
+ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
33
+
34
+
35
+ if is_flash_attn_2_available():
36
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ # General docstring
42
+ _CONFIG_FOR_DOC = "SiglipConfig"
43
+ _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
44
+
45
+
46
+ def _trunc_normal_(tensor, mean, std, a, b):
47
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
48
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
49
+ def norm_cdf(x):
50
+ # Computes standard normal cumulative distribution function
51
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
52
+
53
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
54
+ warnings.warn(
55
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
56
+ "The distribution of values may be incorrect.",
57
+ stacklevel=2,
58
+ )
59
+
60
+ # Values are generated by using a truncated uniform distribution and
61
+ # then using the inverse CDF for the normal distribution.
62
+ # Get upper and lower cdf values
63
+ l = norm_cdf((a - mean) / std)
64
+ u = norm_cdf((b - mean) / std)
65
+
66
+ # Uniformly fill tensor with values from [l, u], then translate to
67
+ # [2l-1, 2u-1].
68
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
69
+
70
+ # Use inverse cdf transform for normal distribution to get truncated
71
+ # standard normal
72
+ tensor.erfinv_()
73
+
74
+ # Transform to proper mean, std
75
+ tensor.mul_(std * math.sqrt(2.0))
76
+ tensor.add_(mean)
77
+
78
+ # Clamp to ensure it's in the proper range
79
+ tensor.clamp_(min=a, max=b)
80
+
81
+
82
+ def trunc_normal_tf_(
83
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
84
+ ) -> torch.Tensor:
85
+ """Fills the input Tensor with values drawn from a truncated
86
+ normal distribution. The values are effectively drawn from the
87
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
88
+ with values outside :math:`[a, b]` redrawn until they are within
89
+ the bounds. The method used for generating the random values works
90
+ best when :math:`a \\leq \text{mean} \\leq b`.
91
+
92
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
93
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
94
+ and the result is subsequently scaled and shifted by the mean and std args.
95
+
96
+ Args:
97
+ tensor: an n-dimensional `torch.Tensor`
98
+ mean: the mean of the normal distribution
99
+ std: the standard deviation of the normal distribution
100
+ a: the minimum cutoff value
101
+ b: the maximum cutoff value
102
+ """
103
+ with torch.no_grad():
104
+ _trunc_normal_(tensor, 0, 1.0, a, b)
105
+ tensor.mul_(std).add_(mean)
106
+
107
+
108
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
109
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
110
+ if mode == "fan_in":
111
+ denom = fan_in
112
+ elif mode == "fan_out":
113
+ denom = fan_out
114
+ elif mode == "fan_avg":
115
+ denom = (fan_in + fan_out) / 2
116
+
117
+ variance = scale / denom
118
+
119
+ if distribution == "truncated_normal":
120
+ # constant is stddev of standard normal truncated to (-2, 2)
121
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
122
+ elif distribution == "normal":
123
+ with torch.no_grad():
124
+ tensor.normal_(std=math.sqrt(variance))
125
+ elif distribution == "uniform":
126
+ bound = math.sqrt(3 * variance)
127
+ with torch.no_grad():
128
+ tensor.uniform_(-bound, bound)
129
+ else:
130
+ raise ValueError(f"invalid distribution {distribution}")
131
+
132
+
133
+ def lecun_normal_(tensor):
134
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
135
+
136
+
137
+ def default_flax_embed_init(tensor):
138
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
139
+
140
+
141
+ @dataclass
142
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
143
+ class SiglipVisionModelOutput(ModelOutput):
144
+ """
145
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
146
+
147
+ Args:
148
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
149
+ The image embeddings obtained by applying the projection layer to the pooler_output.
150
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
151
+ Sequence of hidden-states at the output of the last layer of the model.
152
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
153
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
154
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
155
+
156
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
157
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
158
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
159
+ sequence_length)`.
160
+
161
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
162
+ heads.
163
+ """
164
+
165
+ image_embeds: Optional[torch.FloatTensor] = None
166
+ last_hidden_state: torch.FloatTensor = None
167
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
168
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
169
+
170
+
171
+ @dataclass
172
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
173
+ class SiglipTextModelOutput(ModelOutput):
174
+ """
175
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
176
+
177
+ Args:
178
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
179
+ The text embeddings obtained by applying the projection layer to the pooler_output.
180
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
181
+ Sequence of hidden-states at the output of the last layer of the model.
182
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
183
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
184
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
185
+
186
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
187
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
188
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
189
+ sequence_length)`.
190
+
191
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
192
+ heads.
193
+ """
194
+
195
+ text_embeds: Optional[torch.FloatTensor] = None
196
+ last_hidden_state: torch.FloatTensor = None
197
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
198
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
199
+
200
+
201
+ @dataclass
202
+ # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
203
+ class SiglipOutput(ModelOutput):
204
+ """
205
+ Args:
206
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
207
+ Contrastive loss for image-text similarity.
208
+ logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
209
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
210
+ similarity scores.
211
+ logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
212
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
213
+ similarity scores.
214
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
215
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
216
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
217
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
218
+ text_model_output (`BaseModelOutputWithPooling`):
219
+ The output of the [`SiglipTextModel`].
220
+ vision_model_output (`BaseModelOutputWithPooling`):
221
+ The output of the [`SiglipVisionModel`].
222
+ """
223
+
224
+ loss: Optional[torch.FloatTensor] = None
225
+ logits_per_image: torch.FloatTensor = None
226
+ logits_per_text: torch.FloatTensor = None
227
+ text_embeds: torch.FloatTensor = None
228
+ image_embeds: torch.FloatTensor = None
229
+ text_model_output: BaseModelOutputWithPooling = None
230
+ vision_model_output: BaseModelOutputWithPooling = None
231
+
232
+ def to_tuple(self) -> Tuple[Any]:
233
+ return tuple(
234
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
235
+ for k in self.keys()
236
+ )
237
+
238
+
239
+ class SiglipVisionEmbeddings(nn.Module):
240
+ def __init__(self, config: SiglipVisionConfig):
241
+ super().__init__()
242
+ self.config = config
243
+ self.embed_dim = config.hidden_size
244
+ self.image_size = config.image_size
245
+ self.patch_size = config.patch_size
246
+
247
+ self.patch_embedding = nn.Conv2d(
248
+ in_channels=config.num_channels,
249
+ out_channels=self.embed_dim,
250
+ kernel_size=self.patch_size,
251
+ stride=self.patch_size,
252
+ padding="valid",
253
+ )
254
+
255
+ self.num_patches = (self.image_size // self.patch_size) ** 2
256
+ self.num_positions = self.num_patches
257
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
258
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
259
+
260
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
261
+ """
262
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
263
+ images. This method is also adapted to support torch.jit tracing and no class embeddings.
264
+
265
+ Adapted from:
266
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
267
+ - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
268
+ """
269
+
270
+ num_patches = embeddings.shape[1]
271
+ num_positions = self.position_embedding.weight.shape[0]
272
+
273
+ # always interpolate when tracing to ensure the exported model works for dynamic input shapes
274
+ if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
275
+ return self.position_embedding(self.position_ids)
276
+
277
+ patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
278
+
279
+ dim = embeddings.shape[-1]
280
+
281
+ new_height = height // self.patch_size
282
+ new_width = width // self.patch_size
283
+
284
+ sqrt_num_positions = torch_int(num_positions**0.5)
285
+ patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
286
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
287
+
288
+ patch_pos_embed = nn.functional.interpolate(
289
+ patch_pos_embed,
290
+ size=(new_height, new_width),
291
+ mode="bicubic",
292
+ align_corners=False,
293
+ )
294
+
295
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
296
+ return patch_pos_embed
297
+
298
+ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
299
+ _, _, height, width = pixel_values.shape
300
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
301
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
302
+
303
+ if interpolate_pos_encoding:
304
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
305
+ else:
306
+ embeddings = embeddings + self.position_embedding(self.position_ids)
307
+ return embeddings
308
+
309
+
310
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
311
+ class SiglipTextEmbeddings(nn.Module):
312
+ def __init__(self, config: SiglipTextConfig):
313
+ super().__init__()
314
+ embed_dim = config.hidden_size
315
+
316
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
317
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
318
+
319
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
320
+ self.register_buffer(
321
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
322
+ )
323
+
324
+ def forward(
325
+ self,
326
+ input_ids: Optional[torch.LongTensor] = None,
327
+ position_ids: Optional[torch.LongTensor] = None,
328
+ inputs_embeds: Optional[torch.FloatTensor] = None,
329
+ ) -> torch.Tensor:
330
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
331
+
332
+ if position_ids is None:
333
+ position_ids = self.position_ids[:, :seq_length]
334
+
335
+ if inputs_embeds is None:
336
+ inputs_embeds = self.token_embedding(input_ids)
337
+
338
+ position_embeddings = self.position_embedding(position_ids)
339
+ embeddings = inputs_embeds + position_embeddings
340
+
341
+ return embeddings
342
+
343
+
344
+ class SiglipAttention(nn.Module):
345
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
346
+
347
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
348
+ def __init__(self, config):
349
+ super().__init__()
350
+ self.config = config
351
+ self.embed_dim = config.hidden_size
352
+ self.num_heads = config.num_attention_heads
353
+ self.head_dim = self.embed_dim // self.num_heads
354
+ if self.head_dim * self.num_heads != self.embed_dim:
355
+ raise ValueError(
356
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
357
+ f" {self.num_heads})."
358
+ )
359
+ self.scale = self.head_dim**-0.5
360
+ self.dropout = config.attention_dropout
361
+
362
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
363
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
364
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
365
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
366
+
367
+ def forward(
368
+ self,
369
+ hidden_states: torch.Tensor,
370
+ attention_mask: Optional[torch.Tensor] = None,
371
+ output_attentions: Optional[bool] = False,
372
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
373
+ """Input shape: Batch x Time x Channel"""
374
+
375
+ batch_size, q_len, _ = hidden_states.size()
376
+
377
+ query_states = self.q_proj(hidden_states)
378
+ key_states = self.k_proj(hidden_states)
379
+ value_states = self.v_proj(hidden_states)
380
+
381
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
382
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
383
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
384
+
385
+ k_v_seq_len = key_states.shape[-2]
386
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
387
+
388
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
389
+ raise ValueError(
390
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
391
+ f" {attn_weights.size()}"
392
+ )
393
+
394
+ if attention_mask is not None:
395
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
396
+ raise ValueError(
397
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
398
+ )
399
+ attn_weights = attn_weights + attention_mask
400
+
401
+ # upcast attention to fp32
402
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
403
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
404
+ attn_output = torch.matmul(attn_weights, value_states)
405
+
406
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
407
+ raise ValueError(
408
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
409
+ f" {attn_output.size()}"
410
+ )
411
+
412
+ attn_output = attn_output.transpose(1, 2).contiguous()
413
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
414
+
415
+ attn_output = self.out_proj(attn_output)
416
+
417
+ return attn_output, attn_weights
418
+
419
+
420
+ class SiglipFlashAttention2(SiglipAttention):
421
+ """
422
+ SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays
423
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
424
+ flash attention and deal with padding tokens in case the input contains any of them.
425
+ """
426
+
427
+ is_causal = False
428
+
429
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
430
+ def __init__(self, *args, **kwargs):
431
+ super().__init__(*args, **kwargs)
432
+
433
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
434
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
435
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
436
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
437
+
438
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
439
+ def forward(
440
+ self,
441
+ hidden_states: torch.Tensor,
442
+ attention_mask: Optional[torch.LongTensor] = None,
443
+ output_attentions: bool = False,
444
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
445
+ output_attentions = False
446
+
447
+ batch_size, q_len, _ = hidden_states.size()
448
+
449
+ query_states = self.q_proj(hidden_states)
450
+ key_states = self.k_proj(hidden_states)
451
+ value_states = self.v_proj(hidden_states)
452
+
453
+ # Flash attention requires the input to have the shape
454
+ # batch_size x seq_length x head_dim x hidden_dim
455
+ # therefore we just need to keep the original shape
456
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
457
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
458
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
459
+
460
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
461
+ # to be able to avoid many of these transpose/reshape/view.
462
+ query_states = query_states.transpose(1, 2)
463
+ key_states = key_states.transpose(1, 2)
464
+ value_states = value_states.transpose(1, 2)
465
+
466
+ dropout_rate = self.dropout if self.training else 0.0
467
+
468
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
469
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
470
+ # cast them back in the correct dtype just to be sure everything works as expected.
471
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
472
+ # in fp32.
473
+
474
+ input_dtype = query_states.dtype
475
+ if input_dtype == torch.float32:
476
+ if torch.is_autocast_enabled():
477
+ target_dtype = torch.get_autocast_gpu_dtype()
478
+ # Handle the case where the model is quantized
479
+ elif hasattr(self.config, "_pre_quantization_dtype"):
480
+ target_dtype = self.config._pre_quantization_dtype
481
+ else:
482
+ target_dtype = self.q_proj.weight.dtype
483
+
484
+ logger.warning_once(
485
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
486
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
487
+ f" {target_dtype}."
488
+ )
489
+
490
+ query_states = query_states.to(target_dtype)
491
+ key_states = key_states.to(target_dtype)
492
+ value_states = value_states.to(target_dtype)
493
+
494
+ attn_output = _flash_attention_forward(
495
+ query_states,
496
+ key_states,
497
+ value_states,
498
+ attention_mask,
499
+ q_len,
500
+ dropout=dropout_rate,
501
+ is_causal=self.is_causal,
502
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
503
+ )
504
+
505
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
506
+ attn_output = self.out_proj(attn_output)
507
+
508
+ if not output_attentions:
509
+ attn_weights = None
510
+
511
+ return attn_output, attn_weights
512
+
513
+
514
+ class SiglipSdpaAttention(SiglipAttention):
515
+ """
516
+ Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
517
+ `SiglipAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
518
+ SDPA API.
519
+ """
520
+
521
+ is_causal = False
522
+
523
+ # Adapted from SiglipAttention.forward and transformers.models.llama.modeling_llama.LlamaSdpaAttention.forward
524
+ def forward(
525
+ self,
526
+ hidden_states: torch.Tensor,
527
+ attention_mask: Optional[torch.Tensor] = None,
528
+ output_attentions: Optional[bool] = False,
529
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
530
+ if output_attentions:
531
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
532
+ logger.warning_once(
533
+ "SiglipModel is using SiglipSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
534
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
535
+ )
536
+ return super().forward(
537
+ hidden_states=hidden_states,
538
+ attention_mask=attention_mask,
539
+ output_attentions=output_attentions,
540
+ )
541
+
542
+ batch_size, q_len, _ = hidden_states.size()
543
+
544
+ query_states = self.q_proj(hidden_states)
545
+ key_states = self.k_proj(hidden_states)
546
+ value_states = self.v_proj(hidden_states)
547
+
548
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
549
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
550
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
551
+
552
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
553
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
554
+ if query_states.device.type == "cuda" and attention_mask is not None:
555
+ query_states = query_states.contiguous()
556
+ key_states = key_states.contiguous()
557
+ value_states = value_states.contiguous()
558
+
559
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
560
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
561
+ is_causal = True if self.is_causal and q_len > 1 else False
562
+
563
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
564
+ query_states,
565
+ key_states,
566
+ value_states,
567
+ attn_mask=attention_mask,
568
+ dropout_p=self.dropout if self.training else 0.0,
569
+ is_causal=is_causal,
570
+ )
571
+
572
+ attn_output = attn_output.transpose(1, 2).contiguous()
573
+ attn_output = attn_output.view(batch_size, q_len, self.embed_dim)
574
+
575
+ attn_output = self.out_proj(attn_output)
576
+
577
+ return attn_output, None
578
+
579
+
580
+ SIGLIP_ATTENTION_CLASSES = {
581
+ "eager": SiglipAttention,
582
+ "flash_attention_2": SiglipFlashAttention2,
583
+ "sdpa": SiglipSdpaAttention,
584
+ }
585
+
586
+
587
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
588
+ class SiglipMLP(nn.Module):
589
+ def __init__(self, config):
590
+ super().__init__()
591
+ self.config = config
592
+ self.activation_fn = ACT2FN[config.hidden_act]
593
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
594
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
595
+
596
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
597
+ hidden_states = self.fc1(hidden_states)
598
+ hidden_states = self.activation_fn(hidden_states)
599
+ hidden_states = self.fc2(hidden_states)
600
+ return hidden_states
601
+
602
+
603
+ class SiglipEncoderLayer(nn.Module):
604
+ def __init__(self, config: SiglipConfig):
605
+ super().__init__()
606
+ self.embed_dim = config.hidden_size
607
+ self.self_attn = SIGLIP_ATTENTION_CLASSES[config._attn_implementation](config=config)
608
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
609
+ self.mlp = SiglipMLP(config)
610
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
611
+
612
+ # Ignore copy
613
+ def forward(
614
+ self,
615
+ hidden_states: torch.Tensor,
616
+ attention_mask: torch.Tensor,
617
+ output_attentions: Optional[bool] = False,
618
+ ) -> Tuple[torch.FloatTensor]:
619
+ """
620
+ Args:
621
+ hidden_states (`torch.FloatTensor`):
622
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
623
+ attention_mask (`torch.FloatTensor`):
624
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
625
+ output_attentions (`bool`, *optional*, defaults to `False`):
626
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
627
+ returned tensors for more detail.
628
+ """
629
+ residual = hidden_states
630
+
631
+ hidden_states = self.layer_norm1(hidden_states)
632
+ hidden_states, attn_weights = self.self_attn(
633
+ hidden_states=hidden_states,
634
+ attention_mask=attention_mask,
635
+ output_attentions=output_attentions,
636
+ )
637
+ hidden_states = residual + hidden_states
638
+
639
+ residual = hidden_states
640
+ hidden_states = self.layer_norm2(hidden_states)
641
+ hidden_states = self.mlp(hidden_states)
642
+ hidden_states = residual + hidden_states
643
+
644
+ outputs = (hidden_states,)
645
+
646
+ if output_attentions:
647
+ outputs += (attn_weights,)
648
+
649
+ return outputs
650
+
651
+
652
+ class SiglipPreTrainedModel(PreTrainedModel):
653
+ """
654
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
655
+ models.
656
+ """
657
+
658
+ config_class = SiglipConfig
659
+ base_model_prefix = "siglip"
660
+ supports_gradient_checkpointing = True
661
+
662
+ _no_split_modules = [
663
+ "SiglipTextEmbeddings",
664
+ "SiglipEncoderLayer",
665
+ "SiglipVisionEmbeddings",
666
+ "SiglipEncoderLayer",
667
+ "SiglipMultiheadAttentionPoolingHead",
668
+ ]
669
+ _supports_flash_attn_2 = True
670
+ _supports_sdpa = True
671
+
672
+ def _init_weights(self, module):
673
+ """Initialize the weights"""
674
+ if isinstance(module, SiglipVisionEmbeddings):
675
+ width = (
676
+ self.config.vision_config.hidden_size
677
+ if isinstance(self.config, SiglipConfig)
678
+ else self.config.hidden_size
679
+ )
680
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
681
+ elif isinstance(module, nn.Embedding):
682
+ default_flax_embed_init(module.weight)
683
+ elif isinstance(module, SiglipAttention):
684
+ nn.init.xavier_uniform_(module.q_proj.weight)
685
+ nn.init.xavier_uniform_(module.k_proj.weight)
686
+ nn.init.xavier_uniform_(module.v_proj.weight)
687
+ nn.init.xavier_uniform_(module.out_proj.weight)
688
+ nn.init.zeros_(module.q_proj.bias)
689
+ nn.init.zeros_(module.k_proj.bias)
690
+ nn.init.zeros_(module.v_proj.bias)
691
+ nn.init.zeros_(module.out_proj.bias)
692
+ elif isinstance(module, SiglipMLP):
693
+ nn.init.xavier_uniform_(module.fc1.weight)
694
+ nn.init.xavier_uniform_(module.fc2.weight)
695
+ nn.init.normal_(module.fc1.bias, std=1e-6)
696
+ nn.init.normal_(module.fc2.bias, std=1e-6)
697
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
698
+ nn.init.xavier_uniform_(module.probe.data)
699
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
700
+ nn.init.zeros_(module.attention.in_proj_bias.data)
701
+ elif isinstance(module, SiglipModel):
702
+ logit_scale_init = torch.log(torch.tensor(1.0))
703
+ module.logit_scale.data.fill_(logit_scale_init)
704
+ module.logit_bias.data.zero_()
705
+ elif isinstance(module, SiglipForImageClassification):
706
+ nn.init.normal_(
707
+ module.classifier.weight,
708
+ std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
709
+ )
710
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
711
+ lecun_normal_(module.weight)
712
+ if module.bias is not None:
713
+ nn.init.zeros_(module.bias)
714
+ elif isinstance(module, nn.LayerNorm):
715
+ module.bias.data.zero_()
716
+ module.weight.data.fill_(1.0)
717
+
718
+
719
+ SIGLIP_START_DOCSTRING = r"""
720
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
721
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
722
+ etc.)
723
+
724
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
725
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
726
+ and behavior.
727
+
728
+ Parameters:
729
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
730
+ Initializing with a config file does not load the weights associated with the model, only the
731
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
732
+ """
733
+
734
+ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
735
+ Args:
736
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
737
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
738
+ it.
739
+
740
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
741
+ [`PreTrainedTokenizer.__call__`] for details.
742
+
743
+ [What are input IDs?](../glossary#input-ids)
744
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
745
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
746
+
747
+ - 1 for tokens that are **not masked**,
748
+ - 0 for tokens that are **masked**.
749
+
750
+ [What are attention masks?](../glossary#attention-mask)
751
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
752
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
753
+ config.max_position_embeddings - 1]`.
754
+
755
+ [What are position IDs?](../glossary#position-ids)
756
+ output_attentions (`bool`, *optional*):
757
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
758
+ tensors for more detail.
759
+ output_hidden_states (`bool`, *optional*):
760
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
761
+ more detail.
762
+ return_dict (`bool`, *optional*):
763
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
764
+ """
765
+
766
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
767
+ Args:
768
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
769
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
770
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
771
+ output_attentions (`bool`, *optional*):
772
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
773
+ tensors for more detail.
774
+ output_hidden_states (`bool`, *optional*):
775
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
776
+ more detail.
777
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
778
+ Whether to interpolate the pre-trained position encodings.
779
+ return_dict (`bool`, *optional*):
780
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
781
+ """
782
+
783
+ SIGLIP_INPUTS_DOCSTRING = r"""
784
+ Args:
785
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
786
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
787
+ it.
788
+
789
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
790
+ [`PreTrainedTokenizer.__call__`] for details.
791
+
792
+ [What are input IDs?](../glossary#input-ids)
793
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
794
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
795
+
796
+ - 1 for tokens that are **not masked**,
797
+ - 0 for tokens that are **masked**.
798
+
799
+ [What are attention masks?](../glossary#attention-mask)
800
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
801
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
802
+ config.max_position_embeddings - 1]`.
803
+
804
+ [What are position IDs?](../glossary#position-ids)
805
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
806
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
807
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
808
+ return_loss (`bool`, *optional*):
809
+ Whether or not to return the contrastive loss.
810
+ output_attentions (`bool`, *optional*):
811
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
812
+ tensors for more detail.
813
+ output_hidden_states (`bool`, *optional*):
814
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
815
+ more detail.
816
+ interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
817
+ Whether to interpolate the pre-trained position encodings.
818
+ return_dict (`bool`, *optional*):
819
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
820
+ """
821
+
822
+
823
+ # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip
824
+ class SiglipEncoder(nn.Module):
825
+ """
826
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
827
+ [`SiglipEncoderLayer`].
828
+
829
+ Args:
830
+ config: SiglipConfig
831
+ """
832
+
833
+ def __init__(self, config: SiglipConfig):
834
+ super().__init__()
835
+ self.config = config
836
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
837
+ self.gradient_checkpointing = False
838
+
839
+ # Ignore copy
840
+ def forward(
841
+ self,
842
+ inputs_embeds,
843
+ attention_mask: Optional[torch.Tensor] = None,
844
+ output_attentions: Optional[bool] = None,
845
+ output_hidden_states: Optional[bool] = None,
846
+ return_dict: Optional[bool] = None,
847
+ ) -> Union[Tuple, BaseModelOutput]:
848
+ r"""
849
+ Args:
850
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
851
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
852
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
853
+ than the model's internal embedding lookup matrix.
854
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
855
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
856
+
857
+ - 1 for tokens that are **not masked**,
858
+ - 0 for tokens that are **masked**.
859
+
860
+ [What are attention masks?](../glossary#attention-mask)
861
+ output_attentions (`bool`, *optional*):
862
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
863
+ returned tensors for more detail.
864
+ output_hidden_states (`bool`, *optional*):
865
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
866
+ for more detail.
867
+ return_dict (`bool`, *optional*):
868
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
869
+ """
870
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
871
+ output_hidden_states = (
872
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
873
+ )
874
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
875
+
876
+ encoder_states = () if output_hidden_states else None
877
+ all_attentions = () if output_attentions else None
878
+
879
+ hidden_states = inputs_embeds
880
+ for encoder_layer in self.layers:
881
+ if output_hidden_states:
882
+ encoder_states = encoder_states + (hidden_states,)
883
+ if self.gradient_checkpointing and self.training:
884
+ layer_outputs = self._gradient_checkpointing_func(
885
+ encoder_layer.__call__,
886
+ hidden_states,
887
+ attention_mask,
888
+ output_attentions,
889
+ )
890
+ else:
891
+ layer_outputs = encoder_layer(
892
+ hidden_states,
893
+ attention_mask,
894
+ output_attentions=output_attentions,
895
+ )
896
+
897
+ hidden_states = layer_outputs[0]
898
+
899
+ if output_attentions:
900
+ all_attentions = all_attentions + (layer_outputs[1],)
901
+
902
+ if output_hidden_states:
903
+ encoder_states = encoder_states + (hidden_states,)
904
+
905
+ if not return_dict:
906
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
907
+ return BaseModelOutput(
908
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
909
+ )
910
+
911
+
912
+ class SiglipTextTransformer(nn.Module):
913
+ def __init__(self, config: SiglipTextConfig):
914
+ super().__init__()
915
+ self.config = config
916
+ embed_dim = config.hidden_size
917
+ self.embeddings = SiglipTextEmbeddings(config)
918
+ self.encoder = SiglipEncoder(config)
919
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
920
+
921
+ self.head = nn.Linear(embed_dim, embed_dim)
922
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
923
+
924
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
925
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
926
+ def forward(
927
+ self,
928
+ input_ids: Optional[torch.Tensor] = None,
929
+ attention_mask: Optional[torch.Tensor] = None,
930
+ position_ids: Optional[torch.Tensor] = None,
931
+ output_attentions: Optional[bool] = None,
932
+ output_hidden_states: Optional[bool] = None,
933
+ return_dict: Optional[bool] = None,
934
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
935
+ r"""
936
+ Returns:
937
+
938
+ """
939
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
940
+ output_hidden_states = (
941
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
942
+ )
943
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
944
+
945
+ if input_ids is None:
946
+ raise ValueError("You have to specify input_ids")
947
+
948
+ input_shape = input_ids.size()
949
+ input_ids = input_ids.view(-1, input_shape[-1])
950
+
951
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
952
+
953
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
954
+ # expand attention_mask
955
+ if attention_mask is not None and not self._use_flash_attention_2:
956
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
957
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
958
+
959
+ encoder_outputs = self.encoder(
960
+ inputs_embeds=hidden_states,
961
+ attention_mask=attention_mask,
962
+ output_attentions=output_attentions,
963
+ output_hidden_states=output_hidden_states,
964
+ return_dict=return_dict,
965
+ )
966
+
967
+ last_hidden_state = encoder_outputs[0]
968
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
969
+
970
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
971
+ pooled_output = last_hidden_state[:, -1, :]
972
+ pooled_output = self.head(pooled_output)
973
+
974
+ if not return_dict:
975
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
976
+
977
+ return BaseModelOutputWithPooling(
978
+ last_hidden_state=last_hidden_state,
979
+ pooler_output=pooled_output,
980
+ hidden_states=encoder_outputs.hidden_states,
981
+ attentions=encoder_outputs.attentions,
982
+ )
983
+
984
+
985
+ @add_start_docstrings(
986
+ """The text model from SigLIP without any head or projection on top.""",
987
+ SIGLIP_START_DOCSTRING,
988
+ )
989
+ class SiglipTextModel(SiglipPreTrainedModel):
990
+ config_class = SiglipTextConfig
991
+
992
+ def __init__(self, config: SiglipTextConfig):
993
+ super().__init__(config)
994
+ self.text_model = SiglipTextTransformer(config)
995
+ # Initialize weights and apply final processing
996
+ self.post_init()
997
+
998
+ def get_input_embeddings(self) -> nn.Module:
999
+ return self.text_model.embeddings.token_embedding
1000
+
1001
+ def set_input_embeddings(self, value):
1002
+ self.text_model.embeddings.token_embedding = value
1003
+
1004
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1005
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
1006
+ def forward(
1007
+ self,
1008
+ input_ids: Optional[torch.Tensor] = None,
1009
+ attention_mask: Optional[torch.Tensor] = None,
1010
+ position_ids: Optional[torch.Tensor] = None,
1011
+ output_attentions: Optional[bool] = None,
1012
+ output_hidden_states: Optional[bool] = None,
1013
+ return_dict: Optional[bool] = None,
1014
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1015
+ r"""
1016
+ Returns:
1017
+
1018
+ Examples:
1019
+
1020
+ ```python
1021
+ >>> from transformers import AutoTokenizer, SiglipTextModel
1022
+
1023
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
1024
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1025
+
1026
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1027
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1028
+
1029
+ >>> outputs = model(**inputs)
1030
+ >>> last_hidden_state = outputs.last_hidden_state
1031
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
1032
+ ```"""
1033
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1034
+
1035
+ return self.text_model(
1036
+ input_ids=input_ids,
1037
+ attention_mask=attention_mask,
1038
+ position_ids=position_ids,
1039
+ output_attentions=output_attentions,
1040
+ output_hidden_states=output_hidden_states,
1041
+ return_dict=return_dict,
1042
+ )
1043
+
1044
+
1045
+ class SiglipVisionTransformer(nn.Module):
1046
+ def __init__(self, config: SiglipVisionConfig):
1047
+ super().__init__()
1048
+ self.config = config
1049
+ embed_dim = config.hidden_size
1050
+
1051
+ self.embeddings = SiglipVisionEmbeddings(config)
1052
+ self.encoder = SiglipEncoder(config)
1053
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
1054
+ self.use_head = True if not hasattr(config, "vision_use_head") else config.vision_use_head
1055
+ if self.use_head:
1056
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
1057
+
1058
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1059
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1060
+ def forward(
1061
+ self,
1062
+ pixel_values,
1063
+ output_attentions: Optional[bool] = None,
1064
+ output_hidden_states: Optional[bool] = None,
1065
+ return_dict: Optional[bool] = None,
1066
+ interpolate_pos_encoding: Optional[bool] = False,
1067
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1068
+ r"""
1069
+ Returns:
1070
+
1071
+ """
1072
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1073
+ output_hidden_states = (
1074
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1075
+ )
1076
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1077
+
1078
+ hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
1079
+
1080
+ encoder_outputs = self.encoder(
1081
+ inputs_embeds=hidden_states,
1082
+ output_attentions=output_attentions,
1083
+ output_hidden_states=output_hidden_states,
1084
+ return_dict=return_dict,
1085
+ )
1086
+
1087
+ last_hidden_state = encoder_outputs[0]
1088
+ last_hidden_state = self.post_layernorm(last_hidden_state)
1089
+
1090
+ pooler_output = self.head(last_hidden_state) if self.use_head else None
1091
+ if not return_dict:
1092
+ return (last_hidden_state, pooler_output) + encoder_outputs[1:]
1093
+
1094
+ return BaseModelOutputWithPooling(
1095
+ last_hidden_state=last_hidden_state,
1096
+ pooler_output=pooler_output,
1097
+ hidden_states=encoder_outputs.hidden_states,
1098
+ attentions=encoder_outputs.attentions,
1099
+ )
1100
+
1101
+
1102
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
1103
+ """Multihead Attention Pooling."""
1104
+
1105
+ def __init__(self, config: SiglipVisionConfig):
1106
+ super().__init__()
1107
+
1108
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
1109
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
1110
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1111
+ self.mlp = SiglipMLP(config)
1112
+
1113
+ def forward(self, hidden_state):
1114
+ batch_size = hidden_state.shape[0]
1115
+ probe = self.probe.repeat(batch_size, 1, 1)
1116
+
1117
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
1118
+
1119
+ residual = hidden_state
1120
+ hidden_state = self.layernorm(hidden_state)
1121
+ hidden_state = residual + self.mlp(hidden_state)
1122
+
1123
+ return hidden_state[:, 0]
1124
+
1125
+
1126
+ @add_start_docstrings(
1127
+ """The vision model from SigLIP without any head or projection on top.""",
1128
+ SIGLIP_START_DOCSTRING,
1129
+ )
1130
+ class SiglipVisionModel(SiglipPreTrainedModel):
1131
+ config_class = SiglipVisionConfig
1132
+ main_input_name = "pixel_values"
1133
+
1134
+ def __init__(self, config: SiglipVisionConfig):
1135
+ super().__init__(config)
1136
+
1137
+ self.vision_model = SiglipVisionTransformer(config)
1138
+
1139
+ # Initialize weights and apply final processing
1140
+ self.post_init()
1141
+
1142
+ def get_input_embeddings(self) -> nn.Module:
1143
+ return self.vision_model.embeddings.patch_embedding
1144
+
1145
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1146
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
1147
+ def forward(
1148
+ self,
1149
+ pixel_values,
1150
+ output_attentions: Optional[bool] = None,
1151
+ output_hidden_states: Optional[bool] = None,
1152
+ return_dict: Optional[bool] = None,
1153
+ interpolate_pos_encoding: bool = False,
1154
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1155
+ r"""
1156
+ Returns:
1157
+
1158
+ Examples:
1159
+
1160
+ ```python
1161
+ >>> from PIL import Image
1162
+ >>> import requests
1163
+ >>> from transformers import AutoProcessor, SiglipVisionModel
1164
+
1165
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
1166
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1167
+
1168
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1169
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1170
+
1171
+ >>> inputs = processor(images=image, return_tensors="pt")
1172
+
1173
+ >>> outputs = model(**inputs)
1174
+ >>> last_hidden_state = outputs.last_hidden_state
1175
+ >>> pooled_output = outputs.pooler_output # pooled features
1176
+ ```"""
1177
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1178
+
1179
+ return self.vision_model(
1180
+ pixel_values=pixel_values,
1181
+ output_attentions=output_attentions,
1182
+ output_hidden_states=output_hidden_states,
1183
+ return_dict=return_dict,
1184
+ interpolate_pos_encoding=interpolate_pos_encoding,
1185
+ )
1186
+
1187
+
1188
+ @add_start_docstrings(SIGLIP_START_DOCSTRING)
1189
+ class SiglipModel(SiglipPreTrainedModel):
1190
+ config_class = SiglipConfig
1191
+
1192
+ def __init__(self, config: SiglipConfig):
1193
+ super().__init__(config)
1194
+
1195
+ if not isinstance(config.text_config, SiglipTextConfig):
1196
+ raise TypeError(
1197
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
1198
+ f" {type(config.text_config)}."
1199
+ )
1200
+
1201
+ if not isinstance(config.vision_config, SiglipVisionConfig):
1202
+ raise TypeError(
1203
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
1204
+ f" {type(config.vision_config)}."
1205
+ )
1206
+
1207
+ text_config = config.text_config
1208
+ vision_config = config.vision_config
1209
+
1210
+ # First, initialize the text and vision models with proper attention implementation
1211
+ text_model = SiglipTextModel._from_config(text_config)
1212
+ vision_model = SiglipVisionModel._from_config(vision_config)
1213
+
1214
+ # Second, get the text and vision submodules (for backward compatibility)
1215
+ self.text_model = text_model.text_model
1216
+ self.vision_model = vision_model.vision_model
1217
+
1218
+ self.logit_scale = nn.Parameter(torch.randn(1))
1219
+ self.logit_bias = nn.Parameter(torch.randn(1))
1220
+
1221
+ # Initialize weights and apply final processing
1222
+ self.post_init()
1223
+
1224
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1225
+ def get_text_features(
1226
+ self,
1227
+ input_ids: Optional[torch.Tensor] = None,
1228
+ attention_mask: Optional[torch.Tensor] = None,
1229
+ position_ids: Optional[torch.Tensor] = None,
1230
+ output_attentions: Optional[bool] = None,
1231
+ output_hidden_states: Optional[bool] = None,
1232
+ return_dict: Optional[bool] = None,
1233
+ ) -> torch.FloatTensor:
1234
+ r"""
1235
+ Returns:
1236
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1237
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
1238
+
1239
+ Examples:
1240
+
1241
+ ```python
1242
+ >>> from transformers import AutoTokenizer, AutoModel
1243
+ >>> import torch
1244
+
1245
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1246
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1247
+
1248
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1249
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1250
+ >>> with torch.no_grad():
1251
+ ... text_features = model.get_text_features(**inputs)
1252
+ ```"""
1253
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1254
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1255
+ output_hidden_states = (
1256
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1257
+ )
1258
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1259
+
1260
+ text_outputs = self.text_model(
1261
+ input_ids=input_ids,
1262
+ attention_mask=attention_mask,
1263
+ position_ids=position_ids,
1264
+ output_attentions=output_attentions,
1265
+ output_hidden_states=output_hidden_states,
1266
+ return_dict=return_dict,
1267
+ )
1268
+
1269
+ pooled_output = text_outputs[1]
1270
+
1271
+ return pooled_output
1272
+
1273
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1274
+ def get_image_features(
1275
+ self,
1276
+ pixel_values: Optional[torch.FloatTensor] = None,
1277
+ output_attentions: Optional[bool] = None,
1278
+ output_hidden_states: Optional[bool] = None,
1279
+ return_dict: Optional[bool] = None,
1280
+ interpolate_pos_encoding: bool = False,
1281
+ ) -> torch.FloatTensor:
1282
+ r"""
1283
+ Returns:
1284
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1285
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
1286
+
1287
+ Examples:
1288
+
1289
+ ```python
1290
+ >>> from PIL import Image
1291
+ >>> import requests
1292
+ >>> from transformers import AutoProcessor, AutoModel
1293
+ >>> import torch
1294
+
1295
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1296
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1297
+
1298
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1299
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1300
+
1301
+ >>> inputs = processor(images=image, return_tensors="pt")
1302
+
1303
+ >>> with torch.no_grad():
1304
+ ... image_features = model.get_image_features(**inputs)
1305
+ ```"""
1306
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1307
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1308
+ output_hidden_states = (
1309
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1310
+ )
1311
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1312
+
1313
+ vision_outputs = self.vision_model(
1314
+ pixel_values=pixel_values,
1315
+ output_attentions=output_attentions,
1316
+ output_hidden_states=output_hidden_states,
1317
+ return_dict=return_dict,
1318
+ interpolate_pos_encoding=interpolate_pos_encoding,
1319
+ )
1320
+
1321
+ pooled_output = vision_outputs[1]
1322
+
1323
+ return pooled_output
1324
+
1325
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1326
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1327
+ def forward(
1328
+ self,
1329
+ input_ids: Optional[torch.LongTensor] = None,
1330
+ pixel_values: Optional[torch.FloatTensor] = None,
1331
+ attention_mask: Optional[torch.Tensor] = None,
1332
+ position_ids: Optional[torch.LongTensor] = None,
1333
+ return_loss: Optional[bool] = None,
1334
+ output_attentions: Optional[bool] = None,
1335
+ output_hidden_states: Optional[bool] = None,
1336
+ return_dict: Optional[bool] = None,
1337
+ interpolate_pos_encoding: bool = False,
1338
+ ) -> Union[Tuple, SiglipOutput]:
1339
+ r"""
1340
+ Returns:
1341
+
1342
+ Examples:
1343
+
1344
+ ```python
1345
+ >>> from PIL import Image
1346
+ >>> import requests
1347
+ >>> from transformers import AutoProcessor, AutoModel
1348
+ >>> import torch
1349
+
1350
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1351
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1352
+
1353
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1354
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1355
+
1356
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1357
+ >>> # important: we pass `padding=max_length` since the model was trained with this
1358
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1359
+
1360
+ >>> with torch.no_grad():
1361
+ ... outputs = model(**inputs)
1362
+
1363
+ >>> logits_per_image = outputs.logits_per_image
1364
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1365
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1366
+ 31.9% that image 0 is 'a photo of 2 cats'
1367
+ ```"""
1368
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1369
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1370
+ output_hidden_states = (
1371
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1372
+ )
1373
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1374
+
1375
+ vision_outputs = self.vision_model(
1376
+ pixel_values=pixel_values,
1377
+ output_attentions=output_attentions,
1378
+ output_hidden_states=output_hidden_states,
1379
+ return_dict=return_dict,
1380
+ interpolate_pos_encoding=interpolate_pos_encoding,
1381
+ )
1382
+
1383
+ text_outputs = self.text_model(
1384
+ input_ids=input_ids,
1385
+ attention_mask=attention_mask,
1386
+ position_ids=position_ids,
1387
+ output_attentions=output_attentions,
1388
+ output_hidden_states=output_hidden_states,
1389
+ return_dict=return_dict,
1390
+ )
1391
+
1392
+ image_embeds = vision_outputs[1]
1393
+ text_embeds = text_outputs[1]
1394
+
1395
+ # normalized features
1396
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1397
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1398
+
1399
+ # cosine similarity as logits
1400
+ logits_per_text = (
1401
+ torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device)) * self.logit_scale.exp()
1402
+ + self.logit_bias
1403
+ )
1404
+ logits_per_image = logits_per_text.t()
1405
+
1406
+ loss = None
1407
+ if return_loss:
1408
+ # Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
1409
+ eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
1410
+ m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
1411
+ loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
1412
+ nll = -torch.sum(loglik, dim=-1)
1413
+ loss = nll.mean()
1414
+
1415
+ if not return_dict:
1416
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1417
+ return ((loss,) + output) if loss is not None else output
1418
+
1419
+ return SiglipOutput(
1420
+ loss=loss,
1421
+ logits_per_image=logits_per_image,
1422
+ logits_per_text=logits_per_text,
1423
+ text_embeds=text_embeds,
1424
+ image_embeds=image_embeds,
1425
+ text_model_output=text_outputs,
1426
+ vision_model_output=vision_outputs,
1427
+ )
1428
+
1429
+
1430
+ @add_start_docstrings(
1431
+ """
1432
+ SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
1433
+ the patch tokens) e.g. for ImageNet.
1434
+ """,
1435
+ SIGLIP_START_DOCSTRING,
1436
+ )
1437
+ class SiglipForImageClassification(SiglipPreTrainedModel):
1438
+ main_input_name = "pixel_values"
1439
+
1440
+ def __init__(self, config: SiglipConfig) -> None:
1441
+ super().__init__(config)
1442
+
1443
+ self.num_labels = config.num_labels
1444
+
1445
+ # Create the vision model with proper attention
1446
+ # and take only vision_model submodule (for backward compatibility)
1447
+ vision_model = SiglipVisionModel._from_config(config.vision_config)
1448
+ self.vision_model = vision_model.vision_model
1449
+
1450
+ # Classifier head
1451
+ self.classifier = (
1452
+ nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
1453
+ )
1454
+
1455
+ # Initialize weights and apply final processing
1456
+ self.post_init()
1457
+
1458
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1459
+ @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
1460
+ def forward(
1461
+ self,
1462
+ pixel_values: Optional[torch.Tensor] = None,
1463
+ labels: Optional[torch.Tensor] = None,
1464
+ output_attentions: Optional[bool] = None,
1465
+ output_hidden_states: Optional[bool] = None,
1466
+ return_dict: Optional[bool] = None,
1467
+ interpolate_pos_encoding: bool = False,
1468
+ ) -> Union[tuple, ImageClassifierOutput]:
1469
+ r"""
1470
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1471
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1472
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1473
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1474
+
1475
+ Returns:
1476
+
1477
+ Examples:
1478
+
1479
+ ```python
1480
+ >>> from transformers import AutoImageProcessor, SiglipForImageClassification
1481
+ >>> import torch
1482
+ >>> from PIL import Image
1483
+ >>> import requests
1484
+
1485
+ >>> torch.manual_seed(3) # doctest: +IGNORE_RESULT
1486
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1487
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1488
+
1489
+ >>> # note: we are loading a `SiglipModel` from the hub here,
1490
+ >>> # so the head will be randomly initialized, hence the predictions will be random if seed is not set above.
1491
+ >>> image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
1492
+ >>> model = SiglipForImageClassification.from_pretrained("google/siglip-base-patch16-224")
1493
+
1494
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1495
+ >>> outputs = model(**inputs)
1496
+ >>> logits = outputs.logits
1497
+ >>> # model predicts one of the two classes
1498
+ >>> predicted_class_idx = logits.argmax(-1).item()
1499
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
1500
+ Predicted class: LABEL_1
1501
+ ```"""
1502
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1503
+ output_hidden_states = (
1504
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1505
+ )
1506
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1507
+
1508
+ outputs = self.vision_model(
1509
+ pixel_values,
1510
+ output_attentions=output_attentions,
1511
+ output_hidden_states=output_hidden_states,
1512
+ return_dict=return_dict,
1513
+ interpolate_pos_encoding=interpolate_pos_encoding,
1514
+ )
1515
+
1516
+ sequence_output = outputs[0]
1517
+
1518
+ # average pool the patch tokens
1519
+ sequence_output = torch.mean(sequence_output, dim=1)
1520
+ # apply classifier
1521
+ logits = self.classifier(sequence_output)
1522
+
1523
+ loss = None
1524
+ if labels is not None:
1525
+ # move labels to correct device to enable model parallelism
1526
+ labels = labels.to(logits.device)
1527
+ if self.config.problem_type is None:
1528
+ if self.num_labels == 1:
1529
+ self.config.problem_type = "regression"
1530
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1531
+ self.config.problem_type = "single_label_classification"
1532
+ else:
1533
+ self.config.problem_type = "multi_label_classification"
1534
+
1535
+ if self.config.problem_type == "regression":
1536
+ loss_fct = MSELoss()
1537
+ if self.num_labels == 1:
1538
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1539
+ else:
1540
+ loss = loss_fct(logits, labels)
1541
+ elif self.config.problem_type == "single_label_classification":
1542
+ loss_fct = CrossEntropyLoss()
1543
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1544
+ elif self.config.problem_type == "multi_label_classification":
1545
+ loss_fct = BCEWithLogitsLoss()
1546
+ loss = loss_fct(logits, labels)
1547
+
1548
+ if not return_dict:
1549
+ output = (logits,) + outputs[2:]
1550
+ return ((loss,) + output) if loss is not None else output
1551
+
1552
+ return ImageClassifierOutput(
1553
+ loss=loss,
1554
+ logits=logits,
1555
+ hidden_states=outputs.hidden_states,
1556
+ attentions=outputs.attentions,
1557
+ )
modeling/siglip/processing_siglip.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Image/Text processor class for SigLIP.
6
+ """
7
+
8
+ from typing import List, Optional, Union
9
+
10
+ from transformers.feature_extraction_utils import BatchFeature
11
+ from transformers.image_utils import ImageInput
12
+ from transformers.processing_utils import ProcessorMixin
13
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
14
+ from transformers.utils import TensorType
15
+
16
+
17
+ class SiglipProcessor(ProcessorMixin):
18
+ r"""
19
+ Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
20
+
21
+ [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
22
+ [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
23
+
24
+ Args:
25
+ image_processor ([`SiglipImageProcessor`]):
26
+ The image processor is a required input.
27
+ tokenizer ([`SiglipTokenizer`]):
28
+ The tokenizer is a required input.
29
+ """
30
+
31
+ attributes = ["image_processor", "tokenizer"]
32
+ image_processor_class = "SiglipImageProcessor"
33
+ tokenizer_class = "SiglipTokenizer"
34
+
35
+ def __init__(self, image_processor, tokenizer):
36
+ super().__init__(image_processor, tokenizer)
37
+
38
+ def __call__(
39
+ self,
40
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
41
+ images: ImageInput = None,
42
+ padding: Union[bool, str, PaddingStrategy] = False,
43
+ truncation: Union[bool, str, TruncationStrategy] = None,
44
+ max_length: int = None,
45
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
46
+ ) -> BatchFeature:
47
+ """
48
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
49
+ and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
50
+ the text. To prepare the image(s), this method forwards the `images` argument to
51
+ SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
52
+ of the above two methods for more information.
53
+
54
+ Args:
55
+ text (`str`, `List[str]`, `List[List[str]]`):
56
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
57
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
58
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
59
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
60
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
61
+ tensor. Both channels-first and channels-last formats are supported.
62
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
63
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
64
+ index) among:
65
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
66
+ sequence if provided).
67
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
68
+ acceptable input length for the model if that argument is not provided.
69
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
70
+ lengths).
71
+ max_length (`int`, *optional*):
72
+ Maximum length of the returned list and optionally padding length (see above).
73
+ truncation (`bool`, *optional*):
74
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
75
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
76
+ If set, will return tensors of a particular framework. Acceptable values are:
77
+
78
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
79
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
80
+ - `'np'`: Return NumPy `np.ndarray` objects.
81
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
82
+
83
+ Returns:
84
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
85
+
86
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
87
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
88
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
89
+ `None`).
90
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
91
+ """
92
+
93
+ if text is None and images is None:
94
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
95
+
96
+ if text is not None:
97
+ encoding = self.tokenizer(
98
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
99
+ )
100
+
101
+ if images is not None:
102
+ image_features = self.image_processor(images, return_tensors=return_tensors)
103
+
104
+ if text is not None and images is not None:
105
+ encoding["pixel_values"] = image_features.pixel_values
106
+ return encoding
107
+ elif text is not None:
108
+ return encoding
109
+ else:
110
+ return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
111
+
112
+ def decode(self, *args, **kwargs):
113
+ """
114
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
115
+ the docstring of this method for more information.
116
+ """
117
+ return self.tokenizer.decode(*args, **kwargs)
118
+
119
+ def batch_decode(self, *args, **kwargs):
120
+ """
121
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
122
+ refer to the docstring of this method for more information.
123
+ """
124
+ return self.tokenizer.batch_decode(*args, **kwargs)
125
+
126
+ @property
127
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
128
+ def model_input_names(self):
129
+ tokenizer_input_names = self.tokenizer.model_input_names
130
+ image_processor_input_names = self.image_processor.model_input_names
131
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
modeling/siglip/tokenization_siglip.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tokenization class for SigLIP model."""
5
+
6
+ import os
7
+ import re
8
+ import string
9
+ import warnings
10
+ from shutil import copyfile
11
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
12
+
13
+ import sentencepiece as spm
14
+
15
+ from transformers.convert_slow_tokenizer import import_protobuf
16
+ from transformers.tokenization_utils import PreTrainedTokenizer
17
+ from transformers.tokenization_utils_base import AddedToken
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from transformers.tokenization_utils_base import TextInput
22
+ from transformers.utils import logging, requires_backends
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
28
+
29
+
30
+ SPIECE_UNDERLINE = "▁"
31
+
32
+
33
+ class SiglipTokenizer(PreTrainedTokenizer):
34
+ """
35
+ Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
36
+
37
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
38
+ this superclass for more information regarding those methods.
39
+
40
+ Args:
41
+ vocab_file (`str`):
42
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
43
+ contains the vocabulary necessary to instantiate a tokenizer.
44
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
45
+ The end of sequence token.
46
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ pad_token (`str`, *optional*, defaults to `"</s>"`):
50
+ The token used for padding, for example when batching sequences of different lengths.
51
+ additional_special_tokens (`List[str]`, *optional*):
52
+ Additional special tokens used by the tokenizer.
53
+ sp_model_kwargs (`dict`, *optional*):
54
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
55
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
56
+ to set:
57
+
58
+ - `enable_sampling`: Enable subword regularization.
59
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
60
+
61
+ - `nbest_size = {0,1}`: No sampling is performed.
62
+ - `nbest_size > 1`: samples from the nbest_size results.
63
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
64
+ using forward-filtering-and-backward-sampling algorithm.
65
+
66
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
67
+ BPE-dropout.
68
+ model_max_length (`int`, *optional*, defaults to 64):
69
+ The maximum length (in number of tokens) for model inputs.
70
+ do_lower_case (`bool`, *optional*, defaults to `True`):
71
+ Whether or not to lowercase the input when tokenizing.
72
+ """
73
+
74
+ vocab_files_names = VOCAB_FILES_NAMES
75
+ model_input_names = ["input_ids", "attention_mask"]
76
+
77
+ def __init__(
78
+ self,
79
+ vocab_file,
80
+ eos_token="</s>",
81
+ unk_token="<unk>",
82
+ pad_token="</s>",
83
+ additional_special_tokens=None,
84
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
85
+ model_max_length=64,
86
+ do_lower_case=True,
87
+ **kwargs,
88
+ ) -> None:
89
+ requires_backends(self, "protobuf")
90
+
91
+ pad_token = (
92
+ AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True)
93
+ if isinstance(pad_token, str)
94
+ else pad_token
95
+ )
96
+ unk_token = (
97
+ AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True)
98
+ if isinstance(unk_token, str)
99
+ else unk_token
100
+ )
101
+ eos_token = (
102
+ AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True)
103
+ if isinstance(eos_token, str)
104
+ else eos_token
105
+ )
106
+
107
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
108
+
109
+ self.do_lower_case = do_lower_case
110
+ self.vocab_file = vocab_file
111
+
112
+ self.sp_model = self.get_spm_processor()
113
+ self.vocab_file = vocab_file
114
+
115
+ super().__init__(
116
+ eos_token=eos_token,
117
+ unk_token=unk_token,
118
+ pad_token=pad_token,
119
+ additional_special_tokens=additional_special_tokens,
120
+ sp_model_kwargs=self.sp_model_kwargs,
121
+ model_max_length=model_max_length,
122
+ do_lower_case=do_lower_case,
123
+ **kwargs,
124
+ )
125
+
126
+ def get_spm_processor(self):
127
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
128
+ with open(self.vocab_file, "rb") as f:
129
+ sp_model = f.read()
130
+ model_pb2 = import_protobuf()
131
+ model = model_pb2.ModelProto.FromString(sp_model)
132
+ normalizer_spec = model_pb2.NormalizerSpec()
133
+ normalizer_spec.add_dummy_prefix = False
134
+ model.normalizer_spec.MergeFrom(normalizer_spec)
135
+ sp_model = model.SerializeToString()
136
+ tokenizer.LoadFromSerializedProto(sp_model)
137
+ return tokenizer
138
+
139
+ @property
140
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size
141
+ def vocab_size(self):
142
+ return self.sp_model.get_piece_size()
143
+
144
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab
145
+ def get_vocab(self):
146
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
147
+ vocab.update(self.added_tokens_encoder)
148
+ return vocab
149
+
150
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask
151
+ def get_special_tokens_mask(
152
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
153
+ ) -> List[int]:
154
+ """
155
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
156
+ special tokens using the tokenizer `prepare_for_model` method.
157
+
158
+ Args:
159
+ token_ids_0 (`List[int]`):
160
+ List of IDs.
161
+ token_ids_1 (`List[int]`, *optional*):
162
+ Optional second list of IDs for sequence pairs.
163
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
164
+ Whether or not the token list is already formatted with special tokens for the model.
165
+
166
+ Returns:
167
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
168
+ """
169
+ if already_has_special_tokens:
170
+ return super().get_special_tokens_mask(
171
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
172
+ )
173
+
174
+ # normal case: some special tokens
175
+ if token_ids_1 is None:
176
+ return ([0] * len(token_ids_0)) + [1]
177
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
178
+
179
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present
180
+ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
181
+ """Do not add eos again if user already added it."""
182
+ if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
183
+ warnings.warn(
184
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
185
+ " eos tokens being added."
186
+ )
187
+ return token_ids
188
+ else:
189
+ return token_ids + [self.eos_token_id]
190
+
191
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences
192
+ def create_token_type_ids_from_sequences(
193
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
194
+ ) -> List[int]:
195
+ """
196
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
197
+ use of token type ids, therefore a list of zeros is returned.
198
+
199
+ Args:
200
+ token_ids_0 (`List[int]`):
201
+ List of IDs.
202
+ token_ids_1 (`List[int]`, *optional*):
203
+ Optional second list of IDs for sequence pairs.
204
+
205
+ Returns:
206
+ `List[int]`: List of zeros.
207
+ """
208
+ eos = [self.eos_token_id]
209
+
210
+ if token_ids_1 is None:
211
+ return len(token_ids_0 + eos) * [0]
212
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
213
+
214
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens
215
+ def build_inputs_with_special_tokens(
216
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
217
+ ) -> List[int]:
218
+ """
219
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
220
+ adding special tokens. A sequence has the following format:
221
+
222
+ - single sequence: `X </s>`
223
+ - pair of sequences: `A </s> B </s>`
224
+
225
+ Args:
226
+ token_ids_0 (`List[int]`):
227
+ List of IDs to which the special tokens will be added.
228
+ token_ids_1 (`List[int]`, *optional*):
229
+ Optional second list of IDs for sequence pairs.
230
+
231
+ Returns:
232
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
233
+ """
234
+ token_ids_0 = self._add_eos_if_not_present(token_ids_0)
235
+ if token_ids_1 is None:
236
+ return token_ids_0
237
+ else:
238
+ token_ids_1 = self._add_eos_if_not_present(token_ids_1)
239
+ return token_ids_0 + token_ids_1
240
+
241
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__
242
+ def __getstate__(self):
243
+ state = self.__dict__.copy()
244
+ state["sp_model"] = None
245
+ return state
246
+
247
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__
248
+ def __setstate__(self, d):
249
+ self.__dict__ = d
250
+
251
+ # for backward compatibility
252
+ if not hasattr(self, "sp_model_kwargs"):
253
+ self.sp_model_kwargs = {}
254
+
255
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
256
+ self.sp_model.Load(self.vocab_file)
257
+
258
+ def remove_punctuation(self, text: str) -> str:
259
+ return text.translate(str.maketrans("", "", string.punctuation))
260
+
261
+ # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
262
+ def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
263
+ """Returns canonicalized `text` (puncuation removed).
264
+
265
+ Args:
266
+ text (`str`):
267
+ String to be canonicalized.
268
+ keep_punctuation_exact_string (`str`, *optional*):
269
+ If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}'
270
+ (but will still remove '{' and '}' that appear separately).
271
+ """
272
+ if keep_punctuation_exact_string:
273
+ text = keep_punctuation_exact_string.join(
274
+ self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string)
275
+ )
276
+ else:
277
+ text = self.remove_punctuation(text)
278
+ text = re.sub(r"\s+", " ", text)
279
+ text = text.strip()
280
+
281
+ return text
282
+
283
+ def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
284
+ """
285
+ Converts a string to a list of tokens.
286
+ """
287
+ tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
288
+
289
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
290
+ tokens = tokens[1:]
291
+ return tokens
292
+
293
+ @property
294
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length
295
+ def unk_token_length(self):
296
+ return len(self.sp_model.encode(str(self.unk_token)))
297
+
298
+ def _tokenize(self, text, **kwargs):
299
+ """
300
+ Returns a tokenized string.
301
+
302
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
303
+ SPIECE_UNDERLINE.
304
+
305
+ For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`.
306
+
307
+ Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
308
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
309
+ """
310
+ text = self.canonicalize_text(text, keep_punctuation_exact_string=None)
311
+ tokens = self.sp_model.encode(text, out_type=str)
312
+
313
+ # 1. Encode string + prefix ex: "<unk> Hey"
314
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
315
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
316
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
317
+
318
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id
319
+ def _convert_token_to_id(self, token):
320
+ """Converts a token (str) in an id using the vocab."""
321
+ return self.sp_model.piece_to_id(token)
322
+
323
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token
324
+ def _convert_id_to_token(self, index):
325
+ """Converts an index (integer) in a token (str) using the vocab."""
326
+ token = self.sp_model.IdToPiece(index)
327
+ return token
328
+
329
+ def convert_tokens_to_string(self, tokens):
330
+ """Converts a sequence of tokens (string) in a single string."""
331
+ current_sub_tokens = []
332
+ out_string = ""
333
+ prev_is_special = False
334
+ for token in tokens:
335
+ # make sure that special tokens are not decoded using sentencepiece model
336
+ if token in self.all_special_tokens:
337
+ if not prev_is_special:
338
+ out_string += " "
339
+ out_string += self.sp_model.decode(current_sub_tokens) + token
340
+ prev_is_special = True
341
+ current_sub_tokens = []
342
+ else:
343
+ current_sub_tokens.append(token)
344
+ prev_is_special = False
345
+ out_string += self.sp_model.decode(current_sub_tokens)
346
+ return out_string.strip()
347
+
348
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary
349
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
350
+ if not os.path.isdir(save_directory):
351
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
352
+ return
353
+ out_vocab_file = os.path.join(
354
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
355
+ )
356
+
357
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
358
+ copyfile(self.vocab_file, out_vocab_file)
359
+ elif not os.path.isfile(self.vocab_file):
360
+ with open(out_vocab_file, "wb") as fi:
361
+ content_spiece_model = self.sp_model.serialized_model_proto()
362
+ fi.write(content_spiece_model)
363
+
364
+ return (out_vocab_file,)
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ decord==0.6.0
2
+ einops==0.8.1
3
+ huggingface_hub==0.29.1
4
+ matplotlib==3.7.0
5
+ numpy==1.24.4
6
+ opencv_python==4.7.0.72
7
+ pyarrow==11.0.0
8
+ PyYAML==6.0.2
9
+ Requests==2.32.3
10
+ safetensors==0.4.5
11
+ scipy==1.10.1
12
+ sentencepiece==0.1.99
13
+ torch==2.5.1
14
+ torchvision==0.20.1
15
+ transformers==4.49.0
16
+ accelerate>=0.34.0
17
+ wandb
test_images/meme.jpg ADDED
test_images/octupusy.jpg ADDED
test_images/women.jpg ADDED