diff --git a/runs/test/config.json b/runs/test/config.json index 43c2e06f7e5d034a23d11d7bd00a43a2af6d57d5..de98f8e2c6df682615ec26ac0a483e29f9c30a2c 100644 --- a/runs/test/config.json +++ b/runs/test/config.json @@ -7,9 +7,12 @@ } }, "model":{ - "encoder": "osrt", + "encoder": "sam", "encoder_kwargs": { - "pos_start_octave": -5, + "points_per_side": 12, + "box_nms_thresh": 0.7, + "stability_score_thresh": 0.9, + "pred_iou_thresh": 0.88, "num_slots": 6 }, "decoder": "slot_mixer", @@ -19,7 +22,7 @@ }, "training":{ "num_workers": 4, - "batch_size": 64, + "batch_size": 16, "num_gpu": 8, "model_selection_metric": "psnr", "model_selection_mode": "max", diff --git a/train_lit.py b/train_lit.py index 67321c7b35dcad92b53d69fe6a438a551e164b99..b290fbdddf49c59053cc3f080363fa101c3625d5 100644 --- a/train_lit.py +++ b/train_lit.py @@ -131,8 +131,7 @@ def train_sam( ### Encode input informations and extract masks if isinstance(model.encoder, FeatureMasking): input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k, h, w, c] - h, w, c = input_images[0][0].shape - masks_info, z = model.encoder(input_images,(h, w), input_camera_pos, input_rays, extract_masks=True) + masks_info, z = model.encoder(input_images, input_camera_pos, input_rays, extract_masks=True) else: z = model.encoder(input_images, input_camera_pos, input_rays) @@ -245,7 +244,7 @@ def main(cfg) -> None: os.makedirs(cfg['training']['out_dir'], exist_ok=True) with fabric.device: - model = OSRT(cfg) + model = OSRT(cfg["model"]) ######################### ### Loading the dataset