diff --git a/africat/categorise.py b/africat/categorise.py index ce4952e..8e1dd97 100755 --- a/africat/categorise.py +++ b/africat/categorise.py @@ -25,13 +25,25 @@ from torch import nn from models.rnn import RNN -story_num = 64 # XXX None for all +all_categories = list() +# XXX None for all stories +#story_num = 128 +#story_num = 256 +story_num = 512 +#story_num = 1024 +#story_num = None # Hyperparameters EPOCHS = 10 # epoch +#EPOCHS = 2 # epoch #LR = 5 # learning rate -LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate -BATCH_SIZE = 64 # batch size for training +#LR = 0.5 +LR = 0.05 +#LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate +#BATCH_SIZE = 64 # batch size for training +#BATCH_SIZE = 16 # batch size for training +BATCH_SIZE = 8 # batch size for training +#BATCH_SIZE = 4 # batch size for training def read_csv(input_csv, rows=None, verbose=0): if verbose > 0: @@ -136,7 +148,8 @@ class TextCategoriesDataset(Dataset): #print(self.text_vocab.get_itos()) self.cats_vocab = build_vocab_from_iterator( - [self.catTokens(cats) for i, cats in self.df[cats_column].items()], + #[self.catTokens(cats) for i, cats in self.df[cats_column].items()], + [self.catTokens(all_categories)], min_freq=1, specials=[''], special_first=True @@ -163,6 +176,10 @@ class TextCategoriesDataset(Dataset): if self.transform: text, cats = self.transform(text, cats) + #print(cats) + #print(self.catTokens(cats)) + #print(self.getTransform(self.cats_vocab, "cats")(self.catTokens(cats))) + # Numericalise by applying transforms return ( self.getTransform(self.text_vocab, "text")(self.textTokens(text)), @@ -217,11 +234,12 @@ class CollateBatch: in a batch of equal length. We can do this a collate_fn callback class, which returns a tensor ''' - def __init__(self, pad_idx): + def __init__(self, pad_idx, cats): ''' pad_idx (int): the index of the "" token in the vocabulary. ''' self.pad_idx = pad_idx + self.cats = cats def __call__(self, batch): ''' @@ -250,42 +268,30 @@ class CollateBatch: #) #cats_lengths = torch.LongTensor(list(map(len, batch_cats))) + ''' # Pad cats_tensor to all possible categories - # TODO will this be necessary with larger training sets, that should - # encompass all categories? Best to be safe... - all_cats = list(set(itertools.chain(*batch_cats))) - num_cats = len(all_cats) - # if there's no 0, there was no , so increment to allow for it to be a possible category - if 0 not in all_cats: - num_cats += 1 - - # Convert cats to tensor - #cats_tensor = nn.utils.rnn.pad_sequence( - # [torch.LongTensor(s) for s in batch_cats], - # batch_first=True, padding_value=self.pad_idx - #) - #cats_lengths = torch.LongTensor(list(map(len, batch_cats))) - - # Convert cats to tensor, alt version - #cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).long() - #cats_lengths = torch.LongTensor(list(map(len, batch_cats))) - #for idx, (c, clen) in enumerate(zip(batch_cats, cats_lengths)): - # cats_tensor[idx, :clen] = torch.LongTensor(c) - #print([torch.LongTensor(s) for s in batch_cats]) - #print(torch.LongTensor([torch.LongTensor(s) for s in batch_cats])) - #cats_tensor = nn.functional.one_hot(torch.LongTensor([torch.LongTensor(s) for s in batch_cats]), num_cats) - #cats_tensor = nn.functional.one_hot(torch.FloatTensor(batch_cats), num_cats) + num_cats = len(all_categories) # Convert cats to multi-label one-hot representation - # This will be a target for CrossEntropyLoss(pred, target), - # which takes FloatTensor pred and LongTensor target cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).float() cats_lengths = torch.LongTensor(list(map(len, batch_cats))) for idx, cats in enumerate(batch_cats): - #print("\nsample", idx, cats) - for c in cats: - cats_tensor[idx][c] = 1 - #print(cats_tensor[idx]) + #print("\nsample", idx, cats) + for c in cats: + #print(c) + cats_tensor[idx][c] = 1 + #print(cats_tensor[idx]) + ''' + # Convert cats to multi-label one-hot representation + # add one to all_categories to account for + cats_tensor = torch.full((len(batch_cats), len(all_categories)+1), self.pad_idx).float() + for idx, cats in enumerate(batch_cats): + #print("\nsample", idx, cats) + for c in cats: + cats_tensor[idx][c] = 1 + #print(cats_tensor[idx]) + #sys.exit(0) + ''' # XXX why?? @@ -310,10 +316,10 @@ class CollateBatch: text_lengths, ) -def labels_to_hot_one(label_vocab, labels, pad_idx: int): +def cat2tensor(label_vocab, labels, pad_idx: int): all_labels = vocab.get_itos() num_labels = len(all_labels) - # if there's no 0, there was no , so increment to allow for it to be a possible category + # add if 0 not in all_labels: num_labels += 1 @@ -326,40 +332,49 @@ def labels_to_hot_one(label_vocab, labels, pad_idx: int): #print(labels_tensor[idx]) return labels_tensor -def hot_one_to_labels(vocab, tensor): +def tensor2cat(vocab, tensor): + all_cats = vocab.get_itos() if tensor.ndimension() == 2: - all_labels = vocab.get_itos() batch = list() for result in tensor: chance = dict() for idx, pred in enumerate(result): if pred > 0: # XXX - chance[all_labels[idx]] = pred.item() - print(chance) + chance[all_cats[idx]] = pred.item() + #print(chance) batch.append(chance) return batch + elif tensor.ndimension() == 1: + chance = dict() + for idx, pred in enumerate(tensor): + if idx >= len(all_cats): + print(f"Idx {idx} not in {len(all_cats)} categories") + elif pred > 0: # XXX + #print(idx, len(all_cats)) + chance[all_cats[idx]] = pred.item() + #print(chance) + return chance else: raise ValueError("Only tensors with 2 dimensions are supported") + return vocab.get_itos(cat) -def train(dataloader, dataset, model, optimizer, criterion): +def train(dataloader, dataset, model, optimizer, criterion, epoch=0): total_acc, total_count = 0, 0 log_interval = 500 - start_time = time.time() torch.set_printoptions(precision=2) - for idx, (text, cats, text_lengths) in enumerate(dataloader): + + batch = tqdm.tqdm(dataloader, unit="batch") + for idx, data in enumerate(batch): + batch.set_description(f"Train {epoch}.{idx}") + text, cats, text_lengths = data optimizer.zero_grad() - print("text_lengths shape", text_lengths.shape) - print("input shape", text.shape) - print("target", cats) - print("target shape", cats.shape) - output = model(text, text_lengths) - print("output", output) - print("output shape", output.shape) + #print("output", output) + #print("output shape", output.shape) loss = criterion(input=output, target=cats) loss.backward() @@ -368,45 +383,85 @@ def train(dataloader, dataset, model, optimizer, criterion): optimizer.step() - print(loss) - print("expected", cats) - [pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, cats)] - print("predicted", output) - [pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)] + #print("train loss",loss) - return + ##predicted = np.round(output) + ##total_acc += (predicted == cats).sum().item() - total_acc += (output == cats).sum().item() + predictions = torch.zeros(output.shape) + predictions[output >= 0.25] = True + #predictions[output >= 0.5] = True + #predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5 + + batch.clear() + for target, out, pred in list(zip(cats, output, predictions)): + expect = tensor2cat(dataset.cats_vocab, target) + raw = tensor2cat(dataset.cats_vocab, out) + predict = tensor2cat(dataset.cats_vocab, pred) + print("Expected: ", expect) + print("Predicted: ", predict) + print("Raw output:", raw) + print("\n") + batch.refresh() + + N, C = cats.shape + #print("eq", (output == cats)) + #print("sum", (output == cats).sum()) + #print("accuracy", (output == cats).sum() / (N*C) * 100) + accuracy = (output == cats).sum() / (N*C) * 100 + total_acc += accuracy + #print("train accuracy", accuracy) + #print("train total_acc", total_acc) total_count += cats.size(0) - if idx % log_interval == 0 and idx > 0: - elapsed = time.time() - start_time - print( - "| epoch {:3d} | {:5d}/{:5d} batches " - "| accuracy {:8.3f}".format( - epoch, idx, len(dataloader), total_acc / total_count - ) - ) - total_acc, total_count = 0, 0 - start_time = time.time() + batch.set_postfix({ + "accuracy": int(total_acc / total_count), + }) + total_acc, total_count = 0, 0 -def evaluate(dataloader, dataset, model, criterion): +def evaluate(dataloader, dataset, model, criterion, epoch=0): model.eval() total_acc, total_count = 0, 0 with torch.no_grad(): - for idx, (text, cats, text_lengths) in enumerate(dataloader): - predicted_label = model(text, text_lengths) - print(predicted_label) - loss = criterion(predicted_label, label) - print(loss) - print("expected labels:", label) - print([dataset.cats_vocab.get_itos(i) for i in label]) - print("predicted labels:", predicted_label) - print([dataset.cats_vocab.get_itos(i) for i in predicted_label]) - total_acc += (predicted_label.argmax(1) == label).sum().item() - total_count += label.size(0) - return total_acc / total_count + batch = tqdm.tqdm(dataloader, unit="batch") + for idx, data in enumerate(batch): + batch.set_description(f"Evaluate {epoch}.{idx}") + text, cats, text_lengths = data + + output = model(text, text_lengths) + #print("eval predicted", output) + + loss = criterion(output, cats) + #print("eval loss", loss) + + predictions = torch.zeros(output.shape) + predictions[output >= 0.5] = True + predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5 + + batch.clear() + for target, out, pred in list(zip(cats, output, predictions)): + expect = tensor2cat(dataset.cats_vocab, target) + raw = tensor2cat(dataset.cats_vocab, out) + predict = tensor2cat(dataset.cats_vocab, pred) + print("Evaluate expected: ", expect) + print("Evaluate predicted: ", predict) + print("Evaluate raw output:", raw) + print("\n") + batch.refresh() + + ##total_acc += (predicted_cats.argmax(1) == cats).sum().item() + N, C = cats.shape + accuracy = (predictions == cats).sum() / (N*C) * 100 + total_acc += accuracy + #print("eval accuracy", accuracy) + #print("eval total_acc", total_acc) + total_count += cats.size(0) + + batch.set_postfix({ + "accuracy": int(total_acc / total_count), + }) + return total_acc / total_count def main(): @@ -442,6 +497,18 @@ def main(): sys.exit(1) data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose) + + # create list of all categories + global all_categories + for cats in data.categories: + for c in cats.split(";"): + if c not in all_categories: + all_categories.append(c) + all_categories = sorted(all_categories) + #print(all_categories) + #print(len(all_categories)) + #sys.exit(0) + train_data, valid_data, = split_dataset(data, verbose=args.verbose) ''' @@ -466,6 +533,7 @@ def main(): ) #for text, cat in enumerate(train_dataset): # print(text, cat) + #print("-" * 20) #for text, cat in enumerate(valid_dataset): # print(text, cat) #sys.exit(0) @@ -483,33 +551,30 @@ def main(): ''' dataloader = DataLoader(dataset, batch_size=4, + drop_last=True, shuffle=True, num_workers=0, - collate_fn=CollateBatch(pad_idx=dataset.stoi['']), + collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['']), ) ''' train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, + drop_last=True, shuffle=True, num_workers=0, - collate_fn=CollateBatch(pad_idx=train_dataset.stoi['']), + collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['']), ) valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, + drop_last=True, shuffle=True, num_workers=0, - collate_fn=CollateBatch(pad_idx=valid_dataset.stoi['']), + collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['']), ) #for i_batch, sample_batched in enumerate(dataloader): # print(i_batch, sample_batched[0], sample_batched[1]) #for i_batch, sample_batched in enumerate(train_dataloader): - #print(i_batch, sample_batched[0], sample_batched[1]) - #print(i_batch) - #print("batch elements:") - #for i in sample_batched: - # print(i) - # print(i.shape) - # print("\n") + # print(i_batch, sample_batched[0], sample_batched[1]) #sys.exit(0) input_size = len(train_dataset.text_vocab) @@ -547,40 +612,32 @@ def main(): # optimizer and loss #optimizer = torch.optim.SGD(model.parameters(), lr=LR) - criterion = nn.CrossEntropyLoss() + criterion = nn.BCEWithLogitsLoss() optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay) if args.verbose: print(criterion) print(optimizer) total_accu = None - for epoch in range(1, EPOCHS + 1): + #for epoch in range(1, EPOCHS + 1): + e = tqdm.tqdm(range(1, EPOCHS + 1), unit="epoch") + for epoch in e: + e.set_description(f"Epoch {epoch}") model.train() - epoch_start_time = time.time() - with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar: - bar.set_description(f"Epoch {epoch}") - train(train_dataloader, train_dataset, model, optimizer, criterion) - bar.set_postfix( - # loss=float(loss), - # acc=float(acc) - ) - accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion) + train(train_dataloader, train_dataset, model, optimizer, criterion, epoch) + + accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch) if total_accu is not None and total_accu > accu_val: - scheduler.step() + optimizer.step() else: total_accu = accu_val - print("-" * 59) - print( - "| end of epoch {:3d} | time: {:5.2f}s | " - "valid accuracy {:8.3f} ".format( - epoch, time.time() - epoch_start_time, accu_val - ) - ) - print("-" * 59) + e.set_postfix({ + "accuracy": accu_val.int().item(), + }) - print("Checking the results of test dataset.") - accu_test = evaluate(test_dataloader, test_dataset) - print("test accuracy {:8.3f}".format(accu_test)) +# print("Checking the results of test dataset.") +# accu_test = evaluate(test_dataloader, test_dataset) +# print("test accuracy {:8.3f}".format(accu_test)) return