import copy import os import os.path as osp import warnings from collections import defaultdict from typing import List, Union import torch from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoProcessor, AutoTokenizer from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput, VideoInput from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import logging from .constants import DEFAULT_IMAGE_TOKEN, MEDIA_TOKENS from .media import Image, Video, extract_media from .mm_utils import process_image, process_images from .tokenizer_utils import tokenize_conversation def fetch_image_url_or_fpath(url_or_fpath): if url_or_fpath.startswith("http") or url_or_fpath.startswith("https"): import tempfile import requests # Download the image to a temporary file temp_dir = tempfile.mkdtemp() temp_file = os.path.join(temp_dir, os.path.basename(url_or_fpath)) response = requests.get(url_or_fpath, stream=True) response.raise_for_status() with open(temp_file, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) return temp_file elif url_or_fpath.startswith("file://"): fpath = url_or_fpath.replace("file://", "") assert osp.exists(fpath), f"File {fpath} does not exist" return fpath elif osp.exists(url_or_fpath): assert osp.isfile(url_or_fpath), f"File {url_or_fpath} is not a file" return url_or_fpath else: raise ValueError(f"Unsupported image path: {url_or_fpath}") def __pad_fn(input_ids_list, padding_value=0, target_len=None, padding_side="left"): # tensor shape is (batch_size, seq_len) max_len = max([ids.shape[1] for ids in input_ids_list]) if target_len is not None: assert target_len >= max_len, "target_len must be greater than or equal to max_len" max_len = target_len new_input_ids_list = [] for i, input_ids in enumerate(input_ids_list): pad_tensor = torch.ones_like(input_ids) * padding_value curr_len = input_ids.shape[1] pad_tensor = pad_tensor[:, : max_len - curr_len] if padding_side == "right": input_ids = torch.cat((input_ids, pad_tensor), dim=1) else: input_ids = torch.cat((pad_tensor, input_ids), dim=1) new_input_ids_list.append(input_ids) return torch.cat(new_input_ids_list, dim=0) class VILAProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "padding": False, }, } class VILAProcessor(ProcessorMixin): # attributes = ["image_processor", "tokenizer"] attributes = [] # valid_kwargs = ["chat_template"] valid_kwargs = [] # image_processor_class = "VILAImageProcessor" # tokenizer_class = ("VILATokenizer", "VILATokenizerFast") def __init__(self, image_processor=None, tokenizer=None, chat_template=None, config=None, **kwargs): # self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token # self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token self.image_token = MEDIA_TOKENS["image"] self.video_token = MEDIA_TOKENS["video"] self.config = config self.image_processor = image_processor self.tokenizer = tokenizer super().__init__(image_processor, tokenizer, chat_template=chat_template) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): if os.path.isdir(pretrained_model_name_or_path): pretrained_model_name_or_path = pretrained_model_name_or_path else: print(f"pretrained_model_name_or_path {pretrained_model_name_or_path} is not a directory, downloading") from huggingface_hub import snapshot_download pretrained_model_name_or_path = snapshot_download(pretrained_model_name_or_path) image_processor = AutoImageProcessor.from_pretrained( osp.join(pretrained_model_name_or_path, "vision_tower"), trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( osp.join(pretrained_model_name_or_path, "llm"), trust_remote_code=True ) config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) return cls(image_processor=image_processor, tokenizer=tokenizer, config=config) def __repr__(self): return ( f"VILAProcessor(image_processor={self.image_processor}, tokenizer={self.tokenizer}, config={self.config})" ) def __call__( self, conversation, images: ImageInput = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, videos: VideoInput = None, **kwargs: Unpack[VILAProcessorKwargs], ) -> BatchFeature: if images is not None: warnings.warn("images is not supported in __call__") input_ids = [] media = defaultdict(list) media_config = defaultdict(dict) for conv in conversation: feat = self.__single_call__(conv, images, text, videos, **kwargs) input_ids.append(feat.input_ids) for name in feat.media: media[name] += feat.media[name] for name in feat.media_config: media_config[name].update(feat.media_config[name]) return BatchFeature( data={ # "input_ids": torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.pad_token_id), "input_ids": __pad_fn( input_ids, padding_value=self.tokenizer.pad_token_id, padding_side="left", ), "media": media, "media_config": media_config, } ) def __single_call__( self, conversation, images: ImageInput = None, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, videos: VideoInput = None, **kwargs: Unpack[VILAProcessorKwargs], ) -> BatchFeature: # TODO: should be merged with llava_arch.py/generate_content() # TODO (extract and preprocess should be done together, as the preprocess of image and video can be different, i.e. when dynamic res is used) conversation = copy.deepcopy(conversation) media = extract_media(conversation, self.config) # Process media media_config = defaultdict(dict) for name in media: if name == "image": if len(media["image"]) == 1 and self.config.image_aspect_ratio in ["dynamic", "dynamic_s2"]: self.config.image_processor = self.image_processor if self.config.image_aspect_ratio == "dynamic": images = process_image(media["image"][0], self.config, None, enable_dynamic_res=True).half() conversation[0]["value"] = conversation[0]["value"].replace( DEFAULT_IMAGE_TOKEN, f"{DEFAULT_IMAGE_TOKEN}\n" * images.shape[0] ) else: if type(self.config.s2_scales) is str: self.config.s2_scales = list(map(int, self.config.s2_scales.split(","))) images, block_sizes = process_image( media["image"][0], self.config, None, enable_dynamic_s2=True ) images = images.half() media_config[name]["block_sizes"] = [block_sizes] else: images = process_images(media["image"], self.vision_tower.image_processor, self.config).half() media[name] = [image for image in images] elif name == "video": media[name] = [ process_images(images, self.vision_tower.image_processor, self.config).half() for images in media[name] ] else: raise ValueError(f"Unsupported media type: {name}") input_ids = tokenize_conversation(conversation, self.tokenizer, add_generation_prompt=True).cuda().unsqueeze(0) # Set up the generation config return BatchFeature(data={"input_ids": input_ids, "media": media, "media_config": media_config}) def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) def decode(self, *args, **kwargs): """ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) def post_process_image_text_to_text(self, generated_outputs): """ Post-process the output of the model to decode the text. Args: generated_outputs (`torch.Tensor` or `np.ndarray`): The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` or `(sequence_length,)`. Returns: `List[str]`: The decoded text. """ return self.tokenizer.batch_decode( generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False ) @property def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) # inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt") def apply_chat_template(self, conversation, add_generation_prompt=True, **kwargs): vila_conv = [] for chat in conversation: vila_chat = {"from": "", "value": []} if chat["role"] == "user": # user allows to input image and text vila_chat["from"] = "human" for content in chat["content"]: if content["type"] == "image": if "path" in content: # VILA style vila_chat["value"].append(Image(fetch_image_url_or_fpath(content["path"]))) elif "image" in content: # Qwen style vila_chat["value"].append(Image(fetch_image_url_or_fpath(content["image"]))) else: raise ValueError(f"Unsupported content type `image`: {content}, `image` and `path` are required") elif content["type"] == "text": vila_chat["value"].append(content["text"]) # NOTE(ligeng): video supports are needed here else: raise ValueError(f"Unsupported content type: {content['type']}") elif chat["role"] == "assistant": vila_chat["from"] = "gpt" for content in chat["content"]: assert content["type"] == "text", f"Unsupported content type: {content['type']}" vila_chat["value"].append(content["text"]) vila_conv.append(vila_chat) return vila_conv if __name__ == "__main__": # gpt style: user, assistant # vila style: human, gpt gpt_conv = [ { "role": "user", "content": [ {"type": "image", "path": "demo_images/demo_img_1.png"}, {"type": "text", "text": "Describe this image."}, ], } ] llavaconv = [ { "from": "human", "value": [ PIL.Image.open("demo_images/demo_img_1.png"), "Describe this image.", ], } ] processor = AutoProcessor.from_pretrained(output_dir, trust_remote_code=True) inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt") # model = llava.load("Efficient-Large-Model/qwen25_2B_3x3-sft").cuda() # print(model) model_path = "NVILA-Lite-2B-hf-preview" model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto") # res = model.generate_content(["how are you today?"]) # print(model.config) # print(model.tokenizer) # print(res) # exit(0) processor = VILAProcessor( config=model.config, image_processor=model.vision_tower.image_processor, tokenizer=model.tokenizer, ) # TODO: add padding, return_tensors, inputs = processor(conversation=llavaconv, padding=True, return_tensors="pt") print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image]) print("vila conv pass") inputs = processor.apply_chat_template(conversation=gpt_conv, padding=True, return_tensors="pt") print(inputs.keys(), inputs.input_ids.shape, [_.shape for _ in inputs.image]) print("gpt conv pass") output_ids = model.generate( input_ids=inputs.input_ids, media={ "image": inputs.image, }, media_config={"image": {}}, generation_config=model.generation_config, max_new_tokens=100, ) print(output_ids)