147 lines
4.6 KiB
Executable File
147 lines
4.6 KiB
Executable File
import argparse
import os
import sys
import pprint
import re
import string
from string import digits
import warnings
import html
from xml.etree import ElementTree as ET
#data manupulation libs
import csv
import pandas as pd
from pandarallel import pandarallel
parser = argparse.ArgumentParser(
description='Turn XML data files into a dataset for use with pytorch',
parser.add_argument('--output', '-o', required=True, help='path of output CSV file')
parser.add_argument('--input', '-i', required=True, help='path of input directory containing XML files')
args = parser.parse_args()
if os.path.isdir(args.input) is False:
print(f"{args.input} is not a directory or does not exist");
#1. Load XML file
#2. Create structure
#3. Preprocess the data to remove punctuations, digits, spaces and making the text lower.
#. This helps reduce the vocab of the data (as now, "Cat ~" is "cat")
def insert_line_numbers(txt):
return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))])
def local_clean(s):
s = re.sub('[\x0b\x1a-\x1f]', '', s)
return s
def partial_unescape(s):
parts = re.split(r'&(lt;|#60;|#x3c;|gt;|#62;|#x3e;|amp;?|#38;|#x26;)', s)
for i, part in enumerate(parts):
if i % 2 == 0:
parts[i] = html.unescape(parts[i])
return "".join(parts)
articles = list()
#allCats = list()
total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0
for root, dirs, files in os.walk(args.input):
for file in sorted(files):
#if re.search('2022\/10\/09', root) and re.search('0028.aans$', file):
if re.search('.aans$', file):
xml_file = os.path.join(root, file)
total += 1
with open(xml_file, 'r', encoding="ASCII") as f:
content = f.read()
#print(f"ASCII read succeeded in {xml_file}")
plain += 1
except Exception as e:
#print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}")
with open(xml_file, 'r', encoding="UTF-8") as f:
content = f.read()
#print(f"UTF-8 read succeeded in {xml_file}")
utf8 += 1
except Exception as e:
#print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}")
with open(xml_file, 'r', encoding="ISO-8859-1") as f:
content = f.read()
#print(f"ISO-8859-1 read succeeded in {xml_file}")
iso88591 += 1
except Exception as e:
print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}")
failed += 1
content = partial_unescape(content)
content = local_clean(content)
key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file)
doc = ET.fromstring(content)
entry = dict()
entry["key"] = key
cats = list()
for cat in doc.findall('category'):
#if cat not in allCats:
# allCats.append(cat)
#entry["categories"] = cats
entry["categories"] = ";".join(cats)
text = list()
#text = "\n".join([p.text for p in doc.find('./body')])
for p in doc.find('./body'):
if p.text is not None:
if text is not None and len(cats) > 1:
entry["content"] = "\n".join(text)
except Exception as e:
print(f"{xml_file} : {e}")
except ET.ParseError as e:
print("Parse error in " + xml_file + " : ", e)
print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed))
data = pd.DataFrame(articles)
data.set_index("key", inplace=True)
# Initialization
# Lowercase everything
data['content'] = data.content.parallel_apply(lambda x: x.lower())
# Remove special characters
exclude = set(string.punctuation) #set of all special chars
data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude))
# Remove digits
remove_digits = str.maketrans('','',digits)
data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits))
# Remove extra spaces
data['content']=data.content.parallel_apply(lambda x: x.strip())
data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
with open(args.output, 'w', encoding="utf-8") as f:
data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)