Skip to content
Snippets Groups Projects
Commit a3f58aaf authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Add visualisation script

parent 7cda846e
No related branches found
No related tags found
No related merge requests found
......@@ -88,7 +88,7 @@ def draw_visualization_grid(columns, outfile, row_labels=None, name=None):
plt.savefig(f'{outfile}.png')
plt.close()
def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, save_file = False):
def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save="./", step= 0, save_file = False):
fig, ax = plt.subplots(1, num_slots + 2, figsize=(15, 2))
image = image.squeeze(0)
recon_combined = recon_combined.squeeze(0)
......@@ -99,19 +99,19 @@ def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, sa
recons = recons.cpu().detach().numpy()
masks = masks.cpu().detach().numpy()
# Extract data and put it on a plot
ax[0].imshow(image)
ax[0].set_title('Image')
ax[1].imshow(recon_combined)
ax[1].set_title('Recon.')
for i in range(6):
picture = recons[i] * masks[i] + (1 - masks[i])
ax[i + 2].imshow(picture)
ax[i + 2].set_title('Slot %s' % str(i + 1))
for i in range(len(ax)):
ax[i].grid(False)
ax[i].axis('off')
if not save_file:
ax[0].imshow(image)
ax[0].set_title('Image')
ax[1].imshow(recon_combined)
ax[1].set_title('Recon.')
for i in range(6):
picture = recons[i] * masks[i] + (1 - masks[i])
ax[i + 2].imshow(picture)
ax[i + 2].set_title('Slot %s' % str(i + 1))
for i in range(len(ax)):
ax[i].grid(False)
ax[i].axis('off')
plt.show()
else:
# TODO : save png in file
pass
plt.savefig(f'{folder_save}visualisation_{step}.png', bbox_inches='tight')
......@@ -7,6 +7,7 @@ import argparse
import yaml
from osrt.model import SlotAttentionAutoEncoder
from osrt import data
from osrt.utils.visualize import visualize_slot_attention
from torch.utils.data import DataLoader
import torch.nn.functional as F
......@@ -64,16 +65,16 @@ def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resolution = (128, 128)
# Build dataset iterators, optimizers, and model.
"""data_iterator = data_utils.build_clevr_iterator(
batch_size, split="train", resolution=resolution, shuffle=True,
max_n_objects=6, get_properties=False, apply_crop=True)"""
train_dataset = data.get_dataset('train', cfg['data'])
train_loader = DataLoader(
train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True,
shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True)
vis_dataset = data.get_dataset('test', cfg['data'])
vis_loader = DataLoader(
vis_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True,
shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True)
model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations).to(device)
num_params = sum(p.numel() for p in model.parameters())
......@@ -116,17 +117,24 @@ def main():
global_step += 1
# Log the training loss.
if not global_step % 100:
print("Step: %s, Loss: %.6f, Time: %s",
global_step, loss_value,
datetime.timedelta(seconds=time.time() - start))
# We save the checkpoints every 1000 iterations.
if not global_step % 1000:
if not global_step % cfg["training"]["print_every"]:
print(f"Step: {global_step}, Loss: {loss_value}, Time: {datetime.timedelta(seconds=time.time() - start)}")
# We save the checkpoints
if not global_step % cfg["training"]["checkpoint_every"]:
# Save the checkpoint of the model.
ckpt['global_step'] = global_step
torch.save(ckpt, args.ckpt + '/ckpt.pth')
print("Saved checkpoint: %s", args.ckpt + '/ckpt.pth')
print(f"Saved checkpoint: {args.ckpt + '/ckpt_' + str(global_step) + '.pth'}")
# We visualize some test data
if not global_step % cfg["training"]["visualize_every"]:
image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1)
image = F.interpolate(image, size=128)
image = image.to(device)
recon_combined, recons, masks, slots = model(image)
visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=global_step, save_file=True)
if __name__ == "__main__":
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment