killTheHostage commited on
Commit
b74e18c
·
0 Parent(s):

Update a convenient way to use this model through Huggingface

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .vscode/
2
+ test/
3
+ test.sh
README.md ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ base_model:
4
+ - DeepGlint-AI/MLCD-Embodied-7B
5
+ ---
6
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-label-cluster-discrimination-for-visual/referring-expression-segmentation-on-refcocog)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcocog?p=multi-label-cluster-discrimination-for-visual)
7
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-label-cluster-discrimination-for-visual/referring-expression-segmentation-on-refcoco-5)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco-5?p=multi-label-cluster-discrimination-for-visual)
8
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-label-cluster-discrimination-for-visual/referring-expression-segmentation-on-refcoco-3)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco-3?p=multi-label-cluster-discrimination-for-visual)
9
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-label-cluster-discrimination-for-visual/referring-expression-segmentation-on-refcocog-1)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcocog-1?p=multi-label-cluster-discrimination-for-visual)
10
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-label-cluster-discrimination-for-visual/referring-expression-segmentation-on-refcoco-8)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco-8?p=multi-label-cluster-discrimination-for-visual)
11
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-label-cluster-discrimination-for-visual/referring-expression-segmentation-on-refcoco-4)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco-4?p=multi-label-cluster-discrimination-for-visual)
12
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-label-cluster-discrimination-for-visual/referring-expression-segmentation-on-refcoco-9)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco-9?p=multi-label-cluster-discrimination-for-visual)
13
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-label-cluster-discrimination-for-visual/referring-expression-segmentation-on-refcoco)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco?p=multi-label-cluster-discrimination-for-visual)
14
+ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-label-cluster-discrimination-for-visual/referring-expression-segmentation-on-refcoco)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refcoco?p=multi-label-cluster-discrimination-for-visual)
15
+
16
+
17
+ ## RefCOCO Segmentation Evaluation:
18
+
19
+ | Dataset | Split | MLCD-seg-7B | EVF-SAM | GLaMM | VisionLLM v2| LISA |
20
+ | :-- | :-: | :-: | :-: | :-: | :-: | :-: |
21
+ | RefCOCO | val | **83.6** | 82.4 | 79.5 | 79.2 | 74.9 |
22
+ | RefCOCO | testA | **85.3** | 84.2 | 83.2 | 82.3 | 79.1 |
23
+ | RefCOCO | testB | **81.5** | 80.2 | 76.9 | 77.0 | 72.3 |
24
+ | RefCOCO+ | val | **79.4** | 76.5 | 72.6 | 68.9 | 65.1 |
25
+ | RefCOCO+ | testA | **82.9** | 80.0 | 78.7 | 75.8 | 70.8 |
26
+ | RefCOCO+ | testB | **75.6** | 71.9 | 64.6 | 61.8 | 58.1 |
27
+ | RefCOCOg | val | **79.7** | 78.2 | 74.2 | 73.3 | 67.9 |
28
+ | RefCOCOg | test | **80.5** | 78.3 | 74.9 | 74.8 | 70.6 |
29
+
30
+
31
+ ## Evaluation
32
+
33
+ ```python
34
+ model_path = "DeepGlint-AI/MLCD-Seg" # or use your local path
35
+ mlcd_seg = AutoModel.from_pretrained(
36
+ model_path,
37
+ torch_dtype=torch.float16,
38
+ trust_remote_code=True
39
+ ).cuda()
40
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
41
+ # Assuming you have an image named test.jpg
42
+ seg_img = Image.open("test.jpg").convert('RGB')
43
+ seg_prompt = "The <image> provides an overview of the picture.\nCould you provide a segmentation mask for the right giraffe in this image?"
44
+ pred_mask = model.predict_forward(seg_img, seg_prompt, tokenizer, force_seg=False)
45
+ ```
46
+
47
+ ## Tips for updating this repo in the future
48
+
49
+
50
+ Huggingface uses cache management module code, so manual clearing of cache is required after repo update
51
+
52
+
53
+ ```bash
54
+ cd ~/.cache/huggingface/modules/transformers_modules
55
+ rm mlcd_seg.py vision_projector.py vision_resampler.py vision_tower.py sam.py conversation_mlcd_seg.py
56
+ ```
57
+
58
+
59
+ ## Citations
60
+ ```
61
+ @misc{mlcdseg_wukun,
62
+ author = {Wu, Kun and Xie, Yin and Zhou, Xinyu and An, Xiang, and Deng, Jiankang, and Jie, Yu},
63
+ title = {MLCD-Seg},
64
+ year = {2025},
65
+ url = {https://github.com/deepglint/unicom/tree/main/downstream},
66
+ }
67
+ ```
config.json ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "DeepGlint-AI/MLCD-Embodied-7B",
3
+ "add_faster_video": false,
4
+ "add_time_instruction": false,
5
+ "architectures": [
6
+ "MLCDSegForCausalLM"
7
+ ],
8
+ "auto_map": {
9
+ "AutoConfig": "mlcd_seg.MLCDSegConfig",
10
+ "AutoModel": "mlcd_seg.MLCDSegForCausalLM",
11
+ "AutoModelForCausalLM": "mlcd_seg.MLCDSegForCausalLM"
12
+ },
13
+ "attn_implementation": "flash_attention_2",
14
+ "attention_dropout": 0.0,
15
+ "bos_token_id": 151643,
16
+ "eos_token_id": 151645,
17
+ "faster_token_stride": 10,
18
+ "force_sample": false,
19
+ "hidden_act": "silu",
20
+ "hidden_size": 3584,
21
+ "image_aspect_ratio": "anyres",
22
+ "image_crop_resolution": null,
23
+ "image_grid_pinpoints": [
24
+ [
25
+ 336,
26
+ 336
27
+ ],
28
+ [
29
+ 336,
30
+ 672
31
+ ],
32
+ [
33
+ 336,
34
+ 1008
35
+ ],
36
+ [
37
+ 336,
38
+ 1344
39
+ ],
40
+ [
41
+ 336,
42
+ 1680
43
+ ],
44
+ [
45
+ 336,
46
+ 2016
47
+ ],
48
+ [
49
+ 672,
50
+ 336
51
+ ],
52
+ [
53
+ 672,
54
+ 672
55
+ ],
56
+ [
57
+ 672,
58
+ 1008
59
+ ],
60
+ [
61
+ 672,
62
+ 1344
63
+ ],
64
+ [
65
+ 672,
66
+ 1680
67
+ ],
68
+ [
69
+ 672,
70
+ 2016
71
+ ],
72
+ [
73
+ 1008,
74
+ 336
75
+ ],
76
+ [
77
+ 1008,
78
+ 672
79
+ ],
80
+ [
81
+ 1008,
82
+ 1008
83
+ ],
84
+ [
85
+ 1008,
86
+ 1344
87
+ ],
88
+ [
89
+ 1008,
90
+ 1680
91
+ ],
92
+ [
93
+ 1008,
94
+ 2016
95
+ ],
96
+ [
97
+ 1344,
98
+ 336
99
+ ],
100
+ [
101
+ 1344,
102
+ 672
103
+ ],
104
+ [
105
+ 1344,
106
+ 1008
107
+ ],
108
+ [
109
+ 1344,
110
+ 1344
111
+ ],
112
+ [
113
+ 1344,
114
+ 1680
115
+ ],
116
+ [
117
+ 1344,
118
+ 2016
119
+ ],
120
+ [
121
+ 1680,
122
+ 336
123
+ ],
124
+ [
125
+ 1680,
126
+ 672
127
+ ],
128
+ [
129
+ 1680,
130
+ 1008
131
+ ],
132
+ [
133
+ 1680,
134
+ 1344
135
+ ],
136
+ [
137
+ 1680,
138
+ 1680
139
+ ],
140
+ [
141
+ 1680,
142
+ 2016
143
+ ],
144
+ [
145
+ 2016,
146
+ 336
147
+ ],
148
+ [
149
+ 2016,
150
+ 672
151
+ ],
152
+ [
153
+ 2016,
154
+ 1008
155
+ ],
156
+ [
157
+ 2016,
158
+ 1344
159
+ ],
160
+ [
161
+ 2016,
162
+ 1680
163
+ ],
164
+ [
165
+ 2016,
166
+ 2016
167
+ ]
168
+ ],
169
+ "image_split_resolution": null,
170
+ "initializer_range": 0.02,
171
+ "intermediate_size": 18944,
172
+ "max_position_embeddings": 32768,
173
+ "max_window_layers": 28,
174
+ "mm_hidden_size": 1024,
175
+ "mm_newline_position": "grid",
176
+ "mm_patch_merge_type": "spatial_unpad",
177
+ "mm_projector_lr": null,
178
+ "mm_projector_type": "mlp2x_gelu",
179
+ "mm_resampler_type": null,
180
+ "mm_spatial_pool_mode": "bilinear",
181
+ "mm_spatial_pool_stride": null,
182
+ "mm_tunable_parts": "mm_vision_tower,mm_mlp_adapter,mm_language_model,sam",
183
+ "mm_use_im_patch_token": false,
184
+ "mm_use_im_start_end": false,
185
+ "mm_vision_select_feature": "patch",
186
+ "mm_vision_select_layer": -2,
187
+ "mm_vision_tower_lr": 2e-06,
188
+ "vision_tower_config": {
189
+ "_name_or_path": "",
190
+ "architectures": [
191
+ "CLIPVisionModel"
192
+ ],
193
+ "attention_dropout": 0.0,
194
+ "hidden_act": "quick_gelu",
195
+ "hidden_size": 1024,
196
+ "image_size": 336,
197
+ "initializer_factor": 1.0,
198
+ "initializer_range": 0.02,
199
+ "intermediate_size": 4096,
200
+ "layer_norm_eps": 1e-05,
201
+ "model_type": "clip_vision_model",
202
+ "num_attention_heads": 16,
203
+ "num_channels": 3,
204
+ "num_hidden_layers": 24,
205
+ "patch_size": 14,
206
+ "projection_dim": 1024,
207
+ "torch_dtype": "float32",
208
+ "transformers_version": "4.44.0"
209
+ },
210
+ "vision_tower_processor": {
211
+ "crop_size": 336,
212
+ "do_center_crop": true,
213
+ "do_normalize": true,
214
+ "do_resize": true,
215
+ "feature_extractor_type": "CLIPFeatureExtractor",
216
+ "image_mean": [
217
+ 0.48145466,
218
+ 0.4578275,
219
+ 0.40821073
220
+ ],
221
+ "image_std": [
222
+ 0.26862954,
223
+ 0.26130258,
224
+ 0.27577711
225
+ ],
226
+ "resample": 3,
227
+ "size": 336
228
+ },
229
+ "model_type": "qwen2",
230
+ "num_attention_heads": 28,
231
+ "num_hidden_layers": 28,
232
+ "num_key_value_heads": 4,
233
+ "pos_skipping_range": 4096,
234
+ "rms_norm_eps": 1e-06,
235
+ "rope_scaling": null,
236
+ "rope_theta": 1000000.0,
237
+ "sliding_window": null,
238
+ "tie_word_embeddings": false,
239
+ "tokenizer_model_max_length": 32768,
240
+ "tokenizer_padding_side": "right",
241
+ "torch_dtype": "bfloat16",
242
+ "transformers_version": "4.47.0",
243
+ "use_cache": true,
244
+ "use_mm_proj": true,
245
+ "use_pos_skipping": false,
246
+ "use_sliding_window": false,
247
+ "vision_tower_pretrained": null,
248
+ "vocab_size": 151666
249
+ }
conversation_mlcd_seg.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ File separation from LLAVA project
3
+ https://github.com/haotian-liu/LLaVA
4
+ '''
5
+
6
+
7
+ import dataclasses
8
+ from enum import auto, Enum
9
+ from typing import List, Any, Union
10
+ import re
11
+ import base64
12
+ from io import BytesIO
13
+ from PIL import Image
14
+ from transformers import AutoTokenizer
15
+
16
+
17
+ class SeparatorStyle(Enum):
18
+ """Different separator style."""
19
+
20
+ SINGLE = auto()
21
+ TWO = auto()
22
+ MPT = auto()
23
+ PLAIN = auto()
24
+ CHATML = auto()
25
+ LLAMA_2 = auto()
26
+ LLAMA_3 = auto()
27
+ QWEN = auto()
28
+ GEMMA = auto()
29
+
30
+
31
+ @dataclasses.dataclass
32
+ class Conversation:
33
+ """A class that keeps all conversation history."""
34
+
35
+ system: str
36
+ roles: List[str]
37
+ messages: List[List[str]]
38
+ offset: int
39
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
40
+ sep: str = "###"
41
+ sep2: str = None
42
+ version: str = "Unknown"
43
+
44
+ tokenizer_id: str = ""
45
+ tokenizer: Any = None
46
+ # Stop criteria (the default one is EOS token)
47
+ stop_str: Union[str, List[str]] = None
48
+ # Stops generation if meeting any token in this list
49
+ stop_token_ids: List[int] = None
50
+
51
+ skip_next: bool = False
52
+
53
+ def get_prompt(self):
54
+ messages = self.messages
55
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
56
+ messages = self.messages.copy()
57
+ init_role, init_msg = messages[0].copy()
58
+ init_msg = init_msg[0]
59
+ if "mmtag" in self.version:
60
+ init_msg = init_msg.replace("<image>", "").strip()
61
+ messages[0] = (init_role, init_msg)
62
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
63
+ messages.insert(1, (self.roles[1], "Received."))
64
+ elif not init_msg.startswith("<image>"):
65
+ init_msg = init_msg.replace("<image>", "").strip()
66
+ messages[0] = (init_role, "<image>\n" + init_msg)
67
+ else:
68
+ messages[0] = (init_role, init_msg)
69
+
70
+ if self.sep_style == SeparatorStyle.SINGLE:
71
+ ret = self.system + self.sep
72
+ for role, message in messages:
73
+ if message:
74
+ if type(message) is tuple:
75
+ message, _, _ = message
76
+ ret += role + ": " + message + self.sep
77
+ else:
78
+ ret += role + ":"
79
+
80
+ elif self.sep_style == SeparatorStyle.TWO:
81
+ seps = [self.sep, self.sep2]
82
+ ret = self.system + seps[0]
83
+ for i, (role, message) in enumerate(messages):
84
+ if message:
85
+ if type(message) is tuple:
86
+ message, _, _ = message
87
+ ret += role + ": " + message + seps[i % 2]
88
+ else:
89
+ ret += role + ":"
90
+
91
+ elif self.sep_style == SeparatorStyle.CHATML:
92
+ ret = "" if self.system == "" else self.system + self.sep + "\n"
93
+ for role, message in messages:
94
+ if message:
95
+ if type(message) is tuple:
96
+ message, images, _ = message
97
+ message = "<image>" * len(images) + message
98
+ ret += role + "\n" + message + self.sep + "\n"
99
+ else:
100
+ ret += role + "\n"
101
+ return ret
102
+
103
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
104
+ if self.tokenizer is None:
105
+ raise ValueError("Llama 3 tokenizer is not available. Make sure you have the necessary permissions.")
106
+ chat_template_messages = [{"role": "system", "content": self.system}]
107
+ for role, message in messages:
108
+ if message:
109
+ if type(message) is tuple:
110
+ message, images = message
111
+ message = "<image>" * len(images) + message
112
+ chat_template_messages.append({"role": role, "content": message})
113
+ return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
114
+
115
+ elif self.sep_style == SeparatorStyle.MPT:
116
+ ret = self.system + self.sep
117
+ for role, message in messages:
118
+ if message:
119
+ if type(message) is tuple:
120
+ message, _, _ = message
121
+ ret += role + message + self.sep
122
+ else:
123
+ ret += role
124
+
125
+ elif self.sep_style == SeparatorStyle.GEMMA:
126
+ ret = ""
127
+ for i, (role, message) in enumerate(messages):
128
+ assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
129
+ if message:
130
+ if type(message) is tuple:
131
+ message, _, _ = message
132
+ ret += role + message + self.sep
133
+ else:
134
+ ret += role
135
+
136
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
137
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
138
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
139
+ ret = ""
140
+
141
+ for i, (role, message) in enumerate(messages):
142
+ if i == 0:
143
+ assert message, "first message should not be none"
144
+ assert role == self.roles[0], "first message should come from user"
145
+ if message:
146
+ if type(message) is tuple:
147
+ message, _, _ = message
148
+ if i == 0:
149
+ message = wrap_sys(self.system) + message
150
+ if i % 2 == 0:
151
+ message = wrap_inst(message)
152
+ ret += self.sep + message
153
+ else:
154
+ ret += " " + message + " " + self.sep2
155
+ else:
156
+ ret += ""
157
+ ret = ret.lstrip(self.sep)
158
+
159
+ elif self.sep_style == SeparatorStyle.PLAIN:
160
+ seps = [self.sep, self.sep2]
161
+ ret = self.system
162
+ for i, (role, message) in enumerate(messages):
163
+ if message:
164
+ if type(message) is tuple:
165
+ message, _, _ = message
166
+ ret += message + seps[i % 2]
167
+ else:
168
+ ret += ""
169
+ else:
170
+ raise ValueError(f"Invalid style: {self.sep_style}")
171
+
172
+ return ret
173
+
174
+ def append_message(self, role, message):
175
+ self.messages.append([role, message])
176
+
177
+ def process_image(self, image, image_process_mode, return_pil=False, image_format="PNG"):
178
+ if image_process_mode == "Pad":
179
+
180
+ def expand2square(pil_img, background_color=(122, 116, 104)):
181
+ width, height = pil_img.size
182
+ if width == height:
183
+ return pil_img
184
+ elif width > height:
185
+ result = Image.new(pil_img.mode, (width, width), background_color)
186
+ result.paste(pil_img, (0, (width - height) // 2))
187
+ return result
188
+ else:
189
+ result = Image.new(pil_img.mode, (height, height), background_color)
190
+ result.paste(pil_img, ((height - width) // 2, 0))
191
+ return result
192
+
193
+ image = expand2square(image)
194
+ elif image_process_mode in ["Default", "Crop"]:
195
+ pass
196
+ elif image_process_mode == "Resize":
197
+ image = image.resize((336, 336))
198
+ else:
199
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
200
+
201
+ if type(image) is not Image.Image:
202
+ image = Image.open(image).convert("RGB")
203
+
204
+ max_hw, min_hw = max(image.size), min(image.size)
205
+ aspect_ratio = max_hw / min_hw
206
+ max_len, min_len = 672, 448
207
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
208
+ longest_edge = int(shortest_edge * aspect_ratio)
209
+ W, H = image.size
210
+ if H > W:
211
+ H, W = longest_edge, shortest_edge
212
+ else:
213
+ H, W = shortest_edge, longest_edge
214
+ image = image.resize((W, H))
215
+ if return_pil:
216
+ return image
217
+ else:
218
+ buffered = BytesIO()
219
+ image.save(buffered, format=image_format)
220
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
221
+ return img_b64_str
222
+
223
+ def get_images(self, return_pil=False, return_path=False):
224
+ images = []
225
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
226
+ if i % 2 == 0:
227
+ if type(msg) is tuple:
228
+ msg, image, image_process_mode = msg
229
+ if type(image) != list:
230
+ image = [image]
231
+ for img in image:
232
+ if not return_path and self.is_image_file(img):
233
+ img = self.process_image(img, image_process_mode, return_pil=return_pil)
234
+ else:
235
+ images.append(img)
236
+ return images
237
+
238
+ def is_image_file(self, filename):
239
+ image_extensions = [".png", ".jpg", ".jpeg", ".gif", ".bmp", ".tiff", ".webp"]
240
+ return any(filename.lower().endswith(ext) for ext in image_extensions)
241
+
242
+ def is_video_file(self, filename):
243
+ video_extensions = [".mp4", ".mov", ".avi", ".mkv", ".wmv", ".flv", ".mpeg", ".mpg"]
244
+ return any(filename.lower().endswith(ext) for ext in video_extensions)
245
+
246
+ def to_gradio_chatbot(self):
247
+ ret = []
248
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
249
+ if i % 2 == 0:
250
+ if type(msg) is tuple:
251
+ msg, image, image_process_mode = msg
252
+ if type(image) != list:
253
+ image = [image]
254
+ if len(image) == 1:
255
+ msg = "<image>\n" + msg.replace("<image>", "").strip()
256
+ else:
257
+ msg = re.sub(r"(<image>)\n(?=<image>)", r"\1 ", msg)
258
+
259
+ img_str_list = []
260
+ for img in image:
261
+ if self.is_image_file(img):
262
+ img_b64_str = self.process_image(img, "Default", return_pil=False, image_format="JPEG")
263
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" style="max-width: 256px; max-height: 256px; width: auto; height: auto; object-fit: contain;"/>'
264
+ img_str_list.append(img_str)
265
+ elif self.is_video_file(img):
266
+ ret.append(((img,), None))
267
+
268
+ msg = msg.strip()
269
+ img_place_holder = ""
270
+ for img_str in img_str_list:
271
+ img_place_holder += f"{img_str}\n\n"
272
+
273
+ if len(img_str_list) > 0:
274
+ msg = f"{img_place_holder}\n\n{msg}"
275
+
276
+ if len(msg) > 0:
277
+ ret.append([msg, None])
278
+ else:
279
+ ret.append([msg, None])
280
+ else:
281
+ ret[-1][-1] = msg
282
+ return ret
283
+
284
+ def copy(self):
285
+ return Conversation(system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version)
286
+
287
+ def dict(self):
288
+ if len(self.get_images()) > 0:
289
+ return {
290
+ "system": self.system,
291
+ "roles": self.roles,
292
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
293
+ "offset": self.offset,
294
+ "sep": self.sep,
295
+ "sep2": self.sep2,
296
+ }
297
+ return {
298
+ "system": self.system,
299
+ "roles": self.roles,
300
+ "messages": self.messages,
301
+ "offset": self.offset,
302
+ "sep": self.sep,
303
+ "sep2": self.sep2,
304
+ }
305
+
306
+
307
+ conv_vicuna_v0 = Conversation(
308
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
309
+ roles=("Human", "Assistant"),
310
+ messages=[
311
+ ["Human", "What are the key differences between renewable and non-renewable energy sources?"],
312
+ [
313
+ "Assistant",
314
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
315
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
316
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
317
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
318
+ "renewable and non-renewable energy sources:\n"
319
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
320
+ "energy sources are finite and will eventually run out.\n"
321
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
322
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
323
+ "and other negative effects.\n"
324
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
325
+ "have lower operational costs than non-renewable sources.\n"
326
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
327
+ "locations than non-renewable sources.\n"
328
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
329
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
330
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
331
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
332
+ ],
333
+ ],
334
+ offset=2,
335
+ sep_style=SeparatorStyle.SINGLE,
336
+ sep="###",
337
+ )
338
+
339
+ conv_vicuna_v1 = Conversation(
340
+ system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.",
341
+ roles=("USER", "ASSISTANT"),
342
+ version="v1",
343
+ messages=[],
344
+ offset=0,
345
+ sep_style=SeparatorStyle.TWO,
346
+ sep=" ",
347
+ sep2="</s>",
348
+ )
349
+
350
+ conv_llama_2 = Conversation(
351
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
352
+
353
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
354
+ roles=("USER", "ASSISTANT"),
355
+ version="llama_v2",
356
+ messages=[],
357
+ offset=0,
358
+ sep_style=SeparatorStyle.LLAMA_2,
359
+ sep="<s>",
360
+ sep2="</s>",
361
+ )
362
+
363
+ conv_llava_llama_2 = Conversation(
364
+ system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
365
+ roles=("USER", "ASSISTANT"),
366
+ version="llama_v2",
367
+ messages=[],
368
+ offset=0,
369
+ sep_style=SeparatorStyle.LLAMA_2,
370
+ sep="<s>",
371
+ sep2="</s>",
372
+ )
373
+
374
+ def safe_load_tokenizer(tokenizer_id):
375
+ try:
376
+ return AutoTokenizer.from_pretrained(tokenizer_id)
377
+ except Exception:
378
+ return None
379
+
380
+ conv_llava_llama_3 = Conversation(
381
+ system="You are a helpful language and vision assistant. " "You are able to understand the visual content that the user provides, " "and assist the user with a variety of tasks using natural language.",
382
+ roles=("user", "assistant"),
383
+ version="llama_v3",
384
+ messages=[],
385
+ offset=0,
386
+ sep="<|eot_id|>",
387
+ sep_style=SeparatorStyle.LLAMA_3,
388
+ tokenizer_id="meta-llama/Meta-Llama-3-8B-Instruct",
389
+ tokenizer=safe_load_tokenizer("meta-llama/Meta-Llama-3-8B-Instruct"),
390
+ stop_token_ids=[128009],
391
+ )
392
+
393
+ conv_mistral_instruct = Conversation(
394
+ system="",
395
+ roles=("USER", "ASSISTANT"),
396
+ version="llama_v2",
397
+ messages=[],
398
+ offset=0,
399
+ sep_style=SeparatorStyle.LLAMA_2,
400
+ sep="",
401
+ sep2="</s>",
402
+ )
403
+
404
+ conv_llava_llama_2_simple = Conversation(
405
+ system="Answer the questions about the visual content that the user provides.",
406
+ roles=("USER", "ASSISTANT"),
407
+ version="llama_v2",
408
+ messages=[],
409
+ offset=0,
410
+ sep_style=SeparatorStyle.LLAMA_2,
411
+ sep="<s>",
412
+ sep2="</s>",
413
+ )
414
+
415
+ conv_llava_llama_2_mmtag = Conversation(
416
+ system="Answer the questions about the visual content that the user provides." "The visual content will be provided with the following format: <Image>visual content</Image>.",
417
+ roles=("USER", "ASSISTANT"),
418
+ version="llama_v2_mmtag",
419
+ messages=[],
420
+ offset=0,
421
+ sep_style=SeparatorStyle.LLAMA_2,
422
+ sep="<s>",
423
+ sep2="</s>",
424
+ )
425
+
426
+ conv_mpt = Conversation(
427
+ system="""<|im_start|>system
428
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
429
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
430
+ version="mpt",
431
+ messages=[],
432
+ offset=0,
433
+ sep_style=SeparatorStyle.MPT,
434
+ sep="<|im_end|>",
435
+ )
436
+
437
+ conv_qwen = Conversation(
438
+ system="""<|im_start|>system
439
+ You are a helpful assistant.""",
440
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
441
+ version="qwen",
442
+ messages=[],
443
+ offset=0,
444
+ sep_style=SeparatorStyle.CHATML,
445
+ sep="<|im_end|>",
446
+ )
447
+
448
+ conv_gemma_instruct = Conversation(system="", roles=("<start_of_turn>user\n", "<start_of_turn>model\n"), version="gemma", messages=[], offset=0, sep_style=SeparatorStyle.GEMMA, sep="<end_of_turn>\n")
449
+
450
+ conv_llava_plain = Conversation(
451
+ system="",
452
+ roles=("", ""),
453
+ messages=[],
454
+ offset=0,
455
+ sep_style=SeparatorStyle.PLAIN,
456
+ sep="\n",
457
+ )
458
+
459
+ conv_llava_v0 = Conversation(
460
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
461
+ roles=("Human", "Assistant"),
462
+ messages=[],
463
+ offset=0,
464
+ sep_style=SeparatorStyle.SINGLE,
465
+ sep="###",
466
+ )
467
+
468
+ conv_llava_v0_mmtag = Conversation(
469
+ system="A chat between a curious user and an artificial intelligence assistant. "
470
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
471
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
472
+ roles=("Human", "Assistant"),
473
+ messages=[],
474
+ offset=0,
475
+ sep_style=SeparatorStyle.SINGLE,
476
+ sep="###",
477
+ version="v0_mmtag",
478
+ )
479
+
480
+ conv_llava_v1 = Conversation(
481
+ system="A chat between a curious human and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the human's questions.",
482
+ roles=("USER", "ASSISTANT"),
483
+ version="v1",
484
+ messages=[],
485
+ offset=0,
486
+ sep_style=SeparatorStyle.TWO,
487
+ sep=" ",
488
+ sep2="</s>",
489
+ )
490
+
491
+ conv_llava_v1_mmtag = Conversation(
492
+ system="A chat between a curious user and an artificial intelligence assistant. "
493
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
494
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
495
+ roles=("USER", "ASSISTANT"),
496
+ messages=[],
497
+ offset=0,
498
+ sep_style=SeparatorStyle.TWO,
499
+ sep=" ",
500
+ sep2="</s>",
501
+ version="v1_mmtag",
502
+ )
503
+
504
+ conv_mistral_orca = Conversation(
505
+ system="""<|im_start|>system
506
+ You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!""",
507
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
508
+ version="mpt",
509
+ messages=[],
510
+ offset=0,
511
+ sep_style=SeparatorStyle.MPT,
512
+ sep="<|im_end|>",
513
+ )
514
+
515
+ conv_mistral_zephyr = Conversation(
516
+ system="""<|system|>
517
+ You are a helpful AI assistant.""",
518
+ roles=("<|user|>\n", "<|assistant|>\n"),
519
+ version="mpt",
520
+ messages=[],
521
+ offset=0,
522
+ sep_style=SeparatorStyle.MPT,
523
+ sep="</s>",
524
+ )
525
+
526
+ conv_mistral_direct = Conversation(
527
+ system="""<|im_start|>system
528
+ Answer the questions.""",
529
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
530
+ version="mpt",
531
+ messages=[],
532
+ offset=0,
533
+ sep_style=SeparatorStyle.MPT,
534
+ sep="<|im_end|>",
535
+ )
536
+
537
+ conv_chatml_direct = Conversation(
538
+ system="""<|im_start|>system
539
+ Answer the questions.""",
540
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
541
+ version="mpt",
542
+ messages=[],
543
+ offset=0,
544
+ sep_style=SeparatorStyle.MPT,
545
+ sep="<|im_end|>",
546
+ )
547
+
548
+ default_conversation = conv_vicuna_v0
549
+ conv_templates = {
550
+ "default": conv_vicuna_v0,
551
+ "v0": conv_vicuna_v0,
552
+ "v1": conv_vicuna_v1,
553
+ "vicuna_v1": conv_vicuna_v1,
554
+ "llama_2": conv_llama_2,
555
+ "mistral_instruct": conv_mistral_instruct,
556
+ "mistral_orca": conv_mistral_orca,
557
+ "mistral_zephyr": conv_mistral_zephyr,
558
+ "mistral_direct": conv_mistral_direct,
559
+ "plain": conv_llava_plain,
560
+ "v0_plain": conv_llava_plain,
561
+ "chatml_direct": conv_chatml_direct,
562
+ "llava_v0": conv_llava_v0,
563
+ "llava_v0_mmtag": conv_llava_v0_mmtag,
564
+ "llava_v1": conv_llava_v1,
565
+ "llava_v1_mmtag": conv_llava_v1_mmtag,
566
+ "llava_llama_2": conv_llava_llama_2,
567
+ "llava_llama_3": conv_llava_llama_3,
568
+ "llava_llama_2_simple": conv_llava_llama_2_simple,
569
+ "llava_llama_2_mmtag": conv_llava_llama_2_mmtag,
570
+ "llava_mistral_instruct": conv_mistral_instruct,
571
+ "mpt": conv_mpt,
572
+ "qwen_1_5": conv_qwen,
573
+ "qwen_2": conv_qwen,
574
+ "gemma_instruct": conv_gemma_instruct,
575
+ }
576
+
577
+
578
+ if __name__ == "__main__":
579
+ print(default_conversation.get_prompt())
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
mlcd_seg.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ File modification from LLAVA project @DeepGlintAI 2025
3
+ https://github.com/haotian-liu/LLaVA
4
+
5
+ origin copyright:
6
+
7
+ Copyright 2023 Haotian Liu
8
+
9
+ Licensed under the Apache License, Version 2.0 (the "License");
10
+ you may not use this file except in compliance with the License.
11
+ You may obtain a copy of the License at
12
+
13
+ http://www.apache.org/licenses/LICENSE-2.0
14
+
15
+ Unless required by applicable law or agreed to in writing, software
16
+ distributed under the License is distributed on an "AS IS" BASIS,
17
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ See the License for the specific language governing permissions and
19
+ limitations under the License.
20
+ '''
21
+
22
+
23
+ from abc import ABC, abstractmethod
24
+
25
+ import math
26
+ import random
27
+ import ast
28
+ import re
29
+ import json
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ from pathlib import Path
35
+ from dataclasses import dataclass
36
+ from PIL import Image
37
+ from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer
38
+ from transformers.modeling_outputs import CausalLMOutputWithPast
39
+ from transformers.generation.utils import GenerateOutput
40
+ from safetensors.torch import load_file as safetensors_load
41
+ from .vision_tower import build_vision_tower
42
+ from .vision_resampler import build_vision_resampler
43
+ from .vision_projector import build_vision_projector
44
+ from .sam import build_sam_vit_h, text2sam_projection_layer
45
+ from .conversation_mlcd_seg import default_conversation
46
+ from .transform import ResizeLongestSide
47
+ from typing import Optional, Any, List, Tuple, Union, Dict
48
+
49
+ IGNORE_INDEX = -100
50
+ IMAGE_TOKEN_INDEX = -200
51
+ DEFAULT_SEG_TOKEN = "[SEG]"
52
+
53
+ IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
54
+ IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
55
+ IMG_SIZE = 1024
56
+
57
+ def select_best_resolution(original_size, possible_resolutions):
58
+ """
59
+ Selects the best resolution from a list of possible resolutions based on the original size.
60
+
61
+ Args:
62
+ original_size (tuple): The original size of the image in the format (width, height).
63
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
64
+
65
+ Returns:
66
+ tuple: The best fit resolution in the format (width, height).
67
+ """
68
+ original_width, original_height = original_size
69
+ best_fit = None
70
+ max_effective_resolution = 0
71
+ min_wasted_resolution = float("inf")
72
+
73
+ for width, height in possible_resolutions:
74
+ # Calculate the downscaled size to keep the aspect ratio
75
+ scale = min(width / original_width, height / original_height)
76
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
77
+
78
+ # Calculate effective and wasted resolutions
79
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
80
+ wasted_resolution = (width * height) - effective_resolution
81
+
82
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
83
+ max_effective_resolution = effective_resolution
84
+ min_wasted_resolution = wasted_resolution
85
+ best_fit = (width, height)
86
+
87
+ return best_fit
88
+
89
+
90
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
91
+ """
92
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
93
+
94
+ Args:
95
+ image_size (tuple): The size of the input image in the format (width, height).
96
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
97
+ patch_size (int): The size of each image patch.
98
+
99
+ Returns:
100
+ tuple: The shape of the image patch grid in the format (width, height).
101
+ """
102
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
103
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
104
+ # Use regex to extract the range from the input string
105
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
106
+ range_start = tuple(map(int, matches[0]))
107
+ range_end = tuple(map(int, matches[-1]))
108
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
109
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
110
+ # Multiply all elements by patch_size
111
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
112
+ if type(grid_pinpoints) is list:
113
+ possible_resolutions = grid_pinpoints
114
+ else:
115
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
116
+ width, height = select_best_resolution(image_size, possible_resolutions)
117
+ return width // patch_size, height // patch_size
118
+
119
+
120
+ class MLCDSegMetaModel:
121
+
122
+ def __init__(self, config):
123
+ super(MLCDSegMetaModel, self).__init__(config)
124
+
125
+ if hasattr(config, "vision_tower_config"):
126
+ vision_tower_weight, sam_weight, projector_weight, text2sam_projection_weight = self.dispatch_weight(config)
127
+ delay_load = getattr(config, "delay_load", False)
128
+ self.vision_tower = build_vision_tower(config, delay_load=delay_load)
129
+ self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
130
+ self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)
131
+ self.vision_tower.vision_tower.load_state_dict(vision_tower_weight)
132
+ self.mm_projector.load_state_dict(projector_weight)
133
+ self.sam = build_sam_vit_h()
134
+ self.sam.load_state_dict(sam_weight)
135
+ self.text2sam_projection = text2sam_projection_layer(config)
136
+ self.text2sam_projection.load_state_dict(text2sam_projection_weight)
137
+
138
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
139
+ self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype))
140
+
141
+ def dispatch_weight(self, config):
142
+ safetensors_set = set()
143
+ index_file = Path(getattr(config, "name_or_path", "./")) / "model.safetensors.index.json"
144
+ with open(index_file, "r") as safetensors_index:
145
+ safetensors_map = json.loads(safetensors_index.read())
146
+ for key, value in safetensors_map["weight_map"].items():
147
+ if key.startswith("model.vision_tower.vision_tower") or key.startswith("model.sam") or \
148
+ key.startswith("model.mm_projector") or key.startswith("model.text2sam_projection"):
149
+ safetensors_set.add(value)
150
+ vision_tower_weight = {}
151
+ sam_weight = {}
152
+ projector_weight = {}
153
+ text2sam_projection_weight = {}
154
+ for safetensors_file in safetensors_set:
155
+ temp_load = safetensors_load(safetensors_file)
156
+ for key, value in temp_load.items():
157
+ if key.startswith("model.sam."):
158
+ sam_weight[key.replace("model.sam.", "")] = value
159
+ if key.startswith("model.vision_tower.vision_tower."):
160
+ vision_tower_weight[key.replace("model.vision_tower.vision_tower.", "")] = value
161
+ if key.startswith("model.mm_projector."):
162
+ projector_weight[key.replace("model.mm_projector.", "")] = value
163
+ if key.startswith("model.text2sam_projection."):
164
+ text2sam_projection_weight[key.replace("model.text2sam_projection.", "")] = value
165
+ return vision_tower_weight, sam_weight, projector_weight, text2sam_projection_weight
166
+
167
+ def get_vision_tower(self):
168
+ vision_tower = getattr(self, "vision_tower", None)
169
+ if type(vision_tower) is list:
170
+ vision_tower = vision_tower[0]
171
+ return vision_tower
172
+
173
+ def initialize_vision_modules(self, model_args, fsdp=None):
174
+ vision_tower = model_args.vision_tower
175
+ mm_vision_select_layer = model_args.mm_vision_select_layer
176
+ mm_vision_select_feature = model_args.mm_vision_select_feature
177
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
178
+ mm_patch_merge_type = model_args.mm_patch_merge_type
179
+
180
+ self.config.mm_vision_tower = vision_tower
181
+ self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
182
+
183
+ if self.get_vision_tower() is None:
184
+ vision_tower = build_vision_tower(model_args)
185
+ vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower)
186
+ for k, v in vision_resampler.config.items():
187
+ setattr(self.config, k, v)
188
+
189
+ if fsdp is not None and len(fsdp) > 0:
190
+ self.vision_tower = [vision_tower]
191
+ self.vision_resampler = [vision_resampler]
192
+ else:
193
+ self.vision_tower = vision_tower
194
+ self.vision_resampler = vision_resampler
195
+ else:
196
+ if fsdp is not None and len(fsdp) > 0:
197
+ vision_resampler = self.vision_resampler[0]
198
+ vision_tower = self.vision_tower[0]
199
+ else:
200
+ vision_resampler = self.vision_resampler
201
+ vision_tower = self.vision_tower
202
+ vision_tower.load_model()
203
+
204
+ # In case it is frozen by LoRA
205
+ for p in self.vision_resampler.parameters():
206
+ p.requires_grad = True
207
+
208
+ self.config.use_mm_proj = True
209
+ self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
210
+ self.config.mm_hidden_size = getattr(vision_resampler, "hidden_size", vision_tower.hidden_size)
211
+ self.config.mm_vision_select_layer = mm_vision_select_layer
212
+ self.config.mm_vision_select_feature = mm_vision_select_feature
213
+ self.config.mm_patch_merge_type = mm_patch_merge_type
214
+
215
+ for key in vars(model_args):
216
+ if key.startswith('sam_'):
217
+ setattr(self.config, key, getattr(model_args, key))
218
+
219
+ if not hasattr(self.config, 'add_faster_video'):
220
+ if model_args.add_faster_video:
221
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
222
+ self.faster_token = nn.Parameter(
223
+ torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
224
+ )
225
+
226
+ if getattr(self, "mm_projector", None) is None:
227
+ self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config)
228
+
229
+ if "unpad" in mm_patch_merge_type:
230
+ embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
231
+ self.image_newline = nn.Parameter(torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std)
232
+
233
+ if getattr(self.config, 'sam_path', None) is not None:
234
+ self.sam = build_sam_vit_h(self.config.sam_path)
235
+ self.text2sam_projection = text2sam_projection_layer(self.config)
236
+ else:
237
+ if getattr(self.config, 'sam_path', None) is not None and self.config.sam_path !="":
238
+ self.sam = build_sam_vit_h(self.config.sam_path)
239
+ self.text2sam_projection = text2sam_projection_layer(self.config)
240
+ # In case it is frozen by LoRA
241
+ for p in self.mm_projector.parameters():
242
+ p.requires_grad = True
243
+
244
+ if pretrain_mm_mlp_adapter is not None:
245
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
246
+
247
+ def get_w(weights, keyword):
248
+ return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
249
+
250
+ incompatible_keys = self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
251
+ incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, "vision_resampler"), strict=False)
252
+
253
+
254
+ def unpad_image(tensor, original_size):
255
+ """
256
+ Unpads a PyTorch tensor of a padded and resized image.
257
+
258
+ Args:
259
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
260
+ original_size (tuple): The original size of the image (height, width).
261
+
262
+ Returns:
263
+ torch.Tensor: The unpadded image tensor.
264
+ """
265
+ original_width, original_height = original_size
266
+ current_height, current_width = tensor.shape[1:]
267
+
268
+ # Compute aspect ratios
269
+ original_aspect_ratio = original_width / original_height
270
+ current_aspect_ratio = current_width / current_height
271
+
272
+ # Determine padding size and direction
273
+ if original_aspect_ratio > current_aspect_ratio:
274
+ # Padding was added to the height
275
+ scale_factor = current_width / original_width
276
+ new_height = int(original_height * scale_factor)
277
+ padding = (current_height - new_height) // 2
278
+ unpadded_tensor = tensor[:, padding : current_height - padding, :]
279
+ else:
280
+ # Padding was added to the width
281
+ scale_factor = current_height / original_height
282
+ new_width = int(original_width * scale_factor)
283
+ padding = (current_width - new_width) // 2
284
+ unpadded_tensor = tensor[:, :, padding : current_width - padding]
285
+
286
+ return unpadded_tensor
287
+
288
+
289
+ def resize_and_pad_image(image, target_resolution):
290
+ """
291
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
292
+
293
+ Args:
294
+ image (PIL.Image.Image): The input image.
295
+ target_resolution (tuple): The target resolution (width, height) of the image.
296
+
297
+ Returns:
298
+ PIL.Image.Image: The resized and padded image.
299
+ """
300
+ original_width, original_height = image.size
301
+ target_width, target_height = target_resolution
302
+
303
+ # Determine which dimension (width or height) to fill
304
+ scale_w = target_width / original_width
305
+ scale_h = target_height / original_height
306
+
307
+ if scale_w < scale_h:
308
+ # Width will be filled completely
309
+ new_width = target_width
310
+ new_height = min(math.ceil(original_height * scale_w), target_height)
311
+ else:
312
+ # Height will be filled completely
313
+ new_height = target_height
314
+ new_width = min(math.ceil(original_width * scale_h), target_width)
315
+
316
+ # Resize the image
317
+ resized_image = image.resize((new_width, new_height))
318
+
319
+ # Create a new image with the target size and paste the resized image onto it
320
+ new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0))
321
+ paste_x = (target_width - new_width) // 2
322
+ paste_y = (target_height - new_height) // 2
323
+ new_image.paste(resized_image, (paste_x, paste_y))
324
+
325
+ return new_image
326
+
327
+
328
+ def divide_to_patches(image, patch_size):
329
+ """
330
+ Divides an image into patches of a specified size.
331
+
332
+ Args:
333
+ image (PIL.Image.Image): The input image.
334
+ patch_size (int): The size of each patch.
335
+
336
+ Returns:
337
+ list: A list of PIL.Image.Image objects representing the patches.
338
+ """
339
+ patches = []
340
+ width, height = image.size
341
+ for i in range(0, height, patch_size):
342
+ for j in range(0, width, patch_size):
343
+ box = (j, i, j + patch_size, i + patch_size)
344
+ patch = image.crop(box)
345
+ patches.append(patch)
346
+
347
+ return patches
348
+
349
+
350
+ def process_anyres_image(image, processor, grid_pinpoints):
351
+ """
352
+ Process an image with variable resolutions.
353
+
354
+ Args:
355
+ image (PIL.Image.Image): The input image to be processed.
356
+ processor: The image processor object.
357
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
358
+
359
+ Returns:
360
+ torch.Tensor: A tensor containing the processed image patches.
361
+ """
362
+ # Convert grid_pinpoints from string to list
363
+ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
364
+ try:
365
+ patch_size = processor.size[0]
366
+ except Exception as e:
367
+ patch_size = processor.size["shortest_edge"]
368
+ assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
369
+ # Use regex to extract the range from the input string
370
+ matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
371
+ range_start = tuple(map(int, matches[0]))
372
+ range_end = tuple(map(int, matches[-1]))
373
+ # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1])
374
+ grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)]
375
+ # Multiply all elements by patch_size
376
+ grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
377
+
378
+ if type(grid_pinpoints) is list:
379
+ possible_resolutions = grid_pinpoints
380
+ else:
381
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
382
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
383
+ image_padded = resize_and_pad_image(image, best_resolution)
384
+
385
+ patches = divide_to_patches(image_padded, processor.crop_size["height"])
386
+
387
+ # FIXME: this seems to be a bug that it resizes instead of pad.
388
+ # but to keep it consistent with previous, i will keep it as it is
389
+ # TODO: uncomment below to ablate with the padding
390
+ if isinstance(processor.size, dict):
391
+ shortest_edge = processor.size["shortest_edge"]
392
+ else:
393
+ shortest_edge = min(processor.size)
394
+ image_original_resize = image.resize((shortest_edge, shortest_edge))
395
+ # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
396
+ # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
397
+
398
+ image_patches = [image_original_resize] + patches
399
+ image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches]
400
+ return torch.stack(image_patches, dim=0)
401
+
402
+
403
+ class MLCDSegMetaForCausalLM(ABC):
404
+
405
+ @abstractmethod
406
+ def get_model(self):
407
+ pass
408
+
409
+ def get_vision_tower(self):
410
+ return self.get_model().get_vision_tower()
411
+
412
+ def get_2dPool(self, image_feature, stride=2):
413
+ height = width = self.get_vision_tower().num_patches_per_side
414
+ num_frames, num_tokens, num_dim = image_feature.shape
415
+ image_feature = image_feature.view(num_frames, height, width, -1)
416
+ image_feature = image_feature.permute(0, 3, 1, 2).contiguous()
417
+ # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride)
418
+ if self.config.mm_spatial_pool_mode == "average":
419
+ image_feature = nn.functional.avg_pool2d(image_feature, stride)
420
+ elif self.config.mm_spatial_pool_mode == "max":
421
+ image_feature = nn.functional.max_pool2d(image_feature, stride)
422
+ elif self.config.mm_spatial_pool_mode == "bilinear":
423
+ height, width = image_feature.shape[2:]
424
+ scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)]
425
+ image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear')
426
+
427
+ else:
428
+ raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}")
429
+ image_feature = image_feature.permute(0, 2, 3, 1)
430
+ image_feature = image_feature.view(num_frames, -1, num_dim)
431
+ return image_feature
432
+
433
+ def encode_images(self, images):
434
+ image_features = self.get_model().get_vision_tower()(images)
435
+ # image_features = self.get_model().vision_resampler(image_features, images=images)
436
+ image_features = self.get_model().mm_projector(image_features)
437
+ return image_features
438
+
439
+ def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None):
440
+ videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images)
441
+ per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096)
442
+ all_videos_or_images_features = []
443
+ all_faster_video_features = []
444
+ cur_mm_spatial_pool_stride = self.config.mm_spatial_pool_stride
445
+
446
+ for idx, feat in enumerate(per_videos_or_images_features):
447
+
448
+ feat = self.get_model().mm_projector(feat)
449
+ faster_video_feature = 0
450
+ slower_img_feat = 0
451
+ if idx in video_idx_in_batch and cur_mm_spatial_pool_stride > 1:
452
+ slower_img_feat = self.get_2dPool(feat,cur_mm_spatial_pool_stride)
453
+ if self.config.add_faster_video:
454
+ cur_mm_spatial_pool_stride = cur_mm_spatial_pool_stride * 2
455
+ faster_video_feature = self.get_2dPool(feat,cur_mm_spatial_pool_stride)
456
+ if slower_img_feat != 0:
457
+ all_videos_or_images_features.append(slower_img_feat)
458
+ else:
459
+ all_videos_or_images_features.append(feat)
460
+ all_faster_video_features.append(faster_video_feature)
461
+ return all_videos_or_images_features,all_faster_video_features
462
+
463
+ def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None):
464
+ vision_tower = self.get_vision_tower()
465
+ # rank_print(modalities)
466
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
467
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
468
+
469
+ if isinstance(modalities, str):
470
+ modalities = [modalities]
471
+
472
+ # import pdb; pdb.set_trace()
473
+ if type(images) is list or images.ndim == 5:
474
+ if type(images) is list:
475
+ images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
476
+ images_list = []
477
+ for image in images:
478
+ if image.ndim == 4:
479
+ images_list.append(image)
480
+ else:
481
+ images_list.append(image.unsqueeze(0))
482
+ concat_images = torch.cat([image for image in images_list], dim=0)
483
+ split_sizes = [image.shape[0] for image in images_list]
484
+ encoded_image_features = self.encode_images(concat_images)
485
+ # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
486
+
487
+ # This is a list, each element is [num_images, patch * patch, dim]
488
+ # rank_print(f"Concat images : {concat_images.shape}")
489
+ encoded_image_features = torch.split(encoded_image_features, split_sizes)
490
+ image_features = []
491
+ for idx, image_feat in enumerate(encoded_image_features):
492
+ image_features.append(image_feat)
493
+ # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes)
494
+ # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}")
495
+ # image_features = torch.split(image_features, split_sizes, dim=0)
496
+ mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
497
+ image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
498
+ mm_newline_position = getattr(self.config, "mm_newline_position", "one_token")
499
+
500
+ if mm_patch_merge_type == "flat":
501
+ image_features = [x.flatten(0, 1) for x in image_features]
502
+
503
+ elif mm_patch_merge_type.startswith("spatial"):
504
+ new_image_features = []
505
+ for image_idx, image_feature in enumerate(image_features):
506
+ # FIXME: now assume the image is square, and split to 2x2 patches
507
+ # num_patches = h * w, where h = w = sqrt(num_patches)
508
+ # currently image_feature is a tensor of shape (4, num_patches, hidden_size)
509
+ # we want to first unflatten it to (2, 2, h, w, hidden_size)
510
+ # rank0_print("At least we are reaching here")
511
+ # import pdb; pdb.set_trace()
512
+ if image_feature.shape[0] > 1: # multi patches and multi images operations
513
+ # rank0_print("Single-images")
514
+ base_image_feature = image_feature[0]
515
+ image_feature = image_feature[1:]
516
+ height = width = self.get_vision_tower().num_patches_per_side
517
+ assert height * width == base_image_feature.shape[0]
518
+
519
+ if "anyres_max" in image_aspect_ratio:
520
+ matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio)
521
+ if matched_anyres_max_num_patches:
522
+ max_num_patches = int(matched_anyres_max_num_patches.group(1))
523
+
524
+ if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
525
+ if hasattr(self.get_vision_tower(), "image_size"):
526
+ vision_tower_image_size = self.get_vision_tower().image_size
527
+ else:
528
+ raise ValueError("vision_tower_image_size is not found in the vision tower.")
529
+ try:
530
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size)
531
+ except Exception as e:
532
+ num_patch_width, num_patch_height = 2, 2
533
+ image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
534
+ else:
535
+ image_feature = image_feature.view(2, 2, height, width, -1)
536
+
537
+ if "maxpool2x2" in mm_patch_merge_type:
538
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
539
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
540
+ image_feature = nn.functional.max_pool2d(image_feature, 2)
541
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
542
+ elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches:
543
+ unit = image_feature.shape[2]
544
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
545
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
546
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
547
+ c, h, w = image_feature.shape
548
+ times = math.sqrt(h * w / (max_num_patches * unit**2))
549
+ if times > 1.1:
550
+ image_feature = image_feature[None]
551
+ image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0]
552
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
553
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
554
+ elif "unpad" in mm_patch_merge_type:
555
+ image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
556
+ image_feature = image_feature.flatten(1, 2).flatten(2, 3)
557
+ image_feature = unpad_image(image_feature, image_sizes[image_idx])
558
+ image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1)
559
+ image_feature = image_feature.flatten(1, 2).transpose(0, 1)
560
+ else:
561
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
562
+ image_feature = image_feature.flatten(0, 3)
563
+ if "nobase" in mm_patch_merge_type:
564
+ pass
565
+ else:
566
+ image_feature = torch.cat((base_image_feature, image_feature), dim=0)
567
+ new_image_features.append(image_feature)
568
+ else: # single image operations
569
+ image_feature = image_feature[0]
570
+ if "unpad" in mm_patch_merge_type:
571
+ image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0)
572
+
573
+ new_image_features.append(image_feature)
574
+ image_features = new_image_features
575
+ else:
576
+ raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
577
+ else:
578
+ image_features = self.encode_images(images)
579
+
580
+ # TODO: image start / end is not implemented here to support pretraining.
581
+ if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False):
582
+ raise NotImplementedError
583
+ # rank_print(f"Total images : {len(image_features)}")
584
+
585
+ # Let's just add dummy tensors if they do not exist,
586
+ # it is a headache to deal with None all the time.
587
+ # But it is not ideal, and if you have a better idea,
588
+ # please open an issue / submit a PR, thanks.
589
+ _labels = labels
590
+ _position_ids = position_ids
591
+ _attention_mask = attention_mask
592
+ if attention_mask is None:
593
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
594
+ else:
595
+ attention_mask = attention_mask.bool()
596
+ if position_ids is None:
597
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
598
+ if labels is None:
599
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
600
+
601
+ # remove the padding using attention_mask -- FIXME
602
+ _input_ids = input_ids
603
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
604
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
605
+ old_attention_mask = attention_mask.clone().detach()
606
+
607
+ new_input_embeds = []
608
+ new_labels = []
609
+ cur_image_idx = 0
610
+ img_token_num = [0 for _ in range(len(input_ids))]
611
+ num_images_batch = []
612
+ # rank_print("Inserting Images embedding")
613
+ for batch_idx, cur_input_ids in enumerate(input_ids):
614
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
615
+ num_images_batch.append(num_images)
616
+ # rank0_print(num_images)
617
+ if num_images == 0:
618
+ cur_image_features = image_features[cur_image_idx]
619
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
620
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
621
+ new_input_embeds.append(cur_input_embeds)
622
+ new_labels.append(labels[batch_idx])
623
+ cur_image_idx += 1
624
+ continue
625
+
626
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
627
+ cur_input_ids_noim = []
628
+ cur_labels = labels[batch_idx]
629
+ cur_labels_noim = []
630
+ for i in range(len(image_token_indices) - 1):
631
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]])
632
+ cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]])
633
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
634
+ cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
635
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
636
+ cur_new_input_embeds = []
637
+ cur_new_labels = []
638
+
639
+ for i in range(num_images + 1):
640
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
641
+ cur_new_labels.append(cur_labels_noim[i])
642
+ if i < num_images:
643
+ try:
644
+ cur_image_features = image_features[cur_image_idx]
645
+ except IndexError:
646
+ cur_image_features = image_features[cur_image_idx - 1]
647
+ img_token_num[batch_idx] += image_features[cur_image_idx].shape[0]
648
+ cur_image_idx += 1
649
+ cur_new_input_embeds.append(cur_image_features)
650
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
651
+
652
+ cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
653
+
654
+ # import pdb; pdb.set_trace()
655
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
656
+ cur_new_labels = torch.cat(cur_new_labels)
657
+
658
+ new_input_embeds.append(cur_new_input_embeds)
659
+ new_labels.append(cur_new_labels)
660
+
661
+ # Truncate sequences to max length as image embeddings can make the sequence longer
662
+ tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
663
+ # rank_print("Finishing Inserting")
664
+
665
+ new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
666
+ new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
667
+ # TODO: Hard code for control loss spike
668
+ # if tokenizer_model_max_length is not None:
669
+ # new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)]
670
+ # new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)]
671
+
672
+ # Combine them
673
+ max_len = max(x.shape[0] for x in new_input_embeds)
674
+ batch_size = len(new_input_embeds)
675
+
676
+ new_input_embeds_padded = []
677
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
678
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
679
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
680
+ # rank0_print("Prepare pos id")
681
+
682
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
683
+ cur_len = cur_new_embed.shape[0]
684
+ if getattr(self.config, "tokenizer_padding_side", "right") == "left":
685
+ new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0))
686
+ if cur_len > 0:
687
+ new_labels_padded[i, -cur_len:] = cur_new_labels
688
+ attention_mask[i, -cur_len:] = True
689
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
690
+ else:
691
+ new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
692
+ if cur_len > 0:
693
+ new_labels_padded[i, :cur_len] = cur_new_labels
694
+ attention_mask[i, :cur_len] = True
695
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
696
+
697
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
698
+ # rank0_print("tokenizer padding")
699
+
700
+ if _labels is None:
701
+ new_labels = None
702
+ else:
703
+ new_labels = new_labels_padded
704
+
705
+ if _attention_mask is None:
706
+ attention_mask = None
707
+ else:
708
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
709
+
710
+ if _position_ids is None:
711
+ position_ids = None
712
+ if getattr(self.config, "use_pos_skipping", False) and self.training:
713
+ position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device)
714
+ split_position = random.randint(0, new_input_embeds.size(1))
715
+ left_add = random.randint(0, self.config.pos_skipping_range)
716
+ right_add = random.randint(left_add, self.config.pos_skipping_range)
717
+ position_ids[:, :split_position] += left_add
718
+ position_ids[:, split_position:] += right_add
719
+ # import pdb; pdb.set_trace()
720
+ # rank0_print("Finish preparing")
721
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, old_attention_mask, img_token_num, num_images_batch
722
+
723
+
724
+ class MLCDSegConfig(Qwen2Config):
725
+ model_type = "mlcd_seg"
726
+
727
+
728
+ class MLCDSegModel(MLCDSegMetaModel, Qwen2Model):
729
+ config_class = MLCDSegConfig
730
+
731
+ def __init__(self, config: Qwen2Config):
732
+ super(MLCDSegModel, self).__init__(config)
733
+
734
+ @dataclass
735
+ class MLCDSegOutputWithPast(CausalLMOutputWithPast):
736
+ labels: Optional[torch.FloatTensor] = None
737
+
738
+ class MLCDSegForCausalLM(Qwen2ForCausalLM, MLCDSegMetaForCausalLM):
739
+ config_class = MLCDSegConfig
740
+
741
+ def __init__(self, config):
742
+ # super(Qwen2ForCausalLM, self).__init__(config)
743
+ Qwen2ForCausalLM.__init__(self, config)
744
+ config.model_type = "mlcd_seg_clm"
745
+ config.rope_scaling = None
746
+
747
+ self.model = MLCDSegModel(config)
748
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
749
+ # Initialize weights and apply final processing
750
+ self.post_init()
751
+ self.sam_transform = ResizeLongestSide(IMG_SIZE)
752
+
753
+ def get_model(self):
754
+ return self.model
755
+
756
+ def forward(
757
+ self,
758
+ input_ids: torch.LongTensor = None,
759
+ attention_mask: Optional[torch.Tensor] = None,
760
+ position_ids: Optional[torch.LongTensor] = None,
761
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
762
+ inputs_embeds: Optional[torch.FloatTensor] = None,
763
+ labels: Optional[torch.LongTensor] = None,
764
+ use_cache: Optional[bool] = None,
765
+ output_attentions: Optional[bool] = None,
766
+ output_hidden_states: Optional[bool] = None,
767
+ images: Optional[torch.FloatTensor] = None,
768
+ image_sizes: Optional[List[List[int]]] = None,
769
+ return_dict: Optional[bool] = None,
770
+ modalities: Optional[List[str]] = ["image"],
771
+ dpo_forward: Optional[bool] = False,
772
+ cache_position=None,
773
+ grounding_enc_imgs: Optional[List[torch.FloatTensor]] = None,
774
+ image_sam_resizes: Optional[List[torch.FloatTensor]] = None,
775
+ original_sizes: Optional[List[torch.FloatTensor]] = None,
776
+ masks_list: Optional[List[List[torch.FloatTensor]]] = None,
777
+ infer: bool = False,
778
+ force_seg: bool = True
779
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
780
+ input_ids_ = input_ids
781
+ if inputs_embeds is None:
782
+ (
783
+ input_ids,
784
+ position_ids,
785
+ attention_mask,
786
+ past_key_values,
787
+ inputs_embeds,
788
+ labels,
789
+ old_attention_mask,
790
+ img_token_num,
791
+ num_images_batch
792
+ ) = self.prepare_inputs_labels_for_multimodal(
793
+ input_ids,
794
+ position_ids,
795
+ attention_mask,
796
+ past_key_values,
797
+ labels,
798
+ images,
799
+ modalities,
800
+ image_sizes
801
+ )
802
+
803
+ if dpo_forward:
804
+ outputs = self.model(
805
+ input_ids=input_ids,
806
+ attention_mask=attention_mask,
807
+ position_ids=position_ids,
808
+ past_key_values=past_key_values,
809
+ inputs_embeds=inputs_embeds,
810
+ use_cache=use_cache,
811
+ output_attentions=output_attentions,
812
+ output_hidden_states=output_hidden_states,
813
+ return_dict=return_dict,
814
+ )
815
+
816
+ hidden_states = outputs[0]
817
+ logits = self.lm_head(hidden_states)
818
+ return logits, labels
819
+
820
+ else:
821
+ output = super().forward(
822
+ input_ids=input_ids,
823
+ attention_mask=attention_mask,
824
+ position_ids=position_ids,
825
+ past_key_values=past_key_values,
826
+ inputs_embeds=inputs_embeds,
827
+ labels=labels,
828
+ use_cache=use_cache,
829
+ output_attentions=output_attentions,
830
+ output_hidden_states=True,
831
+ return_dict=return_dict,
832
+ cache_position=cache_position
833
+ )
834
+ sam_image_embeddings = self.get_grounding_encoder_embs(grounding_enc_imgs)
835
+ if force_seg:
836
+ seg_token_mask = self.create_seg_token_mask(input_ids_, old_attention_mask, img_token_num, num_images_batch)
837
+ else:
838
+ # should be raise NotImplementedError
839
+ seg_token_mask = self.create_seg_token_mask(input_ids_, old_attention_mask, img_token_num, num_images_batch)
840
+ seg_text_embeds_batch = self.process_hidden_states(output["hidden_states"], seg_token_mask)
841
+ pred_masks_batch = self.generate_and_postprocess_masks(seg_text_embeds_batch, sam_image_embeddings, num_images_batch, image_sam_resizes, original_sizes)
842
+ if infer:
843
+ return {"output":output, "pred_masks":pred_masks_batch}
844
+ return MLCDSegOutputWithPast(**output)
845
+
846
+ @torch.no_grad()
847
+ def generate(
848
+ self,
849
+ inputs: Optional[torch.Tensor] = None,
850
+ images: Optional[torch.Tensor] = None,
851
+ image_sizes: Optional[torch.Tensor] = None,
852
+ modalities: Optional[List[str]] = ["image"],
853
+ **kwargs,
854
+ ) -> Union[GenerateOutput, torch.LongTensor]:
855
+ position_ids = kwargs.pop("position_ids", None)
856
+ attention_mask = kwargs.pop("attention_mask", None)
857
+ if "inputs_embeds" in kwargs:
858
+ raise NotImplementedError("`inputs_embeds` is not supported")
859
+
860
+ if images is not None:
861
+ (inputs, position_ids, attention_mask, _, inputs_embeds, _) = self.prepare_inputs_labels_for_multimodal(inputs, position_ids, attention_mask, None, None, images, modalities, image_sizes=image_sizes)
862
+ else:
863
+ inputs_embeds = self.get_model().embed_tokens(inputs)
864
+
865
+ return super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs)
866
+
867
+ def generate_and_postprocess_masks(self, seg_text_embeds_batch, sam_image_embeddings, num_images_batch, image_sam_resizes, original_sizes):
868
+ assert len(seg_text_embeds_batch) == len(num_images_batch)
869
+
870
+ pred_masks_batch = [] # list()
871
+ for batch_i, seg_text_embeds in enumerate(seg_text_embeds_batch):
872
+ num_img = max(1, num_images_batch[batch_i])
873
+
874
+ pred_mask_ = torch.empty((0, original_sizes[batch_i][0], original_sizes[batch_i][1]), device=seg_text_embeds.device)
875
+ for img_i in range(num_img):
876
+ sparse_embeddings, dense_embeddings = self.model.sam.prompt_encoder(
877
+ points=None, boxes=None, masks=None, text_embeds=seg_text_embeds.unsqueeze(1)[img_i::num_img,:,:]
878
+ )
879
+ sparse_embeddings = sparse_embeddings.to(seg_text_embeds.dtype)
880
+
881
+ low_res_masks, _ = self.model.sam.mask_decoder(
882
+ image_embeddings=sam_image_embeddings[batch_i][img_i].unsqueeze(0),
883
+ image_pe=self.model.sam.prompt_encoder.get_dense_pe(),
884
+ sparse_prompt_embeddings=sparse_embeddings,
885
+ dense_prompt_embeddings=dense_embeddings,
886
+ multimask_output=False, )
887
+ pred_mask = self.model.sam.postprocess_masks(
888
+ low_res_masks, input_size=image_sam_resizes[batch_i][img_i], original_size=original_sizes[batch_i],)
889
+ pred_mask_ = torch.cat([pred_mask_, pred_mask[:,0]], dim=0)
890
+ pred_masks_batch.append(pred_mask_)
891
+ return pred_masks_batch
892
+
893
+ def process_hidden_states(self, output_hidden_states, seg_token_mask):
894
+ hidden_states_ = [self.model.text2sam_projection(output_hidden_states[-1])]
895
+ hidden_states_ = torch.stack(hidden_states_, dim=-1).sum(dim=-1)
896
+ seg_text_embeds_batch = []
897
+ for i, hidden_state_ in enumerate(hidden_states_):
898
+ # assert hidden_state_.shape[0] == seg_token_mask.shape[1], f"hidden:{hidden_state_.shape}, segtoken:{seg_token_mask.shape}"
899
+ # seg_text_embeds_batch.append(hidden_state_[seg_token_mask[i]])
900
+ seg_text_embeds_batch.append(hidden_state_[seg_token_mask[i][:hidden_state_.shape[0]]])
901
+ return seg_text_embeds_batch
902
+
903
+ def create_seg_token_mask(self, input_ids, attention_mask, img_token_num, num_images_batch):
904
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
905
+ max_len = 0
906
+ for i, _ in enumerate(input_ids):
907
+ max_len = max(max_len, len(input_ids[i]) + img_token_num[i] - num_images_batch[i])
908
+
909
+ seg_token_mask = []
910
+ for i, _ in enumerate(input_ids):
911
+ mask = input_ids[i][num_images_batch[i]:] == self.seg_token_idx
912
+ seg_token_mask.append(
913
+ torch.cat(
914
+ [torch.zeros((1, img_token_num[i])).bool().cuda(), mask.unsqueeze(0), torch.zeros((1, max_len-(len(input_ids[i]) + img_token_num[i] - num_images_batch[i]))).bool().cuda()], dim=1
915
+ )
916
+ )
917
+ return torch.cat(seg_token_mask, dim=0)
918
+
919
+ def get_grounding_encoder_embs(self, batch_images: torch.FloatTensor):
920
+ # with torch.no_grad():
921
+ batch_feats = []
922
+ for images in batch_images:
923
+ batch_feats.append(torch.cat([self._encode_single_image(img) for img in images], dim=0))
924
+ return batch_feats
925
+
926
+ def _encode_single_image(self, image):
927
+ # torch.cuda.empty_cache()
928
+ return self.model.sam.image_encoder(image.unsqueeze(0))
929
+
930
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
931
+ images = kwargs.pop("images", None)
932
+ image_sizes = kwargs.pop("image_sizes", None)
933
+ inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs)
934
+ if images is not None:
935
+ inputs["images"] = images
936
+ if image_sizes is not None:
937
+ inputs["image_sizes"] = image_sizes
938
+ return inputs
939
+
940
+ def process_prompt(self, text, tokenizer: PreTrainedTokenizer, force_seg=True) -> Dict:
941
+ conv = default_conversation.copy()
942
+ BEGIN_SIGNAL = "### "
943
+ END_SIGNAL = "\n"
944
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
945
+ # Apply prompt templates
946
+ sys_prompt = default_conversation.system + "\n\n"
947
+ full_prompt = sys_prompt + BEGIN_SIGNAL + roles["human"] + ": " + text + END_SIGNAL
948
+ if force_seg:
949
+ full_prompt += BEGIN_SIGNAL + roles["gpt"] + ": It is [SEG]." + END_SIGNAL
950
+ full_prompt += BEGIN_SIGNAL
951
+ input_ids = torch.stack([gen_image_token(full_prompt, tokenizer, return_tensors='pt')], dim=0)
952
+ return dict(
953
+ input_ids=input_ids,
954
+ labels=None,
955
+ )
956
+
957
+ def process_images(self, images, image_processor, model_cfg):
958
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
959
+ new_images = []
960
+ if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
961
+ for image in images:
962
+ image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
963
+ new_images.append(image)
964
+ else:
965
+ return image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
966
+ if all(x.shape == new_images[0].shape for x in new_images):
967
+ new_images = torch.stack(new_images, dim=0)
968
+ return new_images
969
+
970
+ def predict_forward(self, image, prompt, tokenizer, force_seg=True):
971
+ self.seg_token_idx = tokenizer(DEFAULT_SEG_TOKEN, add_special_tokens=False).input_ids[0]
972
+ image_np = np.array(image)
973
+ image_sizes = [image.size]
974
+ input_ids = self.process_prompt(prompt, tokenizer, force_seg)["input_ids"].to(self.device) # 这里需要设置对应的device
975
+ image_processor = self.get_vision_tower().image_processor
976
+ image_tensors = self.process_images([image], image_processor, self.config)
977
+ image_np_resize = self.sam_transform.apply_image(image_np)
978
+ original_size_list = [image_np.shape[:2]]
979
+ image_sam_resize_list = [image_np_resize.shape[:2]]
980
+ grounding_enc_img_list = [grounding_enc_processor(torch.from_numpy(image_np_resize).permute(2, 0, 1).contiguous()).to(dtype=self.dtype, device=self.device, non_blocking=True)]
981
+ collect_size = list(set(original_size_list))
982
+ if len(collect_size) == 0:
983
+ mask_h, mask_w = 336, 336
984
+ elif len(collect_size) == 1:
985
+ mask_h, mask_w = collect_size[0]
986
+ else:
987
+ areas = [h*w for (h, w) in collect_size]
988
+ mask_h, mask_w = collect_size[areas.index(max(areas))]
989
+ if isinstance(image_tensors, list):
990
+ image_aspect_ratio = getattr(self.config, "image_aspect_ratio", None)
991
+ if image_aspect_ratio=="anyres_mul" or image_aspect_ratio=="anyres":
992
+ image_tensors = [[x_.to(dtype=self.dtype, device=self.device, non_blocking=True)for x_ in image_tensors]]
993
+ else:
994
+ image_tensors = [[x_.unsqueeze(dim=0).to(dtype=self.dtype, device=self.device, non_blocking=True) for x_ in image_tensors]]
995
+ else:
996
+ image_tensors = image_tensors.to(dtype=self.dtype, device='cuda', non_blocking=True)
997
+
998
+ with torch.inference_mode():
999
+ net_out = self.forward(
1000
+ input_ids=input_ids,
1001
+ output_hidden_states=True,
1002
+ images=image_tensors,
1003
+ image_sizes=image_sizes,
1004
+ grounding_enc_imgs=[torch.stack(grounding_enc_img_list, dim=0)],
1005
+ image_sam_resizes=[image_sam_resize_list],
1006
+ original_sizes=[(mask_h, mask_w)],
1007
+ infer=True,
1008
+ force_seg=force_seg
1009
+ )
1010
+ pred_mask = net_out["pred_masks"][0]
1011
+ return pred_mask
1012
+
1013
+
1014
+ def gen_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
1015
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
1016
+
1017
+ def insert_separator(X, sep):
1018
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
1019
+
1020
+ input_ids = []
1021
+ offset = 0
1022
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
1023
+ offset = 1
1024
+ input_ids.append(prompt_chunks[0][0])
1025
+
1026
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
1027
+ input_ids.extend(x[offset:])
1028
+
1029
+ if return_tensors is not None:
1030
+ if return_tensors == "pt":
1031
+ return torch.tensor(input_ids, dtype=torch.long)
1032
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
1033
+ return input_ids
1034
+
1035
+
1036
+ def grounding_enc_processor(x: torch.Tensor) -> torch.Tensor:
1037
+ x = (x - IMG_MEAN) / IMG_STD
1038
+ h, w = x.shape[-2:]
1039
+ x = F.pad(x, (0, IMG_SIZE - w, 0, IMG_SIZE - h))
1040
+ return x
1041
+
1042
+
1043
+ AutoConfig.register("mlcd_seg", MLCDSegConfig)
1044
+ AutoModelForCausalLM.register(MLCDSegConfig, MLCDSegForCausalLM)
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26445d9fa45cd664abd8fff7c7bbcad04f55ed7e4fffa4a1626ce28cd9db8a67
3
+ size 4874815168
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7243468b3338748241548e0d162d45dd5cdca63b73a4db86060f36053e23f065
3
+ size 4932751008
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84b408e9cc3986e0f122644d3ba397e5a92ac492fdebe6dbbe7a5bd9a7fc5955
3
+ size 4996578680
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a74f572810afc8de9c3c31c8f278fd276c5ece132356b9f29fc945c09e88ff6a
3
+ size 2345774016
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
sam.py ADDED
@@ -0,0 +1,1324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Merge files from META Sam project @DeepGlint 2025
3
+ https://github.com/facebookresearch/segment-anything
4
+ '''
5
+
6
+
7
+ import math
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from torch import Tensor
13
+ from typing import List, Dict, Any, Tuple, Type, Optional
14
+ from functools import partial
15
+
16
+
17
+ def text2sam_projection_layer(config):
18
+ in_dim, out_dim = config.hidden_size, 256
19
+ modules = [nn.Linear(in_dim, out_dim)]
20
+ for _ in range(1, 2):
21
+ modules.append(nn.GELU())
22
+ modules.append(nn.Linear(out_dim, out_dim))
23
+ return nn.Sequential(*modules)
24
+
25
+
26
+ def build_sam_vit_h():
27
+ return _build_sam(
28
+ encoder_embed_dim=1280,
29
+ encoder_depth=32,
30
+ encoder_num_heads=16,
31
+ encoder_global_attn_indexes=[7, 15, 23, 31],
32
+ )
33
+
34
+
35
+ def _build_sam(
36
+ encoder_embed_dim,
37
+ encoder_depth,
38
+ encoder_num_heads,
39
+ encoder_global_attn_indexes,
40
+ ):
41
+ prompt_embed_dim = 256
42
+ image_size = 1024
43
+ vit_patch_size = 16
44
+ image_embedding_size = image_size // vit_patch_size
45
+ sam = Sam(
46
+ image_encoder=ImageEncoderViT(
47
+ depth=encoder_depth,
48
+ embed_dim=encoder_embed_dim,
49
+ img_size=image_size,
50
+ mlp_ratio=4,
51
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
52
+ num_heads=encoder_num_heads,
53
+ patch_size=vit_patch_size,
54
+ qkv_bias=True,
55
+ use_rel_pos=True,
56
+ global_attn_indexes=encoder_global_attn_indexes,
57
+ window_size=14,
58
+ out_chans=prompt_embed_dim,
59
+ ),
60
+ prompt_encoder=PromptEncoder(
61
+ embed_dim=prompt_embed_dim,
62
+ image_embedding_size=(image_embedding_size, image_embedding_size),
63
+ input_image_size=(image_size, image_size),
64
+ mask_in_chans=16,
65
+ ),
66
+ mask_decoder=MaskDecoder(
67
+ num_multimask_outputs=3,
68
+ transformer=TwoWayTransformer(
69
+ depth=2,
70
+ embedding_dim=prompt_embed_dim,
71
+ mlp_dim=2048,
72
+ num_heads=8,
73
+ ),
74
+ transformer_dim=prompt_embed_dim,
75
+ iou_head_depth=3,
76
+ iou_head_hidden_dim=256,
77
+ ),
78
+ pixel_mean=[123.675, 116.28, 103.53],
79
+ pixel_std=[58.395, 57.12, 57.375],
80
+ )
81
+ sam.eval()
82
+ return sam
83
+
84
+
85
+ def window_partition(
86
+ x: torch.Tensor, window_size: int
87
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
88
+ """
89
+ Partition into non-overlapping windows with padding if needed.
90
+ Args:
91
+ x (tensor): input tokens with [B, H, W, C].
92
+ window_size (int): window size.
93
+
94
+ Returns:
95
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
96
+ (Hp, Wp): padded height and width before partition
97
+ """
98
+ B, H, W, C = x.shape
99
+
100
+ pad_h = (window_size - H % window_size) % window_size
101
+ pad_w = (window_size - W % window_size) % window_size
102
+ if pad_h > 0 or pad_w > 0:
103
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
104
+ Hp, Wp = H + pad_h, W + pad_w
105
+
106
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
107
+ windows = (
108
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
109
+ )
110
+ return windows, (Hp, Wp)
111
+
112
+
113
+ def window_unpartition(
114
+ windows: torch.Tensor,
115
+ window_size: int,
116
+ pad_hw: Tuple[int, int],
117
+ hw: Tuple[int, int],
118
+ ) -> torch.Tensor:
119
+ """
120
+ Window unpartition into original sequences and removing padding.
121
+ Args:
122
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
123
+ window_size (int): window size.
124
+ pad_hw (Tuple): padded height and width (Hp, Wp).
125
+ hw (Tuple): original height and width (H, W) before padding.
126
+
127
+ Returns:
128
+ x: unpartitioned sequences with [B, H, W, C].
129
+ """
130
+ Hp, Wp = pad_hw
131
+ H, W = hw
132
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
133
+ x = windows.view(
134
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
135
+ )
136
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
137
+
138
+ if Hp > H or Wp > W:
139
+ x = x[:, :H, :W, :].contiguous()
140
+ return x
141
+
142
+
143
+ class CommonMLP(nn.Module):
144
+ def __init__(
145
+ self,
146
+ embedding_dim: int,
147
+ mlp_dim: int,
148
+ act: Type[nn.Module] = nn.GELU,
149
+ ) -> None:
150
+ super().__init__()
151
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
152
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
153
+ self.act = act()
154
+
155
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
156
+ return self.lin2(self.act(self.lin1(x)))
157
+
158
+
159
+ class LayerNorm2d(nn.Module):
160
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
161
+ super().__init__()
162
+ self.weight = nn.Parameter(torch.ones(num_channels))
163
+ self.bias = nn.Parameter(torch.zeros(num_channels))
164
+ self.eps = eps
165
+
166
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
167
+ u = x.mean(1, keepdim=True)
168
+ s = (x - u).pow(2).mean(1, keepdim=True)
169
+ x = (x - u) / torch.sqrt(s + self.eps)
170
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
171
+ return x
172
+
173
+
174
+ class Attention(nn.Module):
175
+ """
176
+ An attention layer that allows for downscaling the size of the embedding
177
+ after projection to queries, keys, and values.
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ embedding_dim: int,
183
+ num_heads: int,
184
+ downsample_rate: int = 1,
185
+ ) -> None:
186
+ super().__init__()
187
+ self.embedding_dim = embedding_dim
188
+ self.internal_dim = embedding_dim // downsample_rate
189
+ self.num_heads = num_heads
190
+ assert (
191
+ self.internal_dim % num_heads == 0
192
+ ), "num_heads must divide embedding_dim."
193
+
194
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
195
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
196
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
197
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
198
+
199
+ def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
200
+ b, n, c = x.shape
201
+ x = x.reshape(b, n, num_heads, c // num_heads)
202
+ return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
203
+
204
+ def _recombine_heads(self, x: Tensor) -> Tensor:
205
+ b, n_heads, n_tokens, c_per_head = x.shape
206
+ x = x.transpose(1, 2)
207
+ return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
208
+
209
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
210
+ # Input projections
211
+ q = self.q_proj(q)
212
+ k = self.k_proj(k)
213
+ v = self.v_proj(v)
214
+
215
+ # Separate into heads
216
+ q = self._separate_heads(q, self.num_heads)
217
+ k = self._separate_heads(k, self.num_heads)
218
+ v = self._separate_heads(v, self.num_heads)
219
+
220
+ # Attention
221
+ _, _, _, c_per_head = q.shape
222
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
223
+ attn = attn / math.sqrt(c_per_head)
224
+ attn = torch.softmax(attn, dim=-1)
225
+
226
+ # Get output
227
+ out = attn @ v
228
+ out = self._recombine_heads(out)
229
+ out = self.out_proj(out)
230
+
231
+ return out
232
+
233
+
234
+ class TwoWayTransformer(nn.Module):
235
+ def __init__(
236
+ self,
237
+ depth: int,
238
+ embedding_dim: int,
239
+ num_heads: int,
240
+ mlp_dim: int,
241
+ activation: Type[nn.Module] = nn.ReLU,
242
+ attention_downsample_rate: int = 2,
243
+ ) -> None:
244
+ """
245
+ A transformer decoder that attends to an input image using
246
+ queries whose positional embedding is supplied.
247
+
248
+ Args:
249
+ depth (int): number of layers in the transformer
250
+ embedding_dim (int): the channel dimension for the input embeddings
251
+ num_heads (int): the number of heads for multihead attention. Must
252
+ divide embedding_dim
253
+ mlp_dim (int): the channel dimension internal to the MLP block
254
+ activation (nn.Module): the activation to use in the MLP block
255
+ """
256
+ super().__init__()
257
+ self.depth = depth
258
+ self.embedding_dim = embedding_dim
259
+ self.num_heads = num_heads
260
+ self.mlp_dim = mlp_dim
261
+ self.layers = nn.ModuleList()
262
+
263
+ for i in range(depth):
264
+ self.layers.append(
265
+ TwoWayAttentionBlock(
266
+ embedding_dim=embedding_dim,
267
+ num_heads=num_heads,
268
+ mlp_dim=mlp_dim,
269
+ activation=activation,
270
+ attention_downsample_rate=attention_downsample_rate,
271
+ skip_first_layer_pe=(i == 0),
272
+ )
273
+ )
274
+
275
+ self.final_attn_token_to_image = Attention(
276
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
277
+ )
278
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
279
+
280
+ def forward(
281
+ self,
282
+ image_embedding: Tensor,
283
+ image_pe: Tensor,
284
+ point_embedding: Tensor,
285
+ ) -> Tuple[Tensor, Tensor]:
286
+ """
287
+ Args:
288
+ image_embedding (torch.Tensor): image to attend to. Should be shape
289
+ B x embedding_dim x h x w for any h and w.
290
+ image_pe (torch.Tensor): the positional encoding to add to the image. Must
291
+ have the same shape as image_embedding.
292
+ point_embedding (torch.Tensor): the embedding to add to the query points.
293
+ Must have shape B x N_points x embedding_dim for any N_points.
294
+
295
+ Returns:
296
+ torch.Tensor: the processed point_embedding
297
+ torch.Tensor: the processed image_embedding
298
+ """
299
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
300
+ bs, c, h, w = image_embedding.shape
301
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
302
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
303
+
304
+ # Prepare queries
305
+ queries = point_embedding
306
+ keys = image_embedding
307
+
308
+ # Apply transformer blocks and final layernorm
309
+ for layer in self.layers:
310
+ queries, keys = layer(
311
+ queries=queries,
312
+ keys=keys,
313
+ query_pe=point_embedding,
314
+ key_pe=image_pe,
315
+ )
316
+
317
+ # Apply the final attention layer from the points to the image
318
+ q = queries + point_embedding
319
+ k = keys + image_pe
320
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
321
+ queries = queries + attn_out
322
+ queries = self.norm_final_attn(queries)
323
+
324
+ return queries, keys
325
+
326
+
327
+ class TwoWayAttentionBlock(nn.Module):
328
+ def __init__(
329
+ self,
330
+ embedding_dim: int,
331
+ num_heads: int,
332
+ mlp_dim: int = 2048,
333
+ activation: Type[nn.Module] = nn.ReLU,
334
+ attention_downsample_rate: int = 2,
335
+ skip_first_layer_pe: bool = False,
336
+ ) -> None:
337
+ """
338
+ A transformer block with four layers: (1) self-attention of sparse
339
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
340
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
341
+ inputs.
342
+
343
+ Arguments:
344
+ embedding_dim (int): the channel dimension of the embeddings
345
+ num_heads (int): the number of heads in the attention layers
346
+ mlp_dim (int): the hidden dimension of the mlp block
347
+ activation (nn.Module): the activation of the mlp block
348
+ skip_first_layer_pe (bool): skip the PE on the first layer
349
+ """
350
+ super().__init__()
351
+ self.self_attn = Attention(embedding_dim, num_heads)
352
+ self.norm1 = nn.LayerNorm(embedding_dim)
353
+
354
+ self.cross_attn_token_to_image = Attention(
355
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
356
+ )
357
+ self.norm2 = nn.LayerNorm(embedding_dim)
358
+
359
+ self.mlp = CommonMLP(embedding_dim, mlp_dim, activation)
360
+ self.norm3 = nn.LayerNorm(embedding_dim)
361
+
362
+ self.norm4 = nn.LayerNorm(embedding_dim)
363
+ self.cross_attn_image_to_token = Attention(
364
+ embedding_dim, num_heads, downsample_rate=attention_downsample_rate
365
+ )
366
+
367
+ self.skip_first_layer_pe = skip_first_layer_pe
368
+
369
+ def forward(
370
+ self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
371
+ ) -> Tuple[Tensor, Tensor]:
372
+ # Self attention block
373
+ if self.skip_first_layer_pe:
374
+ queries = self.self_attn(q=queries, k=queries, v=queries)
375
+ else:
376
+ q = queries + query_pe
377
+ attn_out = self.self_attn(q=q, k=q, v=queries)
378
+ queries = queries + attn_out
379
+ queries = self.norm1(queries)
380
+
381
+ # Cross attention block, tokens attending to image embedding
382
+ q = queries + query_pe
383
+ k = keys + key_pe
384
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
385
+ queries = queries + attn_out
386
+ queries = self.norm2(queries)
387
+
388
+ # MLP block
389
+ mlp_out = self.mlp(queries)
390
+ queries = queries + mlp_out
391
+ queries = self.norm3(queries)
392
+
393
+ # Cross attention block, image embedding attending to tokens
394
+ q = queries + query_pe
395
+ k = keys + key_pe
396
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
397
+ keys = keys + attn_out
398
+ keys = self.norm4(keys)
399
+
400
+ return queries, keys
401
+
402
+
403
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
404
+ """
405
+ Get relative positional embeddings according to the relative positions of
406
+ query and key sizes.
407
+ Args:
408
+ q_size (int): size of query q.
409
+ k_size (int): size of key k.
410
+ rel_pos (Tensor): relative position embeddings (L, C).
411
+
412
+ Returns:
413
+ Extracted positional embeddings according to relative positions.
414
+ """
415
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
416
+ # Interpolate rel pos if needed.
417
+ if rel_pos.shape[0] != max_rel_dist:
418
+ # Interpolate rel pos.
419
+ rel_pos_resized = F.interpolate(
420
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
421
+ size=max_rel_dist,
422
+ mode="linear",
423
+ )
424
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
425
+ else:
426
+ rel_pos_resized = rel_pos
427
+
428
+ # Scale the coords with short length if shapes for q and k are different.
429
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
430
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
431
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
432
+
433
+ return rel_pos_resized[relative_coords.long()]
434
+
435
+
436
+ def add_decomposed_rel_pos(
437
+ attn: torch.Tensor,
438
+ q: torch.Tensor,
439
+ rel_pos_h: torch.Tensor,
440
+ rel_pos_w: torch.Tensor,
441
+ q_size: Tuple[int, int],
442
+ k_size: Tuple[int, int],
443
+ ) -> torch.Tensor:
444
+ """
445
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
446
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
447
+ Args:
448
+ attn (Tensor): attention map.
449
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
450
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
451
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
452
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
453
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
454
+
455
+ Returns:
456
+ attn (Tensor): attention map with added relative positional embeddings.
457
+ """
458
+ q_h, q_w = q_size
459
+ k_h, k_w = k_size
460
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
461
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
462
+
463
+ B, _, dim = q.shape
464
+ r_q = q.reshape(B, q_h, q_w, dim)
465
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
466
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
467
+
468
+ attn = (
469
+ attn.view(B, q_h, q_w, k_h, k_w)
470
+ + rel_h[:, :, :, :, None]
471
+ + rel_w[:, :, :, None, :]
472
+ ).view(B, q_h * q_w, k_h * k_w)
473
+
474
+ return attn
475
+
476
+
477
+ class ImageEncoderViTAttention(nn.Module):
478
+ """Multi-head Attention block with relative position embeddings."""
479
+
480
+ def __init__(
481
+ self,
482
+ dim: int,
483
+ num_heads: int = 8,
484
+ qkv_bias: bool = True,
485
+ use_rel_pos: bool = False,
486
+ rel_pos_zero_init: bool = True,
487
+ input_size: Optional[Tuple[int, int]] = None,
488
+ ) -> None:
489
+ """
490
+ Args:
491
+ dim (int): Number of input channels.
492
+ num_heads (int): Number of attention heads.
493
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
494
+ rel_pos (bool): If True, add relative positional embeddings to the attention map.
495
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
496
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
497
+ positional parameter size.
498
+ """
499
+ super().__init__()
500
+ self.num_heads = num_heads
501
+ head_dim = dim // num_heads
502
+ self.scale = head_dim**-0.5
503
+
504
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
505
+ self.proj = nn.Linear(dim, dim)
506
+
507
+ self.use_rel_pos = use_rel_pos
508
+ if self.use_rel_pos:
509
+ assert (
510
+ input_size is not None
511
+ ), "Input size must be provided if using relative positional encoding."
512
+ # initialize relative positional embeddings
513
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
514
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
515
+
516
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
517
+ B, H, W, _ = x.shape
518
+ # qkv with shape (3, B, nHead, H * W, C)
519
+ qkv = (
520
+ self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
521
+ )
522
+ # q, k, v with shape (B * nHead, H * W, C)
523
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
524
+
525
+ attn = (q * self.scale) @ k.transpose(-2, -1)
526
+
527
+ if self.use_rel_pos:
528
+ attn = add_decomposed_rel_pos(
529
+ attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
530
+ )
531
+
532
+ attn = attn.softmax(dim=-1)
533
+ x = (
534
+ (attn @ v)
535
+ .view(B, self.num_heads, H, W, -1)
536
+ .permute(0, 2, 3, 1, 4)
537
+ .reshape(B, H, W, -1)
538
+ )
539
+ x = self.proj(x)
540
+
541
+ return x
542
+
543
+
544
+ class PatchEmbed(nn.Module):
545
+ """
546
+ Image to Patch Embedding.
547
+ """
548
+
549
+ def __init__(
550
+ self,
551
+ kernel_size: Tuple[int, int] = (16, 16),
552
+ stride: Tuple[int, int] = (16, 16),
553
+ padding: Tuple[int, int] = (0, 0),
554
+ in_chans: int = 3,
555
+ embed_dim: int = 768,
556
+ ) -> None:
557
+ """
558
+ Args:
559
+ kernel_size (Tuple): kernel size of the projection layer.
560
+ stride (Tuple): stride of the projection layer.
561
+ padding (Tuple): padding size of the projection layer.
562
+ in_chans (int): Number of input image channels.
563
+ embed_dim (int): Patch embedding dimension.
564
+ """
565
+ super().__init__()
566
+
567
+ self.proj = nn.Conv2d(
568
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
569
+ )
570
+
571
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
572
+ x = self.proj(x)
573
+ # B C H W -> B H W C
574
+ x = x.permute(0, 2, 3, 1)
575
+ return x
576
+
577
+
578
+ class ImageEncoderViTBlock(nn.Module):
579
+ """Transformer blocks with support of window attention and residual propagation blocks"""
580
+
581
+ def __init__(
582
+ self,
583
+ dim: int,
584
+ num_heads: int,
585
+ mlp_ratio: float = 4.0,
586
+ qkv_bias: bool = True,
587
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
588
+ act_layer: Type[nn.Module] = nn.GELU,
589
+ use_rel_pos: bool = False,
590
+ rel_pos_zero_init: bool = True,
591
+ window_size: int = 0,
592
+ input_size: Optional[Tuple[int, int]] = None,
593
+ ) -> None:
594
+ """
595
+ Args:
596
+ dim (int): Number of input channels.
597
+ num_heads (int): Number of attention heads in each ViT block.
598
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
599
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
600
+ norm_layer (nn.Module): Normalization layer.
601
+ act_layer (nn.Module): Activation layer.
602
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
603
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
604
+ window_size (int): Window size for window attention blocks. If it equals 0, then
605
+ use global attention.
606
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
607
+ positional parameter size.
608
+ """
609
+ super().__init__()
610
+ self.norm1 = norm_layer(dim)
611
+ self.attn = ImageEncoderViTAttention(
612
+ dim,
613
+ num_heads=num_heads,
614
+ qkv_bias=qkv_bias,
615
+ use_rel_pos=use_rel_pos,
616
+ rel_pos_zero_init=rel_pos_zero_init,
617
+ input_size=input_size if window_size == 0 else (window_size, window_size),
618
+ )
619
+
620
+ self.norm2 = norm_layer(dim)
621
+ self.mlp = CommonMLP(
622
+ embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
623
+ )
624
+
625
+ self.window_size = window_size
626
+
627
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
628
+ shortcut = x
629
+ x = self.norm1(x)
630
+ # Window partition
631
+ if self.window_size > 0:
632
+ H, W = x.shape[1], x.shape[2]
633
+ x, pad_hw = window_partition(x, self.window_size)
634
+
635
+ x = self.attn(x)
636
+ # Reverse window partition
637
+ if self.window_size > 0:
638
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
639
+
640
+ x = shortcut + x
641
+ x = x + self.mlp(self.norm2(x))
642
+
643
+ return x
644
+
645
+
646
+ class ImageEncoderViT(nn.Module):
647
+ def __init__(
648
+ self,
649
+ img_size: int = 1024,
650
+ patch_size: int = 16,
651
+ in_chans: int = 3,
652
+ embed_dim: int = 768,
653
+ depth: int = 12,
654
+ num_heads: int = 12,
655
+ mlp_ratio: float = 4.0,
656
+ out_chans: int = 256,
657
+ qkv_bias: bool = True,
658
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
659
+ act_layer: Type[nn.Module] = nn.GELU,
660
+ use_abs_pos: bool = True,
661
+ use_rel_pos: bool = False,
662
+ rel_pos_zero_init: bool = True,
663
+ window_size: int = 0,
664
+ global_attn_indexes: Tuple[int, ...] = (),
665
+ ) -> None:
666
+ """
667
+ Args:
668
+ img_size (int): Input image size.
669
+ patch_size (int): Patch size.
670
+ in_chans (int): Number of input image channels.
671
+ embed_dim (int): Patch embedding dimension.
672
+ depth (int): Depth of ViT.
673
+ num_heads (int): Number of attention heads in each ViT block.
674
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
675
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
676
+ norm_layer (nn.Module): Normalization layer.
677
+ act_layer (nn.Module): Activation layer.
678
+ use_abs_pos (bool): If True, use absolute positional embeddings.
679
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
680
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
681
+ window_size (int): Window size for window attention blocks.
682
+ global_attn_indexes (list): Indexes for blocks using global attention.
683
+ """
684
+ super().__init__()
685
+ self.img_size = img_size
686
+ self.embed_dim = embed_dim
687
+ self.out_chans = out_chans
688
+
689
+ self.patch_embed = PatchEmbed(
690
+ kernel_size=(patch_size, patch_size),
691
+ stride=(patch_size, patch_size),
692
+ in_chans=in_chans,
693
+ embed_dim=embed_dim,
694
+ )
695
+
696
+ self.pos_embed: Optional[nn.Parameter] = None
697
+ if use_abs_pos:
698
+ # Initialize absolute positional embedding with pretrain image size.
699
+ self.pos_embed = nn.Parameter(
700
+ torch.zeros(
701
+ 1, img_size // patch_size, img_size // patch_size, embed_dim
702
+ )
703
+ )
704
+
705
+ self.blocks = nn.ModuleList()
706
+ for i in range(depth):
707
+ block = ImageEncoderViTBlock(
708
+ dim=embed_dim,
709
+ num_heads=num_heads,
710
+ mlp_ratio=mlp_ratio,
711
+ qkv_bias=qkv_bias,
712
+ norm_layer=norm_layer,
713
+ act_layer=act_layer,
714
+ use_rel_pos=use_rel_pos,
715
+ rel_pos_zero_init=rel_pos_zero_init,
716
+ window_size=window_size if i not in global_attn_indexes else 0,
717
+ input_size=(img_size // patch_size, img_size // patch_size),
718
+ )
719
+ self.blocks.append(block)
720
+
721
+ self.neck = nn.Sequential(
722
+ nn.Conv2d(
723
+ embed_dim,
724
+ out_chans,
725
+ kernel_size=1,
726
+ bias=False,
727
+ ),
728
+ LayerNorm2d(out_chans),
729
+ nn.Conv2d(
730
+ out_chans,
731
+ out_chans,
732
+ kernel_size=3,
733
+ padding=1,
734
+ bias=False,
735
+ ),
736
+ LayerNorm2d(out_chans),
737
+ )
738
+
739
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
740
+ x = self.patch_embed(x)
741
+ if self.pos_embed is not None:
742
+ x = x + self.pos_embed
743
+
744
+ for blk in self.blocks:
745
+ x = blk(x)
746
+
747
+ dtype = x.dtype
748
+ if dtype == torch.float16: # prevent overflow
749
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
750
+ x = self.neck(x.permute(0, 3, 1, 2))
751
+ x = x.to(dtype)
752
+ else:
753
+ x = self.neck(x.permute(0, 3, 1, 2))
754
+ return x
755
+
756
+
757
+ class PositionEmbeddingRandom(nn.Module):
758
+ """
759
+ Positional encoding using random spatial frequencies.
760
+ """
761
+
762
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
763
+ super().__init__()
764
+ if scale is None or scale <= 0.0:
765
+ scale = 1.0
766
+ self.register_buffer(
767
+ "positional_encoding_gaussian_matrix",
768
+ scale * torch.randn((2, num_pos_feats)),
769
+ )
770
+
771
+ def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
772
+ """Positionally encode points that are normalized to [0,1]."""
773
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
774
+ coords = 2 * coords - 1
775
+
776
+ if coords.dtype != self.positional_encoding_gaussian_matrix.dtype:
777
+ coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)
778
+
779
+ coords = coords @ self.positional_encoding_gaussian_matrix
780
+ coords = 2 * np.pi * coords
781
+ # outputs d_1 x ... x d_n x C shape
782
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
783
+
784
+ def forward(self, size: Tuple[int, int]) -> torch.Tensor:
785
+ """Generate positional encoding for a grid of the specified size."""
786
+ h, w = size
787
+ device: Any = self.positional_encoding_gaussian_matrix.device
788
+ grid = torch.ones(
789
+ (h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype
790
+ )
791
+ y_embed = grid.cumsum(dim=0) - 0.5
792
+ x_embed = grid.cumsum(dim=1) - 0.5
793
+ y_embed = y_embed / h
794
+ x_embed = x_embed / w
795
+
796
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
797
+ return pe.permute(2, 0, 1) # C x H x W
798
+
799
+ def forward_with_coords(
800
+ self, coords_input: torch.Tensor, image_size: Tuple[int, int]
801
+ ) -> torch.Tensor:
802
+ """Positionally encode points that are not normalized to [0,1]."""
803
+ coords = coords_input.clone()
804
+ coords[:, :, 0] = coords[:, :, 0] / image_size[1]
805
+ coords[:, :, 1] = coords[:, :, 1] / image_size[0]
806
+ return self._pe_encoding(coords.to(torch.float)) # B x N x C
807
+
808
+
809
+ class PromptEncoder(nn.Module):
810
+ def __init__(
811
+ self,
812
+ embed_dim: int,
813
+ image_embedding_size: Tuple[int, int],
814
+ input_image_size: Tuple[int, int],
815
+ mask_in_chans: int,
816
+ activation: Type[nn.Module] = nn.GELU,
817
+ ) -> None:
818
+ """
819
+ Encodes prompts for input to SAM's mask decoder.
820
+
821
+ Arguments:
822
+ embed_dim (int): The prompts' embedding dimension
823
+ image_embedding_size (tuple(int, int)): The spatial size of the
824
+ image embedding, as (H, W).
825
+ input_image_size (int): The padded size of the image as input
826
+ to the image encoder, as (H, W).
827
+ mask_in_chans (int): The number of hidden channels used for
828
+ encoding input masks.
829
+ activation (nn.Module): The activation to use when encoding
830
+ input masks.
831
+ """
832
+ super().__init__()
833
+ self.embed_dim = embed_dim
834
+ self.input_image_size = input_image_size
835
+ self.image_embedding_size = image_embedding_size
836
+ self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
837
+
838
+ self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
839
+ point_embeddings = [
840
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
841
+ ]
842
+ self.point_embeddings = nn.ModuleList(point_embeddings)
843
+ self.not_a_point_embed = nn.Embedding(1, embed_dim)
844
+
845
+ self.mask_input_size = (
846
+ 4 * image_embedding_size[0],
847
+ 4 * image_embedding_size[1],
848
+ )
849
+ self.mask_downscaling = nn.Sequential(
850
+ nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
851
+ LayerNorm2d(mask_in_chans // 4),
852
+ activation(),
853
+ nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
854
+ LayerNorm2d(mask_in_chans),
855
+ activation(),
856
+ nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
857
+ )
858
+ self.no_mask_embed = nn.Embedding(1, embed_dim)
859
+
860
+ def get_dense_pe(self) -> torch.Tensor:
861
+ """
862
+ Returns the positional encoding used to encode point prompts,
863
+ applied to a dense set of points the shape of the image encoding.
864
+
865
+ Returns:
866
+ torch.Tensor: Positional encoding with shape
867
+ 1x(embed_dim)x(embedding_h)x(embedding_w)
868
+ """
869
+ return self.pe_layer(self.image_embedding_size).unsqueeze(0)
870
+
871
+ def _embed_points(
872
+ self,
873
+ points: torch.Tensor,
874
+ labels: torch.Tensor,
875
+ pad: bool,
876
+ ) -> torch.Tensor:
877
+ """Embeds point prompts."""
878
+ points = points + 0.5 # Shift to center of pixel
879
+ if pad:
880
+ padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
881
+ padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
882
+ points = torch.cat([points, padding_point], dim=1)
883
+ labels = torch.cat([labels, padding_label], dim=1)
884
+ point_embedding = self.pe_layer.forward_with_coords(
885
+ points, self.input_image_size
886
+ )
887
+ point_embedding[labels == -1] = 0.0
888
+ point_embedding[labels == -1] += self.not_a_point_embed.weight
889
+ point_embedding[labels == 0] += self.point_embeddings[0].weight
890
+ point_embedding[labels == 1] += self.point_embeddings[1].weight
891
+ return point_embedding
892
+
893
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
894
+ """Embeds box prompts."""
895
+ boxes = boxes + 0.5 # Shift to center of pixel
896
+ coords = boxes.reshape(-1, 2, 2)
897
+ corner_embedding = self.pe_layer.forward_with_coords(
898
+ coords, self.input_image_size
899
+ )
900
+ corner_embedding[:, 0, :] += self.point_embeddings[2].weight
901
+ corner_embedding[:, 1, :] += self.point_embeddings[3].weight
902
+ return corner_embedding
903
+
904
+ def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
905
+ """Embeds mask inputs."""
906
+ mask_embedding = self.mask_downscaling(masks)
907
+ return mask_embedding
908
+
909
+ def _get_batch_size(
910
+ self,
911
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
912
+ boxes: Optional[torch.Tensor],
913
+ masks: Optional[torch.Tensor],
914
+ text_embeds: Optional[torch.Tensor],
915
+ ) -> int:
916
+ """
917
+ Gets the batch size of the output given the batch size of the input prompts.
918
+ """
919
+ if points is not None:
920
+ return points[0].shape[0]
921
+ elif boxes is not None:
922
+ return boxes.shape[0]
923
+ elif masks is not None:
924
+ return masks.shape[0]
925
+ elif text_embeds is not None:
926
+ return text_embeds.shape[0]
927
+ else:
928
+ return 1
929
+
930
+ def _get_device(self) -> torch.device:
931
+ return self.point_embeddings[0].weight.device
932
+
933
+ def forward(
934
+ self,
935
+ points: Optional[Tuple[torch.Tensor, torch.Tensor]],
936
+ boxes: Optional[torch.Tensor],
937
+ masks: Optional[torch.Tensor],
938
+ text_embeds: Optional[torch.Tensor],
939
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
940
+ """
941
+ Embeds different types of prompts, returning both sparse and dense
942
+ embeddings.
943
+
944
+ Arguments:
945
+ points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
946
+ and labels to embed.
947
+ boxes (torch.Tensor or none): boxes to embed
948
+ masks (torch.Tensor or none): masks to embed
949
+
950
+ Returns:
951
+ torch.Tensor: sparse embeddings for the points and boxes, with shape
952
+ BxNx(embed_dim), where N is determined by the number of input points
953
+ and boxes.
954
+ torch.Tensor: dense embeddings for the masks, in the shape
955
+ Bx(embed_dim)x(embed_H)x(embed_W)
956
+ """
957
+ bs = self._get_batch_size(points, boxes, masks, text_embeds)
958
+ sparse_embeddings = torch.empty(
959
+ (bs, 0, self.embed_dim), device=self._get_device()
960
+ )
961
+ if points is not None:
962
+ coords, labels = points
963
+ point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
964
+ sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
965
+ if boxes is not None:
966
+ box_embeddings = self._embed_boxes(boxes)
967
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
968
+
969
+ if text_embeds is not None:
970
+ sparse_embeddings = torch.cat([sparse_embeddings, text_embeds], dim=1)
971
+
972
+ if masks is not None:
973
+ dense_embeddings = self._embed_masks(masks)
974
+ else:
975
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
976
+ bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
977
+ )
978
+
979
+ return sparse_embeddings, dense_embeddings
980
+
981
+
982
+ class MaskDecoderMLP(nn.Module):
983
+ def __init__(
984
+ self,
985
+ input_dim: int,
986
+ hidden_dim: int,
987
+ output_dim: int,
988
+ num_layers: int,
989
+ sigmoid_output: bool = False,
990
+ ) -> None:
991
+ super().__init__()
992
+ self.num_layers = num_layers
993
+ h = [hidden_dim] * (num_layers - 1)
994
+ self.layers = nn.ModuleList(
995
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
996
+ )
997
+ self.sigmoid_output = sigmoid_output
998
+
999
+ def forward(self, x):
1000
+ for i, layer in enumerate(self.layers):
1001
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1002
+ if self.sigmoid_output:
1003
+ x = F.sigmoid(x)
1004
+ return x
1005
+
1006
+
1007
+ class MaskDecoder(nn.Module):
1008
+ def __init__(
1009
+ self,
1010
+ *,
1011
+ transformer_dim: int,
1012
+ transformer: nn.Module,
1013
+ num_multimask_outputs: int = 3,
1014
+ activation: Type[nn.Module] = nn.GELU,
1015
+ iou_head_depth: int = 3,
1016
+ iou_head_hidden_dim: int = 256,
1017
+ ) -> None:
1018
+ """
1019
+ Predicts masks given an image and prompt embeddings, using a
1020
+ transformer architecture.
1021
+
1022
+ Arguments:
1023
+ transformer_dim (int): the channel dimension of the transformer
1024
+ transformer (nn.Module): the transformer used to predict masks
1025
+ num_multimask_outputs (int): the number of masks to predict
1026
+ when disambiguating masks
1027
+ activation (nn.Module): the type of activation to use when
1028
+ upscaling masks
1029
+ iou_head_depth (int): the depth of the MLP used to predict
1030
+ mask quality
1031
+ iou_head_hidden_dim (int): the hidden dimension of the MLP
1032
+ used to predict mask quality
1033
+ """
1034
+ super().__init__()
1035
+ self.transformer_dim = transformer_dim
1036
+ self.transformer = transformer
1037
+
1038
+ self.num_multimask_outputs = num_multimask_outputs
1039
+
1040
+ self.iou_token = nn.Embedding(1, transformer_dim)
1041
+ self.num_mask_tokens = num_multimask_outputs + 1
1042
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
1043
+
1044
+ self.output_upscaling = nn.Sequential(
1045
+ nn.ConvTranspose2d(
1046
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
1047
+ ),
1048
+ LayerNorm2d(transformer_dim // 4),
1049
+ activation(),
1050
+ nn.ConvTranspose2d(
1051
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
1052
+ ),
1053
+ activation(),
1054
+ )
1055
+ self.output_hypernetworks_mlps = nn.ModuleList(
1056
+ [
1057
+ MaskDecoderMLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
1058
+ for i in range(self.num_mask_tokens)
1059
+ ]
1060
+ )
1061
+
1062
+ self.iou_prediction_head = MaskDecoderMLP(
1063
+ transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
1064
+ )
1065
+
1066
+ def forward(
1067
+ self,
1068
+ image_embeddings: torch.Tensor,
1069
+ image_pe: torch.Tensor,
1070
+ sparse_prompt_embeddings: torch.Tensor,
1071
+ dense_prompt_embeddings: torch.Tensor,
1072
+ multimask_output: bool,
1073
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1074
+ """
1075
+ Predict masks given image and prompt embeddings.
1076
+
1077
+ Arguments:
1078
+ image_embeddings (torch.Tensor): the embeddings from the image encoder
1079
+ image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
1080
+ sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
1081
+ dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
1082
+ multimask_output (bool): Whether to return multiple masks or a single
1083
+ mask.
1084
+
1085
+ Returns:
1086
+ torch.Tensor: batched predicted masks
1087
+ torch.Tensor: batched predictions of mask quality
1088
+ """
1089
+ masks, iou_pred = self.predict_masks(
1090
+ image_embeddings=image_embeddings,
1091
+ image_pe=image_pe,
1092
+ sparse_prompt_embeddings=sparse_prompt_embeddings,
1093
+ dense_prompt_embeddings=dense_prompt_embeddings,
1094
+ )
1095
+
1096
+ # Select the correct mask or masks for output
1097
+ if multimask_output:
1098
+ mask_slice = slice(1, None)
1099
+ else:
1100
+ mask_slice = slice(0, 1)
1101
+ masks = masks[:, mask_slice, :, :]
1102
+ iou_pred = iou_pred[:, mask_slice]
1103
+
1104
+ # Prepare output
1105
+ return masks, iou_pred
1106
+
1107
+ def predict_masks(
1108
+ self,
1109
+ image_embeddings: torch.Tensor,
1110
+ image_pe: torch.Tensor,
1111
+ sparse_prompt_embeddings: torch.Tensor,
1112
+ dense_prompt_embeddings: torch.Tensor,
1113
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1114
+ """Predicts masks. See 'forward' for more details."""
1115
+ # Concatenate output tokens
1116
+ output_tokens = torch.cat(
1117
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
1118
+ )
1119
+ output_tokens = output_tokens.unsqueeze(0).expand(
1120
+ sparse_prompt_embeddings.size(0), -1, -1
1121
+ )
1122
+
1123
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
1124
+
1125
+ # image_embeddings: [1, C, H, W], tokens: [B, N, C]
1126
+ # dense_prompt_embeddings: [B, C, H, W]
1127
+ # Expand per-image data in batch direction to be per-mask
1128
+ src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
1129
+ src = src + dense_prompt_embeddings
1130
+ pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
1131
+ b, c, h, w = src.shape
1132
+
1133
+ # Run the transformer
1134
+ hs, src = self.transformer(src, pos_src, tokens)
1135
+ iou_token_out = hs[:, 0, :]
1136
+ mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
1137
+
1138
+ # Upscale mask embeddings and predict masks using the mask tokens
1139
+ src = src.transpose(1, 2).view(b, c, h, w)
1140
+ upscaled_embedding = self.output_upscaling(src)
1141
+ hyper_in_list: List[torch.Tensor] = []
1142
+ for i in range(self.num_mask_tokens):
1143
+ hyper_in_list.append(
1144
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
1145
+ )
1146
+ hyper_in = torch.stack(hyper_in_list, dim=1)
1147
+ b, c, h, w = upscaled_embedding.shape
1148
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(
1149
+ b, self.num_mask_tokens, h, w
1150
+ )
1151
+
1152
+ # Generate mask quality predictions
1153
+ iou_pred = self.iou_prediction_head(iou_token_out)
1154
+
1155
+ return masks, iou_pred
1156
+
1157
+
1158
+ class Sam(nn.Module):
1159
+ mask_threshold: float = 0.0
1160
+ image_format: str = "RGB"
1161
+
1162
+ def __init__(
1163
+ self,
1164
+ image_encoder: ImageEncoderViT,
1165
+ prompt_encoder: PromptEncoder,
1166
+ mask_decoder: MaskDecoder,
1167
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
1168
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
1169
+ ) -> None:
1170
+ """
1171
+ SAM predicts object masks from an image and input prompts.
1172
+
1173
+ Arguments:
1174
+ image_encoder (ImageEncoderViT): The backbone used to encode the
1175
+ image into image embeddings that allow for efficient mask prediction.
1176
+ prompt_encoder (PromptEncoder): Encodes various types of input prompts.
1177
+ mask_decoder (MaskDecoder): Predicts masks from the image embeddings
1178
+ and encoded prompts.
1179
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
1180
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
1181
+ """
1182
+ super().__init__()
1183
+ self.image_encoder = image_encoder
1184
+ self.prompt_encoder = prompt_encoder
1185
+ self.mask_decoder = mask_decoder
1186
+ self.register_buffer(
1187
+ "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
1188
+ )
1189
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
1190
+
1191
+ @property
1192
+ def device(self) -> Any:
1193
+ return self.pixel_mean.device
1194
+
1195
+ @torch.no_grad()
1196
+ def forward(
1197
+ self,
1198
+ batched_input: List[Dict[str, Any]],
1199
+ multimask_output: bool,
1200
+ ) -> List[Dict[str, torch.Tensor]]:
1201
+ """
1202
+ Predicts masks end-to-end from provided images and prompts.
1203
+ If prompts are not known in advance, using SamPredictor is
1204
+ recommended over calling the model directly.
1205
+
1206
+ Arguments:
1207
+ batched_input (list(dict)): A list over input images, each a
1208
+ dictionary with the following keys. A prompt key can be
1209
+ excluded if it is not present.
1210
+ 'image': The image as a torch tensor in 3xHxW format,
1211
+ already transformed for input to the model.
1212
+ 'original_size': (tuple(int, int)) The original size of
1213
+ the image before transformation, as (H, W).
1214
+ 'point_coords': (torch.Tensor) Batched point prompts for
1215
+ this image, with shape BxNx2. Already transformed to the
1216
+ input frame of the model.
1217
+ 'point_labels': (torch.Tensor) Batched labels for point prompts,
1218
+ with shape BxN.
1219
+ 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
1220
+ Already transformed to the input frame of the model.
1221
+ 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
1222
+ in the form Bx1xHxW.
1223
+ multimask_output (bool): Whether the model should predict multiple
1224
+ disambiguating masks, or return a single mask.
1225
+
1226
+ Returns:
1227
+ (list(dict)): A list over input images, where each element is
1228
+ as dictionary with the following keys.
1229
+ 'masks': (torch.Tensor) Batched binary mask predictions,
1230
+ with shape BxCxHxW, where B is the number of input prompts,
1231
+ C is determined by multimask_output, and (H, W) is the
1232
+ original size of the image.
1233
+ 'iou_predictions': (torch.Tensor) The model's predictions
1234
+ of mask quality, in shape BxC.
1235
+ 'low_res_logits': (torch.Tensor) Low resolution logits with
1236
+ shape BxCxHxW, where H=W=256. Can be passed as mask input
1237
+ to subsequent iterations of prediction.
1238
+ """
1239
+ input_images = torch.stack(
1240
+ [self.preprocess(x["image"]) for x in batched_input], dim=0
1241
+ )
1242
+ image_embeddings = self.image_encoder(input_images)
1243
+
1244
+ outputs = []
1245
+ for image_record, curr_embedding in zip(batched_input, image_embeddings):
1246
+ if "point_coords" in image_record:
1247
+ points = (image_record["point_coords"], image_record["point_labels"])
1248
+ else:
1249
+ points = None
1250
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
1251
+ points=points,
1252
+ boxes=image_record.get("boxes", None),
1253
+ masks=image_record.get("mask_inputs", None),
1254
+ )
1255
+ low_res_masks, iou_predictions = self.mask_decoder(
1256
+ image_embeddings=curr_embedding.unsqueeze(0),
1257
+ image_pe=self.prompt_encoder.get_dense_pe(),
1258
+ sparse_prompt_embeddings=sparse_embeddings,
1259
+ dense_prompt_embeddings=dense_embeddings,
1260
+ multimask_output=multimask_output,
1261
+ )
1262
+ masks = self.postprocess_masks(
1263
+ low_res_masks,
1264
+ input_size=image_record["image"].shape[-2:],
1265
+ original_size=image_record["original_size"],
1266
+ )
1267
+ masks = masks > self.mask_threshold
1268
+ outputs.append(
1269
+ {
1270
+ "masks": masks,
1271
+ "iou_predictions": iou_predictions,
1272
+ "low_res_logits": low_res_masks,
1273
+ }
1274
+ )
1275
+ return outputs
1276
+
1277
+ def postprocess_masks(
1278
+ self,
1279
+ masks: torch.Tensor,
1280
+ input_size: Tuple[int, ...],
1281
+ original_size: Tuple[int, ...],
1282
+ ) -> torch.Tensor:
1283
+ """
1284
+ Remove padding and upscale masks to the original image size.
1285
+
1286
+ Arguments:
1287
+ masks (torch.Tensor): Batched masks from the mask_decoder,
1288
+ in BxCxHxW format.
1289
+ input_size (tuple(int, int)): The size of the image input to the
1290
+ model, in (H, W) format. Used to remove padding.
1291
+ original_size (tuple(int, int)): The original size of the image
1292
+ before resizing for input to the model, in (H, W) format.
1293
+
1294
+ Returns:
1295
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
1296
+ is given by original_size.
1297
+ """
1298
+
1299
+ dtype = masks.dtype
1300
+
1301
+ masks = F.interpolate(
1302
+ masks.float(),
1303
+ (self.image_encoder.img_size, self.image_encoder.img_size),
1304
+ mode="bilinear",
1305
+ align_corners=False,
1306
+ )
1307
+ # masks = masks.to(dtype)
1308
+ masks = masks[..., : input_size[0], : input_size[1]]
1309
+ masks = F.interpolate(
1310
+ masks, original_size, mode="bilinear", align_corners=False
1311
+ )
1312
+ return masks
1313
+
1314
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
1315
+ """Normalize pixel values and pad to a square input."""
1316
+ # Normalize colors
1317
+ x = (x - self.pixel_mean) / self.pixel_std
1318
+
1319
+ # Pad
1320
+ h, w = x.shape[-2:]
1321
+ padh = self.image_encoder.img_size - h
1322
+ padw = self.image_encoder.img_size - w
1323
+ x = F.pad(x, (0, padw, 0, padh))
1324
+ return x
special_tokens_map.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>",
16
+ "[SEG]"
17
+ ],
18
+ "eos_token": {
19
+ "content": "<|im_end|>",
20
+ "lstrip": false,
21
+ "normalized": false,
22
+ "rstrip": false,
23
+ "single_word": false
24
+ },
25
+ "pad_token": {
26
+ "content": "<|endoftext|>",
27
+ "lstrip": false,
28
+ "normalized": false,
29
+ "rstrip": false,
30
+ "single_word": false
31
+ }
32
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b376a39f73c663abe179f7ae1cae86da5dd2b689bcce504348d4fb00e6a64240
3
+ size 11422078
tokenizer_config.json ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "[SEG]",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": true
188
+ }
189
+ },
190
+ "additional_special_tokens": [
191
+ "<|im_start|>",
192
+ "<|im_end|>",
193
+ "<|object_ref_start|>",
194
+ "<|object_ref_end|>",
195
+ "<|box_start|>",
196
+ "<|box_end|>",
197
+ "<|quad_start|>",
198
+ "<|quad_end|>",
199
+ "<|vision_start|>",
200
+ "<|vision_end|>",
201
+ "<|vision_pad|>",
202
+ "<|image_pad|>",
203
+ "<|video_pad|>",
204
+ "[SEG]"
205
+ ],
206
+ "bos_token": null,
207
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
208
+ "clean_up_tokenization_spaces": false,
209
+ "eos_token": "<|im_end|>",
210
+ "errors": "replace",
211
+ "extra_special_tokens": {},
212
+ "model_max_length": 32768,
213
+ "pad_token": "<|endoftext|>",
214
+ "padding_side": "right",
215
+ "split_special_tokens": false,
216
+ "tokenizer_class": "Qwen2Tokenizer",
217
+ "unk_token": null
218
+ }
transform.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ File separation from META Sam project @DeepGlint 2025
3
+ https://github.com/facebookresearch/segment-anything
4
+ '''
5
+
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from typing import Tuple
11
+ from copy import deepcopy
12
+ from torchvision.transforms.functional import resize # type: ignore
13
+ from torchvision.transforms.functional import to_pil_image
14
+
15
+
16
+
17
+ class ResizeLongestSide:
18
+ """
19
+ Resizes images to the longest side 'target_length', as well as provides
20
+ methods for resizing coordinates and boxes. Provides methods for
21
+ transforming both numpy array and batched torch tensors.
22
+ """
23
+
24
+ def __init__(self, target_length: int) -> None:
25
+ self.target_length = target_length
26
+
27
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
28
+ """
29
+ Expects a numpy array with shape HxWxC in uint8 format.
30
+ """
31
+ target_size = self.get_preprocess_shape(
32
+ image.shape[0], image.shape[1], self.target_length
33
+ )
34
+ return np.array(resize(to_pil_image(image), target_size))
35
+
36
+ def apply_coords(
37
+ self, coords: np.ndarray, original_size: Tuple[int, ...]
38
+ ) -> np.ndarray:
39
+ """
40
+ Expects a numpy array of length 2 in the final dimension. Requires the
41
+ original image size in (H, W) format.
42
+ """
43
+ old_h, old_w = original_size
44
+ new_h, new_w = self.get_preprocess_shape(
45
+ original_size[0], original_size[1], self.target_length
46
+ )
47
+ coords = deepcopy(coords).astype(float)
48
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
49
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
50
+ return coords
51
+
52
+ def apply_boxes(
53
+ self, boxes: np.ndarray, original_size: Tuple[int, ...]
54
+ ) -> np.ndarray:
55
+ """
56
+ Expects a numpy array shape Bx4. Requires the original image size
57
+ in (H, W) format.
58
+ """
59
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
60
+ return boxes.reshape(-1, 4)
61
+
62
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
63
+ """
64
+ Expects batched images with shape BxCxHxW and float format. This
65
+ transformation may not exactly match apply_image. apply_image is
66
+ the transformation expected by the model.
67
+ """
68
+ # Expects an image in BCHW format. May not exactly match apply_image.
69
+ target_size = self.get_preprocess_shape(
70
+ image.shape[0], image.shape[1], self.target_length
71
+ )
72
+ return F.interpolate(
73
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
74
+ )
75
+
76
+ def apply_coords_torch(
77
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
78
+ ) -> torch.Tensor:
79
+ """
80
+ Expects a torch tensor with length 2 in the last dimension. Requires the
81
+ original image size in (H, W) format.
82
+ """
83
+ old_h, old_w = original_size
84
+ new_h, new_w = self.get_preprocess_shape(
85
+ original_size[0], original_size[1], self.target_length
86
+ )
87
+ coords = deepcopy(coords).to(torch.float)
88
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
89
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
90
+ return coords
91
+
92
+ def apply_boxes_torch(
93
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
94
+ ) -> torch.Tensor:
95
+ """
96
+ Expects a torch tensor with shape Bx4. Requires the original image
97
+ size in (H, W) format.
98
+ """
99
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
100
+ return boxes.reshape(-1, 4)
101
+
102
+ @staticmethod
103
+ def get_preprocess_shape(
104
+ oldh: int, oldw: int, long_side_length: int
105
+ ) -> Tuple[int, int]:
106
+ """
107
+ Compute the output size given input size and target long side length.
108
+ """
109
+ scale = long_side_length * 1.0 / max(oldh, oldw)
110
+ newh, neww = oldh * scale, oldw * scale
111
+ neww = int(neww + 0.5)
112
+ newh = int(newh + 0.5)
113
+ return (newh, neww)
vision_projector.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ CopyRight @DeepGlint 2025
3
+ '''
4
+
5
+
6
+
7
+ import torch.nn as nn
8
+ import re
9
+
10
+ class IdentityMap(nn.Module):
11
+ def __init__(self):
12
+ super().__init__()
13
+
14
+ def forward(self, x, *args, **kwargs):
15
+ return x
16
+
17
+ @property
18
+ def config(self):
19
+ return {"mm_projector_type": "identity"}
20
+
21
+
22
+ class SimpleResBlock(nn.Module):
23
+ def __init__(self, channels):
24
+ super().__init__()
25
+ self.pre_norm = nn.LayerNorm(channels)
26
+
27
+ self.proj = nn.Sequential(nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels))
28
+
29
+ def forward(self, x):
30
+ x = self.pre_norm(x)
31
+ return x + self.proj(x)
32
+
33
+ def build_vision_projector(config, delay_load=False, **kwargs):
34
+ projector_type = getattr(config, "mm_projector_type", "linear")
35
+
36
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
37
+ if mlp_gelu_match:
38
+ mlp_depth = int(mlp_gelu_match.group(1))
39
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
40
+ for _ in range(1, mlp_depth):
41
+ modules.append(nn.GELU())
42
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
43
+ return nn.Sequential(*modules)
44
+
45
+ mlp_gelu_resnet_match = re.match(r"^mlp(\d+)x_res(\d+)x_gelu$", projector_type)
46
+ if mlp_gelu_resnet_match:
47
+ mlp_depth = int(mlp_gelu_resnet_match.group(1))
48
+ res_depth = int(mlp_gelu_resnet_match.group(2))
49
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
50
+ for _ in range(1, mlp_depth):
51
+ modules.append(nn.GELU())
52
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
53
+ for _ in range(res_depth):
54
+ modules.append(SimpleResBlock(config.hidden_size))
55
+ return nn.Sequential(*modules)
56
+
57
+ raise ValueError(f"Unknown projector type: {projector_type}")
vision_resampler.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ CopyRight @DeepGlint 2025
3
+ '''
4
+
5
+
6
+ import torch
7
+
8
+
9
+
10
+ def build_vision_resampler(model_args, delay_load=False, **kwargs):
11
+ return IdentityMap()
12
+
13
+
14
+ class IdentityMap(torch.nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def forward(self, x, *args, **kwargs):
19
+ return x
20
+
21
+ @property
22
+ def config(self):
23
+ return {"mm_resampler_type": None}
vision_tower.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ CopyRight @DeepGlint 2025
3
+ '''
4
+
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
9
+
10
+
11
+ def build_vision_tower(model_cfg, **kwargs):
12
+ vision_tower = getattr(model_cfg, "vision_tower_config", getattr(model_cfg, "vision_tower", None))
13
+ return CLIPVisionTower(vision_tower, args=model_cfg, **kwargs)
14
+
15
+
16
+ class CLIPVisionTower(nn.Module):
17
+ def __init__(self, vision_tower, args, delay_load=False):
18
+ super().__init__()
19
+
20
+ self.is_loaded = False
21
+
22
+ self.vision_tower_cfg = vision_tower
23
+ self.vision_tower_processor = args.vision_tower_processor
24
+ self.select_layer = args.mm_vision_select_layer
25
+ self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
26
+
27
+ if not delay_load:
28
+ self.init_model()
29
+ elif getattr(args, "unfreeze_mm_vision_tower", False):
30
+ # TODO: better detector is needed.
31
+ self.init_model()
32
+ elif hasattr(args, "mm_tunable_parts") and "mm_vision_tower" in args.mm_tunable_parts:
33
+ self.init_model()
34
+ else:
35
+ raise RuntimeError("Not support now, please check config.json or contact us")
36
+
37
+ def init_model(self, device_map=None):
38
+ if self.is_loaded:
39
+ return
40
+ vision_tower_config = CLIPVisionConfig().from_dict(self.vision_tower_cfg)
41
+ self.image_processor = CLIPImageProcessor(**self.vision_tower_processor)
42
+ self.vision_tower = CLIPVisionModel(config=vision_tower_config)
43
+ self.vision_tower.requires_grad_(False)
44
+
45
+ self.is_loaded = True
46
+
47
+ def feature_select(self, image_forward_outs):
48
+ select_feature_type = self.select_feature
49
+
50
+ if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
51
+ select_every_k_layer = len(image_forward_outs.hidden_states) // 4
52
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
53
+ select_feature_type = select_feature_type.replace("slicefour_", "")
54
+ elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
55
+ select_layers = [-2, -5, -8, -11, 6]
56
+ image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
57
+ select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
58
+ else:
59
+ image_features = image_forward_outs.hidden_states[self.select_layer]
60
+
61
+ if select_feature_type == "patch":
62
+ image_features = image_features[:, 1:]
63
+ elif select_feature_type == "cls_patch":
64
+ image_features = image_features
65
+ else:
66
+ raise ValueError(f"Unexpected select feature: {select_feature_type}")
67
+ return image_features
68
+
69
+ def forward(self, images):
70
+ if type(images) is list:
71
+ image_features = []
72
+ for image in images:
73
+ image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
74
+ image_feature = self.feature_select(image_forward_out).to(image.dtype)
75
+ image_features.append(image_feature)
76
+ else:
77
+ image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
78
+ image_features = self.feature_select(image_forward_outs).to(images.dtype)
79
+
80
+ return image_features
81
+
82
+ @property
83
+ def dummy_feature(self):
84
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
85
+
86
+ @property
87
+ def dtype(self):
88
+ return self.vision_tower.dtype
89
+
90
+ @property
91
+ def device(self):
92
+ return self.vision_tower.device
93
+
94
+ @property
95
+ def config(self):
96
+ if self.is_loaded:
97
+ return self.vision_tower.config
98
+ else:
99
+ return self.cfg_only
100
+
101
+ @property
102
+ def hidden_size(self):
103
+ _hidden_size = self.config.hidden_size
104
+ if "slicefour" in self.select_feature:
105
+ _hidden_size *= 4
106
+ if "slice_m25811_f6" in self.select_feature:
107
+ _hidden_size *= 5
108
+ return _hidden_size
109
+
110
+ @property
111
+ def num_patches_per_side(self):
112
+ return self.config.image_size // self.config.patch_size
113
+
114
+ @property
115
+ def num_patches(self):
116
+ _num_patches = (self.config.image_size // self.config.patch_size) ** 2
117
+ if "cls_patch" in self.select_feature:
118
+ _num_patches += 1
119
+ return _num_patches
120
+
121
+ @property
122
+ def image_size(self):
123
+ return self.config.image_size
vocab.json ADDED
The diff for this file is too large to render. See raw diff