Metadata
This commit is contained in:
parent
a871c9235c
commit
94025fc0c6
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
||||
data/
|
||||
__pycache__
|
||||
|
@ -0,0 +1,5 @@
|
||||
This is a multi-class, multi-label network that categorises text into one of ~160 categories, mostly relating to the African continent.
|
||||
|
||||
The training dataset is a proprietry dataset from allAfrica.com, consisting of stories that have been manuially categorised according to AllAfrica's inhouse categorisation scheme.
|
||||
|
||||
The trained model is freely available, as is the training and evaluation code.
|
@ -191,3 +191,5 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# vim: set expandtab shiftwidth=2 softtabstop=2:
|
||||
|
@ -14,8 +14,7 @@ import random
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import itertools
|
||||
#from pandarallel import pandarallel
|
||||
from tqdm import tqdm
|
||||
import tqdm
|
||||
# torch
|
||||
import torch
|
||||
import torchdata.datapipes as dp
|
||||
@ -30,15 +29,15 @@ story_num = 64 # XXX None for all
|
||||
|
||||
# Hyperparameters
|
||||
EPOCHS = 10 # epoch
|
||||
#LR = 5 # learning rate
|
||||
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||
#LR = 5 # learning rate
|
||||
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||
BATCH_SIZE = 64 # batch size for training
|
||||
|
||||
def read_csv(input_csv, rows=None, verbose=0):
|
||||
if verbose > 0:
|
||||
with open(input_csv, 'r', encoding="utf-8") as f:
|
||||
data = pd.concat(
|
||||
[chunk for chunk in tqdm(
|
||||
[chunk for chunk in tqdm.tqdm(
|
||||
pd.read_csv(f,
|
||||
encoding="utf-8",
|
||||
quoting=csv.QUOTE_ALL,
|
||||
@ -345,7 +344,6 @@ def hot_one_to_labels(vocab, tensor):
|
||||
|
||||
|
||||
def train(dataloader, dataset, model, optimizer, criterion):
|
||||
model.train()
|
||||
total_acc, total_count = 0, 0
|
||||
log_interval = 500
|
||||
start_time = time.time()
|
||||
@ -377,6 +375,7 @@ def train(dataloader, dataset, model, optimizer, criterion):
|
||||
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
|
||||
|
||||
return
|
||||
|
||||
total_acc += (output == cats).sum().item()
|
||||
total_count += cats.size(0)
|
||||
if idx % log_interval == 0 and idx > 0:
|
||||
@ -556,8 +555,15 @@ def main():
|
||||
|
||||
total_accu = None
|
||||
for epoch in range(1, EPOCHS + 1):
|
||||
model.train()
|
||||
epoch_start_time = time.time()
|
||||
train(train_dataloader, train_dataset, model, optimizer, criterion)
|
||||
with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar:
|
||||
bar.set_description(f"Epoch {epoch}")
|
||||
train(train_dataloader, train_dataset, model, optimizer, criterion)
|
||||
bar.set_postfix(
|
||||
# loss=float(loss),
|
||||
# acc=float(acc)
|
||||
)
|
||||
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
|
||||
if total_accu is not None and total_accu > accu_val:
|
||||
scheduler.step()
|
||||
@ -581,3 +587,4 @@ def main():
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
# vim: set expandtab shiftwidth=2 softtabstop=2:
|
||||
|
Loading…
Reference in New Issue
Block a user