From 0c199256ddfedfa6bcc5dc4f26d74dd5d2ee06f4 Mon Sep 17 00:00:00 2001
From: Timothy Allen <tim@treehouse.org.za>
Date: Sun, 31 Dec 2023 10:32:32 +0200
Subject: [PATCH] Add extra debugging

---
 africat/categorise.py | 58 +++++++++++++++++++++++++------------------
 1 file changed, 34 insertions(+), 24 deletions(-)

diff --git a/africat/categorise.py b/africat/categorise.py
index d20246e..200c94a 100755
--- a/africat/categorise.py
+++ b/africat/categorise.py
@@ -169,6 +169,7 @@ class TextCategoriesDataset(Dataset):
     return (
       self.textTransform()(text),
       cats.fillna(0).values.tolist(),
+      text,
     )
 
   def textTransform(self):
@@ -214,7 +215,7 @@ class CollateBatch:
       batch: a list of tuples with (text, cats), each of which
              is a list of tokens
     '''
-    batch_text, batch_cats = zip(*batch)
+    batch_text, batch_cats, batch_orig = zip(*batch)
 
     # Pad text to the longest
     text_tensor = nn.utils.rnn.pad_sequence(
@@ -236,6 +237,7 @@ class CollateBatch:
     return (
       text_tensor,
       cats_tensor,
+      batch_orig,
     )
 
 def tensor2cat(dataset, tensor):
@@ -267,28 +269,26 @@ def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
   total_acc, total_count = 0, 1 # XXX
   log_interval = 500
 
-  torch.set_printoptions(precision=2)
-
   model.train()
 
   batch = tqdm.tqdm(dataloader, unit="batch")
   for idx, data in enumerate(batch):
     batch.set_description(f"Train {epoch}.{idx}")
-    text, cats = data
+    text, cats, orig_text = data
     optimizer.zero_grad()
 
     output = model(text)
     #print("output", output)
     #print("output shape", output.shape)
 
+    optimizer.zero_grad()
     loss = criterion(input=output, target=cats)
+    optimizer.zero_grad()
     loss.backward()
-
-    nn.utils.clip_grad_norm_(model.parameters(), 0.1)
-
+    #nn.utils.clip_grad_norm_(model.parameters(), 0.1)
     optimizer.step()
 
-    print("train loss", loss)
+    #print("train loss", loss)
 
     ##predicted = np.round(output)
     ##total_acc += (predicted == cats).sum().item()
@@ -299,10 +299,12 @@ def train(dataloader, dataset, model, optimizer, criterion, epoch=0):
     predictions[output <  0.5] = False ## assign 0 label to those with less than 0.5
 
     batch.clear()
-    for target, out, pred in list(zip(cats, output, predictions)):
+    for target, out, pred, orig in list(zip(cats, output, predictions, orig_text)):
       expect  = tensor2cat(dataset, target)
       raw     = tensor2cat(dataset, out)
       predict = tensor2cat(dataset, pred)
+      print("Text:", orig)
+      print("Loss:", loss.item())
       print("Expected:  ", expect)
       print("Predicted: ", predict)
       print("Raw output:", raw)
@@ -333,7 +335,7 @@ def evaluate(dataloader, dataset, model, criterion, epoch=0):
     batch = tqdm.tqdm(dataloader, unit="batch")
     for idx, data in enumerate(batch):
       batch.set_description(f"Evaluate {epoch}.{idx}")
-      text, cats = data
+      text, cats, orig_text = data
 
       output = model(text)
       #print("eval predicted", output)
@@ -346,10 +348,12 @@ def evaluate(dataloader, dataset, model, criterion, epoch=0):
       predictions[output <  0.5] = False ## assign 0 label to those with less than 0.5
 
       batch.clear()
-      for target, out, pred in list(zip(cats, output, predictions)):
+      for target, out, pred, orig in list(zip(cats, output, predictions, orig_text)):
         expect  = tensor2cat(dataset, target)
         raw     = tensor2cat(dataset, out)
         predict = tensor2cat(dataset, pred)
+        print("Evaluate Text:", orig)
+        print("Evaluate Loss:", loss.item())
         print("Evaluate expected:  ", expect)
         print("Evaluate predicted: ", predict)
         print("Evaluate raw output:", raw)
@@ -465,17 +469,20 @@ def main():
   )
   print(f"Using {device} device")
 
+  torch.set_printoptions(precision=2)
+
   # Hyperparameters
-  #epochs = 10      # epoch
-  epochs = 6       # epoch
+  epochs = 10      # epoch
+  #epochs = 6       # epoch
   #epochs = 4       # epoch
   #lr = 5           # learning rate
   #lr = 0.5
   #lr = 0.05
   #lr = 0.005       # initial learning rate; too small may result in a long training process that could get stuck, whereas a value too large may result in learning a sub-optimal set of weights too fast or an unstable training process -- perhaps the most important hyperparameter. If you have time to tune only one hyperparameter, tune the learning rate
-  lr = 0.0001
+  lr = 0.00005
   #batch_size = 64  # batch size for training
-  batch_size = 16  # batch size for training
+  batch_size = 32  # batch size for training
+  #batch_size = 16  # batch size for training
   #batch_size = 8   # batch size for training
   #batch_size = 4   # batch size for training
 
@@ -485,8 +492,9 @@ def main():
   #hidden_size = 8 # hidden size of rnn module, should be tweaked manually
   mean_seq = True # use mean of rnn output
   #mean_seq = False # use mean of rnn output
-  weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
   #weight_decay = 1e-3 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
+  #weight_decay = 1e-4 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
+  weight_decay = 1e-5 # helps the neural networks to learn smoother / simpler functions which most of the time generalizes better compared to spiky, noisy ones ; try 1e-3, 1e-4
 
   '''
   dataloader = DataLoader(dataset,
@@ -545,7 +553,9 @@ def main():
   # optimizer and loss
   criterion = nn.BCEWithLogitsLoss()
   #optimizer = torch.optim.SGD(model.parameters(), lr=lr)
-  optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
+  #optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
+  optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
+
   if args.verbose:
     print(criterion)
     print(optimizer)
@@ -560,17 +570,17 @@ def main():
 
     accu_val = evaluate(valid_dataloader, valid_dataset, model, criterion, epoch)
 
-    if total_accu is not None and total_accu > accu_val:
-      optimizer.step()
-    else:
-      total_accu = accu_val
+    #if total_accu is not None and total_accu > accu_val:
+    #  optimizer.step()
+    #else:
+    #  total_accu = accu_val
     e.set_postfix({
       "accuracy": accu_val,
     })
 
-#  print("Checking the results of test dataset.")
-#  accu_test = evaluate(test_dataloader, test_dataset)
-#  print("test accuracy {:8.3f}".format(accu_test))
+  print("Checking the results of test dataset.")
+  accu_test = evaluate(test_dataloader, test_dataset)
+  print("test accuracy {:8.3f}".format(accu_test))
 
   if model_out is not None:
     torch.save(model.state_dict(), model_out)