KingNish commited on
Commit
56bc2e4
·
verified ·
1 Parent(s): 71c41c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +363 -362
app.py CHANGED
@@ -13,7 +13,6 @@ subprocess.run(
13
 
14
  from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
  from PIL import Image
16
- import uuid
17
 
18
  from data.data_utils import add_special_tokens, pil_img2rgb
19
  from data.transforms import ImageTransform
@@ -32,19 +31,16 @@ save_dir = "./model"
32
  repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
33
  cache_dir = save_dir + "/cache"
34
 
35
- if not os.path.exists(os.path.join(save_dir, "ema.safetensors")):
36
- print(f"Downloading model from {repo_id} to {save_dir}")
37
- snapshot_download(cache_dir=cache_dir,
38
- local_dir=save_dir,
39
- repo_id=repo_id,
40
- local_dir_use_symlinks=False,
41
- resume_download=True,
42
- allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
43
- )
44
- else:
45
- print(f"Model found at {save_dir}")
46
 
47
- model_path = "./model"
 
48
 
49
  llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
50
  llm_config.qk_norm = True
@@ -60,7 +56,7 @@ vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetens
60
  config = BagelConfig(
61
  visual_gen=True,
62
  visual_und=True,
63
- llm_config=llm_config,
64
  vit_config=vit_config,
65
  vae_config=vae_config,
66
  vit_max_num_patch_per_side=70,
@@ -81,6 +77,7 @@ tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
81
  vae_transform = ImageTransform(1024, 512, 16)
82
  vit_transform = ImageTransform(980, 224, 14)
83
 
 
84
  device_map = infer_auto_device_map(
85
  model,
86
  max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
@@ -100,20 +97,16 @@ same_device_modules = [
100
  if torch.cuda.device_count() == 1:
101
  first_device = device_map.get(same_device_modules[0], "cuda:0")
102
  for k in same_device_modules:
103
- device_map[k] = first_device
 
 
 
104
  else:
105
- # Ensure all same_device_modules are on the same device if they exist in device_map
106
- # Find the device for the first module in the list that is actually in the device_map
107
- first_assigned_device = None
108
- for k_module in same_device_modules:
109
- if k_module in device_map:
110
- first_assigned_device = device_map[k_module]
111
- break
112
- if first_assigned_device is not None:
113
- for k_module in same_device_modules:
114
- if k_module in device_map: # Only assign if the module is part of the device_map
115
- device_map[k_module] = first_assigned_device
116
-
117
  model = load_checkpoint_and_dispatch(
118
  model,
119
  checkpoint=os.path.join(model_path, "ema.safetensors"),
@@ -123,6 +116,8 @@ model = load_checkpoint_and_dispatch(
123
  force_hooks=True,
124
  ).eval()
125
 
 
 
126
  inferencer = InterleaveInferencer(
127
  model=model,
128
  vae_model=vae_model,
@@ -133,7 +128,8 @@ inferencer = InterleaveInferencer(
133
  )
134
 
135
  def set_seed(seed):
136
- if seed is not None and seed > 0:
 
137
  random.seed(seed)
138
  np.random.seed(seed)
139
  torch.manual_seed(seed)
@@ -144,384 +140,389 @@ def set_seed(seed):
144
  torch.backends.cudnn.benchmark = False
145
  return seed
146
 
147
- # --- Backend Functions (Adapted from original app.py) ---
148
- @spaces.GPU(duration=90)
149
- def call_text_to_image(prompt, show_thinking, cfg_text_scale, cfg_interval,
150
- timestep_shift, num_timesteps, cfg_renorm_min, cfg_renorm_type,
151
- max_think_token_n, do_sample, text_temperature, seed, image_ratio):
 
 
152
  set_seed(seed)
153
- image_shapes = (1024, 1024)
154
- if image_ratio == "4:3": image_shapes = (768, 1024)
155
- elif image_ratio == "3:4": image_shapes = (1024, 768)
156
- elif image_ratio == "16:9": image_shapes = (576, 1024)
157
- elif image_ratio == "9:16": image_shapes = (1024, 576)
158
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  inference_hyper = dict(
160
  max_think_token_n=max_think_token_n if show_thinking else 1024,
161
  do_sample=do_sample if show_thinking else False,
162
  text_temperature=text_temperature if show_thinking else 0.3,
163
  cfg_text_scale=cfg_text_scale,
164
- cfg_interval=[cfg_interval, 1.0],
165
  timestep_shift=timestep_shift,
166
  num_timesteps=num_timesteps,
167
  cfg_renorm_min=cfg_renorm_min,
168
  cfg_renorm_type=cfg_renorm_type,
169
  image_shapes=image_shapes,
170
  )
 
 
171
  result = inferencer(text=prompt, think=show_thinking, **inference_hyper)
172
- return result.get("image", None), result.get("text", None) # text is thinking
 
 
 
 
 
 
 
 
 
 
173
 
174
- @spaces.GPU(duration=90)
175
- def call_image_understanding(image, prompt, show_thinking, do_sample, text_temperature, max_new_tokens, seed):
176
- set_seed(seed)
177
- if image is None: return "Please upload an image.", None
178
- if isinstance(image, np.ndarray): image = Image.fromarray(image)
179
  image = pil_img2rgb(image)
180
 
 
181
  inference_hyper = dict(
182
  do_sample=do_sample,
183
  text_temperature=text_temperature,
184
- max_think_token_n=max_new_tokens,
185
  )
186
- result = inferencer(image=image, text=prompt, think=show_thinking, understanding_output=True, **inference_hyper)
187
- return result.get("text", None), None # Main output is text, thinking is part of it if show_thinking=True
188
-
189
- @spaces.GPU(duration=90)
190
- def call_edit_image(image, prompt, show_thinking, cfg_text_scale, cfg_img_scale, cfg_interval,
191
- timestep_shift, num_timesteps, cfg_renorm_min, cfg_renorm_type,
192
- max_think_token_n, do_sample, text_temperature, seed):
 
 
 
 
 
 
 
193
  set_seed(seed)
194
- if image is None: return "Please upload an image.", None, None
195
- if isinstance(image, np.ndarray): image = Image.fromarray(image)
196
- image = pil_img2rgb(image)
 
 
 
197
 
 
 
 
198
  inference_hyper = dict(
199
  max_think_token_n=max_think_token_n if show_thinking else 1024,
200
  do_sample=do_sample if show_thinking else False,
201
  text_temperature=text_temperature if show_thinking else 0.3,
202
  cfg_text_scale=cfg_text_scale,
203
  cfg_img_scale=cfg_img_scale,
204
- cfg_interval=[cfg_interval, 1.0],
205
  timestep_shift=timestep_shift,
206
  num_timesteps=num_timesteps,
207
  cfg_renorm_min=cfg_renorm_min,
208
  cfg_renorm_type=cfg_renorm_type,
209
  )
210
- result = inferencer(image=image, text=prompt, think=show_thinking, **inference_hyper)
211
- return result.get("image", None), result.get("text", None) # text is thinking
212
-
213
- # --- Gradio UI ---
214
-
215
- DEFAULT_WELCOME_MESSAGE = {
216
- "role": "assistant",
217
- "content": "Hello! I am BAGEL, your multimodal assistant. How can I help you today? Select a mode and enter your prompt.",
218
- "key": "welcome"
219
- }
220
-
221
- class GradioApp:
222
- def __init__(self):
223
- self.current_conversation_id = None
224
- self.conversation_contexts = {}
225
- self.conversations_list = [] # For the sidebar
226
-
227
- def _get_current_history(self):
228
- if self.current_conversation_id and self.current_conversation_id in self.conversation_contexts:
229
- return self.conversation_contexts[self.current_conversation_id]["history"]
230
- return []
231
-
232
- def _get_current_settings(self):
233
- if self.current_conversation_id and self.current_conversation_id in self.conversation_contexts:
234
- return self.conversation_contexts[self.current_conversation_id].get("settings", {})
235
- return {}
236
 
237
- def _update_conversation_list_ui(self):
238
- return gr.update(choices=[(c['label'], c['key']) for c in self.conversations_list], value=self.current_conversation_id)
239
-
240
- def add_message(self, text_input, image_input, mode,
241
- # TTI params
242
- tti_show_thinking, tti_cfg_text_scale, tti_cfg_interval, tti_timestep_shift, tti_num_timesteps, tti_cfg_renorm_min, tti_cfg_renorm_type, tti_max_think_token_n, tti_do_sample, tti_text_temperature, tti_seed, tti_image_ratio,
243
- # Edit params
244
- edit_show_thinking, edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval, edit_timestep_shift, edit_num_timesteps, edit_cfg_renorm_min, edit_cfg_renorm_type, edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed,
245
- # Understand params
246
- und_show_thinking, und_do_sample, und_text_temperature, und_max_new_tokens, und_seed
247
- ):
248
- if not text_input and not (mode in ["Image Edit", "Image Understanding"] and image_input):
249
- gr.Warning("Please enter a prompt or upload an image for Edit/Understanding modes.")
250
- # Need to yield original state for all outputs if we return early
251
- # This part is tricky with dynamic outputs, might need a dummy update for all
252
- # For simplicity, let's assume user always provides some input
253
- # A better way is to disable submit button if input is invalid
254
- return self._get_current_history(), gr.update(value=None), gr.update(value=None) # chatbot, text_input, image_input
255
-
256
- if not self.current_conversation_id:
257
- self.new_chat_session(text_input[:30] if text_input else "New Chat") # Create a new chat if none exists
258
-
259
- history = self._get_current_history()
260
-
261
- # Store settings for this turn
262
- # This is simplified; best-gradio-ui.py stores settings per conversation
263
- current_turn_settings = {
264
- "mode": mode,
265
- # Store PIL image directly if needed, or handle path carefully
266
- "image_input": image_input, # Now storing the PIL image or None
267
- # TTI
268
- "tti_show_thinking": tti_show_thinking, "tti_cfg_text_scale": tti_cfg_text_scale, "tti_cfg_interval": tti_cfg_interval, "tti_timestep_shift": tti_timestep_shift, "tti_num_timesteps": tti_num_timesteps, "tti_cfg_renorm_min": tti_cfg_renorm_min, "tti_cfg_renorm_type": tti_cfg_renorm_type, "tti_max_think_token_n": tti_max_think_token_n, "tti_do_sample": tti_do_sample, "tti_text_temperature": tti_text_temperature, "tti_seed": tti_seed, "tti_image_ratio": tti_image_ratio,
269
- # Edit
270
- "edit_show_thinking": edit_show_thinking, "edit_cfg_text_scale": edit_cfg_text_scale, "edit_cfg_img_scale": edit_cfg_img_scale, "edit_cfg_interval": edit_cfg_interval, "edit_timestep_shift": edit_timestep_shift, "edit_num_timesteps": edit_num_timesteps, "edit_cfg_renorm_min": edit_cfg_renorm_min, "edit_cfg_renorm_type": edit_cfg_renorm_type, "edit_max_think_token_n": edit_max_think_token_n, "edit_do_sample": edit_do_sample, "edit_text_temperature": edit_text_temperature, "edit_seed": edit_seed,
271
- # Understand
272
- "und_show_thinking": und_show_thinking, "und_do_sample": und_do_sample, "und_text_temperature": und_text_temperature, "und_max_new_tokens": und_max_new_tokens, "und_seed": und_seed
273
- }
274
- self.conversation_contexts[self.current_conversation_id]["settings"] = current_turn_settings
275
-
276
- user_content_list = []
277
- if text_input:
278
- user_content_list.append({"type": "text", "text": text_input})
279
- if image_input and mode in ["Image Edit", "Image Understanding"]:
280
- # For 'messages' format, images are typically handled by passing them as part of a list of content dicts.
281
- # Gradio's Chatbot with type='messages' can render PIL Images or file paths directly in the 'content' list.
282
- user_content_list.append({"type": "image", "image": image_input}) # Assuming image_input is PIL
283
-
284
- # Construct the user message for history
285
- # If only text, content can be a string. If mixed, it's a list of dicts.
286
- user_message_for_history = {
287
- "role": "user",
288
- "content": text_input if not image_input else user_content_list,
289
- "key": str(uuid.uuid4())
290
- }
291
- if not text_input and image_input:
292
- user_message_for_history["content"] = user_content_list
293
- elif not user_content_list:
294
- # Handle case where there's no input at all, though prior checks should prevent this.
295
- gr.Warning("No input provided.")
296
- return self._get_current_history(), gr.update(value=None), gr.update(value=None)
297
-
298
-
299
- history.append(user_message_for_history)
300
- history.append({"role": "assistant", "content": "Processing...", "key": str(uuid.uuid4())})
301
-
302
- yield history, gr.update(value=None), gr.update(value=None) # chatbot, text_input, image_input (clear inputs)
303
 
304
- # Call backend
305
- try:
306
- output_image = None
307
- output_text = None
308
- thinking_text = None
309
-
310
- # image_input is already a PIL image from the gr.Image component with type="pil"
311
- pil_image_input = image_input
312
-
313
- if mode == "Text to Image":
314
- output_image, thinking_text = call_text_to_image(text_input, tti_show_thinking, tti_cfg_text_scale, tti_cfg_interval, tti_timestep_shift, tti_num_timesteps, tti_cfg_renorm_min, tti_cfg_renorm_type, tti_max_think_token_n, tti_do_sample, tti_text_temperature, tti_seed, tti_image_ratio)
315
- elif mode == "Image Edit":
316
- if not pil_image_input:
317
- output_text = "Error: Image required for Image Edit mode."
318
- else:
319
- output_image, thinking_text = call_edit_image(pil_image_input, text_input, edit_show_thinking, edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval, edit_timestep_shift, edit_num_timesteps, edit_cfg_renorm_min, edit_cfg_renorm_type, edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed)
320
- elif mode == "Image Understanding":
321
- if not pil_image_input:
322
- output_text = "Error: Image required for Image Understanding mode."
323
- else:
324
- output_text, _ = call_image_understanding(pil_image_input, text_input, und_show_thinking, und_do_sample, und_text_temperature, und_max_new_tokens, und_seed)
325
- # For VLM, the main output is text, thinking might be part of it or not separately returned
326
- # depending on `inferencer`'s behavior with `understanding_output=True`
327
- if und_show_thinking and output_text and "Thinking:" in output_text: # crude check
328
- parts = output_text.split("Thinking:", 1)
329
- if len(parts) > 1:
330
- thinking_text = "Thinking:" + parts[1].split("\nAnswer:")[0] if "\nAnswer:" in parts[1] else parts[1]
331
- output_text = parts[0].strip() + ("\nAnswer:" + output_text.split("\nAnswer:")[1] if "\nAnswer:" in output_text else "")
332
- else:
333
- thinking_text = None # Or handle as part of main output_text
334
-
335
- bot_response_content = []
336
- if thinking_text:
337
- # For 'messages' type, each part of the content is a dict in a list
338
- bot_response_content.append({"type": "text", "text": f"**Thinking Process:**\n{thinking_text}"})
339
- if output_text:
340
- bot_response_content.append({"type": "text", "text": output_text})
341
- if output_image: # output_image should be a PIL Image
342
- bot_response_content.append({"type": "image", "image": output_image})
343
-
344
- if not bot_response_content:
345
- bot_response_content.append({"type": "text", "text": "(No output generated)"})
346
-
347
- # Update the last message (which was "Processing...")
348
- history[-1]["content"] = bot_response_content_list[0]["text"] if len(bot_response_content_list) == 1 and bot_response_content_list[0]["type"] == "text" else bot_response_content_list
349
-
350
- except Exception as e:
351
- print(f"Error during processing: {e}")
352
- history[-1]["content"] = [{"type": "text", "content": f"Error: {str(e)}"}]
353
- history[-1]["loading"] = False
354
- raise gr.Error(f"Processing Error: {str(e)}")
355
-
356
- yield history, gr.update(value=None), gr.update(value=None)
357
-
358
- def new_chat_session(self, label="New Chat"):
359
- session_id = str(uuid.uuid4())
360
- self.current_conversation_id = session_id
361
- self.conversation_contexts[session_id] = {
362
- "history": [DEFAULT_WELCOME_MESSAGE.copy()],
363
- "settings": {} # Initialize with default settings if any
364
- }
365
- # Ensure label is unique if needed, or just use the provided one
366
- # For simplicity, we allow duplicate labels for now.
367
- new_conv_entry = {"label": label if label else f"Chat {len(self.conversations_list) + 1}", "key": session_id}
368
- self.conversations_list.insert(0, new_conv_entry) # Add to top
369
- return self._get_current_history(), self._update_conversation_list_ui()
370
-
371
- def change_chat_session(self, session_id):
372
- if session_id and session_id in self.conversation_contexts:
373
- self.current_conversation_id = session_id
374
- # Potentially update hyperparameter UI elements based on loaded session_settings
375
- # For now, just load history
376
- return self._get_current_history()
377
- return self._get_current_history() # No change or invalid ID
378
-
379
- def clear_history(self):
380
- if self.current_conversation_id:
381
- self.conversation_contexts[self.current_conversation_id]["history"] = [DEFAULT_WELCOME_MESSAGE.copy()]
382
- # Also clear current inputs if desired
383
- return self._get_current_history(), gr.update(value=None), gr.update(value=None)
384
- return [], gr.update(value=None), gr.update(value=None)
385
-
386
- def build_ui(self):
387
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
388
- gr.Markdown("""
389
  <div>
390
  <img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/>
391
- <h1>Unified BAGEL Chat Interface</h1>
392
  </div>
393
  """)
394
-
395
- with gr.Row():
396
- with gr.Column(scale=1):
397
- gr.Markdown("### Conversations")
398
- conversation_selector = gr.Radio(
399
- label="Select Chat",
400
- choices=[],
401
- type="value"
402
- )
403
- new_chat_btn = gr.Button("➕ New Chat")
404
-
405
- gr.Markdown("### Operation Mode")
406
- mode_selector = gr.Radio(
407
- label="Select Mode",
408
- choices=["Text to Image", "Image Edit", "Image Understanding"],
409
- value="Text to Image",
410
- interactive=True
411
- )
 
 
412
 
413
- # --- Hyperparameter Accordions ---
414
- # Visibility will be controlled by mode_selector
415
- with gr.Accordion("Text to Image Settings", open=True, visible=True) as tti_accordion:
416
- tti_show_thinking_cb = gr.Checkbox(label="Show Thinking Process", value=False, interactive=True)
417
- tti_seed_slider = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, label="Seed (0 for random)", interactive=True)
418
- tti_image_ratio_dd = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"], value="1:1", label="Image Ratio", interactive=True)
419
- tti_cfg_text_scale_slider = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, label="CFG Text Scale", interactive=True)
420
- tti_cfg_interval_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, label="CFG Interval Start", interactive=True)
421
- tti_cfg_renorm_type_dd = gr.Dropdown(choices=["global", "local", "text_channel"], value="global", label="CFG Renorm Type", interactive=True)
422
- tti_cfg_renorm_min_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, label="CFG Renorm Min", interactive=True)
423
- tti_num_timesteps_slider = gr.Slider(minimum=10, maximum=100, value=50, step=5, label="Timesteps", interactive=True)
424
- tti_timestep_shift_slider = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, label="Timestep Shift", interactive=True)
425
- with gr.Group(visible=False) as tti_thinking_params_group:
426
- tti_do_sample_cb = gr.Checkbox(label="Sampling (for thinking)", value=False, interactive=True)
427
- tti_max_think_token_slider = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Think Tokens", interactive=True)
428
- tti_text_temp_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, label="Temperature (for thinking)", interactive=True)
429
- tti_show_thinking_cb.change(lambda x: gr.update(visible=x), inputs=[tti_show_thinking_cb], outputs=[tti_thinking_params_group])
430
-
431
- with gr.Accordion("Image Edit Settings", open=False, visible=False) as edit_accordion:
432
- edit_show_thinking_cb = gr.Checkbox(label="Show Thinking Process", value=False, interactive=True)
433
- edit_seed_slider = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, label="Seed (0 for random)", interactive=True)
434
- edit_cfg_text_scale_slider = gr.Slider(1.0, 8.0, value=4.0, step=0.1, label="CFG Text Scale", interactive=True)
435
- edit_cfg_img_scale_slider = gr.Slider(1.0, 4.0, value=2.0, step=0.1, label="CFG Image Scale", interactive=True)
436
- edit_cfg_interval_slider = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="CFG Interval Start", interactive=True)
437
- edit_cfg_renorm_type_dd = gr.Dropdown(["global", "local", "text_channel"], value="text_channel", label="CFG Renorm Type", interactive=True)
438
- edit_cfg_renorm_min_slider = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="CFG Renorm Min", interactive=True)
439
- edit_num_timesteps_slider = gr.Slider(10, 100, value=50, step=5, label="Timesteps", interactive=True)
440
- edit_timestep_shift_slider = gr.Slider(1.0, 10.0, value=3.0, step=0.5, label="Timestep Shift", interactive=True)
441
- with gr.Group(visible=False) as edit_thinking_params_group:
442
- edit_do_sample_cb = gr.Checkbox(label="Sampling (for thinking)", value=False, interactive=True)
443
- edit_max_think_token_slider = gr.Slider(64, 4096, value=1024, step=64, label="Max Think Tokens", interactive=True)
444
- edit_text_temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (for thinking)", interactive=True)
445
- edit_show_thinking_cb.change(lambda x: gr.update(visible=x), inputs=[edit_show_thinking_cb], outputs=[edit_thinking_params_group])
446
-
447
- with gr.Accordion("Image Understanding Settings", open=False, visible=False) as und_accordion:
448
- und_show_thinking_cb = gr.Checkbox(label="Show Thinking Process (if applicable)", value=False, interactive=True)
449
- und_seed_slider = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, label="Seed (0 for random)", interactive=True)
450
- und_do_sample_cb = gr.Checkbox(label="Sampling", value=False, interactive=True)
451
- und_text_temp_slider = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature", interactive=True)
452
- und_max_new_tokens_slider = gr.Slider(32, 2048, value=512, step=32, label="Max New Tokens", interactive=True)
453
-
454
- # Logic to show/hide accordions based on mode
455
- def update_accordion_visibility(mode):
456
- return (
457
- gr.update(visible=mode == "Text to Image"),
458
- gr.update(visible=mode == "Image Edit"),
459
- gr.update(visible=mode == "Image Understanding")
460
- )
461
- mode_selector.change(update_accordion_visibility, inputs=[mode_selector], outputs=[tti_accordion, edit_accordion, und_accordion])
462
-
463
- with gr.Column(scale=3):
464
- chatbot_ui = gr.Chatbot(label="BAGEL Chat", value=[DEFAULT_WELCOME_MESSAGE.copy()], bubble_full_width=False, height=600)
465
- with gr.Row():
466
- image_upload_ui = gr.Image(type="pil", label="Upload Image (for Edit/Understand)", sources=['upload'], visible=False, interactive=True)
467
  with gr.Row():
468
- text_input_ui = gr.Textbox(label="Enter your prompt here...", lines=3, scale=7, interactive=True)
469
- submit_btn = gr.Button("Send", variant="primary", scale=1)
470
- clear_btn = gr.Button("Clear Chat", scale=1)
471
-
472
- # Show/hide image upload based on mode
473
- def update_image_upload_visibility(mode):
474
- return gr.update(visible=mode in ["Image Edit", "Image Understanding"])
475
- mode_selector.change(update_image_upload_visibility, inputs=[mode_selector], outputs=[image_upload_ui])
476
-
477
- # Initial state setup
478
- demo.load(lambda: self.new_chat_session("Welcome Chat"), outputs=[chatbot_ui, conversation_selector])
479
-
480
- # Event handlers
481
- new_chat_btn.click(
482
- self.new_chat_session,
483
- inputs=None,
484
- outputs=[chatbot_ui, conversation_selector]
485
- )
486
- conversation_selector.change(
487
- self.change_chat_session,
488
- inputs=[conversation_selector],
489
- outputs=[chatbot_ui]
 
 
 
 
 
 
 
 
 
490
  )
491
-
492
- submit_btn.click(
493
- self.add_message,
494
- inputs=[
495
- text_input_ui, image_upload_ui, mode_selector,
496
- # TTI
497
- tti_show_thinking_cb, tti_cfg_text_scale_slider, tti_cfg_interval_slider, tti_timestep_shift_slider, tti_num_timesteps_slider, tti_cfg_renorm_min_slider, tti_cfg_renorm_type_dd, tti_max_think_token_slider, tti_do_sample_cb, tti_text_temp_slider, tti_seed_slider, tti_image_ratio_dd,
498
- # Edit
499
- edit_show_thinking_cb, edit_cfg_text_scale_slider, edit_cfg_img_scale_slider, edit_cfg_interval_slider, edit_timestep_shift_slider, edit_num_timesteps_slider, edit_cfg_renorm_min_slider, edit_cfg_renorm_type_dd, edit_max_think_token_slider, edit_do_sample_cb, edit_text_temp_slider, edit_seed_slider,
500
- # Understand
501
- und_show_thinking_cb, und_do_sample_cb, und_text_temp_slider, und_max_new_tokens_slider, und_seed_slider
502
- ],
503
- outputs=[chatbot_ui, text_input_ui, image_upload_ui]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  )
505
- text_input_ui.submit(
506
- self.add_message,
507
- inputs=[
508
- text_input_ui, image_upload_ui, mode_selector,
509
- # TTI
510
- tti_show_thinking_cb, tti_cfg_text_scale_slider, tti_cfg_interval_slider, tti_timestep_shift_slider, tti_num_timesteps_slider, tti_cfg_renorm_min_slider, tti_cfg_renorm_type_dd, tti_max_think_token_slider, tti_do_sample_cb, tti_text_temp_slider, tti_seed_slider, tti_image_ratio_dd,
511
- # Edit
512
- edit_show_thinking_cb, edit_cfg_text_scale_slider, edit_cfg_img_scale_slider, edit_cfg_interval_slider, edit_timestep_shift_slider, edit_num_timesteps_slider, edit_cfg_renorm_min_slider, edit_cfg_renorm_type_dd, edit_max_think_token_slider, edit_do_sample_cb, edit_text_temp_slider, edit_seed_slider,
513
- # Understand
514
- und_show_thinking_cb, und_do_sample_cb, und_text_temp_slider, und_max_new_tokens_slider, und_seed_slider
515
- ],
516
- outputs=[chatbot_ui, text_input_ui, image_upload_ui]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  )
518
-
519
- clear_btn.click(self.clear_history, inputs=None, outputs=[chatbot_ui, text_input_ui, image_upload_ui])
520
 
521
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
- # Main execution
524
- if __name__ == "__main__":
525
- app_instance = GradioApp()
526
- demo_ui = app_instance.build_ui()
527
- demo_ui.queue().launch(share=True, debug=True) # Set share=True if you need a public link
 
13
 
14
  from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
  from PIL import Image
 
16
 
17
  from data.data_utils import add_special_tokens, pil_img2rgb
18
  from data.transforms import ImageTransform
 
31
  repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
32
  cache_dir = save_dir + "/cache"
33
 
34
+ snapshot_download(cache_dir=cache_dir,
35
+ local_dir=save_dir,
36
+ repo_id=repo_id,
37
+ local_dir_use_symlinks=False,
38
+ resume_download=True,
39
+ allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
40
+ )
 
 
 
 
41
 
42
+ # Model Initialization
43
+ model_path = "./model" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT
44
 
45
  llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
46
  llm_config.qk_norm = True
 
56
  config = BagelConfig(
57
  visual_gen=True,
58
  visual_und=True,
59
+ llm_config=llm_config,
60
  vit_config=vit_config,
61
  vae_config=vae_config,
62
  vit_max_num_patch_per_side=70,
 
77
  vae_transform = ImageTransform(1024, 512, 16)
78
  vit_transform = ImageTransform(980, 224, 14)
79
 
80
+ # Model Loading and Multi GPU Infernece Preparing
81
  device_map = infer_auto_device_map(
82
  model,
83
  max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
 
97
  if torch.cuda.device_count() == 1:
98
  first_device = device_map.get(same_device_modules[0], "cuda:0")
99
  for k in same_device_modules:
100
+ if k in device_map:
101
+ device_map[k] = first_device
102
+ else:
103
+ device_map[k] = "cuda:0"
104
  else:
105
+ first_device = device_map.get(same_device_modules[0])
106
+ for k in same_device_modules:
107
+ if k in device_map:
108
+ device_map[k] = first_device
109
+
 
 
 
 
 
 
 
110
  model = load_checkpoint_and_dispatch(
111
  model,
112
  checkpoint=os.path.join(model_path, "ema.safetensors"),
 
116
  force_hooks=True,
117
  ).eval()
118
 
119
+
120
+ # Inferencer Preparing
121
  inferencer = InterleaveInferencer(
122
  model=model,
123
  vae_model=vae_model,
 
128
  )
129
 
130
  def set_seed(seed):
131
+ """Set random seeds for reproducibility"""
132
+ if seed > 0:
133
  random.seed(seed)
134
  np.random.seed(seed)
135
  torch.manual_seed(seed)
 
140
  torch.backends.cudnn.benchmark = False
141
  return seed
142
 
143
+ # Text to Image function with thinking option and hyperparameters
144
+ def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
145
+ timestep_shift=3.0, num_timesteps=50,
146
+ cfg_renorm_min=1.0, cfg_renorm_type="global",
147
+ max_think_token_n=1024, do_sample=False, text_temperature=0.3,
148
+ seed=0, image_ratio="1:1"):
149
+ # Set seed for reproducibility
150
  set_seed(seed)
 
 
 
 
 
151
 
152
+ if image_ratio == "1:1":
153
+ image_shapes = (1024, 1024)
154
+ elif image_ratio == "4:3":
155
+ image_shapes = (768, 1024)
156
+ elif image_ratio == "3:4":
157
+ image_shapes = (1024, 768)
158
+ elif image_ratio == "16:9":
159
+ image_shapes = (576, 1024)
160
+ elif image_ratio == "9:16":
161
+ image_shapes = (1024, 576)
162
+
163
+ # Set hyperparameters
164
  inference_hyper = dict(
165
  max_think_token_n=max_think_token_n if show_thinking else 1024,
166
  do_sample=do_sample if show_thinking else False,
167
  text_temperature=text_temperature if show_thinking else 0.3,
168
  cfg_text_scale=cfg_text_scale,
169
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
170
  timestep_shift=timestep_shift,
171
  num_timesteps=num_timesteps,
172
  cfg_renorm_min=cfg_renorm_min,
173
  cfg_renorm_type=cfg_renorm_type,
174
  image_shapes=image_shapes,
175
  )
176
+
177
+ # Call inferencer with or without think parameter based on user choice
178
  result = inferencer(text=prompt, think=show_thinking, **inference_hyper)
179
+ return result["image"], result.get("text", None)
180
+
181
+
182
+ # Image Understanding function with thinking option and hyperparameters
183
+ def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
184
+ do_sample=False, text_temperature=0.3, max_new_tokens=512):
185
+ if image is None:
186
+ return "Please upload an image."
187
+
188
+ if isinstance(image, np.ndarray):
189
+ image = Image.fromarray(image)
190
 
 
 
 
 
 
191
  image = pil_img2rgb(image)
192
 
193
+ # Set hyperparameters
194
  inference_hyper = dict(
195
  do_sample=do_sample,
196
  text_temperature=text_temperature,
197
+ max_think_token_n=max_new_tokens, # Set max_length
198
  )
199
+
200
+ # Use show_thinking parameter to control thinking process
201
+ result = inferencer(image=image, text=prompt, think=show_thinking,
202
+ understanding_output=True, **inference_hyper)
203
+ return result["text"]
204
+
205
+
206
+ # Image Editing function with thinking option and hyperparameters
207
+ def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
208
+ cfg_img_scale=2.0, cfg_interval=0.0,
209
+ timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
210
+ cfg_renorm_type="text_channel", max_think_token_n=1024,
211
+ do_sample=False, text_temperature=0.3, seed=0):
212
+ # Set seed for reproducibility
213
  set_seed(seed)
214
+
215
+ if image is None:
216
+ return "Please upload an image.", ""
217
+
218
+ if isinstance(image, np.ndarray):
219
+ image = Image.fromarray(image)
220
 
221
+ image = pil_img2rgb(image)
222
+
223
+ # Set hyperparameters
224
  inference_hyper = dict(
225
  max_think_token_n=max_think_token_n if show_thinking else 1024,
226
  do_sample=do_sample if show_thinking else False,
227
  text_temperature=text_temperature if show_thinking else 0.3,
228
  cfg_text_scale=cfg_text_scale,
229
  cfg_img_scale=cfg_img_scale,
230
+ cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
231
  timestep_shift=timestep_shift,
232
  num_timesteps=num_timesteps,
233
  cfg_renorm_min=cfg_renorm_min,
234
  cfg_renorm_type=cfg_renorm_type,
235
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
+ # Include thinking parameter based on user choice
238
+ result = inferencer(image=image, text=prompt, think=show_thinking, **inference_hyper)
239
+ return result["image"], result.get("text", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
+
242
+ # Helper function to load example images
243
+ def load_example_image(image_path):
244
+ try:
245
+ return Image.open(image_path)
246
+ except Exception as e:
247
+ print(f"Error loading example image: {e}")
248
+ return None
249
+
250
+
251
+ # Gradio UI
252
+ with gr.Blocks() as demo:
253
+ gr.Markdown("""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  <div>
255
  <img src="https://lf3-static.bytednsdoc.com/obj/eden-cn/nuhojubrps/banner.png" alt="BAGEL" width="380"/>
 
256
  </div>
257
  """)
258
+
259
+ with gr.Tab("📝 Text to Image"):
260
+ txt_input = gr.Textbox(
261
+ label="Prompt",
262
+ value="A female cosplayer portraying an ethereal fairy or elf, wearing a flowing dress made of delicate fabrics in soft, mystical colors like emerald green and silver. She has pointed ears, a gentle, enchanting expression, and her outfit is adorned with sparkling jewels and intricate patterns. The background is a magical forest with glowing plants, mystical creatures, and a serene atmosphere."
263
+ )
264
+
265
+ with gr.Row():
266
+ show_thinking = gr.Checkbox(label="Thinking", value=False)
267
+
268
+ # Add hyperparameter controls in an accordion
269
+ with gr.Accordion("Inference Hyperparameters", open=False):
270
+ # 参数一排两个布局
271
+ with gr.Group():
272
+ with gr.Row():
273
+ seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1,
274
+ label="Seed", info="0 for random seed, positive for reproducible results")
275
+ image_ratio = gr.Dropdown(choices=["1:1", "4:3", "3:4", "16:9", "9:16"],
276
+ value="1:1", label="Image Ratio",
277
+ info="The longer size is fixed to 1024")
278
 
279
+ with gr.Row():
280
+ cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
281
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt (4.0-8.0)")
282
+ cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1,
283
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
284
+
285
+ with gr.Row():
286
+ cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
287
+ value="global", label="CFG Renorm Type",
288
+ info="If the genrated image is blurry, use 'global'")
289
+ cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
290
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
291
+
292
+ with gr.Row():
293
+ num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
294
+ label="Timesteps", info="Total denoising steps")
295
+ timestep_shift = gr.Slider(minimum=1.0, maximum=5.0, value=3.0, step=0.5, interactive=True,
296
+ label="Timestep Shift", info="Higher values for layout, lower for details")
297
+
298
+ # Thinking parameters in a single row
299
+ thinking_params = gr.Group(visible=False)
300
+ with thinking_params:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  with gr.Row():
302
+ do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
303
+ max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
304
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
305
+ text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
306
+ label="Temperature", info="Controls randomness in text generation")
307
+
308
+ thinking_output = gr.Textbox(label="Thinking Process", visible=False)
309
+ img_output = gr.Image(label="Generated Image")
310
+ gen_btn = gr.Button("Generate")
311
+
312
+ # Dynamically show/hide thinking process box and parameters
313
+ def update_thinking_visibility(show):
314
+ return gr.update(visible=show), gr.update(visible=show)
315
+
316
+ show_thinking.change(
317
+ fn=update_thinking_visibility,
318
+ inputs=[show_thinking],
319
+ outputs=[thinking_output, thinking_params]
320
+ )
321
+
322
+ # Process function based on thinking option and hyperparameters
323
+ @spaces.GPU(duration=90)
324
+ def process_text_to_image(prompt, show_thinking, cfg_text_scale,
325
+ cfg_interval, timestep_shift,
326
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
327
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio):
328
+ image, thinking = text_to_image(
329
+ prompt, show_thinking, cfg_text_scale, cfg_interval,
330
+ timestep_shift, num_timesteps,
331
+ cfg_renorm_min, cfg_renorm_type,
332
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio
333
  )
334
+ return image, thinking if thinking else ""
335
+
336
+ gen_btn.click(
337
+ fn=process_text_to_image,
338
+ inputs=[
339
+ txt_input, show_thinking, cfg_text_scale,
340
+ cfg_interval, timestep_shift,
341
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
342
+ max_think_token_n, do_sample, text_temperature, seed, image_ratio
343
+ ],
344
+ outputs=[img_output, thinking_output]
345
+ )
346
+
347
+ with gr.Tab("🖌️ Image Edit"):
348
+ with gr.Row():
349
+ with gr.Column(scale=1):
350
+ edit_image_input = gr.Image(label="Input Image", value=load_example_image('test_images/women.jpg'))
351
+ edit_prompt = gr.Textbox(
352
+ label="Prompt",
353
+ value="She boards a modern subway, quietly reading a folded newspaper, wearing the same clothes."
354
+ )
355
+
356
+ with gr.Column(scale=1):
357
+ edit_image_output = gr.Image(label="Result")
358
+ edit_thinking_output = gr.Textbox(label="Thinking Process", visible=False)
359
+
360
+ with gr.Row():
361
+ edit_show_thinking = gr.Checkbox(label="Thinking", value=False)
362
+
363
+ # Add hyperparameter controls in an accordion
364
+ with gr.Accordion("Inference Hyperparameters", open=False):
365
+ with gr.Group():
366
+ with gr.Row():
367
+ edit_seed = gr.Slider(minimum=0, maximum=1000000, value=0, step=1, interactive=True,
368
+ label="Seed", info="0 for random seed, positive for reproducible results")
369
+ edit_cfg_text_scale = gr.Slider(minimum=1.0, maximum=8.0, value=4.0, step=0.1, interactive=True,
370
+ label="CFG Text Scale", info="Controls how strongly the model follows the text prompt")
371
+
372
+ with gr.Row():
373
+ edit_cfg_img_scale = gr.Slider(minimum=1.0, maximum=4.0, value=2.0, step=0.1, interactive=True,
374
+ label="CFG Image Scale", info="Controls how much the model preserves input image details")
375
+ edit_cfg_interval = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
376
+ label="CFG Interval", info="Start of CFG application interval (end is fixed at 1.0)")
377
+
378
+ with gr.Row():
379
+ edit_cfg_renorm_type = gr.Dropdown(choices=["global", "local", "text_channel"],
380
+ value="text_channel", label="CFG Renorm Type",
381
+ info="If the genrated image is blurry, use 'global")
382
+ edit_cfg_renorm_min = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True,
383
+ label="CFG Renorm Min", info="1.0 disables CFG-Renorm")
384
+
385
+ with gr.Row():
386
+ edit_num_timesteps = gr.Slider(minimum=10, maximum=100, value=50, step=5, interactive=True,
387
+ label="Timesteps", info="Total denoising steps")
388
+ edit_timestep_shift = gr.Slider(minimum=1.0, maximum=10.0, value=3.0, step=0.5, interactive=True,
389
+ label="Timestep Shift", info="Higher values for layout, lower for details")
390
+
391
+
392
+ # Thinking parameters in a single row
393
+ edit_thinking_params = gr.Group(visible=False)
394
+ with edit_thinking_params:
395
+ with gr.Row():
396
+ edit_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
397
+ edit_max_think_token_n = gr.Slider(minimum=64, maximum=4006, value=1024, step=64, interactive=True,
398
+ label="Max Think Tokens", info="Maximum number of tokens for thinking")
399
+ edit_text_temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.3, step=0.1, interactive=True,
400
+ label="Temperature", info="Controls randomness in text generation")
401
+
402
+ edit_btn = gr.Button("Submit")
403
+
404
+ # Dynamically show/hide thinking process box for editing
405
+ def update_edit_thinking_visibility(show):
406
+ return gr.update(visible=show), gr.update(visible=show)
407
+
408
+ edit_show_thinking.change(
409
+ fn=update_edit_thinking_visibility,
410
+ inputs=[edit_show_thinking],
411
+ outputs=[edit_thinking_output, edit_thinking_params]
412
+ )
413
+
414
+ # Process editing with thinking option and hyperparameters
415
+ @spaces.GPU(duration=90)
416
+ def process_edit_image(image, prompt, show_thinking, cfg_text_scale,
417
+ cfg_img_scale, cfg_interval,
418
+ timestep_shift, num_timesteps, cfg_renorm_min,
419
+ cfg_renorm_type, max_think_token_n, do_sample,
420
+ text_temperature, seed):
421
+ edited_image, thinking = edit_image(
422
+ image, prompt, show_thinking, cfg_text_scale, cfg_img_scale,
423
+ cfg_interval, timestep_shift,
424
+ num_timesteps, cfg_renorm_min, cfg_renorm_type,
425
+ max_think_token_n, do_sample, text_temperature, seed
426
  )
427
+
428
+ return edited_image, thinking if thinking else ""
429
+
430
+ edit_btn.click(
431
+ fn=process_edit_image,
432
+ inputs=[
433
+ edit_image_input, edit_prompt, edit_show_thinking,
434
+ edit_cfg_text_scale, edit_cfg_img_scale, edit_cfg_interval,
435
+ edit_timestep_shift, edit_num_timesteps,
436
+ edit_cfg_renorm_min, edit_cfg_renorm_type,
437
+ edit_max_think_token_n, edit_do_sample, edit_text_temperature, edit_seed
438
+ ],
439
+ outputs=[edit_image_output, edit_thinking_output]
440
+ )
441
+
442
+ with gr.Tab("🖼️ Image Understanding"):
443
+ with gr.Row():
444
+ with gr.Column(scale=1):
445
+ img_input = gr.Image(label="Input Image", value=load_example_image('test_images/meme.jpg'))
446
+ understand_prompt = gr.Textbox(
447
+ label="Prompt",
448
+ value="Can someone explain what's funny about this meme??"
449
+ )
450
+
451
+ with gr.Column(scale=1):
452
+ txt_output = gr.Textbox(label="Result", lines=20)
453
+
454
+ with gr.Row():
455
+ understand_show_thinking = gr.Checkbox(label="Thinking", value=False)
456
+
457
+ # Add hyperparameter controls in an accordion
458
+ with gr.Accordion("Inference Hyperparameters", open=False):
459
+ with gr.Row():
460
+ understand_do_sample = gr.Checkbox(label="Sampling", value=False, info="Enable sampling for text generation")
461
+ understand_text_temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.05, interactive=True,
462
+ label="Temperature", info="Controls randomness in text generation (0=deterministic, 1=creative)")
463
+ understand_max_new_tokens = gr.Slider(minimum=64, maximum=4096, value=512, step=64, interactive=True,
464
+ label="Max New Tokens", info="Maximum length of generated text, including potential thinking")
465
+
466
+ img_understand_btn = gr.Button("Submit")
467
+
468
+ # Process understanding with thinking option and hyperparameters
469
+ @spaces.GPU(duration=90)
470
+ def process_understanding(image, prompt, show_thinking, do_sample,
471
+ text_temperature, max_new_tokens):
472
+ result = image_understanding(
473
+ image, prompt, show_thinking, do_sample,
474
+ text_temperature, max_new_tokens
475
  )
476
+ return result
 
477
 
478
+ img_understand_btn.click(
479
+ fn=process_understanding,
480
+ inputs=[
481
+ img_input, understand_prompt, understand_show_thinking,
482
+ understand_do_sample, understand_text_temperature, understand_max_new_tokens
483
+ ],
484
+ outputs=txt_output
485
+ )
486
+
487
+ gr.Markdown("""
488
+ <div style="display: flex; justify-content: flex-start; flex-wrap: wrap; gap: 10px;">
489
+ <a href="https://bagel-ai.org/">
490
+ <img
491
+ src="https://img.shields.io/badge/BAGEL-Website-0A66C2?logo=safari&logoColor=white"
492
+ alt="BAGEL Website"
493
+ />
494
+ </a>
495
+ <a href="https://arxiv.org/abs/2505.14683">
496
+ <img
497
+ src="https://img.shields.io/badge/BAGEL-Paper-red?logo=arxiv&logoColor=red"
498
+ alt="BAGEL Paper on arXiv"
499
+ />
500
+ </a>
501
+ <a href="https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT">
502
+ <img
503
+ src="https://img.shields.io/badge/BAGEL-Hugging%20Face-orange?logo=huggingface&logoColor=yellow"
504
+ alt="BAGEL on Hugging Face"
505
+ />
506
+ </a>
507
+ <a href="https://demo.bagel-ai.org/">
508
+ <img
509
+ src="https://img.shields.io/badge/BAGEL-Demo-blue?logo=googleplay&logoColor=blue"
510
+ alt="BAGEL Demo"
511
+ />
512
+ </a>
513
+ <a href="https://discord.gg/Z836xxzy">
514
+ <img
515
+ src="https://img.shields.io/badge/BAGEL-Discord-5865F2?logo=discord&logoColor=purple"
516
+ alt="BAGEL Discord"
517
+ />
518
+ </a>
519
+ <a href="mailto:[email protected]">
520
+ <img
521
+ src="https://img.shields.io/badge/BAGEL-Email-D14836?logo=gmail&logoColor=red"
522
+ alt="BAGEL Email"
523
+ />
524
+ </a>
525
+ </div>
526
+ """)
527
 
528
+ demo.launch()