africat/categorise.py

259 lines
7.4 KiB
Python
Raw Normal View History

#!/usr/bin/python
import argparse
import os
import sys
import pprint
import re
import string
import warnings
#data manupulation libs
import csv
import random
import pandas as pd
import numpy as np
#from pandarallel import pandarallel
from tqdm import tqdm
#torch libs
import torch
import torchdata.datapipes as dp
import torchtext.transforms as T
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import Dataset, DataLoader
parser = argparse.ArgumentParser(
description='Classify text data according to categories',
add_help=True,
)
parser.add_argument('action', help='train or classify')
parser.add_argument('--input', '-i', required=True, help='path of CSV file containing dataset')
parser.add_argument('--output', '-o', help='path to trained model')
args = parser.parse_args()
if args.action != 'train' and args.action != 'classify':
print("ERROR: train or classify data")
sys.exit(1)
if args.action == 'classify' and s.path.isfile(model_storage) is None:
print("No model found for classification; running training instead")
args.action = 'train'
if os.path.isfile(args.input) is False:
print(f"{args.input} is not a valid file")
sys.exit(1)
#with open(args.input, 'r', encoding="utf-8") as f:
# data = pd.read_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)
with open(args.input, 'r', encoding="utf-8") as f:
data = pd.concat(
[chunk for chunk in tqdm(
pd.read_csv(f,
encoding="utf-8",
quoting=csv.QUOTE_ALL,
nrows=200, ## XXX
chunksize=100),
desc='Loading data'
)])
data.dropna(axis='index', inplace=True)
#print(data)
#sys.exit(0)
'''
#######################################################
# Create Training and Validation sets
#######################################################
# create a list of ints till len of data
data_idx = list(range(len(data)))
np.random.shuffle(data_idx)
# get indexes for validation and train
val_frac = 0.1 # precentage of data in validation set
val_split_idx = int(len(data)*val_frac) # index on which to split (10% of data)
val_idx, train_idx = data_idx[:val_split_idx], data_idx[val_split_idx:]
print('len of train: ', len(train_idx))
print('len of val: ', len(val_idx))
# create the training and validation sets, as dataframes
train_data = data.iloc[train_idx].reset_index().drop('index',axis=1)
valid_data = data.iloc[val_idx].reset_index().drop('index',axis=1)
# Next, we create Pytorch Datasets and Dataloaders for these dataframes
'''
'''
Create a dataset that builds a tokenised vocabulary,
and then, as each row is accessed, transforms it into
'''
class TextCategoriesDataset(Dataset):
''' Dataset of Text and Categories '''
def __init__(self, df, text_column, cats_column, transform=None):
'''
Arguments:
df (panda.Dataframe): csv content, loaded as dataframe
text_column (str): the name of the column containing the text
cats_column (str): the name of the column containing
semicolon-separated categories
transform (callable, optional): Optional transform to be
applied on a sample.
'''
self.df = df
self.transform = transform
self.texts = self.df[text_column]
self.cats = self.df[cats_column]
# index-to-token dict
# <pad> : padding, used for padding the shorter sentences in a batch
# to match the length of longest sentence in the batch
# <sos> : start of sentence token
# <eos> : end of sentence token
# <unk> : unknown token: words which are not found in the vocab are
# replaced by this token
self.itos = {0: '<pad>', 1:'<sos>', 2:'<eos>', 3: '<unk>'}
# token-to-index dict
self.stoi = {k:j for j,k in self.itos.items()}
# Create vocabularies upon initialisation
self.text_vocab = build_vocab_from_iterator(
[self.textTokens(text) for i, text in self.df[text_column].items()],
min_freq=2,
specials= self.itos.values(),
special_first=True
)
self.text_vocab.set_default_index(self.text_vocab['<unk>'])
#print(self.text_vocab.get_itos())
self.cats_vocab = build_vocab_from_iterator(
[self.catTokens(cats) for i, cats in self.df[cats_column].items()],
min_freq=1,
specials= self.itos.values(),
special_first=True
)
self.cats_vocab.set_default_index(self.cats_vocab['<unk>'])
#print(self.cats_vocab.get_itos())
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
# Get the raw data
text = self.texts[idx]
cats = self.cats[idx]
if self.transform:
text, cats = self.transform(text, cats)
# Numericalise by applying transforms
return (
self.getTransform(self.text_vocab)(self.textTokens(text)),
self.getTransform(self.cats_vocab)(self.catTokens(cats)),
)
@staticmethod
def textTokens(text):
if isinstance(text, str):
return [word for word in text.split()]
@staticmethod
def catTokens(cats):
if isinstance(cats, str):
return [cat for cat in cats.split(';')]
elif isinstance(cats, list):
return [cat for cat in cats]
def getTransform(self, vocab):
'''
Create transforms based on given vocabulary. The returned transform
is applied to a sequence of tokens.
'''
return T.Sequential(
# converts the sentences to indices based on given vocabulary
T.VocabTransform(vocab=vocab),
# Add <sos> at beginning of each sentence. 1 because the index
# for <sos> in vocabulary is 1 as seen in previous section
T.AddToken(1, begin=True),
# Add <eos> at beginning of each sentence. 2 because the index
# for <eos> in vocabulary is 2 as seen in previous section
T.AddToken(2, begin=False)
)
dataset = TextCategoriesDataset(df=data,
text_column="content",
cats_column="categories",
)
'''
train_dataset = TextCategoriesDataset(df=train_data,
text_column="content",
cats_column="categories",
)
valid_dataset = TextCategoriesDataset(df=valid_data,
text_column="content",
cats_column="categories",
)
'''
#print(dataset[2])
#for text, cat in dataset:
# print(text, cat)
#sys.exit(0)
'''
Now that we have a dataset, let's create dataloader,
which can batch, shuffle, and load the data in parallel
'''
class Collate:
'''
We need to pad shorter sentences in a batch to make all the sequences
in a batch of equal length. We can do this a collate_fn callback class,
which returns a tensor
'''
def __init__(self, pad_idx):
self.pad_idx = pad_idx
def __call__(self, batch):
# T.ToTensor(0) returns a transform that converts the sequence
# to a torch.tensor and also applies padding.
# pad_idx is passed to the constructor to specify the
# index of the "<pad>" token in the vocabulary.
return (
T.ToTensor(self.pad_idx)(list(batch[0])),
T.ToTensor(self.pad_idx)(list(batch[1])),
)
dataloader = DataLoader(dataset,
batch_size=4,
shuffle=True,
num_workers=0,
collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
)
'''
train_dataloader = DataLoader(train_dataset,
batch_size=4,
shuffle=True,
num_workers=0,
collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
)
valid_dataloader = DataLoader(valid_dataset,
batch_size=4,
shuffle=True,
num_workers=0,
collate_fn=Collate(pad_idx=dataset.stoi['<pad>']),
)
'''
#for i_batch, sample_batched in enumerate(dataloader):
# print(i_batch, sample_batched[0], sample_batched[1])
#sys.exit(0)