From a3f58aafa513475e99d7e562b17a8da3eafb47db Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Fri, 21 Jul 2023 15:08:43 +0200 Subject: [PATCH] Add visualisation script --- osrt/utils/visualize.py | 28 ++++++++++++++-------------- train_sa.py | 34 +++++++++++++++++++++------------- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/osrt/utils/visualize.py b/osrt/utils/visualize.py index af51290..677ad6d 100644 --- a/osrt/utils/visualize.py +++ b/osrt/utils/visualize.py @@ -88,7 +88,7 @@ def draw_visualization_grid(columns, outfile, row_labels=None, name=None): plt.savefig(f'{outfile}.png') plt.close() -def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, save_file = False): +def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save="./", step= 0, save_file = False): fig, ax = plt.subplots(1, num_slots + 2, figsize=(15, 2)) image = image.squeeze(0) recon_combined = recon_combined.squeeze(0) @@ -99,19 +99,19 @@ def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, sa recons = recons.cpu().detach().numpy() masks = masks.cpu().detach().numpy() + # Extract data and put it on a plot + ax[0].imshow(image) + ax[0].set_title('Image') + ax[1].imshow(recon_combined) + ax[1].set_title('Recon.') + for i in range(6): + picture = recons[i] * masks[i] + (1 - masks[i]) + ax[i + 2].imshow(picture) + ax[i + 2].set_title('Slot %s' % str(i + 1)) + for i in range(len(ax)): + ax[i].grid(False) + ax[i].axis('off') if not save_file: - ax[0].imshow(image) - ax[0].set_title('Image') - ax[1].imshow(recon_combined) - ax[1].set_title('Recon.') - for i in range(6): - picture = recons[i] * masks[i] + (1 - masks[i]) - ax[i + 2].imshow(picture) - ax[i + 2].set_title('Slot %s' % str(i + 1)) - for i in range(len(ax)): - ax[i].grid(False) - ax[i].axis('off') plt.show() else: - # TODO : save png in file - pass + plt.savefig(f'{folder_save}visualisation_{step}.png', bbox_inches='tight') diff --git a/train_sa.py b/train_sa.py index 35aa44d..173fb14 100644 --- a/train_sa.py +++ b/train_sa.py @@ -7,6 +7,7 @@ import argparse import yaml from osrt.model import SlotAttentionAutoEncoder from osrt import data +from osrt.utils.visualize import visualize_slot_attention from torch.utils.data import DataLoader import torch.nn.functional as F @@ -64,16 +65,16 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") resolution = (128, 128) - - # Build dataset iterators, optimizers, and model. - """data_iterator = data_utils.build_clevr_iterator( - batch_size, split="train", resolution=resolution, shuffle=True, - max_n_objects=6, get_properties=False, apply_crop=True)""" train_dataset = data.get_dataset('train', cfg['data']) train_loader = DataLoader( train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True, shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True) + + vis_dataset = data.get_dataset('test', cfg['data']) + vis_loader = DataLoader( + vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True, + shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True) model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device) num_params = sum(p.numel() for p in model.parameters()) @@ -116,17 +117,24 @@ def main(): global_step += 1 # Log the training loss. - if not global_step % 100: - print("Step: %s, Loss: %.6f, Time: %s", - global_step, loss_value, - datetime.timedelta(seconds=time.time() - start)) - - # We save the checkpoints every 1000 iterations. - if not global_step % 1000: + if not global_step % cfg["training"]["print_every"]: + print(f"Step: {global_step}, Loss: {loss_value}, Time: {datetime.timedelta(seconds=time.time() - start)}") + + # We save the checkpoints + if not global_step % cfg["training"]["checkpoint_every"]: # Save the checkpoint of the model. ckpt['global_step'] = global_step torch.save(ckpt, args.ckpt + '/ckpt.pth') - print("Saved checkpoint: %s", args.ckpt + '/ckpt.pth') + print(f"Saved checkpoint: {args.ckpt + '/ckpt_' + str(global_step) + '.pth'}") + + # We visualize some test data + if not global_step % cfg["training"]["visualize_every"]: + image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1) + image = F.interpolate(image, size=128) + image = image.to(device) + recon_combined, recons, masks, slots = model(image) + visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=global_step, save_file=True) + if __name__ == "__main__": main() \ No newline at end of file -- GitLab