From d33ad3895e7e87e47b88486589a4bb227b5868c7 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 19 May 2025 16:30:04 +0200 Subject: [PATCH] add : barlow model --- barlow_twin_like/config.py | 28 +++ barlow_twin_like/dataset_barlow.py | 304 +++++++++++++++++++++++++++ barlow_twin_like/main.py | 193 +++++++++++++++++ barlow_twin_like/model.py | 325 +++++++++++++++++++++++++++++ requirements.txt | 7 +- 5 files changed, 856 insertions(+), 1 deletion(-) create mode 100644 barlow_twin_like/config.py diff --git a/barlow_twin_like/config.py b/barlow_twin_like/config.py new file mode 100644 index 00000000..2f38fd18 --- /dev/null +++ b/barlow_twin_like/config.py @@ -0,0 +1,28 @@ +import argparse + + +def load_args_barlow(): + parser = argparse.ArgumentParser() + + parser.add_argument('--epoches', type=int, default=100) + parser.add_argument('--classification_epoches', type=int, default=10) + parser.add_argument('--eval_inter', type=int, default=1) + parser.add_argument('--test_inter', type=int, default=10) + parser.add_argument('--lr', type=float, default=0.001) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--opti', type=str, default='adam') + parser.add_argument('--model', type=str, default='ResNet18') + parser.add_argument('--projector', type=str, default='1024-512-256-128') + parser.add_argument('--sampler', type=str, default=None) + parser.add_argument('--dataset_train_dir', type=str, default='data/processed_data_wiff_clean_005_10000_apex/npy_image/train_data') + parser.add_argument('--dataset_val_dir', type=str, default='data/processed_data_wiff_clean_005_10000_apex/npy_image/test_data') + parser.add_argument('--dataset_test_dir', type=str, default=None) + parser.add_argument('--base_out', type=str, default='output/best_model_base_ray') + parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') + parser.add_argument('--output', type=str, default='output/out_barlow.csv') + parser.add_argument('--save_path', type=str, default='output/best_model_barlow.pt') + parser.add_argument('--pretrain_path', type=str, default=None) + parser.add_argument('--wandb', type=str, default='wandb_run') + args = parser.parse_args() + + return args \ No newline at end of file diff --git a/barlow_twin_like/dataset_barlow.py b/barlow_twin_like/dataset_barlow.py index e69de29b..274ecdf0 100644 --- a/barlow_twin_like/dataset_barlow.py +++ b/barlow_twin_like/dataset_barlow.py @@ -0,0 +1,304 @@ +import random + +import numpy as np +import torch +import torchvision +import torchvision.transforms as transforms +import torch.utils.data as data +from PIL import Image +import os +import os.path +from typing import Callable, cast, Dict, List, Optional, Tuple, Union +from pathlib import Path +from collections import OrderedDict + +from torch.utils.data import WeightedRandomSampler + +IMG_EXTENSIONS = ".npy" + +def npy_loader(path): + sample = torch.from_numpy(np.load(path)) + sample = sample.unsqueeze(0) + return sample + +def remove_aer_ana(l): + l = map(lambda x : x.split('_')[0],l) + return list(OrderedDict.fromkeys(l)) + +def make_dataset_custom( + directory: Union[str, Path], + class_to_idx: Optional[Dict[str, int]] = None, + extensions: Optional[Union[str, Tuple[str, ...]]] = IMG_EXTENSIONS, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, +) -> List[Tuple[str, str, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + See :class:`DatasetFolder` for details. + + Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function + by default. + """ + directory = os.path.expanduser(directory) + + if class_to_idx is None: + _, class_to_idx = torchvision.datasets.folder.find_classes(directory) + elif not class_to_idx: + raise ValueError("'class_to_index' must have at least one entry to collect any samples.") + + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + + if extensions is not None: + + def is_valid_file(x: str) -> bool: + return torchvision.datasets.folder.has_file_allowed_extension(x, extensions) # type: ignore[arg-type] + + is_valid_file = cast(Callable[[str], bool], is_valid_file) + + instances = [] + available_classes = set() + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + for fname in sorted(fnames): + path = os.path.join(root, fname) + if is_valid_file(path) and os.path.isfile(path): + item = path, class_index + instances.append(item) + + if target_class not in available_classes: + available_classes.add(target_class) + + empty_classes = set(class_to_idx.keys()) - available_classes + if empty_classes and not allow_empty: + msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " + if extensions is not None: + msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" + raise FileNotFoundError(msg) + + return instances + + +def make_dataset_base( + directory: Union[str, Path], + class_to_idx: Optional[Dict[str, int]] = None, + extensions: Optional[Union[str, Tuple[str, ...]]] = IMG_EXTENSIONS, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, +) -> List[Tuple[str, str, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + See :class:`DatasetFolder` for details. + + Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function + by default. + """ + directory = os.path.expanduser(directory) + + if class_to_idx is None: + _, class_to_idx = torchvision.datasets.folder.find_classes(directory) + elif not class_to_idx: + raise ValueError("'class_to_index' must have at least one entry to collect any samples.") + + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + + if extensions is not None: + + def is_valid_file(x: str) -> bool: + return torchvision.datasets.folder.has_file_allowed_extension(x, extensions) # type: ignore[arg-type] + + is_valid_file = cast(Callable[[str], bool], is_valid_file) + + instances = [] + available_classes = set() + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + fnames_base = remove_aer_ana(fnames) + for fname in sorted(fnames_base): + fname_ana = fname+'_ANA.npy' + fname_aer = fname + '_AER.npy' + path_ana = os.path.join(root, fname_ana) + path_aer = os.path.join(root, fname_aer) + if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer): + item = path_aer, path_ana, class_index + instances.append(item) + + if target_class not in available_classes: + available_classes.add(target_class) + + empty_classes = set(class_to_idx.keys()) - available_classes + if empty_classes and not allow_empty: + msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " + if extensions is not None: + msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" + raise FileNotFoundError(msg) + + return instances + + +class ImageFolder(data.Dataset): + def __init__(self, root, transform=None, target_transform=None, + flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None, ref_transform=None): + self.root = root + self.imlist = flist_reader(root) + self.transform = transform + self.target_transform = target_transform + self.ref_transform = ref_transform + self.loader = loader + self.classes = torchvision.datasets.folder.find_classes(root)[0] + self.ref_dir = ref_dir + + + def __getitem__(self, index): + impath, target = self.imlist[index] + img = self.loader(impath) + class_ref = self.classes[target] + path_ref = self.ref_dir +'/'+ class_ref + '.npy' + img_ref = self.loader(path_ref) + if self.transform is not None: + img = self.transform(img) + img_ref = self.ref_transform(img_ref) + return img, img_ref + + def __len__(self): + return len(self.imlist) + +class ImageFolderDuo(data.Dataset): + def __init__(self, root, transform=None, target_transform=None, + flist_reader=make_dataset_base, loader=npy_loader, ref_transform=None): + self.root = root + self.imlist = flist_reader(root) + self.transform = transform + self.target_transform = target_transform + self.ref_transform = ref_transform + self.loader = loader + self.classes = torchvision.datasets.folder.find_classes(root)[0] + + def __getitem__(self, index): + impathAER, impathANA, target = self.imlist[index] + imgAER = self.loader(impathAER) + imgANA = self.loader(impathANA) + if self.transform is not None: + imgAER = self.transform(imgAER) + imgANA = self.transform(imgANA) + return imgAER, imgANA, target + + def __len__(self): + return len(self.imlist) + +def load_data_duo(base_dir_train, base_dir_val, base_dir_test, batch_size, shuffle=True,ref_dir = None, sampler=None): + + print('Default val transform') + + train_dataset = ImageFolder(root=base_dir_train, ref_dir = ref_dir) + val_dataset = ImageFolder(root=base_dir_val, ref_dir = ref_dir) + + train_dataset_classifier = ImageFolderDuo(root=base_dir_train) + val_dataset_classifier = ImageFolderDuo(root=base_dir_val) + + if base_dir_test is not None : + test_dataset = ImageFolder(root=base_dir_test, ref_dir=ref_dir) + + test_dataset_classifier = ImageFolderDuo(root=base_dir_test) + + + if sampler =='balanced' : + y_train_label = np.array([i for (_,i)in train_dataset.imlist]) + class_sample_count = np.array([len(np.where(y_train_label == t)[0]) for t in np.unique(y_train_label)]) + weight = 1. / class_sample_count + samples_weight = np.array([weight[t] for t in y_train_label]) + + samples_weight = torch.from_numpy(samples_weight) + sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight)) + + data_loader_train = data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + data_loader_train_classifier = data.DataLoader( + dataset=train_dataset_classifier, + batch_size=batch_size, + sampler=sampler, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + else : + data_loader_train = data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + data_loader_train_classifier = data.DataLoader( + dataset=train_dataset_classifier, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + data_loader_val = data.DataLoader( + dataset=val_dataset, + batch_size=1, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + data_val_classifier = data.DataLoader( + dataset=val_dataset_classifier, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + if base_dir_test is not None : + data_loader_test = data.DataLoader( + dataset=test_dataset, + batch_size=1, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + data_test_classifier = data.DataLoader( + dataset=test_dataset_classifier, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + else : + data_loader_test = None + + data_test_classifier = None + + return data_loader_train, data_loader_val, data_loader_test, data_loader_train_classifier, data_val_classifier, data_test_classifier diff --git a/barlow_twin_like/main.py b/barlow_twin_like/main.py index e69de29b..43cc536b 100644 --- a/barlow_twin_like/main.py +++ b/barlow_twin_like/main.py @@ -0,0 +1,193 @@ +import os + +import numpy as np +import torch +import wandb as wdb +from torch import optim, nn + +from model import BarlowTwins, BaseClassifier +from dataset_barlow import load_data_duo +from config import load_args_barlow + +def save_model(model, path): + print('Model saved') + torch.save(model.state_dict(), path) + + +def train_representation(model, data_train, optimizer, epoch, wandb): + model.train() + losses = 0. + for param in model.parameters(): + param.requires_grad = True + + for img, img_ref in data_train: + img = img.float() + img_ref = img_ref.float() + if torch.cuda.is_available(): + img = img.cuda() + img_ref = img_ref.cuda() + loss = model.forward(img, img_ref) + losses += loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses = losses / len(data_train.dataset) + print('Train epoch {}, loss : {:.3f}'.format(epoch, losses)) + + if wandb is not None: + wdb.log({"train loss": losses, 'train epoch': epoch}) + + return losses + + +def test_representation(model, data_val, epoch, wandb): + model.eval() + losses = 0. + for param in model.parameters(): + param.requires_grad = False + + for img, img_ref in data_val: + img = img.float() + img_ref = img_ref.float() + if torch.cuda.is_available(): + img = img.cuda() + img_ref = img_ref.cuda() + loss = model.forward(img, img_ref) + losses += loss.item() + losses = losses / len(data_val.dataset) + print('Val epoch {}, loss : {:.3f}'.format(epoch, losses)) + + if wandb is not None: + wdb.log({"train loss": losses, 'train epoch': epoch}) + + return losses + + +def train_classification(model, classifier, data_train, optimizer, epoch, wandb): + classifier.train() + losses = 0. + acc = 0. + loss_function = nn.CrossEntropyLoss() + for param in classifier.parameters(): + param.requires_grad = True + + for img, label in data_train: + img = img.float() + label = label.long() + if torch.cuda.is_available(): + img = img.cuda() + label = label.cuda() + representation = model(img) + pred_logits = classifier(representation) + pred_class = torch.argmax(pred_logits, dim=1) + acc += (pred_class == label).sum().item() + loss = loss_function(pred_logits, label) + losses += loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses = losses / len(data_train.dataset) + acc = acc / len(data_train.dataset) + print('Train classifier epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch, losses, acc)) + + if wandb is not None: + wdb.log({"train classifier loss": losses, 'train classifier epoch': epoch, "train classifier accuracy": acc}) + + return losses, acc + + +def test_classification(model, classifier, data_val, epoch, wandb): + classifier.train() + losses = 0. + acc = 0. + loss_function = nn.CrossEntropyLoss() + for param in classifier.parameters(): + param.requires_grad = False + + for img, label in data_val: + img = img.float() + label = label.long() + if torch.cuda.is_available(): + img = img.cuda() + label = label.cuda() + representation = model(img) + pred_logits = classifier(representation) + pred_class = torch.argmax(pred_logits, dim=1) + acc += (pred_class == label).sum().item() + loss = loss_function(pred_logits, label) + losses += loss.item() + losses = losses / len(data_val.dataset) + acc = acc / len(data_val.dataset) + print('Val classifier epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch, losses, acc)) + + if wandb is not None: + wdb.log({"val classifier loss": losses, 'val classifier epoch': epoch, "val classifier accuracy": acc}) + + return losses, acc + + +def run(): + args = load_args_barlow() + # wandb init + if args.wandb is not None: + os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd' + + os.environ["WANDB_MODE"] = "offline" + os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run") + + wdb.init(project="barlow_classification", dir='./wandb_run', + config={'hparam/optimizer': args.opti, 'hparam/sampler': args.sampler, + 'hparam/lr': args.lr, 'hparam/data_train': args.dataset_train_dir, + 'hparam/data_val': args.dataset_train_dir, + 'hparam/model': args.model, + 'hparam/ref_dir': args.dataset_ref_dir}) + + print('Wandb initialised') + # load data + data_train, data_val, data_test, data_train_classifier, data_val_classifier, data_test_classifier = ( + load_data_duo(base_dir_train=args.dataset_train_dir, + base_dir_val=args.dataset_val_dir, + base_dir_test=args.dataset_test_dir, + batch_size=args.batch_size, + ref_dir=args.dataset_ref_dir, + sampler=args.sampler)) + + # load model + model = BarlowTwins(args) + classifier = BaseClassifier(args) + model.float() + classifier.float() + # load weight + if args.pretrain_path is not None: + print('Model weight loaded') + model.load_state_dict(torch.load(args.pretrain_path, weights_only=True)) + # move parameters to GPU + if torch.cuda.is_available(): + print('Model loaded on GPU') + model = model.cuda() + classifier = classifier.cuda() + + if args.opti == 'adam': + optimizer = optim.Adam(model.parameters(), lr=args.lr) + else: + optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) + + best_loss = np.inf + for e in args.epoches: + loss = train_representation(model, data_train, optimizer, e, args.wandb) + if e % args.eval_inter == 0: + loss = test_representation(model, data_val, e, args.wandb) + if loss < best_loss: + save_model(model, args.save_path) + best_loss = loss + + model.load_state_dict((torch.load(args.save_path, weights_only=True))) #load best model + for param in model.parameters(): # freezing representations before classifier training + param.requires_grad = False + + for e in args.classification_epoches: + train_classification(model, classifier, data_train_classifier, optimizer, e, args.wandb) + test_classification() + +if __name__ == '__main__': + run() \ No newline at end of file diff --git a/barlow_twin_like/model.py b/barlow_twin_like/model.py index e69de29b..35cb5bb1 100644 --- a/barlow_twin_like/model.py +++ b/barlow_twin_like/model.py @@ -0,0 +1,325 @@ +import torch +import torch.nn as nn +import torchvision + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, in_channels=3): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x): + return self._forward_impl(x) + + +def _resnet(block, layers, **kwargs): + model = ResNet(block, layers, **kwargs) + + return model + + +def resnet18(num_classes=1000,**kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(BasicBlock, [2, 2, 2, 2],num_classes=num_classes, + **kwargs) + + + +def resnet34(num_classes=1000, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet( BasicBlock, [3, 4, 6, 3],num_classes=num_classes, + **kwargs) + + + +def resnet50(num_classes=1000,**kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes, + **kwargs) + + + +def resnet101(num_classes=1000,**kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes, + **kwargs) + + + +def resnet152(num_classes=1000,**kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(Bottleneck, [3, 8, 36, 3],num_classes=num_classes, + **kwargs) + +def off_diagonal(x): + # return a flattened view of the off-diagonal elements of a square matrix + n, m = x.shape + assert n == m + return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() + + +class BarlowTwins(nn.Module): + def __init__(self, args): + super().__init__() + if args.model =='ResNet18': + self.backbone = resnet18(zero_init_residual=True,in_channels=1) + if args.model =='ResNet34': + self.backbone = resnet34(zero_init_residual=True, in_channels=1) + if args.model =='ResNet50': + self.backbone = resnet50(zero_init_residual=True,in_channels=1) + self.args = args + self.backbone.fc = nn.Identity() #remove final fc layer + + # projector + sizes = [2048] + list(map(int, args.projector.split('-'))) + layers = [] + for i in range(len(sizes) - 2): + layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False)) + layers.append(nn.BatchNorm1d(sizes[i + 1])) + layers.append(nn.ReLU(inplace=True)) + layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False)) + self.projector = nn.Sequential(*layers) + + # normalization layer for the representations z1 and z2 + self.bn = nn.BatchNorm1d(sizes[-1], affine=False) + + def forward(self, y1, y2): + z1 = self.projector(self.backbone(y1)) + z2 = self.projector(self.backbone(y2)) + + # empirical cross-correlation matrix + c = self.bn(z1).T @ self.bn(z2) + + # sum the cross-correlation matrix between all gpus + c.div_(self.args.batch_size) + torch.distributed.all_reduce(c) + + on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() + off_diag = off_diagonal(c).pow_(2).sum() + loss = on_diag + self.args.lambd * off_diag + return loss + + def compute_representation(self,y1): + out = self.backbone(y1) + return out + +class BaseClassifier(nn.Module): + def __init__(self, args,n_classes): + super().__init__() + self.classifier = nn.Sequential( + nn.Linear(list(map(int, args.projector.split('-')))[-1]*2,n_classes) + ) + + def forward(self, y1, y2): + input = torch.concat([y1, y2],dim=1) + out = self.classifier(input) + return out diff --git a/requirements.txt b/requirements.txt index b528c6cb..826fb23a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,9 @@ pillow~=11.1.0 seaborn~=0.13.2 scikit-learn~=1.6.1 fastapy~=1.0.5 -pyarrow~=19.0.1 \ No newline at end of file +pyarrow~=19.0.1 +wandb~=0.19.9 +opencv-python~=4.11.0.86 +ray~=2.44.1 +setuptools~=68.2.0 +pythonnet~=3.0.5 \ No newline at end of file -- GitLab