KingNish commited on
Commit
3ee3ce9
·
verified ·
1 Parent(s): 18e3525
Files changed (2) hide show
  1. app.py +2 -22
  2. inferencer.py +312 -315
app.py CHANGED
@@ -1,15 +1,8 @@
1
- import spaces
2
  import gradio as gr
3
  import numpy as np
4
  import os
5
  import torch
6
  import random
7
- import subprocess
8
- subprocess.run(
9
- "pip install flash-attn --no-build-isolation",
10
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
11
- shell=True,
12
- )
13
 
14
  from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
  from PIL import Image
@@ -24,23 +17,10 @@ from modeling.bagel import (
24
  SiglipVisionConfig, SiglipVisionModel
25
  )
26
  from modeling.qwen2 import Qwen2Tokenizer
27
- from huggingface_hub import snapshot_download
28
-
29
- save_dir = "./model"
30
- repo_id = "ByteDance-Seed/BAGEL-7B-MoT"
31
- cache_dir = save_dir + "/cache"
32
-
33
- snapshot_download(cache_dir=cache_dir,
34
- local_dir=save_dir,
35
- repo_id=repo_id,
36
- local_dir_use_symlinks=False,
37
- resume_download=True,
38
- allow_patterns=["*.json", "*.safetensors", "*.bin", "*.py", "*.md", "*.txt"],
39
- )
40
 
41
 
42
  # Model Initialization
43
- model_path = "./model"
44
 
45
  llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
46
  llm_config.qk_norm = True
@@ -522,4 +502,4 @@ with gr.Blocks() as demo:
522
  </div>
523
  """)
524
 
525
- demo.launch()
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import os
4
  import torch
5
  import random
 
 
 
 
 
 
6
 
7
  from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
8
  from PIL import Image
 
17
  SiglipVisionConfig, SiglipVisionModel
18
  )
19
  from modeling.qwen2 import Qwen2Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  # Model Initialization
23
+ model_path = "/path/to/BAGEL-7B-MoT/weights" #Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT
24
 
25
  llm_config = Qwen2Config.from_json_file(os.path.join(model_path, "llm_config.json"))
26
  llm_config.qk_norm = True
 
502
  </div>
503
  """)
504
 
505
+ demo.launch(share=True)
inferencer.py CHANGED
@@ -1,315 +1,312 @@
1
- # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
- # SPDX-License-Identifier: Apache-2.0
3
-
4
- from copy import deepcopy
5
- from typing import List, Dict, Tuple, Optional, Union, Any
6
- import matplotlib.pyplot as plt
7
-
8
- from PIL import Image
9
- import torch
10
- import torch.nn.functional as F
11
- from torch import nn
12
- from torch.nn.attention.flex_attention import create_block_mask
13
- from transformers.configuration_utils import PretrainedConfig
14
- from transformers.modeling_utils import PreTrainedModel
15
-
16
- from data.data_utils import pil_img2rgb
17
- from modeling.bagel.qwen2_navit import NaiveCache
18
-
19
-
20
-
21
- VLM_THINK_SYSTEM_PROMPT = '''You should first think about the reasoning process in the mind and then provide the user with the answer.
22
- The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here'''
23
-
24
- GEN_THINK_SYSTEM_PROMPT = '''You should first think about the planning process in the mind and then generate the image.
25
- The planning process is enclosed within <think> </think> tags, i.e. <think> planning process here </think> image here'''
26
-
27
-
28
- class InterleaveInferencer:
29
- def __init__(self, model, vae_model, tokenizer, vae_transform, vit_transform, new_token_ids):
30
- self.model = model
31
- self.vae_model = vae_model
32
- self.tokenizer = tokenizer
33
- self.vae_transform = vae_transform
34
- self.vit_transform = vit_transform
35
- self.new_token_ids = new_token_ids
36
-
37
- def init_gen_context(self):
38
- gen_context = {
39
- 'kv_lens': [0],
40
- 'ropes': [0],
41
- 'past_key_values': NaiveCache(self.model.config.llm_config.num_hidden_layers),
42
- }
43
- return gen_context
44
-
45
- @torch.no_grad()
46
- def update_context_text(self, text, gen_context):
47
- # used for interleave data, currently only support 1 data inference,
48
-
49
- past_key_values = gen_context['past_key_values']
50
- kv_lens = gen_context['kv_lens']
51
- ropes = gen_context['ropes']
52
- generation_input, kv_lens, ropes = self.model.prepare_prompts(
53
- curr_kvlens=kv_lens,
54
- curr_rope=ropes,
55
- prompts=[text],
56
- tokenizer=self.tokenizer,
57
- new_token_ids=self.new_token_ids,
58
- )
59
-
60
- past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
61
- gen_context['kv_lens'] = kv_lens
62
- gen_context['ropes'] = ropes
63
- gen_context['past_key_values'] = past_key_values
64
-
65
- return gen_context
66
-
67
- @torch.no_grad()
68
- def update_context_image(self, image, gen_context, vae=True, vit=True):
69
- # used for interleave data, currently only support 1 data inference,
70
-
71
- assert vae or vit
72
- past_key_values = gen_context['past_key_values']
73
- kv_lens = gen_context['kv_lens']
74
- ropes = gen_context['ropes']
75
-
76
- if vae:
77
- ## update vae
78
- generation_input, kv_lens, ropes = self.model.prepare_vae_images(
79
- curr_kvlens=kv_lens,
80
- curr_rope=ropes,
81
- images=[image],
82
- transforms=self.vae_transform,
83
- new_token_ids=self.new_token_ids,
84
- )
85
- past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
86
-
87
- if vit:
88
- ## update vit
89
- generation_input, kv_lens, ropes = self.model.prepare_vit_images(
90
- curr_kvlens=kv_lens,
91
- curr_rope=ropes,
92
- images=[image],
93
- transforms=self.vit_transform,
94
- new_token_ids=self.new_token_ids,
95
- )
96
- past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
97
-
98
- gen_context['kv_lens'] = kv_lens
99
- gen_context['ropes'] = ropes
100
- gen_context['past_key_values'] = past_key_values
101
-
102
- return gen_context
103
-
104
- @torch.no_grad()
105
- def gen_image(
106
- self,
107
- image_shape,
108
- gen_context,
109
- cfg_text_scale=4.0,
110
- cfg_img_scale=1.5,
111
-
112
- cfg_text_precontext=None,
113
- cfg_img_precontext=None,
114
- cfg_interval=(0.4, 1.0),
115
- cfg_renorm_min=0.0,
116
- cfg_renorm_type="global",
117
-
118
- num_timesteps=50,
119
- timestep_shift=3.0
120
- ):
121
- # print(cfg_renorm_type)
122
- past_key_values = gen_context['past_key_values']
123
- kv_lens = gen_context['kv_lens']
124
- ropes = gen_context['ropes']
125
- generation_input = self.model.prepare_vae_latent(
126
- curr_kvlens=kv_lens,
127
- curr_rope=ropes,
128
- image_sizes=[image_shape],
129
- new_token_ids=self.new_token_ids,
130
- )
131
-
132
- # text cfg
133
- cfg_text_past_key_values = cfg_text_precontext['past_key_values']
134
- kv_lens_cfg = cfg_text_precontext['kv_lens']
135
- ropes_cfg = cfg_text_precontext['ropes']
136
- generation_input_cfg_text = self.model.prepare_vae_latent_cfg(
137
- curr_kvlens=kv_lens_cfg,
138
- curr_rope=ropes_cfg,
139
- image_sizes=[image_shape],
140
- )
141
-
142
- # img cfg
143
- cfg_img_past_key_values = cfg_img_precontext['past_key_values']
144
- kv_lens_cfg = cfg_img_precontext['kv_lens']
145
- ropes_cfg = cfg_img_precontext['ropes']
146
- generation_input_cfg_img = self.model.prepare_vae_latent_cfg(
147
- curr_kvlens=kv_lens_cfg,
148
- curr_rope=ropes_cfg,
149
- image_sizes=[image_shape],
150
- )
151
-
152
- unpacked_latent = self.model.generate_image(
153
- past_key_values=past_key_values,
154
- cfg_text_past_key_values=cfg_text_past_key_values,
155
- cfg_img_past_key_values=cfg_img_past_key_values,
156
- num_timesteps=num_timesteps,
157
- cfg_text_scale=cfg_text_scale,
158
- cfg_img_scale=cfg_img_scale,
159
- cfg_interval=cfg_interval,
160
- cfg_renorm_min=cfg_renorm_min,
161
- cfg_renorm_type=cfg_renorm_type,
162
- timestep_shift=timestep_shift,
163
- **generation_input,
164
- cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
165
- cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
166
- cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
167
- cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
168
- cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
169
- cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
170
- cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
171
- cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
172
- )
173
-
174
- image = self.decode_image(unpacked_latent[0], image_shape)
175
- return image
176
-
177
-
178
- def decode_image(self, latent, image_shape):
179
- H, W = image_shape
180
- h, w = H // self.model.latent_downsample, W // self.model.latent_downsample
181
-
182
- latent = latent.reshape(1, h, w, self.model.latent_patch_size, self.model.latent_patch_size, self.model.latent_channel)
183
- latent = torch.einsum("nhwpqc->nchpwq", latent)
184
- latent = latent.reshape(1, self.model.latent_channel, h * self.model.latent_patch_size, w * self.model.latent_patch_size)
185
- image = self.vae_model.decode(latent)
186
- image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255
187
- image = Image.fromarray((image).to(torch.uint8).cpu().numpy())
188
-
189
- return image
190
-
191
- @torch.no_grad()
192
- def gen_text(self, gen_context, max_length: int = 500, do_sample: bool = True, temperature: float = 1.0):
193
- gen_context = deepcopy(gen_context)
194
- past_key_values = gen_context['past_key_values']
195
- kv_lens = gen_context['kv_lens']
196
- ropes = gen_context['ropes']
197
-
198
- generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
199
- unpacked_latent = self.model.generate_text(
200
- past_key_values=past_key_values,
201
- max_length=max_length,
202
- do_sample=do_sample,
203
- temperature=temperature,
204
- end_token_id=self.new_token_ids['eos_token_id'],
205
- **generation_input,
206
- )
207
- output = self.tokenizer.decode(unpacked_latent[:,0])
208
- output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
209
- return output
210
-
211
- @torch.no_grad()
212
- def interleave_inference(
213
- self,
214
- input_lists: List[Union[str, Image.Image]],
215
- think=False,
216
- understanding_output=False,
217
-
218
- max_think_token_n=1000,
219
- do_sample=False,
220
- text_temperature=0.3,
221
- cfg_text_scale=3.0,
222
- cfg_img_scale=1.5,
223
- cfg_interval=[0.4, 1.0],
224
- timestep_shift=3.0,
225
- num_timesteps=50,
226
- cfg_renorm_min=0.0,
227
- cfg_renorm_type="global",
228
- image_shapes=(1024, 1024),
229
- ) -> List[Union[str, Image.Image]]:
230
-
231
- output_list = []
232
- gen_context = self.init_gen_context()
233
- cfg_text_context = deepcopy(gen_context)
234
- cfg_img_context = deepcopy(gen_context)
235
-
236
- with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
237
- if think:
238
- if understanding_output:
239
- system_prompt = VLM_THINK_SYSTEM_PROMPT
240
- else:
241
- system_prompt = GEN_THINK_SYSTEM_PROMPT
242
- gen_context = self.update_context_text(system_prompt, gen_context)
243
- cfg_img_context = self.update_context_text(system_prompt, cfg_img_context)
244
-
245
- for input_term in input_lists:
246
- if isinstance(input_term, str):
247
- cfg_text_context = deepcopy(gen_context)
248
- gen_context = self.update_context_text(input_term, gen_context)
249
- cfg_img_context = self.update_context_text(input_term, cfg_img_context)
250
-
251
- elif isinstance(input_term, Image.Image):
252
- input_term = self.vae_transform.resize_transform(pil_img2rgb(input_term))
253
- gen_context = self.update_context_image(input_term, gen_context, vae=not understanding_output)
254
-
255
- image_shapes = input_term.size[::-1]
256
- cfg_text_context = deepcopy(gen_context)
257
-
258
- else:
259
- raise ValueError(f"Unsupported input type: {type(input_term)}")
260
-
261
- if understanding_output:
262
- gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n)
263
- output_list.append(gen_text)
264
-
265
- else:
266
- if think:
267
- gen_text = self.gen_text(gen_context, do_sample=do_sample, temperature=text_temperature, max_length=max_think_token_n)
268
- gen_context = self.update_context_text(gen_text, gen_context)
269
- output_list.append(gen_text)
270
-
271
- img = self.gen_image(
272
- image_shapes,
273
- gen_context,
274
- cfg_text_precontext=cfg_text_context,
275
- cfg_img_precontext=cfg_img_context,
276
-
277
- cfg_text_scale=cfg_text_scale,
278
- cfg_img_scale=cfg_img_scale,
279
- cfg_interval=cfg_interval,
280
- timestep_shift=timestep_shift,
281
- num_timesteps=num_timesteps,
282
- cfg_renorm_min=cfg_renorm_min,
283
- cfg_renorm_type=cfg_renorm_type,
284
- )
285
-
286
- output_list.append(img)
287
-
288
- return output_list
289
-
290
- def __call__(
291
- self,
292
- image: Optional[Image.Image] = None,
293
- text: Optional[str] = None,
294
- **kargs
295
- ) -> Dict[str, Any]:
296
- output_dict = {'image': None, 'text': None}
297
-
298
- if image is None and text is None:
299
- print('Please provide at least one input: either an image or text.')
300
- return output_dict
301
-
302
- input_list = []
303
- if image is not None:
304
- input_list.append(image)
305
- if text is not None:
306
- input_list.append(text)
307
-
308
- output_list = self.interleave_inference(input_list, **kargs)
309
-
310
- for i in output_list:
311
- if isinstance(i, Image.Image):
312
- output_dict['image'] = i
313
- elif isinstance(i, str):
314
- output_dict['text'] = i
315
- return output_dict
 
1
+ # Copyright 2025 Bytedance Ltd. and/or its affiliates.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from copy import deepcopy
5
+ from typing import List, Dict, Tuple, Optional, Union, Any
6
+ import matplotlib.pyplot as plt
7
+
8
+ from PIL import Image
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import nn
12
+ from torch.nn.attention.flex_attention import create_block_mask
13
+ from transformers.configuration_utils import PretrainedConfig
14
+ from transformers.modeling_utils import PreTrainedModel
15
+
16
+ from data.data_utils import pil_img2rgb
17
+ from modeling.bagel.qwen2_navit import NaiveCache
18
+
19
+
20
+
21
+ VLM_THINK_SYSTEM_PROMPT = '''You should first think about the reasoning process in the mind and then provide the user with the answer.
22
+ The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here'''
23
+
24
+ GEN_THINK_SYSTEM_PROMPT = '''You should first think about the planning process in the mind and then generate the image.
25
+ The planning process is enclosed within <think> </think> tags, i.e. <think> planning process here </think> image here'''
26
+
27
+
28
+ class InterleaveInferencer:
29
+ def __init__(self, model, vae_model, tokenizer, vae_transform, vit_transform, new_token_ids):
30
+ self.model = model
31
+ self.vae_model = vae_model
32
+ self.tokenizer = tokenizer
33
+ self.vae_transform = vae_transform
34
+ self.vit_transform = vit_transform
35
+ self.new_token_ids = new_token_ids
36
+
37
+ def init_gen_context(self):
38
+ gen_context = {
39
+ 'kv_lens': [0],
40
+ 'ropes': [0],
41
+ 'past_key_values': NaiveCache(self.model.config.llm_config.num_hidden_layers),
42
+ }
43
+ return gen_context
44
+
45
+ @torch.no_grad()
46
+ def update_context_text(self, text, gen_context):
47
+ # used for interleave data, currently only support 1 data inference,
48
+
49
+ past_key_values = gen_context['past_key_values']
50
+ kv_lens = gen_context['kv_lens']
51
+ ropes = gen_context['ropes']
52
+ generation_input, kv_lens, ropes = self.model.prepare_prompts(
53
+ curr_kvlens=kv_lens,
54
+ curr_rope=ropes,
55
+ prompts=[text],
56
+ tokenizer=self.tokenizer,
57
+ new_token_ids=self.new_token_ids,
58
+ )
59
+
60
+ past_key_values = self.model.forward_cache_update_text(past_key_values, **generation_input)
61
+ gen_context['kv_lens'] = kv_lens
62
+ gen_context['ropes'] = ropes
63
+ gen_context['past_key_values'] = past_key_values
64
+
65
+ return gen_context
66
+
67
+ @torch.no_grad()
68
+ def update_context_image(self, image, gen_context, vae=True, vit=True):
69
+ # used for interleave data, currently only support 1 data inference,
70
+
71
+ assert vae or vit
72
+ past_key_values = gen_context['past_key_values']
73
+ kv_lens = gen_context['kv_lens']
74
+ ropes = gen_context['ropes']
75
+
76
+ if vae:
77
+ ## update vae
78
+ generation_input, kv_lens, ropes = self.model.prepare_vae_images(
79
+ curr_kvlens=kv_lens,
80
+ curr_rope=ropes,
81
+ images=[image],
82
+ transforms=self.vae_transform,
83
+ new_token_ids=self.new_token_ids,
84
+ )
85
+ past_key_values = self.model.forward_cache_update_vae(self.vae_model, past_key_values, **generation_input)
86
+
87
+ if vit:
88
+ ## update vit
89
+ generation_input, kv_lens, ropes = self.model.prepare_vit_images(
90
+ curr_kvlens=kv_lens,
91
+ curr_rope=ropes,
92
+ images=[image],
93
+ transforms=self.vit_transform,
94
+ new_token_ids=self.new_token_ids,
95
+ )
96
+ past_key_values = self.model.forward_cache_update_vit(past_key_values, **generation_input)
97
+
98
+ gen_context['kv_lens'] = kv_lens
99
+ gen_context['ropes'] = ropes
100
+ gen_context['past_key_values'] = past_key_values
101
+
102
+ return gen_context
103
+
104
+ @torch.no_grad()
105
+ def gen_image(
106
+ self,
107
+ image_shape,
108
+ gen_context,
109
+ cfg_text_scale=4.0,
110
+ cfg_img_scale=1.5,
111
+
112
+ cfg_text_precontext=None,
113
+ cfg_img_precontext=None,
114
+ cfg_interval=(0.4, 1.0),
115
+ cfg_renorm_min=0.0,
116
+ cfg_renorm_type="global",
117
+
118
+ num_timesteps=50,
119
+ timestep_shift=3.0
120
+ ):
121
+ # print(cfg_renorm_type)
122
+ past_key_values = gen_context['past_key_values']
123
+ kv_lens = gen_context['kv_lens']
124
+ ropes = gen_context['ropes']
125
+ generation_input = self.model.prepare_vae_latent(
126
+ curr_kvlens=kv_lens,
127
+ curr_rope=ropes,
128
+ image_sizes=[image_shape],
129
+ new_token_ids=self.new_token_ids,
130
+ )
131
+
132
+ # text cfg
133
+ cfg_text_past_key_values = cfg_text_precontext['past_key_values']
134
+ kv_lens_cfg = cfg_text_precontext['kv_lens']
135
+ ropes_cfg = cfg_text_precontext['ropes']
136
+ generation_input_cfg_text = self.model.prepare_vae_latent_cfg(
137
+ curr_kvlens=kv_lens_cfg,
138
+ curr_rope=ropes_cfg,
139
+ image_sizes=[image_shape],
140
+ )
141
+
142
+ # img cfg
143
+ cfg_img_past_key_values = cfg_img_precontext['past_key_values']
144
+ kv_lens_cfg = cfg_img_precontext['kv_lens']
145
+ ropes_cfg = cfg_img_precontext['ropes']
146
+ generation_input_cfg_img = self.model.prepare_vae_latent_cfg(
147
+ curr_kvlens=kv_lens_cfg,
148
+ curr_rope=ropes_cfg,
149
+ image_sizes=[image_shape],
150
+ )
151
+
152
+ unpacked_latent = self.model.generate_image(
153
+ past_key_values=past_key_values,
154
+ cfg_text_past_key_values=cfg_text_past_key_values,
155
+ cfg_img_past_key_values=cfg_img_past_key_values,
156
+ num_timesteps=num_timesteps,
157
+ cfg_text_scale=cfg_text_scale,
158
+ cfg_img_scale=cfg_img_scale,
159
+ cfg_interval=cfg_interval,
160
+ cfg_renorm_min=cfg_renorm_min,
161
+ cfg_renorm_type=cfg_renorm_type,
162
+ timestep_shift=timestep_shift,
163
+ **generation_input,
164
+ cfg_text_packed_position_ids=generation_input_cfg_text['cfg_packed_position_ids'],
165
+ cfg_text_packed_query_indexes=generation_input_cfg_text['cfg_packed_query_indexes'],
166
+ cfg_text_key_values_lens=generation_input_cfg_text['cfg_key_values_lens'],
167
+ cfg_text_packed_key_value_indexes=generation_input_cfg_text['cfg_packed_key_value_indexes'],
168
+ cfg_img_packed_position_ids=generation_input_cfg_img['cfg_packed_position_ids'],
169
+ cfg_img_packed_query_indexes=generation_input_cfg_img['cfg_packed_query_indexes'],
170
+ cfg_img_key_values_lens=generation_input_cfg_img['cfg_key_values_lens'],
171
+ cfg_img_packed_key_value_indexes=generation_input_cfg_img['cfg_packed_key_value_indexes'],
172
+ )
173
+
174
+ image = self.decode_image(unpacked_latent[0], image_shape)
175
+ return image
176
+
177
+
178
+ def decode_image(self, latent, image_shape):
179
+ H, W = image_shape
180
+ h, w = H // self.model.latent_downsample, W // self.model.latent_downsample
181
+
182
+ latent = latent.reshape(1, h, w, self.model.latent_patch_size, self.model.latent_patch_size, self.model.latent_channel)
183
+ latent = torch.einsum("nhwpqc->nchpwq", latent)
184
+ latent = latent.reshape(1, self.model.latent_channel, h * self.model.latent_patch_size, w * self.model.latent_patch_size)
185
+ image = self.vae_model.decode(latent)
186
+ image = (image * 0.5 + 0.5).clamp(0, 1)[0].permute(1, 2, 0) * 255
187
+ image = Image.fromarray((image).to(torch.uint8).cpu().numpy())
188
+
189
+ return image
190
+
191
+ @torch.no_grad()
192
+ def gen_text(self, gen_context, max_length: int = 500, do_sample: bool = True, temperature: float = 1.0):
193
+ gen_context = deepcopy(gen_context)
194
+ past_key_values = gen_context['past_key_values']
195
+ kv_lens = gen_context['kv_lens']
196
+ ropes = gen_context['ropes']
197
+
198
+ generation_input = self.model.prepare_start_tokens(kv_lens, ropes, self.new_token_ids)
199
+ for unpacked_latent in self.model.generate_text(
200
+ past_key_values=past_key_values,
201
+ max_length=max_length,
202
+ do_sample=do_sample,
203
+ temperature=temperature,
204
+ end_token_id=self.new_token_ids['eos_token_id'],
205
+ **generation_input,
206
+ ):
207
+ output = self.tokenizer.decode(unpacked_latent[:,0])
208
+ # output = output.split('<|im_end|>')[0].split('<|im_start|>')[1]
209
+ yield output
210
+
211
+ @torch.no_grad()
212
+ def interleave_inference(
213
+ self,
214
+ input_lists: List[Union[str, Image.Image]],
215
+ think=False,
216
+ understanding_output=False,
217
+ max_think_token_n=1000,
218
+ do_sample=False, # for gen_text
219
+ temperature=0.3, # for gen_text
220
+ # gen_image kargs
221
+ cfg_text_scale=3.0,
222
+ cfg_img_scale=1.5,
223
+ cfg_interval=[0.4, 1.0],
224
+ timestep_shift=3.0,
225
+ num_timesteps=50,
226
+ cfg_renorm_min=0.0,
227
+ cfg_renorm_type="global",
228
+ image_shapes=(1024, 1024), # Default, can be overridden by actual input image
229
+ ):
230
+ gen_context = self.init_gen_context()
231
+ cfg_text_context = self.init_gen_context()
232
+ cfg_img_context = self.init_gen_context()
233
+
234
+ current_image_shapes = image_shapes
235
+
236
+ # Use torch.cuda.amp.autocast if available, otherwise a simple context manager
237
+ # For simplicity, assuming it's handled externally or not strictly needed for this snippet
238
+ # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
239
+
240
+ if think:
241
+ system_prompt = VLM_THINK_SYSTEM_PROMPT if understanding_output else GEN_THINK_SYSTEM_PROMPT
242
+ gen_context = self.update_context_text(system_prompt, gen_context)
243
+ cfg_text_context = self.update_context_text(system_prompt, cfg_text_context)
244
+ cfg_img_context = self.update_context_text(system_prompt, cfg_img_context)
245
+
246
+ for input_term in input_lists:
247
+ if isinstance(input_term, str):
248
+ gen_context = self.update_context_text(input_term, gen_context)
249
+ cfg_text_context = self.update_context_text(input_term, cfg_text_context)
250
+ cfg_img_context = self.update_context_text(input_term, cfg_img_context)
251
+ elif isinstance(input_term, Image.Image):
252
+ current_image_shapes = input_term.size[::-1] # H, W
253
+ use_vae_for_input_image = not understanding_output
254
+ gen_context = self.update_context_image(input_term, gen_context, vae=use_vae_for_input_image, vit=True)
255
+ cfg_text_context = self.update_context_image(input_term, cfg_text_context, vae=use_vae_for_input_image, vit=True)
256
+ # cfg_img_context does not typically see input images
257
+ else:
258
+ raise ValueError(f"Unsupported input type: {type(input_term)}")
259
+
260
+ if understanding_output: # Generate text
261
+ yield from self.gen_text(gen_context, max_length=max_think_token_n, do_sample=do_sample, temperature=temperature)
262
+ else: # Generate image
263
+ if think:
264
+ thought_text_parts = []
265
+ for part in self.gen_text(gen_context, max_length=max_think_token_n, do_sample=do_sample, temperature=temperature):
266
+ yield part # Stream the thought
267
+ thought_text_parts.append(part)
268
+ full_thought_text = "".join(thought_text_parts)
269
+ if full_thought_text: # Only update if thought was generated
270
+ gen_context = self.update_context_text(full_thought_text, gen_context)
271
+ cfg_text_context = self.update_context_text(full_thought_text, cfg_text_context)
272
+
273
+ img = self.gen_image(
274
+ image_shape=current_image_shapes,
275
+ gen_context=gen_context,
276
+ cfg_text_precontext=cfg_text_context,
277
+ cfg_img_precontext=cfg_img_context,
278
+ cfg_text_scale=cfg_text_scale,
279
+ cfg_img_scale=cfg_img_scale,
280
+ cfg_interval=cfg_interval,
281
+ timestep_shift=timestep_shift,
282
+ num_timesteps=num_timesteps,
283
+ cfg_renorm_min=cfg_renorm_min,
284
+ cfg_renorm_type=cfg_renorm_type,
285
+ )
286
+ yield img
287
+
288
+ def __call__(
289
+ self,
290
+ image: Optional[Image.Image] = None,
291
+ text: Optional[str] = None,
292
+ **kargs
293
+ ) -> Any:
294
+ input_list = []
295
+ if image is not None:
296
+ input_list.append(image)
297
+ if text is not None:
298
+ input_list.append(text)
299
+
300
+ if not input_list and not kargs.get('force_empty_input', False): # allow forcing for special cases if needed
301
+ return
302
+
303
+ # Intelligent setting of 'understanding_output' if not provided by caller
304
+ # This helps app.py's simpler calls like inferencer(text=...) to correctly produce text.
305
+ if 'understanding_output' not in kargs:
306
+ if text is not None and image is None: # Primarily text input
307
+ kargs['understanding_output'] = True
308
+ elif image is not None and text is None: # Primarily image input, assume image-to-text (captioning/VQA)
309
+ kargs['understanding_output'] = True
310
+ # If both text and image, or neither, rely on caller or default (False for image gen)
311
+
312
+ yield from self.interleave_inference(input_list, **kargs)