Skip to content
Snippets Groups Projects
Commit 2b07732f authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Add CLEVR 2D dataset

parent 70224686
No related branches found
No related tags found
No related merge requests found
Showing
with 334 additions and 66 deletions
from osrt.data.core import get_dataset, worker_init_fn
from osrt.data.nmr import NMRDataset
from osrt.data.multishapenet import MultishapenetDataset
from osrt.data.obsurf import Clevr3dDataset
from osrt.data.obsurf import Clevr3dDataset, Clevr2dDataset
......@@ -43,6 +43,8 @@ def get_dataset(mode, cfg, max_len=None, full_scale=False):
dataset = data.Clevr3dDataset(dataset_folder, mode, points_per_item=points_per_item,
shapenet=False, max_len=max_len, full_scale=full_scale,
canonical_view=canonical_view, **kwargs)
elif dataset_type == 'clevr2d':
dataset = data.Clevr2dDataset(dataset_folder, mode, **kwargs)
elif dataset_type == 'obsurf_msn':
dataset = data.Clevr3dDataset(dataset_folder, mode, points_per_item=points_per_item,
shapenet=True, max_len=max_len, full_scale=full_scale,
......
import numpy as np
import imageio
import yaml
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import torch
from PIL import Image
import os
from osrt.utils.nerf import get_camera_rays, get_extrinsic, transform_points
import torch.nn.functional as F
def downsample(x, num_steps=1):
......@@ -14,6 +18,24 @@ def downsample(x, num_steps=1):
stride = 2**num_steps
return x[stride//2::stride, stride//2::stride]
def crop_center(image, crop_size=192):
height, width = image.shape[:2]
center_x = width // 2
center_y = height // 2
crop_size_half = crop_size // 2
# Calculate the top-left corner coordinates of the crop
crop_x1 = center_x - crop_size_half
crop_y1 = center_y - crop_size_half
# Calculate the bottom-right corner coordinates of the crop
crop_x2 = center_x + crop_size_half
crop_y2 = center_y + crop_size_half
# Crop the image
return image[crop_y1:crop_y2, crop_x1:crop_x2]
class Clevr3dDataset(Dataset):
def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True,
......@@ -93,6 +115,7 @@ class Clevr3dDataset(Dataset):
masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8)
np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)
input_image = downsample(imgs[view_idx], num_steps=self.downsample)
input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0)
......@@ -150,6 +173,9 @@ class Clevr3dDataset(Dataset):
target_pixels = all_pixels
target_masks = all_masks
print(f"Final input_image : {input_images.shape} and type {type(input_images)}")
print(f"Final input_masks : {input_masks.shape} and type {type(input_masks)}")
result = {
'input_images': input_images, # [1, 3, h, w]
'input_camera_pos': input_camera_pos, # [1, 3]
......@@ -168,3 +194,75 @@ class Clevr3dDataset(Dataset):
return result
class Clevr2dDataset(Dataset):
def __init__(self, path, mode, max_objects=6):
""" Loads the dataset used in the ObSuRF paper.
They may be downloaded at: https://stelzner.github.io/obsurf
Args:
path (str): Path to dataset.
mode (str): 'train', 'val', or 'test'.
full_scale (bool): Return all available target points, instead of sampling.
max_objects (int): Load only scenes with at most this many objects.
shapenet (bool): Load ObSuRF's MultiShapeNet dataset, instead of CLEVR3D.
downsample (int): Downsample height and width of input image by a factor of 2**downsample
"""
self.path = path.replace("clevr2d", "clevr3d")
self.mode = mode
self.max_objects = max_objects
self.max_num_entities = 11
self.start_idx, self.end_idx = {'train': (0, 70000),
'val': (70000, 75000),
'test': (85000, 100000)}[mode]
self.metadata = np.load(os.path.join(self.path, 'metadata.npz'))
self.metadata = {k: v for k, v in self.metadata.items()}
num_objs = (self.metadata['shape'][self.start_idx:self.end_idx] > 0).sum(1)
self.idxs = np.arange(self.start_idx, self.end_idx)[num_objs <= max_objects]
dataset_name = 'CLEVR'
print(f'Initialized {dataset_name} {mode} set, {len(self.idxs)} examples')
print(self.idxs)
def __len__(self):
return len(self.idxs)
def __getitem__(self, idx, noisy=True):
scene_idx = idx % len(self.idxs)
scene_idx = self.idxs[scene_idx]
img_path = os.path.join(self.path, 'images', f'img_{scene_idx}_0.png')
img = np.asarray(imageio.imread(img_path))
img = img[..., :3].astype(np.float32) / 255
input_image = crop_center(img, 192)
input_image = F.interpolate(torch.tensor(input_image).permute(2, 0, 1).unsqueeze(0), size=128)
input_image = input_image.squeeze(0).permute(1, 2, 0)
mask_path = os.path.join(self.path, 'masks', f'masks_{scene_idx}_0.png')
mask_idxs = imageio.imread(mask_path)
masks = np.zeros((240, 320, self.max_num_entities), dtype=np.uint8)
np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)
metadata = {k: v[scene_idx] for (k, v) in self.metadata.items()}
input_masks = crop_center(torch.tensor(masks), 192)
input_masks = F.interpolate(input_masks.permute(2, 0, 1).unsqueeze(0), size=128)
input_masks = input_masks.squeeze(0).permute(1, 2, 0)
result = {
'input_image': input_image, # [3, h, w]
'input_masks': input_masks, # [h, w, self.max_num_entities]
'sceneid': idx, # int
}
return result
......@@ -134,7 +134,7 @@ def downsample(x, num_steps=1):
return x[stride//2::stride, stride//2::stride]
class YCBVideo(Dataset):
class YCBVideo3D(Dataset):
def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True,
max_len=None, full_scale=False, shapenet=False, downsample=None):
""" Loads the YCB-Video dataset that we have adapted.
......@@ -274,4 +274,99 @@ class YCBVideo(Dataset):
return result
class YCBVideo2D(Dataset):
def __init__(self, path, mode, downsample=None):
""" Loads the YCB-Video dataset that we have adapted.
Args:
path (str): Path to dataset.
mode (str): 'train', 'val', or 'test'.
downsample (int): Downsample height and width of input image by a factor of 2**downsample
"""
self.path = path
self.mode = mode
self.downsample = downsample
self.max_num_entities = 21 # max number of objects in a scene
# TODO : set the right number here
dataset_name = 'YCB-Video'
print(f'Initialized {dataset_name} {mode} set')
def __len__(self):
return len(self.idxs)
def __getitem__(self, idx):
imgs = [np.asarray(imageio.imread(
os.path.join(self.path, 'images', f'img_{scene_idx}_{v}.png')))
for v in range(self.num_views)]
imgs = [img[..., :3].astype(np.float32) / 255 for img in imgs]
mask_idxs = [imageio.imread(os.path.join(self.path, 'masks', f'masks_{scene_idx}_{v}.png'))
for v in range(self.num_views)]
masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8)
np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1)
input_image = downsample(imgs[view_idx], num_steps=self.downsample)
input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0)
all_rays = []
# TODO : find a way to get the camera poses
all_camera_pos = self.metadata['camera_pos'][:self.num_views].astype(np.float32)
all_camera_rot= self.metadata['camera_rot'][:self.num_views].astype(np.float32)
for i in range(self.num_views):
cur_rays = get_camera_rays(all_camera_pos[i], all_camera_rot[i], noisy=False) # TODO : adapt function
all_rays.append(cur_rays)
all_rays = np.stack(all_rays, 0).astype(np.float32)
input_camera_pos = all_camera_pos[view_idx]
if self.canonical:
track_point = np.zeros_like(input_camera_pos) # All cameras are pointed at the origin
canonical_extrinsic = get_extrinsic(input_camera_pos, track_point=track_point) # TODO : adapt function
canonical_extrinsic = canonical_extrinsic.astype(np.float32)
all_rays = transform_points(all_rays, canonical_extrinsic, translate=False) # TODO : adapt function
all_camera_pos = transform_points(all_camera_pos, canonical_extrinsic)
input_camera_pos = all_camera_pos[view_idx]
input_rays = all_rays[view_idx]
input_rays = downsample(input_rays, num_steps=self.downsample)
input_rays = np.expand_dims(input_rays, 0)
input_masks = masks[view_idx]
input_masks = downsample(input_masks, num_steps=self.downsample)
input_masks = np.expand_dims(input_masks, 0)
input_camera_pos = np.expand_dims(input_camera_pos, 0)
all_pixels = np.reshape(np.stack(imgs, 0), (self.num_views * 240 * 320, 3))
all_rays = np.reshape(all_rays, (self.num_views * 240 * 320, 3))
all_camera_pos = np.tile(np.expand_dims(all_camera_pos, 1), (1, 240 * 320, 1))
all_camera_pos = np.reshape(all_camera_pos, (self.num_views * 240 * 320, 3))
all_masks = np.reshape(masks, (self.num_views * 240 * 320, self.max_num_entities))
target_camera_pos = all_camera_pos
target_pixels = all_pixels
target_masks = all_masks
result = {
'input_images': input_images, # [1, 3, h, w]
'input_camera_pos': input_camera_pos, # [1, 3]
'input_rays': input_rays, # [1, h, w, 3]
'input_masks': input_masks, # [1, h, w, self.max_num_entities]
'target_pixels': target_pixels, # [p, 3]
'target_camera_pos': target_camera_pos, # [p, 3]
'target_masks': target_masks, # [p, self.max_num_entities]
'sceneid': idx, # int
}
if self.canonical:
result['transform'] = canonical_extrinsic # [3, 4] (optional)
return result
......@@ -4,6 +4,7 @@ from torch import nn
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import numpy as np
......@@ -11,7 +12,7 @@ from osrt.encoder import OSRTEncoder, ImprovedSRTEncoder, FeatureMasking
from osrt.decoder import SlotMixerDecoder, SpatialBroadcastDecoder, ImprovedSRTDecoder
from osrt.layers import SlotAttention, TransformerSlotAttention, Encoder, Decoder
import osrt.layers as layers
from osrt.utils.common import mse2psnr
from osrt.utils.common import mse2psnr, compute_adjusted_rand_index
import lightning as pl
......@@ -69,6 +70,11 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
self.encoder = Encoder()
self.decoder = Decoder()
self.peak_it = cfg["training"]["warmup_it"]
self.peak_lr = 1e-4
self.decay_rate = 0.16
self.decay_it = cfg["training"]["decay_it"]
model_type = cfg['model']['model_type']
if model_type == 'sa':
self.slot_attention = SlotAttention(
......@@ -87,24 +93,23 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
depth=self.num_iterations) # in a way, the depth of the transformer corresponds to the number of iterations in the original model
def forward(self, image):
x = self.encoder(image)
slots, attn_logits, attn_slotwise = self.slot_attention(x.flatten(start_dim = 1, end_dim = 2))
x = slots.reshape(-1, 1, 1, slots.shape[-1]).expand(-1, *self.decoder.decoder_initial_size, -1)
x = self.decoder(x)
x = F.interpolate(x, image.shape[-2:], mode='bilinear')
x = x.unflatten(0, (len(image), len(x) // len(image)))
recons, masks = x.split((3, 1), dim = 2)
masks = masks.softmax(dim = 1)
recon_combined = (recons * masks).sum(dim = 1)
return recon_combined, recons, masks, slots, attn_slotwise.unsqueeze(-2).unflatten(-1, x.shape[-2:]) if attn_slotwise is not None else None
return self.one_step(image)
def configure_optimizers(self) -> Any:
optimizer = optim.Adam(self.parameters(), lr=1e-3, eps=1e-08)
return optimizer
def lr_func(it):
if it < self.peak_it: # Warmup period
return self.peak_lr * (it / self.peak_it)
it_since_peak = it - self.peak_it
return self.peak_lr * (self.decay_rate ** (it_since_peak / self.decay_it))
optimizer = optim.Adam(self.parameters(), lr=0)
scheduler = optim.LambdaLR(optimizer, lr_lambda=lr_func)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'interval': 'step' # Update the scheduler at each step
}
}
def one_step(self, image):
x = self.encoder(image)
......@@ -126,18 +131,19 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx):
"""Perform a single training step."""
input_image = torch.squeeze(batch.get('input_images'), dim=1)
input_image = torch.squeeze(batch.get('input_images'), dim=1) # Delete dim 1 if only one view
input_image = F.interpolate(input_image, size=128)
# Get the prediction of the model and compute the loss.
preds = self.one_step(input_image)
recon_combined, recons, masks, slots, _ = preds
#input_image = input_image.permute(0, 2, 3, 1)
loss_value = self.criterion(recon_combined, input_image)
del recons, masks, slots # Unused.
fg_ari = compute_adjusted_rand_index(true_seg.transpose(1, 2)[:, 1:],
pred_seg.transpose(1, 2))
self.log('train_mse', loss_value, on_epoch=True)
return {'loss': loss_value}
def validation_step(self, batch, batch_idx):
......@@ -148,13 +154,12 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
# Get the prediction of the model and compute the loss.
preds = self.one_step(input_image)
recon_combined, recons, masks, slots, _ = preds
#input_image = input_image.permute(0, 2, 3, 1)
loss_value = self.criterion(recon_combined, input_image)
del recons, masks, slots # Unused.
psnr = mse2psnr(loss_value)
self.log('val_mse', loss_value)
self.log('val_psnr', psnr)
self.log('s', psnr)
self.print(f"Validation metrics, MSE: {loss_value} PSNR: {psnr}")
return {'loss': loss_value, 'val_psnr': psnr.item()}
def test_step(self, batch, batch_idx):
......@@ -165,7 +170,6 @@ class LitSlotAttentionAutoEncoder(pl.LightningModule):
# Get the prediction of the model and compute the loss.
preds = self.one_step(input_image)
recon_combined, recons, masks, slots, _ = preds
#input_image = input_image.permute(0, 2, 3, 1)
loss_value = self.criterion(recon_combined, input_image)
del recons, masks, slots # Unused.
psnr = mse2psnr(loss_value)
......
outputs/visualisation_10.png

66.4 KiB

outputs/visualisation_2.png

48.2 KiB

outputs/visualisation_3.png

48 KiB

outputs/visualisation_4.png

50.8 KiB

outputs/visualisation_5.png

54.5 KiB

outputs/visualisation_6.png

56.2 KiB

outputs/visualisation_7.png

64.6 KiB

import yaml
from osrt import data
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
with open("runs/clevr/slot_att/config.yaml", 'r') as f:
cfg = yaml.load(f, Loader=yaml.CLoader)
train_dataset = data.get_dataset('train', cfg['data'])
train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0,shuffle=True)
for val in train_loader:
fig, axes = plt.subplots(2, 2)
axes[0][0].imshow(val['input_image'][0])
axes[0][1].imshow(val['input_masks'][0][:, :, 0])
axes[1][0].imshow(val['input_image'][1])
axes[1][1].imshow(val['input_masks'][1][:, :, 0])
plt.show()
\ No newline at end of file
data:
dataset: clevr3d
num_points: 2000
kwargs:
downsample: 1
dataset: clevr2d
model:
encoder: osrt
encoder_kwargs:
pos_start_octave: -5
num_slots: 6
decoder: slot_mixer
decoder_kwargs:
pos_start_octave: -5
num_slots: 10
iters: 3
model_type: sa
training:
num_workers: 2
batch_size: 64
model_selection_metric: psnr
model_selection_mode: maximize
print_every: 10
visualize_every: 5000
num_gpus: 1
batch_size: 8
max_it: 333000000
warmup_it: 10000
decay_rate: 0.5
decay_it: 100000
validate_every: 5000
checkpoint_every: 1000
print_every: 10
visualize_every: 5000
backup_every: 25000
max_it: 333000000
decay_it: 4000000
lr_warmup: 5000
data:
dataset: clevr2d
model:
num_slots: 10
iters: 3
model_type: tsa
training:
num_workers: 2
num_gpus: 1
batch_size: 32
max_it: 333000000
warmup_it: 10000
decay_rate: 0.5
decay_it: 100000
lr_warmup: 5000
print_every: 10
visualize_every: 5000
validate_every: 5000
checkpoint_every: 1000
backup_every: 25000
......@@ -12,4 +12,10 @@ training:
warmup_it: 10000
decay_rate: 0.5
decay_it: 100000
validate_every: 5000
checkpoint_every: 1000
print_every: 10
visualize_every: 5000
backup_every: 25000
lr_warmup: 5000
......@@ -12,4 +12,10 @@ training:
warmup_it: 10000
decay_rate: 0.5
decay_it: 100000
lr_warmup: 5000
print_every: 10
visualize_every: 5000
validate_every: 5000
checkpoint_every: 1000
backup_every: 25000
import datetime
import time
import torch
import torch.nn as nn
import argparse
......@@ -12,12 +10,18 @@ from osrt.utils.common import mse2psnr
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
import os
import lightning as pl
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
import warnings
from lightning.pytorch.utilities.warnings import PossibleUserWarning
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
# Ignore all warnings that could be false positives : cf https://lightning.ai/docs/pytorch/stable/advanced/speed.html
warnings.filterwarnings("ignore", category=PossibleUserWarning)
def main():
# Arguments
......@@ -42,21 +46,21 @@ def main():
num_slots = cfg["model"]["num_slots"]
num_iterations = cfg["model"]["iters"]
num_train_steps = cfg["training"]["max_it"]
warmup_steps = cfg["training"]["warmup_it"]
decay_rate = cfg["training"]["decay_rate"]
decay_steps = cfg["training"]["decay_it"]
resolution = (128, 128)
print(f"Number of CPU Cores : {os.cpu_count()}")
#### Create datasets
train_dataset = data.get_dataset('train', cfg['data'])
val_every = val_every // len(train_dataset)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn)
train_dataset, batch_size=batch_size, num_workers=cfg["training"]["num_workers"]-9,
shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
val_dataset = data.get_dataset('val', cfg['data'])
val_loader = DataLoader(
val_dataset, batch_size=batch_size, num_workers=1,
shuffle=True, worker_init_fn=data.worker_init_fn)
val_dataset, batch_size=batch_size, num_workers=8,
shuffle=True, worker_init_fn=data.worker_init_fn, pin_memory=True)
#### Create model
model = LitSlotAttentionAutoEncoder(resolution, num_slots, num_iterations, cfg=cfg)
......@@ -65,26 +69,35 @@ def main():
checkpoint = torch.load(args.ckpt)
model.load_state_dict(checkpoint['state_dict'])
checkpoint_callback = ModelCheckpoint(
save_top_k=10,
monitor="val_psnr",
mode="max",
dirpath="./checkpoints" if cfg["model"]["model_type"] == "sa" else "./checkpoints_tsa",
filename="slot_att-clevr3d-{epoch:02d}-psnr{val_psnr:.2f}.pth",
dirpath="./checkpoints",
filename="ckpt-" + str(cfg["data"]["dataset"]) + "-" + str(cfg["model"]["model_type"]) +"-{epoch:02d}-psnr{val_psnr:.2f}",
save_weights_only=True, # don't save optimizer states nor lr-scheduler, ...
every_n_train_steps=cfg["training"]["checkpoint_every"]
)
trainer = pl.Trainer(accelerator="gpu", devices=num_gpus, profiler="simple",
default_root_dir="./logs", logger=WandbLogger(project="slot-att") if args.wandb else None,
strategy="ddp_find_unused_parameters_true" if num_gpus > 1 else "auto", callbacks=[checkpoint_callback],
log_every_n_steps=100, max_steps=num_train_steps)
early_stopping = EarlyStopping(monitor="val_psnr", mode="max")
trainer = pl.Trainer(accelerator="gpu",
devices=num_gpus,
profiler="simple",
default_root_dir="./logs",
logger=WandbLogger(project="slot-att") if args.wandb else None,
strategy="ddp" if num_gpus > 1 else "auto",
callbacks=[checkpoint_callback, early_stopping],
log_every_n_steps=100,
val_check_interval=cfg["training"]["validate_every"],
max_steps=num_train_steps,
enable_model_summary=True)
trainer.fit(model, train_loader, val_loader)
#### Create datasets
vis_dataset = data.get_dataset('train', cfg['data'])
vis_loader = DataLoader(
vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"],
vis_dataset, batch_size=1, num_workers=1,
shuffle=True, worker_init_fn=data.worker_init_fn)
device = model.device
......@@ -100,21 +113,4 @@ def main():
visualize_slot_attention(num_slots, image, recon_combined, recons, masks, folder_save=args.ckpt, step=0, save_file=True)
if __name__ == "__main__":
main()
#print(f"[TRAIN] Epoch : {epoch} || Step: {global_step}, Loss: {total_loss}, Time: {datetime.timedelta(seconds=time.time() - start)}")
"""
if not epoch % cfg["training"]["checkpoint_every"]:
# Save the checkpoint of the model.
ckpt['global_step'] = global_step
ckpt['model_state_dict'] = model.state_dict()
torch.save(ckpt, args.ckpt + '/ckpt_' + str(global_step) + '.pth')
print(f"Saved checkpoint: {args.ckpt + '/ckpt_' + str(global_step) + '.pth'}")
# We visualize some test data
if not epoch % cfg["training"]["visualize_every"]:
"""
\ No newline at end of file
main()
\ No newline at end of file
......@@ -26,7 +26,7 @@ def main():
parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path')
parser.add_argument('--output', type=str, default="./outputs", help='Folder in which to save images')
parser.add_argument('--output', type=str, default="./outputs/", help='Folder in which to save images')
parser.add_argument('--step', type=int, default=".", help='Step of the model')
args = parser.parse_args()
......@@ -44,7 +44,7 @@ def main():
resolution = (128, 128)
#### Create datasets
vis_dataset = data.get_dataset('train', cfg['data'])
vis_dataset = data.get_dataset('val', cfg['data'])
vis_loader = DataLoader(
vis_dataset, batch_size=1, num_workers=cfg["training"]["num_workers"],
shuffle=True, worker_init_fn=data.worker_init_fn)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment