Skip to content
Snippets Groups Projects
Commit 0f20c38c authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change encoding part of the model

parent 4d10dcf9
No related branches found
No related tags found
No related merge requests found
......@@ -7,9 +7,12 @@
}
},
"model":{
"encoder": "osrt",
"encoder": "sam",
"encoder_kwargs": {
"pos_start_octave": -5,
"points_per_side": 12,
"box_nms_thresh": 0.7,
"stability_score_thresh": 0.9,
"pred_iou_thresh": 0.88,
"num_slots": 6
},
"decoder": "slot_mixer",
......@@ -19,7 +22,7 @@
},
"training":{
"num_workers": 4,
"batch_size": 64,
"batch_size": 16,
"num_gpu": 8,
"model_selection_metric": "psnr",
"model_selection_mode": "max",
......
......@@ -131,8 +131,7 @@ def train_sam(
### Encode input informations and extract masks
if isinstance(model.encoder, FeatureMasking):
input_images = input_images.permute(0, 1, 3, 4, 2) # from [b, k, c, h, w] to [b, k, h, w, c]
h, w, c = input_images[0][0].shape
masks_info, z = model.encoder(input_images,(h, w), input_camera_pos, input_rays, extract_masks=True)
masks_info, z = model.encoder(input_images, input_camera_pos, input_rays, extract_masks=True)
else:
z = model.encoder(input_images, input_camera_pos, input_rays)
......@@ -245,7 +244,7 @@ def main(cfg) -> None:
os.makedirs(cfg['training']['out_dir'], exist_ok=True)
with fabric.device:
model = OSRT(cfg)
model = OSRT(cfg["model"])
#########################
### Loading the dataset
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment