africat/aa_create_dataset.py

147 lines
4.6 KiB
Python
Raw Normal View History

#!/usr/bin/python
import argparse
import os
import sys
import pprint
import re
import string
from string import digits
import warnings
import html
from xml.etree import ElementTree as ET
#data manupulation libs
import csv
import pandas as pd
from pandarallel import pandarallel
parser = argparse.ArgumentParser(
description='Turn XML data files into a dataset for use with pytorch',
add_help=True,
)
parser.add_argument('--output', '-o', required=True, help='path of output CSV file')
parser.add_argument('--input', '-i', required=True, help='path of input directory containing XML files')
args = parser.parse_args()
if os.path.isdir(args.input) is False:
print(f"{args.input} is not a directory or does not exist");
sys.exit(1)
#1. Load XML file
#2. Create structure
#3. Preprocess the data to remove punctuations, digits, spaces and making the text lower.
#. This helps reduce the vocab of the data (as now, "Cat ~" is "cat")
def insert_line_numbers(txt):
return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))])
def local_clean(s):
s = re.sub('[\x0b\x1a-\x1f]', '', s)
return s
def partial_unescape(s):
parts = re.split(r'&(lt;|#60;|#x3c;|gt;|#62;|#x3e;|amp;?|#38;|#x26;)', s)
for i, part in enumerate(parts):
if i % 2 == 0:
parts[i] = html.unescape(parts[i])
return "".join(parts)
articles = list()
#allCats = list()
total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0
for root, dirs, files in os.walk(args.input):
dirs.sort()
print(root)
for file in sorted(files):
#if re.search('2022\/10\/09', root) and re.search('0028.aans$', file):
if re.search('.aans$', file):
xml_file = os.path.join(root, file)
total += 1
try:
with open(xml_file, 'r', encoding="ASCII") as f:
content = f.read()
#print(f"ASCII read succeeded in {xml_file}")
plain += 1
except Exception as e:
#print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}")
try:
with open(xml_file, 'r', encoding="UTF-8") as f:
content = f.read()
#print(f"UTF-8 read succeeded in {xml_file}")
utf8 += 1
except Exception as e:
#print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}")
try:
with open(xml_file, 'r', encoding="ISO-8859-1") as f:
content = f.read()
#print(f"ISO-8859-1 read succeeded in {xml_file}")
iso88591 += 1
except Exception as e:
print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}")
print(content)
failed += 1
content = partial_unescape(content)
content = local_clean(content)
#print(content)
key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file)
try:
doc = ET.fromstring(content)
entry = dict()
entry["key"] = key
cats = list()
for cat in doc.findall('category'):
#if cat not in allCats:
# allCats.append(cat)
cats.append(cat.text)
#entry["categories"] = cats
entry["categories"] = ";".join(cats)
text = list()
try:
#text = "\n".join([p.text for p in doc.find('./body')])
for p in doc.find('./body'):
if p.text is not None:
text.append(p.text)
if text is not None and len(cats) > 1:
entry["content"] = "\n".join(text)
articles.append(entry)
except Exception as e:
print(f"{xml_file} : {e}")
except ET.ParseError as e:
print(insert_line_numbers(content))
print("Parse error in " + xml_file + " : ", e)
raise(SystemExit)
print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed))
#sys.exit(0)
data = pd.DataFrame(articles)
data.set_index("key", inplace=True)
#print(data.categories)
# Initialization
pandarallel.initialize()
# Lowercase everything
data['content'] = data.content.parallel_apply(lambda x: x.lower())
# Remove special characters
exclude = set(string.punctuation) #set of all special chars
data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude))
# Remove digits
remove_digits = str.maketrans('','',digits)
data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits))
# Remove extra spaces
data['content']=data.content.parallel_apply(lambda x: x.strip())
data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
with open(args.output, 'w', encoding="utf-8") as f:
data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)