killTheHostage commited on
Commit
367cc7f
·
1 Parent(s): 305557f

Change the call interface and adjust the program execution logic

Browse files
Files changed (1) hide show
  1. mlcd_seg.py +4 -85
mlcd_seg.py CHANGED
@@ -38,6 +38,7 @@ from PIL import Image
38
  from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
39
  from transformers.modeling_outputs import CausalLMOutputWithPast
40
  from transformers.generation.utils import GenerateOutput
 
41
  from safetensors.torch import load_file as safetensors_load
42
  from .vision_tower import build_vision_tower
43
  from .vision_resampler import build_vision_resampler
@@ -141,10 +142,8 @@ class MLCDSegMetaModel:
141
 
142
  def dispatch_weight(self, config):
143
  safetensors_set = set()
144
- index_folder = Path(getattr(config, "_name_or_path", "./"))
145
- index_file = index_folder / "model.safetensors.index.json"
146
- if not index_file.exists():
147
- os.getenv("")
148
  with open(index_file, "r") as safetensors_index:
149
  safetensors_map = json.loads(safetensors_index.read())
150
  for key, value in safetensors_map["weight_map"].items():
@@ -156,7 +155,7 @@ class MLCDSegMetaModel:
156
  projector_weight = {}
157
  text2sam_projection_weight = {}
158
  for safetensors_file in safetensors_set:
159
- temp_load = safetensors_load(index_folder / safetensors_file)
160
  for key, value in temp_load.items():
161
  if key.startswith("model.sam."):
162
  sam_weight[key.replace("model.sam.", "")] = value
@@ -174,86 +173,6 @@ class MLCDSegMetaModel:
174
  vision_tower = vision_tower[0]
175
  return vision_tower
176
 
177
- # def initialize_vision_modules(self, model_args, fsdp=None):
178
- # vision_tower = model_args.vision_tower
179
- # mm_vision_select_layer = model_args.mm_vision_select_layer
180
- # mm_vision_select_feature = model_args.mm_vision_select_feature
181
- # pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
182
- # mm_patch_merge_type = model_args.mm_patch_merge_type
183
-
184
- # self.config.mm_vision_tower = vision_tower
185
- # self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
186
-
187
- # if self.get_vision_tower() is None:
188
- # vision_tower = build_vision_tower(model_args)
189
- # vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
190
- # for k, v in vision_resampler.config.items():
191
- # setattr(self.config, k, v)
192
-
193
- # if fsdp is not None and len(fsdp) > 0:
194
- # self.vision_tower = [vision_tower]
195
- # self.vision_resampler = [vision_resampler]
196
- # else:
197
- # self.vision_tower = vision_tower
198
- # self.vision_resampler = vision_resampler
199
- # else:
200
- # if fsdp is not None and len(fsdp) > 0:
201
- # vision_resampler = self.vision_resampler[0]
202
- # vision_tower = self.vision_tower[0]
203
- # else:
204
- # vision_resampler = self.vision_resampler
205
- # vision_tower = self.vision_tower
206
- # vision_tower.load_model()
207
-
208
- # # In case it is frozen by LoRA
209
- # for p in self.vision_resampler.parameters():
210
- # p.requires_grad = True
211
-
212
- # self.config.use_mm_proj = True
213
- # self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
214
- # self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
215
- # self.config.mm_vision_select_layer = mm_vision_select_layer
216
- # self.config.mm_vision_select_feature = mm_vision_select_feature
217
- # self.config.mm_patch_merge_type = mm_patch_merge_type
218
-
219
- # for key in vars(model_args):
220
- # if key.startswith('sam_'):
221
- # setattr(self.config, key, getattr(model_args, key))
222
-
223
- # if not hasattr(self.config, 'add_faster_video'):
224
- # if model_args.add_faster_video:
225
- # embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
226
- # self.faster_token = nn.Parameter(
227
- # torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
228
- # )
229
-
230
- # if getattr(self, "mm_projector", None) is None:
231
- # self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
232
-
233
- # if "unpad" in mm_patch_merge_type:
234
- # embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
235
- # self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
236
-
237
- # if getattr(self.config, 'sam_path', None) is not None:
238
- # self.sam = build_sam_vit_h(self.config.sam_path)
239
- # self.text2sam_projection = text2sam_projection_layer(self.config)
240
- # else:
241
- # if getattr(self.config, 'sam_path', None) is not None and self.config.sam_path !="":
242
- # self.sam = build_sam_vit_h(self.config.sam_path)
243
- # self.text2sam_projection = text2sam_projection_layer(self.config)
244
- # # In case it is frozen by LoRA
245
- # for p in self.mm_projector.parameters():
246
- # p.requires_grad = True
247
-
248
- # if pretrain_mm_mlp_adapter is not None:
249
- # mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
250
-
251
- # def get_w(weights, keyword):
252
- # return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
253
-
254
- # incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
255
- # incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
256
-
257
 
258
  def unpad_image(tensor, original_size):
259
  """
 
38
  from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
39
  from transformers.modeling_outputs import CausalLMOutputWithPast
40
  from transformers.generation.utils import GenerateOutput
41
+ from transformers.utils import cached_file
42
  from safetensors.torch import load_file as safetensors_load
43
  from .vision_tower import build_vision_tower
44
  from .vision_resampler import build_vision_resampler
 
142
 
143
  def dispatch_weight(self, config):
144
  safetensors_set = set()
145
+ repo = getattr(config, "_name_or_path", "'DeepGlint-AI/MLCD-Seg'")
146
+ index_file = cached_file(repo, "model.safetensors.index.json")
 
 
147
  with open(index_file, "r") as safetensors_index:
148
  safetensors_map = json.loads(safetensors_index.read())
149
  for key, value in safetensors_map["weight_map"].items():
 
155
  projector_weight = {}
156
  text2sam_projection_weight = {}
157
  for safetensors_file in safetensors_set:
158
+ temp_load = safetensors_load(cached_file(repo, safetensors_file))
159
  for key, value in temp_load.items():
160
  if key.startswith("model.sam."):
161
  sam_weight[key.replace("model.sam.", "")] = value
 
173
  vision_tower = vision_tower[0]
174
  return vision_tower
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  def unpad_image(tensor, original_size):
178
  """