diff --git a/osrt/encoder.py b/osrt/encoder.py
index 2c49b92820d75229d028b9ce8c354e24891e6735..a121d0f98ce89680a9f40c1fc0459a4779653f55 100644
--- a/osrt/encoder.py
+++ b/osrt/encoder.py
@@ -16,12 +16,12 @@ from osrt.layers import RayEncoder, Transformer, SlotAttention
 from osrt.utils.common import batch_iterator, MaskData, calculate_stability_score, get_indices_sorted_pos, get_positional_embedding, create_points_grid, rotation_matrix_to_euler_angles
 from osrt.utils.nerf import get_extrinsic
 
-from segment_anything import sam_model_registry
-from segment_anything.modeling.image_encoder import ImageEncoderViT
-from segment_anything.modeling.mask_decoder import MaskDecoder
-from segment_anything.modeling.prompt_encoder import PromptEncoder
-from segment_anything.utils.transforms import ResizeLongestSide
-from segment_anything.utils.amg import batched_mask_to_box, remove_small_regions, mask_to_rle_pytorch, area_from_rle
+from sam import sam_model_registry
+from sam.image_encoder import ImageEncoderViT
+from sam.mask_decoder import MaskDecoder
+from sam.prompt_encoder import PromptEncoder
+from sam.utils.transforms import ResizeLongestSide
+from sam.utils.amg import batched_mask_to_box, remove_small_regions, mask_to_rle_pytorch, area_from_rle
 
 from torch.nn.utils.rnn import pad_sequence
 
diff --git a/osrt/sam/__init__.py b/osrt/sam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..93e458e4b70b5bf24898fc926d5d1cabe793f88c
--- /dev/null
+++ b/osrt/sam/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .sam import Sam
+from .image_encoder import ImageEncoderViT
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+from .transformer import TwoWayTransformer
+from .build_sam import (
+    build_sam,
+    build_sam_vit_h,
+    build_sam_vit_l,
+    build_sam_vit_b,
+    sam_model_registry,
+)
\ No newline at end of file
diff --git a/osrt/sam/build_sam.py b/osrt/sam/build_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..f23f46d54397a156b83462ffabc9febfcb6d80ae
--- /dev/null
+++ b/osrt/sam/build_sam.py
@@ -0,0 +1,108 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from functools import partial
+
+from . import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
+
+
+def build_sam_vit_h(checkpoint=None):
+    return _build_sam(
+        encoder_embed_dim=1280,
+        encoder_depth=32,
+        encoder_num_heads=16,
+        encoder_global_attn_indexes=[7, 15, 23, 31],
+        checkpoint=checkpoint,
+    )
+
+
+build_sam = build_sam_vit_h
+
+
+def build_sam_vit_l(checkpoint=None):
+    return _build_sam(
+        encoder_embed_dim=1024,
+        encoder_depth=24,
+        encoder_num_heads=16,
+        encoder_global_attn_indexes=[5, 11, 17, 23],
+        checkpoint=checkpoint,
+    )
+
+
+def build_sam_vit_b(checkpoint=None):
+    return _build_sam(
+        encoder_embed_dim=768,
+        encoder_depth=12,
+        encoder_num_heads=12,
+        encoder_global_attn_indexes=[2, 5, 8, 11],
+        checkpoint=checkpoint,
+    )
+
+
+
+sam_model_registry = {
+    "default": build_sam_vit_h,
+    "vit_h": build_sam_vit_h,
+    "vit_l": build_sam_vit_l,
+    "vit_b": build_sam_vit_b,
+}
+
+
+def _build_sam(
+    encoder_embed_dim,
+    encoder_depth,
+    encoder_num_heads,
+    encoder_global_attn_indexes,
+    checkpoint=None,
+):
+    prompt_embed_dim = 256
+    image_size = 1024
+    vit_patch_size = 16
+    image_embedding_size = image_size // vit_patch_size
+    sam = Sam(
+        image_encoder=ImageEncoderViT(
+            depth=encoder_depth,
+            embed_dim=encoder_embed_dim,
+            img_size=image_size,
+            mlp_ratio=4,
+            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+            num_heads=encoder_num_heads,
+            patch_size=vit_patch_size,
+            qkv_bias=True,
+            use_rel_pos=True,
+            global_attn_indexes=encoder_global_attn_indexes,
+            window_size=14,
+            out_chans=prompt_embed_dim,
+        ),
+        prompt_encoder=PromptEncoder(
+            embed_dim=prompt_embed_dim,
+            image_embedding_size=(image_embedding_size, image_embedding_size),
+            input_image_size=(image_size, image_size),
+            mask_in_chans=16,
+        ),
+        mask_decoder=MaskDecoder(
+            num_multimask_outputs=3,
+            transformer=TwoWayTransformer(
+                depth=2,
+                embedding_dim=prompt_embed_dim,
+                mlp_dim=2048,
+                num_heads=8,
+            ),
+            transformer_dim=prompt_embed_dim,
+            iou_head_depth=3,
+            iou_head_hidden_dim=256,
+        ),
+        pixel_mean=[123.675, 116.28, 103.53],
+        pixel_std=[58.395, 57.12, 57.375],
+    )
+    sam.eval()
+    if checkpoint is not None:
+        with open(checkpoint, "rb") as f:
+            state_dict = torch.load(f)
+        sam.load_state_dict(state_dict)
+    return sam
diff --git a/osrt/sam/common.py b/osrt/sam/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96
--- /dev/null
+++ b/osrt/sam/common.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from typing import Type
+
+
+class MLPBlock(nn.Module):
+    def __init__(
+        self,
+        embedding_dim: int,
+        mlp_dim: int,
+        act: Type[nn.Module] = nn.GELU,
+    ) -> None:
+        super().__init__()
+        self.lin1 = nn.Linear(embedding_dim, mlp_dim)
+        self.lin2 = nn.Linear(mlp_dim, embedding_dim)
+        self.act = act()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.lin2(self.act(self.lin1(x)))
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119  # noqa
+class LayerNorm2d(nn.Module):
+    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(num_channels))
+        self.bias = nn.Parameter(torch.zeros(num_channels))
+        self.eps = eps
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        u = x.mean(1, keepdim=True)
+        s = (x - u).pow(2).mean(1, keepdim=True)
+        x = (x - u) / torch.sqrt(s + self.eps)
+        x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        return x
diff --git a/osrt/sam/image_encoder.py b/osrt/sam/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..528e4e010214eb59192ec4ca99fb44e440fac6d5
--- /dev/null
+++ b/osrt/sam/image_encoder.py
@@ -0,0 +1,396 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Type
+
+from .common import LayerNorm2d, MLPBlock
+
+
+# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
+class ImageEncoderViT(nn.Module):
+    def __init__(
+        self,
+        img_size: int = 1024,
+        patch_size: int = 16,
+        in_chans: int = 3,
+        embed_dim: int = 768,
+        depth: int = 12,
+        num_heads: int = 12,
+        mlp_ratio: float = 4.0,
+        out_chans: int = 256,
+        qkv_bias: bool = True,
+        norm_layer: Type[nn.Module] = nn.LayerNorm,
+        act_layer: Type[nn.Module] = nn.GELU,
+        use_abs_pos: bool = True,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        window_size: int = 0,
+        global_attn_indexes: Tuple[int, ...] = (),
+    ) -> None:
+        """
+        Args:
+            img_size (int): Input image size.
+            patch_size (int): Patch size.
+            in_chans (int): Number of input image channels.
+            embed_dim (int): Patch embedding dimension.
+            depth (int): Depth of ViT.
+            num_heads (int): Number of attention heads in each ViT block.
+            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+            qkv_bias (bool): If True, add a learnable bias to query, key, value.
+            norm_layer (nn.Module): Normalization layer.
+            act_layer (nn.Module): Activation layer.
+            use_abs_pos (bool): If True, use absolute positional embeddings.
+            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            window_size (int): Window size for window attention blocks.
+            global_attn_indexes (list): Indexes for blocks using global attention.
+        """
+        super().__init__()
+        self.img_size = img_size
+
+        self.patch_embed = PatchEmbed(
+            kernel_size=(patch_size, patch_size),
+            stride=(patch_size, patch_size),
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+        )
+
+        self.pos_embed: Optional[nn.Parameter] = None
+        if use_abs_pos:
+            # Initialize absolute positional embedding with pretrain image size.
+            self.pos_embed = nn.Parameter(
+                torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
+            )
+
+        self.blocks = nn.ModuleList()
+        for i in range(depth):
+            block = Block(
+                dim=embed_dim,
+                num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                norm_layer=norm_layer,
+                act_layer=act_layer,
+                use_rel_pos=use_rel_pos,
+                rel_pos_zero_init=rel_pos_zero_init,
+                window_size=window_size if i not in global_attn_indexes else 0,
+                input_size=(img_size // patch_size, img_size // patch_size),
+            )
+            self.blocks.append(block)
+
+        self.neck = nn.Sequential(
+            nn.Conv2d(
+                embed_dim,
+                out_chans,
+                kernel_size=1,
+                bias=False,
+            ),
+            LayerNorm2d(out_chans),
+            nn.Conv2d(
+                out_chans,
+                out_chans,
+                kernel_size=3,
+                padding=1,
+                bias=False,
+            ),
+            LayerNorm2d(out_chans),
+        )
+    def forward(self, x: torch.Tensor, before_channel_reduc=False) -> torch.Tensor:
+
+        x = self.patch_embed(x)
+
+        if self.pos_embed is not None:
+            x = x + self.pos_embed
+        
+        for blk in self.blocks:
+            x = blk(x)
+
+        embed_no_reduc = x
+        x = self.neck(x.permute(0, 3, 1, 2))
+
+        # Extract image embedding before channel reduction, cf. https://github.com/facebookresearch/segment-anything/issues/283
+        if before_channel_reduc:
+            return x, embed_no_reduc
+        
+        return x
+
+
+class Block(nn.Module):
+    """Transformer blocks with support of window attention and residual propagation blocks"""
+
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        qkv_bias: bool = True,
+        norm_layer: Type[nn.Module] = nn.LayerNorm,
+        act_layer: Type[nn.Module] = nn.GELU,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        window_size: int = 0,
+        input_size: Optional[Tuple[int, int]] = None,
+    ) -> None:
+        """
+        Args:
+            dim (int): Number of input channels.
+            num_heads (int): Number of attention heads in each ViT block.
+            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+            qkv_bias (bool): If True, add a learnable bias to query, key, value.
+            norm_layer (nn.Module): Normalization layer.
+            act_layer (nn.Module): Activation layer.
+            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            window_size (int): Window size for window attention blocks. If it equals 0, then
+                use global attention.
+            input_size (tuple(int, int) or None): Input resolution for calculating the relative
+                positional parameter size.
+        """
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            use_rel_pos=use_rel_pos,
+            rel_pos_zero_init=rel_pos_zero_init,
+            input_size=input_size if window_size == 0 else (window_size, window_size),
+        )
+
+        self.norm2 = norm_layer(dim)
+        self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+
+        self.window_size = window_size
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        shortcut = x
+        x = self.norm1(x)
+        # Window partition
+        if self.window_size > 0:
+            H, W = x.shape[1], x.shape[2]
+            x, pad_hw = window_partition(x, self.window_size)
+
+        x = self.attn(x)
+
+        # Reverse window partition
+        if self.window_size > 0:
+            x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+        torch.cuda.empty_cache()
+
+        x = shortcut + x
+        x = x + self.mlp(self.norm2(x))
+
+        return x
+
+
+class Attention(nn.Module):
+    """Multi-head Attention block with relative position embeddings."""
+
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int = 8,
+        qkv_bias: bool = True,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        input_size: Optional[Tuple[int, int]] = None,
+    ) -> None:
+        """
+        Args:
+            dim (int): Number of input channels.
+            num_heads (int): Number of attention heads.
+            qkv_bias (bool):  If True, add a learnable bias to query, key, value.
+            rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            input_size (tuple(int, int) or None): Input resolution for calculating the relative
+                positional parameter size.
+        """
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = head_dim**-0.5
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.proj = nn.Linear(dim, dim)
+
+        self.use_rel_pos = use_rel_pos
+        if self.use_rel_pos:
+            assert (
+                input_size is not None
+            ), "Input size must be provided if using relative positional encoding."
+            # initialize relative positional embeddings
+            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        B, H, W, _ = x.shape
+        # qkv with shape (3, B, nHead, H * W, C)
+        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        # q, k, v with shape (B * nHead, H * W, C)
+        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+        x = (q * self.scale) @ k.transpose(-2, -1)
+        if self.use_rel_pos:
+            x = add_decomposed_rel_pos(x, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+        x = x.softmax(dim=-1)
+        x = (x @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+        x = self.proj(x)
+
+        return x
+
+
+def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
+    """
+    Partition into non-overlapping windows with padding if needed.
+    Args:
+        x (tensor): input tokens with [B, H, W, C].
+        window_size (int): window size.
+
+    Returns:
+        windows: windows after partition with [B * num_windows, window_size, window_size, C].
+        (Hp, Wp): padded height and width before partition
+    """
+    B, H, W, C = x.shape
+
+    pad_h = (window_size - H % window_size) % window_size
+    pad_w = (window_size - W % window_size) % window_size
+    if pad_h > 0 or pad_w > 0:
+        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+    Hp, Wp = H + pad_h, W + pad_w
+    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 
+    return x, (Hp, Wp)
+
+
+def window_unpartition(
+    windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
+) -> torch.Tensor:
+    """
+    Window unpartition into original sequences and removing padding.
+    Args:
+        windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+        window_size (int): window size.
+        pad_hw (Tuple): padded height and width (Hp, Wp).
+        hw (Tuple): original height and width (H, W) before padding.
+
+    Returns:
+        x: unpartitioned sequences with [B, H, W, C].
+    """
+    Hp, Wp = pad_hw
+    H, W = hw
+    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+    if Hp > H or Wp > W:
+        x = x[:, :H, :W, :].contiguous()
+    return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+    """
+    Get relative positional embeddings according to the relative positions of
+        query and key sizes.
+    Args:
+        q_size (int): size of query q.
+        k_size (int): size of key k.
+        rel_pos (Tensor): relative position embeddings (L, C).
+
+    Returns:
+        Extracted positional embeddings according to relative positions.
+    """
+    max_rel_dist = int(2 * max(q_size, k_size) - 1)
+    # Interpolate rel pos if needed.
+    if rel_pos.shape[0] != max_rel_dist:
+        # Interpolate rel pos.
+        rel_pos = F.interpolate(
+            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+            size=max_rel_dist,
+            mode="linear",
+        )
+        rel_pos = rel_pos.reshape(-1, max_rel_dist).permute(1, 0)
+
+    # Scale the coords with short length if shapes for q and k are different.
+    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+    return rel_pos[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(
+    attn: torch.Tensor,
+    q: torch.Tensor,
+    rel_pos_h: torch.Tensor,
+    rel_pos_w: torch.Tensor,
+    q_size: Tuple[int, int],
+    k_size: Tuple[int, int],
+) -> torch.Tensor:
+    """
+    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py   # noqa B950
+    Args:
+        attn (Tensor): attention map.
+        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
+        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
+        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
+        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
+        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
+
+    Returns:
+        attn (Tensor): attention map with added relative positional embeddings.
+    """
+    q_h, q_w = q_size
+    k_h, k_w = k_size
+    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+    Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+    B, _, dim = q.shape
+    r_q = q.reshape(B, q_h, q_w, dim)
+    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+    attn = (
+        attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+    ).view(B, q_h * q_w, k_h * k_w)
+
+    return attn
+
+
+class PatchEmbed(nn.Module):
+    """
+    Image to Patch Embedding.
+    """
+
+    def __init__(
+        self,
+        kernel_size: Tuple[int, int] = (16, 16),
+        stride: Tuple[int, int] = (16, 16),
+        padding: Tuple[int, int] = (0, 0),
+        in_chans: int = 3,
+        embed_dim: int = 768,
+    ) -> None:
+        """
+        Args:
+            kernel_size (Tuple): kernel size of the projection layer.
+            stride (Tuple): stride of the projection layer.
+            padding (Tuple): padding size of the projection layer.
+            in_chans (int): Number of input image channels.
+            embed_dim (int): Patch embedding dimension.
+        """
+        super().__init__()
+
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.proj(x)
+        # B C H W -> B H W C
+        x = x.permute(0, 2, 3, 1)
+        return x
diff --git a/osrt/sam/mask_decoder.py b/osrt/sam/mask_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..347c5c5977a65c36553c487416a281e0e5694356
--- /dev/null
+++ b/osrt/sam/mask_decoder.py
@@ -0,0 +1,181 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import List, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class MaskDecoder(nn.Module):
+    def __init__(
+        self,
+        *,
+        transformer_dim: int,
+        transformer: nn.Module,
+        num_multimask_outputs: int = 3,
+        activation: Type[nn.Module] = nn.GELU,
+        iou_head_depth: int = 3,
+        iou_head_hidden_dim: int = 256,
+    ) -> None:
+        """
+        Predicts masks given an image and prompt embeddings, using a
+        transformer architecture.
+
+        Arguments:
+          transformer_dim (int): the channel dimension of the transformer
+          transformer (nn.Module): the transformer used to predict masks
+          num_multimask_outputs (int): the number of masks to predict
+            when disambiguating masks
+          activation (nn.Module): the type of activation to use when
+            upscaling masks
+          iou_head_depth (int): the depth of the MLP used to predict
+            mask quality
+          iou_head_hidden_dim (int): the hidden dimension of the MLP
+            used to predict mask quality
+        """
+        super().__init__()
+        self.transformer_dim = transformer_dim
+        self.transformer = transformer
+
+        self.num_multimask_outputs = num_multimask_outputs
+
+        self.iou_token = nn.Embedding(1, transformer_dim)
+        self.num_mask_tokens = num_multimask_outputs + 1
+        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+        self.output_upscaling = nn.Sequential(
+            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+            LayerNorm2d(transformer_dim // 4),
+            activation(),
+            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+            activation(),
+        )
+        self.output_hypernetworks_mlps = nn.ModuleList(
+            [
+                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+                for i in range(self.num_mask_tokens)
+            ]
+        )
+
+        self.iou_prediction_head = MLP(
+            transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
+        )
+
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        multimask_output: bool,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Predict masks given image and prompt embeddings.
+
+        Arguments:
+          image_embeddings (torch.Tensor): the embeddings from the image encoder
+          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+          multimask_output (bool): Whether to return multiple masks or a single
+            mask.
+
+        Returns:
+          torch.Tensor: batched predicted masks
+          torch.Tensor: batched predictions of mask quality
+        """
+        masks, iou_pred = self.predict_masks(
+            image_embeddings=image_embeddings,
+            image_pe=image_pe,
+            sparse_prompt_embeddings=sparse_prompt_embeddings,
+            dense_prompt_embeddings=dense_prompt_embeddings,
+        )
+
+        # Select the correct mask or masks for output
+        if multimask_output:
+            mask_slice = slice(1, None)
+        else:
+            mask_slice = slice(0, 1)
+        masks = masks[:, mask_slice, :, :]
+        iou_pred = iou_pred[:, mask_slice]
+
+        # Prepare output
+        return masks, iou_pred
+
+    def predict_masks(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Predicts masks. See 'forward' for more details."""
+        # Concatenate output tokens
+        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+        # Expand per-image data in batch direction to be per-mask
+        if image_embeddings.shape[0] != tokens.shape[0]:
+            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+        else:
+            src = image_embeddings
+        src = src + dense_prompt_embeddings
+
+        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+        b, c, h, w = src.shape
+
+
+        # Run the transformer
+        hs, src = self.transformer(src, pos_src, tokens)
+        iou_token_out = hs[:, 0, :]
+        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+        # Upscale mask embeddings and predict masks using the mask tokens
+        src = src.transpose(1, 2).view(b, c, h, w)
+        upscaled_embedding = self.output_upscaling(src)
+        hyper_in_list: List[torch.Tensor] = []
+        for i in range(self.num_mask_tokens):
+            hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+        hyper_in = torch.stack(hyper_in_list, dim=1)
+        b, c, h, w = upscaled_embedding.shape
+        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+        # Generate mask quality predictions
+        iou_pred = self.iou_prediction_head(iou_token_out)
+
+        return masks, iou_pred
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+    def __init__(
+        self,
+        input_dim: int,
+        hidden_dim: int,
+        output_dim: int,
+        num_layers: int,
+        sigmoid_output: bool = False,
+    ) -> None:
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(
+            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+        )
+        self.sigmoid_output = sigmoid_output
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        if self.sigmoid_output:
+            x = F.sigmoid(x)
+        return x
diff --git a/osrt/sam/prompt_encoder.py b/osrt/sam/prompt_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3143f4f8e02ddd7ca8587b40ff5d47c3a6b7ef3
--- /dev/null
+++ b/osrt/sam/prompt_encoder.py
@@ -0,0 +1,214 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch import nn
+
+from typing import Any, Optional, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class PromptEncoder(nn.Module):
+    def __init__(
+        self,
+        embed_dim: int,
+        image_embedding_size: Tuple[int, int],
+        input_image_size: Tuple[int, int],
+        mask_in_chans: int,
+        activation: Type[nn.Module] = nn.GELU,
+    ) -> None:
+        """
+        Encodes prompts for input to SAM's mask decoder.
+
+        Arguments:
+          embed_dim (int): The prompts' embedding dimension
+          image_embedding_size (tuple(int, int)): The spatial size of the
+            image embedding, as (H, W).
+          input_image_size (int): The padded size of the image as input
+            to the image encoder, as (H, W).
+          mask_in_chans (int): The number of hidden channels used for
+            encoding input masks.
+          activation (nn.Module): The activation to use when encoding
+            input masks.
+        """
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.input_image_size = input_image_size
+        self.image_embedding_size = image_embedding_size
+        self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+        self.num_point_embeddings: int = 4  # pos/neg point + 2 box corners
+        point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
+        self.point_embeddings = nn.ModuleList(point_embeddings)
+        self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+        self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
+        self.mask_downscaling = nn.Sequential(
+            nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+            LayerNorm2d(mask_in_chans // 4),
+            activation(),
+            nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+            LayerNorm2d(mask_in_chans),
+            activation(),
+            nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+        )
+        self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+    def get_dense_pe(self) -> torch.Tensor:
+        """
+        Returns the positional encoding used to encode point prompts,
+        applied to a dense set of points the shape of the image encoding.
+
+        Returns:
+          torch.Tensor: Positional encoding with shape
+            1x(embed_dim)x(embedding_h)x(embedding_w)
+        """
+        return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+    def _embed_points(
+        self,
+        points: torch.Tensor,
+        labels: torch.Tensor,
+        pad: bool,
+    ) -> torch.Tensor:
+        """Embeds point prompts."""
+        points = points + 0.5  # Shift to center of pixel
+        if pad:
+            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+            padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+            points = torch.cat([points, padding_point], dim=1)
+            labels = torch.cat([labels, padding_label], dim=1)
+        point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
+        point_embedding[labels == -1] = 0.0
+        point_embedding[labels == -1] += self.not_a_point_embed.weight
+        point_embedding[labels == 0] += self.point_embeddings[0].weight
+        point_embedding[labels == 1] += self.point_embeddings[1].weight
+        return point_embedding
+
+    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+        """Embeds box prompts."""
+        boxes = boxes + 0.5  # Shift to center of pixel
+        coords = boxes.reshape(-1, 2, 2)
+        corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
+        corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+        corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+        return corner_embedding
+
+    def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+        """Embeds mask inputs."""
+        mask_embedding = self.mask_downscaling(masks)
+        return mask_embedding
+
+    def _get_batch_size(
+        self,
+        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+        boxes: Optional[torch.Tensor],
+        masks: Optional[torch.Tensor],
+    ) -> int:
+        """
+        Gets the batch size of the output given the batch size of the input prompts.
+        """
+        if points is not None:
+            return points[0].shape[0]
+        elif boxes is not None:
+            return boxes.shape[0]
+        elif masks is not None:
+            return masks.shape[0]
+        else:
+            return 1
+
+    def _get_device(self) -> torch.device:
+        return self.point_embeddings[0].weight.device
+
+    def forward(
+        self,
+        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+        boxes: Optional[torch.Tensor],
+        masks: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Embeds different types of prompts, returning both sparse and dense
+        embeddings.
+
+        Arguments:
+          points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
+            and labels to embed.
+          boxes (torch.Tensor or none): boxes to embed
+          masks (torch.Tensor or none): masks to embed
+
+        Returns:
+          torch.Tensor: sparse embeddings for the points and boxes, with shape
+            BxNx(embed_dim), where N is determined by the number of input points
+            and boxes.
+          torch.Tensor: dense embeddings for the masks, in the shape
+            Bx(embed_dim)x(embed_H)x(embed_W)
+        """
+        bs = self._get_batch_size(points, boxes, masks)
+        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
+        if points is not None:
+            coords, labels = points
+            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+        if boxes is not None:
+            box_embeddings = self._embed_boxes(boxes)
+            sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+
+        if masks is not None:
+            dense_embeddings = self._embed_masks(masks)
+        else:
+            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+            )
+
+        return sparse_embeddings, dense_embeddings
+
+
+class PositionEmbeddingRandom(nn.Module):
+    """
+    Positional encoding using random spatial frequencies.
+    """
+
+    def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+        super().__init__()
+        if scale is None or scale <= 0.0:
+            scale = 1.0
+        self.register_buffer(
+            "positional_encoding_gaussian_matrix",
+            scale * torch.randn((2, num_pos_feats)),
+        )
+
+    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+        """Positionally encode points that are normalized to [0,1]."""
+        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+        coords = 2 * coords - 1
+        coords = coords @ self.positional_encoding_gaussian_matrix
+        coords = 2 * np.pi * coords
+        # outputs d_1 x ... x d_n x C shape
+        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+        """Generate positional encoding for a grid of the specified size."""
+        h, w = size
+        device: Any = self.positional_encoding_gaussian_matrix.device
+        grid = torch.ones((h, w), device=device, dtype=torch.float32)
+        y_embed = grid.cumsum(dim=0) - 0.5
+        x_embed = grid.cumsum(dim=1) - 0.5
+        y_embed = y_embed / h
+        x_embed = x_embed / w
+
+        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+        return pe.permute(2, 0, 1)  # C x H x W
+
+    def forward_with_coords(
+        self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+    ) -> torch.Tensor:
+        """Positionally encode points that are not normalized to [0,1]."""
+        coords = coords_input.clone()
+        coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+        coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+        return self._pe_encoding(coords.to(torch.float))  # B x N x C
diff --git a/osrt/sam/sam.py b/osrt/sam/sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..8074cff6b40addc6b66f7ab4962218eef20da13c
--- /dev/null
+++ b/osrt/sam/sam.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import Any, Dict, List, Tuple
+
+from .image_encoder import ImageEncoderViT
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+
+
+class Sam(nn.Module):
+    mask_threshold: float = 0.0
+    image_format: str = "RGB"
+
+    def __init__(
+        self,
+        image_encoder: ImageEncoderViT,
+        prompt_encoder: PromptEncoder,
+        mask_decoder: MaskDecoder,
+        pixel_mean: List[float] = [123.675, 116.28, 103.53],
+        pixel_std: List[float] = [58.395, 57.12, 57.375],
+    ) -> None:
+        """
+        SAM predicts object masks from an image and input prompts.
+
+        Arguments:
+          image_encoder (ImageEncoderViT): The backbone used to encode the
+            image into image embeddings that allow for efficient mask prediction.
+          prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+          mask_decoder (MaskDecoder): Predicts masks from the image embeddings
+            and encoded prompts.
+          pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
+          pixel_std (list(float)): Std values for normalizing pixels in the input image.
+        """
+        super().__init__()
+        self.image_encoder = image_encoder
+        self.prompt_encoder = prompt_encoder
+        self.mask_decoder = mask_decoder
+        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+    @property
+    def device(self) -> Any:
+        return self.pixel_mean.device
+
+    @torch.no_grad()
+    def forward(
+        self,
+        batched_input: List[Dict[str, Any]],
+        multimask_output: bool,
+    ) -> List[Dict[str, torch.Tensor]]:
+        """
+        Predicts masks end-to-end from provided images and prompts.
+        If prompts are not known in advance, using SamPredictor is
+        recommended over calling the model directly.
+
+        Arguments:
+          batched_input (list(dict)): A list over input images, each a
+            dictionary with the following keys. A prompt key can be
+            excluded if it is not present.
+              'image': The image as a torch tensor in 3xHxW format,
+                already transformed for input to the model.
+              'original_size': (tuple(int, int)) The original size of
+                the image before transformation, as (H, W).
+              'point_coords': (torch.Tensor) Batched point prompts for
+                this image, with shape BxNx2. Already transformed to the
+                input frame of the model.
+              'point_labels': (torch.Tensor) Batched labels for point prompts,
+                with shape BxN.
+              'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
+                Already transformed to the input frame of the model.
+              'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
+                in the form Bx1xHxW.
+          multimask_output (bool): Whether the model should predict multiple
+            disambiguating masks, or return a single mask.
+
+        Returns:
+          (list(dict)): A list over input images, where each element is
+            as dictionary with the following keys.
+              'masks': (torch.Tensor) Batched binary mask predictions,
+                with shape BxCxHxW, where B is the number of input prompts,
+                C is determined by multimask_output, and (H, W) is the
+                original size of the image.
+              'iou_predictions': (torch.Tensor) The model's predictions
+                of mask quality, in shape BxC.
+              'low_res_logits': (torch.Tensor) Low resolution logits with
+                shape BxCxHxW, where H=W=256. Can be passed as mask input
+                to subsequent iterations of prediction.
+        """
+        input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
+        image_embeddings = self.image_encoder(input_images)
+
+        outputs = []
+        for image_record, curr_embedding in zip(batched_input, image_embeddings):
+            if "point_coords" in image_record:
+                points = (image_record["point_coords"], image_record["point_labels"])
+            else:
+                points = None
+            sparse_embeddings, dense_embeddings = self.prompt_encoder(
+                points=points,
+                boxes=image_record.get("boxes", None),
+                masks=image_record.get("mask_inputs", None),
+            )
+            low_res_masks, iou_predictions = self.mask_decoder(
+                image_embeddings=curr_embedding.unsqueeze(0),
+                image_pe=self.prompt_encoder.get_dense_pe(),
+                sparse_prompt_embeddings=sparse_embeddings,
+                dense_prompt_embeddings=dense_embeddings,
+                multimask_output=multimask_output,
+            )
+            masks = self.postprocess_masks(
+                low_res_masks,
+                input_size=image_record["image"].shape[-2:],
+                original_size=image_record["original_size"],
+            )
+            masks = masks > self.mask_threshold
+            outputs.append(
+                {
+                    "masks": masks,
+                    "iou_predictions": iou_predictions,
+                    "low_res_logits": low_res_masks,
+                }
+            )
+        return outputs
+
+    def postprocess_masks(
+        self,
+        masks: torch.Tensor,
+        input_size: Tuple[int, ...],
+        original_size: Tuple[int, ...],
+    ) -> torch.Tensor:
+        """
+        Remove padding and upscale masks to the original image size.
+
+        Arguments:
+          masks (torch.Tensor): Batched masks from the mask_decoder,
+            in BxCxHxW format.
+          input_size (tuple(int, int)): The size of the image input to the
+            model, in (H, W) format. Used to remove padding.
+          original_size (tuple(int, int)): The original size of the image
+            before resizing for input to the model, in (H, W) format.
+
+        Returns:
+          (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
+            is given by original_size.
+        """
+        masks = F.interpolate(
+            masks,
+            (self.image_encoder.img_size, self.image_encoder.img_size),
+            mode="bilinear",
+            align_corners=False,
+        )
+        masks = masks[..., : input_size[0], : input_size[1]]
+        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
+        return masks
+
+    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+        """Normalize pixel values and pad to a square input."""
+        # Normalize colors
+        x = (x - self.pixel_mean) / self.pixel_std
+
+        # Pad
+        h, w = x.shape[-2:]
+        padh = self.image_encoder.img_size - h
+        padw = self.image_encoder.img_size - w
+        x = F.pad(x, (0, padw, 0, padh))
+        return x
diff --git a/osrt/sam/transformer.py b/osrt/sam/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..28fafea52288603fea275f3a100790471825c34a
--- /dev/null
+++ b/osrt/sam/transformer.py
@@ -0,0 +1,240 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor, nn
+
+import math
+from typing import Tuple, Type
+
+from .common import MLPBlock
+
+
+class TwoWayTransformer(nn.Module):
+    def __init__(
+        self,
+        depth: int,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int,
+        activation: Type[nn.Module] = nn.ReLU,
+        attention_downsample_rate: int = 2,
+    ) -> None:
+        """
+        A transformer decoder that attends to an input image using
+        queries whose positional embedding is supplied.
+
+        Args:
+          depth (int): number of layers in the transformer
+          embedding_dim (int): the channel dimension for the input embeddings
+          num_heads (int): the number of heads for multihead attention. Must
+            divide embedding_dim
+          mlp_dim (int): the channel dimension internal to the MLP block
+          activation (nn.Module): the activation to use in the MLP block
+        """
+        super().__init__()
+        self.depth = depth
+        self.embedding_dim = embedding_dim
+        self.num_heads = num_heads
+        self.mlp_dim = mlp_dim
+        self.layers = nn.ModuleList()
+
+        for i in range(depth):
+            self.layers.append(
+                TwoWayAttentionBlock(
+                    embedding_dim=embedding_dim,
+                    num_heads=num_heads,
+                    mlp_dim=mlp_dim,
+                    activation=activation,
+                    attention_downsample_rate=attention_downsample_rate,
+                    skip_first_layer_pe=(i == 0),
+                )
+            )
+
+        self.final_attn_token_to_image = Attention(
+            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+        )
+        self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+    def forward(
+        self,
+        image_embedding: Tensor,
+        image_pe: Tensor,
+        point_embedding: Tensor,
+    ) -> Tuple[Tensor, Tensor]:
+        """
+        Args:
+          image_embedding (torch.Tensor): image to attend to. Should be shape
+            B x embedding_dim x h x w for any h and w.
+          image_pe (torch.Tensor): the positional encoding to add to the image. Must
+            have the same shape as image_embedding.
+          point_embedding (torch.Tensor): the embedding to add to the query points.
+            Must have shape B x N_points x embedding_dim for any N_points.
+
+        Returns:
+          torch.Tensor: the processed point_embedding
+          torch.Tensor: the processed image_embedding
+        """
+        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+        bs, c, h, w = image_embedding.shape
+        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+        image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+        # Prepare queries
+        queries = point_embedding
+        keys = image_embedding
+
+        # Apply transformer blocks and final layernorm
+        for layer in self.layers:
+            queries, keys = layer(
+                queries=queries,
+                keys=keys,
+                query_pe=point_embedding,
+                key_pe=image_pe,
+            )
+
+        # Apply the final attention layer from the points to the image
+        q = queries + point_embedding
+        k = keys + image_pe
+        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+        queries = queries + attn_out
+        queries = self.norm_final_attn(queries)
+
+        return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int = 2048,
+        activation: Type[nn.Module] = nn.ReLU,
+        attention_downsample_rate: int = 2,
+        skip_first_layer_pe: bool = False,
+    ) -> None:
+        """
+        A transformer block with four layers: (1) self-attention of sparse
+        inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+        block on sparse inputs, and (4) cross attention of dense inputs to sparse
+        inputs.
+
+        Arguments:
+          embedding_dim (int): the channel dimension of the embeddings
+          num_heads (int): the number of heads in the attention layers
+          mlp_dim (int): the hidden dimension of the mlp block
+          activation (nn.Module): the activation of the mlp block
+          skip_first_layer_pe (bool): skip the PE on the first layer
+        """
+        super().__init__()
+        self.self_attn = Attention(embedding_dim, num_heads)
+        self.norm1 = nn.LayerNorm(embedding_dim)
+
+        self.cross_attn_token_to_image = Attention(
+            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+        )
+        self.norm2 = nn.LayerNorm(embedding_dim)
+
+        self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
+        self.norm3 = nn.LayerNorm(embedding_dim)
+
+        self.norm4 = nn.LayerNorm(embedding_dim)
+        self.cross_attn_image_to_token = Attention(
+            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+        )
+
+        self.skip_first_layer_pe = skip_first_layer_pe
+
+    def forward(
+        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        # Self attention block
+        if self.skip_first_layer_pe:
+            queries = self.self_attn(q=queries, k=queries, v=queries)
+        else:
+            q = queries + query_pe
+            attn_out = self.self_attn(q=q, k=q, v=queries)
+            queries = queries + attn_out
+        queries = self.norm1(queries)
+
+        # Cross attention block, tokens attending to image embedding
+        q = queries + query_pe
+        k = keys + key_pe
+        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+        queries = queries + attn_out
+        queries = self.norm2(queries)
+
+        # MLP block
+        mlp_out = self.mlp(queries)
+        queries = queries + mlp_out
+        queries = self.norm3(queries)
+
+        # Cross attention block, image embedding attending to tokens
+        q = queries + query_pe
+        k = keys + key_pe
+        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+        keys = keys + attn_out
+        keys = self.norm4(keys)
+
+        return queries, keys
+
+
+class Attention(nn.Module):
+    """
+    An attention layer that allows for downscaling the size of the embedding
+    after projection to queries, keys, and values.
+    """
+
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_heads: int,
+        downsample_rate: int = 1,
+    ) -> None:
+        super().__init__()
+        self.embedding_dim = embedding_dim
+        self.internal_dim = embedding_dim // downsample_rate
+        self.num_heads = num_heads
+        assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
+
+        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+    def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+        b, n, c = x.shape
+        x = x.reshape(b, n, num_heads, c // num_heads)
+        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head
+
+    def _recombine_heads(self, x: Tensor) -> Tensor:
+        b, n_heads, n_tokens, c_per_head = x.shape
+        x = x.transpose(1, 2)
+        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C
+
+    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+        # Input projections
+        q = self.q_proj(q)
+        k = self.k_proj(k)
+        v = self.v_proj(v)
+
+        # Separate into heads
+        q = self._separate_heads(q, self.num_heads)
+        k = self._separate_heads(k, self.num_heads)
+        v = self._separate_heads(v, self.num_heads)
+
+        # Attention
+        _, _, _, c_per_head = q.shape
+        attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
+        attn = attn / math.sqrt(c_per_head)
+        attn = torch.softmax(attn, dim=-1)
+
+        # Get output
+        out = attn @ v
+        out = self._recombine_heads(out)
+        out = self.out_proj(out)
+
+        return out
diff --git a/osrt/sam/utils/__init__.py b/osrt/sam/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/osrt/sam/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/osrt/sam/utils/amg.py b/osrt/sam/utils/amg.py
new file mode 100644
index 0000000000000000000000000000000000000000..b114d37a3f7acae1b2cef8f6724f48b5950c4cbe
--- /dev/null
+++ b/osrt/sam/utils/amg.py
@@ -0,0 +1,347 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+
+class MaskData:
+    """
+    A structure for storing masks and their related data in batched format.
+    Implements basic filtering and concatenation.
+    """
+
+    def __init__(self, **kwargs) -> None:
+        for v in kwargs.values():
+            assert isinstance(
+                v, (list, np.ndarray, torch.Tensor)
+            ), "MaskData only supports list, numpy arrays, and torch tensors."
+        self._stats = dict(**kwargs)
+
+    def __setitem__(self, key: str, item: Any) -> None:
+        assert isinstance(
+            item, (list, np.ndarray, torch.Tensor)
+        ), "MaskData only supports list, numpy arrays, and torch tensors."
+        self._stats[key] = item
+
+    def __delitem__(self, key: str) -> None:
+        del self._stats[key]
+
+    def __getitem__(self, key: str) -> Any:
+        return self._stats[key]
+
+    def items(self) -> ItemsView[str, Any]:
+        return self._stats.items()
+
+    def filter(self, keep: torch.Tensor) -> None:
+        for k, v in self._stats.items():
+            if v is None:
+                self._stats[k] = None
+            elif isinstance(v, torch.Tensor):
+                self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+            elif isinstance(v, np.ndarray):
+                self._stats[k] = v[keep.detach().cpu().numpy()]
+            elif isinstance(v, list) and keep.dtype == torch.bool:
+                self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+            elif isinstance(v, list):
+                self._stats[k] = [v[i] for i in keep]
+            else:
+                raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+
+    def cat(self, new_stats: "MaskData") -> None:
+        for k, v in new_stats.items():
+            if k not in self._stats or self._stats[k] is None:
+                self._stats[k] = deepcopy(v)
+            elif isinstance(v, torch.Tensor):
+                self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+            elif isinstance(v, np.ndarray):
+                self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+            elif isinstance(v, list):
+                self._stats[k] = self._stats[k] + deepcopy(v)
+            else:
+                raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+    def to_numpy(self) -> None:
+        for k, v in self._stats.items():
+            if isinstance(v, torch.Tensor):
+                self._stats[k] = v.detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(
+    boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+    """Filter masks at the edge of a crop, but not at the edge of the original image."""
+    crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+    orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+    boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+    near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+    near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+    near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+    return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+    box_xywh = deepcopy(box_xyxy)
+    box_xywh[2] = box_xywh[2] - box_xywh[0]
+    box_xywh[3] = box_xywh[3] - box_xywh[1]
+    return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+    assert len(args) > 0 and all(
+        len(a) == len(args[0]) for a in args
+    ), "Batched iteration must have inputs of all the same size."
+    n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+    for b in range(n_batches):
+        yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+    """
+    Encodes masks to an uncompressed RLE, in the format expected by
+    pycoco tools.
+    """
+    # Put in fortran order and flatten h,w
+    b, h, w = tensor.shape
+    tensor = tensor.permute(0, 2, 1).flatten(1)
+
+    # Compute change indices
+    diff = tensor[:, 1:] ^ tensor[:, :-1]
+    change_indices = diff.nonzero()
+
+    # Encode run length
+    out = []
+    for i in range(b):
+        cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+        cur_idxs = torch.cat(
+            [
+                torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+                cur_idxs + 1,
+                torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
+            ]
+        )
+        btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+        counts = [] if tensor[i, 0] == 0 else [0]
+        counts.extend(btw_idxs.detach().cpu().tolist())
+        out.append({"size": [h, w], "counts": counts})
+    return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+    """Compute a binary mask from an uncompressed RLE."""
+    h, w = rle["size"]
+    mask = np.empty(h * w, dtype=bool)
+    idx = 0
+    parity = False
+    for count in rle["counts"]:
+        mask[idx : idx + count] = parity
+        idx += count
+        parity ^= True
+    mask = mask.reshape(w, h)
+    return mask.transpose()  # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+    return sum(rle["counts"][1::2])
+
+
+def calculate_stability_score(
+    masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+    """
+    Computes the stability score for a batch of masks. The stability
+    score is the IoU between the binary masks obtained by thresholding
+    the predicted mask logits at high and low values.
+    """
+    # One mask is always contained inside the other.
+    # Save memory by preventing unnecessary cast to torch.int64
+    intersections = (
+        (masks > (mask_threshold + threshold_offset))
+        .sum(-1, dtype=torch.int16)
+        .sum(-1, dtype=torch.int32)
+    )
+    unions = (
+        (masks > (mask_threshold - threshold_offset))
+        .sum(-1, dtype=torch.int16)
+        .sum(-1, dtype=torch.int32)
+    )
+    return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+    """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+    offset = 1 / (2 * n_per_side)
+    points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+    points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+    points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+    points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+    return points
+
+
+def build_all_layer_point_grids(
+    n_per_side: int, n_layers: int, scale_per_layer: int
+) -> List[np.ndarray]:
+    """Generates point grids for all crop layers."""
+    points_by_layer = []
+    for i in range(n_layers + 1):
+        n_points = int(n_per_side / (scale_per_layer**i))
+        points_by_layer.append(build_point_grid(n_points))
+    return points_by_layer
+
+
+def generate_crop_boxes(
+    im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+    """
+    Generates a list of crop boxes of different sizes. Each layer
+    has (2**i)**2 boxes for the ith layer.
+    """
+    crop_boxes, layer_idxs = [], []
+    im_h, im_w = im_size
+    short_side = min(im_h, im_w)
+
+    # Original image
+    crop_boxes.append([0, 0, im_w, im_h])
+    layer_idxs.append(0)
+
+    def crop_len(orig_len, n_crops, overlap):
+        return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+    for i_layer in range(n_layers):
+        n_crops_per_side = 2 ** (i_layer + 1)
+        overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+        crop_w = crop_len(im_w, n_crops_per_side, overlap)
+        crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+        crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+        crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+        # Crops in XYWH format
+        for x0, y0 in product(crop_box_x0, crop_box_y0):
+            box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+            crop_boxes.append(box)
+            layer_idxs.append(i_layer + 1)
+
+    return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+    x0, y0, _, _ = crop_box
+    offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+    # Check if boxes has a channel dimension
+    if len(boxes.shape) == 3:
+        offset = offset.unsqueeze(1)
+    return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+    x0, y0, _, _ = crop_box
+    offset = torch.tensor([[x0, y0]], device=points.device)
+    # Check if points has a channel dimension
+    if len(points.shape) == 3:
+        offset = offset.unsqueeze(1)
+    return points + offset
+
+
+def uncrop_masks(
+    masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
+) -> torch.Tensor:
+    x0, y0, x1, y1 = crop_box
+    if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+        return masks
+    # Coordinate transform masks
+    pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+    pad = (x0, pad_x - x0, y0, pad_y - y0)
+    return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(
+    mask: np.ndarray, area_thresh: float, mode: str
+) -> Tuple[np.ndarray, bool]:
+    """
+    Removes small disconnected regions and holes in a mask. Returns the
+    mask and an indicator of if the mask has been modified.
+    """
+    import cv2  # type: ignore
+
+    assert mode in ["holes", "islands"]
+    correct_holes = mode == "holes"
+    working_mask = (correct_holes ^ mask).astype(np.uint8)
+    n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+    sizes = stats[:, -1][1:]  # Row 0 is background label
+    small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+    if len(small_regions) == 0:
+        return mask, False
+    fill_labels = [0] + small_regions
+    if not correct_holes:
+        fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+        # If every region is below threshold, keep largest
+        if len(fill_labels) == 0:
+            fill_labels = [int(np.argmax(sizes)) + 1]
+    mask = np.isin(regions, fill_labels)
+    return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+    from pycocotools import mask as mask_utils  # type: ignore
+
+    h, w = uncompressed_rle["size"]
+    rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+    rle["counts"] = rle["counts"].decode("utf-8")  # Necessary to serialize with json
+    return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+    """
+    Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
+    an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
+    """
+    # torch.max below raises an error on empty inputs, just skip in this case
+    if torch.numel(masks) == 0:
+        return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+    # Normalize shape to CxHxW
+    shape = masks.shape
+    h, w = shape[-2:]
+    if len(shape) > 2:
+        masks = masks.flatten(0, -3)
+    else:
+        masks = masks.unsqueeze(0)
+
+    # Get top and bottom edges
+    in_height, _ = torch.max(masks, dim=-1)
+    in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+    bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+    in_height_coords = in_height_coords + h * (~in_height)
+    top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+    # Get left and right edges
+    in_width, _ = torch.max(masks, dim=-2)
+    in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+    right_edges, _ = torch.max(in_width_coords, dim=-1)
+    in_width_coords = in_width_coords + w * (~in_width)
+    left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+    # If the mask is empty the right edge will be to the left of the left edge.
+    # Replace these boxes with [0, 0, 0, 0]
+    empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+    out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+    out = out * (~empty_filter).unsqueeze(-1)
+
+    # Return to original shape
+    if len(shape) > 2:
+        out = out.reshape(*shape[:-2], 4)
+    else:
+        out = out[0]
+
+    return out
diff --git a/osrt/sam/utils/onnx.py b/osrt/sam/utils/onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..3196bdf4b782e6eeb3da4ad66ef3c7b1741535fe
--- /dev/null
+++ b/osrt/sam/utils/onnx.py
@@ -0,0 +1,144 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from typing import Tuple
+
+from ..modeling import Sam
+from .amg import calculate_stability_score
+
+
+class SamOnnxModel(nn.Module):
+    """
+    This model should not be called directly, but is used in ONNX export.
+    It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
+    with some functions modified to enable model tracing. Also supports extra
+    options controlling what information. See the ONNX export script for details.
+    """
+
+    def __init__(
+        self,
+        model: Sam,
+        return_single_mask: bool,
+        use_stability_score: bool = False,
+        return_extra_metrics: bool = False,
+    ) -> None:
+        super().__init__()
+        self.mask_decoder = model.mask_decoder
+        self.model = model
+        self.img_size = model.image_encoder.img_size
+        self.return_single_mask = return_single_mask
+        self.use_stability_score = use_stability_score
+        self.stability_score_offset = 1.0
+        self.return_extra_metrics = return_extra_metrics
+
+    @staticmethod
+    def resize_longest_image_size(
+        input_image_size: torch.Tensor, longest_side: int
+    ) -> torch.Tensor:
+        input_image_size = input_image_size.to(torch.float32)
+        scale = longest_side / torch.max(input_image_size)
+        transformed_size = scale * input_image_size
+        transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
+        return transformed_size
+
+    def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
+        point_coords = point_coords + 0.5
+        point_coords = point_coords / self.img_size
+        point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
+        point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
+
+        point_embedding = point_embedding * (point_labels != -1)
+        point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
+            point_labels == -1
+        )
+
+        for i in range(self.model.prompt_encoder.num_point_embeddings):
+            point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
+                i
+            ].weight * (point_labels == i)
+
+        return point_embedding
+
+    def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
+        mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
+        mask_embedding = mask_embedding + (
+            1 - has_mask_input
+        ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
+        return mask_embedding
+
+    def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
+        masks = F.interpolate(
+            masks,
+            size=(self.img_size, self.img_size),
+            mode="bilinear",
+            align_corners=False,
+        )
+
+        prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
+        masks = masks[..., : prepadded_size[0], : prepadded_size[1]]  # type: ignore
+
+        orig_im_size = orig_im_size.to(torch.int64)
+        h, w = orig_im_size[0], orig_im_size[1]
+        masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
+        return masks
+
+    def select_masks(
+        self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Determine if we should return the multiclick mask or not from the number of points.
+        # The reweighting is used to avoid control flow.
+        score_reweight = torch.tensor(
+            [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
+        ).to(iou_preds.device)
+        score = iou_preds + (num_points - 2.5) * score_reweight
+        best_idx = torch.argmax(score, dim=1)
+        masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
+        iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
+
+        return masks, iou_preds
+
+    @torch.no_grad()
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        point_coords: torch.Tensor,
+        point_labels: torch.Tensor,
+        mask_input: torch.Tensor,
+        has_mask_input: torch.Tensor,
+        orig_im_size: torch.Tensor,
+    ):
+        sparse_embedding = self._embed_points(point_coords, point_labels)
+        dense_embedding = self._embed_masks(mask_input, has_mask_input)
+
+        masks, scores = self.model.mask_decoder.predict_masks(
+            image_embeddings=image_embeddings,
+            image_pe=self.model.prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embedding,
+            dense_prompt_embeddings=dense_embedding,
+        )
+
+        if self.use_stability_score:
+            scores = calculate_stability_score(
+                masks, self.model.mask_threshold, self.stability_score_offset
+            )
+
+        if self.return_single_mask:
+            masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
+
+        upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
+
+        if self.return_extra_metrics:
+            stability_scores = calculate_stability_score(
+                upscaled_masks, self.model.mask_threshold, self.stability_score_offset
+            )
+            areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
+            return upscaled_masks, scores, stability_scores, areas, masks
+
+        return upscaled_masks, scores, masks
diff --git a/osrt/sam/utils/transforms.py b/osrt/sam/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..e801bdf1a342f509c0a32d92b032ad369f0e2908
--- /dev/null
+++ b/osrt/sam/utils/transforms.py
@@ -0,0 +1,105 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torchvision.transforms.functional import resize, to_pil_image  # type: ignore
+
+from copy import deepcopy
+from typing import Tuple
+
+
+class ResizeLongestSide:
+    """
+    Resizes images to the longest side 'target_length', as well as provides
+    methods for resizing coordinates and boxes. Provides methods for
+    transforming both numpy array and batched torch tensors.
+    """
+
+    def __init__(self, target_length: int) -> None:
+        self.target_length = target_length
+
+    def apply_image(self, image: np.ndarray) -> np.ndarray:
+        """
+        Expects a numpy array with shape HxWxC in uint8 format.
+        """
+        target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+        # TODO : change this part to accept Tensors
+        if torch.is_tensor(image):
+            return np.array(resize(to_pil_image(image.permute(2, 0, 1)), target_size))
+        return np.array(resize(to_pil_image(image), target_size))
+
+    def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+        """
+        Expects a numpy array of length 2 in the final dimension. Requires the
+        original image size in (H, W) format.
+        """
+        old_h, old_w = original_size
+        new_h, new_w = self.get_preprocess_shape(
+            original_size[0], original_size[1], self.target_length
+        )
+        coords = deepcopy(coords).astype(float)
+        coords[..., 0] = coords[..., 0] * (new_w / old_w)
+        coords[..., 1] = coords[..., 1] * (new_h / old_h)
+        return coords
+
+    def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+        """
+        Expects a numpy array shape Bx4. Requires the original image size
+        in (H, W) format.
+        """
+        boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
+        return boxes.reshape(-1, 4)
+
+    def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
+        """
+        Expects batched images with shape BxCxHxW and float format. This
+        transformation may not exactly match apply_image. apply_image is
+        the transformation expected by the model.
+        """
+        # Expects an image in BCHW format. May not exactly match apply_image.
+        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
+        return F.interpolate(
+            image, target_size, mode="bilinear", align_corners=False, antialias=True
+        )
+
+    def apply_coords_torch(
+        self, coords: torch.Tensor, original_size: Tuple[int, ...]
+    ) -> torch.Tensor:
+        """
+        Expects a torch tensor with length 2 in the last dimension. Requires the
+        original image size in (H, W) format.
+        """
+        old_h, old_w = original_size
+        new_h, new_w = self.get_preprocess_shape(
+            original_size[0], original_size[1], self.target_length
+        )
+        coords = deepcopy(coords).to(torch.float)
+        coords[..., 0] = coords[..., 0] * (new_w / old_w)
+        coords[..., 1] = coords[..., 1] * (new_h / old_h)
+        return coords
+
+    def apply_boxes_torch(
+        self, boxes: torch.Tensor, original_size: Tuple[int, ...]
+    ) -> torch.Tensor:
+        """
+        Expects a torch tensor with shape Bx4. Requires the original image
+        size in (H, W) format.
+        """
+        boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
+        return boxes.reshape(-1, 4)
+
+    @staticmethod
+    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
+        """
+        Compute the output size given input size and target long side length.
+        """
+        scale = long_side_length * 1.0 / max(oldh, oldw)
+        newh, neww = oldh * scale, oldw * scale
+        neww = int(neww + 0.5)
+        newh = int(newh + 0.5)
+        return (newh, neww)
diff --git a/osrt/utils/visualize.py b/osrt/utils/visualize.py
index 2bdf364d79ca3359905adc2a8699462491db0265..53c59485edf7584ef3ef822b94903d7bd9f49f20 100644
--- a/osrt/utils/visualize.py
+++ b/osrt/utils/visualize.py
@@ -29,7 +29,7 @@ def get_clustering_colors(num_colors):
     colors = [(0., 0., 0.)]
     for i in range(num_colors):
         colors.append(hsv_to_rgb(i / num_colors, 0.45, 0.8))
-    colors = np.array(colors.cpu())
+    colors = np.array(colors)
     return colors
 
 
diff --git a/runs/clevr3d/osrt/config_ds.json b/runs/clevr3d/osrt/config_ds.json
index 9e96b5ef35f5de375fccb63cb221a0fe6bf3c9b4..f25876b8556371265b538fef5604b6f121498842 100644
--- a/runs/clevr3d/osrt/config_ds.json
+++ b/runs/clevr3d/osrt/config_ds.json
@@ -31,17 +31,23 @@
         "max_it": 333000000,
         "decay_it": 4000000,
         "lr_warmup": 5000,
-        "precision": "32"
+        "precision": "16-mixed"
     },
     "strategy": {
         "deepspeed": {
+            "fp16": {
+                "enabled": true,
+                "auto_cast": true
+            },
             "zero_optimization": {
-                "stage": 2,  
-                "contiguous_gradients": true, 
+                "stage": 3,  
                 "overlap_comm": true,  
                 "allgather_bucket_size": 2e8,  
                 "reduce_bucket_size": 2e8
             },
+            "activation_checkpointing": {
+                "partition_activations": true
+            },
             "zero_allow_untested_optimizer": true
         }
     }
diff --git a/segment-anything b/segment-anything
deleted file mode 160000
index f7b29ba9df1496489af8c71a4bdabed7e8b017b1..0000000000000000000000000000000000000000
--- a/segment-anything
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit f7b29ba9df1496489af8c71a4bdabed7e8b017b1