This commit is contained in:
Timothy Allen 2023-12-13 20:22:39 +02:00
parent 22df0a0ba0
commit 6864e43ce4
4 changed files with 22 additions and 7 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
data/
__pycache__

View File

@ -0,0 +1,5 @@
This is a multi-class, multi-label network that categorises text into one of ~160 categories, mostly relating to the African continent.
The training dataset is a proprietry dataset, consisting of news articles that have been manually categorised.
The trained model is freely available, as is the training and evaluation code; the dataset is unfortunately not.

View File

@ -191,3 +191,5 @@ def main():
if __name__ == "__main__":
main()
# vim: set expandtab shiftwidth=2 softtabstop=2:

View File

@ -14,8 +14,7 @@ import random
import pandas as pd
import numpy as np
import itertools
#from pandarallel import pandarallel
from tqdm import tqdm
import tqdm
# torch
import torch
import torchdata.datapipes as dp
@ -30,15 +29,15 @@ story_num = 64 # XXX None for all
# Hyperparameters
EPOCHS = 10 # epoch
#LR = 5 # learning rate
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
#LR = 5 # learning rate
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
BATCH_SIZE = 64 # batch size for training
def read_csv(input_csv, rows=None, verbose=0):
if verbose > 0:
with open(input_csv, 'r', encoding="utf-8") as f:
data = pd.concat(
[chunk for chunk in tqdm(
[chunk for chunk in tqdm.tqdm(
pd.read_csv(f,
encoding="utf-8",
quoting=csv.QUOTE_ALL,
@ -345,7 +344,6 @@ def hot_one_to_labels(vocab, tensor):
def train(dataloader, dataset, model, optimizer, criterion):
model.train()
total_acc, total_count = 0, 0
log_interval = 500
start_time = time.time()
@ -377,6 +375,7 @@ def train(dataloader, dataset, model, optimizer, criterion):
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
return
total_acc += (output == cats).sum().item()
total_count += cats.size(0)
if idx % log_interval == 0 and idx > 0:
@ -556,8 +555,15 @@ def main():
total_accu = None
for epoch in range(1, EPOCHS + 1):
model.train()
epoch_start_time = time.time()
train(train_dataloader, train_dataset, model, optimizer, criterion)
with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar:
bar.set_description(f"Epoch {epoch}")
train(train_dataloader, train_dataset, model, optimizer, criterion)
bar.set_postfix(
# loss=float(loss),
# acc=float(acc)
)
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
if total_accu is not None and total_accu > accu_val:
scheduler.step()
@ -581,3 +587,4 @@ def main():
if __name__ == "__main__":
main()
# vim: set expandtab shiftwidth=2 softtabstop=2: