Skip to content
Snippets Groups Projects
Commit ece9ede3 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Try to adapt code from SAM lightning

parent 03a2236d
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
......@@ -51,6 +51,9 @@ class OSRT(nn.Module):
self.decoder = SlotMixerDecoder(**cfg['decoder_kwargs'])
else:
raise ValueError(f'Unknown decoder type: {decoder_type}')
self.encoder.train()
self.decoder.train()
class LitOSRT(pl.LightningModule):
def __init__(self, encoder:nn.Module, decoder: nn.Module, cfg: Dict, extract_masks:bool =False):
......
......@@ -10,7 +10,7 @@ from torch.nn import functional as F
from typing import Tuple
from ..modeling import Sam
from .. import Sam
from .amg import calculate_stability_score
......
......@@ -11,7 +11,29 @@ from torchvision.transforms.functional import resize, to_pil_image # type: igno
from copy import deepcopy
from typing import Tuple
import torchvision.transforms as transforms
class ResizeAndPad:
def __init__(self, target_size):
self.target_size = target_size
self.transform = ResizeLongestSide(target_size)
self.to_tensor = transforms.ToTensor()
def __call__(self, image):
# Resize image
image = self.transform.apply_image(image)
image = self.to_tensor(image)
# Pad image to form a square
_, h, w = image.shape
max_dim = max(w, h)
pad_w = (max_dim - w) // 2
pad_h = (max_dim - h) // 2
padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h)
image = transforms.Pad(padding)(image)
return image
class ResizeLongestSide:
"""
......
"""
Code extracted from : https://github.com/luca-medeiros/lightning-sam/blob/main/lightning_sam/losses.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
ALPHA = 0.8
GAMMA = 2
class FocalLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super().__init__()
def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
inputs = F.sigmoid(inputs)
inputs = torch.clamp(inputs, min=0, max=1)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
BCE_EXP = torch.exp(-BCE)
focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE
return focal_loss
class DiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super().__init__()
def forward(self, inputs, targets, smooth=1):
inputs = F.sigmoid(inputs)
inputs = torch.clamp(inputs, min=0, max=1)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
return 1 - dice
\ No newline at end of file
......@@ -16,6 +16,24 @@ from collections import defaultdict
import time
import os
class AverageMeter:
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def compute_loss(
batch,
batch_idx : int,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment