Skip to content
Snippets Groups Projects
Commit e666be99 authored by Alexandre Chapin's avatar Alexandre Chapin :race_car:
Browse files

To device

parent 95d16f1d
No related branches found
No related tags found
No related merge requests found
...@@ -21,7 +21,6 @@ def compute_focal_loss(inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1): ...@@ -21,7 +21,6 @@ def compute_focal_loss(inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE
return focal_loss return focal_loss
def compute_dice_loss(inputs, targets, smooth=1): def compute_dice_loss(inputs, targets, smooth=1):
inputs = F.sigmoid(inputs) inputs = F.sigmoid(inputs)
...@@ -34,7 +33,6 @@ def compute_dice_loss(inputs, targets, smooth=1): ...@@ -34,7 +33,6 @@ def compute_dice_loss(inputs, targets, smooth=1):
dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
return 1 - dice return 1 - dice
def compute_ari(true_mask, pred_mask): def compute_ari(true_mask, pred_mask):
""" """
Computes the adjusted rand index (ARI) of a given image segmentation, ignoring the background. Computes the adjusted rand index (ARI) of a given image segmentation, ignoring the background.
...@@ -111,7 +109,6 @@ def compute_ari(true_mask, pred_mask): ...@@ -111,7 +109,6 @@ def compute_ari(true_mask, pred_mask):
return torch.where(both_single_cluster, torch.ones_like(ari), ari) return torch.where(both_single_cluster, torch.ones_like(ari), ari)
def precision_recall(segmentation_gt: torch.Tensor, segmentation_pred: torch.Tensor, mode: str, adjusted: bool): def precision_recall(segmentation_gt: torch.Tensor, segmentation_pred: torch.Tensor, mode: str, adjusted: bool):
""" Compute the (Adjusted) Rand Precision/Recall. """ Compute the (Adjusted) Rand Precision/Recall.
Implementation obtained from paper : Sensitivity of Slot-Based Object-Centric Models to their Number of Slots Implementation obtained from paper : Sensitivity of Slot-Based Object-Centric Models to their Number of Slots
......
...@@ -260,6 +260,9 @@ def main(cfg) -> None: ...@@ -260,6 +260,9 @@ def main(cfg) -> None:
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers) test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)
train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader) train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader)
train_loader = fabric.to_device(train_loader)
val_loader = fabric.to_device(val_loader)
test_loader = fabric.to_device(test_loader)
vis_loader_val = DataLoader(val_dataset, batch_size=12, num_workers=num_workers) vis_loader_val = DataLoader(val_dataset, batch_size=12, num_workers=num_workers)
data_vis_val = next(iter(vis_loader_val)) # Validation set data for visualization data_vis_val = next(iter(vis_loader_val)) # Validation set data for visualization
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment