Skip to content
Snippets Groups Projects
Commit ac9dbb39 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Add SAM submodule

parent bf7455df
No related branches found
No related tags found
No related merge requests found
[submodule "segment-anything"]
path = segment-anything
url = git@github.com:facebookresearch/segment-anything.git
...@@ -93,7 +93,7 @@ class OSRTEncoder(nn.Module): ...@@ -93,7 +93,7 @@ class OSRTEncoder(nn.Module):
return slot_latents return slot_latents
class FeatureMasking(nn.Module): class FeatureMasking(nn.Module):
def __init__(self, pos_start_octave=0, num_slots=6, slot_dim=1536, slot_iters=1, sam_model="default", sam_path="sam_vit_h_4b8939.pth", def __init__(self, pos_start_octave=0, num_slots=6, num_conv_blocks=3, num_att_blocks=5, slot_dim=1536, slot_iters=1, sam_model="default", sam_path="sam_vit_h_4b8939.pth",
randomize_initial_slots=False): randomize_initial_slots=False):
super().__init__() super().__init__()
...@@ -128,7 +128,7 @@ class FeatureMasking(nn.Module): ...@@ -128,7 +128,7 @@ class FeatureMasking(nn.Module):
def forward(self, images, camera_pos, rays): def forward(self, images, camera_pos, rays):
masks = self.mask_generator.generate(image) masks = self.mask_generator.generate(images)
batch_size, num_images = images.shape[:2] batch_size, num_images = images.shape[:2]
x = images.flatten(0, 1) x = images.flatten(0, 1)
......
Subproject commit 6fdee8f2727f4506cfbbe553e23b895e27956588
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment