KingNish commited on
Commit
12a0dd9
·
verified ·
1 Parent(s): 3ee3ce9
app.py CHANGED
@@ -1,505 +1,508 @@
1
- import gradio as gr
2
- import numpy as np
3
- import os
4
- import torch
5
- import random
6
-
7
- from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
8
- from PIL import Image
9
-
10
- from data.data_utils import add_special_tokens, pil_img2rgb
11
- from data.transforms import ImageTransform
12
- from inferencer import InterleaveInferencer
13
- from modeling.autoencoder import load_ae
14
- from modeling.bagel.qwen2_navit import NaiveCache
15
- from modeling.bagel import (
16
- BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
17
- SiglipVisionConfig, SiglipVisionModel
18
- )
19
- from modeling.qwen2 import Qwen2Tokenizer
20
-
21
-
22
- # Model Initialization
23
- model_path = "/path/to/BAGEL-7B-MoT/weights" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT
24
-
25
- llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
26
- llm_config.qk_norm = True
27
- llm_config.tie_word_embeddings = False
28
- llm_config.layer_module = "Qwen2MoTDecoderLayer"
29
-
30
- vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
31
- vit_config.rope = False
32
- vit_config.num_hidden_layers -= 1
33
-
34
- vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
35
-
36
- config = BagelConfig(
37
- visual_gen=True,
38
- visual_und=True,
39
- llm_config=llm_config,
40
- vit_config=vit_config,
41
- vae_config=vae_config,
42
- vit_max_num_patch_per_side=70,
43
- connector_act='gelu_pytorch_tanh',
44
- latent_patch_size=2,
45
- max_latent_size=64,
46
- )
47
-
48
- with init_empty_weights():
49
- language_model = Qwen2ForCausalLM(llm_config)
50
- vit_model = SiglipVisionModel(vit_config)
51
- model = Bagel(language_model, vit_model, config)
52
- model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
53
-
54
- tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
55
- tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
56
-
57
- vae_transform = ImageTransform(1024, 512, 16)
58
- vit_transform = ImageTransform(980, 224, 14)
59
-
60
- # Model Loading and Multi GPU Infernece Preparing
61
- device_map = infer_auto_device_map(
62
- model,
63
- max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
64
- no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
65
- )
66
-
67
- same_device_modules = [
68
- 'language_model.model.embed_tokens',
69
- 'time_embedder',
70
- 'latent_pos_embed',
71
- 'vae2llm',
72
- 'llm2vae',
73
- 'connector',
74
- 'vit_pos_embed'
75
- ]
76
-
77
- if torch.cuda.device_count() == 1:
78
- first_device = device_map.get(same_device_modules[0], "cuda:0")
79
- for k in same_device_modules:
80
- if k in device_map:
81
- device_map[k] = first_device
82
- else:
83
- device_map[k] = "cuda:0"
84
- else:
85
- first_device = device_map.get(same_device_modules[0])
86
- for k in same_device_modules:
87
- if k in device_map:
88
- device_map[k] = first_device
89
-
90
- model = load_checkpoint_and_dispatch(
91
- model,
92
- checkpoint=os.path.join(model_path, "ema.safetensors"),
93
- device_map=device_map,
94
- offload_buffers=True,
95
- dtype=torch.bfloat16,
96
- force_hooks=True,
97
- ).eval()
98
-
99
-
100
- # Inferencer Preparing
101
- inferencer = InterleaveInferencer(
102
- model=model,
103
- vae_model=vae_model,
104
- tokenizer=tokenizer,
105
- vae_transform=vae_transform,
106
- vit_transform=vit_transform,
107
- new_token_ids=new_token_ids,
108
- )
109
-
110
- def set_seed(seed):
111
- """Set random seeds for reproducibility"""
112
- if seed > 0:
113
- random.seed(seed)
114
- np.random.seed(seed)
115
- torch.manual_seed(seed)
116
- if torch.cuda.is_available():
117
- torch.cuda.manual_seed(seed)
118
- torch.cuda.manual_seed_all(seed)
119
- torch.backends.cudnn.deterministic = True
120
- torch.backends.cudnn.benchmark = False
121
- return seed
122
-
123
- # Text to Image function with thinking option and hyperparameters
124
- def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
125
- timestep_shift=3.0, num_timesteps=50,
126
- cfg_renorm_min=1.0, cfg_renorm_type="global",
127
- max_think_token_n=1024, do_sample=False, text_temperature=0.3,
128
- seed=0, image_ratio="1:1"):
129
- # Set seed for reproducibility
130
- set_seed(seed)
131
-
132
- if image_ratio == "1:1":
133
- image_shapes = (1024, 1024)
134
- elif image_ratio == "4:3":
135
- image_shapes = (768, 1024)
136
- elif image_ratio == "3:4":
137
- image_shapes = (1024, 768)
138
- elif image_ratio == "16:9":
139
- image_shapes = (576, 1024)
140
- elif image_ratio == "9:16":
141
- image_shapes = (1024, 576)
142
-
143
- # Set hyperparameters
144
- inference_hyper = dict(
145
- max_think_token_n=max_think_token_n if show_thinking else 1024,
146
- do_sample=do_sample if show_thinking else False,
147
- text_temperature=text_temperature if show_thinking else 0.3,
148
- cfg_text_scale=cfg_text_scale,
149
- cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
150
- timestep_shift=timestep_shift,
151
- num_timesteps=num_timesteps,
152
- cfg_renorm_min=cfg_renorm_min,
153
- cfg_renorm_type=cfg_renorm_type,
154
- image_shapes=image_shapes,
155
- )
156
-
157
- # Call inferencer with or without think parameter based on user choice
158
- result = inferencer(text=prompt, think=show_thinking, **inference_hyper)
159
- return result["image"], result.get("text", None)
160
-
161
-
162
- # Image Understanding function with thinking option and hyperparameters
163
- def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
164
- do_sample=False, text_temperature=0.3, max_new_tokens=512):
165
- if image is None:
166
- return "Please upload an image."
167
-
168
- if isinstance(image, np.ndarray):
169
- image = Image.fromarray(image)
170
-
171
- image = pil_img2rgb(image)
172
-
173
- # Set hyperparameters
174
- inference_hyper = dict(
175
- do_sample=do_sample,
176
- text_temperature=text_temperature,
177
- max_think_token_n=max_new_tokens, # Set max_length
178
- )
179
-
180
- # Use show_thinking parameter to control thinking process
181
- result = inferencer(image=image, text=prompt, think=show_thinking,
182
- understanding_output=True, **inference_hyper)
183
- return result["text"]
184
-
185
-
186
- # Image Editing function with thinking option and hyperparameters
187
- def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
188
- cfg_img_scale=2.0, cfg_interval=0.0,
189
- timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
190
- cfg_renorm_type="text_channel", max_think_token_n=1024,
191
- do_sample=False, text_temperature=0.3, seed=0):
192
- # Set seed for reproducibility
193
- set_seed(seed)
194
-
195
- if image is None:
196
- return "Please upload an image.", ""
197
-
198
- if isinstance(image, np.ndarray):
199
- image = Image.fromarray(image)
200
-
201
- image = pil_img2rgb(image)
202
-
203
- # Set hyperparameters
204
- inference_hyper = dict(
205
- max_think_token_n=max_think_token_n if show_thinking else 1024,
206
- do_sample=do_sample if show_thinking else False,
207
- text_temperature=text_temperature if show_thinking else 0.3,
208
- cfg_text_scale=cfg_text_scale,
209
- cfg_img_scale=cfg_img_scale,
210
- cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
211
- timestep_shift=timestep_shift,
212
- num_timesteps=num_timesteps,
213
- cfg_renorm_min=cfg_renorm_min,
214
- cfg_renorm_type=cfg_renorm_type,
215
- )
216
-
217
- # Include thinking parameter based on user choice
218
- result = inferencer(image=image, text=prompt, think=show_thinking, **inference_hyper)
219
- return result["image"], result.get("text", "")
220
-
221
-
222
- # Helper function to load example images
223
- def load_example_image(image_path):
224
- try:
225
- return Image.open(image_path)
226
- except Exception as e:
227
- print(f"Error loading example image: {e}")
228
- return None
229
-
230
-
231
- # Gradio UI
232
- with gr.Blocks() as demo:
233
- gr.Markdown("""
234
- <div>
235
- <img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/>
236
- </div>
237
- """)
238
-
239
- with gr.Tab("📝 Text to Image"):
240
- txt_input = gr.Textbox(
241
- label="Prompt",
242
- value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
243
- )
244
-
245
- with gr.Row():
246
- show_thinking = gr.Checkbox(label="Thinking", value=False)
247
-
248
- # Add hyperparameter controls in an accordion
249
- with gr.Accordion("Inference Hyperparameters", open=False):
250
- # 参数一排两个布局
251
- with gr.Group():
252
- with gr.Row():
253
- seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1,
254
- label="Seed", info="0 for random seed, positive for reproducible results")
255
- image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"],
256
- value="1:1", label="Image Ratio",
257
- info="The longer size is fixed to 1024")
258
-
259
- with gr.Row():
260
- cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
261
- label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)")
262
- cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1,
263
- label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
264
-
265
- with gr.Row():
266
- cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
267
- value="global", label="CFG Renorm Type",
268
- info="If the genrated image is blurry, use 'global'")
269
- cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
270
- label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
271
-
272
- with gr.Row():
273
- num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
274
- label="Timesteps", info="Total denoising steps")
275
- timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True,
276
- label="Timestep Shift", info="Higher values for layout, lower for details")
277
-
278
- # Thinking parameters in a single row
279
- thinking_params = gr.Group(visible=False)
280
- with thinking_params:
281
- with gr.Row():
282
- do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
283
- max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
284
- label="Max Think Tokens", info="Maximum number of tokens for thinking")
285
- text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
286
- label="Temperature", info="Controls randomness in text generation")
287
-
288
- thinking_output = gr.Textbox(label="Thinking Process", visible=False)
289
- img_output = gr.Image(label="Generated Image")
290
- gen_btn = gr.Button("Generate")
291
-
292
- # Dynamically show/hide thinking process box and parameters
293
- def update_thinking_visibility(show):
294
- return gr.update(visible=show), gr.update(visible=show)
295
-
296
- show_thinking.change(
297
- fn=update_thinking_visibility,
298
- inputs=[show_thinking],
299
- outputs=[thinking_output, thinking_params]
300
- )
301
-
302
- # Process function based on thinking option and hyperparameters
303
- def process_text_to_image(prompt, show_thinking, cfg_text_scale,
304
- cfg_interval, timestep_shift,
305
- num_timesteps, cfg_renorm_min, cfg_renorm_type,
306
- max_think_token_n, do_sample, text_temperature, seed, image_ratio):
307
- image, thinking = text_to_image(
308
- prompt, show_thinking, cfg_text_scale, cfg_interval,
309
- timestep_shift, num_timesteps,
310
- cfg_renorm_min, cfg_renorm_type,
311
- max_think_token_n, do_sample, text_temperature, seed, image_ratio
312
- )
313
- return image, thinking if thinking else ""
314
-
315
- gen_btn.click(
316
- fn=process_text_to_image,
317
- inputs=[
318
- txt_input, show_thinking, cfg_text_scale,
319
- cfg_interval, timestep_shift,
320
- num_timesteps, cfg_renorm_min, cfg_renorm_type,
321
- max_think_token_n, do_sample, text_temperature, seed, image_ratio
322
- ],
323
- outputs=[img_output, thinking_output]
324
- )
325
-
326
- with gr.Tab("🖌️ Image Edit"):
327
- with gr.Row():
328
- with gr.Column(scale=1):
329
- edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'))
330
- edit_prompt = gr.Textbox(
331
- label="Prompt",
332
- value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes."
333
- )
334
-
335
- with gr.Column(scale=1):
336
- edit_image_output = gr.Image(label="Result")
337
- edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
338
-
339
- with gr.Row():
340
- edit_show_thinking = gr.Checkbox(label="Thinking", value=False)
341
-
342
- # Add hyperparameter controls in an accordion
343
- with gr.Accordion("Inference Hyperparameters", open=False):
344
- with gr.Group():
345
- with gr.Row():
346
- edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True,
347
- label="Seed", info="0 for random seed, positive for reproducible results")
348
- edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
349
- label="CFG Text Scale", info="Controls how strongly the model follows the text prompt")
350
-
351
- with gr.Row():
352
- edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True,
353
- label="CFG Image Scale", info="Controls how much the model preserves input image details")
354
- edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
355
- label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
356
-
357
- with gr.Row():
358
- edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
359
- value="text_channel", label="CFG Renorm Type",
360
- info="If the genrated image is blurry, use 'global")
361
- edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
362
- label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
363
-
364
- with gr.Row():
365
- edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
366
- label="Timesteps", info="Total denoising steps")
367
- edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True,
368
- label="Timestep Shift", info="Higher values for layout, lower for details")
369
-
370
-
371
- # Thinking parameters in a single row
372
- edit_thinking_params = gr.Group(visible=False)
373
- with edit_thinking_params:
374
- with gr.Row():
375
- edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
376
- edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
377
- label="Max Think Tokens", info="Maximum number of tokens for thinking")
378
- edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
379
- label="Temperature", info="Controls randomness in text generation")
380
-
381
- edit_btn = gr.Button("Submit")
382
-
383
- # Dynamically show/hide thinking process box for editing
384
- def update_edit_thinking_visibility(show):
385
- return gr.update(visible=show), gr.update(visible=show)
386
-
387
- edit_show_thinking.change(
388
- fn=update_edit_thinking_visibility,
389
- inputs=[edit_show_thinking],
390
- outputs=[edit_thinking_output, edit_thinking_params]
391
- )
392
-
393
- # Process editing with thinking option and hyperparameters
394
- def process_edit_image(image, prompt, show_thinking, cfg_text_scale,
395
- cfg_img_scale, cfg_interval,
396
- timestep_shift, num_timesteps, cfg_renorm_min,
397
- cfg_renorm_type, max_think_token_n, do_sample,
398
- text_temperature, seed):
399
- edited_image, thinking = edit_image(
400
- image, prompt, show_thinking, cfg_text_scale, cfg_img_scale,
401
- cfg_interval, timestep_shift,
402
- num_timesteps, cfg_renorm_min, cfg_renorm_type,
403
- max_think_token_n, do_sample, text_temperature, seed
404
- )
405
-
406
- return edited_image, thinking if thinking else ""
407
-
408
- edit_btn.click(
409
- fn=process_edit_image,
410
- inputs=[
411
- edit_image_input, edit_prompt, edit_show_thinking,
412
- edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval,
413
- edit_timestep_shift, edit_num_timesteps,
414
- edit_cfg_renorm_min, edit_cfg_renorm_type,
415
- edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed
416
- ],
417
- outputs=[edit_image_output, edit_thinking_output]
418
- )
419
-
420
- with gr.Tab("🖼️ Image Understanding"):
421
- with gr.Row():
422
- with gr.Column(scale=1):
423
- img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'))
424
- understand_prompt = gr.Textbox(
425
- label="Prompt",
426
- value="Can someone explain what's funny about this meme??"
427
- )
428
-
429
- with gr.Column(scale=1):
430
- txt_output = gr.Textbox(label="Result", lines=20)
431
-
432
- with gr.Row():
433
- understand_show_thinking = gr.Checkbox(label="Thinking", value=False)
434
-
435
- # Add hyperparameter controls in an accordion
436
- with gr.Accordion("Inference Hyperparameters", open=False):
437
- with gr.Row():
438
- understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
439
- understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True,
440
- label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)")
441
- understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True,
442
- label="Max New Tokens", info="Maximum length of generated text, including potential thinking")
443
-
444
- img_understand_btn = gr.Button("Submit")
445
-
446
- # Process understanding with thinking option and hyperparameters
447
- def process_understanding(image, prompt, show_thinking, do_sample,
448
- text_temperature, max_new_tokens):
449
- result = image_understanding(
450
- image, prompt, show_thinking, do_sample,
451
- text_temperature, max_new_tokens
452
- )
453
- return result
454
-
455
- img_understand_btn.click(
456
- fn=process_understanding,
457
- inputs=[
458
- img_input, understand_prompt, understand_show_thinking,
459
- understand_do_sample, understand_text_temperature, understand_max_new_tokens
460
- ],
461
- outputs=txt_output
462
- )
463
-
464
- gr.Markdown("""
465
- <div style="display: flex; justify-content: flex-start; flex-wrap: wrap; gap: 10px;">
466
- <a href="https://bagel-ai.org/">
467
- <img
468
- src="https://img.shields.io/badge/BAGEL-Website-0A66C2?logo=safari&logoColor=white"
469
- alt="BAGEL Website"
470
- />
471
- </a>
472
- <a href="https://arxiv.org/abs/2505.14683">
473
- <img
474
- src="https://img.shields.io/badge/BAGEL-Paper-red?logo=arxiv&logoColor=red"
475
- alt="BAGEL Paper on arXiv"
476
- />
477
- </a>
478
- <a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
479
- <img
480
- src="https://img.shields.io/badge/BAGEL-Hugging%20Face-orange?logo=huggingface&logoColor=yellow"
481
- alt="BAGEL on Hugging Face"
482
- />
483
- </a>
484
- <a href="https://demo.bagel-ai.org/">
485
- <img
486
- src="https://img.shields.io/badge/BAGEL-Demo-blue?logo=googleplay&logoColor=blue"
487
- alt="BAGEL Demo"
488
- />
489
- </a>
490
- <a href="https://discord.gg/Z836xxzy">
491
- <img
492
- src="https://img.shields.io/badge/BAGEL-Discord-5865F2?logo=discord&logoColor=purple"
493
- alt="BAGEL Discord"
494
- />
495
- </a>
496
- <a href="mailto:[email protected]">
497
- <img
498
- src="https://img.shields.io/badge/BAGEL-Email-D14836?logo=gmail&logoColor=red"
499
- alt="BAGEL Email"
500
- />
501
- </a>
502
- </div>
503
- """)
504
-
505
- demo.launch(share=True)
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ import random
7
+ import subprocess
8
+ subprocess.run(
9
+ "pip install flash-attn --no-build-isolation",
10
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
11
+ shell=True,
12
+ )
13
+
14
+ from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
+ from PIL import Image
16
+
17
+ from data.data_utils import add_special_tokens, pil_img2rgb
18
+ from data.transforms import ImageTransform
19
+ from inferencer import InterleaveInferencer
20
+ from modeling.autoencoder import load_ae
21
+ from modeling.bagel.qwen2_navit import NaiveCache
22
+ from modeling.bagel import (
23
+ BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
24
+ SiglipVisionConfig, SiglipVisionModel
25
+ )
26
+ from modeling.qwen2 import Qwen2Tokenizer
27
+
28
+ from huggingface_hub import snapshot_download
29
+
30
+ save_dir = "./model"
31
+ repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
32
+ cache_dir = save_dir + "/cache"
33
+
34
+ snapshot_download(cache_dir=cache_dir,
35
+ local_dir=save_dir,
36
+ repo_id=repo_id,
37
+ local_dir_use_symlinks=False,
38
+ resume_download=True,
39
+ allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
40
+ )
41
+
42
+ # Model Initialization
43
+ model_path = "./model" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT
44
+
45
+ llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
46
+ llm_config.qk_norm = True
47
+ llm_config.tie_word_embeddings = False
48
+ llm_config.layer_module = "Qwen2MoTDecoderLayer"
49
+
50
+ vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
51
+ vit_config.rope = False
52
+ vit_config.num_hidden_layers -= 1
53
+
54
+ vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
55
+
56
+ config = BagelConfig(
57
+ visual_gen=True,
58
+ visual_und=True,
59
+ llm_config=llm_config,
60
+ vit_config=vit_config,
61
+ vae_config=vae_config,
62
+ vit_max_num_patch_per_side=70,
63
+ connector_act='gelu_pytorch_tanh',
64
+ latent_patch_size=2,
65
+ max_latent_size=64,
66
+ )
67
+
68
+ with init_empty_weights():
69
+ language_model = Qwen2ForCausalLM(llm_config)
70
+ vit_model = SiglipVisionModel(vit_config)
71
+ model = Bagel(language_model, vit_model, config)
72
+ model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
73
+
74
+ tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
75
+ tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
76
+
77
+ vae_transform = ImageTransform(1024, 512, 16)
78
+ vit_transform = ImageTransform(980, 224, 14)
79
+
80
+ # Model Loading and Multi GPU Infernece Preparing
81
+ device_map = infer_auto_device_map(
82
+ model,
83
+ max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
84
+ no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
85
+ )
86
+
87
+ same_device_modules = [
88
+ 'language_model.model.embed_tokens',
89
+ 'time_embedder',
90
+ 'latent_pos_embed',
91
+ 'vae2llm',
92
+ 'llm2vae',
93
+ 'connector',
94
+ 'vit_pos_embed'
95
+ ]
96
+
97
+ if torch.cuda.device_count() == 1:
98
+ first_device = device_map.get(same_device_modules[0], "cuda:0")
99
+ for k in same_device_modules:
100
+ if k in device_map:
101
+ device_map[k] = first_device
102
+ else:
103
+ device_map[k] = "cuda:0"
104
+ else:
105
+ first_device = device_map.get(same_device_modules[0])
106
+ for k in same_device_modules:
107
+ if k in device_map:
108
+ device_map[k] = first_device
109
+
110
+ model = load_checkpoint_and_dispatch(
111
+ model,
112
+ checkpoint=os.path.join(model_path, "ema.safetensors"),
113
+ device_map=device_map,
114
+ offload_buffers=True,
115
+ dtype=torch.bfloat16,
116
+ force_hooks=True,
117
+ ).eval()
118
+
119
+
120
+ # Inferencer Preparing
121
+ inferencer = InterleaveInferencer(
122
+ model=model,
123
+ vae_model=vae_model,
124
+ tokenizer=tokenizer,
125
+ vae_transform=vae_transform,
126
+ vit_transform=vit_transform,
127
+ new_token_ids=new_token_ids,
128
+ )
129
+
130
+ def set_seed(seed):
131
+ """Set random seeds for reproducibility"""
132
+ if seed > 0:
133
+ random.seed(seed)
134
+ np.random.seed(seed)
135
+ torch.manual_seed(seed)
136
+ if torch.cuda.is_available():
137
+ torch.cuda.manual_seed(seed)
138
+ torch.cuda.manual_seed_all(seed)
139
+ torch.backends.cudnn.deterministic = True
140
+ torch.backends.cudnn.benchmark = False
141
+ return seed
142
+
143
+ # Text to Image function with thinking option and hyperparameters
144
+ @spaces.GPU(duration=90)
145
+ def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
146
+ timestep_shift=3.0, num_timesteps=50,
147
+ cfg_renorm_min=1.0, cfg_renorm_type="global",
148
+ max_think_token_n=1024, do_sample=False, text_temperature=0.3,
149
+ seed=0, image_ratio="1:1"):
150
+ # Set seed for reproducibility
151
+ set_seed(seed)
152
+
153
+ if image_ratio == "1:1":
154
+ image_shapes = (1024, 1024)
155
+ elif image_ratio == "4:3":
156
+ image_shapes = (768, 1024)
157
+ elif image_ratio == "3:4":
158
+ image_shapes = (1024, 768)
159
+ elif image_ratio == "16:9":
160
+ image_shapes = (576, 1024)
161
+ elif image_ratio == "9:16":
162
+ image_shapes = (1024, 576)
163
+
164
+ # Set hyperparameters
165
+ inference_hyper = dict(
166
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
167
+ do_sample=do_sample if show_thinking else False,
168
+ text_temperature=text_temperature if show_thinking else 0.3,
169
+ cfg_text_scale=cfg_text_scale,
170
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
171
+ timestep_shift=timestep_shift,
172
+ num_timesteps=num_timesteps,
173
+ cfg_renorm_min=cfg_renorm_min,
174
+ cfg_renorm_type=cfg_renorm_type,
175
+ image_shapes=image_shapes,
176
+ )
177
+
178
+ result = {}
179
+
180
+ # Call inferencer with or without think parameter based on user choice
181
+ for i in inferencer(text=prompt, think=show_thinking, **inference_hyper):
182
+ if type(i) == str:
183
+ result["text"] += i
184
+ elif type(i) == Image.Image:
185
+ result["image"] = i
186
+
187
+ yield result["image"], result.get("text", None)
188
+
189
+
190
+ # Image Understanding function with thinking option and hyperparameters
191
+ @spaces.GPU(duration=90)
192
+ def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
193
+ do_sample=False, text_temperature=0.3, max_new_tokens=512):
194
+ if image is None:
195
+ return "Please upload an image."
196
+
197
+ if isinstance(image, np.ndarray):
198
+ image = Image.fromarray(image)
199
+
200
+ image = pil_img2rgb(image)
201
+
202
+ # Set hyperparameters
203
+ inference_hyper = dict(
204
+ do_sample=do_sample,
205
+ text_temperature=text_temperature,
206
+ max_think_token_n=max_new_tokens, # Set max_length
207
+ )
208
+
209
+ result = {}
210
+ # Use show_thinking parameter to control thinking process
211
+ for i in inferencer(image=image, text=prompt, think=show_thinking,
212
+ understanding_output=True, **inference_hyper):
213
+ if type(i) == str:
214
+ result["text"] += i
215
+ elif type(i) == Image.Image:
216
+ result["image"] = i
217
+ yield result["text"]
218
+
219
+
220
+ # Image Editing function with thinking option and hyperparameters
221
+ @spaces.GPU(duration=90)
222
+ def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
223
+ cfg_img_scale=2.0, cfg_interval=0.0,
224
+ timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
225
+ cfg_renorm_type="text_channel", max_think_token_n=1024,
226
+ do_sample=False, text_temperature=0.3, seed=0):
227
+ # Set seed for reproducibility
228
+ set_seed(seed)
229
+
230
+ if image is None:
231
+ return "Please upload an image.", ""
232
+
233
+ if isinstance(image, np.ndarray):
234
+ image = Image.fromarray(image)
235
+
236
+ image = pil_img2rgb(image)
237
+
238
+ # Set hyperparameters
239
+ inference_hyper = dict(
240
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
241
+ do_sample=do_sample if show_thinking else False,
242
+ text_temperature=text_temperature if show_thinking else 0.3,
243
+ cfg_text_scale=cfg_text_scale,
244
+ cfg_img_scale=cfg_img_scale,
245
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
246
+ timestep_shift=timestep_shift,
247
+ num_timesteps=num_timesteps,
248
+ cfg_renorm_min=cfg_renorm_min,
249
+ cfg_renorm_type=cfg_renorm_type,
250
+ )
251
+
252
+ # Include thinking parameter based on user choice
253
+ result = {}
254
+ for i in inferencer(image=image, text=prompt, think=show_thinking, **inference_hyper):
255
+ if type(i) == str:
256
+ result["text"] += i
257
+ elif type(i) == Image.Image:
258
+ result["image"] = i
259
+
260
+ yield result["image"], result.get("text", "")
261
+
262
+ # Helper function to load example images
263
+ def load_example_image(image_path):
264
+ try:
265
+ return Image.open(image_path)
266
+ except Exception as e:
267
+ print(f"Error loading example image: {e}")
268
+ return None
269
+
270
+
271
+ # Gradio UI
272
+ with gr.Blocks() as demo:
273
+ gr.Markdown("""
274
+ <div>
275
+ <img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/>
276
+ </div>
277
+ """)
278
+
279
+ with gr.Tab("📝 Text to Image"):
280
+ txt_input = gr.Textbox(
281
+ label="Prompt",
282
+ value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
283
+ )
284
+
285
+ with gr.Row():
286
+ show_thinking = gr.Checkbox(label="Thinking", value=False)
287
+
288
+ # Add hyperparameter controls in an accordion
289
+ with gr.Accordion("Inference Hyperparameters", open=False):
290
+ # 参数一排两个布局
291
+ with gr.Group():
292
+ with gr.Row():
293
+ seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1,
294
+ label="Seed", info="0 for random seed, positive for reproducible results")
295
+ image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"],
296
+ value="1:1", label="Image Ratio",
297
+ info="The longer size is fixed to 1024")
298
+
299
+ with gr.Row():
300
+ cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
301
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)")
302
+ cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1,
303
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
304
+
305
+ with gr.Row():
306
+ cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
307
+ value="global", label="CFG Renorm Type",
308
+ info="If the genrated image is blurry, use 'global'")
309
+ cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
310
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
311
+
312
+ with gr.Row():
313
+ num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
314
+ label="Timesteps", info="Total denoising steps")
315
+ timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True,
316
+ label="Timestep Shift", info="Higher values for layout, lower for details")
317
+
318
+ # Thinking parameters in a single row
319
+ thinking_params = gr.Group(visible=False)
320
+ with thinking_params:
321
+ with gr.Row():
322
+ do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
323
+ max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
324
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
325
+ text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
326
+ label="Temperature", info="Controls randomness in text generation")
327
+
328
+ thinking_output = gr.Textbox(label="Thinking Process", visible=False)
329
+ img_output = gr.Image(label="Generated Image")
330
+ gen_btn = gr.Button("Generate")
331
+
332
+ # Dynamically show/hide thinking process box and parameters
333
+ def update_thinking_visibility(show):
334
+ return gr.update(visible=show), gr.update(visible=show)
335
+
336
+ show_thinking.change(
337
+ fn=update_thinking_visibility,
338
+ inputs=[show_thinking],
339
+ outputs=[thinking_output, thinking_params]
340
+ )
341
+
342
+ gen_btn.click(
343
+ fn=text_to_image,
344
+ inputs=[
345
+ txt_input, show_thinking, cfg_text_scale,
346
+ cfg_interval, timestep_shift,
347
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
348
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio
349
+ ],
350
+ outputs=[img_output, thinking_output]
351
+ )
352
+
353
+ with gr.Tab("🖌️ Image Edit"):
354
+ with gr.Row():
355
+ with gr.Column(scale=1):
356
+ edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'))
357
+ edit_prompt = gr.Textbox(
358
+ label="Prompt",
359
+ value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes."
360
+ )
361
+
362
+ with gr.Column(scale=1):
363
+ edit_image_output = gr.Image(label="Result")
364
+ edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
365
+
366
+ with gr.Row():
367
+ edit_show_thinking = gr.Checkbox(label="Thinking", value=False)
368
+
369
+ # Add hyperparameter controls in an accordion
370
+ with gr.Accordion("Inference Hyperparameters", open=False):
371
+ with gr.Group():
372
+ with gr.Row():
373
+ edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True,
374
+ label="Seed", info="0 for random seed, positive for reproducible results")
375
+ edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
376
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt")
377
+
378
+ with gr.Row():
379
+ edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True,
380
+ label="CFG Image Scale", info="Controls how much the model preserves input image details")
381
+ edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
382
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
383
+
384
+ with gr.Row():
385
+ edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
386
+ value="text_channel", label="CFG Renorm Type",
387
+ info="If the genrated image is blurry, use 'global")
388
+ edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
389
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
390
+
391
+ with gr.Row():
392
+ edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
393
+ label="Timesteps", info="Total denoising steps")
394
+ edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True,
395
+ label="Timestep Shift", info="Higher values for layout, lower for details")
396
+
397
+
398
+ # Thinking parameters in a single row
399
+ edit_thinking_params = gr.Group(visible=False)
400
+ with edit_thinking_params:
401
+ with gr.Row():
402
+ edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
403
+ edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
404
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
405
+ edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
406
+ label="Temperature", info="Controls randomness in text generation")
407
+
408
+ edit_btn = gr.Button("Submit")
409
+
410
+ # Dynamically show/hide thinking process box for editing
411
+ def update_edit_thinking_visibility(show):
412
+ return gr.update(visible=show), gr.update(visible=show)
413
+
414
+ edit_show_thinking.change(
415
+ fn=update_edit_thinking_visibility,
416
+ inputs=[edit_show_thinking],
417
+ outputs=[edit_thinking_output, edit_thinking_params]
418
+ )
419
+
420
+ edit_btn.click(
421
+ fn=edit_image,
422
+ inputs=[
423
+ edit_image_input, edit_prompt, edit_show_thinking,
424
+ edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval,
425
+ edit_timestep_shift, edit_num_timesteps,
426
+ edit_cfg_renorm_min, edit_cfg_renorm_type,
427
+ edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed
428
+ ],
429
+ outputs=[edit_image_output, edit_thinking_output]
430
+ )
431
+
432
+ with gr.Tab("🖼️ Image Understanding"):
433
+ with gr.Row():
434
+ with gr.Column(scale=1):
435
+ img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'))
436
+ understand_prompt = gr.Textbox(
437
+ label="Prompt",
438
+ value="Can someone explain what's funny about this meme??"
439
+ )
440
+
441
+ with gr.Column(scale=1):
442
+ txt_output = gr.Textbox(label="Result", lines=20)
443
+
444
+ with gr.Row():
445
+ understand_show_thinking = gr.Checkbox(label="Thinking", value=False)
446
+
447
+ # Add hyperparameter controls in an accordion
448
+ with gr.Accordion("Inference Hyperparameters", open=False):
449
+ with gr.Row():
450
+ understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
451
+ understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True,
452
+ label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)")
453
+ understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True,
454
+ label="Max New Tokens", info="Maximum length of generated text, including potential thinking")
455
+
456
+ img_understand_btn = gr.Button("Submit")
457
+
458
+ img_understand_btn.click(
459
+ fn=image_understanding,
460
+ inputs=[
461
+ img_input, understand_prompt, understand_show_thinking,
462
+ understand_do_sample, understand_text_temperature, understand_max_new_tokens
463
+ ],
464
+ outputs=txt_output
465
+ )
466
+
467
+ gr.Markdown("""
468
+ <div style="display: flex; justify-content: flex-start; flex-wrap: wrap; gap: 10px;">
469
+ <a href="https://bagel-ai.org/">
470
+ <img
471
+ src="https://img.shields.io/badge/BAGEL-Website-0A66C2?logo=safari&logoColor=white"
472
+ alt="BAGEL Website"
473
+ />
474
+ </a>
475
+ <a href="https://arxiv.org/abs/2505.14683">
476
+ <img
477
+ src="https://img.shields.io/badge/BAGEL-Paper-red?logo=arxiv&logoColor=red"
478
+ alt="BAGEL Paper on arXiv"
479
+ />
480
+ </a>
481
+ <a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
482
+ <img
483
+ src="https://img.shields.io/badge/BAGEL-Hugging%20Face-orange?logo=huggingface&logoColor=yellow"
484
+ alt="BAGEL on Hugging Face"
485
+ />
486
+ </a>
487
+ <a href="https://demo.bagel-ai.org/">
488
+ <img
489
+ src="https://img.shields.io/badge/BAGEL-Demo-blue?logo=googleplay&logoColor=blue"
490
+ alt="BAGEL Demo"
491
+ />
492
+ </a>
493
+ <a href="https://discord.gg/Z836xxzy">
494
+ <img
495
+ src="https://img.shields.io/badge/BAGEL-Discord-5865F2?logo=discord&logoColor=purple"
496
+ alt="BAGEL Discord"
497
+ />
498
+ </a>
499
+ <a href="mailto:[email protected]">
500
+ <img
501
+ src="https://img.shields.io/badge/BAGEL-Email-D14836?logo=gmail&logoColor=red"
502
+ alt="BAGEL Email"
503
+ />
504
+ </a>
505
+ </div>
506
+ """)
507
+
508
+ demo.launch()
modeling/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
  from . import bagel, qwen2, siglip, autoencoder
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
  from . import bagel, qwen2, siglip, autoencoder
modeling/autoencoder.py CHANGED
@@ -1,361 +1,361 @@
1
- # Copyright (c) 2024 Black Forest Labs.
2
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
- #
7
- # Original file was released under Apache-2.0, with the full license text
8
- # available at https://github.com/black-forest-labs/flux/blob/main/LICENSE.
9
- #
10
- # This modified file is released under the same license.
11
-
12
- from dataclasses import dataclass
13
-
14
- import torch
15
- from einops import rearrange
16
- from torch import Tensor, nn
17
- from huggingface_hub import hf_hub_download
18
- from safetensors.torch import load_file as load_sft
19
-
20
-
21
- @dataclass
22
- class AutoEncoderParams:
23
- resolution: int
24
- in_channels: int
25
- downsample: int
26
- ch: int
27
- out_ch: int
28
- ch_mult: list[int]
29
- num_res_blocks: int
30
- z_channels: int
31
- scale_factor: float
32
- shift_factor: float
33
-
34
-
35
- def swish(x: Tensor) -> Tensor:
36
- return x * torch.sigmoid(x)
37
-
38
-
39
- class AttnBlock(nn.Module):
40
- def __init__(self, in_channels: int):
41
- super().__init__()
42
- self.in_channels = in_channels
43
-
44
- self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
45
-
46
- self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
47
- self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
48
- self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
49
- self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
50
-
51
- def attention(self, h_: Tensor) -> Tensor:
52
- h_ = self.norm(h_)
53
- q = self.q(h_)
54
- k = self.k(h_)
55
- v = self.v(h_)
56
-
57
- b, c, h, w = q.shape
58
- q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
59
- k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
60
- v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
61
- h_ = nn.functional.scaled_dot_product_attention(q, k, v)
62
-
63
- return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
64
-
65
- def forward(self, x: Tensor) -> Tensor:
66
- return x + self.proj_out(self.attention(x))
67
-
68
-
69
- class ResnetBlock(nn.Module):
70
- def __init__(self, in_channels: int, out_channels: int):
71
- super().__init__()
72
- self.in_channels = in_channels
73
- out_channels = in_channels if out_channels is None else out_channels
74
- self.out_channels = out_channels
75
-
76
- self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
77
- self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
78
- self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
79
- self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
80
- if self.in_channels != self.out_channels:
81
- self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
82
-
83
- def forward(self, x):
84
- h = x
85
- h = self.norm1(h)
86
- h = swish(h)
87
- h = self.conv1(h)
88
-
89
- h = self.norm2(h)
90
- h = swish(h)
91
- h = self.conv2(h)
92
-
93
- if self.in_channels != self.out_channels:
94
- x = self.nin_shortcut(x)
95
-
96
- return x + h
97
-
98
-
99
- class Downsample(nn.Module):
100
- def __init__(self, in_channels: int):
101
- super().__init__()
102
- # no asymmetric padding in torch conv, must do it ourselves
103
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
104
-
105
- def forward(self, x: Tensor):
106
- pad = (0, 1, 0, 1)
107
- x = nn.functional.pad(x, pad, mode="constant", value=0)
108
- x = self.conv(x)
109
- return x
110
-
111
-
112
- class Upsample(nn.Module):
113
- def __init__(self, in_channels: int):
114
- super().__init__()
115
- self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
116
-
117
- def forward(self, x: Tensor):
118
- x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
119
- x = self.conv(x)
120
- return x
121
-
122
-
123
- class Encoder(nn.Module):
124
- def __init__(
125
- self,
126
- resolution: int,
127
- in_channels: int,
128
- ch: int,
129
- ch_mult: list[int],
130
- num_res_blocks: int,
131
- z_channels: int,
132
- ):
133
- super().__init__()
134
- self.ch = ch
135
- self.num_resolutions = len(ch_mult)
136
- self.num_res_blocks = num_res_blocks
137
- self.resolution = resolution
138
- self.in_channels = in_channels
139
- # downsampling
140
- self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
141
-
142
- curr_res = resolution
143
- in_ch_mult = (1,) + tuple(ch_mult)
144
- self.in_ch_mult = in_ch_mult
145
- self.down = nn.ModuleList()
146
- block_in = self.ch
147
- for i_level in range(self.num_resolutions):
148
- block = nn.ModuleList()
149
- attn = nn.ModuleList()
150
- block_in = ch * in_ch_mult[i_level]
151
- block_out = ch * ch_mult[i_level]
152
- for _ in range(self.num_res_blocks):
153
- block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
154
- block_in = block_out
155
- down = nn.Module()
156
- down.block = block
157
- down.attn = attn
158
- if i_level != self.num_resolutions - 1:
159
- down.downsample = Downsample(block_in)
160
- curr_res = curr_res // 2
161
- self.down.append(down)
162
-
163
- # middle
164
- self.mid = nn.Module()
165
- self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
166
- self.mid.attn_1 = AttnBlock(block_in)
167
- self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
168
-
169
- # end
170
- self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
171
- self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
172
-
173
- def forward(self, x: Tensor) -> Tensor:
174
- # downsampling
175
- hs = [self.conv_in(x)]
176
- for i_level in range(self.num_resolutions):
177
- for i_block in range(self.num_res_blocks):
178
- h = self.down[i_level].block[i_block](hs[-1])
179
- if len(self.down[i_level].attn) > 0:
180
- h = self.down[i_level].attn[i_block](h)
181
- hs.append(h)
182
- if i_level != self.num_resolutions - 1:
183
- hs.append(self.down[i_level].downsample(hs[-1]))
184
-
185
- # middle
186
- h = hs[-1]
187
- h = self.mid.block_1(h)
188
- h = self.mid.attn_1(h)
189
- h = self.mid.block_2(h)
190
- # end
191
- h = self.norm_out(h)
192
- h = swish(h)
193
- h = self.conv_out(h)
194
- return h
195
-
196
-
197
- class Decoder(nn.Module):
198
- def __init__(
199
- self,
200
- ch: int,
201
- out_ch: int,
202
- ch_mult: list[int],
203
- num_res_blocks: int,
204
- in_channels: int,
205
- resolution: int,
206
- z_channels: int,
207
- ):
208
- super().__init__()
209
- self.ch = ch
210
- self.num_resolutions = len(ch_mult)
211
- self.num_res_blocks = num_res_blocks
212
- self.resolution = resolution
213
- self.in_channels = in_channels
214
- self.ffactor = 2 ** (self.num_resolutions - 1)
215
-
216
- # compute in_ch_mult, block_in and curr_res at lowest res
217
- block_in = ch * ch_mult[self.num_resolutions - 1]
218
- curr_res = resolution // 2 ** (self.num_resolutions - 1)
219
- self.z_shape = (1, z_channels, curr_res, curr_res)
220
-
221
- # z to block_in
222
- self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
223
-
224
- # middle
225
- self.mid = nn.Module()
226
- self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
227
- self.mid.attn_1 = AttnBlock(block_in)
228
- self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
229
-
230
- # upsampling
231
- self.up = nn.ModuleList()
232
- for i_level in reversed(range(self.num_resolutions)):
233
- block = nn.ModuleList()
234
- attn = nn.ModuleList()
235
- block_out = ch * ch_mult[i_level]
236
- for _ in range(self.num_res_blocks + 1):
237
- block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
238
- block_in = block_out
239
- up = nn.Module()
240
- up.block = block
241
- up.attn = attn
242
- if i_level != 0:
243
- up.upsample = Upsample(block_in)
244
- curr_res = curr_res * 2
245
- self.up.insert(0, up) # prepend to get consistent order
246
-
247
- # end
248
- self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
249
- self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
250
-
251
- def forward(self, z: Tensor) -> Tensor:
252
- # z to block_in
253
- h = self.conv_in(z)
254
-
255
- # middle
256
- h = self.mid.block_1(h)
257
- h = self.mid.attn_1(h)
258
- h = self.mid.block_2(h)
259
-
260
- # upsampling
261
- for i_level in reversed(range(self.num_resolutions)):
262
- for i_block in range(self.num_res_blocks + 1):
263
- h = self.up[i_level].block[i_block](h)
264
- if len(self.up[i_level].attn) > 0:
265
- h = self.up[i_level].attn[i_block](h)
266
- if i_level != 0:
267
- h = self.up[i_level].upsample(h)
268
-
269
- # end
270
- h = self.norm_out(h)
271
- h = swish(h)
272
- h = self.conv_out(h)
273
- return h
274
-
275
-
276
- class DiagonalGaussian(nn.Module):
277
- def __init__(self, sample: bool = True, chunk_dim: int = 1):
278
- super().__init__()
279
- self.sample = sample
280
- self.chunk_dim = chunk_dim
281
-
282
- def forward(self, z: Tensor) -> Tensor:
283
- mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
284
- if self.sample:
285
- std = torch.exp(0.5 * logvar)
286
- return mean + std * torch.randn_like(mean)
287
- else:
288
- return mean
289
-
290
-
291
- class AutoEncoder(nn.Module):
292
- def __init__(self, params: AutoEncoderParams):
293
- super().__init__()
294
- self.encoder = Encoder(
295
- resolution=params.resolution,
296
- in_channels=params.in_channels,
297
- ch=params.ch,
298
- ch_mult=params.ch_mult,
299
- num_res_blocks=params.num_res_blocks,
300
- z_channels=params.z_channels,
301
- )
302
- self.decoder = Decoder(
303
- resolution=params.resolution,
304
- in_channels=params.in_channels,
305
- ch=params.ch,
306
- out_ch=params.out_ch,
307
- ch_mult=params.ch_mult,
308
- num_res_blocks=params.num_res_blocks,
309
- z_channels=params.z_channels,
310
- )
311
- self.reg = DiagonalGaussian()
312
-
313
- self.scale_factor = params.scale_factor
314
- self.shift_factor = params.shift_factor
315
-
316
- def encode(self, x: Tensor) -> Tensor:
317
- z = self.reg(self.encoder(x))
318
- z = self.scale_factor * (z - self.shift_factor)
319
- return z
320
-
321
- def decode(self, z: Tensor) -> Tensor:
322
- z = z / self.scale_factor + self.shift_factor
323
- return self.decoder(z)
324
-
325
- def forward(self, x: Tensor) -> Tensor:
326
- return self.decode(self.encode(x))
327
-
328
-
329
- def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
330
- if len(missing) > 0 and len(unexpected) > 0:
331
- print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
332
- print("\n" + "-" * 79 + "\n")
333
- print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
334
- elif len(missing) > 0:
335
- print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
336
- elif len(unexpected) > 0:
337
- print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
338
-
339
-
340
- def load_ae(local_path: str) -> AutoEncoder:
341
- ae_params = AutoEncoderParams(
342
- resolution=256,
343
- in_channels=3,
344
- downsample=8,
345
- ch=128,
346
- out_ch=3,
347
- ch_mult=[1, 2, 4, 4],
348
- num_res_blocks=2,
349
- z_channels=16,
350
- scale_factor=0.3611,
351
- shift_factor=0.1159,
352
- )
353
-
354
- # Loading the autoencoder
355
- ae = AutoEncoder(ae_params)
356
-
357
- if local_path is not None:
358
- sd = load_sft(local_path)
359
- missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
360
- print_load_warning(missing, unexpected)
361
- return ae, ae_params
 
1
+ # Copyright (c) 2024 Black Forest Labs.
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under Apache-2.0, with the full license text
8
+ # available at https://github.com/black-forest-labs/flux/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ from einops import rearrange
16
+ from torch import Tensor, nn
17
+ from huggingface_hub import hf_hub_download
18
+ from safetensors.torch import load_file as load_sft
19
+
20
+
21
+ @dataclass
22
+ class AutoEncoderParams:
23
+ resolution: int
24
+ in_channels: int
25
+ downsample: int
26
+ ch: int
27
+ out_ch: int
28
+ ch_mult: list[int]
29
+ num_res_blocks: int
30
+ z_channels: int
31
+ scale_factor: float
32
+ shift_factor: float
33
+
34
+
35
+ def swish(x: Tensor) -> Tensor:
36
+ return x * torch.sigmoid(x)
37
+
38
+
39
+ class AttnBlock(nn.Module):
40
+ def __init__(self, in_channels: int):
41
+ super().__init__()
42
+ self.in_channels = in_channels
43
+
44
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
45
+
46
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
47
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
48
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
49
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
50
+
51
+ def attention(self, h_: Tensor) -> Tensor:
52
+ h_ = self.norm(h_)
53
+ q = self.q(h_)
54
+ k = self.k(h_)
55
+ v = self.v(h_)
56
+
57
+ b, c, h, w = q.shape
58
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
59
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
60
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
61
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
62
+
63
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
64
+
65
+ def forward(self, x: Tensor) -> Tensor:
66
+ return x + self.proj_out(self.attention(x))
67
+
68
+
69
+ class ResnetBlock(nn.Module):
70
+ def __init__(self, in_channels: int, out_channels: int):
71
+ super().__init__()
72
+ self.in_channels = in_channels
73
+ out_channels = in_channels if out_channels is None else out_channels
74
+ self.out_channels = out_channels
75
+
76
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
77
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
78
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
79
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
80
+ if self.in_channels != self.out_channels:
81
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
82
+
83
+ def forward(self, x):
84
+ h = x
85
+ h = self.norm1(h)
86
+ h = swish(h)
87
+ h = self.conv1(h)
88
+
89
+ h = self.norm2(h)
90
+ h = swish(h)
91
+ h = self.conv2(h)
92
+
93
+ if self.in_channels != self.out_channels:
94
+ x = self.nin_shortcut(x)
95
+
96
+ return x + h
97
+
98
+
99
+ class Downsample(nn.Module):
100
+ def __init__(self, in_channels: int):
101
+ super().__init__()
102
+ # no asymmetric padding in torch conv, must do it ourselves
103
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
104
+
105
+ def forward(self, x: Tensor):
106
+ pad = (0, 1, 0, 1)
107
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
108
+ x = self.conv(x)
109
+ return x
110
+
111
+
112
+ class Upsample(nn.Module):
113
+ def __init__(self, in_channels: int):
114
+ super().__init__()
115
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
116
+
117
+ def forward(self, x: Tensor):
118
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
119
+ x = self.conv(x)
120
+ return x
121
+
122
+
123
+ class Encoder(nn.Module):
124
+ def __init__(
125
+ self,
126
+ resolution: int,
127
+ in_channels: int,
128
+ ch: int,
129
+ ch_mult: list[int],
130
+ num_res_blocks: int,
131
+ z_channels: int,
132
+ ):
133
+ super().__init__()
134
+ self.ch = ch
135
+ self.num_resolutions = len(ch_mult)
136
+ self.num_res_blocks = num_res_blocks
137
+ self.resolution = resolution
138
+ self.in_channels = in_channels
139
+ # downsampling
140
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
141
+
142
+ curr_res = resolution
143
+ in_ch_mult = (1,) + tuple(ch_mult)
144
+ self.in_ch_mult = in_ch_mult
145
+ self.down = nn.ModuleList()
146
+ block_in = self.ch
147
+ for i_level in range(self.num_resolutions):
148
+ block = nn.ModuleList()
149
+ attn = nn.ModuleList()
150
+ block_in = ch * in_ch_mult[i_level]
151
+ block_out = ch * ch_mult[i_level]
152
+ for _ in range(self.num_res_blocks):
153
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
154
+ block_in = block_out
155
+ down = nn.Module()
156
+ down.block = block
157
+ down.attn = attn
158
+ if i_level != self.num_resolutions - 1:
159
+ down.downsample = Downsample(block_in)
160
+ curr_res = curr_res // 2
161
+ self.down.append(down)
162
+
163
+ # middle
164
+ self.mid = nn.Module()
165
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
166
+ self.mid.attn_1 = AttnBlock(block_in)
167
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
168
+
169
+ # end
170
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
171
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
172
+
173
+ def forward(self, x: Tensor) -> Tensor:
174
+ # downsampling
175
+ hs = [self.conv_in(x)]
176
+ for i_level in range(self.num_resolutions):
177
+ for i_block in range(self.num_res_blocks):
178
+ h = self.down[i_level].block[i_block](hs[-1])
179
+ if len(self.down[i_level].attn) > 0:
180
+ h = self.down[i_level].attn[i_block](h)
181
+ hs.append(h)
182
+ if i_level != self.num_resolutions - 1:
183
+ hs.append(self.down[i_level].downsample(hs[-1]))
184
+
185
+ # middle
186
+ h = hs[-1]
187
+ h = self.mid.block_1(h)
188
+ h = self.mid.attn_1(h)
189
+ h = self.mid.block_2(h)
190
+ # end
191
+ h = self.norm_out(h)
192
+ h = swish(h)
193
+ h = self.conv_out(h)
194
+ return h
195
+
196
+
197
+ class Decoder(nn.Module):
198
+ def __init__(
199
+ self,
200
+ ch: int,
201
+ out_ch: int,
202
+ ch_mult: list[int],
203
+ num_res_blocks: int,
204
+ in_channels: int,
205
+ resolution: int,
206
+ z_channels: int,
207
+ ):
208
+ super().__init__()
209
+ self.ch = ch
210
+ self.num_resolutions = len(ch_mult)
211
+ self.num_res_blocks = num_res_blocks
212
+ self.resolution = resolution
213
+ self.in_channels = in_channels
214
+ self.ffactor = 2 ** (self.num_resolutions - 1)
215
+
216
+ # compute in_ch_mult, block_in and curr_res at lowest res
217
+ block_in = ch * ch_mult[self.num_resolutions - 1]
218
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
219
+ self.z_shape = (1, z_channels, curr_res, curr_res)
220
+
221
+ # z to block_in
222
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
223
+
224
+ # middle
225
+ self.mid = nn.Module()
226
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
227
+ self.mid.attn_1 = AttnBlock(block_in)
228
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
229
+
230
+ # upsampling
231
+ self.up = nn.ModuleList()
232
+ for i_level in reversed(range(self.num_resolutions)):
233
+ block = nn.ModuleList()
234
+ attn = nn.ModuleList()
235
+ block_out = ch * ch_mult[i_level]
236
+ for _ in range(self.num_res_blocks + 1):
237
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
238
+ block_in = block_out
239
+ up = nn.Module()
240
+ up.block = block
241
+ up.attn = attn
242
+ if i_level != 0:
243
+ up.upsample = Upsample(block_in)
244
+ curr_res = curr_res * 2
245
+ self.up.insert(0, up) # prepend to get consistent order
246
+
247
+ # end
248
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
249
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
250
+
251
+ def forward(self, z: Tensor) -> Tensor:
252
+ # z to block_in
253
+ h = self.conv_in(z)
254
+
255
+ # middle
256
+ h = self.mid.block_1(h)
257
+ h = self.mid.attn_1(h)
258
+ h = self.mid.block_2(h)
259
+
260
+ # upsampling
261
+ for i_level in reversed(range(self.num_resolutions)):
262
+ for i_block in range(self.num_res_blocks + 1):
263
+ h = self.up[i_level].block[i_block](h)
264
+ if len(self.up[i_level].attn) > 0:
265
+ h = self.up[i_level].attn[i_block](h)
266
+ if i_level != 0:
267
+ h = self.up[i_level].upsample(h)
268
+
269
+ # end
270
+ h = self.norm_out(h)
271
+ h = swish(h)
272
+ h = self.conv_out(h)
273
+ return h
274
+
275
+
276
+ class DiagonalGaussian(nn.Module):
277
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
278
+ super().__init__()
279
+ self.sample = sample
280
+ self.chunk_dim = chunk_dim
281
+
282
+ def forward(self, z: Tensor) -> Tensor:
283
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
284
+ if self.sample:
285
+ std = torch.exp(0.5 * logvar)
286
+ return mean + std * torch.randn_like(mean)
287
+ else:
288
+ return mean
289
+
290
+
291
+ class AutoEncoder(nn.Module):
292
+ def __init__(self, params: AutoEncoderParams):
293
+ super().__init__()
294
+ self.encoder = Encoder(
295
+ resolution=params.resolution,
296
+ in_channels=params.in_channels,
297
+ ch=params.ch,
298
+ ch_mult=params.ch_mult,
299
+ num_res_blocks=params.num_res_blocks,
300
+ z_channels=params.z_channels,
301
+ )
302
+ self.decoder = Decoder(
303
+ resolution=params.resolution,
304
+ in_channels=params.in_channels,
305
+ ch=params.ch,
306
+ out_ch=params.out_ch,
307
+ ch_mult=params.ch_mult,
308
+ num_res_blocks=params.num_res_blocks,
309
+ z_channels=params.z_channels,
310
+ )
311
+ self.reg = DiagonalGaussian()
312
+
313
+ self.scale_factor = params.scale_factor
314
+ self.shift_factor = params.shift_factor
315
+
316
+ def encode(self, x: Tensor) -> Tensor:
317
+ z = self.reg(self.encoder(x))
318
+ z = self.scale_factor * (z - self.shift_factor)
319
+ return z
320
+
321
+ def decode(self, z: Tensor) -> Tensor:
322
+ z = z / self.scale_factor + self.shift_factor
323
+ return self.decoder(z)
324
+
325
+ def forward(self, x: Tensor) -> Tensor:
326
+ return self.decode(self.encode(x))
327
+
328
+
329
+ def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
330
+ if len(missing) > 0 and len(unexpected) > 0:
331
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
332
+ print("\n" + "-" * 79 + "\n")
333
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
334
+ elif len(missing) > 0:
335
+ print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
336
+ elif len(unexpected) > 0:
337
+ print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
338
+
339
+
340
+ def load_ae(local_path: str) -> AutoEncoder:
341
+ ae_params = AutoEncoderParams(
342
+ resolution=256,
343
+ in_channels=3,
344
+ downsample=8,
345
+ ch=128,
346
+ out_ch=3,
347
+ ch_mult=[1, 2, 4, 4],
348
+ num_res_blocks=2,
349
+ z_channels=16,
350
+ scale_factor=0.3611,
351
+ shift_factor=0.1159,
352
+ )
353
+
354
+ # Loading the autoencoder
355
+ ae = AutoEncoder(ae_params)
356
+
357
+ if local_path is not None:
358
+ sd = load_sft(local_path)
359
+ missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
360
+ print_load_warning(missing, unexpected)
361
+ return ae, ae_params
modeling/bagel/__init__.py CHANGED
@@ -1,18 +1,18 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
-
5
- from .bagel import BagelConfig, Bagel
6
- from .qwen2_navit import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
7
- from .siglip_navit import SiglipVisionConfig, SiglipVisionModel
8
-
9
-
10
- __all__ = [
11
- 'BagelConfig',
12
- 'Bagel',
13
- 'Qwen2Config',
14
- 'Qwen2Model',
15
- 'Qwen2ForCausalLM',
16
- 'SiglipVisionConfig',
17
- 'SiglipVisionModel',
18
- ]
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ from .bagel import BagelConfig, Bagel
6
+ from .qwen2_navit import Qwen2Config, Qwen2Model, Qwen2ForCausalLM
7
+ from .siglip_navit import SiglipVisionConfig, SiglipVisionModel
8
+
9
+
10
+ __all__ = [
11
+ 'BagelConfig',
12
+ 'Bagel',
13
+ 'Qwen2Config',
14
+ 'Qwen2Model',
15
+ 'Qwen2ForCausalLM',
16
+ 'SiglipVisionConfig',
17
+ 'SiglipVisionModel',
18
+ ]
modeling/bagel/bagel.py CHANGED
@@ -1,1026 +1,1026 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- import copy
5
- from typing import List, Tuple, Optional
6
- import matplotlib.pyplot as plt
7
-
8
- from PIL import Image
9
- import torch
10
- import torch.nn.functional as F
11
- from torch import nn
12
- from torch.nn.attention.flex_attention import create_block_mask
13
- from transformers.configuration_utils import PretrainedConfig
14
- from transformers.modeling_utils import PreTrainedModel
15
-
16
- from data.data_utils import (
17
- create_sparse_mask,
18
- get_flattened_position_ids_extrapolate,
19
- get_flattened_position_ids_interpolate,
20
- patchify,
21
- )
22
- from .qwen2_navit import NaiveCache
23
- from .modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding
24
-
25
-
26
- class BagelConfig(PretrainedConfig):
27
- def __init__(
28
- self,
29
- visual_gen=True,
30
- visual_und=True,
31
- llm_config=None,
32
- vit_config=None,
33
- vae_config=None,
34
- latent_patch_size=2,
35
- max_latent_size=32,
36
- vit_max_num_patch_per_side=70,
37
- connector_act="gelu_pytorch_tanh",
38
- interpolate_pos=False,
39
- timestep_shift=1.0,
40
- **kwargs
41
- ):
42
- super().__init__(**kwargs)
43
- self.visual_gen = visual_gen
44
- self.visual_und = visual_und
45
- self.llm_config = llm_config
46
- self.vit_config = vit_config
47
- self.vae_config = vae_config
48
- self.latent_patch_size = latent_patch_size
49
- self.max_latent_size = max_latent_size
50
- self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
51
- self.connector_act = connector_act
52
- self.interpolate_pos = interpolate_pos
53
- self.timestep_shift = timestep_shift
54
-
55
-
56
- class Bagel(PreTrainedModel):
57
- config_class = BagelConfig
58
- base_model_prefix = 'bagel'
59
-
60
- def __init__(self, language_model, vit_model, config: BagelConfig):
61
- super().__init__(config)
62
- self.language_model = language_model
63
- self.hidden_size = config.llm_config.hidden_size
64
- self.use_moe = "Mo" in config.llm_config.layer_module
65
- self.num_heads = config.llm_config.num_attention_heads
66
-
67
- if config.visual_gen:
68
- self.latent_patch_size = config.latent_patch_size
69
- self.timestep_shift = config.timestep_shift
70
- self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
71
- self.max_latent_size = config.max_latent_size
72
- self.latent_channel = config.vae_config.z_channels
73
- self.patch_latent_dim = self.latent_patch_size ** 2 * self.latent_channel
74
- self.time_embedder = TimestepEmbedder(self.hidden_size)
75
- self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
76
- self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
77
- self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
78
-
79
- if config.visual_und:
80
- self.vit_model = vit_model
81
- self.vit_patch_size = config.vit_config.patch_size
82
- self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
83
- self.vit_hidden_size = config.vit_config.hidden_size
84
- self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act)
85
- self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size)
86
-
87
- if config.interpolate_pos:
88
- self.get_flattened_position_ids = get_flattened_position_ids_interpolate
89
- else:
90
- self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
91
-
92
- self.config = config
93
- self._init_weights()
94
-
95
- def _init_weights(self):
96
- if self.config.visual_gen:
97
- nn.init.constant_(self.llm2vae.weight, 0)
98
- nn.init.constant_(self.llm2vae.bias, 0)
99
-
100
- def forward(
101
- self,
102
- sequence_length: int,
103
- packed_text_ids: torch.LongTensor,
104
- packed_text_indexes: torch.LongTensor,
105
- sample_lens: List[int],
106
- packed_position_ids: torch.LongTensor,
107
- nested_attention_masks: List[torch.Tensor] = None,
108
- split_lens: List[int] = None,
109
- attn_modes: List[str] = None,
110
- # for visual understanding
111
- ce_loss_indexes: Optional[torch.BoolTensor] = None,
112
- packed_label_ids: Optional[torch.LongTensor] = None,
113
- packed_vit_tokens: Optional[torch.Tensor] = None,
114
- packed_vit_token_indexes: Optional[torch.LongTensor] = None,
115
- packed_vit_position_ids: Optional[torch.LongTensor] = None,
116
- vit_token_seqlens: Optional[torch.IntTensor] = None,
117
- # for visual generation
118
- padded_latent: Optional[torch.Tensor] = None,
119
- patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None,
120
- packed_latent_position_ids: Optional[torch.LongTensor] = None,
121
- packed_vae_token_indexes: Optional[torch.LongTensor] = None,
122
- packed_timesteps: Optional[torch.LongTensor] = None,
123
- mse_loss_indexes: Optional[torch.BoolTensor] = None,
124
- ) -> torch.Tensor:
125
- """
126
- Args:
127
- sequence_length: length of sequence.
128
- packed_text_ids: 1-D int tensor, packed text token ids.
129
- packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
130
- sample_lens: A list of N ints, length of each sample in packed_sequence.
131
- nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
132
- -inf means ignore.
133
- packed_position_ids: packed 1-D positions, an image has only one global position shared
134
- by all latent tokens.
135
-
136
- packed_vit_tokens: packed patchified image tokens for vit model.
137
- packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
138
- packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
139
- vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
140
- packed_label_ids: 1-D int tensor, packed label token ids.
141
- ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
142
-
143
- padded_latent: padded latent from VAE encoder.
144
- patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
145
- packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
146
- packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
147
- packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
148
- mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
149
- """
150
- packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
151
- packed_sequence = packed_text_embedding.new_zeros(size=(sequence_length, self.hidden_size))
152
- packed_sequence[packed_text_indexes] = packed_text_embedding
153
-
154
- if nested_attention_masks is None:
155
- sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, packed_text_embedding.device)
156
- seqlen = sum(sample_lens)
157
- block_mask = create_block_mask(
158
- sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen,
159
- device=packed_text_embedding.device, BLOCK_SIZE=128, _compile=True
160
- )
161
- attention_mask = block_mask
162
- else:
163
- attention_mask = nested_attention_masks
164
-
165
- if self.config.visual_und:
166
- cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
167
- cu_seqlens = cu_seqlens.to(torch.int32)
168
- max_seqlen = torch.max(vit_token_seqlens).item()
169
- packed_vit_token_embed = self.vit_model(
170
- packed_pixel_values=packed_vit_tokens,
171
- packed_flattened_position_ids=packed_vit_position_ids,
172
- cu_seqlens=cu_seqlens,
173
- max_seqlen=max_seqlen,
174
- )
175
- packed_vit_token_embed = self.connector(packed_vit_token_embed)
176
- vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids)
177
- packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb
178
- packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
179
-
180
- if self.config.visual_gen:
181
- p = self.latent_patch_size
182
- packed_latent = []
183
- for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
184
- latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
185
- latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
186
- packed_latent.append(latent)
187
- packed_latent_clean = torch.cat(packed_latent, dim=0)
188
-
189
- noise = torch.randn_like(packed_latent_clean)
190
- packed_timesteps = torch.sigmoid(packed_timesteps)
191
- packed_timesteps = self.timestep_shift * packed_timesteps / (1 + (self.timestep_shift - 1) * packed_timesteps)
192
- packed_latent = (1 - packed_timesteps[:, None]) * packed_latent_clean + packed_timesteps[:, None] * noise
193
- packed_timestep_embeds = self.time_embedder(packed_timesteps)
194
- latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids)
195
- packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + latent_token_pos_emb
196
- packed_sequence[packed_vae_token_indexes] = packed_latent
197
-
198
- extra_inputs = {}
199
- if self.use_moe:
200
- packed_und_token_indexes = packed_text_indexes
201
- if packed_vit_token_indexes is not None:
202
- packed_und_token_indexes=torch.cat([packed_text_indexes, packed_vit_token_indexes], dim=0)
203
- extra_inputs.update(
204
- packed_und_token_indexes=packed_und_token_indexes,
205
- packed_gen_token_indexes=packed_vae_token_indexes,
206
- )
207
-
208
- last_hidden_state = self.language_model(
209
- packed_sequence=packed_sequence,
210
- sample_lens=sample_lens,
211
- attention_mask=attention_mask,
212
- packed_position_ids=packed_position_ids,
213
- **extra_inputs,
214
- )
215
-
216
- mse = None
217
- if self.config.visual_gen:
218
- packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes])
219
- target = noise - packed_latent_clean # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
220
- has_mse = packed_timesteps > 0
221
- mse = (packed_mse_preds - target[has_mse]) ** 2
222
-
223
- ce = None
224
- if ce_loss_indexes is not None:
225
- packed_ce_preds = self.language_model.lm_head(last_hidden_state[ce_loss_indexes])
226
- ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none")
227
-
228
- return dict(mse=mse, ce=ce)
229
-
230
-
231
- def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids):
232
- packed_text_ids = list()
233
- packed_text_position_ids = list()
234
- text_token_lens = list()
235
- packed_text_indexes = list()
236
- packed_key_value_indexes = list()
237
-
238
- curr = 0
239
- newlens, new_rope = list(), list()
240
- for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope):
241
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
242
- curr += curr_kvlen
243
-
244
- text_ids = tokenizer.encode(prompt)
245
- text_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']]
246
- text_token_lens.append(len(text_ids))
247
- packed_text_ids.extend(text_ids)
248
- packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids)))
249
- packed_text_indexes.extend(range(curr, curr + len(text_ids)))
250
- newlens.append(curr_kvlen + len(text_ids))
251
- new_rope.append(curr_position_id + len(text_ids))
252
- curr += len(text_ids)
253
-
254
- generation_input = {
255
- "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int),
256
- "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
257
- "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long),
258
- "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
259
- "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
260
- "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
261
- }
262
-
263
- return generation_input, newlens, new_rope
264
-
265
- @torch.no_grad
266
- def forward_cache_update_text(
267
- self,
268
- past_key_values: NaiveCache,
269
- packed_text_ids: torch.IntTensor,
270
- packed_text_position_ids: torch.LongTensor,
271
- text_token_lens: torch.LongTensor,
272
- packed_text_indexes: torch.LongTensor,
273
- packed_key_value_indexes: torch.LongTensor,
274
- key_values_lens: torch.IntTensor,
275
- ):
276
- packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
277
-
278
- extra_inputs = {}
279
- if self.use_moe:
280
- extra_inputs = {"mode": "und"}
281
-
282
- output = self.language_model.forward_inference(
283
- packed_query_sequence=packed_text_embedding,
284
- query_lens=text_token_lens,
285
- packed_query_position_ids=packed_text_position_ids,
286
- packed_query_indexes=packed_text_indexes,
287
- past_key_values=past_key_values,
288
- packed_key_value_indexes=packed_key_value_indexes,
289
- key_values_lens=key_values_lens,
290
- update_past_key_values=True,
291
- is_causal=True,
292
- **extra_inputs,
293
- )
294
- past_key_values = output.past_key_values
295
-
296
- return past_key_values
297
-
298
- def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids):
299
- packed_vit_token_indexes = list()
300
- vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list()
301
- packed_text_ids, packed_text_indexes = list(), list()
302
- packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
303
- packed_key_value_indexes = list()
304
-
305
- _curr = curr = 0
306
- newlens, new_rope = list(), list()
307
- for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
308
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
309
- curr += curr_kvlen
310
-
311
- packed_text_ids.append(new_token_ids['start_of_image'])
312
- packed_text_indexes.append(_curr)
313
- packed_indexes.append(curr)
314
- curr += 1
315
- _curr += 1
316
-
317
- image_tensor = transforms(image)
318
- vit_position_ids = self.get_flattened_position_ids(
319
- image_tensor.size(1), image_tensor.size(2),
320
- self.vit_patch_size,
321
- max_num_patches_per_side=self.vit_max_num_patch_per_side
322
- )
323
- vit_tokens = patchify(image_tensor, self.vit_patch_size)
324
- packed_vit_tokens.append(vit_tokens)
325
- num_img_tokens = vit_tokens.shape[0]
326
- packed_vit_position_ids.append(vit_position_ids)
327
- vit_token_seqlens.append(num_img_tokens)
328
- packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
329
- packed_indexes.extend(range(curr, curr + num_img_tokens))
330
- curr += num_img_tokens
331
- _curr += num_img_tokens
332
-
333
- packed_text_ids.append(new_token_ids['end_of_image'])
334
- packed_text_indexes.append(_curr)
335
- packed_indexes.append(curr)
336
- curr += 1
337
- _curr += 1
338
-
339
- packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
340
- packed_seqlens.append(num_img_tokens + 2)
341
- newlens.append(curr_kvlen + num_img_tokens + 2)
342
- new_rope.append(curr_position_id + 1)
343
-
344
- generation_input = {
345
- "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
346
- "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
347
- "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
348
- "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
349
- "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
350
- "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long),
351
- "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
352
- "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
353
- "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
354
- "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
355
- "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
356
- }
357
-
358
- return generation_input, newlens, new_rope
359
-
360
- @torch.no_grad
361
- def forward_cache_update_vit(
362
- self,
363
- past_key_values: NaiveCache,
364
- packed_text_ids: torch.LongTensor,
365
- packed_text_indexes: torch.LongTensor,
366
- packed_vit_tokens: torch.Tensor,
367
- packed_vit_token_indexes: torch.LongTensor,
368
- packed_vit_position_ids: torch.LongTensor,
369
- vit_token_seqlens: torch.IntTensor,
370
- packed_position_ids: torch.LongTensor,
371
- packed_seqlens: torch.IntTensor,
372
- packed_indexes: torch.LongTensor,
373
- packed_key_value_indexes: torch.LongTensor,
374
- key_values_lens: torch.IntTensor,
375
- ):
376
- packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
377
- packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
378
- packed_sequence[packed_text_indexes] = packed_text_embedding
379
-
380
- cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
381
- cu_seqlens = cu_seqlens.to(torch.int32)
382
- max_seqlen = torch.max(vit_token_seqlens).item()
383
- packed_vit_token_embed = self.vit_model(
384
- packed_pixel_values=packed_vit_tokens,
385
- packed_flattened_position_ids=packed_vit_position_ids,
386
- cu_seqlens=cu_seqlens,
387
- max_seqlen=max_seqlen,
388
- )
389
- packed_vit_token_embed = self.connector(packed_vit_token_embed)
390
- pos_emb = self.vit_pos_embed(packed_vit_position_ids)
391
- packed_vit_token_embed = packed_vit_token_embed + pos_emb
392
- packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
393
-
394
- extra_inputs = {}
395
- if self.use_moe:
396
- extra_inputs = {"mode": "und"}
397
-
398
- output = self.language_model.forward_inference(
399
- packed_query_sequence=packed_sequence,
400
- query_lens=packed_seqlens,
401
- packed_query_position_ids=packed_position_ids,
402
- packed_query_indexes=packed_indexes,
403
- past_key_values=past_key_values,
404
- packed_key_value_indexes=packed_key_value_indexes,
405
- key_values_lens=key_values_lens,
406
- update_past_key_values=True,
407
- is_causal=False,
408
- **extra_inputs,
409
- )
410
- past_key_values = output.past_key_values
411
-
412
- return past_key_values
413
-
414
- def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0):
415
- patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
416
- packed_vae_token_indexes = list()
417
- packed_text_ids, packed_text_indexes = list(), list()
418
- packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
419
- packed_key_value_indexes = list()
420
-
421
- _curr = curr = 0
422
- vae_image_tensors = list()
423
- newlens, new_rope = list(), list()
424
- for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
425
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
426
- curr += curr_kvlen
427
-
428
- packed_text_ids.append(new_token_ids['start_of_image'])
429
- packed_text_indexes.append(_curr)
430
- packed_indexes.append(curr)
431
- curr += 1
432
- _curr += 1
433
-
434
- image_tensor = transforms(image)
435
- vae_image_tensors.append(image_tensor)
436
- vae_posiiton_ids = self.get_flattened_position_ids(
437
- image_tensor.size(1), image_tensor.size(2),
438
- self.latent_downsample,
439
- max_num_patches_per_side=self.max_latent_size
440
- )
441
- packed_vae_position_ids.append(vae_posiiton_ids)
442
- H, W = image_tensor.shape[1:]
443
- h = H // self.latent_downsample
444
- w = W // self.latent_downsample
445
- patchified_vae_latent_shapes.append((h, w))
446
-
447
- num_img_tokens = w * h
448
- packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
449
- packed_indexes.extend(range(curr, curr + num_img_tokens))
450
- curr += num_img_tokens
451
- _curr += num_img_tokens
452
-
453
- packed_text_ids.append(new_token_ids['end_of_image'])
454
- packed_text_indexes.append(_curr)
455
- packed_indexes.append(curr)
456
- curr += 1
457
- _curr += 1
458
-
459
- packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
460
- packed_seqlens.append(num_img_tokens + 2)
461
- newlens.append(curr_kvlen + num_img_tokens + 2)
462
- new_rope.append(curr_position_id + 1)
463
-
464
- image_sizes = [item.shape for item in vae_image_tensors]
465
- max_image_size = [max(item) for item in list(zip(*image_sizes))]
466
- padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
467
- for i, image_tensor in enumerate(vae_image_tensors):
468
- padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
469
-
470
- generation_input = {
471
- "padded_images": padded_images,
472
- "patchified_vae_latent_shapes": patchified_vae_latent_shapes,
473
- "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
474
- "packed_timesteps": torch.tensor([timestep]),
475
- "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
476
- "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
477
- "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
478
- "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
479
- "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
480
- "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
481
- "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
482
- "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
483
- }
484
-
485
- return generation_input, newlens, new_rope
486
-
487
- @torch.no_grad
488
- def forward_cache_update_vae(
489
- self,
490
- vae_model,
491
- past_key_values: NaiveCache,
492
- padded_images: torch.Tensor,
493
- patchified_vae_latent_shapes: List,
494
- packed_vae_position_ids: torch.LongTensor,
495
- packed_timesteps: torch.Tensor,
496
- packed_vae_token_indexes: torch.LongTensor,
497
- packed_text_ids: torch.LongTensor,
498
- packed_text_indexes: torch.LongTensor,
499
- packed_position_ids: torch.LongTensor,
500
- packed_seqlens: torch.IntTensor,
501
- packed_indexes: torch.LongTensor,
502
- key_values_lens: torch.IntTensor,
503
- packed_key_value_indexes: torch.Tensor,
504
- ):
505
- packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
506
- packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
507
- packed_sequence[packed_text_indexes] = packed_text_embedding
508
-
509
- padded_latent = vae_model.encode(padded_images)
510
-
511
- p = self.latent_patch_size
512
- packed_latent = list()
513
- for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
514
- latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
515
- latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
516
- packed_latent.append(latent)
517
- packed_latent = torch.cat(packed_latent, dim=0)
518
- packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
519
- packed_timestep_embeds = self.time_embedder(packed_timesteps)
520
- packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
521
- packed_sequence[packed_vae_token_indexes] = packed_latent
522
-
523
- extra_inputs = {}
524
- if self.use_moe:
525
- extra_inputs = {
526
- "mode": "gen",
527
- "packed_vae_token_indexes": packed_vae_token_indexes,
528
- "packed_text_indexes": packed_text_indexes
529
- }
530
-
531
- output = self.language_model.forward_inference(
532
- packed_query_sequence=packed_sequence,
533
- query_lens=packed_seqlens,
534
- packed_query_position_ids=packed_position_ids,
535
- packed_query_indexes=packed_indexes,
536
- past_key_values=past_key_values,
537
- key_values_lens=key_values_lens,
538
- packed_key_value_indexes=packed_key_value_indexes,
539
- update_past_key_values=True,
540
- is_causal=False,
541
- **extra_inputs,
542
- )
543
- past_key_values = output.past_key_values
544
-
545
- return past_key_values
546
-
547
- def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids):
548
- packed_text_ids, packed_text_indexes = list(), list()
549
- packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list()
550
- packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list()
551
- packed_key_value_indexes = list()
552
-
553
- query_curr = curr = 0
554
- for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
555
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
556
- curr += curr_kvlen
557
-
558
- packed_text_ids.append(new_token_ids['start_of_image'])
559
- packed_text_indexes.append(query_curr)
560
- packed_indexes.append(curr)
561
- curr += 1
562
- query_curr += 1
563
-
564
- vae_posiiton_ids = self.get_flattened_position_ids(
565
- H, W,
566
- self.latent_downsample,
567
- max_num_patches_per_side=self.max_latent_size
568
- )
569
- packed_vae_position_ids.append(vae_posiiton_ids)
570
-
571
- h, w = H // self.latent_downsample, W // self.latent_downsample
572
- num_image_tokens = h * w
573
- packed_init_noises.append(
574
- torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size ** 2)
575
- )
576
- packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens))
577
- packed_indexes.extend(range(curr, curr + num_image_tokens))
578
- curr += num_image_tokens
579
- query_curr += num_image_tokens
580
-
581
- packed_text_ids.append(new_token_ids['end_of_image'])
582
- packed_text_indexes.append(query_curr)
583
- packed_indexes.append(curr)
584
- curr += 1
585
- query_curr += 1
586
-
587
- packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
588
- packed_seqlens.append(num_image_tokens + 2)
589
-
590
- generation_input = {
591
- "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
592
- "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
593
- "packed_init_noises": torch.cat(packed_init_noises, dim=0),
594
- "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
595
- "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
596
- "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
597
- "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
598
- "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
599
- "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
600
- "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
601
- }
602
-
603
- return generation_input
604
-
605
- def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
606
- packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list()
607
-
608
- query_curr = curr = 0
609
- for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
610
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
611
- curr += curr_kvlen
612
-
613
- packed_indexes.append(curr)
614
- curr += 1
615
- query_curr += 1
616
-
617
- h, w = H // self.latent_downsample, W // self.latent_downsample
618
- num_image_tokens = h * w
619
- packed_indexes.extend(range(curr, curr + num_image_tokens))
620
- curr += num_image_tokens
621
- query_curr += num_image_tokens
622
-
623
- packed_indexes.append(curr)
624
- curr += 1
625
- query_curr += 1
626
-
627
- packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
628
-
629
- generation_input = {
630
- "cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
631
- "cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
632
- "cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long),
633
- "cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
634
- }
635
-
636
- return generation_input
637
-
638
- @torch.no_grad
639
- def generate_image(
640
- self,
641
- packed_text_ids: torch.LongTensor,
642
- packed_text_indexes: torch.LongTensor,
643
- packed_init_noises: torch.Tensor,
644
- packed_vae_position_ids: torch.LongTensor,
645
- packed_vae_token_indexes: torch.LongTensor,
646
- packed_seqlens: torch.IntTensor,
647
- packed_position_ids: torch.LongTensor,
648
- packed_indexes: torch.LongTensor,
649
- past_key_values: NaiveCache,
650
- key_values_lens: torch.IntTensor,
651
- packed_key_value_indexes: torch.LongTensor,
652
- num_timesteps: int = 24,
653
- timestep_shift: float = 1.0,
654
- cfg_renorm_min: float = 0.0,
655
- cfg_renorm_type: str = "global",
656
- cfg_interval: Optional[Tuple[float, float]] = [0, 1],
657
- # cfg_text
658
- cfg_text_scale: float = 1.0,
659
- cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
660
- cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
661
- cfg_text_past_key_values: Optional[NaiveCache] = None,
662
- cfg_text_key_values_lens: Optional[torch.IntTensor] = None,
663
- cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
664
- # cfg_img
665
- cfg_img_scale: float = 1.0,
666
- cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
667
- cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
668
- cfg_img_past_key_values: Optional[NaiveCache] = None,
669
- cfg_img_key_values_lens: Optional[torch.IntTensor] = None,
670
- cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
671
- cfg_type: str = "parallel",
672
- ):
673
- x_t = packed_init_noises
674
-
675
- timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device)
676
- timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps)
677
- dts = timesteps[:-1] - timesteps[1:]
678
- timesteps = timesteps[:-1]
679
-
680
- for i, t in enumerate(timesteps):
681
-
682
- timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
683
- if t > cfg_interval[0] and t <= cfg_interval[1]:
684
- cfg_text_scale_ = cfg_text_scale
685
- cfg_img_scale_ = cfg_img_scale
686
- else:
687
- cfg_text_scale_ = 1.0
688
- cfg_img_scale_ = 1.0
689
- v_t = self._forward_flow(
690
- x_t=x_t,
691
- timestep=timestep,
692
- packed_vae_token_indexes=packed_vae_token_indexes,
693
- packed_vae_position_ids=packed_vae_position_ids,
694
- packed_text_ids=packed_text_ids,
695
- packed_text_indexes=packed_text_indexes,
696
- packed_position_ids=packed_position_ids,
697
- packed_indexes=packed_indexes,
698
- packed_seqlens=packed_seqlens,
699
- key_values_lens=key_values_lens,
700
- past_key_values=past_key_values,
701
- packed_key_value_indexes=packed_key_value_indexes,
702
- cfg_renorm_min=cfg_renorm_min,
703
- cfg_renorm_type=cfg_renorm_type,
704
- # cfg_text
705
- cfg_text_scale=cfg_text_scale_,
706
- cfg_text_packed_position_ids=cfg_text_packed_position_ids,
707
- cfg_text_packed_query_indexes=cfg_text_packed_query_indexes,
708
- cfg_text_key_values_lens=cfg_text_key_values_lens,
709
- cfg_text_past_key_values=cfg_text_past_key_values,
710
- cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes,
711
- # cfg_img
712
- cfg_img_scale=cfg_img_scale_,
713
- cfg_img_packed_position_ids=cfg_img_packed_position_ids,
714
- cfg_img_packed_query_indexes=cfg_img_packed_query_indexes,
715
- cfg_img_key_values_lens=cfg_img_key_values_lens,
716
- cfg_img_past_key_values=cfg_img_past_key_values,
717
- cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
718
- cfg_type=cfg_type,
719
- )
720
-
721
- x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise
722
-
723
- unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
724
- return unpacked_latent
725
-
726
- @torch.no_grad
727
- def _forward_flow(
728
- self,
729
- x_t: torch.Tensor,
730
- timestep: torch.LongTensor,
731
- packed_vae_token_indexes: torch.LongTensor,
732
- packed_vae_position_ids: torch.LongTensor,
733
- packed_text_ids: torch.LongTensor,
734
- packed_text_indexes: torch.LongTensor,
735
- packed_indexes: torch.LongTensor,
736
- packed_position_ids: torch.LongTensor,
737
- packed_seqlens: torch.IntTensor,
738
- key_values_lens: torch.IntTensor,
739
- past_key_values: NaiveCache,
740
- packed_key_value_indexes: torch.LongTensor,
741
- cfg_renorm_min: float = 0.0,
742
- cfg_renorm_type: str = "global",
743
- # cfg_text
744
- cfg_text_scale: float = 1.0,
745
- cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
746
- cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
747
- cfg_text_key_values_lens: Optional[torch.Tensor] = None,
748
- cfg_text_past_key_values: Optional[NaiveCache] = None,
749
- cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
750
- # cfg_img
751
- cfg_img_scale: float = 1.0,
752
- cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
753
- cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
754
- cfg_img_key_values_lens: Optional[torch.Tensor] = None,
755
- cfg_img_past_key_values: Optional[NaiveCache] = None,
756
- cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
757
- cfg_type: str = "parallel",
758
- ):
759
- packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
760
- packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
761
- packed_sequence[packed_text_indexes] = packed_text_embedding
762
-
763
- assert timestep.unique().shape[0] == 1
764
- packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
765
- packed_timestep_embeds = self.time_embedder(timestep)
766
- x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
767
- packed_sequence[packed_vae_token_indexes] = x_t
768
-
769
- extra_inputs = {}
770
- if self.use_moe:
771
- extra_inputs = {
772
- "mode": "gen",
773
- "packed_vae_token_indexes": packed_vae_token_indexes,
774
- "packed_text_indexes": packed_text_indexes
775
- }
776
-
777
- output = self.language_model.forward_inference(
778
- packed_query_sequence=packed_sequence,
779
- query_lens=packed_seqlens,
780
- packed_query_position_ids=packed_position_ids,
781
- packed_query_indexes=packed_indexes,
782
- past_key_values=past_key_values,
783
- key_values_lens=key_values_lens,
784
- packed_key_value_indexes=packed_key_value_indexes,
785
- update_past_key_values=False,
786
- is_causal=False,
787
- **extra_inputs,
788
- )
789
- v_t = self.llm2vae(output.packed_query_sequence)
790
- v_t = v_t[packed_vae_token_indexes]
791
-
792
- if cfg_text_scale > 1.0:
793
- cfg_text_output = self.language_model.forward_inference(
794
- packed_query_sequence=packed_sequence,
795
- query_lens=packed_seqlens,
796
- packed_query_position_ids=cfg_text_packed_position_ids,
797
- packed_query_indexes=cfg_text_packed_query_indexes,
798
- past_key_values=cfg_text_past_key_values,
799
- key_values_lens=cfg_text_key_values_lens,
800
- packed_key_value_indexes=cfg_text_packed_key_value_indexes,
801
- update_past_key_values=False,
802
- is_causal=False,
803
- **extra_inputs,
804
- )
805
- cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence)
806
- cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes]
807
-
808
- if cfg_img_scale > 1.0:
809
- cfg_img_output = self.language_model.forward_inference(
810
- packed_query_sequence=packed_sequence,
811
- query_lens=packed_seqlens,
812
- packed_query_position_ids=cfg_img_packed_position_ids,
813
- packed_query_indexes=cfg_img_packed_query_indexes,
814
- past_key_values=cfg_img_past_key_values,
815
- key_values_lens=cfg_img_key_values_lens,
816
- packed_key_value_indexes=cfg_img_packed_key_value_indexes,
817
- update_past_key_values=False,
818
- is_causal=False,
819
- **extra_inputs,
820
- )
821
- cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence)
822
- cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes]
823
-
824
- if cfg_text_scale > 1.0:
825
- if cfg_renorm_type == "text_channel":
826
- v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
827
- norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
828
- norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
829
- scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
830
- v_t_text = v_t_text_ * scale
831
- if cfg_img_scale > 1.0:
832
- v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
833
- else:
834
- v_t = v_t_text
835
- else:
836
- v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
837
-
838
- if cfg_img_scale > 1.0:
839
- v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
840
- else:
841
- v_t_ = v_t_text_
842
-
843
- # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
844
- if cfg_renorm_type == "global":
845
- norm_v_t = torch.norm(v_t)
846
- norm_v_t_ = torch.norm(v_t_)
847
- elif cfg_renorm_type == "channel":
848
- norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
849
- norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
850
- else:
851
- raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted")
852
- scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
853
- v_t = v_t_ * scale
854
- else:
855
- # No CFG
856
- pass
857
-
858
- return v_t
859
-
860
- def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
861
- packed_start_tokens, packed_key_value_indexes = list(), list()
862
- packed_query_position_ids = list()
863
-
864
- curr = 0
865
- for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
866
- packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
867
- packed_start_tokens.append(new_token_ids['bos_token_id'])
868
- packed_query_position_ids.append(curr_position_id)
869
- curr += curr_kvlen
870
-
871
- generation_input = {
872
- "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
873
- "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
874
- "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
875
- "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
876
- }
877
-
878
- return generation_input
879
-
880
- @torch.no_grad
881
- def generate_text(
882
- self,
883
- past_key_values: NaiveCache,
884
- packed_key_value_indexes: torch.LongTensor,
885
- key_values_lens: torch.IntTensor,
886
- packed_start_tokens: torch.LongTensor,
887
- packed_query_position_ids: torch.LongTensor,
888
- max_length: int,
889
- do_sample: bool = False,
890
- temperature: float = 1.0,
891
- end_token_id: int = None,
892
- ):
893
- step = 0
894
- generated_sequence = []
895
- curr_tokens = packed_start_tokens
896
- while step < max_length:
897
- generated_sequence.append(curr_tokens)
898
- packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
899
- query_lens = torch.ones_like(curr_tokens)
900
- packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
901
- 0, len(key_values_lens),
902
- device=key_values_lens.device,
903
- dtype=key_values_lens.dtype
904
- )
905
-
906
- uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
907
- for i in range(len(uppacked)):
908
- uppacked[i] += i
909
- packed_key_value_indexes = torch.cat(uppacked, dim=0)
910
-
911
- extra_inputs = {}
912
- if self.use_moe:
913
- extra_inputs = {"mode": "und"}
914
-
915
- output = self.language_model.forward_inference(
916
- packed_query_sequence=packed_text_embedding,
917
- query_lens=query_lens,
918
- packed_query_position_ids=packed_query_position_ids,
919
- packed_query_indexes=packed_query_indexes,
920
- past_key_values=past_key_values,
921
- key_values_lens=key_values_lens,
922
- packed_key_value_indexes=packed_key_value_indexes,
923
- update_past_key_values=True,
924
- is_causal=True,
925
- **extra_inputs,
926
- )
927
- past_key_values = output.past_key_values
928
- packed_query_sequence = output.packed_query_sequence
929
- pred_logits = self.language_model.lm_head(packed_query_sequence)
930
-
931
- if do_sample:
932
- probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
933
- curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
934
- else:
935
- curr_tokens = torch.argmax(pred_logits, dim=-1)
936
-
937
- uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
938
- for i in range(len(uppacked)):
939
- uppacked[i] = torch.cat(
940
- [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
941
- )
942
- packed_key_value_indexes = torch.cat(uppacked, dim=0)
943
- key_values_lens = key_values_lens + 1
944
- packed_query_position_ids = packed_query_position_ids + 1
945
- step += 1
946
-
947
- if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1
948
- break
949
-
950
- output_device = generated_sequence[0].device
951
- return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)
952
-
953
- # for evaluation
954
- @torch.no_grad()
955
- def chat(
956
- self,
957
- tokenizer,
958
- new_token_ids,
959
- image_transform,
960
- images,
961
- prompt,
962
- max_length: int,
963
- do_sample: bool = False,
964
- temperature: float = 1.0,
965
- ):
966
- device = next(self.parameters()).device
967
-
968
- if isinstance(new_token_ids, dict):
969
- for k, v in new_token_ids.items():
970
- if torch.is_tensor(v):
971
- new_token_ids[k] = v.to(device)
972
- elif torch.is_tensor(new_token_ids):
973
- new_token_ids = new_token_ids.to(device)
974
-
975
- # prefill
976
- past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers)
977
- newlens = [0]
978
- new_rope = [0]
979
-
980
- # add images
981
- for image in images:
982
- generation_input, newlens, new_rope = self.prepare_vit_images(
983
- curr_kvlens=newlens,
984
- curr_rope=new_rope,
985
- images=[image],
986
- transforms=image_transform,
987
- new_token_ids=new_token_ids,
988
- )
989
- for k, v in generation_input.items():
990
- if torch.is_tensor(v):
991
- generation_input[k] = v.to(device)
992
- with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
993
- past_key_values = self.forward_cache_update_vit(past_key_values, **generation_input)
994
-
995
- # add text
996
- generation_input, newlens, new_rope = self.prepare_prompts(
997
- curr_kvlens=newlens,
998
- curr_rope=new_rope,
999
- prompts=[prompt],
1000
- tokenizer=tokenizer,
1001
- new_token_ids=new_token_ids,
1002
- )
1003
- for k, v in generation_input.items():
1004
- if torch.is_tensor(v):
1005
- generation_input[k] = v.to(device)
1006
- with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1007
- past_key_values = self.forward_cache_update_text(past_key_values, **generation_input)
1008
-
1009
- # decode
1010
- generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids)
1011
- for k, v in generation_input.items():
1012
- if torch.is_tensor(v):
1013
- generation_input[k] = v.to(device)
1014
- with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1015
- unpacked_latent = self.generate_text(
1016
- past_key_values=past_key_values,
1017
- max_length=max_length,
1018
- do_sample=do_sample,
1019
- temperature=temperature,
1020
- end_token_id=new_token_ids['eos_token_id'],
1021
- **generation_input,
1022
- )
1023
- output = tokenizer.decode(unpacked_latent[:,0])
1024
- output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
1025
-
1026
- return output
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import copy
5
+ from typing import List, Tuple, Optional
6
+ import matplotlib.pyplot as plt
7
+
8
+ from PIL import Image
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.nn.attention.flex_attention import create_block_mask
13
+ from transformers.configuration_utils import PretrainedConfig
14
+ from transformers.modeling_utils import PreTrainedModel
15
+
16
+ from data.data_utils import (
17
+ create_sparse_mask,
18
+ get_flattened_position_ids_extrapolate,
19
+ get_flattened_position_ids_interpolate,
20
+ patchify,
21
+ )
22
+ from .qwen2_navit import NaiveCache
23
+ from .modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding
24
+
25
+
26
+ class BagelConfig(PretrainedConfig):
27
+ def __init__(
28
+ self,
29
+ visual_gen=True,
30
+ visual_und=True,
31
+ llm_config=None,
32
+ vit_config=None,
33
+ vae_config=None,
34
+ latent_patch_size=2,
35
+ max_latent_size=32,
36
+ vit_max_num_patch_per_side=70,
37
+ connector_act="gelu_pytorch_tanh",
38
+ interpolate_pos=False,
39
+ timestep_shift=1.0,
40
+ **kwargs
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.visual_gen = visual_gen
44
+ self.visual_und = visual_und
45
+ self.llm_config = llm_config
46
+ self.vit_config = vit_config
47
+ self.vae_config = vae_config
48
+ self.latent_patch_size = latent_patch_size
49
+ self.max_latent_size = max_latent_size
50
+ self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
51
+ self.connector_act = connector_act
52
+ self.interpolate_pos = interpolate_pos
53
+ self.timestep_shift = timestep_shift
54
+
55
+
56
+ class Bagel(PreTrainedModel):
57
+ config_class = BagelConfig
58
+ base_model_prefix = 'bagel'
59
+
60
+ def __init__(self, language_model, vit_model, config: BagelConfig):
61
+ super().__init__(config)
62
+ self.language_model = language_model
63
+ self.hidden_size = config.llm_config.hidden_size
64
+ self.use_moe = "Mo" in config.llm_config.layer_module
65
+ self.num_heads = config.llm_config.num_attention_heads
66
+
67
+ if config.visual_gen:
68
+ self.latent_patch_size = config.latent_patch_size
69
+ self.timestep_shift = config.timestep_shift
70
+ self.latent_downsample = config.vae_config.downsample * config.latent_patch_size
71
+ self.max_latent_size = config.max_latent_size
72
+ self.latent_channel = config.vae_config.z_channels
73
+ self.patch_latent_dim = self.latent_patch_size ** 2 * self.latent_channel
74
+ self.time_embedder = TimestepEmbedder(self.hidden_size)
75
+ self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
76
+ self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
77
+ self.latent_pos_embed = PositionEmbedding(self.max_latent_size, self.hidden_size)
78
+
79
+ if config.visual_und:
80
+ self.vit_model = vit_model
81
+ self.vit_patch_size = config.vit_config.patch_size
82
+ self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
83
+ self.vit_hidden_size = config.vit_config.hidden_size
84
+ self.connector = MLPconnector(self.vit_hidden_size, self.hidden_size, config.connector_act)
85
+ self.vit_pos_embed = PositionEmbedding(self.vit_max_num_patch_per_side, self.hidden_size)
86
+
87
+ if config.interpolate_pos:
88
+ self.get_flattened_position_ids = get_flattened_position_ids_interpolate
89
+ else:
90
+ self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
91
+
92
+ self.config = config
93
+ self._init_weights()
94
+
95
+ def _init_weights(self):
96
+ if self.config.visual_gen:
97
+ nn.init.constant_(self.llm2vae.weight, 0)
98
+ nn.init.constant_(self.llm2vae.bias, 0)
99
+
100
+ def forward(
101
+ self,
102
+ sequence_length: int,
103
+ packed_text_ids: torch.LongTensor,
104
+ packed_text_indexes: torch.LongTensor,
105
+ sample_lens: List[int],
106
+ packed_position_ids: torch.LongTensor,
107
+ nested_attention_masks: List[torch.Tensor] = None,
108
+ split_lens: List[int] = None,
109
+ attn_modes: List[str] = None,
110
+ # for visual understanding
111
+ ce_loss_indexes: Optional[torch.BoolTensor] = None,
112
+ packed_label_ids: Optional[torch.LongTensor] = None,
113
+ packed_vit_tokens: Optional[torch.Tensor] = None,
114
+ packed_vit_token_indexes: Optional[torch.LongTensor] = None,
115
+ packed_vit_position_ids: Optional[torch.LongTensor] = None,
116
+ vit_token_seqlens: Optional[torch.IntTensor] = None,
117
+ # for visual generation
118
+ padded_latent: Optional[torch.Tensor] = None,
119
+ patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None,
120
+ packed_latent_position_ids: Optional[torch.LongTensor] = None,
121
+ packed_vae_token_indexes: Optional[torch.LongTensor] = None,
122
+ packed_timesteps: Optional[torch.LongTensor] = None,
123
+ mse_loss_indexes: Optional[torch.BoolTensor] = None,
124
+ ) -> torch.Tensor:
125
+ """
126
+ Args:
127
+ sequence_length: length of sequence.
128
+ packed_text_ids: 1-D int tensor, packed text token ids.
129
+ packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
130
+ sample_lens: A list of N ints, length of each sample in packed_sequence.
131
+ nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
132
+ -inf means ignore.
133
+ packed_position_ids: packed 1-D positions, an image has only one global position shared
134
+ by all latent tokens.
135
+
136
+ packed_vit_tokens: packed patchified image tokens for vit model.
137
+ packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
138
+ packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
139
+ vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
140
+ packed_label_ids: 1-D int tensor, packed label token ids.
141
+ ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
142
+
143
+ padded_latent: padded latent from VAE encoder.
144
+ patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
145
+ packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
146
+ packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
147
+ packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
148
+ mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
149
+ """
150
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
151
+ packed_sequence = packed_text_embedding.new_zeros(size=(sequence_length, self.hidden_size))
152
+ packed_sequence[packed_text_indexes] = packed_text_embedding
153
+
154
+ if nested_attention_masks is None:
155
+ sparse_mask = create_sparse_mask(sample_lens, split_lens, attn_modes, packed_text_embedding.device)
156
+ seqlen = sum(sample_lens)
157
+ block_mask = create_block_mask(
158
+ sparse_mask, B=1, H=self.num_heads, Q_LEN=seqlen, KV_LEN=seqlen,
159
+ device=packed_text_embedding.device, BLOCK_SIZE=128, _compile=True
160
+ )
161
+ attention_mask = block_mask
162
+ else:
163
+ attention_mask = nested_attention_masks
164
+
165
+ if self.config.visual_und:
166
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
167
+ cu_seqlens = cu_seqlens.to(torch.int32)
168
+ max_seqlen = torch.max(vit_token_seqlens).item()
169
+ packed_vit_token_embed = self.vit_model(
170
+ packed_pixel_values=packed_vit_tokens,
171
+ packed_flattened_position_ids=packed_vit_position_ids,
172
+ cu_seqlens=cu_seqlens,
173
+ max_seqlen=max_seqlen,
174
+ )
175
+ packed_vit_token_embed = self.connector(packed_vit_token_embed)
176
+ vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids)
177
+ packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb
178
+ packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
179
+
180
+ if self.config.visual_gen:
181
+ p = self.latent_patch_size
182
+ packed_latent = []
183
+ for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
184
+ latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
185
+ latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
186
+ packed_latent.append(latent)
187
+ packed_latent_clean = torch.cat(packed_latent, dim=0)
188
+
189
+ noise = torch.randn_like(packed_latent_clean)
190
+ packed_timesteps = torch.sigmoid(packed_timesteps)
191
+ packed_timesteps = self.timestep_shift * packed_timesteps / (1 + (self.timestep_shift - 1) * packed_timesteps)
192
+ packed_latent = (1 - packed_timesteps[:, None]) * packed_latent_clean + packed_timesteps[:, None] * noise
193
+ packed_timestep_embeds = self.time_embedder(packed_timesteps)
194
+ latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids)
195
+ packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + latent_token_pos_emb
196
+ packed_sequence[packed_vae_token_indexes] = packed_latent
197
+
198
+ extra_inputs = {}
199
+ if self.use_moe:
200
+ packed_und_token_indexes = packed_text_indexes
201
+ if packed_vit_token_indexes is not None:
202
+ packed_und_token_indexes=torch.cat([packed_text_indexes, packed_vit_token_indexes], dim=0)
203
+ extra_inputs.update(
204
+ packed_und_token_indexes=packed_und_token_indexes,
205
+ packed_gen_token_indexes=packed_vae_token_indexes,
206
+ )
207
+
208
+ last_hidden_state = self.language_model(
209
+ packed_sequence=packed_sequence,
210
+ sample_lens=sample_lens,
211
+ attention_mask=attention_mask,
212
+ packed_position_ids=packed_position_ids,
213
+ **extra_inputs,
214
+ )
215
+
216
+ mse = None
217
+ if self.config.visual_gen:
218
+ packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes])
219
+ target = noise - packed_latent_clean # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
220
+ has_mse = packed_timesteps > 0
221
+ mse = (packed_mse_preds - target[has_mse]) ** 2
222
+
223
+ ce = None
224
+ if ce_loss_indexes is not None:
225
+ packed_ce_preds = self.language_model.lm_head(last_hidden_state[ce_loss_indexes])
226
+ ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none")
227
+
228
+ return dict(mse=mse, ce=ce)
229
+
230
+
231
+ def prepare_prompts(self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids):
232
+ packed_text_ids = list()
233
+ packed_text_position_ids = list()
234
+ text_token_lens = list()
235
+ packed_text_indexes = list()
236
+ packed_key_value_indexes = list()
237
+
238
+ curr = 0
239
+ newlens, new_rope = list(), list()
240
+ for prompt, curr_kvlen, curr_position_id in zip(prompts, curr_kvlens, curr_rope):
241
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
242
+ curr += curr_kvlen
243
+
244
+ text_ids = tokenizer.encode(prompt)
245
+ text_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']]
246
+ text_token_lens.append(len(text_ids))
247
+ packed_text_ids.extend(text_ids)
248
+ packed_text_position_ids.extend(range(curr_position_id, curr_position_id + len(text_ids)))
249
+ packed_text_indexes.extend(range(curr, curr + len(text_ids)))
250
+ newlens.append(curr_kvlen + len(text_ids))
251
+ new_rope.append(curr_position_id + len(text_ids))
252
+ curr += len(text_ids)
253
+
254
+ generation_input = {
255
+ "text_token_lens": torch.tensor(text_token_lens, dtype=torch.int),
256
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
257
+ "packed_text_position_ids": torch.tensor(packed_text_position_ids, dtype=torch.long),
258
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
259
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
260
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
261
+ }
262
+
263
+ return generation_input, newlens, new_rope
264
+
265
+ @torch.no_grad
266
+ def forward_cache_update_text(
267
+ self,
268
+ past_key_values: NaiveCache,
269
+ packed_text_ids: torch.IntTensor,
270
+ packed_text_position_ids: torch.LongTensor,
271
+ text_token_lens: torch.LongTensor,
272
+ packed_text_indexes: torch.LongTensor,
273
+ packed_key_value_indexes: torch.LongTensor,
274
+ key_values_lens: torch.IntTensor,
275
+ ):
276
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
277
+
278
+ extra_inputs = {}
279
+ if self.use_moe:
280
+ extra_inputs = {"mode": "und"}
281
+
282
+ output = self.language_model.forward_inference(
283
+ packed_query_sequence=packed_text_embedding,
284
+ query_lens=text_token_lens,
285
+ packed_query_position_ids=packed_text_position_ids,
286
+ packed_query_indexes=packed_text_indexes,
287
+ past_key_values=past_key_values,
288
+ packed_key_value_indexes=packed_key_value_indexes,
289
+ key_values_lens=key_values_lens,
290
+ update_past_key_values=True,
291
+ is_causal=True,
292
+ **extra_inputs,
293
+ )
294
+ past_key_values = output.past_key_values
295
+
296
+ return past_key_values
297
+
298
+ def prepare_vit_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids):
299
+ packed_vit_token_indexes = list()
300
+ vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = list(), list(), list()
301
+ packed_text_ids, packed_text_indexes = list(), list()
302
+ packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
303
+ packed_key_value_indexes = list()
304
+
305
+ _curr = curr = 0
306
+ newlens, new_rope = list(), list()
307
+ for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
308
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
309
+ curr += curr_kvlen
310
+
311
+ packed_text_ids.append(new_token_ids['start_of_image'])
312
+ packed_text_indexes.append(_curr)
313
+ packed_indexes.append(curr)
314
+ curr += 1
315
+ _curr += 1
316
+
317
+ image_tensor = transforms(image)
318
+ vit_position_ids = self.get_flattened_position_ids(
319
+ image_tensor.size(1), image_tensor.size(2),
320
+ self.vit_patch_size,
321
+ max_num_patches_per_side=self.vit_max_num_patch_per_side
322
+ )
323
+ vit_tokens = patchify(image_tensor, self.vit_patch_size)
324
+ packed_vit_tokens.append(vit_tokens)
325
+ num_img_tokens = vit_tokens.shape[0]
326
+ packed_vit_position_ids.append(vit_position_ids)
327
+ vit_token_seqlens.append(num_img_tokens)
328
+ packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
329
+ packed_indexes.extend(range(curr, curr + num_img_tokens))
330
+ curr += num_img_tokens
331
+ _curr += num_img_tokens
332
+
333
+ packed_text_ids.append(new_token_ids['end_of_image'])
334
+ packed_text_indexes.append(_curr)
335
+ packed_indexes.append(curr)
336
+ curr += 1
337
+ _curr += 1
338
+
339
+ packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
340
+ packed_seqlens.append(num_img_tokens + 2)
341
+ newlens.append(curr_kvlen + num_img_tokens + 2)
342
+ new_rope.append(curr_position_id + 1)
343
+
344
+ generation_input = {
345
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
346
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
347
+ "vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
348
+ "packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
349
+ "packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
350
+ "packed_vit_token_indexes": torch.tensor(packed_vit_token_indexes, dtype=torch.long),
351
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
352
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
353
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
354
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
355
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
356
+ }
357
+
358
+ return generation_input, newlens, new_rope
359
+
360
+ @torch.no_grad
361
+ def forward_cache_update_vit(
362
+ self,
363
+ past_key_values: NaiveCache,
364
+ packed_text_ids: torch.LongTensor,
365
+ packed_text_indexes: torch.LongTensor,
366
+ packed_vit_tokens: torch.Tensor,
367
+ packed_vit_token_indexes: torch.LongTensor,
368
+ packed_vit_position_ids: torch.LongTensor,
369
+ vit_token_seqlens: torch.IntTensor,
370
+ packed_position_ids: torch.LongTensor,
371
+ packed_seqlens: torch.IntTensor,
372
+ packed_indexes: torch.LongTensor,
373
+ packed_key_value_indexes: torch.LongTensor,
374
+ key_values_lens: torch.IntTensor,
375
+ ):
376
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
377
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
378
+ packed_sequence[packed_text_indexes] = packed_text_embedding
379
+
380
+ cu_seqlens = torch.nn.functional.pad(torch.cumsum(vit_token_seqlens, dim=0), (1, 0))
381
+ cu_seqlens = cu_seqlens.to(torch.int32)
382
+ max_seqlen = torch.max(vit_token_seqlens).item()
383
+ packed_vit_token_embed = self.vit_model(
384
+ packed_pixel_values=packed_vit_tokens,
385
+ packed_flattened_position_ids=packed_vit_position_ids,
386
+ cu_seqlens=cu_seqlens,
387
+ max_seqlen=max_seqlen,
388
+ )
389
+ packed_vit_token_embed = self.connector(packed_vit_token_embed)
390
+ pos_emb = self.vit_pos_embed(packed_vit_position_ids)
391
+ packed_vit_token_embed = packed_vit_token_embed + pos_emb
392
+ packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
393
+
394
+ extra_inputs = {}
395
+ if self.use_moe:
396
+ extra_inputs = {"mode": "und"}
397
+
398
+ output = self.language_model.forward_inference(
399
+ packed_query_sequence=packed_sequence,
400
+ query_lens=packed_seqlens,
401
+ packed_query_position_ids=packed_position_ids,
402
+ packed_query_indexes=packed_indexes,
403
+ past_key_values=past_key_values,
404
+ packed_key_value_indexes=packed_key_value_indexes,
405
+ key_values_lens=key_values_lens,
406
+ update_past_key_values=True,
407
+ is_causal=False,
408
+ **extra_inputs,
409
+ )
410
+ past_key_values = output.past_key_values
411
+
412
+ return past_key_values
413
+
414
+ def prepare_vae_images(self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0):
415
+ patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
416
+ packed_vae_token_indexes = list()
417
+ packed_text_ids, packed_text_indexes = list(), list()
418
+ packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
419
+ packed_key_value_indexes = list()
420
+
421
+ _curr = curr = 0
422
+ vae_image_tensors = list()
423
+ newlens, new_rope = list(), list()
424
+ for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
425
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
426
+ curr += curr_kvlen
427
+
428
+ packed_text_ids.append(new_token_ids['start_of_image'])
429
+ packed_text_indexes.append(_curr)
430
+ packed_indexes.append(curr)
431
+ curr += 1
432
+ _curr += 1
433
+
434
+ image_tensor = transforms(image)
435
+ vae_image_tensors.append(image_tensor)
436
+ vae_posiiton_ids = self.get_flattened_position_ids(
437
+ image_tensor.size(1), image_tensor.size(2),
438
+ self.latent_downsample,
439
+ max_num_patches_per_side=self.max_latent_size
440
+ )
441
+ packed_vae_position_ids.append(vae_posiiton_ids)
442
+ H, W = image_tensor.shape[1:]
443
+ h = H // self.latent_downsample
444
+ w = W // self.latent_downsample
445
+ patchified_vae_latent_shapes.append((h, w))
446
+
447
+ num_img_tokens = w * h
448
+ packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
449
+ packed_indexes.extend(range(curr, curr + num_img_tokens))
450
+ curr += num_img_tokens
451
+ _curr += num_img_tokens
452
+
453
+ packed_text_ids.append(new_token_ids['end_of_image'])
454
+ packed_text_indexes.append(_curr)
455
+ packed_indexes.append(curr)
456
+ curr += 1
457
+ _curr += 1
458
+
459
+ packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
460
+ packed_seqlens.append(num_img_tokens + 2)
461
+ newlens.append(curr_kvlen + num_img_tokens + 2)
462
+ new_rope.append(curr_position_id + 1)
463
+
464
+ image_sizes = [item.shape for item in vae_image_tensors]
465
+ max_image_size = [max(item) for item in list(zip(*image_sizes))]
466
+ padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
467
+ for i, image_tensor in enumerate(vae_image_tensors):
468
+ padded_images[i, :, :image_tensor.shape[1], :image_tensor.shape[2]] = image_tensor
469
+
470
+ generation_input = {
471
+ "padded_images": padded_images,
472
+ "patchified_vae_latent_shapes": patchified_vae_latent_shapes,
473
+ "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
474
+ "packed_timesteps": torch.tensor([timestep]),
475
+ "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
476
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
477
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
478
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
479
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
480
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
481
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
482
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
483
+ }
484
+
485
+ return generation_input, newlens, new_rope
486
+
487
+ @torch.no_grad
488
+ def forward_cache_update_vae(
489
+ self,
490
+ vae_model,
491
+ past_key_values: NaiveCache,
492
+ padded_images: torch.Tensor,
493
+ patchified_vae_latent_shapes: List,
494
+ packed_vae_position_ids: torch.LongTensor,
495
+ packed_timesteps: torch.Tensor,
496
+ packed_vae_token_indexes: torch.LongTensor,
497
+ packed_text_ids: torch.LongTensor,
498
+ packed_text_indexes: torch.LongTensor,
499
+ packed_position_ids: torch.LongTensor,
500
+ packed_seqlens: torch.IntTensor,
501
+ packed_indexes: torch.LongTensor,
502
+ key_values_lens: torch.IntTensor,
503
+ packed_key_value_indexes: torch.Tensor,
504
+ ):
505
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
506
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
507
+ packed_sequence[packed_text_indexes] = packed_text_embedding
508
+
509
+ padded_latent = vae_model.encode(padded_images)
510
+
511
+ p = self.latent_patch_size
512
+ packed_latent = list()
513
+ for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
514
+ latent = latent[:, :h * p, :w * p].reshape(self.latent_channel, h, p, w, p)
515
+ latent = torch.einsum("chpwq->hwpqc", latent).reshape(-1, p * p * self.latent_channel)
516
+ packed_latent.append(latent)
517
+ packed_latent = torch.cat(packed_latent, dim=0)
518
+ packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
519
+ packed_timestep_embeds = self.time_embedder(packed_timesteps)
520
+ packed_latent = self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
521
+ packed_sequence[packed_vae_token_indexes] = packed_latent
522
+
523
+ extra_inputs = {}
524
+ if self.use_moe:
525
+ extra_inputs = {
526
+ "mode": "gen",
527
+ "packed_vae_token_indexes": packed_vae_token_indexes,
528
+ "packed_text_indexes": packed_text_indexes
529
+ }
530
+
531
+ output = self.language_model.forward_inference(
532
+ packed_query_sequence=packed_sequence,
533
+ query_lens=packed_seqlens,
534
+ packed_query_position_ids=packed_position_ids,
535
+ packed_query_indexes=packed_indexes,
536
+ past_key_values=past_key_values,
537
+ key_values_lens=key_values_lens,
538
+ packed_key_value_indexes=packed_key_value_indexes,
539
+ update_past_key_values=True,
540
+ is_causal=False,
541
+ **extra_inputs,
542
+ )
543
+ past_key_values = output.past_key_values
544
+
545
+ return past_key_values
546
+
547
+ def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids):
548
+ packed_text_ids, packed_text_indexes = list(), list()
549
+ packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = list(), list(), list()
550
+ packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list()
551
+ packed_key_value_indexes = list()
552
+
553
+ query_curr = curr = 0
554
+ for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
555
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
556
+ curr += curr_kvlen
557
+
558
+ packed_text_ids.append(new_token_ids['start_of_image'])
559
+ packed_text_indexes.append(query_curr)
560
+ packed_indexes.append(curr)
561
+ curr += 1
562
+ query_curr += 1
563
+
564
+ vae_posiiton_ids = self.get_flattened_position_ids(
565
+ H, W,
566
+ self.latent_downsample,
567
+ max_num_patches_per_side=self.max_latent_size
568
+ )
569
+ packed_vae_position_ids.append(vae_posiiton_ids)
570
+
571
+ h, w = H // self.latent_downsample, W // self.latent_downsample
572
+ num_image_tokens = h * w
573
+ packed_init_noises.append(
574
+ torch.randn(num_image_tokens, self.latent_channel * self.latent_patch_size ** 2)
575
+ )
576
+ packed_vae_token_indexes.extend(range(query_curr, query_curr + num_image_tokens))
577
+ packed_indexes.extend(range(curr, curr + num_image_tokens))
578
+ curr += num_image_tokens
579
+ query_curr += num_image_tokens
580
+
581
+ packed_text_ids.append(new_token_ids['end_of_image'])
582
+ packed_text_indexes.append(query_curr)
583
+ packed_indexes.append(curr)
584
+ curr += 1
585
+ query_curr += 1
586
+
587
+ packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
588
+ packed_seqlens.append(num_image_tokens + 2)
589
+
590
+ generation_input = {
591
+ "packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
592
+ "packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
593
+ "packed_init_noises": torch.cat(packed_init_noises, dim=0),
594
+ "packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
595
+ "packed_vae_token_indexes": torch.tensor(packed_vae_token_indexes, dtype=torch.long),
596
+ "packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
597
+ "packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
598
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
599
+ "packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
600
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
601
+ }
602
+
603
+ return generation_input
604
+
605
+ def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
606
+ packed_position_ids, packed_indexes, packed_key_value_indexes = list(), list(), list()
607
+
608
+ query_curr = curr = 0
609
+ for (H, W), curr_kvlen, curr_position_id in zip(image_sizes, curr_kvlens, curr_rope):
610
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
611
+ curr += curr_kvlen
612
+
613
+ packed_indexes.append(curr)
614
+ curr += 1
615
+ query_curr += 1
616
+
617
+ h, w = H // self.latent_downsample, W // self.latent_downsample
618
+ num_image_tokens = h * w
619
+ packed_indexes.extend(range(curr, curr + num_image_tokens))
620
+ curr += num_image_tokens
621
+ query_curr += num_image_tokens
622
+
623
+ packed_indexes.append(curr)
624
+ curr += 1
625
+ query_curr += 1
626
+
627
+ packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
628
+
629
+ generation_input = {
630
+ "cfg_packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
631
+ "cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
632
+ "cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long),
633
+ "cfg_packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
634
+ }
635
+
636
+ return generation_input
637
+
638
+ @torch.no_grad
639
+ def generate_image(
640
+ self,
641
+ packed_text_ids: torch.LongTensor,
642
+ packed_text_indexes: torch.LongTensor,
643
+ packed_init_noises: torch.Tensor,
644
+ packed_vae_position_ids: torch.LongTensor,
645
+ packed_vae_token_indexes: torch.LongTensor,
646
+ packed_seqlens: torch.IntTensor,
647
+ packed_position_ids: torch.LongTensor,
648
+ packed_indexes: torch.LongTensor,
649
+ past_key_values: NaiveCache,
650
+ key_values_lens: torch.IntTensor,
651
+ packed_key_value_indexes: torch.LongTensor,
652
+ num_timesteps: int = 24,
653
+ timestep_shift: float = 1.0,
654
+ cfg_renorm_min: float = 0.0,
655
+ cfg_renorm_type: str = "global",
656
+ cfg_interval: Optional[Tuple[float, float]] = [0, 1],
657
+ # cfg_text
658
+ cfg_text_scale: float = 1.0,
659
+ cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
660
+ cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
661
+ cfg_text_past_key_values: Optional[NaiveCache] = None,
662
+ cfg_text_key_values_lens: Optional[torch.IntTensor] = None,
663
+ cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
664
+ # cfg_img
665
+ cfg_img_scale: float = 1.0,
666
+ cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
667
+ cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
668
+ cfg_img_past_key_values: Optional[NaiveCache] = None,
669
+ cfg_img_key_values_lens: Optional[torch.IntTensor] = None,
670
+ cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
671
+ cfg_type: str = "parallel",
672
+ ):
673
+ x_t = packed_init_noises
674
+
675
+ timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device)
676
+ timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps)
677
+ dts = timesteps[:-1] - timesteps[1:]
678
+ timesteps = timesteps[:-1]
679
+
680
+ for i, t in enumerate(timesteps):
681
+
682
+ timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
683
+ if t > cfg_interval[0] and t <= cfg_interval[1]:
684
+ cfg_text_scale_ = cfg_text_scale
685
+ cfg_img_scale_ = cfg_img_scale
686
+ else:
687
+ cfg_text_scale_ = 1.0
688
+ cfg_img_scale_ = 1.0
689
+ v_t = self._forward_flow(
690
+ x_t=x_t,
691
+ timestep=timestep,
692
+ packed_vae_token_indexes=packed_vae_token_indexes,
693
+ packed_vae_position_ids=packed_vae_position_ids,
694
+ packed_text_ids=packed_text_ids,
695
+ packed_text_indexes=packed_text_indexes,
696
+ packed_position_ids=packed_position_ids,
697
+ packed_indexes=packed_indexes,
698
+ packed_seqlens=packed_seqlens,
699
+ key_values_lens=key_values_lens,
700
+ past_key_values=past_key_values,
701
+ packed_key_value_indexes=packed_key_value_indexes,
702
+ cfg_renorm_min=cfg_renorm_min,
703
+ cfg_renorm_type=cfg_renorm_type,
704
+ # cfg_text
705
+ cfg_text_scale=cfg_text_scale_,
706
+ cfg_text_packed_position_ids=cfg_text_packed_position_ids,
707
+ cfg_text_packed_query_indexes=cfg_text_packed_query_indexes,
708
+ cfg_text_key_values_lens=cfg_text_key_values_lens,
709
+ cfg_text_past_key_values=cfg_text_past_key_values,
710
+ cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes,
711
+ # cfg_img
712
+ cfg_img_scale=cfg_img_scale_,
713
+ cfg_img_packed_position_ids=cfg_img_packed_position_ids,
714
+ cfg_img_packed_query_indexes=cfg_img_packed_query_indexes,
715
+ cfg_img_key_values_lens=cfg_img_key_values_lens,
716
+ cfg_img_past_key_values=cfg_img_past_key_values,
717
+ cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
718
+ cfg_type=cfg_type,
719
+ )
720
+
721
+ x_t = x_t - v_t.to(x_t.device) * dts[i] # velocity pointing from data to noise
722
+
723
+ unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
724
+ return unpacked_latent
725
+
726
+ @torch.no_grad
727
+ def _forward_flow(
728
+ self,
729
+ x_t: torch.Tensor,
730
+ timestep: torch.LongTensor,
731
+ packed_vae_token_indexes: torch.LongTensor,
732
+ packed_vae_position_ids: torch.LongTensor,
733
+ packed_text_ids: torch.LongTensor,
734
+ packed_text_indexes: torch.LongTensor,
735
+ packed_indexes: torch.LongTensor,
736
+ packed_position_ids: torch.LongTensor,
737
+ packed_seqlens: torch.IntTensor,
738
+ key_values_lens: torch.IntTensor,
739
+ past_key_values: NaiveCache,
740
+ packed_key_value_indexes: torch.LongTensor,
741
+ cfg_renorm_min: float = 0.0,
742
+ cfg_renorm_type: str = "global",
743
+ # cfg_text
744
+ cfg_text_scale: float = 1.0,
745
+ cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
746
+ cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
747
+ cfg_text_key_values_lens: Optional[torch.Tensor] = None,
748
+ cfg_text_past_key_values: Optional[NaiveCache] = None,
749
+ cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
750
+ # cfg_img
751
+ cfg_img_scale: float = 1.0,
752
+ cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
753
+ cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
754
+ cfg_img_key_values_lens: Optional[torch.Tensor] = None,
755
+ cfg_img_past_key_values: Optional[NaiveCache] = None,
756
+ cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
757
+ cfg_type: str = "parallel",
758
+ ):
759
+ packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
760
+ packed_sequence = packed_text_embedding.new_zeros((sum(packed_seqlens), self.hidden_size))
761
+ packed_sequence[packed_text_indexes] = packed_text_embedding
762
+
763
+ assert timestep.unique().shape[0] == 1
764
+ packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
765
+ packed_timestep_embeds = self.time_embedder(timestep)
766
+ x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
767
+ packed_sequence[packed_vae_token_indexes] = x_t
768
+
769
+ extra_inputs = {}
770
+ if self.use_moe:
771
+ extra_inputs = {
772
+ "mode": "gen",
773
+ "packed_vae_token_indexes": packed_vae_token_indexes,
774
+ "packed_text_indexes": packed_text_indexes
775
+ }
776
+
777
+ output = self.language_model.forward_inference(
778
+ packed_query_sequence=packed_sequence,
779
+ query_lens=packed_seqlens,
780
+ packed_query_position_ids=packed_position_ids,
781
+ packed_query_indexes=packed_indexes,
782
+ past_key_values=past_key_values,
783
+ key_values_lens=key_values_lens,
784
+ packed_key_value_indexes=packed_key_value_indexes,
785
+ update_past_key_values=False,
786
+ is_causal=False,
787
+ **extra_inputs,
788
+ )
789
+ v_t = self.llm2vae(output.packed_query_sequence)
790
+ v_t = v_t[packed_vae_token_indexes]
791
+
792
+ if cfg_text_scale > 1.0:
793
+ cfg_text_output = self.language_model.forward_inference(
794
+ packed_query_sequence=packed_sequence,
795
+ query_lens=packed_seqlens,
796
+ packed_query_position_ids=cfg_text_packed_position_ids,
797
+ packed_query_indexes=cfg_text_packed_query_indexes,
798
+ past_key_values=cfg_text_past_key_values,
799
+ key_values_lens=cfg_text_key_values_lens,
800
+ packed_key_value_indexes=cfg_text_packed_key_value_indexes,
801
+ update_past_key_values=False,
802
+ is_causal=False,
803
+ **extra_inputs,
804
+ )
805
+ cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence)
806
+ cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes]
807
+
808
+ if cfg_img_scale > 1.0:
809
+ cfg_img_output = self.language_model.forward_inference(
810
+ packed_query_sequence=packed_sequence,
811
+ query_lens=packed_seqlens,
812
+ packed_query_position_ids=cfg_img_packed_position_ids,
813
+ packed_query_indexes=cfg_img_packed_query_indexes,
814
+ past_key_values=cfg_img_past_key_values,
815
+ key_values_lens=cfg_img_key_values_lens,
816
+ packed_key_value_indexes=cfg_img_packed_key_value_indexes,
817
+ update_past_key_values=False,
818
+ is_causal=False,
819
+ **extra_inputs,
820
+ )
821
+ cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence)
822
+ cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes]
823
+
824
+ if cfg_text_scale > 1.0:
825
+ if cfg_renorm_type == "text_channel":
826
+ v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
827
+ norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
828
+ norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
829
+ scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
830
+ v_t_text = v_t_text_ * scale
831
+ if cfg_img_scale > 1.0:
832
+ v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
833
+ else:
834
+ v_t = v_t_text
835
+ else:
836
+ v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
837
+
838
+ if cfg_img_scale > 1.0:
839
+ v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
840
+ else:
841
+ v_t_ = v_t_text_
842
+
843
+ # NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
844
+ if cfg_renorm_type == "global":
845
+ norm_v_t = torch.norm(v_t)
846
+ norm_v_t_ = torch.norm(v_t_)
847
+ elif cfg_renorm_type == "channel":
848
+ norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
849
+ norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
850
+ else:
851
+ raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted")
852
+ scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0)
853
+ v_t = v_t_ * scale
854
+ else:
855
+ # No CFG
856
+ pass
857
+
858
+ return v_t
859
+
860
+ def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
861
+ packed_start_tokens, packed_key_value_indexes = list(), list()
862
+ packed_query_position_ids = list()
863
+
864
+ curr = 0
865
+ for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
866
+ packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
867
+ packed_start_tokens.append(new_token_ids['bos_token_id'])
868
+ packed_query_position_ids.append(curr_position_id)
869
+ curr += curr_kvlen
870
+
871
+ generation_input = {
872
+ "packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
873
+ "packed_query_position_ids": torch.tensor(packed_query_position_ids, dtype=torch.long),
874
+ "key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
875
+ "packed_key_value_indexes": torch.tensor(packed_key_value_indexes, dtype=torch.long),
876
+ }
877
+
878
+ return generation_input
879
+
880
+ @torch.no_grad
881
+ def generate_text(
882
+ self,
883
+ past_key_values: NaiveCache,
884
+ packed_key_value_indexes: torch.LongTensor,
885
+ key_values_lens: torch.IntTensor,
886
+ packed_start_tokens: torch.LongTensor,
887
+ packed_query_position_ids: torch.LongTensor,
888
+ max_length: int,
889
+ do_sample: bool = False,
890
+ temperature: float = 1.0,
891
+ end_token_id: int = None,
892
+ ):
893
+ step = 0
894
+ # generated_sequence = [] # Removed for streaming
895
+ curr_tokens = packed_start_tokens
896
+ while step < max_length:
897
+ # generated_sequence.append(curr_tokens) # Removed for streaming
898
+ yield curr_tokens # Yield current tokens
899
+
900
+ packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
901
+ query_lens = torch.ones_like(curr_tokens)
902
+ packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
903
+ 0, len(key_values_lens),
904
+ device=key_values_lens.device,
905
+ dtype=key_values_lens.dtype
906
+ )
907
+
908
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
909
+ for i in range(len(uppacked)):
910
+ uppacked[i] += i
911
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
912
+
913
+ extra_inputs = {}
914
+ if self.use_moe:
915
+ extra_inputs = {"mode": "und"}
916
+
917
+ output = self.language_model.forward_inference(
918
+ packed_query_sequence=packed_text_embedding,
919
+ query_lens=query_lens,
920
+ packed_query_position_ids=packed_query_position_ids,
921
+ packed_query_indexes=packed_query_indexes,
922
+ past_key_values=past_key_values,
923
+ key_values_lens=key_values_lens,
924
+ packed_key_value_indexes=packed_key_value_indexes,
925
+ update_past_key_values=True,
926
+ is_causal=True,
927
+ **extra_inputs,
928
+ )
929
+ past_key_values = output.past_key_values
930
+ packed_query_sequence = output.packed_query_sequence
931
+ pred_logits = self.language_model.lm_head(packed_query_sequence)
932
+
933
+ if do_sample:
934
+ probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
935
+ curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
936
+ else:
937
+ curr_tokens = torch.argmax(pred_logits, dim=-1)
938
+
939
+ uppacked = list(packed_key_value_indexes.split(key_values_lens.tolist(), dim=0))
940
+ for i in range(len(uppacked)):
941
+ uppacked[i] = torch.cat(
942
+ [uppacked[i], torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device)], dim=0
943
+ )
944
+ packed_key_value_indexes = torch.cat(uppacked, dim=0)
945
+ key_values_lens = key_values_lens + 1
946
+ packed_query_position_ids = packed_query_position_ids + 1
947
+ step += 1
948
+
949
+ if end_token_id is not None and curr_tokens[0] == end_token_id: # only support batch=1
950
+ break
951
+
952
+ # output_device = generated_sequence[0].device # Removed for streaming
953
+ # return torch.stack([i.to(output_device) for i in generated_sequence], dim=0) # Removed for streaming
954
+
955
+ # for evaluation
956
+ @torch.no_grad()
957
+ def chat(
958
+ self,
959
+ tokenizer,
960
+ new_token_ids,
961
+ image_transform,
962
+ images,
963
+ prompt,
964
+ max_length: int,
965
+ do_sample: bool = False,
966
+ temperature: float = 1.0,
967
+ ):
968
+ device = next(self.parameters()).device
969
+
970
+ if isinstance(new_token_ids, dict):
971
+ for k, v in new_token_ids.items():
972
+ if torch.is_tensor(v):
973
+ new_token_ids[k] = v.to(device)
974
+ elif torch.is_tensor(new_token_ids):
975
+ new_token_ids = new_token_ids.to(device)
976
+
977
+ # prefill
978
+ past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers)
979
+ newlens = [0]
980
+ new_rope = [0]
981
+
982
+ # add images
983
+ for image in images:
984
+ generation_input, newlens, new_rope = self.prepare_vit_images(
985
+ curr_kvlens=newlens,
986
+ curr_rope=new_rope,
987
+ images=[image],
988
+ transforms=image_transform,
989
+ new_token_ids=new_token_ids,
990
+ )
991
+ for k, v in generation_input.items():
992
+ if torch.is_tensor(v):
993
+ generation_input[k] = v.to(device)
994
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
995
+ past_key_values = self.forward_cache_update_vit(past_key_values, **generation_input)
996
+
997
+ # add text
998
+ generation_input, newlens, new_rope = self.prepare_prompts(
999
+ curr_kvlens=newlens,
1000
+ curr_rope=new_rope,
1001
+ prompts=[prompt],
1002
+ tokenizer=tokenizer,
1003
+ new_token_ids=new_token_ids,
1004
+ )
1005
+ for k, v in generation_input.items():
1006
+ if torch.is_tensor(v):
1007
+ generation_input[k] = v.to(device)
1008
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1009
+ past_key_values = self.forward_cache_update_text(past_key_values, **generation_input)
1010
+
1011
+ # decode
1012
+ generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids)
1013
+ for k, v in generation_input.items():
1014
+ if torch.is_tensor(v):
1015
+ generation_input[k] = v.to(device)
1016
+ with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
1017
+ for unpacked_latent in self.generate_text(
1018
+ past_key_values=past_key_values,
1019
+ max_length=max_length,
1020
+ do_sample=do_sample,
1021
+ temperature=temperature,
1022
+ end_token_id=new_token_ids['eos_token_id'],
1023
+ **generation_input,
1024
+ ):
1025
+ output = tokenizer.decode(unpacked_latent[:,0])
1026
+ yield output
modeling/bagel/modeling_utils.py CHANGED
@@ -1,144 +1,144 @@
1
- # Copyright (c) 2022 Facebook, Inc. and its affiliates.
2
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
- # SPDX-License-Identifier: CC BY-NC 4.0
4
- #
5
- # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
- #
7
- # Original file was released under CC BY-NC 4.0, with the full license text
8
- # available at https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt.
9
- #
10
- # This modified file is released under the same license.
11
-
12
- import math
13
-
14
- import numpy as np
15
- import torch
16
- from torch import nn
17
- from transformers.activations import ACT2FN
18
-
19
- # --------------------------------------------------------
20
- # 2D sine-cosine position embedding
21
- # References:
22
- # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
23
- # --------------------------------------------------------
24
- def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
25
- grid_h = np.arange(grid_size, dtype=np.float32)
26
- grid_w = np.arange(grid_size, dtype=np.float32)
27
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
28
- grid = np.stack(grid, axis=0)
29
-
30
- grid = grid.reshape([2, 1, grid_size, grid_size])
31
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
32
- if cls_token and extra_tokens > 0:
33
- pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
34
- return pos_embed
35
-
36
-
37
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
38
- assert embed_dim % 2 == 0
39
-
40
- # use half of dimensions to encode grid_h
41
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
42
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
43
-
44
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
45
- return emb
46
-
47
-
48
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
49
- """
50
- embed_dim: output dimension for each position
51
- pos: a list of positions to be encoded: size (M,)
52
- out: (M, D)
53
- """
54
- assert embed_dim % 2 == 0
55
- omega = np.arange(embed_dim // 2, dtype=np.float64)
56
- omega /= embed_dim / 2.
57
- omega = 1. / 10000**omega # (D/2,)
58
-
59
- pos = pos.reshape(-1) # (M,)
60
- out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
61
-
62
- emb_sin = np.sin(out) # (M, D/2)
63
- emb_cos = np.cos(out) # (M, D/2)
64
-
65
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
66
- return emb
67
-
68
-
69
- # --------------------------------------------------------
70
- # TimestepEmbedder
71
- # Reference:
72
- # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
73
- # --------------------------------------------------------
74
- class TimestepEmbedder(nn.Module):
75
- """
76
- Embeds scalar timesteps into vector representations.
77
- """
78
- def __init__(self, hidden_size, frequency_embedding_size=256):
79
- super().__init__()
80
- self.mlp = nn.Sequential(
81
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
82
- nn.SiLU(),
83
- nn.Linear(hidden_size, hidden_size, bias=True),
84
- )
85
- self.frequency_embedding_size = frequency_embedding_size
86
-
87
- @staticmethod
88
- def timestep_embedding(t, dim, max_period=10000):
89
- """
90
- Create sinusoidal timestep embeddings.
91
- :param t: a 1-D Tensor of N indices, one per batch element.
92
- These may be fractional.
93
- :param dim: the dimension of the output.
94
- :param max_period: controls the minimum frequency of the embeddings.
95
- :return: an (N, D) Tensor of positional embeddings.
96
- """
97
- half = dim // 2
98
- freqs = torch.exp(
99
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
100
- ).to(device=t.device)
101
- args = t[:, None].float() * freqs[None]
102
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
103
- if dim % 2:
104
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
105
- return embedding
106
-
107
- def forward(self, t):
108
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
109
- t_emb = self.mlp(t_freq)
110
- return t_emb
111
-
112
-
113
- class MLPconnector(nn.Module):
114
- def __init__(self, in_dim: int, out_dim: int, hidden_act: str):
115
- super().__init__()
116
- self.activation_fn = ACT2FN[hidden_act]
117
- self.fc1 = nn.Linear(in_dim, out_dim)
118
- self.fc2 = nn.Linear(out_dim, out_dim)
119
-
120
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
121
- hidden_states = self.fc1(hidden_states)
122
- hidden_states = self.activation_fn(hidden_states)
123
- hidden_states = self.fc2(hidden_states)
124
- return hidden_states
125
-
126
-
127
- class PositionEmbedding(nn.Module):
128
- def __init__(self, max_num_patch_per_side, hidden_size):
129
- super().__init__()
130
- self.max_num_patch_per_side = max_num_patch_per_side
131
- self.hidden_size = hidden_size
132
- self.pos_embed = nn.Parameter(
133
- torch.zeros(max_num_patch_per_side ** 2, hidden_size),
134
- requires_grad=False
135
- )
136
- self._init_weights()
137
-
138
- def _init_weights(self):
139
- # Initialize (and freeze) pos_embed by sin-cos embedding:
140
- pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side)
141
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
142
-
143
- def forward(self, position_ids):
144
  return self.pos_embed[position_ids]
 
1
+ # Copyright (c) 2022 Facebook, Inc. and its affiliates.
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: CC BY-NC 4.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under CC BY-NC 4.0, with the full license text
8
+ # available at https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import math
13
+
14
+ import numpy as np
15
+ import torch
16
+ from torch import nn
17
+ from transformers.activations import ACT2FN
18
+
19
+ # --------------------------------------------------------
20
+ # 2D sine-cosine position embedding
21
+ # References:
22
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
23
+ # --------------------------------------------------------
24
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
25
+ grid_h = np.arange(grid_size, dtype=np.float32)
26
+ grid_w = np.arange(grid_size, dtype=np.float32)
27
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
28
+ grid = np.stack(grid, axis=0)
29
+
30
+ grid = grid.reshape([2, 1, grid_size, grid_size])
31
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
32
+ if cls_token and extra_tokens > 0:
33
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
34
+ return pos_embed
35
+
36
+
37
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
38
+ assert embed_dim % 2 == 0
39
+
40
+ # use half of dimensions to encode grid_h
41
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
42
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
43
+
44
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
45
+ return emb
46
+
47
+
48
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
49
+ """
50
+ embed_dim: output dimension for each position
51
+ pos: a list of positions to be encoded: size (M,)
52
+ out: (M, D)
53
+ """
54
+ assert embed_dim % 2 == 0
55
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
56
+ omega /= embed_dim / 2.
57
+ omega = 1. / 10000**omega # (D/2,)
58
+
59
+ pos = pos.reshape(-1) # (M,)
60
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
61
+
62
+ emb_sin = np.sin(out) # (M, D/2)
63
+ emb_cos = np.cos(out) # (M, D/2)
64
+
65
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
66
+ return emb
67
+
68
+
69
+ # --------------------------------------------------------
70
+ # TimestepEmbedder
71
+ # Reference:
72
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
73
+ # --------------------------------------------------------
74
+ class TimestepEmbedder(nn.Module):
75
+ """
76
+ Embeds scalar timesteps into vector representations.
77
+ """
78
+ def __init__(self, hidden_size, frequency_embedding_size=256):
79
+ super().__init__()
80
+ self.mlp = nn.Sequential(
81
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
82
+ nn.SiLU(),
83
+ nn.Linear(hidden_size, hidden_size, bias=True),
84
+ )
85
+ self.frequency_embedding_size = frequency_embedding_size
86
+
87
+ @staticmethod
88
+ def timestep_embedding(t, dim, max_period=10000):
89
+ """
90
+ Create sinusoidal timestep embeddings.
91
+ :param t: a 1-D Tensor of N indices, one per batch element.
92
+ These may be fractional.
93
+ :param dim: the dimension of the output.
94
+ :param max_period: controls the minimum frequency of the embeddings.
95
+ :return: an (N, D) Tensor of positional embeddings.
96
+ """
97
+ half = dim // 2
98
+ freqs = torch.exp(
99
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
100
+ ).to(device=t.device)
101
+ args = t[:, None].float() * freqs[None]
102
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
103
+ if dim % 2:
104
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
105
+ return embedding
106
+
107
+ def forward(self, t):
108
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
109
+ t_emb = self.mlp(t_freq)
110
+ return t_emb
111
+
112
+
113
+ class MLPconnector(nn.Module):
114
+ def __init__(self, in_dim: int, out_dim: int, hidden_act: str):
115
+ super().__init__()
116
+ self.activation_fn = ACT2FN[hidden_act]
117
+ self.fc1 = nn.Linear(in_dim, out_dim)
118
+ self.fc2 = nn.Linear(out_dim, out_dim)
119
+
120
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
121
+ hidden_states = self.fc1(hidden_states)
122
+ hidden_states = self.activation_fn(hidden_states)
123
+ hidden_states = self.fc2(hidden_states)
124
+ return hidden_states
125
+
126
+
127
+ class PositionEmbedding(nn.Module):
128
+ def __init__(self, max_num_patch_per_side, hidden_size):
129
+ super().__init__()
130
+ self.max_num_patch_per_side = max_num_patch_per_side
131
+ self.hidden_size = hidden_size
132
+ self.pos_embed = nn.Parameter(
133
+ torch.zeros(max_num_patch_per_side ** 2, hidden_size),
134
+ requires_grad=False
135
+ )
136
+ self._init_weights()
137
+
138
+ def _init_weights(self):
139
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
140
+ pos_embed = get_2d_sincos_pos_embed(self.hidden_size, self.max_num_patch_per_side)
141
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
142
+
143
+ def forward(self, position_ids):
144
  return self.pos_embed[position_ids]
modeling/bagel/qwen2_navit.py CHANGED
The diff for this file is too large to render. See raw diff
 
modeling/bagel/siglip_navit.py CHANGED
@@ -1,402 +1,402 @@
1
- # Copyright (c) 2024 The HuggingFace Inc. team.
2
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
- # SPDX-License-Identifier: Apache-2.0
4
- #
5
- # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
- #
7
- # Original file was released under Apache-2.0, with the full license text
8
- # available at https://github.com/huggingface/transformers/blob/main/LICENSE.
9
- #
10
- # This modified file is released under the same license.
11
-
12
- import torch
13
- from torch import nn
14
-
15
- from transformers.activations import ACT2FN
16
- from modeling.siglip.configuration_siglip import SiglipVisionConfig as _SiglipVisionConfig
17
- from modeling.siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel
18
- from flash_attn import flash_attn_varlen_func
19
-
20
-
21
- class SiglipVisionConfig(_SiglipVisionConfig):
22
- r"""
23
- This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
24
- Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
25
- configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
26
- [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
27
-
28
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
- documentation from [`PretrainedConfig`] for more information.
30
-
31
- Args:
32
- hidden_size (`int`, *optional*, defaults to 768):
33
- Dimensionality of the encoder layers and the pooler layer.
34
- intermediate_size (`int`, *optional*, defaults to 3072):
35
- Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
36
- num_hidden_layers (`int`, *optional*, defaults to 12):
37
- Number of hidden layers in the Transformer encoder.
38
- num_attention_heads (`int`, *optional*, defaults to 12):
39
- Number of attention heads for each attention layer in the Transformer encoder.
40
- num_channels (`int`, *optional*, defaults to 3):
41
- Number of channels in the input images.
42
- image_size (`int`, *optional*, defaults to 224):
43
- The size (resolution) of each image.
44
- patch_size (`int`, *optional*, defaults to 16):
45
- The size (resolution) of each patch.
46
- hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
47
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48
- `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
49
- layer_norm_eps (`float`, *optional*, defaults to 1e-06):
50
- The epsilon used by the layer normalization layers.
51
- attention_dropout (`float`, *optional*, defaults to 0.0):
52
- The dropout ratio for the attention probabilities.
53
-
54
- Example:
55
-
56
- ```python
57
- >>> from transformers import SiglipVisionConfig, SiglipVisionModel
58
-
59
- >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
60
- >>> configuration = SiglipVisionConfig()
61
-
62
- >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
63
- >>> model = SiglipVisionModel(configuration)
64
-
65
- >>> # Accessing the model configuration
66
- >>> configuration = model.config
67
- ```"""
68
-
69
- model_type = "siglip_vision_model"
70
-
71
- def __init__(
72
- self,
73
- hidden_size=768,
74
- intermediate_size=3072,
75
- num_hidden_layers=12,
76
- num_attention_heads=12,
77
- num_channels=3,
78
- image_size=224,
79
- patch_size=16,
80
- hidden_act="gelu_pytorch_tanh",
81
- layer_norm_eps=1e-6,
82
- attention_dropout=0.0,
83
- rope=True,
84
- **kwargs,
85
- ):
86
- super().__init__(
87
- hidden_size=hidden_size,
88
- intermediate_size=intermediate_size,
89
- num_hidden_layers=num_hidden_layers,
90
- num_attention_heads=num_attention_heads,
91
- num_channels=num_channels,
92
- image_size=image_size,
93
- patch_size=patch_size,
94
- hidden_act=hidden_act,
95
- layer_norm_eps=layer_norm_eps,
96
- attention_dropout=attention_dropout,
97
- **kwargs)
98
-
99
- self.rope = rope
100
-
101
-
102
- class RotaryEmbedding2D(torch.nn.Module):
103
- def __init__(self, dim, max_h, max_w, base=10000):
104
- super().__init__()
105
- freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
106
- inv_freq = 1.0 / (base ** freq)
107
-
108
- grid_h = torch.arange(0, max_h)
109
- grid_h = grid_h.to(inv_freq.dtype)
110
- grid_h = grid_h[:, None].repeat(1, max_w)
111
-
112
- grid_w = torch.arange(0, max_w)
113
- grid_w = grid_w.to(inv_freq.dtype)
114
- grid_w = grid_w[None, :].repeat(max_h, 1)
115
-
116
- cos_h, sin_h = self._forward_one_side(grid_h, inv_freq)
117
- cos_w, sin_w = self._forward_one_side(grid_w, inv_freq)
118
-
119
- self.register_buffer("cos_h", cos_h)
120
- self.register_buffer("sin_h", sin_h)
121
- self.register_buffer("cos_w", cos_w)
122
- self.register_buffer("sin_w", sin_w)
123
-
124
- def _forward_one_side(self, grid, inv_freq):
125
- freqs = grid[..., None] * inv_freq[None, None, :]
126
- emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1)
127
- return emb.cos(), emb.sin()
128
-
129
-
130
- def rotate_half(x):
131
- x1 = x[..., : x.shape[-1] // 2]
132
- x2 = x[..., x.shape[-1] // 2 :]
133
- return torch.cat((-x2, x1), dim=-1)
134
-
135
-
136
- def apply_rotary_pos_emb(q, k, cos, sin):
137
- # unsqueeze due to the head dimension
138
- cos = cos.unsqueeze(1)
139
- sin = sin.unsqueeze(1)
140
- q_embed = (q * cos) + (rotate_half(q) * sin)
141
- k_embed = (k * cos) + (rotate_half(k) * sin)
142
- return q_embed, k_embed
143
-
144
-
145
- class SiglipVisionEmbeddings(nn.Module):
146
- def __init__(self, config: SiglipVisionConfig):
147
- super().__init__()
148
- self.config = config
149
- self.embed_dim = config.hidden_size
150
- self.image_size = config.image_size
151
- self.patch_size = config.patch_size
152
-
153
- self.patch_embedding = nn.Conv2d(
154
- in_channels=config.num_channels,
155
- out_channels=self.embed_dim,
156
- kernel_size=self.patch_size,
157
- stride=self.patch_size,
158
- padding="valid",
159
- )
160
-
161
- self.num_patches_per_side = self.image_size // self.patch_size
162
- self.num_patches = self.num_patches_per_side**2
163
- self.num_positions = self.num_patches
164
- if not config.rope:
165
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
166
-
167
- def convert_conv2d_to_linear(self, config, meta=False):
168
- if meta:
169
- linear_patch_embedding = nn.Linear(
170
- config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True, device='meta'
171
- )
172
- else:
173
- linear_patch_embedding = nn.Linear(
174
- config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True
175
- )
176
- W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape(
177
- self.embed_dim, config.num_channels * self.patch_size ** 2
178
- )
179
- linear_patch_embedding.weight.data = W
180
- linear_patch_embedding.bias.data = self.patch_embedding.bias.data
181
- del self.patch_embedding
182
- self.patch_embedding = linear_patch_embedding
183
-
184
- def forward(
185
- self,
186
- packed_pixel_values: torch.FloatTensor,
187
- packed_flattened_position_ids: torch.LongTensor
188
- ) -> torch.Tensor:
189
-
190
- patch_embeds = self.patch_embedding(packed_pixel_values)
191
- if not self.config.rope:
192
- embeddings = patch_embeds + self.position_embedding(packed_flattened_position_ids)
193
- else:
194
- embeddings = patch_embeds
195
- return embeddings
196
-
197
-
198
- class SiglipFlashAttention2(SiglipAttention):
199
- def __init__(self, *args, **kwargs):
200
- super().__init__(*args, **kwargs)
201
-
202
- def forward(
203
- self,
204
- hidden_states: torch.Tensor,
205
- cu_seqlens: torch.IntTensor,
206
- max_seqlen: int,
207
- cos_h: torch.Tensor = None,
208
- sin_h: torch.Tensor = None,
209
- cos_w: torch.Tensor = None,
210
- sin_w: torch.Tensor = None,
211
- **kwargs,
212
- ) -> torch.Tensor:
213
-
214
- total_q_len, _ = hidden_states.size()
215
-
216
- query_states = self.q_proj(hidden_states)
217
- key_states = self.k_proj(hidden_states)
218
- value_states = self.v_proj(hidden_states)
219
-
220
- query_states = query_states.view(total_q_len, self.num_heads, self.head_dim)
221
- key_states = key_states.view(total_q_len, self.num_heads, self.head_dim)
222
- value_states = value_states.view(total_q_len, self.num_heads, self.head_dim)
223
-
224
- if self.config.rope:
225
- qh, qw = query_states[:, :, :self.head_dim // 2], query_states[:, :, self.head_dim // 2:]
226
- kh, kw = key_states[:, :, :self.head_dim // 2], key_states[:, :, self.head_dim // 2:]
227
- qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h)
228
- qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w)
229
- query_states = torch.cat([qh, qw], dim=-1)
230
- key_states = torch.cat([kh, kw], dim=-1)
231
-
232
- attn_output = flash_attn_varlen_func(
233
- query_states.to(torch.bfloat16),
234
- key_states.to(torch.bfloat16),
235
- value_states.to(torch.bfloat16),
236
- cu_seqlens_q=cu_seqlens,
237
- cu_seqlens_k=cu_seqlens,
238
- max_seqlen_q=max_seqlen,
239
- max_seqlen_k=max_seqlen,
240
- causal=False,
241
- )
242
-
243
- attn_output = self.out_proj(attn_output.reshape(total_q_len, -1))
244
- return attn_output
245
-
246
-
247
- class SiglipMLP(nn.Module):
248
- def __init__(self, config):
249
- super().__init__()
250
- self.config = config
251
- self.activation_fn = ACT2FN[config.hidden_act]
252
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
253
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
254
-
255
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
256
- hidden_states = self.fc1(hidden_states)
257
- hidden_states = self.activation_fn(hidden_states)
258
- hidden_states = self.fc2(hidden_states)
259
- return hidden_states
260
-
261
-
262
- class SiglipEncoderLayer(nn.Module):
263
- def __init__(self, config: SiglipVisionConfig):
264
- super().__init__()
265
- self.embed_dim = config.hidden_size
266
- self.self_attn = SiglipFlashAttention2(config)
267
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
268
- self.mlp = SiglipMLP(config)
269
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
270
-
271
- def forward(
272
- self,
273
- hidden_states: torch.Tensor,
274
- cu_seqlens: torch.IntTensor,
275
- max_seqlen: int,
276
- cos_h: torch.Tensor = None,
277
- sin_h: torch.Tensor = None,
278
- cos_w: torch.Tensor = None,
279
- sin_w: torch.Tensor = None
280
- ) -> torch.Tensor:
281
- residual = hidden_states
282
-
283
- hidden_states = self.layer_norm1(hidden_states)
284
- hidden_states = self.self_attn(
285
- hidden_states=hidden_states,
286
- cu_seqlens=cu_seqlens,
287
- max_seqlen=max_seqlen,
288
- cos_h=cos_h,
289
- sin_h=sin_h,
290
- cos_w=cos_w,
291
- sin_w=sin_w
292
- )
293
- hidden_states = residual + hidden_states
294
-
295
- residual = hidden_states
296
- hidden_states = self.layer_norm2(hidden_states)
297
- hidden_states = self.mlp(hidden_states)
298
- hidden_states = residual + hidden_states
299
-
300
- return hidden_states
301
-
302
-
303
- class SiglipEncoder(nn.Module):
304
- def __init__(self, config: SiglipVisionConfig):
305
- super().__init__()
306
- self.config = config
307
- self.layers = nn.ModuleList(
308
- [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
309
- )
310
-
311
- def forward(
312
- self,
313
- inputs_embeds: torch.Tensor,
314
- cu_seqlens: torch.IntTensor,
315
- max_seqlen: int,
316
- cos_h: torch.Tensor = None,
317
- sin_h: torch.Tensor = None,
318
- cos_w: torch.Tensor = None,
319
- sin_w: torch.Tensor = None,
320
- ) -> torch.Tensor:
321
-
322
- hidden_states = inputs_embeds
323
- for encoder_layer in self.layers:
324
- hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen,
325
- cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w)
326
-
327
- return hidden_states
328
-
329
-
330
- class SiglipVisionTransformer(nn.Module):
331
- def __init__(self, config: SiglipVisionConfig):
332
- super().__init__()
333
- self.config = config
334
- embed_dim = config.hidden_size
335
-
336
- self.embeddings = SiglipVisionEmbeddings(config)
337
- if config.rope:
338
- max_size = config.image_size // config.patch_size
339
- dim_head = config.hidden_size // config.num_attention_heads
340
- self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size)
341
-
342
- self.encoder = SiglipEncoder(config)
343
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
344
-
345
- def forward(
346
- self,
347
- packed_pixel_values: torch.Tensor,
348
- packed_flattened_position_ids: torch.LongTensor,
349
- cu_seqlens: torch.IntTensor,
350
- max_seqlen: int,
351
- ) -> torch.Tensor:
352
- hidden_states = self.embeddings(
353
- packed_pixel_values=packed_pixel_values,
354
- packed_flattened_position_ids=packed_flattened_position_ids
355
- )
356
-
357
- extra_inputs = {}
358
- if self.config.rope:
359
- extra_inputs.update(
360
- cos_h = self.rope.cos_h[packed_flattened_position_ids],
361
- sin_h = self.rope.sin_h[packed_flattened_position_ids],
362
- cos_w = self.rope.cos_w[packed_flattened_position_ids],
363
- sin_w = self.rope.sin_w[packed_flattened_position_ids]
364
- )
365
-
366
- last_hidden_state = self.encoder(
367
- inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
368
- **extra_inputs
369
- )
370
- last_hidden_state = self.post_layernorm(last_hidden_state)
371
- return last_hidden_state
372
-
373
-
374
- class SiglipVisionModel(SiglipPreTrainedModel):
375
- config_class = SiglipVisionConfig
376
- main_input_name = "packed_pixel_values"
377
-
378
- def __init__(self, config: SiglipVisionConfig):
379
- super().__init__(config)
380
-
381
- self.vision_model = SiglipVisionTransformer(config)
382
-
383
- # Initialize weights and apply final processing
384
- self.post_init()
385
-
386
- def get_input_embeddings(self) -> nn.Module:
387
- return self.vision_model.embeddings.patch_embedding
388
-
389
- def forward(
390
- self,
391
- packed_pixel_values: torch.Tensor,
392
- packed_flattened_position_ids: torch.LongTensor,
393
- cu_seqlens: torch.IntTensor,
394
- max_seqlen: int,
395
- ) -> torch.Tensor:
396
-
397
- return self.vision_model(
398
- packed_pixel_values=packed_pixel_values,
399
- packed_flattened_position_ids=packed_flattened_position_ids,
400
- cu_seqlens=cu_seqlens,
401
- max_seqlen=max_seqlen,
402
- )
 
1
+ # Copyright (c) 2024 The HuggingFace Inc. team.
2
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
6
+ #
7
+ # Original file was released under Apache-2.0, with the full license text
8
+ # available at https://github.com/huggingface/transformers/blob/main/LICENSE.
9
+ #
10
+ # This modified file is released under the same license.
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from transformers.activations import ACT2FN
16
+ from modeling.siglip.configuration_siglip import SiglipVisionConfig as _SiglipVisionConfig
17
+ from modeling.siglip.modeling_siglip import SiglipAttention, SiglipPreTrainedModel
18
+ from flash_attn import flash_attn_varlen_func
19
+
20
+
21
+ class SiglipVisionConfig(_SiglipVisionConfig):
22
+ r"""
23
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
24
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
25
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
26
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
27
+
28
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29
+ documentation from [`PretrainedConfig`] for more information.
30
+
31
+ Args:
32
+ hidden_size (`int`, *optional*, defaults to 768):
33
+ Dimensionality of the encoder layers and the pooler layer.
34
+ intermediate_size (`int`, *optional*, defaults to 3072):
35
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
36
+ num_hidden_layers (`int`, *optional*, defaults to 12):
37
+ Number of hidden layers in the Transformer encoder.
38
+ num_attention_heads (`int`, *optional*, defaults to 12):
39
+ Number of attention heads for each attention layer in the Transformer encoder.
40
+ num_channels (`int`, *optional*, defaults to 3):
41
+ Number of channels in the input images.
42
+ image_size (`int`, *optional*, defaults to 224):
43
+ The size (resolution) of each image.
44
+ patch_size (`int`, *optional*, defaults to 16):
45
+ The size (resolution) of each patch.
46
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
47
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
49
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
50
+ The epsilon used by the layer normalization layers.
51
+ attention_dropout (`float`, *optional*, defaults to 0.0):
52
+ The dropout ratio for the attention probabilities.
53
+
54
+ Example:
55
+
56
+ ```python
57
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
58
+
59
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
60
+ >>> configuration = SiglipVisionConfig()
61
+
62
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
63
+ >>> model = SiglipVisionModel(configuration)
64
+
65
+ >>> # Accessing the model configuration
66
+ >>> configuration = model.config
67
+ ```"""
68
+
69
+ model_type = "siglip_vision_model"
70
+
71
+ def __init__(
72
+ self,
73
+ hidden_size=768,
74
+ intermediate_size=3072,
75
+ num_hidden_layers=12,
76
+ num_attention_heads=12,
77
+ num_channels=3,
78
+ image_size=224,
79
+ patch_size=16,
80
+ hidden_act="gelu_pytorch_tanh",
81
+ layer_norm_eps=1e-6,
82
+ attention_dropout=0.0,
83
+ rope=True,
84
+ **kwargs,
85
+ ):
86
+ super().__init__(
87
+ hidden_size=hidden_size,
88
+ intermediate_size=intermediate_size,
89
+ num_hidden_layers=num_hidden_layers,
90
+ num_attention_heads=num_attention_heads,
91
+ num_channels=num_channels,
92
+ image_size=image_size,
93
+ patch_size=patch_size,
94
+ hidden_act=hidden_act,
95
+ layer_norm_eps=layer_norm_eps,
96
+ attention_dropout=attention_dropout,
97
+ **kwargs)
98
+
99
+ self.rope = rope
100
+
101
+
102
+ class RotaryEmbedding2D(torch.nn.Module):
103
+ def __init__(self, dim, max_h, max_w, base=10000):
104
+ super().__init__()
105
+ freq = torch.arange(0, dim, 2, dtype=torch.int64).float() / dim
106
+ inv_freq = 1.0 / (base ** freq)
107
+
108
+ grid_h = torch.arange(0, max_h)
109
+ grid_h = grid_h.to(inv_freq.dtype)
110
+ grid_h = grid_h[:, None].repeat(1, max_w)
111
+
112
+ grid_w = torch.arange(0, max_w)
113
+ grid_w = grid_w.to(inv_freq.dtype)
114
+ grid_w = grid_w[None, :].repeat(max_h, 1)
115
+
116
+ cos_h, sin_h = self._forward_one_side(grid_h, inv_freq)
117
+ cos_w, sin_w = self._forward_one_side(grid_w, inv_freq)
118
+
119
+ self.register_buffer("cos_h", cos_h)
120
+ self.register_buffer("sin_h", sin_h)
121
+ self.register_buffer("cos_w", cos_w)
122
+ self.register_buffer("sin_w", sin_w)
123
+
124
+ def _forward_one_side(self, grid, inv_freq):
125
+ freqs = grid[..., None] * inv_freq[None, None, :]
126
+ emb = torch.cat((freqs, freqs), dim=-1).flatten(0, 1)
127
+ return emb.cos(), emb.sin()
128
+
129
+
130
+ def rotate_half(x):
131
+ x1 = x[..., : x.shape[-1] // 2]
132
+ x2 = x[..., x.shape[-1] // 2 :]
133
+ return torch.cat((-x2, x1), dim=-1)
134
+
135
+
136
+ def apply_rotary_pos_emb(q, k, cos, sin):
137
+ # unsqueeze due to the head dimension
138
+ cos = cos.unsqueeze(1)
139
+ sin = sin.unsqueeze(1)
140
+ q_embed = (q * cos) + (rotate_half(q) * sin)
141
+ k_embed = (k * cos) + (rotate_half(k) * sin)
142
+ return q_embed, k_embed
143
+
144
+
145
+ class SiglipVisionEmbeddings(nn.Module):
146
+ def __init__(self, config: SiglipVisionConfig):
147
+ super().__init__()
148
+ self.config = config
149
+ self.embed_dim = config.hidden_size
150
+ self.image_size = config.image_size
151
+ self.patch_size = config.patch_size
152
+
153
+ self.patch_embedding = nn.Conv2d(
154
+ in_channels=config.num_channels,
155
+ out_channels=self.embed_dim,
156
+ kernel_size=self.patch_size,
157
+ stride=self.patch_size,
158
+ padding="valid",
159
+ )
160
+
161
+ self.num_patches_per_side = self.image_size // self.patch_size
162
+ self.num_patches = self.num_patches_per_side**2
163
+ self.num_positions = self.num_patches
164
+ if not config.rope:
165
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
166
+
167
+ def convert_conv2d_to_linear(self, config, meta=False):
168
+ if meta:
169
+ linear_patch_embedding = nn.Linear(
170
+ config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True, device='meta'
171
+ )
172
+ else:
173
+ linear_patch_embedding = nn.Linear(
174
+ config.num_channels * self.patch_size ** 2, self.embed_dim, bias=True
175
+ )
176
+ W = self.patch_embedding.weight.permute(0, 2, 3, 1).reshape(
177
+ self.embed_dim, config.num_channels * self.patch_size ** 2
178
+ )
179
+ linear_patch_embedding.weight.data = W
180
+ linear_patch_embedding.bias.data = self.patch_embedding.bias.data
181
+ del self.patch_embedding
182
+ self.patch_embedding = linear_patch_embedding
183
+
184
+ def forward(
185
+ self,
186
+ packed_pixel_values: torch.FloatTensor,
187
+ packed_flattened_position_ids: torch.LongTensor
188
+ ) -> torch.Tensor:
189
+
190
+ patch_embeds = self.patch_embedding(packed_pixel_values)
191
+ if not self.config.rope:
192
+ embeddings = patch_embeds + self.position_embedding(packed_flattened_position_ids)
193
+ else:
194
+ embeddings = patch_embeds
195
+ return embeddings
196
+
197
+
198
+ class SiglipFlashAttention2(SiglipAttention):
199
+ def __init__(self, *args, **kwargs):
200
+ super().__init__(*args, **kwargs)
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states: torch.Tensor,
205
+ cu_seqlens: torch.IntTensor,
206
+ max_seqlen: int,
207
+ cos_h: torch.Tensor = None,
208
+ sin_h: torch.Tensor = None,
209
+ cos_w: torch.Tensor = None,
210
+ sin_w: torch.Tensor = None,
211
+ **kwargs,
212
+ ) -> torch.Tensor:
213
+
214
+ total_q_len, _ = hidden_states.size()
215
+
216
+ query_states = self.q_proj(hidden_states)
217
+ key_states = self.k_proj(hidden_states)
218
+ value_states = self.v_proj(hidden_states)
219
+
220
+ query_states = query_states.view(total_q_len, self.num_heads, self.head_dim)
221
+ key_states = key_states.view(total_q_len, self.num_heads, self.head_dim)
222
+ value_states = value_states.view(total_q_len, self.num_heads, self.head_dim)
223
+
224
+ if self.config.rope:
225
+ qh, qw = query_states[:, :, :self.head_dim // 2], query_states[:, :, self.head_dim // 2:]
226
+ kh, kw = key_states[:, :, :self.head_dim // 2], key_states[:, :, self.head_dim // 2:]
227
+ qh, kh = apply_rotary_pos_emb(qh, kh, cos_h, sin_h)
228
+ qw, kw = apply_rotary_pos_emb(qw, kw, cos_w, sin_w)
229
+ query_states = torch.cat([qh, qw], dim=-1)
230
+ key_states = torch.cat([kh, kw], dim=-1)
231
+
232
+ attn_output = flash_attn_varlen_func(
233
+ query_states.to(torch.bfloat16),
234
+ key_states.to(torch.bfloat16),
235
+ value_states.to(torch.bfloat16),
236
+ cu_seqlens_q=cu_seqlens,
237
+ cu_seqlens_k=cu_seqlens,
238
+ max_seqlen_q=max_seqlen,
239
+ max_seqlen_k=max_seqlen,
240
+ causal=False,
241
+ )
242
+
243
+ attn_output = self.out_proj(attn_output.reshape(total_q_len, -1))
244
+ return attn_output
245
+
246
+
247
+ class SiglipMLP(nn.Module):
248
+ def __init__(self, config):
249
+ super().__init__()
250
+ self.config = config
251
+ self.activation_fn = ACT2FN[config.hidden_act]
252
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
253
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
254
+
255
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
256
+ hidden_states = self.fc1(hidden_states)
257
+ hidden_states = self.activation_fn(hidden_states)
258
+ hidden_states = self.fc2(hidden_states)
259
+ return hidden_states
260
+
261
+
262
+ class SiglipEncoderLayer(nn.Module):
263
+ def __init__(self, config: SiglipVisionConfig):
264
+ super().__init__()
265
+ self.embed_dim = config.hidden_size
266
+ self.self_attn = SiglipFlashAttention2(config)
267
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
268
+ self.mlp = SiglipMLP(config)
269
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
270
+
271
+ def forward(
272
+ self,
273
+ hidden_states: torch.Tensor,
274
+ cu_seqlens: torch.IntTensor,
275
+ max_seqlen: int,
276
+ cos_h: torch.Tensor = None,
277
+ sin_h: torch.Tensor = None,
278
+ cos_w: torch.Tensor = None,
279
+ sin_w: torch.Tensor = None
280
+ ) -> torch.Tensor:
281
+ residual = hidden_states
282
+
283
+ hidden_states = self.layer_norm1(hidden_states)
284
+ hidden_states = self.self_attn(
285
+ hidden_states=hidden_states,
286
+ cu_seqlens=cu_seqlens,
287
+ max_seqlen=max_seqlen,
288
+ cos_h=cos_h,
289
+ sin_h=sin_h,
290
+ cos_w=cos_w,
291
+ sin_w=sin_w
292
+ )
293
+ hidden_states = residual + hidden_states
294
+
295
+ residual = hidden_states
296
+ hidden_states = self.layer_norm2(hidden_states)
297
+ hidden_states = self.mlp(hidden_states)
298
+ hidden_states = residual + hidden_states
299
+
300
+ return hidden_states
301
+
302
+
303
+ class SiglipEncoder(nn.Module):
304
+ def __init__(self, config: SiglipVisionConfig):
305
+ super().__init__()
306
+ self.config = config
307
+ self.layers = nn.ModuleList(
308
+ [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
309
+ )
310
+
311
+ def forward(
312
+ self,
313
+ inputs_embeds: torch.Tensor,
314
+ cu_seqlens: torch.IntTensor,
315
+ max_seqlen: int,
316
+ cos_h: torch.Tensor = None,
317
+ sin_h: torch.Tensor = None,
318
+ cos_w: torch.Tensor = None,
319
+ sin_w: torch.Tensor = None,
320
+ ) -> torch.Tensor:
321
+
322
+ hidden_states = inputs_embeds
323
+ for encoder_layer in self.layers:
324
+ hidden_states = encoder_layer(hidden_states, cu_seqlens, max_seqlen,
325
+ cos_h=cos_h, sin_h=sin_h, cos_w=cos_w, sin_w=sin_w)
326
+
327
+ return hidden_states
328
+
329
+
330
+ class SiglipVisionTransformer(nn.Module):
331
+ def __init__(self, config: SiglipVisionConfig):
332
+ super().__init__()
333
+ self.config = config
334
+ embed_dim = config.hidden_size
335
+
336
+ self.embeddings = SiglipVisionEmbeddings(config)
337
+ if config.rope:
338
+ max_size = config.image_size // config.patch_size
339
+ dim_head = config.hidden_size // config.num_attention_heads
340
+ self.rope = RotaryEmbedding2D(dim_head // 2, max_size, max_size)
341
+
342
+ self.encoder = SiglipEncoder(config)
343
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
344
+
345
+ def forward(
346
+ self,
347
+ packed_pixel_values: torch.Tensor,
348
+ packed_flattened_position_ids: torch.LongTensor,
349
+ cu_seqlens: torch.IntTensor,
350
+ max_seqlen: int,
351
+ ) -> torch.Tensor:
352
+ hidden_states = self.embeddings(
353
+ packed_pixel_values=packed_pixel_values,
354
+ packed_flattened_position_ids=packed_flattened_position_ids
355
+ )
356
+
357
+ extra_inputs = {}
358
+ if self.config.rope:
359
+ extra_inputs.update(
360
+ cos_h = self.rope.cos_h[packed_flattened_position_ids],
361
+ sin_h = self.rope.sin_h[packed_flattened_position_ids],
362
+ cos_w = self.rope.cos_w[packed_flattened_position_ids],
363
+ sin_w = self.rope.sin_w[packed_flattened_position_ids]
364
+ )
365
+
366
+ last_hidden_state = self.encoder(
367
+ inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen,
368
+ **extra_inputs
369
+ )
370
+ last_hidden_state = self.post_layernorm(last_hidden_state)
371
+ return last_hidden_state
372
+
373
+
374
+ class SiglipVisionModel(SiglipPreTrainedModel):
375
+ config_class = SiglipVisionConfig
376
+ main_input_name = "packed_pixel_values"
377
+
378
+ def __init__(self, config: SiglipVisionConfig):
379
+ super().__init__(config)
380
+
381
+ self.vision_model = SiglipVisionTransformer(config)
382
+
383
+ # Initialize weights and apply final processing
384
+ self.post_init()
385
+
386
+ def get_input_embeddings(self) -> nn.Module:
387
+ return self.vision_model.embeddings.patch_embedding
388
+
389
+ def forward(
390
+ self,
391
+ packed_pixel_values: torch.Tensor,
392
+ packed_flattened_position_ids: torch.LongTensor,
393
+ cu_seqlens: torch.IntTensor,
394
+ max_seqlen: int,
395
+ ) -> torch.Tensor:
396
+
397
+ return self.vision_model(
398
+ packed_pixel_values=packed_pixel_values,
399
+ packed_flattened_position_ids=packed_flattened_position_ids,
400
+ cu_seqlens=cu_seqlens,
401
+ max_seqlen=max_seqlen,
402
+ )
modeling/qwen2/__init__.py CHANGED
@@ -1,68 +1,68 @@
1
- # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import TYPE_CHECKING
5
-
6
- from transformers.utils import (
7
- OptionalDependencyNotAvailable,
8
- _LazyModule,
9
- is_tokenizers_available,
10
- is_torch_available,
11
- )
12
-
13
-
14
- _import_structure = {
15
- "configuration_qwen2": ["Qwen2Config"],
16
- "tokenization_qwen2": ["Qwen2Tokenizer"],
17
- }
18
-
19
- try:
20
- if not is_tokenizers_available():
21
- raise OptionalDependencyNotAvailable()
22
- except OptionalDependencyNotAvailable:
23
- pass
24
- else:
25
- _import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
26
-
27
- try:
28
- if not is_torch_available():
29
- raise OptionalDependencyNotAvailable()
30
- except OptionalDependencyNotAvailable:
31
- pass
32
- else:
33
- _import_structure["modeling_qwen2"] = [
34
- "Qwen2ForCausalLM",
35
- "Qwen2Model",
36
- "Qwen2PreTrainedModel",
37
- ]
38
-
39
-
40
- if TYPE_CHECKING:
41
- from .configuration_qwen2 import Qwen2Config
42
- from .tokenization_qwen2 import Qwen2Tokenizer
43
-
44
- try:
45
- if not is_tokenizers_available():
46
- raise OptionalDependencyNotAvailable()
47
- except OptionalDependencyNotAvailable:
48
- pass
49
- else:
50
- from .tokenization_qwen2_fast import Qwen2TokenizerFast
51
-
52
- try:
53
- if not is_torch_available():
54
- raise OptionalDependencyNotAvailable()
55
- except OptionalDependencyNotAvailable:
56
- pass
57
- else:
58
- from .modeling_qwen2 import (
59
- Qwen2ForCausalLM,
60
- Qwen2Model,
61
- Qwen2PreTrainedModel,
62
- )
63
-
64
-
65
- else:
66
- import sys
67
-
68
- sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ from transformers.utils import (
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_tokenizers_available,
10
+ is_torch_available,
11
+ )
12
+
13
+
14
+ _import_structure = {
15
+ "configuration_qwen2": ["Qwen2Config"],
16
+ "tokenization_qwen2": ["Qwen2Tokenizer"],
17
+ }
18
+
19
+ try:
20
+ if not is_tokenizers_available():
21
+ raise OptionalDependencyNotAvailable()
22
+ except OptionalDependencyNotAvailable:
23
+ pass
24
+ else:
25
+ _import_structure["tokenization_qwen2_fast"] = ["Qwen2TokenizerFast"]
26
+
27
+ try:
28
+ if not is_torch_available():
29
+ raise OptionalDependencyNotAvailable()
30
+ except OptionalDependencyNotAvailable:
31
+ pass
32
+ else:
33
+ _import_structure["modeling_qwen2"] = [
34
+ "Qwen2ForCausalLM",
35
+ "Qwen2Model",
36
+ "Qwen2PreTrainedModel",
37
+ ]
38
+
39
+
40
+ if TYPE_CHECKING:
41
+ from .configuration_qwen2 import Qwen2Config
42
+ from .tokenization_qwen2 import Qwen2Tokenizer
43
+
44
+ try:
45
+ if not is_tokenizers_available():
46
+ raise OptionalDependencyNotAvailable()
47
+ except OptionalDependencyNotAvailable:
48
+ pass
49
+ else:
50
+ from .tokenization_qwen2_fast import Qwen2TokenizerFast
51
+
52
+ try:
53
+ if not is_torch_available():
54
+ raise OptionalDependencyNotAvailable()
55
+ except OptionalDependencyNotAvailable:
56
+ pass
57
+ else:
58
+ from .modeling_qwen2 import (
59
+ Qwen2ForCausalLM,
60
+ Qwen2Model,
61
+ Qwen2PreTrainedModel,
62
+ )
63
+
64
+
65
+ else:
66
+ import sys
67
+
68
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
modeling/qwen2/configuration_qwen2.py CHANGED
@@ -1,179 +1,179 @@
1
- # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """Qwen2 model configuration"""
5
-
6
- from transformers.configuration_utils import PretrainedConfig
7
- from transformers.modeling_rope_utils import rope_config_validation
8
- from transformers.utils import logging
9
-
10
-
11
- logger = logging.get_logger(__name__)
12
-
13
-
14
- class Qwen2Config(PretrainedConfig):
15
- r"""
16
- This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
17
- Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
18
- with the defaults will yield a similar configuration to that of
19
- Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
20
-
21
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
- documentation from [`PretrainedConfig`] for more information.
23
-
24
-
25
- Args:
26
- vocab_size (`int`, *optional*, defaults to 151936):
27
- Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
28
- `inputs_ids` passed when calling [`Qwen2Model`]
29
- hidden_size (`int`, *optional*, defaults to 4096):
30
- Dimension of the hidden representations.
31
- intermediate_size (`int`, *optional*, defaults to 22016):
32
- Dimension of the MLP representations.
33
- num_hidden_layers (`int`, *optional*, defaults to 32):
34
- Number of hidden layers in the Transformer encoder.
35
- num_attention_heads (`int`, *optional*, defaults to 32):
36
- Number of attention heads for each attention layer in the Transformer encoder.
37
- num_key_value_heads (`int`, *optional*, defaults to 32):
38
- This is the number of key_value heads that should be used to implement Grouped Query Attention. If
39
- `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
40
- `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
41
- converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
42
- by meanpooling all the original heads within that group. For more details checkout [this
43
- paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
44
- hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
45
- The non-linear activation function (function or string) in the decoder.
46
- max_position_embeddings (`int`, *optional*, defaults to 32768):
47
- The maximum sequence length that this model might ever be used with.
48
- initializer_range (`float`, *optional*, defaults to 0.02):
49
- The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
- rms_norm_eps (`float`, *optional*, defaults to 1e-06):
51
- The epsilon used by the rms normalization layers.
52
- use_cache (`bool`, *optional*, defaults to `True`):
53
- Whether or not the model should return the last key/values attentions (not used by all models). Only
54
- relevant if `config.is_decoder=True`.
55
- tie_word_embeddings (`bool`, *optional*, defaults to `False`):
56
- Whether the model's input and output word embeddings should be tied.
57
- rope_theta (`float`, *optional*, defaults to 10000.0):
58
- The base period of the RoPE embeddings.
59
- rope_scaling (`Dict`, *optional*):
60
- Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
61
- and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
62
- accordingly.
63
- Expected contents:
64
- `rope_type` (`str`):
65
- The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
66
- 'llama3'], with 'default' being the original RoPE implementation.
67
- `factor` (`float`, *optional*):
68
- Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
69
- most scaling types, a `factor` of x will enable the model to handle sequences of length x *
70
- original maximum pre-trained length.
71
- `original_max_position_embeddings` (`int`, *optional*):
72
- Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
73
- pretraining.
74
- `attention_factor` (`float`, *optional*):
75
- Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
76
- computation. If unspecified, it defaults to value recommended by the implementation, using the
77
- `factor` field to infer the suggested value.
78
- `beta_fast` (`float`, *optional*):
79
- Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
80
- ramp function. If unspecified, it defaults to 32.
81
- `beta_slow` (`float`, *optional*):
82
- Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
83
- ramp function. If unspecified, it defaults to 1.
84
- `short_factor` (`List[float]`, *optional*):
85
- Only used with 'longrope'. The scaling factor to be applied to short contexts (<
86
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
87
- size divided by the number of attention heads divided by 2
88
- `long_factor` (`List[float]`, *optional*):
89
- Only used with 'longrope'. The scaling factor to be applied to long contexts (<
90
- `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
91
- size divided by the number of attention heads divided by 2
92
- `low_freq_factor` (`float`, *optional*):
93
- Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
94
- `high_freq_factor` (`float`, *optional*):
95
- Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
96
- use_sliding_window (`bool`, *optional*, defaults to `False`):
97
- Whether to use sliding window attention.
98
- sliding_window (`int`, *optional*, defaults to 4096):
99
- Sliding window attention (SWA) window size. If not specified, will default to `4096`.
100
- max_window_layers (`int`, *optional*, defaults to 28):
101
- The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
102
- attention_dropout (`float`, *optional*, defaults to 0.0):
103
- The dropout ratio for the attention probabilities.
104
-
105
- ```python
106
- >>> from transformers import Qwen2Model, Qwen2Config
107
-
108
- >>> # Initializing a Qwen2 style configuration
109
- >>> configuration = Qwen2Config()
110
-
111
- >>> # Initializing a model from the Qwen2-7B style configuration
112
- >>> model = Qwen2Model(configuration)
113
-
114
- >>> # Accessing the model configuration
115
- >>> configuration = model.config
116
- ```"""
117
-
118
- model_type = "qwen2"
119
- keys_to_ignore_at_inference = ["past_key_values"]
120
-
121
- def __init__(
122
- self,
123
- vocab_size=151936,
124
- hidden_size=4096,
125
- intermediate_size=22016,
126
- num_hidden_layers=32,
127
- num_attention_heads=32,
128
- num_key_value_heads=32,
129
- hidden_act="silu",
130
- max_position_embeddings=32768,
131
- initializer_range=0.02,
132
- rms_norm_eps=1e-6,
133
- use_cache=True,
134
- tie_word_embeddings=False,
135
- rope_theta=10000.0,
136
- rope_scaling=None,
137
- use_sliding_window=False,
138
- sliding_window=4096,
139
- max_window_layers=28,
140
- attention_dropout=0.0,
141
- is_causal=True,
142
- _attn_implementation="flash_attention_2",
143
- **kwargs,
144
- ):
145
- self.vocab_size = vocab_size
146
- self.max_position_embeddings = max_position_embeddings
147
- self.hidden_size = hidden_size
148
- self.intermediate_size = intermediate_size
149
- self.num_hidden_layers = num_hidden_layers
150
- self.num_attention_heads = num_attention_heads
151
- self.use_sliding_window = use_sliding_window
152
- self.sliding_window = sliding_window if use_sliding_window else None
153
- self.max_window_layers = max_window_layers
154
-
155
- # for backward compatibility
156
- if num_key_value_heads is None:
157
- num_key_value_heads = num_attention_heads
158
-
159
- self.num_key_value_heads = num_key_value_heads
160
- self.hidden_act = hidden_act
161
- self.initializer_range = initializer_range
162
- self.rms_norm_eps = rms_norm_eps
163
- self.use_cache = use_cache
164
- self.rope_theta = rope_theta
165
- self.rope_scaling = rope_scaling
166
- self.attention_dropout = attention_dropout
167
- self.is_causal = is_causal
168
- self._attn_implementation = _attn_implementation
169
-
170
- # Validate the correctness of rotary position embeddings parameters
171
- # BC: if there is a 'type' field, move it to 'rope_type'.
172
- if self.rope_scaling is not None and "type" in self.rope_scaling:
173
- self.rope_scaling["rope_type"] = self.rope_scaling["type"]
174
- rope_config_validation(self)
175
-
176
- super().__init__(
177
- tie_word_embeddings=tie_word_embeddings,
178
- **kwargs,
179
- )
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Qwen2 model configuration"""
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.modeling_rope_utils import rope_config_validation
8
+ from transformers.utils import logging
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+
14
+ class Qwen2Config(PretrainedConfig):
15
+ r"""
16
+ This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
17
+ Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
18
+ with the defaults will yield a similar configuration to that of
19
+ Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
20
+
21
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
+ documentation from [`PretrainedConfig`] for more information.
23
+
24
+
25
+ Args:
26
+ vocab_size (`int`, *optional*, defaults to 151936):
27
+ Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
28
+ `inputs_ids` passed when calling [`Qwen2Model`]
29
+ hidden_size (`int`, *optional*, defaults to 4096):
30
+ Dimension of the hidden representations.
31
+ intermediate_size (`int`, *optional*, defaults to 22016):
32
+ Dimension of the MLP representations.
33
+ num_hidden_layers (`int`, *optional*, defaults to 32):
34
+ Number of hidden layers in the Transformer encoder.
35
+ num_attention_heads (`int`, *optional*, defaults to 32):
36
+ Number of attention heads for each attention layer in the Transformer encoder.
37
+ num_key_value_heads (`int`, *optional*, defaults to 32):
38
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
39
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
40
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
41
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
42
+ by meanpooling all the original heads within that group. For more details checkout [this
43
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
45
+ The non-linear activation function (function or string) in the decoder.
46
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
47
+ The maximum sequence length that this model might ever be used with.
48
+ initializer_range (`float`, *optional*, defaults to 0.02):
49
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
50
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
51
+ The epsilon used by the rms normalization layers.
52
+ use_cache (`bool`, *optional*, defaults to `True`):
53
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
54
+ relevant if `config.is_decoder=True`.
55
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
56
+ Whether the model's input and output word embeddings should be tied.
57
+ rope_theta (`float`, *optional*, defaults to 10000.0):
58
+ The base period of the RoPE embeddings.
59
+ rope_scaling (`Dict`, *optional*):
60
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
61
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
62
+ accordingly.
63
+ Expected contents:
64
+ `rope_type` (`str`):
65
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
66
+ 'llama3'], with 'default' being the original RoPE implementation.
67
+ `factor` (`float`, *optional*):
68
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
69
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
70
+ original maximum pre-trained length.
71
+ `original_max_position_embeddings` (`int`, *optional*):
72
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
73
+ pretraining.
74
+ `attention_factor` (`float`, *optional*):
75
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
76
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
77
+ `factor` field to infer the suggested value.
78
+ `beta_fast` (`float`, *optional*):
79
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
80
+ ramp function. If unspecified, it defaults to 32.
81
+ `beta_slow` (`float`, *optional*):
82
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
83
+ ramp function. If unspecified, it defaults to 1.
84
+ `short_factor` (`List[float]`, *optional*):
85
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
86
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
87
+ size divided by the number of attention heads divided by 2
88
+ `long_factor` (`List[float]`, *optional*):
89
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
90
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
91
+ size divided by the number of attention heads divided by 2
92
+ `low_freq_factor` (`float`, *optional*):
93
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
94
+ `high_freq_factor` (`float`, *optional*):
95
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
96
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
97
+ Whether to use sliding window attention.
98
+ sliding_window (`int`, *optional*, defaults to 4096):
99
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
100
+ max_window_layers (`int`, *optional*, defaults to 28):
101
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
102
+ attention_dropout (`float`, *optional*, defaults to 0.0):
103
+ The dropout ratio for the attention probabilities.
104
+
105
+ ```python
106
+ >>> from transformers import Qwen2Model, Qwen2Config
107
+
108
+ >>> # Initializing a Qwen2 style configuration
109
+ >>> configuration = Qwen2Config()
110
+
111
+ >>> # Initializing a model from the Qwen2-7B style configuration
112
+ >>> model = Qwen2Model(configuration)
113
+
114
+ >>> # Accessing the model configuration
115
+ >>> configuration = model.config
116
+ ```"""
117
+
118
+ model_type = "qwen2"
119
+ keys_to_ignore_at_inference = ["past_key_values"]
120
+
121
+ def __init__(
122
+ self,
123
+ vocab_size=151936,
124
+ hidden_size=4096,
125
+ intermediate_size=22016,
126
+ num_hidden_layers=32,
127
+ num_attention_heads=32,
128
+ num_key_value_heads=32,
129
+ hidden_act="silu",
130
+ max_position_embeddings=32768,
131
+ initializer_range=0.02,
132
+ rms_norm_eps=1e-6,
133
+ use_cache=True,
134
+ tie_word_embeddings=False,
135
+ rope_theta=10000.0,
136
+ rope_scaling=None,
137
+ use_sliding_window=False,
138
+ sliding_window=4096,
139
+ max_window_layers=28,
140
+ attention_dropout=0.0,
141
+ is_causal=True,
142
+ _attn_implementation="flash_attention_2",
143
+ **kwargs,
144
+ ):
145
+ self.vocab_size = vocab_size
146
+ self.max_position_embeddings = max_position_embeddings
147
+ self.hidden_size = hidden_size
148
+ self.intermediate_size = intermediate_size
149
+ self.num_hidden_layers = num_hidden_layers
150
+ self.num_attention_heads = num_attention_heads
151
+ self.use_sliding_window = use_sliding_window
152
+ self.sliding_window = sliding_window if use_sliding_window else None
153
+ self.max_window_layers = max_window_layers
154
+
155
+ # for backward compatibility
156
+ if num_key_value_heads is None:
157
+ num_key_value_heads = num_attention_heads
158
+
159
+ self.num_key_value_heads = num_key_value_heads
160
+ self.hidden_act = hidden_act
161
+ self.initializer_range = initializer_range
162
+ self.rms_norm_eps = rms_norm_eps
163
+ self.use_cache = use_cache
164
+ self.rope_theta = rope_theta
165
+ self.rope_scaling = rope_scaling
166
+ self.attention_dropout = attention_dropout
167
+ self.is_causal = is_causal
168
+ self._attn_implementation = _attn_implementation
169
+
170
+ # Validate the correctness of rotary position embeddings parameters
171
+ # BC: if there is a 'type' field, move it to 'rope_type'.
172
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
173
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
174
+ rope_config_validation(self)
175
+
176
+ super().__init__(
177
+ tie_word_embeddings=tie_word_embeddings,
178
+ **kwargs,
179
+ )
modeling/qwen2/modeling_qwen2.py CHANGED
@@ -1,929 +1,929 @@
1
- # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """PyTorch Qwen2 model."""
5
-
6
- import math
7
- from typing import List, Optional, Tuple, Union
8
-
9
- import torch
10
- import torch.utils.checkpoint
11
- from torch import nn
12
-
13
- from transformers.activations import ACT2FN
14
- from transformers.cache_utils import Cache, DynamicCache
15
- from transformers.generation import GenerationMixin
16
- from transformers.modeling_outputs import (
17
- BaseModelOutputWithPast,
18
- CausalLMOutputWithPast,
19
- )
20
- from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
21
- from transformers.modeling_utils import PreTrainedModel
22
- from transformers.utils import (
23
- add_start_docstrings,
24
- add_start_docstrings_to_model_forward,
25
- is_flash_attn_2_available,
26
- is_flash_attn_greater_or_equal_2_10,
27
- logging,
28
- replace_return_docstrings,
29
- )
30
- from .configuration_qwen2 import Qwen2Config
31
-
32
-
33
- if is_flash_attn_2_available():
34
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
35
-
36
-
37
- logger = logging.get_logger(__name__)
38
-
39
-
40
- _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B"
41
- _CONFIG_FOR_DOC = "Qwen2Config"
42
-
43
-
44
- # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
45
- class Qwen2RMSNorm(nn.Module):
46
- def __init__(self, hidden_size, eps=1e-6):
47
- """
48
- Qwen2RMSNorm is equivalent to T5LayerNorm
49
- """
50
- super().__init__()
51
- self.weight = nn.Parameter(torch.ones(hidden_size))
52
- self.variance_epsilon = eps
53
-
54
- def forward(self, hidden_states):
55
- input_dtype = hidden_states.dtype
56
- hidden_states = hidden_states.to(torch.float32)
57
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
58
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
59
- return self.weight * hidden_states.to(input_dtype)
60
-
61
- def extra_repr(self):
62
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
63
-
64
-
65
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
66
- class Qwen2RotaryEmbedding(nn.Module):
67
- def __init__(
68
- self,
69
- dim=None,
70
- max_position_embeddings=2048,
71
- base=10000,
72
- device=None,
73
- scaling_factor=1.0,
74
- rope_type="default",
75
- config: Optional[Qwen2Config] = None,
76
- ):
77
- super().__init__()
78
- # TODO (joao): remove the `if` below, only used for BC
79
- self.rope_kwargs = {}
80
- if config is None:
81
- logger.warning_once(
82
- "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
83
- "`config` argument. All other arguments will be removed in v4.46"
84
- )
85
- self.rope_kwargs = {
86
- "rope_type": rope_type,
87
- "factor": scaling_factor,
88
- "dim": dim,
89
- "base": base,
90
- "max_position_embeddings": max_position_embeddings,
91
- }
92
- self.rope_type = rope_type
93
- self.max_seq_len_cached = max_position_embeddings
94
- self.original_max_seq_len = max_position_embeddings
95
- else:
96
- # BC: "rope_type" was originally "type"
97
- if config.rope_scaling is not None:
98
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
99
- else:
100
- self.rope_type = "default"
101
- self.max_seq_len_cached = config.max_position_embeddings
102
- self.original_max_seq_len = config.max_position_embeddings
103
-
104
- self.config = config
105
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
106
-
107
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
108
- self.register_buffer("inv_freq", inv_freq, persistent=False)
109
- self.original_inv_freq = self.inv_freq
110
-
111
- def _dynamic_frequency_update(self, position_ids, device):
112
- """
113
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
114
- 1 - growing beyond the cached sequence length (allow scaling)
115
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
116
- """
117
- seq_len = torch.max(position_ids) + 1
118
- if seq_len > self.max_seq_len_cached: # growth
119
- inv_freq, self.attention_scaling = self.rope_init_fn(
120
- self.config, device, seq_len=seq_len, **self.rope_kwargs
121
- )
122
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
123
- self.max_seq_len_cached = seq_len
124
-
125
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
126
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
127
- self.max_seq_len_cached = self.original_max_seq_len
128
-
129
- @torch.no_grad()
130
- def forward(self, x, position_ids):
131
- if "dynamic" in self.rope_type:
132
- self._dynamic_frequency_update(position_ids, device=x.device)
133
-
134
- # Core RoPE block
135
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
136
- position_ids_expanded = position_ids[:, None, :].float()
137
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
138
- device_type = x.device.type
139
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
140
- with torch.autocast(device_type=device_type, enabled=False):
141
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
142
- emb = torch.cat((freqs, freqs), dim=-1)
143
- cos = emb.cos()
144
- sin = emb.sin()
145
-
146
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
147
- cos = cos * self.attention_scaling
148
- sin = sin * self.attention_scaling
149
-
150
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
151
-
152
-
153
- # Copied from transformers.models.llama.modeling_llama.rotate_half
154
- def rotate_half(x):
155
- """Rotates half the hidden dims of the input."""
156
- x1 = x[..., : x.shape[-1] // 2]
157
- x2 = x[..., x.shape[-1] // 2 :]
158
- return torch.cat((-x2, x1), dim=-1)
159
-
160
-
161
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
162
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
163
- """Applies Rotary Position Embedding to the query and key tensors.
164
-
165
- Args:
166
- q (`torch.Tensor`): The query tensor.
167
- k (`torch.Tensor`): The key tensor.
168
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
169
- sin (`torch.Tensor`): The sine part of the rotary embedding.
170
- position_ids (`torch.Tensor`, *optional*):
171
- Deprecated and unused.
172
- unsqueeze_dim (`int`, *optional*, defaults to 1):
173
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
174
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
175
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
176
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
177
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
178
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
179
- Returns:
180
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
181
- """
182
- cos = cos.unsqueeze(unsqueeze_dim)
183
- sin = sin.unsqueeze(unsqueeze_dim)
184
- q_embed = (q * cos) + (rotate_half(q) * sin)
185
- k_embed = (k * cos) + (rotate_half(k) * sin)
186
- return q_embed, k_embed
187
-
188
-
189
- # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
190
- class Qwen2MLP(nn.Module):
191
- def __init__(self, config):
192
- super().__init__()
193
- self.hidden_size = config.hidden_size
194
- self.intermediate_size = config.intermediate_size
195
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
196
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
197
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
198
- self.act_fn = ACT2FN[config.hidden_act]
199
-
200
- def forward(self, hidden_state):
201
- return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
202
-
203
-
204
- # Copied from transformers.models.llama.modeling_llama.repeat_kv
205
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
206
- """
207
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
208
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
209
- """
210
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
211
- if n_rep == 1:
212
- return hidden_states
213
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
214
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
215
-
216
-
217
- class Qwen2Attention(nn.Module):
218
- """
219
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
220
- and "Generating Long Sequences with Sparse Transformers".
221
- """
222
-
223
- def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
224
- super().__init__()
225
- self.config = config
226
- self.layer_idx = layer_idx
227
- if layer_idx is None:
228
- logger.warning_once(
229
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
230
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
231
- "when creating this class."
232
- )
233
-
234
- self.hidden_size = config.hidden_size
235
- self.num_heads = config.num_attention_heads
236
- self.head_dim = self.hidden_size // self.num_heads
237
- self.num_key_value_heads = config.num_key_value_heads
238
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
239
- self.max_position_embeddings = config.max_position_embeddings
240
- self.rope_theta = config.rope_theta
241
- self.is_causal = config.is_causal
242
- self.attention_dropout = config.attention_dropout
243
-
244
- if (self.head_dim * self.num_heads) != self.hidden_size:
245
- raise ValueError(
246
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
247
- f" and `num_heads`: {self.num_heads})."
248
- )
249
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
250
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
251
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
252
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
253
-
254
- def forward(
255
- self,
256
- hidden_states: torch.Tensor,
257
- attention_mask: Optional[torch.Tensor] = None,
258
- position_ids: Optional[torch.LongTensor] = None,
259
- past_key_value: Optional[Cache] = None,
260
- output_attentions: bool = False,
261
- use_cache: bool = False,
262
- cache_position: Optional[torch.LongTensor] = None,
263
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
264
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
265
- bsz, q_len, _ = hidden_states.size()
266
-
267
- query_states = self.q_proj(hidden_states)
268
- key_states = self.k_proj(hidden_states)
269
- value_states = self.v_proj(hidden_states)
270
-
271
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
272
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
273
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
274
-
275
- if position_embeddings is None:
276
- logger.warning_once(
277
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
278
- "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
279
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
280
- "removed and `position_embeddings` will be mandatory."
281
- )
282
- cos, sin = self.rotary_emb(value_states, position_ids)
283
- else:
284
- cos, sin = position_embeddings
285
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
286
-
287
- if past_key_value is not None:
288
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
289
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
290
-
291
- # repeat k/v heads if n_kv_heads < n_heads
292
- key_states = repeat_kv(key_states, self.num_key_value_groups)
293
- value_states = repeat_kv(value_states, self.num_key_value_groups)
294
-
295
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
296
- if attention_mask is not None: # no matter the length, we just slice it
297
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
298
- attn_weights = attn_weights + causal_mask
299
-
300
- # upcast attention to fp32
301
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
302
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
303
- attn_output = torch.matmul(attn_weights, value_states)
304
-
305
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
306
- raise ValueError(
307
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
308
- f" {attn_output.size()}"
309
- )
310
-
311
- attn_output = attn_output.transpose(1, 2).contiguous()
312
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
313
-
314
- attn_output = self.o_proj(attn_output)
315
-
316
- if not output_attentions:
317
- attn_weights = None
318
-
319
- return attn_output, attn_weights, past_key_value
320
-
321
-
322
- class Qwen2FlashAttention2(Qwen2Attention):
323
- """
324
- Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
325
- as the weights of the module stays untouched. The only required change would be on the forward pass
326
- where it needs to correctly call the public API of flash attention and deal with padding tokens
327
- in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
328
- config.max_window_layers layers.
329
- """
330
-
331
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
332
- def __init__(self, *args, **kwargs):
333
- super().__init__(*args, **kwargs)
334
-
335
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
336
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
337
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
338
- self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
339
-
340
- def forward(
341
- self,
342
- hidden_states: torch.Tensor,
343
- attention_mask: Optional[torch.Tensor] = None,
344
- position_ids: Optional[torch.LongTensor] = None,
345
- past_key_value: Optional[Cache] = None,
346
- output_attentions: bool = False,
347
- use_cache: bool = False,
348
- cache_position: Optional[torch.LongTensor] = None,
349
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
350
- ):
351
- bsz, q_len, _ = hidden_states.size()
352
-
353
- query_states = self.q_proj(hidden_states)
354
- key_states = self.k_proj(hidden_states)
355
- value_states = self.v_proj(hidden_states)
356
-
357
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
358
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
359
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
360
-
361
- if position_embeddings is None:
362
- logger.warning_once(
363
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
364
- "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
365
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
366
- "removed and `position_embeddings` will be mandatory."
367
- )
368
- cos, sin = self.rotary_emb(value_states, position_ids)
369
- else:
370
- cos, sin = position_embeddings
371
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
372
-
373
- if past_key_value is not None:
374
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
375
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
376
-
377
- # repeat k/v heads if n_kv_heads < n_heads
378
- key_states = repeat_kv(key_states, self.num_key_value_groups)
379
- value_states = repeat_kv(value_states, self.num_key_value_groups)
380
- dropout_rate = 0.0 if not self.training else self.attention_dropout
381
-
382
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
383
- # therefore the input hidden states gets silently casted in float32. Hence, we need
384
- # cast them back in float16 just to be sure everything works as expected.
385
- input_dtype = query_states.dtype
386
- if input_dtype == torch.float32:
387
- if torch.is_autocast_enabled():
388
- target_dtype = torch.get_autocast_gpu_dtype()
389
- # Handle the case where the model is quantized
390
- elif hasattr(self.config, "_pre_quantization_dtype"):
391
- target_dtype = self.config._pre_quantization_dtype
392
- else:
393
- target_dtype = self.q_proj.weight.dtype
394
-
395
- logger.warning_once(
396
- f"The input hidden states seems to be silently casted in float32, this might be related to"
397
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
398
- f" {target_dtype}."
399
- )
400
-
401
- query_states = query_states.to(target_dtype)
402
- key_states = key_states.to(target_dtype)
403
- value_states = value_states.to(target_dtype)
404
-
405
- # Reashape to the expected shape for Flash Attention
406
- query_states = query_states.transpose(1, 2)
407
- key_states = key_states.transpose(1, 2)
408
- value_states = value_states.transpose(1, 2)
409
-
410
- if (
411
- self.config.use_sliding_window
412
- and getattr(self.config, "sliding_window", None) is not None
413
- and self.layer_idx >= self.config.max_window_layers
414
- ):
415
- sliding_window = self.config.sliding_window
416
- else:
417
- sliding_window = None
418
-
419
- attn_output = _flash_attention_forward(
420
- query_states,
421
- key_states,
422
- value_states,
423
- attention_mask,
424
- q_len,
425
- position_ids=position_ids,
426
- dropout=dropout_rate,
427
- sliding_window=sliding_window,
428
- is_causal=self.is_causal,
429
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
430
- )
431
-
432
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
433
- attn_output = self.o_proj(attn_output)
434
-
435
- if not output_attentions:
436
- attn_weights = None
437
-
438
- return attn_output, attn_weights, past_key_value
439
-
440
-
441
- QWEN2_ATTENTION_CLASSES = {
442
- "eager": Qwen2Attention,
443
- "flash_attention_2": Qwen2FlashAttention2,
444
- }
445
-
446
-
447
- class Qwen2DecoderLayer(nn.Module):
448
- def __init__(self, config: Qwen2Config, layer_idx: int):
449
- super().__init__()
450
- self.hidden_size = config.hidden_size
451
-
452
- if config.sliding_window and config._attn_implementation != "flash_attention_2":
453
- logger.warning_once(
454
- f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
455
- "unexpected results may be encountered."
456
- )
457
- self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
458
-
459
- self.mlp = Qwen2MLP(config)
460
- self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
461
- self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
462
-
463
- def forward(
464
- self,
465
- hidden_states: torch.Tensor,
466
- attention_mask: Optional[torch.Tensor] = None,
467
- position_ids: Optional[torch.LongTensor] = None,
468
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
469
- output_attentions: Optional[bool] = False,
470
- use_cache: Optional[bool] = False,
471
- cache_position: Optional[torch.LongTensor] = None,
472
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
473
- **kwargs,
474
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
475
- """
476
- Args:
477
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
478
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
479
- `(batch, sequence_length)` where padding elements are indicated by 0.
480
- output_attentions (`bool`, *optional*):
481
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
482
- returned tensors for more detail.
483
- use_cache (`bool`, *optional*):
484
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
485
- (see `past_key_values`).
486
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
487
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
488
- Indices depicting the position of the input sequence tokens in the sequence.
489
- position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
490
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
491
- with `head_dim` being the embedding dimension of each attention head.
492
- kwargs (`dict`, *optional*):
493
- Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
494
- into the model
495
- """
496
-
497
- residual = hidden_states
498
-
499
- hidden_states = self.input_layernorm(hidden_states)
500
-
501
- # Self Attention
502
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
503
- hidden_states=hidden_states,
504
- attention_mask=attention_mask,
505
- position_ids=position_ids,
506
- past_key_value=past_key_value,
507
- output_attentions=output_attentions,
508
- use_cache=use_cache,
509
- cache_position=cache_position,
510
- position_embeddings=position_embeddings,
511
- )
512
- hidden_states = residual + hidden_states
513
-
514
- # Fully Connected
515
- residual = hidden_states
516
- hidden_states = self.post_attention_layernorm(hidden_states)
517
- hidden_states = self.mlp(hidden_states)
518
- hidden_states = residual + hidden_states
519
-
520
- outputs = (hidden_states,)
521
-
522
- if output_attentions:
523
- outputs += (self_attn_weights,)
524
-
525
- if use_cache:
526
- outputs += (present_key_value,)
527
-
528
- return outputs
529
-
530
-
531
- QWEN2_START_DOCSTRING = r"""
532
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
533
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
534
- etc.)
535
-
536
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
537
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
538
- and behavior.
539
-
540
- Parameters:
541
- config ([`Qwen2Config`]):
542
- Model configuration class with all the parameters of the model. Initializing with a config file does not
543
- load the weights associated with the model, only the configuration. Check out the
544
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
545
- """
546
-
547
-
548
- @add_start_docstrings(
549
- "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
550
- QWEN2_START_DOCSTRING,
551
- )
552
- class Qwen2PreTrainedModel(PreTrainedModel):
553
- config_class = Qwen2Config
554
- base_model_prefix = "model"
555
- supports_gradient_checkpointing = True
556
- _no_split_modules = ["Qwen2DecoderLayer"]
557
- _skip_keys_device_placement = "past_key_values"
558
- _supports_flash_attn_2 = True
559
- _supports_cache_class = True
560
- _supports_quantized_cache = True
561
- _supports_static_cache = True
562
-
563
- def _init_weights(self, module):
564
- std = self.config.initializer_range
565
- if isinstance(module, nn.Linear):
566
- module.weight.data.normal_(mean=0.0, std=std)
567
- if module.bias is not None:
568
- module.bias.data.zero_()
569
- elif isinstance(module, nn.Embedding):
570
- module.weight.data.normal_(mean=0.0, std=std)
571
- if module.padding_idx is not None:
572
- module.weight.data[module.padding_idx].zero_()
573
-
574
-
575
- QWEN2_INPUTS_DOCSTRING = r"""
576
- Args:
577
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
578
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
579
- it.
580
-
581
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
582
- [`PreTrainedTokenizer.__call__`] for details.
583
-
584
- [What are input IDs?](../glossary#input-ids)
585
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
586
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
587
-
588
- - 1 for tokens that are **not masked**,
589
- - 0 for tokens that are **masked**.
590
-
591
- [What are attention masks?](../glossary#attention-mask)
592
-
593
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
594
- [`PreTrainedTokenizer.__call__`] for details.
595
-
596
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
597
- `past_key_values`).
598
-
599
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
600
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
601
- information on the default strategy.
602
-
603
- - 1 indicates the head is **not masked**,
604
- - 0 indicates the head is **masked**.
605
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
606
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
607
- config.n_positions - 1]`.
608
-
609
- [What are position IDs?](../glossary#position-ids)
610
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
611
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
612
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
613
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
614
-
615
- Two formats are allowed:
616
- - a [`~cache_utils.Cache`] instance, see our
617
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
618
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
619
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
620
- cache format.
621
-
622
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
623
- legacy cache format will be returned.
624
-
625
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
626
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
627
- of shape `(batch_size, sequence_length)`.
628
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
629
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
630
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
631
- model's internal embedding lookup matrix.
632
- use_cache (`bool`, *optional*):
633
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
634
- `past_key_values`).
635
- output_attentions (`bool`, *optional*):
636
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
637
- tensors for more detail.
638
- output_hidden_states (`bool`, *optional*):
639
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
640
- more detail.
641
- return_dict (`bool`, *optional*):
642
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
643
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
644
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
645
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
646
- the complete sequence length.
647
- """
648
-
649
-
650
- @add_start_docstrings(
651
- "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
652
- QWEN2_START_DOCSTRING,
653
- )
654
- class Qwen2Model(Qwen2PreTrainedModel):
655
- """
656
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
657
-
658
- Args:
659
- config: Qwen2Config
660
- """
661
-
662
- def __init__(self, config: Qwen2Config):
663
- super().__init__(config)
664
- self.padding_idx = config.pad_token_id
665
- self.vocab_size = config.vocab_size
666
-
667
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
668
- self.layers = nn.ModuleList(
669
- [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
670
- )
671
- self._attn_implementation = config._attn_implementation
672
- self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
673
- self.rotary_emb = Qwen2RotaryEmbedding(config=config)
674
-
675
- self.gradient_checkpointing = False
676
- # Initialize weights and apply final processing
677
- self.post_init()
678
-
679
- def get_input_embeddings(self):
680
- return self.embed_tokens
681
-
682
- def set_input_embeddings(self, value):
683
- self.embed_tokens = value
684
-
685
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
686
- def forward(
687
- self,
688
- input_ids: torch.LongTensor = None,
689
- attention_mask: Optional[torch.Tensor] = None,
690
- position_ids: Optional[torch.LongTensor] = None,
691
- past_key_values: Optional[List[torch.FloatTensor]] = None,
692
- inputs_embeds: Optional[torch.FloatTensor] = None,
693
- use_cache: Optional[bool] = None,
694
- output_attentions: Optional[bool] = None,
695
- output_hidden_states: Optional[bool] = None,
696
- return_dict: Optional[bool] = None,
697
- cache_position: Optional[torch.LongTensor] = None,
698
- ) -> Union[Tuple, BaseModelOutputWithPast]:
699
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
700
- output_hidden_states = (
701
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
702
- )
703
- use_cache = use_cache if use_cache is not None else self.config.use_cache
704
-
705
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
706
-
707
- if (input_ids is None) ^ (inputs_embeds is not None):
708
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
709
-
710
- if self.gradient_checkpointing and self.training:
711
- if use_cache:
712
- logger.warning_once(
713
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
714
- )
715
- use_cache = False
716
-
717
- # kept for BC (non `Cache` `past_key_values` inputs)
718
- return_legacy_cache = False
719
- if use_cache and not isinstance(past_key_values, Cache):
720
- return_legacy_cache = True
721
- if past_key_values is None:
722
- past_key_values = DynamicCache()
723
- else:
724
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
725
- logger.warning_once(
726
- "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
727
- "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
728
- "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
729
- )
730
-
731
- if inputs_embeds is None:
732
- inputs_embeds = self.embed_tokens(input_ids)
733
-
734
- if cache_position is None:
735
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
736
- cache_position = torch.arange(
737
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
738
- )
739
- if position_ids is None:
740
- position_ids = cache_position.unsqueeze(0)
741
-
742
- if attention_mask is not None and 0.0 in attention_mask:
743
- causal_mask = attention_mask
744
- else:
745
- causal_mask = None
746
-
747
- hidden_states = inputs_embeds
748
- # create position embeddings to be shared across the decoder layers
749
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
750
-
751
- # decoder layers
752
- all_hidden_states = () if output_hidden_states else None
753
- all_self_attns = () if output_attentions else None
754
- next_decoder_cache = None
755
-
756
- for decoder_layer in self.layers:
757
- if output_hidden_states:
758
- all_hidden_states += (hidden_states,)
759
-
760
- if self.gradient_checkpointing and self.training:
761
- layer_outputs = self._gradient_checkpointing_func(
762
- decoder_layer.__call__,
763
- hidden_states,
764
- causal_mask,
765
- position_ids,
766
- past_key_values,
767
- output_attentions,
768
- use_cache,
769
- cache_position,
770
- position_embeddings,
771
- )
772
- else:
773
- layer_outputs = decoder_layer(
774
- hidden_states,
775
- attention_mask=causal_mask,
776
- position_ids=position_ids,
777
- past_key_value=past_key_values,
778
- output_attentions=output_attentions,
779
- use_cache=use_cache,
780
- cache_position=cache_position,
781
- position_embeddings=position_embeddings,
782
- )
783
-
784
- hidden_states = layer_outputs[0]
785
-
786
- if use_cache:
787
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
788
-
789
- if output_attentions:
790
- all_self_attns += (layer_outputs[1],)
791
-
792
- hidden_states = self.norm(hidden_states)
793
-
794
- # add hidden states from the last decoder layer
795
- if output_hidden_states:
796
- all_hidden_states += (hidden_states,)
797
-
798
- next_cache = next_decoder_cache if use_cache else None
799
- if return_legacy_cache:
800
- next_cache = next_cache.to_legacy_cache()
801
-
802
- if not return_dict:
803
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
804
- return BaseModelOutputWithPast(
805
- last_hidden_state=hidden_states,
806
- past_key_values=next_cache,
807
- hidden_states=all_hidden_states,
808
- attentions=all_self_attns,
809
- )
810
-
811
-
812
- class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
813
- _tied_weights_keys = ["lm_head.weight"]
814
-
815
- def __init__(self, config):
816
- super().__init__(config)
817
- self.model = Qwen2Model(config)
818
- self.vocab_size = config.vocab_size
819
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
820
-
821
- # Initialize weights and apply final processing
822
- self.post_init()
823
-
824
- def get_input_embeddings(self):
825
- return self.model.embed_tokens
826
-
827
- def set_input_embeddings(self, value):
828
- self.model.embed_tokens = value
829
-
830
- def get_output_embeddings(self):
831
- return self.lm_head
832
-
833
- def set_output_embeddings(self, new_embeddings):
834
- self.lm_head = new_embeddings
835
-
836
- def set_decoder(self, decoder):
837
- self.model = decoder
838
-
839
- def get_decoder(self):
840
- return self.model
841
-
842
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
843
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
844
- def forward(
845
- self,
846
- input_ids: torch.LongTensor = None,
847
- attention_mask: Optional[torch.Tensor] = None,
848
- position_ids: Optional[torch.LongTensor] = None,
849
- past_key_values: Optional[List[torch.FloatTensor]] = None,
850
- inputs_embeds: Optional[torch.FloatTensor] = None,
851
- labels: Optional[torch.LongTensor] = None,
852
- use_cache: Optional[bool] = None,
853
- output_attentions: Optional[bool] = None,
854
- output_hidden_states: Optional[bool] = None,
855
- return_dict: Optional[bool] = None,
856
- cache_position: Optional[torch.LongTensor] = None,
857
- num_logits_to_keep: int = 0,
858
- **loss_kwargs,
859
- ) -> Union[Tuple, CausalLMOutputWithPast]:
860
- r"""
861
- Args:
862
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
863
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
864
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
865
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
866
-
867
- num_logits_to_keep (`int`, *optional*):
868
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
869
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
870
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
871
-
872
- Returns:
873
-
874
- Example:
875
-
876
- ```python
877
- >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
878
-
879
- >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
880
- >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
881
-
882
- >>> prompt = "Hey, are you conscious? Can you talk to me?"
883
- >>> inputs = tokenizer(prompt, return_tensors="pt")
884
-
885
- >>> # Generate
886
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
887
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
888
- "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
889
- ```"""
890
-
891
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
892
- output_hidden_states = (
893
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
894
- )
895
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
896
-
897
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
898
- outputs = self.model(
899
- input_ids=input_ids,
900
- attention_mask=attention_mask,
901
- position_ids=position_ids,
902
- past_key_values=past_key_values,
903
- inputs_embeds=inputs_embeds,
904
- use_cache=use_cache,
905
- output_attentions=output_attentions,
906
- output_hidden_states=output_hidden_states,
907
- return_dict=return_dict,
908
- cache_position=cache_position,
909
- )
910
-
911
- hidden_states = outputs[0]
912
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
913
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
914
-
915
- loss = None
916
- if labels is not None:
917
- loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
918
-
919
- if not return_dict:
920
- output = (logits,) + outputs[1:]
921
- return (loss,) + output if loss is not None else output
922
-
923
- return CausalLMOutputWithPast(
924
- loss=loss,
925
- logits=logits,
926
- past_key_values=outputs.past_key_values,
927
- hidden_states=outputs.hidden_states,
928
- attentions=outputs.attentions,
929
- )
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """PyTorch Qwen2 model."""
5
+
6
+ import math
7
+ from typing import List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.utils.checkpoint
11
+ from torch import nn
12
+
13
+ from transformers.activations import ACT2FN
14
+ from transformers.cache_utils import Cache, DynamicCache
15
+ from transformers.generation import GenerationMixin
16
+ from transformers.modeling_outputs import (
17
+ BaseModelOutputWithPast,
18
+ CausalLMOutputWithPast,
19
+ )
20
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
21
+ from transformers.modeling_utils import PreTrainedModel
22
+ from transformers.utils import (
23
+ add_start_docstrings,
24
+ add_start_docstrings_to_model_forward,
25
+ is_flash_attn_2_available,
26
+ is_flash_attn_greater_or_equal_2_10,
27
+ logging,
28
+ replace_return_docstrings,
29
+ )
30
+ from .configuration_qwen2 import Qwen2Config
31
+
32
+
33
+ if is_flash_attn_2_available():
34
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ _CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B"
41
+ _CONFIG_FOR_DOC = "Qwen2Config"
42
+
43
+
44
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Qwen2
45
+ class Qwen2RMSNorm(nn.Module):
46
+ def __init__(self, hidden_size, eps=1e-6):
47
+ """
48
+ Qwen2RMSNorm is equivalent to T5LayerNorm
49
+ """
50
+ super().__init__()
51
+ self.weight = nn.Parameter(torch.ones(hidden_size))
52
+ self.variance_epsilon = eps
53
+
54
+ def forward(self, hidden_states):
55
+ input_dtype = hidden_states.dtype
56
+ hidden_states = hidden_states.to(torch.float32)
57
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
58
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
59
+ return self.weight * hidden_states.to(input_dtype)
60
+
61
+ def extra_repr(self):
62
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
63
+
64
+
65
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2
66
+ class Qwen2RotaryEmbedding(nn.Module):
67
+ def __init__(
68
+ self,
69
+ dim=None,
70
+ max_position_embeddings=2048,
71
+ base=10000,
72
+ device=None,
73
+ scaling_factor=1.0,
74
+ rope_type="default",
75
+ config: Optional[Qwen2Config] = None,
76
+ ):
77
+ super().__init__()
78
+ # TODO (joao): remove the `if` below, only used for BC
79
+ self.rope_kwargs = {}
80
+ if config is None:
81
+ logger.warning_once(
82
+ "`Qwen2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
83
+ "`config` argument. All other arguments will be removed in v4.46"
84
+ )
85
+ self.rope_kwargs = {
86
+ "rope_type": rope_type,
87
+ "factor": scaling_factor,
88
+ "dim": dim,
89
+ "base": base,
90
+ "max_position_embeddings": max_position_embeddings,
91
+ }
92
+ self.rope_type = rope_type
93
+ self.max_seq_len_cached = max_position_embeddings
94
+ self.original_max_seq_len = max_position_embeddings
95
+ else:
96
+ # BC: "rope_type" was originally "type"
97
+ if config.rope_scaling is not None:
98
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
99
+ else:
100
+ self.rope_type = "default"
101
+ self.max_seq_len_cached = config.max_position_embeddings
102
+ self.original_max_seq_len = config.max_position_embeddings
103
+
104
+ self.config = config
105
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
106
+
107
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
108
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
109
+ self.original_inv_freq = self.inv_freq
110
+
111
+ def _dynamic_frequency_update(self, position_ids, device):
112
+ """
113
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
114
+ 1 - growing beyond the cached sequence length (allow scaling)
115
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
116
+ """
117
+ seq_len = torch.max(position_ids) + 1
118
+ if seq_len > self.max_seq_len_cached: # growth
119
+ inv_freq, self.attention_scaling = self.rope_init_fn(
120
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
121
+ )
122
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
123
+ self.max_seq_len_cached = seq_len
124
+
125
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
126
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
127
+ self.max_seq_len_cached = self.original_max_seq_len
128
+
129
+ @torch.no_grad()
130
+ def forward(self, x, position_ids):
131
+ if "dynamic" in self.rope_type:
132
+ self._dynamic_frequency_update(position_ids, device=x.device)
133
+
134
+ # Core RoPE block
135
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
136
+ position_ids_expanded = position_ids[:, None, :].float()
137
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
138
+ device_type = x.device.type
139
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
140
+ with torch.autocast(device_type=device_type, enabled=False):
141
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
142
+ emb = torch.cat((freqs, freqs), dim=-1)
143
+ cos = emb.cos()
144
+ sin = emb.sin()
145
+
146
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
147
+ cos = cos * self.attention_scaling
148
+ sin = sin * self.attention_scaling
149
+
150
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
151
+
152
+
153
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
154
+ def rotate_half(x):
155
+ """Rotates half the hidden dims of the input."""
156
+ x1 = x[..., : x.shape[-1] // 2]
157
+ x2 = x[..., x.shape[-1] // 2 :]
158
+ return torch.cat((-x2, x1), dim=-1)
159
+
160
+
161
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
162
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
163
+ """Applies Rotary Position Embedding to the query and key tensors.
164
+
165
+ Args:
166
+ q (`torch.Tensor`): The query tensor.
167
+ k (`torch.Tensor`): The key tensor.
168
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
169
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
170
+ position_ids (`torch.Tensor`, *optional*):
171
+ Deprecated and unused.
172
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
173
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
174
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
175
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
176
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
177
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
178
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
179
+ Returns:
180
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
181
+ """
182
+ cos = cos.unsqueeze(unsqueeze_dim)
183
+ sin = sin.unsqueeze(unsqueeze_dim)
184
+ q_embed = (q * cos) + (rotate_half(q) * sin)
185
+ k_embed = (k * cos) + (rotate_half(k) * sin)
186
+ return q_embed, k_embed
187
+
188
+
189
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
190
+ class Qwen2MLP(nn.Module):
191
+ def __init__(self, config):
192
+ super().__init__()
193
+ self.hidden_size = config.hidden_size
194
+ self.intermediate_size = config.intermediate_size
195
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
196
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
197
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
198
+ self.act_fn = ACT2FN[config.hidden_act]
199
+
200
+ def forward(self, hidden_state):
201
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
202
+
203
+
204
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
205
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
206
+ """
207
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
208
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
209
+ """
210
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
211
+ if n_rep == 1:
212
+ return hidden_states
213
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
214
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
215
+
216
+
217
+ class Qwen2Attention(nn.Module):
218
+ """
219
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
220
+ and "Generating Long Sequences with Sparse Transformers".
221
+ """
222
+
223
+ def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
224
+ super().__init__()
225
+ self.config = config
226
+ self.layer_idx = layer_idx
227
+ if layer_idx is None:
228
+ logger.warning_once(
229
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
230
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
231
+ "when creating this class."
232
+ )
233
+
234
+ self.hidden_size = config.hidden_size
235
+ self.num_heads = config.num_attention_heads
236
+ self.head_dim = self.hidden_size // self.num_heads
237
+ self.num_key_value_heads = config.num_key_value_heads
238
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
239
+ self.max_position_embeddings = config.max_position_embeddings
240
+ self.rope_theta = config.rope_theta
241
+ self.is_causal = config.is_causal
242
+ self.attention_dropout = config.attention_dropout
243
+
244
+ if (self.head_dim * self.num_heads) != self.hidden_size:
245
+ raise ValueError(
246
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
247
+ f" and `num_heads`: {self.num_heads})."
248
+ )
249
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
250
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
251
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
252
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
253
+
254
+ def forward(
255
+ self,
256
+ hidden_states: torch.Tensor,
257
+ attention_mask: Optional[torch.Tensor] = None,
258
+ position_ids: Optional[torch.LongTensor] = None,
259
+ past_key_value: Optional[Cache] = None,
260
+ output_attentions: bool = False,
261
+ use_cache: bool = False,
262
+ cache_position: Optional[torch.LongTensor] = None,
263
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
264
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
265
+ bsz, q_len, _ = hidden_states.size()
266
+
267
+ query_states = self.q_proj(hidden_states)
268
+ key_states = self.k_proj(hidden_states)
269
+ value_states = self.v_proj(hidden_states)
270
+
271
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
272
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
273
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
274
+
275
+ if position_embeddings is None:
276
+ logger.warning_once(
277
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
278
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
279
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
280
+ "removed and `position_embeddings` will be mandatory."
281
+ )
282
+ cos, sin = self.rotary_emb(value_states, position_ids)
283
+ else:
284
+ cos, sin = position_embeddings
285
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
286
+
287
+ if past_key_value is not None:
288
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
289
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
290
+
291
+ # repeat k/v heads if n_kv_heads < n_heads
292
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
293
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
294
+
295
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
296
+ if attention_mask is not None: # no matter the length, we just slice it
297
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
298
+ attn_weights = attn_weights + causal_mask
299
+
300
+ # upcast attention to fp32
301
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
302
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
303
+ attn_output = torch.matmul(attn_weights, value_states)
304
+
305
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
306
+ raise ValueError(
307
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
308
+ f" {attn_output.size()}"
309
+ )
310
+
311
+ attn_output = attn_output.transpose(1, 2).contiguous()
312
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
313
+
314
+ attn_output = self.o_proj(attn_output)
315
+
316
+ if not output_attentions:
317
+ attn_weights = None
318
+
319
+ return attn_output, attn_weights, past_key_value
320
+
321
+
322
+ class Qwen2FlashAttention2(Qwen2Attention):
323
+ """
324
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
325
+ as the weights of the module stays untouched. The only required change would be on the forward pass
326
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
327
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
328
+ config.max_window_layers layers.
329
+ """
330
+
331
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
332
+ def __init__(self, *args, **kwargs):
333
+ super().__init__(*args, **kwargs)
334
+
335
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
336
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
337
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
338
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
339
+
340
+ def forward(
341
+ self,
342
+ hidden_states: torch.Tensor,
343
+ attention_mask: Optional[torch.Tensor] = None,
344
+ position_ids: Optional[torch.LongTensor] = None,
345
+ past_key_value: Optional[Cache] = None,
346
+ output_attentions: bool = False,
347
+ use_cache: bool = False,
348
+ cache_position: Optional[torch.LongTensor] = None,
349
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
350
+ ):
351
+ bsz, q_len, _ = hidden_states.size()
352
+
353
+ query_states = self.q_proj(hidden_states)
354
+ key_states = self.k_proj(hidden_states)
355
+ value_states = self.v_proj(hidden_states)
356
+
357
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
358
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
359
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
360
+
361
+ if position_embeddings is None:
362
+ logger.warning_once(
363
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
364
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
365
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
366
+ "removed and `position_embeddings` will be mandatory."
367
+ )
368
+ cos, sin = self.rotary_emb(value_states, position_ids)
369
+ else:
370
+ cos, sin = position_embeddings
371
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
372
+
373
+ if past_key_value is not None:
374
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
375
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
376
+
377
+ # repeat k/v heads if n_kv_heads < n_heads
378
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
379
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
380
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
381
+
382
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
383
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
384
+ # cast them back in float16 just to be sure everything works as expected.
385
+ input_dtype = query_states.dtype
386
+ if input_dtype == torch.float32:
387
+ if torch.is_autocast_enabled():
388
+ target_dtype = torch.get_autocast_gpu_dtype()
389
+ # Handle the case where the model is quantized
390
+ elif hasattr(self.config, "_pre_quantization_dtype"):
391
+ target_dtype = self.config._pre_quantization_dtype
392
+ else:
393
+ target_dtype = self.q_proj.weight.dtype
394
+
395
+ logger.warning_once(
396
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
397
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
398
+ f" {target_dtype}."
399
+ )
400
+
401
+ query_states = query_states.to(target_dtype)
402
+ key_states = key_states.to(target_dtype)
403
+ value_states = value_states.to(target_dtype)
404
+
405
+ # Reashape to the expected shape for Flash Attention
406
+ query_states = query_states.transpose(1, 2)
407
+ key_states = key_states.transpose(1, 2)
408
+ value_states = value_states.transpose(1, 2)
409
+
410
+ if (
411
+ self.config.use_sliding_window
412
+ and getattr(self.config, "sliding_window", None) is not None
413
+ and self.layer_idx >= self.config.max_window_layers
414
+ ):
415
+ sliding_window = self.config.sliding_window
416
+ else:
417
+ sliding_window = None
418
+
419
+ attn_output = _flash_attention_forward(
420
+ query_states,
421
+ key_states,
422
+ value_states,
423
+ attention_mask,
424
+ q_len,
425
+ position_ids=position_ids,
426
+ dropout=dropout_rate,
427
+ sliding_window=sliding_window,
428
+ is_causal=self.is_causal,
429
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
430
+ )
431
+
432
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
433
+ attn_output = self.o_proj(attn_output)
434
+
435
+ if not output_attentions:
436
+ attn_weights = None
437
+
438
+ return attn_output, attn_weights, past_key_value
439
+
440
+
441
+ QWEN2_ATTENTION_CLASSES = {
442
+ "eager": Qwen2Attention,
443
+ "flash_attention_2": Qwen2FlashAttention2,
444
+ }
445
+
446
+
447
+ class Qwen2DecoderLayer(nn.Module):
448
+ def __init__(self, config: Qwen2Config, layer_idx: int):
449
+ super().__init__()
450
+ self.hidden_size = config.hidden_size
451
+
452
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
453
+ logger.warning_once(
454
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
455
+ "unexpected results may be encountered."
456
+ )
457
+ self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
458
+
459
+ self.mlp = Qwen2MLP(config)
460
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
461
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
462
+
463
+ def forward(
464
+ self,
465
+ hidden_states: torch.Tensor,
466
+ attention_mask: Optional[torch.Tensor] = None,
467
+ position_ids: Optional[torch.LongTensor] = None,
468
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
469
+ output_attentions: Optional[bool] = False,
470
+ use_cache: Optional[bool] = False,
471
+ cache_position: Optional[torch.LongTensor] = None,
472
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
473
+ **kwargs,
474
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
475
+ """
476
+ Args:
477
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
478
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
479
+ `(batch, sequence_length)` where padding elements are indicated by 0.
480
+ output_attentions (`bool`, *optional*):
481
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
482
+ returned tensors for more detail.
483
+ use_cache (`bool`, *optional*):
484
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
485
+ (see `past_key_values`).
486
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
487
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
488
+ Indices depicting the position of the input sequence tokens in the sequence.
489
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
490
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
491
+ with `head_dim` being the embedding dimension of each attention head.
492
+ kwargs (`dict`, *optional*):
493
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
494
+ into the model
495
+ """
496
+
497
+ residual = hidden_states
498
+
499
+ hidden_states = self.input_layernorm(hidden_states)
500
+
501
+ # Self Attention
502
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
503
+ hidden_states=hidden_states,
504
+ attention_mask=attention_mask,
505
+ position_ids=position_ids,
506
+ past_key_value=past_key_value,
507
+ output_attentions=output_attentions,
508
+ use_cache=use_cache,
509
+ cache_position=cache_position,
510
+ position_embeddings=position_embeddings,
511
+ )
512
+ hidden_states = residual + hidden_states
513
+
514
+ # Fully Connected
515
+ residual = hidden_states
516
+ hidden_states = self.post_attention_layernorm(hidden_states)
517
+ hidden_states = self.mlp(hidden_states)
518
+ hidden_states = residual + hidden_states
519
+
520
+ outputs = (hidden_states,)
521
+
522
+ if output_attentions:
523
+ outputs += (self_attn_weights,)
524
+
525
+ if use_cache:
526
+ outputs += (present_key_value,)
527
+
528
+ return outputs
529
+
530
+
531
+ QWEN2_START_DOCSTRING = r"""
532
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
533
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
534
+ etc.)
535
+
536
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
537
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
538
+ and behavior.
539
+
540
+ Parameters:
541
+ config ([`Qwen2Config`]):
542
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
543
+ load the weights associated with the model, only the configuration. Check out the
544
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
545
+ """
546
+
547
+
548
+ @add_start_docstrings(
549
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
550
+ QWEN2_START_DOCSTRING,
551
+ )
552
+ class Qwen2PreTrainedModel(PreTrainedModel):
553
+ config_class = Qwen2Config
554
+ base_model_prefix = "model"
555
+ supports_gradient_checkpointing = True
556
+ _no_split_modules = ["Qwen2DecoderLayer"]
557
+ _skip_keys_device_placement = "past_key_values"
558
+ _supports_flash_attn_2 = True
559
+ _supports_cache_class = True
560
+ _supports_quantized_cache = True
561
+ _supports_static_cache = True
562
+
563
+ def _init_weights(self, module):
564
+ std = self.config.initializer_range
565
+ if isinstance(module, nn.Linear):
566
+ module.weight.data.normal_(mean=0.0, std=std)
567
+ if module.bias is not None:
568
+ module.bias.data.zero_()
569
+ elif isinstance(module, nn.Embedding):
570
+ module.weight.data.normal_(mean=0.0, std=std)
571
+ if module.padding_idx is not None:
572
+ module.weight.data[module.padding_idx].zero_()
573
+
574
+
575
+ QWEN2_INPUTS_DOCSTRING = r"""
576
+ Args:
577
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
578
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
579
+ it.
580
+
581
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
582
+ [`PreTrainedTokenizer.__call__`] for details.
583
+
584
+ [What are input IDs?](../glossary#input-ids)
585
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
586
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
587
+
588
+ - 1 for tokens that are **not masked**,
589
+ - 0 for tokens that are **masked**.
590
+
591
+ [What are attention masks?](../glossary#attention-mask)
592
+
593
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
594
+ [`PreTrainedTokenizer.__call__`] for details.
595
+
596
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
597
+ `past_key_values`).
598
+
599
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
600
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
601
+ information on the default strategy.
602
+
603
+ - 1 indicates the head is **not masked**,
604
+ - 0 indicates the head is **masked**.
605
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
606
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
607
+ config.n_positions - 1]`.
608
+
609
+ [What are position IDs?](../glossary#position-ids)
610
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
611
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
612
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
613
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
614
+
615
+ Two formats are allowed:
616
+ - a [`~cache_utils.Cache`] instance, see our
617
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
618
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
619
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
620
+ cache format.
621
+
622
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
623
+ legacy cache format will be returned.
624
+
625
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
626
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
627
+ of shape `(batch_size, sequence_length)`.
628
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
629
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
630
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
631
+ model's internal embedding lookup matrix.
632
+ use_cache (`bool`, *optional*):
633
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
634
+ `past_key_values`).
635
+ output_attentions (`bool`, *optional*):
636
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
637
+ tensors for more detail.
638
+ output_hidden_states (`bool`, *optional*):
639
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
640
+ more detail.
641
+ return_dict (`bool`, *optional*):
642
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
643
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
644
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
645
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
646
+ the complete sequence length.
647
+ """
648
+
649
+
650
+ @add_start_docstrings(
651
+ "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
652
+ QWEN2_START_DOCSTRING,
653
+ )
654
+ class Qwen2Model(Qwen2PreTrainedModel):
655
+ """
656
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
657
+
658
+ Args:
659
+ config: Qwen2Config
660
+ """
661
+
662
+ def __init__(self, config: Qwen2Config):
663
+ super().__init__(config)
664
+ self.padding_idx = config.pad_token_id
665
+ self.vocab_size = config.vocab_size
666
+
667
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
668
+ self.layers = nn.ModuleList(
669
+ [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
670
+ )
671
+ self._attn_implementation = config._attn_implementation
672
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
673
+ self.rotary_emb = Qwen2RotaryEmbedding(config=config)
674
+
675
+ self.gradient_checkpointing = False
676
+ # Initialize weights and apply final processing
677
+ self.post_init()
678
+
679
+ def get_input_embeddings(self):
680
+ return self.embed_tokens
681
+
682
+ def set_input_embeddings(self, value):
683
+ self.embed_tokens = value
684
+
685
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
686
+ def forward(
687
+ self,
688
+ input_ids: torch.LongTensor = None,
689
+ attention_mask: Optional[torch.Tensor] = None,
690
+ position_ids: Optional[torch.LongTensor] = None,
691
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
692
+ inputs_embeds: Optional[torch.FloatTensor] = None,
693
+ use_cache: Optional[bool] = None,
694
+ output_attentions: Optional[bool] = None,
695
+ output_hidden_states: Optional[bool] = None,
696
+ return_dict: Optional[bool] = None,
697
+ cache_position: Optional[torch.LongTensor] = None,
698
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
699
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
700
+ output_hidden_states = (
701
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
702
+ )
703
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
704
+
705
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
706
+
707
+ if (input_ids is None) ^ (inputs_embeds is not None):
708
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
709
+
710
+ if self.gradient_checkpointing and self.training:
711
+ if use_cache:
712
+ logger.warning_once(
713
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
714
+ )
715
+ use_cache = False
716
+
717
+ # kept for BC (non `Cache` `past_key_values` inputs)
718
+ return_legacy_cache = False
719
+ if use_cache and not isinstance(past_key_values, Cache):
720
+ return_legacy_cache = True
721
+ if past_key_values is None:
722
+ past_key_values = DynamicCache()
723
+ else:
724
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
725
+ logger.warning_once(
726
+ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
727
+ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
728
+ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
729
+ )
730
+
731
+ if inputs_embeds is None:
732
+ inputs_embeds = self.embed_tokens(input_ids)
733
+
734
+ if cache_position is None:
735
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
736
+ cache_position = torch.arange(
737
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
738
+ )
739
+ if position_ids is None:
740
+ position_ids = cache_position.unsqueeze(0)
741
+
742
+ if attention_mask is not None and 0.0 in attention_mask:
743
+ causal_mask = attention_mask
744
+ else:
745
+ causal_mask = None
746
+
747
+ hidden_states = inputs_embeds
748
+ # create position embeddings to be shared across the decoder layers
749
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
750
+
751
+ # decoder layers
752
+ all_hidden_states = () if output_hidden_states else None
753
+ all_self_attns = () if output_attentions else None
754
+ next_decoder_cache = None
755
+
756
+ for decoder_layer in self.layers:
757
+ if output_hidden_states:
758
+ all_hidden_states += (hidden_states,)
759
+
760
+ if self.gradient_checkpointing and self.training:
761
+ layer_outputs = self._gradient_checkpointing_func(
762
+ decoder_layer.__call__,
763
+ hidden_states,
764
+ causal_mask,
765
+ position_ids,
766
+ past_key_values,
767
+ output_attentions,
768
+ use_cache,
769
+ cache_position,
770
+ position_embeddings,
771
+ )
772
+ else:
773
+ layer_outputs = decoder_layer(
774
+ hidden_states,
775
+ attention_mask=causal_mask,
776
+ position_ids=position_ids,
777
+ past_key_value=past_key_values,
778
+ output_attentions=output_attentions,
779
+ use_cache=use_cache,
780
+ cache_position=cache_position,
781
+ position_embeddings=position_embeddings,
782
+ )
783
+
784
+ hidden_states = layer_outputs[0]
785
+
786
+ if use_cache:
787
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
788
+
789
+ if output_attentions:
790
+ all_self_attns += (layer_outputs[1],)
791
+
792
+ hidden_states = self.norm(hidden_states)
793
+
794
+ # add hidden states from the last decoder layer
795
+ if output_hidden_states:
796
+ all_hidden_states += (hidden_states,)
797
+
798
+ next_cache = next_decoder_cache if use_cache else None
799
+ if return_legacy_cache:
800
+ next_cache = next_cache.to_legacy_cache()
801
+
802
+ if not return_dict:
803
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
804
+ return BaseModelOutputWithPast(
805
+ last_hidden_state=hidden_states,
806
+ past_key_values=next_cache,
807
+ hidden_states=all_hidden_states,
808
+ attentions=all_self_attns,
809
+ )
810
+
811
+
812
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
813
+ _tied_weights_keys = ["lm_head.weight"]
814
+
815
+ def __init__(self, config):
816
+ super().__init__(config)
817
+ self.model = Qwen2Model(config)
818
+ self.vocab_size = config.vocab_size
819
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
820
+
821
+ # Initialize weights and apply final processing
822
+ self.post_init()
823
+
824
+ def get_input_embeddings(self):
825
+ return self.model.embed_tokens
826
+
827
+ def set_input_embeddings(self, value):
828
+ self.model.embed_tokens = value
829
+
830
+ def get_output_embeddings(self):
831
+ return self.lm_head
832
+
833
+ def set_output_embeddings(self, new_embeddings):
834
+ self.lm_head = new_embeddings
835
+
836
+ def set_decoder(self, decoder):
837
+ self.model = decoder
838
+
839
+ def get_decoder(self):
840
+ return self.model
841
+
842
+ @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
843
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
844
+ def forward(
845
+ self,
846
+ input_ids: torch.LongTensor = None,
847
+ attention_mask: Optional[torch.Tensor] = None,
848
+ position_ids: Optional[torch.LongTensor] = None,
849
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
850
+ inputs_embeds: Optional[torch.FloatTensor] = None,
851
+ labels: Optional[torch.LongTensor] = None,
852
+ use_cache: Optional[bool] = None,
853
+ output_attentions: Optional[bool] = None,
854
+ output_hidden_states: Optional[bool] = None,
855
+ return_dict: Optional[bool] = None,
856
+ cache_position: Optional[torch.LongTensor] = None,
857
+ num_logits_to_keep: int = 0,
858
+ **loss_kwargs,
859
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
860
+ r"""
861
+ Args:
862
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
863
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
864
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
865
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
866
+
867
+ num_logits_to_keep (`int`, *optional*):
868
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
869
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
870
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
871
+
872
+ Returns:
873
+
874
+ Example:
875
+
876
+ ```python
877
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
878
+
879
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
880
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
881
+
882
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
883
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
884
+
885
+ >>> # Generate
886
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
887
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
888
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
889
+ ```"""
890
+
891
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
892
+ output_hidden_states = (
893
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
894
+ )
895
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
896
+
897
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
898
+ outputs = self.model(
899
+ input_ids=input_ids,
900
+ attention_mask=attention_mask,
901
+ position_ids=position_ids,
902
+ past_key_values=past_key_values,
903
+ inputs_embeds=inputs_embeds,
904
+ use_cache=use_cache,
905
+ output_attentions=output_attentions,
906
+ output_hidden_states=output_hidden_states,
907
+ return_dict=return_dict,
908
+ cache_position=cache_position,
909
+ )
910
+
911
+ hidden_states = outputs[0]
912
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
913
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
914
+
915
+ loss = None
916
+ if labels is not None:
917
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
918
+
919
+ if not return_dict:
920
+ output = (logits,) + outputs[1:]
921
+ return (loss,) + output if loss is not None else output
922
+
923
+ return CausalLMOutputWithPast(
924
+ loss=loss,
925
+ logits=logits,
926
+ past_key_values=outputs.past_key_values,
927
+ hidden_states=outputs.hidden_states,
928
+ attentions=outputs.attentions,
929
+ )
modeling/qwen2/tokenization_qwen2.py CHANGED
@@ -1,328 +1,328 @@
1
- # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """Tokenization classes for Qwen2."""
5
-
6
- import json
7
- import os
8
- import unicodedata
9
- from functools import lru_cache
10
- from typing import Optional, Tuple
11
-
12
- import regex as re
13
-
14
- from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
15
- from transformers.utils import logging
16
-
17
-
18
- logger = logging.get_logger(__name__)
19
-
20
- VOCAB_FILES_NAMES = {
21
- "vocab_file": "vocab.json",
22
- "merges_file": "merges.txt",
23
- }
24
-
25
-
26
- MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
27
-
28
- PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
29
-
30
-
31
- @lru_cache()
32
- # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
33
- def bytes_to_unicode():
34
- """
35
- Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
36
- characters the bpe code barfs on.
37
-
38
- The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
39
- if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
40
- decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
41
- tables between utf-8 bytes and unicode strings.
42
- """
43
- bs = (
44
- list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
45
- )
46
- cs = bs[:]
47
- n = 0
48
- for b in range(2**8):
49
- if b not in bs:
50
- bs.append(b)
51
- cs.append(2**8 + n)
52
- n += 1
53
- cs = [chr(n) for n in cs]
54
- return dict(zip(bs, cs))
55
-
56
-
57
- # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
58
- def get_pairs(word):
59
- """
60
- Return set of symbol pairs in a word.
61
-
62
- Word is represented as tuple of symbols (symbols being variable-length strings).
63
- """
64
- pairs = set()
65
- prev_char = word[0]
66
- for char in word[1:]:
67
- pairs.add((prev_char, char))
68
- prev_char = char
69
- return pairs
70
-
71
-
72
- class Qwen2Tokenizer(PreTrainedTokenizer):
73
- """
74
- Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
75
-
76
- Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
77
- be encoded differently whether it is at the beginning of the sentence (without space) or not:
78
-
79
- ```python
80
- >>> from transformers import Qwen2Tokenizer
81
-
82
- >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
83
- >>> tokenizer("Hello world")["input_ids"]
84
- [9707, 1879]
85
-
86
- >>> tokenizer(" Hello world")["input_ids"]
87
- [21927, 1879]
88
- ```
89
- This is expected.
90
-
91
- You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
92
-
93
- This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
94
- this superclass for more information regarding those methods.
95
-
96
- Args:
97
- vocab_file (`str`):
98
- Path to the vocabulary file.
99
- merges_file (`str`):
100
- Path to the merges file.
101
- errors (`str`, *optional*, defaults to `"replace"`):
102
- Paradigm to follow when decoding bytes to UTF-8. See
103
- [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
104
- unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
105
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
106
- token instead.
107
- bos_token (`str`, *optional*):
108
- The beginning of sequence token. Not applicable for this tokenizer.
109
- eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
110
- The end of sequence token.
111
- pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
112
- The token used for padding, for example when batching sequences of different lengths.
113
- clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
114
- Whether or not the model should cleanup the spaces that were added when splitting the input text during the
115
- tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
116
- split_special_tokens (`bool`, *optional*, defaults to `False`):
117
- Whether or not the special tokens should be split during the tokenization process. The default behavior is
118
- to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
119
- ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
120
- '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
121
- """
122
-
123
- vocab_files_names = VOCAB_FILES_NAMES
124
- model_input_names = ["input_ids", "attention_mask"]
125
-
126
- def __init__(
127
- self,
128
- vocab_file,
129
- merges_file,
130
- errors="replace",
131
- unk_token="<|endoftext|>",
132
- bos_token=None,
133
- eos_token="<|endoftext|>",
134
- pad_token="<|endoftext|>",
135
- clean_up_tokenization_spaces=False,
136
- split_special_tokens=False,
137
- **kwargs,
138
- ):
139
- # Qwen vocab does not contain control tokens; added tokens need to be special
140
- bos_token = (
141
- AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
142
- if isinstance(bos_token, str)
143
- else bos_token
144
- )
145
- eos_token = (
146
- AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
147
- if isinstance(eos_token, str)
148
- else eos_token
149
- )
150
- unk_token = (
151
- AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
152
- if isinstance(unk_token, str)
153
- else unk_token
154
- )
155
- pad_token = (
156
- AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
157
- if isinstance(pad_token, str)
158
- else pad_token
159
- )
160
-
161
- with open(vocab_file, encoding="utf-8") as vocab_handle:
162
- self.encoder = json.load(vocab_handle)
163
- self.decoder = {v: k for k, v in self.encoder.items()}
164
- self.errors = errors # how to handle errors in decoding
165
- self.byte_encoder = bytes_to_unicode()
166
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
167
- bpe_merges = []
168
- with open(merges_file, encoding="utf-8") as merges_handle:
169
- for i, line in enumerate(merges_handle):
170
- line = line.strip()
171
- if (i == 0 and line.startswith("#version:")) or not line:
172
- continue
173
- bpe_merges.append(tuple(line.split()))
174
- self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
175
- # NOTE: the cache can grow without bound and will get really large for long running processes
176
- # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
177
- # not a memory leak but appears as one.
178
- # GPT2Tokenizer has the same problem, so let's be consistent.
179
- self.cache = {}
180
-
181
- self.pat = re.compile(PRETOKENIZE_REGEX)
182
-
183
- if kwargs.get("add_prefix_space", False):
184
- logger.warning_once(
185
- f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
186
- )
187
-
188
- super().__init__(
189
- errors=errors,
190
- bos_token=bos_token,
191
- eos_token=eos_token,
192
- pad_token=pad_token,
193
- unk_token=unk_token,
194
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
195
- split_special_tokens=split_special_tokens,
196
- **kwargs,
197
- )
198
-
199
- @property
200
- def vocab_size(self) -> int:
201
- return len(self.encoder)
202
-
203
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
204
- def get_vocab(self):
205
- return dict(self.encoder, **self.added_tokens_encoder)
206
-
207
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
208
- def bpe(self, token):
209
- if token in self.cache:
210
- return self.cache[token]
211
- word = tuple(token)
212
- pairs = get_pairs(word)
213
-
214
- if not pairs:
215
- return token
216
-
217
- while True:
218
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
219
- if bigram not in self.bpe_ranks:
220
- break
221
- first, second = bigram
222
- new_word = []
223
- i = 0
224
- while i < len(word):
225
- try:
226
- j = word.index(first, i)
227
- except ValueError:
228
- new_word.extend(word[i:])
229
- break
230
- else:
231
- new_word.extend(word[i:j])
232
- i = j
233
-
234
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
235
- new_word.append(first + second)
236
- i += 2
237
- else:
238
- new_word.append(word[i])
239
- i += 1
240
- new_word = tuple(new_word)
241
- word = new_word
242
- if len(word) == 1:
243
- break
244
- else:
245
- pairs = get_pairs(word)
246
- word = " ".join(word)
247
- self.cache[token] = word
248
- return word
249
-
250
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
251
- def _tokenize(self, text):
252
- """Tokenize a string."""
253
- bpe_tokens = []
254
- for token in re.findall(self.pat, text):
255
- token = "".join(
256
- self.byte_encoder[b] for b in token.encode("utf-8")
257
- ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
258
- bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
259
- return bpe_tokens
260
-
261
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
262
- def _convert_token_to_id(self, token):
263
- """Converts a token (str) in an id using the vocab."""
264
- return self.encoder.get(token, self.encoder.get(self.unk_token))
265
-
266
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
267
- def _convert_id_to_token(self, index):
268
- """Converts an index (integer) in a token (str) using the vocab."""
269
- return self.decoder.get(index)
270
-
271
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
272
- def convert_tokens_to_string(self, tokens):
273
- """Converts a sequence of tokens (string) in a single string."""
274
- text = "".join(tokens)
275
- text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
276
- return text
277
-
278
- def decode(
279
- self,
280
- token_ids,
281
- skip_special_tokens: bool = False,
282
- clean_up_tokenization_spaces: Optional[bool] = False,
283
- spaces_between_special_tokens: bool = False,
284
- **kwargs,
285
- ) -> str:
286
- # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
287
- # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
288
- return super().decode(
289
- token_ids,
290
- skip_special_tokens=skip_special_tokens,
291
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
292
- spaces_between_special_tokens=spaces_between_special_tokens,
293
- **kwargs,
294
- )
295
-
296
- # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
297
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
298
- if not os.path.isdir(save_directory):
299
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
300
- return
301
- vocab_file = os.path.join(
302
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
303
- )
304
- merge_file = os.path.join(
305
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
306
- )
307
-
308
- with open(vocab_file, "w", encoding="utf-8") as f:
309
- f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
310
-
311
- index = 0
312
- with open(merge_file, "w", encoding="utf-8") as writer:
313
- writer.write("#version: 0.2\n")
314
- for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
315
- if index != token_index:
316
- logger.warning(
317
- f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
318
- " Please check that the tokenizer is not corrupted!"
319
- )
320
- index = token_index
321
- writer.write(" ".join(bpe_tokens) + "\n")
322
- index += 1
323
-
324
- return vocab_file, merge_file
325
-
326
- def prepare_for_tokenization(self, text, **kwargs):
327
- text = unicodedata.normalize("NFC", text)
328
- return (text, kwargs)
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tokenization classes for Qwen2."""
5
+
6
+ import json
7
+ import os
8
+ import unicodedata
9
+ from functools import lru_cache
10
+ from typing import Optional, Tuple
11
+
12
+ import regex as re
13
+
14
+ from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
15
+ from transformers.utils import logging
16
+
17
+
18
+ logger = logging.get_logger(__name__)
19
+
20
+ VOCAB_FILES_NAMES = {
21
+ "vocab_file": "vocab.json",
22
+ "merges_file": "merges.txt",
23
+ }
24
+
25
+
26
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
27
+
28
+ PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
29
+
30
+
31
+ @lru_cache()
32
+ # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
33
+ def bytes_to_unicode():
34
+ """
35
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
36
+ characters the bpe code barfs on.
37
+
38
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
39
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
40
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
41
+ tables between utf-8 bytes and unicode strings.
42
+ """
43
+ bs = (
44
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
45
+ )
46
+ cs = bs[:]
47
+ n = 0
48
+ for b in range(2**8):
49
+ if b not in bs:
50
+ bs.append(b)
51
+ cs.append(2**8 + n)
52
+ n += 1
53
+ cs = [chr(n) for n in cs]
54
+ return dict(zip(bs, cs))
55
+
56
+
57
+ # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
58
+ def get_pairs(word):
59
+ """
60
+ Return set of symbol pairs in a word.
61
+
62
+ Word is represented as tuple of symbols (symbols being variable-length strings).
63
+ """
64
+ pairs = set()
65
+ prev_char = word[0]
66
+ for char in word[1:]:
67
+ pairs.add((prev_char, char))
68
+ prev_char = char
69
+ return pairs
70
+
71
+
72
+ class Qwen2Tokenizer(PreTrainedTokenizer):
73
+ """
74
+ Construct a Qwen2 tokenizer. Based on byte-level Byte-Pair-Encoding.
75
+
76
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
77
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
78
+
79
+ ```python
80
+ >>> from transformers import Qwen2Tokenizer
81
+
82
+ >>> tokenizer = Qwen2Tokenizer.from_pretrained("Qwen/Qwen-tokenizer")
83
+ >>> tokenizer("Hello world")["input_ids"]
84
+ [9707, 1879]
85
+
86
+ >>> tokenizer(" Hello world")["input_ids"]
87
+ [21927, 1879]
88
+ ```
89
+ This is expected.
90
+
91
+ You should not use GPT2Tokenizer instead, because of the different pretokenization rules.
92
+
93
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
94
+ this superclass for more information regarding those methods.
95
+
96
+ Args:
97
+ vocab_file (`str`):
98
+ Path to the vocabulary file.
99
+ merges_file (`str`):
100
+ Path to the merges file.
101
+ errors (`str`, *optional*, defaults to `"replace"`):
102
+ Paradigm to follow when decoding bytes to UTF-8. See
103
+ [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
104
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
105
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
106
+ token instead.
107
+ bos_token (`str`, *optional*):
108
+ The beginning of sequence token. Not applicable for this tokenizer.
109
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
110
+ The end of sequence token.
111
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
112
+ The token used for padding, for example when batching sequences of different lengths.
113
+ clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
114
+ Whether or not the model should cleanup the spaces that were added when splitting the input text during the
115
+ tokenization process. Not applicable to this tokenizer, since tokenization does not add spaces.
116
+ split_special_tokens (`bool`, *optional*, defaults to `False`):
117
+ Whether or not the special tokens should be split during the tokenization process. The default behavior is
118
+ to not split special tokens. This means that if `<|endoftext|>` is the `eos_token`, then `tokenizer.tokenize("<|endoftext|>") =
119
+ ['<|endoftext|>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<|endoftext|>")` will be give `['<',
120
+ '|', 'endo', 'ft', 'ext', '|', '>']`. This argument is only supported for `slow` tokenizers for the moment.
121
+ """
122
+
123
+ vocab_files_names = VOCAB_FILES_NAMES
124
+ model_input_names = ["input_ids", "attention_mask"]
125
+
126
+ def __init__(
127
+ self,
128
+ vocab_file,
129
+ merges_file,
130
+ errors="replace",
131
+ unk_token="<|endoftext|>",
132
+ bos_token=None,
133
+ eos_token="<|endoftext|>",
134
+ pad_token="<|endoftext|>",
135
+ clean_up_tokenization_spaces=False,
136
+ split_special_tokens=False,
137
+ **kwargs,
138
+ ):
139
+ # Qwen vocab does not contain control tokens; added tokens need to be special
140
+ bos_token = (
141
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
142
+ if isinstance(bos_token, str)
143
+ else bos_token
144
+ )
145
+ eos_token = (
146
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
147
+ if isinstance(eos_token, str)
148
+ else eos_token
149
+ )
150
+ unk_token = (
151
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
152
+ if isinstance(unk_token, str)
153
+ else unk_token
154
+ )
155
+ pad_token = (
156
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
157
+ if isinstance(pad_token, str)
158
+ else pad_token
159
+ )
160
+
161
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
162
+ self.encoder = json.load(vocab_handle)
163
+ self.decoder = {v: k for k, v in self.encoder.items()}
164
+ self.errors = errors # how to handle errors in decoding
165
+ self.byte_encoder = bytes_to_unicode()
166
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
167
+ bpe_merges = []
168
+ with open(merges_file, encoding="utf-8") as merges_handle:
169
+ for i, line in enumerate(merges_handle):
170
+ line = line.strip()
171
+ if (i == 0 and line.startswith("#version:")) or not line:
172
+ continue
173
+ bpe_merges.append(tuple(line.split()))
174
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
175
+ # NOTE: the cache can grow without bound and will get really large for long running processes
176
+ # (esp. for texts of language that do not use space between word, e.g. Chinese); technically
177
+ # not a memory leak but appears as one.
178
+ # GPT2Tokenizer has the same problem, so let's be consistent.
179
+ self.cache = {}
180
+
181
+ self.pat = re.compile(PRETOKENIZE_REGEX)
182
+
183
+ if kwargs.get("add_prefix_space", False):
184
+ logger.warning_once(
185
+ f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect."
186
+ )
187
+
188
+ super().__init__(
189
+ errors=errors,
190
+ bos_token=bos_token,
191
+ eos_token=eos_token,
192
+ pad_token=pad_token,
193
+ unk_token=unk_token,
194
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
195
+ split_special_tokens=split_special_tokens,
196
+ **kwargs,
197
+ )
198
+
199
+ @property
200
+ def vocab_size(self) -> int:
201
+ return len(self.encoder)
202
+
203
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab
204
+ def get_vocab(self):
205
+ return dict(self.encoder, **self.added_tokens_encoder)
206
+
207
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe
208
+ def bpe(self, token):
209
+ if token in self.cache:
210
+ return self.cache[token]
211
+ word = tuple(token)
212
+ pairs = get_pairs(word)
213
+
214
+ if not pairs:
215
+ return token
216
+
217
+ while True:
218
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
219
+ if bigram not in self.bpe_ranks:
220
+ break
221
+ first, second = bigram
222
+ new_word = []
223
+ i = 0
224
+ while i < len(word):
225
+ try:
226
+ j = word.index(first, i)
227
+ except ValueError:
228
+ new_word.extend(word[i:])
229
+ break
230
+ else:
231
+ new_word.extend(word[i:j])
232
+ i = j
233
+
234
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
235
+ new_word.append(first + second)
236
+ i += 2
237
+ else:
238
+ new_word.append(word[i])
239
+ i += 1
240
+ new_word = tuple(new_word)
241
+ word = new_word
242
+ if len(word) == 1:
243
+ break
244
+ else:
245
+ pairs = get_pairs(word)
246
+ word = " ".join(word)
247
+ self.cache[token] = word
248
+ return word
249
+
250
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize
251
+ def _tokenize(self, text):
252
+ """Tokenize a string."""
253
+ bpe_tokens = []
254
+ for token in re.findall(self.pat, text):
255
+ token = "".join(
256
+ self.byte_encoder[b] for b in token.encode("utf-8")
257
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
258
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
259
+ return bpe_tokens
260
+
261
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id
262
+ def _convert_token_to_id(self, token):
263
+ """Converts a token (str) in an id using the vocab."""
264
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
265
+
266
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token
267
+ def _convert_id_to_token(self, index):
268
+ """Converts an index (integer) in a token (str) using the vocab."""
269
+ return self.decoder.get(index)
270
+
271
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string
272
+ def convert_tokens_to_string(self, tokens):
273
+ """Converts a sequence of tokens (string) in a single string."""
274
+ text = "".join(tokens)
275
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
276
+ return text
277
+
278
+ def decode(
279
+ self,
280
+ token_ids,
281
+ skip_special_tokens: bool = False,
282
+ clean_up_tokenization_spaces: Optional[bool] = False,
283
+ spaces_between_special_tokens: bool = False,
284
+ **kwargs,
285
+ ) -> str:
286
+ # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers
287
+ # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer
288
+ return super().decode(
289
+ token_ids,
290
+ skip_special_tokens=skip_special_tokens,
291
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
292
+ spaces_between_special_tokens=spaces_between_special_tokens,
293
+ **kwargs,
294
+ )
295
+
296
+ # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.save_vocabulary
297
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
298
+ if not os.path.isdir(save_directory):
299
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
300
+ return
301
+ vocab_file = os.path.join(
302
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
303
+ )
304
+ merge_file = os.path.join(
305
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
306
+ )
307
+
308
+ with open(vocab_file, "w", encoding="utf-8") as f:
309
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
310
+
311
+ index = 0
312
+ with open(merge_file, "w", encoding="utf-8") as writer:
313
+ writer.write("#version: 0.2\n")
314
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
315
+ if index != token_index:
316
+ logger.warning(
317
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
318
+ " Please check that the tokenizer is not corrupted!"
319
+ )
320
+ index = token_index
321
+ writer.write(" ".join(bpe_tokens) + "\n")
322
+ index += 1
323
+
324
+ return vocab_file, merge_file
325
+
326
+ def prepare_for_tokenization(self, text, **kwargs):
327
+ text = unicodedata.normalize("NFC", text)
328
+ return (text, kwargs)
modeling/qwen2/tokenization_qwen2_fast.py CHANGED
@@ -1,123 +1,123 @@
1
- # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """Tokenization classes for Qwen2."""
5
-
6
- from typing import Optional, Tuple
7
-
8
- from transformers.tokenization_utils import AddedToken
9
- from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
10
- from transformers.utils import logging
11
- from .tokenization_qwen2 import Qwen2Tokenizer
12
-
13
-
14
- logger = logging.get_logger(__name__)
15
-
16
- VOCAB_FILES_NAMES = {
17
- "vocab_file": "vocab.json",
18
- "merges_file": "merges.txt",
19
- "tokenizer_file": "tokenizer.json",
20
- }
21
-
22
-
23
- MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
24
-
25
-
26
- class Qwen2TokenizerFast(PreTrainedTokenizerFast):
27
- """
28
- Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
29
- Byte-Pair-Encoding.
30
-
31
- Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
32
- be encoded differently whether it is at the beginning of the sentence (without space) or not:
33
-
34
- ```python
35
- >>> from transformers import Qwen2TokenizerFast
36
-
37
- >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
38
- >>> tokenizer("Hello world")["input_ids"]
39
- [9707, 1879]
40
-
41
- >>> tokenizer(" Hello world")["input_ids"]
42
- [21927, 1879]
43
- ```
44
- This is expected.
45
-
46
- This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
47
- refer to this superclass for more information regarding those methods.
48
-
49
- Args:
50
- vocab_file (`str`, *optional*):
51
- Path to the vocabulary file.
52
- merges_file (`str`, *optional*):
53
- Path to the merges file.
54
- tokenizer_file (`str`, *optional*):
55
- Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
56
- contains everything needed to load the tokenizer.
57
- unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
58
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
59
- token instead. Not applicable to this tokenizer.
60
- bos_token (`str`, *optional*):
61
- The beginning of sequence token. Not applicable for this tokenizer.
62
- eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
63
- The end of sequence token.
64
- pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
65
- The token used for padding, for example when batching sequences of different lengths.
66
- """
67
-
68
- vocab_files_names = VOCAB_FILES_NAMES
69
- model_input_names = ["input_ids", "attention_mask"]
70
- slow_tokenizer_class = Qwen2Tokenizer
71
-
72
- def __init__(
73
- self,
74
- vocab_file=None,
75
- merges_file=None,
76
- tokenizer_file=None,
77
- unk_token="<|endoftext|>",
78
- bos_token=None,
79
- eos_token="<|endoftext|>",
80
- pad_token="<|endoftext|>",
81
- **kwargs,
82
- ):
83
- # We need to at least pass vocab_file and merges_file to base class
84
- # in case a slow tokenizer needs to be initialized; other can be
85
- # configured through files.
86
- # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
87
-
88
- bos_token = (
89
- AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
90
- if isinstance(bos_token, str)
91
- else bos_token
92
- )
93
- eos_token = (
94
- AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
95
- if isinstance(eos_token, str)
96
- else eos_token
97
- )
98
- unk_token = (
99
- AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
100
- if isinstance(unk_token, str)
101
- else unk_token
102
- )
103
- pad_token = (
104
- AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
105
- if isinstance(pad_token, str)
106
- else pad_token
107
- )
108
-
109
- super().__init__(
110
- vocab_file=vocab_file,
111
- merges_file=merges_file,
112
- tokenizer_file=tokenizer_file,
113
- unk_token=unk_token,
114
- bos_token=bos_token,
115
- eos_token=eos_token,
116
- pad_token=pad_token,
117
- **kwargs,
118
- )
119
-
120
- # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
121
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
122
- files = self._tokenizer.model.save(save_directory, name=filename_prefix)
123
- return tuple(files)
 
1
+ # Copyright 2024 The Qwen Team and The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tokenization classes for Qwen2."""
5
+
6
+ from typing import Optional, Tuple
7
+
8
+ from transformers.tokenization_utils import AddedToken
9
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
10
+ from transformers.utils import logging
11
+ from .tokenization_qwen2 import Qwen2Tokenizer
12
+
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+ VOCAB_FILES_NAMES = {
17
+ "vocab_file": "vocab.json",
18
+ "merges_file": "merges.txt",
19
+ "tokenizer_file": "tokenizer.json",
20
+ }
21
+
22
+
23
+ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768}
24
+
25
+
26
+ class Qwen2TokenizerFast(PreTrainedTokenizerFast):
27
+ """
28
+ Construct a "fast" Qwen2 tokenizer (backed by HuggingFace's *tokenizers* library). Based on byte-level
29
+ Byte-Pair-Encoding.
30
+
31
+ Same with GPT2Tokenizer, this tokenizer has been trained to treat spaces like parts of the tokens so a word will
32
+ be encoded differently whether it is at the beginning of the sentence (without space) or not:
33
+
34
+ ```python
35
+ >>> from transformers import Qwen2TokenizerFast
36
+
37
+ >>> tokenizer = Qwen2TokenizerFast.from_pretrained("Qwen/Qwen-tokenizer")
38
+ >>> tokenizer("Hello world")["input_ids"]
39
+ [9707, 1879]
40
+
41
+ >>> tokenizer(" Hello world")["input_ids"]
42
+ [21927, 1879]
43
+ ```
44
+ This is expected.
45
+
46
+ This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
47
+ refer to this superclass for more information regarding those methods.
48
+
49
+ Args:
50
+ vocab_file (`str`, *optional*):
51
+ Path to the vocabulary file.
52
+ merges_file (`str`, *optional*):
53
+ Path to the merges file.
54
+ tokenizer_file (`str`, *optional*):
55
+ Path to [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
56
+ contains everything needed to load the tokenizer.
57
+ unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
58
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
59
+ token instead. Not applicable to this tokenizer.
60
+ bos_token (`str`, *optional*):
61
+ The beginning of sequence token. Not applicable for this tokenizer.
62
+ eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
63
+ The end of sequence token.
64
+ pad_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
65
+ The token used for padding, for example when batching sequences of different lengths.
66
+ """
67
+
68
+ vocab_files_names = VOCAB_FILES_NAMES
69
+ model_input_names = ["input_ids", "attention_mask"]
70
+ slow_tokenizer_class = Qwen2Tokenizer
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_file=None,
75
+ merges_file=None,
76
+ tokenizer_file=None,
77
+ unk_token="<|endoftext|>",
78
+ bos_token=None,
79
+ eos_token="<|endoftext|>",
80
+ pad_token="<|endoftext|>",
81
+ **kwargs,
82
+ ):
83
+ # We need to at least pass vocab_file and merges_file to base class
84
+ # in case a slow tokenizer needs to be initialized; other can be
85
+ # configured through files.
86
+ # following GPT2TokenizerFast, also adding unk_token, bos_token, and eos_token
87
+
88
+ bos_token = (
89
+ AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False)
90
+ if isinstance(bos_token, str)
91
+ else bos_token
92
+ )
93
+ eos_token = (
94
+ AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False)
95
+ if isinstance(eos_token, str)
96
+ else eos_token
97
+ )
98
+ unk_token = (
99
+ AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False)
100
+ if isinstance(unk_token, str)
101
+ else unk_token
102
+ )
103
+ pad_token = (
104
+ AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False)
105
+ if isinstance(pad_token, str)
106
+ else pad_token
107
+ )
108
+
109
+ super().__init__(
110
+ vocab_file=vocab_file,
111
+ merges_file=merges_file,
112
+ tokenizer_file=tokenizer_file,
113
+ unk_token=unk_token,
114
+ bos_token=bos_token,
115
+ eos_token=eos_token,
116
+ pad_token=pad_token,
117
+ **kwargs,
118
+ )
119
+
120
+ # Copied from transformers.models.gpt2.tokenization_gpt2_fast.GPT2TokenizerFast.save_vocabulary
121
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
122
+ files = self._tokenizer.model.save(save_directory, name=filename_prefix)
123
+ return tuple(files)
modeling/siglip/__init__.py CHANGED
@@ -1,98 +1,98 @@
1
- # Copyright 2024 The HuggingFace Inc. team.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from typing import TYPE_CHECKING
5
-
6
- from transformers.utils import (
7
- OptionalDependencyNotAvailable,
8
- _LazyModule,
9
- is_sentencepiece_available,
10
- is_torch_available,
11
- is_vision_available,
12
- )
13
-
14
-
15
- _import_structure = {
16
- "configuration_siglip": [
17
- "SiglipConfig",
18
- "SiglipTextConfig",
19
- "SiglipVisionConfig",
20
- ],
21
- "processing_siglip": ["SiglipProcessor"],
22
- }
23
-
24
- try:
25
- if not is_sentencepiece_available():
26
- raise OptionalDependencyNotAvailable()
27
- except OptionalDependencyNotAvailable:
28
- pass
29
- else:
30
- _import_structure["tokenization_siglip"] = ["SiglipTokenizer"]
31
-
32
-
33
- try:
34
- if not is_vision_available():
35
- raise OptionalDependencyNotAvailable()
36
- except OptionalDependencyNotAvailable:
37
- pass
38
- else:
39
- _import_structure["image_processing_siglip"] = ["SiglipImageProcessor"]
40
-
41
- try:
42
- if not is_torch_available():
43
- raise OptionalDependencyNotAvailable()
44
- except OptionalDependencyNotAvailable:
45
- pass
46
- else:
47
- _import_structure["modeling_siglip"] = [
48
- "SiglipModel",
49
- "SiglipPreTrainedModel",
50
- "SiglipTextModel",
51
- "SiglipVisionModel",
52
- "SiglipForImageClassification",
53
- ]
54
-
55
-
56
- if TYPE_CHECKING:
57
- from .configuration_siglip import (
58
- SiglipConfig,
59
- SiglipTextConfig,
60
- SiglipVisionConfig,
61
- )
62
- from .processing_siglip import SiglipProcessor
63
-
64
- try:
65
- if not is_sentencepiece_available():
66
- raise OptionalDependencyNotAvailable()
67
- except OptionalDependencyNotAvailable:
68
- pass
69
- else:
70
- from .tokenization_siglip import SiglipTokenizer
71
-
72
- try:
73
- if not is_vision_available():
74
- raise OptionalDependencyNotAvailable()
75
- except OptionalDependencyNotAvailable:
76
- pass
77
- else:
78
- from .image_processing_siglip import SiglipImageProcessor
79
-
80
- try:
81
- if not is_torch_available():
82
- raise OptionalDependencyNotAvailable()
83
- except OptionalDependencyNotAvailable:
84
- pass
85
- else:
86
- from .modeling_siglip import (
87
- SiglipForImageClassification,
88
- SiglipModel,
89
- SiglipPreTrainedModel,
90
- SiglipTextModel,
91
- SiglipVisionModel,
92
- )
93
-
94
-
95
- else:
96
- import sys
97
-
98
- sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import TYPE_CHECKING
5
+
6
+ from transformers.utils import (
7
+ OptionalDependencyNotAvailable,
8
+ _LazyModule,
9
+ is_sentencepiece_available,
10
+ is_torch_available,
11
+ is_vision_available,
12
+ )
13
+
14
+
15
+ _import_structure = {
16
+ "configuration_siglip": [
17
+ "SiglipConfig",
18
+ "SiglipTextConfig",
19
+ "SiglipVisionConfig",
20
+ ],
21
+ "processing_siglip": ["SiglipProcessor"],
22
+ }
23
+
24
+ try:
25
+ if not is_sentencepiece_available():
26
+ raise OptionalDependencyNotAvailable()
27
+ except OptionalDependencyNotAvailable:
28
+ pass
29
+ else:
30
+ _import_structure["tokenization_siglip"] = ["SiglipTokenizer"]
31
+
32
+
33
+ try:
34
+ if not is_vision_available():
35
+ raise OptionalDependencyNotAvailable()
36
+ except OptionalDependencyNotAvailable:
37
+ pass
38
+ else:
39
+ _import_structure["image_processing_siglip"] = ["SiglipImageProcessor"]
40
+
41
+ try:
42
+ if not is_torch_available():
43
+ raise OptionalDependencyNotAvailable()
44
+ except OptionalDependencyNotAvailable:
45
+ pass
46
+ else:
47
+ _import_structure["modeling_siglip"] = [
48
+ "SiglipModel",
49
+ "SiglipPreTrainedModel",
50
+ "SiglipTextModel",
51
+ "SiglipVisionModel",
52
+ "SiglipForImageClassification",
53
+ ]
54
+
55
+
56
+ if TYPE_CHECKING:
57
+ from .configuration_siglip import (
58
+ SiglipConfig,
59
+ SiglipTextConfig,
60
+ SiglipVisionConfig,
61
+ )
62
+ from .processing_siglip import SiglipProcessor
63
+
64
+ try:
65
+ if not is_sentencepiece_available():
66
+ raise OptionalDependencyNotAvailable()
67
+ except OptionalDependencyNotAvailable:
68
+ pass
69
+ else:
70
+ from .tokenization_siglip import SiglipTokenizer
71
+
72
+ try:
73
+ if not is_vision_available():
74
+ raise OptionalDependencyNotAvailable()
75
+ except OptionalDependencyNotAvailable:
76
+ pass
77
+ else:
78
+ from .image_processing_siglip import SiglipImageProcessor
79
+
80
+ try:
81
+ if not is_torch_available():
82
+ raise OptionalDependencyNotAvailable()
83
+ except OptionalDependencyNotAvailable:
84
+ pass
85
+ else:
86
+ from .modeling_siglip import (
87
+ SiglipForImageClassification,
88
+ SiglipModel,
89
+ SiglipPreTrainedModel,
90
+ SiglipTextModel,
91
+ SiglipVisionModel,
92
+ )
93
+
94
+
95
+ else:
96
+ import sys
97
+
98
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
modeling/siglip/configuration_siglip.py CHANGED
@@ -1,287 +1,287 @@
1
- # Copyright 2024 The HuggingFace Inc. team.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """Siglip model configuration"""
5
-
6
- import os
7
- from typing import Union
8
-
9
- from transformers.configuration_utils import PretrainedConfig
10
- from transformers.utils import logging
11
-
12
-
13
- logger = logging.get_logger(__name__)
14
-
15
-
16
- class SiglipTextConfig(PretrainedConfig):
17
- r"""
18
- This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
19
- Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
20
- configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
21
- [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
22
-
23
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
24
- documentation from [`PretrainedConfig`] for more information.
25
-
26
- Args:
27
- vocab_size (`int`, *optional*, defaults to 32000):
28
- Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
29
- the `inputs_ids` passed when calling [`SiglipModel`].
30
- hidden_size (`int`, *optional*, defaults to 768):
31
- Dimensionality of the encoder layers and the pooler layer.
32
- intermediate_size (`int`, *optional*, defaults to 3072):
33
- Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
34
- num_hidden_layers (`int`, *optional*, defaults to 12):
35
- Number of hidden layers in the Transformer encoder.
36
- num_attention_heads (`int`, *optional*, defaults to 12):
37
- Number of attention heads for each attention layer in the Transformer encoder.
38
- max_position_embeddings (`int`, *optional*, defaults to 64):
39
- The maximum sequence length that this model might ever be used with. Typically set this to something large
40
- just in case (e.g., 512 or 1024 or 2048).
41
- hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
42
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
43
- `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
44
- layer_norm_eps (`float`, *optional*, defaults to 1e-06):
45
- The epsilon used by the layer normalization layers.
46
- attention_dropout (`float`, *optional*, defaults to 0.0):
47
- The dropout ratio for the attention probabilities.
48
- pad_token_id (`int`, *optional*, defaults to 1):
49
- The id of the padding token in the vocabulary.
50
- bos_token_id (`int`, *optional*, defaults to 49406):
51
- The id of the beginning-of-sequence token in the vocabulary.
52
- eos_token_id (`int`, *optional*, defaults to 49407):
53
- The id of the end-of-sequence token in the vocabulary.
54
-
55
- Example:
56
-
57
- ```python
58
- >>> from transformers import SiglipTextConfig, SiglipTextModel
59
-
60
- >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
61
- >>> configuration = SiglipTextConfig()
62
-
63
- >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
64
- >>> model = SiglipTextModel(configuration)
65
-
66
- >>> # Accessing the model configuration
67
- >>> configuration = model.config
68
- ```"""
69
-
70
- model_type = "siglip_text_model"
71
-
72
- def __init__(
73
- self,
74
- vocab_size=32000,
75
- hidden_size=768,
76
- intermediate_size=3072,
77
- num_hidden_layers=12,
78
- num_attention_heads=12,
79
- max_position_embeddings=64,
80
- hidden_act="gelu_pytorch_tanh",
81
- layer_norm_eps=1e-6,
82
- attention_dropout=0.0,
83
- # This differs from `CLIPTokenizer`'s default and from openai/siglip
84
- # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
85
- pad_token_id=1,
86
- bos_token_id=49406,
87
- eos_token_id=49407,
88
- **kwargs,
89
- ):
90
- super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
91
-
92
- self.vocab_size = vocab_size
93
- self.hidden_size = hidden_size
94
- self.intermediate_size = intermediate_size
95
- self.num_hidden_layers = num_hidden_layers
96
- self.num_attention_heads = num_attention_heads
97
- self.max_position_embeddings = max_position_embeddings
98
- self.layer_norm_eps = layer_norm_eps
99
- self.hidden_act = hidden_act
100
- self.attention_dropout = attention_dropout
101
-
102
- @classmethod
103
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
104
- cls._set_token_in_kwargs(kwargs)
105
-
106
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
107
-
108
- # get the text config dict if we are loading from SiglipConfig
109
- if config_dict.get("model_type") == "siglip":
110
- config_dict = config_dict["text_config"]
111
-
112
- if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
113
- logger.warning(
114
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
115
- f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
116
- )
117
-
118
- return cls.from_dict(config_dict, **kwargs)
119
-
120
-
121
- class SiglipVisionConfig(PretrainedConfig):
122
- r"""
123
- This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
124
- Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
125
- configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
126
- [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
127
-
128
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
129
- documentation from [`PretrainedConfig`] for more information.
130
-
131
- Args:
132
- hidden_size (`int`, *optional*, defaults to 768):
133
- Dimensionality of the encoder layers and the pooler layer.
134
- intermediate_size (`int`, *optional*, defaults to 3072):
135
- Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
136
- num_hidden_layers (`int`, *optional*, defaults to 12):
137
- Number of hidden layers in the Transformer encoder.
138
- num_attention_heads (`int`, *optional*, defaults to 12):
139
- Number of attention heads for each attention layer in the Transformer encoder.
140
- num_channels (`int`, *optional*, defaults to 3):
141
- Number of channels in the input images.
142
- image_size (`int`, *optional*, defaults to 224):
143
- The size (resolution) of each image.
144
- patch_size (`int`, *optional*, defaults to 16):
145
- The size (resolution) of each patch.
146
- hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
147
- The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
148
- `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
149
- layer_norm_eps (`float`, *optional*, defaults to 1e-06):
150
- The epsilon used by the layer normalization layers.
151
- attention_dropout (`float`, *optional*, defaults to 0.0):
152
- The dropout ratio for the attention probabilities.
153
-
154
- Example:
155
-
156
- ```python
157
- >>> from transformers import SiglipVisionConfig, SiglipVisionModel
158
-
159
- >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
160
- >>> configuration = SiglipVisionConfig()
161
-
162
- >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
163
- >>> model = SiglipVisionModel(configuration)
164
-
165
- >>> # Accessing the model configuration
166
- >>> configuration = model.config
167
- ```"""
168
-
169
- model_type = "siglip_vision_model"
170
-
171
- def __init__(
172
- self,
173
- hidden_size=768,
174
- intermediate_size=3072,
175
- num_hidden_layers=12,
176
- num_attention_heads=12,
177
- num_channels=3,
178
- image_size=224,
179
- patch_size=16,
180
- hidden_act="gelu_pytorch_tanh",
181
- layer_norm_eps=1e-6,
182
- attention_dropout=0.0,
183
- **kwargs,
184
- ):
185
- super().__init__(**kwargs)
186
-
187
- self.hidden_size = hidden_size
188
- self.intermediate_size = intermediate_size
189
- self.num_hidden_layers = num_hidden_layers
190
- self.num_attention_heads = num_attention_heads
191
- self.num_channels = num_channels
192
- self.patch_size = patch_size
193
- self.image_size = image_size
194
- self.attention_dropout = attention_dropout
195
- self.layer_norm_eps = layer_norm_eps
196
- self.hidden_act = hidden_act
197
-
198
- @classmethod
199
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
200
- cls._set_token_in_kwargs(kwargs)
201
-
202
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
203
-
204
- # get the vision config dict if we are loading from SiglipConfig
205
- if config_dict.get("model_type") == "siglip":
206
- config_dict = config_dict["vision_config"]
207
-
208
- if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
209
- logger.warning(
210
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
211
- f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
212
- )
213
-
214
- return cls.from_dict(config_dict, **kwargs)
215
-
216
-
217
- class SiglipConfig(PretrainedConfig):
218
- r"""
219
- [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
220
- instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
221
- Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
222
- [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
223
-
224
- Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
225
- documentation from [`PretrainedConfig`] for more information.
226
-
227
- Args:
228
- text_config (`dict`, *optional*):
229
- Dictionary of configuration options used to initialize [`SiglipTextConfig`].
230
- vision_config (`dict`, *optional*):
231
- Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
232
- kwargs (*optional*):
233
- Dictionary of keyword arguments.
234
-
235
- Example:
236
-
237
- ```python
238
- >>> from transformers import SiglipConfig, SiglipModel
239
-
240
- >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
241
- >>> configuration = SiglipConfig()
242
-
243
- >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
244
- >>> model = SiglipModel(configuration)
245
-
246
- >>> # Accessing the model configuration
247
- >>> configuration = model.config
248
-
249
- >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
250
- >>> from transformers import SiglipTextConfig, SiglipVisionConfig
251
-
252
- >>> # Initializing a SiglipText and SiglipVision configuration
253
- >>> config_text = SiglipTextConfig()
254
- >>> config_vision = SiglipVisionConfig()
255
-
256
- >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
257
- ```"""
258
-
259
- model_type = "siglip"
260
-
261
- def __init__(self, text_config=None, vision_config=None, **kwargs):
262
- super().__init__(**kwargs)
263
-
264
- if text_config is None:
265
- text_config = {}
266
- logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
267
-
268
- if vision_config is None:
269
- vision_config = {}
270
- logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
271
-
272
- self.text_config = SiglipTextConfig(**text_config)
273
- self.vision_config = SiglipVisionConfig(**vision_config)
274
-
275
- self.initializer_factor = 1.0
276
-
277
- @classmethod
278
- def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
279
- r"""
280
- Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
281
- model configuration.
282
-
283
- Returns:
284
- [`SiglipConfig`]: An instance of a configuration object
285
- """
286
-
287
- return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Siglip model configuration"""
5
+
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import logging
11
+
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ class SiglipTextConfig(PretrainedConfig):
17
+ r"""
18
+ This is the configuration class to store the configuration of a [`SiglipTextModel`]. It is used to instantiate a
19
+ Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a
20
+ configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip
21
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
22
+
23
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
24
+ documentation from [`PretrainedConfig`] for more information.
25
+
26
+ Args:
27
+ vocab_size (`int`, *optional*, defaults to 32000):
28
+ Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by
29
+ the `inputs_ids` passed when calling [`SiglipModel`].
30
+ hidden_size (`int`, *optional*, defaults to 768):
31
+ Dimensionality of the encoder layers and the pooler layer.
32
+ intermediate_size (`int`, *optional*, defaults to 3072):
33
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
34
+ num_hidden_layers (`int`, *optional*, defaults to 12):
35
+ Number of hidden layers in the Transformer encoder.
36
+ num_attention_heads (`int`, *optional*, defaults to 12):
37
+ Number of attention heads for each attention layer in the Transformer encoder.
38
+ max_position_embeddings (`int`, *optional*, defaults to 64):
39
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
40
+ just in case (e.g., 512 or 1024 or 2048).
41
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
42
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
43
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
44
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
45
+ The epsilon used by the layer normalization layers.
46
+ attention_dropout (`float`, *optional*, defaults to 0.0):
47
+ The dropout ratio for the attention probabilities.
48
+ pad_token_id (`int`, *optional*, defaults to 1):
49
+ The id of the padding token in the vocabulary.
50
+ bos_token_id (`int`, *optional*, defaults to 49406):
51
+ The id of the beginning-of-sequence token in the vocabulary.
52
+ eos_token_id (`int`, *optional*, defaults to 49407):
53
+ The id of the end-of-sequence token in the vocabulary.
54
+
55
+ Example:
56
+
57
+ ```python
58
+ >>> from transformers import SiglipTextConfig, SiglipTextModel
59
+
60
+ >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration
61
+ >>> configuration = SiglipTextConfig()
62
+
63
+ >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration
64
+ >>> model = SiglipTextModel(configuration)
65
+
66
+ >>> # Accessing the model configuration
67
+ >>> configuration = model.config
68
+ ```"""
69
+
70
+ model_type = "siglip_text_model"
71
+
72
+ def __init__(
73
+ self,
74
+ vocab_size=32000,
75
+ hidden_size=768,
76
+ intermediate_size=3072,
77
+ num_hidden_layers=12,
78
+ num_attention_heads=12,
79
+ max_position_embeddings=64,
80
+ hidden_act="gelu_pytorch_tanh",
81
+ layer_norm_eps=1e-6,
82
+ attention_dropout=0.0,
83
+ # This differs from `CLIPTokenizer`'s default and from openai/siglip
84
+ # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
85
+ pad_token_id=1,
86
+ bos_token_id=49406,
87
+ eos_token_id=49407,
88
+ **kwargs,
89
+ ):
90
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
91
+
92
+ self.vocab_size = vocab_size
93
+ self.hidden_size = hidden_size
94
+ self.intermediate_size = intermediate_size
95
+ self.num_hidden_layers = num_hidden_layers
96
+ self.num_attention_heads = num_attention_heads
97
+ self.max_position_embeddings = max_position_embeddings
98
+ self.layer_norm_eps = layer_norm_eps
99
+ self.hidden_act = hidden_act
100
+ self.attention_dropout = attention_dropout
101
+
102
+ @classmethod
103
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
104
+ cls._set_token_in_kwargs(kwargs)
105
+
106
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
107
+
108
+ # get the text config dict if we are loading from SiglipConfig
109
+ if config_dict.get("model_type") == "siglip":
110
+ config_dict = config_dict["text_config"]
111
+
112
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
113
+ logger.warning(
114
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
115
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
116
+ )
117
+
118
+ return cls.from_dict(config_dict, **kwargs)
119
+
120
+
121
+ class SiglipVisionConfig(PretrainedConfig):
122
+ r"""
123
+ This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
124
+ Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
125
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
126
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
127
+
128
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
129
+ documentation from [`PretrainedConfig`] for more information.
130
+
131
+ Args:
132
+ hidden_size (`int`, *optional*, defaults to 768):
133
+ Dimensionality of the encoder layers and the pooler layer.
134
+ intermediate_size (`int`, *optional*, defaults to 3072):
135
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
136
+ num_hidden_layers (`int`, *optional*, defaults to 12):
137
+ Number of hidden layers in the Transformer encoder.
138
+ num_attention_heads (`int`, *optional*, defaults to 12):
139
+ Number of attention heads for each attention layer in the Transformer encoder.
140
+ num_channels (`int`, *optional*, defaults to 3):
141
+ Number of channels in the input images.
142
+ image_size (`int`, *optional*, defaults to 224):
143
+ The size (resolution) of each image.
144
+ patch_size (`int`, *optional*, defaults to 16):
145
+ The size (resolution) of each patch.
146
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
147
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
148
+ `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
149
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
150
+ The epsilon used by the layer normalization layers.
151
+ attention_dropout (`float`, *optional*, defaults to 0.0):
152
+ The dropout ratio for the attention probabilities.
153
+
154
+ Example:
155
+
156
+ ```python
157
+ >>> from transformers import SiglipVisionConfig, SiglipVisionModel
158
+
159
+ >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
160
+ >>> configuration = SiglipVisionConfig()
161
+
162
+ >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
163
+ >>> model = SiglipVisionModel(configuration)
164
+
165
+ >>> # Accessing the model configuration
166
+ >>> configuration = model.config
167
+ ```"""
168
+
169
+ model_type = "siglip_vision_model"
170
+
171
+ def __init__(
172
+ self,
173
+ hidden_size=768,
174
+ intermediate_size=3072,
175
+ num_hidden_layers=12,
176
+ num_attention_heads=12,
177
+ num_channels=3,
178
+ image_size=224,
179
+ patch_size=16,
180
+ hidden_act="gelu_pytorch_tanh",
181
+ layer_norm_eps=1e-6,
182
+ attention_dropout=0.0,
183
+ **kwargs,
184
+ ):
185
+ super().__init__(**kwargs)
186
+
187
+ self.hidden_size = hidden_size
188
+ self.intermediate_size = intermediate_size
189
+ self.num_hidden_layers = num_hidden_layers
190
+ self.num_attention_heads = num_attention_heads
191
+ self.num_channels = num_channels
192
+ self.patch_size = patch_size
193
+ self.image_size = image_size
194
+ self.attention_dropout = attention_dropout
195
+ self.layer_norm_eps = layer_norm_eps
196
+ self.hidden_act = hidden_act
197
+
198
+ @classmethod
199
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
200
+ cls._set_token_in_kwargs(kwargs)
201
+
202
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
203
+
204
+ # get the vision config dict if we are loading from SiglipConfig
205
+ if config_dict.get("model_type") == "siglip":
206
+ config_dict = config_dict["vision_config"]
207
+
208
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
209
+ logger.warning(
210
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
211
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
212
+ )
213
+
214
+ return cls.from_dict(config_dict, **kwargs)
215
+
216
+
217
+ class SiglipConfig(PretrainedConfig):
218
+ r"""
219
+ [`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. It is used to
220
+ instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs.
221
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip
222
+ [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
223
+
224
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
225
+ documentation from [`PretrainedConfig`] for more information.
226
+
227
+ Args:
228
+ text_config (`dict`, *optional*):
229
+ Dictionary of configuration options used to initialize [`SiglipTextConfig`].
230
+ vision_config (`dict`, *optional*):
231
+ Dictionary of configuration options used to initialize [`SiglipVisionConfig`].
232
+ kwargs (*optional*):
233
+ Dictionary of keyword arguments.
234
+
235
+ Example:
236
+
237
+ ```python
238
+ >>> from transformers import SiglipConfig, SiglipModel
239
+
240
+ >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration
241
+ >>> configuration = SiglipConfig()
242
+
243
+ >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration
244
+ >>> model = SiglipModel(configuration)
245
+
246
+ >>> # Accessing the model configuration
247
+ >>> configuration = model.config
248
+
249
+ >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig
250
+ >>> from transformers import SiglipTextConfig, SiglipVisionConfig
251
+
252
+ >>> # Initializing a SiglipText and SiglipVision configuration
253
+ >>> config_text = SiglipTextConfig()
254
+ >>> config_vision = SiglipVisionConfig()
255
+
256
+ >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision)
257
+ ```"""
258
+
259
+ model_type = "siglip"
260
+
261
+ def __init__(self, text_config=None, vision_config=None, **kwargs):
262
+ super().__init__(**kwargs)
263
+
264
+ if text_config is None:
265
+ text_config = {}
266
+ logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.")
267
+
268
+ if vision_config is None:
269
+ vision_config = {}
270
+ logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.")
271
+
272
+ self.text_config = SiglipTextConfig(**text_config)
273
+ self.vision_config = SiglipVisionConfig(**vision_config)
274
+
275
+ self.initializer_factor = 1.0
276
+
277
+ @classmethod
278
+ def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs):
279
+ r"""
280
+ Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision
281
+ model configuration.
282
+
283
+ Returns:
284
+ [`SiglipConfig`]: An instance of a configuration object
285
+ """
286
+
287
+ return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
modeling/siglip/convert_siglip_to_hf.py CHANGED
@@ -1,401 +1,401 @@
1
- # Copyright 2024 The HuggingFace Inc. team.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """Convert SigLIP checkpoints from the original repository.
5
-
6
- URL: https://github.com/google-research/big_vision/tree/main
7
- """
8
-
9
- import argparse
10
- import collections
11
- from pathlib import Path
12
-
13
- import numpy as np
14
- import requests
15
- import torch
16
- from huggingface_hub import hf_hub_download
17
- from numpy import load
18
- from PIL import Image
19
-
20
- from transformers import SiglipConfig, SiglipImageProcessor, SiglipModel, SiglipProcessor, SiglipTokenizer
21
- from transformers.utils import logging
22
-
23
-
24
- logging.set_verbosity_info()
25
- logger = logging.get_logger(__name__)
26
-
27
-
28
- model_name_to_checkpoint = {
29
- # base checkpoints
30
- "siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz",
31
- "siglip-base-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_256_60500360.npz",
32
- "siglip-base-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_384_68578854.npz",
33
- "siglip-base-patch16-512": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_512_68580893.npz",
34
- # large checkpoints
35
- "siglip-large-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_256_60552751.npz",
36
- "siglip-large-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_384_63634585.npz",
37
- # multilingual checkpoint
38
- "siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz",
39
- # so400m checkpoints
40
- "siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz",
41
- }
42
-
43
- model_name_to_image_size = {
44
- "siglip-base-patch16-224": 224,
45
- "siglip-base-patch16-256": 256,
46
- "siglip-base-patch16-384": 384,
47
- "siglip-base-patch16-512": 512,
48
- "siglip-large-patch16-256": 256,
49
- "siglip-large-patch16-384": 384,
50
- "siglip-base-patch16-256-i18n": 256,
51
- "siglip-so400m-patch14-384": 384,
52
- }
53
-
54
-
55
- def get_siglip_config(model_name):
56
- config = SiglipConfig()
57
-
58
- vocab_size = 250000 if "i18n" in model_name else 32000
59
- image_size = model_name_to_image_size[model_name]
60
- patch_size = 16 if "patch16" in model_name else 14
61
-
62
- # size of the architecture
63
- config.vision_config.image_size = image_size
64
- config.vision_config.patch_size = patch_size
65
- config.text_config.vocab_size = vocab_size
66
-
67
- if "base" in model_name:
68
- pass
69
- elif "large" in model_name:
70
- config.text_config.hidden_size = 1024
71
- config.text_config.intermediate_size = 4096
72
- config.text_config.num_hidden_layers = 24
73
- config.text_config.num_attention_heads = 16
74
- config.vision_config.hidden_size = 1024
75
- config.vision_config.intermediate_size = 4096
76
- config.vision_config.num_hidden_layers = 24
77
- config.vision_config.num_attention_heads = 16
78
- elif "so400m" in model_name:
79
- config.text_config.hidden_size = 1152
80
- config.text_config.intermediate_size = 4304
81
- config.text_config.num_hidden_layers = 27
82
- config.text_config.num_attention_heads = 16
83
- config.vision_config.hidden_size = 1152
84
- config.vision_config.intermediate_size = 4304
85
- config.vision_config.num_hidden_layers = 27
86
- config.vision_config.num_attention_heads = 16
87
- else:
88
- raise ValueError("Model not supported")
89
-
90
- return config
91
-
92
-
93
- def create_rename_keys(config):
94
- rename_keys = []
95
- # fmt: off
96
-
97
- # vision encoder
98
-
99
- rename_keys.append(("params/img/embedding/kernel", "vision_model.embeddings.patch_embedding.weight"))
100
- rename_keys.append(("params/img/embedding/bias", "vision_model.embeddings.patch_embedding.bias"))
101
- rename_keys.append(("params/img/pos_embedding", "vision_model.embeddings.position_embedding.weight"))
102
-
103
- for i in range(config.vision_config.num_hidden_layers):
104
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
105
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
106
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
107
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
108
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
109
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
110
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
111
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
112
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight"))
113
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias"))
114
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight"))
115
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias"))
116
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight"))
117
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias"))
118
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight"))
119
- rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias"))
120
-
121
- rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.post_layernorm.weight"))
122
- rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.post_layernorm.bias"))
123
-
124
- rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.head.probe"))
125
- rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.head.layernorm.weight"))
126
- rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.head.layernorm.bias"))
127
- rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.head.mlp.fc1.weight"))
128
- rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.head.mlp.fc1.bias"))
129
- rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.head.mlp.fc2.weight"))
130
- rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.head.mlp.fc2.bias"))
131
- rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.head.attention.out_proj.weight"))
132
- rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.head.attention.out_proj.bias"))
133
-
134
- # text encoder
135
-
136
- rename_keys.append(("params/txt/Embed_0/embedding", "text_model.embeddings.token_embedding.weight"))
137
- rename_keys.append(("params/txt/pos_embedding", "text_model.embeddings.position_embedding.weight"))
138
-
139
- for i in range(config.text_config.num_hidden_layers):
140
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.encoder.layers.{i}.layer_norm1.weight"))
141
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.encoder.layers.{i}.layer_norm1.bias"))
142
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.encoder.layers.{i}.layer_norm2.weight"))
143
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.encoder.layers.{i}.layer_norm2.bias"))
144
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.encoder.layers.{i}.mlp.fc1.weight"))
145
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.encoder.layers.{i}.mlp.fc1.bias"))
146
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.encoder.layers.{i}.mlp.fc2.weight"))
147
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.encoder.layers.{i}.mlp.fc2.bias"))
148
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.encoder.layers.{i}.self_attn.k_proj.weight"))
149
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.encoder.layers.{i}.self_attn.k_proj.bias"))
150
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.encoder.layers.{i}.self_attn.v_proj.weight"))
151
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.encoder.layers.{i}.self_attn.v_proj.bias"))
152
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.encoder.layers.{i}.self_attn.q_proj.weight"))
153
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.encoder.layers.{i}.self_attn.q_proj.bias"))
154
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.encoder.layers.{i}.self_attn.out_proj.weight"))
155
- rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.encoder.layers.{i}.self_attn.out_proj.bias"))
156
-
157
- rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.final_layer_norm.weight"))
158
- rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.final_layer_norm.bias"))
159
- rename_keys.append(("params/txt/head/kernel", "text_model.head.weight"))
160
- rename_keys.append(("params/txt/head/bias", "text_model.head.bias"))
161
-
162
- # learned temperature and bias
163
- rename_keys.append(("params/t", "logit_scale"))
164
- rename_keys.append(("params/b", "logit_bias"))
165
-
166
- # fmt: on
167
- return rename_keys
168
-
169
-
170
- def rename_key(dct, old, new, config):
171
- val = dct.pop(old)
172
-
173
- if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new:
174
- val = val.reshape(-1, config.vision_config.hidden_size)
175
- if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new:
176
- val = val.reshape(-1, config.text_config.hidden_size)
177
-
178
- if "patch_embedding.weight" in new:
179
- val = val.transpose(3, 2, 0, 1)
180
- elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new:
181
- val = val.T
182
-
183
- if "position_embedding" in new and "vision" in new:
184
- val = val.reshape(-1, config.vision_config.hidden_size)
185
- if "position_embedding" in new and "text" in new:
186
- val = val.reshape(-1, config.text_config.hidden_size)
187
-
188
- if new.endswith("bias"):
189
- val = val.reshape(-1)
190
-
191
- dct[new] = torch.from_numpy(val)
192
-
193
-
194
- def read_in_q_k_v_head(state_dict, config):
195
- # read in individual input projection layers
196
- key_proj_weight = (
197
- state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel")
198
- .reshape(-1, config.vision_config.hidden_size)
199
- .T
200
- )
201
- key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1)
202
- value_proj_weight = (
203
- state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel")
204
- .reshape(-1, config.vision_config.hidden_size)
205
- .T
206
- )
207
- value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1)
208
- query_proj_weight = (
209
- state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel")
210
- .reshape(-1, config.vision_config.hidden_size)
211
- .T
212
- )
213
- query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1)
214
-
215
- # next, add them to the state dict as a single matrix + vector
216
- state_dict["vision_model.head.attention.in_proj_weight"] = torch.from_numpy(
217
- np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0)
218
- )
219
- state_dict["vision_model.head.attention.in_proj_bias"] = torch.from_numpy(
220
- np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0)
221
- )
222
-
223
-
224
- # We will verify our results on an image of cute cats
225
- def prepare_img():
226
- url = "http://images.cocodataset.org/val2017/000000039769.jpg"
227
- image = Image.open(requests.get(url, stream=True).raw)
228
- return image
229
-
230
-
231
- def flatten_nested_dict(params, parent_key="", sep="/"):
232
- items = []
233
-
234
- for k, v in params.items():
235
- new_key = parent_key + sep + k if parent_key else k
236
-
237
- if isinstance(v, collections.abc.MutableMapping):
238
- items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
239
- else:
240
- items.append((new_key, v))
241
- return dict(items)
242
-
243
-
244
- @torch.no_grad()
245
- def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False):
246
- """
247
- Copy/paste/tweak model's weights to our SigLIP structure.
248
- """
249
-
250
- # define default SigLIP configuration
251
- config = get_siglip_config(model_name)
252
-
253
- # get checkpoint
254
- checkpoint = model_name_to_checkpoint[model_name]
255
-
256
- # get vocab file
257
- if "i18n" in model_name:
258
- vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model"
259
- else:
260
- vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model"
261
-
262
- # load original state dict
263
- data = load(checkpoint)
264
- state_dict = flatten_nested_dict(data)
265
-
266
- # remove and rename some keys
267
- rename_keys = create_rename_keys(config)
268
- for src, dest in rename_keys:
269
- rename_key(state_dict, src, dest, config)
270
-
271
- # qkv matrices of attention pooling head need special treatment
272
- read_in_q_k_v_head(state_dict, config)
273
-
274
- # load HuggingFace model
275
- model = SiglipModel(config).eval()
276
- model.load_state_dict(state_dict)
277
-
278
- # create processor
279
- # important: make tokenizer not return attention_mask since original one doesn't require it
280
- image_size = config.vision_config.image_size
281
- size = {"height": image_size, "width": image_size}
282
- image_processor = SiglipImageProcessor(size=size)
283
- tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"])
284
- processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer)
285
-
286
- # verify on dummy images and texts
287
- url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
288
- image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB")
289
- url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
290
- image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB")
291
- texts = ["an apple", "a picture of an apple"]
292
-
293
- inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt", padding="max_length")
294
-
295
- # verify input_ids against original ones
296
- if image_size == 224:
297
- filename = "siglip_pixel_values.pt"
298
- elif image_size == 256:
299
- filename = "siglip_pixel_values_256.pt"
300
- elif image_size == 384:
301
- filename = "siglip_pixel_values_384.pt"
302
- elif image_size == 512:
303
- filename = "siglip_pixel_values_512.pt"
304
- else:
305
- raise ValueError("Image size not supported")
306
-
307
- filepath = hf_hub_download(repo_id="nielsr/test-image", filename=filename, repo_type="dataset")
308
- original_pixel_values = torch.load(filepath)
309
- filepath = hf_hub_download(repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset")
310
- original_input_ids = torch.load(filepath)
311
-
312
- if "i18n" not in model_name:
313
- assert inputs.input_ids.tolist() == original_input_ids.tolist()
314
-
315
- print("Mean of original pixel values:", original_pixel_values.mean())
316
- print("Mean of new pixel values:", inputs.pixel_values.mean())
317
-
318
- # note: we're testing with original pixel values here since we don't have exact pixel values
319
- with torch.no_grad():
320
- outputs = model(input_ids=inputs.input_ids, pixel_values=original_pixel_values)
321
-
322
- # with torch.no_grad():
323
- # outputs = model(input_ids=inputs.input_ids, pixel_values=inputs.pixel_values)
324
-
325
- print(outputs.logits_per_image[:3, :3])
326
-
327
- probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities
328
- print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
329
- print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'")
330
-
331
- if verify_logits:
332
- if model_name == "siglip-base-patch16-224":
333
- expected_slice = torch.tensor(
334
- [[-2.9621, -2.1672], [-0.2713, 0.2910]],
335
- )
336
- elif model_name == "siglip-base-patch16-256":
337
- expected_slice = torch.tensor(
338
- [[-3.1146, -1.9894], [-0.7312, 0.6387]],
339
- )
340
- elif model_name == "siglip-base-patch16-384":
341
- expected_slice = torch.tensor(
342
- [[-2.8098, -2.1891], [-0.4242, 0.4102]],
343
- )
344
- elif model_name == "siglip-base-patch16-512":
345
- expected_slice = torch.tensor(
346
- [[-2.7899, -2.2668], [-0.4295, -0.0735]],
347
- )
348
- elif model_name == "siglip-large-patch16-256":
349
- expected_slice = torch.tensor(
350
- [[-1.5827, -0.5801], [-0.9153, 0.1363]],
351
- )
352
- elif model_name == "siglip-large-patch16-384":
353
- expected_slice = torch.tensor(
354
- [[-2.1523, -0.2899], [-0.2959, 0.7884]],
355
- )
356
- elif model_name == "siglip-so400m-patch14-384":
357
- expected_slice = torch.tensor([[-1.2441, -0.6649], [-0.7060, 0.7374]])
358
- elif model_name == "siglip-base-patch16-256-i18n":
359
- expected_slice = torch.tensor(
360
- [[-0.9064, 0.1073], [-0.0299, 0.5304]],
361
- )
362
-
363
- assert torch.allclose(outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4)
364
- print("Looks ok!")
365
-
366
- if pytorch_dump_folder_path is not None:
367
- Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
368
- print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
369
- model.save_pretrained(pytorch_dump_folder_path)
370
- print(f"Saving processor to {pytorch_dump_folder_path}")
371
- processor.save_pretrained(pytorch_dump_folder_path)
372
-
373
- if push_to_hub:
374
- model.push_to_hub(f"nielsr/{model_name}")
375
- processor.push_to_hub(f"nielsr/{model_name}")
376
-
377
-
378
- if __name__ == "__main__":
379
- parser = argparse.ArgumentParser()
380
- # Required parameters
381
- parser.add_argument(
382
- "--model_name",
383
- default="siglip-base-patch16-224",
384
- type=str,
385
- choices=model_name_to_checkpoint.keys(),
386
- help="Name of the model you'd like to convert.",
387
- )
388
- parser.add_argument(
389
- "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
390
- )
391
- parser.add_argument(
392
- "--verify_logits",
393
- action="store_false",
394
- help="Whether to verify logits against the original implementation.",
395
- )
396
- parser.add_argument(
397
- "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
398
- )
399
-
400
- args = parser.parse_args()
401
- convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub)
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Convert SigLIP checkpoints from the original repository.
5
+
6
+ URL: https://github.com/google-research/big_vision/tree/main
7
+ """
8
+
9
+ import argparse
10
+ import collections
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import requests
15
+ import torch
16
+ from huggingface_hub import hf_hub_download
17
+ from numpy import load
18
+ from PIL import Image
19
+
20
+ from transformers import SiglipConfig, SiglipImageProcessor, SiglipModel, SiglipProcessor, SiglipTokenizer
21
+ from transformers.utils import logging
22
+
23
+
24
+ logging.set_verbosity_info()
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ model_name_to_checkpoint = {
29
+ # base checkpoints
30
+ "siglip-base-patch16-224": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_224_63724782.npz",
31
+ "siglip-base-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_256_60500360.npz",
32
+ "siglip-base-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_384_68578854.npz",
33
+ "siglip-base-patch16-512": "/Users/nielsrogge/Documents/SigLIP/webli_en_b16_512_68580893.npz",
34
+ # large checkpoints
35
+ "siglip-large-patch16-256": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_256_60552751.npz",
36
+ "siglip-large-patch16-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_l16_384_63634585.npz",
37
+ # multilingual checkpoint
38
+ "siglip-base-patch16-256-i18n": "/Users/nielsrogge/Documents/SigLIP/webli_i18n_b16_256_66117334.npz",
39
+ # so400m checkpoints
40
+ "siglip-so400m-patch14-384": "/Users/nielsrogge/Documents/SigLIP/webli_en_so400m_384_58765454.npz",
41
+ }
42
+
43
+ model_name_to_image_size = {
44
+ "siglip-base-patch16-224": 224,
45
+ "siglip-base-patch16-256": 256,
46
+ "siglip-base-patch16-384": 384,
47
+ "siglip-base-patch16-512": 512,
48
+ "siglip-large-patch16-256": 256,
49
+ "siglip-large-patch16-384": 384,
50
+ "siglip-base-patch16-256-i18n": 256,
51
+ "siglip-so400m-patch14-384": 384,
52
+ }
53
+
54
+
55
+ def get_siglip_config(model_name):
56
+ config = SiglipConfig()
57
+
58
+ vocab_size = 250000 if "i18n" in model_name else 32000
59
+ image_size = model_name_to_image_size[model_name]
60
+ patch_size = 16 if "patch16" in model_name else 14
61
+
62
+ # size of the architecture
63
+ config.vision_config.image_size = image_size
64
+ config.vision_config.patch_size = patch_size
65
+ config.text_config.vocab_size = vocab_size
66
+
67
+ if "base" in model_name:
68
+ pass
69
+ elif "large" in model_name:
70
+ config.text_config.hidden_size = 1024
71
+ config.text_config.intermediate_size = 4096
72
+ config.text_config.num_hidden_layers = 24
73
+ config.text_config.num_attention_heads = 16
74
+ config.vision_config.hidden_size = 1024
75
+ config.vision_config.intermediate_size = 4096
76
+ config.vision_config.num_hidden_layers = 24
77
+ config.vision_config.num_attention_heads = 16
78
+ elif "so400m" in model_name:
79
+ config.text_config.hidden_size = 1152
80
+ config.text_config.intermediate_size = 4304
81
+ config.text_config.num_hidden_layers = 27
82
+ config.text_config.num_attention_heads = 16
83
+ config.vision_config.hidden_size = 1152
84
+ config.vision_config.intermediate_size = 4304
85
+ config.vision_config.num_hidden_layers = 27
86
+ config.vision_config.num_attention_heads = 16
87
+ else:
88
+ raise ValueError("Model not supported")
89
+
90
+ return config
91
+
92
+
93
+ def create_rename_keys(config):
94
+ rename_keys = []
95
+ # fmt: off
96
+
97
+ # vision encoder
98
+
99
+ rename_keys.append(("params/img/embedding/kernel", "vision_model.embeddings.patch_embedding.weight"))
100
+ rename_keys.append(("params/img/embedding/bias", "vision_model.embeddings.patch_embedding.bias"))
101
+ rename_keys.append(("params/img/pos_embedding", "vision_model.embeddings.position_embedding.weight"))
102
+
103
+ for i in range(config.vision_config.num_hidden_layers):
104
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/scale", f"vision_model.encoder.layers.{i}.layer_norm1.weight"))
105
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_0/bias", f"vision_model.encoder.layers.{i}.layer_norm1.bias"))
106
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/scale", f"vision_model.encoder.layers.{i}.layer_norm2.weight"))
107
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/LayerNorm_1/bias", f"vision_model.encoder.layers.{i}.layer_norm2.bias"))
108
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"vision_model.encoder.layers.{i}.mlp.fc1.weight"))
109
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"vision_model.encoder.layers.{i}.mlp.fc1.bias"))
110
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"vision_model.encoder.layers.{i}.mlp.fc2.weight"))
111
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"vision_model.encoder.layers.{i}.mlp.fc2.bias"))
112
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"vision_model.encoder.layers.{i}.self_attn.k_proj.weight"))
113
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"vision_model.encoder.layers.{i}.self_attn.k_proj.bias"))
114
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"vision_model.encoder.layers.{i}.self_attn.v_proj.weight"))
115
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"vision_model.encoder.layers.{i}.self_attn.v_proj.bias"))
116
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"vision_model.encoder.layers.{i}.self_attn.q_proj.weight"))
117
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"vision_model.encoder.layers.{i}.self_attn.q_proj.bias"))
118
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"vision_model.encoder.layers.{i}.self_attn.out_proj.weight"))
119
+ rename_keys.append((f"params/img/Transformer/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"vision_model.encoder.layers.{i}.self_attn.out_proj.bias"))
120
+
121
+ rename_keys.append(("params/img/Transformer/encoder_norm/scale", "vision_model.post_layernorm.weight"))
122
+ rename_keys.append(("params/img/Transformer/encoder_norm/bias", "vision_model.post_layernorm.bias"))
123
+
124
+ rename_keys.append(("params/img/MAPHead_0/probe", "vision_model.head.probe"))
125
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/scale", "vision_model.head.layernorm.weight"))
126
+ rename_keys.append(("params/img/MAPHead_0/LayerNorm_0/bias", "vision_model.head.layernorm.bias"))
127
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/kernel", "vision_model.head.mlp.fc1.weight"))
128
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_0/bias", "vision_model.head.mlp.fc1.bias"))
129
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/kernel", "vision_model.head.mlp.fc2.weight"))
130
+ rename_keys.append(("params/img/MAPHead_0/MlpBlock_0/Dense_1/bias", "vision_model.head.mlp.fc2.bias"))
131
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/kernel", "vision_model.head.attention.out_proj.weight"))
132
+ rename_keys.append(("params/img/MAPHead_0/MultiHeadDotProductAttention_0/out/bias", "vision_model.head.attention.out_proj.bias"))
133
+
134
+ # text encoder
135
+
136
+ rename_keys.append(("params/txt/Embed_0/embedding", "text_model.embeddings.token_embedding.weight"))
137
+ rename_keys.append(("params/txt/pos_embedding", "text_model.embeddings.position_embedding.weight"))
138
+
139
+ for i in range(config.text_config.num_hidden_layers):
140
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/scale", f"text_model.encoder.layers.{i}.layer_norm1.weight"))
141
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_0/bias", f"text_model.encoder.layers.{i}.layer_norm1.bias"))
142
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/scale", f"text_model.encoder.layers.{i}.layer_norm2.weight"))
143
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/LayerNorm_1/bias", f"text_model.encoder.layers.{i}.layer_norm2.bias"))
144
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/kernel", f"text_model.encoder.layers.{i}.mlp.fc1.weight"))
145
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_0/bias", f"text_model.encoder.layers.{i}.mlp.fc1.bias"))
146
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/kernel", f"text_model.encoder.layers.{i}.mlp.fc2.weight"))
147
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MlpBlock_0/Dense_1/bias", f"text_model.encoder.layers.{i}.mlp.fc2.bias"))
148
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/kernel", f"text_model.encoder.layers.{i}.self_attn.k_proj.weight"))
149
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/key/bias", f"text_model.encoder.layers.{i}.self_attn.k_proj.bias"))
150
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/kernel", f"text_model.encoder.layers.{i}.self_attn.v_proj.weight"))
151
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/value/bias", f"text_model.encoder.layers.{i}.self_attn.v_proj.bias"))
152
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/kernel", f"text_model.encoder.layers.{i}.self_attn.q_proj.weight"))
153
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/query/bias", f"text_model.encoder.layers.{i}.self_attn.q_proj.bias"))
154
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/kernel", f"text_model.encoder.layers.{i}.self_attn.out_proj.weight"))
155
+ rename_keys.append((f"params/txt/Encoder_0/encoderblock_{i}/MultiHeadDotProductAttention_0/out/bias", f"text_model.encoder.layers.{i}.self_attn.out_proj.bias"))
156
+
157
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/scale", "text_model.final_layer_norm.weight"))
158
+ rename_keys.append(("params/txt/Encoder_0/encoder_norm/bias", "text_model.final_layer_norm.bias"))
159
+ rename_keys.append(("params/txt/head/kernel", "text_model.head.weight"))
160
+ rename_keys.append(("params/txt/head/bias", "text_model.head.bias"))
161
+
162
+ # learned temperature and bias
163
+ rename_keys.append(("params/t", "logit_scale"))
164
+ rename_keys.append(("params/b", "logit_bias"))
165
+
166
+ # fmt: on
167
+ return rename_keys
168
+
169
+
170
+ def rename_key(dct, old, new, config):
171
+ val = dct.pop(old)
172
+
173
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "vision" in new:
174
+ val = val.reshape(-1, config.vision_config.hidden_size)
175
+ if ("out_proj" in new or "v_proj" in new or "k_proj" in new or "q_proj" in new) and "text" in new:
176
+ val = val.reshape(-1, config.text_config.hidden_size)
177
+
178
+ if "patch_embedding.weight" in new:
179
+ val = val.transpose(3, 2, 0, 1)
180
+ elif new.endswith("weight") and "position_embedding" not in new and "token_embedding" not in new:
181
+ val = val.T
182
+
183
+ if "position_embedding" in new and "vision" in new:
184
+ val = val.reshape(-1, config.vision_config.hidden_size)
185
+ if "position_embedding" in new and "text" in new:
186
+ val = val.reshape(-1, config.text_config.hidden_size)
187
+
188
+ if new.endswith("bias"):
189
+ val = val.reshape(-1)
190
+
191
+ dct[new] = torch.from_numpy(val)
192
+
193
+
194
+ def read_in_q_k_v_head(state_dict, config):
195
+ # read in individual input projection layers
196
+ key_proj_weight = (
197
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/kernel")
198
+ .reshape(-1, config.vision_config.hidden_size)
199
+ .T
200
+ )
201
+ key_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/key/bias").reshape(-1)
202
+ value_proj_weight = (
203
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/kernel")
204
+ .reshape(-1, config.vision_config.hidden_size)
205
+ .T
206
+ )
207
+ value_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/value/bias").reshape(-1)
208
+ query_proj_weight = (
209
+ state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/kernel")
210
+ .reshape(-1, config.vision_config.hidden_size)
211
+ .T
212
+ )
213
+ query_proj_bias = state_dict.pop("params/img/MAPHead_0/MultiHeadDotProductAttention_0/query/bias").reshape(-1)
214
+
215
+ # next, add them to the state dict as a single matrix + vector
216
+ state_dict["vision_model.head.attention.in_proj_weight"] = torch.from_numpy(
217
+ np.concatenate([query_proj_weight, key_proj_weight, value_proj_weight], axis=0)
218
+ )
219
+ state_dict["vision_model.head.attention.in_proj_bias"] = torch.from_numpy(
220
+ np.concatenate([query_proj_bias, key_proj_bias, value_proj_bias], axis=0)
221
+ )
222
+
223
+
224
+ # We will verify our results on an image of cute cats
225
+ def prepare_img():
226
+ url = "http://images.cocodataset.org/val2017/000000039769.jpg"
227
+ image = Image.open(requests.get(url, stream=True).raw)
228
+ return image
229
+
230
+
231
+ def flatten_nested_dict(params, parent_key="", sep="/"):
232
+ items = []
233
+
234
+ for k, v in params.items():
235
+ new_key = parent_key + sep + k if parent_key else k
236
+
237
+ if isinstance(v, collections.abc.MutableMapping):
238
+ items.extend(flatten_nested_dict(v, new_key, sep=sep).items())
239
+ else:
240
+ items.append((new_key, v))
241
+ return dict(items)
242
+
243
+
244
+ @torch.no_grad()
245
+ def convert_siglip_checkpoint(model_name, pytorch_dump_folder_path, verify_logits=True, push_to_hub=False):
246
+ """
247
+ Copy/paste/tweak model's weights to our SigLIP structure.
248
+ """
249
+
250
+ # define default SigLIP configuration
251
+ config = get_siglip_config(model_name)
252
+
253
+ # get checkpoint
254
+ checkpoint = model_name_to_checkpoint[model_name]
255
+
256
+ # get vocab file
257
+ if "i18n" in model_name:
258
+ vocab_file = "/Users/nielsrogge/Documents/SigLIP/multilingual_vocab/sentencepiece.model"
259
+ else:
260
+ vocab_file = "/Users/nielsrogge/Documents/SigLIP/english_vocab/sentencepiece.model"
261
+
262
+ # load original state dict
263
+ data = load(checkpoint)
264
+ state_dict = flatten_nested_dict(data)
265
+
266
+ # remove and rename some keys
267
+ rename_keys = create_rename_keys(config)
268
+ for src, dest in rename_keys:
269
+ rename_key(state_dict, src, dest, config)
270
+
271
+ # qkv matrices of attention pooling head need special treatment
272
+ read_in_q_k_v_head(state_dict, config)
273
+
274
+ # load HuggingFace model
275
+ model = SiglipModel(config).eval()
276
+ model.load_state_dict(state_dict)
277
+
278
+ # create processor
279
+ # important: make tokenizer not return attention_mask since original one doesn't require it
280
+ image_size = config.vision_config.image_size
281
+ size = {"height": image_size, "width": image_size}
282
+ image_processor = SiglipImageProcessor(size=size)
283
+ tokenizer = SiglipTokenizer(vocab_file=vocab_file, model_input_names=["input_ids"])
284
+ processor = SiglipProcessor(image_processor=image_processor, tokenizer=tokenizer)
285
+
286
+ # verify on dummy images and texts
287
+ url_1 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-ipod.jpg"
288
+ image_1 = Image.open(requests.get(url_1, stream=True).raw).convert("RGB")
289
+ url_2 = "https://cdn.openai.com/multimodal-neurons/assets/apple/apple-blank.jpg"
290
+ image_2 = Image.open(requests.get(url_2, stream=True).raw).convert("RGB")
291
+ texts = ["an apple", "a picture of an apple"]
292
+
293
+ inputs = processor(images=[image_1, image_2], text=texts, return_tensors="pt", padding="max_length")
294
+
295
+ # verify input_ids against original ones
296
+ if image_size == 224:
297
+ filename = "siglip_pixel_values.pt"
298
+ elif image_size == 256:
299
+ filename = "siglip_pixel_values_256.pt"
300
+ elif image_size == 384:
301
+ filename = "siglip_pixel_values_384.pt"
302
+ elif image_size == 512:
303
+ filename = "siglip_pixel_values_512.pt"
304
+ else:
305
+ raise ValueError("Image size not supported")
306
+
307
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename=filename, repo_type="dataset")
308
+ original_pixel_values = torch.load(filepath)
309
+ filepath = hf_hub_download(repo_id="nielsr/test-image", filename="siglip_input_ids.pt", repo_type="dataset")
310
+ original_input_ids = torch.load(filepath)
311
+
312
+ if "i18n" not in model_name:
313
+ assert inputs.input_ids.tolist() == original_input_ids.tolist()
314
+
315
+ print("Mean of original pixel values:", original_pixel_values.mean())
316
+ print("Mean of new pixel values:", inputs.pixel_values.mean())
317
+
318
+ # note: we're testing with original pixel values here since we don't have exact pixel values
319
+ with torch.no_grad():
320
+ outputs = model(input_ids=inputs.input_ids, pixel_values=original_pixel_values)
321
+
322
+ # with torch.no_grad():
323
+ # outputs = model(input_ids=inputs.input_ids, pixel_values=inputs.pixel_values)
324
+
325
+ print(outputs.logits_per_image[:3, :3])
326
+
327
+ probs = torch.sigmoid(outputs.logits_per_image) # these are the probabilities
328
+ print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
329
+ print(f"{probs[0][1]:.1%} that image 0 is '{texts[1]}'")
330
+
331
+ if verify_logits:
332
+ if model_name == "siglip-base-patch16-224":
333
+ expected_slice = torch.tensor(
334
+ [[-2.9621, -2.1672], [-0.2713, 0.2910]],
335
+ )
336
+ elif model_name == "siglip-base-patch16-256":
337
+ expected_slice = torch.tensor(
338
+ [[-3.1146, -1.9894], [-0.7312, 0.6387]],
339
+ )
340
+ elif model_name == "siglip-base-patch16-384":
341
+ expected_slice = torch.tensor(
342
+ [[-2.8098, -2.1891], [-0.4242, 0.4102]],
343
+ )
344
+ elif model_name == "siglip-base-patch16-512":
345
+ expected_slice = torch.tensor(
346
+ [[-2.7899, -2.2668], [-0.4295, -0.0735]],
347
+ )
348
+ elif model_name == "siglip-large-patch16-256":
349
+ expected_slice = torch.tensor(
350
+ [[-1.5827, -0.5801], [-0.9153, 0.1363]],
351
+ )
352
+ elif model_name == "siglip-large-patch16-384":
353
+ expected_slice = torch.tensor(
354
+ [[-2.1523, -0.2899], [-0.2959, 0.7884]],
355
+ )
356
+ elif model_name == "siglip-so400m-patch14-384":
357
+ expected_slice = torch.tensor([[-1.2441, -0.6649], [-0.7060, 0.7374]])
358
+ elif model_name == "siglip-base-patch16-256-i18n":
359
+ expected_slice = torch.tensor(
360
+ [[-0.9064, 0.1073], [-0.0299, 0.5304]],
361
+ )
362
+
363
+ assert torch.allclose(outputs.logits_per_image[:3, :3], expected_slice, atol=1e-4)
364
+ print("Looks ok!")
365
+
366
+ if pytorch_dump_folder_path is not None:
367
+ Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
368
+ print(f"Saving model {model_name} to {pytorch_dump_folder_path}")
369
+ model.save_pretrained(pytorch_dump_folder_path)
370
+ print(f"Saving processor to {pytorch_dump_folder_path}")
371
+ processor.save_pretrained(pytorch_dump_folder_path)
372
+
373
+ if push_to_hub:
374
+ model.push_to_hub(f"nielsr/{model_name}")
375
+ processor.push_to_hub(f"nielsr/{model_name}")
376
+
377
+
378
+ if __name__ == "__main__":
379
+ parser = argparse.ArgumentParser()
380
+ # Required parameters
381
+ parser.add_argument(
382
+ "--model_name",
383
+ default="siglip-base-patch16-224",
384
+ type=str,
385
+ choices=model_name_to_checkpoint.keys(),
386
+ help="Name of the model you'd like to convert.",
387
+ )
388
+ parser.add_argument(
389
+ "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
390
+ )
391
+ parser.add_argument(
392
+ "--verify_logits",
393
+ action="store_false",
394
+ help="Whether to verify logits against the original implementation.",
395
+ )
396
+ parser.add_argument(
397
+ "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
398
+ )
399
+
400
+ args = parser.parse_args()
401
+ convert_siglip_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.verify_logits, args.push_to_hub)
modeling/siglip/image_processing_siglip.py CHANGED
@@ -1,230 +1,230 @@
1
- # Copyright 2024 The HuggingFace Inc. team.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """Image processor class for SigLIP."""
5
-
6
- from typing import Dict, List, Optional, Union
7
-
8
- from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
9
- from transformers.image_transforms import (
10
- convert_to_rgb,
11
- resize,
12
- to_channel_dimension_format,
13
- )
14
- from transformers.image_utils import (
15
- IMAGENET_STANDARD_MEAN,
16
- IMAGENET_STANDARD_STD,
17
- ChannelDimension,
18
- ImageInput,
19
- PILImageResampling,
20
- infer_channel_dimension_format,
21
- is_scaled_image,
22
- make_list_of_images,
23
- to_numpy_array,
24
- valid_images,
25
- validate_preprocess_arguments,
26
- )
27
- from transformers.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
28
-
29
-
30
- logger = logging.get_logger(__name__)
31
-
32
-
33
- if is_vision_available():
34
- import PIL
35
-
36
-
37
- class SiglipImageProcessor(BaseImageProcessor):
38
- r"""
39
- Constructs a SigLIP image processor.
40
-
41
- Args:
42
- do_resize (`bool`, *optional*, defaults to `True`):
43
- Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
44
- `do_resize` in the `preprocess` method.
45
- size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
46
- Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
47
- resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
48
- Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
49
- do_rescale (`bool`, *optional*, defaults to `True`):
50
- Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
51
- the `preprocess` method.
52
- rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
53
- Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
54
- method.
55
- do_normalize (`bool`, *optional*, defaults to `True`):
56
- Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
57
- `do_normalize` in the `preprocess` method.
58
- image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
59
- Mean to use if normalizing the image. This is a float or list of floats the length of the number of
60
- channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
61
- image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
62
- Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
63
- number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
64
- Can be overridden by the `image_std` parameter in the `preprocess` method.
65
- do_convert_rgb (`bool`, *optional*, defaults to `True`):
66
- Whether to convert the image to RGB.
67
- """
68
-
69
- model_input_names = ["pixel_values"]
70
-
71
- def __init__(
72
- self,
73
- do_resize: bool = True,
74
- size: Dict[str, int] = None,
75
- resample: PILImageResampling = PILImageResampling.BICUBIC,
76
- do_rescale: bool = True,
77
- rescale_factor: Union[int, float] = 1 / 255,
78
- do_normalize: bool = True,
79
- image_mean: Optional[Union[float, List[float]]] = None,
80
- image_std: Optional[Union[float, List[float]]] = None,
81
- do_convert_rgb: bool = None,
82
- **kwargs,
83
- ) -> None:
84
- super().__init__(**kwargs)
85
- size = size if size is not None else {"height": 224, "width": 224}
86
- image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
87
- image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
88
-
89
- self.do_resize = do_resize
90
- self.size = size
91
- self.resample = resample
92
- self.do_rescale = do_rescale
93
- self.rescale_factor = rescale_factor
94
- self.do_normalize = do_normalize
95
- self.image_mean = image_mean
96
- self.image_std = image_std
97
- self.do_convert_rgb = do_convert_rgb
98
-
99
- @filter_out_non_signature_kwargs()
100
- def preprocess(
101
- self,
102
- images: ImageInput,
103
- do_resize: bool = None,
104
- size: Dict[str, int] = None,
105
- resample: PILImageResampling = None,
106
- do_rescale: bool = None,
107
- rescale_factor: float = None,
108
- do_normalize: bool = None,
109
- image_mean: Optional[Union[float, List[float]]] = None,
110
- image_std: Optional[Union[float, List[float]]] = None,
111
- return_tensors: Optional[Union[str, TensorType]] = None,
112
- data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
113
- input_data_format: Optional[Union[str, ChannelDimension]] = None,
114
- do_convert_rgb: bool = None,
115
- ) -> PIL.Image.Image:
116
- """
117
- Preprocess an image or batch of images.
118
-
119
- Args:
120
- images (`ImageInput`):
121
- Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
122
- passing in images with pixel values between 0 and 1, set `do_rescale=False`.
123
- do_resize (`bool`, *optional*, defaults to `self.do_resize`):
124
- Whether to resize the image.
125
- size (`Dict[str, int]`, *optional*, defaults to `self.size`):
126
- Size of the image after resizing.
127
- resample (`int`, *optional*, defaults to `self.resample`):
128
- Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
129
- has an effect if `do_resize` is set to `True`.
130
- do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
131
- Whether to rescale the image.
132
- rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
133
- Rescale factor to rescale the image by if `do_rescale` is set to `True`.
134
- do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
135
- Whether to normalize the image.
136
- image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
137
- Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
138
- image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
139
- Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
140
- `True`.
141
- return_tensors (`str` or `TensorType`, *optional*):
142
- The type of tensors to return. Can be one of:
143
- - Unset: Return a list of `np.ndarray`.
144
- - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
145
- - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
146
- - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
147
- - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
148
- data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
149
- The channel dimension format for the output image. Can be one of:
150
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
151
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
152
- - Unset: Use the channel dimension format of the input image.
153
- input_data_format (`ChannelDimension` or `str`, *optional*):
154
- The channel dimension format for the input image. If unset, the channel dimension format is inferred
155
- from the input image. Can be one of:
156
- - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
157
- - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
158
- - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
159
- do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
160
- Whether to convert the image to RGB.
161
- """
162
- do_resize = do_resize if do_resize is not None else self.do_resize
163
- size = size if size is not None else self.size
164
- size = get_size_dict(size, param_name="size", default_to_square=False)
165
- resample = resample if resample is not None else self.resample
166
- do_rescale = do_rescale if do_rescale is not None else self.do_rescale
167
- rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
168
- do_normalize = do_normalize if do_normalize is not None else self.do_normalize
169
- image_mean = image_mean if image_mean is not None else self.image_mean
170
- image_std = image_std if image_std is not None else self.image_std
171
- do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
172
-
173
- images = make_list_of_images(images)
174
-
175
- if not valid_images(images):
176
- raise ValueError(
177
- "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
178
- "torch.Tensor, tf.Tensor or jax.ndarray."
179
- )
180
- validate_preprocess_arguments(
181
- do_rescale=do_rescale,
182
- rescale_factor=rescale_factor,
183
- do_normalize=do_normalize,
184
- image_mean=image_mean,
185
- image_std=image_std,
186
- do_resize=do_resize,
187
- size=size,
188
- resample=resample,
189
- )
190
- # All transformations expect numpy arrays.
191
- images = [to_numpy_array(image) for image in images]
192
-
193
- if do_convert_rgb:
194
- images = [convert_to_rgb(image) for image in images]
195
-
196
- if is_scaled_image(images[0]) and do_rescale:
197
- logger.warning_once(
198
- "It looks like you are trying to rescale already rescaled images. If the input"
199
- " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
200
- )
201
-
202
- if input_data_format is None:
203
- # We assume that all images have the same channel dimension format.
204
- input_data_format = infer_channel_dimension_format(images[0])
205
-
206
- if do_resize:
207
- height, width = size["height"], size["width"]
208
- images = [
209
- resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format)
210
- for image in images
211
- ]
212
-
213
- if do_rescale:
214
- images = [
215
- self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
216
- for image in images
217
- ]
218
-
219
- if do_normalize:
220
- images = [
221
- self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
222
- for image in images
223
- ]
224
-
225
- images = [
226
- to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
227
- ]
228
-
229
- data = {"pixel_values": images}
230
- return BatchFeature(data=data, tensor_type=return_tensors)
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Image processor class for SigLIP."""
5
+
6
+ from typing import Dict, List, Optional, Union
7
+
8
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
9
+ from transformers.image_transforms import (
10
+ convert_to_rgb,
11
+ resize,
12
+ to_channel_dimension_format,
13
+ )
14
+ from transformers.image_utils import (
15
+ IMAGENET_STANDARD_MEAN,
16
+ IMAGENET_STANDARD_STD,
17
+ ChannelDimension,
18
+ ImageInput,
19
+ PILImageResampling,
20
+ infer_channel_dimension_format,
21
+ is_scaled_image,
22
+ make_list_of_images,
23
+ to_numpy_array,
24
+ valid_images,
25
+ validate_preprocess_arguments,
26
+ )
27
+ from transformers.utils import TensorType, filter_out_non_signature_kwargs, is_vision_available, logging
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ if is_vision_available():
34
+ import PIL
35
+
36
+
37
+ class SiglipImageProcessor(BaseImageProcessor):
38
+ r"""
39
+ Constructs a SigLIP image processor.
40
+
41
+ Args:
42
+ do_resize (`bool`, *optional*, defaults to `True`):
43
+ Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by
44
+ `do_resize` in the `preprocess` method.
45
+ size (`Dict[str, int]` *optional*, defaults to `{"height": 224, "width": 224}`):
46
+ Size of the image after resizing. Can be overridden by `size` in the `preprocess` method.
47
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
48
+ Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
49
+ do_rescale (`bool`, *optional*, defaults to `True`):
50
+ Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
51
+ the `preprocess` method.
52
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
53
+ Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
54
+ method.
55
+ do_normalize (`bool`, *optional*, defaults to `True`):
56
+ Whether to normalize the image by the specified mean and standard deviation. Can be overridden by
57
+ `do_normalize` in the `preprocess` method.
58
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
59
+ Mean to use if normalizing the image. This is a float or list of floats the length of the number of
60
+ channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method.
61
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
62
+ Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
63
+ number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
64
+ Can be overridden by the `image_std` parameter in the `preprocess` method.
65
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
66
+ Whether to convert the image to RGB.
67
+ """
68
+
69
+ model_input_names = ["pixel_values"]
70
+
71
+ def __init__(
72
+ self,
73
+ do_resize: bool = True,
74
+ size: Dict[str, int] = None,
75
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
76
+ do_rescale: bool = True,
77
+ rescale_factor: Union[int, float] = 1 / 255,
78
+ do_normalize: bool = True,
79
+ image_mean: Optional[Union[float, List[float]]] = None,
80
+ image_std: Optional[Union[float, List[float]]] = None,
81
+ do_convert_rgb: bool = None,
82
+ **kwargs,
83
+ ) -> None:
84
+ super().__init__(**kwargs)
85
+ size = size if size is not None else {"height": 224, "width": 224}
86
+ image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
87
+ image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
88
+
89
+ self.do_resize = do_resize
90
+ self.size = size
91
+ self.resample = resample
92
+ self.do_rescale = do_rescale
93
+ self.rescale_factor = rescale_factor
94
+ self.do_normalize = do_normalize
95
+ self.image_mean = image_mean
96
+ self.image_std = image_std
97
+ self.do_convert_rgb = do_convert_rgb
98
+
99
+ @filter_out_non_signature_kwargs()
100
+ def preprocess(
101
+ self,
102
+ images: ImageInput,
103
+ do_resize: bool = None,
104
+ size: Dict[str, int] = None,
105
+ resample: PILImageResampling = None,
106
+ do_rescale: bool = None,
107
+ rescale_factor: float = None,
108
+ do_normalize: bool = None,
109
+ image_mean: Optional[Union[float, List[float]]] = None,
110
+ image_std: Optional[Union[float, List[float]]] = None,
111
+ return_tensors: Optional[Union[str, TensorType]] = None,
112
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
113
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
114
+ do_convert_rgb: bool = None,
115
+ ) -> PIL.Image.Image:
116
+ """
117
+ Preprocess an image or batch of images.
118
+
119
+ Args:
120
+ images (`ImageInput`):
121
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
122
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
123
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
124
+ Whether to resize the image.
125
+ size (`Dict[str, int]`, *optional*, defaults to `self.size`):
126
+ Size of the image after resizing.
127
+ resample (`int`, *optional*, defaults to `self.resample`):
128
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
129
+ has an effect if `do_resize` is set to `True`.
130
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
131
+ Whether to rescale the image.
132
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
133
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
134
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
135
+ Whether to normalize the image.
136
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
137
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
138
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
139
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to
140
+ `True`.
141
+ return_tensors (`str` or `TensorType`, *optional*):
142
+ The type of tensors to return. Can be one of:
143
+ - Unset: Return a list of `np.ndarray`.
144
+ - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
145
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
146
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
147
+ - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
148
+ data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
149
+ The channel dimension format for the output image. Can be one of:
150
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
151
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
152
+ - Unset: Use the channel dimension format of the input image.
153
+ input_data_format (`ChannelDimension` or `str`, *optional*):
154
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
155
+ from the input image. Can be one of:
156
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
157
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
158
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
159
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
160
+ Whether to convert the image to RGB.
161
+ """
162
+ do_resize = do_resize if do_resize is not None else self.do_resize
163
+ size = size if size is not None else self.size
164
+ size = get_size_dict(size, param_name="size", default_to_square=False)
165
+ resample = resample if resample is not None else self.resample
166
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
167
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
168
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
169
+ image_mean = image_mean if image_mean is not None else self.image_mean
170
+ image_std = image_std if image_std is not None else self.image_std
171
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
172
+
173
+ images = make_list_of_images(images)
174
+
175
+ if not valid_images(images):
176
+ raise ValueError(
177
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
178
+ "torch.Tensor, tf.Tensor or jax.ndarray."
179
+ )
180
+ validate_preprocess_arguments(
181
+ do_rescale=do_rescale,
182
+ rescale_factor=rescale_factor,
183
+ do_normalize=do_normalize,
184
+ image_mean=image_mean,
185
+ image_std=image_std,
186
+ do_resize=do_resize,
187
+ size=size,
188
+ resample=resample,
189
+ )
190
+ # All transformations expect numpy arrays.
191
+ images = [to_numpy_array(image) for image in images]
192
+
193
+ if do_convert_rgb:
194
+ images = [convert_to_rgb(image) for image in images]
195
+
196
+ if is_scaled_image(images[0]) and do_rescale:
197
+ logger.warning_once(
198
+ "It looks like you are trying to rescale already rescaled images. If the input"
199
+ " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
200
+ )
201
+
202
+ if input_data_format is None:
203
+ # We assume that all images have the same channel dimension format.
204
+ input_data_format = infer_channel_dimension_format(images[0])
205
+
206
+ if do_resize:
207
+ height, width = size["height"], size["width"]
208
+ images = [
209
+ resize(image=image, size=(height, width), resample=resample, input_data_format=input_data_format)
210
+ for image in images
211
+ ]
212
+
213
+ if do_rescale:
214
+ images = [
215
+ self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
216
+ for image in images
217
+ ]
218
+
219
+ if do_normalize:
220
+ images = [
221
+ self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
222
+ for image in images
223
+ ]
224
+
225
+ images = [
226
+ to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
227
+ ]
228
+
229
+ data = {"pixel_values": images}
230
+ return BatchFeature(data=data, tensor_type=return_tensors)
modeling/siglip/modeling_siglip.py CHANGED
The diff for this file is too large to render. See raw diff
 
modeling/siglip/processing_siglip.py CHANGED
@@ -1,131 +1,131 @@
1
- # Copyright 2024 The HuggingFace Inc. team.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """
5
- Image/Text processor class for SigLIP.
6
- """
7
-
8
- from typing import List, Optional, Union
9
-
10
- from transformers.feature_extraction_utils import BatchFeature
11
- from transformers.image_utils import ImageInput
12
- from transformers.processing_utils import ProcessorMixin
13
- from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
14
- from transformers.utils import TensorType
15
-
16
-
17
- class SiglipProcessor(ProcessorMixin):
18
- r"""
19
- Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
20
-
21
- [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
22
- [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
23
-
24
- Args:
25
- image_processor ([`SiglipImageProcessor`]):
26
- The image processor is a required input.
27
- tokenizer ([`SiglipTokenizer`]):
28
- The tokenizer is a required input.
29
- """
30
-
31
- attributes = ["image_processor", "tokenizer"]
32
- image_processor_class = "SiglipImageProcessor"
33
- tokenizer_class = "SiglipTokenizer"
34
-
35
- def __init__(self, image_processor, tokenizer):
36
- super().__init__(image_processor, tokenizer)
37
-
38
- def __call__(
39
- self,
40
- text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
41
- images: ImageInput = None,
42
- padding: Union[bool, str, PaddingStrategy] = False,
43
- truncation: Union[bool, str, TruncationStrategy] = None,
44
- max_length: int = None,
45
- return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
46
- ) -> BatchFeature:
47
- """
48
- Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
49
- and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
50
- the text. To prepare the image(s), this method forwards the `images` argument to
51
- SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
52
- of the above two methods for more information.
53
-
54
- Args:
55
- text (`str`, `List[str]`, `List[List[str]]`):
56
- The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
57
- (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
58
- `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
59
- images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
60
- The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
61
- tensor. Both channels-first and channels-last formats are supported.
62
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
63
- Select a strategy to pad the returned sequences (according to the model's padding side and padding
64
- index) among:
65
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
66
- sequence if provided).
67
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
68
- acceptable input length for the model if that argument is not provided.
69
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
70
- lengths).
71
- max_length (`int`, *optional*):
72
- Maximum length of the returned list and optionally padding length (see above).
73
- truncation (`bool`, *optional*):
74
- Activates truncation to cut input sequences longer than `max_length` to `max_length`.
75
- return_tensors (`str` or [`~utils.TensorType`], *optional*):
76
- If set, will return tensors of a particular framework. Acceptable values are:
77
-
78
- - `'tf'`: Return TensorFlow `tf.constant` objects.
79
- - `'pt'`: Return PyTorch `torch.Tensor` objects.
80
- - `'np'`: Return NumPy `np.ndarray` objects.
81
- - `'jax'`: Return JAX `jnp.ndarray` objects.
82
-
83
- Returns:
84
- [`BatchFeature`]: A [`BatchFeature`] with the following fields:
85
-
86
- - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
87
- - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
88
- `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
89
- `None`).
90
- - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
91
- """
92
-
93
- if text is None and images is None:
94
- raise ValueError("You have to specify either text or images. Both cannot be none.")
95
-
96
- if text is not None:
97
- encoding = self.tokenizer(
98
- text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
99
- )
100
-
101
- if images is not None:
102
- image_features = self.image_processor(images, return_tensors=return_tensors)
103
-
104
- if text is not None and images is not None:
105
- encoding["pixel_values"] = image_features.pixel_values
106
- return encoding
107
- elif text is not None:
108
- return encoding
109
- else:
110
- return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
111
-
112
- def decode(self, *args, **kwargs):
113
- """
114
- This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
115
- the docstring of this method for more information.
116
- """
117
- return self.tokenizer.decode(*args, **kwargs)
118
-
119
- def batch_decode(self, *args, **kwargs):
120
- """
121
- This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
122
- refer to the docstring of this method for more information.
123
- """
124
- return self.tokenizer.batch_decode(*args, **kwargs)
125
-
126
- @property
127
- # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
128
- def model_input_names(self):
129
- tokenizer_input_names = self.tokenizer.model_input_names
130
- image_processor_input_names = self.image_processor.model_input_names
131
- return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """
5
+ Image/Text processor class for SigLIP.
6
+ """
7
+
8
+ from typing import List, Optional, Union
9
+
10
+ from transformers.feature_extraction_utils import BatchFeature
11
+ from transformers.image_utils import ImageInput
12
+ from transformers.processing_utils import ProcessorMixin
13
+ from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
14
+ from transformers.utils import TensorType
15
+
16
+
17
+ class SiglipProcessor(ProcessorMixin):
18
+ r"""
19
+ Constructs a Siglip processor which wraps a Siglip image processor and a Siglip tokenizer into a single processor.
20
+
21
+ [`SiglipProcessor`] offers all the functionalities of [`SiglipImageProcessor`] and [`SiglipTokenizer`]. See the
22
+ [`~SiglipProcessor.__call__`] and [`~SiglipProcessor.decode`] for more information.
23
+
24
+ Args:
25
+ image_processor ([`SiglipImageProcessor`]):
26
+ The image processor is a required input.
27
+ tokenizer ([`SiglipTokenizer`]):
28
+ The tokenizer is a required input.
29
+ """
30
+
31
+ attributes = ["image_processor", "tokenizer"]
32
+ image_processor_class = "SiglipImageProcessor"
33
+ tokenizer_class = "SiglipTokenizer"
34
+
35
+ def __init__(self, image_processor, tokenizer):
36
+ super().__init__(image_processor, tokenizer)
37
+
38
+ def __call__(
39
+ self,
40
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
41
+ images: ImageInput = None,
42
+ padding: Union[bool, str, PaddingStrategy] = False,
43
+ truncation: Union[bool, str, TruncationStrategy] = None,
44
+ max_length: int = None,
45
+ return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
46
+ ) -> BatchFeature:
47
+ """
48
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
49
+ and `kwargs` arguments to SiglipTokenizer's [`~SiglipTokenizer.__call__`] if `text` is not `None` to encode
50
+ the text. To prepare the image(s), this method forwards the `images` argument to
51
+ SiglipImageProcessor's [`~SiglipImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring
52
+ of the above two methods for more information.
53
+
54
+ Args:
55
+ text (`str`, `List[str]`, `List[List[str]]`):
56
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
57
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
58
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
59
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
60
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
61
+ tensor. Both channels-first and channels-last formats are supported.
62
+ padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
63
+ Select a strategy to pad the returned sequences (according to the model's padding side and padding
64
+ index) among:
65
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
66
+ sequence if provided).
67
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
68
+ acceptable input length for the model if that argument is not provided.
69
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
70
+ lengths).
71
+ max_length (`int`, *optional*):
72
+ Maximum length of the returned list and optionally padding length (see above).
73
+ truncation (`bool`, *optional*):
74
+ Activates truncation to cut input sequences longer than `max_length` to `max_length`.
75
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
76
+ If set, will return tensors of a particular framework. Acceptable values are:
77
+
78
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
79
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
80
+ - `'np'`: Return NumPy `np.ndarray` objects.
81
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
82
+
83
+ Returns:
84
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
85
+
86
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
87
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
88
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
89
+ `None`).
90
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
91
+ """
92
+
93
+ if text is None and images is None:
94
+ raise ValueError("You have to specify either text or images. Both cannot be none.")
95
+
96
+ if text is not None:
97
+ encoding = self.tokenizer(
98
+ text, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length
99
+ )
100
+
101
+ if images is not None:
102
+ image_features = self.image_processor(images, return_tensors=return_tensors)
103
+
104
+ if text is not None and images is not None:
105
+ encoding["pixel_values"] = image_features.pixel_values
106
+ return encoding
107
+ elif text is not None:
108
+ return encoding
109
+ else:
110
+ return BatchFeature(data=dict(**image_features), tensor_type=return_tensors)
111
+
112
+ def decode(self, *args, **kwargs):
113
+ """
114
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
115
+ the docstring of this method for more information.
116
+ """
117
+ return self.tokenizer.decode(*args, **kwargs)
118
+
119
+ def batch_decode(self, *args, **kwargs):
120
+ """
121
+ This method forwards all its arguments to SiglipTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
122
+ refer to the docstring of this method for more information.
123
+ """
124
+ return self.tokenizer.batch_decode(*args, **kwargs)
125
+
126
+ @property
127
+ # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->Siglip, T5->Siglip
128
+ def model_input_names(self):
129
+ tokenizer_input_names = self.tokenizer.model_input_names
130
+ image_processor_input_names = self.image_processor.model_input_names
131
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
modeling/siglip/tokenization_siglip.py CHANGED
@@ -1,364 +1,364 @@
1
- # Copyright 2024 The HuggingFace Inc. team.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- """Tokenization class for SigLIP model."""
5
-
6
- import os
7
- import re
8
- import string
9
- import warnings
10
- from shutil import copyfile
11
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
12
-
13
- import sentencepiece as spm
14
-
15
- from transformers.convert_slow_tokenizer import import_protobuf
16
- from transformers.tokenization_utils import PreTrainedTokenizer
17
- from transformers.tokenization_utils_base import AddedToken
18
-
19
-
20
- if TYPE_CHECKING:
21
- from transformers.tokenization_utils_base import TextInput
22
- from transformers.utils import logging, requires_backends
23
-
24
-
25
- logger = logging.get_logger(__name__)
26
-
27
- VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
28
-
29
-
30
- SPIECE_UNDERLINE = "▁"
31
-
32
-
33
- class SiglipTokenizer(PreTrainedTokenizer):
34
- """
35
- Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
36
-
37
- This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
38
- this superclass for more information regarding those methods.
39
-
40
- Args:
41
- vocab_file (`str`):
42
- [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
43
- contains the vocabulary necessary to instantiate a tokenizer.
44
- eos_token (`str`, *optional*, defaults to `"</s>"`):
45
- The end of sequence token.
46
- unk_token (`str`, *optional*, defaults to `"<unk>"`):
47
- The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
- token instead.
49
- pad_token (`str`, *optional*, defaults to `"</s>"`):
50
- The token used for padding, for example when batching sequences of different lengths.
51
- additional_special_tokens (`List[str]`, *optional*):
52
- Additional special tokens used by the tokenizer.
53
- sp_model_kwargs (`dict`, *optional*):
54
- Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
55
- SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
56
- to set:
57
-
58
- - `enable_sampling`: Enable subword regularization.
59
- - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
60
-
61
- - `nbest_size = {0,1}`: No sampling is performed.
62
- - `nbest_size > 1`: samples from the nbest_size results.
63
- - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
64
- using forward-filtering-and-backward-sampling algorithm.
65
-
66
- - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
67
- BPE-dropout.
68
- model_max_length (`int`, *optional*, defaults to 64):
69
- The maximum length (in number of tokens) for model inputs.
70
- do_lower_case (`bool`, *optional*, defaults to `True`):
71
- Whether or not to lowercase the input when tokenizing.
72
- """
73
-
74
- vocab_files_names = VOCAB_FILES_NAMES
75
- model_input_names = ["input_ids", "attention_mask"]
76
-
77
- def __init__(
78
- self,
79
- vocab_file,
80
- eos_token="</s>",
81
- unk_token="<unk>",
82
- pad_token="</s>",
83
- additional_special_tokens=None,
84
- sp_model_kwargs: Optional[Dict[str, Any]] = None,
85
- model_max_length=64,
86
- do_lower_case=True,
87
- **kwargs,
88
- ) -> None:
89
- requires_backends(self, "protobuf")
90
-
91
- pad_token = (
92
- AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True)
93
- if isinstance(pad_token, str)
94
- else pad_token
95
- )
96
- unk_token = (
97
- AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True)
98
- if isinstance(unk_token, str)
99
- else unk_token
100
- )
101
- eos_token = (
102
- AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True)
103
- if isinstance(eos_token, str)
104
- else eos_token
105
- )
106
-
107
- self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
108
-
109
- self.do_lower_case = do_lower_case
110
- self.vocab_file = vocab_file
111
-
112
- self.sp_model = self.get_spm_processor()
113
- self.vocab_file = vocab_file
114
-
115
- super().__init__(
116
- eos_token=eos_token,
117
- unk_token=unk_token,
118
- pad_token=pad_token,
119
- additional_special_tokens=additional_special_tokens,
120
- sp_model_kwargs=self.sp_model_kwargs,
121
- model_max_length=model_max_length,
122
- do_lower_case=do_lower_case,
123
- **kwargs,
124
- )
125
-
126
- def get_spm_processor(self):
127
- tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
128
- with open(self.vocab_file, "rb") as f:
129
- sp_model = f.read()
130
- model_pb2 = import_protobuf()
131
- model = model_pb2.ModelProto.FromString(sp_model)
132
- normalizer_spec = model_pb2.NormalizerSpec()
133
- normalizer_spec.add_dummy_prefix = False
134
- model.normalizer_spec.MergeFrom(normalizer_spec)
135
- sp_model = model.SerializeToString()
136
- tokenizer.LoadFromSerializedProto(sp_model)
137
- return tokenizer
138
-
139
- @property
140
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size
141
- def vocab_size(self):
142
- return self.sp_model.get_piece_size()
143
-
144
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab
145
- def get_vocab(self):
146
- vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
147
- vocab.update(self.added_tokens_encoder)
148
- return vocab
149
-
150
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask
151
- def get_special_tokens_mask(
152
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
153
- ) -> List[int]:
154
- """
155
- Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
156
- special tokens using the tokenizer `prepare_for_model` method.
157
-
158
- Args:
159
- token_ids_0 (`List[int]`):
160
- List of IDs.
161
- token_ids_1 (`List[int]`, *optional*):
162
- Optional second list of IDs for sequence pairs.
163
- already_has_special_tokens (`bool`, *optional*, defaults to `False`):
164
- Whether or not the token list is already formatted with special tokens for the model.
165
-
166
- Returns:
167
- `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
168
- """
169
- if already_has_special_tokens:
170
- return super().get_special_tokens_mask(
171
- token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
172
- )
173
-
174
- # normal case: some special tokens
175
- if token_ids_1 is None:
176
- return ([0] * len(token_ids_0)) + [1]
177
- return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
178
-
179
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present
180
- def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
181
- """Do not add eos again if user already added it."""
182
- if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
183
- warnings.warn(
184
- f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
185
- " eos tokens being added."
186
- )
187
- return token_ids
188
- else:
189
- return token_ids + [self.eos_token_id]
190
-
191
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences
192
- def create_token_type_ids_from_sequences(
193
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
194
- ) -> List[int]:
195
- """
196
- Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
197
- use of token type ids, therefore a list of zeros is returned.
198
-
199
- Args:
200
- token_ids_0 (`List[int]`):
201
- List of IDs.
202
- token_ids_1 (`List[int]`, *optional*):
203
- Optional second list of IDs for sequence pairs.
204
-
205
- Returns:
206
- `List[int]`: List of zeros.
207
- """
208
- eos = [self.eos_token_id]
209
-
210
- if token_ids_1 is None:
211
- return len(token_ids_0 + eos) * [0]
212
- return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
213
-
214
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens
215
- def build_inputs_with_special_tokens(
216
- self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
217
- ) -> List[int]:
218
- """
219
- Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
220
- adding special tokens. A sequence has the following format:
221
-
222
- - single sequence: `X </s>`
223
- - pair of sequences: `A </s> B </s>`
224
-
225
- Args:
226
- token_ids_0 (`List[int]`):
227
- List of IDs to which the special tokens will be added.
228
- token_ids_1 (`List[int]`, *optional*):
229
- Optional second list of IDs for sequence pairs.
230
-
231
- Returns:
232
- `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
233
- """
234
- token_ids_0 = self._add_eos_if_not_present(token_ids_0)
235
- if token_ids_1 is None:
236
- return token_ids_0
237
- else:
238
- token_ids_1 = self._add_eos_if_not_present(token_ids_1)
239
- return token_ids_0 + token_ids_1
240
-
241
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__
242
- def __getstate__(self):
243
- state = self.__dict__.copy()
244
- state["sp_model"] = None
245
- return state
246
-
247
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__
248
- def __setstate__(self, d):
249
- self.__dict__ = d
250
-
251
- # for backward compatibility
252
- if not hasattr(self, "sp_model_kwargs"):
253
- self.sp_model_kwargs = {}
254
-
255
- self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
256
- self.sp_model.Load(self.vocab_file)
257
-
258
- def remove_punctuation(self, text: str) -> str:
259
- return text.translate(str.maketrans("", "", string.punctuation))
260
-
261
- # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
262
- def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
263
- """Returns canonicalized `text` (puncuation removed).
264
-
265
- Args:
266
- text (`str`):
267
- String to be canonicalized.
268
- keep_punctuation_exact_string (`str`, *optional*):
269
- If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}'
270
- (but will still remove '{' and '}' that appear separately).
271
- """
272
- if keep_punctuation_exact_string:
273
- text = keep_punctuation_exact_string.join(
274
- self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string)
275
- )
276
- else:
277
- text = self.remove_punctuation(text)
278
- text = re.sub(r"\s+", " ", text)
279
- text = text.strip()
280
-
281
- return text
282
-
283
- def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
284
- """
285
- Converts a string to a list of tokens.
286
- """
287
- tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
288
-
289
- if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
290
- tokens = tokens[1:]
291
- return tokens
292
-
293
- @property
294
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length
295
- def unk_token_length(self):
296
- return len(self.sp_model.encode(str(self.unk_token)))
297
-
298
- def _tokenize(self, text, **kwargs):
299
- """
300
- Returns a tokenized string.
301
-
302
- We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
303
- SPIECE_UNDERLINE.
304
-
305
- For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`.
306
-
307
- Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
308
- `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
309
- """
310
- text = self.canonicalize_text(text, keep_punctuation_exact_string=None)
311
- tokens = self.sp_model.encode(text, out_type=str)
312
-
313
- # 1. Encode string + prefix ex: "<unk> Hey"
314
- tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
315
- # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
316
- return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
317
-
318
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id
319
- def _convert_token_to_id(self, token):
320
- """Converts a token (str) in an id using the vocab."""
321
- return self.sp_model.piece_to_id(token)
322
-
323
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token
324
- def _convert_id_to_token(self, index):
325
- """Converts an index (integer) in a token (str) using the vocab."""
326
- token = self.sp_model.IdToPiece(index)
327
- return token
328
-
329
- def convert_tokens_to_string(self, tokens):
330
- """Converts a sequence of tokens (string) in a single string."""
331
- current_sub_tokens = []
332
- out_string = ""
333
- prev_is_special = False
334
- for token in tokens:
335
- # make sure that special tokens are not decoded using sentencepiece model
336
- if token in self.all_special_tokens:
337
- if not prev_is_special:
338
- out_string += " "
339
- out_string += self.sp_model.decode(current_sub_tokens) + token
340
- prev_is_special = True
341
- current_sub_tokens = []
342
- else:
343
- current_sub_tokens.append(token)
344
- prev_is_special = False
345
- out_string += self.sp_model.decode(current_sub_tokens)
346
- return out_string.strip()
347
-
348
- # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary
349
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
350
- if not os.path.isdir(save_directory):
351
- logger.error(f"Vocabulary path ({save_directory}) should be a directory")
352
- return
353
- out_vocab_file = os.path.join(
354
- save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
355
- )
356
-
357
- if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
358
- copyfile(self.vocab_file, out_vocab_file)
359
- elif not os.path.isfile(self.vocab_file):
360
- with open(out_vocab_file, "wb") as fi:
361
- content_spiece_model = self.sp_model.serialized_model_proto()
362
- fi.write(content_spiece_model)
363
-
364
- return (out_vocab_file,)
 
1
+ # Copyright 2024 The HuggingFace Inc. team.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Tokenization class for SigLIP model."""
5
+
6
+ import os
7
+ import re
8
+ import string
9
+ import warnings
10
+ from shutil import copyfile
11
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
12
+
13
+ import sentencepiece as spm
14
+
15
+ from transformers.convert_slow_tokenizer import import_protobuf
16
+ from transformers.tokenization_utils import PreTrainedTokenizer
17
+ from transformers.tokenization_utils_base import AddedToken
18
+
19
+
20
+ if TYPE_CHECKING:
21
+ from transformers.tokenization_utils_base import TextInput
22
+ from transformers.utils import logging, requires_backends
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
28
+
29
+
30
+ SPIECE_UNDERLINE = "▁"
31
+
32
+
33
+ class SiglipTokenizer(PreTrainedTokenizer):
34
+ """
35
+ Construct a Siglip tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
36
+
37
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
38
+ this superclass for more information regarding those methods.
39
+
40
+ Args:
41
+ vocab_file (`str`):
42
+ [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
43
+ contains the vocabulary necessary to instantiate a tokenizer.
44
+ eos_token (`str`, *optional*, defaults to `"</s>"`):
45
+ The end of sequence token.
46
+ unk_token (`str`, *optional*, defaults to `"<unk>"`):
47
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
48
+ token instead.
49
+ pad_token (`str`, *optional*, defaults to `"</s>"`):
50
+ The token used for padding, for example when batching sequences of different lengths.
51
+ additional_special_tokens (`List[str]`, *optional*):
52
+ Additional special tokens used by the tokenizer.
53
+ sp_model_kwargs (`dict`, *optional*):
54
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
55
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
56
+ to set:
57
+
58
+ - `enable_sampling`: Enable subword regularization.
59
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
60
+
61
+ - `nbest_size = {0,1}`: No sampling is performed.
62
+ - `nbest_size > 1`: samples from the nbest_size results.
63
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
64
+ using forward-filtering-and-backward-sampling algorithm.
65
+
66
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
67
+ BPE-dropout.
68
+ model_max_length (`int`, *optional*, defaults to 64):
69
+ The maximum length (in number of tokens) for model inputs.
70
+ do_lower_case (`bool`, *optional*, defaults to `True`):
71
+ Whether or not to lowercase the input when tokenizing.
72
+ """
73
+
74
+ vocab_files_names = VOCAB_FILES_NAMES
75
+ model_input_names = ["input_ids", "attention_mask"]
76
+
77
+ def __init__(
78
+ self,
79
+ vocab_file,
80
+ eos_token="</s>",
81
+ unk_token="<unk>",
82
+ pad_token="</s>",
83
+ additional_special_tokens=None,
84
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
85
+ model_max_length=64,
86
+ do_lower_case=True,
87
+ **kwargs,
88
+ ) -> None:
89
+ requires_backends(self, "protobuf")
90
+
91
+ pad_token = (
92
+ AddedToken(pad_token, rstrip=True, lstrip=True, normalized=False, special=True)
93
+ if isinstance(pad_token, str)
94
+ else pad_token
95
+ )
96
+ unk_token = (
97
+ AddedToken(unk_token, rstrip=True, lstrip=True, normalized=False, special=True)
98
+ if isinstance(unk_token, str)
99
+ else unk_token
100
+ )
101
+ eos_token = (
102
+ AddedToken(eos_token, rstrip=True, lstrip=True, normalized=False, special=True)
103
+ if isinstance(eos_token, str)
104
+ else eos_token
105
+ )
106
+
107
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
108
+
109
+ self.do_lower_case = do_lower_case
110
+ self.vocab_file = vocab_file
111
+
112
+ self.sp_model = self.get_spm_processor()
113
+ self.vocab_file = vocab_file
114
+
115
+ super().__init__(
116
+ eos_token=eos_token,
117
+ unk_token=unk_token,
118
+ pad_token=pad_token,
119
+ additional_special_tokens=additional_special_tokens,
120
+ sp_model_kwargs=self.sp_model_kwargs,
121
+ model_max_length=model_max_length,
122
+ do_lower_case=do_lower_case,
123
+ **kwargs,
124
+ )
125
+
126
+ def get_spm_processor(self):
127
+ tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs)
128
+ with open(self.vocab_file, "rb") as f:
129
+ sp_model = f.read()
130
+ model_pb2 = import_protobuf()
131
+ model = model_pb2.ModelProto.FromString(sp_model)
132
+ normalizer_spec = model_pb2.NormalizerSpec()
133
+ normalizer_spec.add_dummy_prefix = False
134
+ model.normalizer_spec.MergeFrom(normalizer_spec)
135
+ sp_model = model.SerializeToString()
136
+ tokenizer.LoadFromSerializedProto(sp_model)
137
+ return tokenizer
138
+
139
+ @property
140
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.vocab_size
141
+ def vocab_size(self):
142
+ return self.sp_model.get_piece_size()
143
+
144
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_vocab
145
+ def get_vocab(self):
146
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
147
+ vocab.update(self.added_tokens_encoder)
148
+ return vocab
149
+
150
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_special_tokens_mask
151
+ def get_special_tokens_mask(
152
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
153
+ ) -> List[int]:
154
+ """
155
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
156
+ special tokens using the tokenizer `prepare_for_model` method.
157
+
158
+ Args:
159
+ token_ids_0 (`List[int]`):
160
+ List of IDs.
161
+ token_ids_1 (`List[int]`, *optional*):
162
+ Optional second list of IDs for sequence pairs.
163
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
164
+ Whether or not the token list is already formatted with special tokens for the model.
165
+
166
+ Returns:
167
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
168
+ """
169
+ if already_has_special_tokens:
170
+ return super().get_special_tokens_mask(
171
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
172
+ )
173
+
174
+ # normal case: some special tokens
175
+ if token_ids_1 is None:
176
+ return ([0] * len(token_ids_0)) + [1]
177
+ return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
178
+
179
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._add_eos_if_not_present
180
+ def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
181
+ """Do not add eos again if user already added it."""
182
+ if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
183
+ warnings.warn(
184
+ f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
185
+ " eos tokens being added."
186
+ )
187
+ return token_ids
188
+ else:
189
+ return token_ids + [self.eos_token_id]
190
+
191
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.create_token_type_ids_from_sequences
192
+ def create_token_type_ids_from_sequences(
193
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
194
+ ) -> List[int]:
195
+ """
196
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
197
+ use of token type ids, therefore a list of zeros is returned.
198
+
199
+ Args:
200
+ token_ids_0 (`List[int]`):
201
+ List of IDs.
202
+ token_ids_1 (`List[int]`, *optional*):
203
+ Optional second list of IDs for sequence pairs.
204
+
205
+ Returns:
206
+ `List[int]`: List of zeros.
207
+ """
208
+ eos = [self.eos_token_id]
209
+
210
+ if token_ids_1 is None:
211
+ return len(token_ids_0 + eos) * [0]
212
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
213
+
214
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.build_inputs_with_special_tokens
215
+ def build_inputs_with_special_tokens(
216
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
217
+ ) -> List[int]:
218
+ """
219
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
220
+ adding special tokens. A sequence has the following format:
221
+
222
+ - single sequence: `X </s>`
223
+ - pair of sequences: `A </s> B </s>`
224
+
225
+ Args:
226
+ token_ids_0 (`List[int]`):
227
+ List of IDs to which the special tokens will be added.
228
+ token_ids_1 (`List[int]`, *optional*):
229
+ Optional second list of IDs for sequence pairs.
230
+
231
+ Returns:
232
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
233
+ """
234
+ token_ids_0 = self._add_eos_if_not_present(token_ids_0)
235
+ if token_ids_1 is None:
236
+ return token_ids_0
237
+ else:
238
+ token_ids_1 = self._add_eos_if_not_present(token_ids_1)
239
+ return token_ids_0 + token_ids_1
240
+
241
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__getstate__
242
+ def __getstate__(self):
243
+ state = self.__dict__.copy()
244
+ state["sp_model"] = None
245
+ return state
246
+
247
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.__setstate__
248
+ def __setstate__(self, d):
249
+ self.__dict__ = d
250
+
251
+ # for backward compatibility
252
+ if not hasattr(self, "sp_model_kwargs"):
253
+ self.sp_model_kwargs = {}
254
+
255
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
256
+ self.sp_model.Load(self.vocab_file)
257
+
258
+ def remove_punctuation(self, text: str) -> str:
259
+ return text.translate(str.maketrans("", "", string.punctuation))
260
+
261
+ # source: https://github.com/google-research/big_vision/blob/3b8e5ab6ad4f96e32b32826f9e1b8fd277914f9c/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94
262
+ def canonicalize_text(self, text, *, keep_punctuation_exact_string=None):
263
+ """Returns canonicalized `text` (puncuation removed).
264
+
265
+ Args:
266
+ text (`str`):
267
+ String to be canonicalized.
268
+ keep_punctuation_exact_string (`str`, *optional*):
269
+ If provided, then this exact string is kept. For example providing '{}' will keep any occurrences of '{}'
270
+ (but will still remove '{' and '}' that appear separately).
271
+ """
272
+ if keep_punctuation_exact_string:
273
+ text = keep_punctuation_exact_string.join(
274
+ self.remove_punctuation(part) for part in text.split(keep_punctuation_exact_string)
275
+ )
276
+ else:
277
+ text = self.remove_punctuation(text)
278
+ text = re.sub(r"\s+", " ", text)
279
+ text = text.strip()
280
+
281
+ return text
282
+
283
+ def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]:
284
+ """
285
+ Converts a string to a list of tokens.
286
+ """
287
+ tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs)
288
+
289
+ if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens:
290
+ tokens = tokens[1:]
291
+ return tokens
292
+
293
+ @property
294
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.unk_token_length
295
+ def unk_token_length(self):
296
+ return len(self.sp_model.encode(str(self.unk_token)))
297
+
298
+ def _tokenize(self, text, **kwargs):
299
+ """
300
+ Returns a tokenized string.
301
+
302
+ We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any
303
+ SPIECE_UNDERLINE.
304
+
305
+ For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give `['H', 'e', 'y']` instead of `['▁He', 'y']`.
306
+
307
+ Thus we always encode `f"{unk_token}text"` and strip the `unk_token`. Here is an example with `unk_token = "<unk>"` and `unk_token_length = 4`.
308
+ `self.tokenizer.sp_model.encode("<unk> Hey", out_type = str)[4:]`.
309
+ """
310
+ text = self.canonicalize_text(text, keep_punctuation_exact_string=None)
311
+ tokens = self.sp_model.encode(text, out_type=str)
312
+
313
+ # 1. Encode string + prefix ex: "<unk> Hey"
314
+ tokens = self.sp_model.encode(self.unk_token + text, out_type=str)
315
+ # 2. Remove self.unk_token from ['<','unk','>', '▁Hey']
316
+ return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens
317
+
318
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_token_to_id
319
+ def _convert_token_to_id(self, token):
320
+ """Converts a token (str) in an id using the vocab."""
321
+ return self.sp_model.piece_to_id(token)
322
+
323
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._convert_id_to_token
324
+ def _convert_id_to_token(self, index):
325
+ """Converts an index (integer) in a token (str) using the vocab."""
326
+ token = self.sp_model.IdToPiece(index)
327
+ return token
328
+
329
+ def convert_tokens_to_string(self, tokens):
330
+ """Converts a sequence of tokens (string) in a single string."""
331
+ current_sub_tokens = []
332
+ out_string = ""
333
+ prev_is_special = False
334
+ for token in tokens:
335
+ # make sure that special tokens are not decoded using sentencepiece model
336
+ if token in self.all_special_tokens:
337
+ if not prev_is_special:
338
+ out_string += " "
339
+ out_string += self.sp_model.decode(current_sub_tokens) + token
340
+ prev_is_special = True
341
+ current_sub_tokens = []
342
+ else:
343
+ current_sub_tokens.append(token)
344
+ prev_is_special = False
345
+ out_string += self.sp_model.decode(current_sub_tokens)
346
+ return out_string.strip()
347
+
348
+ # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.save_vocabulary
349
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
350
+ if not os.path.isdir(save_directory):
351
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
352
+ return
353
+ out_vocab_file = os.path.join(
354
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
355
+ )
356
+
357
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
358
+ copyfile(self.vocab_file, out_vocab_file)
359
+ elif not os.path.isfile(self.vocab_file):
360
+ with open(out_vocab_file, "wb") as fi:
361
+ content_spiece_model = self.sp_model.serialized_model_proto()
362
+ fi.write(content_spiece_model)
363
+
364
+ return (out_vocab_file,)