Merge master to main
Browse files- .gitattributes +1 -0
- README.md +67 -3
- config.json +249 -0
- conversation_mlcd_seg.py +579 -0
- merges.txt +0 -0
- mlcd_seg.py +1044 -0
- model-00001-of-00004.safetensors +3 -0
- model-00002-of-00004.safetensors +3 -0
- model-00003-of-00004.safetensors +3 -0
- model-00004-of-00004.safetensors +3 -0
- model.safetensors.index.json +0 -0
- sam.py +1324 -0
- special_tokens_map.json +32 -0
- tokenizer.json +3 -0
- tokenizer_config.json +218 -0
- transform.py +113 -0
- vision_projector.py +57 -0
- vision_resampler.py +23 -0
- vision_tower.py +123 -0
- vocab.json +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
@@ -1,3 +1,67 @@
|
|
1 |
-
---
|
2 |
-
license: apache-2.0
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
base_model:
|
4 |
+
- DeepGlint-AI/MLCD-Embodied-7B
|
5 |
+
---
|
6 |
+
[](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcocog?p=multi-label-cluster-discrimination-for-visual)
|
7 |
+
[](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco-5?p=multi-label-cluster-discrimination-for-visual)
|
8 |
+
[](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco-3?p=multi-label-cluster-discrimination-for-visual)
|
9 |
+
[](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcocog-1?p=multi-label-cluster-discrimination-for-visual)
|
10 |
+
[](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco-8?p=multi-label-cluster-discrimination-for-visual)
|
11 |
+
[](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco-4?p=multi-label-cluster-discrimination-for-visual)
|
12 |
+
[](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco-9?p=multi-label-cluster-discrimination-for-visual)
|
13 |
+
[](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco?p=multi-label-cluster-discrimination-for-visual)
|
14 |
+
[](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco?p=multi-label-cluster-discrimination-for-visual)
|
15 |
+
|
16 |
+
|
17 |
+
## RefCOCO Segmentation Evaluation:
|
18 |
+
|
19 |
+
| Dataset | Split | MLCD-seg-7B | EVF-SAM | GLaMM | VisionLLM v2| LISA |
|
20 |
+
| :-- | :-: | :-: | :-: | :-: | :-: | :-: |
|
21 |
+
| RefCOCO | val | **83.6** | 82.4 | 79.5 | 79.2 | 74.9 |
|
22 |
+
| RefCOCO | testA | **85.3** | 84.2 | 83.2 | 82.3 | 79.1 |
|
23 |
+
| RefCOCO | testB | **81.5** | 80.2 | 76.9 | 77.0 | 72.3 |
|
24 |
+
| RefCOCO+ | val | **79.4** | 76.5 | 72.6 | 68.9 | 65.1 |
|
25 |
+
| RefCOCO+ | testA | **82.9** | 80.0 | 78.7 | 75.8 | 70.8 |
|
26 |
+
| RefCOCO+ | testB | **75.6** | 71.9 | 64.6 | 61.8 | 58.1 |
|
27 |
+
| RefCOCOg | val | **79.7** | 78.2 | 74.2 | 73.3 | 67.9 |
|
28 |
+
| RefCOCOg | test | **80.5** | 78.3 | 74.9 | 74.8 | 70.6 |
|
29 |
+
|
30 |
+
|
31 |
+
## Evaluation
|
32 |
+
|
33 |
+
```python
|
34 |
+
model_path = "DeepGlint-AI/MLCD-Seg" # or use your local path
|
35 |
+
mlcd_seg = AutoModel.from_pretrained(
|
36 |
+
model_path,
|
37 |
+
torch_dtype=torch.float16,
|
38 |
+
trust_remote_code=True
|
39 |
+
).cuda()
|
40 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
41 |
+
# Assuming you have an image named test.jpg
|
42 |
+
seg_img = Image.open("test.jpg").convert('RGB')
|
43 |
+
seg_prompt = "The <image> provides an overview of the picture.\nCould you provide a segmentation mask for the right giraffe in this image?"
|
44 |
+
pred_mask = model.predict_forward(seg_img, seg_prompt, tokenizer, force_seg=False)
|
45 |
+
```
|
46 |
+
|
47 |
+
## Tips for updating this repo in the future
|
48 |
+
|
49 |
+
|
50 |
+
Huggingface uses cache management module code, so manual clearing of cache is required after repo update
|
51 |
+
|
52 |
+
|
53 |
+
```bash
|
54 |
+
cd ~/.cache/huggingface/modules/transformers_modules
|
55 |
+
rm mlcd_seg.py vision_projector.py vision_resampler.py vision_tower.py sam.py conversation_mlcd_seg.py
|
56 |
+
```
|
57 |
+
|
58 |
+
|
59 |
+
## Citations
|
60 |
+
```
|
61 |
+
@misc{mlcdseg_wukun,
|
62 |
+
author = {Wu, Kun and Xie, Yin and Zhou, Xinyu and An, Xiang, and Deng, Jiankang, and Jie, Yu},
|
63 |
+
title = {MLCD-Seg},
|
64 |
+
year = {2025},
|
65 |
+
url = {https://github.com/deepglint/unicom/tree/main/downstream},
|
66 |
+
}
|
67 |
+
```
|
config.json
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "DeepGlint-AI/MLCD-Embodied-7B",
|
3 |
+
"add_faster_video": false,
|
4 |
+
"add_time_instruction": false,
|
5 |
+
"architectures": [
|
6 |
+
"MLCDSegForCausalLM"
|
7 |
+
],
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "mlcd_seg.MLCDSegConfig",
|
10 |
+
"AutoModel": "mlcd_seg.MLCDSegForCausalLM",
|
11 |
+
"AutoModelForCausalLM": "mlcd_seg.MLCDSegForCausalLM"
|
12 |
+
},
|
13 |
+
"attn_implementation": "flash_attention_2",
|
14 |
+
"attention_dropout": 0.0,
|
15 |
+
"bos_token_id": 151643,
|
16 |
+
"eos_token_id": 151645,
|
17 |
+
"faster_token_stride": 10,
|
18 |
+
"force_sample": false,
|
19 |
+
"hidden_act": "silu",
|
20 |
+
"hidden_size": 3584,
|
21 |
+
"image_aspect_ratio": "anyres",
|
22 |
+
"image_crop_resolution": null,
|
23 |
+
"image_grid_pinpoints": [
|
24 |
+
[
|
25 |
+
336,
|
26 |
+
336
|
27 |
+
],
|
28 |
+
[
|
29 |
+
336,
|
30 |
+
672
|
31 |
+
],
|
32 |
+
[
|
33 |
+
336,
|
34 |
+
1008
|
35 |
+
],
|
36 |
+
[
|
37 |
+
336,
|
38 |
+
1344
|
39 |
+
],
|
40 |
+
[
|
41 |
+
336,
|
42 |
+
1680
|
43 |
+
],
|
44 |
+
[
|
45 |
+
336,
|
46 |
+
2016
|
47 |
+
],
|
48 |
+
[
|
49 |
+
672,
|
50 |
+
336
|
51 |
+
],
|
52 |
+
[
|
53 |
+
672,
|
54 |
+
672
|
55 |
+
],
|
56 |
+
[
|
57 |
+
672,
|
58 |
+
1008
|
59 |
+
],
|
60 |
+
[
|
61 |
+
672,
|
62 |
+
1344
|
63 |
+
],
|
64 |
+
[
|
65 |
+
672,
|
66 |
+
1680
|
67 |
+
],
|
68 |
+
[
|
69 |
+
672,
|
70 |
+
2016
|
71 |
+
],
|
72 |
+
[
|
73 |
+
1008,
|
74 |
+
336
|
75 |
+
],
|
76 |
+
[
|
77 |
+
1008,
|
78 |
+
672
|
79 |
+
],
|
80 |
+
[
|
81 |
+
1008,
|
82 |
+
1008
|
83 |
+
],
|
84 |
+
[
|
85 |
+
1008,
|
86 |
+
1344
|
87 |
+
],
|
88 |
+
[
|
89 |
+
1008,
|
90 |
+
1680
|
91 |
+
],
|
92 |
+
[
|
93 |
+
1008,
|
94 |
+
2016
|
95 |
+
],
|
96 |
+
[
|
97 |
+
1344,
|
98 |
+
336
|
99 |
+
],
|
100 |
+
[
|
101 |
+
1344,
|
102 |
+
672
|
103 |
+
],
|
104 |
+
[
|
105 |
+
1344,
|
106 |
+
1008
|
107 |
+
],
|
108 |
+
[
|
109 |
+
1344,
|
110 |
+
1344
|
111 |
+
],
|
112 |
+
[
|
113 |
+
1344,
|
114 |
+
1680
|
115 |
+
],
|
116 |
+
[
|
117 |
+
1344,
|
118 |
+
2016
|
119 |
+
],
|
120 |
+
[
|
121 |
+
1680,
|
122 |
+
336
|
123 |
+
],
|
124 |
+
[
|
125 |
+
1680,
|
126 |
+
672
|
127 |
+
],
|
128 |
+
[
|
129 |
+
1680,
|
130 |
+
1008
|
131 |
+
],
|
132 |
+
[
|
133 |
+
1680,
|
134 |
+
1344
|
135 |
+
],
|
136 |
+
[
|
137 |
+
1680,
|
138 |
+
1680
|
139 |
+
],
|
140 |
+
[
|
141 |
+
1680,
|
142 |
+
2016
|
143 |
+
],
|
144 |
+
[
|
145 |
+
2016,
|
146 |
+
336
|
147 |
+
],
|
148 |
+
[
|
149 |
+
2016,
|
150 |
+
672
|
151 |
+
],
|
152 |
+
[
|
153 |
+
2016,
|
154 |
+
1008
|
155 |
+
],
|
156 |
+
[
|
157 |
+
2016,
|
158 |
+
1344
|
159 |
+
],
|
160 |
+
[
|
161 |
+
2016,
|
162 |
+
1680
|
163 |
+
],
|
164 |
+
[
|
165 |
+
2016,
|
166 |
+
2016
|
167 |
+
]
|
168 |
+
],
|
169 |
+
"image_split_resolution": null,
|
170 |
+
"initializer_range": 0.02,
|
171 |
+
"intermediate_size": 18944,
|
172 |
+
"max_position_embeddings": 32768,
|
173 |
+
"max_window_layers": 28,
|
174 |
+
"mm_hidden_size": 1024,
|
175 |
+
"mm_newline_position": "grid",
|
176 |
+
"mm_patch_merge_type": "spatial_unpad",
|
177 |
+
"mm_projector_lr": null,
|
178 |
+
"mm_projector_type": "mlp2x_gelu",
|
179 |
+
"mm_resampler_type": null,
|
180 |
+
"mm_spatial_pool_mode": "bilinear",
|
181 |
+
"mm_spatial_pool_stride": null,
|
182 |
+
"mm_tunable_parts": "mm_vision_tower,mm_mlp_adapter,mm_language_model,sam",
|
183 |
+
"mm_use_im_patch_token": false,
|
184 |
+
"mm_use_im_start_end": false,
|
185 |
+
"mm_vision_select_feature": "patch",
|
186 |
+
"mm_vision_select_layer": -2,
|
187 |
+
"mm_vision_tower_lr": 2e-06,
|
188 |
+
"vision_tower_config": {
|
189 |
+
"_name_or_path": "",
|
190 |
+
"architectures": [
|
191 |
+
"CLIPVisionModel"
|
192 |
+
],
|
193 |
+
"attention_dropout": 0.0,
|
194 |
+
"hidden_act": "quick_gelu",
|
195 |
+
"hidden_size": 1024,
|
196 |
+
"image_size": 336,
|
197 |
+
"initializer_factor": 1.0,
|
198 |
+
"initializer_range": 0.02,
|
199 |
+
"intermediate_size": 4096,
|
200 |
+
"layer_norm_eps": 1e-05,
|
201 |
+
"model_type": "clip_vision_model",
|
202 |
+
"num_attention_heads": 16,
|
203 |
+
"num_channels": 3,
|
204 |
+
"num_hidden_layers": 24,
|
205 |
+
"patch_size": 14,
|
206 |
+
"projection_dim": 1024,
|
207 |
+
"torch_dtype": "float32",
|
208 |
+
"transformers_version": "4.44.0"
|
209 |
+
},
|
210 |
+
"vision_tower_processor": {
|
211 |
+
"crop_size": 336,
|
212 |
+
"do_center_crop": true,
|
213 |
+
"do_normalize": true,
|
214 |
+
"do_resize": true,
|
215 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
216 |
+
"image_mean": [
|
217 |
+
0.48145466,
|
218 |
+
0.4578275,
|
219 |
+
0.40821073
|
220 |
+
],
|
221 |
+
"image_std": [
|
222 |
+
0.26862954,
|
223 |
+
0.26130258,
|
224 |
+
0.27577711
|
225 |
+
],
|
226 |
+
"resample": 3,
|
227 |
+
"size": 336
|
228 |
+
},
|
229 |
+
"model_type": "qwen2",
|
230 |
+
"num_attention_heads": 28,
|
231 |
+
"num_hidden_layers": 28,
|
232 |
+
"num_key_value_heads": 4,
|
233 |
+
"pos_skipping_range": 4096,
|
234 |
+
"rms_norm_eps": 1e-06,
|
235 |
+
"rope_scaling": null,
|
236 |
+
"rope_theta": 1000000.0,
|
237 |
+
"sliding_window": null,
|
238 |
+
"tie_word_embeddings": false,
|
239 |
+
"tokenizer_model_max_length": 32768,
|
240 |
+
"tokenizer_padding_side": "right",
|
241 |
+
"torch_dtype": "bfloat16",
|
242 |
+
"transformers_version": "4.47.0",
|
243 |
+
"use_cache": true,
|
244 |
+
"use_mm_proj": true,
|
245 |
+
"use_pos_skipping": false,
|
246 |
+
"use_sliding_window": false,
|
247 |
+
"vision_tower_pretrained": null,
|
248 |
+
"vocab_size": 151666
|
249 |
+
}
|
conversation_mlcd_seg.py
ADDED
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
File separation from LLAVA project
|
3 |
+
https://github.com/haotian-liu/LLaVA
|
4 |
+
'''
|
5 |
+
|
6 |
+
|
7 |
+
import dataclasses
|
8 |
+
from enum import auto, Enum
|
9 |
+
from typing import List, Any, Union
|
10 |
+
import re
|
11 |
+
import base64
|
12 |
+
from io import BytesIO
|
13 |
+
from PIL import Image
|
14 |
+
from transformers import AutoTokenizer
|
15 |
+
|
16 |
+
|
17 |
+
class SeparatorStyle(Enum):
|
18 |
+
"""Different separator style."""
|
19 |
+
|
20 |
+
SINGLE = auto()
|
21 |
+
TWO = auto()
|
22 |
+
MPT = auto()
|
23 |
+
PLAIN = auto()
|
24 |
+
CHATML = auto()
|
25 |
+
LLAMA_2 = auto()
|
26 |
+
LLAMA_3 = auto()
|
27 |
+
QWEN = auto()
|
28 |
+
GEMMA = auto()
|
29 |
+
|
30 |
+
|
31 |
+
@dataclasses.dataclass
|
32 |
+
class Conversation:
|
33 |
+
"""A class that keeps all conversation history."""
|
34 |
+
|
35 |
+
system: str
|
36 |
+
roles: List[str]
|
37 |
+
messages: List[List[str]]
|
38 |
+
offset: int
|
39 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
40 |
+
sep: str = "###"
|
41 |
+
sep2: str = None
|
42 |
+
version: str = "Unknown"
|
43 |
+
|
44 |
+
tokenizer_id: str = ""
|
45 |
+
tokenizer: Any = None
|
46 |
+
# Stop criteria (the default one is EOS token)
|
47 |
+
stop_str: Union[str, List[str]] = None
|
48 |
+
# Stops generation if meeting any token in this list
|
49 |
+
stop_token_ids: List[int] = None
|
50 |
+
|
51 |
+
skip_next: bool = False
|
52 |
+
|
53 |
+
def get_prompt(self):
|
54 |
+
messages = self.messages
|
55 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
56 |
+
messages = self.messages.copy()
|
57 |
+
init_role, init_msg = messages[0].copy()
|
58 |
+
init_msg = init_msg[0]
|
59 |
+
if "mmtag" in self.version:
|
60 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
61 |
+
messages[0] = (init_role, init_msg)
|
62 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
63 |
+
messages.insert(1, (self.roles[1], "Received."))
|
64 |
+
elif not init_msg.startswith("<image>"):
|
65 |
+
init_msg = init_msg.replace("<image>", "").strip()
|
66 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
67 |
+
else:
|
68 |
+
messages[0] = (init_role, init_msg)
|
69 |
+
|
70 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
71 |
+
ret = self.system + self.sep
|
72 |
+
for role, message in messages:
|
73 |
+
if message:
|
74 |
+
if type(message) is tuple:
|
75 |
+
message, _, _ = message
|
76 |
+
ret += role + ": " + message + self.sep
|
77 |
+
else:
|
78 |
+
ret += role + ":"
|
79 |
+
|
80 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
81 |
+
seps = [self.sep, self.sep2]
|
82 |
+
ret = self.system + seps[0]
|
83 |
+
for i, (role, message) in enumerate(messages):
|
84 |
+
if message:
|
85 |
+
if type(message) is tuple:
|
86 |
+
message, _, _ = message
|
87 |
+
ret += role + ": " + message + seps[i % 2]
|
88 |
+
else:
|
89 |
+
ret += role + ":"
|
90 |
+
|
91 |
+
elif self.sep_style == SeparatorStyle.CHATML:
|
92 |
+
ret = "" if self.system == "" else self.system + self.sep + "\n"
|
93 |
+
for role, message in messages:
|
94 |
+
if message:
|
95 |
+
if type(message) is tuple:
|
96 |
+
message, images, _ = message
|
97 |
+
message = "<image>" * len(images) + message
|
98 |
+
ret += role + "\n" + message + self.sep + "\n"
|
99 |
+
else:
|
100 |
+
ret += role + "\n"
|
101 |
+
return ret
|
102 |
+
|
103 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
104 |
+
if self.tokenizer is None:
|
105 |
+
raise ValueError("Llama 3 tokenizer is not available. Make sure you have the necessary permissions.")
|
106 |
+
chat_template_messages = [{"role": "system", "content": self.system}]
|
107 |
+
for role, message in messages:
|
108 |
+
if message:
|
109 |
+
if type(message) is tuple:
|
110 |
+
message, images = message
|
111 |
+
message = "<image>" * len(images) + message
|
112 |
+
chat_template_messages.append({"role": role, "content": message})
|
113 |
+
return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
|
114 |
+
|
115 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
116 |
+
ret = self.system + self.sep
|
117 |
+
for role, message in messages:
|
118 |
+
if message:
|
119 |
+
if type(message) is tuple:
|
120 |
+
message, _, _ = message
|
121 |
+
ret += role + message + self.sep
|
122 |
+
else:
|
123 |
+
ret += role
|
124 |
+
|
125 |
+
elif self.sep_style == SeparatorStyle.GEMMA:
|
126 |
+
ret = ""
|
127 |
+
for i, (role, message) in enumerate(messages):
|
128 |
+
assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
|
129 |
+
if message:
|
130 |
+
if type(message) is tuple:
|
131 |
+
message, _, _ = message
|
132 |
+
ret += role + message + self.sep
|
133 |
+
else:
|
134 |
+
ret += role
|
135 |
+
|
136 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
137 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
138 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
139 |
+
ret = ""
|
140 |
+
|
141 |
+
for i, (role, message) in enumerate(messages):
|
142 |
+
if i == 0:
|
143 |
+
assert message, "first message should not be none"
|
144 |
+
assert role == self.roles[0], "first message should come from user"
|
145 |
+
if message:
|
146 |
+
if type(message) is tuple:
|
147 |
+
message, _, _ = message
|
148 |
+
if i == 0:
|
149 |
+
message = wrap_sys(self.system) + message
|
150 |
+
if i % 2 == 0:
|
151 |
+
message = wrap_inst(message)
|
152 |
+
ret += self.sep + message
|
153 |
+
else:
|
154 |
+
ret += " " + message + " " + self.sep2
|
155 |
+
else:
|
156 |
+
ret += ""
|
157 |
+
ret = ret.lstrip(self.sep)
|
158 |
+
|
159 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
160 |
+
seps = [self.sep, self.sep2]
|
161 |
+
ret = self.system
|
162 |
+
for i, (role, message) in enumerate(messages):
|
163 |
+
if message:
|
164 |
+
if type(message) is tuple:
|
165 |
+
message, _, _ = message
|
166 |
+
ret += message + seps[i % 2]
|
167 |
+
else:
|
168 |
+
ret += ""
|
169 |
+
else:
|
170 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
171 |
+
|
172 |
+
return ret
|
173 |
+
|
174 |
+
def append_message(self, role, message):
|
175 |
+
self.messages.append([role, message])
|
176 |
+
|
177 |
+
def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
|
178 |
+
if image_process_mode == "Pad":
|
179 |
+
|
180 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
181 |
+
width, height = pil_img.size
|
182 |
+
if width == height:
|
183 |
+
return pil_img
|
184 |
+
elif width > height:
|
185 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
186 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
187 |
+
return result
|
188 |
+
else:
|
189 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
190 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
191 |
+
return result
|
192 |
+
|
193 |
+
image = expand2square(image)
|
194 |
+
elif image_process_mode in ["Default", "Crop"]:
|
195 |
+
pass
|
196 |
+
elif image_process_mode == "Resize":
|
197 |
+
image = image.resize((336, 336))
|
198 |
+
else:
|
199 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
200 |
+
|
201 |
+
if type(image) is not Image.Image:
|
202 |
+
image = Image.open(image).convert("RGB")
|
203 |
+
|
204 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
205 |
+
aspect_ratio = max_hw / min_hw
|
206 |
+
max_len, min_len = 672, 448
|
207 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
208 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
209 |
+
W, H = image.size
|
210 |
+
if H > W:
|
211 |
+
H, W = longest_edge, shortest_edge
|
212 |
+
else:
|
213 |
+
H, W = shortest_edge, longest_edge
|
214 |
+
image = image.resize((W, H))
|
215 |
+
if return_pil:
|
216 |
+
return image
|
217 |
+
else:
|
218 |
+
buffered = BytesIO()
|
219 |
+
image.save(buffered, format=image_format)
|
220 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
221 |
+
return img_b64_str
|
222 |
+
|
223 |
+
def get_images(self, return_pil=False, return_path=False):
|
224 |
+
images = []
|
225 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
226 |
+
if i % 2 == 0:
|
227 |
+
if type(msg) is tuple:
|
228 |
+
msg, image, image_process_mode = msg
|
229 |
+
if type(image) != list:
|
230 |
+
image = [image]
|
231 |
+
for img in image:
|
232 |
+
if not return_path and self.is_image_file(img):
|
233 |
+
img = self.process_image(img, image_process_mode, return_pil=return_pil)
|
234 |
+
else:
|
235 |
+
images.append(img)
|
236 |
+
return images
|
237 |
+
|
238 |
+
def is_image_file(self, filename):
|
239 |
+
image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"]
|
240 |
+
return any(filename.lower().endswith(ext) for ext in image_extensions)
|
241 |
+
|
242 |
+
def is_video_file(self, filename):
|
243 |
+
video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".mpeg", ".mpg"]
|
244 |
+
return any(filename.lower().endswith(ext) for ext in video_extensions)
|
245 |
+
|
246 |
+
def to_gradio_chatbot(self):
|
247 |
+
ret = []
|
248 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
249 |
+
if i % 2 == 0:
|
250 |
+
if type(msg) is tuple:
|
251 |
+
msg, image, image_process_mode = msg
|
252 |
+
if type(image) != list:
|
253 |
+
image = [image]
|
254 |
+
if len(image) == 1:
|
255 |
+
msg = "<image>\n" + msg.replace("<image>", "").strip()
|
256 |
+
else:
|
257 |
+
msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
|
258 |
+
|
259 |
+
img_str_list = []
|
260 |
+
for img in image:
|
261 |
+
if self.is_image_file(img):
|
262 |
+
img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
|
263 |
+
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" style="max-width: 256px; max-height: 256px; width: auto; height: auto; object-fit: contain;"/>'
|
264 |
+
img_str_list.append(img_str)
|
265 |
+
elif self.is_video_file(img):
|
266 |
+
ret.append(((img,), None))
|
267 |
+
|
268 |
+
msg = msg.strip()
|
269 |
+
img_place_holder = ""
|
270 |
+
for img_str in img_str_list:
|
271 |
+
img_place_holder += f"{img_str}\n\n"
|
272 |
+
|
273 |
+
if len(img_str_list) > 0:
|
274 |
+
msg = f"{img_place_holder}\n\n{msg}"
|
275 |
+
|
276 |
+
if len(msg) > 0:
|
277 |
+
ret.append([msg, None])
|
278 |
+
else:
|
279 |
+
ret.append([msg, None])
|
280 |
+
else:
|
281 |
+
ret[-1][-1] = msg
|
282 |
+
return ret
|
283 |
+
|
284 |
+
def copy(self):
|
285 |
+
return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
|
286 |
+
|
287 |
+
def dict(self):
|
288 |
+
if len(self.get_images()) > 0:
|
289 |
+
return {
|
290 |
+
"system": self.system,
|
291 |
+
"roles": self.roles,
|
292 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
293 |
+
"offset": self.offset,
|
294 |
+
"sep": self.sep,
|
295 |
+
"sep2": self.sep2,
|
296 |
+
}
|
297 |
+
return {
|
298 |
+
"system": self.system,
|
299 |
+
"roles": self.roles,
|
300 |
+
"messages": self.messages,
|
301 |
+
"offset": self.offset,
|
302 |
+
"sep": self.sep,
|
303 |
+
"sep2": self.sep2,
|
304 |
+
}
|
305 |
+
|
306 |
+
|
307 |
+
conv_vicuna_v0 = Conversation(
|
308 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
309 |
+
roles=("Human", "Assistant"),
|
310 |
+
messages=[
|
311 |
+
["Human", "What are the key differences between renewable and non-renewable energy sources?"],
|
312 |
+
[
|
313 |
+
"Assistant",
|
314 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
315 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
316 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
317 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
318 |
+
"renewable and non-renewable energy sources:\n"
|
319 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
320 |
+
"energy sources are finite and will eventually run out.\n"
|
321 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
322 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
323 |
+
"and other negative effects.\n"
|
324 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
325 |
+
"have lower operational costs than non-renewable sources.\n"
|
326 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
327 |
+
"locations than non-renewable sources.\n"
|
328 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
329 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
330 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
331 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
|
332 |
+
],
|
333 |
+
],
|
334 |
+
offset=2,
|
335 |
+
sep_style=SeparatorStyle.SINGLE,
|
336 |
+
sep="###",
|
337 |
+
)
|
338 |
+
|
339 |
+
conv_vicuna_v1 = Conversation(
|
340 |
+
system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
341 |
+
roles=("USER", "ASSISTANT"),
|
342 |
+
version="v1",
|
343 |
+
messages=[],
|
344 |
+
offset=0,
|
345 |
+
sep_style=SeparatorStyle.TWO,
|
346 |
+
sep=" ",
|
347 |
+
sep2="</s>",
|
348 |
+
)
|
349 |
+
|
350 |
+
conv_llama_2 = Conversation(
|
351 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
352 |
+
|
353 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
354 |
+
roles=("USER", "ASSISTANT"),
|
355 |
+
version="llama_v2",
|
356 |
+
messages=[],
|
357 |
+
offset=0,
|
358 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
359 |
+
sep="<s>",
|
360 |
+
sep2="</s>",
|
361 |
+
)
|
362 |
+
|
363 |
+
conv_llava_llama_2 = Conversation(
|
364 |
+
system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
365 |
+
roles=("USER", "ASSISTANT"),
|
366 |
+
version="llama_v2",
|
367 |
+
messages=[],
|
368 |
+
offset=0,
|
369 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
370 |
+
sep="<s>",
|
371 |
+
sep2="</s>",
|
372 |
+
)
|
373 |
+
|
374 |
+
def safe_load_tokenizer(tokenizer_id):
|
375 |
+
try:
|
376 |
+
return AutoTokenizer.from_pretrained(tokenizer_id)
|
377 |
+
except Exception:
|
378 |
+
return None
|
379 |
+
|
380 |
+
conv_llava_llama_3 = Conversation(
|
381 |
+
system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
|
382 |
+
roles=("user", "assistant"),
|
383 |
+
version="llama_v3",
|
384 |
+
messages=[],
|
385 |
+
offset=0,
|
386 |
+
sep="<|eot_id|>",
|
387 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
388 |
+
tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
|
389 |
+
tokenizer=safe_load_tokenizer("meta-llama/Meta-Llama-3-8B-Instruct"),
|
390 |
+
stop_token_ids=[128009],
|
391 |
+
)
|
392 |
+
|
393 |
+
conv_mistral_instruct = Conversation(
|
394 |
+
system="",
|
395 |
+
roles=("USER", "ASSISTANT"),
|
396 |
+
version="llama_v2",
|
397 |
+
messages=[],
|
398 |
+
offset=0,
|
399 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
400 |
+
sep="",
|
401 |
+
sep2="</s>",
|
402 |
+
)
|
403 |
+
|
404 |
+
conv_llava_llama_2_simple = Conversation(
|
405 |
+
system="Answer the questions about the visual content that the user provides.",
|
406 |
+
roles=("USER", "ASSISTANT"),
|
407 |
+
version="llama_v2",
|
408 |
+
messages=[],
|
409 |
+
offset=0,
|
410 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
411 |
+
sep="<s>",
|
412 |
+
sep2="</s>",
|
413 |
+
)
|
414 |
+
|
415 |
+
conv_llava_llama_2_mmtag = Conversation(
|
416 |
+
system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: <Image>visual content</Image>.",
|
417 |
+
roles=("USER", "ASSISTANT"),
|
418 |
+
version="llama_v2_mmtag",
|
419 |
+
messages=[],
|
420 |
+
offset=0,
|
421 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
422 |
+
sep="<s>",
|
423 |
+
sep2="</s>",
|
424 |
+
)
|
425 |
+
|
426 |
+
conv_mpt = Conversation(
|
427 |
+
system="""<|im_start|>system
|
428 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
429 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
430 |
+
version="mpt",
|
431 |
+
messages=[],
|
432 |
+
offset=0,
|
433 |
+
sep_style=SeparatorStyle.MPT,
|
434 |
+
sep="<|im_end|>",
|
435 |
+
)
|
436 |
+
|
437 |
+
conv_qwen = Conversation(
|
438 |
+
system="""<|im_start|>system
|
439 |
+
You are a helpful assistant.""",
|
440 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
441 |
+
version="qwen",
|
442 |
+
messages=[],
|
443 |
+
offset=0,
|
444 |
+
sep_style=SeparatorStyle.CHATML,
|
445 |
+
sep="<|im_end|>",
|
446 |
+
)
|
447 |
+
|
448 |
+
conv_gemma_instruct = Conversation(system="", roles=("<start_of_turn>user\n", "<start_of_turn>model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="<end_of_turn>\n")
|
449 |
+
|
450 |
+
conv_llava_plain = Conversation(
|
451 |
+
system="",
|
452 |
+
roles=("", ""),
|
453 |
+
messages=[],
|
454 |
+
offset=0,
|
455 |
+
sep_style=SeparatorStyle.PLAIN,
|
456 |
+
sep="\n",
|
457 |
+
)
|
458 |
+
|
459 |
+
conv_llava_v0 = Conversation(
|
460 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
461 |
+
roles=("Human", "Assistant"),
|
462 |
+
messages=[],
|
463 |
+
offset=0,
|
464 |
+
sep_style=SeparatorStyle.SINGLE,
|
465 |
+
sep="###",
|
466 |
+
)
|
467 |
+
|
468 |
+
conv_llava_v0_mmtag = Conversation(
|
469 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
470 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
471 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
472 |
+
roles=("Human", "Assistant"),
|
473 |
+
messages=[],
|
474 |
+
offset=0,
|
475 |
+
sep_style=SeparatorStyle.SINGLE,
|
476 |
+
sep="###",
|
477 |
+
version="v0_mmtag",
|
478 |
+
)
|
479 |
+
|
480 |
+
conv_llava_v1 = Conversation(
|
481 |
+
system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
482 |
+
roles=("USER", "ASSISTANT"),
|
483 |
+
version="v1",
|
484 |
+
messages=[],
|
485 |
+
offset=0,
|
486 |
+
sep_style=SeparatorStyle.TWO,
|
487 |
+
sep=" ",
|
488 |
+
sep2="</s>",
|
489 |
+
)
|
490 |
+
|
491 |
+
conv_llava_v1_mmtag = Conversation(
|
492 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
493 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
494 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
495 |
+
roles=("USER", "ASSISTANT"),
|
496 |
+
messages=[],
|
497 |
+
offset=0,
|
498 |
+
sep_style=SeparatorStyle.TWO,
|
499 |
+
sep=" ",
|
500 |
+
sep2="</s>",
|
501 |
+
version="v1_mmtag",
|
502 |
+
)
|
503 |
+
|
504 |
+
conv_mistral_orca = Conversation(
|
505 |
+
system="""<|im_start|>system
|
506 |
+
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
|
507 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
508 |
+
version="mpt",
|
509 |
+
messages=[],
|
510 |
+
offset=0,
|
511 |
+
sep_style=SeparatorStyle.MPT,
|
512 |
+
sep="<|im_end|>",
|
513 |
+
)
|
514 |
+
|
515 |
+
conv_mistral_zephyr = Conversation(
|
516 |
+
system="""<|system|>
|
517 |
+
You are a helpful AI assistant.""",
|
518 |
+
roles=("<|user|>\n", "<|assistant|>\n"),
|
519 |
+
version="mpt",
|
520 |
+
messages=[],
|
521 |
+
offset=0,
|
522 |
+
sep_style=SeparatorStyle.MPT,
|
523 |
+
sep="</s>",
|
524 |
+
)
|
525 |
+
|
526 |
+
conv_mistral_direct = Conversation(
|
527 |
+
system="""<|im_start|>system
|
528 |
+
Answer the questions.""",
|
529 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
530 |
+
version="mpt",
|
531 |
+
messages=[],
|
532 |
+
offset=0,
|
533 |
+
sep_style=SeparatorStyle.MPT,
|
534 |
+
sep="<|im_end|>",
|
535 |
+
)
|
536 |
+
|
537 |
+
conv_chatml_direct = Conversation(
|
538 |
+
system="""<|im_start|>system
|
539 |
+
Answer the questions.""",
|
540 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
541 |
+
version="mpt",
|
542 |
+
messages=[],
|
543 |
+
offset=0,
|
544 |
+
sep_style=SeparatorStyle.MPT,
|
545 |
+
sep="<|im_end|>",
|
546 |
+
)
|
547 |
+
|
548 |
+
default_conversation = conv_vicuna_v0
|
549 |
+
conv_templates = {
|
550 |
+
"default": conv_vicuna_v0,
|
551 |
+
"v0": conv_vicuna_v0,
|
552 |
+
"v1": conv_vicuna_v1,
|
553 |
+
"vicuna_v1": conv_vicuna_v1,
|
554 |
+
"llama_2": conv_llama_2,
|
555 |
+
"mistral_instruct": conv_mistral_instruct,
|
556 |
+
"mistral_orca": conv_mistral_orca,
|
557 |
+
"mistral_zephyr": conv_mistral_zephyr,
|
558 |
+
"mistral_direct": conv_mistral_direct,
|
559 |
+
"plain": conv_llava_plain,
|
560 |
+
"v0_plain": conv_llava_plain,
|
561 |
+
"chatml_direct": conv_chatml_direct,
|
562 |
+
"llava_v0": conv_llava_v0,
|
563 |
+
"llava_v0_mmtag": conv_llava_v0_mmtag,
|
564 |
+
"llava_v1": conv_llava_v1,
|
565 |
+
"llava_v1_mmtag": conv_llava_v1_mmtag,
|
566 |
+
"llava_llama_2": conv_llava_llama_2,
|
567 |
+
"llava_llama_3": conv_llava_llama_3,
|
568 |
+
"llava_llama_2_simple": conv_llava_llama_2_simple,
|
569 |
+
"llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
|
570 |
+
"llava_mistral_instruct": conv_mistral_instruct,
|
571 |
+
"mpt": conv_mpt,
|
572 |
+
"qwen_1_5": conv_qwen,
|
573 |
+
"qwen_2": conv_qwen,
|
574 |
+
"gemma_instruct": conv_gemma_instruct,
|
575 |
+
}
|
576 |
+
|
577 |
+
|
578 |
+
if __name__ == "__main__":
|
579 |
+
print(default_conversation.get_prompt())
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
mlcd_seg.py
ADDED
@@ -0,0 +1,1044 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
File modification from LLAVA project @DeepGlintAI 2025
|
3 |
+
https://github.com/haotian-liu/LLaVA
|
4 |
+
|
5 |
+
origin copyright:
|
6 |
+
|
7 |
+
Copyright 2023 Haotian Liu
|
8 |
+
|
9 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
10 |
+
you may not use this file except in compliance with the License.
|
11 |
+
You may obtain a copy of the License at
|
12 |
+
|
13 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
14 |
+
|
15 |
+
Unless required by applicable law or agreed to in writing, software
|
16 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
17 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
+
See the License for the specific language governing permissions and
|
19 |
+
limitations under the License.
|
20 |
+
'''
|
21 |
+
|
22 |
+
|
23 |
+
from abc import ABC, abstractmethod
|
24 |
+
|
25 |
+
import math
|
26 |
+
import random
|
27 |
+
import ast
|
28 |
+
import re
|
29 |
+
import json
|
30 |
+
import numpy as np
|
31 |
+
import torch
|
32 |
+
import torch.nn as nn
|
33 |
+
import torch.nn.functional as F
|
34 |
+
from pathlib import Path
|
35 |
+
from dataclasses import dataclass
|
36 |
+
from PIL import Image
|
37 |
+
from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
|
38 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
39 |
+
from transformers.generation.utils import GenerateOutput
|
40 |
+
from safetensors.torch import load_file as safetensors_load
|
41 |
+
from .vision_tower import build_vision_tower
|
42 |
+
from .vision_resampler import build_vision_resampler
|
43 |
+
from .vision_projector import build_vision_projector
|
44 |
+
from .sam import build_sam_vit_h, text2sam_projection_layer
|
45 |
+
from .conversation_mlcd_seg import default_conversation
|
46 |
+
from .transform import ResizeLongestSide
|
47 |
+
from typing import Optional, Any, List, Tuple, Union, Dict
|
48 |
+
|
49 |
+
IGNORE_INDEX = -100
|
50 |
+
IMAGE_TOKEN_INDEX = -200
|
51 |
+
DEFAULT_SEG_TOKEN = "[SEG]"
|
52 |
+
|
53 |
+
IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
|
54 |
+
IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
|
55 |
+
IMG_SIZE = 1024
|
56 |
+
|
57 |
+
def select_best_resolution(original_size, possible_resolutions):
|
58 |
+
"""
|
59 |
+
Selects the best resolution from a list of possible resolutions based on the original size.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
original_size (tuple): The original size of the image in the format (width, height).
|
63 |
+
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
tuple: The best fit resolution in the format (width, height).
|
67 |
+
"""
|
68 |
+
original_width, original_height = original_size
|
69 |
+
best_fit = None
|
70 |
+
max_effective_resolution = 0
|
71 |
+
min_wasted_resolution = float("inf")
|
72 |
+
|
73 |
+
for width, height in possible_resolutions:
|
74 |
+
# Calculate the downscaled size to keep the aspect ratio
|
75 |
+
scale = min(width / original_width, height / original_height)
|
76 |
+
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
|
77 |
+
|
78 |
+
# Calculate effective and wasted resolutions
|
79 |
+
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
|
80 |
+
wasted_resolution = (width * height) - effective_resolution
|
81 |
+
|
82 |
+
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
|
83 |
+
max_effective_resolution = effective_resolution
|
84 |
+
min_wasted_resolution = wasted_resolution
|
85 |
+
best_fit = (width, height)
|
86 |
+
|
87 |
+
return best_fit
|
88 |
+
|
89 |
+
|
90 |
+
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
91 |
+
"""
|
92 |
+
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
image_size (tuple): The size of the input image in the format (width, height).
|
96 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
97 |
+
patch_size (int): The size of each image patch.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
tuple: The shape of the image patch grid in the format (width, height).
|
101 |
+
"""
|
102 |
+
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
103 |
+
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
|
104 |
+
# Use regex to extract the range from the input string
|
105 |
+
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
106 |
+
range_start = tuple(map(int, matches[0]))
|
107 |
+
range_end = tuple(map(int, matches[-1]))
|
108 |
+
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
109 |
+
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
|
110 |
+
# Multiply all elements by patch_size
|
111 |
+
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
112 |
+
if type(grid_pinpoints) is list:
|
113 |
+
possible_resolutions = grid_pinpoints
|
114 |
+
else:
|
115 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
116 |
+
width, height = select_best_resolution(image_size, possible_resolutions)
|
117 |
+
return width // patch_size, height // patch_size
|
118 |
+
|
119 |
+
|
120 |
+
class MLCDSegMetaModel:
|
121 |
+
|
122 |
+
def __init__(self, config):
|
123 |
+
super(MLCDSegMetaModel, self).__init__(config)
|
124 |
+
|
125 |
+
if hasattr(config, "vision_tower_config"):
|
126 |
+
vision_tower_weight, sam_weight, projector_weight, text2sam_projection_weight = self.dispatch_weight(config)
|
127 |
+
delay_load = getattr(config, "delay_load", False)
|
128 |
+
self.vision_tower = build_vision_tower(config, delay_load=delay_load)
|
129 |
+
self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
|
130 |
+
self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
|
131 |
+
self.vision_tower.vision_tower.load_state_dict(vision_tower_weight)
|
132 |
+
self.mm_projector.load_state_dict(projector_weight)
|
133 |
+
self.sam = build_sam_vit_h()
|
134 |
+
self.sam.load_state_dict(sam_weight)
|
135 |
+
self.text2sam_projection = text2sam_projection_layer(config)
|
136 |
+
self.text2sam_projection.load_state_dict(text2sam_projection_weight)
|
137 |
+
|
138 |
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
139 |
+
self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
|
140 |
+
|
141 |
+
def dispatch_weight(self, config):
|
142 |
+
safetensors_set = set()
|
143 |
+
index_file = Path(getattr(config, "name_or_path", "./")) / "model.safetensors.index.json"
|
144 |
+
with open(index_file, "r") as safetensors_index:
|
145 |
+
safetensors_map = json.loads(safetensors_index.read())
|
146 |
+
for key, value in safetensors_map["weight_map"].items():
|
147 |
+
if key.startswith("model.vision_tower.vision_tower") or key.startswith("model.sam") or \
|
148 |
+
key.startswith("model.mm_projector") or key.startswith("model.text2sam_projection"):
|
149 |
+
safetensors_set.add(value)
|
150 |
+
vision_tower_weight = {}
|
151 |
+
sam_weight = {}
|
152 |
+
projector_weight = {}
|
153 |
+
text2sam_projection_weight = {}
|
154 |
+
for safetensors_file in safetensors_set:
|
155 |
+
temp_load = safetensors_load(safetensors_file)
|
156 |
+
for key, value in temp_load.items():
|
157 |
+
if key.startswith("model.sam."):
|
158 |
+
sam_weight[key.replace("model.sam.", "")] = value
|
159 |
+
if key.startswith("model.vision_tower.vision_tower."):
|
160 |
+
vision_tower_weight[key.replace("model.vision_tower.vision_tower.", "")] = value
|
161 |
+
if key.startswith("model.mm_projector."):
|
162 |
+
projector_weight[key.replace("model.mm_projector.", "")] = value
|
163 |
+
if key.startswith("model.text2sam_projection."):
|
164 |
+
text2sam_projection_weight[key.replace("model.text2sam_projection.", "")] = value
|
165 |
+
return vision_tower_weight, sam_weight, projector_weight, text2sam_projection_weight
|
166 |
+
|
167 |
+
def get_vision_tower(self):
|
168 |
+
vision_tower = getattr(self, "vision_tower", None)
|
169 |
+
if type(vision_tower) is list:
|
170 |
+
vision_tower = vision_tower[0]
|
171 |
+
return vision_tower
|
172 |
+
|
173 |
+
def initialize_vision_modules(self, model_args, fsdp=None):
|
174 |
+
vision_tower = model_args.vision_tower
|
175 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
176 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
177 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
178 |
+
mm_patch_merge_type = model_args.mm_patch_merge_type
|
179 |
+
|
180 |
+
self.config.mm_vision_tower = vision_tower
|
181 |
+
self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
|
182 |
+
|
183 |
+
if self.get_vision_tower() is None:
|
184 |
+
vision_tower = build_vision_tower(model_args)
|
185 |
+
vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
|
186 |
+
for k, v in vision_resampler.config.items():
|
187 |
+
setattr(self.config, k, v)
|
188 |
+
|
189 |
+
if fsdp is not None and len(fsdp) > 0:
|
190 |
+
self.vision_tower = [vision_tower]
|
191 |
+
self.vision_resampler = [vision_resampler]
|
192 |
+
else:
|
193 |
+
self.vision_tower = vision_tower
|
194 |
+
self.vision_resampler = vision_resampler
|
195 |
+
else:
|
196 |
+
if fsdp is not None and len(fsdp) > 0:
|
197 |
+
vision_resampler = self.vision_resampler[0]
|
198 |
+
vision_tower = self.vision_tower[0]
|
199 |
+
else:
|
200 |
+
vision_resampler = self.vision_resampler
|
201 |
+
vision_tower = self.vision_tower
|
202 |
+
vision_tower.load_model()
|
203 |
+
|
204 |
+
# In case it is frozen by LoRA
|
205 |
+
for p in self.vision_resampler.parameters():
|
206 |
+
p.requires_grad = True
|
207 |
+
|
208 |
+
self.config.use_mm_proj = True
|
209 |
+
self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
|
210 |
+
self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
|
211 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
212 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
213 |
+
self.config.mm_patch_merge_type = mm_patch_merge_type
|
214 |
+
|
215 |
+
for key in vars(model_args):
|
216 |
+
if key.startswith('sam_'):
|
217 |
+
setattr(self.config, key, getattr(model_args, key))
|
218 |
+
|
219 |
+
if not hasattr(self.config, 'add_faster_video'):
|
220 |
+
if model_args.add_faster_video:
|
221 |
+
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
222 |
+
self.faster_token = nn.Parameter(
|
223 |
+
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
|
224 |
+
)
|
225 |
+
|
226 |
+
if getattr(self, "mm_projector", None) is None:
|
227 |
+
self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
|
228 |
+
|
229 |
+
if "unpad" in mm_patch_merge_type:
|
230 |
+
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
|
231 |
+
self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
|
232 |
+
|
233 |
+
if getattr(self.config, 'sam_path', None) is not None:
|
234 |
+
self.sam = build_sam_vit_h(self.config.sam_path)
|
235 |
+
self.text2sam_projection = text2sam_projection_layer(self.config)
|
236 |
+
else:
|
237 |
+
if getattr(self.config, 'sam_path', None) is not None and self.config.sam_path !="":
|
238 |
+
self.sam = build_sam_vit_h(self.config.sam_path)
|
239 |
+
self.text2sam_projection = text2sam_projection_layer(self.config)
|
240 |
+
# In case it is frozen by LoRA
|
241 |
+
for p in self.mm_projector.parameters():
|
242 |
+
p.requires_grad = True
|
243 |
+
|
244 |
+
if pretrain_mm_mlp_adapter is not None:
|
245 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
|
246 |
+
|
247 |
+
def get_w(weights, keyword):
|
248 |
+
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
|
249 |
+
|
250 |
+
incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
|
251 |
+
incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
|
252 |
+
|
253 |
+
|
254 |
+
def unpad_image(tensor, original_size):
|
255 |
+
"""
|
256 |
+
Unpads a PyTorch tensor of a padded and resized image.
|
257 |
+
|
258 |
+
Args:
|
259 |
+
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
|
260 |
+
original_size (tuple): The original size of the image (height, width).
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
torch.Tensor: The unpadded image tensor.
|
264 |
+
"""
|
265 |
+
original_width, original_height = original_size
|
266 |
+
current_height, current_width = tensor.shape[1:]
|
267 |
+
|
268 |
+
# Compute aspect ratios
|
269 |
+
original_aspect_ratio = original_width / original_height
|
270 |
+
current_aspect_ratio = current_width / current_height
|
271 |
+
|
272 |
+
# Determine padding size and direction
|
273 |
+
if original_aspect_ratio > current_aspect_ratio:
|
274 |
+
# Padding was added to the height
|
275 |
+
scale_factor = current_width / original_width
|
276 |
+
new_height = int(original_height * scale_factor)
|
277 |
+
padding = (current_height - new_height) // 2
|
278 |
+
unpadded_tensor = tensor[:, padding : current_height - padding, :]
|
279 |
+
else:
|
280 |
+
# Padding was added to the width
|
281 |
+
scale_factor = current_height / original_height
|
282 |
+
new_width = int(original_width * scale_factor)
|
283 |
+
padding = (current_width - new_width) // 2
|
284 |
+
unpadded_tensor = tensor[:, :, padding : current_width - padding]
|
285 |
+
|
286 |
+
return unpadded_tensor
|
287 |
+
|
288 |
+
|
289 |
+
def resize_and_pad_image(image, target_resolution):
|
290 |
+
"""
|
291 |
+
Resize and pad an image to a target resolution while maintaining aspect ratio.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
image (PIL.Image.Image): The input image.
|
295 |
+
target_resolution (tuple): The target resolution (width, height) of the image.
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
PIL.Image.Image: The resized and padded image.
|
299 |
+
"""
|
300 |
+
original_width, original_height = image.size
|
301 |
+
target_width, target_height = target_resolution
|
302 |
+
|
303 |
+
# Determine which dimension (width or height) to fill
|
304 |
+
scale_w = target_width / original_width
|
305 |
+
scale_h = target_height / original_height
|
306 |
+
|
307 |
+
if scale_w < scale_h:
|
308 |
+
# Width will be filled completely
|
309 |
+
new_width = target_width
|
310 |
+
new_height = min(math.ceil(original_height * scale_w), target_height)
|
311 |
+
else:
|
312 |
+
# Height will be filled completely
|
313 |
+
new_height = target_height
|
314 |
+
new_width = min(math.ceil(original_width * scale_h), target_width)
|
315 |
+
|
316 |
+
# Resize the image
|
317 |
+
resized_image = image.resize((new_width, new_height))
|
318 |
+
|
319 |
+
# Create a new image with the target size and paste the resized image onto it
|
320 |
+
new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
|
321 |
+
paste_x = (target_width - new_width) // 2
|
322 |
+
paste_y = (target_height - new_height) // 2
|
323 |
+
new_image.paste(resized_image, (paste_x, paste_y))
|
324 |
+
|
325 |
+
return new_image
|
326 |
+
|
327 |
+
|
328 |
+
def divide_to_patches(image, patch_size):
|
329 |
+
"""
|
330 |
+
Divides an image into patches of a specified size.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
image (PIL.Image.Image): The input image.
|
334 |
+
patch_size (int): The size of each patch.
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
list: A list of PIL.Image.Image objects representing the patches.
|
338 |
+
"""
|
339 |
+
patches = []
|
340 |
+
width, height = image.size
|
341 |
+
for i in range(0, height, patch_size):
|
342 |
+
for j in range(0, width, patch_size):
|
343 |
+
box = (j, i, j + patch_size, i + patch_size)
|
344 |
+
patch = image.crop(box)
|
345 |
+
patches.append(patch)
|
346 |
+
|
347 |
+
return patches
|
348 |
+
|
349 |
+
|
350 |
+
def process_anyres_image(image, processor, grid_pinpoints):
|
351 |
+
"""
|
352 |
+
Process an image with variable resolutions.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
image (PIL.Image.Image): The input image to be processed.
|
356 |
+
processor: The image processor object.
|
357 |
+
grid_pinpoints (str): A string representation of a list of possible resolutions.
|
358 |
+
|
359 |
+
Returns:
|
360 |
+
torch.Tensor: A tensor containing the processed image patches.
|
361 |
+
"""
|
362 |
+
# Convert grid_pinpoints from string to list
|
363 |
+
if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
|
364 |
+
try:
|
365 |
+
patch_size = processor.size[0]
|
366 |
+
except Exception as e:
|
367 |
+
patch_size = processor.size["shortest_edge"]
|
368 |
+
assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
|
369 |
+
# Use regex to extract the range from the input string
|
370 |
+
matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
|
371 |
+
range_start = tuple(map(int, matches[0]))
|
372 |
+
range_end = tuple(map(int, matches[-1]))
|
373 |
+
# Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
|
374 |
+
grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
|
375 |
+
# Multiply all elements by patch_size
|
376 |
+
grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
|
377 |
+
|
378 |
+
if type(grid_pinpoints) is list:
|
379 |
+
possible_resolutions = grid_pinpoints
|
380 |
+
else:
|
381 |
+
possible_resolutions = ast.literal_eval(grid_pinpoints)
|
382 |
+
best_resolution = select_best_resolution(image.size, possible_resolutions)
|
383 |
+
image_padded = resize_and_pad_image(image, best_resolution)
|
384 |
+
|
385 |
+
patches = divide_to_patches(image_padded, processor.crop_size["height"])
|
386 |
+
|
387 |
+
# FIXME: this seems to be a bug that it resizes instead of pad.
|
388 |
+
# but to keep it consistent with previous, i will keep it as it is
|
389 |
+
# TODO: uncomment below to ablate with the padding
|
390 |
+
if isinstance(processor.size, dict):
|
391 |
+
shortest_edge = processor.size["shortest_edge"]
|
392 |
+
else:
|
393 |
+
shortest_edge = min(processor.size)
|
394 |
+
image_original_resize = image.resize((shortest_edge, shortest_edge))
|
395 |
+
# image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
|
396 |
+
# image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
|
397 |
+
|
398 |
+
image_patches = [image_original_resize] + patches
|
399 |
+
image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
|
400 |
+
return torch.stack(image_patches, dim=0)
|
401 |
+
|
402 |
+
|
403 |
+
class MLCDSegMetaForCausalLM(ABC):
|
404 |
+
|
405 |
+
@abstractmethod
|
406 |
+
def get_model(self):
|
407 |
+
pass
|
408 |
+
|
409 |
+
def get_vision_tower(self):
|
410 |
+
return self.get_model().get_vision_tower()
|
411 |
+
|
412 |
+
def get_2dPool(self, image_feature, stride=2):
|
413 |
+
height = width = self.get_vision_tower().num_patches_per_side
|
414 |
+
num_frames, num_tokens, num_dim = image_feature.shape
|
415 |
+
image_feature = image_feature.view(num_frames, height, width, -1)
|
416 |
+
image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
|
417 |
+
# image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
|
418 |
+
if self.config.mm_spatial_pool_mode == "average":
|
419 |
+
image_feature = nn.functional.avg_pool2d(image_feature, stride)
|
420 |
+
elif self.config.mm_spatial_pool_mode == "max":
|
421 |
+
image_feature = nn.functional.max_pool2d(image_feature, stride)
|
422 |
+
elif self.config.mm_spatial_pool_mode == "bilinear":
|
423 |
+
height, width = image_feature.shape[2:]
|
424 |
+
scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
|
425 |
+
image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')
|
426 |
+
|
427 |
+
else:
|
428 |
+
raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
|
429 |
+
image_feature = image_feature.permute(0, 2, 3, 1)
|
430 |
+
image_feature = image_feature.view(num_frames, -1, num_dim)
|
431 |
+
return image_feature
|
432 |
+
|
433 |
+
def encode_images(self, images):
|
434 |
+
image_features = self.get_model().get_vision_tower()(images)
|
435 |
+
# image_features = self.get_model().vision_resampler(image_features, images=images)
|
436 |
+
image_features = self.get_model().mm_projector(image_features)
|
437 |
+
return image_features
|
438 |
+
|
439 |
+
def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
|
440 |
+
videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
|
441 |
+
per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096)
|
442 |
+
all_videos_or_images_features = []
|
443 |
+
all_faster_video_features = []
|
444 |
+
cur_mm_spatial_pool_stride = self.config.mm_spatial_pool_stride
|
445 |
+
|
446 |
+
for idx, feat in enumerate(per_videos_or_images_features):
|
447 |
+
|
448 |
+
feat = self.get_model().mm_projector(feat)
|
449 |
+
faster_video_feature = 0
|
450 |
+
slower_img_feat = 0
|
451 |
+
if idx in video_idx_in_batch and cur_mm_spatial_pool_stride > 1:
|
452 |
+
slower_img_feat = self.get_2dPool(feat,cur_mm_spatial_pool_stride)
|
453 |
+
if self.config.add_faster_video:
|
454 |
+
cur_mm_spatial_pool_stride = cur_mm_spatial_pool_stride * 2
|
455 |
+
faster_video_feature = self.get_2dPool(feat,cur_mm_spatial_pool_stride)
|
456 |
+
if slower_img_feat != 0:
|
457 |
+
all_videos_or_images_features.append(slower_img_feat)
|
458 |
+
else:
|
459 |
+
all_videos_or_images_features.append(feat)
|
460 |
+
all_faster_video_features.append(faster_video_feature)
|
461 |
+
return all_videos_or_images_features,all_faster_video_features
|
462 |
+
|
463 |
+
def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None):
|
464 |
+
vision_tower = self.get_vision_tower()
|
465 |
+
# rank_print(modalities)
|
466 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
467 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
468 |
+
|
469 |
+
if isinstance(modalities, str):
|
470 |
+
modalities = [modalities]
|
471 |
+
|
472 |
+
# import pdb; pdb.set_trace()
|
473 |
+
if type(images) is list or images.ndim == 5:
|
474 |
+
if type(images) is list:
|
475 |
+
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
|
476 |
+
images_list = []
|
477 |
+
for image in images:
|
478 |
+
if image.ndim == 4:
|
479 |
+
images_list.append(image)
|
480 |
+
else:
|
481 |
+
images_list.append(image.unsqueeze(0))
|
482 |
+
concat_images = torch.cat([image for image in images_list], dim=0)
|
483 |
+
split_sizes = [image.shape[0] for image in images_list]
|
484 |
+
encoded_image_features = self.encode_images(concat_images)
|
485 |
+
# image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
|
486 |
+
|
487 |
+
# This is a list, each element is [num_images, patch * patch, dim]
|
488 |
+
# rank_print(f"Concat images : {concat_images.shape}")
|
489 |
+
encoded_image_features = torch.split(encoded_image_features, split_sizes)
|
490 |
+
image_features = []
|
491 |
+
for idx, image_feat in enumerate(encoded_image_features):
|
492 |
+
image_features.append(image_feat)
|
493 |
+
# image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
|
494 |
+
# rank_print(f"Encoded image feats : {[x.shape for x in image_features]}")
|
495 |
+
# image_features = torch.split(image_features, split_sizes, dim=0)
|
496 |
+
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
497 |
+
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
498 |
+
mm_newline_position = getattr(self.config, "mm_newline_position", "one_token")
|
499 |
+
|
500 |
+
if mm_patch_merge_type == "flat":
|
501 |
+
image_features = [x.flatten(0, 1) for x in image_features]
|
502 |
+
|
503 |
+
elif mm_patch_merge_type.startswith("spatial"):
|
504 |
+
new_image_features = []
|
505 |
+
for image_idx, image_feature in enumerate(image_features):
|
506 |
+
# FIXME: now assume the image is square, and split to 2x2 patches
|
507 |
+
# num_patches = h * w, where h = w = sqrt(num_patches)
|
508 |
+
# currently image_feature is a tensor of shape (4, num_patches, hidden_size)
|
509 |
+
# we want to first unflatten it to (2, 2, h, w, hidden_size)
|
510 |
+
# rank0_print("At least we are reaching here")
|
511 |
+
# import pdb; pdb.set_trace()
|
512 |
+
if image_feature.shape[0] > 1: # multi patches and multi images operations
|
513 |
+
# rank0_print("Single-images")
|
514 |
+
base_image_feature = image_feature[0]
|
515 |
+
image_feature = image_feature[1:]
|
516 |
+
height = width = self.get_vision_tower().num_patches_per_side
|
517 |
+
assert height * width == base_image_feature.shape[0]
|
518 |
+
|
519 |
+
if "anyres_max" in image_aspect_ratio:
|
520 |
+
matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
|
521 |
+
if matched_anyres_max_num_patches:
|
522 |
+
max_num_patches = int(matched_anyres_max_num_patches.group(1))
|
523 |
+
|
524 |
+
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
|
525 |
+
if hasattr(self.get_vision_tower(), "image_size"):
|
526 |
+
vision_tower_image_size = self.get_vision_tower().image_size
|
527 |
+
else:
|
528 |
+
raise ValueError("vision_tower_image_size is not found in the vision tower.")
|
529 |
+
try:
|
530 |
+
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
|
531 |
+
except Exception as e:
|
532 |
+
num_patch_width, num_patch_height = 2, 2
|
533 |
+
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
|
534 |
+
else:
|
535 |
+
image_feature = image_feature.view(2, 2, height, width, -1)
|
536 |
+
|
537 |
+
if "maxpool2x2" in mm_patch_merge_type:
|
538 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
539 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
540 |
+
image_feature = nn.functional.max_pool2d(image_feature, 2)
|
541 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
542 |
+
elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
|
543 |
+
unit = image_feature.shape[2]
|
544 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
545 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
546 |
+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
547 |
+
c, h, w = image_feature.shape
|
548 |
+
times = math.sqrt(h * w / (max_num_patches * unit**2))
|
549 |
+
if times > 1.1:
|
550 |
+
image_feature = image_feature[None]
|
551 |
+
image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
|
552 |
+
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
|
553 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
554 |
+
elif "unpad" in mm_patch_merge_type:
|
555 |
+
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
|
556 |
+
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
|
557 |
+
image_feature = unpad_image(image_feature, image_sizes[image_idx])
|
558 |
+
image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
|
559 |
+
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
|
560 |
+
else:
|
561 |
+
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
|
562 |
+
image_feature = image_feature.flatten(0, 3)
|
563 |
+
if "nobase" in mm_patch_merge_type:
|
564 |
+
pass
|
565 |
+
else:
|
566 |
+
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
|
567 |
+
new_image_features.append(image_feature)
|
568 |
+
else: # single image operations
|
569 |
+
image_feature = image_feature[0]
|
570 |
+
if "unpad" in mm_patch_merge_type:
|
571 |
+
image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
|
572 |
+
|
573 |
+
new_image_features.append(image_feature)
|
574 |
+
image_features = new_image_features
|
575 |
+
else:
|
576 |
+
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
|
577 |
+
else:
|
578 |
+
image_features = self.encode_images(images)
|
579 |
+
|
580 |
+
# TODO: image start / end is not implemented here to support pretraining.
|
581 |
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
|
582 |
+
raise NotImplementedError
|
583 |
+
# rank_print(f"Total images : {len(image_features)}")
|
584 |
+
|
585 |
+
# Let's just add dummy tensors if they do not exist,
|
586 |
+
# it is a headache to deal with None all the time.
|
587 |
+
# But it is not ideal, and if you have a better idea,
|
588 |
+
# please open an issue / submit a PR, thanks.
|
589 |
+
_labels = labels
|
590 |
+
_position_ids = position_ids
|
591 |
+
_attention_mask = attention_mask
|
592 |
+
if attention_mask is None:
|
593 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
594 |
+
else:
|
595 |
+
attention_mask = attention_mask.bool()
|
596 |
+
if position_ids is None:
|
597 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
598 |
+
if labels is None:
|
599 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
600 |
+
|
601 |
+
# remove the padding using attention_mask -- FIXME
|
602 |
+
_input_ids = input_ids
|
603 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
604 |
+
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
|
605 |
+
old_attention_mask = attention_mask.clone().detach()
|
606 |
+
|
607 |
+
new_input_embeds = []
|
608 |
+
new_labels = []
|
609 |
+
cur_image_idx = 0
|
610 |
+
img_token_num = [0 for _ in range(len(input_ids))]
|
611 |
+
num_images_batch = []
|
612 |
+
# rank_print("Inserting Images embedding")
|
613 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
614 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
615 |
+
num_images_batch.append(num_images)
|
616 |
+
# rank0_print(num_images)
|
617 |
+
if num_images == 0:
|
618 |
+
cur_image_features = image_features[cur_image_idx]
|
619 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
620 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
621 |
+
new_input_embeds.append(cur_input_embeds)
|
622 |
+
new_labels.append(labels[batch_idx])
|
623 |
+
cur_image_idx += 1
|
624 |
+
continue
|
625 |
+
|
626 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
627 |
+
cur_input_ids_noim = []
|
628 |
+
cur_labels = labels[batch_idx]
|
629 |
+
cur_labels_noim = []
|
630 |
+
for i in range(len(image_token_indices) - 1):
|
631 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
632 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
|
633 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
634 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
635 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
636 |
+
cur_new_input_embeds = []
|
637 |
+
cur_new_labels = []
|
638 |
+
|
639 |
+
for i in range(num_images + 1):
|
640 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
641 |
+
cur_new_labels.append(cur_labels_noim[i])
|
642 |
+
if i < num_images:
|
643 |
+
try:
|
644 |
+
cur_image_features = image_features[cur_image_idx]
|
645 |
+
except IndexError:
|
646 |
+
cur_image_features = image_features[cur_image_idx - 1]
|
647 |
+
img_token_num[batch_idx] += image_features[cur_image_idx].shape[0]
|
648 |
+
cur_image_idx += 1
|
649 |
+
cur_new_input_embeds.append(cur_image_features)
|
650 |
+
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
651 |
+
|
652 |
+
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
|
653 |
+
|
654 |
+
# import pdb; pdb.set_trace()
|
655 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
656 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
657 |
+
|
658 |
+
new_input_embeds.append(cur_new_input_embeds)
|
659 |
+
new_labels.append(cur_new_labels)
|
660 |
+
|
661 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
662 |
+
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
663 |
+
# rank_print("Finishing Inserting")
|
664 |
+
|
665 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
|
666 |
+
new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
|
667 |
+
# TODO: Hard code for control loss spike
|
668 |
+
# if tokenizer_model_max_length is not None:
|
669 |
+
# new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
|
670 |
+
# new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
|
671 |
+
|
672 |
+
# Combine them
|
673 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
674 |
+
batch_size = len(new_input_embeds)
|
675 |
+
|
676 |
+
new_input_embeds_padded = []
|
677 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
678 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
679 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
680 |
+
# rank0_print("Prepare pos id")
|
681 |
+
|
682 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
683 |
+
cur_len = cur_new_embed.shape[0]
|
684 |
+
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
685 |
+
new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
|
686 |
+
if cur_len > 0:
|
687 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
688 |
+
attention_mask[i, -cur_len:] = True
|
689 |
+
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
690 |
+
else:
|
691 |
+
new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
|
692 |
+
if cur_len > 0:
|
693 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
694 |
+
attention_mask[i, :cur_len] = True
|
695 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
696 |
+
|
697 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
698 |
+
# rank0_print("tokenizer padding")
|
699 |
+
|
700 |
+
if _labels is None:
|
701 |
+
new_labels = None
|
702 |
+
else:
|
703 |
+
new_labels = new_labels_padded
|
704 |
+
|
705 |
+
if _attention_mask is None:
|
706 |
+
attention_mask = None
|
707 |
+
else:
|
708 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
709 |
+
|
710 |
+
if _position_ids is None:
|
711 |
+
position_ids = None
|
712 |
+
if getattr(self.config, "use_pos_skipping", False) and self.training:
|
713 |
+
position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device)
|
714 |
+
split_position = random.randint(0, new_input_embeds.size(1))
|
715 |
+
left_add = random.randint(0, self.config.pos_skipping_range)
|
716 |
+
right_add = random.randint(left_add, self.config.pos_skipping_range)
|
717 |
+
position_ids[:, :split_position] += left_add
|
718 |
+
position_ids[:, split_position:] += right_add
|
719 |
+
# import pdb; pdb.set_trace()
|
720 |
+
# rank0_print("Finish preparing")
|
721 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, old_attention_mask, img_token_num, num_images_batch
|
722 |
+
|
723 |
+
|
724 |
+
class MLCDSegConfig(Qwen2Config):
|
725 |
+
model_type = "mlcd_seg"
|
726 |
+
|
727 |
+
|
728 |
+
class MLCDSegModel(MLCDSegMetaModel, Qwen2Model):
|
729 |
+
config_class = MLCDSegConfig
|
730 |
+
|
731 |
+
def __init__(self, config: Qwen2Config):
|
732 |
+
super(MLCDSegModel, self).__init__(config)
|
733 |
+
|
734 |
+
@dataclass
|
735 |
+
class MLCDSegOutputWithPast(CausalLMOutputWithPast):
|
736 |
+
labels: Optional[torch.FloatTensor] = None
|
737 |
+
|
738 |
+
class MLCDSegForCausalLM(Qwen2ForCausalLM, MLCDSegMetaForCausalLM):
|
739 |
+
config_class = MLCDSegConfig
|
740 |
+
|
741 |
+
def __init__(self, config):
|
742 |
+
# super(Qwen2ForCausalLM, self).__init__(config)
|
743 |
+
Qwen2ForCausalLM.__init__(self, config)
|
744 |
+
config.model_type = "mlcd_seg_clm"
|
745 |
+
config.rope_scaling = None
|
746 |
+
|
747 |
+
self.model = MLCDSegModel(config)
|
748 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
749 |
+
# Initialize weights and apply final processing
|
750 |
+
self.post_init()
|
751 |
+
self.sam_transform = ResizeLongestSide(IMG_SIZE)
|
752 |
+
|
753 |
+
def get_model(self):
|
754 |
+
return self.model
|
755 |
+
|
756 |
+
def forward(
|
757 |
+
self,
|
758 |
+
input_ids: torch.LongTensor = None,
|
759 |
+
attention_mask: Optional[torch.Tensor] = None,
|
760 |
+
position_ids: Optional[torch.LongTensor] = None,
|
761 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
762 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
763 |
+
labels: Optional[torch.LongTensor] = None,
|
764 |
+
use_cache: Optional[bool] = None,
|
765 |
+
output_attentions: Optional[bool] = None,
|
766 |
+
output_hidden_states: Optional[bool] = None,
|
767 |
+
images: Optional[torch.FloatTensor] = None,
|
768 |
+
image_sizes: Optional[List[List[int]]] = None,
|
769 |
+
return_dict: Optional[bool] = None,
|
770 |
+
modalities: Optional[List[str]] = ["image"],
|
771 |
+
dpo_forward: Optional[bool] = False,
|
772 |
+
cache_position=None,
|
773 |
+
grounding_enc_imgs: Optional[List[torch.FloatTensor]] = None,
|
774 |
+
image_sam_resizes: Optional[List[torch.FloatTensor]] = None,
|
775 |
+
original_sizes: Optional[List[torch.FloatTensor]] = None,
|
776 |
+
masks_list: Optional[List[List[torch.FloatTensor]]] = None,
|
777 |
+
infer: bool = False,
|
778 |
+
force_seg: bool = True
|
779 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
780 |
+
input_ids_ = input_ids
|
781 |
+
if inputs_embeds is None:
|
782 |
+
(
|
783 |
+
input_ids,
|
784 |
+
position_ids,
|
785 |
+
attention_mask,
|
786 |
+
past_key_values,
|
787 |
+
inputs_embeds,
|
788 |
+
labels,
|
789 |
+
old_attention_mask,
|
790 |
+
img_token_num,
|
791 |
+
num_images_batch
|
792 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
793 |
+
input_ids,
|
794 |
+
position_ids,
|
795 |
+
attention_mask,
|
796 |
+
past_key_values,
|
797 |
+
labels,
|
798 |
+
images,
|
799 |
+
modalities,
|
800 |
+
image_sizes
|
801 |
+
)
|
802 |
+
|
803 |
+
if dpo_forward:
|
804 |
+
outputs = self.model(
|
805 |
+
input_ids=input_ids,
|
806 |
+
attention_mask=attention_mask,
|
807 |
+
position_ids=position_ids,
|
808 |
+
past_key_values=past_key_values,
|
809 |
+
inputs_embeds=inputs_embeds,
|
810 |
+
use_cache=use_cache,
|
811 |
+
output_attentions=output_attentions,
|
812 |
+
output_hidden_states=output_hidden_states,
|
813 |
+
return_dict=return_dict,
|
814 |
+
)
|
815 |
+
|
816 |
+
hidden_states = outputs[0]
|
817 |
+
logits = self.lm_head(hidden_states)
|
818 |
+
return logits, labels
|
819 |
+
|
820 |
+
else:
|
821 |
+
output = super().forward(
|
822 |
+
input_ids=input_ids,
|
823 |
+
attention_mask=attention_mask,
|
824 |
+
position_ids=position_ids,
|
825 |
+
past_key_values=past_key_values,
|
826 |
+
inputs_embeds=inputs_embeds,
|
827 |
+
labels=labels,
|
828 |
+
use_cache=use_cache,
|
829 |
+
output_attentions=output_attentions,
|
830 |
+
output_hidden_states=True,
|
831 |
+
return_dict=return_dict,
|
832 |
+
cache_position=cache_position
|
833 |
+
)
|
834 |
+
sam_image_embeddings = self.get_grounding_encoder_embs(grounding_enc_imgs)
|
835 |
+
if force_seg:
|
836 |
+
seg_token_mask = self.create_seg_token_mask(input_ids_, old_attention_mask, img_token_num, num_images_batch)
|
837 |
+
else:
|
838 |
+
# should be raise NotImplementedError
|
839 |
+
seg_token_mask = self.create_seg_token_mask(input_ids_, old_attention_mask, img_token_num, num_images_batch)
|
840 |
+
seg_text_embeds_batch = self.process_hidden_states(output["hidden_states"], seg_token_mask)
|
841 |
+
pred_masks_batch = self.generate_and_postprocess_masks(seg_text_embeds_batch, sam_image_embeddings, num_images_batch, image_sam_resizes, original_sizes)
|
842 |
+
if infer:
|
843 |
+
return {"output":output, "pred_masks":pred_masks_batch}
|
844 |
+
return MLCDSegOutputWithPast(**output)
|
845 |
+
|
846 |
+
@torch.no_grad()
|
847 |
+
def generate(
|
848 |
+
self,
|
849 |
+
inputs: Optional[torch.Tensor] = None,
|
850 |
+
images: Optional[torch.Tensor] = None,
|
851 |
+
image_sizes: Optional[torch.Tensor] = None,
|
852 |
+
modalities: Optional[List[str]] = ["image"],
|
853 |
+
**kwargs,
|
854 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
855 |
+
position_ids = kwargs.pop("position_ids", None)
|
856 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
857 |
+
if "inputs_embeds" in kwargs:
|
858 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
859 |
+
|
860 |
+
if images is not None:
|
861 |
+
(inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
|
862 |
+
else:
|
863 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
864 |
+
|
865 |
+
return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
|
866 |
+
|
867 |
+
def generate_and_postprocess_masks(self, seg_text_embeds_batch, sam_image_embeddings, num_images_batch, image_sam_resizes, original_sizes):
|
868 |
+
assert len(seg_text_embeds_batch) == len(num_images_batch)
|
869 |
+
|
870 |
+
pred_masks_batch = [] # list()
|
871 |
+
for batch_i, seg_text_embeds in enumerate(seg_text_embeds_batch):
|
872 |
+
num_img = max(1, num_images_batch[batch_i])
|
873 |
+
|
874 |
+
pred_mask_ = torch.empty((0, original_sizes[batch_i][0], original_sizes[batch_i][1]), device=seg_text_embeds.device)
|
875 |
+
for img_i in range(num_img):
|
876 |
+
sparse_embeddings, dense_embeddings = self.model.sam.prompt_encoder(
|
877 |
+
points=None, boxes=None, masks=None, text_embeds=seg_text_embeds.unsqueeze(1)[img_i::num_img,:,:]
|
878 |
+
)
|
879 |
+
sparse_embeddings = sparse_embeddings.to(seg_text_embeds.dtype)
|
880 |
+
|
881 |
+
low_res_masks, _ = self.model.sam.mask_decoder(
|
882 |
+
image_embeddings=sam_image_embeddings[batch_i][img_i].unsqueeze(0),
|
883 |
+
image_pe=self.model.sam.prompt_encoder.get_dense_pe(),
|
884 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
885 |
+
dense_prompt_embeddings=dense_embeddings,
|
886 |
+
multimask_output=False, )
|
887 |
+
pred_mask = self.model.sam.postprocess_masks(
|
888 |
+
low_res_masks, input_size=image_sam_resizes[batch_i][img_i], original_size=original_sizes[batch_i],)
|
889 |
+
pred_mask_ = torch.cat([pred_mask_, pred_mask[:,0]], dim=0)
|
890 |
+
pred_masks_batch.append(pred_mask_)
|
891 |
+
return pred_masks_batch
|
892 |
+
|
893 |
+
def process_hidden_states(self, output_hidden_states, seg_token_mask):
|
894 |
+
hidden_states_ = [self.model.text2sam_projection(output_hidden_states[-1])]
|
895 |
+
hidden_states_ = torch.stack(hidden_states_, dim=-1).sum(dim=-1)
|
896 |
+
seg_text_embeds_batch = []
|
897 |
+
for i, hidden_state_ in enumerate(hidden_states_):
|
898 |
+
# assert hidden_state_.shape[0] == seg_token_mask.shape[1], f"hidden:{hidden_state_.shape}, segtoken:{seg_token_mask.shape}"
|
899 |
+
# seg_text_embeds_batch.append(hidden_state_[seg_token_mask[i]])
|
900 |
+
seg_text_embeds_batch.append(hidden_state_[seg_token_mask[i][:hidden_state_.shape[0]]])
|
901 |
+
return seg_text_embeds_batch
|
902 |
+
|
903 |
+
def create_seg_token_mask(self, input_ids, attention_mask, img_token_num, num_images_batch):
|
904 |
+
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
|
905 |
+
max_len = 0
|
906 |
+
for i, _ in enumerate(input_ids):
|
907 |
+
max_len = max(max_len, len(input_ids[i]) + img_token_num[i] - num_images_batch[i])
|
908 |
+
|
909 |
+
seg_token_mask = []
|
910 |
+
for i, _ in enumerate(input_ids):
|
911 |
+
mask = input_ids[i][num_images_batch[i]:] == self.seg_token_idx
|
912 |
+
seg_token_mask.append(
|
913 |
+
torch.cat(
|
914 |
+
[torch.zeros((1, img_token_num[i])).bool().cuda(), mask.unsqueeze(0), torch.zeros((1, max_len-(len(input_ids[i]) + img_token_num[i] - num_images_batch[i]))).bool().cuda()], dim=1
|
915 |
+
)
|
916 |
+
)
|
917 |
+
return torch.cat(seg_token_mask, dim=0)
|
918 |
+
|
919 |
+
def get_grounding_encoder_embs(self, batch_images: torch.FloatTensor):
|
920 |
+
# with torch.no_grad():
|
921 |
+
batch_feats = []
|
922 |
+
for images in batch_images:
|
923 |
+
batch_feats.append(torch.cat([self._encode_single_image(img) for img in images], dim=0))
|
924 |
+
return batch_feats
|
925 |
+
|
926 |
+
def _encode_single_image(self, image):
|
927 |
+
# torch.cuda.empty_cache()
|
928 |
+
return self.model.sam.image_encoder(image.unsqueeze(0))
|
929 |
+
|
930 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
931 |
+
images = kwargs.pop("images", None)
|
932 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
933 |
+
inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
|
934 |
+
if images is not None:
|
935 |
+
inputs["images"] = images
|
936 |
+
if image_sizes is not None:
|
937 |
+
inputs["image_sizes"] = image_sizes
|
938 |
+
return inputs
|
939 |
+
|
940 |
+
def process_prompt(self, text, tokenizer: PreTrainedTokenizer, force_seg=True) -> Dict:
|
941 |
+
conv = default_conversation.copy()
|
942 |
+
BEGIN_SIGNAL = "### "
|
943 |
+
END_SIGNAL = "\n"
|
944 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
945 |
+
# Apply prompt templates
|
946 |
+
sys_prompt = default_conversation.system + "\n\n"
|
947 |
+
full_prompt = sys_prompt + BEGIN_SIGNAL + roles["human"] + ": " + text + END_SIGNAL
|
948 |
+
if force_seg:
|
949 |
+
full_prompt += BEGIN_SIGNAL + roles["gpt"] + ": It is [SEG]." + END_SIGNAL
|
950 |
+
full_prompt += BEGIN_SIGNAL
|
951 |
+
input_ids = torch.stack([gen_image_token(full_prompt, tokenizer, return_tensors='pt')], dim=0)
|
952 |
+
return dict(
|
953 |
+
input_ids=input_ids,
|
954 |
+
labels=None,
|
955 |
+
)
|
956 |
+
|
957 |
+
def process_images(self, images, image_processor, model_cfg):
|
958 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
959 |
+
new_images = []
|
960 |
+
if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
|
961 |
+
for image in images:
|
962 |
+
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
|
963 |
+
new_images.append(image)
|
964 |
+
else:
|
965 |
+
return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
|
966 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
967 |
+
new_images = torch.stack(new_images, dim=0)
|
968 |
+
return new_images
|
969 |
+
|
970 |
+
def predict_forward(self, image, prompt, tokenizer, force_seg=True):
|
971 |
+
self.seg_token_idx = tokenizer(DEFAULT_SEG_TOKEN, add_special_tokens=False).input_ids[0]
|
972 |
+
image_np = np.array(image)
|
973 |
+
image_sizes = [image.size]
|
974 |
+
input_ids = self.process_prompt(prompt, tokenizer, force_seg)["input_ids"].to(self.device) # 这里需要设置对应的device
|
975 |
+
image_processor = self.get_vision_tower().image_processor
|
976 |
+
image_tensors = self.process_images([image], image_processor, self.config)
|
977 |
+
image_np_resize = self.sam_transform.apply_image(image_np)
|
978 |
+
original_size_list = [image_np.shape[:2]]
|
979 |
+
image_sam_resize_list = [image_np_resize.shape[:2]]
|
980 |
+
grounding_enc_img_list = [grounding_enc_processor(torch.from_numpy(image_np_resize).permute(2, 0, 1).contiguous()).to(dtype=self.dtype, device=self.device, non_blocking=True)]
|
981 |
+
collect_size = list(set(original_size_list))
|
982 |
+
if len(collect_size) == 0:
|
983 |
+
mask_h, mask_w = 336, 336
|
984 |
+
elif len(collect_size) == 1:
|
985 |
+
mask_h, mask_w = collect_size[0]
|
986 |
+
else:
|
987 |
+
areas = [h*w for (h, w) in collect_size]
|
988 |
+
mask_h, mask_w = collect_size[areas.index(max(areas))]
|
989 |
+
if isinstance(image_tensors, list):
|
990 |
+
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", None)
|
991 |
+
if image_aspect_ratio=="anyres_mul" or image_aspect_ratio=="anyres":
|
992 |
+
image_tensors = [[x_.to(dtype=self.dtype, device=self.device, non_blocking=True)for x_ in image_tensors]]
|
993 |
+
else:
|
994 |
+
image_tensors = [[x_.unsqueeze(dim=0).to(dtype=self.dtype, device=self.device, non_blocking=True) for x_ in image_tensors]]
|
995 |
+
else:
|
996 |
+
image_tensors = image_tensors.to(dtype=self.dtype, device='cuda', non_blocking=True)
|
997 |
+
|
998 |
+
with torch.inference_mode():
|
999 |
+
net_out = self.forward(
|
1000 |
+
input_ids=input_ids,
|
1001 |
+
output_hidden_states=True,
|
1002 |
+
images=image_tensors,
|
1003 |
+
image_sizes=image_sizes,
|
1004 |
+
grounding_enc_imgs=[torch.stack(grounding_enc_img_list, dim=0)],
|
1005 |
+
image_sam_resizes=[image_sam_resize_list],
|
1006 |
+
original_sizes=[(mask_h, mask_w)],
|
1007 |
+
infer=True,
|
1008 |
+
force_seg=force_seg
|
1009 |
+
)
|
1010 |
+
pred_mask = net_out["pred_masks"][0]
|
1011 |
+
return pred_mask
|
1012 |
+
|
1013 |
+
|
1014 |
+
def gen_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
1015 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
1016 |
+
|
1017 |
+
def insert_separator(X, sep):
|
1018 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
1019 |
+
|
1020 |
+
input_ids = []
|
1021 |
+
offset = 0
|
1022 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
1023 |
+
offset = 1
|
1024 |
+
input_ids.append(prompt_chunks[0][0])
|
1025 |
+
|
1026 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
1027 |
+
input_ids.extend(x[offset:])
|
1028 |
+
|
1029 |
+
if return_tensors is not None:
|
1030 |
+
if return_tensors == "pt":
|
1031 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
1032 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
1033 |
+
return input_ids
|
1034 |
+
|
1035 |
+
|
1036 |
+
def grounding_enc_processor(x: torch.Tensor) -> torch.Tensor:
|
1037 |
+
x = (x - IMG_MEAN) / IMG_STD
|
1038 |
+
h, w = x.shape[-2:]
|
1039 |
+
x = F.pad(x, (0, IMG_SIZE - w, 0, IMG_SIZE - h))
|
1040 |
+
return x
|
1041 |
+
|
1042 |
+
|
1043 |
+
AutoConfig.register("mlcd_seg", MLCDSegConfig)
|
1044 |
+
AutoModelForCausalLM.register(MLCDSegConfig, MLCDSegForCausalLM)
|
model-00001-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:26445d9fa45cd664abd8fff7c7bbcad04f55ed7e4fffa4a1626ce28cd9db8a67
|
3 |
+
size 4874815168
|
model-00002-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7243468b3338748241548e0d162d45dd5cdca63b73a4db86060f36053e23f065
|
3 |
+
size 4932751008
|
model-00003-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:84b408e9cc3986e0f122644d3ba397e5a92ac492fdebe6dbbe7a5bd9a7fc5955
|
3 |
+
size 4996578680
|
model-00004-of-00004.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a74f572810afc8de9c3c31c8f278fd276c5ece132356b9f29fc945c09e88ff6a
|
3 |
+
size 2345774016
|
model.safetensors.index.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sam.py
ADDED
@@ -0,0 +1,1324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Merge files from META Sam project @DeepGlint 2025
|
3 |
+
https://github.com/facebookresearch/segment-anything
|
4 |
+
'''
|
5 |
+
|
6 |
+
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch import Tensor
|
13 |
+
from typing import List, Dict, Any, Tuple, Type, Optional
|
14 |
+
from functools import partial
|
15 |
+
|
16 |
+
|
17 |
+
def text2sam_projection_layer(config):
|
18 |
+
in_dim, out_dim = config.hidden_size, 256
|
19 |
+
modules = [nn.Linear(in_dim, out_dim)]
|
20 |
+
for _ in range(1, 2):
|
21 |
+
modules.append(nn.GELU())
|
22 |
+
modules.append(nn.Linear(out_dim, out_dim))
|
23 |
+
return nn.Sequential(*modules)
|
24 |
+
|
25 |
+
|
26 |
+
def build_sam_vit_h():
|
27 |
+
return _build_sam(
|
28 |
+
encoder_embed_dim=1280,
|
29 |
+
encoder_depth=32,
|
30 |
+
encoder_num_heads=16,
|
31 |
+
encoder_global_attn_indexes=[7, 15, 23, 31],
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def _build_sam(
|
36 |
+
encoder_embed_dim,
|
37 |
+
encoder_depth,
|
38 |
+
encoder_num_heads,
|
39 |
+
encoder_global_attn_indexes,
|
40 |
+
):
|
41 |
+
prompt_embed_dim = 256
|
42 |
+
image_size = 1024
|
43 |
+
vit_patch_size = 16
|
44 |
+
image_embedding_size = image_size // vit_patch_size
|
45 |
+
sam = Sam(
|
46 |
+
image_encoder=ImageEncoderViT(
|
47 |
+
depth=encoder_depth,
|
48 |
+
embed_dim=encoder_embed_dim,
|
49 |
+
img_size=image_size,
|
50 |
+
mlp_ratio=4,
|
51 |
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
52 |
+
num_heads=encoder_num_heads,
|
53 |
+
patch_size=vit_patch_size,
|
54 |
+
qkv_bias=True,
|
55 |
+
use_rel_pos=True,
|
56 |
+
global_attn_indexes=encoder_global_attn_indexes,
|
57 |
+
window_size=14,
|
58 |
+
out_chans=prompt_embed_dim,
|
59 |
+
),
|
60 |
+
prompt_encoder=PromptEncoder(
|
61 |
+
embed_dim=prompt_embed_dim,
|
62 |
+
image_embedding_size=(image_embedding_size, image_embedding_size),
|
63 |
+
input_image_size=(image_size, image_size),
|
64 |
+
mask_in_chans=16,
|
65 |
+
),
|
66 |
+
mask_decoder=MaskDecoder(
|
67 |
+
num_multimask_outputs=3,
|
68 |
+
transformer=TwoWayTransformer(
|
69 |
+
depth=2,
|
70 |
+
embedding_dim=prompt_embed_dim,
|
71 |
+
mlp_dim=2048,
|
72 |
+
num_heads=8,
|
73 |
+
),
|
74 |
+
transformer_dim=prompt_embed_dim,
|
75 |
+
iou_head_depth=3,
|
76 |
+
iou_head_hidden_dim=256,
|
77 |
+
),
|
78 |
+
pixel_mean=[123.675, 116.28, 103.53],
|
79 |
+
pixel_std=[58.395, 57.12, 57.375],
|
80 |
+
)
|
81 |
+
sam.eval()
|
82 |
+
return sam
|
83 |
+
|
84 |
+
|
85 |
+
def window_partition(
|
86 |
+
x: torch.Tensor, window_size: int
|
87 |
+
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
88 |
+
"""
|
89 |
+
Partition into non-overlapping windows with padding if needed.
|
90 |
+
Args:
|
91 |
+
x (tensor): input tokens with [B, H, W, C].
|
92 |
+
window_size (int): window size.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
|
96 |
+
(Hp, Wp): padded height and width before partition
|
97 |
+
"""
|
98 |
+
B, H, W, C = x.shape
|
99 |
+
|
100 |
+
pad_h = (window_size - H % window_size) % window_size
|
101 |
+
pad_w = (window_size - W % window_size) % window_size
|
102 |
+
if pad_h > 0 or pad_w > 0:
|
103 |
+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
104 |
+
Hp, Wp = H + pad_h, W + pad_w
|
105 |
+
|
106 |
+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
107 |
+
windows = (
|
108 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
109 |
+
)
|
110 |
+
return windows, (Hp, Wp)
|
111 |
+
|
112 |
+
|
113 |
+
def window_unpartition(
|
114 |
+
windows: torch.Tensor,
|
115 |
+
window_size: int,
|
116 |
+
pad_hw: Tuple[int, int],
|
117 |
+
hw: Tuple[int, int],
|
118 |
+
) -> torch.Tensor:
|
119 |
+
"""
|
120 |
+
Window unpartition into original sequences and removing padding.
|
121 |
+
Args:
|
122 |
+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
|
123 |
+
window_size (int): window size.
|
124 |
+
pad_hw (Tuple): padded height and width (Hp, Wp).
|
125 |
+
hw (Tuple): original height and width (H, W) before padding.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
x: unpartitioned sequences with [B, H, W, C].
|
129 |
+
"""
|
130 |
+
Hp, Wp = pad_hw
|
131 |
+
H, W = hw
|
132 |
+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
133 |
+
x = windows.view(
|
134 |
+
B, Hp // window_size, Wp // window_size, window_size, window_size, -1
|
135 |
+
)
|
136 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
137 |
+
|
138 |
+
if Hp > H or Wp > W:
|
139 |
+
x = x[:, :H, :W, :].contiguous()
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
class CommonMLP(nn.Module):
|
144 |
+
def __init__(
|
145 |
+
self,
|
146 |
+
embedding_dim: int,
|
147 |
+
mlp_dim: int,
|
148 |
+
act: Type[nn.Module] = nn.GELU,
|
149 |
+
) -> None:
|
150 |
+
super().__init__()
|
151 |
+
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
|
152 |
+
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
|
153 |
+
self.act = act()
|
154 |
+
|
155 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
156 |
+
return self.lin2(self.act(self.lin1(x)))
|
157 |
+
|
158 |
+
|
159 |
+
class LayerNorm2d(nn.Module):
|
160 |
+
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
|
161 |
+
super().__init__()
|
162 |
+
self.weight = nn.Parameter(torch.ones(num_channels))
|
163 |
+
self.bias = nn.Parameter(torch.zeros(num_channels))
|
164 |
+
self.eps = eps
|
165 |
+
|
166 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
167 |
+
u = x.mean(1, keepdim=True)
|
168 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
169 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
170 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
171 |
+
return x
|
172 |
+
|
173 |
+
|
174 |
+
class Attention(nn.Module):
|
175 |
+
"""
|
176 |
+
An attention layer that allows for downscaling the size of the embedding
|
177 |
+
after projection to queries, keys, and values.
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
embedding_dim: int,
|
183 |
+
num_heads: int,
|
184 |
+
downsample_rate: int = 1,
|
185 |
+
) -> None:
|
186 |
+
super().__init__()
|
187 |
+
self.embedding_dim = embedding_dim
|
188 |
+
self.internal_dim = embedding_dim // downsample_rate
|
189 |
+
self.num_heads = num_heads
|
190 |
+
assert (
|
191 |
+
self.internal_dim % num_heads == 0
|
192 |
+
), "num_heads must divide embedding_dim."
|
193 |
+
|
194 |
+
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
195 |
+
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
|
196 |
+
self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
|
197 |
+
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
198 |
+
|
199 |
+
def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
|
200 |
+
b, n, c = x.shape
|
201 |
+
x = x.reshape(b, n, num_heads, c // num_heads)
|
202 |
+
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
203 |
+
|
204 |
+
def _recombine_heads(self, x: Tensor) -> Tensor:
|
205 |
+
b, n_heads, n_tokens, c_per_head = x.shape
|
206 |
+
x = x.transpose(1, 2)
|
207 |
+
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
208 |
+
|
209 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
210 |
+
# Input projections
|
211 |
+
q = self.q_proj(q)
|
212 |
+
k = self.k_proj(k)
|
213 |
+
v = self.v_proj(v)
|
214 |
+
|
215 |
+
# Separate into heads
|
216 |
+
q = self._separate_heads(q, self.num_heads)
|
217 |
+
k = self._separate_heads(k, self.num_heads)
|
218 |
+
v = self._separate_heads(v, self.num_heads)
|
219 |
+
|
220 |
+
# Attention
|
221 |
+
_, _, _, c_per_head = q.shape
|
222 |
+
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
223 |
+
attn = attn / math.sqrt(c_per_head)
|
224 |
+
attn = torch.softmax(attn, dim=-1)
|
225 |
+
|
226 |
+
# Get output
|
227 |
+
out = attn @ v
|
228 |
+
out = self._recombine_heads(out)
|
229 |
+
out = self.out_proj(out)
|
230 |
+
|
231 |
+
return out
|
232 |
+
|
233 |
+
|
234 |
+
class TwoWayTransformer(nn.Module):
|
235 |
+
def __init__(
|
236 |
+
self,
|
237 |
+
depth: int,
|
238 |
+
embedding_dim: int,
|
239 |
+
num_heads: int,
|
240 |
+
mlp_dim: int,
|
241 |
+
activation: Type[nn.Module] = nn.ReLU,
|
242 |
+
attention_downsample_rate: int = 2,
|
243 |
+
) -> None:
|
244 |
+
"""
|
245 |
+
A transformer decoder that attends to an input image using
|
246 |
+
queries whose positional embedding is supplied.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
depth (int): number of layers in the transformer
|
250 |
+
embedding_dim (int): the channel dimension for the input embeddings
|
251 |
+
num_heads (int): the number of heads for multihead attention. Must
|
252 |
+
divide embedding_dim
|
253 |
+
mlp_dim (int): the channel dimension internal to the MLP block
|
254 |
+
activation (nn.Module): the activation to use in the MLP block
|
255 |
+
"""
|
256 |
+
super().__init__()
|
257 |
+
self.depth = depth
|
258 |
+
self.embedding_dim = embedding_dim
|
259 |
+
self.num_heads = num_heads
|
260 |
+
self.mlp_dim = mlp_dim
|
261 |
+
self.layers = nn.ModuleList()
|
262 |
+
|
263 |
+
for i in range(depth):
|
264 |
+
self.layers.append(
|
265 |
+
TwoWayAttentionBlock(
|
266 |
+
embedding_dim=embedding_dim,
|
267 |
+
num_heads=num_heads,
|
268 |
+
mlp_dim=mlp_dim,
|
269 |
+
activation=activation,
|
270 |
+
attention_downsample_rate=attention_downsample_rate,
|
271 |
+
skip_first_layer_pe=(i == 0),
|
272 |
+
)
|
273 |
+
)
|
274 |
+
|
275 |
+
self.final_attn_token_to_image = Attention(
|
276 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
277 |
+
)
|
278 |
+
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
279 |
+
|
280 |
+
def forward(
|
281 |
+
self,
|
282 |
+
image_embedding: Tensor,
|
283 |
+
image_pe: Tensor,
|
284 |
+
point_embedding: Tensor,
|
285 |
+
) -> Tuple[Tensor, Tensor]:
|
286 |
+
"""
|
287 |
+
Args:
|
288 |
+
image_embedding (torch.Tensor): image to attend to. Should be shape
|
289 |
+
B x embedding_dim x h x w for any h and w.
|
290 |
+
image_pe (torch.Tensor): the positional encoding to add to the image. Must
|
291 |
+
have the same shape as image_embedding.
|
292 |
+
point_embedding (torch.Tensor): the embedding to add to the query points.
|
293 |
+
Must have shape B x N_points x embedding_dim for any N_points.
|
294 |
+
|
295 |
+
Returns:
|
296 |
+
torch.Tensor: the processed point_embedding
|
297 |
+
torch.Tensor: the processed image_embedding
|
298 |
+
"""
|
299 |
+
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
300 |
+
bs, c, h, w = image_embedding.shape
|
301 |
+
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
302 |
+
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
303 |
+
|
304 |
+
# Prepare queries
|
305 |
+
queries = point_embedding
|
306 |
+
keys = image_embedding
|
307 |
+
|
308 |
+
# Apply transformer blocks and final layernorm
|
309 |
+
for layer in self.layers:
|
310 |
+
queries, keys = layer(
|
311 |
+
queries=queries,
|
312 |
+
keys=keys,
|
313 |
+
query_pe=point_embedding,
|
314 |
+
key_pe=image_pe,
|
315 |
+
)
|
316 |
+
|
317 |
+
# Apply the final attention layer from the points to the image
|
318 |
+
q = queries + point_embedding
|
319 |
+
k = keys + image_pe
|
320 |
+
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
321 |
+
queries = queries + attn_out
|
322 |
+
queries = self.norm_final_attn(queries)
|
323 |
+
|
324 |
+
return queries, keys
|
325 |
+
|
326 |
+
|
327 |
+
class TwoWayAttentionBlock(nn.Module):
|
328 |
+
def __init__(
|
329 |
+
self,
|
330 |
+
embedding_dim: int,
|
331 |
+
num_heads: int,
|
332 |
+
mlp_dim: int = 2048,
|
333 |
+
activation: Type[nn.Module] = nn.ReLU,
|
334 |
+
attention_downsample_rate: int = 2,
|
335 |
+
skip_first_layer_pe: bool = False,
|
336 |
+
) -> None:
|
337 |
+
"""
|
338 |
+
A transformer block with four layers: (1) self-attention of sparse
|
339 |
+
inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
|
340 |
+
block on sparse inputs, and (4) cross attention of dense inputs to sparse
|
341 |
+
inputs.
|
342 |
+
|
343 |
+
Arguments:
|
344 |
+
embedding_dim (int): the channel dimension of the embeddings
|
345 |
+
num_heads (int): the number of heads in the attention layers
|
346 |
+
mlp_dim (int): the hidden dimension of the mlp block
|
347 |
+
activation (nn.Module): the activation of the mlp block
|
348 |
+
skip_first_layer_pe (bool): skip the PE on the first layer
|
349 |
+
"""
|
350 |
+
super().__init__()
|
351 |
+
self.self_attn = Attention(embedding_dim, num_heads)
|
352 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
353 |
+
|
354 |
+
self.cross_attn_token_to_image = Attention(
|
355 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
356 |
+
)
|
357 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
358 |
+
|
359 |
+
self.mlp = CommonMLP(embedding_dim, mlp_dim, activation)
|
360 |
+
self.norm3 = nn.LayerNorm(embedding_dim)
|
361 |
+
|
362 |
+
self.norm4 = nn.LayerNorm(embedding_dim)
|
363 |
+
self.cross_attn_image_to_token = Attention(
|
364 |
+
embedding_dim, num_heads, downsample_rate=attention_downsample_rate
|
365 |
+
)
|
366 |
+
|
367 |
+
self.skip_first_layer_pe = skip_first_layer_pe
|
368 |
+
|
369 |
+
def forward(
|
370 |
+
self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
|
371 |
+
) -> Tuple[Tensor, Tensor]:
|
372 |
+
# Self attention block
|
373 |
+
if self.skip_first_layer_pe:
|
374 |
+
queries = self.self_attn(q=queries, k=queries, v=queries)
|
375 |
+
else:
|
376 |
+
q = queries + query_pe
|
377 |
+
attn_out = self.self_attn(q=q, k=q, v=queries)
|
378 |
+
queries = queries + attn_out
|
379 |
+
queries = self.norm1(queries)
|
380 |
+
|
381 |
+
# Cross attention block, tokens attending to image embedding
|
382 |
+
q = queries + query_pe
|
383 |
+
k = keys + key_pe
|
384 |
+
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
385 |
+
queries = queries + attn_out
|
386 |
+
queries = self.norm2(queries)
|
387 |
+
|
388 |
+
# MLP block
|
389 |
+
mlp_out = self.mlp(queries)
|
390 |
+
queries = queries + mlp_out
|
391 |
+
queries = self.norm3(queries)
|
392 |
+
|
393 |
+
# Cross attention block, image embedding attending to tokens
|
394 |
+
q = queries + query_pe
|
395 |
+
k = keys + key_pe
|
396 |
+
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
397 |
+
keys = keys + attn_out
|
398 |
+
keys = self.norm4(keys)
|
399 |
+
|
400 |
+
return queries, keys
|
401 |
+
|
402 |
+
|
403 |
+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
404 |
+
"""
|
405 |
+
Get relative positional embeddings according to the relative positions of
|
406 |
+
query and key sizes.
|
407 |
+
Args:
|
408 |
+
q_size (int): size of query q.
|
409 |
+
k_size (int): size of key k.
|
410 |
+
rel_pos (Tensor): relative position embeddings (L, C).
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
Extracted positional embeddings according to relative positions.
|
414 |
+
"""
|
415 |
+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
416 |
+
# Interpolate rel pos if needed.
|
417 |
+
if rel_pos.shape[0] != max_rel_dist:
|
418 |
+
# Interpolate rel pos.
|
419 |
+
rel_pos_resized = F.interpolate(
|
420 |
+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
421 |
+
size=max_rel_dist,
|
422 |
+
mode="linear",
|
423 |
+
)
|
424 |
+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
425 |
+
else:
|
426 |
+
rel_pos_resized = rel_pos
|
427 |
+
|
428 |
+
# Scale the coords with short length if shapes for q and k are different.
|
429 |
+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
430 |
+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
431 |
+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
432 |
+
|
433 |
+
return rel_pos_resized[relative_coords.long()]
|
434 |
+
|
435 |
+
|
436 |
+
def add_decomposed_rel_pos(
|
437 |
+
attn: torch.Tensor,
|
438 |
+
q: torch.Tensor,
|
439 |
+
rel_pos_h: torch.Tensor,
|
440 |
+
rel_pos_w: torch.Tensor,
|
441 |
+
q_size: Tuple[int, int],
|
442 |
+
k_size: Tuple[int, int],
|
443 |
+
) -> torch.Tensor:
|
444 |
+
"""
|
445 |
+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
|
446 |
+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
|
447 |
+
Args:
|
448 |
+
attn (Tensor): attention map.
|
449 |
+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
|
450 |
+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
|
451 |
+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
|
452 |
+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
|
453 |
+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
|
454 |
+
|
455 |
+
Returns:
|
456 |
+
attn (Tensor): attention map with added relative positional embeddings.
|
457 |
+
"""
|
458 |
+
q_h, q_w = q_size
|
459 |
+
k_h, k_w = k_size
|
460 |
+
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
461 |
+
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
462 |
+
|
463 |
+
B, _, dim = q.shape
|
464 |
+
r_q = q.reshape(B, q_h, q_w, dim)
|
465 |
+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
466 |
+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
467 |
+
|
468 |
+
attn = (
|
469 |
+
attn.view(B, q_h, q_w, k_h, k_w)
|
470 |
+
+ rel_h[:, :, :, :, None]
|
471 |
+
+ rel_w[:, :, :, None, :]
|
472 |
+
).view(B, q_h * q_w, k_h * k_w)
|
473 |
+
|
474 |
+
return attn
|
475 |
+
|
476 |
+
|
477 |
+
class ImageEncoderViTAttention(nn.Module):
|
478 |
+
"""Multi-head Attention block with relative position embeddings."""
|
479 |
+
|
480 |
+
def __init__(
|
481 |
+
self,
|
482 |
+
dim: int,
|
483 |
+
num_heads: int = 8,
|
484 |
+
qkv_bias: bool = True,
|
485 |
+
use_rel_pos: bool = False,
|
486 |
+
rel_pos_zero_init: bool = True,
|
487 |
+
input_size: Optional[Tuple[int, int]] = None,
|
488 |
+
) -> None:
|
489 |
+
"""
|
490 |
+
Args:
|
491 |
+
dim (int): Number of input channels.
|
492 |
+
num_heads (int): Number of attention heads.
|
493 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
494 |
+
rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
495 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
496 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
497 |
+
positional parameter size.
|
498 |
+
"""
|
499 |
+
super().__init__()
|
500 |
+
self.num_heads = num_heads
|
501 |
+
head_dim = dim // num_heads
|
502 |
+
self.scale = head_dim**-0.5
|
503 |
+
|
504 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
505 |
+
self.proj = nn.Linear(dim, dim)
|
506 |
+
|
507 |
+
self.use_rel_pos = use_rel_pos
|
508 |
+
if self.use_rel_pos:
|
509 |
+
assert (
|
510 |
+
input_size is not None
|
511 |
+
), "Input size must be provided if using relative positional encoding."
|
512 |
+
# initialize relative positional embeddings
|
513 |
+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
|
514 |
+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
|
515 |
+
|
516 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
517 |
+
B, H, W, _ = x.shape
|
518 |
+
# qkv with shape (3, B, nHead, H * W, C)
|
519 |
+
qkv = (
|
520 |
+
self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
521 |
+
)
|
522 |
+
# q, k, v with shape (B * nHead, H * W, C)
|
523 |
+
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
524 |
+
|
525 |
+
attn = (q * self.scale) @ k.transpose(-2, -1)
|
526 |
+
|
527 |
+
if self.use_rel_pos:
|
528 |
+
attn = add_decomposed_rel_pos(
|
529 |
+
attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
|
530 |
+
)
|
531 |
+
|
532 |
+
attn = attn.softmax(dim=-1)
|
533 |
+
x = (
|
534 |
+
(attn @ v)
|
535 |
+
.view(B, self.num_heads, H, W, -1)
|
536 |
+
.permute(0, 2, 3, 1, 4)
|
537 |
+
.reshape(B, H, W, -1)
|
538 |
+
)
|
539 |
+
x = self.proj(x)
|
540 |
+
|
541 |
+
return x
|
542 |
+
|
543 |
+
|
544 |
+
class PatchEmbed(nn.Module):
|
545 |
+
"""
|
546 |
+
Image to Patch Embedding.
|
547 |
+
"""
|
548 |
+
|
549 |
+
def __init__(
|
550 |
+
self,
|
551 |
+
kernel_size: Tuple[int, int] = (16, 16),
|
552 |
+
stride: Tuple[int, int] = (16, 16),
|
553 |
+
padding: Tuple[int, int] = (0, 0),
|
554 |
+
in_chans: int = 3,
|
555 |
+
embed_dim: int = 768,
|
556 |
+
) -> None:
|
557 |
+
"""
|
558 |
+
Args:
|
559 |
+
kernel_size (Tuple): kernel size of the projection layer.
|
560 |
+
stride (Tuple): stride of the projection layer.
|
561 |
+
padding (Tuple): padding size of the projection layer.
|
562 |
+
in_chans (int): Number of input image channels.
|
563 |
+
embed_dim (int): Patch embedding dimension.
|
564 |
+
"""
|
565 |
+
super().__init__()
|
566 |
+
|
567 |
+
self.proj = nn.Conv2d(
|
568 |
+
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
|
569 |
+
)
|
570 |
+
|
571 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
572 |
+
x = self.proj(x)
|
573 |
+
# B C H W -> B H W C
|
574 |
+
x = x.permute(0, 2, 3, 1)
|
575 |
+
return x
|
576 |
+
|
577 |
+
|
578 |
+
class ImageEncoderViTBlock(nn.Module):
|
579 |
+
"""Transformer blocks with support of window attention and residual propagation blocks"""
|
580 |
+
|
581 |
+
def __init__(
|
582 |
+
self,
|
583 |
+
dim: int,
|
584 |
+
num_heads: int,
|
585 |
+
mlp_ratio: float = 4.0,
|
586 |
+
qkv_bias: bool = True,
|
587 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
588 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
589 |
+
use_rel_pos: bool = False,
|
590 |
+
rel_pos_zero_init: bool = True,
|
591 |
+
window_size: int = 0,
|
592 |
+
input_size: Optional[Tuple[int, int]] = None,
|
593 |
+
) -> None:
|
594 |
+
"""
|
595 |
+
Args:
|
596 |
+
dim (int): Number of input channels.
|
597 |
+
num_heads (int): Number of attention heads in each ViT block.
|
598 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
599 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
600 |
+
norm_layer (nn.Module): Normalization layer.
|
601 |
+
act_layer (nn.Module): Activation layer.
|
602 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
603 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
604 |
+
window_size (int): Window size for window attention blocks. If it equals 0, then
|
605 |
+
use global attention.
|
606 |
+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
607 |
+
positional parameter size.
|
608 |
+
"""
|
609 |
+
super().__init__()
|
610 |
+
self.norm1 = norm_layer(dim)
|
611 |
+
self.attn = ImageEncoderViTAttention(
|
612 |
+
dim,
|
613 |
+
num_heads=num_heads,
|
614 |
+
qkv_bias=qkv_bias,
|
615 |
+
use_rel_pos=use_rel_pos,
|
616 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
617 |
+
input_size=input_size if window_size == 0 else (window_size, window_size),
|
618 |
+
)
|
619 |
+
|
620 |
+
self.norm2 = norm_layer(dim)
|
621 |
+
self.mlp = CommonMLP(
|
622 |
+
embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
|
623 |
+
)
|
624 |
+
|
625 |
+
self.window_size = window_size
|
626 |
+
|
627 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
628 |
+
shortcut = x
|
629 |
+
x = self.norm1(x)
|
630 |
+
# Window partition
|
631 |
+
if self.window_size > 0:
|
632 |
+
H, W = x.shape[1], x.shape[2]
|
633 |
+
x, pad_hw = window_partition(x, self.window_size)
|
634 |
+
|
635 |
+
x = self.attn(x)
|
636 |
+
# Reverse window partition
|
637 |
+
if self.window_size > 0:
|
638 |
+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
639 |
+
|
640 |
+
x = shortcut + x
|
641 |
+
x = x + self.mlp(self.norm2(x))
|
642 |
+
|
643 |
+
return x
|
644 |
+
|
645 |
+
|
646 |
+
class ImageEncoderViT(nn.Module):
|
647 |
+
def __init__(
|
648 |
+
self,
|
649 |
+
img_size: int = 1024,
|
650 |
+
patch_size: int = 16,
|
651 |
+
in_chans: int = 3,
|
652 |
+
embed_dim: int = 768,
|
653 |
+
depth: int = 12,
|
654 |
+
num_heads: int = 12,
|
655 |
+
mlp_ratio: float = 4.0,
|
656 |
+
out_chans: int = 256,
|
657 |
+
qkv_bias: bool = True,
|
658 |
+
norm_layer: Type[nn.Module] = nn.LayerNorm,
|
659 |
+
act_layer: Type[nn.Module] = nn.GELU,
|
660 |
+
use_abs_pos: bool = True,
|
661 |
+
use_rel_pos: bool = False,
|
662 |
+
rel_pos_zero_init: bool = True,
|
663 |
+
window_size: int = 0,
|
664 |
+
global_attn_indexes: Tuple[int, ...] = (),
|
665 |
+
) -> None:
|
666 |
+
"""
|
667 |
+
Args:
|
668 |
+
img_size (int): Input image size.
|
669 |
+
patch_size (int): Patch size.
|
670 |
+
in_chans (int): Number of input image channels.
|
671 |
+
embed_dim (int): Patch embedding dimension.
|
672 |
+
depth (int): Depth of ViT.
|
673 |
+
num_heads (int): Number of attention heads in each ViT block.
|
674 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
675 |
+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
676 |
+
norm_layer (nn.Module): Normalization layer.
|
677 |
+
act_layer (nn.Module): Activation layer.
|
678 |
+
use_abs_pos (bool): If True, use absolute positional embeddings.
|
679 |
+
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
|
680 |
+
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
681 |
+
window_size (int): Window size for window attention blocks.
|
682 |
+
global_attn_indexes (list): Indexes for blocks using global attention.
|
683 |
+
"""
|
684 |
+
super().__init__()
|
685 |
+
self.img_size = img_size
|
686 |
+
self.embed_dim = embed_dim
|
687 |
+
self.out_chans = out_chans
|
688 |
+
|
689 |
+
self.patch_embed = PatchEmbed(
|
690 |
+
kernel_size=(patch_size, patch_size),
|
691 |
+
stride=(patch_size, patch_size),
|
692 |
+
in_chans=in_chans,
|
693 |
+
embed_dim=embed_dim,
|
694 |
+
)
|
695 |
+
|
696 |
+
self.pos_embed: Optional[nn.Parameter] = None
|
697 |
+
if use_abs_pos:
|
698 |
+
# Initialize absolute positional embedding with pretrain image size.
|
699 |
+
self.pos_embed = nn.Parameter(
|
700 |
+
torch.zeros(
|
701 |
+
1, img_size // patch_size, img_size // patch_size, embed_dim
|
702 |
+
)
|
703 |
+
)
|
704 |
+
|
705 |
+
self.blocks = nn.ModuleList()
|
706 |
+
for i in range(depth):
|
707 |
+
block = ImageEncoderViTBlock(
|
708 |
+
dim=embed_dim,
|
709 |
+
num_heads=num_heads,
|
710 |
+
mlp_ratio=mlp_ratio,
|
711 |
+
qkv_bias=qkv_bias,
|
712 |
+
norm_layer=norm_layer,
|
713 |
+
act_layer=act_layer,
|
714 |
+
use_rel_pos=use_rel_pos,
|
715 |
+
rel_pos_zero_init=rel_pos_zero_init,
|
716 |
+
window_size=window_size if i not in global_attn_indexes else 0,
|
717 |
+
input_size=(img_size // patch_size, img_size // patch_size),
|
718 |
+
)
|
719 |
+
self.blocks.append(block)
|
720 |
+
|
721 |
+
self.neck = nn.Sequential(
|
722 |
+
nn.Conv2d(
|
723 |
+
embed_dim,
|
724 |
+
out_chans,
|
725 |
+
kernel_size=1,
|
726 |
+
bias=False,
|
727 |
+
),
|
728 |
+
LayerNorm2d(out_chans),
|
729 |
+
nn.Conv2d(
|
730 |
+
out_chans,
|
731 |
+
out_chans,
|
732 |
+
kernel_size=3,
|
733 |
+
padding=1,
|
734 |
+
bias=False,
|
735 |
+
),
|
736 |
+
LayerNorm2d(out_chans),
|
737 |
+
)
|
738 |
+
|
739 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
740 |
+
x = self.patch_embed(x)
|
741 |
+
if self.pos_embed is not None:
|
742 |
+
x = x + self.pos_embed
|
743 |
+
|
744 |
+
for blk in self.blocks:
|
745 |
+
x = blk(x)
|
746 |
+
|
747 |
+
dtype = x.dtype
|
748 |
+
if dtype == torch.float16: # prevent overflow
|
749 |
+
with torch.autocast(device_type="cuda", dtype=torch.float32):
|
750 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
751 |
+
x = x.to(dtype)
|
752 |
+
else:
|
753 |
+
x = self.neck(x.permute(0, 3, 1, 2))
|
754 |
+
return x
|
755 |
+
|
756 |
+
|
757 |
+
class PositionEmbeddingRandom(nn.Module):
|
758 |
+
"""
|
759 |
+
Positional encoding using random spatial frequencies.
|
760 |
+
"""
|
761 |
+
|
762 |
+
def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
|
763 |
+
super().__init__()
|
764 |
+
if scale is None or scale <= 0.0:
|
765 |
+
scale = 1.0
|
766 |
+
self.register_buffer(
|
767 |
+
"positional_encoding_gaussian_matrix",
|
768 |
+
scale * torch.randn((2, num_pos_feats)),
|
769 |
+
)
|
770 |
+
|
771 |
+
def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
|
772 |
+
"""Positionally encode points that are normalized to [0,1]."""
|
773 |
+
# assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
|
774 |
+
coords = 2 * coords - 1
|
775 |
+
|
776 |
+
if coords.dtype != self.positional_encoding_gaussian_matrix.dtype:
|
777 |
+
coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)
|
778 |
+
|
779 |
+
coords = coords @ self.positional_encoding_gaussian_matrix
|
780 |
+
coords = 2 * np.pi * coords
|
781 |
+
# outputs d_1 x ... x d_n x C shape
|
782 |
+
return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
|
783 |
+
|
784 |
+
def forward(self, size: Tuple[int, int]) -> torch.Tensor:
|
785 |
+
"""Generate positional encoding for a grid of the specified size."""
|
786 |
+
h, w = size
|
787 |
+
device: Any = self.positional_encoding_gaussian_matrix.device
|
788 |
+
grid = torch.ones(
|
789 |
+
(h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype
|
790 |
+
)
|
791 |
+
y_embed = grid.cumsum(dim=0) - 0.5
|
792 |
+
x_embed = grid.cumsum(dim=1) - 0.5
|
793 |
+
y_embed = y_embed / h
|
794 |
+
x_embed = x_embed / w
|
795 |
+
|
796 |
+
pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
|
797 |
+
return pe.permute(2, 0, 1) # C x H x W
|
798 |
+
|
799 |
+
def forward_with_coords(
|
800 |
+
self, coords_input: torch.Tensor, image_size: Tuple[int, int]
|
801 |
+
) -> torch.Tensor:
|
802 |
+
"""Positionally encode points that are not normalized to [0,1]."""
|
803 |
+
coords = coords_input.clone()
|
804 |
+
coords[:, :, 0] = coords[:, :, 0] / image_size[1]
|
805 |
+
coords[:, :, 1] = coords[:, :, 1] / image_size[0]
|
806 |
+
return self._pe_encoding(coords.to(torch.float)) # B x N x C
|
807 |
+
|
808 |
+
|
809 |
+
class PromptEncoder(nn.Module):
|
810 |
+
def __init__(
|
811 |
+
self,
|
812 |
+
embed_dim: int,
|
813 |
+
image_embedding_size: Tuple[int, int],
|
814 |
+
input_image_size: Tuple[int, int],
|
815 |
+
mask_in_chans: int,
|
816 |
+
activation: Type[nn.Module] = nn.GELU,
|
817 |
+
) -> None:
|
818 |
+
"""
|
819 |
+
Encodes prompts for input to SAM's mask decoder.
|
820 |
+
|
821 |
+
Arguments:
|
822 |
+
embed_dim (int): The prompts' embedding dimension
|
823 |
+
image_embedding_size (tuple(int, int)): The spatial size of the
|
824 |
+
image embedding, as (H, W).
|
825 |
+
input_image_size (int): The padded size of the image as input
|
826 |
+
to the image encoder, as (H, W).
|
827 |
+
mask_in_chans (int): The number of hidden channels used for
|
828 |
+
encoding input masks.
|
829 |
+
activation (nn.Module): The activation to use when encoding
|
830 |
+
input masks.
|
831 |
+
"""
|
832 |
+
super().__init__()
|
833 |
+
self.embed_dim = embed_dim
|
834 |
+
self.input_image_size = input_image_size
|
835 |
+
self.image_embedding_size = image_embedding_size
|
836 |
+
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
837 |
+
|
838 |
+
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
839 |
+
point_embeddings = [
|
840 |
+
nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
|
841 |
+
]
|
842 |
+
self.point_embeddings = nn.ModuleList(point_embeddings)
|
843 |
+
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
844 |
+
|
845 |
+
self.mask_input_size = (
|
846 |
+
4 * image_embedding_size[0],
|
847 |
+
4 * image_embedding_size[1],
|
848 |
+
)
|
849 |
+
self.mask_downscaling = nn.Sequential(
|
850 |
+
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
851 |
+
LayerNorm2d(mask_in_chans // 4),
|
852 |
+
activation(),
|
853 |
+
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
854 |
+
LayerNorm2d(mask_in_chans),
|
855 |
+
activation(),
|
856 |
+
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
857 |
+
)
|
858 |
+
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
859 |
+
|
860 |
+
def get_dense_pe(self) -> torch.Tensor:
|
861 |
+
"""
|
862 |
+
Returns the positional encoding used to encode point prompts,
|
863 |
+
applied to a dense set of points the shape of the image encoding.
|
864 |
+
|
865 |
+
Returns:
|
866 |
+
torch.Tensor: Positional encoding with shape
|
867 |
+
1x(embed_dim)x(embedding_h)x(embedding_w)
|
868 |
+
"""
|
869 |
+
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
870 |
+
|
871 |
+
def _embed_points(
|
872 |
+
self,
|
873 |
+
points: torch.Tensor,
|
874 |
+
labels: torch.Tensor,
|
875 |
+
pad: bool,
|
876 |
+
) -> torch.Tensor:
|
877 |
+
"""Embeds point prompts."""
|
878 |
+
points = points + 0.5 # Shift to center of pixel
|
879 |
+
if pad:
|
880 |
+
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
|
881 |
+
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
|
882 |
+
points = torch.cat([points, padding_point], dim=1)
|
883 |
+
labels = torch.cat([labels, padding_label], dim=1)
|
884 |
+
point_embedding = self.pe_layer.forward_with_coords(
|
885 |
+
points, self.input_image_size
|
886 |
+
)
|
887 |
+
point_embedding[labels == -1] = 0.0
|
888 |
+
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
889 |
+
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
890 |
+
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
891 |
+
return point_embedding
|
892 |
+
|
893 |
+
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
894 |
+
"""Embeds box prompts."""
|
895 |
+
boxes = boxes + 0.5 # Shift to center of pixel
|
896 |
+
coords = boxes.reshape(-1, 2, 2)
|
897 |
+
corner_embedding = self.pe_layer.forward_with_coords(
|
898 |
+
coords, self.input_image_size
|
899 |
+
)
|
900 |
+
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
901 |
+
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
902 |
+
return corner_embedding
|
903 |
+
|
904 |
+
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
905 |
+
"""Embeds mask inputs."""
|
906 |
+
mask_embedding = self.mask_downscaling(masks)
|
907 |
+
return mask_embedding
|
908 |
+
|
909 |
+
def _get_batch_size(
|
910 |
+
self,
|
911 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
912 |
+
boxes: Optional[torch.Tensor],
|
913 |
+
masks: Optional[torch.Tensor],
|
914 |
+
text_embeds: Optional[torch.Tensor],
|
915 |
+
) -> int:
|
916 |
+
"""
|
917 |
+
Gets the batch size of the output given the batch size of the input prompts.
|
918 |
+
"""
|
919 |
+
if points is not None:
|
920 |
+
return points[0].shape[0]
|
921 |
+
elif boxes is not None:
|
922 |
+
return boxes.shape[0]
|
923 |
+
elif masks is not None:
|
924 |
+
return masks.shape[0]
|
925 |
+
elif text_embeds is not None:
|
926 |
+
return text_embeds.shape[0]
|
927 |
+
else:
|
928 |
+
return 1
|
929 |
+
|
930 |
+
def _get_device(self) -> torch.device:
|
931 |
+
return self.point_embeddings[0].weight.device
|
932 |
+
|
933 |
+
def forward(
|
934 |
+
self,
|
935 |
+
points: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
936 |
+
boxes: Optional[torch.Tensor],
|
937 |
+
masks: Optional[torch.Tensor],
|
938 |
+
text_embeds: Optional[torch.Tensor],
|
939 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
940 |
+
"""
|
941 |
+
Embeds different types of prompts, returning both sparse and dense
|
942 |
+
embeddings.
|
943 |
+
|
944 |
+
Arguments:
|
945 |
+
points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
|
946 |
+
and labels to embed.
|
947 |
+
boxes (torch.Tensor or none): boxes to embed
|
948 |
+
masks (torch.Tensor or none): masks to embed
|
949 |
+
|
950 |
+
Returns:
|
951 |
+
torch.Tensor: sparse embeddings for the points and boxes, with shape
|
952 |
+
BxNx(embed_dim), where N is determined by the number of input points
|
953 |
+
and boxes.
|
954 |
+
torch.Tensor: dense embeddings for the masks, in the shape
|
955 |
+
Bx(embed_dim)x(embed_H)x(embed_W)
|
956 |
+
"""
|
957 |
+
bs = self._get_batch_size(points, boxes, masks, text_embeds)
|
958 |
+
sparse_embeddings = torch.empty(
|
959 |
+
(bs, 0, self.embed_dim), device=self._get_device()
|
960 |
+
)
|
961 |
+
if points is not None:
|
962 |
+
coords, labels = points
|
963 |
+
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
964 |
+
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
965 |
+
if boxes is not None:
|
966 |
+
box_embeddings = self._embed_boxes(boxes)
|
967 |
+
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
968 |
+
|
969 |
+
if text_embeds is not None:
|
970 |
+
sparse_embeddings = torch.cat([sparse_embeddings, text_embeds], dim=1)
|
971 |
+
|
972 |
+
if masks is not None:
|
973 |
+
dense_embeddings = self._embed_masks(masks)
|
974 |
+
else:
|
975 |
+
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
976 |
+
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
977 |
+
)
|
978 |
+
|
979 |
+
return sparse_embeddings, dense_embeddings
|
980 |
+
|
981 |
+
|
982 |
+
class MaskDecoderMLP(nn.Module):
|
983 |
+
def __init__(
|
984 |
+
self,
|
985 |
+
input_dim: int,
|
986 |
+
hidden_dim: int,
|
987 |
+
output_dim: int,
|
988 |
+
num_layers: int,
|
989 |
+
sigmoid_output: bool = False,
|
990 |
+
) -> None:
|
991 |
+
super().__init__()
|
992 |
+
self.num_layers = num_layers
|
993 |
+
h = [hidden_dim] * (num_layers - 1)
|
994 |
+
self.layers = nn.ModuleList(
|
995 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
996 |
+
)
|
997 |
+
self.sigmoid_output = sigmoid_output
|
998 |
+
|
999 |
+
def forward(self, x):
|
1000 |
+
for i, layer in enumerate(self.layers):
|
1001 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
1002 |
+
if self.sigmoid_output:
|
1003 |
+
x = F.sigmoid(x)
|
1004 |
+
return x
|
1005 |
+
|
1006 |
+
|
1007 |
+
class MaskDecoder(nn.Module):
|
1008 |
+
def __init__(
|
1009 |
+
self,
|
1010 |
+
*,
|
1011 |
+
transformer_dim: int,
|
1012 |
+
transformer: nn.Module,
|
1013 |
+
num_multimask_outputs: int = 3,
|
1014 |
+
activation: Type[nn.Module] = nn.GELU,
|
1015 |
+
iou_head_depth: int = 3,
|
1016 |
+
iou_head_hidden_dim: int = 256,
|
1017 |
+
) -> None:
|
1018 |
+
"""
|
1019 |
+
Predicts masks given an image and prompt embeddings, using a
|
1020 |
+
transformer architecture.
|
1021 |
+
|
1022 |
+
Arguments:
|
1023 |
+
transformer_dim (int): the channel dimension of the transformer
|
1024 |
+
transformer (nn.Module): the transformer used to predict masks
|
1025 |
+
num_multimask_outputs (int): the number of masks to predict
|
1026 |
+
when disambiguating masks
|
1027 |
+
activation (nn.Module): the type of activation to use when
|
1028 |
+
upscaling masks
|
1029 |
+
iou_head_depth (int): the depth of the MLP used to predict
|
1030 |
+
mask quality
|
1031 |
+
iou_head_hidden_dim (int): the hidden dimension of the MLP
|
1032 |
+
used to predict mask quality
|
1033 |
+
"""
|
1034 |
+
super().__init__()
|
1035 |
+
self.transformer_dim = transformer_dim
|
1036 |
+
self.transformer = transformer
|
1037 |
+
|
1038 |
+
self.num_multimask_outputs = num_multimask_outputs
|
1039 |
+
|
1040 |
+
self.iou_token = nn.Embedding(1, transformer_dim)
|
1041 |
+
self.num_mask_tokens = num_multimask_outputs + 1
|
1042 |
+
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
1043 |
+
|
1044 |
+
self.output_upscaling = nn.Sequential(
|
1045 |
+
nn.ConvTranspose2d(
|
1046 |
+
transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
|
1047 |
+
),
|
1048 |
+
LayerNorm2d(transformer_dim // 4),
|
1049 |
+
activation(),
|
1050 |
+
nn.ConvTranspose2d(
|
1051 |
+
transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
|
1052 |
+
),
|
1053 |
+
activation(),
|
1054 |
+
)
|
1055 |
+
self.output_hypernetworks_mlps = nn.ModuleList(
|
1056 |
+
[
|
1057 |
+
MaskDecoderMLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
|
1058 |
+
for i in range(self.num_mask_tokens)
|
1059 |
+
]
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
self.iou_prediction_head = MaskDecoderMLP(
|
1063 |
+
transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
|
1064 |
+
)
|
1065 |
+
|
1066 |
+
def forward(
|
1067 |
+
self,
|
1068 |
+
image_embeddings: torch.Tensor,
|
1069 |
+
image_pe: torch.Tensor,
|
1070 |
+
sparse_prompt_embeddings: torch.Tensor,
|
1071 |
+
dense_prompt_embeddings: torch.Tensor,
|
1072 |
+
multimask_output: bool,
|
1073 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1074 |
+
"""
|
1075 |
+
Predict masks given image and prompt embeddings.
|
1076 |
+
|
1077 |
+
Arguments:
|
1078 |
+
image_embeddings (torch.Tensor): the embeddings from the image encoder
|
1079 |
+
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
|
1080 |
+
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
|
1081 |
+
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
|
1082 |
+
multimask_output (bool): Whether to return multiple masks or a single
|
1083 |
+
mask.
|
1084 |
+
|
1085 |
+
Returns:
|
1086 |
+
torch.Tensor: batched predicted masks
|
1087 |
+
torch.Tensor: batched predictions of mask quality
|
1088 |
+
"""
|
1089 |
+
masks, iou_pred = self.predict_masks(
|
1090 |
+
image_embeddings=image_embeddings,
|
1091 |
+
image_pe=image_pe,
|
1092 |
+
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
1093 |
+
dense_prompt_embeddings=dense_prompt_embeddings,
|
1094 |
+
)
|
1095 |
+
|
1096 |
+
# Select the correct mask or masks for output
|
1097 |
+
if multimask_output:
|
1098 |
+
mask_slice = slice(1, None)
|
1099 |
+
else:
|
1100 |
+
mask_slice = slice(0, 1)
|
1101 |
+
masks = masks[:, mask_slice, :, :]
|
1102 |
+
iou_pred = iou_pred[:, mask_slice]
|
1103 |
+
|
1104 |
+
# Prepare output
|
1105 |
+
return masks, iou_pred
|
1106 |
+
|
1107 |
+
def predict_masks(
|
1108 |
+
self,
|
1109 |
+
image_embeddings: torch.Tensor,
|
1110 |
+
image_pe: torch.Tensor,
|
1111 |
+
sparse_prompt_embeddings: torch.Tensor,
|
1112 |
+
dense_prompt_embeddings: torch.Tensor,
|
1113 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1114 |
+
"""Predicts masks. See 'forward' for more details."""
|
1115 |
+
# Concatenate output tokens
|
1116 |
+
output_tokens = torch.cat(
|
1117 |
+
[self.iou_token.weight, self.mask_tokens.weight], dim=0
|
1118 |
+
)
|
1119 |
+
output_tokens = output_tokens.unsqueeze(0).expand(
|
1120 |
+
sparse_prompt_embeddings.size(0), -1, -1
|
1121 |
+
)
|
1122 |
+
|
1123 |
+
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
1124 |
+
|
1125 |
+
# image_embeddings: [1, C, H, W], tokens: [B, N, C]
|
1126 |
+
# dense_prompt_embeddings: [B, C, H, W]
|
1127 |
+
# Expand per-image data in batch direction to be per-mask
|
1128 |
+
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
1129 |
+
src = src + dense_prompt_embeddings
|
1130 |
+
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
1131 |
+
b, c, h, w = src.shape
|
1132 |
+
|
1133 |
+
# Run the transformer
|
1134 |
+
hs, src = self.transformer(src, pos_src, tokens)
|
1135 |
+
iou_token_out = hs[:, 0, :]
|
1136 |
+
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
1137 |
+
|
1138 |
+
# Upscale mask embeddings and predict masks using the mask tokens
|
1139 |
+
src = src.transpose(1, 2).view(b, c, h, w)
|
1140 |
+
upscaled_embedding = self.output_upscaling(src)
|
1141 |
+
hyper_in_list: List[torch.Tensor] = []
|
1142 |
+
for i in range(self.num_mask_tokens):
|
1143 |
+
hyper_in_list.append(
|
1144 |
+
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
|
1145 |
+
)
|
1146 |
+
hyper_in = torch.stack(hyper_in_list, dim=1)
|
1147 |
+
b, c, h, w = upscaled_embedding.shape
|
1148 |
+
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(
|
1149 |
+
b, self.num_mask_tokens, h, w
|
1150 |
+
)
|
1151 |
+
|
1152 |
+
# Generate mask quality predictions
|
1153 |
+
iou_pred = self.iou_prediction_head(iou_token_out)
|
1154 |
+
|
1155 |
+
return masks, iou_pred
|
1156 |
+
|
1157 |
+
|
1158 |
+
class Sam(nn.Module):
|
1159 |
+
mask_threshold: float = 0.0
|
1160 |
+
image_format: str = "RGB"
|
1161 |
+
|
1162 |
+
def __init__(
|
1163 |
+
self,
|
1164 |
+
image_encoder: ImageEncoderViT,
|
1165 |
+
prompt_encoder: PromptEncoder,
|
1166 |
+
mask_decoder: MaskDecoder,
|
1167 |
+
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
1168 |
+
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
1169 |
+
) -> None:
|
1170 |
+
"""
|
1171 |
+
SAM predicts object masks from an image and input prompts.
|
1172 |
+
|
1173 |
+
Arguments:
|
1174 |
+
image_encoder (ImageEncoderViT): The backbone used to encode the
|
1175 |
+
image into image embeddings that allow for efficient mask prediction.
|
1176 |
+
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
|
1177 |
+
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
|
1178 |
+
and encoded prompts.
|
1179 |
+
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
1180 |
+
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
1181 |
+
"""
|
1182 |
+
super().__init__()
|
1183 |
+
self.image_encoder = image_encoder
|
1184 |
+
self.prompt_encoder = prompt_encoder
|
1185 |
+
self.mask_decoder = mask_decoder
|
1186 |
+
self.register_buffer(
|
1187 |
+
"pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
|
1188 |
+
)
|
1189 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
1190 |
+
|
1191 |
+
@property
|
1192 |
+
def device(self) -> Any:
|
1193 |
+
return self.pixel_mean.device
|
1194 |
+
|
1195 |
+
@torch.no_grad()
|
1196 |
+
def forward(
|
1197 |
+
self,
|
1198 |
+
batched_input: List[Dict[str, Any]],
|
1199 |
+
multimask_output: bool,
|
1200 |
+
) -> List[Dict[str, torch.Tensor]]:
|
1201 |
+
"""
|
1202 |
+
Predicts masks end-to-end from provided images and prompts.
|
1203 |
+
If prompts are not known in advance, using SamPredictor is
|
1204 |
+
recommended over calling the model directly.
|
1205 |
+
|
1206 |
+
Arguments:
|
1207 |
+
batched_input (list(dict)): A list over input images, each a
|
1208 |
+
dictionary with the following keys. A prompt key can be
|
1209 |
+
excluded if it is not present.
|
1210 |
+
'image': The image as a torch tensor in 3xHxW format,
|
1211 |
+
already transformed for input to the model.
|
1212 |
+
'original_size': (tuple(int, int)) The original size of
|
1213 |
+
the image before transformation, as (H, W).
|
1214 |
+
'point_coords': (torch.Tensor) Batched point prompts for
|
1215 |
+
this image, with shape BxNx2. Already transformed to the
|
1216 |
+
input frame of the model.
|
1217 |
+
'point_labels': (torch.Tensor) Batched labels for point prompts,
|
1218 |
+
with shape BxN.
|
1219 |
+
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
|
1220 |
+
Already transformed to the input frame of the model.
|
1221 |
+
'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
|
1222 |
+
in the form Bx1xHxW.
|
1223 |
+
multimask_output (bool): Whether the model should predict multiple
|
1224 |
+
disambiguating masks, or return a single mask.
|
1225 |
+
|
1226 |
+
Returns:
|
1227 |
+
(list(dict)): A list over input images, where each element is
|
1228 |
+
as dictionary with the following keys.
|
1229 |
+
'masks': (torch.Tensor) Batched binary mask predictions,
|
1230 |
+
with shape BxCxHxW, where B is the number of input prompts,
|
1231 |
+
C is determined by multimask_output, and (H, W) is the
|
1232 |
+
original size of the image.
|
1233 |
+
'iou_predictions': (torch.Tensor) The model's predictions
|
1234 |
+
of mask quality, in shape BxC.
|
1235 |
+
'low_res_logits': (torch.Tensor) Low resolution logits with
|
1236 |
+
shape BxCxHxW, where H=W=256. Can be passed as mask input
|
1237 |
+
to subsequent iterations of prediction.
|
1238 |
+
"""
|
1239 |
+
input_images = torch.stack(
|
1240 |
+
[self.preprocess(x["image"]) for x in batched_input], dim=0
|
1241 |
+
)
|
1242 |
+
image_embeddings = self.image_encoder(input_images)
|
1243 |
+
|
1244 |
+
outputs = []
|
1245 |
+
for image_record, curr_embedding in zip(batched_input, image_embeddings):
|
1246 |
+
if "point_coords" in image_record:
|
1247 |
+
points = (image_record["point_coords"], image_record["point_labels"])
|
1248 |
+
else:
|
1249 |
+
points = None
|
1250 |
+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
|
1251 |
+
points=points,
|
1252 |
+
boxes=image_record.get("boxes", None),
|
1253 |
+
masks=image_record.get("mask_inputs", None),
|
1254 |
+
)
|
1255 |
+
low_res_masks, iou_predictions = self.mask_decoder(
|
1256 |
+
image_embeddings=curr_embedding.unsqueeze(0),
|
1257 |
+
image_pe=self.prompt_encoder.get_dense_pe(),
|
1258 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
1259 |
+
dense_prompt_embeddings=dense_embeddings,
|
1260 |
+
multimask_output=multimask_output,
|
1261 |
+
)
|
1262 |
+
masks = self.postprocess_masks(
|
1263 |
+
low_res_masks,
|
1264 |
+
input_size=image_record["image"].shape[-2:],
|
1265 |
+
original_size=image_record["original_size"],
|
1266 |
+
)
|
1267 |
+
masks = masks > self.mask_threshold
|
1268 |
+
outputs.append(
|
1269 |
+
{
|
1270 |
+
"masks": masks,
|
1271 |
+
"iou_predictions": iou_predictions,
|
1272 |
+
"low_res_logits": low_res_masks,
|
1273 |
+
}
|
1274 |
+
)
|
1275 |
+
return outputs
|
1276 |
+
|
1277 |
+
def postprocess_masks(
|
1278 |
+
self,
|
1279 |
+
masks: torch.Tensor,
|
1280 |
+
input_size: Tuple[int, ...],
|
1281 |
+
original_size: Tuple[int, ...],
|
1282 |
+
) -> torch.Tensor:
|
1283 |
+
"""
|
1284 |
+
Remove padding and upscale masks to the original image size.
|
1285 |
+
|
1286 |
+
Arguments:
|
1287 |
+
masks (torch.Tensor): Batched masks from the mask_decoder,
|
1288 |
+
in BxCxHxW format.
|
1289 |
+
input_size (tuple(int, int)): The size of the image input to the
|
1290 |
+
model, in (H, W) format. Used to remove padding.
|
1291 |
+
original_size (tuple(int, int)): The original size of the image
|
1292 |
+
before resizing for input to the model, in (H, W) format.
|
1293 |
+
|
1294 |
+
Returns:
|
1295 |
+
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
1296 |
+
is given by original_size.
|
1297 |
+
"""
|
1298 |
+
|
1299 |
+
dtype = masks.dtype
|
1300 |
+
|
1301 |
+
masks = F.interpolate(
|
1302 |
+
masks.float(),
|
1303 |
+
(self.image_encoder.img_size, self.image_encoder.img_size),
|
1304 |
+
mode="bilinear",
|
1305 |
+
align_corners=False,
|
1306 |
+
)
|
1307 |
+
# masks = masks.to(dtype)
|
1308 |
+
masks = masks[..., : input_size[0], : input_size[1]]
|
1309 |
+
masks = F.interpolate(
|
1310 |
+
masks, original_size, mode="bilinear", align_corners=False
|
1311 |
+
)
|
1312 |
+
return masks
|
1313 |
+
|
1314 |
+
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
1315 |
+
"""Normalize pixel values and pad to a square input."""
|
1316 |
+
# Normalize colors
|
1317 |
+
x = (x - self.pixel_mean) / self.pixel_std
|
1318 |
+
|
1319 |
+
# Pad
|
1320 |
+
h, w = x.shape[-2:]
|
1321 |
+
padh = self.image_encoder.img_size - h
|
1322 |
+
padw = self.image_encoder.img_size - w
|
1323 |
+
x = F.pad(x, (0, padw, 0, padh))
|
1324 |
+
return x
|
special_tokens_map.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|im_start|>",
|
4 |
+
"<|im_end|>",
|
5 |
+
"<|object_ref_start|>",
|
6 |
+
"<|object_ref_end|>",
|
7 |
+
"<|box_start|>",
|
8 |
+
"<|box_end|>",
|
9 |
+
"<|quad_start|>",
|
10 |
+
"<|quad_end|>",
|
11 |
+
"<|vision_start|>",
|
12 |
+
"<|vision_end|>",
|
13 |
+
"<|vision_pad|>",
|
14 |
+
"<|image_pad|>",
|
15 |
+
"<|video_pad|>",
|
16 |
+
"[SEG]"
|
17 |
+
],
|
18 |
+
"eos_token": {
|
19 |
+
"content": "<|im_end|>",
|
20 |
+
"lstrip": false,
|
21 |
+
"normalized": false,
|
22 |
+
"rstrip": false,
|
23 |
+
"single_word": false
|
24 |
+
},
|
25 |
+
"pad_token": {
|
26 |
+
"content": "<|endoftext|>",
|
27 |
+
"lstrip": false,
|
28 |
+
"normalized": false,
|
29 |
+
"rstrip": false,
|
30 |
+
"single_word": false
|
31 |
+
}
|
32 |
+
}
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b376a39f73c663abe179f7ae1cae86da5dd2b689bcce504348d4fb00e6a64240
|
3 |
+
size 11422078
|
tokenizer_config.json
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": false,
|
3 |
+
"add_prefix_space": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"151643": {
|
6 |
+
"content": "<|endoftext|>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"151644": {
|
14 |
+
"content": "<|im_start|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"151645": {
|
22 |
+
"content": "<|im_end|>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
},
|
29 |
+
"151646": {
|
30 |
+
"content": "<|object_ref_start|>",
|
31 |
+
"lstrip": false,
|
32 |
+
"normalized": false,
|
33 |
+
"rstrip": false,
|
34 |
+
"single_word": false,
|
35 |
+
"special": true
|
36 |
+
},
|
37 |
+
"151647": {
|
38 |
+
"content": "<|object_ref_end|>",
|
39 |
+
"lstrip": false,
|
40 |
+
"normalized": false,
|
41 |
+
"rstrip": false,
|
42 |
+
"single_word": false,
|
43 |
+
"special": true
|
44 |
+
},
|
45 |
+
"151648": {
|
46 |
+
"content": "<|box_start|>",
|
47 |
+
"lstrip": false,
|
48 |
+
"normalized": false,
|
49 |
+
"rstrip": false,
|
50 |
+
"single_word": false,
|
51 |
+
"special": true
|
52 |
+
},
|
53 |
+
"151649": {
|
54 |
+
"content": "<|box_end|>",
|
55 |
+
"lstrip": false,
|
56 |
+
"normalized": false,
|
57 |
+
"rstrip": false,
|
58 |
+
"single_word": false,
|
59 |
+
"special": true
|
60 |
+
},
|
61 |
+
"151650": {
|
62 |
+
"content": "<|quad_start|>",
|
63 |
+
"lstrip": false,
|
64 |
+
"normalized": false,
|
65 |
+
"rstrip": false,
|
66 |
+
"single_word": false,
|
67 |
+
"special": true
|
68 |
+
},
|
69 |
+
"151651": {
|
70 |
+
"content": "<|quad_end|>",
|
71 |
+
"lstrip": false,
|
72 |
+
"normalized": false,
|
73 |
+
"rstrip": false,
|
74 |
+
"single_word": false,
|
75 |
+
"special": true
|
76 |
+
},
|
77 |
+
"151652": {
|
78 |
+
"content": "<|vision_start|>",
|
79 |
+
"lstrip": false,
|
80 |
+
"normalized": false,
|
81 |
+
"rstrip": false,
|
82 |
+
"single_word": false,
|
83 |
+
"special": true
|
84 |
+
},
|
85 |
+
"151653": {
|
86 |
+
"content": "<|vision_end|>",
|
87 |
+
"lstrip": false,
|
88 |
+
"normalized": false,
|
89 |
+
"rstrip": false,
|
90 |
+
"single_word": false,
|
91 |
+
"special": true
|
92 |
+
},
|
93 |
+
"151654": {
|
94 |
+
"content": "<|vision_pad|>",
|
95 |
+
"lstrip": false,
|
96 |
+
"normalized": false,
|
97 |
+
"rstrip": false,
|
98 |
+
"single_word": false,
|
99 |
+
"special": true
|
100 |
+
},
|
101 |
+
"151655": {
|
102 |
+
"content": "<|image_pad|>",
|
103 |
+
"lstrip": false,
|
104 |
+
"normalized": false,
|
105 |
+
"rstrip": false,
|
106 |
+
"single_word": false,
|
107 |
+
"special": true
|
108 |
+
},
|
109 |
+
"151656": {
|
110 |
+
"content": "<|video_pad|>",
|
111 |
+
"lstrip": false,
|
112 |
+
"normalized": false,
|
113 |
+
"rstrip": false,
|
114 |
+
"single_word": false,
|
115 |
+
"special": true
|
116 |
+
},
|
117 |
+
"151657": {
|
118 |
+
"content": "<tool_call>",
|
119 |
+
"lstrip": false,
|
120 |
+
"normalized": false,
|
121 |
+
"rstrip": false,
|
122 |
+
"single_word": false,
|
123 |
+
"special": false
|
124 |
+
},
|
125 |
+
"151658": {
|
126 |
+
"content": "</tool_call>",
|
127 |
+
"lstrip": false,
|
128 |
+
"normalized": false,
|
129 |
+
"rstrip": false,
|
130 |
+
"single_word": false,
|
131 |
+
"special": false
|
132 |
+
},
|
133 |
+
"151659": {
|
134 |
+
"content": "<|fim_prefix|>",
|
135 |
+
"lstrip": false,
|
136 |
+
"normalized": false,
|
137 |
+
"rstrip": false,
|
138 |
+
"single_word": false,
|
139 |
+
"special": false
|
140 |
+
},
|
141 |
+
"151660": {
|
142 |
+
"content": "<|fim_middle|>",
|
143 |
+
"lstrip": false,
|
144 |
+
"normalized": false,
|
145 |
+
"rstrip": false,
|
146 |
+
"single_word": false,
|
147 |
+
"special": false
|
148 |
+
},
|
149 |
+
"151661": {
|
150 |
+
"content": "<|fim_suffix|>",
|
151 |
+
"lstrip": false,
|
152 |
+
"normalized": false,
|
153 |
+
"rstrip": false,
|
154 |
+
"single_word": false,
|
155 |
+
"special": false
|
156 |
+
},
|
157 |
+
"151662": {
|
158 |
+
"content": "<|fim_pad|>",
|
159 |
+
"lstrip": false,
|
160 |
+
"normalized": false,
|
161 |
+
"rstrip": false,
|
162 |
+
"single_word": false,
|
163 |
+
"special": false
|
164 |
+
},
|
165 |
+
"151663": {
|
166 |
+
"content": "<|repo_name|>",
|
167 |
+
"lstrip": false,
|
168 |
+
"normalized": false,
|
169 |
+
"rstrip": false,
|
170 |
+
"single_word": false,
|
171 |
+
"special": false
|
172 |
+
},
|
173 |
+
"151664": {
|
174 |
+
"content": "<|file_sep|>",
|
175 |
+
"lstrip": false,
|
176 |
+
"normalized": false,
|
177 |
+
"rstrip": false,
|
178 |
+
"single_word": false,
|
179 |
+
"special": false
|
180 |
+
},
|
181 |
+
"151665": {
|
182 |
+
"content": "[SEG]",
|
183 |
+
"lstrip": false,
|
184 |
+
"normalized": false,
|
185 |
+
"rstrip": false,
|
186 |
+
"single_word": false,
|
187 |
+
"special": true
|
188 |
+
}
|
189 |
+
},
|
190 |
+
"additional_special_tokens": [
|
191 |
+
"<|im_start|>",
|
192 |
+
"<|im_end|>",
|
193 |
+
"<|object_ref_start|>",
|
194 |
+
"<|object_ref_end|>",
|
195 |
+
"<|box_start|>",
|
196 |
+
"<|box_end|>",
|
197 |
+
"<|quad_start|>",
|
198 |
+
"<|quad_end|>",
|
199 |
+
"<|vision_start|>",
|
200 |
+
"<|vision_end|>",
|
201 |
+
"<|vision_pad|>",
|
202 |
+
"<|image_pad|>",
|
203 |
+
"<|video_pad|>",
|
204 |
+
"[SEG]"
|
205 |
+
],
|
206 |
+
"bos_token": null,
|
207 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
|
208 |
+
"clean_up_tokenization_spaces": false,
|
209 |
+
"eos_token": "<|im_end|>",
|
210 |
+
"errors": "replace",
|
211 |
+
"extra_special_tokens": {},
|
212 |
+
"model_max_length": 32768,
|
213 |
+
"pad_token": "<|endoftext|>",
|
214 |
+
"padding_side": "right",
|
215 |
+
"split_special_tokens": false,
|
216 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
217 |
+
"unk_token": null
|
218 |
+
}
|
transform.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
File separation from META Sam project @DeepGlint 2025
|
3 |
+
https://github.com/facebookresearch/segment-anything
|
4 |
+
'''
|
5 |
+
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from typing import Tuple
|
11 |
+
from copy import deepcopy
|
12 |
+
from torchvision.transforms.functional import resize # type: ignore
|
13 |
+
from torchvision.transforms.functional import to_pil_image
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class ResizeLongestSide:
|
18 |
+
"""
|
19 |
+
Resizes images to the longest side 'target_length', as well as provides
|
20 |
+
methods for resizing coordinates and boxes. Provides methods for
|
21 |
+
transforming both numpy array and batched torch tensors.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, target_length: int) -> None:
|
25 |
+
self.target_length = target_length
|
26 |
+
|
27 |
+
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
28 |
+
"""
|
29 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
30 |
+
"""
|
31 |
+
target_size = self.get_preprocess_shape(
|
32 |
+
image.shape[0], image.shape[1], self.target_length
|
33 |
+
)
|
34 |
+
return np.array(resize(to_pil_image(image), target_size))
|
35 |
+
|
36 |
+
def apply_coords(
|
37 |
+
self, coords: np.ndarray, original_size: Tuple[int, ...]
|
38 |
+
) -> np.ndarray:
|
39 |
+
"""
|
40 |
+
Expects a numpy array of length 2 in the final dimension. Requires the
|
41 |
+
original image size in (H, W) format.
|
42 |
+
"""
|
43 |
+
old_h, old_w = original_size
|
44 |
+
new_h, new_w = self.get_preprocess_shape(
|
45 |
+
original_size[0], original_size[1], self.target_length
|
46 |
+
)
|
47 |
+
coords = deepcopy(coords).astype(float)
|
48 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
49 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
50 |
+
return coords
|
51 |
+
|
52 |
+
def apply_boxes(
|
53 |
+
self, boxes: np.ndarray, original_size: Tuple[int, ...]
|
54 |
+
) -> np.ndarray:
|
55 |
+
"""
|
56 |
+
Expects a numpy array shape Bx4. Requires the original image size
|
57 |
+
in (H, W) format.
|
58 |
+
"""
|
59 |
+
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
60 |
+
return boxes.reshape(-1, 4)
|
61 |
+
|
62 |
+
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
63 |
+
"""
|
64 |
+
Expects batched images with shape BxCxHxW and float format. This
|
65 |
+
transformation may not exactly match apply_image. apply_image is
|
66 |
+
the transformation expected by the model.
|
67 |
+
"""
|
68 |
+
# Expects an image in BCHW format. May not exactly match apply_image.
|
69 |
+
target_size = self.get_preprocess_shape(
|
70 |
+
image.shape[0], image.shape[1], self.target_length
|
71 |
+
)
|
72 |
+
return F.interpolate(
|
73 |
+
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
74 |
+
)
|
75 |
+
|
76 |
+
def apply_coords_torch(
|
77 |
+
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
78 |
+
) -> torch.Tensor:
|
79 |
+
"""
|
80 |
+
Expects a torch tensor with length 2 in the last dimension. Requires the
|
81 |
+
original image size in (H, W) format.
|
82 |
+
"""
|
83 |
+
old_h, old_w = original_size
|
84 |
+
new_h, new_w = self.get_preprocess_shape(
|
85 |
+
original_size[0], original_size[1], self.target_length
|
86 |
+
)
|
87 |
+
coords = deepcopy(coords).to(torch.float)
|
88 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
89 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
90 |
+
return coords
|
91 |
+
|
92 |
+
def apply_boxes_torch(
|
93 |
+
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
94 |
+
) -> torch.Tensor:
|
95 |
+
"""
|
96 |
+
Expects a torch tensor with shape Bx4. Requires the original image
|
97 |
+
size in (H, W) format.
|
98 |
+
"""
|
99 |
+
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
100 |
+
return boxes.reshape(-1, 4)
|
101 |
+
|
102 |
+
@staticmethod
|
103 |
+
def get_preprocess_shape(
|
104 |
+
oldh: int, oldw: int, long_side_length: int
|
105 |
+
) -> Tuple[int, int]:
|
106 |
+
"""
|
107 |
+
Compute the output size given input size and target long side length.
|
108 |
+
"""
|
109 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
110 |
+
newh, neww = oldh * scale, oldw * scale
|
111 |
+
neww = int(neww + 0.5)
|
112 |
+
newh = int(newh + 0.5)
|
113 |
+
return (newh, neww)
|
vision_projector.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
CopyRight @DeepGlint 2025
|
3 |
+
'''
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
import re
|
9 |
+
|
10 |
+
class IdentityMap(nn.Module):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
def forward(self, x, *args, **kwargs):
|
15 |
+
return x
|
16 |
+
|
17 |
+
@property
|
18 |
+
def config(self):
|
19 |
+
return {"mm_projector_type": "identity"}
|
20 |
+
|
21 |
+
|
22 |
+
class SimpleResBlock(nn.Module):
|
23 |
+
def __init__(self, channels):
|
24 |
+
super().__init__()
|
25 |
+
self.pre_norm = nn.LayerNorm(channels)
|
26 |
+
|
27 |
+
self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
x = self.pre_norm(x)
|
31 |
+
return x + self.proj(x)
|
32 |
+
|
33 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
34 |
+
projector_type = getattr(config, "mm_projector_type", "linear")
|
35 |
+
|
36 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
37 |
+
if mlp_gelu_match:
|
38 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
39 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
40 |
+
for _ in range(1, mlp_depth):
|
41 |
+
modules.append(nn.GELU())
|
42 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
43 |
+
return nn.Sequential(*modules)
|
44 |
+
|
45 |
+
mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
|
46 |
+
if mlp_gelu_resnet_match:
|
47 |
+
mlp_depth = int(mlp_gelu_resnet_match.group(1))
|
48 |
+
res_depth = int(mlp_gelu_resnet_match.group(2))
|
49 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
50 |
+
for _ in range(1, mlp_depth):
|
51 |
+
modules.append(nn.GELU())
|
52 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
53 |
+
for _ in range(res_depth):
|
54 |
+
modules.append(SimpleResBlock(config.hidden_size))
|
55 |
+
return nn.Sequential(*modules)
|
56 |
+
|
57 |
+
raise ValueError(f"Unknown projector type: {projector_type}")
|
vision_resampler.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
CopyRight @DeepGlint 2025
|
3 |
+
'''
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
def build_vision_resampler(model_args, delay_load=False, **kwargs):
|
11 |
+
return IdentityMap()
|
12 |
+
|
13 |
+
|
14 |
+
class IdentityMap(torch.nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
def forward(self, x, *args, **kwargs):
|
19 |
+
return x
|
20 |
+
|
21 |
+
@property
|
22 |
+
def config(self):
|
23 |
+
return {"mm_resampler_type": None}
|
vision_tower.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
CopyRight @DeepGlint 2025
|
3 |
+
'''
|
4 |
+
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
9 |
+
|
10 |
+
|
11 |
+
def build_vision_tower(model_cfg, **kwargs):
|
12 |
+
vision_tower = getattr(model_cfg, "vision_tower_config", getattr(model_cfg, "vision_tower", None))
|
13 |
+
return CLIPVisionTower(vision_tower, args=model_cfg, **kwargs)
|
14 |
+
|
15 |
+
|
16 |
+
class CLIPVisionTower(nn.Module):
|
17 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
self.is_loaded = False
|
21 |
+
|
22 |
+
self.vision_tower_cfg = vision_tower
|
23 |
+
self.vision_tower_processor = args.vision_tower_processor
|
24 |
+
self.select_layer = args.mm_vision_select_layer
|
25 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
26 |
+
|
27 |
+
if not delay_load:
|
28 |
+
self.init_model()
|
29 |
+
elif getattr(args, "unfreeze_mm_vision_tower", False):
|
30 |
+
# TODO: better detector is needed.
|
31 |
+
self.init_model()
|
32 |
+
elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
|
33 |
+
self.init_model()
|
34 |
+
else:
|
35 |
+
raise RuntimeError("Not support now, please check config.json or contact us")
|
36 |
+
|
37 |
+
def init_model(self, device_map=None):
|
38 |
+
if self.is_loaded:
|
39 |
+
return
|
40 |
+
vision_tower_config = CLIPVisionConfig().from_dict(self.vision_tower_cfg)
|
41 |
+
self.image_processor = CLIPImageProcessor(**self.vision_tower_processor)
|
42 |
+
self.vision_tower = CLIPVisionModel(config=vision_tower_config)
|
43 |
+
self.vision_tower.requires_grad_(False)
|
44 |
+
|
45 |
+
self.is_loaded = True
|
46 |
+
|
47 |
+
def feature_select(self, image_forward_outs):
|
48 |
+
select_feature_type = self.select_feature
|
49 |
+
|
50 |
+
if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
|
51 |
+
select_every_k_layer = len(image_forward_outs.hidden_states) // 4
|
52 |
+
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
|
53 |
+
select_feature_type = select_feature_type.replace("slicefour_", "")
|
54 |
+
elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
|
55 |
+
select_layers = [-2, -5, -8, -11, 6]
|
56 |
+
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
|
57 |
+
select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
|
58 |
+
else:
|
59 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
60 |
+
|
61 |
+
if select_feature_type == "patch":
|
62 |
+
image_features = image_features[:, 1:]
|
63 |
+
elif select_feature_type == "cls_patch":
|
64 |
+
image_features = image_features
|
65 |
+
else:
|
66 |
+
raise ValueError(f"Unexpected select feature: {select_feature_type}")
|
67 |
+
return image_features
|
68 |
+
|
69 |
+
def forward(self, images):
|
70 |
+
if type(images) is list:
|
71 |
+
image_features = []
|
72 |
+
for image in images:
|
73 |
+
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
|
74 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
75 |
+
image_features.append(image_feature)
|
76 |
+
else:
|
77 |
+
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
|
78 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
79 |
+
|
80 |
+
return image_features
|
81 |
+
|
82 |
+
@property
|
83 |
+
def dummy_feature(self):
|
84 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
85 |
+
|
86 |
+
@property
|
87 |
+
def dtype(self):
|
88 |
+
return self.vision_tower.dtype
|
89 |
+
|
90 |
+
@property
|
91 |
+
def device(self):
|
92 |
+
return self.vision_tower.device
|
93 |
+
|
94 |
+
@property
|
95 |
+
def config(self):
|
96 |
+
if self.is_loaded:
|
97 |
+
return self.vision_tower.config
|
98 |
+
else:
|
99 |
+
return self.cfg_only
|
100 |
+
|
101 |
+
@property
|
102 |
+
def hidden_size(self):
|
103 |
+
_hidden_size = self.config.hidden_size
|
104 |
+
if "slicefour" in self.select_feature:
|
105 |
+
_hidden_size *= 4
|
106 |
+
if "slice_m25811_f6" in self.select_feature:
|
107 |
+
_hidden_size *= 5
|
108 |
+
return _hidden_size
|
109 |
+
|
110 |
+
@property
|
111 |
+
def num_patches_per_side(self):
|
112 |
+
return self.config.image_size // self.config.patch_size
|
113 |
+
|
114 |
+
@property
|
115 |
+
def num_patches(self):
|
116 |
+
_num_patches = (self.config.image_size // self.config.patch_size) ** 2
|
117 |
+
if "cls_patch" in self.select_feature:
|
118 |
+
_num_patches += 1
|
119 |
+
return _num_patches
|
120 |
+
|
121 |
+
@property
|
122 |
+
def image_size(self):
|
123 |
+
return self.config.image_size
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|