Get model working (basically)
This commit is contained in:
parent
5dd850d1cb
commit
fe7870e9d4
@ -13,6 +13,7 @@ import csv
|
|||||||
import random
|
import random
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import itertools
|
||||||
#from pandarallel import pandarallel
|
#from pandarallel import pandarallel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
# torch
|
# torch
|
||||||
@ -23,11 +24,14 @@ from torchtext.vocab import build_vocab_from_iterator
|
|||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
story_num = 40 # XXX None for all
|
from models.rnn import RNN
|
||||||
|
|
||||||
|
story_num = 64 # XXX None for all
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
EPOCHS = 10 # epoch
|
EPOCHS = 10 # epoch
|
||||||
LR = 5 # learning rate
|
#LR = 5 # learning rate
|
||||||
|
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||||
BATCH_SIZE = 64 # batch size for training
|
BATCH_SIZE = 64 # batch size for training
|
||||||
|
|
||||||
def read_csv(input_csv, rows=None, verbose=0):
|
def read_csv(input_csv, rows=None, verbose=0):
|
||||||
@ -135,7 +139,7 @@ class TextCategoriesDataset(Dataset):
|
|||||||
self.cats_vocab = build_vocab_from_iterator(
|
self.cats_vocab = build_vocab_from_iterator(
|
||||||
[self.catTokens(cats) for i, cats in self.df[cats_column].items()],
|
[self.catTokens(cats) for i, cats in self.df[cats_column].items()],
|
||||||
min_freq=1,
|
min_freq=1,
|
||||||
specials= self.itos.values(),
|
specials=['<unk>'],
|
||||||
special_first=True
|
special_first=True
|
||||||
)
|
)
|
||||||
self.cats_vocab.set_default_index(self.cats_vocab['<unk>'])
|
self.cats_vocab.set_default_index(self.cats_vocab['<unk>'])
|
||||||
@ -162,8 +166,8 @@ class TextCategoriesDataset(Dataset):
|
|||||||
|
|
||||||
# Numericalise by applying transforms
|
# Numericalise by applying transforms
|
||||||
return (
|
return (
|
||||||
self.getTransform(self.text_vocab)(self.textTokens(text)),
|
self.getTransform(self.text_vocab, "text")(self.textTokens(text)),
|
||||||
self.getTransform(self.cats_vocab)(self.catTokens(cats)),
|
self.getTransform(self.cats_vocab, "cats")(self.catTokens(cats)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -178,26 +182,32 @@ class TextCategoriesDataset(Dataset):
|
|||||||
elif isinstance(cats, list):
|
elif isinstance(cats, list):
|
||||||
return [cat for cat in cats]
|
return [cat for cat in cats]
|
||||||
|
|
||||||
def getTransform(self, vocab):
|
def getTransform(self, vocab, vType):
|
||||||
'''
|
'''
|
||||||
Create transforms based on given vocabulary. The returned transform
|
Create transforms based on given vocabulary. The returned transform
|
||||||
is applied to a sequence of tokens.
|
is applied to a sequence of tokens.
|
||||||
'''
|
'''
|
||||||
|
if vType == "text":
|
||||||
return T.Sequential(
|
return T.Sequential(
|
||||||
# converts the sentences to indices based on given vocabulary
|
# converts the sentences to indices based on given vocabulary
|
||||||
T.VocabTransform(vocab=vocab),
|
T.VocabTransform(vocab=vocab),
|
||||||
# Add <sos> at beginning of each sentence. 1 because the index
|
# Add <sos> at beginning of each sentence. 1 because the index
|
||||||
# for <sos> in vocabulary is 1 as seen in previous section
|
# for <sos> in vocabulary is 1 as seen in previous section
|
||||||
T.AddToken(1, begin=True),
|
T.AddToken(self.text_vocab['<sos>'], begin=True),
|
||||||
# Add <eos> at beginning of each sentence. 2 because the index
|
# Add <eos> at end of each sentence. 2 because the index
|
||||||
# for <eos> in vocabulary is 2 as seen in previous section
|
# for <eos> in vocabulary is 2 as seen in previous section
|
||||||
T.AddToken(2, begin=False)
|
T.AddToken(self.text_vocab['<eos>'], begin=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return T.Sequential(
|
||||||
|
# converts the sentences to indices based on given vocabulary
|
||||||
|
T.VocabTransform(vocab=vocab),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Now that we have a dataset, let's create dataloader,
|
Now that we have a dataset, let's create a dataloader callback;
|
||||||
which can batch, shuffle, and load the data in parallel
|
the dataloader can batch, shuffle, and load the data in parallel
|
||||||
'''
|
'''
|
||||||
|
|
||||||
class CollateBatch:
|
class CollateBatch:
|
||||||
@ -207,41 +217,105 @@ class CollateBatch:
|
|||||||
which returns a tensor
|
which returns a tensor
|
||||||
'''
|
'''
|
||||||
def __init__(self, pad_idx):
|
def __init__(self, pad_idx):
|
||||||
|
'''
|
||||||
|
pad_idx (int): the index of the "<pad>" token in the vocabulary.
|
||||||
|
'''
|
||||||
self.pad_idx = pad_idx
|
self.pad_idx = pad_idx
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
# T.ToTensor(0) returns a transform that converts the sequence
|
'''
|
||||||
# to a torch.tensor and also applies padding.
|
batch: a list of tuples with (text, cats), each of which
|
||||||
#
|
is a list of tokens
|
||||||
# pad_idx is passed to the constructor to specify the index of
|
'''
|
||||||
# the "<pad>" token in the vocabulary.
|
batch_text, batch_cats = zip(*batch)
|
||||||
|
#for i in range(len(batch)):
|
||||||
|
# print(batch[i])
|
||||||
|
#max_text_len = len(max(batch_text, key=len))
|
||||||
|
#max_cats_len = len(max(batch_cats, key=len))
|
||||||
|
|
||||||
|
#text_tensor = T.ToTensor(self.pad_idx)(batch_text)
|
||||||
|
#cats_tensor = T.ToTensor(self.pad_idx)(batch_cats)
|
||||||
|
|
||||||
|
# Pad text to the longest
|
||||||
|
text_tensor = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
[torch.LongTensor(s) for s in batch_text],
|
||||||
|
batch_first=True, padding_value=self.pad_idx
|
||||||
|
)
|
||||||
|
text_lengths = torch.tensor([t.shape[0] for t in text_tensor])
|
||||||
|
|
||||||
|
#cats_tensor = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
# [torch.LongTensor(s) for s in batch_cats],
|
||||||
|
# batch_first=True, padding_value=self.pad_idx
|
||||||
|
#)
|
||||||
|
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
||||||
|
|
||||||
|
# Pad cats_tensor to all possible categories
|
||||||
|
# TODO will this be necessary with larger training sets, that should
|
||||||
|
# encompass all categories? Best to be safe...
|
||||||
|
all_cats = list(set(itertools.chain(*batch_cats)))
|
||||||
|
num_cats = len(all_cats)
|
||||||
|
# if there's no 0, there was no <unk>, so increment to allow for it to be a possible category
|
||||||
|
if 0 not in all_cats:
|
||||||
|
num_cats += 1
|
||||||
|
cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).long()
|
||||||
|
cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
||||||
|
for idx, (c, clen) in enumerate(zip(batch_cats, cats_lengths)):
|
||||||
|
cats_tensor[idx, :clen] = torch.LongTensor(c)
|
||||||
|
|
||||||
|
# XXX why??
|
||||||
|
## SORT YOUR TENSORS BY LENGTH!
|
||||||
|
text_lengths, perm_idx = text_lengths.sort(0, descending=True)
|
||||||
|
text_tensor = text_tensor[perm_idx]
|
||||||
|
cats_tensor = cats_tensor[perm_idx]
|
||||||
|
|
||||||
|
#print(text_tensor)
|
||||||
|
#print("text shape:", text_tensor.shape)
|
||||||
|
#print(cats_tensor)
|
||||||
|
#print("cats shape:", cats_tensor.shape)
|
||||||
|
#print(text_lengths)
|
||||||
|
#print("text_lengths shape:", text_lengths.shape)
|
||||||
|
|
||||||
|
#sys.exit(0)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
T.ToTensor(self.pad_idx)(list(batch[0])),
|
text_tensor,
|
||||||
T.ToTensor(self.pad_idx)(list(batch[1])),
|
cats_tensor,
|
||||||
|
text_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TextClassificationModel(nn.Module):
|
def train(dataloader, model, optimizer, criterion):
|
||||||
def __init__(self, input_size, output_size, verbose):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def train(dataloader):
|
|
||||||
model.train()
|
model.train()
|
||||||
total_acc, total_count = 0, 0
|
total_acc, total_count = 0, 0
|
||||||
log_interval = 500
|
log_interval = 500
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
for idx, (label, text) in enumerate(dataloader):
|
for idx, (text, cats, text_lengths) in enumerate(dataloader):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
predicted_label = model(text)
|
|
||||||
loss = criterion(predicted_label, label)
|
print("text_lengths shape", text_lengths.shape)
|
||||||
|
print("input shape", text.shape)
|
||||||
|
print("target", cats)
|
||||||
|
print("target shape", cats.shape)
|
||||||
|
|
||||||
|
output = model(text, text_lengths)
|
||||||
|
print("output", output)
|
||||||
|
print("output shape", output.shape)
|
||||||
|
# reshape output and target for cross entropy loss
|
||||||
|
# output = output.reshape(output.size(0)*output.size(1), -1) # (batch * seq_len x classes)
|
||||||
|
# cats = cats.reshape(-1) # (batch * seq_len), class index
|
||||||
|
# print("output", output)
|
||||||
|
# print("output shape", output.shape)
|
||||||
|
# print("target shape", cats.shape)
|
||||||
|
# print()
|
||||||
|
|
||||||
|
loss = criterion(input=output, target=cats)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
total_acc += (predicted_label.argmax(1) == label).sum().item()
|
total_acc += (predicted_label.argmax(1) == label).sum().item()
|
||||||
total_count += label.size(0)
|
total_count += label.size(0)
|
||||||
if idx % log_interval == 0 and idx > 0:
|
if idx % log_interval == 0 and idx > 0:
|
||||||
@ -256,7 +330,7 @@ def train(dataloader):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
def evaluate(dataloader):
|
def evaluate(dataloader, model, criterion):
|
||||||
model.eval()
|
model.eval()
|
||||||
total_acc, total_count = 0, 0
|
total_acc, total_count = 0, 0
|
||||||
|
|
||||||
@ -324,7 +398,8 @@ def main():
|
|||||||
lang_column="language",
|
lang_column="language",
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
)
|
)
|
||||||
#print(dataset[2])
|
#for text, cat in enumerate(train_dataset):
|
||||||
|
# print(text, cat)
|
||||||
#for text, cat in enumerate(valid_dataset):
|
#for text, cat in enumerate(valid_dataset):
|
||||||
# print(text, cat)
|
# print(text, cat)
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
@ -361,24 +436,59 @@ def main():
|
|||||||
)
|
)
|
||||||
#for i_batch, sample_batched in enumerate(dataloader):
|
#for i_batch, sample_batched in enumerate(dataloader):
|
||||||
# print(i_batch, sample_batched[0], sample_batched[1])
|
# print(i_batch, sample_batched[0], sample_batched[1])
|
||||||
|
#for i_batch, sample_batched in enumerate(train_dataloader):
|
||||||
|
#print(i_batch, sample_batched[0], sample_batched[1])
|
||||||
|
#print(i_batch)
|
||||||
|
#print("batch elements:")
|
||||||
|
#for i in sample_batched:
|
||||||
|
# print(i)
|
||||||
|
# print(i.shape)
|
||||||
|
# print("\n")
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
|
|
||||||
num_class = len(set([cats for key, cats, text, lang in train_data.values]))
|
|
||||||
input_size = len(train_dataset.text_vocab)
|
input_size = len(train_dataset.text_vocab)
|
||||||
output_size = len(train_dataset.cats_vocab)
|
output_size = len(train_dataset.cats_vocab) # every output item is the likelihood of a particular category
|
||||||
emsize = 64
|
|
||||||
model = TextClassificationModel(input_size, output_size, args.verbose).to(device)
|
|
||||||
|
|
||||||
|
embed = torch.empty(input_size, len(train_dataset)) # tokens per sample x samples
|
||||||
|
embedding_size = embed.size(1) # was 64 (should be: samples)
|
||||||
|
num_layers = 2 # 2-3 layers should be enough for LTSM
|
||||||
|
hidden_size = 128 # hidden size of rnn module, should be tweaked manually
|
||||||
|
mean_seq = True # use mean of rnn output
|
||||||
|
weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
||||||
|
|
||||||
|
#for i in train_dataset.text_vocab.get_itos():
|
||||||
|
# print(i)
|
||||||
|
print("input_size: ", input_size)
|
||||||
|
print("output_size:", output_size)
|
||||||
|
print("embed shape:", embed.shape)
|
||||||
|
print("embedding_size:", embedding_size, " (that is, number of samples)")
|
||||||
|
|
||||||
|
model = RNN(
|
||||||
|
#rnn_model='GRU',
|
||||||
|
rnn_model='LSTM',
|
||||||
|
vocab_size=input_size,
|
||||||
|
embed_size=embedding_size,
|
||||||
|
num_output=output_size,
|
||||||
|
use_last=(not mean_seq),
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
embedding_tensor=embed,
|
||||||
|
num_layers=num_layers,
|
||||||
|
batch_first=True
|
||||||
|
)
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
# optimizer and loss
|
||||||
|
#optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay)
|
||||||
|
print(criterion)
|
||||||
|
print(optimizer)
|
||||||
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
|
||||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
|
|
||||||
total_accu = None
|
total_accu = None
|
||||||
|
|
||||||
for epoch in range(1, EPOCHS + 1):
|
for epoch in range(1, EPOCHS + 1):
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
train(train_dataloader)
|
train(train_dataloader, model, optimizer, criterion)
|
||||||
accu_val = evaluate(valid_dataloader)
|
accu_val = evaluate(valid_dataloader, model, criterion)
|
||||||
if total_accu is not None and total_accu > accu_val:
|
if total_accu is not None and total_accu > accu_val:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
else:
|
else:
|
||||||
|
47
africat/models/classifier.py
Normal file
47
africat/models/classifier.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class RNN(nn.Module):
|
||||||
|
#define all the layers used in model
|
||||||
|
def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim,
|
||||||
|
n_layers, bidirectional, dropout):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
#embedding layer
|
||||||
|
self.embedding = nn.Embedding(vocab_size, embedding_dim)
|
||||||
|
|
||||||
|
#lstm layer
|
||||||
|
self.lstm = nn.LSTM(embedding_dim,
|
||||||
|
hidden_dim,
|
||||||
|
num_layers=n_layers,
|
||||||
|
bidirectional=bidirectional,
|
||||||
|
dropout=dropout,
|
||||||
|
batch_first=True)
|
||||||
|
|
||||||
|
#dense layer
|
||||||
|
self.fc = nn.Linear(hidden_dim * 2, output_dim)
|
||||||
|
|
||||||
|
#activation function
|
||||||
|
self.act = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, text, text_lengths):
|
||||||
|
#text = [batch size,sent_length]
|
||||||
|
embedded = self.embedding(text)
|
||||||
|
#embedded = [batch size, sent_len, emb dim]
|
||||||
|
|
||||||
|
#packed sequence
|
||||||
|
packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths, batch_first=True)
|
||||||
|
|
||||||
|
packed_output, (hidden, cell) = self.lstm(packed_embedded)
|
||||||
|
#hidden = [batch size, num layers * num directions,hid dim]
|
||||||
|
#cell = [batch size, num layers * num directions,hid dim]
|
||||||
|
|
||||||
|
#concat the final forward and backward hidden state
|
||||||
|
hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
|
||||||
|
|
||||||
|
#hidden = [batch size, hid dim * num directions]
|
||||||
|
dense_outputs=self.fc(hidden)
|
||||||
|
|
||||||
|
#Final activation function
|
||||||
|
outputs=self.act(dense_outputs)
|
||||||
|
|
||||||
|
return outputs
|
14
africat/models/multiclass.py
Normal file
14
africat/models/multiclass.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class Multiclass(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden = nn.Linear(4, 8)
|
||||||
|
self.act = nn.ReLU()
|
||||||
|
self.output = nn.Linear(8, 3)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.act(self.hidden(x))
|
||||||
|
x = self.output(x)
|
||||||
|
return x
|
84
africat/models/rnn.py
Normal file
84
africat/models/rnn.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||||
|
|
||||||
|
|
||||||
|
class RNN(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, vocab_size, embed_size, num_output, rnn_model='LSTM', use_last=True, embedding_tensor=None,
|
||||||
|
padding_index=0, hidden_size=64, num_layers=1, batch_first=True):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size: vocab size
|
||||||
|
embed_size: embedding size
|
||||||
|
num_output: number of output (classes)
|
||||||
|
rnn_model: LSTM or GRU
|
||||||
|
use_last: bool
|
||||||
|
embedding_tensor:
|
||||||
|
padding_index:
|
||||||
|
hidden_size: hidden size of rnn module
|
||||||
|
num_layers: number of layers in rnn module
|
||||||
|
batch_first: batch first option
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(RNN, self).__init__()
|
||||||
|
self.use_last = use_last
|
||||||
|
# embedding
|
||||||
|
self.encoder = None
|
||||||
|
if torch.is_tensor(embedding_tensor):
|
||||||
|
self.encoder = nn.Embedding(vocab_size, embed_size, padding_idx=padding_index, _weight=embedding_tensor)
|
||||||
|
self.encoder.weight.requires_grad = False
|
||||||
|
else:
|
||||||
|
self.encoder = nn.Embedding(vocab_size, embed_size, padding_idx=padding_index)
|
||||||
|
|
||||||
|
self.drop_en = nn.Dropout(p=0.6)
|
||||||
|
|
||||||
|
# rnn module
|
||||||
|
if rnn_model == 'LSTM':
|
||||||
|
self.rnn = nn.LSTM( input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, dropout=0.5,
|
||||||
|
batch_first=True, bidirectional=True)
|
||||||
|
elif rnn_model == 'GRU':
|
||||||
|
self.rnn = nn.GRU( input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, dropout=0.5,
|
||||||
|
batch_first=True, bidirectional=True)
|
||||||
|
else:
|
||||||
|
raise LookupError(' only support LSTM and GRU')
|
||||||
|
|
||||||
|
|
||||||
|
self.bn2 = nn.BatchNorm1d(hidden_size*2)
|
||||||
|
self.fc = nn.Linear(hidden_size*2, num_output)
|
||||||
|
|
||||||
|
def forward(self, x, seq_lengths):
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
x: (batch, time_step, input_size)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
num_output size
|
||||||
|
'''
|
||||||
|
|
||||||
|
x_embed = self.encoder(x)
|
||||||
|
x_embed = self.drop_en(x_embed)
|
||||||
|
packed_input = pack_padded_sequence(x_embed, seq_lengths.cpu().numpy(),batch_first=True)
|
||||||
|
|
||||||
|
# r_out shape (batch, time_step, output_size)
|
||||||
|
# None is for initial hidden state
|
||||||
|
packed_output, ht = self.rnn(packed_input, None)
|
||||||
|
out_rnn, _ = pad_packed_sequence(packed_output, batch_first=True)
|
||||||
|
|
||||||
|
row_indices = torch.arange(0, x.size(0)).long()
|
||||||
|
col_indices = seq_lengths - 1
|
||||||
|
if next(self.parameters()).is_cuda:
|
||||||
|
row_indices = row_indices.cuda()
|
||||||
|
col_indices = col_indices.cuda()
|
||||||
|
|
||||||
|
if self.use_last:
|
||||||
|
last_tensor=out_rnn[row_indices, col_indices, :]
|
||||||
|
else:
|
||||||
|
# use mean
|
||||||
|
last_tensor = out_rnn[row_indices, :, :]
|
||||||
|
last_tensor = torch.mean(last_tensor, dim=1)
|
||||||
|
|
||||||
|
fc_input = self.bn2(last_tensor)
|
||||||
|
out = self.fc(fc_input)
|
||||||
|
return out
|
Loading…
Reference in New Issue
Block a user