This commit is contained in:
Timothy Allen 2023-12-13 20:22:39 +02:00
parent a871c9235c
commit 94025fc0c6
4 changed files with 22 additions and 7 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
data/ data/
__pycache__

View File

@ -0,0 +1,5 @@
This is a multi-class, multi-label network that categorises text into one of ~160 categories, mostly relating to the African continent.
The training dataset is a proprietry dataset from allAfrica.com, consisting of stories that have been manuially categorised according to AllAfrica's inhouse categorisation scheme.
The trained model is freely available, as is the training and evaluation code.

View File

@ -191,3 +191,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# vim: set expandtab shiftwidth=2 softtabstop=2:

View File

@ -14,8 +14,7 @@ import random
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import itertools import itertools
#from pandarallel import pandarallel import tqdm
from tqdm import tqdm
# torch # torch
import torch import torch
import torchdata.datapipes as dp import torchdata.datapipes as dp
@ -38,7 +37,7 @@ def read_csv(input_csv, rows=None, verbose=0):
if verbose > 0: if verbose > 0:
with open(input_csv, 'r', encoding="utf-8") as f: with open(input_csv, 'r', encoding="utf-8") as f:
data = pd.concat( data = pd.concat(
[chunk for chunk in tqdm( [chunk for chunk in tqdm.tqdm(
pd.read_csv(f, pd.read_csv(f,
encoding="utf-8", encoding="utf-8",
quoting=csv.QUOTE_ALL, quoting=csv.QUOTE_ALL,
@ -345,7 +344,6 @@ def hot_one_to_labels(vocab, tensor):
def train(dataloader, dataset, model, optimizer, criterion): def train(dataloader, dataset, model, optimizer, criterion):
model.train()
total_acc, total_count = 0, 0 total_acc, total_count = 0, 0
log_interval = 500 log_interval = 500
start_time = time.time() start_time = time.time()
@ -377,6 +375,7 @@ def train(dataloader, dataset, model, optimizer, criterion):
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)] [pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
return return
total_acc += (output == cats).sum().item() total_acc += (output == cats).sum().item()
total_count += cats.size(0) total_count += cats.size(0)
if idx % log_interval == 0 and idx > 0: if idx % log_interval == 0 and idx > 0:
@ -556,8 +555,15 @@ def main():
total_accu = None total_accu = None
for epoch in range(1, EPOCHS + 1): for epoch in range(1, EPOCHS + 1):
model.train()
epoch_start_time = time.time() epoch_start_time = time.time()
with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar:
bar.set_description(f"Epoch {epoch}")
train(train_dataloader, train_dataset, model, optimizer, criterion) train(train_dataloader, train_dataset, model, optimizer, criterion)
bar.set_postfix(
# loss=float(loss),
# acc=float(acc)
)
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion) accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
if total_accu is not None and total_accu > accu_val: if total_accu is not None and total_accu > accu_val:
scheduler.step() scheduler.step()
@ -581,3 +587,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# vim: set expandtab shiftwidth=2 softtabstop=2: