From 31319bab0c0f9fbd65ec75bfbbd37c156cd2e678 Mon Sep 17 00:00:00 2001
From: tim <tim@treehouse.org.za>
Date: Thu, 30 Nov 2023 02:00:56 +0200
Subject: [PATCH] Add possible split between training and validation data

---
 categorise.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++++---
 1 file changed, 49 insertions(+), 3 deletions(-)

diff --git a/categorise.py b/categorise.py
index 2834bd6..6428a86 100755
--- a/categorise.py
+++ b/categorise.py
@@ -63,6 +63,30 @@ data.dropna(axis='index', inplace=True)
 #print(data)
 #sys.exit(0)
 
+'''
+#######################################################
+#               Create Training and Validation sets
+#######################################################
+
+# create a list of ints till len of data
+data_idx = list(range(len(data)))
+np.random.shuffle(data_idx)
+
+# get indexes for validation and train
+val_frac = 0.1 # precentage of data in validation set
+val_split_idx = int(len(data)*val_frac) # index on which to split (10% of data)
+val_idx, train_idx = data_idx[:val_split_idx], data_idx[val_split_idx:]
+print('len of train: ', len(train_idx))
+print('len of val: ', len(val_idx))
+
+# create the training and validation sets, as dataframes
+train_data = data.iloc[train_idx].reset_index().drop('index',axis=1)
+valid_data = data.iloc[val_idx].reset_index().drop('index',axis=1)
+
+# Next, we create Pytorch Datasets and Dataloaders for these dataframes
+'''
+
+
 '''
   Create a dataset that builds a tokenised vocabulary,
   and then, as each row is accessed, transforms it into
@@ -167,6 +191,16 @@ dataset = TextCategoriesDataset(df=data,
   text_column="content",
   cats_column="categories",
 )
+'''
+train_dataset = TextCategoriesDataset(df=train_data,
+  text_column="content",
+  cats_column="categories",
+)
+valid_dataset = TextCategoriesDataset(df=valid_data,
+  text_column="content",
+  cats_column="categories",
+)
+'''
 #print(dataset[2])
 #for text, cat in dataset:
 #  print(text, cat)
@@ -197,14 +231,26 @@ class Collate:
       T.ToTensor(self.pad_idx)(list(batch[1])),
     )
 
-
-pad_idx = dataset.stoi['<pad>']
 dataloader = DataLoader(dataset,
   batch_size=4,
   shuffle=True,
   num_workers=0,
-  collate_fn=Collate(pad_idx=pad_idx),
+  collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
 )
+'''
+train_dataloader = DataLoader(train_dataset,
+  batch_size=4,
+  shuffle=True,
+  num_workers=0,
+  collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
+)
+valid_dataloader = DataLoader(valid_dataset,
+  batch_size=4,
+  shuffle=True,
+  num_workers=0,
+  collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
+)
+'''
 #for i_batch, sample_batched in enumerate(dataloader):
 #  print(i_batch, sample_batched[0], sample_batched[1])
 #sys.exit(0)