First working model
This commit is contained in:
parent
fe7870e9d4
commit
723c6d4378
@ -198,11 +198,13 @@ class TextCategoriesDataset(Dataset):
|
|||||||
# for <eos> in vocabulary is 2 as seen in previous section
|
# for <eos> in vocabulary is 2 as seen in previous section
|
||||||
T.AddToken(self.text_vocab['<eos>'], begin=False)
|
T.AddToken(self.text_vocab['<eos>'], begin=False)
|
||||||
)
|
)
|
||||||
else:
|
elif vType == "cats":
|
||||||
return T.Sequential(
|
return T.Sequential(
|
||||||
# converts the sentences to indices based on given vocabulary
|
# converts the sentences to indices based on given vocabulary
|
||||||
T.VocabTransform(vocab=vocab),
|
T.VocabTransform(vocab=vocab),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise Exception('wrong transformation type')
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
@ -237,7 +239,7 @@ class CollateBatch:
|
|||||||
#cats_tensor = T.ToTensor(self.pad_idx)(batch_cats)
|
#cats_tensor = T.ToTensor(self.pad_idx)(batch_cats)
|
||||||
|
|
||||||
# Pad text to the longest
|
# Pad text to the longest
|
||||||
text_tensor = torch.nn.utils.rnn.pad_sequence(
|
text_tensor = nn.utils.rnn.pad_sequence(
|
||||||
[torch.LongTensor(s) for s in batch_text],
|
[torch.LongTensor(s) for s in batch_text],
|
||||||
batch_first=True, padding_value=self.pad_idx
|
batch_first=True, padding_value=self.pad_idx
|
||||||
)
|
)
|
||||||
@ -257,18 +259,44 @@ class CollateBatch:
|
|||||||
# if there's no 0, there was no <unk>, so increment to allow for it to be a possible category
|
# if there's no 0, there was no <unk>, so increment to allow for it to be a possible category
|
||||||
if 0 not in all_cats:
|
if 0 not in all_cats:
|
||||||
num_cats += 1
|
num_cats += 1
|
||||||
cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).long()
|
|
||||||
cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
|
||||||
for idx, (c, clen) in enumerate(zip(batch_cats, cats_lengths)):
|
|
||||||
cats_tensor[idx, :clen] = torch.LongTensor(c)
|
|
||||||
|
|
||||||
|
# Convert cats to tensor
|
||||||
|
#cats_tensor = nn.utils.rnn.pad_sequence(
|
||||||
|
# [torch.LongTensor(s) for s in batch_cats],
|
||||||
|
# batch_first=True, padding_value=self.pad_idx
|
||||||
|
#)
|
||||||
|
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
||||||
|
|
||||||
|
# Convert cats to tensor, alt version
|
||||||
|
#cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).long()
|
||||||
|
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
||||||
|
#for idx, (c, clen) in enumerate(zip(batch_cats, cats_lengths)):
|
||||||
|
# cats_tensor[idx, :clen] = torch.LongTensor(c)
|
||||||
|
#print([torch.LongTensor(s) for s in batch_cats])
|
||||||
|
#print(torch.LongTensor([torch.LongTensor(s) for s in batch_cats]))
|
||||||
|
#cats_tensor = nn.functional.one_hot(torch.LongTensor([torch.LongTensor(s) for s in batch_cats]), num_cats)
|
||||||
|
#cats_tensor = nn.functional.one_hot(torch.FloatTensor(batch_cats), num_cats)
|
||||||
|
|
||||||
|
# Convert cats to multi-label one-hot representation
|
||||||
|
# This will be a target for CrossEntropyLoss(pred, target),
|
||||||
|
# which takes FloatTensor pred and LongTensor target
|
||||||
|
cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).float()
|
||||||
|
cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
||||||
|
for idx, cats in enumerate(batch_cats):
|
||||||
|
#print("\nsample", idx, cats)
|
||||||
|
for c in cats:
|
||||||
|
cats_tensor[idx][c] = 1
|
||||||
|
#print(cats_tensor[idx])
|
||||||
|
|
||||||
|
'''
|
||||||
# XXX why??
|
# XXX why??
|
||||||
## SORT YOUR TENSORS BY LENGTH!
|
## SORT YOUR TENSORS BY LENGTH!
|
||||||
text_lengths, perm_idx = text_lengths.sort(0, descending=True)
|
text_lengths, perm_idx = text_lengths.sort(0, descending=True)
|
||||||
text_tensor = text_tensor[perm_idx]
|
text_tensor = text_tensor[perm_idx]
|
||||||
cats_tensor = cats_tensor[perm_idx]
|
cats_tensor = cats_tensor[perm_idx]
|
||||||
|
'''
|
||||||
|
|
||||||
#print(text_tensor)
|
#print("text", text_tensor)
|
||||||
#print("text shape:", text_tensor.shape)
|
#print("text shape:", text_tensor.shape)
|
||||||
#print(cats_tensor)
|
#print(cats_tensor)
|
||||||
#print("cats shape:", cats_tensor.shape)
|
#print("cats shape:", cats_tensor.shape)
|
||||||
@ -283,13 +311,46 @@ class CollateBatch:
|
|||||||
text_lengths,
|
text_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def labels_to_hot_one(label_vocab, labels, pad_idx: int):
|
||||||
|
all_labels = vocab.get_itos()
|
||||||
|
num_labels = len(all_labels)
|
||||||
|
# if there's no 0, there was no <unk>, so increment to allow for it to be a possible category
|
||||||
|
if 0 not in all_labels:
|
||||||
|
num_labels += 1
|
||||||
|
|
||||||
def train(dataloader, model, optimizer, criterion):
|
labels_tensor = torch.full((len(labels), num_labels), pad_idx).float()
|
||||||
|
labels_lengths = torch.LongTensor(list(map(len, labels)))
|
||||||
|
for idx, labels in enumerate(labels):
|
||||||
|
#print("\nsample", idx, labels)
|
||||||
|
for l in labels:
|
||||||
|
labels_tensor[idx][l] = 1
|
||||||
|
#print(labels_tensor[idx])
|
||||||
|
return labels_tensor
|
||||||
|
|
||||||
|
def hot_one_to_labels(vocab, tensor):
|
||||||
|
if tensor.ndimension() == 2:
|
||||||
|
all_labels = vocab.get_itos()
|
||||||
|
batch = list()
|
||||||
|
for result in tensor:
|
||||||
|
chance = dict()
|
||||||
|
for idx, pred in enumerate(result):
|
||||||
|
if pred > 0: # XXX
|
||||||
|
chance[all_labels[idx]] = pred.item()
|
||||||
|
print(chance)
|
||||||
|
batch.append(chance)
|
||||||
|
return batch
|
||||||
|
else:
|
||||||
|
raise ValueError("Only tensors with 2 dimensions are supported")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def train(dataloader, dataset, model, optimizer, criterion):
|
||||||
model.train()
|
model.train()
|
||||||
total_acc, total_count = 0, 0
|
total_acc, total_count = 0, 0
|
||||||
log_interval = 500
|
log_interval = 500
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
torch.set_printoptions(precision=2)
|
||||||
for idx, (text, cats, text_lengths) in enumerate(dataloader):
|
for idx, (text, cats, text_lengths) in enumerate(dataloader):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
@ -301,23 +362,23 @@ def train(dataloader, model, optimizer, criterion):
|
|||||||
output = model(text, text_lengths)
|
output = model(text, text_lengths)
|
||||||
print("output", output)
|
print("output", output)
|
||||||
print("output shape", output.shape)
|
print("output shape", output.shape)
|
||||||
# reshape output and target for cross entropy loss
|
|
||||||
# output = output.reshape(output.size(0)*output.size(1), -1) # (batch * seq_len x classes)
|
|
||||||
# cats = cats.reshape(-1) # (batch * seq_len), class index
|
|
||||||
# print("output", output)
|
|
||||||
# print("output shape", output.shape)
|
|
||||||
# print("target shape", cats.shape)
|
|
||||||
# print()
|
|
||||||
|
|
||||||
loss = criterion(input=output, target=cats)
|
loss = criterion(input=output, target=cats)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
|
nn.utils.clip_grad_norm_(model.parameters(), 0.1)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
total_acc += (predicted_label.argmax(1) == label).sum().item()
|
print(loss)
|
||||||
total_count += label.size(0)
|
print("expected", cats)
|
||||||
|
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, cats)]
|
||||||
|
print("predicted", output)
|
||||||
|
[pprint.pprint(x) for x in hot_one_to_labels(dataset.cats_vocab, output)]
|
||||||
|
|
||||||
|
return
|
||||||
|
total_acc += (output == cats).sum().item()
|
||||||
|
total_count += cats.size(0)
|
||||||
if idx % log_interval == 0 and idx > 0:
|
if idx % log_interval == 0 and idx > 0:
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
print(
|
print(
|
||||||
@ -330,14 +391,20 @@ def train(dataloader, model, optimizer, criterion):
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
def evaluate(dataloader, model, criterion):
|
def evaluate(dataloader, dataset, model, criterion):
|
||||||
model.eval()
|
model.eval()
|
||||||
total_acc, total_count = 0, 0
|
total_acc, total_count = 0, 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for idx, (label, text) in enumerate(dataloader):
|
for idx, (text, cats, text_lengths) in enumerate(dataloader):
|
||||||
predicted_label = model(text)
|
predicted_label = model(text, text_lengths)
|
||||||
|
print(predicted_label)
|
||||||
loss = criterion(predicted_label, label)
|
loss = criterion(predicted_label, label)
|
||||||
|
print(loss)
|
||||||
|
print("expected labels:", label)
|
||||||
|
print([dataset.cats_vocab.get_itos(i) for i in label])
|
||||||
|
print("predicted labels:", predicted_label)
|
||||||
|
print([dataset.cats_vocab.get_itos(i) for i in predicted_label])
|
||||||
total_acc += (predicted_label.argmax(1) == label).sum().item()
|
total_acc += (predicted_label.argmax(1) == label).sum().item()
|
||||||
total_count += label.size(0)
|
total_count += label.size(0)
|
||||||
return total_acc / total_count
|
return total_acc / total_count
|
||||||
@ -456,12 +523,13 @@ def main():
|
|||||||
mean_seq = True # use mean of rnn output
|
mean_seq = True # use mean of rnn output
|
||||||
weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
|
||||||
|
|
||||||
#for i in train_dataset.text_vocab.get_itos():
|
if args.verbose:
|
||||||
# print(i)
|
#for i in train_dataset.text_vocab.get_itos():
|
||||||
print("input_size: ", input_size)
|
# print(i)
|
||||||
print("output_size:", output_size)
|
print("input_size: ", input_size)
|
||||||
print("embed shape:", embed.shape)
|
print("output_size:", output_size)
|
||||||
print("embedding_size:", embedding_size, " (that is, number of samples)")
|
print("embed shape:", embed.shape)
|
||||||
|
print("embedding_size:", embedding_size, " (that is, number of samples)")
|
||||||
|
|
||||||
model = RNN(
|
model = RNN(
|
||||||
#rnn_model='GRU',
|
#rnn_model='GRU',
|
||||||
@ -475,20 +543,22 @@ def main():
|
|||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
batch_first=True
|
batch_first=True
|
||||||
)
|
)
|
||||||
print(model)
|
if args.verbose:
|
||||||
|
print(model)
|
||||||
|
|
||||||
# optimizer and loss
|
# optimizer and loss
|
||||||
#optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
#optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay)
|
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=weight_decay)
|
||||||
print(criterion)
|
if args.verbose:
|
||||||
print(optimizer)
|
print(criterion)
|
||||||
|
print(optimizer)
|
||||||
|
|
||||||
total_accu = None
|
total_accu = None
|
||||||
for epoch in range(1, EPOCHS + 1):
|
for epoch in range(1, EPOCHS + 1):
|
||||||
epoch_start_time = time.time()
|
epoch_start_time = time.time()
|
||||||
train(train_dataloader, model, optimizer, criterion)
|
train(train_dataloader, train_dataset, model, optimizer, criterion)
|
||||||
accu_val = evaluate(valid_dataloader, model, criterion)
|
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion)
|
||||||
if total_accu is not None and total_accu > accu_val:
|
if total_accu is not None and total_accu > accu_val:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
else:
|
else:
|
||||||
@ -503,7 +573,7 @@ def main():
|
|||||||
print("-" * 59)
|
print("-" * 59)
|
||||||
|
|
||||||
print("Checking the results of test dataset.")
|
print("Checking the results of test dataset.")
|
||||||
accu_test = evaluate(test_dataloader)
|
accu_test = evaluate(test_dataloader, test_dataset)
|
||||||
print("test accuracy {:8.3f}".format(accu_test))
|
print("test accuracy {:8.3f}".format(accu_test))
|
||||||
|
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user