From 46f533746ee945b56ea5d5849efe7b4f5572394b Mon Sep 17 00:00:00 2001 From: tim Date: Fri, 1 Dec 2023 23:02:05 +0200 Subject: [PATCH] Format for poetry and add debugging --- africat/aa_create_dataset.py | 233 ++++++++++++++---------- africat/categorise.py | 344 ++++++++++++++++++++++++----------- 2 files changed, 375 insertions(+), 202 deletions(-) diff --git a/africat/aa_create_dataset.py b/africat/aa_create_dataset.py index b45e435..dcaf528 100755 --- a/africat/aa_create_dataset.py +++ b/africat/aa_create_dataset.py @@ -1,4 +1,10 @@ #!/usr/bin/python +''' + 1. Load XML file + 2. Create structure + 3. Preprocess the data to remove punctuations, digits, spaces and making the text lower. + This helps reduce the vocab of the data (as now, "Cat ~" is "cat") +''' import argparse import os @@ -8,31 +14,16 @@ import re import string from string import digits import warnings - import html from xml.etree import ElementTree as ET - -#data manupulation libs +# data manupulation libs import csv import pandas as pd from pandarallel import pandarallel -parser = argparse.ArgumentParser( - description='Turn XML data files into a dataset for use with pytorch', - add_help=True, -) -parser.add_argument('--output', '-o', required=True, help='path of output CSV file') -parser.add_argument('--input', '-i', required=True, help='path of input directory containing XML files') -args = parser.parse_args() - -if os.path.isdir(args.input) is False: - print(f"{args.input} is not a directory or does not exist"); - sys.exit(1) - -#1. Load XML file -#2. Create structure -#3. Preprocess the data to remove punctuations, digits, spaces and making the text lower. -#. This helps reduce the vocab of the data (as now, "Cat ~" is "cat") +def write_csv(data, output): + with open(output, 'w', encoding="utf-8") as f: + data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL) def insert_line_numbers(txt): return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))]) @@ -48,99 +39,153 @@ def partial_unescape(s): parts[i] = html.unescape(parts[i]) return "".join(parts) -articles = list() -#allCats = list() +def parse_and_extract(input_dir, verbose): + articles = list() + + total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0 + for root, dirs, files in os.walk(input_dir): + dirs.sort() + if verbose > 0: + print(root) + for file in sorted(files): + #if re.search('2022\/10\/09', root) and re.search('0028.aans$', file): + if re.search('.aans$', file): + xml_file = os.path.join(root, file) + total += 1 -total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0 -for root, dirs, files in os.walk(args.input): - dirs.sort() - print(root) - for file in sorted(files): - #if re.search('2022\/10\/09', root) and re.search('0028.aans$', file): - if re.search('.aans$', file): - xml_file = os.path.join(root, file) - total += 1 - try: - with open(xml_file, 'r', encoding="ASCII") as f: - content = f.read() - #print(f"ASCII read succeeded in {xml_file}") - plain += 1 - except Exception as e: - #print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}") try: - with open(xml_file, 'r', encoding="UTF-8") as f: + with open(xml_file, 'r', encoding="ASCII") as f: content = f.read() - #print(f"UTF-8 read succeeded in {xml_file}") - utf8 += 1 + if verbose > 1: + print(f"ASCII read succeeded in {xml_file}") + plain += 1 except Exception as e: - #print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}") + if verbose > 1: + print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}") try: - with open(xml_file, 'r', encoding="ISO-8859-1") as f: + with open(xml_file, 'r', encoding="UTF-8") as f: content = f.read() - #print(f"ISO-8859-1 read succeeded in {xml_file}") - iso88591 += 1 - except Exception as e: - print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}") - print(content) - failed += 1 - content = partial_unescape(content) - content = local_clean(content) - #print(content) + if verbose > 1: + print(f"UTF-8 read succeeded in {xml_file}") + utf8 += 1 + except Exception as e: + if verbose > 1: + print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}") + try: + with open(xml_file, 'r', encoding="ISO-8859-1") as f: + content = f.read() + if verbose > 1: + print(f"ISO-8859-1 read succeeded in {xml_file}") + iso88591 += 1 + except Exception as e: + print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}") + if verbose > 2: + print(content) + failed += 1 + content = partial_unescape(content) + content = local_clean(content) + if verbose > 3: + print(content) - key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file) - try: - doc = ET.fromstring(content) - entry = dict() - entry["key"] = key - cats = list() - for cat in doc.findall('category'): - #if cat not in allCats: - # allCats.append(cat) - cats.append(cat.text) - #entry["categories"] = cats - entry["categories"] = ";".join(cats) - text = list() + key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file) try: - #text = "\n".join([p.text for p in doc.find('./body')]) - for p in doc.find('./body'): - if p.text is not None: - text.append(p.text) + doc = ET.fromstring(content) + + entry = dict() + entry["key"] = key + + cats = list() + for cat in doc.findall('./category'): + cats.append(cat.text) + #entry["categories"] = cats # if you want a list + entry["categories"] = ";".join(cats) # if you want a string + + text = list() + lang = "" + try: + for p in doc.find('./body'): + if p.text is not None: + text.append(p.text) + lang = doc.find('./language').text + except Exception as e: + print(f"{xml_file} : {e}") + if text is not None and len(cats) > 1: entry["content"] = "\n".join(text) + entry["language"] = lang articles.append(entry) - except Exception as e: - print(f"{xml_file} : {e}") - except ET.ParseError as e: - print(insert_line_numbers(content)) - print("Parse error in " + xml_file + " : ", e) - raise(SystemExit) -print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed)) + except ET.ParseError as e: + if verbose > 1: + print(insert_line_numbers(content)) + print("Parse error in " + xml_file + " : ", e) + raise(SystemExit) -#sys.exit(0) + if verbose > 0: + print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed)) -data = pd.DataFrame(articles) -data.set_index("key", inplace=True) + #sys.exit(0) + return articles -#print(data.categories) -# Initialization -pandarallel.initialize() +def scrub_data(articles, verbose): + data = pd.DataFrame(articles) + data.set_index("key", inplace=True) -# Lowercase everything -data['content'] = data.content.parallel_apply(lambda x: x.lower()) + #if verbose > 2: + # print(data.categories) -# Remove special characters -exclude = set(string.punctuation) #set of all special chars -data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude)) + # Initialization + pandarallel.initialize() -# Remove digits -remove_digits = str.maketrans('','',digits) -data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits)) + # Lowercase everything + data['content'] = data.content.parallel_apply(lambda x: x.lower()) -# Remove extra spaces -data['content']=data.content.parallel_apply(lambda x: x.strip()) -data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x)) + # Remove special characters + exclude = set(string.punctuation) #set of all special chars + data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude)) -with open(args.output, 'w', encoding="utf-8") as f: - data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL) + # Remove digits + remove_digits = str.maketrans('','',digits) + data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits)) + + # Remove extra spaces + data['content']=data.content.parallel_apply(lambda x: x.strip()) + data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x)) + + # TODO: lemmas? See spaCy + + return data + +def main(): + parser = argparse.ArgumentParser( + description='Turn XML data files into a dataset for use with pytorch', + add_help=True, + ) + parser.add_argument('--output', '-o', + required=True, + help='path of output CSV file') + parser.add_argument('--input', '-i', + required=True, + help='path of input directory containing XML files') + parser.add_argument('--verbose', '-v', + type=int, nargs='?', + const=1, # Default value if -v is supplied + default=0, # Default value if -v is not supplied + help='print debugging') + args = parser.parse_args() + + if os.path.isdir(args.input) is False: + print(f"{args.input} is not a directory or does not exist"); + sys.exit(1) + + articles = parse_and_extract(args.input, args.verbose) + + data = scrub_data(articles, args.verbose) + + write_csv(data, args.output) + + return + +if __name__ == "__main__": + main() diff --git a/africat/categorise.py b/africat/categorise.py index 537bc36..479d08b 100755 --- a/africat/categorise.py +++ b/africat/categorise.py @@ -6,80 +6,84 @@ import sys import pprint import re import string +import time import warnings - -#data manupulation libs +# data manupulation import csv import random import pandas as pd import numpy as np #from pandarallel import pandarallel from tqdm import tqdm - -#torch libs +# torch import torch import torchdata.datapipes as dp import torchtext.transforms as T from torchtext.vocab import build_vocab_from_iterator from torch.utils.data import Dataset, DataLoader +from torch import nn -parser = argparse.ArgumentParser( - description='Classify text data according to categories', - add_help=True, -) -parser.add_argument('action', help='train or classify') -parser.add_argument('--input', '-i', required=True, help='path of CSV file containing dataset') -parser.add_argument('--output', '-o', help='path to trained model') -args = parser.parse_args() +story_num = 40 # XXX None for all -if args.action != 'train' and args.action != 'classify': - print("ERROR: train or classify data") - sys.exit(1) +# Hyperparameters +EPOCHS = 10 # epoch +LR = 5 # learning rate +BATCH_SIZE = 64 # batch size for training -if args.action == 'classify' and s.path.isfile(model_storage) is None: - print("No model found for classification; running training instead") - args.action = 'train' - -if os.path.isfile(args.input) is False: - print(f"{args.input} is not a valid file") - sys.exit(1) - -#with open(args.input, 'r', encoding="utf-8") as f: -# data = pd.read_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL) - -with open(args.input, 'r', encoding="utf-8") as f: - data = pd.concat( - [chunk for chunk in tqdm( - pd.read_csv(f, +def read_csv(input_csv, rows=None, verbose=0): + if verbose > 0: + with open(input_csv, 'r', encoding="utf-8") as f: + data = pd.concat( + [chunk for chunk in tqdm( + pd.read_csv(f, + encoding="utf-8", + quoting=csv.QUOTE_ALL, + nrows=rows, + chunksize=50, + ), + desc='Loading data' + )]) + else: + with open(input_csv, 'r', encoding="utf-8") as f: + data = pd.read_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL, - nrows=200, ## XXX - chunksize=100), - desc='Loading data' - )]) + nrows=rows, + ) -data.dropna(axis='index', inplace=True) + data.dropna(axis='index', inplace=True) + #print(data) + #sys.exit(0) + return data -#print(data) -#sys.exit(0) ''' Create Training and Validation sets ''' -# Create a list of ints till len of data -data_idx = list(range(len(data))) -np.random.shuffle(data_idx) +def split_dataset(data, verbose=0): + # Create a list of ints till len of data + data_idx = list(range(len(data))) + np.random.shuffle(data_idx) -# Get indexes for validation and train -split_percent = 0.95 -num_train = int(len(data) * split_percent) -valid_idx, train_idx = data_idx[num_train:], data_idx[:num_train] -print("Length of train_data: {}".format(len(train_idx))) -print("Length of valid_data: {}".format(len(valid_idx))) + # Get indexes for validation and train + split_percent = 0.05 + num_valid = int(len(data) * split_percent) + #num_tests = int(len(data) * split_percent) + #train_idx = data_idx[num_valid:-num_tests] + train_idx = data_idx[num_valid:] + valid_idx = data_idx[:num_valid] + #tests_idx = data_idx[-num_tests:] + if verbose > 0: + print("Length of train_data: {}".format(len(train_idx))) + print("Length of valid_data: {}".format(len(valid_idx))) + #print("Length of tests_data: {}".format(len(tests_idx))) -# Create the training and validation sets, as dataframes -train_data = data.iloc[train_idx].reset_index().drop('index', axis=1) -valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1) + # Create the training and validation sets, as dataframes + train_data = data.iloc[train_idx].reset_index().drop('index', axis=1) + valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1) + #tests_data = data.iloc[tests_idx].reset_index().drop('index', axis=1) + #return(train_data, valid_data, tests_data) + return(train_data, valid_data) ''' @@ -88,21 +92,24 @@ valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1) ''' class TextCategoriesDataset(Dataset): ''' Dataset of Text and Categories ''' - def __init__(self, df, text_column, cats_column, transform=None): + def __init__(self, df, text_column, cats_column, lang_column, transform=None, verbose=0): ''' Arguments: df (panda.Dataframe): csv content, loaded as dataframe text_column (str): the name of the column containing the text cats_column (str): the name of the column containing semicolon-separated categories + text_column (str): the name of the column containing the language transform (callable, optional): Optional transform to be applied on a sample. ''' self.df = df self.transform = transform + self.verbose = verbose - self.texts = self.df[text_column] - self.cats = self.df[cats_column] + self.text = self.df[text_column] + self.cats = self.df[cats_column] + self.lang = self.df[lang_column] # index-to-token dict # : padding, used for padding the shorter sentences in a batch @@ -146,8 +153,9 @@ class TextCategoriesDataset(Dataset): idx = idx.tolist() # Get the raw data - text = self.texts[idx] + text = self.text[idx] cats = self.cats[idx] + lang = self.lang[idx] if self.transform: text, cats = self.transform(text, cats) @@ -186,25 +194,6 @@ class TextCategoriesDataset(Dataset): T.AddToken(2, begin=False) ) -''' -dataset = TextCategoriesDataset(df=data, - text_column="content", - cats_column="categories", -) -''' -train_dataset = TextCategoriesDataset(df=train_data, - text_column="content", - cats_column="categories", -) -valid_dataset = TextCategoriesDataset(df=valid_data, - text_column="content", - cats_column="categories", -) -#print(dataset[2]) -#for text, cat in enumerate(valid_dataset): -# print(text, cat) -#sys.exit(0) - ''' Now that we have a dataset, let's create dataloader, @@ -232,44 +221,183 @@ class CollateBatch: ) -# Hyperparameters -EPOCHS = 10 # epoch -LR = 5 # learning rate -BATCH_SIZE = 64 # batch size for training +class TextClassificationModel(nn.Module): + def __init__(self, input_size, output_size, verbose): + super().__init__() -# Get cpu, gpu or mps device for training. -# Move tensor to the NVIDIA GPU if available -device = ( - "cuda" if torch.cuda.is_available() - else "xps" if hasattr(torch, "xpu") and torch.xpu.is_available() - else "mps" if torch.backends.mps.is_available() - else "cpu" -) -print(f"Using {device} device") + def forward(self, x): + return x -''' -dataloader = DataLoader(dataset, - batch_size=4, - shuffle=True, - num_workers=0, - collate_fn=CollateBatch(pad_idx=dataset.stoi['']), -) -''' -train_dataloader = DataLoader(train_dataset, - batch_size=BATCH_SIZE, - shuffle=True, - num_workers=0, - collate_fn=CollateBatch(pad_idx=train_dataset.stoi['']), -) -valid_dataloader = DataLoader(valid_dataset, - batch_size=BATCH_SIZE, - shuffle=True, - num_workers=0, - collate_fn=CollateBatch(pad_idx=valid_dataset.stoi['']), -) -#for i_batch, sample_batched in enumerate(dataloader): -# print(i_batch, sample_batched[0], sample_batched[1]) -#sys.exit(0) +def train(dataloader): + model.train() + total_acc, total_count = 0, 0 + log_interval = 500 + start_time = time.time() + + for idx, (label, text) in enumerate(dataloader): + optimizer.zero_grad() + predicted_label = model(text) + loss = criterion(predicted_label, label) + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) + optimizer.step() + total_acc += (predicted_label.argmax(1) == label).sum().item() + total_count += label.size(0) + if idx % log_interval == 0 and idx > 0: + elapsed = time.time() - start_time + print( + "| epoch {:3d} | {:5d}/{:5d} batches " + "| accuracy {:8.3f}".format( + epoch, idx, len(dataloader), total_acc / total_count + ) + ) + total_acc, total_count = 0, 0 + start_time = time.time() +def evaluate(dataloader): + model.eval() + total_acc, total_count = 0, 0 + + with torch.no_grad(): + for idx, (label, text) in enumerate(dataloader): + predicted_label = model(text) + loss = criterion(predicted_label, label) + total_acc += (predicted_label.argmax(1) == label).sum().item() + total_count += label.size(0) + return total_acc / total_count + + +def main(): + parser = argparse.ArgumentParser( + description='Classify text data according to categories', + add_help=True, + ) + parser.add_argument('action', + help='train or classify') + parser.add_argument('--input', '-i', + required=True, + help='path of CSV file containing dataset') + parser.add_argument('--model', '-m', + #required=True, # XXX + help='path to training model') + parser.add_argument('--verbose', '-v', + type=int, nargs='?', + const=1, # Default value if -v is supplied + default=0, # Default value if -v is not supplied + help='print debugging') + args = parser.parse_args() + + if args.action != 'train' and args.action != 'classify': + print("ERROR: train or classify data") + sys.exit(1) + + if args.action == 'classify' and s.path.isfile(model_storage) is None: + print("No model found for classification; running training instead") + args.action = 'train' + + if os.path.isfile(args.input) is False: + print(f"{args.input} is not a valid file") + sys.exit(1) + + data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose) + train_data, valid_data, = split_dataset(data, verbose=args.verbose) + + ''' + dataset = TextCategoriesDataset(df=data, + text_column="content", + cats_column="categories", + lang_column="language", + verbose=args.verbose, + ) + ''' + train_dataset = TextCategoriesDataset(df=train_data, + text_column="content", + cats_column="categories", + lang_column="language", + verbose=args.verbose, + ) + valid_dataset = TextCategoriesDataset(df=valid_data, + text_column="content", + cats_column="categories", + lang_column="language", + verbose=args.verbose, + ) + #print(dataset[2]) + #for text, cat in enumerate(valid_dataset): + # print(text, cat) + #sys.exit(0) + + # Get cpu, gpu or mps device for training. + # Move tensor to the NVIDIA GPU if available + device = ( + "cuda" if torch.cuda.is_available() + else "xps" if hasattr(torch, "xpu") and torch.xpu.is_available() + else "mps" if torch.backends.mps.is_available() + else "cpu" + ) + print(f"Using {device} device") + + ''' + dataloader = DataLoader(dataset, + batch_size=4, + shuffle=True, + num_workers=0, + collate_fn=CollateBatch(pad_idx=dataset.stoi['']), + ) + ''' + train_dataloader = DataLoader(train_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=0, + collate_fn=CollateBatch(pad_idx=train_dataset.stoi['']), + ) + valid_dataloader = DataLoader(valid_dataset, + batch_size=BATCH_SIZE, + shuffle=True, + num_workers=0, + collate_fn=CollateBatch(pad_idx=valid_dataset.stoi['']), + ) + #for i_batch, sample_batched in enumerate(dataloader): + # print(i_batch, sample_batched[0], sample_batched[1]) + #sys.exit(0) + + num_class = len(set([cats for key, cats, text, lang in train_data.values])) + input_size = len(train_dataset.text_vocab) + output_size = len(train_dataset.cats_vocab) + emsize = 64 + model = TextClassificationModel(input_size, output_size, args.verbose).to(device) + + + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.SGD(model.parameters(), lr=LR) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1) + total_accu = None + + for epoch in range(1, EPOCHS + 1): + epoch_start_time = time.time() + train(train_dataloader) + accu_val = evaluate(valid_dataloader) + if total_accu is not None and total_accu > accu_val: + scheduler.step() + else: + total_accu = accu_val + print("-" * 59) + print( + "| end of epoch {:3d} | time: {:5.2f}s | " + "valid accuracy {:8.3f} ".format( + epoch, time.time() - epoch_start_time, accu_val + ) + ) + print("-" * 59) + + print("Checking the results of test dataset.") + accu_test = evaluate(test_dataloader) + print("test accuracy {:8.3f}".format(accu_test)) + + return + +if __name__ == "__main__": + main() +