Fix evaluation, as well as progress reporting.
This commit is contained in:
parent
94025fc0c6
commit
c9a9e24619
@ -25,13 +25,25 @@ from torch import nn
|
||||
|
||||
from models.rnn import RNN
|
||||
|
||||
story_num = 64 # XXX None for all
|
||||
all_categories = list()
|
||||
# XXX None for all stories
|
||||
#story_num = 128
|
||||
#story_num = 256
|
||||
story_num = 512
|
||||
#story_num = 1024
|
||||
#story_num = None
|
||||
|
||||
# Hyperparameters
|
||||
EPOCHS = 10 # epoch
|
||||
#EPOCHS = 2 # epoch
|
||||
#LR = 5 # learning rate
|
||||
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||
BATCH_SIZE = 64 # batch size for training
|
||||
#LR = 0.5
|
||||
LR = 0.05
|
||||
#LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||
#BATCH_SIZE = 64 # batch size for training
|
||||
#BATCH_SIZE = 16 # batch size for training
|
||||
BATCH_SIZE = 8 # batch size for training
|
||||
#BATCH_SIZE = 4 # batch size for training
|
||||
|
||||
def read_csv(input_csv, rows=None, verbose=0):
|
||||
if verbose > 0:
|
||||
@ -136,7 +148,8 @@ class TextCategoriesDataset(Dataset):
|
||||
#print(self.text_vocab.get_itos())
|
||||
|
||||
self.cats_vocab = build_vocab_from_iterator(
|
||||
[self.catTokens(cats) for i, cats in self.df[cats_column].items()],
|
||||
#[self.catTokens(cats) for i, cats in self.df[cats_column].items()],
|
||||
[self.catTokens(all_categories)],
|
||||
min_freq=1,
|
||||
specials=['<unk>'],
|
||||
special_first=True
|
||||
@ -163,6 +176,10 @@ class TextCategoriesDataset(Dataset):
|
||||
if self.transform:
|
||||
text, cats = self.transform(text, cats)
|
||||
|
||||
#print(cats)
|
||||
#print(self.catTokens(cats))
|
||||
#print(self.getTransform(self.cats_vocab, "cats")(self.catTokens(cats)))
|
||||
|
||||
# Numericalise by applying transforms
|
||||
return (
|
||||
self.getTransform(self.text_vocab, "text")(self.textTokens(text)),
|
||||
@ -217,11 +234,12 @@ class CollateBatch:
|
||||
in a batch of equal length. We can do this a collate_fn callback class,
|
||||
which returns a tensor
|
||||
'''
|
||||
def __init__(self, pad_idx):
|
||||
def __init__(self, pad_idx, cats):
|
||||
'''
|
||||
pad_idx (int): the index of the "<pad>" token in the vocabulary.
|
||||
'''
|
||||
self.pad_idx = pad_idx
|
||||
self.cats = cats
|
||||
|
||||
def __call__(self, batch):
|
||||
'''
|
||||
@ -250,42 +268,30 @@ class CollateBatch:
|
||||
#)
|
||||
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
||||
|
||||
'''
|
||||
# Pad cats_tensor to all possible categories
|
||||
# TODO will this be necessary with larger training sets, that should
|
||||
# encompass all categories? Best to be safe...
|
||||
all_cats = list(set(itertools.chain(*batch_cats)))
|
||||
num_cats = len(all_cats)
|
||||
# if there's no 0, there was no <unk>, so increment to allow for it to be a possible category
|
||||
if 0 not in all_cats:
|
||||
num_cats += 1
|
||||
|
||||
# Convert cats to tensor
|
||||
#cats_tensor = nn.utils.rnn.pad_sequence(
|
||||
# [torch.LongTensor(s) for s in batch_cats],
|
||||
# batch_first=True, padding_value=self.pad_idx
|
||||
#)
|
||||
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
||||
|
||||
# Convert cats to tensor, alt version
|
||||
#cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).long()
|
||||
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
||||
#for idx, (c, clen) in enumerate(zip(batch_cats, cats_lengths)):
|
||||
# cats_tensor[idx, :clen] = torch.LongTensor(c)
|
||||
#print([torch.LongTensor(s) for s in batch_cats])
|
||||
#print(torch.LongTensor([torch.LongTensor(s) for s in batch_cats]))
|
||||
#cats_tensor = nn.functional.one_hot(torch.LongTensor([torch.LongTensor(s) for s in batch_cats]), num_cats)
|
||||
#cats_tensor = nn.functional.one_hot(torch.FloatTensor(batch_cats), num_cats)
|
||||
num_cats = len(all_categories)
|
||||
|
||||
# Convert cats to multi-label one-hot representation
|
||||
# This will be a target for CrossEntropyLoss(pred, target),
|
||||
# which takes FloatTensor pred and LongTensor target
|
||||
cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).float()
|
||||
cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
||||
for idx, cats in enumerate(batch_cats):
|
||||
#print("\nsample", idx, cats)
|
||||
for c in cats:
|
||||
#print(c)
|
||||
cats_tensor[idx][c] = 1
|
||||
#print(cats_tensor[idx])
|
||||
'''
|
||||
# Convert cats to multi-label one-hot representation
|
||||
# add one to all_categories to account for <unk>
|
||||
cats_tensor = torch.full((len(batch_cats), len(all_categories)+1), self.pad_idx).float()
|
||||
for idx, cats in enumerate(batch_cats):
|
||||
#print("\nsample", idx, cats)
|
||||
for c in cats:
|
||||
cats_tensor[idx][c] = 1
|
||||
#print(cats_tensor[idx])
|
||||
#sys.exit(0)
|
||||
|
||||
|
||||
'''
|
||||
# XXX why??
|
||||
@ -310,10 +316,10 @@ class CollateBatch:
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
def labels_to_hot_one(label_vocab, labels, pad_idx: int):
|
||||
def cat2tensor(label_vocab, labels, pad_idx: int):
|
||||
all_labels = vocab.get_itos()
|
||||
num_labels = len(all_labels)
|
||||
# if there's no 0, there was no <unk>, so increment to allow for it to be a possible category
|
||||
# add <unk>
|
||||
if 0 not in all_labels:
|
||||
num_labels += 1
|
||||
|
||||
@ -326,40 +332,49 @@ def labels_to_hot_one(label_vocab, labels, pad_idx: int):
|
||||
#print(labels_tensor[idx])
|
||||
return labels_tensor
|
||||
|
||||
def hot_one_to_labels(vocab, tensor):
|
||||
def tensor2cat(vocab, tensor):
|
||||
all_cats = vocab.get_itos()
|
||||
if tensor.ndimension() == 2:
|
||||
all_labels = vocab.get_itos()
|
||||
batch = list()
|
||||
for result in tensor:
|
||||
chance = dict()
|
||||
for idx, pred in enumerate(result):
|
||||
if pred > 0: # XXX
|
||||
chance[all_labels[idx]] = pred.item()
|
||||
print(chance)
|
||||
chance[all_cats[idx]] = pred.item()
|
||||
#print(chance)
|
||||
batch.append(chance)
|
||||
return batch
|
||||
elif tensor.ndimension() == 1:
|
||||
chance = dict()
|
||||
for idx, pred in enumerate(tensor):
|
||||
if idx >= len(all_cats):
|
||||
print(f"Idx {idx} not in {len(all_cats)} categories")
|
||||
elif pred > 0: # XXX
|
||||
#print(idx, len(all_cats))
|
||||
chance[all_cats[idx]] = pred.item()
|
||||
#print(chance)
|
||||
return chance
|
||||
else:
|
||||
raise ValueError("Only tensors with 2 dimensions are supported")
|
||||
|
||||
return vocab.get_itos(cat)
|
||||
|
||||
|
||||
def train(dataloader, dataset, model, optimizer, criterion):
|
||||
def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
||||
total_acc, total_count = 0, 0
|
||||
log_interval = 500
|
||||
start_time = time.time()
|
||||
|
||||
torch.set_printoptions(precision=2)
|
||||
for idx, (text, cats, text_lengths) in enumerate(dataloader):
|
||||
|
||||
batch = tqdm.tqdm(dataloader, unit="batch")
|
||||
for idx, data in enumerate(batch):
|
||||
batch.set_description(f"Train {epoch}.{idx}")
|
||||
text, cats, text_lengths = data
|
||||
optimizer.zero_grad()
|
||||
|
||||
print("text_lengths shape", text_lengths.shape)
|
||||
print("input shape", text.shape)
|
||||
print("target", cats)
|
||||
print("target shape", cats.shape)
|
||||
|
||||
output = model(text, text_lengths)
|
||||
print("output", output)
|
||||
print("output shape", output.shape)
|
||||
#print("output", output)
|
||||
#print("output shape", output.shape)
|
||||
|
||||
loss = criterion(input=output, target=cats)
|
||||
loss.backward()
|
||||
@ -368,44 +383,84 @@ def train(dataloader, dataset, model, optimizer, criterion):
|
||||
|
||||
optimizer.step()
|
||||
|
||||
print(loss)
|
||||
print("expected", cats)
|
||||
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, cats)]
|
||||
print("predicted", output)
|
||||
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
|
||||
#print("train loss",loss)
|
||||
|
||||
return
|
||||
##predicted = np.round(output)
|
||||
##total_acc += (predicted == cats).sum().item()
|
||||
|
||||
total_acc += (output == cats).sum().item()
|
||||
predictions = torch.zeros(output.shape)
|
||||
predictions[output >= 0.25] = True
|
||||
#predictions[output >= 0.5] = True
|
||||
#predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
|
||||
|
||||
batch.clear()
|
||||
for target, out, pred in list(zip(cats, output, predictions)):
|
||||
expect = tensor2cat(dataset.cats_vocab, target)
|
||||
raw = tensor2cat(dataset.cats_vocab, out)
|
||||
predict = tensor2cat(dataset.cats_vocab, pred)
|
||||
print("Expected: ", expect)
|
||||
print("Predicted: ", predict)
|
||||
print("Raw output:", raw)
|
||||
print("\n")
|
||||
batch.refresh()
|
||||
|
||||
N, C = cats.shape
|
||||
#print("eq", (output == cats))
|
||||
#print("sum", (output == cats).sum())
|
||||
#print("accuracy", (output == cats).sum() / (N*C) * 100)
|
||||
accuracy = (output == cats).sum() / (N*C) * 100
|
||||
total_acc += accuracy
|
||||
#print("train accuracy", accuracy)
|
||||
#print("train total_acc", total_acc)
|
||||
total_count += cats.size(0)
|
||||
if idx % log_interval == 0 and idx > 0:
|
||||
elapsed = time.time() - start_time
|
||||
print(
|
||||
"| epoch {:3d} | {:5d}/{:5d} batches "
|
||||
"| accuracy {:8.3f}".format(
|
||||
epoch, idx, len(dataloader), total_acc / total_count
|
||||
)
|
||||
)
|
||||
batch.set_postfix({
|
||||
"accuracy": int(total_acc / total_count),
|
||||
})
|
||||
total_acc, total_count = 0, 0
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
def evaluate(dataloader, dataset, model, criterion):
|
||||
def evaluate(dataloader, dataset, model, criterion, epoch=0):
|
||||
model.eval()
|
||||
total_acc, total_count = 0, 0
|
||||
|
||||
with torch.no_grad():
|
||||
for idx, (text, cats, text_lengths) in enumerate(dataloader):
|
||||
predicted_label = model(text, text_lengths)
|
||||
print(predicted_label)
|
||||
loss = criterion(predicted_label, label)
|
||||
print(loss)
|
||||
print("expected labels:", label)
|
||||
print([dataset.cats_vocab.get_itos(i) for i in label])
|
||||
print("predicted labels:", predicted_label)
|
||||
print([dataset.cats_vocab.get_itos(i) for i in predicted_label])
|
||||
total_acc += (predicted_label.argmax(1) == label).sum().item()
|
||||
total_count += label.size(0)
|
||||
batch = tqdm.tqdm(dataloader, unit="batch")
|
||||
for idx, data in enumerate(batch):
|
||||
batch.set_description(f"Evaluate {epoch}.{idx}")
|
||||
text, cats, text_lengths = data
|
||||
|
||||
output = model(text, text_lengths)
|
||||
#print("eval predicted", output)
|
||||
|
||||
loss = criterion(output, cats)
|
||||
#print("eval loss", loss)
|
||||
|
||||
predictions = torch.zeros(output.shape)
|
||||
predictions[output >= 0.5] = True
|
||||
predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
|
||||
|
||||
batch.clear()
|
||||
for target, out, pred in list(zip(cats, output, predictions)):
|
||||
expect = tensor2cat(dataset.cats_vocab, target)
|
||||
raw = tensor2cat(dataset.cats_vocab, out)
|
||||
predict = tensor2cat(dataset.cats_vocab, pred)
|
||||
print("Evaluate expected: ", expect)
|
||||
print("Evaluate predicted: ", predict)
|
||||
print("Evaluate raw output:", raw)
|
||||
print("\n")
|
||||
batch.refresh()
|
||||
|
||||
##total_acc += (predicted_cats.argmax(1) == cats).sum().item()
|
||||
N, C = cats.shape
|
||||
accuracy = (predictions == cats).sum() / (N*C) * 100
|
||||
total_acc += accuracy
|
||||
#print("eval accuracy", accuracy)
|
||||
#print("eval total_acc", total_acc)
|
||||
total_count += cats.size(0)
|
||||
|
||||
batch.set_postfix({
|
||||
"accuracy": int(total_acc / total_count),
|
||||
})
|
||||
return total_acc / total_count
|
||||
|
||||
|
||||
@ -442,6 +497,18 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose)
|
||||
|
||||
# create list of all categories
|
||||
global all_categories
|
||||
for cats in data.categories:
|
||||
for c in cats.split(";"):
|
||||
if c not in all_categories:
|
||||
all_categories.append(c)
|
||||
all_categories = sorted(all_categories)
|
||||
#print(all_categories)
|
||||
#print(len(all_categories))
|
||||
#sys.exit(0)
|
||||
|
||||
train_data, valid_data, = split_dataset(data, verbose=args.verbose)
|
||||
|
||||
'''
|
||||
@ -466,6 +533,7 @@ def main():
|
||||
)
|
||||
#for text, cat in enumerate(train_dataset):
|
||||
# print(text, cat)
|
||||
#print("-" * 20)
|
||||
#for text, cat in enumerate(valid_dataset):
|
||||
# print(text, cat)
|
||||
#sys.exit(0)
|
||||
@ -483,33 +551,30 @@ def main():
|
||||
'''
|
||||
dataloader = DataLoader(dataset,
|
||||
batch_size=4,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
collate_fn=CollateBatch(pad_idx=dataset.stoi['<pad>']),
|
||||
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
||||
)
|
||||
'''
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
||||
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
||||
)
|
||||
valid_dataloader = DataLoader(valid_dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
drop_last=True,
|
||||
shuffle=True,
|
||||
num_workers=0,
|
||||
collate_fn=CollateBatch(pad_idx=valid_dataset.stoi['<pad>']),
|
||||
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
||||
)
|
||||
#for i_batch, sample_batched in enumerate(dataloader):
|
||||
# print(i_batch, sample_batched[0], sample_batched[1])
|
||||
#for i_batch, sample_batched in enumerate(train_dataloader):
|
||||
# print(i_batch, sample_batched[0], sample_batched[1])
|
||||
#print(i_batch)
|
||||
#print("batch elements:")
|
||||
#for i in sample_batched:
|
||||
# print(i)
|
||||
# print(i.shape)
|
||||
# print("\n")
|
||||
#sys.exit(0)
|
||||
|
||||
input_size = len(train_dataset.text_vocab)
|
||||
@ -547,40 +612,32 @@ def main():
|
||||
|
||||
# optimizer and loss
|
||||
#optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay)
|
||||
if args.verbose:
|
||||
print(criterion)
|
||||
print(optimizer)
|
||||
|
||||
total_accu = None
|
||||
for epoch in range(1, EPOCHS + 1):
|
||||
#for epoch in range(1, EPOCHS + 1):
|
||||
e = tqdm.tqdm(range(1, EPOCHS + 1), unit="epoch")
|
||||
for epoch in e:
|
||||
e.set_description(f"Epoch {epoch}")
|
||||
model.train()
|
||||
epoch_start_time = time.time()
|
||||
with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar:
|
||||
bar.set_description(f"Epoch {epoch}")
|
||||
train(train_dataloader, train_dataset, model, optimizer, criterion)
|
||||
bar.set_postfix(
|
||||
# loss=float(loss),
|
||||
# acc=float(acc)
|
||||
)
|
||||
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
|
||||
train(train_dataloader, train_dataset, model, optimizer, criterion, epoch)
|
||||
|
||||
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch)
|
||||
if total_accu is not None and total_accu > accu_val:
|
||||
scheduler.step()
|
||||
optimizer.step()
|
||||
else:
|
||||
total_accu = accu_val
|
||||
print("-" * 59)
|
||||
print(
|
||||
"| end of epoch {:3d} | time: {:5.2f}s | "
|
||||
"valid accuracy {:8.3f} ".format(
|
||||
epoch, time.time() - epoch_start_time, accu_val
|
||||
)
|
||||
)
|
||||
print("-" * 59)
|
||||
e.set_postfix({
|
||||
"accuracy": accu_val.int().item(),
|
||||
})
|
||||
|
||||
print("Checking the results of test dataset.")
|
||||
accu_test = evaluate(test_dataloader, test_dataset)
|
||||
print("test accuracy {:8.3f}".format(accu_test))
|
||||
# print("Checking the results of test dataset.")
|
||||
# accu_test = evaluate(test_dataloader, test_dataset)
|
||||
# print("test accuracy {:8.3f}".format(accu_test))
|
||||
|
||||
return
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user