diff --git a/categorise.py b/categorise.py index 6428a86..537bc36 100755 --- a/categorise.py +++ b/categorise.py @@ -24,7 +24,7 @@ from torchtext.vocab import build_vocab_from_iterator from torch.utils.data import Dataset, DataLoader parser = argparse.ArgumentParser( - description='Classify text data according to categories', + description='Classify text data according to categories', add_help=True, ) parser.add_argument('action', help='train or classify') @@ -64,27 +64,22 @@ data.dropna(axis='index', inplace=True) #sys.exit(0) ''' -####################################################### -# Create Training and Validation sets -####################################################### - -# create a list of ints till len of data + Create Training and Validation sets +''' +# Create a list of ints till len of data data_idx = list(range(len(data))) np.random.shuffle(data_idx) -# get indexes for validation and train -val_frac = 0.1 # precentage of data in validation set -val_split_idx = int(len(data)*val_frac) # index on which to split (10% of data) -val_idx, train_idx = data_idx[:val_split_idx], data_idx[val_split_idx:] -print('len of train: ', len(train_idx)) -print('len of val: ', len(val_idx)) +# Get indexes for validation and train +split_percent = 0.95 +num_train = int(len(data) * split_percent) +valid_idx, train_idx = data_idx[num_train:], data_idx[:num_train] +print("Length of train_data: {}".format(len(train_idx))) +print("Length of valid_data: {}".format(len(valid_idx))) -# create the training and validation sets, as dataframes -train_data = data.iloc[train_idx].reset_index().drop('index',axis=1) -valid_data = data.iloc[val_idx].reset_index().drop('index',axis=1) - -# Next, we create Pytorch Datasets and Dataloaders for these dataframes -''' +# Create the training and validation sets, as dataframes +train_data = data.iloc[train_idx].reset_index().drop('index', axis=1) +valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1) ''' @@ -118,7 +113,7 @@ class TextCategoriesDataset(Dataset): # replaced by this token self.itos = {0: '', 1:'', 2:'', 3: ''} # token-to-index dict - self.stoi = {k:j for j,k in self.itos.items()} + self.stoi = {k:j for j, k in self.itos.items()} # Create vocabularies upon initialisation self.text_vocab = build_vocab_from_iterator( @@ -143,6 +138,10 @@ class TextCategoriesDataset(Dataset): return len(self.df) def __getitem__(self, idx): + # Enable use as a plain iterator + if idx not in self.df.index: + raise(StopIteration) + if torch.is_tensor(idx): idx = idx.tolist() @@ -187,6 +186,7 @@ class TextCategoriesDataset(Dataset): T.AddToken(2, begin=False) ) +''' dataset = TextCategoriesDataset(df=data, text_column="content", cats_column="categories", @@ -200,9 +200,8 @@ valid_dataset = TextCategoriesDataset(df=valid_data, text_column="content", cats_column="categories", ) -''' #print(dataset[2]) -#for text, cat in dataset: +#for text, cat in enumerate(valid_dataset): # print(text, cat) #sys.exit(0) @@ -212,7 +211,7 @@ valid_dataset = TextCategoriesDataset(df=valid_data, which can batch, shuffle, and load the data in parallel ''' -class Collate: +class CollateBatch: ''' We need to pad shorter sentences in a batch to make all the sequences in a batch of equal length. We can do this a collate_fn callback class, @@ -220,37 +219,55 @@ class Collate: ''' def __init__(self, pad_idx): self.pad_idx = pad_idx - + def __call__(self, batch): # T.ToTensor(0) returns a transform that converts the sequence # to a torch.tensor and also applies padding. - # pad_idx is passed to the constructor to specify the - # index of the "" token in the vocabulary. + # + # pad_idx is passed to the constructor to specify the index of + # the "" token in the vocabulary. return ( T.ToTensor(self.pad_idx)(list(batch[0])), T.ToTensor(self.pad_idx)(list(batch[1])), ) + +# Hyperparameters +EPOCHS = 10 # epoch +LR = 5 # learning rate +BATCH_SIZE = 64 # batch size for training + +# Get cpu, gpu or mps device for training. +# Move tensor to the NVIDIA GPU if available +device = ( + "cuda" if torch.cuda.is_available() + else "xps" if hasattr(torch, "xpu") and torch.xpu.is_available() + else "mps" if torch.backends.mps.is_available() + else "cpu" +) +print(f"Using {device} device") + + +''' dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0, - collate_fn=Collate(pad_idx=dataset.stoi['']), + collate_fn=CollateBatch(pad_idx=dataset.stoi['']), ) ''' train_dataloader = DataLoader(train_dataset, - batch_size=4, + batch_size=BATCH_SIZE, shuffle=True, num_workers=0, - collate_fn=Collate(pad_idx=dataset.stoi['']), + collate_fn=CollateBatch(pad_idx=train_dataset.stoi['']), ) valid_dataloader = DataLoader(valid_dataset, - batch_size=4, + batch_size=BATCH_SIZE, shuffle=True, num_workers=0, - collate_fn=Collate(pad_idx=dataset.stoi['']), + collate_fn=CollateBatch(pad_idx=valid_dataset.stoi['']), ) -''' #for i_batch, sample_batched in enumerate(dataloader): # print(i_batch, sample_batched[0], sample_batched[1]) #sys.exit(0)