|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
import torch
|
|
from pathlib import Path
|
|
from transformers import AutoTokenizer
|
|
from .tokenization_sapnous import SapnousTokenizer
|
|
|
|
class TestSapnousTokenizer(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
|
|
cls.temp_dir = Path('test_tokenizer_files')
|
|
cls.temp_dir.mkdir(exist_ok=True)
|
|
|
|
|
|
cls.vocab_file = cls.temp_dir / 'vocab.json'
|
|
cls.vocab = {
|
|
'<|endoftext|>': 0,
|
|
'<|startoftext|>': 1,
|
|
'<|pad|>': 2,
|
|
'<|vision_start|>': 3,
|
|
'<|vision_end|>': 4,
|
|
'<|image|>': 5,
|
|
'<|video|>': 6,
|
|
'hello': 7,
|
|
'world': 8,
|
|
'test': 9,
|
|
}
|
|
with cls.vocab_file.open('w', encoding='utf-8') as f:
|
|
import json
|
|
json.dump(cls.vocab, f)
|
|
|
|
|
|
cls.merges_file = cls.temp_dir / 'merges.txt'
|
|
merges_content = "#version: 0.2\nh e\ne l\nl l\no w\nw o\no r\nr l\nl d"
|
|
cls.merges_file.write_text(merges_content)
|
|
|
|
|
|
cls.tokenizer = SapnousTokenizer(
|
|
str(cls.vocab_file),
|
|
str(cls.merges_file),
|
|
)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
|
|
import shutil
|
|
shutil.rmtree(cls.temp_dir)
|
|
|
|
def test_tokenizer_initialization(self):
|
|
self.assertEqual(self.tokenizer.vocab_size, len(self.vocab))
|
|
self.assertEqual(self.tokenizer.get_vocab(), self.vocab)
|
|
|
|
|
|
self.assertEqual(self.tokenizer.unk_token, '<|endoftext|>')
|
|
self.assertEqual(self.tokenizer.bos_token, '<|startoftext|>')
|
|
self.assertEqual(self.tokenizer.eos_token, '<|endoftext|>')
|
|
self.assertEqual(self.tokenizer.pad_token, '<|pad|>')
|
|
|
|
def test_tokenization(self):
|
|
text = "hello world test"
|
|
tokens = self.tokenizer.tokenize(text)
|
|
self.assertIsInstance(tokens, list)
|
|
self.assertTrue(all(isinstance(token, str) for token in tokens))
|
|
|
|
|
|
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
|
self.assertIsInstance(input_ids, list)
|
|
self.assertEqual(len(input_ids), 3)
|
|
|
|
|
|
decoded_text = self.tokenizer.decode(input_ids)
|
|
self.assertEqual(decoded_text.strip(), text)
|
|
|
|
def test_special_tokens_handling(self):
|
|
text = "hello world"
|
|
|
|
tokens_with_special = self.tokenizer.encode(text, add_special_tokens=True)
|
|
self.assertTrue(tokens_with_special[0] == self.tokenizer.bos_token_id)
|
|
self.assertTrue(tokens_with_special[-1] == self.tokenizer.eos_token_id)
|
|
|
|
|
|
tokens_without_special = self.tokenizer.encode(text, add_special_tokens=False)
|
|
self.assertNotEqual(tokens_without_special[0], self.tokenizer.bos_token_id)
|
|
self.assertNotEqual(tokens_without_special[-1], self.tokenizer.eos_token_id)
|
|
|
|
def test_vision_tokens(self):
|
|
|
|
text = "This is an image description"
|
|
vision_text = self.tokenizer.prepare_for_vision(text)
|
|
self.assertTrue(vision_text.startswith('<|vision_start|>'))
|
|
self.assertTrue(vision_text.endswith('<|vision_end|>'))
|
|
|
|
image_text = self.tokenizer.prepare_for_image(text)
|
|
self.assertTrue(image_text.startswith('<|image|>'))
|
|
|
|
video_text = self.tokenizer.prepare_for_video(text)
|
|
self.assertTrue(video_text.startswith('<|video|>'))
|
|
|
|
def test_batch_encoding(self):
|
|
texts = ["hello world", "test hello"]
|
|
batch_encoding = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
|
|
|
self.assertIsInstance(batch_encoding["input_ids"], torch.Tensor)
|
|
self.assertIsInstance(batch_encoding["attention_mask"], torch.Tensor)
|
|
self.assertEqual(batch_encoding["input_ids"].shape[0], len(texts))
|
|
self.assertEqual(batch_encoding["attention_mask"].shape[0], len(texts))
|
|
|
|
def test_save_and_load(self):
|
|
|
|
save_dir = Path('test_save_tokenizer')
|
|
save_dir.mkdir(exist_ok=True)
|
|
|
|
try:
|
|
vocab_files = self.tokenizer.save_vocabulary(str(save_dir))
|
|
self.assertTrue(all(Path(f).exists() for f in vocab_files))
|
|
|
|
|
|
loaded_tokenizer = SapnousTokenizer(*vocab_files)
|
|
self.assertEqual(loaded_tokenizer.get_vocab(), self.tokenizer.get_vocab())
|
|
|
|
|
|
text = "hello world test"
|
|
original_encoding = self.tokenizer.encode(text)
|
|
loaded_encoding = loaded_tokenizer.encode(text)
|
|
self.assertEqual(original_encoding, loaded_encoding)
|
|
finally:
|
|
|
|
import shutil
|
|
shutil.rmtree(save_dir)
|
|
|
|
def test_auto_tokenizer_registration(self):
|
|
|
|
config = {
|
|
"model_type": "sapnous",
|
|
"vocab_file": str(self.vocab_file),
|
|
"merges_file": str(self.merges_file)
|
|
}
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(str(self.temp_dir), **config)
|
|
self.assertIsInstance(tokenizer, SapnousTokenizer)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |