147 lines
4.6 KiB
Python
147 lines
4.6 KiB
Python
|
#!/usr/bin/python
|
||
|
|
||
|
import argparse
|
||
|
import os
|
||
|
import sys
|
||
|
import pprint
|
||
|
import re
|
||
|
import string
|
||
|
from string import digits
|
||
|
import warnings
|
||
|
|
||
|
import html
|
||
|
from xml.etree import ElementTree as ET
|
||
|
|
||
|
#data manupulation libs
|
||
|
import csv
|
||
|
import pandas as pd
|
||
|
from pandarallel import pandarallel
|
||
|
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description='Turn XML data files into a dataset for use with pytorch',
|
||
|
add_help=True,
|
||
|
)
|
||
|
parser.add_argument('--output', '-o', required=True, help='path of output CSV file')
|
||
|
parser.add_argument('--input', '-i', required=True, help='path of input directory containing XML files')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
if os.path.isdir(args.input) is False:
|
||
|
print(f"{args.input} is not a directory or does not exist");
|
||
|
sys.exit(1)
|
||
|
|
||
|
#1. Load XML file
|
||
|
#2. Create structure
|
||
|
#3. Preprocess the data to remove punctuations, digits, spaces and making the text lower.
|
||
|
#. This helps reduce the vocab of the data (as now, "Cat ~" is "cat")
|
||
|
|
||
|
def insert_line_numbers(txt):
|
||
|
return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))])
|
||
|
|
||
|
def local_clean(s):
|
||
|
s = re.sub('[\x0b\x1a-\x1f]', '', s)
|
||
|
return s
|
||
|
|
||
|
def partial_unescape(s):
|
||
|
parts = re.split(r'&(lt;|#60;|#x3c;|gt;|#62;|#x3e;|amp;?|#38;|#x26;)', s)
|
||
|
for i, part in enumerate(parts):
|
||
|
if i % 2 == 0:
|
||
|
parts[i] = html.unescape(parts[i])
|
||
|
return "".join(parts)
|
||
|
|
||
|
articles = list()
|
||
|
#allCats = list()
|
||
|
|
||
|
total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0
|
||
|
for root, dirs, files in os.walk(args.input):
|
||
|
dirs.sort()
|
||
|
print(root)
|
||
|
for file in sorted(files):
|
||
|
#if re.search('2022\/10\/09', root) and re.search('0028.aans$', file):
|
||
|
if re.search('.aans$', file):
|
||
|
xml_file = os.path.join(root, file)
|
||
|
total += 1
|
||
|
try:
|
||
|
with open(xml_file, 'r', encoding="ASCII") as f:
|
||
|
content = f.read()
|
||
|
#print(f"ASCII read succeeded in {xml_file}")
|
||
|
plain += 1
|
||
|
except Exception as e:
|
||
|
#print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}")
|
||
|
try:
|
||
|
with open(xml_file, 'r', encoding="UTF-8") as f:
|
||
|
content = f.read()
|
||
|
#print(f"UTF-8 read succeeded in {xml_file}")
|
||
|
utf8 += 1
|
||
|
except Exception as e:
|
||
|
#print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}")
|
||
|
try:
|
||
|
with open(xml_file, 'r', encoding="ISO-8859-1") as f:
|
||
|
content = f.read()
|
||
|
#print(f"ISO-8859-1 read succeeded in {xml_file}")
|
||
|
iso88591 += 1
|
||
|
except Exception as e:
|
||
|
print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}")
|
||
|
print(content)
|
||
|
failed += 1
|
||
|
content = partial_unescape(content)
|
||
|
content = local_clean(content)
|
||
|
#print(content)
|
||
|
|
||
|
key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file)
|
||
|
try:
|
||
|
doc = ET.fromstring(content)
|
||
|
entry = dict()
|
||
|
entry["key"] = key
|
||
|
cats = list()
|
||
|
for cat in doc.findall('category'):
|
||
|
#if cat not in allCats:
|
||
|
# allCats.append(cat)
|
||
|
cats.append(cat.text)
|
||
|
#entry["categories"] = cats
|
||
|
entry["categories"] = ";".join(cats)
|
||
|
text = list()
|
||
|
try:
|
||
|
#text = "\n".join([p.text for p in doc.find('./body')])
|
||
|
for p in doc.find('./body'):
|
||
|
if p.text is not None:
|
||
|
text.append(p.text)
|
||
|
if text is not None and len(cats) > 1:
|
||
|
entry["content"] = "\n".join(text)
|
||
|
articles.append(entry)
|
||
|
except Exception as e:
|
||
|
print(f"{xml_file} : {e}")
|
||
|
except ET.ParseError as e:
|
||
|
print(insert_line_numbers(content))
|
||
|
print("Parse error in " + xml_file + " : ", e)
|
||
|
raise(SystemExit)
|
||
|
|
||
|
print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed))
|
||
|
|
||
|
#sys.exit(0)
|
||
|
|
||
|
data = pd.DataFrame(articles)
|
||
|
data.set_index("key", inplace=True)
|
||
|
|
||
|
#print(data.categories)
|
||
|
|
||
|
# Initialization
|
||
|
pandarallel.initialize()
|
||
|
|
||
|
# Lowercase everything
|
||
|
data['content'] = data.content.parallel_apply(lambda x: x.lower())
|
||
|
|
||
|
# Remove special characters
|
||
|
exclude = set(string.punctuation) #set of all special chars
|
||
|
data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude))
|
||
|
|
||
|
# Remove digits
|
||
|
remove_digits = str.maketrans('','',digits)
|
||
|
data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits))
|
||
|
|
||
|
# Remove extra spaces
|
||
|
data['content']=data.content.parallel_apply(lambda x: x.strip())
|
||
|
data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
|
||
|
|
||
|
with open(args.output, 'w', encoding="utf-8") as f:
|
||
|
data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)
|