Add extra debugging
This commit is contained in:
parent
54db72fd89
commit
f8eb91fddb
@ -169,6 +169,7 @@ class TextCategoriesDataset(Dataset):
|
||||
return (
|
||||
self.textTransform()(text),
|
||||
cats.fillna(0).values.tolist(),
|
||||
text,
|
||||
)
|
||||
|
||||
def textTransform(self):
|
||||
@ -214,7 +215,7 @@ class CollateBatch:
|
||||
batch: a list of tuples with (text, cats), each of which
|
||||
is a list of tokens
|
||||
'''
|
||||
batch_text, batch_cats = zip(*batch)
|
||||
batch_text, batch_cats, batch_orig = zip(*batch)
|
||||
|
||||
# Pad text to the longest
|
||||
text_tensor = nn.utils.rnn.pad_sequence(
|
||||
@ -236,6 +237,7 @@ class CollateBatch:
|
||||
return (
|
||||
text_tensor,
|
||||
cats_tensor,
|
||||
batch_orig,
|
||||
)
|
||||
|
||||
def tensor2cat(dataset, tensor):
|
||||
@ -267,28 +269,26 @@ def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
||||
total_acc, total_count = 0, 1 # XXX
|
||||
log_interval = 500
|
||||
|
||||
torch.set_printoptions(precision=2)
|
||||
|
||||
model.train()
|
||||
|
||||
batch = tqdm.tqdm(dataloader, unit="batch")
|
||||
for idx, data in enumerate(batch):
|
||||
batch.set_description(f"Train {epoch}.{idx}")
|
||||
text, cats = data
|
||||
text, cats, orig_text = data
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = model(text)
|
||||
#print("output", output)
|
||||
#print("output shape", output.shape)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(input=output, target=cats)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
nn.utils.clip_grad_norm_(model.parameters(), 0.1)
|
||||
|
||||
#nn.utils.clip_grad_norm_(model.parameters(), 0.1)
|
||||
optimizer.step()
|
||||
|
||||
print("train loss", loss)
|
||||
#print("train loss", loss)
|
||||
|
||||
##predicted = np.round(output)
|
||||
##total_acc += (predicted == cats).sum().item()
|
||||
@ -299,10 +299,12 @@ def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
||||
predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
|
||||
|
||||
batch.clear()
|
||||
for target, out, pred in list(zip(cats, output, predictions)):
|
||||
for target, out, pred, orig in list(zip(cats, output, predictions, orig_text)):
|
||||
expect = tensor2cat(dataset, target)
|
||||
raw = tensor2cat(dataset, out)
|
||||
predict = tensor2cat(dataset, pred)
|
||||
print("Text:", orig)
|
||||
print("Loss:", loss.item())
|
||||
print("Expected: ", expect)
|
||||
print("Predicted: ", predict)
|
||||
print("Raw output:", raw)
|
||||
@ -333,7 +335,7 @@ def evaluate(dataloader, dataset, model, criterion, epoch=0):
|
||||
batch = tqdm.tqdm(dataloader, unit="batch")
|
||||
for idx, data in enumerate(batch):
|
||||
batch.set_description(f"Evaluate {epoch}.{idx}")
|
||||
text, cats = data
|
||||
text, cats, orig_text = data
|
||||
|
||||
output = model(text)
|
||||
#print("eval predicted", output)
|
||||
@ -346,10 +348,12 @@ def evaluate(dataloader, dataset, model, criterion, epoch=0):
|
||||
predictions[output < 0.5] = False ## assign 0 label to those with less than 0.5
|
||||
|
||||
batch.clear()
|
||||
for target, out, pred in list(zip(cats, output, predictions)):
|
||||
for target, out, pred, orig in list(zip(cats, output, predictions, orig_text)):
|
||||
expect = tensor2cat(dataset, target)
|
||||
raw = tensor2cat(dataset, out)
|
||||
predict = tensor2cat(dataset, pred)
|
||||
print("Evaluate Text:", orig)
|
||||
print("Evaluate Loss:", loss.item())
|
||||
print("Evaluate expected: ", expect)
|
||||
print("Evaluate predicted: ", predict)
|
||||
print("Evaluate raw output:", raw)
|
||||
@ -465,17 +469,20 @@ def main():
|
||||
)
|
||||
print(f"Using {device} device")
|
||||
|
||||
torch.set_printoptions(precision=2)
|
||||
|
||||
# Hyperparameters
|
||||
#epochs = 10 # epoch
|
||||
epochs = 6 # epoch
|
||||
epochs = 10 # epoch
|
||||
#epochs = 6 # epoch
|
||||
#epochs = 4 # epoch
|
||||
#lr = 5 # learning rate
|
||||
#lr = 0.5
|
||||
#lr = 0.05
|
||||
#lr = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||
lr = 0.0001
|
||||
lr = 0.00005
|
||||
#batch_size = 64 # batch size for training
|
||||
batch_size = 16 # batch size for training
|
||||
batch_size = 32 # batch size for training
|
||||
#batch_size = 16 # batch size for training
|
||||
#batch_size = 8 # batch size for training
|
||||
#batch_size = 4 # batch size for training
|
||||
|
||||
@ -485,8 +492,9 @@ def main():
|
||||
#hidden_size = 8 # hidden size of rnn module, should be tweaked manually
|
||||
mean_seq = True # use mean of rnn output
|
||||
#mean_seq = False # use mean of rnn output
|
||||
weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
||||
#weight_decay = 1e-3 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
||||
#weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
||||
weight_decay = 1e-5 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
||||
|
||||
'''
|
||||
dataloader = DataLoader(dataset,
|
||||
@ -545,7 +553,9 @@ def main():
|
||||
# optimizer and loss
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
#optimizer = torch.optim.SGD(model.parameters(), lr=lr)
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
|
||||
#optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||
|
||||
if args.verbose:
|
||||
print(criterion)
|
||||
print(optimizer)
|
||||
@ -560,17 +570,17 @@ def main():
|
||||
|
||||
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch)
|
||||
|
||||
if total_accu is not None and total_accu > accu_val:
|
||||
optimizer.step()
|
||||
else:
|
||||
total_accu = accu_val
|
||||
#if total_accu is not None and total_accu > accu_val:
|
||||
# optimizer.step()
|
||||
#else:
|
||||
# total_accu = accu_val
|
||||
e.set_postfix({
|
||||
"accuracy": accu_val,
|
||||
})
|
||||
|
||||
# print("Checking the results of test dataset.")
|
||||
# accu_test = evaluate(test_dataloader, test_dataset)
|
||||
# print("test accuracy {:8.3f}".format(accu_test))
|
||||
print("Checking the results of test dataset.")
|
||||
accu_test = evaluate(test_dataloader, test_dataset)
|
||||
print("test accuracy {:8.3f}".format(accu_test))
|
||||
|
||||
if model_out is not None:
|
||||
torch.save(model.state_dict(), model_out)
|
||||
|
Loading…
Reference in New Issue
Block a user