From 94025fc0c68491476e75db1a7012ffa755c4ae1b Mon Sep 17 00:00:00 2001 From: Timothy Allen Date: Wed, 13 Dec 2023 20:22:39 +0200 Subject: [PATCH] Metadata --- .gitignore | 1 + README.md | 5 +++++ africat/aa_create_dataset.py | 2 ++ africat/categorise.py | 21 ++++++++++++++------- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 8fce603..773b1fc 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ data/ +__pycache__ diff --git a/README.md b/README.md index e69de29..d4c11c8 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,5 @@ +This is a multi-class, multi-label network that categorises text into one of ~160 categories, mostly relating to the African continent. + +The training dataset is a proprietry dataset from allAfrica.com, consisting of stories that have been manuially categorised according to AllAfrica's inhouse categorisation scheme. + +The trained model is freely available, as is the training and evaluation code. diff --git a/africat/aa_create_dataset.py b/africat/aa_create_dataset.py index 6bf7ee2..a93045b 100755 --- a/africat/aa_create_dataset.py +++ b/africat/aa_create_dataset.py @@ -191,3 +191,5 @@ def main(): if __name__ == "__main__": main() + +# vim: set expandtab shiftwidth=2 softtabstop=2: diff --git a/africat/categorise.py b/africat/categorise.py index 3f7adfc..ce4952e 100755 --- a/africat/categorise.py +++ b/africat/categorise.py @@ -14,8 +14,7 @@ import random import pandas as pd import numpy as np import itertools -#from pandarallel import pandarallel -from tqdm import tqdm +import tqdm # torch import torch import torchdata.datapipes as dp @@ -30,15 +29,15 @@ story_num = 64 # XXX None for all # Hyperparameters EPOCHS = 10 # epoch -#LR = 5 # learning rate -LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate +#LR = 5 # learning rate +LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate BATCH_SIZE = 64 # batch size for training def read_csv(input_csv, rows=None, verbose=0): if verbose > 0: with open(input_csv, 'r', encoding="utf-8") as f: data = pd.concat( - [chunk for chunk in tqdm( + [chunk for chunk in tqdm.tqdm( pd.read_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL, @@ -345,7 +344,6 @@ def hot_one_to_labels(vocab, tensor): def train(dataloader, dataset, model, optimizer, criterion): - model.train() total_acc, total_count = 0, 0 log_interval = 500 start_time = time.time() @@ -377,6 +375,7 @@ def train(dataloader, dataset, model, optimizer, criterion): [pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)] return + total_acc += (output == cats).sum().item() total_count += cats.size(0) if idx % log_interval == 0 and idx > 0: @@ -556,8 +555,15 @@ def main(): total_accu = None for epoch in range(1, EPOCHS + 1): + model.train() epoch_start_time = time.time() - train(train_dataloader, train_dataset, model, optimizer, criterion) + with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar: + bar.set_description(f"Epoch {epoch}") + train(train_dataloader, train_dataset, model, optimizer, criterion) + bar.set_postfix( + # loss=float(loss), + # acc=float(acc) + ) accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion) if total_accu is not None and total_accu > accu_val: scheduler.step() @@ -581,3 +587,4 @@ def main(): if __name__ == "__main__": main() +# vim: set expandtab shiftwidth=2 softtabstop=2: