Sapnous-VR-6B / test_tokenization_sapnous.py
Atah Alam
Updated py files
5838aa1
raw
history blame contribute delete
6.52 kB
# coding=utf-8
# Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from pathlib import Path
from transformers import AutoTokenizer
from .tokenization_sapnous import SapnousTokenizer
class TestSapnousTokenizer(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Create temporary vocab and merges files for testing
cls.temp_dir = Path('test_tokenizer_files')
cls.temp_dir.mkdir(exist_ok=True)
# Create a simple test vocabulary
cls.vocab_file = cls.temp_dir / 'vocab.json'
cls.vocab = {
'<|endoftext|>': 0,
'<|startoftext|>': 1,
'<|pad|>': 2,
'<|vision_start|>': 3,
'<|vision_end|>': 4,
'<|image|>': 5,
'<|video|>': 6,
'hello': 7,
'world': 8,
'test': 9,
}
with cls.vocab_file.open('w', encoding='utf-8') as f:
import json
json.dump(cls.vocab, f)
# Create test merges file
cls.merges_file = cls.temp_dir / 'merges.txt'
merges_content = "#version: 0.2\nh e\ne l\nl l\no w\nw o\no r\nr l\nl d"
cls.merges_file.write_text(merges_content)
# Initialize tokenizer
cls.tokenizer = SapnousTokenizer(
str(cls.vocab_file),
str(cls.merges_file),
)
@classmethod
def tearDownClass(cls):
# Clean up temporary files
import shutil
shutil.rmtree(cls.temp_dir)
def test_tokenizer_initialization(self):
self.assertEqual(self.tokenizer.vocab_size, len(self.vocab))
self.assertEqual(self.tokenizer.get_vocab(), self.vocab)
# Test special tokens
self.assertEqual(self.tokenizer.unk_token, '<|endoftext|>')
self.assertEqual(self.tokenizer.bos_token, '<|startoftext|>')
self.assertEqual(self.tokenizer.eos_token, '<|endoftext|>')
self.assertEqual(self.tokenizer.pad_token, '<|pad|>')
def test_tokenization(self):
text = "hello world test"
tokens = self.tokenizer.tokenize(text)
self.assertIsInstance(tokens, list)
self.assertTrue(all(isinstance(token, str) for token in tokens))
# Test encoding
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
self.assertIsInstance(input_ids, list)
self.assertEqual(len(input_ids), 3) # 'hello', 'world', 'test'
# Test decoding
decoded_text = self.tokenizer.decode(input_ids)
self.assertEqual(decoded_text.strip(), text)
def test_special_tokens_handling(self):
text = "hello world"
# Test with special tokens
tokens_with_special = self.tokenizer.encode(text, add_special_tokens=True)
self.assertTrue(tokens_with_special[0] == self.tokenizer.bos_token_id)
self.assertTrue(tokens_with_special[-1] == self.tokenizer.eos_token_id)
# Test without special tokens
tokens_without_special = self.tokenizer.encode(text, add_special_tokens=False)
self.assertNotEqual(tokens_without_special[0], self.tokenizer.bos_token_id)
self.assertNotEqual(tokens_without_special[-1], self.tokenizer.eos_token_id)
def test_vision_tokens(self):
# Test vision-specific token methods
text = "This is an image description"
vision_text = self.tokenizer.prepare_for_vision(text)
self.assertTrue(vision_text.startswith('<|vision_start|>'))
self.assertTrue(vision_text.endswith('<|vision_end|>'))
image_text = self.tokenizer.prepare_for_image(text)
self.assertTrue(image_text.startswith('<|image|>'))
video_text = self.tokenizer.prepare_for_video(text)
self.assertTrue(video_text.startswith('<|video|>'))
def test_batch_encoding(self):
texts = ["hello world", "test hello"]
batch_encoding = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
self.assertIsInstance(batch_encoding["input_ids"], torch.Tensor)
self.assertIsInstance(batch_encoding["attention_mask"], torch.Tensor)
self.assertEqual(batch_encoding["input_ids"].shape[0], len(texts))
self.assertEqual(batch_encoding["attention_mask"].shape[0], len(texts))
def test_save_and_load(self):
# Test saving vocabulary
save_dir = Path('test_save_tokenizer')
save_dir.mkdir(exist_ok=True)
try:
vocab_files = self.tokenizer.save_vocabulary(str(save_dir))
self.assertTrue(all(Path(f).exists() for f in vocab_files))
# Test loading saved vocabulary
loaded_tokenizer = SapnousTokenizer(*vocab_files)
self.assertEqual(loaded_tokenizer.get_vocab(), self.tokenizer.get_vocab())
# Test encoding/decoding with loaded tokenizer
text = "hello world test"
original_encoding = self.tokenizer.encode(text)
loaded_encoding = loaded_tokenizer.encode(text)
self.assertEqual(original_encoding, loaded_encoding)
finally:
# Clean up
import shutil
shutil.rmtree(save_dir)
def test_auto_tokenizer_registration(self):
# Test if the tokenizer can be loaded using AutoTokenizer
config = {
"model_type": "sapnous",
"vocab_file": str(self.vocab_file),
"merges_file": str(self.merges_file)
}
tokenizer = AutoTokenizer.from_pretrained(str(self.temp_dir), **config)
self.assertIsInstance(tokenizer, SapnousTokenizer)
if __name__ == '__main__':
unittest.main()