KingNish commited on
Commit
0e5cadd
ยท
verified ยท
1 Parent(s): 12a0dd9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +507 -507
app.py CHANGED
@@ -1,508 +1,508 @@
1
- import spaces
2
- import gradio as gr
3
- import numpy as np
4
- import os
5
- import torch
6
- import random
7
- import subprocess
8
- subprocess.run(
9
- "pip install flash-attn --no-build-isolation",
10
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
11
- shell=True,
12
- )
13
-
14
- from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
- from PIL import Image
16
-
17
- from data.data_utils import add_special_tokens, pil_img2rgb
18
- from data.transforms import ImageTransform
19
- from inferencer import InterleaveInferencer
20
- from modeling.autoencoder import load_ae
21
- from modeling.bagel.qwen2_navit import NaiveCache
22
- from modeling.bagel import (
23
- BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
24
- SiglipVisionConfig, SiglipVisionModel
25
- )
26
- from modeling.qwen2 import Qwen2Tokenizer
27
-
28
- from huggingface_hub import snapshot_download
29
-
30
- save_dir = "./model"
31
- repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
32
- cache_dir = save_dir + "/cache"
33
-
34
- snapshot_download(cache_dir=cache_dir,
35
- local_dir=save_dir,
36
- repo_id=repo_id,
37
- local_dir_use_symlinks=False,
38
- resume_download=True,
39
- allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
40
- )
41
-
42
- # Model Initialization
43
- model_path = "./model" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT
44
-
45
- llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
46
- llm_config.qk_norm = True
47
- llm_config.tie_word_embeddings = False
48
- llm_config.layer_module = "Qwen2MoTDecoderLayer"
49
-
50
- vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
51
- vit_config.rope = False
52
- vit_config.num_hidden_layers -= 1
53
-
54
- vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
55
-
56
- config = BagelConfig(
57
- visual_gen=True,
58
- visual_und=True,
59
- llm_config=llm_config,
60
- vit_config=vit_config,
61
- vae_config=vae_config,
62
- vit_max_num_patch_per_side=70,
63
- connector_act='gelu_pytorch_tanh',
64
- latent_patch_size=2,
65
- max_latent_size=64,
66
- )
67
-
68
- with init_empty_weights():
69
- language_model = Qwen2ForCausalLM(llm_config)
70
- vit_model = SiglipVisionModel(vit_config)
71
- model = Bagel(language_model, vit_model, config)
72
- model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
73
-
74
- tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
75
- tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
76
-
77
- vae_transform = ImageTransform(1024, 512, 16)
78
- vit_transform = ImageTransform(980, 224, 14)
79
-
80
- # Model Loading and Multi GPU Infernece Preparing
81
- device_map = infer_auto_device_map(
82
- model,
83
- max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
84
- no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
85
- )
86
-
87
- same_device_modules = [
88
- 'language_model.model.embed_tokens',
89
- 'time_embedder',
90
- 'latent_pos_embed',
91
- 'vae2llm',
92
- 'llm2vae',
93
- 'connector',
94
- 'vit_pos_embed'
95
- ]
96
-
97
- if torch.cuda.device_count() == 1:
98
- first_device = device_map.get(same_device_modules[0], "cuda:0")
99
- for k in same_device_modules:
100
- if k in device_map:
101
- device_map[k] = first_device
102
- else:
103
- device_map[k] = "cuda:0"
104
- else:
105
- first_device = device_map.get(same_device_modules[0])
106
- for k in same_device_modules:
107
- if k in device_map:
108
- device_map[k] = first_device
109
-
110
- model = load_checkpoint_and_dispatch(
111
- model,
112
- checkpoint=os.path.join(model_path, "ema.safetensors"),
113
- device_map=device_map,
114
- offload_buffers=True,
115
- dtype=torch.bfloat16,
116
- force_hooks=True,
117
- ).eval()
118
-
119
-
120
- # Inferencer Preparing
121
- inferencer = InterleaveInferencer(
122
- model=model,
123
- vae_model=vae_model,
124
- tokenizer=tokenizer,
125
- vae_transform=vae_transform,
126
- vit_transform=vit_transform,
127
- new_token_ids=new_token_ids,
128
- )
129
-
130
- def set_seed(seed):
131
- """Set random seeds for reproducibility"""
132
- if seed > 0:
133
- random.seed(seed)
134
- np.random.seed(seed)
135
- torch.manual_seed(seed)
136
- if torch.cuda.is_available():
137
- torch.cuda.manual_seed(seed)
138
- torch.cuda.manual_seed_all(seed)
139
- torch.backends.cudnn.deterministic = True
140
- torch.backends.cudnn.benchmark = False
141
- return seed
142
-
143
- # Text to Image function with thinking option and hyperparameters
144
- @spaces.GPU(duration=90)
145
- def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
146
- timestep_shift=3.0, num_timesteps=50,
147
- cfg_renorm_min=1.0, cfg_renorm_type="global",
148
- max_think_token_n=1024, do_sample=False, text_temperature=0.3,
149
- seed=0, image_ratio="1:1"):
150
- # Set seed for reproducibility
151
- set_seed(seed)
152
-
153
- if image_ratio == "1:1":
154
- image_shapes = (1024, 1024)
155
- elif image_ratio == "4:3":
156
- image_shapes = (768, 1024)
157
- elif image_ratio == "3:4":
158
- image_shapes = (1024, 768)
159
- elif image_ratio == "16:9":
160
- image_shapes = (576, 1024)
161
- elif image_ratio == "9:16":
162
- image_shapes = (1024, 576)
163
-
164
- # Set hyperparameters
165
- inference_hyper = dict(
166
- max_think_token_n=max_think_token_n if show_thinking else 1024,
167
- do_sample=do_sample if show_thinking else False,
168
- text_temperature=text_temperature if show_thinking else 0.3,
169
- cfg_text_scale=cfg_text_scale,
170
- cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
171
- timestep_shift=timestep_shift,
172
- num_timesteps=num_timesteps,
173
- cfg_renorm_min=cfg_renorm_min,
174
- cfg_renorm_type=cfg_renorm_type,
175
- image_shapes=image_shapes,
176
- )
177
-
178
- result = {}
179
-
180
- # Call inferencer with or without think parameter based on user choice
181
- for i in inferencer(text=prompt, think=show_thinking, **inference_hyper):
182
- if type(i) == str:
183
- result["text"] += i
184
- elif type(i) == Image.Image:
185
- result["image"] = i
186
-
187
- yield result["image"], result.get("text", None)
188
-
189
-
190
- # Image Understanding function with thinking option and hyperparameters
191
- @spaces.GPU(duration=90)
192
- def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
193
- do_sample=False, text_temperature=0.3, max_new_tokens=512):
194
- if image is None:
195
- return "Please upload an image."
196
-
197
- if isinstance(image, np.ndarray):
198
- image = Image.fromarray(image)
199
-
200
- image = pil_img2rgb(image)
201
-
202
- # Set hyperparameters
203
- inference_hyper = dict(
204
- do_sample=do_sample,
205
- text_temperature=text_temperature,
206
- max_think_token_n=max_new_tokens, # Set max_length
207
- )
208
-
209
- result = {}
210
- # Use show_thinking parameter to control thinking process
211
- for i in inferencer(image=image, text=prompt, think=show_thinking,
212
- understanding_output=True, **inference_hyper):
213
- if type(i) == str:
214
- result["text"] += i
215
- elif type(i) == Image.Image:
216
- result["image"] = i
217
- yield result["text"]
218
-
219
-
220
- # Image Editing function with thinking option and hyperparameters
221
- @spaces.GPU(duration=90)
222
- def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
223
- cfg_img_scale=2.0, cfg_interval=0.0,
224
- timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
225
- cfg_renorm_type="text_channel", max_think_token_n=1024,
226
- do_sample=False, text_temperature=0.3, seed=0):
227
- # Set seed for reproducibility
228
- set_seed(seed)
229
-
230
- if image is None:
231
- return "Please upload an image.", ""
232
-
233
- if isinstance(image, np.ndarray):
234
- image = Image.fromarray(image)
235
-
236
- image = pil_img2rgb(image)
237
-
238
- # Set hyperparameters
239
- inference_hyper = dict(
240
- max_think_token_n=max_think_token_n if show_thinking else 1024,
241
- do_sample=do_sample if show_thinking else False,
242
- text_temperature=text_temperature if show_thinking else 0.3,
243
- cfg_text_scale=cfg_text_scale,
244
- cfg_img_scale=cfg_img_scale,
245
- cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
246
- timestep_shift=timestep_shift,
247
- num_timesteps=num_timesteps,
248
- cfg_renorm_min=cfg_renorm_min,
249
- cfg_renorm_type=cfg_renorm_type,
250
- )
251
-
252
- # Include thinking parameter based on user choice
253
- result = {}
254
- for i in inferencer(image=image, text=prompt, think=show_thinking, **inference_hyper):
255
- if type(i) == str:
256
- result["text"] += i
257
- elif type(i) == Image.Image:
258
- result["image"] = i
259
-
260
- yield result["image"], result.get("text", "")
261
-
262
- # Helper function to load example images
263
- def load_example_image(image_path):
264
- try:
265
- return Image.open(image_path)
266
- except Exception as e:
267
- print(f"Error loading example image: {e}")
268
- return None
269
-
270
-
271
- # Gradio UI
272
- with gr.Blocks() as demo:
273
- gr.Markdown("""
274
- <div>
275
- <img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/>
276
- </div>
277
- """)
278
-
279
- with gr.Tab("๐Ÿ“ Text to Image"):
280
- txt_input = gr.Textbox(
281
- label="Prompt",
282
- value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
283
- )
284
-
285
- with gr.Row():
286
- show_thinking = gr.Checkbox(label="Thinking", value=False)
287
-
288
- # Add hyperparameter controls in an accordion
289
- with gr.Accordion("Inference Hyperparameters", open=False):
290
- # ๅ‚ๆ•ฐไธ€ๆŽ’ไธคไธชๅธƒๅฑ€
291
- with gr.Group():
292
- with gr.Row():
293
- seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1,
294
- label="Seed", info="0 for random seed, positive for reproducible results")
295
- image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"],
296
- value="1:1", label="Image Ratio",
297
- info="The longer size is fixed to 1024")
298
-
299
- with gr.Row():
300
- cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
301
- label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)")
302
- cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1,
303
- label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
304
-
305
- with gr.Row():
306
- cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
307
- value="global", label="CFG Renorm Type",
308
- info="If the genrated image is blurry, use 'global'")
309
- cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
310
- label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
311
-
312
- with gr.Row():
313
- num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
314
- label="Timesteps", info="Total denoising steps")
315
- timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True,
316
- label="Timestep Shift", info="Higher values for layout, lower for details")
317
-
318
- # Thinking parameters in a single row
319
- thinking_params = gr.Group(visible=False)
320
- with thinking_params:
321
- with gr.Row():
322
- do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
323
- max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
324
- label="Max Think Tokens", info="Maximum number of tokens for thinking")
325
- text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
326
- label="Temperature", info="Controls randomness in text generation")
327
-
328
- thinking_output = gr.Textbox(label="Thinking Process", visible=False)
329
- img_output = gr.Image(label="Generated Image")
330
- gen_btn = gr.Button("Generate")
331
-
332
- # Dynamically show/hide thinking process box and parameters
333
- def update_thinking_visibility(show):
334
- return gr.update(visible=show), gr.update(visible=show)
335
-
336
- show_thinking.change(
337
- fn=update_thinking_visibility,
338
- inputs=[show_thinking],
339
- outputs=[thinking_output, thinking_params]
340
- )
341
-
342
- gen_btn.click(
343
- fn=text_to_image,
344
- inputs=[
345
- txt_input, show_thinking, cfg_text_scale,
346
- cfg_interval, timestep_shift,
347
- num_timesteps, cfg_renorm_min, cfg_renorm_type,
348
- max_think_token_n, do_sample, text_temperature, seed, image_ratio
349
- ],
350
- outputs=[img_output, thinking_output]
351
- )
352
-
353
- with gr.Tab("๐Ÿ–Œ๏ธ Image Edit"):
354
- with gr.Row():
355
- with gr.Column(scale=1):
356
- edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'))
357
- edit_prompt = gr.Textbox(
358
- label="Prompt",
359
- value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes."
360
- )
361
-
362
- with gr.Column(scale=1):
363
- edit_image_output = gr.Image(label="Result")
364
- edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
365
-
366
- with gr.Row():
367
- edit_show_thinking = gr.Checkbox(label="Thinking", value=False)
368
-
369
- # Add hyperparameter controls in an accordion
370
- with gr.Accordion("Inference Hyperparameters", open=False):
371
- with gr.Group():
372
- with gr.Row():
373
- edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True,
374
- label="Seed", info="0 for random seed, positive for reproducible results")
375
- edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
376
- label="CFG Text Scale", info="Controls how strongly the model follows the text prompt")
377
-
378
- with gr.Row():
379
- edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True,
380
- label="CFG Image Scale", info="Controls how much the model preserves input image details")
381
- edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
382
- label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
383
-
384
- with gr.Row():
385
- edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
386
- value="text_channel", label="CFG Renorm Type",
387
- info="If the genrated image is blurry, use 'global")
388
- edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
389
- label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
390
-
391
- with gr.Row():
392
- edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
393
- label="Timesteps", info="Total denoising steps")
394
- edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True,
395
- label="Timestep Shift", info="Higher values for layout, lower for details")
396
-
397
-
398
- # Thinking parameters in a single row
399
- edit_thinking_params = gr.Group(visible=False)
400
- with edit_thinking_params:
401
- with gr.Row():
402
- edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
403
- edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
404
- label="Max Think Tokens", info="Maximum number of tokens for thinking")
405
- edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
406
- label="Temperature", info="Controls randomness in text generation")
407
-
408
- edit_btn = gr.Button("Submit")
409
-
410
- # Dynamically show/hide thinking process box for editing
411
- def update_edit_thinking_visibility(show):
412
- return gr.update(visible=show), gr.update(visible=show)
413
-
414
- edit_show_thinking.change(
415
- fn=update_edit_thinking_visibility,
416
- inputs=[edit_show_thinking],
417
- outputs=[edit_thinking_output, edit_thinking_params]
418
- )
419
-
420
- edit_btn.click(
421
- fn=edit_image,
422
- inputs=[
423
- edit_image_input, edit_prompt, edit_show_thinking,
424
- edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval,
425
- edit_timestep_shift, edit_num_timesteps,
426
- edit_cfg_renorm_min, edit_cfg_renorm_type,
427
- edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed
428
- ],
429
- outputs=[edit_image_output, edit_thinking_output]
430
- )
431
-
432
- with gr.Tab("๐Ÿ–ผ๏ธ Image Understanding"):
433
- with gr.Row():
434
- with gr.Column(scale=1):
435
- img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'))
436
- understand_prompt = gr.Textbox(
437
- label="Prompt",
438
- value="Can someone explain what's funny about this meme??"
439
- )
440
-
441
- with gr.Column(scale=1):
442
- txt_output = gr.Textbox(label="Result", lines=20)
443
-
444
- with gr.Row():
445
- understand_show_thinking = gr.Checkbox(label="Thinking", value=False)
446
-
447
- # Add hyperparameter controls in an accordion
448
- with gr.Accordion("Inference Hyperparameters", open=False):
449
- with gr.Row():
450
- understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
451
- understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True,
452
- label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)")
453
- understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True,
454
- label="Max New Tokens", info="Maximum length of generated text, including potential thinking")
455
-
456
- img_understand_btn = gr.Button("Submit")
457
-
458
- img_understand_btn.click(
459
- fn=image_understanding,
460
- inputs=[
461
- img_input, understand_prompt, understand_show_thinking,
462
- understand_do_sample, understand_text_temperature, understand_max_new_tokens
463
- ],
464
- outputs=txt_output
465
- )
466
-
467
- gr.Markdown("""
468
- <div style="display: flex; justify-content: flex-start; flex-wrap: wrap; gap: 10px;">
469
- <a href="https://bagel-ai.org/">
470
- <img
471
- src="https://img.shields.io/badge/BAGEL-Website-0A66C2?logo=safari&logoColor=white"
472
- alt="BAGEL Website"
473
- />
474
- </a>
475
- <a href="https://arxiv.org/abs/2505.14683">
476
- <img
477
- src="https://img.shields.io/badge/BAGEL-Paper-red?logo=arxiv&logoColor=red"
478
- alt="BAGEL Paper on arXiv"
479
- />
480
- </a>
481
- <a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
482
- <img
483
- src="https://img.shields.io/badge/BAGEL-Hugging%20Face-orange?logo=huggingface&logoColor=yellow"
484
- alt="BAGEL on Hugging Face"
485
- />
486
- </a>
487
- <a href="https://demo.bagel-ai.org/">
488
- <img
489
- src="https://img.shields.io/badge/BAGEL-Demo-blue?logo=googleplay&logoColor=blue"
490
- alt="BAGEL Demo"
491
- />
492
- </a>
493
- <a href="https://discord.gg/Z836xxzy">
494
- <img
495
- src="https://img.shields.io/badge/BAGEL-Discord-5865F2?logo=discord&logoColor=purple"
496
- alt="BAGEL Discord"
497
- />
498
- </a>
499
- <a href="mailto:[email protected]">
500
- <img
501
- src="https://img.shields.io/badge/BAGEL-Email-D14836?logo=gmail&logoColor=red"
502
- alt="BAGEL Email"
503
- />
504
- </a>
505
- </div>
506
- """)
507
-
508
  demo.launch()
 
1
+ import spaces
2
+ import gradio as gr
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ import random
7
+ import subprocess
8
+ subprocess.run(
9
+ "pip install flash-attn --no-build-isolation",
10
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
11
+ shell=True,
12
+ )
13
+
14
+ from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
+ from PIL import Image
16
+
17
+ from data.data_utils import add_special_tokens, pil_img2rgb
18
+ from data.transforms import ImageTransform
19
+ from inferencer import InterleaveInferencer
20
+ from modeling.autoencoder import load_ae
21
+ from modeling.bagel.qwen2_navit import NaiveCache
22
+ from modeling.bagel import (
23
+ BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
24
+ SiglipVisionConfig, SiglipVisionModel
25
+ )
26
+ from modeling.qwen2 import Qwen2Tokenizer
27
+
28
+ from huggingface_hub import snapshot_download
29
+
30
+ save_dir = "./model"
31
+ repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
32
+ cache_dir = save_dir + "/cache"
33
+
34
+ snapshot_download(cache_dir=cache_dir,
35
+ local_dir=save_dir,
36
+ repo_id=repo_id,
37
+ local_dir_use_symlinks=False,
38
+ resume_download=True,
39
+ allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
40
+ )
41
+
42
+ # Model Initialization
43
+ model_path = "./model" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT
44
+
45
+ llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
46
+ llm_config.qk_norm = True
47
+ llm_config.tie_word_embeddings = False
48
+ llm_config.layer_module = "Qwen2MoTDecoderLayer"
49
+
50
+ vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
51
+ vit_config.rope = False
52
+ vit_config.num_hidden_layers -= 1
53
+
54
+ vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
55
+
56
+ config = BagelConfig(
57
+ visual_gen=True,
58
+ visual_und=True,
59
+ llm_config=llm_config,
60
+ vit_config=vit_config,
61
+ vae_config=vae_config,
62
+ vit_max_num_patch_per_side=70,
63
+ connector_act='gelu_pytorch_tanh',
64
+ latent_patch_size=2,
65
+ max_latent_size=64,
66
+ )
67
+
68
+ with init_empty_weights():
69
+ language_model = Qwen2ForCausalLM(llm_config)
70
+ vit_model = SiglipVisionModel(vit_config)
71
+ model = Bagel(language_model, vit_model, config)
72
+ model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
73
+
74
+ tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
75
+ tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
76
+
77
+ vae_transform = ImageTransform(1024, 512, 16)
78
+ vit_transform = ImageTransform(980, 224, 14)
79
+
80
+ # Model Loading and Multi GPU Infernece Preparing
81
+ device_map = infer_auto_device_map(
82
+ model,
83
+ max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
84
+ no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
85
+ )
86
+
87
+ same_device_modules = [
88
+ 'language_model.model.embed_tokens',
89
+ 'time_embedder',
90
+ 'latent_pos_embed',
91
+ 'vae2llm',
92
+ 'llm2vae',
93
+ 'connector',
94
+ 'vit_pos_embed'
95
+ ]
96
+
97
+ if torch.cuda.device_count() == 1:
98
+ first_device = device_map.get(same_device_modules[0], "cuda:0")
99
+ for k in same_device_modules:
100
+ if k in device_map:
101
+ device_map[k] = first_device
102
+ else:
103
+ device_map[k] = "cuda:0"
104
+ else:
105
+ first_device = device_map.get(same_device_modules[0])
106
+ for k in same_device_modules:
107
+ if k in device_map:
108
+ device_map[k] = first_device
109
+
110
+ model = load_checkpoint_and_dispatch(
111
+ model,
112
+ checkpoint=os.path.join(model_path, "ema.safetensors"),
113
+ device_map=device_map,
114
+ offload_buffers=True,
115
+ dtype=torch.bfloat16,
116
+ force_hooks=True,
117
+ ).eval()
118
+
119
+
120
+ # Inferencer Preparing
121
+ inferencer = InterleaveInferencer(
122
+ model=model,
123
+ vae_model=vae_model,
124
+ tokenizer=tokenizer,
125
+ vae_transform=vae_transform,
126
+ vit_transform=vit_transform,
127
+ new_token_ids=new_token_ids,
128
+ )
129
+
130
+ def set_seed(seed):
131
+ """Set random seeds for reproducibility"""
132
+ if seed > 0:
133
+ random.seed(seed)
134
+ np.random.seed(seed)
135
+ torch.manual_seed(seed)
136
+ if torch.cuda.is_available():
137
+ torch.cuda.manual_seed(seed)
138
+ torch.cuda.manual_seed_all(seed)
139
+ torch.backends.cudnn.deterministic = True
140
+ torch.backends.cudnn.benchmark = False
141
+ return seed
142
+
143
+ # Text to Image function with thinking option and hyperparameters
144
+ @spaces.GPU(duration=90)
145
+ def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
146
+ timestep_shift=3.0, num_timesteps=50,
147
+ cfg_renorm_min=1.0, cfg_renorm_type="global",
148
+ max_think_token_n=1024, do_sample=False, text_temperature=0.3,
149
+ seed=0, image_ratio="1:1"):
150
+ # Set seed for reproducibility
151
+ set_seed(seed)
152
+
153
+ if image_ratio == "1:1":
154
+ image_shapes = (1024, 1024)
155
+ elif image_ratio == "4:3":
156
+ image_shapes = (768, 1024)
157
+ elif image_ratio == "3:4":
158
+ image_shapes = (1024, 768)
159
+ elif image_ratio == "16:9":
160
+ image_shapes = (576, 1024)
161
+ elif image_ratio == "9:16":
162
+ image_shapes = (1024, 576)
163
+
164
+ # Set hyperparameters
165
+ inference_hyper = dict(
166
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
167
+ do_sample=do_sample if show_thinking else False,
168
+ temperature=text_temperature if show_thinking else 0.3,
169
+ cfg_text_scale=cfg_text_scale,
170
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
171
+ timestep_shift=timestep_shift,
172
+ num_timesteps=num_timesteps,
173
+ cfg_renorm_min=cfg_renorm_min,
174
+ cfg_renorm_type=cfg_renorm_type,
175
+ image_shapes=image_shapes,
176
+ )
177
+
178
+ result = {}
179
+
180
+ # Call inferencer with or without think parameter based on user choice
181
+ for i in inferencer(text=prompt, think=show_thinking, **inference_hyper):
182
+ if type(i) == str:
183
+ result["text"] += i
184
+ elif type(i) == Image.Image:
185
+ result["image"] = i
186
+
187
+ yield result["image"], result.get("text", None)
188
+
189
+
190
+ # Image Understanding function with thinking option and hyperparameters
191
+ @spaces.GPU(duration=90)
192
+ def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
193
+ do_sample=False, text_temperature=0.3, max_new_tokens=512):
194
+ if image is None:
195
+ return "Please upload an image."
196
+
197
+ if isinstance(image, np.ndarray):
198
+ image = Image.fromarray(image)
199
+
200
+ image = pil_img2rgb(image)
201
+
202
+ # Set hyperparameters
203
+ inference_hyper = dict(
204
+ do_sample=do_sample,
205
+ temperature=text_temperature,
206
+ max_think_token_n=max_new_tokens, # Set max_length
207
+ )
208
+
209
+ result = {}
210
+ # Use show_thinking parameter to control thinking process
211
+ for i in inferencer(image=image, text=prompt, think=show_thinking,
212
+ understanding_output=True, **inference_hyper):
213
+ if type(i) == str:
214
+ result["text"] += i
215
+ elif type(i) == Image.Image:
216
+ result["image"] = i
217
+ yield result["text"]
218
+
219
+
220
+ # Image Editing function with thinking option and hyperparameters
221
+ @spaces.GPU(duration=90)
222
+ def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
223
+ cfg_img_scale=2.0, cfg_interval=0.0,
224
+ timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
225
+ cfg_renorm_type="text_channel", max_think_token_n=1024,
226
+ do_sample=False, text_temperature=0.3, seed=0):
227
+ # Set seed for reproducibility
228
+ set_seed(seed)
229
+
230
+ if image is None:
231
+ return "Please upload an image.", ""
232
+
233
+ if isinstance(image, np.ndarray):
234
+ image = Image.fromarray(image)
235
+
236
+ image = pil_img2rgb(image)
237
+
238
+ # Set hyperparameters
239
+ inference_hyper = dict(
240
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
241
+ do_sample=do_sample if show_thinking else False,
242
+ temperature=text_temperature if show_thinking else 0.3,
243
+ cfg_text_scale=cfg_text_scale,
244
+ cfg_img_scale=cfg_img_scale,
245
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
246
+ timestep_shift=timestep_shift,
247
+ num_timesteps=num_timesteps,
248
+ cfg_renorm_min=cfg_renorm_min,
249
+ cfg_renorm_type=cfg_renorm_type,
250
+ )
251
+
252
+ # Include thinking parameter based on user choice
253
+ result = {}
254
+ for i in inferencer(image=image, text=prompt, think=show_thinking, **inference_hyper):
255
+ if type(i) == str:
256
+ result["text"] += i
257
+ elif type(i) == Image.Image:
258
+ result["image"] = i
259
+
260
+ yield result["image"], result.get("text", "")
261
+
262
+ # Helper function to load example images
263
+ def load_example_image(image_path):
264
+ try:
265
+ return Image.open(image_path)
266
+ except Exception as e:
267
+ print(f"Error loading example image: {e}")
268
+ return None
269
+
270
+
271
+ # Gradio UI
272
+ with gr.Blocks() as demo:
273
+ gr.Markdown("""
274
+ <div>
275
+ <img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/>
276
+ </div>
277
+ """)
278
+
279
+ with gr.Tab("๐Ÿ“ Text to Image"):
280
+ txt_input = gr.Textbox(
281
+ label="Prompt",
282
+ value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
283
+ )
284
+
285
+ with gr.Row():
286
+ show_thinking = gr.Checkbox(label="Thinking", value=False)
287
+
288
+ # Add hyperparameter controls in an accordion
289
+ with gr.Accordion("Inference Hyperparameters", open=False):
290
+ # ๅ‚ๆ•ฐไธ€ๆŽ’ไธคไธชๅธƒๅฑ€
291
+ with gr.Group():
292
+ with gr.Row():
293
+ seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1,
294
+ label="Seed", info="0 for random seed, positive for reproducible results")
295
+ image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"],
296
+ value="1:1", label="Image Ratio",
297
+ info="The longer size is fixed to 1024")
298
+
299
+ with gr.Row():
300
+ cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
301
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)")
302
+ cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1,
303
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
304
+
305
+ with gr.Row():
306
+ cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
307
+ value="global", label="CFG Renorm Type",
308
+ info="If the genrated image is blurry, use 'global'")
309
+ cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
310
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
311
+
312
+ with gr.Row():
313
+ num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
314
+ label="Timesteps", info="Total denoising steps")
315
+ timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True,
316
+ label="Timestep Shift", info="Higher values for layout, lower for details")
317
+
318
+ # Thinking parameters in a single row
319
+ thinking_params = gr.Group(visible=False)
320
+ with thinking_params:
321
+ with gr.Row():
322
+ do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
323
+ max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
324
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
325
+ text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
326
+ label="Temperature", info="Controls randomness in text generation")
327
+
328
+ thinking_output = gr.Textbox(label="Thinking Process", visible=False)
329
+ img_output = gr.Image(label="Generated Image")
330
+ gen_btn = gr.Button("Generate")
331
+
332
+ # Dynamically show/hide thinking process box and parameters
333
+ def update_thinking_visibility(show):
334
+ return gr.update(visible=show), gr.update(visible=show)
335
+
336
+ show_thinking.change(
337
+ fn=update_thinking_visibility,
338
+ inputs=[show_thinking],
339
+ outputs=[thinking_output, thinking_params]
340
+ )
341
+
342
+ gen_btn.click(
343
+ fn=text_to_image,
344
+ inputs=[
345
+ txt_input, show_thinking, cfg_text_scale,
346
+ cfg_interval, timestep_shift,
347
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
348
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio
349
+ ],
350
+ outputs=[img_output, thinking_output]
351
+ )
352
+
353
+ with gr.Tab("๐Ÿ–Œ๏ธ Image Edit"):
354
+ with gr.Row():
355
+ with gr.Column(scale=1):
356
+ edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'))
357
+ edit_prompt = gr.Textbox(
358
+ label="Prompt",
359
+ value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes."
360
+ )
361
+
362
+ with gr.Column(scale=1):
363
+ edit_image_output = gr.Image(label="Result")
364
+ edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
365
+
366
+ with gr.Row():
367
+ edit_show_thinking = gr.Checkbox(label="Thinking", value=False)
368
+
369
+ # Add hyperparameter controls in an accordion
370
+ with gr.Accordion("Inference Hyperparameters", open=False):
371
+ with gr.Group():
372
+ with gr.Row():
373
+ edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True,
374
+ label="Seed", info="0 for random seed, positive for reproducible results")
375
+ edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
376
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt")
377
+
378
+ with gr.Row():
379
+ edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True,
380
+ label="CFG Image Scale", info="Controls how much the model preserves input image details")
381
+ edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
382
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
383
+
384
+ with gr.Row():
385
+ edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
386
+ value="text_channel", label="CFG Renorm Type",
387
+ info="If the genrated image is blurry, use 'global")
388
+ edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
389
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
390
+
391
+ with gr.Row():
392
+ edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
393
+ label="Timesteps", info="Total denoising steps")
394
+ edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True,
395
+ label="Timestep Shift", info="Higher values for layout, lower for details")
396
+
397
+
398
+ # Thinking parameters in a single row
399
+ edit_thinking_params = gr.Group(visible=False)
400
+ with edit_thinking_params:
401
+ with gr.Row():
402
+ edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
403
+ edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
404
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
405
+ edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
406
+ label="Temperature", info="Controls randomness in text generation")
407
+
408
+ edit_btn = gr.Button("Submit")
409
+
410
+ # Dynamically show/hide thinking process box for editing
411
+ def update_edit_thinking_visibility(show):
412
+ return gr.update(visible=show), gr.update(visible=show)
413
+
414
+ edit_show_thinking.change(
415
+ fn=update_edit_thinking_visibility,
416
+ inputs=[edit_show_thinking],
417
+ outputs=[edit_thinking_output, edit_thinking_params]
418
+ )
419
+
420
+ edit_btn.click(
421
+ fn=edit_image,
422
+ inputs=[
423
+ edit_image_input, edit_prompt, edit_show_thinking,
424
+ edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval,
425
+ edit_timestep_shift, edit_num_timesteps,
426
+ edit_cfg_renorm_min, edit_cfg_renorm_type,
427
+ edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed
428
+ ],
429
+ outputs=[edit_image_output, edit_thinking_output]
430
+ )
431
+
432
+ with gr.Tab("๐Ÿ–ผ๏ธ Image Understanding"):
433
+ with gr.Row():
434
+ with gr.Column(scale=1):
435
+ img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'))
436
+ understand_prompt = gr.Textbox(
437
+ label="Prompt",
438
+ value="Can someone explain what's funny about this meme??"
439
+ )
440
+
441
+ with gr.Column(scale=1):
442
+ txt_output = gr.Textbox(label="Result", lines=20)
443
+
444
+ with gr.Row():
445
+ understand_show_thinking = gr.Checkbox(label="Thinking", value=False)
446
+
447
+ # Add hyperparameter controls in an accordion
448
+ with gr.Accordion("Inference Hyperparameters", open=False):
449
+ with gr.Row():
450
+ understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
451
+ understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True,
452
+ label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)")
453
+ understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True,
454
+ label="Max New Tokens", info="Maximum length of generated text, including potential thinking")
455
+
456
+ img_understand_btn = gr.Button("Submit")
457
+
458
+ img_understand_btn.click(
459
+ fn=image_understanding,
460
+ inputs=[
461
+ img_input, understand_prompt, understand_show_thinking,
462
+ understand_do_sample, understand_text_temperature, understand_max_new_tokens
463
+ ],
464
+ outputs=txt_output
465
+ )
466
+
467
+ gr.Markdown("""
468
+ <div style="display: flex; justify-content: flex-start; flex-wrap: wrap; gap: 10px;">
469
+ <a href="https://bagel-ai.org/">
470
+ <img
471
+ src="https://img.shields.io/badge/BAGEL-Website-0A66C2?logo=safari&logoColor=white"
472
+ alt="BAGEL Website"
473
+ />
474
+ </a>
475
+ <a href="https://arxiv.org/abs/2505.14683">
476
+ <img
477
+ src="https://img.shields.io/badge/BAGEL-Paper-red?logo=arxiv&logoColor=red"
478
+ alt="BAGEL Paper on arXiv"
479
+ />
480
+ </a>
481
+ <a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
482
+ <img
483
+ src="https://img.shields.io/badge/BAGEL-Hugging%20Face-orange?logo=huggingface&logoColor=yellow"
484
+ alt="BAGEL on Hugging Face"
485
+ />
486
+ </a>
487
+ <a href="https://demo.bagel-ai.org/">
488
+ <img
489
+ src="https://img.shields.io/badge/BAGEL-Demo-blue?logo=googleplay&logoColor=blue"
490
+ alt="BAGEL Demo"
491
+ />
492
+ </a>
493
+ <a href="https://discord.gg/Z836xxzy">
494
+ <img
495
+ src="https://img.shields.io/badge/BAGEL-Discord-5865F2?logo=discord&logoColor=purple"
496
+ alt="BAGEL Discord"
497
+ />
498
+ </a>
499
+ <a href="mailto:[email protected]">
500
+ <img
501
+ src="https://img.shields.io/badge/BAGEL-Email-D14836?logo=gmail&logoColor=red"
502
+ alt="BAGEL Email"
503
+ />
504
+ </a>
505
+ </div>
506
+ """)
507
+
508
  demo.launch()