Convert to a multi-hot index in the CSV, to simplify our DataSets and DataLoaders
This commit is contained in:
		@@ -96,11 +96,7 @@ def parse_and_extract(input_dir, verbose):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
          cats = list()
 | 
					          cats = list()
 | 
				
			||||||
          for cat in doc.findall('./category'):
 | 
					          for cat in doc.findall('./category'):
 | 
				
			||||||
            # TODO check against a list of current categories,
 | 
					 | 
				
			||||||
            # and strip any non-current categories
 | 
					 | 
				
			||||||
            cats.append(cat.text)
 | 
					            cats.append(cat.text)
 | 
				
			||||||
          #entry["categories"] = cats          # if you want a list
 | 
					 | 
				
			||||||
          entry["categories"] = ";".join(cats) # if you want a string
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
          text = list()
 | 
					          text = list()
 | 
				
			||||||
          lang = ""
 | 
					          lang = ""
 | 
				
			||||||
@@ -115,10 +111,19 @@ def parse_and_extract(input_dir, verbose):
 | 
				
			|||||||
          except Exception as e:
 | 
					          except Exception as e:
 | 
				
			||||||
            print(f"{xml_file} : {e}")
 | 
					            print(f"{xml_file} : {e}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
          if text is not None and len(cats) > 1:
 | 
					          if text is not None and len(cats) >= 1:
 | 
				
			||||||
            entry["content"] = "\n".join(text)
 | 
					 | 
				
			||||||
            entry["language"] = lang
 | 
					            entry["language"] = lang
 | 
				
			||||||
 | 
					            entry["content"] = "\n".join(text)
 | 
				
			||||||
 | 
					            for cat in cats:
 | 
				
			||||||
 | 
					              entry[cat] = 1
 | 
				
			||||||
            articles.append(entry)
 | 
					            articles.append(entry)
 | 
				
			||||||
 | 
					          else:
 | 
				
			||||||
 | 
					            if len(cats) < 1:
 | 
				
			||||||
 | 
					              print(f"No article added for key {key} due to lack of categories")
 | 
				
			||||||
 | 
					            elif text is None:
 | 
				
			||||||
 | 
					              print(f"No article added for key {key} due to lack of text")
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					              print(f"No article added for key {key} due to unknown error")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except ET.ParseError as e:
 | 
					        except ET.ParseError as e:
 | 
				
			||||||
          if verbose > 1:
 | 
					          if verbose > 1:
 | 
				
			||||||
@@ -158,7 +163,10 @@ def scrub_data(articles, verbose):
 | 
				
			|||||||
  data['content'] = data.content.parallel_apply(lambda x: x.strip())
 | 
					  data['content'] = data.content.parallel_apply(lambda x: x.strip())
 | 
				
			||||||
  data['content'] = data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
 | 
					  data['content'] = data.content.parallel_apply(lambda x: re.sub(" +", " ", x))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # TODO: lemmas? See spaCy
 | 
					  # Any remaining text processing can be done by training/inference step
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  # sort category columns: lowercase first (key, language, content), then title-cased categories
 | 
				
			||||||
 | 
					  data.reindex(columns=sorted(data.columns, key=lambda x: (x.casefold(), x.swapcase())))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return data
 | 
					  return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,19 +19,19 @@ import tqdm
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchdata.datapipes as dp
 | 
					import torchdata.datapipes as dp
 | 
				
			||||||
import torchtext.transforms as T
 | 
					import torchtext.transforms as T
 | 
				
			||||||
 | 
					import torchtext.vocab as vocab
 | 
				
			||||||
from torch import nn
 | 
					from torch import nn
 | 
				
			||||||
from torch.utils.data import Dataset, DataLoader
 | 
					from torch.utils.data import Dataset, DataLoader
 | 
				
			||||||
from torchtext.vocab import build_vocab_from_iterator
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from models.rnn import RNN
 | 
					xlmr_vocab_path = r"https://download.pytorch.org/models/text/xlmr.vocab.pt"
 | 
				
			||||||
 | 
					xlmr_spm_model_path = r"https://download.pytorch.org/models/text/xlmr.sentencepiece.bpe.model"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
all_categories = list()
 | 
					 | 
				
			||||||
# XXX None for all stories
 | 
					# XXX None for all stories
 | 
				
			||||||
#story_num = 128
 | 
					story_num = 128
 | 
				
			||||||
#story_num = 256
 | 
					#story_num = 256
 | 
				
			||||||
#story_num = 512
 | 
					#story_num = 512
 | 
				
			||||||
#story_num = 1024
 | 
					#story_num = 1024
 | 
				
			||||||
story_num = 4096
 | 
					#story_num = 4096
 | 
				
			||||||
#story_num = None
 | 
					#story_num = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def read_csv(input_csv, rows=None, verbose=0):
 | 
					def read_csv(input_csv, rows=None, verbose=0):
 | 
				
			||||||
@@ -42,6 +42,7 @@ def read_csv(input_csv, rows=None, verbose=0):
 | 
				
			|||||||
          pd.read_csv(f,
 | 
					          pd.read_csv(f,
 | 
				
			||||||
            encoding="utf-8",
 | 
					            encoding="utf-8",
 | 
				
			||||||
            quoting=csv.QUOTE_ALL,
 | 
					            quoting=csv.QUOTE_ALL,
 | 
				
			||||||
 | 
					            index_col=0,
 | 
				
			||||||
            nrows=rows,
 | 
					            nrows=rows,
 | 
				
			||||||
            chunksize=50,
 | 
					            chunksize=50,
 | 
				
			||||||
          ),
 | 
					          ),
 | 
				
			||||||
@@ -52,10 +53,10 @@ def read_csv(input_csv, rows=None, verbose=0):
 | 
				
			|||||||
      data = pd.read_csv(f,
 | 
					      data = pd.read_csv(f,
 | 
				
			||||||
        encoding="utf-8",
 | 
					        encoding="utf-8",
 | 
				
			||||||
        quoting=csv.QUOTE_ALL,
 | 
					        quoting=csv.QUOTE_ALL,
 | 
				
			||||||
 | 
					        index_col=0,
 | 
				
			||||||
        nrows=rows,
 | 
					        nrows=rows,
 | 
				
			||||||
      )
 | 
					      )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  data.dropna(axis='index', inplace=True)
 | 
					 | 
				
			||||||
  #print(data)
 | 
					  #print(data)
 | 
				
			||||||
  #sys.exit(0)
 | 
					  #sys.exit(0)
 | 
				
			||||||
  return data
 | 
					  return data
 | 
				
			||||||
@@ -83,9 +84,9 @@ def split_dataset(data, verbose=0):
 | 
				
			|||||||
    #print("Length of tests_data: {}".format(len(tests_idx)))
 | 
					    #print("Length of tests_data: {}".format(len(tests_idx)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Create the training and validation sets, as dataframes
 | 
					  # Create the training and validation sets, as dataframes
 | 
				
			||||||
  train_data = data.iloc[train_idx].reset_index().drop('index', axis=1)
 | 
					  train_data = data.iloc[train_idx].reset_index()
 | 
				
			||||||
  valid_data = data.iloc[valid_idx].reset_index().drop('index', axis=1)
 | 
					  valid_data = data.iloc[valid_idx].reset_index()
 | 
				
			||||||
  #tests_data = data.iloc[tests_idx].reset_index().drop('index', axis=1)
 | 
					  #tests_data = data.iloc[tests_idx].reset_index()
 | 
				
			||||||
  #return(train_data, valid_data, tests_data)
 | 
					  #return(train_data, valid_data, tests_data)
 | 
				
			||||||
  return(train_data, valid_data)
 | 
					  return(train_data, valid_data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -96,24 +97,25 @@ def split_dataset(data, verbose=0):
 | 
				
			|||||||
'''
 | 
					'''
 | 
				
			||||||
class TextCategoriesDataset(Dataset):
 | 
					class TextCategoriesDataset(Dataset):
 | 
				
			||||||
  ''' Dataset of Text and Categories '''
 | 
					  ''' Dataset of Text and Categories '''
 | 
				
			||||||
  def __init__(self, df, text_column, cats_column, lang_column, transform=None, verbose=0):
 | 
					  def __init__(self, df, lang_column, text_column, first_cats_column=0, transform=None, verbose=0):
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    Arguments:
 | 
					    Arguments:
 | 
				
			||||||
      df (panda.Dataframe): csv content, loaded as dataframe
 | 
					      df (panda.Dataframe): csv content, loaded as dataframe
 | 
				
			||||||
 | 
					      lang_column (str): the name of the column containing the language
 | 
				
			||||||
      text_column (str): the name of the column containing the text
 | 
					      text_column (str): the name of the column containing the text
 | 
				
			||||||
      cats_column (str): the name of the column containing
 | 
					      first_cats_column (int): the index of the first column containing
 | 
				
			||||||
        semicolon-separated categories
 | 
					        a category
 | 
				
			||||||
      text_column (str): the name of the column containing the language
 | 
					      transform (callable, optional): Optional transform to be applied
 | 
				
			||||||
      transform (callable, optional): Optional transform to be
 | 
					        on a sample.
 | 
				
			||||||
        applied on a sample.
 | 
					 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    self.df = df
 | 
					    self.df = df
 | 
				
			||||||
    self.transform = transform
 | 
					    self.transform = transform
 | 
				
			||||||
    self.verbose = verbose
 | 
					    self.verbose = verbose
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    self.text = self.df[text_column]
 | 
					 | 
				
			||||||
    self.cats = self.df[cats_column]
 | 
					 | 
				
			||||||
    self.lang = self.df[lang_column]
 | 
					    self.lang = self.df[lang_column]
 | 
				
			||||||
 | 
					    self.text = self.df[text_column]
 | 
				
			||||||
 | 
					    self.cats = self.df.iloc[:, first_cats_column:].sort_index(axis="columns")
 | 
				
			||||||
 | 
					    self.cats_vocab = self.cats.columns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # index-to-token dict
 | 
					    # index-to-token dict
 | 
				
			||||||
    # <pad> : padding, used for padding the shorter sentences in a batch
 | 
					    # <pad> : padding, used for padding the shorter sentences in a batch
 | 
				
			||||||
@@ -126,26 +128,6 @@ class TextCategoriesDataset(Dataset):
 | 
				
			|||||||
    # token-to-index dict
 | 
					    # token-to-index dict
 | 
				
			||||||
    self.stoi = {k:j for j, k in self.itos.items()}
 | 
					    self.stoi = {k:j for j, k in self.itos.items()}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Create vocabularies upon initialisation
 | 
					 | 
				
			||||||
    self.text_vocab = build_vocab_from_iterator(
 | 
					 | 
				
			||||||
      [self.textTokens(text) for i, text in self.df[text_column].items()],
 | 
					 | 
				
			||||||
      min_freq=2,
 | 
					 | 
				
			||||||
      specials=self.itos.values(),
 | 
					 | 
				
			||||||
      special_first=True
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    self.text_vocab.set_default_index(self.text_vocab['<unk>'])
 | 
					 | 
				
			||||||
    #print(self.text_vocab.get_itos())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    self.cats_vocab = build_vocab_from_iterator(
 | 
					 | 
				
			||||||
      #[self.catTokens(cats) for i, cats in self.df[cats_column].items()],
 | 
					 | 
				
			||||||
      [self.catTokens(all_categories)],
 | 
					 | 
				
			||||||
      min_freq=1,
 | 
					 | 
				
			||||||
      specials=['<unk>'],
 | 
					 | 
				
			||||||
      special_first=True
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    self.cats_vocab.set_default_index(self.cats_vocab['<unk>'])
 | 
					 | 
				
			||||||
    #print(self.cats_vocab.get_itos())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  def __len__(self):
 | 
					  def __len__(self):
 | 
				
			||||||
    return len(self.df)
 | 
					    return len(self.df)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -158,58 +140,41 @@ class TextCategoriesDataset(Dataset):
 | 
				
			|||||||
      idx = idx.tolist()
 | 
					      idx = idx.tolist()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Get the raw data
 | 
					    # Get the raw data
 | 
				
			||||||
    text = self.text[idx]
 | 
					 | 
				
			||||||
    cats = self.cats[idx]
 | 
					 | 
				
			||||||
    lang = self.lang[idx]
 | 
					    lang = self.lang[idx]
 | 
				
			||||||
 | 
					    text = self.text[idx]
 | 
				
			||||||
 | 
					    cats = self.cats.iloc[idx]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #print(self.textTransform()(text))
 | 
				
			||||||
 | 
					    #print(cats)
 | 
				
			||||||
 | 
					    #print(cats.fillna(0).values)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if self.transform:
 | 
					    if self.transform:
 | 
				
			||||||
      text, cats = self.transform(text, cats)
 | 
					      text, cats = self.transform(text, cats)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #print(cats)
 | 
					    # Numericalise text by applying transforms, and cats by converting
 | 
				
			||||||
    #print(self.catTokens(cats))
 | 
					    # NaN to zeros and stripping the index
 | 
				
			||||||
    #print(self.getTransform(self.cats_vocab, "cats")(self.catTokens(cats)))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Numericalise by applying transforms
 | 
					 | 
				
			||||||
    return (
 | 
					    return (
 | 
				
			||||||
      self.getTransform(self.text_vocab, "text")(self.textTokens(text)),
 | 
					      self.textTransform()(text),
 | 
				
			||||||
      self.getTransform(self.cats_vocab, "cats")(self.catTokens(cats)),
 | 
					      cats.fillna(0).values,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  @staticmethod
 | 
					  def textTransform(self):
 | 
				
			||||||
  def textTokens(text):
 | 
					 | 
				
			||||||
    if isinstance(text, str):
 | 
					 | 
				
			||||||
      return [word for word in text.split()]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  @staticmethod
 | 
					 | 
				
			||||||
  def catTokens(cats):
 | 
					 | 
				
			||||||
    if isinstance(cats, str):
 | 
					 | 
				
			||||||
      return [cat for cat in cats.split(';')]
 | 
					 | 
				
			||||||
    elif isinstance(cats, list):
 | 
					 | 
				
			||||||
      return [cat for cat in cats]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  def getTransform(self, vocab, vType):
 | 
					 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    Create transforms based on given vocabulary. The returned transform
 | 
					    Create transforms based on given vocabulary. The returned transform
 | 
				
			||||||
    is applied to a sequence of tokens.
 | 
					    is applied to a sequence of tokens.
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    if vType == "text":
 | 
					 | 
				
			||||||
    return T.Sequential(
 | 
					    return T.Sequential(
 | 
				
			||||||
        # converts the sentences to indices based on given vocabulary
 | 
					      # converts the sentences to indices based on given vocabulary using SentencePiece
 | 
				
			||||||
        T.VocabTransform(vocab=vocab),
 | 
					      T.SentencePieceTokenizer(xlmr_spm_model_path),
 | 
				
			||||||
 | 
					      T.VocabTransform(torch.hub.load_state_dict_from_url(xlmr_vocab_path)),
 | 
				
			||||||
      # Add <sos> at beginning of each sentence. 1 because the index
 | 
					      # Add <sos> at beginning of each sentence. 1 because the index
 | 
				
			||||||
      # for <sos> in vocabulary is 1 as seen in previous section
 | 
					      # for <sos> in vocabulary is 1 as seen in previous section
 | 
				
			||||||
        T.AddToken(self.text_vocab['<sos>'], begin=True),
 | 
					      T.AddToken(self.stoi['<sos>'], begin=True),
 | 
				
			||||||
      # Add <eos> at end of each sentence. 2 because the index
 | 
					      # Add <eos> at end of each sentence. 2 because the index
 | 
				
			||||||
      # for <eos> in vocabulary is 2 as seen in previous section
 | 
					      # for <eos> in vocabulary is 2 as seen in previous section
 | 
				
			||||||
        T.AddToken(self.text_vocab['<eos>'], begin=False)
 | 
					      T.AddToken(self.stoi['<eos>'], begin=False)
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    elif vType == "cats":
 | 
					
 | 
				
			||||||
      return T.Sequential(
 | 
					 | 
				
			||||||
        # converts the sentences to indices based on given vocabulary
 | 
					 | 
				
			||||||
        T.VocabTransform(vocab=vocab),
 | 
					 | 
				
			||||||
      )
 | 
					 | 
				
			||||||
    else:
 | 
					 | 
				
			||||||
      raise Exception('wrong transformation type')
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
'''
 | 
					'''
 | 
				
			||||||
@@ -223,12 +188,11 @@ class CollateBatch:
 | 
				
			|||||||
  in a batch of equal length. We can do this a collate_fn callback class,
 | 
					  in a batch of equal length. We can do this a collate_fn callback class,
 | 
				
			||||||
  which returns a tensor
 | 
					  which returns a tensor
 | 
				
			||||||
  '''
 | 
					  '''
 | 
				
			||||||
  def __init__(self, pad_idx, cats):
 | 
					  def __init__(self, pad_idx):
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
      pad_idx (int):  the index of the "<pad>" token in the vocabulary.
 | 
					      pad_idx (int):  the index of the "<pad>" token in the vocabulary.
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    self.pad_idx = pad_idx
 | 
					    self.pad_idx = pad_idx
 | 
				
			||||||
    self.cats = cats
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def __call__(self, batch):
 | 
					  def __call__(self, batch):
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
@@ -236,13 +200,6 @@ class CollateBatch:
 | 
				
			|||||||
             is a list of tokens
 | 
					             is a list of tokens
 | 
				
			||||||
    '''
 | 
					    '''
 | 
				
			||||||
    batch_text, batch_cats = zip(*batch)
 | 
					    batch_text, batch_cats = zip(*batch)
 | 
				
			||||||
    #for i in range(len(batch)):
 | 
					 | 
				
			||||||
    #  print(batch[i])
 | 
					 | 
				
			||||||
    #max_text_len = len(max(batch_text, key=len))
 | 
					 | 
				
			||||||
    #max_cats_len = len(max(batch_cats, key=len))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    #text_tensor = T.ToTensor(self.pad_idx)(batch_text)
 | 
					 | 
				
			||||||
    #cats_tensor = T.ToTensor(self.pad_idx)(batch_cats)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Pad text to the longest
 | 
					    # Pad text to the longest
 | 
				
			||||||
    text_tensor = nn.utils.rnn.pad_sequence(
 | 
					    text_tensor = nn.utils.rnn.pad_sequence(
 | 
				
			||||||
@@ -251,44 +208,7 @@ class CollateBatch:
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
    text_lengths = torch.tensor([t.shape[0] for t in text_tensor])
 | 
					    text_lengths = torch.tensor([t.shape[0] for t in text_tensor])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #cats_tensor = torch.nn.utils.rnn.pad_sequence(
 | 
					    cats_tensor = torch.tensor(batch_cats, dtype=torch.float32)
 | 
				
			||||||
    #  [torch.LongTensor(s) for s in batch_cats],
 | 
					 | 
				
			||||||
    #  batch_first=True, padding_value=self.pad_idx
 | 
					 | 
				
			||||||
    #)
 | 
					 | 
				
			||||||
    #cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    '''
 | 
					 | 
				
			||||||
    # Pad cats_tensor to all possible categories
 | 
					 | 
				
			||||||
    num_cats = len(all_categories)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Convert cats to multi-label one-hot representation
 | 
					 | 
				
			||||||
    cats_tensor = torch.full((len(batch_cats), num_cats), self.pad_idx).float()
 | 
					 | 
				
			||||||
    cats_lengths = torch.LongTensor(list(map(len, batch_cats)))
 | 
					 | 
				
			||||||
    for idx, cats in enumerate(batch_cats):
 | 
					 | 
				
			||||||
      #print("\nsample", idx, cats)
 | 
					 | 
				
			||||||
      for c in cats:
 | 
					 | 
				
			||||||
        #print(c)
 | 
					 | 
				
			||||||
        cats_tensor[idx][c] = 1
 | 
					 | 
				
			||||||
      #print(cats_tensor[idx])
 | 
					 | 
				
			||||||
    '''
 | 
					 | 
				
			||||||
    # Convert cats to multi-label one-hot representation
 | 
					 | 
				
			||||||
    # add one to all_categories to account for <unk>
 | 
					 | 
				
			||||||
    cats_tensor = torch.full((len(batch_cats), len(all_categories)+1), self.pad_idx).float()
 | 
					 | 
				
			||||||
    for idx, cats in enumerate(batch_cats):
 | 
					 | 
				
			||||||
      #print("\nsample", idx, cats)
 | 
					 | 
				
			||||||
      for c in cats:
 | 
					 | 
				
			||||||
        cats_tensor[idx][c] = 1
 | 
					 | 
				
			||||||
      #print(cats_tensor[idx])
 | 
					 | 
				
			||||||
    #sys.exit(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    '''
 | 
					 | 
				
			||||||
    # XXX why??
 | 
					 | 
				
			||||||
    ## SORT YOUR TENSORS BY LENGTH!
 | 
					 | 
				
			||||||
    text_lengths, perm_idx = text_lengths.sort(0, descending=True)
 | 
					 | 
				
			||||||
    text_tensor = text_tensor[perm_idx]
 | 
					 | 
				
			||||||
    cats_tensor = cats_tensor[perm_idx]
 | 
					 | 
				
			||||||
    '''
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #print("text", text_tensor)
 | 
					    #print("text", text_tensor)
 | 
				
			||||||
    #print("text shape:", text_tensor.shape)
 | 
					    #print("text shape:", text_tensor.shape)
 | 
				
			||||||
@@ -296,7 +216,6 @@ class CollateBatch:
 | 
				
			|||||||
    #print("cats shape:", cats_tensor.shape)
 | 
					    #print("cats shape:", cats_tensor.shape)
 | 
				
			||||||
    #print(text_lengths)
 | 
					    #print(text_lengths)
 | 
				
			||||||
    #print("text_lengths shape:", text_lengths.shape)
 | 
					    #print("text_lengths shape:", text_lengths.shape)
 | 
				
			||||||
 | 
					 | 
				
			||||||
    #sys.exit(0)
 | 
					    #sys.exit(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return (
 | 
					    return (
 | 
				
			||||||
@@ -305,48 +224,27 @@ class CollateBatch:
 | 
				
			|||||||
      text_lengths,
 | 
					      text_lengths,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def cat2tensor(label_vocab, labels, pad_idx: int):
 | 
					def tensor2cat(dataset, tensor):
 | 
				
			||||||
  all_labels = vocab.get_itos()
 | 
					  cats = dataset.cats_vocab
 | 
				
			||||||
  num_labels = len(all_labels)
 | 
					 | 
				
			||||||
  # add <unk>
 | 
					 | 
				
			||||||
  if 0 not in all_labels:
 | 
					 | 
				
			||||||
    num_labels += 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  labels_tensor = torch.full((len(labels), num_labels), pad_idx).float()
 | 
					 | 
				
			||||||
  labels_lengths = torch.LongTensor(list(map(len, labels)))
 | 
					 | 
				
			||||||
  for idx, labels in enumerate(labels):
 | 
					 | 
				
			||||||
      #print("\nsample", idx, labels)
 | 
					 | 
				
			||||||
      for l in labels:
 | 
					 | 
				
			||||||
        labels_tensor[idx][l] = 1
 | 
					 | 
				
			||||||
      #print(labels_tensor[idx])
 | 
					 | 
				
			||||||
  return labels_tensor
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def tensor2cat(vocab, tensor):
 | 
					 | 
				
			||||||
  all_cats = vocab.get_itos()
 | 
					 | 
				
			||||||
  if tensor.ndimension() == 2:
 | 
					  if tensor.ndimension() == 2:
 | 
				
			||||||
    batch = list()
 | 
					    batch = list()
 | 
				
			||||||
    for result in tensor:
 | 
					    for result in tensor:
 | 
				
			||||||
      chance = dict()
 | 
					      chance = dict()
 | 
				
			||||||
      for idx, pred in enumerate(result):
 | 
					      for idx, pred in enumerate(result):
 | 
				
			||||||
        if pred > 0: # XXX
 | 
					        if pred > 0: # XXX
 | 
				
			||||||
          chance[all_cats[idx]] = pred.item()
 | 
					          chance[cats[idx]] = pred.item()
 | 
				
			||||||
      #print(chance)
 | 
					 | 
				
			||||||
      batch.append(chance)
 | 
					      batch.append(chance)
 | 
				
			||||||
    return batch
 | 
					    return batch
 | 
				
			||||||
  elif tensor.ndimension() == 1:
 | 
					  elif tensor.ndimension() == 1:
 | 
				
			||||||
    chance = dict()
 | 
					    chance = dict()
 | 
				
			||||||
    for idx, pred in enumerate(tensor):
 | 
					    for idx, pred in enumerate(tensor):
 | 
				
			||||||
      if idx >= len(all_cats):
 | 
					      if idx >= len(cats):
 | 
				
			||||||
        print(f"Idx {idx} not in {len(all_cats)} categories")
 | 
					        print(f"Idx {idx} not in {len(cats)} categories")
 | 
				
			||||||
      #elif pred > 0: # XXX
 | 
					      elif pred > 0: # XXX
 | 
				
			||||||
        #print(idx, len(all_cats))
 | 
					        chance[cats[idx]] = pred.item()
 | 
				
			||||||
      chance[all_cats[idx]] = pred.item()
 | 
					 | 
				
			||||||
    #print(chance)
 | 
					 | 
				
			||||||
    return chance
 | 
					    return chance
 | 
				
			||||||
  else:
 | 
					  else:
 | 
				
			||||||
   raise ValueError("Only tensors with 2 dimensions are supported")
 | 
					   raise ValueError("Only tensors with 1 dimension or batches with 2 dimensions are supported")
 | 
				
			||||||
 | 
					 | 
				
			||||||
  return vocab.get_itos(cat)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
 | 
					def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
 | 
				
			||||||
@@ -452,6 +350,17 @@ def evaluate(dataloader, dataset, model, criterion, epoch=0):
 | 
				
			|||||||
      })
 | 
					      })
 | 
				
			||||||
    return total_acc / total_count
 | 
					    return total_acc / total_count
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO seeding:
 | 
				
			||||||
 | 
					def seed_everything(seed=42):
 | 
				
			||||||
 | 
					  random.seed(seed)
 | 
				
			||||||
 | 
					  os.environ['PYTHONHASHSEED'] = str(seed)
 | 
				
			||||||
 | 
					  np.random.seed(seed)
 | 
				
			||||||
 | 
					  torch.manual_seed(seed)
 | 
				
			||||||
 | 
					  torch.cuda.manual_seed(seed)
 | 
				
			||||||
 | 
					  torch.cuda.manual_seed_all(seed)
 | 
				
			||||||
 | 
					  # Some cudnn methods can be random even after fixing the seed
 | 
				
			||||||
 | 
					  # unless you tell it to be deterministic
 | 
				
			||||||
 | 
					  torch.backends.cudnn.deterministic = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
  parser = argparse.ArgumentParser(
 | 
					  parser = argparse.ArgumentParser(
 | 
				
			||||||
@@ -487,37 +396,26 @@ def main():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose)
 | 
					  data = read_csv(input_csv=args.input, rows=story_num, verbose=args.verbose)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # create list of all categories
 | 
					 | 
				
			||||||
  global all_categories
 | 
					 | 
				
			||||||
  for cats in data.categories:
 | 
					 | 
				
			||||||
    for c in cats.split(";"):
 | 
					 | 
				
			||||||
      if c not in all_categories:
 | 
					 | 
				
			||||||
        all_categories.append(c)
 | 
					 | 
				
			||||||
  all_categories = sorted(all_categories)
 | 
					 | 
				
			||||||
  #print(all_categories)
 | 
					 | 
				
			||||||
  #print(len(all_categories))
 | 
					 | 
				
			||||||
  #sys.exit(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  train_data, valid_data, = split_dataset(data, verbose=args.verbose)
 | 
					  train_data, valid_data, = split_dataset(data, verbose=args.verbose)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  '''
 | 
					  '''
 | 
				
			||||||
  dataset = TextCategoriesDataset(df=data,
 | 
					  dataset = TextCategoriesDataset(df=data,
 | 
				
			||||||
    text_column="content",
 | 
					 | 
				
			||||||
    cats_column="categories",
 | 
					 | 
				
			||||||
    lang_column="language",
 | 
					    lang_column="language",
 | 
				
			||||||
 | 
					    text_column="content",
 | 
				
			||||||
 | 
					    first_cats_column=data.columns.get_loc("content")+1,
 | 
				
			||||||
    verbose=args.verbose,
 | 
					    verbose=args.verbose,
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  '''
 | 
					  '''
 | 
				
			||||||
  train_dataset = TextCategoriesDataset(df=train_data,
 | 
					  train_dataset = TextCategoriesDataset(df=train_data,
 | 
				
			||||||
    text_column="content",
 | 
					 | 
				
			||||||
    cats_column="categories",
 | 
					 | 
				
			||||||
    lang_column="language",
 | 
					    lang_column="language",
 | 
				
			||||||
 | 
					    text_column="content",
 | 
				
			||||||
 | 
					    first_cats_column=train_data.columns.get_loc("content")+1,
 | 
				
			||||||
    verbose=args.verbose,
 | 
					    verbose=args.verbose,
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  valid_dataset = TextCategoriesDataset(df=valid_data,
 | 
					  valid_dataset = TextCategoriesDataset(df=valid_data,
 | 
				
			||||||
    text_column="content",
 | 
					 | 
				
			||||||
    cats_column="categories",
 | 
					 | 
				
			||||||
    lang_column="language",
 | 
					    lang_column="language",
 | 
				
			||||||
 | 
					    text_column="content",
 | 
				
			||||||
 | 
					    first_cats_column=valid_data.columns.get_loc("content")+1,
 | 
				
			||||||
    verbose=args.verbose,
 | 
					    verbose=args.verbose,
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  #for text, cat in enumerate(train_dataset):
 | 
					  #for text, cat in enumerate(train_dataset):
 | 
				
			||||||
@@ -525,6 +423,7 @@ def main():
 | 
				
			|||||||
  #print("-" * 20)
 | 
					  #print("-" * 20)
 | 
				
			||||||
  #for text, cat in enumerate(valid_dataset):
 | 
					  #for text, cat in enumerate(valid_dataset):
 | 
				
			||||||
  #  print(text, cat)
 | 
					  #  print(text, cat)
 | 
				
			||||||
 | 
					  #print(tensor2cat(train_dataset, torch.tensor([0, 0, 0, 1., 0.9])))
 | 
				
			||||||
  #sys.exit(0)
 | 
					  #sys.exit(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # Get cpu, gpu or mps device for training.
 | 
					  # Get cpu, gpu or mps device for training.
 | 
				
			||||||
@@ -565,7 +464,7 @@ def main():
 | 
				
			|||||||
    drop_last=True,
 | 
					    drop_last=True,
 | 
				
			||||||
    shuffle=True,
 | 
					    shuffle=True,
 | 
				
			||||||
    num_workers=0,
 | 
					    num_workers=0,
 | 
				
			||||||
    collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
 | 
					    collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  '''
 | 
					  '''
 | 
				
			||||||
  train_dataloader = DataLoader(train_dataset,
 | 
					  train_dataloader = DataLoader(train_dataset,
 | 
				
			||||||
@@ -573,20 +472,20 @@ def main():
 | 
				
			|||||||
    drop_last=True,
 | 
					    drop_last=True,
 | 
				
			||||||
    shuffle=True,
 | 
					    shuffle=True,
 | 
				
			||||||
    num_workers=0,
 | 
					    num_workers=0,
 | 
				
			||||||
    collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
 | 
					    collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  valid_dataloader = DataLoader(valid_dataset,
 | 
					  valid_dataloader = DataLoader(valid_dataset,
 | 
				
			||||||
    batch_size=batch_size,
 | 
					    batch_size=batch_size,
 | 
				
			||||||
    drop_last=True,
 | 
					    drop_last=True,
 | 
				
			||||||
    shuffle=True,
 | 
					    shuffle=True,
 | 
				
			||||||
    num_workers=0,
 | 
					    num_workers=0,
 | 
				
			||||||
    collate_fn=CollateBatch(cats=train_dataset.cats_vocab.get_stoi(), pad_idx=train_dataset.stoi['<pad>']),
 | 
					    collate_fn=CollateBatch(pad_idx=train_dataset.stoi['<pad>']),
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
  #for i_batch, sample_batched in enumerate(dataloader):
 | 
					  #for i_batch, sample_batched in enumerate(dataloader):
 | 
				
			||||||
  #  print(i_batch, sample_batched[0], sample_batched[1])
 | 
					  #  print(i_batch, sample_batched[0], sample_batched[1])
 | 
				
			||||||
  #for i_batch, sample_batched in enumerate(train_dataloader):
 | 
					  for i_batch, sample_batched in enumerate(train_dataloader):
 | 
				
			||||||
  #  print(i_batch, sample_batched[0], sample_batched[1])
 | 
					    print(i_batch, sample_batched[0], sample_batched[1])
 | 
				
			||||||
  #sys.exit(0)
 | 
					  sys.exit(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  input_size = len(train_dataset.text_vocab)
 | 
					  input_size = len(train_dataset.text_vocab)
 | 
				
			||||||
  output_size = len(train_dataset.cats_vocab) # every output item is the likelihood of a particular category
 | 
					  output_size = len(train_dataset.cats_vocab) # every output item is the likelihood of a particular category
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user