From 35bb0ae018ef1dacfffa8aa6f2182bb42e0b2426 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Wed, 28 Jun 2023 08:53:25 +0200 Subject: [PATCH] Change input image shape --- automatic_mask_train.py | 3 +-- osrt/encoder.py | 2 +- osrt/trainer.py | 1 + requirements.txt | 1 + train.py | 3 ++- 5 files changed, 6 insertions(+), 4 deletions(-) diff --git a/automatic_mask_train.py b/automatic_mask_train.py index 79a34e4..1417aaf 100644 --- a/automatic_mask_train.py +++ b/automatic_mask_train.py @@ -121,5 +121,4 @@ if __name__ == '__main__': show_anns(masks[0][0]) # show masks show_points(new_points, plt.gca()) # show points plt.axis('off') - plt.show() - + plt.show() \ No newline at end of file diff --git a/osrt/encoder.py b/osrt/encoder.py index 16c21cf..699ed59 100644 --- a/osrt/encoder.py +++ b/osrt/encoder.py @@ -112,7 +112,7 @@ class OSRTEncoder(nn.Module): class FeatureMasking(nn.Module): def __init__(self, - points_per_side=8, + points_per_side=12, box_nms_thresh = 0.7, stability_score_thresh = 0.9, pred_iou_thresh=0.88, diff --git a/osrt/trainer.py b/osrt/trainer.py index be755a4..81b4e6e 100644 --- a/osrt/trainer.py +++ b/osrt/trainer.py @@ -73,6 +73,7 @@ class SRTTrainer: input_rays = data.get('input_rays').to(device) target_pixels = data.get('target_pixels').to(device) + input_images = input_images.permute(0, 2, 3, 1).unsqueeze(1) with torch.cuda.amp.autocast(): z = self.model.encoder(input_images, input_camera_pos, input_rays) diff --git a/requirements.txt b/requirements.txt index a59c550..49db6f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ imageio matplotlib tqdm opencv-python +bitsandbytes diff --git a/train.py b/train.py index 8d16072..169dac7 100755 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ import torch import torch.optim as optim from torch.nn.parallel import DistributedDataParallel import numpy as np +import bitsandbytes as bnb import os import argparse @@ -146,7 +147,7 @@ if __name__ == '__main__': # Intialize training params = [p for p in model.parameters() if p.requires_grad] # only keep trainable parameters - optimizer = optim.Adam(params, lr=lr_scheduler.get_cur_lr(0)) + optimizer = bnb.optim.Adam8bit(params, lr=lr_scheduler.get_cur_lr(0)) # Switched from : optim.Adam(params, lr=lr_scheduler.get_cur_lr(0)) trainer = SRTTrainer(model, optimizer, cfg, device, out_dir, train_dataset.render_kwargs) checkpoint = Checkpoint(out_dir, device=device, encoder=encoder_module, decoder=decoder_module, optimizer=optimizer) -- GitLab