Add possible split between training and validation data

This commit is contained in:
Timothy Allen 2023-11-30 02:00:56 +02:00
parent da6f0142e0
commit 235c58f3c5
1 changed files with 49 additions and 3 deletions

View File

@ -63,6 +63,30 @@ data.dropna(axis='index', inplace=True)
#print(data)
#sys.exit(0)
'''
#######################################################
# Create Training and Validation sets
#######################################################
# create a list of ints till len of data
data_idx = list(range(len(data)))
np.random.shuffle(data_idx)
# get indexes for validation and train
val_frac = 0.1 # precentage of data in validation set
val_split_idx = int(len(data)*val_frac) # index on which to split (10% of data)
val_idx, train_idx = data_idx[:val_split_idx], data_idx[val_split_idx:]
print('len of train: ', len(train_idx))
print('len of val: ', len(val_idx))
# create the training and validation sets, as dataframes
train_data = data.iloc[train_idx].reset_index().drop('index',axis=1)
valid_data = data.iloc[val_idx].reset_index().drop('index',axis=1)
# Next, we create Pytorch Datasets and Dataloaders for these dataframes
'''
'''
Create a dataset that builds a tokenised vocabulary,
and then, as each row is accessed, transforms it into
@ -167,6 +191,16 @@ dataset = TextCategoriesDataset(df=data,
text_column="content",
cats_column="categories",
)
'''
train_dataset = TextCategoriesDataset(df=train_data,
text_column="content",
cats_column="categories",
)
valid_dataset = TextCategoriesDataset(df=valid_data,
text_column="content",
cats_column="categories",
)
'''
#print(dataset[2])
#for text, cat in dataset:
# print(text, cat)
@ -197,14 +231,26 @@ class Collate:
T.ToTensor(self.pad_idx)(list(batch[1])),
)
pad_idx = dataset.stoi['<pad>']
dataloader = DataLoader(dataset,
batch_size=4,
shuffle=True,
num_workers=0,
collate_fn=Collate(pad_idx=pad_idx),
collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
)
'''
train_dataloader = DataLoader(train_dataset,
batch_size=4,
shuffle=True,
num_workers=0,
collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
)
valid_dataloader = DataLoader(valid_dataset,
batch_size=4,
shuffle=True,
num_workers=0,
collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
)
'''
#for i_batch, sample_batched in enumerate(dataloader):
# print(i_batch, sample_batched[0], sample_batched[1])
#sys.exit(0)