From 35bb0ae018ef1dacfffa8aa6f2182bb42e0b2426 Mon Sep 17 00:00:00 2001
From: alexcbb <alexchapin@hotmail.fr>
Date: Wed, 28 Jun 2023 08:53:25 +0200
Subject: [PATCH] Change input image shape

---
 automatic_mask_train.py | 3 +--
 osrt/encoder.py         | 2 +-
 osrt/trainer.py         | 1 +
 requirements.txt        | 1 +
 train.py                | 3 ++-
 5 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/automatic_mask_train.py b/automatic_mask_train.py
index 79a34e4..1417aaf 100644
--- a/automatic_mask_train.py
+++ b/automatic_mask_train.py
@@ -121,5 +121,4 @@ if __name__ == '__main__':
         show_anns(masks[0][0]) # show masks 
         show_points(new_points, plt.gca()) # show points
         plt.axis('off')
-        plt.show()
-
+        plt.show()
\ No newline at end of file
diff --git a/osrt/encoder.py b/osrt/encoder.py
index 16c21cf..699ed59 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -112,7 +112,7 @@ class OSRTEncoder(nn.Module):
 
 class FeatureMasking(nn.Module):
     def __init__(self, 
-                 points_per_side=8,
+                 points_per_side=12,
                  box_nms_thresh = 0.7,
                  stability_score_thresh = 0.9,
                  pred_iou_thresh=0.88,
diff --git a/osrt/trainer.py b/osrt/trainer.py
index be755a4..81b4e6e 100644
--- a/osrt/trainer.py
+++ b/osrt/trainer.py
@@ -73,6 +73,7 @@ class SRTTrainer:
         input_rays = data.get('input_rays').to(device)
         target_pixels = data.get('target_pixels').to(device)
 
+        input_images = input_images.permute(0, 2, 3, 1).unsqueeze(1)
         with torch.cuda.amp.autocast(): 
             z = self.model.encoder(input_images, input_camera_pos, input_rays)
 
diff --git a/requirements.txt b/requirements.txt
index a59c550..49db6f2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,3 +7,4 @@ imageio
 matplotlib
 tqdm
 opencv-python
+bitsandbytes
diff --git a/train.py b/train.py
index 8d16072..169dac7 100755
--- a/train.py
+++ b/train.py
@@ -2,6 +2,7 @@ import torch
 import torch.optim as optim
 from torch.nn.parallel import DistributedDataParallel
 import numpy as np
+import bitsandbytes as bnb
 
 import os
 import argparse
@@ -146,7 +147,7 @@ if __name__ == '__main__':
 
     # Intialize training
     params = [p for p in model.parameters() if p.requires_grad] # only keep trainable parameters
-    optimizer = optim.Adam(params, lr=lr_scheduler.get_cur_lr(0))
+    optimizer = bnb.optim.Adam8bit(params, lr=lr_scheduler.get_cur_lr(0)) # Switched from : optim.Adam(params, lr=lr_scheduler.get_cur_lr(0))
     trainer = SRTTrainer(model, optimizer, cfg, device, out_dir, train_dataset.render_kwargs)
     checkpoint = Checkpoint(out_dir, device=device, encoder=encoder_module,
                             decoder=decoder_module, optimizer=optimizer)
-- 
GitLab