Metadata
This commit is contained in:
parent
a871c9235c
commit
94025fc0c6
1
.gitignore
vendored
1
.gitignore
vendored
@ -1 +1,2 @@
|
|||||||
data/
|
data/
|
||||||
|
__pycache__
|
||||||
|
@ -0,0 +1,5 @@
|
|||||||
|
This is a multi-class, multi-label network that categorises text into one of ~160 categories, mostly relating to the African continent.
|
||||||
|
|
||||||
|
The training dataset is a proprietry dataset from allAfrica.com, consisting of stories that have been manuially categorised according to AllAfrica's inhouse categorisation scheme.
|
||||||
|
|
||||||
|
The trained model is freely available, as is the training and evaluation code.
|
@ -191,3 +191,5 @@ def main():
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
# vim: set expandtab shiftwidth=2 softtabstop=2:
|
||||||
|
@ -14,8 +14,7 @@ import random
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import itertools
|
import itertools
|
||||||
#from pandarallel import pandarallel
|
import tqdm
|
||||||
from tqdm import tqdm
|
|
||||||
# torch
|
# torch
|
||||||
import torch
|
import torch
|
||||||
import torchdata.datapipes as dp
|
import torchdata.datapipes as dp
|
||||||
@ -30,15 +29,15 @@ story_num = 64 # XXX None for all
|
|||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
EPOCHS = 10 # epoch
|
EPOCHS = 10 # epoch
|
||||||
#LR = 5 # learning rate
|
#LR = 5 # learning rate
|
||||||
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||||
BATCH_SIZE = 64 # batch size for training
|
BATCH_SIZE = 64 # batch size for training
|
||||||
|
|
||||||
def read_csv(input_csv, rows=None, verbose=0):
|
def read_csv(input_csv, rows=None, verbose=0):
|
||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
with open(input_csv, 'r', encoding="utf-8") as f:
|
with open(input_csv, 'r', encoding="utf-8") as f:
|
||||||
data = pd.concat(
|
data = pd.concat(
|
||||||
[chunk for chunk in tqdm(
|
[chunk for chunk in tqdm.tqdm(
|
||||||
pd.read_csv(f,
|
pd.read_csv(f,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
quoting=csv.QUOTE_ALL,
|
quoting=csv.QUOTE_ALL,
|
||||||
@ -345,7 +344,6 @@ def hot_one_to_labels(vocab, tensor):
|
|||||||
|
|
||||||
|
|
||||||
def train(dataloader, dataset, model, optimizer, criterion):
|
def train(dataloader, dataset, model, optimizer, criterion):
|
||||||
model.train()
|
|
||||||
total_acc, total_count = 0, 0
|
total_acc, total_count = 0, 0
|
||||||
log_interval = 500
|
log_interval = 500
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -377,6 +375,7 @@ def train(dataloader, dataset, model, optimizer, criterion):
|
|||||||
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
|
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
total_acc += (output == cats).sum().item()
|
total_acc += (output == cats).sum().item()
|
||||||
total_count += cats.size(0)
|
total_count += cats.size(0)
|
||||||
if idx % log_interval == 0 and idx > 0:
|
if idx % log_interval == 0 and idx > 0:
|
||||||
@ -556,8 +555,15 @@ def main():
|
|||||||
|
|
||||||
total_accu = None
|
total_accu = None
|
||||||
for epoch in range(1, EPOCHS + 1):
|
for epoch in range(1, EPOCHS + 1):
|
||||||
|
model.train()
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
train(train_dataloader, train_dataset, model, optimizer, criterion)
|
with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar:
|
||||||
|
bar.set_description(f"Epoch {epoch}")
|
||||||
|
train(train_dataloader, train_dataset, model, optimizer, criterion)
|
||||||
|
bar.set_postfix(
|
||||||
|
# loss=float(loss),
|
||||||
|
# acc=float(acc)
|
||||||
|
)
|
||||||
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
|
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
|
||||||
if total_accu is not None and total_accu > accu_val:
|
if total_accu is not None and total_accu > accu_val:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
@ -581,3 +587,4 @@ def main():
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
# vim: set expandtab shiftwidth=2 softtabstop=2:
|
||||||
|
Loading…
Reference in New Issue
Block a user