|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Tuple, Union
|
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
|
from transformers import AutoTokenizer
|
|
import json
|
|
import regex as re
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
BYTES_TO_UNICODE_REGEX = re.compile(r"'([^']+)':\s*([0-9]+)")
|
|
|
|
def bytes_to_unicode():
|
|
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
|
cs = bs[:]
|
|
n = 0
|
|
for b in range(2**8):
|
|
if b not in bs:
|
|
bs.append(b)
|
|
cs.append(2**8 + n)
|
|
n += 1
|
|
cs = [chr(n) for n in cs]
|
|
return dict(zip(bs, cs))
|
|
|
|
def get_pairs(word):
|
|
pairs = set()
|
|
prev_char = word[0]
|
|
for char in word[1:]:
|
|
pairs.add((prev_char, char))
|
|
prev_char = char
|
|
return pairs
|
|
|
|
class SapnousTokenizer(PreTrainedTokenizer):
|
|
model_input_names = ["input_ids", "attention_mask"]
|
|
|
|
def __init__(
|
|
self,
|
|
vocab_file: str,
|
|
merges_file: Optional[str] = None,
|
|
unk_token: str = "<|endoftext|>",
|
|
bos_token: str = "<|startoftext|>",
|
|
eos_token: str = "<|endoftext|>",
|
|
pad_token: str = "<|pad|>",
|
|
vision_start_token: str = "<|vision_start|>",
|
|
vision_end_token: str = "<|vision_end|>",
|
|
image_token: str = "<|image|>",
|
|
video_token: str = "<|video|>",
|
|
add_prefix_space: bool = False,
|
|
**kwargs
|
|
):
|
|
super().__init__(
|
|
unk_token=unk_token,
|
|
bos_token=bos_token,
|
|
eos_token=eos_token,
|
|
pad_token=pad_token,
|
|
**kwargs,
|
|
)
|
|
|
|
self.vocab_file = vocab_file
|
|
self.merges_file = merges_file
|
|
self.add_prefix_space = add_prefix_space
|
|
|
|
self.special_tokens = {
|
|
"unk_token": unk_token,
|
|
"bos_token": bos_token,
|
|
"eos_token": eos_token,
|
|
"pad_token": pad_token,
|
|
"vision_start_token": vision_start_token,
|
|
"vision_end_token": vision_end_token,
|
|
"image_token": image_token,
|
|
"video_token": video_token,
|
|
}
|
|
|
|
with Path(vocab_file).open(encoding="utf-8") as f:
|
|
self.encoder = json.load(f)
|
|
self.decoder = {v: k for k, v in self.encoder.items()}
|
|
|
|
if merges_file:
|
|
with Path(merges_file).open(encoding="utf-8") as f:
|
|
bpe_merges = f.read().strip().split('\n')[1:]
|
|
bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
|
|
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
|
else:
|
|
self.bpe_ranks = {}
|
|
|
|
self.byte_encoder = bytes_to_unicode()
|
|
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
|
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?\d+| ?[^\s\w\d]+|\s+(?!\S)|\s+""")
|
|
|
|
def bpe(self, token: str) -> str:
|
|
if token in self.special_tokens.values():
|
|
return token
|
|
|
|
word = tuple(token)
|
|
pairs = get_pairs(word)
|
|
|
|
if not pairs:
|
|
return token
|
|
|
|
while True:
|
|
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
|
if bigram not in self.bpe_ranks:
|
|
break
|
|
first, second = bigram
|
|
new_word = []
|
|
i = 0
|
|
while i < len(word):
|
|
try:
|
|
j = word.index(first, i)
|
|
new_word.extend(word[i:j])
|
|
if word[j + 1] == second:
|
|
new_word.append(first + second)
|
|
i = j + 2
|
|
else:
|
|
new_word.append(word[j])
|
|
i = j + 1
|
|
except ValueError:
|
|
new_word.extend(word[i:])
|
|
break
|
|
word = tuple(new_word)
|
|
if len(word) == 1:
|
|
break
|
|
pairs = get_pairs(word)
|
|
return ' '.join(word)
|
|
|
|
def _tokenize(self, text: str) -> List[str]:
|
|
if self.add_prefix_space:
|
|
text = ' ' + text
|
|
|
|
bpe_tokens = []
|
|
for token in re.findall(self.pat, text):
|
|
token = ''.join(self.byte_encoder[ord(b)] for b in token)
|
|
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
|
return bpe_tokens
|
|
|
|
def _convert_token_to_id(self, token: str) -> int:
|
|
return self.encoder.get(token, self.encoder.get(self.unk_token))
|
|
|
|
def _convert_id_to_token(self, index: int) -> str:
|
|
return self.decoder.get(index, self.unk_token)
|
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
|
text = ''.join(tokens)
|
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
|
|
return text
|
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, str]:
|
|
if not filename_prefix:
|
|
filename_prefix = ""
|
|
|
|
vocab_file = Path(save_directory) / f"{filename_prefix}vocab.json"
|
|
merge_file = Path(save_directory) / f"{filename_prefix}merges.txt"
|
|
|
|
with vocab_file.open('w', encoding='utf-8') as f:
|
|
json.dump(self.encoder, f, ensure_ascii=False)
|
|
|
|
if self.merges_file:
|
|
with merge_file.open('w', encoding='utf-8') as f:
|
|
for merge in self.bpe_ranks:
|
|
f.write(f"{merge[0]} {merge[1]}\n")
|
|
return str(vocab_file), str(merge_file)
|
|
|
|
return str(vocab_file)
|
|
|
|
def prepare_for_vision(self, text: str) -> str:
|
|
"""Prepare text for vision tasks by adding special tokens."""
|
|
return f"{self.vision_start_token}{text}{self.vision_end_token}"
|
|
|
|
def prepare_for_image(self, text: str) -> str:
|
|
"""Prepare text for image tasks."""
|
|
return f"{self.image_token}{text}"
|
|
|
|
def prepare_for_video(self, text: str) -> str:
|
|
"""Prepare text for video tasks."""
|
|
return f"{self.video_token}{text}"
|
|
|
|
@property
|
|
def vocab_size(self) -> int:
|
|
return len(self.encoder)
|
|
|
|
def get_vocab(self) -> Dict[str, int]:
|
|
return self.encoder.copy()
|
|
|
|
|
|
AutoTokenizer.register(SapnousTokenizer, "sapnous") |