Convert a bunch of XML files into a CSV dataset
This commit is contained in:
parent
fcee47be08
commit
f60aeb0afe
146
aa_create_dataset.py
Executable file
146
aa_create_dataset.py
Executable file
@ -0,0 +1,146 @@
|
||||
#!/usr/bin/python
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import pprint
|
||||
import re
|
||||
import string
|
||||
from string import digits
|
||||
import warnings
|
||||
|
||||
import html
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
#data manupulation libs
|
||||
import csv
|
||||
import pandas as pd
|
||||
from pandarallel import pandarallel
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Turn XML data files into a dataset for use with pytorch',
|
||||
add_help=True,
|
||||
)
|
||||
parser.add_argument('--output', '-o', required=True, help='path of output CSV file')
|
||||
parser.add_argument('--input', '-i', required=True, help='path of input directory containing XML files')
|
||||
args = parser.parse_args()
|
||||
|
||||
if os.path.isdir(args.input) is False:
|
||||
print(f"{args.input} is not a directory or does not exist");
|
||||
sys.exit(1)
|
||||
|
||||
#1. Load XML file
|
||||
#2. Create structure
|
||||
#3. Preprocess the data to remove punctuations, digits, spaces and making the text lower.
|
||||
#. This helps reduce the vocab of the data (as now, "Cat ~" is "cat")
|
||||
|
||||
def insert_line_numbers(txt):
|
||||
return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))])
|
||||
|
||||
def local_clean(s):
|
||||
s = re.sub('[\x0b\x1a-\x1f]', '', s)
|
||||
return s
|
||||
|
||||
def partial_unescape(s):
|
||||
parts = re.split(r'&(lt;|#60;|#x3c;|gt;|#62;|#x3e;|amp;?|#38;|#x26;)', s)
|
||||
for i, part in enumerate(parts):
|
||||
if i % 2 == 0:
|
||||
parts[i] = html.unescape(parts[i])
|
||||
return "".join(parts)
|
||||
|
||||
articles = list()
|
||||
#allCats = list()
|
||||
|
||||
total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0
|
||||
for root, dirs, files in os.walk(args.input):
|
||||
dirs.sort()
|
||||
print(root)
|
||||
for file in sorted(files):
|
||||
#if re.search('2022\/10\/09', root) and re.search('0028.aans$', file):
|
||||
if re.search('.aans$', file):
|
||||
xml_file = os.path.join(root, file)
|
||||
total += 1
|
||||
try:
|
||||
with open(xml_file, 'r', encoding="ASCII") as f:
|
||||
content = f.read()
|
||||
#print(f"ASCII read succeeded in {xml_file}")
|
||||
plain += 1
|
||||
except Exception as e:
|
||||
#print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}")
|
||||
try:
|
||||
with open(xml_file, 'r', encoding="UTF-8") as f:
|
||||
content = f.read()
|
||||
#print(f"UTF-8 read succeeded in {xml_file}")
|
||||
utf8 += 1
|
||||
except Exception as e:
|
||||
#print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}")
|
||||
try:
|
||||
with open(xml_file, 'r', encoding="ISO-8859-1") as f:
|
||||
content = f.read()
|
||||
#print(f"ISO-8859-1 read succeeded in {xml_file}")
|
||||
iso88591 += 1
|
||||
except Exception as e:
|
||||
print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}")
|
||||
print(content)
|
||||
failed += 1
|
||||
content = partial_unescape(content)
|
||||
content = local_clean(content)
|
||||
#print(content)
|
||||
|
||||
key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file)
|
||||
try:
|
||||
doc = ET.fromstring(content)
|
||||
entry = dict()
|
||||
entry["key"] = key
|
||||
cats = list()
|
||||
for cat in doc.findall('category'):
|
||||
#if cat not in allCats:
|
||||
# allCats.append(cat)
|
||||
cats.append(cat.text)
|
||||
#entry["categories"] = cats
|
||||
entry["categories"] = ";".join(cats)
|
||||
text = list()
|
||||
try:
|
||||
#text = "\n".join([p.text for p in doc.find('./body')])
|
||||
for p in doc.find('./body'):
|
||||
if p.text is not None:
|
||||
text.append(p.text)
|
||||
if text is not None and len(cats) > 1:
|
||||
entry["content"] = "\n".join(text)
|
||||
articles.append(entry)
|
||||
except Exception as e:
|
||||
print(f"{xml_file} : {e}")
|
||||
except ET.ParseError as e:
|
||||
print(insert_line_numbers(content))
|
||||
print("Parse error in " + xml_file + " : ", e)
|
||||
raise(SystemExit)
|
||||
|
||||
print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed))
|
||||
|
||||
#sys.exit(0)
|
||||
|
||||
data = pd.DataFrame(articles)
|
||||
data.set_index("key", inplace=True)
|
||||
|
||||
#print(data.categories)
|
||||
|
||||
# Initialization
|
||||
pandarallel.initialize()
|
||||
|
||||
# Lowercase everything
|
||||
data['content'] = data.content.parallel_apply(lambda x: x.lower())
|
||||
|
||||
# Remove special characters
|
||||
exclude = set(string.punctuation) #set of all special chars
|
||||
data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude))
|
||||
|
||||
# Remove digits
|
||||
remove_digits = str.maketrans('','',digits)
|
||||
data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits))
|
||||
|
||||
# Remove extra spaces
|
||||
data['content']=data.content.parallel_apply(lambda x: x.strip())
|
||||
data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
|
||||
|
||||
with open(args.output, 'w', encoding="utf-8") as f:
|
||||
data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)
|
Loading…
Reference in New Issue
Block a user