import re from hat_splitter import HATSplitter as RustHATSplitter class HATSplitter: def __init__(self, special_token_dict: dict | None = None, max_word_size: int = 128): self.hat_splitter = RustHATSplitter() self.max_word_size = max_word_size self.special_token_dict = special_token_dict self.special_token_replace: dict[int, list[int]] = { token: list(text.encode("utf-8")) for text, token in self.special_token_dict.items() } self.special_token_pattern = ( re.compile(rf"({'|'.join(map(re.escape, special_token_dict.keys()))})") if special_token_dict else re.compile(r"(?!)") ) def encode(self, text: str) -> list[list[int]]: chunks = [] for str_chunk in self.special_token_pattern.split(text): if str_chunk: if str_chunk in self.special_token_dict: chunks.append([self.special_token_dict[str_chunk]]) else: chunks.extend(list(chunk) for chunk in self.hat_splitter.split_with_limit(str_chunk, self.max_word_size)) return chunks def decode(self, token_ids: list[int], errors: str = "replace", skip_special_tokens: bool = False) -> str: assert isinstance(token_ids, list), "token_ids must be a list" assert all(isinstance(token_id, int) for token_id in token_ids), "token_ids must be a list of integers" new_token_ids: list[int] if skip_special_tokens: new_token_ids = [token_id for token_id in token_ids if token_id not in self.special_token_replace] else: new_token_ids = [] for token in token_ids: if token in self.special_token_replace: new_token_ids.extend(self.special_token_replace[token]) else: new_token_ids.append(token) return bytes(new_token_ids).decode("utf-8", errors=errors)