From ece9ede389d145c99bd7962814646de0ccab1007 Mon Sep 17 00:00:00 2001 From: alexcbb <alexchapin@hotmail.fr> Date: Mon, 17 Jul 2023 16:01:12 +0200 Subject: [PATCH] Try to adapt code from SAM lightning --- osrt/encoder.py | 493 +++++++++++------------------------ osrt/model.py | 3 + osrt/sam/utils/onnx.py | 2 +- osrt/sam/utils/transforms.py | 22 ++ osrt/utils/losses.py | 47 ++++ osrt/utils/training.py | 18 ++ 6 files changed, 244 insertions(+), 341 deletions(-) create mode 100644 osrt/utils/losses.py diff --git a/osrt/encoder.py b/osrt/encoder.py index 652e7d4..2388dd3 100644 --- a/osrt/encoder.py +++ b/osrt/encoder.py @@ -20,7 +20,7 @@ from osrt.sam import sam_model_registry from osrt.sam.image_encoder import ImageEncoderViT from osrt.sam.mask_decoder import MaskDecoder from osrt.sam.prompt_encoder import PromptEncoder -from osrt.sam.utils.transforms import ResizeLongestSide +from osrt.sam.utils.transforms import ResizeLongestSide, ResizeAndPad from osrt.sam.utils.amg import batched_mask_to_box, remove_small_regions, mask_to_rle_pytorch, area_from_rle from torch.nn.utils.rnn import pad_sequence @@ -115,150 +115,30 @@ class OSRTEncoder(nn.Module): class FeatureMasking(nn.Module): def __init__(self, - points_per_side=12, - box_nms_thresh = 0.7, - stability_score_thresh = 0.9, - pred_iou_thresh=0.88, - points_per_batch=64, - min_mask_region_area=0, - num_slots=32, - slot_dim=1536, - slot_iters=3, - sam_model="default", - sam_path="sam_vit_h_4b8939.pth", - randomize_initial_slots=False): + points_per_side: Optional[int]=12, + box_nms_thresh: float = 0.7, + stability_score_thresh: float = 0.9, + stability_score_offset: float = 1.0, + pred_iou_thresh: float =0.88, + points_per_batch: int =64, + min_mask_region_area: int=0, + num_slots: int=32, + slot_dim: int=1536, + slot_iters: int=3, + sam_model="default", + sam_path="sam_vit_h_4b8939.pth", + randomize_initial_slots=False): super().__init__() - # We first initialize the automatic mask generator from SAM - # TODO : change the loading here !!!! - sam = sam_model_registry[sam_model](checkpoint=sam_path) - - self.mask_generator = SamAutomaticMask(copy.deepcopy(sam.image_encoder), - copy.deepcopy(sam.prompt_encoder), - copy.deepcopy(sam.mask_decoder), - box_nms_thresh=box_nms_thresh, - stability_score_thresh = stability_score_thresh, - pred_iou_thresh=pred_iou_thresh, - points_per_side=points_per_side, - points_per_batch=points_per_batch, - min_mask_region_area=min_mask_region_area) - del sam - - self.slot_attention = SlotAttention(num_slots, input_dim=self.mask_generator.token_dim, slot_dim=slot_dim, iters=slot_iters, - randomize_initial_slots=randomize_initial_slots) - - def forward(self, images, original_size=None, camera_pos=None, rays=None, extract_masks=False): - """ - Args: - images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical. - original_size: tuple(height, width) The original size of the image before transformation. - camera_pos: [batch_size, num_images, 3] - rays: [batch_size, num_images, height, width, 3] - Returns: - annotations: [batch_size, num_image] An array containing a dict for each image in each batch - with the following annotations : - segmentation - area - predicted_iou - point_coords - stability_score - embeddings (Optionnal) - """ - - # Generate images - masks = self.mask_generator(images, original_size, camera_pos, rays, extract_embeddings=True) # [B, N] - - B, N = masks.shape - dim = masks[0][0]["embeddings"][0].shape[0] - - # We infer each batch separetely as it handle a different number of slots - set_latents = None - # TODO : set the number of slots according to either we want min or max - #with torch.no_grad(): - #num_slots = 100000 - embedding_batch = [] - masks_batch = [] - for b in range(B): - latents_batch = torch.empty((0, dim), device=self.mask_generator.device) - for n in range(N): - embeds = masks[b][n]["embeddings"] - #num_slots = min(len(embeds), num_slots) - for embed in embeds: - latents_batch = torch.cat((latents_batch, embed.unsqueeze(0)), 0) - masks_batch.append(torch.zeros(latents_batch.shape[:1])) - embedding_batch.append(latents_batch) - set_latents = pad_sequence(embedding_batch, batch_first=True, padding_value=0.0) - attention_mask = pad_sequence(masks_batch, batch_first=True, padding_value=1.0) - - # [batch_size, num_inputs = num_mask_embed x num_im, dim] - #self.slot_attention.change_slots_number(num_slots) - slot_latents = self.slot_attention(set_latents, attention_mask) - - if extract_masks: - return masks, slot_latents - else: - del masks - return slot_latents - - -class SamAutomaticMask(nn.Module): - mask_threshold: float = 0.0 - - def __init__( - self, - image_encoder: ImageEncoderViT, - prompt_encoder: PromptEncoder, - mask_decoder: MaskDecoder, - pixel_mean: List[float] = [123.675, 116.28, 103.53], - pixel_std: List[float] = [58.395, 57.12, 57.375], - token_dim = 1280, - points_per_side: Optional[int] = 0, - points_per_batch: int = 64, - pred_iou_thresh: float = 0.88, - stability_score_thresh: float = 0.95, - stability_score_offset: float = 1.0, - box_nms_thresh: float = 0.7, - min_mask_region_area: int = 0, - pos_start_octave=0, + token_dim = 1280 + pos_start_octave=0 patch_size = 16 - ) -> None: - """ - This class adapts SAM implementation from original repository but adapting it to our needs : - - Training only the MaskDecoder - - Performing automatic Mask Discovery (combined with AutomaticMask from original repo) - - SAM predicts object masks from an image and input prompts. - Arguments: - image_encoder (ImageEncoderViT): The backbone used to encode the - image into image embeddings that allow for efficient mask prediction. - prompt_encoder (PromptEncoder): Encodes various types of input prompts. - mask_decoder (MaskDecoder): Predicts masks from the image embeddings - and encoded prompts. - pixel_mean (list(float)): Mean values for normalizing pixels in the input image. - pixel_std (list(float)): Std values for normalizing pixels in the input image. - """ - super().__init__() - - # SAM part - self.image_encoder = image_encoder - self.prompt_encoder = prompt_encoder - - # Freeze the image encoder and prompt encoder - """for param in self.image_encoder.parameters(): - param.requires_grad = False - for param in self.prompt_encoder.parameters(): - param.requires_grad = False""" - - self.mask_decoder = mask_decoder - """for param in self.mask_decoder.parameters(): - param.requires_grad = True""" - self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) - self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) - - # Transform image to a square by putting it to the longest side - #self.resize = transforms.Resize(self.image_encoder.img_size, interpolation=transforms.InterpolationMode.BILINEAR) - self.transform = ResizeLongestSide(self.image_encoder.img_size) + # We first initialize the automatic mask generator from SAM + # TODO : change the loading here !!!! + self.mask_generator = sam_model_registry[sam_model](checkpoint=sam_path) + self.preprocess = ResizeAndPad(self.mask_generator.image_encoder.img_size) + self.resize = ResizeLongestSide(self.mask_generator.image_encoder.img_size) if points_per_side > 0: self.points_grid = create_points_grid(points_per_side) @@ -271,7 +151,7 @@ class SamAutomaticMask(nn.Module): self.box_nms_thresh = box_nms_thresh self.min_mask_region_area = min_mask_region_area - self.patch_embed_dim = (self.image_encoder.img_size // patch_size)**2 + self.patch_embed_dim = (self.mask_generator.image_encoder.img_size // patch_size)**2 self.token_dim = token_dim self.tokenizer = nn.Sequential( nn.Linear(self.patch_embed_dim, 3072), @@ -284,132 +164,128 @@ class SamAutomaticMask(nn.Module): # Space positional embedding self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave, ray_octaves=15) + + self.slot_attention = SlotAttention(num_slots, input_dim=self.token_dim, slot_dim=slot_dim, iters=slot_iters, + randomize_initial_slots=randomize_initial_slots) - @property - def device(self) -> Any: - return self.pixel_mean.device - def forward( - self, - images, - orig_size, - camera_pos=None, - rays=None, - extract_embeddings = False): + def forward(self, images, camera_pos=None, rays=None, extract_masks=False): """ Args: - images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical. - original_size: tuple(height, width) The original size of the image before transformation. + images: [batch_size, num_images, height, width, 3]. Assume the first image is canonical. camera_pos: [batch_size, num_images, 3] rays: [batch_size, num_images, height, width, 3] + extract_mask: boolean Wether to extract masks or not Returns: - annotations: [batch_size, num_image] An array containing a dict for each image in each batch - with the following annotations : - segmentation - area - predicted_iou - point_coords - stability_score - embeddings (Optionnal) + """ + + ################################## + # Extract masks with SAM # + ################################## + + ### Get images size B, N, H, W, C = images.shape - images = images.reshape(B*N, H, W, C) # [B x N, C, H, W] + orig_size = (H, W) + images = images.reshape(B*N, H, W, C) # [B x N, C, H, W] + im_size = self.resize(images[0]).shape[-3:-1] - # Pre-process the images for the image encoder - input_images = torch.stack([self.preprocess(self.transform.apply_image(x)) for x in images]) - im_size = self.transform.apply_image(images[0]).shape[-3:-1] + ### Pre-process images for the image encoder (Resize and Pad) + images = torch.stack([self.preprocess(x) for x in images]) + ### Encode images + image_embeddings, embed_no_red = self.mask_generator.image_encoder(images, before_channel_reduc=True) # [B x N, C, H, W] + + i = 0 + masks = np.empty((B, N), dtype=object) points_scale = np.array(im_size)[None, ::-1] points_for_image = self.points_grid * points_scale - with torch.no_grad(): - image_embeddings, embed_no_red = self.image_encoder(input_images, before_channel_reduc=True) # [B x N, C, H, W] - - annotations = np.empty((B, N), dtype=object) - i = 0 + ### Extract mask for each image in the batch for curr_embedding, curr_emb_no_red in zip(image_embeddings, embed_no_red): - mask_data = MaskData() + data = MaskData() for (points,) in batch_iterator(self.points_per_batch, points_for_image): batch_data = self.process_batch(points, im_size, curr_embedding, orig_size) - mask_data.cat(batch_data) + data.cat(batch_data) del batch_data - del curr_embedding - # Remove duplicates + ### Remove duplicates keep_by_nms = batched_nms( - mask_data["boxes"].float(), - mask_data["iou_preds"], - torch.zeros_like(mask_data["boxes"][:, 0]), # categories + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories iou_threshold=self.box_nms_thresh, ) - mask_data.filter(keep_by_nms) + data.filter(keep_by_nms) - mask_data["segmentations"] = mask_data["masks"] + data["segmentations"] = data["masks"] - # Extract mask embeddings - if extract_embeddings: - self.tokenize(mask_data, curr_emb_no_red, im_size, scale_box=1.5) - # TODO : vérifier le positional encoding 3D - #self.position_embeding_3d(mask_data["embeddings"], camera_pos[batch][num_im], rays[batch][num_im]) + ### Extract mask embeddings + self.tokenize(data, curr_emb_no_red, im_size, scale_box=1.5) + # TODO : Add 3D positional encoding in the process + #self.position_embeding_3d(mask_data["embeddings"], camera_pos[batch][num_im], rays[batch][num_im]) - mask_data.to_numpy() + data.to_numpy() - # Filter small disconnected regions and holes in masks NOT USED + ### Filter small disconnected regions and holes in masks NOT USED if self.min_mask_region_area > 0: - mask_data = self.postprocess_small_regions( - mask_data, + data = self.postprocess_small_regions( + data, self.min_mask_region_area, self.box_nms_thresh, ) - # TODO : have a more efficient way to store the data - # Write mask records - if extract_embeddings: - curr_ann = { - "embeddings": mask_data["embeddings"], - "segmentations": mask_data["segmentations"] - } - else : - curr_ann = { - "segmentations": mask_data["segmentations"] - } + ### Write mask records + curr_ann = { + "embeddings": data["embeddings"], + "segmentations": data["segmentations"], + "iou_preds": data["iou_preds"] + } batch = math.floor((i / N)) num_im = i % N - annotations[batch][num_im] = curr_ann + masks[batch][num_im] = curr_ann i+=1 - return annotations # [B, N, 1] : dict containing diverse annotations such as segmentation, area or also embedding - def postprocess_masks( - self, - masks: torch.Tensor, - input_size: Tuple[int, ...], - original_size: Tuple[int, ...], - ) -> torch.Tensor: - """ - Remove padding and upscale masks to the original image size. + ### Get parameters of mask and embedding size + B, N = masks.shape + dim = masks[0][0]["embeddings"][0].shape[0] - Arguments: - masks (torch.Tensor): Batched masks from the mask_decoder, - in BxCxHxW format. - input_size (tuple(int, int)): The size of the image input to the - model, in (H, W) format. Used to remove padding. - original_size (tuple(int, int)): The original size of the image - before resizing for input to the model, in (H, W) format. + ################################## + # Extract mask embeddings # + ################################## - Returns: - (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) - is given by original_size. - """ - masks = F.interpolate( - masks, - (self.image_encoder.img_size, self.image_encoder.img_size), - mode="bilinear", - align_corners=True, - ) + # TODO : set the number of slots according to either we want min or max size (by default a static number of slots) + + ### Pad batches to have the same length + set_latents = None + embedding_batch = [] + masks_batch = [] + for b in range(B): + latents_batch = torch.empty((0, dim)) + for n in range(N): + embeds = masks[b][n]["embeddings"] + for embed in embeds: + latents_batch = torch.cat((latents_batch, embed.unsqueeze(0)), 0) + masks_batch.append(torch.zeros(latents_batch.shape[:1])) + embedding_batch.append(latents_batch) + ### Contains the embeddings values padded + set_latents = pad_sequence(embedding_batch, batch_first=True, padding_value=0.0) + ### Contains the masks to avoid to apply attention to not embedded values + attention_mask = pad_sequence(masks_batch, batch_first=True, padding_value=1.0) - masks = masks[..., : input_size[0], : input_size[1]] - masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=True) - return masks + # [batch_size, num_inputs = num_mask_embed x num_im, dim] + #self.slot_attention.change_slots_number(num_slots) + + ################################### + # Apply slot latent on embeddings # + ################################### + slot_latents = self.slot_attention(set_latents, attention_mask) + + if extract_masks: + return masks, slot_latents + else: + del masks + return slot_latents @staticmethod def postprocess_small_regions( @@ -460,23 +336,6 @@ class SamAutomaticMask(nn.Module): return mask_data - def preprocess(self, x: torch.Tensor) -> torch.Tensor: - """Normalize pixel values and pad to a square input.""" - # Rescale the image relative to the longest side - # TODO : apply this preprocess to the dataset before training - x = torch.as_tensor(x, device=self.device) - x = x.permute(2, 0, 1).contiguous()#[None, :, :, :] - - # Normalize colors - x = (x - self.pixel_mean) / self.pixel_std - - # Pad - h, w = x.shape[-2:] - padh = self.image_encoder.img_size - h - padw = self.image_encoder.img_size - w - x = F.pad(x, (0, padw, 0, padh)) - return x - def process_batch( self, points: np.ndarray, @@ -485,18 +344,41 @@ class SamAutomaticMask(nn.Module): curr_orig_size ): - # Run model on this batch - in_points = torch.as_tensor(self.transform.apply_coords(points, im_size), device=self.device) - in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + ### Prepare the points of the batch + in_points = torch.as_tensor(self.transform.apply_coords(points, im_size)) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int) - masks, iou_preds, _ = self.predict_masks( - in_points[:, None, :], - in_labels[:, None], + point_coords = in_points[:, None, :] + point_labels = in_labels[:, None] + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + ### Embed prompts + sparse_embeddings, dense_embeddings = self.mask_generator.prompt_encoder( + points=points, + boxes=None, + masks=None, + ) + + ### Predict masks + low_res_masks, iou_preds = self.mask_generator.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.mask_generator.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=False, + ) + + ### Upscale the masks to the original image resolution + # TODO : verify result, originally put first back to encoder input size to then put to original size linked with the one in tokenize + masks = F.interpolate( + low_res_masks, curr_orig_size, - im_size, - curr_embedding, - multimask_output=True, - return_logits=True + mode="bilinear", + align_corners=False, ) data = MaskData( @@ -506,102 +388,27 @@ class SamAutomaticMask(nn.Module): ) del masks - # Filter by predicted IoU + ### Filter by predicted IoU if self.pred_iou_thresh > 0.0: keep_mask = data["iou_preds"] > self.pred_iou_thresh data.filter(keep_mask) - # Calculate stability score + ### Calculate stability score data["stability_score"] = calculate_stability_score( - data["masks"], self.mask_threshold, self.stability_score_offset + data["masks"], self.mask_generator.mask_threshold, self.stability_score_offset ) if self.stability_score_thresh > 0.0: keep_mask = data["stability_score"] >= self.stability_score_thresh data.filter(keep_mask) - # Threshold masks and calculate boxes - data["masks"] = data["masks"] > self.mask_threshold + ### Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_generator.mask_threshold data["boxes"] = batched_mask_to_box(data["masks"]) data["rles"] = mask_to_rle_pytorch(data["masks"]) return data - def predict_masks( - self, - point_coords: Optional[torch.Tensor], - point_labels: Optional[torch.Tensor], - curr_orig_size, - curr_input_size, - curr_embedding, - boxes: Optional[torch.Tensor] = None, - mask_input: Optional[torch.Tensor] = None, - multimask_output: bool = True, - return_logits: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Predict masks for the given input prompts, using the current image. - Input prompts are batched torch tensors and are expected to already be - transformed to the input frame using ResizeLongestSide. - - Arguments: - point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the - model. Each point is in (X,Y) in pixels. - point_labels (torch.Tensor or None): A BxN array of labels for the - point prompts. 1 indicates a foreground point and 0 indicates a - background point. - boxes (np.ndarray or None): A Bx4 array given a box prompt to the - model, in XYXY format. - mask_input (np.ndarray): A low resolution mask input to the model, typically - coming from a previous prediction iteration. Has form Bx1xHxW, where - for SAM, H=W=256. Masks returned by a previous iteration of the - predict method do not need further transformation. - multimask_output (bool): If true, the model will return three masks. - For ambiguous input prompts (such as a single click), this will often - produce better masks than a single prediction. If only a single - mask is needed, the model's predicted quality score can be used - to select the best mask. For non-ambiguous prompts, such as multiple - input prompts, multimask_output=False can give better results. - return_logits (bool): If true, returns un-thresholded masks logits - instead of a binary mask. - - Returns: - (torch.Tensor): The output masks in BxCxHxW format, where C is the - number of masks, and (H, W) is the original image size. - (torch.Tensor): An array of shape BxC containing the model's - predictions for the quality of each mask. - (torch.Tensor): An array of shape BxCxHxW, where C is the number - of masks and H=W=256. These low res logits can be passed to - a subsequent iteration as mask input. - """ - if point_coords is not None: - points = (point_coords, point_labels) - else: - points = None - # Embed prompts - #with torch.no_grad(): - sparse_embeddings, dense_embeddings = self.prompt_encoder( - points=points, - boxes=boxes, - masks=mask_input, - ) - # Predict masks - low_res_masks, iou_predictions = self.mask_decoder( - image_embeddings=curr_embedding.unsqueeze(0), - image_pe=self.prompt_encoder.get_dense_pe(), - sparse_prompt_embeddings=sparse_embeddings, - dense_prompt_embeddings=dense_embeddings, - multimask_output=multimask_output, - ) - - # Upscale the masks to the original image resolution - masks = self.postprocess_masks(low_res_masks, curr_input_size, curr_orig_size) - - if not return_logits: - masks = masks > self.mask_threshold - - return masks, iou_predictions.detach(), low_res_masks.detach() - def tokenize(self, mask_data, image_embed, input_size, scale_box=1.5): """ Predicts the embeddings from each mask given the global embedding and @@ -622,7 +429,13 @@ class SamAutomaticMask(nn.Module): out_size = mean_embed.shape # Put the mean image embedding back to the input format - scaled_img_emb = self.postprocess_masks(mean_embed.unsqueeze(0).unsqueeze(0), input_size, (orig_H, orig_W)).squeeze() + # TODO : check results linked with the one in process_batch + scaled_img_emb = F.interpolate( + mean_embed.unsqueeze(0).unsqueeze(0), + (orig_H, orig_W), + mode="bilinear", + align_corners=False, + ).squeeze() mask_data["embeddings"] = [] @@ -646,7 +459,7 @@ class SamAutomaticMask(nn.Module): token_embed += pos_embedding # Apply mask to image embedding - mask_data["embeddings"].append(torch.from_numpy(token_embed).to(self.device)) # [token_dim] + mask_data["embeddings"].append(torch.from_numpy(token_embed)) # [token_dim] def position_embeding_3d(self, tokens, camera_pos, rays): """ @@ -663,4 +476,4 @@ class SamAutomaticMask(nn.Module): pos_rot_x, pos_rot_y, pos_rot_z = get_positional_embedding(rot_x, self.token_dim), get_positional_embedding(rot_y, self.token_dim), get_positional_embedding(rot_z, self.token_dim) pos_embed = (pos_rot_x + pos_rot_y + pos_rot_z + pos_x + pox_y + pos_z) // 6 - tokens += pos_embed \ No newline at end of file + tokens += pos_embed diff --git a/osrt/model.py b/osrt/model.py index 6fb24ec..b01d161 100644 --- a/osrt/model.py +++ b/osrt/model.py @@ -51,6 +51,9 @@ class OSRT(nn.Module): self.decoder = SlotMixerDecoder(**cfg['decoder_kwargs']) else: raise ValueError(f'Unknown decoder type: {decoder_type}') + + self.encoder.train() + self.decoder.train() class LitOSRT(pl.LightningModule): def __init__(self, encoder:nn.Module, decoder: nn.Module, cfg: Dict, extract_masks:bool =False): diff --git a/osrt/sam/utils/onnx.py b/osrt/sam/utils/onnx.py index 3196bdf..396a373 100644 --- a/osrt/sam/utils/onnx.py +++ b/osrt/sam/utils/onnx.py @@ -10,7 +10,7 @@ from torch.nn import functional as F from typing import Tuple -from ..modeling import Sam +from .. import Sam from .amg import calculate_stability_score diff --git a/osrt/sam/utils/transforms.py b/osrt/sam/utils/transforms.py index e801bdf..87ee5c7 100644 --- a/osrt/sam/utils/transforms.py +++ b/osrt/sam/utils/transforms.py @@ -11,7 +11,29 @@ from torchvision.transforms.functional import resize, to_pil_image # type: igno from copy import deepcopy from typing import Tuple +import torchvision.transforms as transforms +class ResizeAndPad: + def __init__(self, target_size): + self.target_size = target_size + self.transform = ResizeLongestSide(target_size) + self.to_tensor = transforms.ToTensor() + + def __call__(self, image): + # Resize image + image = self.transform.apply_image(image) + image = self.to_tensor(image) + + # Pad image to form a square + _, h, w = image.shape + max_dim = max(w, h) + pad_w = (max_dim - w) // 2 + pad_h = (max_dim - h) // 2 + + padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h) + image = transforms.Pad(padding)(image) + + return image class ResizeLongestSide: """ diff --git a/osrt/utils/losses.py b/osrt/utils/losses.py new file mode 100644 index 0000000..cd94527 --- /dev/null +++ b/osrt/utils/losses.py @@ -0,0 +1,47 @@ +""" +Code extracted from : https://github.com/luca-medeiros/lightning-sam/blob/main/lightning_sam/losses.py +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +ALPHA = 0.8 +GAMMA = 2 + + +class FocalLoss(nn.Module): + + def __init__(self, weight=None, size_average=True): + super().__init__() + + def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1): + inputs = F.sigmoid(inputs) + inputs = torch.clamp(inputs, min=0, max=1) + #flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) + + BCE = F.binary_cross_entropy(inputs, targets, reduction='mean') + BCE_EXP = torch.exp(-BCE) + focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE + + return focal_loss + + +class DiceLoss(nn.Module): + + def __init__(self, weight=None, size_average=True): + super().__init__() + + def forward(self, inputs, targets, smooth=1): + inputs = F.sigmoid(inputs) + inputs = torch.clamp(inputs, min=0, max=1) + #flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) + + intersection = (inputs * targets).sum() + dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) + + return 1 - dice \ No newline at end of file diff --git a/osrt/utils/training.py b/osrt/utils/training.py index 70392ee..48240ca 100644 --- a/osrt/utils/training.py +++ b/osrt/utils/training.py @@ -16,6 +16,24 @@ from collections import defaultdict import time import os +class AverageMeter: + """Computes and stores the average and current value.""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + def compute_loss( batch, batch_idx : int, -- GitLab