import torch, os, PIL, numbers
from PIL import Image
import cv2

from transformers.modeling_utils import PreTrainedModel
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
from transformers import AutoConfig, AutoModel, SiglipImageProcessor, SiglipVisionConfig, PretrainedConfig
from typing import Union
import torch.nn.functional as F
import numpy as np


def crop_clip(clip, min_h, min_w, h, w):
    if isinstance(clip[0], np.ndarray):
        cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]

    elif isinstance(clip[0], PIL.Image.Image):
        cropped = [
            img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
        ]
    else:
        raise TypeError('Expected numpy.ndarray or PIL.Image' +
                        'but got list of {0}'.format(type(clip[0])))
    return cropped


class Normalize(object):
    """Normalize a clip with mean and standard deviation.
    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
    will normalize each channel of the input ``torch.*Tensor`` i.e.
    ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
    .. note::
        This transform acts out of place, i.e., it does not mutates the input tensor.
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, clip):
        """
        Args:
            clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized.
        Returns:
            Tensor: Normalized Tensor clip.
        """
        return normalize(clip, self.mean, self.std)

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


class CenterCrop(object):
    """Extract center crop at the same location for a list of images
    Args:
    size (sequence or int): Desired output size for the
    crop in format (h, w)
    """

    def __init__(self, size):
        if isinstance(size, numbers.Number):
            size = (size, size)

        self.size = size

    def __call__(self, clip):
        """
        Args:
        img (PIL.Image or numpy.ndarray): List of images to be cropped
        in format (h, w, c) in numpy.ndarray
        Returns:
        PIL.Image or numpy.ndarray: Cropped list of images
        """
        h, w = self.size
        if isinstance(clip[0], np.ndarray):
            im_h, im_w, im_c = clip[0].shape
        elif isinstance(clip[0], PIL.Image.Image):
            im_w, im_h = clip[0].size
        else:
            raise TypeError('Expected numpy.ndarray or PIL.Image' +
                            'but got list of {0}'.format(type(clip[0])))
        if w > im_w or h > im_h:
            error_msg = (
                'Initial image size should be larger then '
                'cropped size but got cropped sizes : ({w}, {h}) while '
                'initial image is ({im_w}, {im_h})'.format(
                    im_w=im_w, im_h=im_h, w=w, h=h))
            raise ValueError(error_msg)

        x1 = int(round((im_w - w) / 2.))
        y1 = int(round((im_h - h) / 2.))
        cropped = crop_clip(clip, y1, x1, h, w)

        return cropped


def resize_clip(clip, size, interpolation='bilinear'):
    if isinstance(clip[0], np.ndarray):
        if isinstance(size, numbers.Number):
            im_h, im_w, im_c = clip[0].shape
            # Min spatial dim already matches minimal size
            if (im_w <= im_h and im_w == size) or (im_h <= im_w
                                                   and im_h == size):
                return clip
            new_h, new_w = get_resize_sizes(im_h, im_w, size)
            size = (new_w, new_h)
        else:
            size = size[0], size[1]
        if interpolation == 'bilinear':
            np_inter = cv2.INTER_LINEAR
        else:
            np_inter = cv2.INTER_NEAREST
        scaled = [
            cv2.resize(img, size, interpolation=np_inter) for img in clip
        ]
    elif isinstance(clip[0], PIL.Image.Image):
        if isinstance(size, numbers.Number):
            im_w, im_h = clip[0].size
            # Min spatial dim already matches minimal size
            if (im_w <= im_h and im_w == size) or (im_h <= im_w
                                                   and im_h == size):
                return clip
            new_h, new_w = get_resize_sizes(im_h, im_w, size)
            size = (new_w, new_h)
        else:
            size = size[1], size[0]
        if interpolation == 'bilinear':
            pil_inter = PIL.Image.BILINEAR
        else:
            pil_inter = PIL.Image.NEAREST
        scaled = [img.resize(size, pil_inter) for img in clip]
    else:
        raise TypeError('Expected numpy.ndarray or PIL.Image' +
                        'but got list of {0}'.format(type(clip[0])))
    return scaled


def _is_tensor_clip(clip):
    return torch.is_tensor(clip) and clip.ndimension() == 4


def get_resize_sizes(im_h, im_w, size):
    if im_w < im_h:
        ow = size
        oh = int(size * im_h / im_w)
    else:
        oh = size
        ow = int(size * im_w / im_h)
    return oh, ow


def normalize(clip, mean, std, inplace=False):
    if not _is_tensor_clip(clip):
        raise TypeError('tensor is not a torch clip.')

    if not inplace:
        clip = clip.clone()

    dtype = clip.dtype
    mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
    std = torch.as_tensor(std, dtype=dtype, device=clip.device)
    clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])

    return clip


class Resize(object):
    """Resizes a list of (H x W x C) numpy.ndarray to the final size
    The larger the original image is, the more times it takes to
    interpolate
    Args:
    interpolation (str): Can be one of 'nearest', 'bilinear'
    defaults to nearest
    size (tuple): (widht, height)
    """

    def __init__(self, size, interpolation='nearest'):
        self.size = size
        self.interpolation = interpolation

    def __call__(self, clip):
        resized = resize_clip(
            clip, self.size, interpolation=self.interpolation)
        return resized


class Compose(object):
    """Composes several transforms
    Args:
    transforms (list of ``Transform`` objects): list of transforms
    to compose
    """

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, clip):
        for t in self.transforms:
            clip = t(clip)
        return clip


def convert_img(img):
    """Converts (H, W, C) numpy.ndarray to (C, W, H) format"""
    if len(img.shape) == 3:
        img = img.transpose(2, 0, 1)
    if len(img.shape) == 2:
        img = np.expand_dims(img, 0)
    return img


class ClipToTensor(object):
    """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
    to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
    """

    def __init__(self, channel_nb=3, div_255=True, numpy=False):
        self.channel_nb = channel_nb
        self.div_255 = div_255
        self.numpy = numpy

    def __call__(self, clip):
        """
        Args: clip (list of numpy.ndarray): clip (list of images)
        to be converted to tensor.
        """
        # Retrieve shape
        if isinstance(clip[0], np.ndarray):
            h, w, ch = clip[0].shape
            assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
        elif isinstance(clip[0], Image.Image):
            w, h = clip[0].size
        else:
            raise TypeError(
                "Expected numpy.ndarray or PIL.Image\
            but got list of {0}".format(
                    type(clip[0])
                )
            )

        np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])

        # Convert
        for img_idx, img in enumerate(clip):
            if isinstance(img, np.ndarray):
                pass
            elif isinstance(img, Image.Image):
                img = np.array(img, copy=False)
            else:
                raise TypeError(
                    "Expected numpy.ndarray or PIL.Image\
                but got list of {0}".format(
                        type(clip[0])
                    )
                )
            img = convert_img(img)
            np_clip[:, img_idx, :, :] = img
        if self.numpy:
            if self.div_255:
                np_clip = np_clip / 255.0
            return np_clip

        else:
            tensor_clip = torch.from_numpy(np_clip)

            if not isinstance(tensor_clip, torch.FloatTensor):
                tensor_clip = tensor_clip.float()
            if self.div_255:
                tensor_clip = torch.div(tensor_clip, 255)
            return tensor_clip


class VisionTowerConfig(PretrainedConfig):
    model_type = "vision_tower"

    def __init__(self, vision_tower_name: str = None, **kwargs):
        super().__init__()
        self.vision_tower_name = vision_tower_name


class ProcessorWrapper:
    def __init__(self, transform=None, processor=None, height=378, width=378, frames_per_clip=1,
                 image_mean=[0.48145466, 0.4578275, 0.40821073]):
        assert transform is not None or processor is not None, "ERROR: you did not define both `transform` and `processor`! You must define either transform or processor"
        assert transform is None or processor is None, "ERROR: you did defined both `transform` and `processor`! You must define only one of: transform or processor"
        self._size = {
            "height": height,
            "width": width,
            "frames_per_clip": frames_per_clip
        }
        self._transforms = transform
        self._processor = processor
        self.image_mean = image_mean

    @property
    def size(self):
        return self._size

    def preprocess(self, image, return_tensors='pt'):
        # Ensure image is a PIL Image
        output = {}
        if self._transforms is not None:
            output['pixel_values'] = [self._transforms(image)]

        else:
            output = self._processor(image, return_tensors='pt')
        return output

    def save_pretrained(self, save_path):
        if self._transforms is not None:
            transform_dict = transform_to_dict(self._transforms)
            transform_dict["image_processor_type"] = "transforms"
            with open(os.path.join(save_path, 'preprocessor_config.json'), 'w') as f:
                json.dump(transform_dict, f, indent=4)
        else:
            self._processor.save_pretrained(save_path)
        return


class VisionTower(PreTrainedModel):
    config_class = VisionTowerConfig

    def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: VisionTowerConfig = None):
        super().__init__(vision_config)
        self.vision_tower_name = model_name_or_path
        self.vision_config = vision_config
        self.select_layer = getattr(config, "mm_vision_select_layer", -2)
        self.select_feature = getattr(config, "mm_vision_select_feature", "patch")
        self.encode_batch_size = getattr(config, "encode_batch_size", 0) // 2
        self.num_encode_batch = getattr(config, "num_encode_batch", 0) // 2
        self.temporal_tubelet_size = getattr(vision_config, "tubelet_size", 1)

    def feature_select(self, image_features):
        if self.select_layer is not None:
            image_features = image_features.hidden_states[self.select_layer]
            
        if self.select_feature == "patch":
            image_features = image_features[:, 1:]
        elif self.select_feature == "cls_patch":
            image_features = image_features
        else:
            raise ValueError(f"Unexpected select feature: {self.select_feature}")
            
        return image_features

    def vision_tower_forward(self, image):
        image_feature = self.vision_tower(image, output_hidden_states=True)
        return image_feature
    
    def _forward(self, images, out_T=1):
        if type(images) is list:
            image_features = []
            for image in images:
                image_feature = self.vision_tower_forward(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
                image_feature = self.feature_select(image_feature).to(image.dtype)
                image_feature = image_features.reshape(image_feature.shape[0], self.W, self.H, self.D)
                image_features.append(image_feature)
        else:
            original_shape = images.shape
            if len(original_shape) == 5 and self.T == 1:
                # downsample temporally if needed, and reshape from (B, T, C, W, H) to (B*T, C, W, H).
                images = images[:, ::original_shape[1] // out_T, ...]
                original_shape = images.shape
                images = images.view(-1, *original_shape[2:])

            image_features = self.vision_tower_forward(images.to(device=self.device, dtype=self.dtype))
            image_features = self.feature_select(image_features).to(images.dtype)
            # Reshape back to (B, T, ...) if necessary
            if len(original_shape) == 5 and self.T == 1:
                # Assuming the feature dimension does not change, adapt the following line if it does
                new_shape = list(image_features.shape[:-2]) + [self.W, self.H, self.hidden_size]
                image_features = image_features.reshape(new_shape)
                feature_size = image_features.shape[1:]
                image_features = image_features.view(original_shape[0], original_shape[1], *feature_size)
                
            else:
                image_features = image_features.reshape(image_features.shape[0], self.T, self.W, self.H, self.hidden_size)
                
        return image_features
    
    def forward(self, images):
        return self._forward(images)

    @property
    def dummy_feature(self):
        return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)

    @property
    def dtype(self):
        return self.vision_tower.dtype

    @property
    def device(self):
        return self.vision_tower.device

    @property
    def num_patches(self):
        return (self.config.image_size // self.config.patch_size) ** 2


class InternVideoTower(VisionTower):
    def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: PretrainedConfig = None):
        if vision_config is None:
            vision_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)

        super().__init__(model_name_or_path, config, vision_config)
        self.vision_config = vision_config
        normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

        print('loading: ', model_name_or_path)
        model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
        self.vision_tower = model.to(dtype=eval(config.model_dtype))

        transform = Compose([
            Resize(self.vision_config.img_size, interpolation='bilinear'),
            CenterCrop(size=(self.vision_config.img_size, self.vision_config.img_size)),
            ClipToTensor(),
            Normalize(mean=normalize[0], std=normalize[1])
        ])

        self.vision_processor = ProcessorWrapper(transform=transform,
                                                 height=self.vision_config.img_size,
                                                 width=self.vision_config.img_size,
                                                 frames_per_clip=self.vision_config.num_frames,
                                                 image_mean=normalize[0])

        self.W = self.H = vision_config.img_size // vision_config.patch_size
        self.T = self.vision_config.num_frames // self.vision_config.tubelet_size
        self.num_frames = self.vision_config.num_frames
        self.hidden_size = vision_config.d_model
        self.vision_select_layer=self.select_layer
        self.select_layer=None

    def vision_tower_forward(self, video):
        if video.shape[-3] < self.num_frames:
            video = video.repeat_interleave(self.num_frames, dim=-3)
        elif video.shape[-3] > self.num_frames:
            video = video[:, :, ::video.shape[-3] // self.num_frames, ...]

        video_feature = self.vision_tower(video.to(device=self.device, dtype=self.dtype),
                                          x_vis_return_idx=self.vision_select_layer, x_vis_only=True)
        
        return video_feature

    @property
    def device(self):
        return self.vision_tower.pos_embed.device


class SiglipVisionTower(VisionTower):
    def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: PretrainedConfig = None):
        if vision_config is None:
            vision_config = SiglipVisionConfig.from_pretrained(model_name_or_path)

        super().__init__(model_name_or_path, config, vision_config)
        self.vision_config = vision_config
        self.vision_tower_name = model_name_or_path
        self.vision_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)

        print('loading: ', model_name_or_path)
        self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)

        self.hidden_size = self.vision_config.hidden_size
        self.W = self.H = self.vision_config.image_size // self.vision_config.patch_size
        self.T = 1
        self.select_feature = "cls_patch"


class ApolloVisionTower(PreTrainedModel):
    def __init__(self, config, vision_tower_cfg):
        super(ApolloVisionTower, self).__init__(config, vision_tower_cfg)
        self.model_name_or_path = vision_tower_cfg._name_or_path
        self.vision_towers = vision_tower_cfg.vision_towers
        self._config = vision_tower_cfg

        for vision_tower_name in self.vision_towers:
            if 'internvideo' in vision_tower_name.lower():
                vision_tower = InternVideoTower(os.path.join(vision_tower_cfg._name_or_path, vision_tower_name), config)
            elif 'siglip' in vision_tower_name.lower():
                vision_tower = SiglipVisionTower(os.path.join(vision_tower_cfg._name_or_path, vision_tower_name),
                                                 config)

            setattr(self, vision_tower_name, vision_tower)

        self.vision_processor = [getattr(self, vt).vision_processor for vt in self.vision_towers]
        self.num_vision_encoders = len(self.vision_towers)
        self.W = self.H = max([getattr(self, vt).W for vt in self.vision_towers])
        self.T = max([getattr(self, vt).T for vt in self.vision_towers])
        self.max_tubelet_size = max(
            [getattr(getattr(self, vt).vision_config, 'tubelet_size', 1) for vt in self.vision_towers])
        
        self._hidden_size = sum([getattr(self, vt).hidden_size for vt in self.vision_towers])
        self.token_output_shape = (self.T, self.W, self.H)
        self.config.num_vision_encoders = self.num_vision_encoders
        self.config.vision_towers = self.vision_towers
        self.config.token_output_shape = self.token_output_shape

    def forward(self, x):
        output_features = []
        for x_s, vision_tower_name in zip(x, self.vision_towers):
            vision_tower = getattr(self, vision_tower_name)
            features = vision_tower._forward(x_s, out_T=self.T)

            if len(features.shape) != len(self.token_output_shape) + 2:
                features = features.unsqueeze(1)

            if features.shape[-len(self.token_output_shape) - 1:-1] != self.token_output_shape:
                features = features.permute(0, 4, 1, 2, 3).contiguous()  # shape [B, D, T, W, H]
                features = F.interpolate(features.to(torch.float32), size=self.token_output_shape, mode='trilinear',
                                         align_corners=False).to(features.dtype)
                features = features.permute(0, 2, 3, 4, 1).contiguous()

            output_features.append(features)

        output_features = torch.cat(output_features, dim=-1)
        output_features = torch.flatten(output_features, start_dim=1, end_dim=-2)
        return output_features

    def save_pretrained(
            self,
            save_directory: Union[str, os.PathLike],
            state_dict=None,
            **kwargs,
    ):
        if state_dict is None:
            state_dict = self.state_dict()

        for vision_tower_name in self.vision_towers:
            vision_tower = getattr(self, vision_tower_name)
            vision_tower_state_dict = OrderedDict(
                {k.split(f"vision_tower.{vision_tower_name}.vision_tower.")[-1]: v for k, v in state_dict.items() if
                 vision_tower_name in k}
            )
            vision_tower.vision_tower.save_pretrained(os.path.join(save_directory, vision_tower_name),
                                                      state_dict=vision_tower_state_dict, **kwargs)
            vision_tower.vision_processor.save_pretrained(os.path.join(save_directory, vision_tower_name))

        config = self.config
        config.configs = {}
        config.save_pretrained(save_directory)

    @property
    def patch_size(self):
        return self._patch_size

    @property
    def image_size(self):
        return self._image_size

    @property
    def hidden_size(self):
        return self._hidden_size