diff --git a/.visualisation_1639.png b/.visualisation_1639.png index fece79284e2d145e76f53cc37c395ef07a98cebf..4a022d7efa8d6d435d123562b5d9dcbd0337532a 100644 Binary files a/.visualisation_1639.png and b/.visualisation_1639.png differ diff --git a/compile_video.py b/compile_video.py index 8b6de1aa9846bd1266a1c82f3e5cecd5ffd4dfaf..f6b5de30cb7f55ed44404e3747e6a001736fae3b 100644 --- a/compile_video.py +++ b/compile_video.py @@ -8,7 +8,7 @@ from tqdm import tqdm import argparse, os, subprocess from os.path import join -from osrt.utils.visualize import setup_axis, background_image +from osrt.utils.visualization_utils import setup_axis, background_image def compile_video_plot(path, frames=False, num_frames=1000000000): diff --git a/eval_sa.py b/eval_sa.py deleted file mode 100644 index 858462dc755c04c79e1a287b98734f28e606d21d..0000000000000000000000000000000000000000 --- a/eval_sa.py +++ /dev/null @@ -1,55 +0,0 @@ -from osrt import data -from osrt.model import SlotAttentionAutoEncoder -import torch -import matplotlib.pyplot as plt -from PIL import Image as Image -import argparse -import yaml -from torch.utils.data import DataLoader -import torch.nn.functional as F -from osrt.utils.visualize import visualize_slot_attention - -if __name__ == "__main__": - # Arguments - parser = argparse.ArgumentParser( - description='Train a 3D scene representation model.' - ) - parser.add_argument('config', type=str, help="Where to save the checkpoints.") - parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') - parser.add_argument('--seed', type=int, default=0, help='Random seed.') - parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') - - args = parser.parse_args() - with open(args.config, 'r') as f: - cfg = yaml.load(f, Loader=yaml.CLoader) - - # Hyperparameters. - seed = 0 - batch_size = 1 - num_slots = 7 - num_iterations = 3 - resolution = (128, 128) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model = SlotAttentionAutoEncoder(resolution, num_slots, num_iterations) - model = torch.load('./ckpt.pth')['network'] - print(model) - model.eval() - - - - eval_dataset = data.get_dataset('train', cfg['data']) - eval_loader = DataLoader( - eval_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], pin_memory=True, - shuffle=True, worker_init_fn=data.worker_init_fn, persistent_workers=True) - - model = model.to(device) - - image = torch.squeeze(next(iter(eval_loader)).get('input_images').to(device), dim=1) - image = F.interpolate(image, size=128) - image = image.to(device) - recon_combined, recons, masks, slots = model(image) - - visualize_slot_attention(num_slots, image, recon_combined, recons, masks) - - - diff --git a/evaluate_sa.py b/evaluate_sa.py new file mode 100644 index 0000000000000000000000000000000000000000..af75b4bb89159700ede32361e378238a24693ac0 --- /dev/null +++ b/evaluate_sa.py @@ -0,0 +1,61 @@ + +import argparse +import yaml + +from osrt.model import LitSlotAttentionAutoEncoder +from osrt import data + +from torch.utils.data import DataLoader + +import lightning as pl +from lightning.pytorch.loggers.wandb import WandbLogger +from lightning.pytorch.callbacks import ModelCheckpoint + +import torch + +def main(): + # Arguments + parser = argparse.ArgumentParser( + description='Train a 3D scene representation model.' + ) + parser.add_argument('config', type=str, help="Where to save the checkpoints.") + parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') + parser.add_argument('--seed', type=int, default=0, help='Random seed.') + parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') + + args = parser.parse_args() + with open(args.config, 'r') as f: + cfg = yaml.load(f, Loader=yaml.CLoader) + + ### Set random seed. + pl.seed_everything(42, workers=True) + + ### Hyperparameters of the model. + batch_size = cfg["training"]["batch_size"] + num_gpus = cfg["training"]["num_gpus"] + num_slots = cfg["model"]["num_slots"] + num_iterations = cfg["model"]["iters"] + resolution = (128, 128) + + #### Create datasets + test_dataset = data.get_dataset('val', cfg['data']) + test_dataloader = DataLoader( + test_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], + shuffle=True, worker_init_fn=data.worker_init_fn) + + #### Create model + model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) + checkpoint = torch.load(args.ckpt) + + model.load_state_dict(checkpoint['state_dict']) + model.eval() + + trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, + strategy="auto") + + trainer.validate(model, dataloaders=test_dataloader) + + +if __name__ == "__main__": + main() + diff --git a/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0 b/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0 new file mode 100644 index 0000000000000000000000000000000000000000..ddc38a00a26221598c3c339d50d8e4418047e23a Binary files /dev/null and b/lightning_logs/version_0/events.out.tfevents.1690289036.achapin-Precision-5570.74650.0 differ diff --git a/lightning_logs/version_0/hparams.yaml b/lightning_logs/version_0/hparams.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93 --- /dev/null +++ b/lightning_logs/version_0/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0 b/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0 new file mode 100644 index 0000000000000000000000000000000000000000..98f994d92a84a7072f246483250046a883e74f30 Binary files /dev/null and b/lightning_logs/version_1/events.out.tfevents.1690289388.achapin-Precision-5570.76271.0 differ diff --git a/lightning_logs/version_1/hparams.yaml b/lightning_logs/version_1/hparams.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93 --- /dev/null +++ b/lightning_logs/version_1/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/osrt/model.py b/osrt/model.py index 71511e635ca2dc8f9b8695bbab3a9bc2314cfa16..4b04c5b0ee3fd8d523554e9bcdec6e1a7460a002 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -115,7 +115,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): x = self.encoder_pos(x) x = self.mlp(self.layer_norm(x)) - slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2), slots = slots) + slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2)) x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder_initial_size, -1) x = self.decoder_pos(x) x = self.decoder_cnn(x.movedim(-1, 1)) @@ -131,6 +131,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) def configure_optimizers(self) -> Any: + print(self.parameters()) optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) return optimizer @@ -187,3 +188,20 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): return {'loss': loss_value, 'val_psnr': psnr.item()} + def test_step(self, batch, batch_idx): + """Perform a single eval step.""" + input_image = torch.squeeze(batch.get('input_images'), dim=1) + input_image = F.interpolate(input_image, size=128) + + # Get the prediction of the model and compute the loss. + preds = self.one_step(input_image) + recon_combined, recons, masks, slots, _ = preds + #input_image = input_image.permute(0, 2, 3, 1) + loss_value = self.criterion(recon_combined, input_image) + del recons, masks, slots # Unused. + psnr = mse2psnr(loss_value) + self.log('test_loss', loss_value) + self.log('test_psnr', psnr) + + return {'loss': loss_value, 'test_psnr': psnr.item()} + \ No newline at end of file diff --git a/osrt/trainer.py b/osrt/trainer.py index 01eb17dc2ee770a4ed623811bafa5479a24708cf..6afa1745b4ba75b4f78add713305a47843b4e0d9 100644 --- a/osrt/trainer.py +++ b/osrt/trainer.py @@ -3,7 +3,7 @@ import torch.distributed as dist import numpy as np from tqdm import tqdm -import osrt.utils.visualize as vis +import osrt.utils.visualization_utils as vis from osrt.utils.common import mse2psnr, reduce_dict, gather_all, compute_adjusted_rand_index from osrt.utils import nerf from osrt.utils.common import get_rank, get_world_size diff --git a/osrt/utils/visualize.py b/osrt/utils/visualization_utils.py similarity index 95% rename from osrt/utils/visualize.py rename to osrt/utils/visualization_utils.py index 677ad6db5179949b08c11b5a74c8d5be4aaa85d4..93f8c4e7013dacc84d2d9ba86d77d518a4e8c30a 100644 --- a/osrt/utils/visualize.py +++ b/osrt/utils/visualization_utils.py @@ -95,18 +95,19 @@ def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, fo recons = recons.squeeze(0) masks = masks.squeeze(0) image = image.permute(1,2,0).cpu().numpy() - recon_combined = recon_combined.cpu().detach().numpy() + recon_combined = recon_combined.permute(1,2,0).cpu().detach().numpy() recons = recons.cpu().detach().numpy() masks = masks.cpu().detach().numpy() # Extract data and put it on a plot ax[0].imshow(image) ax[0].set_title('Image') - ax[1].imshow(recon_combined) + print(image) + ax[1].imshow((recon_combined * 255).astype(np.uint8)) ax[1].set_title('Recon.') for i in range(6): picture = recons[i] * masks[i] + (1 - masks[i]) - ax[i + 2].imshow(picture) + ax[i + 2].imshow(picture.transpose(1,2,0)) ax[i + 2].set_title('Slot %s' % str(i + 1)) for i in range(len(ax)): ax[i].grid(False) diff --git a/render.py b/render.py index 355553f656c26943763d2b7f4b1bac877266dc3a..b4edb0bfb3167e992f7010c3f813939b4b023575 100644 --- a/render.py +++ b/render.py @@ -9,7 +9,7 @@ from tqdm import tqdm from osrt.data import get_dataset from osrt.checkpoint import Checkpoint -from osrt.utils.visualize import visualize_2d_cluster, get_clustering_colors +from osrt.utils.visualization_utils import visualize_2d_cluster, get_clustering_colors from osrt.utils.nerf import rotate_around_z_axis_torch, get_camera_rays, transform_points_torch, get_extrinsic_torch from osrt.model import OSRT from osrt.trainer import SRTTrainer diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml index 1164186692f69964292e15c2bc6bd2ea9ed6a024..161aaf628a68e847d514e2e78bdf5867ceb2167d 100644 --- a/runs/clevr3d/slot_att/config.yaml +++ b/runs/clevr3d/slot_att/config.yaml @@ -6,8 +6,8 @@ model: model_type: sa training: num_workers: 2 - num_gpus: 8 - batch_size: 32 + num_gpus: 1 + batch_size: 64 max_it: 333000000 warmup_it: 10000 decay_rate: 0.5 diff --git a/train_sa.py b/train_sa.py index 25f5c124da5a746f89a7aaf736d3c2ce42825306..ad03192032f1e171721ee3f801dffa6d1400182d 100644 --- a/train_sa.py +++ b/train_sa.py @@ -7,7 +7,8 @@ import yaml from osrt.model import LitSlotAttentionAutoEncoder from osrt import data -from osrt.utils.visualize import visualize_slot_attention +from osrt.utils.visualization_utils import visualize_slot_attention +from osrt.utils.common import mse2psnr from torch.utils.data import DataLoader import torch.nn.functional as F @@ -70,10 +71,30 @@ def main(): trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", default_root_dir="./logs", logger=WandbLogger(project="slot-att") if args.wandb else None, - strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "default", callbacks=[checkpoint_callback], + strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", callbacks=[checkpoint_callback], log_every_n_steps=100, max_steps=num_train_steps) trainer.fit(model, train_loader, val_loader) + + + + #### Create datasets + vis_dataset = data.get_dataset('train', cfg['data']) + vis_loader = DataLoader( + vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"], + shuffle=True, worker_init_fn=data.worker_init_fn) + + device = model.device + + #### Visualize image + image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1) + image = F.interpolate(image, size=128) + image = image.to(device) + recon_combined, recons, masks, slots, _ = model(image) + loss = nn.MSELoss() + loss_value = loss(recon_combined, image) + psnr = mse2psnr(loss_value) + visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=0, save_file=True) if __name__ == "__main__": main() diff --git a/visualise.py b/visualize_sa.py similarity index 63% rename from visualise.py rename to visualize_sa.py index 19d311b6cd7703d72366b3447465518d659c3775..9153006d27592d4d0961fc60f246f29d60505ef6 100644 --- a/visualise.py +++ b/visualize_sa.py @@ -8,7 +8,7 @@ import yaml from osrt.model import LitSlotAttentionAutoEncoder from osrt import data -from osrt.utils.visualize import visualize_slot_attention +from osrt.utils.visualization_utils import visualize_slot_attention from osrt.utils.common import mse2psnr from torch.utils.data import DataLoader @@ -37,47 +37,33 @@ def main(): ### Hyperparameters of the model. num_slots = cfg["model"]["num_slots"] num_iterations = cfg["model"]["iters"] - base_learning_rate = 0.0004 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") resolution = (128, 128) #### Create datasets - - vis_dataset = data.get_dataset('test', cfg['data']) + vis_dataset = data.get_dataset('train', cfg['data']) vis_loader = DataLoader( vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"], shuffle=True, worker_init_fn=data.worker_init_fn) #### Create model - model = LitSlotAttentionAutoEncoder(resolution, 10, num_iterations, cfg=cfg) - num_params = sum(p.numel() for p in model.parameters()) - - print('Number of parameters:') - print(f'Model slot attention: {num_params}') - - optimizer = optim.Adam(model.parameters(), lr=base_learning_rate, eps=1e-08) + model = LitSlotAttentionAutoEncoder(resolution, 6, num_iterations, cfg=cfg).to(device) + checkpoint = torch.load(args.ckpt) + + model.load_state_dict(checkpoint['state_dict']) - ckpt = { - 'network': model, - 'optimizer': optimizer, - 'global_step': 1639 - } - #ckpt_manager = torch.save(ckpt, args.ckpt + '/ckpt.pth') - """ckpt = torch.load('~/ckpt.pth') - model = ckpt['network']""" - model.load_state_dict(torch.load('/home/achapin/ckpt_1639.pth')["model_state_dict"]) + model.eval() + #### Visualize image image = torch.squeeze(next(iter(vis_loader)).get('input_images').to(device), dim=1) image = F.interpolate(image, size=128) image = image.to(device) - recon_combined, recons, masks, slots = model(image) + recon_combined, recons, masks, _, _ = model(image) loss = nn.MSELoss() - input_image = image.permute(0, 2, 3, 1) - loss_value = loss(recon_combined, input_image) + loss_value = loss(recon_combined, image) psnr = mse2psnr(loss_value) - print(f"MSE value : {loss_value} VS PSNR {psnr}") - visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=1639, save_file=True) + visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=0, save_file=True) if __name__ == "__main__": main() \ No newline at end of file