This commit is contained in:
Timothy Allen 2023-12-13 20:22:39 +02:00
parent a871c9235c
commit 94025fc0c6
4 changed files with 22 additions and 7 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
data/
__pycache__

View File

@ -0,0 +1,5 @@
This is a multi-class, multi-label network that categorises text into one of ~160 categories, mostly relating to the African continent.
The training dataset is a proprietry dataset from allAfrica.com, consisting of stories that have been manuially categorised according to AllAfrica's inhouse categorisation scheme.
The trained model is freely available, as is the training and evaluation code.

View File

@ -191,3 +191,5 @@ def main():
if __name__ == "__main__":
main()
# vim: set expandtab shiftwidth=2 softtabstop=2:

View File

@ -14,8 +14,7 @@ import random
import pandas as pd
import numpy as np
import itertools
#from pandarallel import pandarallel
from tqdm import tqdm
import tqdm
# torch
import torch
import torchdata.datapipes as dp
@ -38,7 +37,7 @@ def read_csv(input_csv, rows=None, verbose=0):
if verbose > 0:
with open(input_csv, 'r', encoding="utf-8") as f:
data = pd.concat(
[chunk for chunk in tqdm(
[chunk for chunk in tqdm.tqdm(
pd.read_csv(f,
encoding="utf-8",
quoting=csv.QUOTE_ALL,
@ -345,7 +344,6 @@ def hot_one_to_labels(vocab, tensor):
def train(dataloader, dataset, model, optimizer, criterion):
model.train()
total_acc, total_count = 0, 0
log_interval = 500
start_time = time.time()
@ -377,6 +375,7 @@ def train(dataloader, dataset, model, optimizer, criterion):
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
return
total_acc += (output == cats).sum().item()
total_count += cats.size(0)
if idx % log_interval == 0 and idx > 0:
@ -556,8 +555,15 @@ def main():
total_accu = None
for epoch in range(1, EPOCHS + 1):
model.train()
epoch_start_time = time.time()
with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar:
bar.set_description(f"Epoch {epoch}")
train(train_dataloader, train_dataset, model, optimizer, criterion)
bar.set_postfix(
# loss=float(loss),
# acc=float(acc)
)
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
if total_accu is not None and total_accu > accu_val:
scheduler.step()
@ -581,3 +587,4 @@ def main():
if __name__ == "__main__":
main()
# vim: set expandtab shiftwidth=2 softtabstop=2: