2023-11-29 23:53:49 +00:00
|
|
|
#!/usr/bin/python
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import os
|
|
|
|
import sys
|
|
|
|
import pprint
|
|
|
|
import re
|
|
|
|
import string
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
#data manupulation libs
|
|
|
|
import csv
|
|
|
|
import random
|
|
|
|
import pandas as pd
|
|
|
|
import numpy as np
|
|
|
|
#from pandarallel import pandarallel
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
#torch libs
|
|
|
|
import torch
|
|
|
|
import torchdata.datapipes as dp
|
|
|
|
import torchtext.transforms as T
|
|
|
|
from torchtext.vocab import build_vocab_from_iterator
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(
|
2023-12-01 19:05:47 +00:00
|
|
|
description='Classify text data according to categories',
|
2023-11-29 23:53:49 +00:00
|
|
|
add_help=True,
|
|
|
|
)
|
|
|
|
parser.add_argument('action', help='train or classify')
|
|
|
|
parser.add_argument('--input', '-i', required=True, help='path of CSV file containing dataset')
|
|
|
|
parser.add_argument('--output', '-o', help='path to trained model')
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if args.action != 'train' and args.action != 'classify':
|
|
|
|
print("ERROR: train or classify data")
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
if args.action == 'classify' and s.path.isfile(model_storage) is None:
|
|
|
|
print("No model found for classification; running training instead")
|
|
|
|
args.action = 'train'
|
|
|
|
|
|
|
|
if os.path.isfile(args.input) is False:
|
|
|
|
print(f"{args.input} is not a valid file")
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
#with open(args.input, 'r', encoding="utf-8") as f:
|
|
|
|
# data = pd.read_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)
|
|
|
|
|
|
|
|
with open(args.input, 'r', encoding="utf-8") as f:
|
|
|
|
data = pd.concat(
|
|
|
|
[chunk for chunk in tqdm(
|
|
|
|
pd.read_csv(f,
|
|
|
|
encoding="utf-8",
|
|
|
|
quoting=csv.QUOTE_ALL,
|
|
|
|
nrows=200, ## XXX
|
|
|
|
chunksize=100),
|
|
|
|
desc='Loading data'
|
|
|
|
)])
|
|
|
|
|
|
|
|
data.dropna(axis='index', inplace=True)
|
|
|
|
|
|
|
|
#print(data)
|
|
|
|
#sys.exit(0)
|
|
|
|
|
2023-11-30 00:00:56 +00:00
|
|
|
'''
|
2023-12-01 19:05:47 +00:00
|
|
|
Create Training and Validation sets
|
|
|
|
'''
|
|
|
|
# Create a list of ints till len of data
|
2023-11-30 00:00:56 +00:00
|
|
|
data_idx = list(range(len(data)))
|
|
|
|
np.random.shuffle(data_idx)
|
|
|
|
|
2023-12-01 19:05:47 +00:00
|
|
|
# Get indexes for validation and train
|
|
|
|
split_percent = 0.95
|
|
|
|
num_train = int(len(data) * split_percent)
|
|
|
|
valid_idx, train_idx = data_idx[num_train:], data_idx[:num_train]
|
|
|
|
print("Length of train_data: {}".format(len(train_idx)))
|
|
|
|
print("Length of valid_data: {}".format(len(valid_idx)))
|
2023-11-30 00:00:56 +00:00
|
|
|
|
2023-12-01 19:05:47 +00:00
|
|
|
# Create the training and validation sets, as dataframes
|
|
|
|
train_data = data.iloc[train_idx].reset_index().drop('index', axis=1)
|
|
|
|
valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1)
|
2023-11-30 00:00:56 +00:00
|
|
|
|
|
|
|
|
2023-11-29 23:53:49 +00:00
|
|
|
'''
|
|
|
|
Create a dataset that builds a tokenised vocabulary,
|
|
|
|
and then, as each row is accessed, transforms it into
|
|
|
|
'''
|
|
|
|
class TextCategoriesDataset(Dataset):
|
|
|
|
''' Dataset of Text and Categories '''
|
|
|
|
def __init__(self, df, text_column, cats_column, transform=None):
|
|
|
|
'''
|
|
|
|
Arguments:
|
|
|
|
df (panda.Dataframe): csv content, loaded as dataframe
|
|
|
|
text_column (str): the name of the column containing the text
|
|
|
|
cats_column (str): the name of the column containing
|
|
|
|
semicolon-separated categories
|
|
|
|
transform (callable, optional): Optional transform to be
|
|
|
|
applied on a sample.
|
|
|
|
'''
|
|
|
|
self.df = df
|
|
|
|
self.transform = transform
|
|
|
|
|
|
|
|
self.texts = self.df[text_column]
|
|
|
|
self.cats = self.df[cats_column]
|
|
|
|
|
|
|
|
# index-to-token dict
|
|
|
|
# <pad> : padding, used for padding the shorter sentences in a batch
|
|
|
|
# to match the length of longest sentence in the batch
|
|
|
|
# <sos> : start of sentence token
|
|
|
|
# <eos> : end of sentence token
|
|
|
|
# <unk> : unknown token: words which are not found in the vocab are
|
|
|
|
# replaced by this token
|
|
|
|
self.itos = {0: '<pad>', 1:'<sos>', 2:'<eos>', 3: '<unk>'}
|
|
|
|
# token-to-index dict
|
2023-12-01 19:05:47 +00:00
|
|
|
self.stoi = {k:j for j, k in self.itos.items()}
|
2023-11-29 23:53:49 +00:00
|
|
|
|
|
|
|
# Create vocabularies upon initialisation
|
|
|
|
self.text_vocab = build_vocab_from_iterator(
|
|
|
|
[self.textTokens(text) for i, text in self.df[text_column].items()],
|
|
|
|
min_freq=2,
|
|
|
|
specials= self.itos.values(),
|
|
|
|
special_first=True
|
|
|
|
)
|
|
|
|
self.text_vocab.set_default_index(self.text_vocab['<unk>'])
|
|
|
|
#print(self.text_vocab.get_itos())
|
|
|
|
|
|
|
|
self.cats_vocab = build_vocab_from_iterator(
|
|
|
|
[self.catTokens(cats) for i, cats in self.df[cats_column].items()],
|
|
|
|
min_freq=1,
|
|
|
|
specials= self.itos.values(),
|
|
|
|
special_first=True
|
|
|
|
)
|
|
|
|
self.cats_vocab.set_default_index(self.cats_vocab['<unk>'])
|
|
|
|
#print(self.cats_vocab.get_itos())
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.df)
|
|
|
|
|
|
|
|
def __getitem__(self, idx):
|
2023-12-01 19:05:47 +00:00
|
|
|
# Enable use as a plain iterator
|
|
|
|
if idx not in self.df.index:
|
|
|
|
raise(StopIteration)
|
|
|
|
|
2023-11-29 23:53:49 +00:00
|
|
|
if torch.is_tensor(idx):
|
|
|
|
idx = idx.tolist()
|
|
|
|
|
|
|
|
# Get the raw data
|
|
|
|
text = self.texts[idx]
|
|
|
|
cats = self.cats[idx]
|
|
|
|
|
|
|
|
if self.transform:
|
|
|
|
text, cats = self.transform(text, cats)
|
|
|
|
|
|
|
|
# Numericalise by applying transforms
|
|
|
|
return (
|
|
|
|
self.getTransform(self.text_vocab)(self.textTokens(text)),
|
|
|
|
self.getTransform(self.cats_vocab)(self.catTokens(cats)),
|
|
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def textTokens(text):
|
|
|
|
if isinstance(text, str):
|
|
|
|
return [word for word in text.split()]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def catTokens(cats):
|
|
|
|
if isinstance(cats, str):
|
|
|
|
return [cat for cat in cats.split(';')]
|
|
|
|
elif isinstance(cats, list):
|
|
|
|
return [cat for cat in cats]
|
|
|
|
|
|
|
|
def getTransform(self, vocab):
|
|
|
|
'''
|
|
|
|
Create transforms based on given vocabulary. The returned transform
|
|
|
|
is applied to a sequence of tokens.
|
|
|
|
'''
|
|
|
|
return T.Sequential(
|
|
|
|
# converts the sentences to indices based on given vocabulary
|
|
|
|
T.VocabTransform(vocab=vocab),
|
|
|
|
# Add <sos> at beginning of each sentence. 1 because the index
|
|
|
|
# for <sos> in vocabulary is 1 as seen in previous section
|
|
|
|
T.AddToken(1, begin=True),
|
|
|
|
# Add <eos> at beginning of each sentence. 2 because the index
|
|
|
|
# for <eos> in vocabulary is 2 as seen in previous section
|
|
|
|
T.AddToken(2, begin=False)
|
|
|
|
)
|
|
|
|
|
2023-12-01 19:05:47 +00:00
|
|
|
'''
|
2023-11-29 23:53:49 +00:00
|
|
|
dataset = TextCategoriesDataset(df=data,
|
|
|
|
text_column="content",
|
|
|
|
cats_column="categories",
|
|
|
|
)
|
2023-11-30 00:00:56 +00:00
|
|
|
'''
|
|
|
|
train_dataset = TextCategoriesDataset(df=train_data,
|
|
|
|
text_column="content",
|
|
|
|
cats_column="categories",
|
|
|
|
)
|
|
|
|
valid_dataset = TextCategoriesDataset(df=valid_data,
|
|
|
|
text_column="content",
|
|
|
|
cats_column="categories",
|
|
|
|
)
|
2023-11-29 23:53:49 +00:00
|
|
|
#print(dataset[2])
|
2023-12-01 19:05:47 +00:00
|
|
|
#for text, cat in enumerate(valid_dataset):
|
2023-11-29 23:53:49 +00:00
|
|
|
# print(text, cat)
|
|
|
|
#sys.exit(0)
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
Now that we have a dataset, let's create dataloader,
|
|
|
|
which can batch, shuffle, and load the data in parallel
|
|
|
|
'''
|
|
|
|
|
2023-12-01 19:05:47 +00:00
|
|
|
class CollateBatch:
|
2023-11-29 23:53:49 +00:00
|
|
|
'''
|
|
|
|
We need to pad shorter sentences in a batch to make all the sequences
|
|
|
|
in a batch of equal length. We can do this a collate_fn callback class,
|
|
|
|
which returns a tensor
|
|
|
|
'''
|
|
|
|
def __init__(self, pad_idx):
|
|
|
|
self.pad_idx = pad_idx
|
2023-12-01 19:05:47 +00:00
|
|
|
|
2023-11-29 23:53:49 +00:00
|
|
|
def __call__(self, batch):
|
|
|
|
# T.ToTensor(0) returns a transform that converts the sequence
|
|
|
|
# to a torch.tensor and also applies padding.
|
2023-12-01 19:05:47 +00:00
|
|
|
#
|
|
|
|
# pad_idx is passed to the constructor to specify the index of
|
|
|
|
# the "<pad>" token in the vocabulary.
|
2023-11-29 23:53:49 +00:00
|
|
|
return (
|
|
|
|
T.ToTensor(self.pad_idx)(list(batch[0])),
|
|
|
|
T.ToTensor(self.pad_idx)(list(batch[1])),
|
|
|
|
)
|
|
|
|
|
2023-12-01 19:05:47 +00:00
|
|
|
|
|
|
|
# Hyperparameters
|
|
|
|
EPOCHS = 10 # epoch
|
|
|
|
LR = 5 # learning rate
|
|
|
|
BATCH_SIZE = 64 # batch size for training
|
|
|
|
|
|
|
|
# Get cpu, gpu or mps device for training.
|
|
|
|
# Move tensor to the NVIDIA GPU if available
|
|
|
|
device = (
|
|
|
|
"cuda" if torch.cuda.is_available()
|
|
|
|
else "xps" if hasattr(torch, "xpu") and torch.xpu.is_available()
|
|
|
|
else "mps" if torch.backends.mps.is_available()
|
|
|
|
else "cpu"
|
|
|
|
)
|
|
|
|
print(f"Using {device} device")
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
2023-11-29 23:53:49 +00:00
|
|
|
dataloader = DataLoader(dataset,
|
|
|
|
batch_size=4,
|
|
|
|
shuffle=True,
|
|
|
|
num_workers=0,
|
2023-12-01 19:05:47 +00:00
|
|
|
collate_fn=CollateBatch(pad_idx=dataset.stoi['<pad>']),
|
2023-11-30 00:00:56 +00:00
|
|
|
)
|
|
|
|
'''
|
|
|
|
train_dataloader = DataLoader(train_dataset,
|
2023-12-01 19:05:47 +00:00
|
|
|
batch_size=BATCH_SIZE,
|
2023-11-30 00:00:56 +00:00
|
|
|
shuffle=True,
|
|
|
|
num_workers=0,
|
2023-12-01 19:05:47 +00:00
|
|
|
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
2023-11-29 23:53:49 +00:00
|
|
|
)
|
2023-11-30 00:00:56 +00:00
|
|
|
valid_dataloader = DataLoader(valid_dataset,
|
2023-12-01 19:05:47 +00:00
|
|
|
batch_size=BATCH_SIZE,
|
2023-11-30 00:00:56 +00:00
|
|
|
shuffle=True,
|
|
|
|
num_workers=0,
|
2023-12-01 19:05:47 +00:00
|
|
|
collate_fn=CollateBatch(pad_idx=valid_dataset.stoi['<pad>']),
|
2023-11-30 00:00:56 +00:00
|
|
|
)
|
2023-11-29 23:53:49 +00:00
|
|
|
#for i_batch, sample_batched in enumerate(dataloader):
|
|
|
|
# print(i_batch, sample_batched[0], sample_batched[1])
|
|
|
|
#sys.exit(0)
|
|
|
|
|
|
|
|
|