likunchang commited on
Commit
c6b9a17
·
1 Parent(s): 717846a
Files changed (1) hide show
  1. app.py +3 -238
app.py CHANGED
@@ -1,264 +1,30 @@
1
- import spaces
2
  import gradio as gr
3
- import numpy as np
4
- import os
5
- import torch
6
- import random
7
- import subprocess
8
- subprocess.run(
9
- "pip install flash-attn --no-build-isolation",
10
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
11
- shell=True,
12
- )
13
-
14
- from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
  from PIL import Image
16
 
17
- from data.data_utils import add_special_tokens, pil_img2rgb
18
- from data.transforms import ImageTransform
19
- from inferencer import InterleaveInferencer
20
- from modeling.autoencoder import load_ae
21
- from modeling.bagel import (
22
- BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM,
23
- SiglipVisionConfig, SiglipVisionModel
24
- )
25
- from modeling.qwen2 import Qwen2Tokenizer
26
-
27
- from huggingface_hub import snapshot_download
28
-
29
- save_dir = "./model_weights"
30
- repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
31
- cache_dir = save_dir + "/cache"
32
-
33
- snapshot_download(
34
- cache_dir=cache_dir,
35
- local_dir=save_dir,
36
- repo_id=repo_id,
37
- local_dir_use_symlinks=False,
38
- resume_download=True,
39
- allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
40
- )
41
-
42
- # Model Initialization
43
- model_path = save_dir
44
-
45
- llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
46
- llm_config.qk_norm = True
47
- llm_config.tie_word_embeddings = False
48
- llm_config.layer_module = "Qwen2MoTDecoderLayer"
49
-
50
- vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_path, "vit_config.json"))
51
- vit_config.rope = False
52
- vit_config.num_hidden_layers -= 1
53
-
54
- vae_model, vae_config = load_ae(local_path=os.path.join(model_path, "ae.safetensors"))
55
-
56
- config = BagelConfig(
57
- visual_gen=True,
58
- visual_und=True,
59
- llm_config=llm_config,
60
- vit_config=vit_config,
61
- vae_config=vae_config,
62
- vit_max_num_patch_per_side=70,
63
- connector_act='gelu_pytorch_tanh',
64
- latent_patch_size=2,
65
- max_latent_size=64,
66
- )
67
-
68
- with init_empty_weights():
69
- language_model = Qwen2ForCausalLM(llm_config)
70
- vit_model = SiglipVisionModel(vit_config)
71
- model = Bagel(language_model, vit_model, config)
72
- model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config, meta=True)
73
-
74
- tokenizer = Qwen2Tokenizer.from_pretrained(model_path)
75
- tokenizer, new_token_ids, _ = add_special_tokens(tokenizer)
76
-
77
- vae_transform = ImageTransform(1024, 512, 16)
78
- vit_transform = ImageTransform(980, 224, 14)
79
-
80
- # Model Loading and Multi GPU Infernece Preparing
81
- device_map = infer_auto_device_map(
82
- model,
83
- max_memory={i: "80GiB" for i in range(torch.cuda.device_count())},
84
- no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"],
85
- )
86
-
87
- same_device_modules = [
88
- 'language_model.model.embed_tokens',
89
- 'time_embedder',
90
- 'latent_pos_embed',
91
- 'vae2llm',
92
- 'llm2vae',
93
- 'connector',
94
- 'vit_pos_embed'
95
- ]
96
-
97
- if torch.cuda.device_count() == 1:
98
- first_device = device_map.get(same_device_modules[0], "cuda:0")
99
- for k in same_device_modules:
100
- if k in device_map:
101
- device_map[k] = first_device
102
- else:
103
- device_map[k] = "cuda:0"
104
- else:
105
- first_device = device_map.get(same_device_modules[0])
106
- for k in same_device_modules:
107
- if k in device_map:
108
- device_map[k] = first_device
109
-
110
- model = load_checkpoint_and_dispatch(
111
- model,
112
- checkpoint=os.path.join(model_path, "ema.safetensors"),
113
- device_map=device_map,
114
- offload_buffers=True,
115
- offload_folder="offload",
116
- dtype=torch.bfloat16,
117
- force_hooks=True,
118
- ).eval()
119
-
120
-
121
- # Inferencer Preparing
122
- inferencer = InterleaveInferencer(
123
- model=model,
124
- vae_model=vae_model,
125
- tokenizer=tokenizer,
126
- vae_transform=vae_transform,
127
- vit_transform=vit_transform,
128
- new_token_ids=new_token_ids,
129
- )
130
-
131
- def set_seed(seed):
132
- """Set random seeds for reproducibility"""
133
- if seed > 0:
134
- random.seed(seed)
135
- np.random.seed(seed)
136
- torch.manual_seed(seed)
137
- if torch.cuda.is_available():
138
- torch.cuda.manual_seed(seed)
139
- torch.cuda.manual_seed_all(seed)
140
- torch.backends.cudnn.deterministic = True
141
- torch.backends.cudnn.benchmark = False
142
- return seed
143
-
144
  # Text to Image function with thinking option and hyperparameters
145
- @spaces.GPU(duration=90)
146
  def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
147
  timestep_shift=3.0, num_timesteps=50,
148
  cfg_renorm_min=1.0, cfg_renorm_type="global",
149
  max_think_token_n=1024, do_sample=False, text_temperature=0.3,
150
  seed=0, image_ratio="1:1"):
151
- # Set seed for reproducibility
152
- set_seed(seed)
153
-
154
- if image_ratio == "1:1":
155
- image_shapes = (1024, 1024)
156
- elif image_ratio == "4:3":
157
- image_shapes = (768, 1024)
158
- elif image_ratio == "3:4":
159
- image_shapes = (1024, 768)
160
- elif image_ratio == "16:9":
161
- image_shapes = (576, 1024)
162
- elif image_ratio == "9:16":
163
- image_shapes = (1024, 576)
164
-
165
- # Set hyperparameters
166
- inference_hyper = dict(
167
- max_think_token_n=max_think_token_n if show_thinking else 1024,
168
- do_sample=do_sample if show_thinking else False,
169
- text_temperature=text_temperature if show_thinking else 0.3,
170
- cfg_text_scale=cfg_text_scale,
171
- cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
172
- timestep_shift=timestep_shift,
173
- num_timesteps=num_timesteps,
174
- cfg_renorm_min=cfg_renorm_min,
175
- cfg_renorm_type=cfg_renorm_type,
176
- image_shapes=image_shapes,
177
- )
178
 
179
- result = {"text": "", "image": None}
180
- # Call inferencer with or without think parameter based on user choice
181
- for i in inferencer(text=prompt, think=show_thinking, understanding_output=False, **inference_hyper):
182
- print(type(i))
183
- if type(i) == str:
184
- result["text"] += i
185
- else:
186
- result["image"] = i
187
-
188
- yield result["image"], result.get("text", None)
189
 
190
 
191
  # Image Understanding function with thinking option and hyperparameters
192
- @spaces.GPU(duration=90)
193
  def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
194
  do_sample=False, text_temperature=0.3, max_new_tokens=512):
195
- if image is None:
196
- return "Please upload an image."
197
-
198
- if isinstance(image, np.ndarray):
199
- image = Image.fromarray(image)
200
-
201
- image = pil_img2rgb(image)
202
-
203
- # Set hyperparameters
204
- inference_hyper = dict(
205
- do_sample=do_sample,
206
- text_temperature=text_temperature,
207
- max_think_token_n=max_new_tokens, # Set max_length
208
- )
209
-
210
- result = {"text": "", "image": None}
211
- # Use show_thinking parameter to control thinking process
212
- for i in inferencer(image=image, text=prompt, think=show_thinking,
213
- understanding_output=True, **inference_hyper):
214
- if type(i) == str:
215
- result["text"] += i
216
- else:
217
- result["image"] = i
218
- yield result["text"]
219
 
220
 
221
  # Image Editing function with thinking option and hyperparameters
222
- @spaces.GPU(duration=90)
223
  def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
224
  cfg_img_scale=2.0, cfg_interval=0.0,
225
  timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
226
  cfg_renorm_type="text_channel", max_think_token_n=1024,
227
  do_sample=False, text_temperature=0.3, seed=0):
228
- # Set seed for reproducibility
229
- set_seed(seed)
230
-
231
- if image is None:
232
- return "Please upload an image.", ""
233
 
234
- if isinstance(image, np.ndarray):
235
- image = Image.fromarray(image)
236
-
237
- image = pil_img2rgb(image)
238
-
239
- # Set hyperparameters
240
- inference_hyper = dict(
241
- max_think_token_n=max_think_token_n if show_thinking else 1024,
242
- do_sample=do_sample if show_thinking else False,
243
- text_temperature=text_temperature if show_thinking else 0.3,
244
- cfg_text_scale=cfg_text_scale,
245
- cfg_img_scale=cfg_img_scale,
246
- cfg_interval=[cfg_interval, 1.0], # End fixed at 1.0
247
- timestep_shift=timestep_shift,
248
- num_timesteps=num_timesteps,
249
- cfg_renorm_min=cfg_renorm_min,
250
- cfg_renorm_type=cfg_renorm_type,
251
- )
252
-
253
- # Include thinking parameter based on user choice
254
- result = {"text": "", "image": None}
255
- for i in inferencer(image=image, text=prompt, think=show_thinking, understanding_output=False, **inference_hyper):
256
- if type(i) == str:
257
- result["text"] += i
258
- else:
259
- result["image"] = i
260
-
261
- yield result["image"], result.get("text", "")
262
 
263
  # Helper function to load example images
264
  def load_example_image(image_path):
@@ -268,7 +34,6 @@ def load_example_image(image_path):
268
  print(f"Error loading example image: {e}")
269
  return None
270
 
271
-
272
  # Gradio UI
273
  with gr.Blocks() as demo:
274
  gr.Markdown("""
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
2
  from PIL import Image
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # Text to Image function with thinking option and hyperparameters
 
5
  def text_to_image(prompt, show_thinking=False, cfg_text_scale=4.0, cfg_interval=0.4,
6
  timestep_shift=3.0, num_timesteps=50,
7
  cfg_renorm_min=1.0, cfg_renorm_type="global",
8
  max_think_token_n=1024, do_sample=False, text_temperature=0.3,
9
  seed=0, image_ratio="1:1"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ yield None, None
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  # Image Understanding function with thinking option and hyperparameters
 
15
  def image_understanding(image: Image.Image, prompt: str, show_thinking=False,
16
  do_sample=False, text_temperature=0.3, max_new_tokens=512):
17
+ yield None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  # Image Editing function with thinking option and hyperparameters
 
21
  def edit_image(image: Image.Image, prompt: str, show_thinking=False, cfg_text_scale=4.0,
22
  cfg_img_scale=2.0, cfg_interval=0.0,
23
  timestep_shift=3.0, num_timesteps=50, cfg_renorm_min=1.0,
24
  cfg_renorm_type="text_channel", max_think_token_n=1024,
25
  do_sample=False, text_temperature=0.3, seed=0):
 
 
 
 
 
26
 
27
+ yield None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Helper function to load example images
30
  def load_example_image(image_path):
 
34
  print(f"Error loading example image: {e}")
35
  return None
36
 
 
37
  # Gradio UI
38
  with gr.Blocks() as demo:
39
  gr.Markdown("""