48 lines
1.4 KiB
Python
48 lines
1.4 KiB
Python
import torch.nn as nn
|
|
|
|
class RNN(nn.Module):
|
|
#define all the layers used in model
|
|
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim,
|
|
n_layers, bidirectional, dropout):
|
|
super().__init__()
|
|
|
|
#embedding layer
|
|
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
|
|
|
#lstm layer
|
|
self.lstm = nn.LSTM(embedding_dim,
|
|
hidden_dim,
|
|
num_layers=n_layers,
|
|
bidirectional=bidirectional,
|
|
dropout=dropout,
|
|
batch_first=True)
|
|
|
|
#dense layer
|
|
self.fc = nn.Linear(hidden_dim * 2, output_dim)
|
|
|
|
#activation function
|
|
self.act = nn.Sigmoid()
|
|
|
|
def forward(self, text, text_lengths):
|
|
#text = [batch size,sent_length]
|
|
embedded = self.embedding(text)
|
|
#embedded = [batch size, sent_len, emb dim]
|
|
|
|
#packed sequence
|
|
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths, batch_first=True)
|
|
|
|
packed_output, (hidden, cell) = self.lstm(packed_embedded)
|
|
#hidden = [batch size, num layers * num directions,hid dim]
|
|
#cell = [batch size, num layers * num directions,hid dim]
|
|
|
|
#concat the final forward and backward hidden state
|
|
hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
|
|
|
|
#hidden = [batch size, hid dim * num directions]
|
|
dense_outputs=self.fc(hidden)
|
|
|
|
#Final activation function
|
|
outputs=self.act(dense_outputs)
|
|
|
|
return outputs
|