convaiinnovations commited on
Commit
b90e0c2
·
verified ·
1 Parent(s): d7c92d2

Upload 2 files

Browse files
Files changed (2) hide show
  1. convaicausallm_model.py +179 -0
  2. hindi_embeddings.py +730 -0
convaicausallm_model.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig
4
+ from typing import Optional, Tuple
5
+
6
+ class ConvaiCausalLMConfig(PretrainedConfig):
7
+ model_type = "convaicausallm"
8
+
9
+ def __init__(
10
+ self,
11
+ vocab_size=16000,
12
+ hidden_size=768,
13
+ num_hidden_layers=12,
14
+ num_attention_heads=16,
15
+ num_key_value_heads=4,
16
+ intermediate_size=3072,
17
+ hidden_act="silu",
18
+ max_position_embeddings=512,
19
+ **kwargs
20
+ ):
21
+ super().__init__(**kwargs)
22
+ self.vocab_size = vocab_size
23
+ self.hidden_size = hidden_size
24
+ self.num_hidden_layers = num_hidden_layers
25
+ self.num_attention_heads = num_attention_heads
26
+ self.num_key_value_heads = num_key_value_heads
27
+ self.intermediate_size = intermediate_size
28
+ self.hidden_act = hidden_act
29
+ self.max_position_embeddings = max_position_embeddings
30
+
31
+ class GroupedQueryAttention(nn.Module):
32
+ def __init__(self, config):
33
+ super().__init__()
34
+ self.hidden_size = config.hidden_size
35
+ self.num_heads = config.num_attention_heads
36
+ self.num_kv_heads = config.num_key_value_heads
37
+ self.head_dim = config.hidden_size // config.num_attention_heads
38
+
39
+ # For MQA/GQA support
40
+ self.num_key_value_groups = self.num_heads // self.num_kv_heads
41
+
42
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim)
43
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim)
44
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim)
45
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)
46
+
47
+ # Create causal mask for attention
48
+ max_positions = config.max_position_embeddings
49
+ self.register_buffer(
50
+ "causal_mask",
51
+ torch.triu(torch.ones(max_positions, max_positions) * -1e9, diagonal=1)
52
+ )
53
+
54
+ def forward(self, hidden_states, attention_mask=None):
55
+ batch_size, seq_len, _ = hidden_states.size()
56
+
57
+ # Project queries, keys, values
58
+ q = self.q_proj(hidden_states)
59
+ k = self.k_proj(hidden_states)
60
+ v = self.v_proj(hidden_states)
61
+
62
+ # Reshape for attention computation
63
+ q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [b, n_heads, seq, head_dim]
64
+ k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # [b, n_kv_heads, seq, head_dim]
65
+ v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) # [b, n_kv_heads, seq, head_dim]
66
+
67
+ # Handle Multi-Query Attention / Grouped-Query Attention
68
+ if self.num_key_value_groups > 1:
69
+ # Repeat k, v for each query in the group
70
+ k = k.repeat_interleave(self.num_key_value_groups, dim=1) # [b, n_heads, seq, head_dim]
71
+ v = v.repeat_interleave(self.num_key_value_groups, dim=1) # [b, n_heads, seq, head_dim]
72
+
73
+ # Compute attention scores: [batch, n_heads, seq_len, seq_len]
74
+ attn_scores = torch.matmul(q, k.transpose(-1, -2)) / (self.head_dim ** 0.5)
75
+
76
+ # Apply causal mask - only attend to previous tokens
77
+ causal_mask = self.causal_mask[:seq_len, :seq_len]
78
+ attn_scores = attn_scores + causal_mask
79
+
80
+ # Apply attention mask if provided
81
+ if attention_mask is not None:
82
+ # attention_mask: [batch, 1, 1, seq_len]
83
+ attn_scores = attn_scores + attention_mask
84
+
85
+ # Normalize the attention scores to probabilities
86
+ attn_probs = torch.softmax(attn_scores, dim=-1)
87
+
88
+ # Apply attention to values
89
+ context = torch.matmul(attn_probs, v) # [b, n_heads, seq, head_dim]
90
+
91
+ # Reshape back to [batch_size, seq_length, hidden_size]
92
+ context = context.transpose(1, 2).contiguous()
93
+ context = context.view(batch_size, seq_len, -1)
94
+
95
+ # Final projection
96
+ output = self.o_proj(context)
97
+
98
+ return output
99
+
100
+ class ConvaiCausalLM(PreTrainedModel):
101
+ config_class = ConvaiCausalLMConfig
102
+
103
+ def __init__(self, config):
104
+ super().__init__(config)
105
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
106
+ self.layers = nn.ModuleList([
107
+ nn.ModuleDict({
108
+ "self_attn": GroupedQueryAttention(config),
109
+ "mlp": nn.Sequential(
110
+ nn.Linear(config.hidden_size, config.intermediate_size),
111
+ nn.SiLU(),
112
+ nn.Linear(config.intermediate_size, config.hidden_size)
113
+ ),
114
+ "input_layernorm": nn.LayerNorm(config.hidden_size),
115
+ "post_attention_layernorm": nn.LayerNorm(config.hidden_size)
116
+ }) for _ in range(config.num_hidden_layers)
117
+ ])
118
+ self.norm = nn.LayerNorm(config.hidden_size)
119
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
120
+
121
+ # Initialize weights
122
+ self.apply(self._init_weights)
123
+
124
+ def _init_weights(self, module):
125
+ if isinstance(module, nn.Linear):
126
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
127
+ if module.bias is not None:
128
+ torch.nn.init.zeros_(module.bias)
129
+ elif isinstance(module, nn.Embedding):
130
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
131
+
132
+ def _prepare_attention_mask(self, attention_mask, input_shape, device):
133
+ # Prepare masks for attention
134
+ if attention_mask is None:
135
+ attention_mask = torch.ones(input_shape, device=device)
136
+
137
+ # Make broadcastable shape: [batch, 1, 1, seq_len]
138
+ extended_mask = attention_mask.unsqueeze(1).unsqueeze(2)
139
+
140
+ # Convert to additive mask (0 for valid, -10000 for masked)
141
+ extended_mask = (1.0 - extended_mask) * -10000.0
142
+
143
+ return extended_mask
144
+
145
+ def forward(self, input_ids, attention_mask=None):
146
+ batch_size, seq_len = input_ids.shape
147
+ device = input_ids.device
148
+
149
+ # Prepare attention mask
150
+ if attention_mask is not None:
151
+ attention_mask = self._prepare_attention_mask(
152
+ attention_mask, (batch_size, seq_len), device
153
+ )
154
+
155
+ # Get embeddings
156
+ hidden_states = self.embed_tokens(input_ids)
157
+
158
+ # Apply each layer
159
+ for layer in self.layers:
160
+ residual = hidden_states
161
+
162
+ # First norm and attention
163
+ hidden_states = layer["input_layernorm"](hidden_states)
164
+ hidden_states = layer["self_attn"](hidden_states, attention_mask)
165
+ hidden_states = residual + hidden_states
166
+
167
+ # Second norm and MLP
168
+ residual = hidden_states
169
+ hidden_states = layer["post_attention_layernorm"](hidden_states)
170
+ hidden_states = layer["mlp"](hidden_states)
171
+ hidden_states = residual + hidden_states
172
+
173
+ # Final norm
174
+ hidden_states = self.norm(hidden_states)
175
+
176
+ # Compute logits
177
+ logits = self.lm_head(hidden_states)
178
+
179
+ return logits
hindi_embeddings.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+ import numpy as np
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ import sentencepiece as spm
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ from tqdm import tqdm
10
+ import matplotlib.pyplot as plt
11
+ from sklearn.manifold import TSNE
12
+
13
+ # Tokenizer wrapper class
14
+ class SentencePieceTokenizerWrapper:
15
+ def __init__(self, sp_model_path):
16
+ self.sp_model = spm.SentencePieceProcessor()
17
+ self.sp_model.Load(sp_model_path)
18
+ self.vocab_size = self.sp_model.GetPieceSize()
19
+
20
+ # Special token IDs from tokenizer training
21
+ self.pad_token_id = 0
22
+ self.bos_token_id = 1
23
+ self.eos_token_id = 2
24
+ self.unk_token_id = 3
25
+
26
+ # Set special tokens
27
+ self.pad_token = "<pad>"
28
+ self.bos_token = "<s>"
29
+ self.eos_token = "</s>"
30
+ self.unk_token = "<unk>"
31
+ self.mask_token = "<mask>"
32
+
33
+ def __call__(self, text, padding=False, truncation=False, max_length=None, return_tensors=None):
34
+ # Handle both string and list inputs
35
+ if isinstance(text, str):
36
+ # Encode a single string
37
+ ids = self.sp_model.EncodeAsIds(text)
38
+
39
+ # Handle truncation
40
+ if truncation and max_length and len(ids) > max_length:
41
+ ids = ids[:max_length]
42
+
43
+ attention_mask = [1] * len(ids)
44
+
45
+ # Handle padding
46
+ if padding and max_length:
47
+ padding_length = max(0, max_length - len(ids))
48
+ ids = ids + [self.pad_token_id] * padding_length
49
+ attention_mask = attention_mask + [0] * padding_length
50
+
51
+ result = {
52
+ 'input_ids': ids,
53
+ 'attention_mask': attention_mask
54
+ }
55
+
56
+ # Convert to tensors if requested
57
+ if return_tensors == 'pt':
58
+ import torch
59
+ result = {k: torch.tensor([v]) for k, v in result.items()}
60
+
61
+ return result
62
+
63
+ # Process a batch of texts
64
+ batch_encoded = [self.sp_model.EncodeAsIds(t) for t in text]
65
+
66
+ # Apply truncation if needed
67
+ if truncation and max_length:
68
+ batch_encoded = [ids[:max_length] for ids in batch_encoded]
69
+
70
+ # Create attention masks
71
+ batch_attention_mask = [[1] * len(ids) for ids in batch_encoded]
72
+
73
+ # Apply padding if needed
74
+ if padding:
75
+ if max_length:
76
+ max_len = max_length
77
+ else:
78
+ max_len = max(len(ids) for ids in batch_encoded)
79
+
80
+ # Pad sequences to max_len
81
+ batch_encoded = [ids + [self.pad_token_id] * (max_len - len(ids)) for ids in batch_encoded]
82
+ batch_attention_mask = [mask + [0] * (max_len - len(mask)) for mask in batch_attention_mask]
83
+
84
+ result = {
85
+ 'input_ids': batch_encoded,
86
+ 'attention_mask': batch_attention_mask
87
+ }
88
+
89
+ # Convert to tensors if requested
90
+ if return_tensors == 'pt':
91
+ import torch
92
+ result = {k: torch.tensor(v) for k, v in result.items()}
93
+
94
+ return result
95
+
96
+ # Model architecture components
97
+ class MultiHeadAttention(nn.Module):
98
+ """Multi-headed attention mechanism"""
99
+ def __init__(self, config):
100
+ super().__init__()
101
+ self.num_attention_heads = config["num_attention_heads"]
102
+ self.attention_head_size = config["hidden_size"] // config["num_attention_heads"]
103
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
104
+
105
+ # Query, Key, Value projections
106
+ self.query = nn.Linear(config["hidden_size"], self.all_head_size)
107
+ self.key = nn.Linear(config["hidden_size"], self.all_head_size)
108
+ self.value = nn.Linear(config["hidden_size"], self.all_head_size)
109
+
110
+ # Output projection
111
+ self.output = nn.Sequential(
112
+ nn.Linear(self.all_head_size, config["hidden_size"]),
113
+ nn.Dropout(config["attention_probs_dropout_prob"])
114
+ )
115
+
116
+ # Simplified relative position bias
117
+ self.max_position_embeddings = config["max_position_embeddings"]
118
+ self.relative_attention_bias = nn.Embedding(
119
+ 2 * config["max_position_embeddings"] - 1,
120
+ config["num_attention_heads"]
121
+ )
122
+
123
+ def transpose_for_scores(self, x):
124
+ new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
125
+ x = x.view(*new_shape)
126
+ return x.permute(0, 2, 1, 3)
127
+
128
+ def forward(self, hidden_states, attention_mask=None):
129
+ batch_size, seq_length = hidden_states.size()[:2]
130
+
131
+ # Project inputs to queries, keys, and values
132
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
133
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
134
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
135
+
136
+ # Take the dot product between query and key to get the raw attention scores
137
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
138
+
139
+ # Generate relative position matrix
140
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device)
141
+ relative_position = position_ids.unsqueeze(1) - position_ids.unsqueeze(0) # [seq_len, seq_len]
142
+ # Shift values to be >= 0
143
+ relative_position = relative_position + self.max_position_embeddings - 1
144
+ # Ensure indices are within bounds
145
+ relative_position = torch.clamp(relative_position, 0, 2 * self.max_position_embeddings - 2)
146
+
147
+ # Get relative position embeddings [seq_len, seq_len, num_heads]
148
+ rel_attn_bias = self.relative_attention_bias(relative_position) # [seq_len, seq_len, num_heads]
149
+
150
+ # Reshape to add to attention heads [1, num_heads, seq_len, seq_len]
151
+ rel_attn_bias = rel_attn_bias.permute(2, 0, 1).unsqueeze(0)
152
+
153
+ # Add to attention scores - now dimensions will match
154
+ attention_scores = attention_scores + rel_attn_bias
155
+
156
+ # Scale attention scores
157
+ attention_scores = attention_scores / (self.attention_head_size ** 0.5)
158
+
159
+ # Apply attention mask
160
+ if attention_mask is not None:
161
+ attention_scores = attention_scores + attention_mask
162
+
163
+ # Normalize the attention scores to probabilities
164
+ attention_probs = F.softmax(attention_scores, dim=-1)
165
+
166
+ # Apply dropout
167
+ attention_probs = F.dropout(attention_probs, p=0.1, training=self.training)
168
+
169
+ # Apply attention to values
170
+ context_layer = torch.matmul(attention_probs, value_layer)
171
+
172
+ # Reshape back to [batch_size, seq_length, hidden_size]
173
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
174
+ new_shape = context_layer.size()[:-2] + (self.all_head_size,)
175
+ context_layer = context_layer.view(*new_shape)
176
+
177
+ # Final output projection
178
+ output = self.output(context_layer)
179
+
180
+ return output
181
+
182
+ class EnhancedTransformerLayer(nn.Module):
183
+ """Advanced transformer layer with pre-layer norm and enhanced attention"""
184
+ def __init__(self, config):
185
+ super().__init__()
186
+ self.attention_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
187
+ self.attention = MultiHeadAttention(config)
188
+
189
+ self.ffn_pre_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
190
+
191
+ # Feed-forward network
192
+ self.ffn = nn.Sequential(
193
+ nn.Linear(config["hidden_size"], config["intermediate_size"]),
194
+ nn.GELU(),
195
+ nn.Dropout(config["hidden_dropout_prob"]),
196
+ nn.Linear(config["intermediate_size"], config["hidden_size"]),
197
+ nn.Dropout(config["hidden_dropout_prob"])
198
+ )
199
+
200
+ def forward(self, hidden_states, attention_mask=None):
201
+ # Pre-layer norm for attention
202
+ attn_norm_hidden = self.attention_pre_norm(hidden_states)
203
+
204
+ # Self-attention
205
+ attention_output = self.attention(attn_norm_hidden, attention_mask)
206
+
207
+ # Residual connection
208
+ hidden_states = hidden_states + attention_output
209
+
210
+ # Pre-layer norm for feed-forward
211
+ ffn_norm_hidden = self.ffn_pre_norm(hidden_states)
212
+
213
+ # Feed-forward
214
+ ffn_output = self.ffn(ffn_norm_hidden)
215
+
216
+ # Residual connection
217
+ hidden_states = hidden_states + ffn_output
218
+
219
+ return hidden_states
220
+
221
+ class AdvancedTransformerModel(nn.Module):
222
+ """Advanced Transformer model for inference"""
223
+
224
+ def __init__(self, config):
225
+ super().__init__()
226
+ self.config = config
227
+
228
+ # Embeddings
229
+ self.word_embeddings = nn.Embedding(
230
+ config["vocab_size"],
231
+ config["hidden_size"],
232
+ padding_idx=config["pad_token_id"]
233
+ )
234
+
235
+ # Position embeddings
236
+ self.position_embeddings = nn.Embedding(config["max_position_embeddings"], config["hidden_size"])
237
+
238
+ # Embedding dropout
239
+ self.embedding_dropout = nn.Dropout(config["hidden_dropout_prob"])
240
+
241
+ # Transformer layers
242
+ self.layers = nn.ModuleList([
243
+ EnhancedTransformerLayer(config) for _ in range(config["num_hidden_layers"])
244
+ ])
245
+
246
+ # Final layer norm
247
+ self.final_layer_norm = nn.LayerNorm(config["hidden_size"], eps=config["layer_norm_eps"])
248
+
249
+ def forward(self, input_ids, attention_mask=None):
250
+ input_shape = input_ids.size()
251
+ batch_size, seq_length = input_shape
252
+
253
+ # Get position ids
254
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
255
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
256
+
257
+ # Get embeddings
258
+ word_embeds = self.word_embeddings(input_ids)
259
+ position_embeds = self.position_embeddings(position_ids)
260
+
261
+ # Sum embeddings
262
+ embeddings = word_embeds + position_embeds
263
+
264
+ # Apply dropout
265
+ embeddings = self.embedding_dropout(embeddings)
266
+
267
+ # Default attention mask
268
+ if attention_mask is None:
269
+ attention_mask = torch.ones(input_shape, device=input_ids.device)
270
+
271
+ # Extended attention mask for transformer layers (1 for tokens to attend to, 0 for masked tokens)
272
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
273
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
274
+
275
+ # Apply transformer layers
276
+ hidden_states = embeddings
277
+ for layer in self.layers:
278
+ hidden_states = layer(hidden_states, extended_attention_mask)
279
+
280
+ # Final layer norm
281
+ hidden_states = self.final_layer_norm(hidden_states)
282
+
283
+ return hidden_states
284
+
285
+ class AdvancedPooling(nn.Module):
286
+ """Advanced pooling module supporting multiple pooling strategies"""
287
+ def __init__(self, config):
288
+ super().__init__()
289
+ self.pooling_mode = config["pooling_mode"] # 'mean', 'max', 'cls', 'attention'
290
+ self.hidden_size = config["hidden_size"]
291
+
292
+ # For attention pooling
293
+ if self.pooling_mode == 'attention':
294
+ self.attention_weights = nn.Linear(config["hidden_size"], 1)
295
+
296
+ # For weighted pooling
297
+ elif self.pooling_mode == 'weighted':
298
+ self.weight_layer = nn.Linear(config["hidden_size"], 1)
299
+
300
+ def forward(self, token_embeddings, attention_mask=None):
301
+ if attention_mask is None:
302
+ attention_mask = torch.ones_like(token_embeddings[:, :, 0])
303
+
304
+ mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
305
+
306
+ if self.pooling_mode == 'cls':
307
+ # Use [CLS] token (first token)
308
+ pooled = token_embeddings[:, 0]
309
+
310
+ elif self.pooling_mode == 'max':
311
+ # Max pooling
312
+ token_embeddings = token_embeddings.clone()
313
+ # Set padding tokens to large negative value to exclude them from max
314
+ token_embeddings[mask_expanded == 0] = -1e9
315
+ pooled = torch.max(token_embeddings, dim=1)[0]
316
+
317
+ elif self.pooling_mode == 'attention':
318
+ # Attention pooling
319
+ weights = self.attention_weights(token_embeddings).squeeze(-1)
320
+ # Mask out padding tokens
321
+ weights = weights.masked_fill(attention_mask == 0, -1e9)
322
+ weights = F.softmax(weights, dim=1).unsqueeze(-1)
323
+ pooled = torch.sum(token_embeddings * weights, dim=1)
324
+
325
+ elif self.pooling_mode == 'weighted':
326
+ # Weighted average pooling
327
+ weights = torch.sigmoid(self.weight_layer(token_embeddings)).squeeze(-1)
328
+ # Apply mask
329
+ weights = weights * attention_mask
330
+ # Normalize weights
331
+ sum_weights = torch.sum(weights, dim=1, keepdim=True)
332
+ sum_weights = torch.clamp(sum_weights, min=1e-9)
333
+ weights = weights / sum_weights
334
+ # Apply weights
335
+ pooled = torch.sum(token_embeddings * weights.unsqueeze(-1), dim=1)
336
+
337
+ else: # Default to mean pooling
338
+ # Mean pooling
339
+ sum_embeddings = torch.sum(token_embeddings * mask_expanded, dim=1)
340
+ sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
341
+ pooled = sum_embeddings / sum_mask
342
+
343
+ # L2 normalize
344
+ pooled = F.normalize(pooled, p=2, dim=1)
345
+
346
+ return pooled
347
+
348
+ class SentenceEmbeddingModel(nn.Module):
349
+ """Complete sentence embedding model for inference"""
350
+ def __init__(self, config):
351
+ super(SentenceEmbeddingModel, self).__init__()
352
+ self.config = config
353
+
354
+ # Create transformer model
355
+ self.transformer = AdvancedTransformerModel(config)
356
+
357
+ # Create pooling module
358
+ self.pooling = AdvancedPooling(config)
359
+
360
+ # Build projection module if needed
361
+ if "projection_dim" in config and config["projection_dim"] > 0:
362
+ self.use_projection = True
363
+ self.projection = nn.Sequential(
364
+ nn.Linear(config["hidden_size"], config["hidden_size"]),
365
+ nn.GELU(),
366
+ nn.Linear(config["hidden_size"], config["projection_dim"]),
367
+ nn.LayerNorm(config["projection_dim"], eps=config["layer_norm_eps"])
368
+ )
369
+ else:
370
+ self.use_projection = False
371
+
372
+ def forward(self, input_ids, attention_mask=None):
373
+ # Get token embeddings from transformer
374
+ token_embeddings = self.transformer(input_ids, attention_mask)
375
+
376
+ # Pool token embeddings
377
+ pooled_output = self.pooling(token_embeddings, attention_mask)
378
+
379
+ # Apply projection if enabled
380
+ if self.use_projection:
381
+ pooled_output = self.projection(pooled_output)
382
+ pooled_output = F.normalize(pooled_output, p=2, dim=1)
383
+
384
+ return pooled_output
385
+
386
+ class HindiEmbedder:
387
+ def __init__(self, model_path="/home/ubuntu/output/hindi-embeddings-custom-tokenizer/final"):
388
+ """
389
+ Initialize the Hindi sentence embedder.
390
+
391
+ Args:
392
+ model_path: Path to the model directory
393
+ """
394
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
395
+ print(f"Using device: {self.device}")
396
+
397
+ # Load tokenizer - look for it in the model directory
398
+ tokenizer_path = os.path.join(model_path, "tokenizer.model")
399
+
400
+ if not os.path.exists(tokenizer_path):
401
+ raise FileNotFoundError(f"Could not find tokenizer at {tokenizer_path}")
402
+
403
+ self.tokenizer = SentencePieceTokenizerWrapper(tokenizer_path)
404
+ print(f"Loaded tokenizer from {tokenizer_path} with vocabulary size: {self.tokenizer.vocab_size}")
405
+
406
+ # Load model config
407
+ config_path = os.path.join(model_path, "config.json")
408
+ with open(config_path, "r") as f:
409
+ self.config = json.load(f)
410
+ print(f"Loaded model config with hidden_size={self.config['hidden_size']}")
411
+
412
+ # Load model
413
+ model_pt_path = os.path.join(model_path, "embedding_model.pt")
414
+
415
+ try:
416
+ # Support both PyTorch 2.6+ and older versions
417
+ try:
418
+ checkpoint = torch.load(model_pt_path, map_location=self.device, weights_only=False)
419
+ print("Loaded model using PyTorch 2.6+ style loading")
420
+ except TypeError:
421
+ checkpoint = torch.load(model_pt_path, map_location=self.device)
422
+ print("Loaded model using older PyTorch style loading")
423
+
424
+ # Create model
425
+ self.model = SentenceEmbeddingModel(self.config)
426
+
427
+ # Load state dict
428
+ if "model_state_dict" in checkpoint:
429
+ state_dict = checkpoint["model_state_dict"]
430
+ else:
431
+ state_dict = checkpoint
432
+
433
+ missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)
434
+ print(f"Loaded model with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys")
435
+
436
+ # Move to device
437
+ self.model.to(self.device)
438
+ self.model.eval()
439
+ print("Model loaded successfully and placed in evaluation mode")
440
+
441
+ except Exception as e:
442
+ print(f"Error loading model: {e}")
443
+ raise RuntimeError(f"Failed to load the model: {e}")
444
+
445
+ def encode(self, sentences, batch_size=32, normalize=True):
446
+ """
447
+ Encode sentences to embeddings.
448
+
449
+ Args:
450
+ sentences: A string or list of strings to encode
451
+ batch_size: Batch size for encoding
452
+ normalize: Whether to normalize the embeddings
453
+
454
+ Returns:
455
+ Numpy array of embeddings
456
+ """
457
+ # Handle single sentence
458
+ if isinstance(sentences, str):
459
+ sentences = [sentences]
460
+
461
+ all_embeddings = []
462
+
463
+ # Process in batches
464
+ with torch.no_grad():
465
+ for i in range(0, len(sentences), batch_size):
466
+ batch = sentences[i:i+batch_size]
467
+
468
+ # Tokenize
469
+ inputs = self.tokenizer(
470
+ batch,
471
+ padding=True,
472
+ truncation=True,
473
+ max_length=self.config.get("max_position_embeddings", 128),
474
+ return_tensors="pt"
475
+ )
476
+
477
+ # Move to device
478
+ input_ids = inputs["input_ids"].to(self.device)
479
+ attention_mask = inputs["attention_mask"].to(self.device)
480
+
481
+ # Get embeddings
482
+ embeddings = self.model(input_ids, attention_mask)
483
+
484
+ # Move to CPU and convert to numpy
485
+ all_embeddings.append(embeddings.cpu().numpy())
486
+
487
+ # Concatenate all embeddings
488
+ all_embeddings = np.vstack(all_embeddings)
489
+
490
+ # Normalize if requested
491
+ if normalize:
492
+ all_embeddings = all_embeddings / np.linalg.norm(all_embeddings, axis=1, keepdims=True)
493
+
494
+ return all_embeddings
495
+
496
+ def compute_similarity(self, texts1, texts2=None):
497
+ """
498
+ Compute cosine similarity between texts.
499
+
500
+ Args:
501
+ texts1: First set of texts
502
+ texts2: Second set of texts. If None, compute similarity matrix within texts1.
503
+
504
+ Returns:
505
+ Similarity scores
506
+ """
507
+ # Convert single strings to lists for consistent handling
508
+ if isinstance(texts1, str):
509
+ texts1 = [texts1]
510
+
511
+ if texts2 is not None and isinstance(texts2, str):
512
+ texts2 = [texts2]
513
+
514
+ embeddings1 = self.encode(texts1)
515
+
516
+ if texts2 is None:
517
+ # Compute similarity matrix within texts1
518
+ similarities = cosine_similarity(embeddings1)
519
+ return similarities
520
+ else:
521
+ # Compute similarity between texts1 and texts2
522
+ embeddings2 = self.encode(texts2)
523
+
524
+ if len(texts1) == len(texts2):
525
+ # Compute pairwise similarity when the number of texts match
526
+ similarities = np.array([
527
+ cosine_similarity([e1], [e2])[0][0]
528
+ for e1, e2 in zip(embeddings1, embeddings2)
529
+ ])
530
+
531
+ # If there's just one pair, return a scalar
532
+ if len(similarities) == 1:
533
+ return similarities[0]
534
+ return similarities
535
+ else:
536
+ # Return full similarity matrix
537
+ return cosine_similarity(embeddings1, embeddings2)
538
+
539
+ def search(self, query, documents, top_k=5):
540
+ """
541
+ Search for similar documents to a query.
542
+
543
+ Args:
544
+ query: The query text
545
+ documents: List of documents to search
546
+ top_k: Number of top results to return
547
+
548
+ Returns:
549
+ List of dictionaries with document and score
550
+ """
551
+ # Get embeddings
552
+ query_embedding = self.encode([query])[0]
553
+ document_embeddings = self.encode(documents)
554
+
555
+ # Compute similarities
556
+ similarities = np.dot(document_embeddings, query_embedding)
557
+
558
+ # Get top indices
559
+ top_indices = np.argsort(similarities)[-top_k:][::-1]
560
+
561
+ # Return results
562
+ results = []
563
+ for idx in top_indices:
564
+ results.append({
565
+ "document": documents[idx],
566
+ "score": float(similarities[idx])
567
+ })
568
+
569
+ return results
570
+
571
+ def evaluate_similarity_samples(self):
572
+ """Evaluate model on some standard similarity examples for Hindi"""
573
+ test_pairs = [
574
+ (
575
+ "मुझे हिंदी में पढ़ना बहुत पसंद है।",
576
+ "मैं हिंदी किताबें बहुत पसंद करता हूँ।"
577
+ ),
578
+ (
579
+ "आज मौसम बहुत अच्छा है।",
580
+ "आज बारिश हो रही है।"
581
+ ),
582
+ (
583
+ "भारत एक विशाल देश है।",
584
+ "भारत में कई भाषाएँ बोली जाती हैं।"
585
+ ),
586
+ (
587
+ "कंप्यूटर विज्ञान एक रोचक विषय है।",
588
+ "मैं कंप्यूटर साइंस का छात्र हूँ।"
589
+ ),
590
+ (
591
+ "मैं रोज सुबह योग करता हूँ।",
592
+ "स्वस्थ रहने के लिए व्यायाम जरूरी है।"
593
+ ),
594
+ # Add contrasting pairs to test discrimination
595
+ (
596
+ "मुझे हिंदी में पढ़ना बहुत पसंद है।",
597
+ "क्रिकेट भारत में सबसे लोकप्रिय खेल है।"
598
+ ),
599
+ (
600
+ "आज मौसम बहुत अच्छा है।",
601
+ "भारतीय व्यंजन दुनिया भर में मशहूर हैं।"
602
+ ),
603
+ (
604
+ "कंप्यूटर विज्ञान एक रोचक विषय है।",
605
+ "हिमालय दुनिया का सबसे ऊंचा पर्वत है।"
606
+ )
607
+ ]
608
+
609
+ print("Evaluating model on standard similarity samples:")
610
+ for i, (text1, text2) in enumerate(test_pairs):
611
+ similarity = self.compute_similarity([text1], [text2])[0]
612
+ print(f"\nPair {i+1}:")
613
+ print(f" Sentence 1: {text1}")
614
+ print(f" Sentence 2: {text2}")
615
+ print(f" Similarity: {similarity:.4f}")
616
+
617
+ return
618
+
619
+ def visualize_embeddings(self, sentences, labels=None, output_path="hindi_embeddings_visualization.png"):
620
+ """
621
+ Create a t-SNE visualization of the embeddings.
622
+
623
+ Args:
624
+ sentences: List of sentences to visualize
625
+ labels: Optional list of labels for the points
626
+ output_path: Path to save the visualization
627
+
628
+ Returns:
629
+ Path to the saved visualization
630
+ """
631
+ # Encode sentences
632
+ embeddings = self.encode(sentences)
633
+
634
+ # Apply t-SNE
635
+ tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, len(embeddings)-1))
636
+ reduced_embeddings = tsne.fit_transform(embeddings)
637
+
638
+ # Create plot
639
+ plt.figure(figsize=(12, 10))
640
+
641
+ # Plot points
642
+ scatter = plt.scatter(
643
+ reduced_embeddings[:, 0],
644
+ reduced_embeddings[:, 1],
645
+ c=range(len(reduced_embeddings)),
646
+ cmap='viridis',
647
+ alpha=0.8,
648
+ s=100
649
+ )
650
+
651
+ # Add labels if provided
652
+ if labels:
653
+ for i, label in enumerate(labels):
654
+ plt.annotate(
655
+ label,
656
+ (reduced_embeddings[i, 0], reduced_embeddings[i, 1]),
657
+ fontsize=10,
658
+ alpha=0.7
659
+ )
660
+
661
+ plt.title("t-SNE Visualization of Hindi Sentence Embeddings", fontsize=16)
662
+ plt.xlabel("Dimension 1", fontsize=12)
663
+ plt.ylabel("Dimension 2", fontsize=12)
664
+ plt.colorbar(scatter, label="Sentence Index")
665
+ plt.grid(alpha=0.3)
666
+
667
+ # Save the figure
668
+ plt.tight_layout()
669
+ plt.savefig(output_path, dpi=300, bbox_inches='tight')
670
+ plt.close()
671
+
672
+ print(f"Visualization saved to {output_path}")
673
+ return output_path
674
+
675
+ def main():
676
+ # Create embedder
677
+ embedder = HindiEmbedder()
678
+
679
+ # Run sample evaluation
680
+ embedder.evaluate_similarity_samples()
681
+
682
+ # Example of semantic search
683
+ print("\nSemantic Search Example:")
684
+ query = "भारत की संस्कृति"
685
+ documents = [
686
+ "भारतीय संस्कृति दुनिया की सबसे प्राचीन संस्कृतियों में से एक है।",
687
+ "भारत की आबादी 1.3 अरब से अधिक है।",
688
+ "हिमालय पर्वत श्रृंखला भारत के उत्तर में स्थित है।",
689
+ "भारतीय व्यंजन में मसालों का प्रयोग किया जाता है।",
690
+ "भारत में 22 आधिकारिक भाषाएँ हैं।",
691
+ "संस्कृति लोगों के रहन-सहन का तरीका है।",
692
+ "भारत के विभिन्न राज्यों की अपनी अलग संस्कृति है।",
693
+ "रामायण और महाभारत भारतीय संस्कृति के महत्वपूर्ण हिस्से हैं।",
694
+ ]
695
+
696
+ results = embedder.search(query, documents)
697
+
698
+ print(f"Query: {query}")
699
+ print("Top results:")
700
+ for i, result in enumerate(results):
701
+ print(f"{i+1}. Score: {result['score']:.4f}")
702
+ print(f" {result['document']}")
703
+
704
+ # Create visualization example
705
+ print("\nCreating embedding visualization...")
706
+ visualization_sentences = [
707
+ "मुझे हिंदी में पढ़ना बहुत पसंद है।",
708
+ "मैं हिंदी किताबें बहुत पसंद करता हूँ।",
709
+ "आज मौसम बहुत अच्छा है।",
710
+ "आज बारिश हो रही है।",
711
+ "भारत एक विशाल देश है।",
712
+ "भारत में कई भाषाएँ बोली जाती हैं।",
713
+ "कंप्यूटर विज्ञान एक रोचक विषय है।",
714
+ "मैं कंप्यूटर साइंस का छात्र हूँ।",
715
+ "क्रिकेट भारत में सबसे लोकप्रिय खेल है।",
716
+ "भारतीय व्यंजन दुनिया भर में मशहूर हैं।",
717
+ "हिमालय दुनिया का सबसे ऊंचा पर्वत है।",
718
+ "गंगा भारत की सबसे पवित्र नदी है।",
719
+ "दिल्ली भारत की राजधानी है।",
720
+ "मुंबई भारत का आर्थिक केंद्र है।",
721
+ "तमिल, तेलुगु, कन्नड़ और मलयालम दक्षिण भारत की प्रमुख भाषाएँ हैं।"
722
+ ]
723
+
724
+ labels = ["पढ़ना", "किताबें", "मौसम", "बारिश", "भारत", "भाषाएँ", "क��प्यूटर",
725
+ "छात्र", "क्रिकेट", "व्यंजन", "हिमालय", "गंगा", "दिल्ली", "मुंबई", "भाषाएँ"]
726
+
727
+ embedder.visualize_embeddings(visualization_sentences, labels)
728
+
729
+ if __name__ == "__main__":
730
+ main()