diff --git a/categorise.py b/categorise.py index 2834bd6..6428a86 100755 --- a/categorise.py +++ b/categorise.py @@ -63,6 +63,30 @@ data.dropna(axis='index', inplace=True) #print(data) #sys.exit(0) +''' +####################################################### +# Create Training and Validation sets +####################################################### + +# create a list of ints till len of data +data_idx = list(range(len(data))) +np.random.shuffle(data_idx) + +# get indexes for validation and train +val_frac = 0.1 # precentage of data in validation set +val_split_idx = int(len(data)*val_frac) # index on which to split (10% of data) +val_idx, train_idx = data_idx[:val_split_idx], data_idx[val_split_idx:] +print('len of train: ', len(train_idx)) +print('len of val: ', len(val_idx)) + +# create the training and validation sets, as dataframes +train_data = data.iloc[train_idx].reset_index().drop('index',axis=1) +valid_data = data.iloc[val_idx].reset_index().drop('index',axis=1) + +# Next, we create Pytorch Datasets and Dataloaders for these dataframes +''' + + ''' Create a dataset that builds a tokenised vocabulary, and then, as each row is accessed, transforms it into @@ -167,6 +191,16 @@ dataset = TextCategoriesDataset(df=data, text_column="content", cats_column="categories", ) +''' +train_dataset = TextCategoriesDataset(df=train_data, + text_column="content", + cats_column="categories", +) +valid_dataset = TextCategoriesDataset(df=valid_data, + text_column="content", + cats_column="categories", +) +''' #print(dataset[2]) #for text, cat in dataset: # print(text, cat) @@ -197,14 +231,26 @@ class Collate: T.ToTensor(self.pad_idx)(list(batch[1])), ) - -pad_idx = dataset.stoi[''] dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, - collate_fn=Collate(pad_idx=pad_idx), + collate_fn=Collate(pad_idx=dataset.stoi['']), ) +''' +train_dataloader = DataLoader(train_dataset, + batch_size=4, + shuffle=True, + num_workers=0, + collate_fn=Collate(pad_idx=dataset.stoi['']), +) +valid_dataloader = DataLoader(valid_dataset, + batch_size=4, + shuffle=True, + num_workers=0, + collate_fn=Collate(pad_idx=dataset.stoi['']), +) +''' #for i_batch, sample_batched in enumerate(dataloader): # print(i_batch, sample_batched[0], sample_batched[1]) #sys.exit(0)