Spaces:
ByteDance-Seed
/
Running on Zero

likunchang commited on
Commit
8f54436
Β·
1 Parent(s): 468111e
Files changed (1) hide show
  1. app.py +245 -52
app.py CHANGED
@@ -1,30 +1,263 @@
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
2
  from PIL import Image
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # Text to Image function with thinking option and hyperparameters
 
5
  def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
6
  timestep_shift=3.0, num_timesteps=50,
7
  cfg_renorm_min=1.0, cfg_renorm_type="global",
8
  max_think_token_n=1024, do_sample=False, text_temperature=0.3,
9
  seed=0, image_ratio="1:1"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- yield None, None
 
 
 
 
 
 
 
 
12
 
13
 
14
  # Image Understanding function with thinking option and hyperparameters
 
15
  def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
16
  do_sample=False, text_temperature=0.3, max_new_tokens=512):
17
- yield None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  # Image Editing function with thinking option and hyperparameters
 
21
  def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
22
  cfg_img_scale=2.0, cfg_interval=0.0,
23
  timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
24
  cfg_renorm_type="text_channel", max_think_token_n=1024,
25
  do_sample=False, text_temperature=0.3, seed=0):
26
-
27
- yield (image, image), None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Helper function to load example images
30
  def load_example_image(image_path):
@@ -34,13 +267,10 @@ def load_example_image(image_path):
34
  print(f"Error loading example image: {e}")
35
  return None
36
 
 
37
  # Gradio UI
38
  with gr.Blocks() as demo:
39
- gr.Markdown("""
40
- <div>
41
- <img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/>
42
- </div>
43
- """)
44
 
45
  with gr.Tab("πŸ“ Text to Image"):
46
  txt_input = gr.Textbox(
@@ -127,7 +357,7 @@ with gr.Blocks() as demo:
127
  )
128
 
129
  with gr.Column(scale=1):
130
- edit_image_output = gr.ImageSlider(label="Result")
131
  edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
132
 
133
  with gr.Row():
@@ -233,45 +463,8 @@ with gr.Blocks() as demo:
233
  outputs=txt_output
234
  )
235
 
236
- gr.Markdown("""
237
- <div style="display: flex; justify-content: flex-start; flex-wrap: wrap; gap: 10px;">
238
- <a href="https://bagel-ai.org/">
239
- <img
240
- src="https://img.shields.io/badge/BAGEL-Website-0A66C2?logo=safari&logoColor=white"
241
- alt="BAGEL Website"
242
- />
243
- </a>
244
- <a href="https://arxiv.org/abs/2505.14683">
245
- <img
246
- src="https://img.shields.io/badge/BAGEL-Paper-red?logo=arxiv&logoColor=red"
247
- alt="BAGEL Paper on arXiv"
248
- />
249
- </a>
250
- <a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
251
- <img
252
- src="https://img.shields.io/badge/BAGEL-Hugging%20Face-orange?logo=huggingface&logoColor=yellow"
253
- alt="BAGEL on Hugging Face"
254
- />
255
- </a>
256
- <a href="https://demo.bagel-ai.org/">
257
- <img
258
- src="https://img.shields.io/badge/BAGEL-Demo-blue?logo=googleplay&logoColor=blue"
259
- alt="BAGEL Demo"
260
- />
261
- </a>
262
- <a href="https://discord.gg/Z836xxzy">
263
- <img
264
- src="https://img.shields.io/badge/BAGEL-Discord-5865F2?logo=discord&logoColor=purple"
265
- alt="BAGEL Discord"
266
- />
267
- </a>
268
- <a href="mailto:[email protected]">
269
- <img
270
- src="https://img.shields.io/badge/BAGEL-Email-D14836?logo=gmail&logoColor=red"
271
- alt="BAGEL Email"
272
- />
273
- </a>
274
- </div>
275
- """)
276
-
277
- demo.launch()
 
1
+ import spaces
2
  import gradio as gr
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ import random
7
+ import subprocess
8
+ subprocess.run(
9
+ "pip install flash-attn --no-build-isolation",
10
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
11
+ shell=True,
12
+ )
13
+
14
+ from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
  from PIL import Image
16
 
17
+ from data.data_utils import add_special_tokens, pil_img2rgb
18
+ from data.transforms import ImageTransform
19
+ from inferencer import InterleaveInferencer
20
+ from modeling.autoencoder import load_ae
21
+ from modeling.bagel import (
22
+ BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
23
+ SiglipVisionConfig, SiglipVisionModel
24
+ )
25
+ from modeling.qwen2 import Qwen2Tokenizer
26
+
27
+ from huggingface_hub import snapshot_download
28
+
29
+ save_dir = "./model_weights"
30
+ repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
31
+ cache_dir = save_dir + "/cache"
32
+
33
+ snapshot_download(
34
+ cache_dir=cache_dir,
35
+ local_dir=save_dir,
36
+ repo_id=repo_id,
37
+ local_dir_use_symlinks=False,
38
+ resume_download=True,
39
+ allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
40
+ )
41
+
42
+ # Model Initialization
43
+ model_path = save_dir
44
+
45
+ llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
46
+ llm_config.qk_norm = True
47
+ llm_config.tie_word_embeddings = False
48
+ llm_config.layer_module = "Qwen2MoTDecoderLayer"
49
+
50
+ vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
51
+ vit_config.rope = False
52
+ vit_config.num_hidden_layers -= 1
53
+
54
+ vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
55
+
56
+ config = BagelConfig(
57
+ visual_gen=True,
58
+ visual_und=True,
59
+ llm_config=llm_config,
60
+ vit_config=vit_config,
61
+ vae_config=vae_config,
62
+ vit_max_num_patch_per_side=70,
63
+ connector_act='gelu_pytorch_tanh',
64
+ latent_patch_size=2,
65
+ max_latent_size=64,
66
+ )
67
+
68
+ with init_empty_weights():
69
+ language_model = Qwen2ForCausalLM(llm_config)
70
+ vit_model = SiglipVisionModel(vit_config)
71
+ model = Bagel(language_model, vit_model, config)
72
+ model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
73
+
74
+ tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
75
+ tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
76
+
77
+ vae_transform = ImageTransform(1024, 512, 16)
78
+ vit_transform = ImageTransform(980, 224, 14)
79
+
80
+ # Model Loading and Multi GPU Infernece Preparing
81
+ device_map = infer_auto_device_map(
82
+ model,
83
+ max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
84
+ no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
85
+ )
86
+
87
+ same_device_modules = [
88
+ 'language_model.model.embed_tokens',
89
+ 'time_embedder',
90
+ 'latent_pos_embed',
91
+ 'vae2llm',
92
+ 'llm2vae',
93
+ 'connector',
94
+ 'vit_pos_embed'
95
+ ]
96
+
97
+ if torch.cuda.device_count() == 1:
98
+ first_device = device_map.get(same_device_modules[0], "cuda:0")
99
+ for k in same_device_modules:
100
+ if k in device_map:
101
+ device_map[k] = first_device
102
+ else:
103
+ device_map[k] = "cuda:0"
104
+ else:
105
+ first_device = device_map.get(same_device_modules[0])
106
+ for k in same_device_modules:
107
+ if k in device_map:
108
+ device_map[k] = first_device
109
+
110
+ model = load_checkpoint_and_dispatch(
111
+ model,
112
+ checkpoint=os.path.join(model_path, "ema.safetensors"),
113
+ device_map=device_map,
114
+ offload_buffers=True,
115
+ offload_folder="offload",
116
+ dtype=torch.bfloat16,
117
+ force_hooks=True,
118
+ ).eval()
119
+
120
+
121
+ # Inferencer Preparing
122
+ inferencer = InterleaveInferencer(
123
+ model=model,
124
+ vae_model=vae_model,
125
+ tokenizer=tokenizer,
126
+ vae_transform=vae_transform,
127
+ vit_transform=vit_transform,
128
+ new_token_ids=new_token_ids,
129
+ )
130
+
131
+ def set_seed(seed):
132
+ """Set random seeds for reproducibility"""
133
+ if seed > 0:
134
+ random.seed(seed)
135
+ np.random.seed(seed)
136
+ torch.manual_seed(seed)
137
+ if torch.cuda.is_available():
138
+ torch.cuda.manual_seed(seed)
139
+ torch.cuda.manual_seed_all(seed)
140
+ torch.backends.cudnn.deterministic = True
141
+ torch.backends.cudnn.benchmark = False
142
+ return seed
143
+
144
  # Text to Image function with thinking option and hyperparameters
145
+ @spaces.GPU(duration=90)
146
  def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
147
  timestep_shift=3.0, num_timesteps=50,
148
  cfg_renorm_min=1.0, cfg_renorm_type="global",
149
  max_think_token_n=1024, do_sample=False, text_temperature=0.3,
150
  seed=0, image_ratio="1:1"):
151
+ # Set seed for reproducibility
152
+ set_seed(seed)
153
+
154
+ if image_ratio == "1:1":
155
+ image_shapes = (1024, 1024)
156
+ elif image_ratio == "4:3":
157
+ image_shapes = (768, 1024)
158
+ elif image_ratio == "3:4":
159
+ image_shapes = (1024, 768)
160
+ elif image_ratio == "16:9":
161
+ image_shapes = (576, 1024)
162
+ elif image_ratio == "9:16":
163
+ image_shapes = (1024, 576)
164
+
165
+ # Set hyperparameters
166
+ inference_hyper = dict(
167
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
168
+ do_sample=do_sample if show_thinking else False,
169
+ text_temperature=text_temperature if show_thinking else 0.3,
170
+ cfg_text_scale=cfg_text_scale,
171
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
172
+ timestep_shift=timestep_shift,
173
+ num_timesteps=num_timesteps,
174
+ cfg_renorm_min=cfg_renorm_min,
175
+ cfg_renorm_type=cfg_renorm_type,
176
+ image_shapes=image_shapes,
177
+ )
178
 
179
+ result = {"text": "", "image": None}
180
+ # Call inferencer with or without think parameter based on user choice
181
+ for i in inferencer(text=prompt, think=show_thinking, understanding_output=False, **inference_hyper):
182
+ if type(i) == str:
183
+ result["text"] += i
184
+ else:
185
+ result["image"] = i
186
+
187
+ yield result["image"], result.get("text", None)
188
 
189
 
190
  # Image Understanding function with thinking option and hyperparameters
191
+ @spaces.GPU(duration=90)
192
  def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
193
  do_sample=False, text_temperature=0.3, max_new_tokens=512):
194
+ if image is None:
195
+ return "Please upload an image."
196
+
197
+ if isinstance(image, np.ndarray):
198
+ image = Image.fromarray(image)
199
+
200
+ image = pil_img2rgb(image)
201
+
202
+ # Set hyperparameters
203
+ inference_hyper = dict(
204
+ do_sample=do_sample,
205
+ text_temperature=text_temperature,
206
+ max_think_token_n=max_new_tokens, # Set max_length
207
+ )
208
+
209
+ result = {"text": "", "image": None}
210
+ # Use show_thinking parameter to control thinking process
211
+ for i in inferencer(image=image, text=prompt, think=show_thinking,
212
+ understanding_output=True, **inference_hyper):
213
+ if type(i) == str:
214
+ result["text"] += i
215
+ else:
216
+ result["image"] = i
217
+ yield result["text"]
218
 
219
 
220
  # Image Editing function with thinking option and hyperparameters
221
+ @spaces.GPU(duration=90)
222
  def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
223
  cfg_img_scale=2.0, cfg_interval=0.0,
224
  timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
225
  cfg_renorm_type="text_channel", max_think_token_n=1024,
226
  do_sample=False, text_temperature=0.3, seed=0):
227
+ # Set seed for reproducibility
228
+ set_seed(seed)
229
+
230
+ if image is None:
231
+ return "Please upload an image.", ""
232
+
233
+ if isinstance(image, np.ndarray):
234
+ image = Image.fromarray(image)
235
+
236
+ image = pil_img2rgb(image)
237
+
238
+ # Set hyperparameters
239
+ inference_hyper = dict(
240
+ max_think_token_n=max_think_token_n if show_thinking else 1024,
241
+ do_sample=do_sample if show_thinking else False,
242
+ text_temperature=text_temperature if show_thinking else 0.3,
243
+ cfg_text_scale=cfg_text_scale,
244
+ cfg_img_scale=cfg_img_scale,
245
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
246
+ timestep_shift=timestep_shift,
247
+ num_timesteps=num_timesteps,
248
+ cfg_renorm_min=cfg_renorm_min,
249
+ cfg_renorm_type=cfg_renorm_type,
250
+ )
251
+
252
+ # Include thinking parameter based on user choice
253
+ result = {"text": "", "image": None}
254
+ for i in inferencer(image=image, text=prompt, think=show_thinking, understanding_output=False, **inference_hyper):
255
+ if type(i) == str:
256
+ result["text"] += i
257
+ else:
258
+ result["image"] = i
259
+
260
+ yield result["image"], result.get("text", "")
261
 
262
  # Helper function to load example images
263
  def load_example_image(image_path):
 
267
  print(f"Error loading example image: {e}")
268
  return None
269
 
270
+
271
  # Gradio UI
272
  with gr.Blocks() as demo:
273
+ gr.Markdown("# πŸ₯― [BAGEL](https://bagel-ai.org/)")
 
 
 
 
274
 
275
  with gr.Tab("πŸ“ Text to Image"):
276
  txt_input = gr.Textbox(
 
357
  )
358
 
359
  with gr.Column(scale=1):
360
+ edit_image_output = gr.Image(label="Result")
361
  edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
362
 
363
  with gr.Row():
 
463
  outputs=txt_output
464
  )
465
 
466
+ gr.Markdown(
467
+ "🌐[Website](https://bagel-ai.org/)&nbsp;&nbsp;πŸ“„[Report](https://arxiv.org/abs/2505.14683)&nbsp;&nbsp;πŸ€—[Model](https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT)&nbsp;&nbsp;πŸš€[Demo](https://demo.bagel-ai.org/)&nbsp;&nbsp;πŸ’¬[Discord](https://discord.gg/Z836xxzy)&nbsp;&nbsp;πŸ“§[Contact](mailto:[email protected])"
468
+ )
469
+
470
+ demo.launch(share=True)