KingNish commited on
Commit
c2cc633
·
verified ·
1 Parent(s): c40e1ba

Update inferencer.py

Browse files
Files changed (1) hide show
  1. inferencer.py +48 -50
inferencer.py CHANGED
@@ -233,57 +233,55 @@ class InterleaveInferencer:
233
 
234
  current_image_shapes = image_shapes
235
 
236
- # Use torch.cuda.amp.autocast if available, otherwise a simple context manager
237
- # For simplicity, assuming it's handled externally or not strictly needed for this snippet
238
- # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
239
-
240
- if think:
241
- system_prompt = VLM_THINK_SYSTEM_PROMPT if understanding_output else GEN_THINK_SYSTEM_PROMPT
242
- gen_context = self.update_context_text(system_prompt, gen_context)
243
- cfg_text_context = self.update_context_text(system_prompt, cfg_text_context)
244
- cfg_img_context = self.update_context_text(system_prompt, cfg_img_context)
245
-
246
- for input_term in input_lists:
247
- if isinstance(input_term, str):
248
- gen_context = self.update_context_text(input_term, gen_context)
249
- cfg_text_context = self.update_context_text(input_term, cfg_text_context)
250
- cfg_img_context = self.update_context_text(input_term, cfg_img_context)
251
- elif isinstance(input_term, Image.Image):
252
- current_image_shapes = input_term.size[::-1] # H, W
253
- use_vae_for_input_image = not understanding_output
254
- gen_context = self.update_context_image(input_term, gen_context, vae=use_vae_for_input_image, vit=True)
255
- cfg_text_context = self.update_context_image(input_term, cfg_text_context, vae=use_vae_for_input_image, vit=True)
256
- # cfg_img_context does not typically see input images
257
- else:
258
- raise ValueError(f"Unsupported input type: {type(input_term)}")
259
-
260
- if understanding_output: # Generate text
261
- yield from self.gen_text(gen_context, max_length=max_think_token_n, do_sample=do_sample, temperature=temperature)
262
- else: # Generate image
263
  if think:
264
- thought_text_parts = []
265
- for part in self.gen_text(gen_context, max_length=max_think_token_n, do_sample=do_sample, temperature=temperature):
266
- yield part # Stream the thought
267
- thought_text_parts.append(part)
268
- full_thought_text = "".join(thought_text_parts)
269
- if full_thought_text: # Only update if thought was generated
270
- gen_context = self.update_context_text(full_thought_text, gen_context)
271
- cfg_text_context = self.update_context_text(full_thought_text, cfg_text_context)
272
-
273
- img = self.gen_image(
274
- image_shape=current_image_shapes,
275
- gen_context=gen_context,
276
- cfg_text_precontext=cfg_text_context,
277
- cfg_img_precontext=cfg_img_context,
278
- cfg_text_scale=cfg_text_scale,
279
- cfg_img_scale=cfg_img_scale,
280
- cfg_interval=cfg_interval,
281
- timestep_shift=timestep_shift,
282
- num_timesteps=num_timesteps,
283
- cfg_renorm_min=cfg_renorm_min,
284
- cfg_renorm_type=cfg_renorm_type,
285
- )
286
- yield img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  def __call__(
289
  self,
 
233
 
234
  current_image_shapes = image_shapes
235
 
236
+ with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
237
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  if think:
239
+ system_prompt = VLM_THINK_SYSTEM_PROMPT if understanding_output else GEN_THINK_SYSTEM_PROMPT
240
+ gen_context = self.update_context_text(system_prompt, gen_context)
241
+ cfg_text_context = self.update_context_text(system_prompt, cfg_text_context)
242
+ cfg_img_context = self.update_context_text(system_prompt, cfg_img_context)
243
+
244
+ for input_term in input_lists:
245
+ if isinstance(input_term, str):
246
+ gen_context = self.update_context_text(input_term, gen_context)
247
+ cfg_text_context = self.update_context_text(input_term, cfg_text_context)
248
+ cfg_img_context = self.update_context_text(input_term, cfg_img_context)
249
+ elif isinstance(input_term, Image.Image):
250
+ current_image_shapes = input_term.size[::-1] # H, W
251
+ use_vae_for_input_image = not understanding_output
252
+ gen_context = self.update_context_image(input_term, gen_context, vae=use_vae_for_input_image, vit=True)
253
+ cfg_text_context = self.update_context_image(input_term, cfg_text_context, vae=use_vae_for_input_image, vit=True)
254
+ # cfg_img_context does not typically see input images
255
+ else:
256
+ raise ValueError(f"Unsupported input type: {type(input_term)}")
257
+
258
+ if understanding_output: # Generate text
259
+ yield from self.gen_text(gen_context, max_length=max_think_token_n, do_sample=do_sample, temperature=temperature)
260
+ else: # Generate image
261
+ if think:
262
+ thought_text_parts = []
263
+ for part in self.gen_text(gen_context, max_length=max_think_token_n, do_sample=do_sample, temperature=temperature):
264
+ yield part # Stream the thought
265
+ thought_text_parts.append(part)
266
+ full_thought_text = "".join(thought_text_parts)
267
+ if full_thought_text: # Only update if thought was generated
268
+ gen_context = self.update_context_text(full_thought_text, gen_context)
269
+ cfg_text_context = self.update_context_text(full_thought_text, cfg_text_context)
270
+
271
+ img = self.gen_image(
272
+ image_shape=current_image_shapes,
273
+ gen_context=gen_context,
274
+ cfg_text_precontext=cfg_text_context,
275
+ cfg_img_precontext=cfg_img_context,
276
+ cfg_text_scale=cfg_text_scale,
277
+ cfg_img_scale=cfg_img_scale,
278
+ cfg_interval=cfg_interval,
279
+ timestep_shift=timestep_shift,
280
+ num_timesteps=num_timesteps,
281
+ cfg_renorm_min=cfg_renorm_min,
282
+ cfg_renorm_type=cfg_renorm_type,
283
+ )
284
+ yield img
285
 
286
  def __call__(
287
  self,