Commit
·
367cc7f
1
Parent(s):
305557f
Change the call interface and adjust the program execution logic
Browse files- mlcd_seg.py +4 -85
mlcd_seg.py
CHANGED
@@ -38,6 +38,7 @@ from PIL import Image
|
|
38 |
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
|
39 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
40 |
from transformers.generation.utils import GenerateOutput
|
|
|
41 |
from safetensors.torch import load_file as safetensors_load
|
42 |
from .vision_tower import build_vision_tower
|
43 |
from .vision_resampler import build_vision_resampler
|
@@ -141,10 +142,8 @@ class MLCDSegMetaModel:
|
|
141 |
|
142 |
def dispatch_weight(self, config):
|
143 |
safetensors_set = set()
|
144 |
-
|
145 |
-
index_file =
|
146 |
-
if not index_file.exists():
|
147 |
-
os.getenv("")
|
148 |
with open(index_file, "r") as safetensors_index:
|
149 |
safetensors_map = json.loads(safetensors_index.read())
|
150 |
for key, value in safetensors_map["weight_map"].items():
|
@@ -156,7 +155,7 @@ class MLCDSegMetaModel:
|
|
156 |
projector_weight = {}
|
157 |
text2sam_projection_weight = {}
|
158 |
for safetensors_file in safetensors_set:
|
159 |
-
temp_load = safetensors_load(
|
160 |
for key, value in temp_load.items():
|
161 |
if key.startswith("model.sam."):
|
162 |
sam_weight[key.replace("model.sam.", "")] = value
|
@@ -174,86 +173,6 @@ class MLCDSegMetaModel:
|
|
174 |
vision_tower = vision_tower[0]
|
175 |
return vision_tower
|
176 |
|
177 |
-
# def initialize_vision_modules(self, model_args, fsdp=None):
|
178 |
-
# vision_tower = model_args.vision_tower
|
179 |
-
# mm_vision_select_layer = model_args.mm_vision_select_layer
|
180 |
-
# mm_vision_select_feature = model_args.mm_vision_select_feature
|
181 |
-
# pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
182 |
-
# mm_patch_merge_type = model_args.mm_patch_merge_type
|
183 |
-
|
184 |
-
# self.config.mm_vision_tower = vision_tower
|
185 |
-
# self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
|
186 |
-
|
187 |
-
# if self.get_vision_tower() is None:
|
188 |
-
# vision_tower = build_vision_tower(model_args)
|
189 |
-
# vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
|
190 |
-
# for k, v in vision_resampler.config.items():
|
191 |
-
# setattr(self.config, k, v)
|
192 |
-
|
193 |
-
# if fsdp is not None and len(fsdp) > 0:
|
194 |
-
# self.vision_tower = [vision_tower]
|
195 |
-
# self.vision_resampler = [vision_resampler]
|
196 |
-
# else:
|
197 |
-
# self.vision_tower = vision_tower
|
198 |
-
# self.vision_resampler = vision_resampler
|
199 |
-
# else:
|
200 |
-
# if fsdp is not None and len(fsdp) > 0:
|
201 |
-
# vision_resampler = self.vision_resampler[0]
|
202 |
-
# vision_tower = self.vision_tower[0]
|
203 |
-
# else:
|
204 |
-
# vision_resampler = self.vision_resampler
|
205 |
-
# vision_tower = self.vision_tower
|
206 |
-
# vision_tower.load_model()
|
207 |
-
|
208 |
-
# # In case it is frozen by LoRA
|
209 |
-
# for p in self.vision_resampler.parameters():
|
210 |
-
# p.requires_grad = True
|
211 |
-
|
212 |
-
# self.config.use_mm_proj = True
|
213 |
-
# self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
|
214 |
-
# self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
|
215 |
-
# self.config.mm_vision_select_layer = mm_vision_select_layer
|
216 |
-
# self.config.mm_vision_select_feature = mm_vision_select_feature
|
217 |
-
# self.config.mm_patch_merge_type = mm_patch_merge_type
|
218 |
-
|
219 |
-
# for key in vars(model_args):
|
220 |
-
# if key.startswith('sam_'):
|
221 |
-
# setattr(self.config, key, getattr(model_args, key))
|
222 |
-
|
223 |
-
# if not hasattr(self.config, 'add_faster_video'):
|
224 |
-
# if model_args.add_faster_video:
|
225 |
-
# embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
226 |
-
# self.faster_token = nn.Parameter(
|
227 |
-
# torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
|
228 |
-
# )
|
229 |
-
|
230 |
-
# if getattr(self, "mm_projector", None) is None:
|
231 |
-
# self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
|
232 |
-
|
233 |
-
# if "unpad" in mm_patch_merge_type:
|
234 |
-
# embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
235 |
-
# self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
|
236 |
-
|
237 |
-
# if getattr(self.config, 'sam_path', None) is not None:
|
238 |
-
# self.sam = build_sam_vit_h(self.config.sam_path)
|
239 |
-
# self.text2sam_projection = text2sam_projection_layer(self.config)
|
240 |
-
# else:
|
241 |
-
# if getattr(self.config, 'sam_path', None) is not None and self.config.sam_path !="":
|
242 |
-
# self.sam = build_sam_vit_h(self.config.sam_path)
|
243 |
-
# self.text2sam_projection = text2sam_projection_layer(self.config)
|
244 |
-
# # In case it is frozen by LoRA
|
245 |
-
# for p in self.mm_projector.parameters():
|
246 |
-
# p.requires_grad = True
|
247 |
-
|
248 |
-
# if pretrain_mm_mlp_adapter is not None:
|
249 |
-
# mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
|
250 |
-
|
251 |
-
# def get_w(weights, keyword):
|
252 |
-
# return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
|
253 |
-
|
254 |
-
# incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
|
255 |
-
# incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
|
256 |
-
|
257 |
|
258 |
def unpad_image(tensor, original_size):
|
259 |
"""
|
|
|
38 |
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
|
39 |
from transformers.modeling_outputs import CausalLMOutputWithPast
|
40 |
from transformers.generation.utils import GenerateOutput
|
41 |
+
from transformers.utils import cached_file
|
42 |
from safetensors.torch import load_file as safetensors_load
|
43 |
from .vision_tower import build_vision_tower
|
44 |
from .vision_resampler import build_vision_resampler
|
|
|
142 |
|
143 |
def dispatch_weight(self, config):
|
144 |
safetensors_set = set()
|
145 |
+
repo = getattr(config, "_name_or_path", "'DeepGlint-AI/MLCD-Seg'")
|
146 |
+
index_file = cached_file(repo, "model.safetensors.index.json")
|
|
|
|
|
147 |
with open(index_file, "r") as safetensors_index:
|
148 |
safetensors_map = json.loads(safetensors_index.read())
|
149 |
for key, value in safetensors_map["weight_map"].items():
|
|
|
155 |
projector_weight = {}
|
156 |
text2sam_projection_weight = {}
|
157 |
for safetensors_file in safetensors_set:
|
158 |
+
temp_load = safetensors_load(cached_file(repo, safetensors_file))
|
159 |
for key, value in temp_load.items():
|
160 |
if key.startswith("model.sam."):
|
161 |
sam_weight[key.replace("model.sam.", "")] = value
|
|
|
173 |
vision_tower = vision_tower[0]
|
174 |
return vision_tower
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
|
177 |
def unpad_image(tensor, original_size):
|
178 |
"""
|