Clean up some minor issues (like iterating over the DataSet) & simplify
This commit is contained in:
parent
235c58f3c5
commit
701c28353d
@ -24,7 +24,7 @@ from torchtext.vocab import build_vocab_from_iterator
|
|||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Classify text data according to categories',
|
description='Classify text data according to categories',
|
||||||
add_help=True,
|
add_help=True,
|
||||||
)
|
)
|
||||||
parser.add_argument('action', help='train or classify')
|
parser.add_argument('action', help='train or classify')
|
||||||
@ -64,27 +64,22 @@ data.dropna(axis='index', inplace=True)
|
|||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
#######################################################
|
Create Training and Validation sets
|
||||||
# Create Training and Validation sets
|
'''
|
||||||
#######################################################
|
# Create a list of ints till len of data
|
||||||
|
|
||||||
# create a list of ints till len of data
|
|
||||||
data_idx = list(range(len(data)))
|
data_idx = list(range(len(data)))
|
||||||
np.random.shuffle(data_idx)
|
np.random.shuffle(data_idx)
|
||||||
|
|
||||||
# get indexes for validation and train
|
# Get indexes for validation and train
|
||||||
val_frac = 0.1 # precentage of data in validation set
|
split_percent = 0.95
|
||||||
val_split_idx = int(len(data)*val_frac) # index on which to split (10% of data)
|
num_train = int(len(data) * split_percent)
|
||||||
val_idx, train_idx = data_idx[:val_split_idx], data_idx[val_split_idx:]
|
valid_idx, train_idx = data_idx[num_train:], data_idx[:num_train]
|
||||||
print('len of train: ', len(train_idx))
|
print("Length of train_data: {}".format(len(train_idx)))
|
||||||
print('len of val: ', len(val_idx))
|
print("Length of valid_data: {}".format(len(valid_idx)))
|
||||||
|
|
||||||
# create the training and validation sets, as dataframes
|
# Create the training and validation sets, as dataframes
|
||||||
train_data = data.iloc[train_idx].reset_index().drop('index',axis=1)
|
train_data = data.iloc[train_idx].reset_index().drop('index', axis=1)
|
||||||
valid_data = data.iloc[val_idx].reset_index().drop('index',axis=1)
|
valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1)
|
||||||
|
|
||||||
# Next, we create Pytorch Datasets and Dataloaders for these dataframes
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
@ -118,7 +113,7 @@ class TextCategoriesDataset(Dataset):
|
|||||||
# replaced by this token
|
# replaced by this token
|
||||||
self.itos = {0: '<pad>', 1:'<sos>', 2:'<eos>', 3: '<unk>'}
|
self.itos = {0: '<pad>', 1:'<sos>', 2:'<eos>', 3: '<unk>'}
|
||||||
# token-to-index dict
|
# token-to-index dict
|
||||||
self.stoi = {k:j for j,k in self.itos.items()}
|
self.stoi = {k:j for j, k in self.itos.items()}
|
||||||
|
|
||||||
# Create vocabularies upon initialisation
|
# Create vocabularies upon initialisation
|
||||||
self.text_vocab = build_vocab_from_iterator(
|
self.text_vocab = build_vocab_from_iterator(
|
||||||
@ -143,6 +138,10 @@ class TextCategoriesDataset(Dataset):
|
|||||||
return len(self.df)
|
return len(self.df)
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
# Enable use as a plain iterator
|
||||||
|
if idx not in self.df.index:
|
||||||
|
raise(StopIteration)
|
||||||
|
|
||||||
if torch.is_tensor(idx):
|
if torch.is_tensor(idx):
|
||||||
idx = idx.tolist()
|
idx = idx.tolist()
|
||||||
|
|
||||||
@ -187,6 +186,7 @@ class TextCategoriesDataset(Dataset):
|
|||||||
T.AddToken(2, begin=False)
|
T.AddToken(2, begin=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
'''
|
||||||
dataset = TextCategoriesDataset(df=data,
|
dataset = TextCategoriesDataset(df=data,
|
||||||
text_column="content",
|
text_column="content",
|
||||||
cats_column="categories",
|
cats_column="categories",
|
||||||
@ -200,9 +200,8 @@ valid_dataset = TextCategoriesDataset(df=valid_data,
|
|||||||
text_column="content",
|
text_column="content",
|
||||||
cats_column="categories",
|
cats_column="categories",
|
||||||
)
|
)
|
||||||
'''
|
|
||||||
#print(dataset[2])
|
#print(dataset[2])
|
||||||
#for text, cat in dataset:
|
#for text, cat in enumerate(valid_dataset):
|
||||||
# print(text, cat)
|
# print(text, cat)
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
|
|
||||||
@ -212,7 +211,7 @@ valid_dataset = TextCategoriesDataset(df=valid_data,
|
|||||||
which can batch, shuffle, and load the data in parallel
|
which can batch, shuffle, and load the data in parallel
|
||||||
'''
|
'''
|
||||||
|
|
||||||
class Collate:
|
class CollateBatch:
|
||||||
'''
|
'''
|
||||||
We need to pad shorter sentences in a batch to make all the sequences
|
We need to pad shorter sentences in a batch to make all the sequences
|
||||||
in a batch of equal length. We can do this a collate_fn callback class,
|
in a batch of equal length. We can do this a collate_fn callback class,
|
||||||
@ -220,37 +219,55 @@ class Collate:
|
|||||||
'''
|
'''
|
||||||
def __init__(self, pad_idx):
|
def __init__(self, pad_idx):
|
||||||
self.pad_idx = pad_idx
|
self.pad_idx = pad_idx
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
# T.ToTensor(0) returns a transform that converts the sequence
|
# T.ToTensor(0) returns a transform that converts the sequence
|
||||||
# to a torch.tensor and also applies padding.
|
# to a torch.tensor and also applies padding.
|
||||||
# pad_idx is passed to the constructor to specify the
|
#
|
||||||
# index of the "<pad>" token in the vocabulary.
|
# pad_idx is passed to the constructor to specify the index of
|
||||||
|
# the "<pad>" token in the vocabulary.
|
||||||
return (
|
return (
|
||||||
T.ToTensor(self.pad_idx)(list(batch[0])),
|
T.ToTensor(self.pad_idx)(list(batch[0])),
|
||||||
T.ToTensor(self.pad_idx)(list(batch[1])),
|
T.ToTensor(self.pad_idx)(list(batch[1])),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
EPOCHS = 10 # epoch
|
||||||
|
LR = 5 # learning rate
|
||||||
|
BATCH_SIZE = 64 # batch size for training
|
||||||
|
|
||||||
|
# Get cpu, gpu or mps device for training.
|
||||||
|
# Move tensor to the NVIDIA GPU if available
|
||||||
|
device = (
|
||||||
|
"cuda" if torch.cuda.is_available()
|
||||||
|
else "xps" if hasattr(torch, "xpu") and torch.xpu.is_available()
|
||||||
|
else "mps" if torch.backends.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
|
print(f"Using {device} device")
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
dataloader = DataLoader(dataset,
|
dataloader = DataLoader(dataset,
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(pad_idx=dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
'''
|
'''
|
||||||
train_dataloader = DataLoader(train_dataset,
|
train_dataloader = DataLoader(train_dataset,
|
||||||
batch_size=4,
|
batch_size=BATCH_SIZE,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
valid_dataloader = DataLoader(valid_dataset,
|
valid_dataloader = DataLoader(valid_dataset,
|
||||||
batch_size=4,
|
batch_size=BATCH_SIZE,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(pad_idx=valid_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
'''
|
|
||||||
#for i_batch, sample_batched in enumerate(dataloader):
|
#for i_batch, sample_batched in enumerate(dataloader):
|
||||||
# print(i_batch, sample_batched[0], sample_batched[1])
|
# print(i_batch, sample_batched[0], sample_batched[1])
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user