africat/africat/aa_create_dataset.py

208 lines
6.3 KiB
Python
Executable File

#!/usr/bin/python
'''
1. Load XML file
2. Create structure
3. Preprocess the data to remove punctuations, digits, spaces and making the text lower.
This helps reduce the vocab of the data (as now, "Cat ~" is "cat")
'''
import argparse
import os
import sys
import pprint
import re
import string
from string import digits
import warnings
import html
from xml.etree import ElementTree as ET
# data manupulation libs
import csv
import pandas as pd
from pandarallel import pandarallel
def write_csv(data, output):
with open(output, 'w', encoding="utf-8") as f:
data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)
def insert_line_numbers(txt):
return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))])
def local_clean(s):
s = re.sub('[\x0b\x1a-\x1f]', '', s)
return s
def partial_unescape(s):
parts = re.split(r'&(lt;|#60;|#x3c;|gt;|#62;|#x3e;|amp;?|#38;|#x26;)', s)
for i, part in enumerate(parts):
if i % 2 == 0:
parts[i] = html.unescape(parts[i])
return "".join(parts)
def parse_and_extract(input_dir, verbose):
articles = list()
total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0
for root, dirs, files in os.walk(input_dir):
dirs.sort()
if verbose > 0:
print(root)
for file in sorted(files):
#if re.search('2022\/10\/09', root) and re.search('0028.aans$', file):
if re.search('.aans$', file):
xml_file = os.path.join(root, file)
total += 1
try:
with open(xml_file, 'r', encoding="ASCII") as f:
content = f.read()
if verbose > 1:
print(f"ASCII read succeeded in {xml_file}")
plain += 1
except Exception as e:
if verbose > 1:
print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}")
try:
with open(xml_file, 'r', encoding="UTF-8") as f:
content = f.read()
if verbose > 1:
print(f"UTF-8 read succeeded in {xml_file}")
utf8 += 1
except Exception as e:
if verbose > 1:
print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}")
try:
with open(xml_file, 'r', encoding="ISO-8859-1") as f:
content = f.read()
if verbose > 1:
print(f"ISO-8859-1 read succeeded in {xml_file}")
iso88591 += 1
except Exception as e:
print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}")
if verbose > 2:
print(content)
failed += 1
content = partial_unescape(content)
content = local_clean(content)
if verbose > 3:
print(content)
key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file)
try:
doc = ET.fromstring(content)
entry = dict()
entry["key"] = key
cats = list()
for cat in doc.findall('./category'):
cats.append(cat.text)
text = list()
lang = ""
try:
for t in doc.find('./publisher_headline'):
if t.text is not None:
text.append(t.text)
for p in doc.find('./body'):
if p.text is not None:
text.append(p.text)
lang = doc.find('./language').text
except Exception as e:
print(f"{xml_file} : {e}")
if text is not None and len(cats) >= 1:
entry["language"] = lang
entry["content"] = "\n".join(text)
for cat in cats:
entry[cat] = 1
articles.append(entry)
else:
if len(cats) < 1:
print(f"No article added for key {key} due to lack of categories")
elif text is None:
print(f"No article added for key {key} due to lack of text")
else:
print(f"No article added for key {key} due to unknown error")
except ET.ParseError as e:
if verbose > 1:
print(insert_line_numbers(content))
print("Parse error in " + xml_file + " : ", e)
raise(SystemExit)
if verbose > 0:
print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed))
#sys.exit(0)
return articles
def scrub_data(articles, verbose):
data = pd.DataFrame(articles)
data.set_index("key", inplace=True)
#if verbose > 2:
# print(data.categories)
# Initialization
pandarallel.initialize()
# Lowercase everything
data['content'] = data.content.parallel_apply(lambda x: x.lower())
# Remove special characters
exclude = set(string.punctuation) #set of all special chars
data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude))
# Remove digits
remove_digits = str.maketrans('','',digits)
data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits))
# Remove extra spaces
data['content'] = data.content.parallel_apply(lambda x: x.strip())
data['content'] = data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
# Any remaining text processing can be done by training/inference step
# sort category columns: lowercase first (key, language, content), then title-cased categories
data.reindex(columns=sorted(data.columns, key=lambda x: (x.casefold(), x.swapcase())))
return data
def main():
parser = argparse.ArgumentParser(
description='Turn XML data files into a dataset for use with pytorch',
add_help=True,
)
parser.add_argument('--output', '-o',
required=True,
help='path of output CSV file')
parser.add_argument('--input', '-i',
required=True,
help='path of input directory containing XML files')
parser.add_argument('--verbose', '-v',
type=int, nargs='?',
const=1, # Default value if -v is supplied
default=0, # Default value if -v is not supplied
help='print debugging')
args = parser.parse_args()
if os.path.isdir(args.input) is False:
print(f"{args.input} is not a directory or does not exist");
sys.exit(1)
articles = parse_and_extract(args.input, args.verbose)
data = scrub_data(articles, args.verbose)
#print(data)
write_csv(data, args.output)
return
if __name__ == "__main__":
main()
# vim: set expandtab shiftwidth=2 softtabstop=2: