Cleanup, and device-aware training
This commit is contained in:
parent
c9a9e24619
commit
61d32c5286
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,2 +1,2 @@
|
||||
data/
|
||||
data/*
|
||||
__pycache__
|
||||
|
@ -1,4 +1,4 @@
|
||||
This is a multi-class, multi-label network that categorises text into one of ~160 categories, mostly relating to the African continent.
|
||||
This is a multi-class, multi-label NLP network that categorises text into one of ~160 categories, mostly relating to the African continent.
|
||||
|
||||
The training dataset is a proprietry dataset from allAfrica.com, consisting of stories that have been manuially categorised according to AllAfrica's inhouse categorisation scheme.
|
||||
|
||||
|
@ -19,9 +19,9 @@ import tqdm
|
||||
import torch
|
||||
import torchdata.datapipes as dp
|
||||
import torchtext.transforms as T
|
||||
from torchtext.vocab import build_vocab_from_iterator
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchtext.vocab import build_vocab_from_iterator
|
||||
|
||||
from models.rnn import RNN
|
||||
|
||||
@ -29,22 +29,11 @@ all_categories = list()
|
||||
# XXX None for all stories
|
||||
#story_num = 128
|
||||
#story_num = 256
|
||||
story_num = 512
|
||||
#story_num = 512
|
||||
#story_num = 1024
|
||||
story_num = 4096
|
||||
#story_num = None
|
||||
|
||||
# Hyperparameters
|
||||
EPOCHS = 10 # epoch
|
||||
#EPOCHS = 2 # epoch
|
||||
#LR = 5 # learning rate
|
||||
#LR = 0.5
|
||||
LR = 0.05
|
||||
#LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||
#BATCH_SIZE = 64 # batch size for training
|
||||
#BATCH_SIZE = 16 # batch size for training
|
||||
BATCH_SIZE = 8 # batch size for training
|
||||
#BATCH_SIZE = 4 # batch size for training
|
||||
|
||||
def read_csv(input_csv, rows=None, verbose=0):
|
||||
if verbose > 0:
|
||||
with open(input_csv, 'r', encoding="utf-8") as f:
|
||||
@ -349,7 +338,7 @@ def tensor2cat(vocab, tensor):
|
||||
for idx, pred in enumerate(tensor):
|
||||
if idx >= len(all_cats):
|
||||
print(f"Idx {idx} not in {len(all_cats)} categories")
|
||||
elif pred > 0: # XXX
|
||||
#elif pred > 0: # XXX
|
||||
#print(idx, len(all_cats))
|
||||
chance[all_cats[idx]] = pred.item()
|
||||
#print(chance)
|
||||
@ -383,15 +372,15 @@ def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
||||
|
||||
optimizer.step()
|
||||
|
||||
#print("train loss",loss)
|
||||
print("train loss", loss)
|
||||
|
||||
##predicted = np.round(output)
|
||||
##total_acc += (predicted == cats).sum().item()
|
||||
|
||||
predictions = torch.zeros(output.shape)
|
||||
predictions[output >= 0.25] = True
|
||||
#predictions[output >= 0.5] = True
|
||||
#predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
|
||||
#predictions[output >= 0.25] = True
|
||||
predictions[output >= 0.5] = True
|
||||
predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
|
||||
|
||||
batch.clear()
|
||||
for target, out, pred in list(zip(cats, output, predictions)):
|
||||
@ -548,6 +537,28 @@ def main():
|
||||
)
|
||||
print(f"Using {device} device")
|
||||
|
||||
# Hyperparameters
|
||||
#epochs = 10 # epoch
|
||||
epochs = 4 # epoch
|
||||
#lr = 5 # learning rate
|
||||
#lr = 0.5
|
||||
#lr = 0.05
|
||||
#lr = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||
lr = 0.0001
|
||||
batch_size = 64 # batch size for training
|
||||
#batch_size = 16 # batch size for training
|
||||
#batch_size = 8 # batch size for training
|
||||
#batch_size = 4 # batch size for training
|
||||
|
||||
#num_layers = 2 # 2-3 layers should be enough for LTSM
|
||||
num_layers = 3 # 2-3 layers should be enough for LTSM
|
||||
hidden_size = 128 # hidden size of rnn module, should be tweaked manually
|
||||
#hidden_size = 8 # hidden size of rnn module, should be tweaked manually
|
||||
mean_seq = True # use mean of rnn output
|
||||
#mean_seq = False # use mean of rnn output
|
||||
weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
||||
#weight_decay = 1e-3 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
||||
|
||||
'''
|
||||
dataloader = DataLoader(dataset,
|
||||
batch_size=4,
|
||||
@ -558,14 +569,14 @@ def main():
|
||||
)
|
||||
'''
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
||||
)
|
||||
valid_dataloader = DataLoader(valid_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
@ -582,10 +593,6 @@ def main():
|
||||
|
||||
embed = torch.empty(input_size, len(train_dataset)) # tokens per sample x samples
|
||||
embedding_size = embed.size(1) # was 64 (should be: samples)
|
||||
num_layers = 2 # 2-3 layers should be enough for LTSM
|
||||
hidden_size = 128 # hidden size of rnn module, should be tweaked manually
|
||||
mean_seq = True # use mean of rnn output
|
||||
weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
||||
|
||||
if args.verbose:
|
||||
#for i in train_dataset.text_vocab.get_itos():
|
||||
@ -611,22 +618,28 @@ def main():
|
||||
print(model)
|
||||
|
||||
# optimizer and loss
|
||||
#optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay)
|
||||
#optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
|
||||
if args.verbose:
|
||||
print(criterion)
|
||||
print(optimizer)
|
||||
|
||||
total_accu = None
|
||||
#for epoch in range(1, EPOCHS + 1):
|
||||
e = tqdm.tqdm(range(1, EPOCHS + 1), unit="epoch")
|
||||
#for epoch in range(1, epochs + 1):
|
||||
e = tqdm.tqdm(range(1, epochs + 1), unit="epoch")
|
||||
for epoch in e:
|
||||
e.set_description(f"Epoch {epoch}")
|
||||
|
||||
train_dataset.to(device)
|
||||
valid_dataset.to(device)
|
||||
model.to(device)
|
||||
|
||||
model.train()
|
||||
train(train_dataloader, train_dataset, model, optimizer, criterion, epoch)
|
||||
|
||||
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch)
|
||||
|
||||
if total_accu is not None and total_accu > accu_val:
|
||||
optimizer.step()
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user