Format for poetry and add debugging
This commit is contained in:
parent
2039b017eb
commit
46f533746e
@ -1,4 +1,10 @@
|
|||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
|
'''
|
||||||
|
1. Load XML file
|
||||||
|
2. Create structure
|
||||||
|
3. Preprocess the data to remove punctuations, digits, spaces and making the text lower.
|
||||||
|
This helps reduce the vocab of the data (as now, "Cat ~" is "cat")
|
||||||
|
'''
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
@ -8,31 +14,16 @@ import re
|
|||||||
import string
|
import string
|
||||||
from string import digits
|
from string import digits
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import html
|
import html
|
||||||
from xml.etree import ElementTree as ET
|
from xml.etree import ElementTree as ET
|
||||||
|
|
||||||
# data manupulation libs
|
# data manupulation libs
|
||||||
import csv
|
import csv
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pandarallel import pandarallel
|
from pandarallel import pandarallel
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
def write_csv(data, output):
|
||||||
description='Turn XML data files into a dataset for use with pytorch',
|
with open(output, 'w', encoding="utf-8") as f:
|
||||||
add_help=True,
|
data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)
|
||||||
)
|
|
||||||
parser.add_argument('--output', '-o', required=True, help='path of output CSV file')
|
|
||||||
parser.add_argument('--input', '-i', required=True, help='path of input directory containing XML files')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if os.path.isdir(args.input) is False:
|
|
||||||
print(f"{args.input} is not a directory or does not exist");
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
#1. Load XML file
|
|
||||||
#2. Create structure
|
|
||||||
#3. Preprocess the data to remove punctuations, digits, spaces and making the text lower.
|
|
||||||
#. This helps reduce the vocab of the data (as now, "Cat ~" is "cat")
|
|
||||||
|
|
||||||
def insert_line_numbers(txt):
|
def insert_line_numbers(txt):
|
||||||
return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))])
|
return "\n".join([f"{n+1:03d} {line}" for n, line in enumerate(txt.split("\n"))])
|
||||||
@ -48,80 +39,100 @@ def partial_unescape(s):
|
|||||||
parts[i] = html.unescape(parts[i])
|
parts[i] = html.unescape(parts[i])
|
||||||
return "".join(parts)
|
return "".join(parts)
|
||||||
|
|
||||||
|
def parse_and_extract(input_dir, verbose):
|
||||||
articles = list()
|
articles = list()
|
||||||
#allCats = list()
|
|
||||||
|
|
||||||
total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0
|
total, plain, utf8, iso88591, failed = 0, 0, 0, 0, 0
|
||||||
for root, dirs, files in os.walk(args.input):
|
for root, dirs, files in os.walk(input_dir):
|
||||||
dirs.sort()
|
dirs.sort()
|
||||||
|
if verbose > 0:
|
||||||
print(root)
|
print(root)
|
||||||
for file in sorted(files):
|
for file in sorted(files):
|
||||||
#if re.search('2022\/10\/09', root) and re.search('0028.aans$', file):
|
#if re.search('2022\/10\/09', root) and re.search('0028.aans$', file):
|
||||||
if re.search('.aans$', file):
|
if re.search('.aans$', file):
|
||||||
xml_file = os.path.join(root, file)
|
xml_file = os.path.join(root, file)
|
||||||
total += 1
|
total += 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(xml_file, 'r', encoding="ASCII") as f:
|
with open(xml_file, 'r', encoding="ASCII") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
#print(f"ASCII read succeeded in {xml_file}")
|
if verbose > 1:
|
||||||
|
print(f"ASCII read succeeded in {xml_file}")
|
||||||
plain += 1
|
plain += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
#print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}")
|
if verbose > 1:
|
||||||
|
print(f"ASCII read failed, trying UTF-8 in {xml_file} : {e}")
|
||||||
try:
|
try:
|
||||||
with open(xml_file, 'r', encoding="UTF-8") as f:
|
with open(xml_file, 'r', encoding="UTF-8") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
#print(f"UTF-8 read succeeded in {xml_file}")
|
if verbose > 1:
|
||||||
|
print(f"UTF-8 read succeeded in {xml_file}")
|
||||||
utf8 += 1
|
utf8 += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
#print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}")
|
if verbose > 1:
|
||||||
|
print(f"UTF-8 read failed, trying ISO-8859-1 in {xml_file} : {e}")
|
||||||
try:
|
try:
|
||||||
with open(xml_file, 'r', encoding="ISO-8859-1") as f:
|
with open(xml_file, 'r', encoding="ISO-8859-1") as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
#print(f"ISO-8859-1 read succeeded in {xml_file}")
|
if verbose > 1:
|
||||||
|
print(f"ISO-8859-1 read succeeded in {xml_file}")
|
||||||
iso88591 += 1
|
iso88591 += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}")
|
print(f"UTF-8 and ISO-8859-1 read failed in {xml_file} : {e}")
|
||||||
|
if verbose > 2:
|
||||||
print(content)
|
print(content)
|
||||||
failed += 1
|
failed += 1
|
||||||
content = partial_unescape(content)
|
content = partial_unescape(content)
|
||||||
content = local_clean(content)
|
content = local_clean(content)
|
||||||
#print(content)
|
if verbose > 3:
|
||||||
|
print(content)
|
||||||
|
|
||||||
key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file)
|
key = re.sub('^.*\/(\d{4})\/(\d{2})\/(\d{2})\/(\d{4}).aans$', '\g<1>\g<2>\g<3>\g<4>', xml_file)
|
||||||
try:
|
try:
|
||||||
doc = ET.fromstring(content)
|
doc = ET.fromstring(content)
|
||||||
|
|
||||||
entry = dict()
|
entry = dict()
|
||||||
entry["key"] = key
|
entry["key"] = key
|
||||||
|
|
||||||
cats = list()
|
cats = list()
|
||||||
for cat in doc.findall('category'):
|
for cat in doc.findall('./category'):
|
||||||
#if cat not in allCats:
|
|
||||||
# allCats.append(cat)
|
|
||||||
cats.append(cat.text)
|
cats.append(cat.text)
|
||||||
#entry["categories"] = cats
|
#entry["categories"] = cats # if you want a list
|
||||||
entry["categories"] = ";".join(cats)
|
entry["categories"] = ";".join(cats) # if you want a string
|
||||||
|
|
||||||
text = list()
|
text = list()
|
||||||
|
lang = ""
|
||||||
try:
|
try:
|
||||||
#text = "\n".join([p.text for p in doc.find('./body')])
|
|
||||||
for p in doc.find('./body'):
|
for p in doc.find('./body'):
|
||||||
if p.text is not None:
|
if p.text is not None:
|
||||||
text.append(p.text)
|
text.append(p.text)
|
||||||
if text is not None and len(cats) > 1:
|
lang = doc.find('./language').text
|
||||||
entry["content"] = "\n".join(text)
|
|
||||||
articles.append(entry)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{xml_file} : {e}")
|
print(f"{xml_file} : {e}")
|
||||||
|
|
||||||
|
if text is not None and len(cats) > 1:
|
||||||
|
entry["content"] = "\n".join(text)
|
||||||
|
entry["language"] = lang
|
||||||
|
articles.append(entry)
|
||||||
|
|
||||||
except ET.ParseError as e:
|
except ET.ParseError as e:
|
||||||
|
if verbose > 1:
|
||||||
print(insert_line_numbers(content))
|
print(insert_line_numbers(content))
|
||||||
print("Parse error in " + xml_file + " : ", e)
|
print("Parse error in " + xml_file + " : ", e)
|
||||||
raise(SystemExit)
|
raise(SystemExit)
|
||||||
|
|
||||||
|
if verbose > 0:
|
||||||
print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed))
|
print("total: {: 7d}\nplain: {: 7d}\nutf8: {: 7d}\niso88591: {: 7d}\nfailed: {: 7d}\n".format(total, plain, utf8, iso88591, failed))
|
||||||
|
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
|
return articles
|
||||||
|
|
||||||
|
|
||||||
|
def scrub_data(articles, verbose):
|
||||||
data = pd.DataFrame(articles)
|
data = pd.DataFrame(articles)
|
||||||
data.set_index("key", inplace=True)
|
data.set_index("key", inplace=True)
|
||||||
|
|
||||||
|
#if verbose > 2:
|
||||||
# print(data.categories)
|
# print(data.categories)
|
||||||
|
|
||||||
# Initialization
|
# Initialization
|
||||||
@ -142,5 +153,39 @@ data['content'] = data.content.parallel_apply(lambda x: x.translate(remove_digit
|
|||||||
data['content']=data.content.parallel_apply(lambda x: x.strip())
|
data['content']=data.content.parallel_apply(lambda x: x.strip())
|
||||||
data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
|
data['content']=data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
|
||||||
|
|
||||||
with open(args.output, 'w', encoding="utf-8") as f:
|
# TODO: lemmas? See spaCy
|
||||||
data.to_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Turn XML data files into a dataset for use with pytorch',
|
||||||
|
add_help=True,
|
||||||
|
)
|
||||||
|
parser.add_argument('--output', '-o',
|
||||||
|
required=True,
|
||||||
|
help='path of output CSV file')
|
||||||
|
parser.add_argument('--input', '-i',
|
||||||
|
required=True,
|
||||||
|
help='path of input directory containing XML files')
|
||||||
|
parser.add_argument('--verbose', '-v',
|
||||||
|
type=int, nargs='?',
|
||||||
|
const=1, # Default value if -v is supplied
|
||||||
|
default=0, # Default value if -v is not supplied
|
||||||
|
help='print debugging')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if os.path.isdir(args.input) is False:
|
||||||
|
print(f"{args.input} is not a directory or does not exist");
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
articles = parse_and_extract(args.input, args.verbose)
|
||||||
|
|
||||||
|
data = scrub_data(articles, args.verbose)
|
||||||
|
|
||||||
|
write_csv(data, args.output)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -6,80 +6,84 @@ import sys
|
|||||||
import pprint
|
import pprint
|
||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
# data manupulation
|
||||||
#data manupulation libs
|
|
||||||
import csv
|
import csv
|
||||||
import random
|
import random
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import numpy as np
|
import numpy as np
|
||||||
#from pandarallel import pandarallel
|
#from pandarallel import pandarallel
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
# torch
|
||||||
#torch libs
|
|
||||||
import torch
|
import torch
|
||||||
import torchdata.datapipes as dp
|
import torchdata.datapipes as dp
|
||||||
import torchtext.transforms as T
|
import torchtext.transforms as T
|
||||||
from torchtext.vocab import build_vocab_from_iterator
|
from torchtext.vocab import build_vocab_from_iterator
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
story_num = 40 # XXX None for all
|
||||||
description='Classify text data according to categories',
|
|
||||||
add_help=True,
|
|
||||||
)
|
|
||||||
parser.add_argument('action', help='train or classify')
|
|
||||||
parser.add_argument('--input', '-i', required=True, help='path of CSV file containing dataset')
|
|
||||||
parser.add_argument('--output', '-o', help='path to trained model')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if args.action != 'train' and args.action != 'classify':
|
# Hyperparameters
|
||||||
print("ERROR: train or classify data")
|
EPOCHS = 10 # epoch
|
||||||
sys.exit(1)
|
LR = 5 # learning rate
|
||||||
|
BATCH_SIZE = 64 # batch size for training
|
||||||
|
|
||||||
if args.action == 'classify' and s.path.isfile(model_storage) is None:
|
def read_csv(input_csv, rows=None, verbose=0):
|
||||||
print("No model found for classification; running training instead")
|
if verbose > 0:
|
||||||
args.action = 'train'
|
with open(input_csv, 'r', encoding="utf-8") as f:
|
||||||
|
|
||||||
if os.path.isfile(args.input) is False:
|
|
||||||
print(f"{args.input} is not a valid file")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
#with open(args.input, 'r', encoding="utf-8") as f:
|
|
||||||
# data = pd.read_csv(f, encoding="utf-8", quoting=csv.QUOTE_ALL)
|
|
||||||
|
|
||||||
with open(args.input, 'r', encoding="utf-8") as f:
|
|
||||||
data = pd.concat(
|
data = pd.concat(
|
||||||
[chunk for chunk in tqdm(
|
[chunk for chunk in tqdm(
|
||||||
pd.read_csv(f,
|
pd.read_csv(f,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
quoting=csv.QUOTE_ALL,
|
quoting=csv.QUOTE_ALL,
|
||||||
nrows=200, ## XXX
|
nrows=rows,
|
||||||
chunksize=100),
|
chunksize=50,
|
||||||
|
),
|
||||||
desc='Loading data'
|
desc='Loading data'
|
||||||
)])
|
)])
|
||||||
|
else:
|
||||||
|
with open(input_csv, 'r', encoding="utf-8") as f:
|
||||||
|
data = pd.read_csv(f,
|
||||||
|
encoding="utf-8",
|
||||||
|
quoting=csv.QUOTE_ALL,
|
||||||
|
nrows=rows,
|
||||||
|
)
|
||||||
|
|
||||||
data.dropna(axis='index', inplace=True)
|
data.dropna(axis='index', inplace=True)
|
||||||
|
|
||||||
#print(data)
|
#print(data)
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Create Training and Validation sets
|
Create Training and Validation sets
|
||||||
'''
|
'''
|
||||||
|
def split_dataset(data, verbose=0):
|
||||||
# Create a list of ints till len of data
|
# Create a list of ints till len of data
|
||||||
data_idx = list(range(len(data)))
|
data_idx = list(range(len(data)))
|
||||||
np.random.shuffle(data_idx)
|
np.random.shuffle(data_idx)
|
||||||
|
|
||||||
# Get indexes for validation and train
|
# Get indexes for validation and train
|
||||||
split_percent = 0.95
|
split_percent = 0.05
|
||||||
num_train = int(len(data) * split_percent)
|
num_valid = int(len(data) * split_percent)
|
||||||
valid_idx, train_idx = data_idx[num_train:], data_idx[:num_train]
|
#num_tests = int(len(data) * split_percent)
|
||||||
|
#train_idx = data_idx[num_valid:-num_tests]
|
||||||
|
train_idx = data_idx[num_valid:]
|
||||||
|
valid_idx = data_idx[:num_valid]
|
||||||
|
#tests_idx = data_idx[-num_tests:]
|
||||||
|
if verbose > 0:
|
||||||
print("Length of train_data: {}".format(len(train_idx)))
|
print("Length of train_data: {}".format(len(train_idx)))
|
||||||
print("Length of valid_data: {}".format(len(valid_idx)))
|
print("Length of valid_data: {}".format(len(valid_idx)))
|
||||||
|
#print("Length of tests_data: {}".format(len(tests_idx)))
|
||||||
|
|
||||||
# Create the training and validation sets, as dataframes
|
# Create the training and validation sets, as dataframes
|
||||||
train_data = data.iloc[train_idx].reset_index().drop('index', axis=1)
|
train_data = data.iloc[train_idx].reset_index().drop('index', axis=1)
|
||||||
valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1)
|
valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1)
|
||||||
|
#tests_data = data.iloc[tests_idx].reset_index().drop('index', axis=1)
|
||||||
|
#return(train_data, valid_data, tests_data)
|
||||||
|
return(train_data, valid_data)
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
@ -88,21 +92,24 @@ valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1)
|
|||||||
'''
|
'''
|
||||||
class TextCategoriesDataset(Dataset):
|
class TextCategoriesDataset(Dataset):
|
||||||
''' Dataset of Text and Categories '''
|
''' Dataset of Text and Categories '''
|
||||||
def __init__(self, df, text_column, cats_column, transform=None):
|
def __init__(self, df, text_column, cats_column, lang_column, transform=None, verbose=0):
|
||||||
'''
|
'''
|
||||||
Arguments:
|
Arguments:
|
||||||
df (panda.Dataframe): csv content, loaded as dataframe
|
df (panda.Dataframe): csv content, loaded as dataframe
|
||||||
text_column (str): the name of the column containing the text
|
text_column (str): the name of the column containing the text
|
||||||
cats_column (str): the name of the column containing
|
cats_column (str): the name of the column containing
|
||||||
semicolon-separated categories
|
semicolon-separated categories
|
||||||
|
text_column (str): the name of the column containing the language
|
||||||
transform (callable, optional): Optional transform to be
|
transform (callable, optional): Optional transform to be
|
||||||
applied on a sample.
|
applied on a sample.
|
||||||
'''
|
'''
|
||||||
self.df = df
|
self.df = df
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
self.texts = self.df[text_column]
|
self.text = self.df[text_column]
|
||||||
self.cats = self.df[cats_column]
|
self.cats = self.df[cats_column]
|
||||||
|
self.lang = self.df[lang_column]
|
||||||
|
|
||||||
# index-to-token dict
|
# index-to-token dict
|
||||||
# <pad> : padding, used for padding the shorter sentences in a batch
|
# <pad> : padding, used for padding the shorter sentences in a batch
|
||||||
@ -146,8 +153,9 @@ class TextCategoriesDataset(Dataset):
|
|||||||
idx = idx.tolist()
|
idx = idx.tolist()
|
||||||
|
|
||||||
# Get the raw data
|
# Get the raw data
|
||||||
text = self.texts[idx]
|
text = self.text[idx]
|
||||||
cats = self.cats[idx]
|
cats = self.cats[idx]
|
||||||
|
lang = self.lang[idx]
|
||||||
|
|
||||||
if self.transform:
|
if self.transform:
|
||||||
text, cats = self.transform(text, cats)
|
text, cats = self.transform(text, cats)
|
||||||
@ -186,25 +194,6 @@ class TextCategoriesDataset(Dataset):
|
|||||||
T.AddToken(2, begin=False)
|
T.AddToken(2, begin=False)
|
||||||
)
|
)
|
||||||
|
|
||||||
'''
|
|
||||||
dataset = TextCategoriesDataset(df=data,
|
|
||||||
text_column="content",
|
|
||||||
cats_column="categories",
|
|
||||||
)
|
|
||||||
'''
|
|
||||||
train_dataset = TextCategoriesDataset(df=train_data,
|
|
||||||
text_column="content",
|
|
||||||
cats_column="categories",
|
|
||||||
)
|
|
||||||
valid_dataset = TextCategoriesDataset(df=valid_data,
|
|
||||||
text_column="content",
|
|
||||||
cats_column="categories",
|
|
||||||
)
|
|
||||||
#print(dataset[2])
|
|
||||||
#for text, cat in enumerate(valid_dataset):
|
|
||||||
# print(text, cat)
|
|
||||||
#sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Now that we have a dataset, let's create dataloader,
|
Now that we have a dataset, let's create dataloader,
|
||||||
@ -232,10 +221,113 @@ class CollateBatch:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Hyperparameters
|
class TextClassificationModel(nn.Module):
|
||||||
EPOCHS = 10 # epoch
|
def __init__(self, input_size, output_size, verbose):
|
||||||
LR = 5 # learning rate
|
super().__init__()
|
||||||
BATCH_SIZE = 64 # batch size for training
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def train(dataloader):
|
||||||
|
model.train()
|
||||||
|
total_acc, total_count = 0, 0
|
||||||
|
log_interval = 500
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for idx, (label, text) in enumerate(dataloader):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
predicted_label = model(text)
|
||||||
|
loss = criterion(predicted_label, label)
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
|
||||||
|
optimizer.step()
|
||||||
|
total_acc += (predicted_label.argmax(1) == label).sum().item()
|
||||||
|
total_count += label.size(0)
|
||||||
|
if idx % log_interval == 0 and idx > 0:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(
|
||||||
|
"| epoch {:3d} | {:5d}/{:5d} batches "
|
||||||
|
"| accuracy {:8.3f}".format(
|
||||||
|
epoch, idx, len(dataloader), total_acc / total_count
|
||||||
|
)
|
||||||
|
)
|
||||||
|
total_acc, total_count = 0, 0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(dataloader):
|
||||||
|
model.eval()
|
||||||
|
total_acc, total_count = 0, 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for idx, (label, text) in enumerate(dataloader):
|
||||||
|
predicted_label = model(text)
|
||||||
|
loss = criterion(predicted_label, label)
|
||||||
|
total_acc += (predicted_label.argmax(1) == label).sum().item()
|
||||||
|
total_count += label.size(0)
|
||||||
|
return total_acc / total_count
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Classify text data according to categories',
|
||||||
|
add_help=True,
|
||||||
|
)
|
||||||
|
parser.add_argument('action',
|
||||||
|
help='train or classify')
|
||||||
|
parser.add_argument('--input', '-i',
|
||||||
|
required=True,
|
||||||
|
help='path of CSV file containing dataset')
|
||||||
|
parser.add_argument('--model', '-m',
|
||||||
|
#required=True, # XXX
|
||||||
|
help='path to training model')
|
||||||
|
parser.add_argument('--verbose', '-v',
|
||||||
|
type=int, nargs='?',
|
||||||
|
const=1, # Default value if -v is supplied
|
||||||
|
default=0, # Default value if -v is not supplied
|
||||||
|
help='print debugging')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.action != 'train' and args.action != 'classify':
|
||||||
|
print("ERROR: train or classify data")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.action == 'classify' and s.path.isfile(model_storage) is None:
|
||||||
|
print("No model found for classification; running training instead")
|
||||||
|
args.action = 'train'
|
||||||
|
|
||||||
|
if os.path.isfile(args.input) is False:
|
||||||
|
print(f"{args.input} is not a valid file")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose)
|
||||||
|
train_data, valid_data, = split_dataset(data, verbose=args.verbose)
|
||||||
|
|
||||||
|
'''
|
||||||
|
dataset = TextCategoriesDataset(df=data,
|
||||||
|
text_column="content",
|
||||||
|
cats_column="categories",
|
||||||
|
lang_column="language",
|
||||||
|
verbose=args.verbose,
|
||||||
|
)
|
||||||
|
'''
|
||||||
|
train_dataset = TextCategoriesDataset(df=train_data,
|
||||||
|
text_column="content",
|
||||||
|
cats_column="categories",
|
||||||
|
lang_column="language",
|
||||||
|
verbose=args.verbose,
|
||||||
|
)
|
||||||
|
valid_dataset = TextCategoriesDataset(df=valid_data,
|
||||||
|
text_column="content",
|
||||||
|
cats_column="categories",
|
||||||
|
lang_column="language",
|
||||||
|
verbose=args.verbose,
|
||||||
|
)
|
||||||
|
#print(dataset[2])
|
||||||
|
#for text, cat in enumerate(valid_dataset):
|
||||||
|
# print(text, cat)
|
||||||
|
#sys.exit(0)
|
||||||
|
|
||||||
# Get cpu, gpu or mps device for training.
|
# Get cpu, gpu or mps device for training.
|
||||||
# Move tensor to the NVIDIA GPU if available
|
# Move tensor to the NVIDIA GPU if available
|
||||||
@ -247,7 +339,6 @@ device = (
|
|||||||
)
|
)
|
||||||
print(f"Using {device} device")
|
print(f"Using {device} device")
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
dataloader = DataLoader(dataset,
|
dataloader = DataLoader(dataset,
|
||||||
batch_size=4,
|
batch_size=4,
|
||||||
@ -272,4 +363,41 @@ valid_dataloader = DataLoader(valid_dataset,
|
|||||||
# print(i_batch, sample_batched[0], sample_batched[1])
|
# print(i_batch, sample_batched[0], sample_batched[1])
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
|
|
||||||
|
num_class = len(set([cats for key, cats, text, lang in train_data.values]))
|
||||||
|
input_size = len(train_dataset.text_vocab)
|
||||||
|
output_size = len(train_dataset.cats_vocab)
|
||||||
|
emsize = 64
|
||||||
|
model = TextClassificationModel(input_size, output_size, args.verbose).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
||||||
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
|
||||||
|
total_accu = None
|
||||||
|
|
||||||
|
for epoch in range(1, EPOCHS + 1):
|
||||||
|
epoch_start_time = time.time()
|
||||||
|
train(train_dataloader)
|
||||||
|
accu_val = evaluate(valid_dataloader)
|
||||||
|
if total_accu is not None and total_accu > accu_val:
|
||||||
|
scheduler.step()
|
||||||
|
else:
|
||||||
|
total_accu = accu_val
|
||||||
|
print("-" * 59)
|
||||||
|
print(
|
||||||
|
"| end of epoch {:3d} | time: {:5.2f}s | "
|
||||||
|
"valid accuracy {:8.3f} ".format(
|
||||||
|
epoch, time.time() - epoch_start_time, accu_val
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("-" * 59)
|
||||||
|
|
||||||
|
print("Checking the results of test dataset.")
|
||||||
|
accu_test = evaluate(test_dataloader)
|
||||||
|
print("test accuracy {:8.3f}".format(accu_test))
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user