This commit is contained in:
Timothy Allen 2023-12-13 20:22:39 +02:00
parent 22df0a0ba0
commit 6864e43ce4
4 changed files with 22 additions and 7 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
data/ data/
__pycache__

View File

@ -0,0 +1,5 @@
This is a multi-class, multi-label network that categorises text into one of ~160 categories, mostly relating to the African continent.
The training dataset is a proprietry dataset, consisting of news articles that have been manually categorised.
The trained model is freely available, as is the training and evaluation code; the dataset is unfortunately not.

View File

@ -191,3 +191,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# vim: set expandtab shiftwidth=2 softtabstop=2:

View File

@ -14,8 +14,7 @@ import random
import pandas as pd import pandas as pd
import numpy as np import numpy as np
import itertools import itertools
#from pandarallel import pandarallel import tqdm
from tqdm import tqdm
# torch # torch
import torch import torch
import torchdata.datapipes as dp import torchdata.datapipes as dp
@ -30,15 +29,15 @@ story_num = 64 # XXX None for all
# Hyperparameters # Hyperparameters
EPOCHS = 10 # epoch EPOCHS = 10 # epoch
#LR = 5 # learning rate #LR = 5 # learning rate
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
BATCH_SIZE = 64 # batch size for training BATCH_SIZE = 64 # batch size for training
def read_csv(input_csv, rows=None, verbose=0): def read_csv(input_csv, rows=None, verbose=0):
if verbose > 0: if verbose > 0:
with open(input_csv, 'r', encoding="utf-8") as f: with open(input_csv, 'r', encoding="utf-8") as f:
data = pd.concat( data = pd.concat(
[chunk for chunk in tqdm( [chunk for chunk in tqdm.tqdm(
pd.read_csv(f, pd.read_csv(f,
encoding="utf-8", encoding="utf-8",
quoting=csv.QUOTE_ALL, quoting=csv.QUOTE_ALL,
@ -345,7 +344,6 @@ def hot_one_to_labels(vocab, tensor):
def train(dataloader, dataset, model, optimizer, criterion): def train(dataloader, dataset, model, optimizer, criterion):
model.train()
total_acc, total_count = 0, 0 total_acc, total_count = 0, 0
log_interval = 500 log_interval = 500
start_time = time.time() start_time = time.time()
@ -377,6 +375,7 @@ def train(dataloader, dataset, model, optimizer, criterion):
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)] [pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
return return
total_acc += (output == cats).sum().item() total_acc += (output == cats).sum().item()
total_count += cats.size(0) total_count += cats.size(0)
if idx % log_interval == 0 and idx > 0: if idx % log_interval == 0 and idx > 0:
@ -556,8 +555,15 @@ def main():
total_accu = None total_accu = None
for epoch in range(1, EPOCHS + 1): for epoch in range(1, EPOCHS + 1):
model.train()
epoch_start_time = time.time() epoch_start_time = time.time()
train(train_dataloader, train_dataset, model, optimizer, criterion) with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar:
bar.set_description(f"Epoch {epoch}")
train(train_dataloader, train_dataset, model, optimizer, criterion)
bar.set_postfix(
# loss=float(loss),
# acc=float(acc)
)
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion) accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
if total_accu is not None and total_accu > accu_val: if total_accu is not None and total_accu > accu_val:
scheduler.step() scheduler.step()
@ -581,3 +587,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# vim: set expandtab shiftwidth=2 softtabstop=2: