KingNish commited on
Commit
e6e1837
·
verified ·
1 Parent(s): 3530257

Update inferencer.py

Browse files
Files changed (1) hide show
  1. inferencer.py +10 -10
inferencer.py CHANGED
@@ -228,8 +228,8 @@ class InterleaveInferencer:
228
  image_shapes=(1024, 1024), # Default, can be overridden by actual input image
229
  ):
230
  gen_context = self.init_gen_context()
231
- cfg_text_context = self.init_gen_context()
232
- cfg_img_context = self.init_gen_context()
233
 
234
  current_image_shapes = image_shapes
235
 
@@ -243,15 +243,16 @@ class InterleaveInferencer:
243
 
244
  for input_term in input_lists:
245
  if isinstance(input_term, str):
 
246
  gen_context = self.update_context_text(input_term, gen_context)
247
- cfg_text_context = self.update_context_text(input_term, cfg_text_context)
248
  cfg_img_context = self.update_context_text(input_term, cfg_img_context)
 
249
  elif isinstance(input_term, Image.Image):
250
- current_image_shapes = input_term.size[::-1] # H, W
251
- use_vae_for_input_image = not understanding_output
252
- gen_context = self.update_context_image(input_term, gen_context, vae=use_vae_for_input_image, vit=True)
253
- cfg_text_context = self.update_context_image(input_term, cfg_text_context, vae=use_vae_for_input_image, vit=True)
254
- # cfg_img_context does not typically see input images
255
  else:
256
  raise ValueError(f"Unsupported input type: {type(input_term)}")
257
 
@@ -266,10 +267,9 @@ class InterleaveInferencer:
266
  full_thought_text = "".join(thought_text_parts)
267
  if full_thought_text: # Only update if thought was generated
268
  gen_context = self.update_context_text(full_thought_text, gen_context)
269
- cfg_text_context = self.update_context_text(full_thought_text, cfg_text_context)
270
 
271
  img = self.gen_image(
272
- image_shape=current_image_shapes,
273
  gen_context=gen_context,
274
  cfg_text_precontext=cfg_text_context,
275
  cfg_img_precontext=cfg_img_context,
 
228
  image_shapes=(1024, 1024), # Default, can be overridden by actual input image
229
  ):
230
  gen_context = self.init_gen_context()
231
+ cfg_text_context = deepcopy(gen_context)
232
+ cfg_img_context = deepcopy(gen_context)
233
 
234
  current_image_shapes = image_shapes
235
 
 
243
 
244
  for input_term in input_lists:
245
  if isinstance(input_term, str):
246
+ cfg_text_context = deepcopy(gen_context)
247
  gen_context = self.update_context_text(input_term, gen_context)
 
248
  cfg_img_context = self.update_context_text(input_term, cfg_img_context)
249
+
250
  elif isinstance(input_term, Image.Image):
251
+ input_term = self.vae_transform.resize_transform(pil_img2rgb(input_term))
252
+ gen_context = self.update_context_image(input_term, gen_context, vae=not understanding_output)
253
+ image_shapes = input_term.size[::-1]
254
+ cfg_text_context = deepcopy(gen_context)
255
+
256
  else:
257
  raise ValueError(f"Unsupported input type: {type(input_term)}")
258
 
 
267
  full_thought_text = "".join(thought_text_parts)
268
  if full_thought_text: # Only update if thought was generated
269
  gen_context = self.update_context_text(full_thought_text, gen_context)
 
270
 
271
  img = self.gen_image(
272
+ image_shape=image_shapes,
273
  gen_context=gen_context,
274
  cfg_text_precontext=cfg_text_context,
275
  cfg_img_precontext=cfg_img_context,