Convert to a multi-hot index in the CSV, to simplify our DataSets and DataLoaders
This commit is contained in:
parent
bedf82d8a1
commit
910e0c9d24
@ -96,11 +96,7 @@ def parse_and_extract(input_dir, verbose):
|
|||||||
|
|
||||||
cats = list()
|
cats = list()
|
||||||
for cat in doc.findall('./category'):
|
for cat in doc.findall('./category'):
|
||||||
# TODO check against a list of current categories,
|
|
||||||
# and strip any non-current categories
|
|
||||||
cats.append(cat.text)
|
cats.append(cat.text)
|
||||||
#entry["categories"] = cats # if you want a list
|
|
||||||
entry["categories"] = ";".join(cats) # if you want a string
|
|
||||||
|
|
||||||
text = list()
|
text = list()
|
||||||
lang = ""
|
lang = ""
|
||||||
@ -115,10 +111,19 @@ def parse_and_extract(input_dir, verbose):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"{xml_file} : {e}")
|
print(f"{xml_file} : {e}")
|
||||||
|
|
||||||
if text is not None and len(cats) > 1:
|
if text is not None and len(cats) >= 1:
|
||||||
entry["content"] = "\n".join(text)
|
|
||||||
entry["language"] = lang
|
entry["language"] = lang
|
||||||
|
entry["content"] = "\n".join(text)
|
||||||
|
for cat in cats:
|
||||||
|
entry[cat] = 1
|
||||||
articles.append(entry)
|
articles.append(entry)
|
||||||
|
else:
|
||||||
|
if len(cats) < 1:
|
||||||
|
print(f"No article added for key {key} due to lack of categories")
|
||||||
|
elif text is None:
|
||||||
|
print(f"No article added for key {key} due to lack of text")
|
||||||
|
else:
|
||||||
|
print(f"No article added for key {key} due to unknown error")
|
||||||
|
|
||||||
except ET.ParseError as e:
|
except ET.ParseError as e:
|
||||||
if verbose > 1:
|
if verbose > 1:
|
||||||
@ -158,7 +163,10 @@ def scrub_data(articles, verbose):
|
|||||||
data['content'] = data.content.parallel_apply(lambda x: x.strip())
|
data['content'] = data.content.parallel_apply(lambda x: x.strip())
|
||||||
data['content'] = data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
|
data['content'] = data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
|
||||||
|
|
||||||
# TODO: lemmas? See spaCy
|
# Any remaining text processing can be done by training/inference step
|
||||||
|
|
||||||
|
# sort category columns: lowercase first (key, language, content), then title-cased categories
|
||||||
|
data.reindex(columns=sorted(data.columns, key=lambda x: (x.casefold(), x.swapcase())))
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -19,19 +19,19 @@ import tqdm
|
|||||||
import torch
|
import torch
|
||||||
import torchdata.datapipes as dp
|
import torchdata.datapipes as dp
|
||||||
import torchtext.transforms as T
|
import torchtext.transforms as T
|
||||||
|
import torchtext.vocab as vocab
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
from torchtext.vocab import build_vocab_from_iterator
|
|
||||||
|
|
||||||
from models.rnn import RNN
|
xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
|
||||||
|
xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
|
||||||
|
|
||||||
all_categories = list()
|
|
||||||
# XXX None for all stories
|
# XXX None for all stories
|
||||||
#story_num = 128
|
story_num = 128
|
||||||
#story_num = 256
|
#story_num = 256
|
||||||
#story_num = 512
|
#story_num = 512
|
||||||
#story_num = 1024
|
#story_num = 1024
|
||||||
story_num = 4096
|
#story_num = 4096
|
||||||
#story_num = None
|
#story_num = None
|
||||||
|
|
||||||
def read_csv(input_csv, rows=None, verbose=0):
|
def read_csv(input_csv, rows=None, verbose=0):
|
||||||
@ -42,6 +42,7 @@ def read_csv(input_csv, rows=None, verbose=0):
|
|||||||
pd.read_csv(f,
|
pd.read_csv(f,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
quoting=csv.QUOTE_ALL,
|
quoting=csv.QUOTE_ALL,
|
||||||
|
index_col=0,
|
||||||
nrows=rows,
|
nrows=rows,
|
||||||
chunksize=50,
|
chunksize=50,
|
||||||
),
|
),
|
||||||
@ -52,10 +53,10 @@ def read_csv(input_csv, rows=None, verbose=0):
|
|||||||
data = pd.read_csv(f,
|
data = pd.read_csv(f,
|
||||||
encoding="utf-8",
|
encoding="utf-8",
|
||||||
quoting=csv.QUOTE_ALL,
|
quoting=csv.QUOTE_ALL,
|
||||||
|
index_col=0,
|
||||||
nrows=rows,
|
nrows=rows,
|
||||||
)
|
)
|
||||||
|
|
||||||
data.dropna(axis='index', inplace=True)
|
|
||||||
#print(data)
|
#print(data)
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
return data
|
return data
|
||||||
@ -83,9 +84,9 @@ def split_dataset(data, verbose=0):
|
|||||||
#print("Length of tests_data: {}".format(len(tests_idx)))
|
#print("Length of tests_data: {}".format(len(tests_idx)))
|
||||||
|
|
||||||
# Create the training and validation sets, as dataframes
|
# Create the training and validation sets, as dataframes
|
||||||
train_data = data.iloc[train_idx].reset_index().drop('index', axis=1)
|
train_data = data.iloc[train_idx].reset_index()
|
||||||
valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1)
|
valid_data = data.iloc[valid_idx].reset_index()
|
||||||
#tests_data = data.iloc[tests_idx].reset_index().drop('index', axis=1)
|
#tests_data = data.iloc[tests_idx].reset_index()
|
||||||
#return(train_data, valid_data, tests_data)
|
#return(train_data, valid_data, tests_data)
|
||||||
return(train_data, valid_data)
|
return(train_data, valid_data)
|
||||||
|
|
||||||
@ -96,24 +97,25 @@ def split_dataset(data, verbose=0):
|
|||||||
'''
|
'''
|
||||||
class TextCategoriesDataset(Dataset):
|
class TextCategoriesDataset(Dataset):
|
||||||
''' Dataset of Text and Categories '''
|
''' Dataset of Text and Categories '''
|
||||||
def __init__(self, df, text_column, cats_column, lang_column, transform=None, verbose=0):
|
def __init__(self, df, lang_column, text_column, first_cats_column=0, transform=None, verbose=0):
|
||||||
'''
|
'''
|
||||||
Arguments:
|
Arguments:
|
||||||
df (panda.Dataframe): csv content, loaded as dataframe
|
df (panda.Dataframe): csv content, loaded as dataframe
|
||||||
|
lang_column (str): the name of the column containing the language
|
||||||
text_column (str): the name of the column containing the text
|
text_column (str): the name of the column containing the text
|
||||||
cats_column (str): the name of the column containing
|
first_cats_column (int): the index of the first column containing
|
||||||
semicolon-separated categories
|
a category
|
||||||
text_column (str): the name of the column containing the language
|
transform (callable, optional): Optional transform to be applied
|
||||||
transform (callable, optional): Optional transform to be
|
on a sample.
|
||||||
applied on a sample.
|
|
||||||
'''
|
'''
|
||||||
self.df = df
|
self.df = df
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
self.text = self.df[text_column]
|
|
||||||
self.cats = self.df[cats_column]
|
|
||||||
self.lang = self.df[lang_column]
|
self.lang = self.df[lang_column]
|
||||||
|
self.text = self.df[text_column]
|
||||||
|
self.cats = self.df.iloc[:, first_cats_column:].sort_index(axis="columns")
|
||||||
|
self.cats_vocab = self.cats.columns
|
||||||
|
|
||||||
# index-to-token dict
|
# index-to-token dict
|
||||||
# <pad> : padding, used for padding the shorter sentences in a batch
|
# <pad> : padding, used for padding the shorter sentences in a batch
|
||||||
@ -126,26 +128,6 @@ class TextCategoriesDataset(Dataset):
|
|||||||
# token-to-index dict
|
# token-to-index dict
|
||||||
self.stoi = {k:j for j, k in self.itos.items()}
|
self.stoi = {k:j for j, k in self.itos.items()}
|
||||||
|
|
||||||
# Create vocabularies upon initialisation
|
|
||||||
self.text_vocab = build_vocab_from_iterator(
|
|
||||||
[self.textTokens(text) for i, text in self.df[text_column].items()],
|
|
||||||
min_freq=2,
|
|
||||||
specials=self.itos.values(),
|
|
||||||
special_first=True
|
|
||||||
)
|
|
||||||
self.text_vocab.set_default_index(self.text_vocab['<unk>'])
|
|
||||||
#print(self.text_vocab.get_itos())
|
|
||||||
|
|
||||||
self.cats_vocab = build_vocab_from_iterator(
|
|
||||||
#[self.catTokens(cats) for i, cats in self.df[cats_column].items()],
|
|
||||||
[self.catTokens(all_categories)],
|
|
||||||
min_freq=1,
|
|
||||||
specials=['<unk>'],
|
|
||||||
special_first=True
|
|
||||||
)
|
|
||||||
self.cats_vocab.set_default_index(self.cats_vocab['<unk>'])
|
|
||||||
#print(self.cats_vocab.get_itos())
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.df)
|
return len(self.df)
|
||||||
|
|
||||||
@ -158,58 +140,41 @@ class TextCategoriesDataset(Dataset):
|
|||||||
idx = idx.tolist()
|
idx = idx.tolist()
|
||||||
|
|
||||||
# Get the raw data
|
# Get the raw data
|
||||||
text = self.text[idx]
|
|
||||||
cats = self.cats[idx]
|
|
||||||
lang = self.lang[idx]
|
lang = self.lang[idx]
|
||||||
|
text = self.text[idx]
|
||||||
|
cats = self.cats.iloc[idx]
|
||||||
|
|
||||||
|
#print(self.textTransform()(text))
|
||||||
|
#print(cats)
|
||||||
|
#print(cats.fillna(0).values)
|
||||||
|
|
||||||
if self.transform:
|
if self.transform:
|
||||||
text, cats = self.transform(text, cats)
|
text, cats = self.transform(text, cats)
|
||||||
|
|
||||||
#print(cats)
|
# Numericalise text by applying transforms, and cats by converting
|
||||||
#print(self.catTokens(cats))
|
# NaN to zeros and stripping the index
|
||||||
#print(self.getTransform(self.cats_vocab, "cats")(self.catTokens(cats)))
|
|
||||||
|
|
||||||
# Numericalise by applying transforms
|
|
||||||
return (
|
return (
|
||||||
self.getTransform(self.text_vocab, "text")(self.textTokens(text)),
|
self.textTransform()(text),
|
||||||
self.getTransform(self.cats_vocab, "cats")(self.catTokens(cats)),
|
cats.fillna(0).values,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
def textTransform(self):
|
||||||
def textTokens(text):
|
|
||||||
if isinstance(text, str):
|
|
||||||
return [word for word in text.split()]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def catTokens(cats):
|
|
||||||
if isinstance(cats, str):
|
|
||||||
return [cat for cat in cats.split(';')]
|
|
||||||
elif isinstance(cats, list):
|
|
||||||
return [cat for cat in cats]
|
|
||||||
|
|
||||||
def getTransform(self, vocab, vType):
|
|
||||||
'''
|
'''
|
||||||
Create transforms based on given vocabulary. The returned transform
|
Create transforms based on given vocabulary. The returned transform
|
||||||
is applied to a sequence of tokens.
|
is applied to a sequence of tokens.
|
||||||
'''
|
'''
|
||||||
if vType == "text":
|
return T.Sequential(
|
||||||
return T.Sequential(
|
# converts the sentences to indices based on given vocabulary using SentencePiece
|
||||||
# converts the sentences to indices based on given vocabulary
|
T.SentencePieceTokenizer(xlmr_spm_model_path),
|
||||||
T.VocabTransform(vocab=vocab),
|
T.VocabTransform(torch.hub.load_state_dict_from_url(xlmr_vocab_path)),
|
||||||
# Add <sos> at beginning of each sentence. 1 because the index
|
# Add <sos> at beginning of each sentence. 1 because the index
|
||||||
# for <sos> in vocabulary is 1 as seen in previous section
|
# for <sos> in vocabulary is 1 as seen in previous section
|
||||||
T.AddToken(self.text_vocab['<sos>'], begin=True),
|
T.AddToken(self.stoi['<sos>'], begin=True),
|
||||||
# Add <eos> at end of each sentence. 2 because the index
|
# Add <eos> at end of each sentence. 2 because the index
|
||||||
# for <eos> in vocabulary is 2 as seen in previous section
|
# for <eos> in vocabulary is 2 as seen in previous section
|
||||||
T.AddToken(self.text_vocab['<eos>'], begin=False)
|
T.AddToken(self.stoi['<eos>'], begin=False)
|
||||||
)
|
)
|
||||||
elif vType == "cats":
|
|
||||||
return T.Sequential(
|
|
||||||
# converts the sentences to indices based on given vocabulary
|
|
||||||
T.VocabTransform(vocab=vocab),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise Exception('wrong transformation type')
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
@ -223,12 +188,11 @@ class CollateBatch:
|
|||||||
in a batch of equal length. We can do this a collate_fn callback class,
|
in a batch of equal length. We can do this a collate_fn callback class,
|
||||||
which returns a tensor
|
which returns a tensor
|
||||||
'''
|
'''
|
||||||
def __init__(self, pad_idx, cats):
|
def __init__(self, pad_idx):
|
||||||
'''
|
'''
|
||||||
pad_idx (int): the index of the "<pad>" token in the vocabulary.
|
pad_idx (int): the index of the "<pad>" token in the vocabulary.
|
||||||
'''
|
'''
|
||||||
self.pad_idx = pad_idx
|
self.pad_idx = pad_idx
|
||||||
self.cats = cats
|
|
||||||
|
|
||||||
def __call__(self, batch):
|
def __call__(self, batch):
|
||||||
'''
|
'''
|
||||||
@ -236,13 +200,6 @@ class CollateBatch:
|
|||||||
is a list of tokens
|
is a list of tokens
|
||||||
'''
|
'''
|
||||||
batch_text, batch_cats = zip(*batch)
|
batch_text, batch_cats = zip(*batch)
|
||||||
#for i in range(len(batch)):
|
|
||||||
# print(batch[i])
|
|
||||||
#max_text_len = len(max(batch_text, key=len))
|
|
||||||
#max_cats_len = len(max(batch_cats, key=len))
|
|
||||||
|
|
||||||
#text_tensor = T.ToTensor(self.pad_idx)(batch_text)
|
|
||||||
#cats_tensor = T.ToTensor(self.pad_idx)(batch_cats)
|
|
||||||
|
|
||||||
# Pad text to the longest
|
# Pad text to the longest
|
||||||
text_tensor = nn.utils.rnn.pad_sequence(
|
text_tensor = nn.utils.rnn.pad_sequence(
|
||||||
@ -251,44 +208,7 @@ class CollateBatch:
|
|||||||
)
|
)
|
||||||
text_lengths = torch.tensor([t.shape[0] for t in text_tensor])
|
text_lengths = torch.tensor([t.shape[0] for t in text_tensor])
|
||||||
|
|
||||||
#cats_tensor = torch.nn.utils.rnn.pad_sequence(
|
cats_tensor = torch.tensor(batch_cats, dtype=torch.float32)
|
||||||
# [torch.LongTensor(s) for s in batch_cats],
|
|
||||||
# batch_first=True, padding_value=self.pad_idx
|
|
||||||
#)
|
|
||||||
#cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
|
||||||
|
|
||||||
'''
|
|
||||||
# Pad cats_tensor to all possible categories
|
|
||||||
num_cats = len(all_categories)
|
|
||||||
|
|
||||||
# Convert cats to multi-label one-hot representation
|
|
||||||
cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).float()
|
|
||||||
cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
|
|
||||||
for idx, cats in enumerate(batch_cats):
|
|
||||||
#print("\nsample", idx, cats)
|
|
||||||
for c in cats:
|
|
||||||
#print(c)
|
|
||||||
cats_tensor[idx][c] = 1
|
|
||||||
#print(cats_tensor[idx])
|
|
||||||
'''
|
|
||||||
# Convert cats to multi-label one-hot representation
|
|
||||||
# add one to all_categories to account for <unk>
|
|
||||||
cats_tensor = torch.full((len(batch_cats), len(all_categories)+1), self.pad_idx).float()
|
|
||||||
for idx, cats in enumerate(batch_cats):
|
|
||||||
#print("\nsample", idx, cats)
|
|
||||||
for c in cats:
|
|
||||||
cats_tensor[idx][c] = 1
|
|
||||||
#print(cats_tensor[idx])
|
|
||||||
#sys.exit(0)
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
# XXX why??
|
|
||||||
## SORT YOUR TENSORS BY LENGTH!
|
|
||||||
text_lengths, perm_idx = text_lengths.sort(0, descending=True)
|
|
||||||
text_tensor = text_tensor[perm_idx]
|
|
||||||
cats_tensor = cats_tensor[perm_idx]
|
|
||||||
'''
|
|
||||||
|
|
||||||
#print("text", text_tensor)
|
#print("text", text_tensor)
|
||||||
#print("text shape:", text_tensor.shape)
|
#print("text shape:", text_tensor.shape)
|
||||||
@ -296,7 +216,6 @@ class CollateBatch:
|
|||||||
#print("cats shape:", cats_tensor.shape)
|
#print("cats shape:", cats_tensor.shape)
|
||||||
#print(text_lengths)
|
#print(text_lengths)
|
||||||
#print("text_lengths shape:", text_lengths.shape)
|
#print("text_lengths shape:", text_lengths.shape)
|
||||||
|
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -305,48 +224,27 @@ class CollateBatch:
|
|||||||
text_lengths,
|
text_lengths,
|
||||||
)
|
)
|
||||||
|
|
||||||
def cat2tensor(label_vocab, labels, pad_idx: int):
|
def tensor2cat(dataset, tensor):
|
||||||
all_labels = vocab.get_itos()
|
cats = dataset.cats_vocab
|
||||||
num_labels = len(all_labels)
|
|
||||||
# add <unk>
|
|
||||||
if 0 not in all_labels:
|
|
||||||
num_labels += 1
|
|
||||||
|
|
||||||
labels_tensor = torch.full((len(labels), num_labels), pad_idx).float()
|
|
||||||
labels_lengths = torch.LongTensor(list(map(len, labels)))
|
|
||||||
for idx, labels in enumerate(labels):
|
|
||||||
#print("\nsample", idx, labels)
|
|
||||||
for l in labels:
|
|
||||||
labels_tensor[idx][l] = 1
|
|
||||||
#print(labels_tensor[idx])
|
|
||||||
return labels_tensor
|
|
||||||
|
|
||||||
def tensor2cat(vocab, tensor):
|
|
||||||
all_cats = vocab.get_itos()
|
|
||||||
if tensor.ndimension() == 2:
|
if tensor.ndimension() == 2:
|
||||||
batch = list()
|
batch = list()
|
||||||
for result in tensor:
|
for result in tensor:
|
||||||
chance = dict()
|
chance = dict()
|
||||||
for idx, pred in enumerate(result):
|
for idx, pred in enumerate(result):
|
||||||
if pred > 0: # XXX
|
if pred > 0: # XXX
|
||||||
chance[all_cats[idx]] = pred.item()
|
chance[cats[idx]] = pred.item()
|
||||||
#print(chance)
|
|
||||||
batch.append(chance)
|
batch.append(chance)
|
||||||
return batch
|
return batch
|
||||||
elif tensor.ndimension() == 1:
|
elif tensor.ndimension() == 1:
|
||||||
chance = dict()
|
chance = dict()
|
||||||
for idx, pred in enumerate(tensor):
|
for idx, pred in enumerate(tensor):
|
||||||
if idx >= len(all_cats):
|
if idx >= len(cats):
|
||||||
print(f"Idx {idx} not in {len(all_cats)} categories")
|
print(f"Idx {idx} not in {len(cats)} categories")
|
||||||
#elif pred > 0: # XXX
|
elif pred > 0: # XXX
|
||||||
#print(idx, len(all_cats))
|
chance[cats[idx]] = pred.item()
|
||||||
chance[all_cats[idx]] = pred.item()
|
|
||||||
#print(chance)
|
|
||||||
return chance
|
return chance
|
||||||
else:
|
else:
|
||||||
raise ValueError("Only tensors with 2 dimensions are supported")
|
raise ValueError("Only tensors with 1 dimension or batches with 2 dimensions are supported")
|
||||||
|
|
||||||
return vocab.get_itos(cat)
|
|
||||||
|
|
||||||
|
|
||||||
def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
|
||||||
@ -452,6 +350,17 @@ def evaluate(dataloader, dataset, model, criterion, epoch=0):
|
|||||||
})
|
})
|
||||||
return total_acc / total_count
|
return total_acc / total_count
|
||||||
|
|
||||||
|
# TODO seeding:
|
||||||
|
def seed_everything(seed=42):
|
||||||
|
random.seed(seed)
|
||||||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
# Some cudnn methods can be random even after fixing the seed
|
||||||
|
# unless you tell it to be deterministic
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -487,37 +396,26 @@ def main():
|
|||||||
|
|
||||||
data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose)
|
data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose)
|
||||||
|
|
||||||
# create list of all categories
|
|
||||||
global all_categories
|
|
||||||
for cats in data.categories:
|
|
||||||
for c in cats.split(";"):
|
|
||||||
if c not in all_categories:
|
|
||||||
all_categories.append(c)
|
|
||||||
all_categories = sorted(all_categories)
|
|
||||||
#print(all_categories)
|
|
||||||
#print(len(all_categories))
|
|
||||||
#sys.exit(0)
|
|
||||||
|
|
||||||
train_data, valid_data, = split_dataset(data, verbose=args.verbose)
|
train_data, valid_data, = split_dataset(data, verbose=args.verbose)
|
||||||
|
|
||||||
'''
|
'''
|
||||||
dataset = TextCategoriesDataset(df=data,
|
dataset = TextCategoriesDataset(df=data,
|
||||||
text_column="content",
|
|
||||||
cats_column="categories",
|
|
||||||
lang_column="language",
|
lang_column="language",
|
||||||
|
text_column="content",
|
||||||
|
first_cats_column=data.columns.get_loc("content")+1,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
)
|
)
|
||||||
'''
|
'''
|
||||||
train_dataset = TextCategoriesDataset(df=train_data,
|
train_dataset = TextCategoriesDataset(df=train_data,
|
||||||
text_column="content",
|
|
||||||
cats_column="categories",
|
|
||||||
lang_column="language",
|
lang_column="language",
|
||||||
|
text_column="content",
|
||||||
|
first_cats_column=train_data.columns.get_loc("content")+1,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
)
|
)
|
||||||
valid_dataset = TextCategoriesDataset(df=valid_data,
|
valid_dataset = TextCategoriesDataset(df=valid_data,
|
||||||
text_column="content",
|
|
||||||
cats_column="categories",
|
|
||||||
lang_column="language",
|
lang_column="language",
|
||||||
|
text_column="content",
|
||||||
|
first_cats_column=valid_data.columns.get_loc("content")+1,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
)
|
)
|
||||||
#for text, cat in enumerate(train_dataset):
|
#for text, cat in enumerate(train_dataset):
|
||||||
@ -525,6 +423,7 @@ def main():
|
|||||||
#print("-" * 20)
|
#print("-" * 20)
|
||||||
#for text, cat in enumerate(valid_dataset):
|
#for text, cat in enumerate(valid_dataset):
|
||||||
# print(text, cat)
|
# print(text, cat)
|
||||||
|
#print(tensor2cat(train_dataset, torch.tensor([0, 0, 0, 1., 0.9])))
|
||||||
#sys.exit(0)
|
#sys.exit(0)
|
||||||
|
|
||||||
# Get cpu, gpu or mps device for training.
|
# Get cpu, gpu or mps device for training.
|
||||||
@ -565,7 +464,7 @@ def main():
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
'''
|
'''
|
||||||
train_dataloader = DataLoader(train_dataset,
|
train_dataloader = DataLoader(train_dataset,
|
||||||
@ -573,20 +472,20 @@ def main():
|
|||||||
drop_last=True,
|
drop_last=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
valid_dataloader = DataLoader(valid_dataset,
|
valid_dataloader = DataLoader(valid_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
|
collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
|
||||||
)
|
)
|
||||||
#for i_batch, sample_batched in enumerate(dataloader):
|
#for i_batch, sample_batched in enumerate(dataloader):
|
||||||
# print(i_batch, sample_batched[0], sample_batched[1])
|
# print(i_batch, sample_batched[0], sample_batched[1])
|
||||||
#for i_batch, sample_batched in enumerate(train_dataloader):
|
for i_batch, sample_batched in enumerate(train_dataloader):
|
||||||
# print(i_batch, sample_batched[0], sample_batched[1])
|
print(i_batch, sample_batched[0], sample_batched[1])
|
||||||
#sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
input_size = len(train_dataset.text_vocab)
|
input_size = len(train_dataset.text_vocab)
|
||||||
output_size = len(train_dataset.cats_vocab) # every output item is the likelihood of a particular category
|
output_size = len(train_dataset.cats_vocab) # every output item is the likelihood of a particular category
|
||||||
|
Loading…
Reference in New Issue
Block a user