File size: 1,982 Bytes
f563cb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import re
from hat_splitter import HATSplitter as RustHATSplitter


class HATSplitter:
    def __init__(self, special_token_dict: dict | None = None, max_word_size: int = 128):
        self.hat_splitter = RustHATSplitter()
        self.max_word_size = max_word_size
        self.special_token_dict = special_token_dict
        self.special_token_replace: dict[int, list[int]] = {
            token: list(text.encode("utf-8")) for text, token in self.special_token_dict.items()
        }
        self.special_token_pattern = (
            re.compile(rf"({'|'.join(map(re.escape, special_token_dict.keys()))})")
            if special_token_dict
            else re.compile(r"(?!)")
        )


    def encode(self, text: str) -> list[list[int]]:
        chunks = []
        for str_chunk in self.special_token_pattern.split(text):
            if str_chunk:
                if str_chunk in self.special_token_dict:
                    chunks.append([self.special_token_dict[str_chunk]])
                else:
                    chunks.extend(list(chunk) for chunk in self.hat_splitter.split_with_limit(str_chunk, self.max_word_size))
        return chunks
    
    def decode(self, token_ids: list[int], errors: str = "replace", skip_special_tokens: bool = False) -> str:
        assert isinstance(token_ids, list), "token_ids must be a list"
        assert all(isinstance(token_id, int) for token_id in token_ids), "token_ids must be a list of integers"

        new_token_ids: list[int]
        if skip_special_tokens:
            new_token_ids = [token_id for token_id in token_ids if token_id not in self.special_token_replace]
        else:
            new_token_ids = []
            for token in token_ids:
                if token in self.special_token_replace:
                    new_token_ids.extend(self.special_token_replace[token])
                else:
                    new_token_ids.append(token)

        return bytes(new_token_ids).decode("utf-8", errors=errors)