"""Tokenization classes for xTrimoPGLM."""

import os
from typing import List, Optional, Union, Dict, Any
from torch import TensorType
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding

VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}


def load_vocab_file(vocab_file: str) -> List[str]:
    with open(vocab_file, "r") as f:
        lines = f.read().splitlines()
        return [line.strip() for line in lines]


class xTrimoPGLMTokenizer(PreTrainedTokenizer):
    """
    Constructs a xTrimoPGLM tokenizer.
    """

    vocab_files_names = VOCAB_FILES_NAMES
    model_input_names = ["input_ids", "attention_mask", "position_ids"]
    def __init__(
        self,
        vocab_file: str,
        unk_token: str = "<unk>",
        pad_token: str = "<pad>",
        mask_token: str = "<mask>",
        eos_token: str = "<eos>",
        model_max_length: int = 2048,
        additional_special_tokens: Optional[List[str]] = None,
        **kwargs,
    ):
        self.all_tokens = load_vocab_file(vocab_file)
        self._id_to_token = dict(enumerate(self.all_tokens))
        self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}

        if additional_special_tokens is None:
            additional_special_tokens = ['<pad>', '<mask>', '<gmask>', '<smask>', '<eod>', '<sop>', '<eop>', '<eos>', '<unk>']

        super().__init__(
            unk_token=unk_token,
            pad_token=pad_token,
            mask_token=mask_token,
            eos_token=eos_token,
            model_max_length=model_max_length,
            additional_special_tokens=additional_special_tokens,
            **kwargs,
        )

        self.unique_no_split_tokens = self.all_tokens
        self._update_trie(self.unique_no_split_tokens)

    def _convert_id_to_token(self, index: int) -> str:
        return self._id_to_token.get(index, self.unk_token)

    def _convert_token_to_id(self, token: str) -> int:
        return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))

    def _tokenize(self, text: str, **kwargs) -> List[str]:
        return text.split()

    def get_vocab(self) -> dict:
        base_vocab = self._token_to_id.copy()
        base_vocab.update(self.added_tokens_encoder)
        return base_vocab

    def token_to_id(self, token: str) -> int:
        return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))

    def id_to_token(self, index: int) -> str:
        return self._id_to_token.get(index, self.unk_token)

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        sep = [self.eos_token_id]  
        if token_ids_1 is None:
            if self.eos_token_id is None:
                return token_ids_0
            else:
                return token_ids_0 + sep
        elif self.eos_token_id is None:
            raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!")
        return token_ids_0 + sep + token_ids_1 + sep  # Multiple inputs always have an EOS token


    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
        vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "tokenizer.model")
        with open(vocab_file, "w") as f:
            f.write("\n".join(self.all_tokens))
        return (vocab_file,)

    @property
    def vocab_size(self) -> int:
        return len(self.all_tokens)

    def apply_chat_template(
        self, 
        query, 
        add_generation_prompt: bool = True, 
        tokenize: bool = True, 
        padding: bool = False,
        truncation: bool = False,
        max_length: Optional[int] = None,
        return_tensors: Optional[Union[str, TensorType]] = None,
        return_dict: bool = False,
        tokenizer_kwargs: Optional[Dict[str, Any]] = None,
        add_special_tokens: bool = True,
        **kwargs,
) -> Union[str, List[int], List[str], List[List[int]], BatchEncoding]:

        generation_prompt = "<gmask><sop><eos>"
        if isinstance(query, str):
            query = [query]
        prompt_query = []
        if add_generation_prompt:
            for each in query:
                assert isinstance(each, str)
                prompt_query.append(generation_prompt+each)
        else:
            prompt_query = query
        if tokenize:
            output = self.batch_encode_plus(
                prompt_query,
                padding=padding,
                truncation=truncation,
                max_length=max_length,
                return_tensors=return_tensors,
                is_split_into_words=True,
                add_special_tokens=False
            )
            if return_dict:
                return output
            else:
                return output["input_ids"]
        else:
            return prompt_query