diff --git a/Nonevisualisation_0.png b/Nonevisualisation_0.png new file mode 100644 index 0000000000000000000000000000000000000000..59df42cecd44a6e04b3cc578d596a74d54a5cc10 Binary files /dev/null and b/Nonevisualisation_0.png differ diff --git a/osrt/data/obsurf.py b/osrt/data/obsurf.py index a4a4ebf9d869be9b73a51f470ce854dd0497738a..074ac76600a9597bb6f9ab1f00b208917479e969 100644 --- a/osrt/data/obsurf.py +++ b/osrt/data/obsurf.py @@ -173,8 +173,6 @@ class Clevr3dDataset(Dataset): target_pixels = all_pixels target_masks = all_masks - print(f"Final input_image : {input_images.shape} and type {type(input_images)}") - print(f"Final input_masks : {input_masks.shape} and type {type(input_masks)}") result = { 'input_images': input_images, # [1, 3, h, w] @@ -212,6 +210,7 @@ class Clevr2dDataset(Dataset): self.max_objects = max_objects self.max_num_entities = 11 + self.rescale = 128 self.start_idx, self.end_idx = {'train': (0, 70000), 'val': (70000, 75000), @@ -241,9 +240,8 @@ class Clevr2dDataset(Dataset): img = img[..., :3].astype(np.float32) / 255 input_image = crop_center(img, 192) - input_image = F.interpolate(torch.tensor(input_image).permute(2, 0, 1).unsqueeze(0), size=128) - input_image = input_image.squeeze(0).permute(1, 2, 0) - + input_image = F.interpolate(torch.tensor(input_image).permute(2, 0, 1).unsqueeze(0), size=self.rescale) + input_image = input_image.squeeze(0) mask_path = os.path.join(self.path, 'masks', f'masks_{scene_idx}_0.png') mask_idxs = imageio.imread(mask_path) @@ -257,10 +255,12 @@ class Clevr2dDataset(Dataset): input_masks = crop_center(torch.tensor(masks), 192) input_masks = F.interpolate(input_masks.permute(2, 0, 1).unsqueeze(0), size=128) input_masks = input_masks.squeeze(0).permute(1, 2, 0) + target_masks = np.reshape(input_masks, (self.rescale*self.rescale, self.max_num_entities)) result = { - 'input_image': input_image, # [3, h, w] + 'input_images': input_image, # [3, h, w] 'input_masks': input_masks, # [h, w, self.max_num_entities] + 'target_masks': target_masks, # [h*w, self.max_num_entities] 'sceneid': idx, # int } diff --git a/osrt/model.py b/osrt/model.py index b1ede9301008f67bfe361b9d842d71ec0222eb88..aac7276fe9af96eb47afa5dc5959bccf90583dda 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -4,7 +4,7 @@ from torch import nn import torch import torch.nn.functional as F import torch.optim as optim -import torch.optim.lr_scheduler as lr_scheduler +from torch.optim.lr_scheduler import LambdaLR import numpy as np @@ -52,18 +52,17 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): Implementation inspired from official repo : https://github.com/google-research/google-research/blob/master/slot_attention/model.py """ - def __init__(self, resolution, num_slots, num_iterations, cfg): + def __init__(self, resolution, num_slots, cfg): """Builds the Slot Attention-based auto-encoder. Args: resolution: Tuple of integers specifying width and height of input image. num_slots: Number of slots in Slot Attention. - num_iterations: Number of iterations in Slot Attention. """ super().__init__() self.resolution = resolution self.num_slots = num_slots - self.num_iterations = num_iterations + self.num_iterations = cfg["model"]["iters"] self.criterion = nn.MSELoss() @@ -79,17 +78,17 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): if model_type == 'sa': self.slot_attention = SlotAttention( num_slots=self.num_slots, - input_dim=64, - slot_dim=64, - hidden_dim=128, + input_dim=cfg["model"]["input_dim"], + slot_dim=cfg["model"]["slot_dim"], + hidden_dim=cfg["model"]["hidden_dim"], iters=self.num_iterations) elif model_type == 'tsa': # We set the same number of inside parameters self.slot_attention = TransformerSlotAttention( num_slots=self.num_slots, - input_dim=64, - slot_dim=64, - hidden_dim=128, + input_dim=cfg["model"]["input_dim"], + slot_dim=cfg["model"]["slot_dim"], + hidden_dim=cfg["model"]["hidden_dim"], depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model def forward(self, image): @@ -102,7 +101,7 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): it_since_peak = it - self.peak_it return self.peak_lr * (self.decay_rate ** (it_since_peak / self.decay_it)) optimizer = optim.Adam(self.parameters(), lr=0) - scheduler = optim.LambdaLR(optimizer, lr_lambda=lr_func) + scheduler = LambdaLR(optimizer, lr_lambda=lr_func) return { 'optimizer': optimizer, 'lr_scheduler': { @@ -124,57 +123,80 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): recons, masks = x.split((3, 1), dim = 2) masks = masks.softmax(dim = 1) - recon_combined = (recons * masks).sum(dim = 1) - + recon_combined = (recons * masks).sum(dim = 1) return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, attn_shape) if attn_slotwise is not None else None def training_step(self, batch, batch_idx): """Perform a single training step.""" input_image = torch.squeeze(batch.get('input_images'), dim=1) # Delete dim 1 if only one view - input_image = F.interpolate(input_image, size=128) + true_masks = torch.squeeze(batch.get('target_masks'), dim=1) # Get the prediction of the model and compute the loss. preds = self.one_step(input_image) recon_combined, recons, masks, slots, _ = preds loss_value = self.criterion(recon_combined, input_image) - del recons, masks, slots # Unused. + del recons, slots # Unused. + masks = masks.squeeze(2) + B, num_ob, H, W = masks.shape + masks = torch.reshape(masks, (B, num_ob, H*W)) + fg_ari = compute_adjusted_rand_index(true_masks.transpose(1, 2)[:, 1:], + masks).mean() + del masks + psnr = mse2psnr(loss_value) - fg_ari = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:], - pred_seg.transpose(1, 2)) self.log('train_mse', loss_value, on_epoch=True) - + self.log('train_fg_ari', fg_ari, on_epoch=True) + self.log('train_psnr', psnr, on_epoch=True) + + #self.print(f"Training metrics, MSE: {loss_value}, FG-ARI: {fg_ari}, PSNR: {psnr.item()}") return {'loss': loss_value} def validation_step(self, batch, batch_idx): """Perform a single eval step.""" input_image = torch.squeeze(batch.get('input_images'), dim=1) - input_image = F.interpolate(input_image, size=128) + true_masks = torch.squeeze(batch.get('target_masks'), dim=1) # Get the prediction of the model and compute the loss. preds = self.one_step(input_image) recon_combined, recons, masks, slots, _ = preds loss_value = self.criterion(recon_combined, input_image) - del recons, masks, slots # Unused. + del recons, slots # Unused. + masks = masks.squeeze(2) + B, num_ob, H, W = masks.shape + masks = torch.reshape(masks, (B, num_ob, H*W)) + fg_ari = compute_adjusted_rand_index(true_masks.transpose(1, 2)[:, 1:], + masks).mean() + del masks psnr = mse2psnr(loss_value) - self.log('val_mse', loss_value) - self.log('s', psnr) - self.print(f"Validation metrics, MSE: {loss_value} PSNR: {psnr}") + self.log('val_mse', loss_value, on_epoch=True) + self.log('val_fg_ari', fg_ari, on_epoch=True) + self.log('val_psnr', psnr, on_epoch=True) + + #self.print(f"Validation metrics, MSE: {loss_value}, FG-ARI: {fg_ari}, PSNR: {psnr.item()}") return {'loss': loss_value, 'val_psnr': psnr.item()} def test_step(self, batch, batch_idx): """Perform a single eval step.""" input_image = torch.squeeze(batch.get('input_images'), dim=1) - input_image = F.interpolate(input_image, size=128) + true_masks = torch.squeeze(batch.get('target_masks'), dim=1) # Get the prediction of the model and compute the loss. preds = self.one_step(input_image) recon_combined, recons, masks, slots, _ = preds loss_value = self.criterion(recon_combined, input_image) - del recons, masks, slots # Unused. + del recons, slots # Unused. + masks = masks.squeeze(2) + B, num_ob, H, W = masks.shape + masks = torch.reshape(masks, (B, num_ob, H*W)) + fg_ari = compute_adjusted_rand_index(true_masks.transpose(1, 2)[:, 1:], + masks).mean() + del masks psnr = mse2psnr(loss_value) self.log('test_loss', loss_value) + self.log('test_fg_ari', fg_ari) self.log('test_psnr', psnr) + #self.print(f"Test metrics, MSE: {loss_value}, FG-ARI: {fg_ari}, PSNR: {psnr.item()}") return {'loss': loss_value, 'test_psnr': psnr.item()} \ No newline at end of file diff --git a/quick_test.py b/quick_test.py index a3aeaa5f7229cedeed6c670def8191a4ba76ca49..d63417171cf34291ea8a66fbee23435b2a972d00 100644 --- a/quick_test.py +++ b/quick_test.py @@ -11,6 +11,7 @@ train_dataset = data.get_dataset('train', cfg['data']) train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0,shuffle=True) for val in train_loader: + print(f"Shape masks {val['input_masks'].shape}") fig, axes = plt.subplots(2, 2) axes[0][0].imshow(val['input_image'][0]) axes[0][1].imshow(val['input_masks'][0][:, :, 0]) diff --git a/runs/clevr/slot_att/config.yaml b/runs/clevr/slot_att/config.yaml index 38fecb2fc2259efd23d100888a35cecb3a23e50a..d53f13d186af8e2157e0ae1e0984ec58c8a9642b 100644 --- a/runs/clevr/slot_att/config.yaml +++ b/runs/clevr/slot_att/config.yaml @@ -4,18 +4,21 @@ model: num_slots: 10 iters: 3 model_type: sa + input_dim: 64 + slot_dim: 64 + hidden_dim: 128 + iters: 3 training: num_workers: 2 num_gpus: 1 batch_size: 8 max_it: 333000000 warmup_it: 10000 + lr_warmup: 5000 decay_rate: 0.5 decay_it: 100000 + visualize_every: 5000 validate_every: 5000 checkpoint_every: 1000 - print_every: 10 - visualize_every: 5000 backup_every: 25000 - lr_warmup: 5000 diff --git a/runs/clevr/slot_att/config_tsa.yaml b/runs/clevr/slot_att/config_tsa.yaml index d6f40290c2f0dab960c4219647614be3f9f111a1..5c9592739a292f0557e5386040dc180203095b78 100644 --- a/runs/clevr/slot_att/config_tsa.yaml +++ b/runs/clevr/slot_att/config_tsa.yaml @@ -4,16 +4,19 @@ model: num_slots: 10 iters: 3 model_type: tsa + input_dim: 64 + slot_dim: 64 + hidden_dim: 128 + iters: 3 training: num_workers: 2 num_gpus: 1 - batch_size: 32 + batch_size: 8 max_it: 333000000 warmup_it: 10000 + lr_warmup: 5000 decay_rate: 0.5 decay_it: 100000 - lr_warmup: 5000 - print_every: 10 visualize_every: 5000 validate_every: 5000 checkpoint_every: 1000 diff --git a/train_sa.py b/train_sa.py index 86de00719beb8e2fb6b4db8600220c7aeba8acaf..c1e9daaac966da48fd4cd98ca77e7a3c06e316e3 100644 --- a/train_sa.py +++ b/train_sa.py @@ -46,24 +46,24 @@ def main(): num_slots = cfg["model"]["num_slots"] num_iterations = cfg["model"]["iters"] num_train_steps = cfg["training"]["max_it"] + num_workers = cfg["training"]["num_workers"] resolution = (128, 128) print(f"Number of CPU Cores : {os.cpu_count()}") #### Create datasets train_dataset = data.get_dataset('train', cfg['data']) - val_every = val_every // len(train_dataset) train_loader = DataLoader( - train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"]-9, + train_dataset, batch_size=batch_size, num_workers=num_workers-9 if num_workers > 9 else 0, shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True) val_dataset = data.get_dataset('val', cfg['data']) val_loader = DataLoader( - val_dataset, batch_size=batch_size, num_workers=8, + val_dataset, batch_size=batch_size, num_workers=8 if num_workers > 9 else 0, shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True) #### Create model - model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) + model = LitSlotAttentionAutoEncoder(resolution, num_slots, cfg=cfg) if args.ckpt: checkpoint = torch.load(args.ckpt) @@ -73,7 +73,7 @@ def main(): monitor="val_psnr", mode="max", dirpath="./checkpoints", - filename="ckpt-" + str(cfg["data"]["dataset"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}", + filename="ckpt-" + str(cfg["data"]["dataset"])+ "-slots:"+ str(cfg["model"]["num_slots"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}", save_weights_only=True, # don't save optimizer states nor lr-scheduler, ... every_n_train_steps=cfg["training"]["checkpoint_every"] ) @@ -89,6 +89,7 @@ def main(): callbacks=[checkpoint_callback, early_stopping], log_every_n_steps=100, val_check_interval=cfg["training"]["validate_every"], + check_val_every_n_epoch=None, max_steps=num_train_steps, enable_model_summary=True)