Fix evaluation, as well as progress reporting.

This commit is contained in:
Timothy Allen 2023-12-19 09:26:27 +02:00
parent 94025fc0c6
commit c9a9e24619

View File

@ -25,13 +25,25 @@ from torch import nn
from models.rnn import RNN from models.rnn import RNN
story_num = 64 # XXX None for all all_categories = list()
# XXX None for all stories
#story_num = 128
#story_num = 256
story_num = 512
#story_num = 1024
#story_num = None
# Hyperparameters # Hyperparameters
EPOCHS = 10 # epoch EPOCHS = 10 # epoch
#EPOCHS = 2 # epoch
#LR = 5 # learning rate #LR = 5 # learning rate
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate #LR = 0.5
BATCH_SIZE = 64 # batch size for training LR = 0.05
#LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
#BATCH_SIZE = 64 # batch size for training
#BATCH_SIZE = 16 # batch size for training
BATCH_SIZE = 8 # batch size for training
#BATCH_SIZE = 4 # batch size for training
def read_csv(input_csv, rows=None, verbose=0): def read_csv(input_csv, rows=None, verbose=0):
if verbose > 0: if verbose > 0:
@ -136,7 +148,8 @@ class TextCategoriesDataset(Dataset):
#print(self.text_vocab.get_itos()) #print(self.text_vocab.get_itos())
self.cats_vocab = build_vocab_from_iterator( self.cats_vocab = build_vocab_from_iterator(
[self.catTokens(cats) for i, cats in self.df[cats_column].items()], #[self.catTokens(cats) for i, cats in self.df[cats_column].items()],
[self.catTokens(all_categories)],
min_freq=1, min_freq=1,
specials=['<unk>'], specials=['<unk>'],
special_first=True special_first=True
@ -163,6 +176,10 @@ class TextCategoriesDataset(Dataset):
if self.transform: if self.transform:
text, cats = self.transform(text, cats) text, cats = self.transform(text, cats)
#print(cats)
#print(self.catTokens(cats))
#print(self.getTransform(self.cats_vocab, "cats")(self.catTokens(cats)))
# Numericalise by applying transforms # Numericalise by applying transforms
return ( return (
self.getTransform(self.text_vocab, "text")(self.textTokens(text)), self.getTransform(self.text_vocab, "text")(self.textTokens(text)),
@ -217,11 +234,12 @@ class CollateBatch:
in a batch of equal length. We can do this a collate_fn callback class, in a batch of equal length. We can do this a collate_fn callback class,
which returns a tensor which returns a tensor
''' '''
def __init__(self, pad_idx): def __init__(self, pad_idx, cats):
''' '''
pad_idx (int): the index of the "<pad>" token in the vocabulary. pad_idx (int): the index of the "<pad>" token in the vocabulary.
''' '''
self.pad_idx = pad_idx self.pad_idx = pad_idx
self.cats = cats
def __call__(self, batch): def __call__(self, batch):
''' '''
@ -250,42 +268,30 @@ class CollateBatch:
#) #)
#cats_lengths = torch.LongTensor(list(map(len, batch_cats))) #cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
'''
# Pad cats_tensor to all possible categories # Pad cats_tensor to all possible categories
# TODO will this be necessary with larger training sets, that should num_cats = len(all_categories)
# encompass all categories? Best to be safe...
all_cats = list(set(itertools.chain(*batch_cats)))
num_cats = len(all_cats)
# if there's no 0, there was no <unk>, so increment to allow for it to be a possible category
if 0 not in all_cats:
num_cats += 1
# Convert cats to tensor
#cats_tensor = nn.utils.rnn.pad_sequence(
# [torch.LongTensor(s) for s in batch_cats],
# batch_first=True, padding_value=self.pad_idx
#)
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
# Convert cats to tensor, alt version
#cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).long()
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
#for idx, (c, clen) in enumerate(zip(batch_cats, cats_lengths)):
# cats_tensor[idx, :clen] = torch.LongTensor(c)
#print([torch.LongTensor(s) for s in batch_cats])
#print(torch.LongTensor([torch.LongTensor(s) for s in batch_cats]))
#cats_tensor = nn.functional.one_hot(torch.LongTensor([torch.LongTensor(s) for s in batch_cats]), num_cats)
#cats_tensor = nn.functional.one_hot(torch.FloatTensor(batch_cats), num_cats)
# Convert cats to multi-label one-hot representation # Convert cats to multi-label one-hot representation
# This will be a target for CrossEntropyLoss(pred, target),
# which takes FloatTensor pred and LongTensor target
cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).float() cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).float()
cats_lengths = torch.LongTensor(list(map(len, batch_cats))) cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
for idx, cats in enumerate(batch_cats): for idx, cats in enumerate(batch_cats):
#print("\nsample", idx, cats) #print("\nsample", idx, cats)
for c in cats: for c in cats:
#print(c)
cats_tensor[idx][c] = 1 cats_tensor[idx][c] = 1
#print(cats_tensor[idx]) #print(cats_tensor[idx])
'''
# Convert cats to multi-label one-hot representation
# add one to all_categories to account for <unk>
cats_tensor = torch.full((len(batch_cats), len(all_categories)+1), self.pad_idx).float()
for idx, cats in enumerate(batch_cats):
#print("\nsample", idx, cats)
for c in cats:
cats_tensor[idx][c] = 1
#print(cats_tensor[idx])
#sys.exit(0)
''' '''
# XXX why?? # XXX why??
@ -310,10 +316,10 @@ class CollateBatch:
text_lengths, text_lengths,
) )
def labels_to_hot_one(label_vocab, labels, pad_idx: int): def cat2tensor(label_vocab, labels, pad_idx: int):
all_labels = vocab.get_itos() all_labels = vocab.get_itos()
num_labels = len(all_labels) num_labels = len(all_labels)
# if there's no 0, there was no <unk>, so increment to allow for it to be a possible category # add <unk>
if 0 not in all_labels: if 0 not in all_labels:
num_labels += 1 num_labels += 1
@ -326,40 +332,49 @@ def labels_to_hot_one(label_vocab, labels, pad_idx: int):
#print(labels_tensor[idx]) #print(labels_tensor[idx])
return labels_tensor return labels_tensor
def hot_one_to_labels(vocab, tensor): def tensor2cat(vocab, tensor):
all_cats = vocab.get_itos()
if tensor.ndimension() == 2: if tensor.ndimension() == 2:
all_labels = vocab.get_itos()
batch = list() batch = list()
for result in tensor: for result in tensor:
chance = dict() chance = dict()
for idx, pred in enumerate(result): for idx, pred in enumerate(result):
if pred > 0: # XXX if pred > 0: # XXX
chance[all_labels[idx]] = pred.item() chance[all_cats[idx]] = pred.item()
print(chance) #print(chance)
batch.append(chance) batch.append(chance)
return batch return batch
elif tensor.ndimension() == 1:
chance = dict()
for idx, pred in enumerate(tensor):
if idx >= len(all_cats):
print(f"Idx {idx} not in {len(all_cats)} categories")
elif pred > 0: # XXX
#print(idx, len(all_cats))
chance[all_cats[idx]] = pred.item()
#print(chance)
return chance
else: else:
raise ValueError("Only tensors with 2 dimensions are supported") raise ValueError("Only tensors with 2 dimensions are supported")
return vocab.get_itos(cat)
def train(dataloader, dataset, model, optimizer, criterion): def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
total_acc, total_count = 0, 0 total_acc, total_count = 0, 0
log_interval = 500 log_interval = 500
start_time = time.time()
torch.set_printoptions(precision=2) torch.set_printoptions(precision=2)
for idx, (text, cats, text_lengths) in enumerate(dataloader):
batch = tqdm.tqdm(dataloader, unit="batch")
for idx, data in enumerate(batch):
batch.set_description(f"Train {epoch}.{idx}")
text, cats, text_lengths = data
optimizer.zero_grad() optimizer.zero_grad()
print("text_lengths shape", text_lengths.shape)
print("input shape", text.shape)
print("target", cats)
print("target shape", cats.shape)
output = model(text, text_lengths) output = model(text, text_lengths)
print("output", output) #print("output", output)
print("output shape", output.shape) #print("output shape", output.shape)
loss = criterion(input=output, target=cats) loss = criterion(input=output, target=cats)
loss.backward() loss.backward()
@ -368,44 +383,84 @@ def train(dataloader, dataset, model, optimizer, criterion):
optimizer.step() optimizer.step()
print(loss) #print("train loss",loss)
print("expected", cats)
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, cats)]
print("predicted", output)
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
return ##predicted = np.round(output)
##total_acc += (predicted == cats).sum().item()
total_acc += (output == cats).sum().item() predictions = torch.zeros(output.shape)
predictions[output >= 0.25] = True
#predictions[output >= 0.5] = True
#predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
batch.clear()
for target, out, pred in list(zip(cats, output, predictions)):
expect = tensor2cat(dataset.cats_vocab, target)
raw = tensor2cat(dataset.cats_vocab, out)
predict = tensor2cat(dataset.cats_vocab, pred)
print("Expected: ", expect)
print("Predicted: ", predict)
print("Raw output:", raw)
print("\n")
batch.refresh()
N, C = cats.shape
#print("eq", (output == cats))
#print("sum", (output == cats).sum())
#print("accuracy", (output == cats).sum() / (N*C) * 100)
accuracy = (output == cats).sum() / (N*C) * 100
total_acc += accuracy
#print("train accuracy", accuracy)
#print("train total_acc", total_acc)
total_count += cats.size(0) total_count += cats.size(0)
if idx % log_interval == 0 and idx > 0: batch.set_postfix({
elapsed = time.time() - start_time "accuracy": int(total_acc / total_count),
print( })
"| epoch {:3d} | {:5d}/{:5d} batches "
"| accuracy {:8.3f}".format(
epoch, idx, len(dataloader), total_acc / total_count
)
)
total_acc, total_count = 0, 0 total_acc, total_count = 0, 0
start_time = time.time()
def evaluate(dataloader, dataset, model, criterion): def evaluate(dataloader, dataset, model, criterion, epoch=0):
model.eval() model.eval()
total_acc, total_count = 0, 0 total_acc, total_count = 0, 0
with torch.no_grad(): with torch.no_grad():
for idx, (text, cats, text_lengths) in enumerate(dataloader): batch = tqdm.tqdm(dataloader, unit="batch")
predicted_label = model(text, text_lengths) for idx, data in enumerate(batch):
print(predicted_label) batch.set_description(f"Evaluate {epoch}.{idx}")
loss = criterion(predicted_label, label) text, cats, text_lengths = data
print(loss)
print("expected labels:", label) output = model(text, text_lengths)
print([dataset.cats_vocab.get_itos(i) for i in label]) #print("eval predicted", output)
print("predicted labels:", predicted_label)
print([dataset.cats_vocab.get_itos(i) for i in predicted_label]) loss = criterion(output, cats)
total_acc += (predicted_label.argmax(1) == label).sum().item() #print("eval loss", loss)
total_count += label.size(0)
predictions = torch.zeros(output.shape)
predictions[output >= 0.5] = True
predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
batch.clear()
for target, out, pred in list(zip(cats, output, predictions)):
expect = tensor2cat(dataset.cats_vocab, target)
raw = tensor2cat(dataset.cats_vocab, out)
predict = tensor2cat(dataset.cats_vocab, pred)
print("Evaluate expected: ", expect)
print("Evaluate predicted: ", predict)
print("Evaluate raw output:", raw)
print("\n")
batch.refresh()
##total_acc += (predicted_cats.argmax(1) == cats).sum().item()
N, C = cats.shape
accuracy = (predictions == cats).sum() / (N*C) * 100
total_acc += accuracy
#print("eval accuracy", accuracy)
#print("eval total_acc", total_acc)
total_count += cats.size(0)
batch.set_postfix({
"accuracy": int(total_acc / total_count),
})
return total_acc / total_count return total_acc / total_count
@ -442,6 +497,18 @@ def main():
sys.exit(1) sys.exit(1)
data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose) data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose)
# create list of all categories
global all_categories
for cats in data.categories:
for c in cats.split(";"):
if c not in all_categories:
all_categories.append(c)
all_categories = sorted(all_categories)
#print(all_categories)
#print(len(all_categories))
#sys.exit(0)
train_data, valid_data, = split_dataset(data, verbose=args.verbose) train_data, valid_data, = split_dataset(data, verbose=args.verbose)
''' '''
@ -466,6 +533,7 @@ def main():
) )
#for text, cat in enumerate(train_dataset): #for text, cat in enumerate(train_dataset):
# print(text, cat) # print(text, cat)
#print("-" * 20)
#for text, cat in enumerate(valid_dataset): #for text, cat in enumerate(valid_dataset):
# print(text, cat) # print(text, cat)
#sys.exit(0) #sys.exit(0)
@ -483,33 +551,30 @@ def main():
''' '''
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=4, batch_size=4,
drop_last=True,
shuffle=True, shuffle=True,
num_workers=0, num_workers=0,
collate_fn=CollateBatch(pad_idx=dataset.stoi['<pad>']), collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
) )
''' '''
train_dataloader = DataLoader(train_dataset, train_dataloader = DataLoader(train_dataset,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
drop_last=True,
shuffle=True, shuffle=True,
num_workers=0, num_workers=0,
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']), collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
) )
valid_dataloader = DataLoader(valid_dataset, valid_dataloader = DataLoader(valid_dataset,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
drop_last=True,
shuffle=True, shuffle=True,
num_workers=0, num_workers=0,
collate_fn=CollateBatch(pad_idx=valid_dataset.stoi['<pad>']), collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
) )
#for i_batch, sample_batched in enumerate(dataloader): #for i_batch, sample_batched in enumerate(dataloader):
# print(i_batch, sample_batched[0], sample_batched[1]) # print(i_batch, sample_batched[0], sample_batched[1])
#for i_batch, sample_batched in enumerate(train_dataloader): #for i_batch, sample_batched in enumerate(train_dataloader):
#print(i_batch, sample_batched[0], sample_batched[1]) # print(i_batch, sample_batched[0], sample_batched[1])
#print(i_batch)
#print("batch elements:")
#for i in sample_batched:
# print(i)
# print(i.shape)
# print("\n")
#sys.exit(0) #sys.exit(0)
input_size = len(train_dataset.text_vocab) input_size = len(train_dataset.text_vocab)
@ -547,40 +612,32 @@ def main():
# optimizer and loss # optimizer and loss
#optimizer = torch.optim.SGD(model.parameters(), lr=LR) #optimizer = torch.optim.SGD(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss() criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay) optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay)
if args.verbose: if args.verbose:
print(criterion) print(criterion)
print(optimizer) print(optimizer)
total_accu = None total_accu = None
for epoch in range(1, EPOCHS + 1): #for epoch in range(1, EPOCHS + 1):
e = tqdm.tqdm(range(1, EPOCHS + 1), unit="epoch")
for epoch in e:
e.set_description(f"Epoch {epoch}")
model.train() model.train()
epoch_start_time = time.time() train(train_dataloader, train_dataset, model, optimizer, criterion, epoch)
with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar:
bar.set_description(f"Epoch {epoch}") accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch)
train(train_dataloader, train_dataset, model, optimizer, criterion)
bar.set_postfix(
# loss=float(loss),
# acc=float(acc)
)
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
if total_accu is not None and total_accu > accu_val: if total_accu is not None and total_accu > accu_val:
scheduler.step() optimizer.step()
else: else:
total_accu = accu_val total_accu = accu_val
print("-" * 59) e.set_postfix({
print( "accuracy": accu_val.int().item(),
"| end of epoch {:3d} | time: {:5.2f}s | " })
"valid accuracy {:8.3f} ".format(
epoch, time.time() - epoch_start_time, accu_val
)
)
print("-" * 59)
print("Checking the results of test dataset.") # print("Checking the results of test dataset.")
accu_test = evaluate(test_dataloader, test_dataset) # accu_test = evaluate(test_dataloader, test_dataset)
print("test accuracy {:8.3f}".format(accu_test)) # print("test accuracy {:8.3f}".format(accu_test))
return return