#!/usr/bin/python

import argparse
import os
import sys
import pprint
import re
import string
from string import digits
import warnings

import html
from xml.etree import ElementTree as ET

#data manupulation libs
import csv
import pandas as pd
from pandarallel import pandarallel

parser = argparse.ArgumentParser(
  description='Turn XML data files into a dataset for use with pytorch', 
  add_help=True,
)
parser.add_argument('--output', '-o', required=True, help='path of output CSV file')
parser.add_argument('--input',  '-i', required=True, help='path of input directory containing XML files')
args = parser.parse_args()

if os.path.isdir(args.input) is False:
  print(f"{args.input} is not a directory or does not exist");
  sys.exit(1)

#1. Load XML file
#2. Create structure
#3. Preprocess the data to remove punctuations, digits, spaces and making the text lower.
#.  This helps reduce the vocab of the data (as now, "Cat ~" is "cat")

def insert_line_numbers(txt):
  return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))])

def local_clean(s):
  s = re.sub('[\x0b\x1a-\x1f]', '', s)
  return s

def partial_unescape(s):
  parts = re.split(r'&(lt;|#60;|#x3c;|gt;|#62;|#x3e;|amp;?|#38;|#x26;)', s)
  for i, part in enumerate(parts):
    if i % 2 == 0:
      parts[i] = html.unescape(parts[i])
  return "".join(parts)

articles = list()
#allCats = list()

total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0
for root, dirs, files in os.walk(args.input):
  dirs.sort()
  print(root)
  for file in sorted(files):
    #if re.search('2022\/10\/09', root) and re.search('0028.aans$', file):
    if re.search('.aans$', file):
      xml_file = os.path.join(root, file)
      total += 1
      try:
        with open(xml_file, 'r', encoding="ASCII") as f:
          content = f.read()
        #print(f"ASCII read succeeded in {xml_file}")
        plain += 1
      except Exception as e:
        #print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}")
        try:
          with open(xml_file, 'r', encoding="UTF-8") as f:
            content = f.read()
          #print(f"UTF-8 read succeeded in {xml_file}")
          utf8 += 1
        except Exception as e:
          #print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}")
          try:
            with open(xml_file, 'r', encoding="ISO-8859-1") as f:
              content = f.read()
            #print(f"ISO-8859-1 read succeeded in {xml_file}")
            iso88591 += 1
          except Exception as e: 
            print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}")
            print(content)
            failed += 1
      content = partial_unescape(content)
      content = local_clean(content)
      #print(content)

      key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file)
      try:
        doc = ET.fromstring(content)
        entry = dict()
        entry["key"] = key
        cats = list()
        for cat in doc.findall('category'):
          #if cat not in allCats:
          #  allCats.append(cat)
          cats.append(cat.text)
        #entry["categories"] = cats
        entry["categories"] = ";".join(cats)
        text = list()
        try:
          #text = "\n".join([p.text for p in doc.find('./body')])
          for p in doc.find('./body'):
            if p.text is not None:
              text.append(p.text)
          if text is not None and len(cats) > 1:
            entry["content"] = "\n".join(text)
            articles.append(entry)
        except Exception as e:
          print(f"{xml_file} : {e}")
      except ET.ParseError as e:
        print(insert_line_numbers(content))
        print("Parse error in " + xml_file + " : ", e)
        raise(SystemExit)

print("total:    {: 7d}\nplain:    {: 7d}\nutf8:     {: 7d}\niso88591: {: 7d}\nfailed:   {: 7d}\n".format(total, plain, utf8, iso88591, failed))

#sys.exit(0)

data = pd.DataFrame(articles)
data.set_index("key", inplace=True)

#print(data.categories)

# Initialization
pandarallel.initialize()

# Lowercase everything
data['content'] = data.content.parallel_apply(lambda x: x.lower())

# Remove special characters
exclude = set(string.punctuation) #set of all special chars
data['content'] = data.content.parallel_apply(lambda x: ''.join(ch for ch in x if ch not in exclude))

# Remove digits
remove_digits = str.maketrans('','',digits)
data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digits))

# Remove extra spaces
data['content']=data.content.parallel_apply(lambda x: x.strip())
data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x))

with open(args.output, 'w', encoding="utf-8") as f:
  data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)