Skip to content
Snippets Groups Projects
Commit ece9ede3 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

Try to adapt code from SAM lightning

parent 03a2236d
No related branches found
No related tags found
No related merge requests found
......@@ -20,7 +20,7 @@ from osrt.sam import sam_model_registry
from osrt.sam.image_encoder import ImageEncoderViT
from osrt.sam.mask_decoder import MaskDecoder
from osrt.sam.prompt_encoder import PromptEncoder
from osrt.sam.utils.transforms import ResizeLongestSide
from osrt.sam.utils.transforms import ResizeLongestSide, ResizeAndPad
from osrt.sam.utils.amg import batched_mask_to_box, remove_small_regions, mask_to_rle_pytorch, area_from_rle
from torch.nn.utils.rnn import pad_sequence
......@@ -115,150 +115,30 @@ class OSRTEncoder(nn.Module):
class FeatureMasking(nn.Module):
def __init__(self,
points_per_side=12,
box_nms_thresh = 0.7,
stability_score_thresh = 0.9,
pred_iou_thresh=0.88,
points_per_batch=64,
min_mask_region_area=0,
num_slots=32,
slot_dim=1536,
slot_iters=3,
sam_model="default",
sam_path="sam_vit_h_4b8939.pth",
randomize_initial_slots=False):
points_per_side: Optional[int]=12,
box_nms_thresh: float = 0.7,
stability_score_thresh: float = 0.9,
stability_score_offset: float = 1.0,
pred_iou_thresh: float =0.88,
points_per_batch: int =64,
min_mask_region_area: int=0,
num_slots: int=32,
slot_dim: int=1536,
slot_iters: int=3,
sam_model="default",
sam_path="sam_vit_h_4b8939.pth",
randomize_initial_slots=False):
super().__init__()
# We first initialize the automatic mask generator from SAM
# TODO : change the loading here !!!!
sam = sam_model_registry[sam_model](checkpoint=sam_path)
self.mask_generator = SamAutomaticMask(copy.deepcopy(sam.image_encoder),
copy.deepcopy(sam.prompt_encoder),
copy.deepcopy(sam.mask_decoder),
box_nms_thresh=box_nms_thresh,
stability_score_thresh = stability_score_thresh,
pred_iou_thresh=pred_iou_thresh,
points_per_side=points_per_side,
points_per_batch=points_per_batch,
min_mask_region_area=min_mask_region_area)
del sam
self.slot_attention = SlotAttention(num_slots, input_dim=self.mask_generator.token_dim, slot_dim=slot_dim, iters=slot_iters,
randomize_initial_slots=randomize_initial_slots)
def forward(self, images, original_size=None, camera_pos=None, rays=None, extract_masks=False):
"""
Args:
images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical.
original_size: tuple(height, width) The original size of the image before transformation.
camera_pos: [batch_size, num_images, 3]
rays: [batch_size, num_images, height, width, 3]
Returns:
annotations: [batch_size, num_image] An array containing a dict for each image in each batch
with the following annotations :
segmentation
area
predicted_iou
point_coords
stability_score
embeddings (Optionnal)
"""
# Generate images
masks = self.mask_generator(images, original_size, camera_pos, rays, extract_embeddings=True) # [B, N]
B, N = masks.shape
dim = masks[0][0]["embeddings"][0].shape[0]
# We infer each batch separetely as it handle a different number of slots
set_latents = None
# TODO : set the number of slots according to either we want min or max
#with torch.no_grad():
#num_slots = 100000
embedding_batch = []
masks_batch = []
for b in range(B):
latents_batch = torch.empty((0, dim), device=self.mask_generator.device)
for n in range(N):
embeds = masks[b][n]["embeddings"]
#num_slots = min(len(embeds), num_slots)
for embed in embeds:
latents_batch = torch.cat((latents_batch, embed.unsqueeze(0)), 0)
masks_batch.append(torch.zeros(latents_batch.shape[:1]))
embedding_batch.append(latents_batch)
set_latents = pad_sequence(embedding_batch, batch_first=True, padding_value=0.0)
attention_mask = pad_sequence(masks_batch, batch_first=True, padding_value=1.0)
# [batch_size, num_inputs = num_mask_embed x num_im, dim]
#self.slot_attention.change_slots_number(num_slots)
slot_latents = self.slot_attention(set_latents, attention_mask)
if extract_masks:
return masks, slot_latents
else:
del masks
return slot_latents
class SamAutomaticMask(nn.Module):
mask_threshold: float = 0.0
def __init__(
self,
image_encoder: ImageEncoderViT,
prompt_encoder: PromptEncoder,
mask_decoder: MaskDecoder,
pixel_mean: List[float] = [123.675, 116.28, 103.53],
pixel_std: List[float] = [58.395, 57.12, 57.375],
token_dim = 1280,
points_per_side: Optional[int] = 0,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.95,
stability_score_offset: float = 1.0,
box_nms_thresh: float = 0.7,
min_mask_region_area: int = 0,
pos_start_octave=0,
token_dim = 1280
pos_start_octave=0
patch_size = 16
) -> None:
"""
This class adapts SAM implementation from original repository but adapting it to our needs :
- Training only the MaskDecoder
- Performing automatic Mask Discovery (combined with AutomaticMask from original repo)
SAM predicts object masks from an image and input prompts.
Arguments:
image_encoder (ImageEncoderViT): The backbone used to encode the
image into image embeddings that allow for efficient mask prediction.
prompt_encoder (PromptEncoder): Encodes various types of input prompts.
mask_decoder (MaskDecoder): Predicts masks from the image embeddings
and encoded prompts.
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
pixel_std (list(float)): Std values for normalizing pixels in the input image.
"""
super().__init__()
# SAM part
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
# Freeze the image encoder and prompt encoder
"""for param in self.image_encoder.parameters():
param.requires_grad = False
for param in self.prompt_encoder.parameters():
param.requires_grad = False"""
self.mask_decoder = mask_decoder
"""for param in self.mask_decoder.parameters():
param.requires_grad = True"""
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
# Transform image to a square by putting it to the longest side
#self.resize = transforms.Resize(self.image_encoder.img_size, interpolation=transforms.InterpolationMode.BILINEAR)
self.transform = ResizeLongestSide(self.image_encoder.img_size)
# We first initialize the automatic mask generator from SAM
# TODO : change the loading here !!!!
self.mask_generator = sam_model_registry[sam_model](checkpoint=sam_path)
self.preprocess = ResizeAndPad(self.mask_generator.image_encoder.img_size)
self.resize = ResizeLongestSide(self.mask_generator.image_encoder.img_size)
if points_per_side > 0:
self.points_grid = create_points_grid(points_per_side)
......@@ -271,7 +151,7 @@ class SamAutomaticMask(nn.Module):
self.box_nms_thresh = box_nms_thresh
self.min_mask_region_area = min_mask_region_area
self.patch_embed_dim = (self.image_encoder.img_size // patch_size)**2
self.patch_embed_dim = (self.mask_generator.image_encoder.img_size // patch_size)**2
self.token_dim = token_dim
self.tokenizer = nn.Sequential(
nn.Linear(self.patch_embed_dim, 3072),
......@@ -284,132 +164,128 @@ class SamAutomaticMask(nn.Module):
# Space positional embedding
self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
ray_octaves=15)
self.slot_attention = SlotAttention(num_slots, input_dim=self.token_dim, slot_dim=slot_dim, iters=slot_iters,
randomize_initial_slots=randomize_initial_slots)
@property
def device(self) -> Any:
return self.pixel_mean.device
def forward(
self,
images,
orig_size,
camera_pos=None,
rays=None,
extract_embeddings = False):
def forward(self, images, camera_pos=None, rays=None, extract_masks=False):
"""
Args:
images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical.
original_size: tuple(height, width) The original size of the image before transformation.
images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical.
camera_pos: [batch_size, num_images, 3]
rays: [batch_size, num_images, height, width, 3]
extract_mask: boolean Wether to extract masks or not
Returns:
annotations: [batch_size, num_image] An array containing a dict for each image in each batch
with the following annotations :
segmentation
area
predicted_iou
point_coords
stability_score
embeddings (Optionnal)
"""
##################################
# Extract masks with SAM #
##################################
### Get images size
B, N, H, W, C = images.shape
images = images.reshape(B*N, H, W, C) # [B x N, C, H, W]
orig_size = (H, W)
images = images.reshape(B*N, H, W, C) # [B x N, C, H, W]
im_size = self.resize(images[0]).shape[-3:-1]
# Pre-process the images for the image encoder
input_images = torch.stack([self.preprocess(self.transform.apply_image(x)) for x in images])
im_size = self.transform.apply_image(images[0]).shape[-3:-1]
### Pre-process images for the image encoder (Resize and Pad)
images = torch.stack([self.preprocess(x) for x in images])
### Encode images
image_embeddings, embed_no_red = self.mask_generator.image_encoder(images, before_channel_reduc=True) # [B x N, C, H, W]
i = 0
masks = np.empty((B, N), dtype=object)
points_scale = np.array(im_size)[None, ::-1]
points_for_image = self.points_grid * points_scale
with torch.no_grad():
image_embeddings, embed_no_red = self.image_encoder(input_images, before_channel_reduc=True) # [B x N, C, H, W]
annotations = np.empty((B, N), dtype=object)
i = 0
### Extract mask for each image in the batch
for curr_embedding, curr_emb_no_red in zip(image_embeddings, embed_no_red):
mask_data = MaskData()
data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
batch_data = self.process_batch(points, im_size, curr_embedding, orig_size)
mask_data.cat(batch_data)
data.cat(batch_data)
del batch_data
del curr_embedding
# Remove duplicates
### Remove duplicates
keep_by_nms = batched_nms(
mask_data["boxes"].float(),
mask_data["iou_preds"],
torch.zeros_like(mask_data["boxes"][:, 0]), # categories
data["boxes"].float(),
data["iou_preds"],
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.box_nms_thresh,
)
mask_data.filter(keep_by_nms)
data.filter(keep_by_nms)
mask_data["segmentations"] = mask_data["masks"]
data["segmentations"] = data["masks"]
# Extract mask embeddings
if extract_embeddings:
self.tokenize(mask_data, curr_emb_no_red, im_size, scale_box=1.5)
# TODO : vérifier le positional encoding 3D
#self.position_embeding_3d(mask_data["embeddings"], camera_pos[batch][num_im], rays[batch][num_im])
### Extract mask embeddings
self.tokenize(data, curr_emb_no_red, im_size, scale_box=1.5)
# TODO : Add 3D positional encoding in the process
#self.position_embeding_3d(mask_data["embeddings"], camera_pos[batch][num_im], rays[batch][num_im])
mask_data.to_numpy()
data.to_numpy()
# Filter small disconnected regions and holes in masks NOT USED
### Filter small disconnected regions and holes in masks NOT USED
if self.min_mask_region_area > 0:
mask_data = self.postprocess_small_regions(
mask_data,
data = self.postprocess_small_regions(
data,
self.min_mask_region_area,
self.box_nms_thresh,
)
# TODO : have a more efficient way to store the data
# Write mask records
if extract_embeddings:
curr_ann = {
"embeddings": mask_data["embeddings"],
"segmentations": mask_data["segmentations"]
}
else :
curr_ann = {
"segmentations": mask_data["segmentations"]
}
### Write mask records
curr_ann = {
"embeddings": data["embeddings"],
"segmentations": data["segmentations"],
"iou_preds": data["iou_preds"]
}
batch = math.floor((i / N))
num_im = i % N
annotations[batch][num_im] = curr_ann
masks[batch][num_im] = curr_ann
i+=1
return annotations # [B, N, 1] : dict containing diverse annotations such as segmentation, area or also embedding
def postprocess_masks(
self,
masks: torch.Tensor,
input_size: Tuple[int, ...],
original_size: Tuple[int, ...],
) -> torch.Tensor:
"""
Remove padding and upscale masks to the original image size.
### Get parameters of mask and embedding size
B, N = masks.shape
dim = masks[0][0]["embeddings"][0].shape[0]
Arguments:
masks (torch.Tensor): Batched masks from the mask_decoder,
in BxCxHxW format.
input_size (tuple(int, int)): The size of the image input to the
model, in (H, W) format. Used to remove padding.
original_size (tuple(int, int)): The original size of the image
before resizing for input to the model, in (H, W) format.
##################################
# Extract mask embeddings #
##################################
Returns:
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
is given by original_size.
"""
masks = F.interpolate(
masks,
(self.image_encoder.img_size, self.image_encoder.img_size),
mode="bilinear",
align_corners=True,
)
# TODO : set the number of slots according to either we want min or max size (by default a static number of slots)
### Pad batches to have the same length
set_latents = None
embedding_batch = []
masks_batch = []
for b in range(B):
latents_batch = torch.empty((0, dim))
for n in range(N):
embeds = masks[b][n]["embeddings"]
for embed in embeds:
latents_batch = torch.cat((latents_batch, embed.unsqueeze(0)), 0)
masks_batch.append(torch.zeros(latents_batch.shape[:1]))
embedding_batch.append(latents_batch)
### Contains the embeddings values padded
set_latents = pad_sequence(embedding_batch, batch_first=True, padding_value=0.0)
### Contains the masks to avoid to apply attention to not embedded values
attention_mask = pad_sequence(masks_batch, batch_first=True, padding_value=1.0)
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=True)
return masks
# [batch_size, num_inputs = num_mask_embed x num_im, dim]
#self.slot_attention.change_slots_number(num_slots)
###################################
# Apply slot latent on embeddings #
###################################
slot_latents = self.slot_attention(set_latents, attention_mask)
if extract_masks:
return masks, slot_latents
else:
del masks
return slot_latents
@staticmethod
def postprocess_small_regions(
......@@ -460,23 +336,6 @@ class SamAutomaticMask(nn.Module):
return mask_data
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
"""Normalize pixel values and pad to a square input."""
# Rescale the image relative to the longest side
# TODO : apply this preprocess to the dataset before training
x = torch.as_tensor(x, device=self.device)
x = x.permute(2, 0, 1).contiguous()#[None, :, :, :]
# Normalize colors
x = (x - self.pixel_mean) / self.pixel_std
# Pad
h, w = x.shape[-2:]
padh = self.image_encoder.img_size - h
padw = self.image_encoder.img_size - w
x = F.pad(x, (0, padw, 0, padh))
return x
def process_batch(
self,
points: np.ndarray,
......@@ -485,18 +344,41 @@ class SamAutomaticMask(nn.Module):
curr_orig_size
):
# Run model on this batch
in_points = torch.as_tensor(self.transform.apply_coords(points, im_size), device=self.device)
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
### Prepare the points of the batch
in_points = torch.as_tensor(self.transform.apply_coords(points, im_size))
in_labels = torch.ones(in_points.shape[0], dtype=torch.int)
masks, iou_preds, _ = self.predict_masks(
in_points[:, None, :],
in_labels[:, None],
point_coords = in_points[:, None, :]
point_labels = in_labels[:, None]
if point_coords is not None:
points = (point_coords, point_labels)
else:
points = None
### Embed prompts
sparse_embeddings, dense_embeddings = self.mask_generator.prompt_encoder(
points=points,
boxes=None,
masks=None,
)
### Predict masks
low_res_masks, iou_preds = self.mask_generator.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.mask_generator.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=False,
)
### Upscale the masks to the original image resolution
# TODO : verify result, originally put first back to encoder input size to then put to original size linked with the one in tokenize
masks = F.interpolate(
low_res_masks,
curr_orig_size,
im_size,
curr_embedding,
multimask_output=True,
return_logits=True
mode="bilinear",
align_corners=False,
)
data = MaskData(
......@@ -506,102 +388,27 @@ class SamAutomaticMask(nn.Module):
)
del masks
# Filter by predicted IoU
### Filter by predicted IoU
if self.pred_iou_thresh > 0.0:
keep_mask = data["iou_preds"] > self.pred_iou_thresh
data.filter(keep_mask)
# Calculate stability score
### Calculate stability score
data["stability_score"] = calculate_stability_score(
data["masks"], self.mask_threshold, self.stability_score_offset
data["masks"], self.mask_generator.mask_threshold, self.stability_score_offset
)
if self.stability_score_thresh > 0.0:
keep_mask = data["stability_score"] >= self.stability_score_thresh
data.filter(keep_mask)
# Threshold masks and calculate boxes
data["masks"] = data["masks"] > self.mask_threshold
### Threshold masks and calculate boxes
data["masks"] = data["masks"] > self.mask_generator.mask_threshold
data["boxes"] = batched_mask_to_box(data["masks"])
data["rles"] = mask_to_rle_pytorch(data["masks"])
return data
def predict_masks(
self,
point_coords: Optional[torch.Tensor],
point_labels: Optional[torch.Tensor],
curr_orig_size,
curr_input_size,
curr_embedding,
boxes: Optional[torch.Tensor] = None,
mask_input: Optional[torch.Tensor] = None,
multimask_output: bool = True,
return_logits: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks for the given input prompts, using the current image.
Input prompts are batched torch tensors and are expected to already be
transformed to the input frame using ResizeLongestSide.
Arguments:
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
model. Each point is in (X,Y) in pixels.
point_labels (torch.Tensor or None): A BxN array of labels for the
point prompts. 1 indicates a foreground point and 0 indicates a
background point.
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
model, in XYXY format.
mask_input (np.ndarray): A low resolution mask input to the model, typically
coming from a previous prediction iteration. Has form Bx1xHxW, where
for SAM, H=W=256. Masks returned by a previous iteration of the
predict method do not need further transformation.
multimask_output (bool): If true, the model will return three masks.
For ambiguous input prompts (such as a single click), this will often
produce better masks than a single prediction. If only a single
mask is needed, the model's predicted quality score can be used
to select the best mask. For non-ambiguous prompts, such as multiple
input prompts, multimask_output=False can give better results.
return_logits (bool): If true, returns un-thresholded masks logits
instead of a binary mask.
Returns:
(torch.Tensor): The output masks in BxCxHxW format, where C is the
number of masks, and (H, W) is the original image size.
(torch.Tensor): An array of shape BxC containing the model's
predictions for the quality of each mask.
(torch.Tensor): An array of shape BxCxHxW, where C is the number
of masks and H=W=256. These low res logits can be passed to
a subsequent iteration as mask input.
"""
if point_coords is not None:
points = (point_coords, point_labels)
else:
points = None
# Embed prompts
#with torch.no_grad():
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=boxes,
masks=mask_input,
)
# Predict masks
low_res_masks, iou_predictions = self.mask_decoder(
image_embeddings=curr_embedding.unsqueeze(0),
image_pe=self.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
)
# Upscale the masks to the original image resolution
masks = self.postprocess_masks(low_res_masks, curr_input_size, curr_orig_size)
if not return_logits:
masks = masks > self.mask_threshold
return masks, iou_predictions.detach(), low_res_masks.detach()
def tokenize(self, mask_data, image_embed, input_size, scale_box=1.5):
"""
Predicts the embeddings from each mask given the global embedding and
......@@ -622,7 +429,13 @@ class SamAutomaticMask(nn.Module):
out_size = mean_embed.shape
# Put the mean image embedding back to the input format
scaled_img_emb = self.postprocess_masks(mean_embed.unsqueeze(0).unsqueeze(0), input_size, (orig_H, orig_W)).squeeze()
# TODO : check results linked with the one in process_batch
scaled_img_emb = F.interpolate(
mean_embed.unsqueeze(0).unsqueeze(0),
(orig_H, orig_W),
mode="bilinear",
align_corners=False,
).squeeze()
mask_data["embeddings"] = []
......@@ -646,7 +459,7 @@ class SamAutomaticMask(nn.Module):
token_embed += pos_embedding
# Apply mask to image embedding
mask_data["embeddings"].append(torch.from_numpy(token_embed).to(self.device)) # [token_dim]
mask_data["embeddings"].append(torch.from_numpy(token_embed)) # [token_dim]
def position_embeding_3d(self, tokens, camera_pos, rays):
"""
......@@ -663,4 +476,4 @@ class SamAutomaticMask(nn.Module):
pos_rot_x, pos_rot_y, pos_rot_z = get_positional_embedding(rot_x, self.token_dim), get_positional_embedding(rot_y, self.token_dim), get_positional_embedding(rot_z, self.token_dim)
pos_embed = (pos_rot_x + pos_rot_y + pos_rot_z + pos_x + pox_y + pos_z) // 6
tokens += pos_embed
\ No newline at end of file
tokens += pos_embed
......@@ -51,6 +51,9 @@ class OSRT(nn.Module):
self.decoder = SlotMixerDecoder(**cfg['decoder_kwargs'])
else:
raise ValueError(f'Unknown decoder type: {decoder_type}')
self.encoder.train()
self.decoder.train()
class LitOSRT(pl.LightningModule):
def __init__(self, encoder:nn.Module, decoder: nn.Module, cfg: Dict, extract_masks:bool =False):
......
......@@ -10,7 +10,7 @@ from torch.nn import functional as F
from typing import Tuple
from ..modeling import Sam
from .. import Sam
from .amg import calculate_stability_score
......
......@@ -11,7 +11,29 @@ from torchvision.transforms.functional import resize, to_pil_image # type: igno
from copy import deepcopy
from typing import Tuple
import torchvision.transforms as transforms
class ResizeAndPad:
def __init__(self, target_size):
self.target_size = target_size
self.transform = ResizeLongestSide(target_size)
self.to_tensor = transforms.ToTensor()
def __call__(self, image):
# Resize image
image = self.transform.apply_image(image)
image = self.to_tensor(image)
# Pad image to form a square
_, h, w = image.shape
max_dim = max(w, h)
pad_w = (max_dim - w) // 2
pad_h = (max_dim - h) // 2
padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h)
image = transforms.Pad(padding)(image)
return image
class ResizeLongestSide:
"""
......
"""
Code extracted from : https://github.com/luca-medeiros/lightning-sam/blob/main/lightning_sam/losses.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
ALPHA = 0.8
GAMMA = 2
class FocalLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super().__init__()
def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
inputs = F.sigmoid(inputs)
inputs = torch.clamp(inputs, min=0, max=1)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
BCE_EXP = torch.exp(-BCE)
focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE
return focal_loss
class DiceLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super().__init__()
def forward(self, inputs, targets, smooth=1):
inputs = F.sigmoid(inputs)
inputs = torch.clamp(inputs, min=0, max=1)
#flatten label and prediction tensors
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
return 1 - dice
\ No newline at end of file
......@@ -16,6 +16,24 @@ from collections import defaultdict
import time
import os
class AverageMeter:
"""Computes and stores the average and current value."""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def compute_loss(
batch,
batch_idx : int,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment