This commit is contained in:
Timothy Allen 2023-12-13 20:22:39 +02:00
parent a871c9235c
commit 94025fc0c6
4 changed files with 22 additions and 7 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
data/
__pycache__

View File

@ -0,0 +1,5 @@
This is a multi-class, multi-label network that categorises text into one of ~160 categories, mostly relating to the African continent.
The training dataset is a proprietry dataset from allAfrica.com, consisting of stories that have been manuially categorised according to AllAfrica's inhouse categorisation scheme.
The trained model is freely available, as is the training and evaluation code.

View File

@ -191,3 +191,5 @@ def main():
if __name__ == "__main__":
main()
# vim: set expandtab shiftwidth=2 softtabstop=2:

View File

@ -14,8 +14,7 @@ import random
import pandas as pd
import numpy as np
import itertools
#from pandarallel import pandarallel
from tqdm import tqdm
import tqdm
# torch
import torch
import torchdata.datapipes as dp
@ -30,15 +29,15 @@ story_num = 64 # XXX None for all
# Hyperparameters
EPOCHS = 10 # epoch
#LR = 5 # learning rate
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
#LR = 5 # learning rate
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
BATCH_SIZE = 64 # batch size for training
def read_csv(input_csv, rows=None, verbose=0):
if verbose > 0:
with open(input_csv, 'r', encoding="utf-8") as f:
data = pd.concat(
[chunk for chunk in tqdm(
[chunk for chunk in tqdm.tqdm(
pd.read_csv(f,
encoding="utf-8",
quoting=csv.QUOTE_ALL,
@ -345,7 +344,6 @@ def hot_one_to_labels(vocab, tensor):
def train(dataloader, dataset, model, optimizer, criterion):
model.train()
total_acc, total_count = 0, 0
log_interval = 500
start_time = time.time()
@ -377,6 +375,7 @@ def train(dataloader, dataset, model, optimizer, criterion):
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
return
total_acc += (output == cats).sum().item()
total_count += cats.size(0)
if idx % log_interval == 0 and idx > 0:
@ -556,8 +555,15 @@ def main():
total_accu = None
for epoch in range(1, EPOCHS + 1):
model.train()
epoch_start_time = time.time()
train(train_dataloader, train_dataset, model, optimizer, criterion)
with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar:
bar.set_description(f"Epoch {epoch}")
train(train_dataloader, train_dataset, model, optimizer, criterion)
bar.set_postfix(
# loss=float(loss),
# acc=float(acc)
)
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
if total_accu is not None and total_accu > accu_val:
scheduler.step()
@ -581,3 +587,4 @@ def main():
if __name__ == "__main__":
main()
# vim: set expandtab shiftwidth=2 softtabstop=2: