Fix evaluation, as well as progress reporting.
This commit is contained in:
parent
94025fc0c6
commit
c9a9e24619
@ -25,13 +25,25 @@ from torch import nn
|
|||||||
|
|
||||||
from models.rnn import RNN
|
from models.rnn import RNN
|
||||||
|
|
||||||
story_num = 64 # XXX None for all
|
all_categories = list()
|
||||||
|
# XXX None for all stories
|
||||||
|
#story_num = 128
|
||||||
|
#story_num = 256
|
||||||
|
story_num = 512
|
||||||
|
#story_num = 1024
|
||||||
|
#story_num = None
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
EPOCHS = 10 # epoch
|
EPOCHS = 10 # epoch
|
||||||
|
#EPOCHS = 2 # epoch
|
||||||
#LR = 5 # learning rate
|
#LR = 5 # learning rate
|
||||||
LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
#LR = 0.5
|
||||||
BATCH_SIZE = 64 # batch size for training
|
LR = 0.05
|
||||||
|
#LR = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||||
|
#BATCH_SIZE = 64 # batch size for training
|
||||||
|
#BATCH_SIZE = 16 # batch size for training
|
||||||
|
BATCH_SIZE = 8 # batch size for training
|
||||||
|
#BATCH_SIZE = 4 # batch size for training
|
||||||
|
|
||||||
def read_csv(input_csv, rows=None, verbose=0):
|
def read_csv(input_csv, rows=None, verbose=0):
|
||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
@ -136,7 +148,8 @@ class TextCategoriesDataset(Dataset):
|
|||||||
#print(self.text_vocab.get_itos())
|
#print(self.text_vocab.get_itos())
|
||||||
|
|
||||||
self.cats_vocab = build_vocab_from_iterator(
|
self.cats_vocab = build_vocab_from_iterator(
|
||||||
[self.catTokens(cats) for i, cats in self.df[cats_column].items()],
|
#[self.catTokens(cats) for i, cats in self.df[cats_column].items()],
|
||||||
|
[self.catTokens(all_categories)],
|
||||||
min_freq=1,
|
min_freq=1,
|
||||||
specials=['<unk>'],
|
specials=['<unk>'],
|
||||||
special_first=True
|
special_first=True
|
||||||
@ -163,6 +176,10 @@ class TextCategoriesDataset(Dataset):
|
|||||||
if self.transform:
|
if self.transform:
|
||||||
text, cats = self.transform(text, cats)
|
text, cats = self.transform(text, cats)
|
||||||
|
|
||||||
|
#print(cats)
|
||||||
|
#print(self.catTokens(cats))
|
||||||
|
#print(self.getTransform(self.cats_vocab, "cats")(self.catTokens(cats)))
|
||||||
|
|
||||||
# Numericalise by applying transforms
|
# Numericalise by applying transforms
|
||||||
return (
|
return (
|
||||||
self.getTransform(self.text_vocab, "text")(self.textTokens(text)),
|
self.getTransform(self.text_vocab, "text")(self.textTokens(text)),
|
||||||
@ -217,11 +234,12 @@ class CollateBatch:
|
|||||||
in a batch of equal length. We can do this a collate_fn callback class,
|
in a batch of equal length. We can do this a collate_fn callback class,
|
||||||
which returns a tensor
|
which returns a tensor
|
||||||
'''
|
'''
|
||||||
def __init__(self, pad_idx):
|
def __init__(self, pad_idx, cats):
|
||||||
'''
|
'''
|
||||||
pad_idx (int): the index of the "<pad>" token in the vocabulary.
|
pad_idx (int): the index of the "<pad>" token in the vocabulary.
|
||||||
'''
|
'''
|
||||||
self.pad_idx = pad_idx
|
self.pad_idx = pad_idx
|
||||||
|
self.cats = cats
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
'''
|
'''
|
||||||
@ -250,42 +268,30 @@ class CollateBatch:
|
|||||||
#)
|
#)
|
||||||
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
||||||
|
|
||||||
|
'''
|
||||||
# Pad cats_tensor to all possible categories
|
# Pad cats_tensor to all possible categories
|
||||||
# TODO will this be necessary with larger training sets, that should
|
num_cats = len(all_categories)
|
||||||
# encompass all categories? Best to be safe...
|
|
||||||
all_cats = list(set(itertools.chain(*batch_cats)))
|
|
||||||
num_cats = len(all_cats)
|
|
||||||
# if there's no 0, there was no <unk>, so increment to allow for it to be a possible category
|
|
||||||
if 0 not in all_cats:
|
|
||||||
num_cats += 1
|
|
||||||
|
|
||||||
# Convert cats to tensor
|
|
||||||
#cats_tensor = nn.utils.rnn.pad_sequence(
|
|
||||||
# [torch.LongTensor(s) for s in batch_cats],
|
|
||||||
# batch_first=True, padding_value=self.pad_idx
|
|
||||||
#)
|
|
||||||
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
|
||||||
|
|
||||||
# Convert cats to tensor, alt version
|
|
||||||
#cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).long()
|
|
||||||
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
|
||||||
#for idx, (c, clen) in enumerate(zip(batch_cats, cats_lengths)):
|
|
||||||
# cats_tensor[idx, :clen] = torch.LongTensor(c)
|
|
||||||
#print([torch.LongTensor(s) for s in batch_cats])
|
|
||||||
#print(torch.LongTensor([torch.LongTensor(s) for s in batch_cats]))
|
|
||||||
#cats_tensor = nn.functional.one_hot(torch.LongTensor([torch.LongTensor(s) for s in batch_cats]), num_cats)
|
|
||||||
#cats_tensor = nn.functional.one_hot(torch.FloatTensor(batch_cats), num_cats)
|
|
||||||
|
|
||||||
# Convert cats to multi-label one-hot representation
|
# Convert cats to multi-label one-hot representation
|
||||||
# This will be a target for CrossEntropyLoss(pred, target),
|
|
||||||
# which takes FloatTensor pred and LongTensor target
|
|
||||||
cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).float()
|
cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).float()
|
||||||
cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
||||||
for idx, cats in enumerate(batch_cats):
|
for idx, cats in enumerate(batch_cats):
|
||||||
#print("\nsample", idx, cats)
|
#print("\nsample", idx, cats)
|
||||||
for c in cats:
|
for c in cats:
|
||||||
cats_tensor[idx][c] = 1
|
#print(c)
|
||||||
#print(cats_tensor[idx])
|
cats_tensor[idx][c] = 1
|
||||||
|
#print(cats_tensor[idx])
|
||||||
|
'''
|
||||||
|
# Convert cats to multi-label one-hot representation
|
||||||
|
# add one to all_categories to account for <unk>
|
||||||
|
cats_tensor = torch.full((len(batch_cats), len(all_categories)+1), self.pad_idx).float()
|
||||||
|
for idx, cats in enumerate(batch_cats):
|
||||||
|
#print("\nsample", idx, cats)
|
||||||
|
for c in cats:
|
||||||
|
cats_tensor[idx][c] = 1
|
||||||
|
#print(cats_tensor[idx])
|
||||||
|
#sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
# XXX why??
|
# XXX why??
|
||||||
@ -310,10 +316,10 @@ class CollateBatch:
|
|||||||
text_lengths,
|
text_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
def labels_to_hot_one(label_vocab, labels, pad_idx: int):
|
def cat2tensor(label_vocab, labels, pad_idx: int):
|
||||||
all_labels = vocab.get_itos()
|
all_labels = vocab.get_itos()
|
||||||
num_labels = len(all_labels)
|
num_labels = len(all_labels)
|
||||||
# if there's no 0, there was no <unk>, so increment to allow for it to be a possible category
|
# add <unk>
|
||||||
if 0 not in all_labels:
|
if 0 not in all_labels:
|
||||||
num_labels += 1
|
num_labels += 1
|
||||||
|
|
||||||
@ -326,40 +332,49 @@ def labels_to_hot_one(label_vocab, labels, pad_idx: int):
|
|||||||
#print(labels_tensor[idx])
|
#print(labels_tensor[idx])
|
||||||
return labels_tensor
|
return labels_tensor
|
||||||
|
|
||||||
def hot_one_to_labels(vocab, tensor):
|
def tensor2cat(vocab, tensor):
|
||||||
|
all_cats = vocab.get_itos()
|
||||||
if tensor.ndimension() == 2:
|
if tensor.ndimension() == 2:
|
||||||
all_labels = vocab.get_itos()
|
|
||||||
batch = list()
|
batch = list()
|
||||||
for result in tensor:
|
for result in tensor:
|
||||||
chance = dict()
|
chance = dict()
|
||||||
for idx, pred in enumerate(result):
|
for idx, pred in enumerate(result):
|
||||||
if pred > 0: # XXX
|
if pred > 0: # XXX
|
||||||
chance[all_labels[idx]] = pred.item()
|
chance[all_cats[idx]] = pred.item()
|
||||||
print(chance)
|
#print(chance)
|
||||||
batch.append(chance)
|
batch.append(chance)
|
||||||
return batch
|
return batch
|
||||||
|
elif tensor.ndimension() == 1:
|
||||||
|
chance = dict()
|
||||||
|
for idx, pred in enumerate(tensor):
|
||||||
|
if idx >= len(all_cats):
|
||||||
|
print(f"Idx {idx} not in {len(all_cats)} categories")
|
||||||
|
elif pred > 0: # XXX
|
||||||
|
#print(idx, len(all_cats))
|
||||||
|
chance[all_cats[idx]] = pred.item()
|
||||||
|
#print(chance)
|
||||||
|
return chance
|
||||||
else:
|
else:
|
||||||
raise ValueError("Only tensors with 2 dimensions are supported")
|
raise ValueError("Only tensors with 2 dimensions are supported")
|
||||||
|
|
||||||
|
return vocab.get_itos(cat)
|
||||||
|
|
||||||
|
|
||||||
def train(dataloader, dataset, model, optimizer, criterion):
|
def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
||||||
total_acc, total_count = 0, 0
|
total_acc, total_count = 0, 0
|
||||||
log_interval = 500
|
log_interval = 500
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
torch.set_printoptions(precision=2)
|
torch.set_printoptions(precision=2)
|
||||||
for idx, (text, cats, text_lengths) in enumerate(dataloader):
|
|
||||||
|
batch = tqdm.tqdm(dataloader, unit="batch")
|
||||||
|
for idx, data in enumerate(batch):
|
||||||
|
batch.set_description(f"Train {epoch}.{idx}")
|
||||||
|
text, cats, text_lengths = data
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
print("text_lengths shape", text_lengths.shape)
|
|
||||||
print("input shape", text.shape)
|
|
||||||
print("target", cats)
|
|
||||||
print("target shape", cats.shape)
|
|
||||||
|
|
||||||
output = model(text, text_lengths)
|
output = model(text, text_lengths)
|
||||||
print("output", output)
|
#print("output", output)
|
||||||
print("output shape", output.shape)
|
#print("output shape", output.shape)
|
||||||
|
|
||||||
loss = criterion(input=output, target=cats)
|
loss = criterion(input=output, target=cats)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -368,45 +383,85 @@ def train(dataloader, dataset, model, optimizer, criterion):
|
|||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
print(loss)
|
#print("train loss",loss)
|
||||||
print("expected", cats)
|
|
||||||
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, cats)]
|
|
||||||
print("predicted", output)
|
|
||||||
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
|
|
||||||
|
|
||||||
return
|
##predicted = np.round(output)
|
||||||
|
##total_acc += (predicted == cats).sum().item()
|
||||||
|
|
||||||
total_acc += (output == cats).sum().item()
|
predictions = torch.zeros(output.shape)
|
||||||
|
predictions[output >= 0.25] = True
|
||||||
|
#predictions[output >= 0.5] = True
|
||||||
|
#predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
|
||||||
|
|
||||||
|
batch.clear()
|
||||||
|
for target, out, pred in list(zip(cats, output, predictions)):
|
||||||
|
expect = tensor2cat(dataset.cats_vocab, target)
|
||||||
|
raw = tensor2cat(dataset.cats_vocab, out)
|
||||||
|
predict = tensor2cat(dataset.cats_vocab, pred)
|
||||||
|
print("Expected: ", expect)
|
||||||
|
print("Predicted: ", predict)
|
||||||
|
print("Raw output:", raw)
|
||||||
|
print("\n")
|
||||||
|
batch.refresh()
|
||||||
|
|
||||||
|
N, C = cats.shape
|
||||||
|
#print("eq", (output == cats))
|
||||||
|
#print("sum", (output == cats).sum())
|
||||||
|
#print("accuracy", (output == cats).sum() / (N*C) * 100)
|
||||||
|
accuracy = (output == cats).sum() / (N*C) * 100
|
||||||
|
total_acc += accuracy
|
||||||
|
#print("train accuracy", accuracy)
|
||||||
|
#print("train total_acc", total_acc)
|
||||||
total_count += cats.size(0)
|
total_count += cats.size(0)
|
||||||
if idx % log_interval == 0 and idx > 0:
|
batch.set_postfix({
|
||||||
elapsed = time.time() - start_time
|
"accuracy": int(total_acc / total_count),
|
||||||
print(
|
})
|
||||||
"| epoch {:3d} | {:5d}/{:5d} batches "
|
total_acc, total_count = 0, 0
|
||||||
"| accuracy {:8.3f}".format(
|
|
||||||
epoch, idx, len(dataloader), total_acc / total_count
|
|
||||||
)
|
|
||||||
)
|
|
||||||
total_acc, total_count = 0, 0
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(dataloader, dataset, model, criterion):
|
def evaluate(dataloader, dataset, model, criterion, epoch=0):
|
||||||
model.eval()
|
model.eval()
|
||||||
total_acc, total_count = 0, 0
|
total_acc, total_count = 0, 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for idx, (text, cats, text_lengths) in enumerate(dataloader):
|
batch = tqdm.tqdm(dataloader, unit="batch")
|
||||||
predicted_label = model(text, text_lengths)
|
for idx, data in enumerate(batch):
|
||||||
print(predicted_label)
|
batch.set_description(f"Evaluate {epoch}.{idx}")
|
||||||
loss = criterion(predicted_label, label)
|
text, cats, text_lengths = data
|
||||||
print(loss)
|
|
||||||
print("expected labels:", label)
|
output = model(text, text_lengths)
|
||||||
print([dataset.cats_vocab.get_itos(i) for i in label])
|
#print("eval predicted", output)
|
||||||
print("predicted labels:", predicted_label)
|
|
||||||
print([dataset.cats_vocab.get_itos(i) for i in predicted_label])
|
loss = criterion(output, cats)
|
||||||
total_acc += (predicted_label.argmax(1) == label).sum().item()
|
#print("eval loss", loss)
|
||||||
total_count += label.size(0)
|
|
||||||
return total_acc / total_count
|
predictions = torch.zeros(output.shape)
|
||||||
|
predictions[output >= 0.5] = True
|
||||||
|
predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
|
||||||
|
|
||||||
|
batch.clear()
|
||||||
|
for target, out, pred in list(zip(cats, output, predictions)):
|
||||||
|
expect = tensor2cat(dataset.cats_vocab, target)
|
||||||
|
raw = tensor2cat(dataset.cats_vocab, out)
|
||||||
|
predict = tensor2cat(dataset.cats_vocab, pred)
|
||||||
|
print("Evaluate expected: ", expect)
|
||||||
|
print("Evaluate predicted: ", predict)
|
||||||
|
print("Evaluate raw output:", raw)
|
||||||
|
print("\n")
|
||||||
|
batch.refresh()
|
||||||
|
|
||||||
|
##total_acc += (predicted_cats.argmax(1) == cats).sum().item()
|
||||||
|
N, C = cats.shape
|
||||||
|
accuracy = (predictions == cats).sum() / (N*C) * 100
|
||||||
|
total_acc += accuracy
|
||||||
|
#print("eval accuracy", accuracy)
|
||||||
|
#print("eval total_acc", total_acc)
|
||||||
|
total_count += cats.size(0)
|
||||||
|
|
||||||
|
batch.set_postfix({
|
||||||
|
"accuracy": int(total_acc / total_count),
|
||||||
|
})
|
||||||
|
return total_acc / total_count
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -442,6 +497,18 @@ def main():
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose)
|
data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose)
|
||||||
|
|
||||||
|
# create list of all categories
|
||||||
|
global all_categories
|
||||||
|
for cats in data.categories:
|
||||||
|
for c in cats.split(";"):
|
||||||
|
if c not in all_categories:
|
||||||
|
all_categories.append(c)
|
||||||
|
all_categories = sorted(all_categories)
|
||||||
|
#print(all_categories)
|
||||||
|
#print(len(all_categories))
|
||||||
|
#sys.exit(0)
|
||||||
|
|
||||||
train_data, valid_data, = split_dataset(data, verbose=args.verbose)
|
train_data, valid_data, = split_dataset(data, verbose=args.verbose)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
@ -466,6 +533,7 @@ def main():
|
|||||||
)
|
)
|
||||||
#for text, cat in enumerate(train_dataset):
|
#for text, cat in enumerate(train_dataset):
|
||||||
# print(text, cat)
|
# print(text, cat)
|
||||||
|
#print("-" * 20)
|
||||||
#for text, cat in enumerate(valid_dataset):
|
#for text, cat in enumerate(valid_dataset):
|
||||||
# print(text, cat)
|
# print(text, cat)
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
@ -483,33 +551,30 @@ def main():
|
|||||||
'''
|
'''
|
||||||
dataloader = DataLoader(dataset,
|
dataloader = DataLoader(dataset,
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
|
drop_last=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=CollateBatch(pad_idx=dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
'''
|
'''
|
||||||
train_dataloader = DataLoader(train_dataset,
|
train_dataloader = DataLoader(train_dataset,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
|
drop_last=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
valid_dataloader = DataLoader(valid_dataset,
|
valid_dataloader = DataLoader(valid_dataset,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
|
drop_last=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=CollateBatch(pad_idx=valid_dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
#for i_batch, sample_batched in enumerate(dataloader):
|
#for i_batch, sample_batched in enumerate(dataloader):
|
||||||
# print(i_batch, sample_batched[0], sample_batched[1])
|
# print(i_batch, sample_batched[0], sample_batched[1])
|
||||||
#for i_batch, sample_batched in enumerate(train_dataloader):
|
#for i_batch, sample_batched in enumerate(train_dataloader):
|
||||||
#print(i_batch, sample_batched[0], sample_batched[1])
|
# print(i_batch, sample_batched[0], sample_batched[1])
|
||||||
#print(i_batch)
|
|
||||||
#print("batch elements:")
|
|
||||||
#for i in sample_batched:
|
|
||||||
# print(i)
|
|
||||||
# print(i.shape)
|
|
||||||
# print("\n")
|
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
|
|
||||||
input_size = len(train_dataset.text_vocab)
|
input_size = len(train_dataset.text_vocab)
|
||||||
@ -547,40 +612,32 @@ def main():
|
|||||||
|
|
||||||
# optimizer and loss
|
# optimizer and loss
|
||||||
#optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
#optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.BCEWithLogitsLoss()
|
||||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay)
|
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay)
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
print(criterion)
|
print(criterion)
|
||||||
print(optimizer)
|
print(optimizer)
|
||||||
|
|
||||||
total_accu = None
|
total_accu = None
|
||||||
for epoch in range(1, EPOCHS + 1):
|
#for epoch in range(1, EPOCHS + 1):
|
||||||
|
e = tqdm.tqdm(range(1, EPOCHS + 1), unit="epoch")
|
||||||
|
for epoch in e:
|
||||||
|
e.set_description(f"Epoch {epoch}")
|
||||||
model.train()
|
model.train()
|
||||||
epoch_start_time = time.time()
|
train(train_dataloader, train_dataset, model, optimizer, criterion, epoch)
|
||||||
with tqdm.trange(BATCH_SIZE, unit="batch", mininterval=0) as bar:
|
|
||||||
bar.set_description(f"Epoch {epoch}")
|
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch)
|
||||||
train(train_dataloader, train_dataset, model, optimizer, criterion)
|
|
||||||
bar.set_postfix(
|
|
||||||
# loss=float(loss),
|
|
||||||
# acc=float(acc)
|
|
||||||
)
|
|
||||||
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
|
|
||||||
if total_accu is not None and total_accu > accu_val:
|
if total_accu is not None and total_accu > accu_val:
|
||||||
scheduler.step()
|
optimizer.step()
|
||||||
else:
|
else:
|
||||||
total_accu = accu_val
|
total_accu = accu_val
|
||||||
print("-" * 59)
|
e.set_postfix({
|
||||||
print(
|
"accuracy": accu_val.int().item(),
|
||||||
"| end of epoch {:3d} | time: {:5.2f}s | "
|
})
|
||||||
"valid accuracy {:8.3f} ".format(
|
|
||||||
epoch, time.time() - epoch_start_time, accu_val
|
|
||||||
)
|
|
||||||
)
|
|
||||||
print("-" * 59)
|
|
||||||
|
|
||||||
print("Checking the results of test dataset.")
|
# print("Checking the results of test dataset.")
|
||||||
accu_test = evaluate(test_dataloader, test_dataset)
|
# accu_test = evaluate(test_dataloader, test_dataset)
|
||||||
print("test accuracy {:8.3f}".format(accu_test))
|
# print("test accuracy {:8.3f}".format(accu_test))
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user