#!/usr/bin/python import argparse import os import sys import pprint import re import string from string import digits import warnings import html from xml.etree import ElementTree as ET #data manupulation libs import csv import pandas as pd from pandarallel import pandarallel parser = argparse.ArgumentParser( description='Turn XML data files into a dataset for use with pytorch', add_help=True, ) parser.add_argument('--output', '-o', required=True, help='path of output CSV file') parser.add_argument('--input', '-i', required=True, help='path of input directory containing XML files') args = parser.parse_args() if os.path.isdir(args.input) is False: print(f"{args.input} is not a directory or does not exist"); sys.exit(1) #1. Load XML file #2. Create structure #3. Preprocess the data to remove punctuations, digits, spaces and making the text lower. #. This helps reduce the vocab of the data (as now, "Cat ~" is "cat") def insert_line_numbers(txt): return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))]) def local_clean(s): s = re.sub('[\x0b\x1a-\x1f]', '', s) return s def partial_unescape(s): parts = re.split(r'&(lt;|#60;|#x3c;|gt;|#62;|#x3e;|amp;?|#38;|#x26;)', s) for i, part in enumerate(parts): if i % 2 == 0: parts[i] = html.unescape(parts[i]) return "".join(parts) articles = list() #allCats = list() total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0 for root, dirs, files in os.walk(args.input): dirs.sort() print(root) for file in sorted(files): #if re.search('2022\/10\/09', root) and re.search('0028.aans$', file): if re.search('.aans$', file): xml_file = os.path.join(root, file) total += 1 try: with open(xml_file, 'r', encoding="ASCII") as f: content = f.read() #print(f"ASCII read succeeded in {xml_file}") plain += 1 except Exception as e: #print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}") try: with open(xml_file, 'r', encoding="UTF-8") as f: content = f.read() #print(f"UTF-8 read succeeded in {xml_file}") utf8 += 1 except Exception as e: #print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}") try: with open(xml_file, 'r', encoding="ISO-8859-1") as f: content = f.read() #print(f"ISO-8859-1 read succeeded in {xml_file}") iso88591 += 1 except Exception as e: print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}") print(content) failed += 1 content = partial_unescape(content) content = local_clean(content) #print(content) key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file) try: doc = ET.fromstring(content) entry = dict() entry["key"] = key cats = list() for cat in doc.findall('category'): #if cat not in allCats: # allCats.append(cat) cats.append(cat.text) #entry["categories"] = cats entry["categories"] = ";".join(cats) text = list() try: #text = "\n".join([p.text for p in doc.find('./body')]) for p in doc.find('./body'): if p.text is not None: text.append(p.text) if text is not None and len(cats) > 1: entry["content"] = "\n".join(text) articles.append(entry) except Exception as e: print(f"{xml_file} : {e}") except ET.ParseError as e: print(insert_line_numbers(content)) print("Parse error in " + xml_file + " : ", e) raise(SystemExit) print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed)) #sys.exit(0) data = pd.DataFrame(articles) data.set_index("key", inplace=True) #print(data.categories) # Initialization pandarallel.initialize() # Lowercase everything data['content'] = data.content.parallel_apply(lambda x: x.lower()) # Remove special characters exclude = set(string.punctuation) #set of all special chars data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude)) # Remove digits remove_digits = str.maketrans('','',digits) data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits)) # Remove extra spaces data['content']=data.content.parallel_apply(lambda x: x.strip()) data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x)) with open(args.output, 'w', encoding="utf-8") as f: data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)