From da6f0142e0dd7dfcac0393e9e3cd6d3fce353194 Mon Sep 17 00:00:00 2001 From: tim Date: Thu, 30 Nov 2023 01:53:49 +0200 Subject: [PATCH] First pass at imbibing a CSV of data and turning it into a dataset, and thence into a dataloader --- categorise.py | 212 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 212 insertions(+) create mode 100755 categorise.py diff --git a/categorise.py b/categorise.py new file mode 100755 index 0000000..2834bd6 --- /dev/null +++ b/categorise.py @@ -0,0 +1,212 @@ +#!/usr/bin/python + +import argparse +import os +import sys +import pprint +import re +import string +import warnings + +#data manupulation libs +import csv +import random +import pandas as pd +import numpy as np +#from pandarallel import pandarallel +from tqdm import tqdm + +#torch libs +import torch +import torchdata.datapipes as dp +import torchtext.transforms as T +from torchtext.vocab import build_vocab_from_iterator +from torch.utils.data import Dataset, DataLoader + +parser = argparse.ArgumentParser( + description='Classify text data according to categories', + add_help=True, +) +parser.add_argument('action', help='train or classify') +parser.add_argument('--input', '-i', required=True, help='path of CSV file containing dataset') +parser.add_argument('--output', '-o', help='path to trained model') +args = parser.parse_args() + +if args.action != 'train' and args.action != 'classify': + print("ERROR: train or classify data") + sys.exit(1) + +if args.action == 'classify' and s.path.isfile(model_storage) is None: + print("No model found for classification; running training instead") + args.action = 'train' + +if os.path.isfile(args.input) is False: + print(f"{args.input} is not a valid file") + sys.exit(1) + +#with open(args.input, 'r', encoding="utf-8") as f: +# data = pd.read_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL) + +with open(args.input, 'r', encoding="utf-8") as f: + data = pd.concat( + [chunk for chunk in tqdm( + pd.read_csv(f, + encoding="utf-8", + quoting=csv.QUOTE_ALL, + nrows=200, ## XXX + chunksize=100), + desc='Loading data' + )]) + +data.dropna(axis='index', inplace=True) + +#print(data) +#sys.exit(0) + +''' + Create a dataset that builds a tokenised vocabulary, + and then, as each row is accessed, transforms it into +''' +class TextCategoriesDataset(Dataset): + ''' Dataset of Text and Categories ''' + def __init__(self, df, text_column, cats_column, transform=None): + ''' + Arguments: + df (panda.Dataframe): csv content, loaded as dataframe + text_column (str): the name of the column containing the text + cats_column (str): the name of the column containing + semicolon-separated categories + transform (callable, optional): Optional transform to be + applied on a sample. + ''' + self.df = df + self.transform = transform + + self.texts = self.df[text_column] + self.cats = self.df[cats_column] + + # index-to-token dict + # : padding, used for padding the shorter sentences in a batch + # to match the length of longest sentence in the batch + # : start of sentence token + # : end of sentence token + # : unknown token: words which are not found in the vocab are + # replaced by this token + self.itos = {0: '', 1:'', 2:'', 3: ''} + # token-to-index dict + self.stoi = {k:j for j,k in self.itos.items()} + + # Create vocabularies upon initialisation + self.text_vocab = build_vocab_from_iterator( + [self.textTokens(text) for i, text in self.df[text_column].items()], + min_freq=2, + specials= self.itos.values(), + special_first=True + ) + self.text_vocab.set_default_index(self.text_vocab['']) + #print(self.text_vocab.get_itos()) + + self.cats_vocab = build_vocab_from_iterator( + [self.catTokens(cats) for i, cats in self.df[cats_column].items()], + min_freq=1, + specials= self.itos.values(), + special_first=True + ) + self.cats_vocab.set_default_index(self.cats_vocab['']) + #print(self.cats_vocab.get_itos()) + + def __len__(self): + return len(self.df) + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + + # Get the raw data + text = self.texts[idx] + cats = self.cats[idx] + + if self.transform: + text, cats = self.transform(text, cats) + + # Numericalise by applying transforms + return ( + self.getTransform(self.text_vocab)(self.textTokens(text)), + self.getTransform(self.cats_vocab)(self.catTokens(cats)), + ) + + @staticmethod + def textTokens(text): + if isinstance(text, str): + return [word for word in text.split()] + + @staticmethod + def catTokens(cats): + if isinstance(cats, str): + return [cat for cat in cats.split(';')] + elif isinstance(cats, list): + return [cat for cat in cats] + + def getTransform(self, vocab): + ''' + Create transforms based on given vocabulary. The returned transform + is applied to a sequence of tokens. + ''' + return T.Sequential( + # converts the sentences to indices based on given vocabulary + T.VocabTransform(vocab=vocab), + # Add at beginning of each sentence. 1 because the index + # for in vocabulary is 1 as seen in previous section + T.AddToken(1, begin=True), + # Add at beginning of each sentence. 2 because the index + # for in vocabulary is 2 as seen in previous section + T.AddToken(2, begin=False) + ) + +dataset = TextCategoriesDataset(df=data, + text_column="content", + cats_column="categories", +) +#print(dataset[2]) +#for text, cat in dataset: +# print(text, cat) +#sys.exit(0) + + +''' + Now that we have a dataset, let's create dataloader, + which can batch, shuffle, and load the data in parallel +''' + +class Collate: + ''' + We need to pad shorter sentences in a batch to make all the sequences + in a batch of equal length. We can do this a collate_fn callback class, + which returns a tensor + ''' + def __init__(self, pad_idx): + self.pad_idx = pad_idx + + def __call__(self, batch): + # T.ToTensor(0) returns a transform that converts the sequence + # to a torch.tensor and also applies padding. + # pad_idx is passed to the constructor to specify the + # index of the "" token in the vocabulary. + return ( + T.ToTensor(self.pad_idx)(list(batch[0])), + T.ToTensor(self.pad_idx)(list(batch[1])), + ) + + +pad_idx = dataset.stoi[''] +dataloader = DataLoader(dataset, + batch_size=4, + shuffle=True, + num_workers=0, + collate_fn=Collate(pad_idx=pad_idx), +) +#for i_batch, sample_batched in enumerate(dataloader): +# print(i_batch, sample_batched[0], sample_batched[1]) +#sys.exit(0) + +