Skip to content
Snippets Groups Projects
Commit 35bb0ae0 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Change input image shape

parent 8dd11727
No related branches found
No related tags found
No related merge requests found
......@@ -121,5 +121,4 @@ if __name__ == '__main__':
show_anns(masks[0][0]) # show masks
show_points(new_points, plt.gca()) # show points
plt.axis('off')
plt.show()
plt.show()
\ No newline at end of file
......@@ -112,7 +112,7 @@ class OSRTEncoder(nn.Module):
class FeatureMasking(nn.Module):
def __init__(self,
points_per_side=8,
points_per_side=12,
box_nms_thresh = 0.7,
stability_score_thresh = 0.9,
pred_iou_thresh=0.88,
......
......@@ -73,6 +73,7 @@ class SRTTrainer:
input_rays = data.get('input_rays').to(device)
target_pixels = data.get('target_pixels').to(device)
input_images = input_images.permute(0, 2, 3, 1).unsqueeze(1)
with torch.cuda.amp.autocast():
z = self.model.encoder(input_images, input_camera_pos, input_rays)
......
......@@ -7,3 +7,4 @@ imageio
matplotlib
tqdm
opencv-python
bitsandbytes
......@@ -2,6 +2,7 @@ import torch
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel
import numpy as np
import bitsandbytes as bnb
import os
import argparse
......@@ -146,7 +147,7 @@ if __name__ == '__main__':
# Intialize training
params = [p for p in model.parameters() if p.requires_grad] # only keep trainable parameters
optimizer = optim.Adam(params, lr=lr_scheduler.get_cur_lr(0))
optimizer = bnb.optim.Adam8bit(params, lr=lr_scheduler.get_cur_lr(0)) # Switched from : optim.Adam(params, lr=lr_scheduler.get_cur_lr(0))
trainer = SRTTrainer(model, optimizer, cfg, device, out_dir, train_dataset.render_kwargs)
checkpoint = Checkpoint(out_dir, device=device, encoder=encoder_module,
decoder=decoder_module, optimizer=optimizer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment