killTheHostage commited on
Commit
305557f
·
1 Parent(s): 363b5fe

Change the call interface and adjust the program execution logic

Browse files
Files changed (3) hide show
  1. README.md +16 -11
  2. config.json +1 -1
  3. mlcd_seg.py +221 -126
README.md CHANGED
@@ -30,6 +30,7 @@ base_model:
30
 
31
  ## Evaluation
32
 
 
33
  ```python
34
  model_path = "DeepGlint-AI/MLCD-Seg" # or use your local path
35
  mlcd_seg = AutoModel.from_pretrained(
@@ -40,19 +41,23 @@ mlcd_seg = AutoModel.from_pretrained(
40
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
41
  # Assuming you have an image named test.jpg
42
  seg_img = Image.open("test.jpg").convert('RGB')
43
- seg_prompt = "The <image> provides an overview of the picture.\nCould you provide a segmentation mask for the right giraffe in this image?"
44
- pred_mask = model.predict_forward(seg_img, seg_prompt, tokenizer, force_seg=False)
45
  ```
46
 
47
- ## Tips for updating this repo in the future
48
-
49
-
50
- Huggingface uses cache management module code, so manual clearing of cache is required after repo update
51
-
52
-
53
- ```bash
54
- cd ~/.cache/huggingface/modules/transformers_modules
55
- rm mlcd_seg.py vision_projector.py vision_resampler.py vision_tower.py sam.py conversation_mlcd_seg.py
 
 
 
 
56
  ```
57
 
58
 
 
30
 
31
  ## Evaluation
32
 
33
+ If you just want to use this code, please refer to this sample below
34
  ```python
35
  model_path = "DeepGlint-AI/MLCD-Seg" # or use your local path
36
  mlcd_seg = AutoModel.from_pretrained(
 
41
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
42
  # Assuming you have an image named test.jpg
43
  seg_img = Image.open("test.jpg").convert('RGB')
44
+ seg_prompt = "Could you provide a segmentation mask for the right giraffe in this image?"
45
+ pred_mask = model.seg(seg_img, seg_prompt, tokenizer, force_seg=False)
46
  ```
47
 
48
+ If you want to use this code measurement dataset (e.g. refcoco), then you need to use the following method
49
+ ```python
50
+ model_path = "DeepGlint-AI/MLCD-Seg" # or use your local path
51
+ mlcd_seg = AutoModel.from_pretrained(
52
+ model_path,
53
+ torch_dtype=torch.float16,
54
+ trust_remote_code=True
55
+ ).cuda()
56
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
57
+ # Assuming you have an image named test.jpg
58
+ seg_img = Image.open("test.jpg").convert('RGB')
59
+ seg_prompt = "Could you provide a segmentation mask for the right giraffe in this image?"
60
+ pred_mask = model.seg(seg_img, seg_prompt, tokenizer, force_seg=True)
61
  ```
62
 
63
 
config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
- "_name_or_path": "DeepGlint-AI/MLCD-Embodied-7B",
3
  "add_faster_video": false,
4
  "add_time_instruction": false,
5
  "architectures": [
 
1
  {
2
+ "_name_or_path": "DeepGlint-AI/MLCD-Seg",
3
  "add_faster_video": false,
4
  "add_time_instruction": false,
5
  "architectures": [
mlcd_seg.py CHANGED
@@ -27,6 +27,7 @@ import random
27
  import ast
28
  import re
29
  import json
 
30
  import numpy as np
31
  import torch
32
  import torch.nn as nn
@@ -42,7 +43,7 @@ from .vision_tower import build_vision_tower
42
  from .vision_resampler import build_vision_resampler
43
  from .vision_projector import build_vision_projector
44
  from .sam import build_sam_vit_h, text2sam_projection_layer
45
- from .conversation_mlcd_seg import default_conversation
46
  from .transform import ResizeLongestSide
47
  from typing import Optional, Any, List, Tuple, Union, Dict
48
 
@@ -140,7 +141,10 @@ class MLCDSegMetaModel:
140
 
141
  def dispatch_weight(self, config):
142
  safetensors_set = set()
143
- index_file = Path(getattr(config, "name_or_path", "./")) / "model.safetensors.index.json"
 
 
 
144
  with open(index_file, "r") as safetensors_index:
145
  safetensors_map = json.loads(safetensors_index.read())
146
  for key, value in safetensors_map["weight_map"].items():
@@ -152,7 +156,7 @@ class MLCDSegMetaModel:
152
  projector_weight = {}
153
  text2sam_projection_weight = {}
154
  for safetensors_file in safetensors_set:
155
- temp_load = safetensors_load(safetensors_file)
156
  for key, value in temp_load.items():
157
  if key.startswith("model.sam."):
158
  sam_weight[key.replace("model.sam.", "")] = value
@@ -170,85 +174,85 @@ class MLCDSegMetaModel:
170
  vision_tower = vision_tower[0]
171
  return vision_tower
172
 
173
- def initialize_vision_modules(self, model_args, fsdp=None):
174
- vision_tower = model_args.vision_tower
175
- mm_vision_select_layer = model_args.mm_vision_select_layer
176
- mm_vision_select_feature = model_args.mm_vision_select_feature
177
- pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
178
- mm_patch_merge_type = model_args.mm_patch_merge_type
179
-
180
- self.config.mm_vision_tower = vision_tower
181
- self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
182
-
183
- if self.get_vision_tower() is None:
184
- vision_tower = build_vision_tower(model_args)
185
- vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
186
- for k, v in vision_resampler.config.items():
187
- setattr(self.config, k, v)
188
-
189
- if fsdp is not None and len(fsdp) > 0:
190
- self.vision_tower = [vision_tower]
191
- self.vision_resampler = [vision_resampler]
192
- else:
193
- self.vision_tower = vision_tower
194
- self.vision_resampler = vision_resampler
195
- else:
196
- if fsdp is not None and len(fsdp) > 0:
197
- vision_resampler = self.vision_resampler[0]
198
- vision_tower = self.vision_tower[0]
199
- else:
200
- vision_resampler = self.vision_resampler
201
- vision_tower = self.vision_tower
202
- vision_tower.load_model()
203
-
204
- # In case it is frozen by LoRA
205
- for p in self.vision_resampler.parameters():
206
- p.requires_grad = True
207
-
208
- self.config.use_mm_proj = True
209
- self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
210
- self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
211
- self.config.mm_vision_select_layer = mm_vision_select_layer
212
- self.config.mm_vision_select_feature = mm_vision_select_feature
213
- self.config.mm_patch_merge_type = mm_patch_merge_type
214
 
215
- for key in vars(model_args):
216
- if key.startswith('sam_'):
217
- setattr(self.config, key, getattr(model_args, key))
218
 
219
- if not hasattr(self.config, 'add_faster_video'):
220
- if model_args.add_faster_video:
221
- embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
222
- self.faster_token = nn.Parameter(
223
- torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
224
- )
225
-
226
- if getattr(self, "mm_projector", None) is None:
227
- self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
228
-
229
- if "unpad" in mm_patch_merge_type:
230
- embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
231
- self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
232
 
233
- if getattr(self.config, 'sam_path', None) is not None:
234
- self.sam = build_sam_vit_h(self.config.sam_path)
235
- self.text2sam_projection = text2sam_projection_layer(self.config)
236
- else:
237
- if getattr(self.config, 'sam_path', None) is not None and self.config.sam_path !="":
238
- self.sam = build_sam_vit_h(self.config.sam_path)
239
- self.text2sam_projection = text2sam_projection_layer(self.config)
240
- # In case it is frozen by LoRA
241
- for p in self.mm_projector.parameters():
242
- p.requires_grad = True
243
 
244
- if pretrain_mm_mlp_adapter is not None:
245
- mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
246
 
247
- def get_w(weights, keyword):
248
- return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
249
 
250
- incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
251
- incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
252
 
253
 
254
  def unpad_image(tensor, original_size):
@@ -774,8 +778,61 @@ class MLCDSegForCausalLM(Qwen2ForCausalLM, MLCDSegMetaForCausalLM):
774
  image_sam_resizes: Optional[List[torch.FloatTensor]] = None,
775
  original_sizes: Optional[List[torch.FloatTensor]] = None,
776
  masks_list: Optional[List[List[torch.FloatTensor]]] = None,
777
- infer: bool = False,
778
- force_seg: bool = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
  ) -> Union[Tuple, CausalLMOutputWithPast]:
780
  input_ids_ = input_ids
781
  if inputs_embeds is None:
@@ -832,16 +889,10 @@ class MLCDSegForCausalLM(Qwen2ForCausalLM, MLCDSegMetaForCausalLM):
832
  cache_position=cache_position
833
  )
834
  sam_image_embeddings = self.get_grounding_encoder_embs(grounding_enc_imgs)
835
- if force_seg:
836
- seg_token_mask = self.create_seg_token_mask(input_ids_, old_attention_mask, img_token_num, num_images_batch)
837
- else:
838
- # should be raise NotImplementedError
839
- seg_token_mask = self.create_seg_token_mask(input_ids_, old_attention_mask, img_token_num, num_images_batch)
840
  seg_text_embeds_batch = self.process_hidden_states(output["hidden_states"], seg_token_mask)
841
  pred_masks_batch = self.generate_and_postprocess_masks(seg_text_embeds_batch, sam_image_embeddings, num_images_batch, image_sam_resizes, original_sizes)
842
- if infer:
843
- return {"output":output, "pred_masks":pred_masks_batch}
844
- return MLCDSegOutputWithPast(**output)
845
 
846
  @torch.no_grad()
847
  def generate(
@@ -856,13 +907,29 @@ class MLCDSegForCausalLM(Qwen2ForCausalLM, MLCDSegMetaForCausalLM):
856
  attention_mask = kwargs.pop("attention_mask", None)
857
  if "inputs_embeds" in kwargs:
858
  raise NotImplementedError("`inputs_embeds` is not supported")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
 
860
- if images is not None:
861
- (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
862
- else:
863
- inputs_embeds = self.get_model().embed_tokens(inputs)
864
-
865
- return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
866
 
867
  def generate_and_postprocess_masks(self, seg_text_embeds_batch, sam_image_embeddings, num_images_batch, image_sam_resizes, original_sizes):
868
  assert len(seg_text_embeds_batch) == len(num_images_batch)
@@ -911,20 +978,18 @@ class MLCDSegForCausalLM(Qwen2ForCausalLM, MLCDSegMetaForCausalLM):
911
  mask = input_ids[i][num_images_batch[i]:] == self.seg_token_idx
912
  seg_token_mask.append(
913
  torch.cat(
914
- [torch.zeros((1, img_token_num[i])).bool().cuda(), mask.unsqueeze(0), torch.zeros((1, max_len-(len(input_ids[i]) + img_token_num[i] - num_images_batch[i]))).bool().cuda()], dim=1
915
  )
916
  )
917
  return torch.cat(seg_token_mask, dim=0)
918
 
919
  def get_grounding_encoder_embs(self, batch_images: torch.FloatTensor):
920
- # with torch.no_grad():
921
  batch_feats = []
922
  for images in batch_images:
923
  batch_feats.append(torch.cat([self._encode_single_image(img) for img in images], dim=0))
924
  return batch_feats
925
 
926
  def _encode_single_image(self, image):
927
- # torch.cuda.empty_cache()
928
  return self.model.sam.image_encoder(image.unsqueeze(0))
929
 
930
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
@@ -937,22 +1002,34 @@ class MLCDSegForCausalLM(Qwen2ForCausalLM, MLCDSegMetaForCausalLM):
937
  inputs["image_sizes"] = image_sizes
938
  return inputs
939
 
940
- def process_prompt(self, text, tokenizer: PreTrainedTokenizer, force_seg=True) -> Dict:
941
- conv = default_conversation.copy()
942
- BEGIN_SIGNAL = "### "
943
- END_SIGNAL = "\n"
944
- roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
945
- # Apply prompt templates
946
- sys_prompt = default_conversation.system + "\n\n"
947
- full_prompt = sys_prompt + BEGIN_SIGNAL + roles["human"] + ": " + text + END_SIGNAL
948
- if force_seg:
 
 
 
 
 
 
 
 
 
 
 
 
949
  full_prompt += BEGIN_SIGNAL + roles["gpt"] + ": It is [SEG]." + END_SIGNAL
950
- full_prompt += BEGIN_SIGNAL
951
- input_ids = torch.stack([gen_image_token(full_prompt, tokenizer, return_tensors='pt')], dim=0)
952
- return dict(
953
- input_ids=input_ids,
954
- labels=None,
955
- )
956
 
957
  def process_images(self, images, image_processor, model_cfg):
958
  image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
@@ -967,11 +1044,11 @@ class MLCDSegForCausalLM(Qwen2ForCausalLM, MLCDSegMetaForCausalLM):
967
  new_images = torch.stack(new_images, dim=0)
968
  return new_images
969
 
970
- def predict_forward(self, image, prompt, tokenizer, force_seg=True):
971
  self.seg_token_idx = tokenizer(DEFAULT_SEG_TOKEN, add_special_tokens=False).input_ids[0]
972
  image_np = np.array(image)
973
  image_sizes = [image.size]
974
- input_ids = self.process_prompt(prompt, tokenizer, force_seg)["input_ids"].to(self.device) # 这里需要设置对应的device
975
  image_processor = self.get_vision_tower().image_processor
976
  image_tensors = self.process_images([image], image_processor, self.config)
977
  image_np_resize = self.sam_transform.apply_image(image_np)
@@ -994,21 +1071,39 @@ class MLCDSegForCausalLM(Qwen2ForCausalLM, MLCDSegMetaForCausalLM):
994
  image_tensors = [[x_.unsqueeze(dim=0).to(dtype=self.dtype, device=self.device, non_blocking=True) for x_ in image_tensors]]
995
  else:
996
  image_tensors = image_tensors.to(dtype=self.dtype, device='cuda', non_blocking=True)
997
-
998
- with torch.inference_mode():
999
- net_out = self.forward(
1000
- input_ids=input_ids,
1001
- output_hidden_states=True,
1002
- images=image_tensors,
1003
- image_sizes=image_sizes,
1004
- grounding_enc_imgs=[torch.stack(grounding_enc_img_list, dim=0)],
1005
- image_sam_resizes=[image_sam_resize_list],
1006
- original_sizes=[(mask_h, mask_w)],
1007
- infer=True,
1008
- force_seg=force_seg
1009
- )
1010
- pred_mask = net_out["pred_masks"][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1011
  return pred_mask
 
 
1012
 
1013
 
1014
  def gen_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
 
27
  import ast
28
  import re
29
  import json
30
+ import os
31
  import numpy as np
32
  import torch
33
  import torch.nn as nn
 
43
  from .vision_resampler import build_vision_resampler
44
  from .vision_projector import build_vision_projector
45
  from .sam import build_sam_vit_h, text2sam_projection_layer
46
+ from .conversation_mlcd_seg import conv_templates, default_conversation
47
  from .transform import ResizeLongestSide
48
  from typing import Optional, Any, List, Tuple, Union, Dict
49
 
 
141
 
142
  def dispatch_weight(self, config):
143
  safetensors_set = set()
144
+ index_folder = Path(getattr(config, "_name_or_path", "./"))
145
+ index_file = index_folder / "model.safetensors.index.json"
146
+ if not index_file.exists():
147
+ os.getenv("")
148
  with open(index_file, "r") as safetensors_index:
149
  safetensors_map = json.loads(safetensors_index.read())
150
  for key, value in safetensors_map["weight_map"].items():
 
156
  projector_weight = {}
157
  text2sam_projection_weight = {}
158
  for safetensors_file in safetensors_set:
159
+ temp_load = safetensors_load(index_folder / safetensors_file)
160
  for key, value in temp_load.items():
161
  if key.startswith("model.sam."):
162
  sam_weight[key.replace("model.sam.", "")] = value
 
174
  vision_tower = vision_tower[0]
175
  return vision_tower
176
 
177
+ # def initialize_vision_modules(self, model_args, fsdp=None):
178
+ # vision_tower = model_args.vision_tower
179
+ # mm_vision_select_layer = model_args.mm_vision_select_layer
180
+ # mm_vision_select_feature = model_args.mm_vision_select_feature
181
+ # pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
182
+ # mm_patch_merge_type = model_args.mm_patch_merge_type
183
+
184
+ # self.config.mm_vision_tower = vision_tower
185
+ # self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
186
+
187
+ # if self.get_vision_tower() is None:
188
+ # vision_tower = build_vision_tower(model_args)
189
+ # vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
190
+ # for k, v in vision_resampler.config.items():
191
+ # setattr(self.config, k, v)
192
+
193
+ # if fsdp is not None and len(fsdp) > 0:
194
+ # self.vision_tower = [vision_tower]
195
+ # self.vision_resampler = [vision_resampler]
196
+ # else:
197
+ # self.vision_tower = vision_tower
198
+ # self.vision_resampler = vision_resampler
199
+ # else:
200
+ # if fsdp is not None and len(fsdp) > 0:
201
+ # vision_resampler = self.vision_resampler[0]
202
+ # vision_tower = self.vision_tower[0]
203
+ # else:
204
+ # vision_resampler = self.vision_resampler
205
+ # vision_tower = self.vision_tower
206
+ # vision_tower.load_model()
207
+
208
+ # # In case it is frozen by LoRA
209
+ # for p in self.vision_resampler.parameters():
210
+ # p.requires_grad = True
211
+
212
+ # self.config.use_mm_proj = True
213
+ # self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
214
+ # self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
215
+ # self.config.mm_vision_select_layer = mm_vision_select_layer
216
+ # self.config.mm_vision_select_feature = mm_vision_select_feature
217
+ # self.config.mm_patch_merge_type = mm_patch_merge_type
218
 
219
+ # for key in vars(model_args):
220
+ # if key.startswith('sam_'):
221
+ # setattr(self.config, key, getattr(model_args, key))
222
 
223
+ # if not hasattr(self.config, 'add_faster_video'):
224
+ # if model_args.add_faster_video:
225
+ # embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
226
+ # self.faster_token = nn.Parameter(
227
+ # torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
228
+ # )
229
+
230
+ # if getattr(self, "mm_projector", None) is None:
231
+ # self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
232
+
233
+ # if "unpad" in mm_patch_merge_type:
234
+ # embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
235
+ # self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
236
 
237
+ # if getattr(self.config, 'sam_path', None) is not None:
238
+ # self.sam = build_sam_vit_h(self.config.sam_path)
239
+ # self.text2sam_projection = text2sam_projection_layer(self.config)
240
+ # else:
241
+ # if getattr(self.config, 'sam_path', None) is not None and self.config.sam_path !="":
242
+ # self.sam = build_sam_vit_h(self.config.sam_path)
243
+ # self.text2sam_projection = text2sam_projection_layer(self.config)
244
+ # # In case it is frozen by LoRA
245
+ # for p in self.mm_projector.parameters():
246
+ # p.requires_grad = True
247
 
248
+ # if pretrain_mm_mlp_adapter is not None:
249
+ # mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
250
 
251
+ # def get_w(weights, keyword):
252
+ # return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
253
 
254
+ # incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
255
+ # incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
256
 
257
 
258
  def unpad_image(tensor, original_size):
 
778
  image_sam_resizes: Optional[List[torch.FloatTensor]] = None,
779
  original_sizes: Optional[List[torch.FloatTensor]] = None,
780
  masks_list: Optional[List[List[torch.FloatTensor]]] = None,
781
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
782
+ if inputs_embeds is None:
783
+ (
784
+ input_ids,
785
+ position_ids,
786
+ attention_mask,
787
+ past_key_values,
788
+ inputs_embeds,
789
+ labels
790
+ ) = self.prepare_inputs_labels_for_multimodal(
791
+ input_ids,
792
+ position_ids,
793
+ attention_mask,
794
+ past_key_values,
795
+ labels,
796
+ images,
797
+ modalities,
798
+ image_sizes
799
+ )
800
+ output = super().forward(
801
+ input_ids=input_ids,
802
+ attention_mask=attention_mask,
803
+ position_ids=position_ids,
804
+ past_key_values=past_key_values,
805
+ inputs_embeds=inputs_embeds,
806
+ labels=labels,
807
+ use_cache=use_cache,
808
+ output_attentions=output_attentions,
809
+ output_hidden_states=True,
810
+ return_dict=return_dict,
811
+ cache_position=cache_position
812
+ )
813
+ return MLCDSegOutputWithPast(**output)
814
+
815
+ def seg_forward(
816
+ self,
817
+ input_ids: torch.LongTensor = None,
818
+ attention_mask: Optional[torch.Tensor] = None,
819
+ position_ids: Optional[torch.LongTensor] = None,
820
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
821
+ inputs_embeds: Optional[torch.FloatTensor] = None,
822
+ labels: Optional[torch.LongTensor] = None,
823
+ use_cache: Optional[bool] = None,
824
+ output_attentions: Optional[bool] = None,
825
+ output_hidden_states: Optional[bool] = None,
826
+ images: Optional[torch.FloatTensor] = None,
827
+ image_sizes: Optional[List[List[int]]] = None,
828
+ return_dict: Optional[bool] = None,
829
+ modalities: Optional[List[str]] = ["image"],
830
+ dpo_forward: Optional[bool] = False,
831
+ cache_position=None,
832
+ grounding_enc_imgs: Optional[List[torch.FloatTensor]] = None,
833
+ image_sam_resizes: Optional[List[torch.FloatTensor]] = None,
834
+ original_sizes: Optional[List[torch.FloatTensor]] = None,
835
+ masks_list: Optional[List[List[torch.FloatTensor]]] = None,
836
  ) -> Union[Tuple, CausalLMOutputWithPast]:
837
  input_ids_ = input_ids
838
  if inputs_embeds is None:
 
889
  cache_position=cache_position
890
  )
891
  sam_image_embeddings = self.get_grounding_encoder_embs(grounding_enc_imgs)
892
+ seg_token_mask = self.create_seg_token_mask(input_ids_, old_attention_mask, img_token_num, num_images_batch)
 
 
 
 
893
  seg_text_embeds_batch = self.process_hidden_states(output["hidden_states"], seg_token_mask)
894
  pred_masks_batch = self.generate_and_postprocess_masks(seg_text_embeds_batch, sam_image_embeddings, num_images_batch, image_sam_resizes, original_sizes)
895
+ return pred_masks_batch
 
 
896
 
897
  @torch.no_grad()
898
  def generate(
 
907
  attention_mask = kwargs.pop("attention_mask", None)
908
  if "inputs_embeds" in kwargs:
909
  raise NotImplementedError("`inputs_embeds` is not supported")
910
+ (
911
+ inputs,
912
+ position_ids,
913
+ attention_mask,
914
+ _,
915
+ inputs_embeds,
916
+ _,
917
+ old_attention_mask,
918
+ img_token_num,
919
+ num_images_batch
920
+ ) = self.prepare_inputs_labels_for_multimodal(
921
+ inputs,
922
+ position_ids,
923
+ attention_mask,
924
+ None,
925
+ None,
926
+ images,
927
+ image_sizes=image_sizes,
928
+ # batch_pboxes=all_pboxes
929
+ )
930
+ llm_out = super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_hidden_states=True, return_dict_in_generate=True, max_length=4096, **kwargs)
931
+ return llm_out.sequences
932
 
 
 
 
 
 
 
933
 
934
  def generate_and_postprocess_masks(self, seg_text_embeds_batch, sam_image_embeddings, num_images_batch, image_sam_resizes, original_sizes):
935
  assert len(seg_text_embeds_batch) == len(num_images_batch)
 
978
  mask = input_ids[i][num_images_batch[i]:] == self.seg_token_idx
979
  seg_token_mask.append(
980
  torch.cat(
981
+ [torch.zeros((1, img_token_num[i])).bool().to(device=self.device), mask.unsqueeze(0), torch.zeros((1, max_len-(len(input_ids[i]) + img_token_num[i] - num_images_batch[i]))).bool().to(device=self.device)], dim=1
982
  )
983
  )
984
  return torch.cat(seg_token_mask, dim=0)
985
 
986
  def get_grounding_encoder_embs(self, batch_images: torch.FloatTensor):
 
987
  batch_feats = []
988
  for images in batch_images:
989
  batch_feats.append(torch.cat([self._encode_single_image(img) for img in images], dim=0))
990
  return batch_feats
991
 
992
  def _encode_single_image(self, image):
 
993
  return self.model.sam.image_encoder(image.unsqueeze(0))
994
 
995
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
 
1002
  inputs["image_sizes"] = image_sizes
1003
  return inputs
1004
 
1005
+ def process_prompt(self, text, tokenizer: PreTrainedTokenizer, stage="gen") -> Dict:
1006
+ if stage.lower() not in ["gen", "seg"]:
1007
+ stage = "seg"
1008
+ if stage.lower() == "gen":
1009
+ conv = conv_templates['qwen_2'].copy()
1010
+ conv.append_message(conv.roles[0], text)
1011
+ conv.append_message(conv.roles[1], None)
1012
+ full_prompt = conv.get_prompt()
1013
+ input_ids = torch.stack([gen_image_token(full_prompt, tokenizer, return_tensors='pt')], dim=0)
1014
+ return dict(
1015
+ input_ids=input_ids,
1016
+ labels=None,
1017
+ )
1018
+ else:
1019
+ conv = default_conversation.copy()
1020
+ BEGIN_SIGNAL = "### "
1021
+ END_SIGNAL = "\n"
1022
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
1023
+ # Apply prompt templates
1024
+ sys_prompt = default_conversation.system + "\n\n" + "The <image> provides an overview of the picture.\n"
1025
+ full_prompt = sys_prompt + BEGIN_SIGNAL + roles["human"] + ": " + text + END_SIGNAL
1026
  full_prompt += BEGIN_SIGNAL + roles["gpt"] + ": It is [SEG]." + END_SIGNAL
1027
+ full_prompt += BEGIN_SIGNAL
1028
+ input_ids = torch.stack([gen_image_token(full_prompt, tokenizer, return_tensors='pt')], dim=0)
1029
+ return dict(
1030
+ input_ids=input_ids,
1031
+ labels=None,
1032
+ )
1033
 
1034
  def process_images(self, images, image_processor, model_cfg):
1035
  image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
 
1044
  new_images = torch.stack(new_images, dim=0)
1045
  return new_images
1046
 
1047
+ def seg(self, image, prompt, tokenizer, force_seg = False):
1048
  self.seg_token_idx = tokenizer(DEFAULT_SEG_TOKEN, add_special_tokens=False).input_ids[0]
1049
  image_np = np.array(image)
1050
  image_sizes = [image.size]
1051
+ input_ids = self.process_prompt(prompt, tokenizer, "gen")["input_ids"].to(self.device)
1052
  image_processor = self.get_vision_tower().image_processor
1053
  image_tensors = self.process_images([image], image_processor, self.config)
1054
  image_np_resize = self.sam_transform.apply_image(image_np)
 
1071
  image_tensors = [[x_.unsqueeze(dim=0).to(dtype=self.dtype, device=self.device, non_blocking=True) for x_ in image_tensors]]
1072
  else:
1073
  image_tensors = image_tensors.to(dtype=self.dtype, device='cuda', non_blocking=True)
1074
+ if not force_seg:
1075
+ attention_mask = torch.ones(input_ids.shape).bool().to(device=self.device)
1076
+ with torch.inference_mode():
1077
+ llm_gen = self.generate(
1078
+ inputs=input_ids,
1079
+ attention_mask=attention_mask,
1080
+ images=image_tensors,
1081
+ image_sizes=image_sizes,
1082
+ grounding_enc_imgs=[torch.stack(grounding_enc_img_list, dim=0)],
1083
+ image_sam_resizes=[image_sam_resize_list],
1084
+ original_sizes=[(mask_h, mask_w)],
1085
+ pad_token_id=tokenizer.eos_token_id
1086
+ )
1087
+ seg_flag = llm_gen == self.seg_token_idx
1088
+ seg_flag = torch.sum(seg_flag.int()).item()
1089
+ if seg_flag > 0:
1090
+ force_seg = True
1091
+ if force_seg:
1092
+ input_ids = self.process_prompt(prompt, tokenizer, "seg")["input_ids"].to(self.device)
1093
+ with torch.inference_mode():
1094
+ net_out = self.seg_forward(
1095
+ input_ids=input_ids,
1096
+ output_hidden_states=True,
1097
+ images=image_tensors,
1098
+ image_sizes=image_sizes,
1099
+ grounding_enc_imgs=[torch.stack(grounding_enc_img_list, dim=0)],
1100
+ image_sam_resizes=[image_sam_resize_list],
1101
+ original_sizes=[(mask_h, mask_w)],
1102
+ )
1103
+ pred_mask = net_out[0]
1104
  return pred_mask
1105
+ else:
1106
+ return torch.zeros([0] + list(image_np.shape[:2]), device=self.device)
1107
 
1108
 
1109
  def gen_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):