Cleanup, and device-aware training
This commit is contained in:
parent
c9a9e24619
commit
61d32c5286
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,2 +1,2 @@
|
|||||||
data/
|
data/*
|
||||||
__pycache__
|
__pycache__
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
This is a multi-class, multi-label network that categorises text into one of ~160 categories, mostly relating to the African continent.
|
This is a multi-class, multi-label NLP network that categorises text into one of ~160 categories, mostly relating to the African continent.
|
||||||
|
|
||||||
The training dataset is a proprietry dataset from allAfrica.com, consisting of stories that have been manuially categorised according to AllAfrica's inhouse categorisation scheme.
|
The training dataset is a proprietry dataset from allAfrica.com, consisting of stories that have been manuially categorised according to AllAfrica's inhouse categorisation scheme.
|
||||||
|
|
||||||
|
@ -19,9 +19,9 @@ import tqdm
|
|||||||
import torch
|
import torch
|
||||||
import torchdata.datapipes as dp
|
import torchdata.datapipes as dp
|
||||||
import torchtext.transforms as T
|
import torchtext.transforms as T
|
||||||
from torchtext.vocab import build_vocab_from_iterator
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchtext.vocab import build_vocab_from_iterator
|
||||||
|
|
||||||
from models.rnn import RNN
|
from models.rnn import RNN
|
||||||
|
|
||||||
@ -29,22 +29,11 @@ all_categories = list()
|
|||||||
# XXX None for all stories
|
# XXX None for all stories
|
||||||
#story_num = 128
|
#story_num = 128
|
||||||
#story_num = 256
|
#story_num = 256
|
||||||
story_num = 512
|
#story_num = 512
|
||||||
#story_num = 1024
|
#story_num = 1024
|
||||||
|
story_num = 4096
|
||||||
#story_num = None
|
#story_num = None
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
EPOCHS = 10 # epoch
|
|
||||||
#EPOCHS = 2 # epoch
|
|
||||||
#LR = 5 # learning rate
|
|
||||||
#LR = 0.5
|
|
||||||
LR = 0.05
|
|
||||||
#LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
|
||||||
#BATCH_SIZE = 64 # batch size for training
|
|
||||||
#BATCH_SIZE = 16 # batch size for training
|
|
||||||
BATCH_SIZE = 8 # batch size for training
|
|
||||||
#BATCH_SIZE = 4 # batch size for training
|
|
||||||
|
|
||||||
def read_csv(input_csv, rows=None, verbose=0):
|
def read_csv(input_csv, rows=None, verbose=0):
|
||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
with open(input_csv, 'r', encoding="utf-8") as f:
|
with open(input_csv, 'r', encoding="utf-8") as f:
|
||||||
@ -349,9 +338,9 @@ def tensor2cat(vocab, tensor):
|
|||||||
for idx, pred in enumerate(tensor):
|
for idx, pred in enumerate(tensor):
|
||||||
if idx >= len(all_cats):
|
if idx >= len(all_cats):
|
||||||
print(f"Idx {idx} not in {len(all_cats)} categories")
|
print(f"Idx {idx} not in {len(all_cats)} categories")
|
||||||
elif pred > 0: # XXX
|
#elif pred > 0: # XXX
|
||||||
#print(idx, len(all_cats))
|
#print(idx, len(all_cats))
|
||||||
chance[all_cats[idx]] = pred.item()
|
chance[all_cats[idx]] = pred.item()
|
||||||
#print(chance)
|
#print(chance)
|
||||||
return chance
|
return chance
|
||||||
else:
|
else:
|
||||||
@ -383,15 +372,15 @@ def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
|||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
#print("train loss",loss)
|
print("train loss", loss)
|
||||||
|
|
||||||
##predicted = np.round(output)
|
##predicted = np.round(output)
|
||||||
##total_acc += (predicted == cats).sum().item()
|
##total_acc += (predicted == cats).sum().item()
|
||||||
|
|
||||||
predictions = torch.zeros(output.shape)
|
predictions = torch.zeros(output.shape)
|
||||||
predictions[output >= 0.25] = True
|
#predictions[output >= 0.25] = True
|
||||||
#predictions[output >= 0.5] = True
|
predictions[output >= 0.5] = True
|
||||||
#predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
|
predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
|
||||||
|
|
||||||
batch.clear()
|
batch.clear()
|
||||||
for target, out, pred in list(zip(cats, output, predictions)):
|
for target, out, pred in list(zip(cats, output, predictions)):
|
||||||
@ -548,6 +537,28 @@ def main():
|
|||||||
)
|
)
|
||||||
print(f"Using {device} device")
|
print(f"Using {device} device")
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
#epochs = 10 # epoch
|
||||||
|
epochs = 4 # epoch
|
||||||
|
#lr = 5 # learning rate
|
||||||
|
#lr = 0.5
|
||||||
|
#lr = 0.05
|
||||||
|
#lr = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||||
|
lr = 0.0001
|
||||||
|
batch_size = 64 # batch size for training
|
||||||
|
#batch_size = 16 # batch size for training
|
||||||
|
#batch_size = 8 # batch size for training
|
||||||
|
#batch_size = 4 # batch size for training
|
||||||
|
|
||||||
|
#num_layers = 2 # 2-3 layers should be enough for LTSM
|
||||||
|
num_layers = 3 # 2-3 layers should be enough for LTSM
|
||||||
|
hidden_size = 128 # hidden size of rnn module, should be tweaked manually
|
||||||
|
#hidden_size = 8 # hidden size of rnn module, should be tweaked manually
|
||||||
|
mean_seq = True # use mean of rnn output
|
||||||
|
#mean_seq = False # use mean of rnn output
|
||||||
|
weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
||||||
|
#weight_decay = 1e-3 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
||||||
|
|
||||||
'''
|
'''
|
||||||
dataloader = DataLoader(dataset,
|
dataloader = DataLoader(dataset,
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
@ -558,14 +569,14 @@ def main():
|
|||||||
)
|
)
|
||||||
'''
|
'''
|
||||||
train_dataloader = DataLoader(train_dataset,
|
train_dataloader = DataLoader(train_dataset,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=batch_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
valid_dataloader = DataLoader(valid_dataset,
|
valid_dataloader = DataLoader(valid_dataset,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=batch_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
@ -582,10 +593,6 @@ def main():
|
|||||||
|
|
||||||
embed = torch.empty(input_size, len(train_dataset)) # tokens per sample x samples
|
embed = torch.empty(input_size, len(train_dataset)) # tokens per sample x samples
|
||||||
embedding_size = embed.size(1) # was 64 (should be: samples)
|
embedding_size = embed.size(1) # was 64 (should be: samples)
|
||||||
num_layers = 2 # 2-3 layers should be enough for LTSM
|
|
||||||
hidden_size = 128 # hidden size of rnn module, should be tweaked manually
|
|
||||||
mean_seq = True # use mean of rnn output
|
|
||||||
weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
|
||||||
|
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
#for i in train_dataset.text_vocab.get_itos():
|
#for i in train_dataset.text_vocab.get_itos():
|
||||||
@ -611,22 +618,28 @@ def main():
|
|||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
# optimizer and loss
|
# optimizer and loss
|
||||||
#optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
|
||||||
criterion = nn.BCEWithLogitsLoss()
|
criterion = nn.BCEWithLogitsLoss()
|
||||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay)
|
#optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||||
|
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
print(criterion)
|
print(criterion)
|
||||||
print(optimizer)
|
print(optimizer)
|
||||||
|
|
||||||
total_accu = None
|
total_accu = None
|
||||||
#for epoch in range(1, EPOCHS + 1):
|
#for epoch in range(1, epochs + 1):
|
||||||
e = tqdm.tqdm(range(1, EPOCHS + 1), unit="epoch")
|
e = tqdm.tqdm(range(1, epochs + 1), unit="epoch")
|
||||||
for epoch in e:
|
for epoch in e:
|
||||||
e.set_description(f"Epoch {epoch}")
|
e.set_description(f"Epoch {epoch}")
|
||||||
|
|
||||||
|
train_dataset.to(device)
|
||||||
|
valid_dataset.to(device)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
train(train_dataloader, train_dataset, model, optimizer, criterion, epoch)
|
train(train_dataloader, train_dataset, model, optimizer, criterion, epoch)
|
||||||
|
|
||||||
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch)
|
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch)
|
||||||
|
|
||||||
if total_accu is not None and total_accu > accu_val:
|
if total_accu is not None and total_accu > accu_val:
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user