Switch to SentencePiece for tokenisation and Roberta for the model
This commit is contained in:
parent
910e0c9d24
commit
54db72fd89
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import pprint
|
|
||||||
import re
|
import re
|
||||||
|
import pprint
|
||||||
import string
|
import string
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
# data manupulation
|
# data manupulation
|
||||||
@ -22,14 +22,22 @@ import torchtext.transforms as T
|
|||||||
import torchtext.vocab as vocab
|
import torchtext.vocab as vocab
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER
|
||||||
|
|
||||||
|
# Check for TPU availability in notebook environment
|
||||||
|
tpu_available = os.environ.get('COLAB_TPU_ADDR') is not None
|
||||||
|
|
||||||
|
if tpu_available:
|
||||||
|
import torch_xla
|
||||||
|
import torch_xla_py.xla_model as xm
|
||||||
|
|
||||||
xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
|
xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
|
||||||
xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
|
xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
|
||||||
|
|
||||||
# XXX None for all stories
|
# XXX None for all stories
|
||||||
story_num = 128
|
#story_num = 128
|
||||||
#story_num = 256
|
#story_num = 256
|
||||||
#story_num = 512
|
story_num = 512
|
||||||
#story_num = 1024
|
#story_num = 1024
|
||||||
#story_num = 4096
|
#story_num = 4096
|
||||||
#story_num = None
|
#story_num = None
|
||||||
@ -115,8 +123,12 @@ class TextCategoriesDataset(Dataset):
|
|||||||
self.lang = self.df[lang_column]
|
self.lang = self.df[lang_column]
|
||||||
self.text = self.df[text_column]
|
self.text = self.df[text_column]
|
||||||
self.cats = self.df.iloc[:, first_cats_column:].sort_index(axis="columns")
|
self.cats = self.df.iloc[:, first_cats_column:].sort_index(axis="columns")
|
||||||
|
|
||||||
self.cats_vocab = self.cats.columns
|
self.cats_vocab = self.cats.columns
|
||||||
|
|
||||||
|
self.text_length = self.text.str.len().max()
|
||||||
|
self.num_cats = len(self.cats_vocab)
|
||||||
|
|
||||||
# index-to-token dict
|
# index-to-token dict
|
||||||
# <pad> : padding, used for padding the shorter sentences in a batch
|
# <pad> : padding, used for padding the shorter sentences in a batch
|
||||||
# to match the length of longest sentence in the batch
|
# to match the length of longest sentence in the batch
|
||||||
@ -145,8 +157,9 @@ class TextCategoriesDataset(Dataset):
|
|||||||
cats = self.cats.iloc[idx]
|
cats = self.cats.iloc[idx]
|
||||||
|
|
||||||
#print(self.textTransform()(text))
|
#print(self.textTransform()(text))
|
||||||
#print(cats)
|
#print(type(cats.fillna(0).values.tolist()))
|
||||||
#print(cats.fillna(0).values)
|
#print(cats.fillna(0).values.tolist())
|
||||||
|
#sys.exit(0)
|
||||||
|
|
||||||
if self.transform:
|
if self.transform:
|
||||||
text, cats = self.transform(text, cats)
|
text, cats = self.transform(text, cats)
|
||||||
@ -155,7 +168,7 @@ class TextCategoriesDataset(Dataset):
|
|||||||
# NaN to zeros and stripping the index
|
# NaN to zeros and stripping the index
|
||||||
return (
|
return (
|
||||||
self.textTransform()(text),
|
self.textTransform()(text),
|
||||||
cats.fillna(0).values,
|
cats.fillna(0).values.tolist(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def textTransform(self):
|
def textTransform(self):
|
||||||
@ -167,6 +180,8 @@ class TextCategoriesDataset(Dataset):
|
|||||||
# converts the sentences to indices based on given vocabulary using SentencePiece
|
# converts the sentences to indices based on given vocabulary using SentencePiece
|
||||||
T.SentencePieceTokenizer(xlmr_spm_model_path),
|
T.SentencePieceTokenizer(xlmr_spm_model_path),
|
||||||
T.VocabTransform(torch.hub.load_state_dict_from_url(xlmr_vocab_path)),
|
T.VocabTransform(torch.hub.load_state_dict_from_url(xlmr_vocab_path)),
|
||||||
|
#T.Truncate(self.text_length - 2), # XXX
|
||||||
|
T.Truncate(256 - 3), # XXX
|
||||||
# Add <sos> at beginning of each sentence. 1 because the index
|
# Add <sos> at beginning of each sentence. 1 because the index
|
||||||
# for <sos> in vocabulary is 1 as seen in previous section
|
# for <sos> in vocabulary is 1 as seen in previous section
|
||||||
T.AddToken(self.stoi['<sos>'], begin=True),
|
T.AddToken(self.stoi['<sos>'], begin=True),
|
||||||
@ -221,7 +236,6 @@ class CollateBatch:
|
|||||||
return (
|
return (
|
||||||
text_tensor,
|
text_tensor,
|
||||||
cats_tensor,
|
cats_tensor,
|
||||||
text_lengths,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def tensor2cat(dataset, tensor):
|
def tensor2cat(dataset, tensor):
|
||||||
@ -233,6 +247,7 @@ def tensor2cat(dataset, tensor):
|
|||||||
for idx, pred in enumerate(result):
|
for idx, pred in enumerate(result):
|
||||||
if pred > 0: # XXX
|
if pred > 0: # XXX
|
||||||
chance[cats[idx]] = pred.item()
|
chance[cats[idx]] = pred.item()
|
||||||
|
chance = dict(sorted(chance.items(), key=lambda x : x[1], reverse=True))
|
||||||
batch.append(chance)
|
batch.append(chance)
|
||||||
return batch
|
return batch
|
||||||
elif tensor.ndimension() == 1:
|
elif tensor.ndimension() == 1:
|
||||||
@ -242,24 +257,27 @@ def tensor2cat(dataset, tensor):
|
|||||||
print(f"Idx {idx} not in {len(cats)} categories")
|
print(f"Idx {idx} not in {len(cats)} categories")
|
||||||
elif pred > 0: # XXX
|
elif pred > 0: # XXX
|
||||||
chance[cats[idx]] = pred.item()
|
chance[cats[idx]] = pred.item()
|
||||||
|
chance = dict(sorted(chance.items(), key=lambda x : x[1], reverse=True))
|
||||||
return chance
|
return chance
|
||||||
else:
|
else:
|
||||||
raise ValueError("Only tensors with 1 dimension or batches with 2 dimensions are supported")
|
raise ValueError("Only tensors with 1 dimension or batches with 2 dimensions are supported")
|
||||||
|
|
||||||
|
|
||||||
def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
||||||
total_acc, total_count = 0, 0
|
total_acc, total_count = 0, 1 # XXX
|
||||||
log_interval = 500
|
log_interval = 500
|
||||||
|
|
||||||
torch.set_printoptions(precision=2)
|
torch.set_printoptions(precision=2)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
|
||||||
batch = tqdm.tqdm(dataloader, unit="batch")
|
batch = tqdm.tqdm(dataloader, unit="batch")
|
||||||
for idx, data in enumerate(batch):
|
for idx, data in enumerate(batch):
|
||||||
batch.set_description(f"Train {epoch}.{idx}")
|
batch.set_description(f"Train {epoch}.{idx}")
|
||||||
text, cats, text_lengths = data
|
text, cats = data
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
output = model(text, text_lengths)
|
output = model(text)
|
||||||
#print("output", output)
|
#print("output", output)
|
||||||
#print("output shape", output.shape)
|
#print("output shape", output.shape)
|
||||||
|
|
||||||
@ -282,9 +300,9 @@ def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
|||||||
|
|
||||||
batch.clear()
|
batch.clear()
|
||||||
for target, out, pred in list(zip(cats, output, predictions)):
|
for target, out, pred in list(zip(cats, output, predictions)):
|
||||||
expect = tensor2cat(dataset.cats_vocab, target)
|
expect = tensor2cat(dataset, target)
|
||||||
raw = tensor2cat(dataset.cats_vocab, out)
|
raw = tensor2cat(dataset, out)
|
||||||
predict = tensor2cat(dataset.cats_vocab, pred)
|
predict = tensor2cat(dataset, pred)
|
||||||
print("Expected: ", expect)
|
print("Expected: ", expect)
|
||||||
print("Predicted: ", predict)
|
print("Predicted: ", predict)
|
||||||
print("Raw output:", raw)
|
print("Raw output:", raw)
|
||||||
@ -307,16 +325,17 @@ def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
|||||||
|
|
||||||
|
|
||||||
def evaluate(dataloader, dataset, model, criterion, epoch=0):
|
def evaluate(dataloader, dataset, model, criterion, epoch=0):
|
||||||
|
total_acc, total_count = 0, 1 # XXX
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
total_acc, total_count = 0, 0
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
batch = tqdm.tqdm(dataloader, unit="batch")
|
batch = tqdm.tqdm(dataloader, unit="batch")
|
||||||
for idx, data in enumerate(batch):
|
for idx, data in enumerate(batch):
|
||||||
batch.set_description(f"Evaluate {epoch}.{idx}")
|
batch.set_description(f"Evaluate {epoch}.{idx}")
|
||||||
text, cats, text_lengths = data
|
text, cats = data
|
||||||
|
|
||||||
output = model(text, text_lengths)
|
output = model(text)
|
||||||
#print("eval predicted", output)
|
#print("eval predicted", output)
|
||||||
|
|
||||||
loss = criterion(output, cats)
|
loss = criterion(output, cats)
|
||||||
@ -328,9 +347,9 @@ def evaluate(dataloader, dataset, model, criterion, epoch=0):
|
|||||||
|
|
||||||
batch.clear()
|
batch.clear()
|
||||||
for target, out, pred in list(zip(cats, output, predictions)):
|
for target, out, pred in list(zip(cats, output, predictions)):
|
||||||
expect = tensor2cat(dataset.cats_vocab, target)
|
expect = tensor2cat(dataset, target)
|
||||||
raw = tensor2cat(dataset.cats_vocab, out)
|
raw = tensor2cat(dataset, out)
|
||||||
predict = tensor2cat(dataset.cats_vocab, pred)
|
predict = tensor2cat(dataset, pred)
|
||||||
print("Evaluate expected: ", expect)
|
print("Evaluate expected: ", expect)
|
||||||
print("Evaluate predicted: ", predict)
|
print("Evaluate predicted: ", predict)
|
||||||
print("Evaluate raw output:", raw)
|
print("Evaluate raw output:", raw)
|
||||||
@ -374,7 +393,10 @@ def main():
|
|||||||
help='path of CSV file containing dataset')
|
help='path of CSV file containing dataset')
|
||||||
parser.add_argument('--model', '-m',
|
parser.add_argument('--model', '-m',
|
||||||
#required=True, # XXX
|
#required=True, # XXX
|
||||||
help='path to training model')
|
help='path to load training model')
|
||||||
|
parser.add_argument('--out', '-o',
|
||||||
|
#required=True, # XXX
|
||||||
|
help='path to save training model')
|
||||||
parser.add_argument('--verbose', '-v',
|
parser.add_argument('--verbose', '-v',
|
||||||
type=int, nargs='?',
|
type=int, nargs='?',
|
||||||
const=1, # Default value if -v is supplied
|
const=1, # Default value if -v is supplied
|
||||||
@ -386,7 +408,10 @@ def main():
|
|||||||
print("ERROR: train or classify data")
|
print("ERROR: train or classify data")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if args.action == 'classify' and s.path.isfile(model_storage) is None:
|
model_in = args.model
|
||||||
|
model_out = args.out
|
||||||
|
|
||||||
|
if args.action == 'classify' and (model_in is None or os.path.isfile(model_in) is None):
|
||||||
print("No model found for classification; running training instead")
|
print("No model found for classification; running training instead")
|
||||||
args.action = 'train'
|
args.action = 'train'
|
||||||
|
|
||||||
@ -423,29 +448,34 @@ def main():
|
|||||||
#print("-" * 20)
|
#print("-" * 20)
|
||||||
#for text, cat in enumerate(valid_dataset):
|
#for text, cat in enumerate(valid_dataset):
|
||||||
# print(text, cat)
|
# print(text, cat)
|
||||||
#print(tensor2cat(train_dataset, torch.tensor([0, 0, 0, 1., 0.9])))
|
#print(tensor2cat(train_dataset, torch.tensor([0, 0, 0, 1., 0.9, 1, 0.5, .6])))
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
|
|
||||||
|
# Make everything a bit more reproducible
|
||||||
|
seed_everything(111)
|
||||||
|
|
||||||
# Get cpu, gpu or mps device for training.
|
# Get cpu, gpu or mps device for training.
|
||||||
# Move tensor to the NVIDIA GPU if available
|
# Move tensor to the NVIDIA GPU if available
|
||||||
device = (
|
device = (
|
||||||
"cuda" if torch.cuda.is_available()
|
xm.xla_device() if tpu_available # google
|
||||||
else "xps" if hasattr(torch, "xpu") and torch.xpu.is_available()
|
else "cuda" if torch.cuda.is_available() # nvidia
|
||||||
else "mps" if torch.backends.mps.is_available()
|
else "xps" if hasattr(torch, "xpu") and torch.xpu.is_available() # intel
|
||||||
|
else "mps" if torch.backends.mps.is_available() # mac
|
||||||
else "cpu"
|
else "cpu"
|
||||||
)
|
)
|
||||||
print(f"Using {device} device")
|
print(f"Using {device} device")
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
#epochs = 10 # epoch
|
#epochs = 10 # epoch
|
||||||
epochs = 4 # epoch
|
epochs = 6 # epoch
|
||||||
|
#epochs = 4 # epoch
|
||||||
#lr = 5 # learning rate
|
#lr = 5 # learning rate
|
||||||
#lr = 0.5
|
#lr = 0.5
|
||||||
#lr = 0.05
|
#lr = 0.05
|
||||||
#lr = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
#lr = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
|
||||||
lr = 0.0001
|
lr = 0.0001
|
||||||
batch_size = 64 # batch size for training
|
#batch_size = 64 # batch size for training
|
||||||
#batch_size = 16 # batch size for training
|
batch_size = 16 # batch size for training
|
||||||
#batch_size = 8 # batch size for training
|
#batch_size = 8 # batch size for training
|
||||||
#batch_size = 4 # batch size for training
|
#batch_size = 4 # batch size for training
|
||||||
|
|
||||||
@ -460,10 +490,10 @@ def main():
|
|||||||
|
|
||||||
'''
|
'''
|
||||||
dataloader = DataLoader(dataset,
|
dataloader = DataLoader(dataset,
|
||||||
batch_size=4,
|
batch_size=batch_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=4,
|
||||||
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
'''
|
'''
|
||||||
@ -471,48 +501,44 @@ def main():
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=4,
|
||||||
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
valid_dataloader = DataLoader(valid_dataset,
|
valid_dataloader = DataLoader(valid_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=4,
|
||||||
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
#for i_batch, sample_batched in enumerate(dataloader):
|
#for i_batch, sample_batched in enumerate(dataloader):
|
||||||
# print(i_batch, sample_batched[0], sample_batched[1])
|
# print(i_batch, sample_batched[0], sample_batched[1])
|
||||||
for i_batch, sample_batched in enumerate(train_dataloader):
|
#for i_batch, sample_batched in enumerate(train_dataloader):
|
||||||
print(i_batch, sample_batched[0], sample_batched[1])
|
# print(i_batch, sample_batched[0], sample_batched[1])
|
||||||
sys.exit(0)
|
#sys.exit(0)
|
||||||
|
|
||||||
input_size = len(train_dataset.text_vocab)
|
#input_size = len(train_dataset.text_vocab)
|
||||||
output_size = len(train_dataset.cats_vocab) # every output item is the likelihood of a particular category
|
#output_size = len(train_dataset.cats_vocab) # every output item is the likelihood of a particular category
|
||||||
|
#embed = torch.empty(input_size, len(train_dataset)) # tokens per sample x samples
|
||||||
embed = torch.empty(input_size, len(train_dataset)) # tokens per sample x samples
|
#embedding_size = embed.size(1) # was 64 (should be: samples)
|
||||||
embedding_size = embed.size(1) # was 64 (should be: samples)
|
#input_size = train_dataset.text_length
|
||||||
|
input_size = 768
|
||||||
|
output_size = train_dataset.num_cats
|
||||||
|
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
#for i in train_dataset.text_vocab.get_itos():
|
#for i in train_dataset.text_vocab.get_itos():
|
||||||
# print(i)
|
# print(i)
|
||||||
print("input_size: ", input_size)
|
print("input_size: ", input_size)
|
||||||
print("output_size:", output_size)
|
print("output_size:", output_size)
|
||||||
print("embed shape:", embed.shape)
|
#print("embed shape:", embed.shape)
|
||||||
print("embedding_size:", embedding_size, " (that is, number of samples)")
|
#print("embedding_size:", embedding_size, " (that is, number of samples)")
|
||||||
|
|
||||||
|
classifier_head = RobertaClassificationHead(num_classes=output_size, input_dim=input_size)
|
||||||
|
model = XLMR_BASE_ENCODER.get_model(head=classifier_head)
|
||||||
|
if model_in is not None and os.path.isfile(model_in):
|
||||||
|
model.load_state_dict(torch.load(model_in))
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
model = RNN(
|
|
||||||
#rnn_model='GRU',
|
|
||||||
rnn_model='LSTM',
|
|
||||||
vocab_size=input_size,
|
|
||||||
embed_size=embedding_size,
|
|
||||||
num_output=output_size,
|
|
||||||
use_last=(not mean_seq),
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
embedding_tensor=embed,
|
|
||||||
num_layers=num_layers,
|
|
||||||
batch_first=True
|
|
||||||
)
|
|
||||||
if args.verbose:
|
if args.verbose:
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
@ -530,11 +556,6 @@ def main():
|
|||||||
for epoch in e:
|
for epoch in e:
|
||||||
e.set_description(f"Epoch {epoch}")
|
e.set_description(f"Epoch {epoch}")
|
||||||
|
|
||||||
train_dataset.to(device)
|
|
||||||
valid_dataset.to(device)
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
train(train_dataloader, train_dataset, model, optimizer, criterion, epoch)
|
train(train_dataloader, train_dataset, model, optimizer, criterion, epoch)
|
||||||
|
|
||||||
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch)
|
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch)
|
||||||
@ -544,13 +565,16 @@ def main():
|
|||||||
else:
|
else:
|
||||||
total_accu = accu_val
|
total_accu = accu_val
|
||||||
e.set_postfix({
|
e.set_postfix({
|
||||||
"accuracy": accu_val.int().item(),
|
"accuracy": accu_val,
|
||||||
})
|
})
|
||||||
|
|
||||||
# print("Checking the results of test dataset.")
|
# print("Checking the results of test dataset.")
|
||||||
# accu_test = evaluate(test_dataloader, test_dataset)
|
# accu_test = evaluate(test_dataloader, test_dataset)
|
||||||
# print("test accuracy {:8.3f}".format(accu_test))
|
# print("test accuracy {:8.3f}".format(accu_test))
|
||||||
|
|
||||||
|
if model_out is not None:
|
||||||
|
torch.save(model.state_dict(), model_out)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user