diff --git a/africat/categorise.py b/africat/categorise.py index 231bfdc..3f7adfc 100755 --- a/africat/categorise.py +++ b/africat/categorise.py @@ -198,11 +198,13 @@ class TextCategoriesDataset(Dataset): # for in vocabulary is 2 as seen in previous section T.AddToken(self.text_vocab[''], begin=False) ) - else: + elif vType == "cats": return T.Sequential( # converts the sentences to indices based on given vocabulary T.VocabTransform(vocab=vocab), ) + else: + raise Exception('wrong transformation type') ''' @@ -237,7 +239,7 @@ class CollateBatch: #cats_tensor = T.ToTensor(self.pad_idx)(batch_cats) # Pad text to the longest - text_tensor = torch.nn.utils.rnn.pad_sequence( + text_tensor = nn.utils.rnn.pad_sequence( [torch.LongTensor(s) for s in batch_text], batch_first=True, padding_value=self.pad_idx ) @@ -257,18 +259,44 @@ class CollateBatch: # if there's no 0, there was no , so increment to allow for it to be a possible category if 0 not in all_cats: num_cats += 1 - cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).long() - cats_lengths = torch.LongTensor(list(map(len, batch_cats))) - for idx, (c, clen) in enumerate(zip(batch_cats, cats_lengths)): - cats_tensor[idx, :clen] = torch.LongTensor(c) + # Convert cats to tensor + #cats_tensor = nn.utils.rnn.pad_sequence( + # [torch.LongTensor(s) for s in batch_cats], + # batch_first=True, padding_value=self.pad_idx + #) + #cats_lengths = torch.LongTensor(list(map(len, batch_cats))) + + # Convert cats to tensor, alt version + #cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).long() + #cats_lengths = torch.LongTensor(list(map(len, batch_cats))) + #for idx, (c, clen) in enumerate(zip(batch_cats, cats_lengths)): + # cats_tensor[idx, :clen] = torch.LongTensor(c) + #print([torch.LongTensor(s) for s in batch_cats]) + #print(torch.LongTensor([torch.LongTensor(s) for s in batch_cats])) + #cats_tensor = nn.functional.one_hot(torch.LongTensor([torch.LongTensor(s) for s in batch_cats]), num_cats) + #cats_tensor = nn.functional.one_hot(torch.FloatTensor(batch_cats), num_cats) + + # Convert cats to multi-label one-hot representation + # This will be a target for CrossEntropyLoss(pred, target), + # which takes FloatTensor pred and LongTensor target + cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).float() + cats_lengths = torch.LongTensor(list(map(len, batch_cats))) + for idx, cats in enumerate(batch_cats): + #print("\nsample", idx, cats) + for c in cats: + cats_tensor[idx][c] = 1 + #print(cats_tensor[idx]) + + ''' # XXX why?? ## SORT YOUR TENSORS BY LENGTH! text_lengths, perm_idx = text_lengths.sort(0, descending=True) text_tensor = text_tensor[perm_idx] cats_tensor = cats_tensor[perm_idx] + ''' - #print(text_tensor) + #print("text", text_tensor) #print("text shape:", text_tensor.shape) #print(cats_tensor) #print("cats shape:", cats_tensor.shape) @@ -283,13 +311,46 @@ class CollateBatch: text_lengths, ) +def labels_to_hot_one(label_vocab, labels, pad_idx: int): + all_labels = vocab.get_itos() + num_labels = len(all_labels) + # if there's no 0, there was no , so increment to allow for it to be a possible category + if 0 not in all_labels: + num_labels += 1 -def train(dataloader, model, optimizer, criterion): + labels_tensor = torch.full((len(labels), num_labels), pad_idx).float() + labels_lengths = torch.LongTensor(list(map(len, labels))) + for idx, labels in enumerate(labels): + #print("\nsample", idx, labels) + for l in labels: + labels_tensor[idx][l] = 1 + #print(labels_tensor[idx]) + return labels_tensor + +def hot_one_to_labels(vocab, tensor): + if tensor.ndimension() == 2: + all_labels = vocab.get_itos() + batch = list() + for result in tensor: + chance = dict() + for idx, pred in enumerate(result): + if pred > 0: # XXX + chance[all_labels[idx]] = pred.item() + print(chance) + batch.append(chance) + return batch + else: + raise ValueError("Only tensors with 2 dimensions are supported") + + + +def train(dataloader, dataset, model, optimizer, criterion): model.train() total_acc, total_count = 0, 0 log_interval = 500 start_time = time.time() + torch.set_printoptions(precision=2) for idx, (text, cats, text_lengths) in enumerate(dataloader): optimizer.zero_grad() @@ -301,23 +362,23 @@ def train(dataloader, model, optimizer, criterion): output = model(text, text_lengths) print("output", output) print("output shape", output.shape) - # reshape output and target for cross entropy loss -# output = output.reshape(output.size(0)*output.size(1), -1) # (batch * seq_len x classes) -# cats = cats.reshape(-1) # (batch * seq_len), class index -# print("output", output) -# print("output shape", output.shape) -# print("target shape", cats.shape) -# print() loss = criterion(input=output, target=cats) loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) + nn.utils.clip_grad_norm_(model.parameters(), 0.1) optimizer.step() - total_acc += (predicted_label.argmax(1) == label).sum().item() - total_count += label.size(0) + print(loss) + print("expected", cats) + [pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, cats)] + print("predicted", output) + [pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)] + + return + total_acc += (output == cats).sum().item() + total_count += cats.size(0) if idx % log_interval == 0 and idx > 0: elapsed = time.time() - start_time print( @@ -330,14 +391,20 @@ def train(dataloader, model, optimizer, criterion): start_time = time.time() -def evaluate(dataloader, model, criterion): +def evaluate(dataloader, dataset, model, criterion): model.eval() total_acc, total_count = 0, 0 with torch.no_grad(): - for idx, (label, text) in enumerate(dataloader): - predicted_label = model(text) + for idx, (text, cats, text_lengths) in enumerate(dataloader): + predicted_label = model(text, text_lengths) + print(predicted_label) loss = criterion(predicted_label, label) + print(loss) + print("expected labels:", label) + print([dataset.cats_vocab.get_itos(i) for i in label]) + print("predicted labels:", predicted_label) + print([dataset.cats_vocab.get_itos(i) for i in predicted_label]) total_acc += (predicted_label.argmax(1) == label).sum().item() total_count += label.size(0) return total_acc / total_count @@ -456,12 +523,13 @@ def main(): mean_seq = True # use mean of rnn output weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4 - #for i in train_dataset.text_vocab.get_itos(): - # print(i) - print("input_size: ", input_size) - print("output_size:", output_size) - print("embed shape:", embed.shape) - print("embedding_size:", embedding_size, " (that is, number of samples)") + if args.verbose: + #for i in train_dataset.text_vocab.get_itos(): + # print(i) + print("input_size: ", input_size) + print("output_size:", output_size) + print("embed shape:", embed.shape) + print("embedding_size:", embedding_size, " (that is, number of samples)") model = RNN( #rnn_model='GRU', @@ -475,20 +543,22 @@ def main(): num_layers=num_layers, batch_first=True ) - print(model) + if args.verbose: + print(model) # optimizer and loss #optimizer = torch.optim.SGD(model.parameters(), lr=LR) criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay) - print(criterion) - print(optimizer) + if args.verbose: + print(criterion) + print(optimizer) total_accu = None for epoch in range(1, EPOCHS + 1): epoch_start_time = time.time() - train(train_dataloader, model, optimizer, criterion) - accu_val = evaluate(valid_dataloader, model, criterion) + train(train_dataloader, train_dataset, model, optimizer, criterion) + accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion) if total_accu is not None and total_accu > accu_val: scheduler.step() else: @@ -503,7 +573,7 @@ def main(): print("-" * 59) print("Checking the results of test dataset.") - accu_test = evaluate(test_dataloader) + accu_test = evaluate(test_dataloader, test_dataset) print("test accuracy {:8.3f}".format(accu_test)) return