Format for poetry and add debugging

This commit is contained in:
Timothy Allen 2023-12-01 23:02:05 +02:00
parent 2039b017eb
commit 46f533746e
2 changed files with 375 additions and 202 deletions

View File

@ -1,4 +1,10 @@
#!/usr/bin/python
'''
1. Load XML file
2. Create structure
3. Preprocess the data to remove punctuations, digits, spaces and making the text lower.
This helps reduce the vocab of the data (as now, "Cat ~" is "cat")
'''
import argparse
import os
@ -8,31 +14,16 @@ import re
import string
from string import digits
import warnings
import html
from xml.etree import ElementTree as ET
#data manupulation libs
# data manupulation libs
import csv
import pandas as pd
from pandarallel import pandarallel
parser = argparse.ArgumentParser(
description='Turn XML data files into a dataset for use with pytorch',
add_help=True,
)
parser.add_argument('--output', '-o', required=True, help='path of output CSV file')
parser.add_argument('--input', '-i', required=True, help='path of input directory containing XML files')
args = parser.parse_args()
if os.path.isdir(args.input) is False:
print(f"{args.input} is not a directory or does not exist");
sys.exit(1)
#1. Load XML file
#2. Create structure
#3. Preprocess the data to remove punctuations, digits, spaces and making the text lower.
#. This helps reduce the vocab of the data (as now, "Cat ~" is "cat")
def write_csv(data, output):
with open(output, 'w', encoding="utf-8") as f:
data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)
def insert_line_numbers(txt):
return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))])
@ -48,99 +39,153 @@ def partial_unescape(s):
parts[i] = html.unescape(parts[i])
return "".join(parts)
articles = list()
#allCats = list()
def parse_and_extract(input_dir, verbose):
articles = list()
total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0
for root, dirs, files in os.walk(args.input):
total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0
for root, dirs, files in os.walk(input_dir):
dirs.sort()
if verbose > 0:
print(root)
for file in sorted(files):
#if re.search('2022\/10\/09', root) and re.search('0028.aans$', file):
if re.search('.aans$', file):
xml_file = os.path.join(root, file)
total += 1
try:
with open(xml_file, 'r', encoding="ASCII") as f:
content = f.read()
#print(f"ASCII read succeeded in {xml_file}")
if verbose > 1:
print(f"ASCII read succeeded in {xml_file}")
plain += 1
except Exception as e:
#print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}")
if verbose > 1:
print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}")
try:
with open(xml_file, 'r', encoding="UTF-8") as f:
content = f.read()
#print(f"UTF-8 read succeeded in {xml_file}")
if verbose > 1:
print(f"UTF-8 read succeeded in {xml_file}")
utf8 += 1
except Exception as e:
#print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}")
if verbose > 1:
print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}")
try:
with open(xml_file, 'r', encoding="ISO-8859-1") as f:
content = f.read()
#print(f"ISO-8859-1 read succeeded in {xml_file}")
if verbose > 1:
print(f"ISO-8859-1 read succeeded in {xml_file}")
iso88591 += 1
except Exception as e:
print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}")
if verbose > 2:
print(content)
failed += 1
content = partial_unescape(content)
content = local_clean(content)
#print(content)
if verbose > 3:
print(content)
key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file)
try:
doc = ET.fromstring(content)
entry = dict()
entry["key"] = key
cats = list()
for cat in doc.findall('category'):
#if cat not in allCats:
# allCats.append(cat)
for cat in doc.findall('./category'):
cats.append(cat.text)
#entry["categories"] = cats
entry["categories"] = ";".join(cats)
#entry["categories"] = cats # if you want a list
entry["categories"] = ";".join(cats) # if you want a string
text = list()
lang = ""
try:
#text = "\n".join([p.text for p in doc.find('./body')])
for p in doc.find('./body'):
if p.text is not None:
text.append(p.text)
if text is not None and len(cats) > 1:
entry["content"] = "\n".join(text)
articles.append(entry)
lang = doc.find('./language').text
except Exception as e:
print(f"{xml_file} : {e}")
if text is not None and len(cats) > 1:
entry["content"] = "\n".join(text)
entry["language"] = lang
articles.append(entry)
except ET.ParseError as e:
if verbose > 1:
print(insert_line_numbers(content))
print("Parse error in " + xml_file + " : ", e)
raise(SystemExit)
print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed))
if verbose > 0:
print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed))
#sys.exit(0)
#sys.exit(0)
return articles
data = pd.DataFrame(articles)
data.set_index("key", inplace=True)
#print(data.categories)
def scrub_data(articles, verbose):
data = pd.DataFrame(articles)
data.set_index("key", inplace=True)
# Initialization
pandarallel.initialize()
#if verbose > 2:
# print(data.categories)
# Lowercase everything
data['content'] = data.content.parallel_apply(lambda x: x.lower())
# Initialization
pandarallel.initialize()
# Remove special characters
exclude = set(string.punctuation) #set of all special chars
data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude))
# Lowercase everything
data['content'] = data.content.parallel_apply(lambda x: x.lower())
# Remove digits
remove_digits = str.maketrans('','',digits)
data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits))
# Remove special characters
exclude = set(string.punctuation) #set of all special chars
data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude))
# Remove extra spaces
data['content']=data.content.parallel_apply(lambda x: x.strip())
data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
# Remove digits
remove_digits = str.maketrans('','',digits)
data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits))
with open(args.output, 'w', encoding="utf-8") as f:
data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)
# Remove extra spaces
data['content']=data.content.parallel_apply(lambda x: x.strip())
data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
# TODO: lemmas? See spaCy
return data
def main():
parser = argparse.ArgumentParser(
description='Turn XML data files into a dataset for use with pytorch',
add_help=True,
)
parser.add_argument('--output', '-o',
required=True,
help='path of output CSV file')
parser.add_argument('--input', '-i',
required=True,
help='path of input directory containing XML files')
parser.add_argument('--verbose', '-v',
type=int, nargs='?',
const=1, # Default value if -v is supplied
default=0, # Default value if -v is not supplied
help='print debugging')
args = parser.parse_args()
if os.path.isdir(args.input) is False:
print(f"{args.input} is not a directory or does not exist");
sys.exit(1)
articles = parse_and_extract(args.input, args.verbose)
data = scrub_data(articles, args.verbose)
write_csv(data, args.output)
return
if __name__ == "__main__":
main()

View File

@ -6,80 +6,84 @@ import sys
import pprint
import re
import string
import time
import warnings
#data manupulation libs
# data manupulation
import csv
import random
import pandas as pd
import numpy as np
#from pandarallel import pandarallel
from tqdm import tqdm
#torch libs
# torch
import torch
import torchdata.datapipes as dp
import torchtext.transforms as T
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import Dataset, DataLoader
from torch import nn
parser = argparse.ArgumentParser(
description='Classify text data according to categories',
add_help=True,
)
parser.add_argument('action', help='train or classify')
parser.add_argument('--input', '-i', required=True, help='path of CSV file containing dataset')
parser.add_argument('--output', '-o', help='path to trained model')
args = parser.parse_args()
story_num = 40 # XXX None for all
if args.action != 'train' and args.action != 'classify':
print("ERROR: train or classify data")
sys.exit(1)
# Hyperparameters
EPOCHS = 10 # epoch
LR = 5 # learning rate
BATCH_SIZE = 64 # batch size for training
if args.action == 'classify' and s.path.isfile(model_storage) is None:
print("No model found for classification; running training instead")
args.action = 'train'
if os.path.isfile(args.input) is False:
print(f"{args.input} is not a valid file")
sys.exit(1)
#with open(args.input, 'r', encoding="utf-8") as f:
# data = pd.read_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)
with open(args.input, 'r', encoding="utf-8") as f:
def read_csv(input_csv, rows=None, verbose=0):
if verbose > 0:
with open(input_csv, 'r', encoding="utf-8") as f:
data = pd.concat(
[chunk for chunk in tqdm(
pd.read_csv(f,
encoding="utf-8",
quoting=csv.QUOTE_ALL,
nrows=200, ## XXX
chunksize=100),
nrows=rows,
chunksize=50,
),
desc='Loading data'
)])
else:
with open(input_csv, 'r', encoding="utf-8") as f:
data = pd.read_csv(f,
encoding="utf-8",
quoting=csv.QUOTE_ALL,
nrows=rows,
)
data.dropna(axis='index', inplace=True)
data.dropna(axis='index', inplace=True)
#print(data)
#sys.exit(0)
return data
#print(data)
#sys.exit(0)
'''
Create Training and Validation sets
'''
# Create a list of ints till len of data
data_idx = list(range(len(data)))
np.random.shuffle(data_idx)
def split_dataset(data, verbose=0):
# Create a list of ints till len of data
data_idx = list(range(len(data)))
np.random.shuffle(data_idx)
# Get indexes for validation and train
split_percent = 0.95
num_train = int(len(data) * split_percent)
valid_idx, train_idx = data_idx[num_train:], data_idx[:num_train]
print("Length of train_data: {}".format(len(train_idx)))
print("Length of valid_data: {}".format(len(valid_idx)))
# Get indexes for validation and train
split_percent = 0.05
num_valid = int(len(data) * split_percent)
#num_tests = int(len(data) * split_percent)
#train_idx = data_idx[num_valid:-num_tests]
train_idx = data_idx[num_valid:]
valid_idx = data_idx[:num_valid]
#tests_idx = data_idx[-num_tests:]
if verbose > 0:
print("Length of train_data: {}".format(len(train_idx)))
print("Length of valid_data: {}".format(len(valid_idx)))
#print("Length of tests_data: {}".format(len(tests_idx)))
# Create the training and validation sets, as dataframes
train_data = data.iloc[train_idx].reset_index().drop('index', axis=1)
valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1)
# Create the training and validation sets, as dataframes
train_data = data.iloc[train_idx].reset_index().drop('index', axis=1)
valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1)
#tests_data = data.iloc[tests_idx].reset_index().drop('index', axis=1)
#return(train_data, valid_data, tests_data)
return(train_data, valid_data)
'''
@ -88,21 +92,24 @@ valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1)
'''
class TextCategoriesDataset(Dataset):
''' Dataset of Text and Categories '''
def __init__(self, df, text_column, cats_column, transform=None):
def __init__(self, df, text_column, cats_column, lang_column, transform=None, verbose=0):
'''
Arguments:
df (panda.Dataframe): csv content, loaded as dataframe
text_column (str): the name of the column containing the text
cats_column (str): the name of the column containing
semicolon-separated categories
text_column (str): the name of the column containing the language
transform (callable, optional): Optional transform to be
applied on a sample.
'''
self.df = df
self.transform = transform
self.verbose = verbose
self.texts = self.df[text_column]
self.text = self.df[text_column]
self.cats = self.df[cats_column]
self.lang = self.df[lang_column]
# index-to-token dict
# <pad> : padding, used for padding the shorter sentences in a batch
@ -146,8 +153,9 @@ class TextCategoriesDataset(Dataset):
idx = idx.tolist()
# Get the raw data
text = self.texts[idx]
text = self.text[idx]
cats = self.cats[idx]
lang = self.lang[idx]
if self.transform:
text, cats = self.transform(text, cats)
@ -186,25 +194,6 @@ class TextCategoriesDataset(Dataset):
T.AddToken(2, begin=False)
)
'''
dataset = TextCategoriesDataset(df=data,
text_column="content",
cats_column="categories",
)
'''
train_dataset = TextCategoriesDataset(df=train_data,
text_column="content",
cats_column="categories",
)
valid_dataset = TextCategoriesDataset(df=valid_data,
text_column="content",
cats_column="categories",
)
#print(dataset[2])
#for text, cat in enumerate(valid_dataset):
# print(text, cat)
#sys.exit(0)
'''
Now that we have a dataset, let's create dataloader,
@ -232,44 +221,183 @@ class CollateBatch:
)
# Hyperparameters
EPOCHS = 10 # epoch
LR = 5 # learning rate
BATCH_SIZE = 64 # batch size for training
class TextClassificationModel(nn.Module):
def __init__(self, input_size, output_size, verbose):
super().__init__()
# Get cpu, gpu or mps device for training.
# Move tensor to the NVIDIA GPU if available
device = (
def forward(self, x):
return x
def train(dataloader):
model.train()
total_acc, total_count = 0, 0
log_interval = 500
start_time = time.time()
for idx, (label, text) in enumerate(dataloader):
optimizer.zero_grad()
predicted_label = model(text)
loss = criterion(predicted_label, label)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
optimizer.step()
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
if idx % log_interval == 0 and idx > 0:
elapsed = time.time() - start_time
print(
"| epoch {:3d} | {:5d}/{:5d} batches "
"| accuracy {:8.3f}".format(
epoch, idx, len(dataloader), total_acc / total_count
)
)
total_acc, total_count = 0, 0
start_time = time.time()
def evaluate(dataloader):
model.eval()
total_acc, total_count = 0, 0
with torch.no_grad():
for idx, (label, text) in enumerate(dataloader):
predicted_label = model(text)
loss = criterion(predicted_label, label)
total_acc += (predicted_label.argmax(1) == label).sum().item()
total_count += label.size(0)
return total_acc / total_count
def main():
parser = argparse.ArgumentParser(
description='Classify text data according to categories',
add_help=True,
)
parser.add_argument('action',
help='train or classify')
parser.add_argument('--input', '-i',
required=True,
help='path of CSV file containing dataset')
parser.add_argument('--model', '-m',
#required=True, # XXX
help='path to training model')
parser.add_argument('--verbose', '-v',
type=int, nargs='?',
const=1, # Default value if -v is supplied
default=0, # Default value if -v is not supplied
help='print debugging')
args = parser.parse_args()
if args.action != 'train' and args.action != 'classify':
print("ERROR: train or classify data")
sys.exit(1)
if args.action == 'classify' and s.path.isfile(model_storage) is None:
print("No model found for classification; running training instead")
args.action = 'train'
if os.path.isfile(args.input) is False:
print(f"{args.input} is not a valid file")
sys.exit(1)
data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose)
train_data, valid_data, = split_dataset(data, verbose=args.verbose)
'''
dataset = TextCategoriesDataset(df=data,
text_column="content",
cats_column="categories",
lang_column="language",
verbose=args.verbose,
)
'''
train_dataset = TextCategoriesDataset(df=train_data,
text_column="content",
cats_column="categories",
lang_column="language",
verbose=args.verbose,
)
valid_dataset = TextCategoriesDataset(df=valid_data,
text_column="content",
cats_column="categories",
lang_column="language",
verbose=args.verbose,
)
#print(dataset[2])
#for text, cat in enumerate(valid_dataset):
# print(text, cat)
#sys.exit(0)
# Get cpu, gpu or mps device for training.
# Move tensor to the NVIDIA GPU if available
device = (
"cuda" if torch.cuda.is_available()
else "xps" if hasattr(torch, "xpu") and torch.xpu.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
print(f"Using {device} device")
)
print(f"Using {device} device")
'''
dataloader = DataLoader(dataset,
'''
dataloader = DataLoader(dataset,
batch_size=4,
shuffle=True,
num_workers=0,
collate_fn=CollateBatch(pad_idx=dataset.stoi['<pad>']),
)
'''
train_dataloader = DataLoader(train_dataset,
)
'''
train_dataloader = DataLoader(train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
)
valid_dataloader = DataLoader(valid_dataset,
)
valid_dataloader = DataLoader(valid_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=0,
collate_fn=CollateBatch(pad_idx=valid_dataset.stoi['<pad>']),
)
#for i_batch, sample_batched in enumerate(dataloader):
# print(i_batch, sample_batched[0], sample_batched[1])
#sys.exit(0)
)
#for i_batch, sample_batched in enumerate(dataloader):
# print(i_batch, sample_batched[0], sample_batched[1])
#sys.exit(0)
num_class = len(set([cats for key, cats, text, lang in train_data.values]))
input_size = len(train_dataset.text_vocab)
output_size = len(train_dataset.cats_vocab)
emsize = 64
model = TextClassificationModel(input_size, output_size, args.verbose).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
for epoch in range(1, EPOCHS + 1):
epoch_start_time = time.time()
train(train_dataloader)
accu_val = evaluate(valid_dataloader)
if total_accu is not None and total_accu > accu_val:
scheduler.step()
else:
total_accu = accu_val
print("-" * 59)
print(
"| end of epoch {:3d} | time: {:5.2f}s | "
"valid accuracy {:8.3f} ".format(
epoch, time.time() - epoch_start_time, accu_val
)
)
print("-" * 59)
print("Checking the results of test dataset.")
accu_test = evaluate(test_dataloader)
print("test accuracy {:8.3f}".format(accu_test))
return
if __name__ == "__main__":
main()