This commit is contained in:
Timothy Allen 2023-12-13 20:22:39 +02:00
parent a871c9235c
commit 94025fc0c6
4 changed files with 22 additions and 7 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
data/ data/
__pycache__

View File

@ -0,0 +1,5 @@
This is a multi-class, multi-label network that categorises text into one of ~160 categories, mostly relating to the African continent.
The training dataset is a proprietry dataset from allAfrica.com, consisting of stories that have been manuially categorised according to AllAfrica's inhouse categorisation scheme.
The trained model is freely available, as is the training and evaluation code.

View File

@ -191,3 +191,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# vim: set expandtab shiftwidth=2 softtabstop=2:

View File

@ -14,8 +14,7 @@ import random
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import itertools import itertools
#from pandarallel import pandarallel import tqdm
from tqdm import tqdm
# torch # torch
import torch import torch
import torchdata.datapipes as dp import torchdata.datapipes as dp
@ -30,15 +29,15 @@ story_num = 64 # XXX None for all
# Hyperparameters # Hyperparameters
EPOCHS = 10 # epoch EPOCHS = 10 # epoch
#LR = 5 # learning rate #LR = 5 # learning rate
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
BATCH_SIZE = 64 # batch size for training BATCH_SIZE = 64 # batch size for training
def read_csv(input_csv, rows=None, verbose=0): def read_csv(input_csv, rows=None, verbose=0):
if verbose > 0: if verbose > 0:
with open(input_csv, 'r', encoding="utf-8") as f: with open(input_csv, 'r', encoding="utf-8") as f:
data = pd.concat( data = pd.concat(
[chunk for chunk in tqdm( [chunk for chunk in tqdm.tqdm(
pd.read_csv(f, pd.read_csv(f,
encoding="utf-8", encoding="utf-8",
quoting=csv.QUOTE_ALL, quoting=csv.QUOTE_ALL,
@ -345,7 +344,6 @@ def hot_one_to_labels(vocab, tensor):
def train(dataloader, dataset, model, optimizer, criterion): def train(dataloader, dataset, model, optimizer, criterion):
model.train()
total_acc, total_count = 0, 0 total_acc, total_count = 0, 0
log_interval = 500 log_interval = 500
start_time = time.time() start_time = time.time()
@ -377,6 +375,7 @@ def train(dataloader, dataset, model, optimizer, criterion):
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)] [pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
return return
total_acc += (output == cats).sum().item() total_acc += (output == cats).sum().item()
total_count += cats.size(0) total_count += cats.size(0)
if idx % log_interval == 0 and idx > 0: if idx % log_interval == 0 and idx > 0:
@ -556,8 +555,15 @@ def main():
total_accu = None total_accu = None
for epoch in range(1, EPOCHS + 1): for epoch in range(1, EPOCHS + 1):
model.train()
epoch_start_time = time.time() epoch_start_time = time.time()
train(train_dataloader, train_dataset, model, optimizer, criterion) with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar:
bar.set_description(f"Epoch {epoch}")
train(train_dataloader, train_dataset, model, optimizer, criterion)
bar.set_postfix(
# loss=float(loss),
# acc=float(acc)
)
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion) accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
if total_accu is not None and total_accu > accu_val: if total_accu is not None and total_accu > accu_val:
scheduler.step() scheduler.step()
@ -581,3 +587,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# vim: set expandtab shiftwidth=2 softtabstop=2: