diff --git a/osrt/data/__init__.py b/osrt/data/__init__.py index 428145ed318953710fabcc7f2fc677d64ea20356..96d64f2bae157df08fdac791a9e16baf5eedc92a 100644 --- a/osrt/data/__init__.py +++ b/osrt/data/__init__.py @@ -1,5 +1,5 @@ from osrt.data.core import get_dataset, worker_init_fn from osrt.data.nmr import NMRDataset from osrt.data.multishapenet import MultishapenetDataset -from osrt.data.obsurf import Clevr3dDataset +from osrt.data.obsurf import Clevr3dDataset, Clevr2dDataset diff --git a/osrt/data/core.py b/osrt/data/core.py index 81327f04901a8384fcdba2be9b4da7b09e8c1c0e..3b8886ebe8c87662e814f8908a085cd560fae565 100644 --- a/osrt/data/core.py +++ b/osrt/data/core.py @@ -43,6 +43,8 @@ def get_dataset(mode, cfg, max_len=None, full_scale=False): dataset = data.Clevr3dDataset(dataset_folder, mode, points_per_item=points_per_item, shapenet=False, max_len=max_len, full_scale=full_scale, canonical_view=canonical_view, **kwargs) + elif dataset_type == 'clevr2d': + dataset = data.Clevr2dDataset(dataset_folder, mode, **kwargs) elif dataset_type == 'obsurf_msn': dataset = data.Clevr3dDataset(dataset_folder, mode, points_per_item=points_per_item, shapenet=True, max_len=max_len, full_scale=full_scale, diff --git a/osrt/data/obsurf.py b/osrt/data/obsurf.py index 4a93013db5236f76161fdaf4dc6147609ba67294..a4a4ebf9d869be9b73a51f470ce854dd0497738a 100644 --- a/osrt/data/obsurf.py +++ b/osrt/data/obsurf.py @@ -1,11 +1,15 @@ import numpy as np import imageio -import yaml + from torch.utils.data import Dataset +import torchvision.transforms as transforms +import torch +from PIL import Image import os from osrt.utils.nerf import get_camera_rays, get_extrinsic, transform_points +import torch.nn.functional as F def downsample(x, num_steps=1): @@ -14,6 +18,24 @@ def downsample(x, num_steps=1): stride = 2**num_steps return x[stride//2::stride, stride//2::stride] +def crop_center(image, crop_size=192): + height, width = image.shape[:2] + + center_x = width // 2 + center_y = height // 2 + + crop_size_half = crop_size // 2 + + # Calculate the top-left corner coordinates of the crop + crop_x1 = center_x - crop_size_half + crop_y1 = center_y - crop_size_half + + # Calculate the bottom-right corner coordinates of the crop + crop_x2 = center_x + crop_size_half + crop_y2 = center_y + crop_size_half + + # Crop the image + return image[crop_y1:crop_y2, crop_x1:crop_x2] class Clevr3dDataset(Dataset): def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True, @@ -93,6 +115,7 @@ class Clevr3dDataset(Dataset): masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8) np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1) + input_image = downsample(imgs[view_idx], num_steps=self.downsample) input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0) @@ -150,6 +173,9 @@ class Clevr3dDataset(Dataset): target_pixels = all_pixels target_masks = all_masks + print(f"Final input_image : {input_images.shape} and type {type(input_images)}") + print(f"Final input_masks : {input_masks.shape} and type {type(input_masks)}") + result = { 'input_images': input_images, # [1, 3, h, w] 'input_camera_pos': input_camera_pos, # [1, 3] @@ -168,3 +194,75 @@ class Clevr3dDataset(Dataset): return result +class Clevr2dDataset(Dataset): + def __init__(self, path, mode, max_objects=6): + """ Loads the dataset used in the ObSuRF paper. + + They may be downloaded at: https://stelzner.github.io/obsurf + Args: + path (str): Path to dataset. + mode (str): 'train', 'val', or 'test'. + full_scale (bool): Return all available target points, instead of sampling. + max_objects (int): Load only scenes with at most this many objects. + shapenet (bool): Load ObSuRF's MultiShapeNet dataset, instead of CLEVR3D. + downsample (int): Downsample height and width of input image by a factor of 2**downsample + """ + self.path = path.replace("clevr2d", "clevr3d") + self.mode = mode + self.max_objects = max_objects + + self.max_num_entities = 11 + + self.start_idx, self.end_idx = {'train': (0, 70000), + 'val': (70000, 75000), + 'test': (85000, 100000)}[mode] + + self.metadata = np.load(os.path.join(self.path, 'metadata.npz')) + self.metadata = {k: v for k, v in self.metadata.items()} + + num_objs = (self.metadata['shape'][self.start_idx:self.end_idx] > 0).sum(1) + + self.idxs = np.arange(self.start_idx, self.end_idx)[num_objs <= max_objects] + + dataset_name = 'CLEVR' + print(f'Initialized {dataset_name} {mode} set, {len(self.idxs)} examples') + print(self.idxs) + + + def __len__(self): + return len(self.idxs) + + def __getitem__(self, idx, noisy=True): + scene_idx = idx % len(self.idxs) + scene_idx = self.idxs[scene_idx] + + img_path = os.path.join(self.path, 'images', f'img_{scene_idx}_0.png') + img = np.asarray(imageio.imread(img_path)) + img = img[..., :3].astype(np.float32) / 255 + + input_image = crop_center(img, 192) + input_image = F.interpolate(torch.tensor(input_image).permute(2, 0, 1).unsqueeze(0), size=128) + input_image = input_image.squeeze(0).permute(1, 2, 0) + + + mask_path = os.path.join(self.path, 'masks', f'masks_{scene_idx}_0.png') + mask_idxs = imageio.imread(mask_path) + + masks = np.zeros((240, 320, self.max_num_entities), dtype=np.uint8) + + np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1) + + metadata = {k: v[scene_idx] for (k, v) in self.metadata.items()} + + input_masks = crop_center(torch.tensor(masks), 192) + input_masks = F.interpolate(input_masks.permute(2, 0, 1).unsqueeze(0), size=128) + input_masks = input_masks.squeeze(0).permute(1, 2, 0) + + result = { + 'input_image': input_image, # [3, h, w] + 'input_masks': input_masks, # [h, w, self.max_num_entities] + 'sceneid': idx, # int + } + + return result + diff --git a/osrt/data/ycb-video.py b/osrt/data/ycb-video.py index aa1a1b88e8f6f4ccde51e213969cfaec3018efaf..615a13f132464ee50c9ca2f324439e79c3e52320 100644 --- a/osrt/data/ycb-video.py +++ b/osrt/data/ycb-video.py @@ -134,7 +134,7 @@ def downsample(x, num_steps=1): return x[stride//2::stride, stride//2::stride] -class YCBVideo(Dataset): +class YCBVideo3D(Dataset): def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True, max_len=None, full_scale=False, shapenet=False, downsample=None): """ Loads the YCB-Video dataset that we have adapted. @@ -274,4 +274,99 @@ class YCBVideo(Dataset): return result +class YCBVideo2D(Dataset): + def __init__(self, path, mode, downsample=None): + """ Loads the YCB-Video dataset that we have adapted. + + Args: + path (str): Path to dataset. + mode (str): 'train', 'val', or 'test'. + downsample (int): Downsample height and width of input image by a factor of 2**downsample + """ + self.path = path + self.mode = mode + self.downsample = downsample + + self.max_num_entities = 21 # max number of objects in a scene + + # TODO : set the right number here + + dataset_name = 'YCB-Video' + print(f'Initialized {dataset_name} {mode} set') + + + def __len__(self): + return len(self.idxs) + + def __getitem__(self, idx): + + imgs = [np.asarray(imageio.imread( + os.path.join(self.path, 'images', f'img_{scene_idx}_{v}.png'))) + for v in range(self.num_views)] + + imgs = [img[..., :3].astype(np.float32) / 255 for img in imgs] + + mask_idxs = [imageio.imread(os.path.join(self.path, 'masks', f'masks_{scene_idx}_{v}.png')) + for v in range(self.num_views)] + masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8) + np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1) + + input_image = downsample(imgs[view_idx], num_steps=self.downsample) + input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0) + + all_rays = [] + # TODO : find a way to get the camera poses + all_camera_pos = self.metadata['camera_pos'][:self.num_views].astype(np.float32) + all_camera_rot= self.metadata['camera_rot'][:self.num_views].astype(np.float32) + for i in range(self.num_views): + cur_rays = get_camera_rays(all_camera_pos[i], all_camera_rot[i], noisy=False) # TODO : adapt function + all_rays.append(cur_rays) + all_rays = np.stack(all_rays, 0).astype(np.float32) + + input_camera_pos = all_camera_pos[view_idx] + + if self.canonical: + track_point = np.zeros_like(input_camera_pos) # All cameras are pointed at the origin + canonical_extrinsic = get_extrinsic(input_camera_pos, track_point=track_point) # TODO : adapt function + canonical_extrinsic = canonical_extrinsic.astype(np.float32) + all_rays = transform_points(all_rays, canonical_extrinsic, translate=False) # TODO : adapt function + all_camera_pos = transform_points(all_camera_pos, canonical_extrinsic) + input_camera_pos = all_camera_pos[view_idx] + + input_rays = all_rays[view_idx] + input_rays = downsample(input_rays, num_steps=self.downsample) + input_rays = np.expand_dims(input_rays, 0) + + input_masks = masks[view_idx] + input_masks = downsample(input_masks, num_steps=self.downsample) + input_masks = np.expand_dims(input_masks, 0) + + input_camera_pos = np.expand_dims(input_camera_pos, 0) + + all_pixels = np.reshape(np.stack(imgs, 0), (self.num_views * 240 * 320, 3)) + all_rays = np.reshape(all_rays, (self.num_views * 240 * 320, 3)) + all_camera_pos = np.tile(np.expand_dims(all_camera_pos, 1), (1, 240 * 320, 1)) + all_camera_pos = np.reshape(all_camera_pos, (self.num_views * 240 * 320, 3)) + all_masks = np.reshape(masks, (self.num_views * 240 * 320, self.max_num_entities)) + + target_camera_pos = all_camera_pos + target_pixels = all_pixels + target_masks = all_masks + + result = { + 'input_images': input_images, # [1, 3, h, w] + 'input_camera_pos': input_camera_pos, # [1, 3] + 'input_rays': input_rays, # [1, h, w, 3] + 'input_masks': input_masks, # [1, h, w, self.max_num_entities] + 'target_pixels': target_pixels, # [p, 3] + 'target_camera_pos': target_camera_pos, # [p, 3] + 'target_masks': target_masks, # [p, self.max_num_entities] + 'sceneid': idx, # int + } + + if self.canonical: + result['transform'] = canonical_extrinsic # [3, 4] (optional) + + return result + diff --git a/osrt/model.py b/osrt/model.py index 24842c9d926bda12ab66cfb52bbe0d99c2fcd0e7..b1ede9301008f67bfe361b9d842d71ec0222eb88 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -4,6 +4,7 @@ from torch import nn import torch import torch.nn.functional as F import torch.optim as optim +import torch.optim.lr_scheduler as lr_scheduler import numpy as np @@ -11,7 +12,7 @@ from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder from osrt.layers import SlotAttention, TransformerSlotAttention, Encoder, Decoder import osrt.layers as layers -from osrt.utils.common import mse2psnr +from osrt.utils.common import mse2psnr, compute_adjusted_rand_index import lightning as pl @@ -69,6 +70,11 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): self.encoder = Encoder() self.decoder = Decoder() + self.peak_it = cfg["training"]["warmup_it"] + self.peak_lr = 1e-4 + self.decay_rate = 0.16 + self.decay_it = cfg["training"]["decay_it"] + model_type = cfg['model']['model_type'] if model_type == 'sa': self.slot_attention = SlotAttention( @@ -87,24 +93,23 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model def forward(self, image): - x = self.encoder(image) - slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2)) - x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1) - - x = self.decoder(x) - x = F.interpolate(x, image.shape[-2:], mode='bilinear') - - x = x.unflatten(0, (len(image), len(x) // len(image))) - - recons, masks = x.split((3, 1), dim = 2) - masks = masks.softmax(dim = 1) - recon_combined = (recons * masks).sum(dim = 1) - - return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None + return self.one_step(image) def configure_optimizers(self) -> Any: - optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08) - return optimizer + def lr_func(it): + if it < self.peak_it: # Warmup period + return self.peak_lr * (it / self.peak_it) + it_since_peak = it - self.peak_it + return self.peak_lr * (self.decay_rate ** (it_since_peak / self.decay_it)) + optimizer = optim.Adam(self.parameters(), lr=0) + scheduler = optim.LambdaLR(optimizer, lr_lambda=lr_func) + return { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': scheduler, + 'interval': 'step' # Update the scheduler at each step + } + } def one_step(self, image): x = self.encoder(image) @@ -126,18 +131,19 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): def training_step(self, batch, batch_idx): """Perform a single training step.""" - input_image = torch.squeeze(batch.get('input_images'), dim=1) + input_image = torch.squeeze(batch.get('input_images'), dim=1) # Delete dim 1 if only one view input_image = F.interpolate(input_image, size=128) # Get the prediction of the model and compute the loss. preds = self.one_step(input_image) recon_combined, recons, masks, slots, _ = preds - #input_image = input_image.permute(0, 2, 3, 1) loss_value = self.criterion(recon_combined, input_image) del recons, masks, slots # Unused. + fg_ari = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:], + pred_seg.transpose(1, 2)) self.log('train_mse', loss_value, on_epoch=True) - + return {'loss': loss_value} def validation_step(self, batch, batch_idx): @@ -148,13 +154,12 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): # Get the prediction of the model and compute the loss. preds = self.one_step(input_image) recon_combined, recons, masks, slots, _ = preds - #input_image = input_image.permute(0, 2, 3, 1) loss_value = self.criterion(recon_combined, input_image) del recons, masks, slots # Unused. psnr = mse2psnr(loss_value) self.log('val_mse', loss_value) - self.log('val_psnr', psnr) - + self.log('s', psnr) + self.print(f"Validation metrics, MSE: {loss_value} PSNR: {psnr}") return {'loss': loss_value, 'val_psnr': psnr.item()} def test_step(self, batch, batch_idx): @@ -165,7 +170,6 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule): # Get the prediction of the model and compute the loss. preds = self.one_step(input_image) recon_combined, recons, masks, slots, _ = preds - #input_image = input_image.permute(0, 2, 3, 1) loss_value = self.criterion(recon_combined, input_image) del recons, masks, slots # Unused. psnr = mse2psnr(loss_value) diff --git a/outputs/visualisation_10.png b/outputs/visualisation_10.png new file mode 100644 index 0000000000000000000000000000000000000000..b0c2093c157c26767aa27566db0e917f9b4d9184 Binary files /dev/null and b/outputs/visualisation_10.png differ diff --git a/outputs/visualisation_2.png b/outputs/visualisation_2.png new file mode 100644 index 0000000000000000000000000000000000000000..97fa47ad3757362be19f32b08b647a3e2da265dd Binary files /dev/null and b/outputs/visualisation_2.png differ diff --git a/outputs/visualisation_3.png b/outputs/visualisation_3.png new file mode 100644 index 0000000000000000000000000000000000000000..c5cf616373e92e744c64439f228bf0ea79ce6a0c Binary files /dev/null and b/outputs/visualisation_3.png differ diff --git a/outputs/visualisation_4.png b/outputs/visualisation_4.png new file mode 100644 index 0000000000000000000000000000000000000000..94be3e3e76c2bc7184ba032ddc690506402f08b8 Binary files /dev/null and b/outputs/visualisation_4.png differ diff --git a/outputs/visualisation_5.png b/outputs/visualisation_5.png new file mode 100644 index 0000000000000000000000000000000000000000..a3f565c9df06d067e0c4ba19cfa19854e92c6636 Binary files /dev/null and b/outputs/visualisation_5.png differ diff --git a/outputs/visualisation_6.png b/outputs/visualisation_6.png new file mode 100644 index 0000000000000000000000000000000000000000..aa4283d62f1cefcc89d59aa2253ab3363c7b9f96 Binary files /dev/null and b/outputs/visualisation_6.png differ diff --git a/outputs/visualisation_7.png b/outputs/visualisation_7.png new file mode 100644 index 0000000000000000000000000000000000000000..8947de651e61d0ad5559c3cd120febd0c7a495d2 Binary files /dev/null and b/outputs/visualisation_7.png differ diff --git a/quick_test.py b/quick_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a3aeaa5f7229cedeed6c670def8191a4ba76ca49 --- /dev/null +++ b/quick_test.py @@ -0,0 +1,19 @@ +import yaml +from osrt import data +from torch.utils.data import DataLoader +import matplotlib.pyplot as plt + +with open("runs/clevr/slot_att/config.yaml", 'r') as f: + cfg = yaml.load(f, Loader=yaml.CLoader) + + +train_dataset = data.get_dataset('train', cfg['data']) +train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0,shuffle=True) + +for val in train_loader: + fig, axes = plt.subplots(2, 2) + axes[0][0].imshow(val['input_image'][0]) + axes[0][1].imshow(val['input_masks'][0][:, :, 0]) + axes[1][0].imshow(val['input_image'][1]) + axes[1][1].imshow(val['input_masks'][1][:, :, 0]) + plt.show() \ No newline at end of file diff --git a/runs/clevr/slot_att/config.yaml b/runs/clevr/slot_att/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38fecb2fc2259efd23d100888a35cecb3a23e50a --- /dev/null +++ b/runs/clevr/slot_att/config.yaml @@ -0,0 +1,21 @@ +data: + dataset: clevr2d +model: + num_slots: 10 + iters: 3 + model_type: sa +training: + num_workers: 2 + num_gpus: 1 + batch_size: 8 + max_it: 333000000 + warmup_it: 10000 + decay_rate: 0.5 + decay_it: 100000 + validate_every: 5000 + checkpoint_every: 1000 + print_every: 10 + visualize_every: 5000 + backup_every: 25000 + lr_warmup: 5000 + diff --git a/runs/clevr/slot_att/config_tsa.yaml b/runs/clevr/slot_att/config_tsa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d6f40290c2f0dab960c4219647614be3f9f111a1 --- /dev/null +++ b/runs/clevr/slot_att/config_tsa.yaml @@ -0,0 +1,21 @@ +data: + dataset: clevr2d +model: + num_slots: 10 + iters: 3 + model_type: tsa +training: + num_workers: 2 + num_gpus: 1 + batch_size: 32 + max_it: 333000000 + warmup_it: 10000 + decay_rate: 0.5 + decay_it: 100000 + lr_warmup: 5000 + print_every: 10 + visualize_every: 5000 + validate_every: 5000 + checkpoint_every: 1000 + backup_every: 25000 + diff --git a/runs/clevr3d/osrt/config.yaml b/runs/clevr3d/osrt/config.yaml deleted file mode 100644 index f11b301b7615bc2ce72702e67d8d7ab619c4f17a..0000000000000000000000000000000000000000 --- a/runs/clevr3d/osrt/config.yaml +++ /dev/null @@ -1,28 +0,0 @@ - -data: - dataset: clevr3d - num_points: 2000 - kwargs: - downsample: 1 -model: - encoder: osrt - encoder_kwargs: - pos_start_octave: -5 - num_slots: 6 - decoder: slot_mixer - decoder_kwargs: - pos_start_octave: -5 -training: - num_workers: 2 - batch_size: 64 - model_selection_metric: psnr - model_selection_mode: maximize - print_every: 10 - visualize_every: 5000 - validate_every: 5000 - checkpoint_every: 1000 - backup_every: 25000 - max_it: 333000000 - decay_it: 4000000 - lr_warmup: 5000 - diff --git a/runs/clevr3d/slot_att/config.yaml b/runs/clevr3d/slot_att/config.yaml index 82c0f8e78d858da06e976697291a7189012fba51..4dceb22aa530ae0ffb6dfe5859ced635bf73c29c 100644 --- a/runs/clevr3d/slot_att/config.yaml +++ b/runs/clevr3d/slot_att/config.yaml @@ -12,4 +12,10 @@ training: warmup_it: 10000 decay_rate: 0.5 decay_it: 100000 + validate_every: 5000 + checkpoint_every: 1000 + print_every: 10 + visualize_every: 5000 + backup_every: 25000 + lr_warmup: 5000 diff --git a/runs/clevr3d/slot_att/config_tsa.yaml b/runs/clevr3d/slot_att/config_tsa.yaml index f03fc97062dd583a0884019890f8fbd0778280b1..6255ab06539a5a7f63af67fc6ebd2cafaaf922e6 100644 --- a/runs/clevr3d/slot_att/config_tsa.yaml +++ b/runs/clevr3d/slot_att/config_tsa.yaml @@ -12,4 +12,10 @@ training: warmup_it: 10000 decay_rate: 0.5 decay_it: 100000 + lr_warmup: 5000 + print_every: 10 + visualize_every: 5000 + validate_every: 5000 + checkpoint_every: 1000 + backup_every: 25000 diff --git a/train_sa.py b/train_sa.py index 79379b06054071619dc1fddbb3df246f51744b8c..86de00719beb8e2fb6b4db8600220c7aeba8acaf 100644 --- a/train_sa.py +++ b/train_sa.py @@ -1,5 +1,3 @@ -import datetime -import time import torch import torch.nn as nn import argparse @@ -12,12 +10,18 @@ from osrt.utils.common import mse2psnr from torch.utils.data import DataLoader import torch.nn.functional as F -from tqdm import tqdm +import os import lightning as pl from lightning.pytorch.loggers.wandb import WandbLogger from lightning.pytorch.callbacks import ModelCheckpoint +import warnings +from lightning.pytorch.utilities.warnings import PossibleUserWarning +from lightning.pytorch.callbacks.early_stopping import EarlyStopping + +# Ignore all warnings that could be false positives : cf https://lightning.ai/docs/pytorch/stable/advanced/speed.html +warnings.filterwarnings("ignore", category=PossibleUserWarning) def main(): # Arguments @@ -42,21 +46,21 @@ def main(): num_slots = cfg["model"]["num_slots"] num_iterations = cfg["model"]["iters"] num_train_steps = cfg["training"]["max_it"] - warmup_steps = cfg["training"]["warmup_it"] - decay_rate = cfg["training"]["decay_rate"] - decay_steps = cfg["training"]["decay_it"] resolution = (128, 128) + print(f"Number of CPU Cores : {os.cpu_count()}") + #### Create datasets train_dataset = data.get_dataset('train', cfg['data']) + val_every = val_every // len(train_dataset) train_loader = DataLoader( - train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"], - shuffle=True, worker_init_fn=data.worker_init_fn) + train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"]-9, + shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True) val_dataset = data.get_dataset('val', cfg['data']) val_loader = DataLoader( - val_dataset, batch_size=batch_size, num_workers=1, - shuffle=True, worker_init_fn=data.worker_init_fn) + val_dataset, batch_size=batch_size, num_workers=8, + shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True) #### Create model model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg) @@ -65,26 +69,35 @@ def main(): checkpoint = torch.load(args.ckpt) model.load_state_dict(checkpoint['state_dict']) - checkpoint_callback = ModelCheckpoint( - save_top_k=10, monitor="val_psnr", mode="max", - dirpath="./checkpoints" if cfg["model"]["model_type"] == "sa" else "./checkpoints_tsa", - filename="slot_att-clevr3d-{epoch:02d}-psnr{val_psnr:.2f}.pth", + dirpath="./checkpoints", + filename="ckpt-" + str(cfg["data"]["dataset"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}", + save_weights_only=True, # don't save optimizer states nor lr-scheduler, ... + every_n_train_steps=cfg["training"]["checkpoint_every"] ) - trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple", - default_root_dir="./logs", logger=WandbLogger(project="slot-att") if args.wandb else None, - strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", callbacks=[checkpoint_callback], - log_every_n_steps=100, max_steps=num_train_steps) + early_stopping = EarlyStopping(monitor="val_psnr", mode="max") + + trainer = pl.Trainer(accelerator="gpu", + devices=num_gpus, + profiler="simple", + default_root_dir="./logs", + logger=WandbLogger(project="slot-att") if args.wandb else None, + strategy="ddp" if num_gpus > 1 else "auto", + callbacks=[checkpoint_callback, early_stopping], + log_every_n_steps=100, + val_check_interval=cfg["training"]["validate_every"], + max_steps=num_train_steps, + enable_model_summary=True) trainer.fit(model, train_loader, val_loader) #### Create datasets vis_dataset = data.get_dataset('train', cfg['data']) vis_loader = DataLoader( - vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"], + vis_dataset, batch_size=1, num_workers=1, shuffle=True, worker_init_fn=data.worker_init_fn) device = model.device @@ -100,21 +113,4 @@ def main(): visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=0, save_file=True) if __name__ == "__main__": - main() - - -#print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}") - -""" - -if not epoch % cfg["training"]["checkpoint_every"]: - # Save the checkpoint of the model. - ckpt['global_step'] = global_step - ckpt['model_state_dict'] = model.state_dict() - torch.save(ckpt, args.ckpt + '/ckpt_' + str(global_step) + '.pth') - print(f"Saved checkpoint: {args.ckpt + '/ckpt_' + str(global_step) + '.pth'}") - - # We visualize some test data - if not epoch % cfg["training"]["visualize_every"]: - -""" \ No newline at end of file + main() \ No newline at end of file diff --git a/visualize_sa.py b/visualize_sa.py index 2e4031d46c69a38c0a82a1451ac90f6f7b52be18..b9e78d372f5cdacbefad2190120a5a99db0c4945 100644 --- a/visualize_sa.py +++ b/visualize_sa.py @@ -26,7 +26,7 @@ def main(): parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') parser.add_argument('--seed', type=int, default=0, help='Random seed.') parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') - parser.add_argument('--output', type=str, default="./outputs", help='Folder in which to save images') + parser.add_argument('--output', type=str, default="./outputs/", help='Folder in which to save images') parser.add_argument('--step', type=int, default=".", help='Step of the model') args = parser.parse_args() @@ -44,7 +44,7 @@ def main(): resolution = (128, 128) #### Create datasets - vis_dataset = data.get_dataset('train', cfg['data']) + vis_dataset = data.get_dataset('val', cfg['data']) vis_loader = DataLoader( vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"], shuffle=True, worker_init_fn=data.worker_init_fn)