Sapnous-VR-6B / test_modeling_sapnous.py
Atah Alam
Updated py files
5838aa1
raw
history blame contribute delete
3.78 kB
# coding=utf-8
# Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from .modeling_sapnous import SapnousT1ForCausalLM
from .configuration_sapnous import SapnousT1Config
class TestSapnousModel(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.config = SapnousT1Config(
vocab_size=32000,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072
)
cls.model = SapnousT1ForCausalLM(cls.config)
def test_model_forward(self):
input_ids = torch.randint(0, self.config.vocab_size, (1, 10))
outputs = self.model(input_ids)
self.assertIsNotNone(outputs)
self.assertTrue(hasattr(outputs, 'logits'))
self.assertEqual(outputs.logits.shape, (1, 10, self.config.vocab_size))
def test_weight_tying(self):
self.model.tie_weights()
self.assertTrue(torch.equal(self.model.lm_head.weight, self.model.model.embeddings.weight))
def test_auto_model_registration(self):
model = AutoModelForCausalLM.from_config(self.config)
self.assertIsInstance(model, SapnousT1ForCausalLM)
def test_vision_embeddings(self):
# Test vision input processing
batch_size = 1
pixel_values = torch.randn(batch_size, 3, 224, 224)
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, 10))
outputs = self.model(input_ids=input_ids, pixel_values=pixel_values)
self.assertIsNotNone(outputs)
self.assertTrue(hasattr(outputs, 'logits'))
# Vision input should increase sequence length
expected_seq_length = 10 + (224 // 16) ** 2 + 1 # text_len + num_patches + cls_token
self.assertEqual(outputs.logits.shape, (batch_size, expected_seq_length, self.config.vocab_size))
def test_attention_mask(self):
# Test attention mask handling
batch_size = 2
seq_length = 15
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_length))
attention_mask = torch.ones(batch_size, seq_length)
attention_mask[:, -5:] = 0 # Mask out last 5 tokens
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
self.assertIsNotNone(outputs)
self.assertEqual(outputs.logits.shape, (batch_size, seq_length, self.config.vocab_size))
def test_generation_with_vision(self):
# Test text generation with vision input
pixel_values = torch.randn(1, 3, 224, 224)
input_ids = torch.randint(0, self.config.vocab_size, (1, 5))
outputs = self.model.generate(
input_ids=input_ids,
pixel_values=pixel_values,
max_length=20,
num_beams=1
)
self.assertIsInstance(outputs, torch.Tensor)
self.assertEqual(outputs.dim(), 2)
self.assertTrue(outputs.size(1) <= 20)
if __name__ == '__main__':
unittest.main()