File size: 3,782 Bytes
5838aa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# coding=utf-8
# Copyright 2025-present, the HuggingFace Inc. Team and AIRAS Inc. Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from .modeling_sapnous import SapnousT1ForCausalLM
from .configuration_sapnous import SapnousT1Config

class TestSapnousModel(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.config = SapnousT1Config(
            vocab_size=32000,
            hidden_size=768,
            num_hidden_layers=12,
            num_attention_heads=12,
            intermediate_size=3072
        )
        cls.model = SapnousT1ForCausalLM(cls.config)

    def test_model_forward(self):
        input_ids = torch.randint(0, self.config.vocab_size, (1, 10))
        outputs = self.model(input_ids)
        
        self.assertIsNotNone(outputs)
        self.assertTrue(hasattr(outputs, 'logits'))
        self.assertEqual(outputs.logits.shape, (1, 10, self.config.vocab_size))

    def test_weight_tying(self):
        self.model.tie_weights()
        self.assertTrue(torch.equal(self.model.lm_head.weight, self.model.model.embeddings.weight))

    def test_auto_model_registration(self):
        model = AutoModelForCausalLM.from_config(self.config)
        self.assertIsInstance(model, SapnousT1ForCausalLM)

    def test_vision_embeddings(self):
        # Test vision input processing
        batch_size = 1
        pixel_values = torch.randn(batch_size, 3, 224, 224)
        input_ids = torch.randint(0, self.config.vocab_size, (batch_size, 10))
        
        outputs = self.model(input_ids=input_ids, pixel_values=pixel_values)
        self.assertIsNotNone(outputs)
        self.assertTrue(hasattr(outputs, 'logits'))
        
        # Vision input should increase sequence length
        expected_seq_length = 10 + (224 // 16) ** 2 + 1  # text_len + num_patches + cls_token
        self.assertEqual(outputs.logits.shape, (batch_size, expected_seq_length, self.config.vocab_size))
    
    def test_attention_mask(self):
        # Test attention mask handling
        batch_size = 2
        seq_length = 15
        input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_length))
        attention_mask = torch.ones(batch_size, seq_length)
        attention_mask[:, -5:] = 0  # Mask out last 5 tokens
        
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        self.assertIsNotNone(outputs)
        self.assertEqual(outputs.logits.shape, (batch_size, seq_length, self.config.vocab_size))
    
    def test_generation_with_vision(self):
        # Test text generation with vision input
        pixel_values = torch.randn(1, 3, 224, 224)
        input_ids = torch.randint(0, self.config.vocab_size, (1, 5))
        
        outputs = self.model.generate(
            input_ids=input_ids,
            pixel_values=pixel_values,
            max_length=20,
            num_beams=1
        )
        
        self.assertIsInstance(outputs, torch.Tensor)
        self.assertEqual(outputs.dim(), 2)
        self.assertTrue(outputs.size(1) <= 20)

if __name__ == '__main__':
    unittest.main()