diff --git a/automatic_mask_train.py b/automatic_mask_train.py
index 79a34e4817bad435ec7650edcc9a498742558adf..1417aafdfcd8c7dca3439bb772f8f9876d66c944 100644
--- a/automatic_mask_train.py
+++ b/automatic_mask_train.py
@@ -121,5 +121,4 @@ if __name__ == '__main__':
         show_anns(masks[0][0]) # show masks 
         show_points(new_points, plt.gca()) # show points
         plt.axis('off')
-        plt.show()
-
+        plt.show()
\ No newline at end of file
diff --git a/osrt/encoder.py b/osrt/encoder.py
index 16c21cfeaaaebce863175535da0cc398ce6271c4..699ed59157a62e5a83302642262a73c6bba2922e 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -112,7 +112,7 @@ class OSRTEncoder(nn.Module):
 
 class FeatureMasking(nn.Module):
     def __init__(self, 
-                 points_per_side=8,
+                 points_per_side=12,
                  box_nms_thresh = 0.7,
                  stability_score_thresh = 0.9,
                  pred_iou_thresh=0.88,
diff --git a/osrt/trainer.py b/osrt/trainer.py
index be755a4cce4cd86fa61d89e65c59efd22ba188da..81b4e6e0d0ff35e5f6a14b51c6901fc7cf0eda13 100644
--- a/osrt/trainer.py
+++ b/osrt/trainer.py
@@ -73,6 +73,7 @@ class SRTTrainer:
         input_rays = data.get('input_rays').to(device)
         target_pixels = data.get('target_pixels').to(device)
 
+        input_images = input_images.permute(0, 2, 3, 1).unsqueeze(1)
         with torch.cuda.amp.autocast(): 
             z = self.model.encoder(input_images, input_camera_pos, input_rays)
 
diff --git a/requirements.txt b/requirements.txt
index a59c5502564c3914fe4a15d27c587560d25cd325..49db6f2a909ace3c07f8fe9dfae438dfe3fe7957 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,3 +7,4 @@ imageio
 matplotlib
 tqdm
 opencv-python
+bitsandbytes
diff --git a/train.py b/train.py
index 8d16072e107e635076a1f009bf7a37d0191c2097..169dac78490ab8c4508b88ed104053c86ff21fdc 100755
--- a/train.py
+++ b/train.py
@@ -2,6 +2,7 @@ import torch
 import torch.optim as optim
 from torch.nn.parallel import DistributedDataParallel
 import numpy as np
+import bitsandbytes as bnb
 
 import os
 import argparse
@@ -146,7 +147,7 @@ if __name__ == '__main__':
 
     # Intialize training
     params = [p for p in model.parameters() if p.requires_grad] # only keep trainable parameters
-    optimizer = optim.Adam(params, lr=lr_scheduler.get_cur_lr(0))
+    optimizer = bnb.optim.Adam8bit(params, lr=lr_scheduler.get_cur_lr(0)) # Switched from : optim.Adam(params, lr=lr_scheduler.get_cur_lr(0))
     trainer = SRTTrainer(model, optimizer, cfg, device, out_dir, train_dataset.render_kwargs)
     checkpoint = Checkpoint(out_dir, device=device, encoder=encoder_module,
                             decoder=decoder_module, optimizer=optimizer)