RabiaSufian's picture
Update model.py
331b272 verified
import torch.nn as nn
import torch
class LSTMClassifier(nn.Module):
def __init__(self, input_size=1, hidden_size=64, num_layers=1,
bidirectional=True, dropout=0.0, num_classes=2):
super(LSTMClassifier, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.lstm = nn.LSTM(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
dropout=dropout if num_layers > 1 else 0.0,
bidirectional=bidirectional
)
direction_factor = 2 if bidirectional else 1
self.fc = nn.Linear(hidden_size * direction_factor, num_classes)
def forward(self, x):
_, (hn, _) = self.lstm(x)
if self.bidirectional:
forward = hn[-2]
backward = hn[-1]
combined = torch.cat((forward, backward), dim=1)
else:
combined = hn[-1]
return self.fc(combined)