File size: 6,515 Bytes
5838aa1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# coding=utf-8
# Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from pathlib import Path
from transformers import AutoTokenizer
from .tokenization_sapnous import SapnousTokenizer
class TestSapnousTokenizer(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Create temporary vocab and merges files for testing
cls.temp_dir = Path('test_tokenizer_files')
cls.temp_dir.mkdir(exist_ok=True)
# Create a simple test vocabulary
cls.vocab_file = cls.temp_dir / 'vocab.json'
cls.vocab = {
'<|endoftext|>': 0,
'<|startoftext|>': 1,
'<|pad|>': 2,
'<|vision_start|>': 3,
'<|vision_end|>': 4,
'<|image|>': 5,
'<|video|>': 6,
'hello': 7,
'world': 8,
'test': 9,
}
with cls.vocab_file.open('w', encoding='utf-8') as f:
import json
json.dump(cls.vocab, f)
# Create test merges file
cls.merges_file = cls.temp_dir / 'merges.txt'
merges_content = "#version: 0.2\nh e\ne l\nl l\no w\nw o\no r\nr l\nl d"
cls.merges_file.write_text(merges_content)
# Initialize tokenizer
cls.tokenizer = SapnousTokenizer(
str(cls.vocab_file),
str(cls.merges_file),
)
@classmethod
def tearDownClass(cls):
# Clean up temporary files
import shutil
shutil.rmtree(cls.temp_dir)
def test_tokenizer_initialization(self):
self.assertEqual(self.tokenizer.vocab_size, len(self.vocab))
self.assertEqual(self.tokenizer.get_vocab(), self.vocab)
# Test special tokens
self.assertEqual(self.tokenizer.unk_token, '<|endoftext|>')
self.assertEqual(self.tokenizer.bos_token, '<|startoftext|>')
self.assertEqual(self.tokenizer.eos_token, '<|endoftext|>')
self.assertEqual(self.tokenizer.pad_token, '<|pad|>')
def test_tokenization(self):
text = "hello world test"
tokens = self.tokenizer.tokenize(text)
self.assertIsInstance(tokens, list)
self.assertTrue(all(isinstance(token, str) for token in tokens))
# Test encoding
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
self.assertIsInstance(input_ids, list)
self.assertEqual(len(input_ids), 3) # 'hello', 'world', 'test'
# Test decoding
decoded_text = self.tokenizer.decode(input_ids)
self.assertEqual(decoded_text.strip(), text)
def test_special_tokens_handling(self):
text = "hello world"
# Test with special tokens
tokens_with_special = self.tokenizer.encode(text, add_special_tokens=True)
self.assertTrue(tokens_with_special[0] == self.tokenizer.bos_token_id)
self.assertTrue(tokens_with_special[-1] == self.tokenizer.eos_token_id)
# Test without special tokens
tokens_without_special = self.tokenizer.encode(text, add_special_tokens=False)
self.assertNotEqual(tokens_without_special[0], self.tokenizer.bos_token_id)
self.assertNotEqual(tokens_without_special[-1], self.tokenizer.eos_token_id)
def test_vision_tokens(self):
# Test vision-specific token methods
text = "This is an image description"
vision_text = self.tokenizer.prepare_for_vision(text)
self.assertTrue(vision_text.startswith('<|vision_start|>'))
self.assertTrue(vision_text.endswith('<|vision_end|>'))
image_text = self.tokenizer.prepare_for_image(text)
self.assertTrue(image_text.startswith('<|image|>'))
video_text = self.tokenizer.prepare_for_video(text)
self.assertTrue(video_text.startswith('<|video|>'))
def test_batch_encoding(self):
texts = ["hello world", "test hello"]
batch_encoding = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
self.assertIsInstance(batch_encoding["input_ids"], torch.Tensor)
self.assertIsInstance(batch_encoding["attention_mask"], torch.Tensor)
self.assertEqual(batch_encoding["input_ids"].shape[0], len(texts))
self.assertEqual(batch_encoding["attention_mask"].shape[0], len(texts))
def test_save_and_load(self):
# Test saving vocabulary
save_dir = Path('test_save_tokenizer')
save_dir.mkdir(exist_ok=True)
try:
vocab_files = self.tokenizer.save_vocabulary(str(save_dir))
self.assertTrue(all(Path(f).exists() for f in vocab_files))
# Test loading saved vocabulary
loaded_tokenizer = SapnousTokenizer(*vocab_files)
self.assertEqual(loaded_tokenizer.get_vocab(), self.tokenizer.get_vocab())
# Test encoding/decoding with loaded tokenizer
text = "hello world test"
original_encoding = self.tokenizer.encode(text)
loaded_encoding = loaded_tokenizer.encode(text)
self.assertEqual(original_encoding, loaded_encoding)
finally:
# Clean up
import shutil
shutil.rmtree(save_dir)
def test_auto_tokenizer_registration(self):
# Test if the tokenizer can be loaded using AutoTokenizer
config = {
"model_type": "sapnous",
"vocab_file": str(self.vocab_file),
"merges_file": str(self.merges_file)
}
tokenizer = AutoTokenizer.from_pretrained(str(self.temp_dir), **config)
self.assertIsInstance(tokenizer, SapnousTokenizer)
if __name__ == '__main__':
unittest.main() |