''' File modification from LLAVA project @DeepGlintAI 2025 https://github.com/haotian-liu/LLaVA origin copyright: Copyright 2023 Haotian Liu Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ''' from abc import ABC, abstractmethod import math import random import ast import re import json import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from pathlib import Path from dataclasses import dataclass from PIL import Image from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, AutoConfig, AutoModelForCausalLM, PreTrainedTokenizer from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.generation.utils import GenerateOutput from transformers.utils import cached_file from safetensors.torch import load_file as safetensors_load from .vision_tower import build_vision_tower from .vision_resampler import build_vision_resampler from .vision_projector import build_vision_projector from .sam import build_sam_vit_h, text2sam_projection_layer from .conversation_mlcd_seg import conv_templates, default_conversation from .transform import ResizeLongestSide from typing import Optional, Any, List, Tuple, Union, Dict IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_SEG_TOKEN = "[SEG]" IMG_MEAN = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) IMG_STD = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) IMG_SIZE = 1024 def select_best_resolution(original_size, possible_resolutions): """ Selects the best resolution from a list of possible resolutions based on the original size. Args: original_size (tuple): The original size of the image in the format (width, height). possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. Returns: tuple: The best fit resolution in the format (width, height). """ original_width, original_height = original_size best_fit = None max_effective_resolution = 0 min_wasted_resolution = float("inf") for width, height in possible_resolutions: # Calculate the downscaled size to keep the aspect ratio scale = min(width / original_width, height / original_height) downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) # Calculate effective and wasted resolutions effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) return best_fit def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Args: image_size (tuple): The size of the input image in the format (width, height). grid_pinpoints (str): A string representation of a list of possible resolutions. patch_size (int): The size of each image patch. Returns: tuple: The shape of the image patch grid in the format (width, height). """ if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" # Use regex to extract the range from the input string matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) range_start = tuple(map(int, matches[0])) range_end = tuple(map(int, matches[-1])) # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] # Multiply all elements by patch_size grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: possible_resolutions = ast.literal_eval(grid_pinpoints) width, height = select_best_resolution(image_size, possible_resolutions) return width // patch_size, height // patch_size class MLCDSegMetaModel: def __init__(self, config): super(MLCDSegMetaModel, self).__init__(config) if hasattr(config, "vision_tower_config"): vision_tower_weight, sam_weight, projector_weight, text2sam_projection_weight = self.dispatch_weight(config) delay_load = getattr(config, "delay_load", False) self.vision_tower = build_vision_tower(config, delay_load=delay_load) self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower) self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) self.vision_tower.vision_tower.load_state_dict(vision_tower_weight) self.mm_projector.load_state_dict(projector_weight) self.sam = build_sam_vit_h() self.sam.load_state_dict(sam_weight) self.text2sam_projection = text2sam_projection_layer(config) self.text2sam_projection.load_state_dict(text2sam_projection_weight) if "unpad" in getattr(config, "mm_patch_merge_type", ""): self.image_newline = nn.Parameter(torch.empty(config.hidden_size, dtype=self.dtype)) def dispatch_weight(self, config): safetensors_set = set() repo = getattr(config, "_name_or_path", "'DeepGlint-AI/MLCD-Seg'") index_file = cached_file(repo, "model.safetensors.index.json") with open(index_file, "r") as safetensors_index: safetensors_map = json.loads(safetensors_index.read()) for key, value in safetensors_map["weight_map"].items(): if key.startswith("model.vision_tower.vision_tower") or key.startswith("model.sam") or \ key.startswith("model.mm_projector") or key.startswith("model.text2sam_projection"): safetensors_set.add(value) vision_tower_weight = {} sam_weight = {} projector_weight = {} text2sam_projection_weight = {} for safetensors_file in safetensors_set: temp_load = safetensors_load(cached_file(repo, safetensors_file)) for key, value in temp_load.items(): if key.startswith("model.sam."): sam_weight[key.replace("model.sam.", "")] = value if key.startswith("model.vision_tower.vision_tower."): vision_tower_weight[key.replace("model.vision_tower.vision_tower.", "")] = value if key.startswith("model.mm_projector."): projector_weight[key.replace("model.mm_projector.", "")] = value if key.startswith("model.text2sam_projection."): text2sam_projection_weight[key.replace("model.text2sam_projection.", "")] = value return vision_tower_weight, sam_weight, projector_weight, text2sam_projection_weight def get_vision_tower(self): vision_tower = getattr(self, "vision_tower", None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def unpad_image(tensor, original_size): """ Unpads a PyTorch tensor of a padded and resized image. Args: tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. original_size (tuple): The original size of the image (height, width). Returns: torch.Tensor: The unpadded image tensor. """ original_width, original_height = original_size current_height, current_width = tensor.shape[1:] # Compute aspect ratios original_aspect_ratio = original_width / original_height current_aspect_ratio = current_width / current_height # Determine padding size and direction if original_aspect_ratio > current_aspect_ratio: # Padding was added to the height scale_factor = current_width / original_width new_height = int(original_height * scale_factor) padding = (current_height - new_height) // 2 unpadded_tensor = tensor[:, padding : current_height - padding, :] else: # Padding was added to the width scale_factor = current_height / original_height new_width = int(original_width * scale_factor) padding = (current_width - new_width) // 2 unpadded_tensor = tensor[:, :, padding : current_width - padding] return unpadded_tensor def resize_and_pad_image(image, target_resolution): """ Resize and pad an image to a target resolution while maintaining aspect ratio. Args: image (PIL.Image.Image): The input image. target_resolution (tuple): The target resolution (width, height) of the image. Returns: PIL.Image.Image: The resized and padded image. """ original_width, original_height = image.size target_width, target_height = target_resolution # Determine which dimension (width or height) to fill scale_w = target_width / original_width scale_h = target_height / original_height if scale_w < scale_h: # Width will be filled completely new_width = target_width new_height = min(math.ceil(original_height * scale_w), target_height) else: # Height will be filled completely new_height = target_height new_width = min(math.ceil(original_width * scale_h), target_width) # Resize the image resized_image = image.resize((new_width, new_height)) # Create a new image with the target size and paste the resized image onto it new_image = Image.new("RGB", (target_width, target_height), (0, 0, 0)) paste_x = (target_width - new_width) // 2 paste_y = (target_height - new_height) // 2 new_image.paste(resized_image, (paste_x, paste_y)) return new_image def divide_to_patches(image, patch_size): """ Divides an image into patches of a specified size. Args: image (PIL.Image.Image): The input image. patch_size (int): The size of each patch. Returns: list: A list of PIL.Image.Image objects representing the patches. """ patches = [] width, height = image.size for i in range(0, height, patch_size): for j in range(0, width, patch_size): box = (j, i, j + patch_size, i + patch_size) patch = image.crop(box) patches.append(patch) return patches def process_anyres_image(image, processor, grid_pinpoints): """ Process an image with variable resolutions. Args: image (PIL.Image.Image): The input image to be processed. processor: The image processor object. grid_pinpoints (str): A string representation of a list of possible resolutions. Returns: torch.Tensor: A tensor containing the processed image patches. """ # Convert grid_pinpoints from string to list if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints: try: patch_size = processor.size[0] except Exception as e: patch_size = processor.size["shortest_edge"] assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]" # Use regex to extract the range from the input string matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints) range_start = tuple(map(int, matches[0])) range_end = tuple(map(int, matches[-1])) # Generate a matrix of tuples from (range_start[0], range_start[1]) to (range_end[0], range_end[1]) grid_pinpoints = [(i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)] # Multiply all elements by patch_size grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints] if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: possible_resolutions = ast.literal_eval(grid_pinpoints) best_resolution = select_best_resolution(image.size, possible_resolutions) image_padded = resize_and_pad_image(image, best_resolution) patches = divide_to_patches(image_padded, processor.crop_size["height"]) # FIXME: this seems to be a bug that it resizes instead of pad. # but to keep it consistent with previous, i will keep it as it is # TODO: uncomment below to ablate with the padding if isinstance(processor.size, dict): shortest_edge = processor.size["shortest_edge"] else: shortest_edge = min(processor.size) image_original_resize = image.resize((shortest_edge, shortest_edge)) # image_padded_square = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) # image_original_resize = image_padded_square.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) image_patches = [image_original_resize] + patches image_patches = [processor.preprocess(image_patch, return_tensors="pt")["pixel_values"][0] for image_patch in image_patches] return torch.stack(image_patches, dim=0) class MLCDSegMetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def get_2dPool(self, image_feature, stride=2): height = width = self.get_vision_tower().num_patches_per_side num_frames, num_tokens, num_dim = image_feature.shape image_feature = image_feature.view(num_frames, height, width, -1) image_feature = image_feature.permute(0, 3, 1, 2).contiguous() # image_feature = nn.functional.max_pool2d(image_feature, self.config.mm_spatial_pool_stride) if self.config.mm_spatial_pool_mode == "average": image_feature = nn.functional.avg_pool2d(image_feature, stride) elif self.config.mm_spatial_pool_mode == "max": image_feature = nn.functional.max_pool2d(image_feature, stride) elif self.config.mm_spatial_pool_mode == "bilinear": height, width = image_feature.shape[2:] scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)] image_feature = nn.functional.interpolate(image_feature, size=scaled_shape, mode='bilinear') else: raise ValueError(f"Unexpected mm_spatial_pool_mode: {self.config.mm_spatial_pool_mode}") image_feature = image_feature.permute(0, 2, 3, 1) image_feature = image_feature.view(num_frames, -1, num_dim) return image_feature def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) # image_features = self.get_model().vision_resampler(image_features, images=images) image_features = self.get_model().mm_projector(image_features) return image_features def encode_multimodals(self, videos_or_images, video_idx_in_batch, split_sizes=None): videos_or_images_features = self.get_model().get_vision_tower()(videos_or_images) per_videos_or_images_features = torch.split(videos_or_images_features, split_sizes, dim=0) # tuple, (dim_1, 576, 4096) all_videos_or_images_features = [] all_faster_video_features = [] cur_mm_spatial_pool_stride = self.config.mm_spatial_pool_stride for idx, feat in enumerate(per_videos_or_images_features): feat = self.get_model().mm_projector(feat) faster_video_feature = 0 slower_img_feat = 0 if idx in video_idx_in_batch and cur_mm_spatial_pool_stride > 1: slower_img_feat = self.get_2dPool(feat,cur_mm_spatial_pool_stride) if self.config.add_faster_video: cur_mm_spatial_pool_stride = cur_mm_spatial_pool_stride * 2 faster_video_feature = self.get_2dPool(feat,cur_mm_spatial_pool_stride) if slower_img_feat != 0: all_videos_or_images_features.append(slower_img_feat) else: all_videos_or_images_features.append(feat) all_faster_video_features.append(faster_video_feature) return all_videos_or_images_features,all_faster_video_features def prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities=["image"], image_sizes=None): vision_tower = self.get_vision_tower() # rank_print(modalities) if vision_tower is None or images is None or input_ids.shape[1] == 1: return input_ids, position_ids, attention_mask, past_key_values, None, labels if isinstance(modalities, str): modalities = [modalities] # import pdb; pdb.set_trace() if type(images) is list or images.ndim == 5: if type(images) is list: images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] images_list = [] for image in images: if image.ndim == 4: images_list.append(image) else: images_list.append(image.unsqueeze(0)) concat_images = torch.cat([image for image in images_list], dim=0) split_sizes = [image.shape[0] for image in images_list] encoded_image_features = self.encode_images(concat_images) # image_features,all_faster_video_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) # This is a list, each element is [num_images, patch * patch, dim] # rank_print(f"Concat images : {concat_images.shape}") encoded_image_features = torch.split(encoded_image_features, split_sizes) image_features = [] for idx, image_feat in enumerate(encoded_image_features): image_features.append(image_feat) # image_features = self.encode_multimodals(concat_images, video_idx_in_batch, split_sizes) # rank_print(f"Encoded image feats : {[x.shape for x in image_features]}") # image_features = torch.split(image_features, split_sizes, dim=0) mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") mm_newline_position = getattr(self.config, "mm_newline_position", "one_token") if mm_patch_merge_type == "flat": image_features = [x.flatten(0, 1) for x in image_features] elif mm_patch_merge_type.startswith("spatial"): new_image_features = [] for image_idx, image_feature in enumerate(image_features): # FIXME: now assume the image is square, and split to 2x2 patches # num_patches = h * w, where h = w = sqrt(num_patches) # currently image_feature is a tensor of shape (4, num_patches, hidden_size) # we want to first unflatten it to (2, 2, h, w, hidden_size) # rank0_print("At least we are reaching here") # import pdb; pdb.set_trace() if image_feature.shape[0] > 1: # multi patches and multi images operations # rank0_print("Single-images") base_image_feature = image_feature[0] image_feature = image_feature[1:] height = width = self.get_vision_tower().num_patches_per_side assert height * width == base_image_feature.shape[0] if "anyres_max" in image_aspect_ratio: matched_anyres_max_num_patches = re.match(r"anyres_max_(\d+)", image_aspect_ratio) if matched_anyres_max_num_patches: max_num_patches = int(matched_anyres_max_num_patches.group(1)) if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: if hasattr(self.get_vision_tower(), "image_size"): vision_tower_image_size = self.get_vision_tower().image_size else: raise ValueError("vision_tower_image_size is not found in the vision tower.") try: num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, vision_tower_image_size) except Exception as e: num_patch_width, num_patch_height = 2, 2 image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) else: image_feature = image_feature.view(2, 2, height, width, -1) if "maxpool2x2" in mm_patch_merge_type: image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = nn.functional.max_pool2d(image_feature, 2) image_feature = image_feature.flatten(1, 2).transpose(0, 1) elif "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches: unit = image_feature.shape[2] image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) c, h, w = image_feature.shape times = math.sqrt(h * w / (max_num_patches * unit**2)) if times > 1.1: image_feature = image_feature[None] image_feature = nn.functional.interpolate(image_feature, [int(h // times), int(w // times)], mode="bilinear")[0] image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) image_feature = image_feature.flatten(1, 2).transpose(0, 1) elif "unpad" in mm_patch_merge_type: image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() image_feature = image_feature.flatten(1, 2).flatten(2, 3) image_feature = unpad_image(image_feature, image_sizes[image_idx]) image_feature = torch.cat((image_feature, self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)), dim=-1) image_feature = image_feature.flatten(1, 2).transpose(0, 1) else: image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() image_feature = image_feature.flatten(0, 3) if "nobase" in mm_patch_merge_type: pass else: image_feature = torch.cat((base_image_feature, image_feature), dim=0) new_image_features.append(image_feature) else: # single image operations image_feature = image_feature[0] if "unpad" in mm_patch_merge_type: image_feature = torch.cat((image_feature, self.model.image_newline[None]), dim=0) new_image_features.append(image_feature) image_features = new_image_features else: raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") else: image_features = self.encode_images(images) # TODO: image start / end is not implemented here to support pretraining. if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(self.config, "mm_use_im_start_end", False): raise NotImplementedError # rank_print(f"Total images : {len(image_features)}") # Let's just add dummy tensors if they do not exist, # it is a headache to deal with None all the time. # But it is not ideal, and if you have a better idea, # please open an issue / submit a PR, thanks. _labels = labels _position_ids = position_ids _attention_mask = attention_mask if attention_mask is None: attention_mask = torch.ones_like(input_ids, dtype=torch.bool) else: attention_mask = attention_mask.bool() if position_ids is None: position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) if labels is None: labels = torch.full_like(input_ids, IGNORE_INDEX) # remove the padding using attention_mask -- FIXME _input_ids = input_ids input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] old_attention_mask = attention_mask.clone().detach() new_input_embeds = [] new_labels = [] cur_image_idx = 0 img_token_num = [0 for _ in range(len(input_ids))] num_images_batch = [] # rank_print("Inserting Images embedding") for batch_idx, cur_input_ids in enumerate(input_ids): num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() num_images_batch.append(num_images) # rank0_print(num_images) if num_images == 0: cur_image_features = image_features[cur_image_idx] cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) new_input_embeds.append(cur_input_embeds) new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] cur_input_ids_noim = [] cur_labels = labels[batch_idx] cur_labels_noim = [] for i in range(len(image_token_indices) - 1): cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]) cur_labels_noim.append(cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]) split_sizes = [x.shape[0] for x in cur_labels_noim] cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) cur_new_input_embeds = [] cur_new_labels = [] for i in range(num_images + 1): cur_new_input_embeds.append(cur_input_embeds_no_im[i]) cur_new_labels.append(cur_labels_noim[i]) if i < num_images: try: cur_image_features = image_features[cur_image_idx] except IndexError: cur_image_features = image_features[cur_image_idx - 1] img_token_num[batch_idx] += image_features[cur_image_idx].shape[0] cur_image_idx += 1 cur_new_input_embeds.append(cur_image_features) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] # import pdb; pdb.set_trace() cur_new_input_embeds = torch.cat(cur_new_input_embeds) cur_new_labels = torch.cat(cur_new_labels) new_input_embeds.append(cur_new_input_embeds) new_labels.append(cur_new_labels) # Truncate sequences to max length as image embeddings can make the sequence longer tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None) # rank_print("Finishing Inserting") new_input_embeds = [x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] new_labels = [x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] # TODO: Hard code for control loss spike # if tokenizer_model_max_length is not None: # new_input_embeds = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_input_embeds, modalities)] # new_labels = [x[:4096] if modality != "video" else x[:tokenizer_model_max_length] for x, modality in zip(new_labels, modalities)] # Combine them max_len = max(x.shape[0] for x in new_input_embeds) batch_size = len(new_input_embeds) new_input_embeds_padded = [] new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) # rank0_print("Prepare pos id") for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): cur_len = cur_new_embed.shape[0] if getattr(self.config, "tokenizer_padding_side", "right") == "left": new_input_embeds_padded.append(torch.cat((torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), cur_new_embed), dim=0)) if cur_len > 0: new_labels_padded[i, -cur_len:] = cur_new_labels attention_mask[i, -cur_len:] = True position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) else: new_input_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)) if cur_len > 0: new_labels_padded[i, :cur_len] = cur_new_labels attention_mask[i, :cur_len] = True position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) # rank0_print("tokenizer padding") if _labels is None: new_labels = None else: new_labels = new_labels_padded if _attention_mask is None: attention_mask = None else: attention_mask = attention_mask.to(dtype=_attention_mask.dtype) if _position_ids is None: position_ids = None if getattr(self.config, "use_pos_skipping", False) and self.training: position_ids = torch.arange(new_input_embeds.size(1), device=new_input_embeds.device).unsqueeze(0).to(new_input_embeds.device) split_position = random.randint(0, new_input_embeds.size(1)) left_add = random.randint(0, self.config.pos_skipping_range) right_add = random.randint(left_add, self.config.pos_skipping_range) position_ids[:, :split_position] += left_add position_ids[:, split_position:] += right_add # import pdb; pdb.set_trace() # rank0_print("Finish preparing") return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels, old_attention_mask, img_token_num, num_images_batch class MLCDSegConfig(Qwen2Config): model_type = "mlcd_seg" class MLCDSegModel(MLCDSegMetaModel, Qwen2Model): config_class = MLCDSegConfig def __init__(self, config: Qwen2Config): super(MLCDSegModel, self).__init__(config) @dataclass class MLCDSegOutputWithPast(CausalLMOutputWithPast): labels: Optional[torch.FloatTensor] = None class MLCDSegForCausalLM(Qwen2ForCausalLM, MLCDSegMetaForCausalLM): config_class = MLCDSegConfig def __init__(self, config): # super(Qwen2ForCausalLM, self).__init__(config) Qwen2ForCausalLM.__init__(self, config) config.model_type = "mlcd_seg_clm" config.rope_scaling = None self.model = MLCDSegModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() self.sam_transform = ResizeLongestSide(IMG_SIZE) def get_model(self): return self.model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, modalities: Optional[List[str]] = ["image"], dpo_forward: Optional[bool] = False, cache_position=None, grounding_enc_imgs: Optional[List[torch.FloatTensor]] = None, image_sam_resizes: Optional[List[torch.FloatTensor]] = None, original_sizes: Optional[List[torch.FloatTensor]] = None, masks_list: Optional[List[List[torch.FloatTensor]]] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels ) = self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes ) output = super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict, cache_position=cache_position ) return MLCDSegOutputWithPast(**output) def seg_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, image_sizes: Optional[List[List[int]]] = None, return_dict: Optional[bool] = None, modalities: Optional[List[str]] = ["image"], dpo_forward: Optional[bool] = False, cache_position=None, grounding_enc_imgs: Optional[List[torch.FloatTensor]] = None, image_sam_resizes: Optional[List[torch.FloatTensor]] = None, original_sizes: Optional[List[torch.FloatTensor]] = None, masks_list: Optional[List[List[torch.FloatTensor]]] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: input_ids_ = input_ids if inputs_embeds is None: ( input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, old_attention_mask, img_token_num, num_images_batch ) = self.prepare_inputs_labels_for_multimodal( input_ids, position_ids, attention_mask, past_key_values, labels, images, modalities, image_sizes ) if dpo_forward: outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) return logits, labels else: output = super().forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict, cache_position=cache_position ) sam_image_embeddings = self.get_grounding_encoder_embs(grounding_enc_imgs) seg_token_mask = self.create_seg_token_mask(input_ids_, old_attention_mask, img_token_num, num_images_batch) seg_text_embeds_batch = self.process_hidden_states(output["hidden_states"], seg_token_mask) pred_masks_batch = self.generate_and_postprocess_masks(seg_text_embeds_batch, sam_image_embeddings, num_images_batch, image_sam_resizes, original_sizes) return pred_masks_batch @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None, modalities: Optional[List[str]] = ["image"], **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: position_ids = kwargs.pop("position_ids", None) attention_mask = kwargs.pop("attention_mask", None) if "inputs_embeds" in kwargs: raise NotImplementedError("`inputs_embeds` is not supported") ( inputs, position_ids, attention_mask, _, inputs_embeds, _, old_attention_mask, img_token_num, num_images_batch ) = self.prepare_inputs_labels_for_multimodal( inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes, # batch_pboxes=all_pboxes ) llm_out = super().generate(position_ids=position_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_hidden_states=True, return_dict_in_generate=True, max_length=4096, **kwargs) return llm_out.sequences def generate_and_postprocess_masks(self, seg_text_embeds_batch, sam_image_embeddings, num_images_batch, image_sam_resizes, original_sizes): assert len(seg_text_embeds_batch) == len(num_images_batch) pred_masks_batch = [] # list() for batch_i, seg_text_embeds in enumerate(seg_text_embeds_batch): num_img = max(1, num_images_batch[batch_i]) pred_mask_ = torch.empty((0, original_sizes[batch_i][0], original_sizes[batch_i][1]), device=seg_text_embeds.device) for img_i in range(num_img): sparse_embeddings, dense_embeddings = self.model.sam.prompt_encoder( points=None, boxes=None, masks=None, text_embeds=seg_text_embeds.unsqueeze(1)[img_i::num_img,:,:] ) sparse_embeddings = sparse_embeddings.to(seg_text_embeds.dtype) low_res_masks, _ = self.model.sam.mask_decoder( image_embeddings=sam_image_embeddings[batch_i][img_i].unsqueeze(0), image_pe=self.model.sam.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=False, ) pred_mask = self.model.sam.postprocess_masks( low_res_masks, input_size=image_sam_resizes[batch_i][img_i], original_size=original_sizes[batch_i],) pred_mask_ = torch.cat([pred_mask_, pred_mask[:,0]], dim=0) pred_masks_batch.append(pred_mask_) return pred_masks_batch def process_hidden_states(self, output_hidden_states, seg_token_mask): hidden_states_ = [self.model.text2sam_projection(output_hidden_states[-1])] hidden_states_ = torch.stack(hidden_states_, dim=-1).sum(dim=-1) seg_text_embeds_batch = [] for i, hidden_state_ in enumerate(hidden_states_): # assert hidden_state_.shape[0] == seg_token_mask.shape[1], f"hidden:{hidden_state_.shape}, segtoken:{seg_token_mask.shape}" # seg_text_embeds_batch.append(hidden_state_[seg_token_mask[i]]) seg_text_embeds_batch.append(hidden_state_[seg_token_mask[i][:hidden_state_.shape[0]]]) return seg_text_embeds_batch def create_seg_token_mask(self, input_ids, attention_mask, img_token_num, num_images_batch): input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] max_len = 0 for i, _ in enumerate(input_ids): max_len = max(max_len, len(input_ids[i]) + img_token_num[i] - num_images_batch[i]) seg_token_mask = [] for i, _ in enumerate(input_ids): mask = input_ids[i][num_images_batch[i]:] == self.seg_token_idx seg_token_mask.append( torch.cat( [torch.zeros((1, img_token_num[i])).bool().to(device=self.device), mask.unsqueeze(0), torch.zeros((1, max_len-(len(input_ids[i]) + img_token_num[i] - num_images_batch[i]))).bool().to(device=self.device)], dim=1 ) ) return torch.cat(seg_token_mask, dim=0) def get_grounding_encoder_embs(self, batch_images: torch.FloatTensor): batch_feats = [] for images in batch_images: batch_feats.append(torch.cat([self._encode_single_image(img) for img in images], dim=0)) return batch_feats def _encode_single_image(self, image): return self.model.sam.image_encoder(image.unsqueeze(0)) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): images = kwargs.pop("images", None) image_sizes = kwargs.pop("image_sizes", None) inputs = super().prepare_inputs_for_generation(input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs) if images is not None: inputs["images"] = images if image_sizes is not None: inputs["image_sizes"] = image_sizes return inputs def process_prompt(self, text, tokenizer: PreTrainedTokenizer, stage="gen") -> Dict: if stage.lower() not in ["gen", "seg"]: stage = "seg" if stage.lower() == "gen": conv = conv_templates['qwen_2'].copy() conv.append_message(conv.roles[0], text) conv.append_message(conv.roles[1], None) full_prompt = conv.get_prompt() input_ids = torch.stack([gen_image_token(full_prompt, tokenizer, return_tensors='pt')], dim=0) return dict( input_ids=input_ids, labels=None, ) else: conv = default_conversation.copy() BEGIN_SIGNAL = "### " END_SIGNAL = "\n" roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # Apply prompt templates sys_prompt = default_conversation.system + "\n\n" + "The provides an overview of the picture.\n" full_prompt = sys_prompt + BEGIN_SIGNAL + roles["human"] + ": " + text + END_SIGNAL full_prompt += BEGIN_SIGNAL + roles["gpt"] + ": It is [SEG]." + END_SIGNAL full_prompt += BEGIN_SIGNAL input_ids = torch.stack([gen_image_token(full_prompt, tokenizer, return_tensors='pt')], dim=0) return dict( input_ids=input_ids, labels=None, ) def process_images(self, images, image_processor, model_cfg): image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) new_images = [] if image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio: for image in images: image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) new_images.append(image) else: return image_processor.preprocess(images, return_tensors="pt")["pixel_values"] if all(x.shape == new_images[0].shape for x in new_images): new_images = torch.stack(new_images, dim=0) return new_images def seg(self, image, prompt, tokenizer, force_seg = False): self.seg_token_idx = tokenizer(DEFAULT_SEG_TOKEN, add_special_tokens=False).input_ids[0] image_np = np.array(image) image_sizes = [image.size] input_ids = self.process_prompt(prompt, tokenizer, "gen")["input_ids"].to(self.device) image_processor = self.get_vision_tower().image_processor image_tensors = self.process_images([image], image_processor, self.config) image_np_resize = self.sam_transform.apply_image(image_np) original_size_list = [image_np.shape[:2]] image_sam_resize_list = [image_np_resize.shape[:2]] grounding_enc_img_list = [grounding_enc_processor(torch.from_numpy(image_np_resize).permute(2, 0, 1).contiguous()).to(dtype=self.dtype, device=self.device, non_blocking=True)] collect_size = list(set(original_size_list)) if len(collect_size) == 0: mask_h, mask_w = 336, 336 elif len(collect_size) == 1: mask_h, mask_w = collect_size[0] else: areas = [h*w for (h, w) in collect_size] mask_h, mask_w = collect_size[areas.index(max(areas))] if isinstance(image_tensors, list): image_aspect_ratio = getattr(self.config, "image_aspect_ratio", None) if image_aspect_ratio=="anyres_mul" or image_aspect_ratio=="anyres": image_tensors = [[x_.to(dtype=self.dtype, device=self.device, non_blocking=True)for x_ in image_tensors]] else: image_tensors = [[x_.unsqueeze(dim=0).to(dtype=self.dtype, device=self.device, non_blocking=True) for x_ in image_tensors]] else: image_tensors = image_tensors.to(dtype=self.dtype, device='cuda', non_blocking=True) if not force_seg: attention_mask = torch.ones(input_ids.shape).bool().to(device=self.device) with torch.inference_mode(): llm_gen = self.generate( inputs=input_ids, attention_mask=attention_mask, images=image_tensors, image_sizes=image_sizes, grounding_enc_imgs=[torch.stack(grounding_enc_img_list, dim=0)], image_sam_resizes=[image_sam_resize_list], original_sizes=[(mask_h, mask_w)], pad_token_id=tokenizer.eos_token_id ) seg_flag = llm_gen == self.seg_token_idx seg_flag = torch.sum(seg_flag.int()).item() if seg_flag > 0: force_seg = True if force_seg: input_ids = self.process_prompt(prompt, tokenizer, "seg")["input_ids"].to(self.device) with torch.inference_mode(): net_out = self.seg_forward( input_ids=input_ids, output_hidden_states=True, images=image_tensors, image_sizes=image_sizes, grounding_enc_imgs=[torch.stack(grounding_enc_img_list, dim=0)], image_sam_resizes=[image_sam_resize_list], original_sizes=[(mask_h, mask_w)], ) pred_mask = net_out[0] mask_tensor = (pred_mask > 0).int() return mask_tensor else: return torch.zeros([0] + list(image_np.shape[:2]), device=self.device) def gen_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == "pt": return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f"Unsupported tensor type: {return_tensors}") return input_ids def grounding_enc_processor(x: torch.Tensor) -> torch.Tensor: x = (x - IMG_MEAN) / IMG_STD h, w = x.shape[-2:] x = F.pad(x, (0, IMG_SIZE - w, 0, IMG_SIZE - h)) return x AutoConfig.register("mlcd_seg", MLCDSegConfig) AutoModelForCausalLM.register(MLCDSegConfig, MLCDSegForCausalLM)