diff --git a/africat/categorise.py b/africat/categorise.py index 479d08b..231bfdc 100755 --- a/africat/categorise.py +++ b/africat/categorise.py @@ -13,6 +13,7 @@ import csv import random import pandas as pd import numpy as np +import itertools #from pandarallel import pandarallel from tqdm import tqdm # torch @@ -23,11 +24,14 @@ from torchtext.vocab import build_vocab_from_iterator from torch.utils.data import Dataset, DataLoader from torch import nn -story_num = 40 # XXX None for all +from models.rnn import RNN + +story_num = 64 # XXX None for all # Hyperparameters EPOCHS = 10 # epoch -LR = 5 # learning rate +#LR = 5 # learning rate +LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate BATCH_SIZE = 64 # batch size for training def read_csv(input_csv, rows=None, verbose=0): @@ -126,7 +130,7 @@ class TextCategoriesDataset(Dataset): self.text_vocab = build_vocab_from_iterator( [self.textTokens(text) for i, text in self.df[text_column].items()], min_freq=2, - specials= self.itos.values(), + specials=self.itos.values(), special_first=True ) self.text_vocab.set_default_index(self.text_vocab['']) @@ -135,7 +139,7 @@ class TextCategoriesDataset(Dataset): self.cats_vocab = build_vocab_from_iterator( [self.catTokens(cats) for i, cats in self.df[cats_column].items()], min_freq=1, - specials= self.itos.values(), + specials=[''], special_first=True ) self.cats_vocab.set_default_index(self.cats_vocab['']) @@ -162,8 +166,8 @@ class TextCategoriesDataset(Dataset): # Numericalise by applying transforms return ( - self.getTransform(self.text_vocab)(self.textTokens(text)), - self.getTransform(self.cats_vocab)(self.catTokens(cats)), + self.getTransform(self.text_vocab, "text")(self.textTokens(text)), + self.getTransform(self.cats_vocab, "cats")(self.catTokens(cats)), ) @staticmethod @@ -178,26 +182,32 @@ class TextCategoriesDataset(Dataset): elif isinstance(cats, list): return [cat for cat in cats] - def getTransform(self, vocab): + def getTransform(self, vocab, vType): ''' Create transforms based on given vocabulary. The returned transform is applied to a sequence of tokens. ''' - return T.Sequential( - # converts the sentences to indices based on given vocabulary - T.VocabTransform(vocab=vocab), - # Add at beginning of each sentence. 1 because the index - # for in vocabulary is 1 as seen in previous section - T.AddToken(1, begin=True), - # Add at beginning of each sentence. 2 because the index - # for in vocabulary is 2 as seen in previous section - T.AddToken(2, begin=False) - ) + if vType == "text": + return T.Sequential( + # converts the sentences to indices based on given vocabulary + T.VocabTransform(vocab=vocab), + # Add at beginning of each sentence. 1 because the index + # for in vocabulary is 1 as seen in previous section + T.AddToken(self.text_vocab[''], begin=True), + # Add at end of each sentence. 2 because the index + # for in vocabulary is 2 as seen in previous section + T.AddToken(self.text_vocab[''], begin=False) + ) + else: + return T.Sequential( + # converts the sentences to indices based on given vocabulary + T.VocabTransform(vocab=vocab), + ) ''' - Now that we have a dataset, let's create dataloader, - which can batch, shuffle, and load the data in parallel + Now that we have a dataset, let's create a dataloader callback; + the dataloader can batch, shuffle, and load the data in parallel ''' class CollateBatch: @@ -207,41 +217,105 @@ class CollateBatch: which returns a tensor ''' def __init__(self, pad_idx): + ''' + pad_idx (int): the index of the "" token in the vocabulary. + ''' self.pad_idx = pad_idx def __call__(self, batch): - # T.ToTensor(0) returns a transform that converts the sequence - # to a torch.tensor and also applies padding. - # - # pad_idx is passed to the constructor to specify the index of - # the "" token in the vocabulary. + ''' + batch: a list of tuples with (text, cats), each of which + is a list of tokens + ''' + batch_text, batch_cats = zip(*batch) + #for i in range(len(batch)): + # print(batch[i]) + #max_text_len = len(max(batch_text, key=len)) + #max_cats_len = len(max(batch_cats, key=len)) + + #text_tensor = T.ToTensor(self.pad_idx)(batch_text) + #cats_tensor = T.ToTensor(self.pad_idx)(batch_cats) + + # Pad text to the longest + text_tensor = torch.nn.utils.rnn.pad_sequence( + [torch.LongTensor(s) for s in batch_text], + batch_first=True, padding_value=self.pad_idx + ) + text_lengths = torch.tensor([t.shape[0] for t in text_tensor]) + + #cats_tensor = torch.nn.utils.rnn.pad_sequence( + # [torch.LongTensor(s) for s in batch_cats], + # batch_first=True, padding_value=self.pad_idx + #) + #cats_lengths = torch.LongTensor(list(map(len, batch_cats))) + + # Pad cats_tensor to all possible categories + # TODO will this be necessary with larger training sets, that should + # encompass all categories? Best to be safe... + all_cats = list(set(itertools.chain(*batch_cats))) + num_cats = len(all_cats) + # if there's no 0, there was no , so increment to allow for it to be a possible category + if 0 not in all_cats: + num_cats += 1 + cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).long() + cats_lengths = torch.LongTensor(list(map(len, batch_cats))) + for idx, (c, clen) in enumerate(zip(batch_cats, cats_lengths)): + cats_tensor[idx, :clen] = torch.LongTensor(c) + + # XXX why?? + ## SORT YOUR TENSORS BY LENGTH! + text_lengths, perm_idx = text_lengths.sort(0, descending=True) + text_tensor = text_tensor[perm_idx] + cats_tensor = cats_tensor[perm_idx] + + #print(text_tensor) + #print("text shape:", text_tensor.shape) + #print(cats_tensor) + #print("cats shape:", cats_tensor.shape) + #print(text_lengths) + #print("text_lengths shape:", text_lengths.shape) + + #sys.exit(0) + return ( - T.ToTensor(self.pad_idx)(list(batch[0])), - T.ToTensor(self.pad_idx)(list(batch[1])), + text_tensor, + cats_tensor, + text_lengths, ) -class TextClassificationModel(nn.Module): - def __init__(self, input_size, output_size, verbose): - super().__init__() - - def forward(self, x): - return x - - -def train(dataloader): +def train(dataloader, model, optimizer, criterion): model.train() total_acc, total_count = 0, 0 log_interval = 500 start_time = time.time() - for idx, (label, text) in enumerate(dataloader): + for idx, (text, cats, text_lengths) in enumerate(dataloader): optimizer.zero_grad() - predicted_label = model(text) - loss = criterion(predicted_label, label) + + print("text_lengths shape", text_lengths.shape) + print("input shape", text.shape) + print("target", cats) + print("target shape", cats.shape) + + output = model(text, text_lengths) + print("output", output) + print("output shape", output.shape) + # reshape output and target for cross entropy loss +# output = output.reshape(output.size(0)*output.size(1), -1) # (batch * seq_len x classes) +# cats = cats.reshape(-1) # (batch * seq_len), class index +# print("output", output) +# print("output shape", output.shape) +# print("target shape", cats.shape) +# print() + + loss = criterion(input=output, target=cats) loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) + optimizer.step() + total_acc += (predicted_label.argmax(1) == label).sum().item() total_count += label.size(0) if idx % log_interval == 0 and idx > 0: @@ -256,7 +330,7 @@ def train(dataloader): start_time = time.time() -def evaluate(dataloader): +def evaluate(dataloader, model, criterion): model.eval() total_acc, total_count = 0, 0 @@ -324,7 +398,8 @@ def main(): lang_column="language", verbose=args.verbose, ) - #print(dataset[2]) + #for text, cat in enumerate(train_dataset): + # print(text, cat) #for text, cat in enumerate(valid_dataset): # print(text, cat) #sys.exit(0) @@ -361,24 +436,59 @@ def main(): ) #for i_batch, sample_batched in enumerate(dataloader): # print(i_batch, sample_batched[0], sample_batched[1]) + #for i_batch, sample_batched in enumerate(train_dataloader): + #print(i_batch, sample_batched[0], sample_batched[1]) + #print(i_batch) + #print("batch elements:") + #for i in sample_batched: + # print(i) + # print(i.shape) + # print("\n") #sys.exit(0) - num_class = len(set([cats for key, cats, text, lang in train_data.values])) input_size = len(train_dataset.text_vocab) - output_size = len(train_dataset.cats_vocab) - emsize = 64 - model = TextClassificationModel(input_size, output_size, args.verbose).to(device) + output_size = len(train_dataset.cats_vocab) # every output item is the likelihood of a particular category + embed = torch.empty(input_size, len(train_dataset)) # tokens per sample x samples + embedding_size = embed.size(1) # was 64 (should be: samples) + num_layers = 2 # 2-3 layers should be enough for LTSM + hidden_size = 128 # hidden size of rnn module, should be tweaked manually + mean_seq = True # use mean of rnn output + weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4 + + #for i in train_dataset.text_vocab.get_itos(): + # print(i) + print("input_size: ", input_size) + print("output_size:", output_size) + print("embed shape:", embed.shape) + print("embedding_size:", embedding_size, " (that is, number of samples)") + + model = RNN( + #rnn_model='GRU', + rnn_model='LSTM', + vocab_size=input_size, + embed_size=embedding_size, + num_output=output_size, + use_last=(not mean_seq), + hidden_size=hidden_size, + embedding_tensor=embed, + num_layers=num_layers, + batch_first=True + ) + print(model) + + # optimizer and loss + #optimizer = torch.optim.SGD(model.parameters(), lr=LR) + criterion = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay) + print(criterion) + print(optimizer) - criterion = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.SGD(model.parameters(), lr=LR) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1) total_accu = None - for epoch in range(1, EPOCHS + 1): epoch_start_time = time.time() - train(train_dataloader) - accu_val = evaluate(valid_dataloader) + train(train_dataloader, model, optimizer, criterion) + accu_val = evaluate(valid_dataloader, model, criterion) if total_accu is not None and total_accu > accu_val: scheduler.step() else: diff --git a/africat/models/classifier.py b/africat/models/classifier.py new file mode 100644 index 0000000..99cb594 --- /dev/null +++ b/africat/models/classifier.py @@ -0,0 +1,47 @@ +import torch.nn as nn + +class RNN(nn.Module): + #define all the layers used in model + def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, + n_layers, bidirectional, dropout): + super().__init__() + + #embedding layer + self.embedding = nn.Embedding(vocab_size, embedding_dim) + + #lstm layer + self.lstm = nn.LSTM(embedding_dim, + hidden_dim, + num_layers=n_layers, + bidirectional=bidirectional, + dropout=dropout, + batch_first=True) + + #dense layer + self.fc = nn.Linear(hidden_dim * 2, output_dim) + + #activation function + self.act = nn.Sigmoid() + + def forward(self, text, text_lengths): + #text = [batch size,sent_length] + embedded = self.embedding(text) + #embedded = [batch size, sent_len, emb dim] + + #packed sequence + packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths, batch_first=True) + + packed_output, (hidden, cell) = self.lstm(packed_embedded) + #hidden = [batch size, num layers * num directions,hid dim] + #cell = [batch size, num layers * num directions,hid dim] + + #concat the final forward and backward hidden state + hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1) + + #hidden = [batch size, hid dim * num directions] + dense_outputs=self.fc(hidden) + + #Final activation function + outputs=self.act(dense_outputs) + + return outputs diff --git a/africat/models/multiclass.py b/africat/models/multiclass.py new file mode 100644 index 0000000..14a21b2 --- /dev/null +++ b/africat/models/multiclass.py @@ -0,0 +1,14 @@ +import torch +import torch.nn as nn + +class Multiclass(nn.Module): + def __init__(self): + super().__init__() + self.hidden = nn.Linear(4, 8) + self.act = nn.ReLU() + self.output = nn.Linear(8, 3) + + def forward(self, x): + x = self.act(self.hidden(x)) + x = self.output(x) + return x diff --git a/africat/models/rnn.py b/africat/models/rnn.py new file mode 100644 index 0000000..93567ff --- /dev/null +++ b/africat/models/rnn.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + + +class RNN(nn.Module): + + def __init__(self, vocab_size, embed_size, num_output, rnn_model='LSTM', use_last=True, embedding_tensor=None, + padding_index=0, hidden_size=64, num_layers=1, batch_first=True): + """ + + Args: + vocab_size: vocab size + embed_size: embedding size + num_output: number of output (classes) + rnn_model: LSTM or GRU + use_last: bool + embedding_tensor: + padding_index: + hidden_size: hidden size of rnn module + num_layers: number of layers in rnn module + batch_first: batch first option + """ + + super(RNN, self).__init__() + self.use_last = use_last + # embedding + self.encoder = None + if torch.is_tensor(embedding_tensor): + self.encoder = nn.Embedding(vocab_size, embed_size, padding_idx=padding_index, _weight=embedding_tensor) + self.encoder.weight.requires_grad = False + else: + self.encoder = nn.Embedding(vocab_size, embed_size, padding_idx=padding_index) + + self.drop_en = nn.Dropout(p=0.6) + + # rnn module + if rnn_model == 'LSTM': + self.rnn = nn.LSTM( input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, dropout=0.5, + batch_first=True, bidirectional=True) + elif rnn_model == 'GRU': + self.rnn = nn.GRU( input_size=embed_size, hidden_size=hidden_size, num_layers=num_layers, dropout=0.5, + batch_first=True, bidirectional=True) + else: + raise LookupError(' only support LSTM and GRU') + + + self.bn2 = nn.BatchNorm1d(hidden_size*2) + self.fc = nn.Linear(hidden_size*2, num_output) + + def forward(self, x, seq_lengths): + ''' + Args: + x: (batch, time_step, input_size) + + Returns: + num_output size + ''' + + x_embed = self.encoder(x) + x_embed = self.drop_en(x_embed) + packed_input = pack_padded_sequence(x_embed, seq_lengths.cpu().numpy(),batch_first=True) + + # r_out shape (batch, time_step, output_size) + # None is for initial hidden state + packed_output, ht = self.rnn(packed_input, None) + out_rnn, _ = pad_packed_sequence(packed_output, batch_first=True) + + row_indices = torch.arange(0, x.size(0)).long() + col_indices = seq_lengths - 1 + if next(self.parameters()).is_cuda: + row_indices = row_indices.cuda() + col_indices = col_indices.cuda() + + if self.use_last: + last_tensor=out_rnn[row_indices, col_indices, :] + else: + # use mean + last_tensor = out_rnn[row_indices, :, :] + last_tensor = torch.mean(last_tensor, dim=1) + + fc_input = self.bn2(last_tensor) + out = self.fc(fc_input) + return out