diff --git a/automatic_mask_train.py b/automatic_mask_train.py index 79a34e4817bad435ec7650edcc9a498742558adf..1417aafdfcd8c7dca3439bb772f8f9876d66c944 100644 --- a/automatic_mask_train.py +++ b/automatic_mask_train.py @@ -121,5 +121,4 @@ if __name__ == '__main__': show_anns(masks[0][0]) # show masks show_points(new_points, plt.gca()) # show points plt.axis('off') - plt.show() - + plt.show() \ No newline at end of file diff --git a/osrt/encoder.py b/osrt/encoder.py index 16c21cfeaaaebce863175535da0cc398ce6271c4..699ed59157a62e5a83302642262a73c6bba2922e 100644 --- a/osrt/encoder.py +++ b/osrt/encoder.py @@ -112,7 +112,7 @@ class OSRTEncoder(nn.Module): class FeatureMasking(nn.Module): def __init__(self, - points_per_side=8, + points_per_side=12, box_nms_thresh = 0.7, stability_score_thresh = 0.9, pred_iou_thresh=0.88, diff --git a/osrt/trainer.py b/osrt/trainer.py index be755a4cce4cd86fa61d89e65c59efd22ba188da..81b4e6e0d0ff35e5f6a14b51c6901fc7cf0eda13 100644 --- a/osrt/trainer.py +++ b/osrt/trainer.py @@ -73,6 +73,7 @@ class SRTTrainer: input_rays = data.get('input_rays').to(device) target_pixels = data.get('target_pixels').to(device) + input_images = input_images.permute(0, 2, 3, 1).unsqueeze(1) with torch.cuda.amp.autocast(): z = self.model.encoder(input_images, input_camera_pos, input_rays) diff --git a/requirements.txt b/requirements.txt index a59c5502564c3914fe4a15d27c587560d25cd325..49db6f2a909ace3c07f8fe9dfae438dfe3fe7957 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ imageio matplotlib tqdm opencv-python +bitsandbytes diff --git a/train.py b/train.py index 8d16072e107e635076a1f009bf7a37d0191c2097..169dac78490ab8c4508b88ed104053c86ff21fdc 100755 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ import torch import torch.optim as optim from torch.nn.parallel import DistributedDataParallel import numpy as np +import bitsandbytes as bnb import os import argparse @@ -146,7 +147,7 @@ if __name__ == '__main__': # Intialize training params = [p for p in model.parameters() if p.requires_grad] # only keep trainable parameters - optimizer = optim.Adam(params, lr=lr_scheduler.get_cur_lr(0)) + optimizer = bnb.optim.Adam8bit(params, lr=lr_scheduler.get_cur_lr(0)) # Switched from : optim.Adam(params, lr=lr_scheduler.get_cur_lr(0)) trainer = SRTTrainer(model, optimizer, cfg, device, out_dir, train_dataset.render_kwargs) checkpoint = Checkpoint(out_dir, device=device, encoder=encoder_module, decoder=decoder_module, optimizer=optimizer)