diff --git a/Nonevisualisation_0.png b/Nonevisualisation_0.png deleted file mode 100644 index 59df42cecd44a6e04b3cc578d596a74d54a5cc10..0000000000000000000000000000000000000000 Binary files a/Nonevisualisation_0.png and /dev/null differ diff --git a/evaluate_sa.py b/evaluate_sa.py index e5e06eb7ec4990daa29ba21aead49de35b00f270..bdd4e629011f9513dab13dadd02daa04e06ece59 100644 --- a/evaluate_sa.py +++ b/evaluate_sa.py @@ -19,7 +19,6 @@ def main(): description='Train a 3D scene representation model.' ) parser.add_argument('config', type=str, help="Where to save the checkpoints.") - parser.add_argument('--wandb', action='store_true', help='Log run to Weights and Biases.') parser.add_argument('--seed', type=int, default=0, help='Random seed.') parser.add_argument('--ckpt', type=str, default=".", help='Model checkpoint path') diff --git a/osrt/data/__init__.py b/osrt/data/__init__.py index 96d64f2bae157df08fdac791a9e16baf5eedc92a..84a842d13a91f31f58c3e6e93526afa784f28f4b 100644 --- a/osrt/data/__init__.py +++ b/osrt/data/__init__.py @@ -2,4 +2,5 @@ from osrt.data.core import get_dataset, worker_init_fn from osrt.data.nmr import NMRDataset from osrt.data.multishapenet import MultishapenetDataset from osrt.data.obsurf import Clevr3dDataset, Clevr2dDataset +from osrt.data.ycbv import YCBVideo2D diff --git a/osrt/data/core.py b/osrt/data/core.py index 3b8886ebe8c87662e814f8908a085cd560fae565..5a8bb82fca7244cf1490205e516dbd653a7412cd 100644 --- a/osrt/data/core.py +++ b/osrt/data/core.py @@ -24,6 +24,7 @@ def get_dataset(mode, cfg, max_len=None, full_scale=False): if 'kwargs' in cfg: kwargs = cfg['kwargs'] + print(kwargs) else: kwargs = dict() @@ -45,6 +46,8 @@ def get_dataset(mode, cfg, max_len=None, full_scale=False): canonical_view=canonical_view, **kwargs) elif dataset_type == 'clevr2d': dataset = data.Clevr2dDataset(dataset_folder, mode, **kwargs) + elif dataset_type == 'ycb2d': + dataset = data.YCBVideo2D(dataset_folder, mode, **kwargs) elif dataset_type == 'obsurf_msn': dataset = data.Clevr3dDataset(dataset_folder, mode, points_per_item=points_per_item, shapenet=True, max_len=max_len, full_scale=full_scale, @@ -62,3 +65,28 @@ def worker_init_fn(worker_id): base_seed = int.from_bytes(random_data, byteorder="big") np.random.seed(base_seed + worker_id) + +def downsample(x, num_steps=1): + if num_steps is None or num_steps < 1: + return x + stride = 2**num_steps + return x[stride//2::stride, stride//2::stride] + +def crop_center(image, crop_size=192): + height, width = image.shape[:2] + + center_x = width // 2 + center_y = height // 2 + + crop_size_half = crop_size // 2 + + # Calculate the top-left corner coordinates of the crop + crop_x1 = center_x - crop_size_half + crop_y1 = center_y - crop_size_half + + # Calculate the bottom-right corner coordinates of the crop + crop_x2 = center_x + crop_size_half + crop_y2 = center_y + crop_size_half + + # Crop the image + return image[crop_y1:crop_y2, crop_x1:crop_x2] \ No newline at end of file diff --git a/osrt/data/obsurf.py b/osrt/data/obsurf.py index 074ac76600a9597bb6f9ab1f00b208917479e969..bb43841cb8ad27013d7a6e5ae4deafb7c4ca5183 100644 --- a/osrt/data/obsurf.py +++ b/osrt/data/obsurf.py @@ -10,32 +10,7 @@ import os from osrt.utils.nerf import get_camera_rays, get_extrinsic, transform_points import torch.nn.functional as F - - -def downsample(x, num_steps=1): - if num_steps is None or num_steps < 1: - return x - stride = 2**num_steps - return x[stride//2::stride, stride//2::stride] - -def crop_center(image, crop_size=192): - height, width = image.shape[:2] - - center_x = width // 2 - center_y = height // 2 - - crop_size_half = crop_size // 2 - - # Calculate the top-left corner coordinates of the crop - crop_x1 = center_x - crop_size_half - crop_y1 = center_y - crop_size_half - - # Calculate the bottom-right corner coordinates of the crop - crop_x2 = center_x + crop_size_half - crop_y2 = center_y + crop_size_half - - # Crop the image - return image[crop_y1:crop_y2, crop_x1:crop_x2] +from .core import downsample, crop_center class Clevr3dDataset(Dataset): def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True, @@ -250,8 +225,6 @@ class Clevr2dDataset(Dataset): np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1) - metadata = {k: v[scene_idx] for (k, v) in self.metadata.items()} - input_masks = crop_center(torch.tensor(masks), 192) input_masks = F.interpolate(input_masks.permute(2, 0, 1).unsqueeze(0), size=128) input_masks = input_masks.squeeze(0).permute(1, 2, 0) diff --git a/osrt/data/ycb-video.py b/osrt/data/ycb-video.py deleted file mode 100644 index 615a13f132464ee50c9ca2f324439e79c3e52320..0000000000000000000000000000000000000000 --- a/osrt/data/ycb-video.py +++ /dev/null @@ -1,372 +0,0 @@ -import numpy as np -import imageio -import yaml -from torch.utils.data import Dataset - -import os - -def get_extrinsic(camera_pos, rays=None, track_point=None, fourxfour=True): - """ Returns extrinsic matrix mapping world to camera coordinates. - Args: - camera_pos (np array [3]): Camera position. - track_point (np array [3]): Point on which the camera is fixated. - rays (np array [h, w, 3]): Rays eminating from the camera. Used to determine track_point - if it's not given. - fourxfour (bool): If true, a 4x4 matrix for homogeneous 3D coordinates is returned. - Otherwise, a 3x4 matrix is returned. - Returns: - extrinsic camera matrix (np array [4, 4] or [3, 4]) - """ - if track_point is None: - h, w, _ = rays.shape - if h % 2 == 0: - center_rays = rays[h//2 - 1:h//2 + 1] - else: - center_rays = rays[h//2:h//2+1] - - if w % 2 == 0: - center_rays = rays[:, w//2 - 1:w//2 + 1] - else: - center_rays = rays[:, w//2:w//2+1] - - camera_z = center_rays.mean((0, 1)) - else: - camera_z = track_point - camera_pos - - camera_z = camera_z / np.linalg.norm(camera_z, axis=-1, keepdims=True) - - # We assume that (a) the z-axis is vertical, and that - # (b) the camera's horizontal, the x-axis, is orthogonal to the vertical, i.e., - # the camera is in a level position. - vertical = np.array((0., 0., 1.)) - - camera_x = np.cross(camera_z, vertical) - camera_x = camera_x / np.linalg.norm(camera_x, axis=-1, keepdims=True) - camera_y = np.cross(camera_z, camera_x) - - camera_matrix = np.stack((camera_x, camera_y, camera_z), -2) - translation = -np.einsum('...ij,...j->...i', camera_matrix, camera_pos) - camera_matrix = np.concatenate((camera_matrix, np.expand_dims(translation, -1)), -1) - - if fourxfour: - filler = np.array([[0., 0., 0., 1.]]) - camera_matrix = np.concatenate((camera_matrix, filler), 0) - return camera_matrix - - -def transform_points(points, transform, translate=True): - """ Apply linear transform to a np array of points. - Args: - points (np array [..., 3]): Points to transform. - transform (np array [3, 4] or [4, 4]): Linear map. - translate (bool): If false, do not apply translation component of transform. - Returns: - transformed points (np array [..., 3]) - """ - # Append ones or zeros to get homogenous coordinates - if translate: - constant_term = np.ones_like(points[..., :1]) - else: - constant_term = np.zeros_like(points[..., :1]) - points = np.concatenate((points, constant_term), axis=-1) - - points = np.einsum('nm,...m->...n', transform, points) - return points[..., :3] - -def get_camera_rays(c_pos, c_rot, width=640, height=480, focal_length=0.035, sensor_width=0.032, - vertical=None): - if vertical is None: - vertical = np.array((0., 0., 1.)) - - c_dir = c_rot - - img_plane_center = c_pos + c_dir * focal_length - - # The horizontal axis of the camera sensor is horizontal (z=0) and orthogonal to the view axis - img_plane_horizontal = np.cross(c_dir, vertical) - img_plane_horizontal = img_plane_horizontal / np.linalg.norm(img_plane_horizontal) - - # The vertical axis is orthogonal to both the view axis and the horizontal axis - img_plane_vertical = np.cross(c_dir, img_plane_horizontal) - img_plane_vertical = img_plane_vertical / np.linalg.norm(img_plane_vertical) - - # Double check that everything is orthogonal - def is_small(x, atol=1e-7): - return abs(x) < atol - - assert(is_small(np.dot(img_plane_vertical, img_plane_horizontal))) - assert(is_small(np.dot(img_plane_vertical, c_dir))) - assert(is_small(np.dot(c_dir, img_plane_horizontal))) - - # Sensor height is implied by sensor width and aspect ratio - sensor_height = (sensor_width / width) * height - - # Compute pixel boundaries - horizontal_offsets = np.linspace(-1, 1, width+1) * sensor_width / 2 - vertical_offsets = np.linspace(-1, 1, height+1) * sensor_height / 2 - - # Compute pixel centers - horizontal_offsets = (horizontal_offsets[:-1] + horizontal_offsets[1:]) / 2 - vertical_offsets = (vertical_offsets[:-1] + vertical_offsets[1:]) / 2 - - horizontal_offsets = np.repeat(np.reshape(horizontal_offsets, (1, width)), height, 0) - vertical_offsets = np.repeat(np.reshape(vertical_offsets, (height, 1)), width, 1) - - - horizontal_offsets = (np.reshape(horizontal_offsets, (height, width, 1)) * - np.reshape(img_plane_horizontal, (1, 1, 3))) - vertical_offsets = (np.reshape(vertical_offsets, (height, width, 1)) * - np.reshape(img_plane_vertical, (1, 1, 3))) - - image_plane = horizontal_offsets + vertical_offsets - - image_plane = image_plane + np.reshape(img_plane_center, (1, 1, 3)) - c_pos_exp = np.reshape(c_pos, (1, 1, 3)) - rays = image_plane - c_pos_exp - ray_norms = np.linalg.norm(rays, axis=2, keepdims=True) - rays = rays / ray_norms - return rays.astype(np.float32) - -def downsample(x, num_steps=1): - if num_steps is None or num_steps < 1: - return x - stride = 2**num_steps - return x[stride//2::stride, stride//2::stride] - - -class YCBVideo3D(Dataset): - def __init__(self, path, mode, max_views=None, points_per_item=2048, canonical_view=True, - max_len=None, full_scale=False, shapenet=False, downsample=None): - """ Loads the YCB-Video dataset that we have adapted. - - Args: - path (str): Path to dataset. - mode (str): 'train', 'val', or 'test'. - points_per_item (int): Number of target points per scene. - max_len (int): Limit to the number of entries in the dataset. - canonical_view (bool): Return data in canonical camera coordinates (like in SRT), as opposed - to world coordinates. - full_scale (bool): Return all available target points, instead of sampling. - downsample (int): Downsample height and width of input image by a factor of 2**downsample - """ - self.path = path - self.mode = mode - self.points_per_item = points_per_item - self.max_len = max_len - self.canonical = canonical_view - self.full_scale = full_scale - self.shapenet = shapenet - self.downsample = downsample - - self.max_num_entities = 21 # max number of objects in a scene - self.num_views = 3 # TODO : set this number for each scene - - self.start_idx, self.end_idx = {'train': (0, 70000), - 'val': (70000, 75000), - 'test': (85000, 100000)}[mode] - - self.metadata = np.load(os.path.join(path, 'metadata.npz')) - self.metadata = {k: v for k, v in self.metadata.items()} - - self.idxs = np.arange(self.start_idx, self.end_idx) - - dataset_name = 'YCB-Video' - print(f'Initialized {dataset_name} {mode} set, {len(self.idxs)} examples') - print(self.idxs) - - self.render_kwargs = { - 'min_dist': 0.035, - 'max_dist': 35.} - - def __len__(self): - if self.max_len is not None: - return self.max_len - return len(self.idxs) * self.num_views - - def __getitem__(self, idx): - scene_idx = idx % len(self.idxs) - view_idx = idx // len(self.idxs) - - scene_idx = self.idxs[scene_idx] - - imgs = [np.asarray(imageio.imread( - os.path.join(self.path, 'images', f'img_{scene_idx}_{v}.png'))) - for v in range(self.num_views)] - - imgs = [img[..., :3].astype(np.float32) / 255 for img in imgs] - - mask_idxs = [imageio.imread(os.path.join(self.path, 'masks', f'masks_{scene_idx}_{v}.png')) - for v in range(self.num_views)] - masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8) - np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1) - - input_image = downsample(imgs[view_idx], num_steps=self.downsample) - input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0) - - all_rays = [] - # TODO : find a way to get the camera poses - all_camera_pos = self.metadata['camera_pos'][:self.num_views].astype(np.float32) - all_camera_rot= self.metadata['camera_rot'][:self.num_views].astype(np.float32) - for i in range(self.num_views): - cur_rays = get_camera_rays(all_camera_pos[i], all_camera_rot[i], noisy=False) # TODO : adapt function - all_rays.append(cur_rays) - all_rays = np.stack(all_rays, 0).astype(np.float32) - - input_camera_pos = all_camera_pos[view_idx] - - if self.canonical: - track_point = np.zeros_like(input_camera_pos) # All cameras are pointed at the origin - canonical_extrinsic = get_extrinsic(input_camera_pos, track_point=track_point) # TODO : adapt function - canonical_extrinsic = canonical_extrinsic.astype(np.float32) - all_rays = transform_points(all_rays, canonical_extrinsic, translate=False) # TODO : adapt function - all_camera_pos = transform_points(all_camera_pos, canonical_extrinsic) - input_camera_pos = all_camera_pos[view_idx] - - input_rays = all_rays[view_idx] - input_rays = downsample(input_rays, num_steps=self.downsample) - input_rays = np.expand_dims(input_rays, 0) - - input_masks = masks[view_idx] - input_masks = downsample(input_masks, num_steps=self.downsample) - input_masks = np.expand_dims(input_masks, 0) - - input_camera_pos = np.expand_dims(input_camera_pos, 0) - - all_pixels = np.reshape(np.stack(imgs, 0), (self.num_views * 240 * 320, 3)) - all_rays = np.reshape(all_rays, (self.num_views * 240 * 320, 3)) - all_camera_pos = np.tile(np.expand_dims(all_camera_pos, 1), (1, 240 * 320, 1)) - all_camera_pos = np.reshape(all_camera_pos, (self.num_views * 240 * 320, 3)) - all_masks = np.reshape(masks, (self.num_views * 240 * 320, self.max_num_entities)) - - num_points = all_rays.shape[0] - - if not self.full_scale: - # If we have fewer points than we want, sample with replacement - replace = num_points < self.points_per_item - sampled_idxs = np.random.choice(np.arange(num_points), - size=(self.points_per_item,), - replace=replace) - - target_rays = all_rays[sampled_idxs] - target_camera_pos = all_camera_pos[sampled_idxs] - target_pixels = all_pixels[sampled_idxs] - target_masks = all_masks[sampled_idxs] - else: - target_rays = all_rays - target_camera_pos = all_camera_pos - target_pixels = all_pixels - target_masks = all_masks - - result = { - 'input_images': input_images, # [1, 3, h, w] - 'input_camera_pos': input_camera_pos, # [1, 3] - 'input_rays': input_rays, # [1, h, w, 3] - 'input_masks': input_masks, # [1, h, w, self.max_num_entities] - 'target_pixels': target_pixels, # [p, 3] - 'target_camera_pos': target_camera_pos, # [p, 3] - 'target_rays': target_rays, # [p, 3] - 'target_masks': target_masks, # [p, self.max_num_entities] - 'sceneid': idx, # int - } - - if self.canonical: - result['transform'] = canonical_extrinsic # [3, 4] (optional) - - return result - -class YCBVideo2D(Dataset): - def __init__(self, path, mode, downsample=None): - """ Loads the YCB-Video dataset that we have adapted. - - Args: - path (str): Path to dataset. - mode (str): 'train', 'val', or 'test'. - downsample (int): Downsample height and width of input image by a factor of 2**downsample - """ - self.path = path - self.mode = mode - self.downsample = downsample - - self.max_num_entities = 21 # max number of objects in a scene - - # TODO : set the right number here - - dataset_name = 'YCB-Video' - print(f'Initialized {dataset_name} {mode} set') - - - def __len__(self): - return len(self.idxs) - - def __getitem__(self, idx): - - imgs = [np.asarray(imageio.imread( - os.path.join(self.path, 'images', f'img_{scene_idx}_{v}.png'))) - for v in range(self.num_views)] - - imgs = [img[..., :3].astype(np.float32) / 255 for img in imgs] - - mask_idxs = [imageio.imread(os.path.join(self.path, 'masks', f'masks_{scene_idx}_{v}.png')) - for v in range(self.num_views)] - masks = np.zeros((self.num_views, 240, 320, self.max_num_entities), dtype=np.uint8) - np.put_along_axis(masks, np.expand_dims(mask_idxs, -1), 1, axis=-1) - - input_image = downsample(imgs[view_idx], num_steps=self.downsample) - input_images = np.expand_dims(np.transpose(input_image, (2, 0, 1)), 0) - - all_rays = [] - # TODO : find a way to get the camera poses - all_camera_pos = self.metadata['camera_pos'][:self.num_views].astype(np.float32) - all_camera_rot= self.metadata['camera_rot'][:self.num_views].astype(np.float32) - for i in range(self.num_views): - cur_rays = get_camera_rays(all_camera_pos[i], all_camera_rot[i], noisy=False) # TODO : adapt function - all_rays.append(cur_rays) - all_rays = np.stack(all_rays, 0).astype(np.float32) - - input_camera_pos = all_camera_pos[view_idx] - - if self.canonical: - track_point = np.zeros_like(input_camera_pos) # All cameras are pointed at the origin - canonical_extrinsic = get_extrinsic(input_camera_pos, track_point=track_point) # TODO : adapt function - canonical_extrinsic = canonical_extrinsic.astype(np.float32) - all_rays = transform_points(all_rays, canonical_extrinsic, translate=False) # TODO : adapt function - all_camera_pos = transform_points(all_camera_pos, canonical_extrinsic) - input_camera_pos = all_camera_pos[view_idx] - - input_rays = all_rays[view_idx] - input_rays = downsample(input_rays, num_steps=self.downsample) - input_rays = np.expand_dims(input_rays, 0) - - input_masks = masks[view_idx] - input_masks = downsample(input_masks, num_steps=self.downsample) - input_masks = np.expand_dims(input_masks, 0) - - input_camera_pos = np.expand_dims(input_camera_pos, 0) - - all_pixels = np.reshape(np.stack(imgs, 0), (self.num_views * 240 * 320, 3)) - all_rays = np.reshape(all_rays, (self.num_views * 240 * 320, 3)) - all_camera_pos = np.tile(np.expand_dims(all_camera_pos, 1), (1, 240 * 320, 1)) - all_camera_pos = np.reshape(all_camera_pos, (self.num_views * 240 * 320, 3)) - all_masks = np.reshape(masks, (self.num_views * 240 * 320, self.max_num_entities)) - - target_camera_pos = all_camera_pos - target_pixels = all_pixels - target_masks = all_masks - - result = { - 'input_images': input_images, # [1, 3, h, w] - 'input_camera_pos': input_camera_pos, # [1, 3] - 'input_rays': input_rays, # [1, h, w, 3] - 'input_masks': input_masks, # [1, h, w, self.max_num_entities] - 'target_pixels': target_pixels, # [p, 3] - 'target_camera_pos': target_camera_pos, # [p, 3] - 'target_masks': target_masks, # [p, self.max_num_entities] - 'sceneid': idx, # int - } - - if self.canonical: - result['transform'] = canonical_extrinsic # [3, 4] (optional) - - return result - - diff --git a/osrt/utils/visualization_utils.py b/osrt/utils/visualization_utils.py index 6e36fafe5039d8abf7a76931366649be8a6c5599..f86b365a1f4d113094089bc56b8006c7ebc45370 100644 --- a/osrt/utils/visualization_utils.py +++ b/osrt/utils/visualization_utils.py @@ -114,4 +114,5 @@ def visualize_slot_attention(num_slots, image, recon_combined, recons, masks, fo if not save_file: plt.show() else: + folder_save if folder_save is not None else "./" plt.savefig(f'{folder_save}visualisation_{step}.png', bbox_inches='tight') diff --git a/quick_test.py b/quick_test.py index d63417171cf34291ea8a66fbee23435b2a972d00..7d8cf5c16b74bffe57194df2bd48448e1ec389e2 100644 --- a/quick_test.py +++ b/quick_test.py @@ -3,18 +3,19 @@ from osrt import data from torch.utils.data import DataLoader import matplotlib.pyplot as plt -with open("runs/clevr/slot_att/config.yaml", 'r') as f: +with open("runs/ycb/slot_att/config.yaml", 'r') as f: cfg = yaml.load(f, Loader=yaml.CLoader) - +cfg['data']['path'] = "/home/achapin/Documents/Datasets/ycbv_small" train_dataset = data.get_dataset('train', cfg['data']) train_loader = DataLoader(train_dataset, batch_size=2, num_workers=0,shuffle=True) + for val in train_loader: print(f"Shape masks {val['input_masks'].shape}") fig, axes = plt.subplots(2, 2) - axes[0][0].imshow(val['input_image'][0]) + axes[0][0].imshow(val['input_images'][0].permute(1, 2, 0)) axes[0][1].imshow(val['input_masks'][0][:, :, 0]) - axes[1][0].imshow(val['input_image'][1]) + axes[1][0].imshow(val['input_images'][1].permute(1, 2, 0)) axes[1][1].imshow(val['input_masks'][1][:, :, 0]) plt.show() \ No newline at end of file diff --git a/runs/ycb/slot_att/config.yaml b/runs/ycb/slot_att/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ebc5132f030cfbcb7720e5226fdd97a272414880 --- /dev/null +++ b/runs/ycb/slot_att/config.yaml @@ -0,0 +1,24 @@ +data: + dataset: ycb2d +model: + num_slots: 10 + iters: 3 + model_type: sa + input_dim: 64 + slot_dim: 64 + hidden_dim: 128 + iters: 3 +training: + num_workers: 2 + num_gpus: 1 + batch_size: 8 + max_it: 333000000 + warmup_it: 10000 + lr_warmup: 5000 + decay_rate: 0.5 + decay_it: 100000 + visualize_every: 5000 + validate_every: 5000 + checkpoint_every: 1000 + backup_every: 25000 + diff --git a/runs/ycb/slot_att/config_tsa.yaml b/runs/ycb/slot_att/config_tsa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e2a05a7a6a25fc626b602c57bb6983049b7b7d8 --- /dev/null +++ b/runs/ycb/slot_att/config_tsa.yaml @@ -0,0 +1,24 @@ +data: + dataset: ycb2d +model: + num_slots: 10 + iters: 3 + model_type: tsa + input_dim: 64 + slot_dim: 64 + hidden_dim: 128 + iters: 3 +training: + num_workers: 2 + num_gpus: 1 + batch_size: 8 + max_it: 333000000 + warmup_it: 10000 + lr_warmup: 5000 + decay_rate: 0.5 + decay_it: 100000 + visualize_every: 5000 + validate_every: 5000 + checkpoint_every: 1000 + backup_every: 25000 + diff --git a/train_sa.py b/train_sa.py index c1e9daaac966da48fd4cd98ca77e7a3c06e316e3..5963cf580f6422d0809ccbe2f465bc14e7be7161 100644 --- a/train_sa.py +++ b/train_sa.py @@ -44,7 +44,6 @@ def main(): batch_size = cfg["training"]["batch_size"] num_gpus = cfg["training"]["num_gpus"] num_slots = cfg["model"]["num_slots"] - num_iterations = cfg["model"]["iters"] num_train_steps = cfg["training"]["max_it"] num_workers = cfg["training"]["num_workers"] resolution = (128, 128) @@ -84,7 +83,7 @@ def main(): devices=num_gpus, profiler="simple", default_root_dir="./logs", - logger=WandbLogger(project="slot-att") if args.wandb else None, + logger=WandbLogger(project="slot-att", offline=True) if args.wandb else None, strategy="ddp" if num_gpus > 1 else "auto", callbacks=[checkpoint_callback, early_stopping], log_every_n_steps=100,