diff --git a/aa_create_dataset.py b/aa_create_dataset.py new file mode 100755 index 0000000..b45e435 --- /dev/null +++ b/aa_create_dataset.py @@ -0,0 +1,146 @@ +#!/usr/bin/python + +import argparse +import os +import sys +import pprint +import re +import string +from string import digits +import warnings + +import html +from xml.etree import ElementTree as ET + +#data manupulation libs +import csv +import pandas as pd +from pandarallel import pandarallel + +parser = argparse.ArgumentParser( + description='Turn XML data files into a dataset for use with pytorch', + add_help=True, +) +parser.add_argument('--output', '-o', required=True, help='path of output CSV file') +parser.add_argument('--input', '-i', required=True, help='path of input directory containing XML files') +args = parser.parse_args() + +if os.path.isdir(args.input) is False: + print(f"{args.input} is not a directory or does not exist"); + sys.exit(1) + +#1. Load XML file +#2. Create structure +#3. Preprocess the data to remove punctuations, digits, spaces and making the text lower. +#. This helps reduce the vocab of the data (as now, "Cat ~" is "cat") + +def insert_line_numbers(txt): + return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))]) + +def local_clean(s): + s = re.sub('[\x0b\x1a-\x1f]', '', s) + return s + +def partial_unescape(s): + parts = re.split(r'&(lt;|#60;|#x3c;|gt;|#62;|#x3e;|amp;?|#38;|#x26;)', s) + for i, part in enumerate(parts): + if i % 2 == 0: + parts[i] = html.unescape(parts[i]) + return "".join(parts) + +articles = list() +#allCats = list() + +total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0 +for root, dirs, files in os.walk(args.input): + dirs.sort() + print(root) + for file in sorted(files): + #if re.search('2022\/10\/09', root) and re.search('0028.aans$', file): + if re.search('.aans$', file): + xml_file = os.path.join(root, file) + total += 1 + try: + with open(xml_file, 'r', encoding="ASCII") as f: + content = f.read() + #print(f"ASCII read succeeded in {xml_file}") + plain += 1 + except Exception as e: + #print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}") + try: + with open(xml_file, 'r', encoding="UTF-8") as f: + content = f.read() + #print(f"UTF-8 read succeeded in {xml_file}") + utf8 += 1 + except Exception as e: + #print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}") + try: + with open(xml_file, 'r', encoding="ISO-8859-1") as f: + content = f.read() + #print(f"ISO-8859-1 read succeeded in {xml_file}") + iso88591 += 1 + except Exception as e: + print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}") + print(content) + failed += 1 + content = partial_unescape(content) + content = local_clean(content) + #print(content) + + key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file) + try: + doc = ET.fromstring(content) + entry = dict() + entry["key"] = key + cats = list() + for cat in doc.findall('category'): + #if cat not in allCats: + # allCats.append(cat) + cats.append(cat.text) + #entry["categories"] = cats + entry["categories"] = ";".join(cats) + text = list() + try: + #text = "\n".join([p.text for p in doc.find('./body')]) + for p in doc.find('./body'): + if p.text is not None: + text.append(p.text) + if text is not None and len(cats) > 1: + entry["content"] = "\n".join(text) + articles.append(entry) + except Exception as e: + print(f"{xml_file} : {e}") + except ET.ParseError as e: + print(insert_line_numbers(content)) + print("Parse error in " + xml_file + " : ", e) + raise(SystemExit) + +print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed)) + +#sys.exit(0) + +data = pd.DataFrame(articles) +data.set_index("key", inplace=True) + +#print(data.categories) + +# Initialization +pandarallel.initialize() + +# Lowercase everything +data['content'] = data.content.parallel_apply(lambda x: x.lower()) + +# Remove special characters +exclude = set(string.punctuation) #set of all special chars +data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude)) + +# Remove digits +remove_digits = str.maketrans('','',digits) +data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits)) + +# Remove extra spaces +data['content']=data.content.parallel_apply(lambda x: x.strip()) +data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x)) + +with open(args.output, 'w', encoding="utf-8") as f: + data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)