|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
from .modeling_sapnous import SapnousT1ForCausalLM
|
|
from .configuration_sapnous import SapnousT1Config
|
|
|
|
class TestSapnousModel(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
cls.config = SapnousT1Config(
|
|
vocab_size=32000,
|
|
hidden_size=768,
|
|
num_hidden_layers=12,
|
|
num_attention_heads=12,
|
|
intermediate_size=3072
|
|
)
|
|
cls.model = SapnousT1ForCausalLM(cls.config)
|
|
|
|
def test_model_forward(self):
|
|
input_ids = torch.randint(0, self.config.vocab_size, (1, 10))
|
|
outputs = self.model(input_ids)
|
|
|
|
self.assertIsNotNone(outputs)
|
|
self.assertTrue(hasattr(outputs, 'logits'))
|
|
self.assertEqual(outputs.logits.shape, (1, 10, self.config.vocab_size))
|
|
|
|
def test_weight_tying(self):
|
|
self.model.tie_weights()
|
|
self.assertTrue(torch.equal(self.model.lm_head.weight, self.model.model.embeddings.weight))
|
|
|
|
def test_auto_model_registration(self):
|
|
model = AutoModelForCausalLM.from_config(self.config)
|
|
self.assertIsInstance(model, SapnousT1ForCausalLM)
|
|
|
|
def test_vision_embeddings(self):
|
|
|
|
batch_size = 1
|
|
pixel_values = torch.randn(batch_size, 3, 224, 224)
|
|
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, 10))
|
|
|
|
outputs = self.model(input_ids=input_ids, pixel_values=pixel_values)
|
|
self.assertIsNotNone(outputs)
|
|
self.assertTrue(hasattr(outputs, 'logits'))
|
|
|
|
|
|
expected_seq_length = 10 + (224 // 16) ** 2 + 1
|
|
self.assertEqual(outputs.logits.shape, (batch_size, expected_seq_length, self.config.vocab_size))
|
|
|
|
def test_attention_mask(self):
|
|
|
|
batch_size = 2
|
|
seq_length = 15
|
|
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_length))
|
|
attention_mask = torch.ones(batch_size, seq_length)
|
|
attention_mask[:, -5:] = 0
|
|
|
|
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
self.assertIsNotNone(outputs)
|
|
self.assertEqual(outputs.logits.shape, (batch_size, seq_length, self.config.vocab_size))
|
|
|
|
def test_generation_with_vision(self):
|
|
|
|
pixel_values = torch.randn(1, 3, 224, 224)
|
|
input_ids = torch.randint(0, self.config.vocab_size, (1, 5))
|
|
|
|
outputs = self.model.generate(
|
|
input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
max_length=20,
|
|
num_beams=1
|
|
)
|
|
|
|
self.assertIsInstance(outputs, torch.Tensor)
|
|
self.assertEqual(outputs.dim(), 2)
|
|
self.assertTrue(outputs.size(1) <= 20)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |