Switch to SentencePiece for tokenisation and Roberta for the model

This commit is contained in:
Timothy Allen 2023-12-30 15:19:52 +02:00
parent 910e0c9d24
commit 54db72fd89

View File

@ -2,10 +2,10 @@
import argparse import argparse
import os import os
import sys
import pprint
import re import re
import pprint
import string import string
import sys
import time import time
import warnings import warnings
# data manupulation # data manupulation
@ -22,14 +22,22 @@ import torchtext.transforms as T
import torchtext.vocab as vocab import torchtext.vocab as vocab
from torch import nn from torch import nn
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from torchtext.models import RobertaClassificationHead, XLMR_BASE_ENCODER
# Check for TPU availability in notebook environment
tpu_available = os.environ.get('COLAB_TPU_ADDR') is not None
if tpu_available:
import torch_xla
import torch_xla_py.xla_model as xm
xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt" xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model" xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
# XXX None for all stories # XXX None for all stories
story_num = 128 #story_num = 128
#story_num = 256 #story_num = 256
#story_num = 512 story_num = 512
#story_num = 1024 #story_num = 1024
#story_num = 4096 #story_num = 4096
#story_num = None #story_num = None
@ -115,8 +123,12 @@ class TextCategoriesDataset(Dataset):
self.lang = self.df[lang_column] self.lang = self.df[lang_column]
self.text = self.df[text_column] self.text = self.df[text_column]
self.cats = self.df.iloc[:, first_cats_column:].sort_index(axis="columns") self.cats = self.df.iloc[:, first_cats_column:].sort_index(axis="columns")
self.cats_vocab = self.cats.columns self.cats_vocab = self.cats.columns
self.text_length = self.text.str.len().max()
self.num_cats = len(self.cats_vocab)
# index-to-token dict # index-to-token dict
# <pad> : padding, used for padding the shorter sentences in a batch # <pad> : padding, used for padding the shorter sentences in a batch
# to match the length of longest sentence in the batch # to match the length of longest sentence in the batch
@ -145,8 +157,9 @@ class TextCategoriesDataset(Dataset):
cats = self.cats.iloc[idx] cats = self.cats.iloc[idx]
#print(self.textTransform()(text)) #print(self.textTransform()(text))
#print(cats) #print(type(cats.fillna(0).values.tolist()))
#print(cats.fillna(0).values) #print(cats.fillna(0).values.tolist())
#sys.exit(0)
if self.transform: if self.transform:
text, cats = self.transform(text, cats) text, cats = self.transform(text, cats)
@ -155,7 +168,7 @@ class TextCategoriesDataset(Dataset):
# NaN to zeros and stripping the index # NaN to zeros and stripping the index
return ( return (
self.textTransform()(text), self.textTransform()(text),
cats.fillna(0).values, cats.fillna(0).values.tolist(),
) )
def textTransform(self): def textTransform(self):
@ -167,6 +180,8 @@ class TextCategoriesDataset(Dataset):
# converts the sentences to indices based on given vocabulary using SentencePiece # converts the sentences to indices based on given vocabulary using SentencePiece
T.SentencePieceTokenizer(xlmr_spm_model_path), T.SentencePieceTokenizer(xlmr_spm_model_path),
T.VocabTransform(torch.hub.load_state_dict_from_url(xlmr_vocab_path)), T.VocabTransform(torch.hub.load_state_dict_from_url(xlmr_vocab_path)),
#T.Truncate(self.text_length - 2), # XXX
T.Truncate(256 - 3), # XXX
# Add <sos> at beginning of each sentence. 1 because the index # Add <sos> at beginning of each sentence. 1 because the index
# for <sos> in vocabulary is 1 as seen in previous section # for <sos> in vocabulary is 1 as seen in previous section
T.AddToken(self.stoi['<sos>'], begin=True), T.AddToken(self.stoi['<sos>'], begin=True),
@ -221,7 +236,6 @@ class CollateBatch:
return ( return (
text_tensor, text_tensor,
cats_tensor, cats_tensor,
text_lengths,
) )
def tensor2cat(dataset, tensor): def tensor2cat(dataset, tensor):
@ -233,6 +247,7 @@ def tensor2cat(dataset, tensor):
for idx, pred in enumerate(result): for idx, pred in enumerate(result):
if pred > 0: # XXX if pred > 0: # XXX
chance[cats[idx]] = pred.item() chance[cats[idx]] = pred.item()
chance = dict(sorted(chance.items(), key=lambda x : x[1], reverse=True))
batch.append(chance) batch.append(chance)
return batch return batch
elif tensor.ndimension() == 1: elif tensor.ndimension() == 1:
@ -242,24 +257,27 @@ def tensor2cat(dataset, tensor):
print(f"Idx {idx} not in {len(cats)} categories") print(f"Idx {idx} not in {len(cats)} categories")
elif pred > 0: # XXX elif pred > 0: # XXX
chance[cats[idx]] = pred.item() chance[cats[idx]] = pred.item()
chance = dict(sorted(chance.items(), key=lambda x : x[1], reverse=True))
return chance return chance
else: else:
raise ValueError("Only tensors with 1 dimension or batches with 2 dimensions are supported") raise ValueError("Only tensors with 1 dimension or batches with 2 dimensions are supported")
def train(dataloader, dataset, model, optimizer, criterion, epoch=0): def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
total_acc, total_count = 0, 0 total_acc, total_count = 0, 1 # XXX
log_interval = 500 log_interval = 500
torch.set_printoptions(precision=2) torch.set_printoptions(precision=2)
model.train()
batch = tqdm.tqdm(dataloader, unit="batch") batch = tqdm.tqdm(dataloader, unit="batch")
for idx, data in enumerate(batch): for idx, data in enumerate(batch):
batch.set_description(f"Train {epoch}.{idx}") batch.set_description(f"Train {epoch}.{idx}")
text, cats, text_lengths = data text, cats = data
optimizer.zero_grad() optimizer.zero_grad()
output = model(text, text_lengths) output = model(text)
#print("output", output) #print("output", output)
#print("output shape", output.shape) #print("output shape", output.shape)
@ -282,9 +300,9 @@ def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
batch.clear() batch.clear()
for target, out, pred in list(zip(cats, output, predictions)): for target, out, pred in list(zip(cats, output, predictions)):
expect = tensor2cat(dataset.cats_vocab, target) expect = tensor2cat(dataset, target)
raw = tensor2cat(dataset.cats_vocab, out) raw = tensor2cat(dataset, out)
predict = tensor2cat(dataset.cats_vocab, pred) predict = tensor2cat(dataset, pred)
print("Expected: ", expect) print("Expected: ", expect)
print("Predicted: ", predict) print("Predicted: ", predict)
print("Raw output:", raw) print("Raw output:", raw)
@ -307,16 +325,17 @@ def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
def evaluate(dataloader, dataset, model, criterion, epoch=0): def evaluate(dataloader, dataset, model, criterion, epoch=0):
total_acc, total_count = 0, 1 # XXX
model.eval() model.eval()
total_acc, total_count = 0, 0
with torch.no_grad(): with torch.no_grad():
batch = tqdm.tqdm(dataloader, unit="batch") batch = tqdm.tqdm(dataloader, unit="batch")
for idx, data in enumerate(batch): for idx, data in enumerate(batch):
batch.set_description(f"Evaluate {epoch}.{idx}") batch.set_description(f"Evaluate {epoch}.{idx}")
text, cats, text_lengths = data text, cats = data
output = model(text, text_lengths) output = model(text)
#print("eval predicted", output) #print("eval predicted", output)
loss = criterion(output, cats) loss = criterion(output, cats)
@ -328,9 +347,9 @@ def evaluate(dataloader, dataset, model, criterion, epoch=0):
batch.clear() batch.clear()
for target, out, pred in list(zip(cats, output, predictions)): for target, out, pred in list(zip(cats, output, predictions)):
expect = tensor2cat(dataset.cats_vocab, target) expect = tensor2cat(dataset, target)
raw = tensor2cat(dataset.cats_vocab, out) raw = tensor2cat(dataset, out)
predict = tensor2cat(dataset.cats_vocab, pred) predict = tensor2cat(dataset, pred)
print("Evaluate expected: ", expect) print("Evaluate expected: ", expect)
print("Evaluate predicted: ", predict) print("Evaluate predicted: ", predict)
print("Evaluate raw output:", raw) print("Evaluate raw output:", raw)
@ -374,7 +393,10 @@ def main():
help='path of CSV file containing dataset') help='path of CSV file containing dataset')
parser.add_argument('--model', '-m', parser.add_argument('--model', '-m',
#required=True, # XXX #required=True, # XXX
help='path to training model') help='path to load training model')
parser.add_argument('--out', '-o',
#required=True, # XXX
help='path to save training model')
parser.add_argument('--verbose', '-v', parser.add_argument('--verbose', '-v',
type=int, nargs='?', type=int, nargs='?',
const=1, # Default value if -v is supplied const=1, # Default value if -v is supplied
@ -386,7 +408,10 @@ def main():
print("ERROR: train or classify data") print("ERROR: train or classify data")
sys.exit(1) sys.exit(1)
if args.action == 'classify' and s.path.isfile(model_storage) is None: model_in = args.model
model_out = args.out
if args.action == 'classify' and (model_in is None or os.path.isfile(model_in) is None):
print("No model found for classification; running training instead") print("No model found for classification; running training instead")
args.action = 'train' args.action = 'train'
@ -423,29 +448,34 @@ def main():
#print("-" * 20) #print("-" * 20)
#for text, cat in enumerate(valid_dataset): #for text, cat in enumerate(valid_dataset):
# print(text, cat) # print(text, cat)
#print(tensor2cat(train_dataset, torch.tensor([0, 0, 0, 1., 0.9]))) #print(tensor2cat(train_dataset, torch.tensor([0, 0, 0, 1., 0.9, 1, 0.5, .6])))
#sys.exit(0) #sys.exit(0)
# Make everything a bit more reproducible
seed_everything(111)
# Get cpu, gpu or mps device for training. # Get cpu, gpu or mps device for training.
# Move tensor to the NVIDIA GPU if available # Move tensor to the NVIDIA GPU if available
device = ( device = (
"cuda" if torch.cuda.is_available() xm.xla_device() if tpu_available # google
else "xps" if hasattr(torch, "xpu") and torch.xpu.is_available() else "cuda" if torch.cuda.is_available() # nvidia
else "mps" if torch.backends.mps.is_available() else "xps" if hasattr(torch, "xpu") and torch.xpu.is_available() # intel
else "mps" if torch.backends.mps.is_available() # mac
else "cpu" else "cpu"
) )
print(f"Using {device} device") print(f"Using {device} device")
# Hyperparameters # Hyperparameters
#epochs = 10 # epoch #epochs = 10 # epoch
epochs = 4 # epoch epochs = 6 # epoch
#epochs = 4 # epoch
#lr = 5 # learning rate #lr = 5 # learning rate
#lr = 0.5 #lr = 0.5
#lr = 0.05 #lr = 0.05
#lr = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate #lr = 0.005 # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
lr = 0.0001 lr = 0.0001
batch_size = 64 # batch size for training #batch_size = 64 # batch size for training
#batch_size = 16 # batch size for training batch_size = 16 # batch size for training
#batch_size = 8 # batch size for training #batch_size = 8 # batch size for training
#batch_size = 4 # batch size for training #batch_size = 4 # batch size for training
@ -460,10 +490,10 @@ def main():
''' '''
dataloader = DataLoader(dataset, dataloader = DataLoader(dataset,
batch_size=4, batch_size=batch_size,
drop_last=True, drop_last=True,
shuffle=True, shuffle=True,
num_workers=0, num_workers=4,
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']), collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
) )
''' '''
@ -471,48 +501,44 @@ def main():
batch_size=batch_size, batch_size=batch_size,
drop_last=True, drop_last=True,
shuffle=True, shuffle=True,
num_workers=0, num_workers=4,
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']), collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
) )
valid_dataloader = DataLoader(valid_dataset, valid_dataloader = DataLoader(valid_dataset,
batch_size=batch_size, batch_size=batch_size,
drop_last=True, drop_last=True,
shuffle=True, shuffle=True,
num_workers=0, num_workers=4,
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']), collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
) )
#for i_batch, sample_batched in enumerate(dataloader): #for i_batch, sample_batched in enumerate(dataloader):
# print(i_batch, sample_batched[0], sample_batched[1]) # print(i_batch, sample_batched[0], sample_batched[1])
for i_batch, sample_batched in enumerate(train_dataloader): #for i_batch, sample_batched in enumerate(train_dataloader):
print(i_batch, sample_batched[0], sample_batched[1]) # print(i_batch, sample_batched[0], sample_batched[1])
sys.exit(0) #sys.exit(0)
input_size = len(train_dataset.text_vocab) #input_size = len(train_dataset.text_vocab)
output_size = len(train_dataset.cats_vocab) # every output item is the likelihood of a particular category #output_size = len(train_dataset.cats_vocab) # every output item is the likelihood of a particular category
#embed = torch.empty(input_size, len(train_dataset)) # tokens per sample x samples
embed = torch.empty(input_size, len(train_dataset)) # tokens per sample x samples #embedding_size = embed.size(1) # was 64 (should be: samples)
embedding_size = embed.size(1) # was 64 (should be: samples) #input_size = train_dataset.text_length
input_size = 768
output_size = train_dataset.num_cats
if args.verbose: if args.verbose:
#for i in train_dataset.text_vocab.get_itos(): #for i in train_dataset.text_vocab.get_itos():
# print(i) # print(i)
print("input_size: ", input_size) print("input_size: ", input_size)
print("output_size:", output_size) print("output_size:", output_size)
print("embed shape:", embed.shape) #print("embed shape:", embed.shape)
print("embedding_size:", embedding_size, " (that is, number of samples)") #print("embedding_size:", embedding_size, " (that is, number of samples)")
classifier_head = RobertaClassificationHead(num_classes=output_size, input_dim=input_size)
model = XLMR_BASE_ENCODER.get_model(head=classifier_head)
if model_in is not None and os.path.isfile(model_in):
model.load_state_dict(torch.load(model_in))
model.to(device)
model = RNN(
#rnn_model='GRU',
rnn_model='LSTM',
vocab_size=input_size,
embed_size=embedding_size,
num_output=output_size,
use_last=(not mean_seq),
hidden_size=hidden_size,
embedding_tensor=embed,
num_layers=num_layers,
batch_first=True
)
if args.verbose: if args.verbose:
print(model) print(model)
@ -530,11 +556,6 @@ def main():
for epoch in e: for epoch in e:
e.set_description(f"Epoch {epoch}") e.set_description(f"Epoch {epoch}")
train_dataset.to(device)
valid_dataset.to(device)
model.to(device)
model.train()
train(train_dataloader, train_dataset, model, optimizer, criterion, epoch) train(train_dataloader, train_dataset, model, optimizer, criterion, epoch)
accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch) accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch)
@ -544,13 +565,16 @@ def main():
else: else:
total_accu = accu_val total_accu = accu_val
e.set_postfix({ e.set_postfix({
"accuracy": accu_val.int().item(), "accuracy": accu_val,
}) })
# print("Checking the results of test dataset.") # print("Checking the results of test dataset.")
# accu_test = evaluate(test_dataloader, test_dataset) # accu_test = evaluate(test_dataloader, test_dataset)
# print("test accuracy {:8.3f}".format(accu_test)) # print("test accuracy {:8.3f}".format(accu_test))
if model_out is not None:
torch.save(model.state_dict(), model_out)
return return
if __name__ == "__main__": if __name__ == "__main__":