From a3f58aafa513475e99d7e562b17a8da3eafb47db Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Fri, 21 Jul 2023 15:08:43 +0200
Subject: [PATCH] Add visualisation script

---
 osrt/utils/visualize.py | 28 ++++++++++++++--------------
 train_sa.py             | 34 +++++++++++++++++++++-------------
 2 files changed, 35 insertions(+), 27 deletions(-)

diff --git a/osrt/utils/visualize.py b/osrt/utils/visualize.py
index af51290..677ad6d 100644
--- a/osrt/utils/visualize.py
+++ b/osrt/utils/visualize.py
@@ -88,7 +88,7 @@ def draw_visualization_grid(columns, outfile, row_labels=None, name=None):
     plt.savefig(f'{outfile}.png')
     plt.close()
 
-def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, save_file = False):
+def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save="./", step= 0, save_file = False):
     fig, ax = plt.subplots(1, num_slots + 2, figsize=(15, 2))
     image = image.squeeze(0)
     recon_combined = recon_combined.squeeze(0)
@@ -99,19 +99,19 @@ def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, sa
     recons = recons.cpu().detach().numpy()
     masks = masks.cpu().detach().numpy()
 
+    # Extract data and put it on a plot
+    ax[0].imshow(image)
+    ax[0].set_title('Image')
+    ax[1].imshow(recon_combined)
+    ax[1].set_title('Recon.')
+    for i in range(6):
+        picture = recons[i] * masks[i] + (1 - masks[i])
+        ax[i + 2].imshow(picture)
+        ax[i + 2].set_title('Slot %s' % str(i + 1))
+    for i in range(len(ax)):
+        ax[i].grid(False)
+        ax[i].axis('off')
     if not save_file:
-        ax[0].imshow(image)
-        ax[0].set_title('Image')
-        ax[1].imshow(recon_combined)
-        ax[1].set_title('Recon.')
-        for i in range(6):
-            picture = recons[i] * masks[i] + (1 - masks[i])
-            ax[i + 2].imshow(picture)
-            ax[i + 2].set_title('Slot %s' % str(i + 1))
-        for i in range(len(ax)):
-            ax[i].grid(False)
-            ax[i].axis('off')
         plt.show()
     else:
-        # TODO : save png in file
-        pass
+        plt.savefig(f'{folder_save}visualisation_{step}.png', bbox_inches='tight')
diff --git a/train_sa.py b/train_sa.py
index 35aa44d..173fb14 100644
--- a/train_sa.py
+++ b/train_sa.py
@@ -7,6 +7,7 @@ import argparse
 import yaml
 from osrt.model import SlotAttentionAutoEncoder
 from osrt import data
+from osrt.utils.visualize import visualize_slot_attention
 
 from torch.utils.data import DataLoader
 import torch.nn.functional as F
@@ -64,16 +65,16 @@ def main():
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
     resolution = (128, 128)
-
-    # Build dataset iterators, optimizers, and model.
-    """data_iterator = data_utils.build_clevr_iterator(
-        batch_size, split="train", resolution=resolution, shuffle=True,
-        max_n_objects=6, get_properties=False, apply_crop=True)"""
     
     train_dataset = data.get_dataset('train', cfg['data'])
     train_loader = DataLoader(
         train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True,
         shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True)
+    
+    vis_dataset = data.get_dataset('test', cfg['data'])
+    vis_loader = DataLoader(
+        vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True,
+        shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True)
 
     model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device)
     num_params = sum(p.numel() for p in model.parameters())
@@ -116,17 +117,24 @@ def main():
         global_step += 1
 
         # Log the training loss.
-        if not global_step % 100:
-            print("Step: %s, Loss: %.6f, Time: %s",
-                         global_step, loss_value,
-                         datetime.timedelta(seconds=time.time() - start))
-
-        # We save the checkpoints every 1000 iterations.
-        if not global_step % 1000:
+        if not global_step % cfg["training"]["print_every"]:
+            print(f"Step: {global_step}, Loss: {loss_value}, Time: {datetime.timedelta(seconds=time.time() - start)}")
+                         
+        # We save the checkpoints
+        if not global_step % cfg["training"]["checkpoint_every"]:
             # Save the checkpoint of the model.
             ckpt['global_step'] = global_step
             torch.save(ckpt, args.ckpt + '/ckpt.pth')
-            print("Saved checkpoint: %s", args.ckpt + '/ckpt.pth')
+            print(f"Saved checkpoint: {args.ckpt + '/ckpt_' + str(global_step) + '.pth'}")
+
+        # We visualize some test data
+        if not global_step % cfg["training"]["visualize_every"]:
+            image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
+            image = F.interpolate(image, size=128)
+            image = image.to(device)
+            recon_combined, recons, masks, slots = model(image)
+            visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=global_step, save_file=True)
+
 
 if __name__ == "__main__":
     main()
\ No newline at end of file
-- 
GitLab