From 85d9b6479ac01fe49fdcc8b7d4e4d113a3bf4171 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Mon, 31 Mar 2025 13:18:21 +0200
Subject: [PATCH] add : model contrastive and dataloader

---
 config/config.py       |  21 +++
 dataset/dataset_ref.py | 186 ++++++++++++++++++++++++++
 image_ref/main.py      | 282 ++++++++++++++++++++++++++++++++++++++-
 image_ref/model.py     | 294 +++++++++++++++++++++++++++++++++++++++++
 image_ref/utils.py     |   1 +
 models/model.py        |  13 +-
 requirements.txt       |   5 +-
 7 files changed, 788 insertions(+), 14 deletions(-)
 create mode 100644 dataset/dataset_ref.py
 create mode 100644 image_ref/model.py

diff --git a/config/config.py b/config/config.py
index fca846d..4975373 100644
--- a/config/config.py
+++ b/config/config.py
@@ -19,3 +19,24 @@ def load_args():
     args = parser.parse_args()
 
     return args
+
+def load_args_contrastive():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('--epoches', type=int, default=3)
+    parser.add_argument('--save_inter', type=int, default=50)
+    parser.add_argument('--eval_inter', type=int, default=1)
+    parser.add_argument('--noise_threshold', type=int, default=0)
+    parser.add_argument('--lr', type=float, default=0.001)
+    parser.add_argument('--batch_size', type=int, default=64)
+    parser.add_argument('--model', type=str, default='ResNet18')
+    parser.add_argument('--model_type', type=str, default='duo')
+    parser.add_argument('--dataset_dir', type=str, default='../data/processed_data/npy_image/data_training')
+    parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref')
+    parser.add_argument('--output', type=str, default='../output/out_contrastive.csv')
+    parser.add_argument('--save_path', type=str, default='../output/best_model_constrastive.pt')
+    parser.add_argument('--pretrain_path', type=str, default=None)
+    args = parser.parse_args()
+
+    return args
+
diff --git a/dataset/dataset_ref.py b/dataset/dataset_ref.py
new file mode 100644
index 0000000..b77f9c7
--- /dev/null
+++ b/dataset/dataset_ref.py
@@ -0,0 +1,186 @@
+import numpy as np
+import torch
+import torchvision
+import torchvision.transforms as transforms
+import torch.utils.data as data
+from PIL import Image
+import os
+import os.path
+from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
+from pathlib import Path
+from collections import OrderedDict
+from sklearn.model_selection import train_test_split
+IMG_EXTENSIONS = ".npy"
+
+class Threshold_noise:
+    """Remove intensities under given threshold"""
+
+    def __init__(self, threshold=100.):
+        self.threshold = threshold
+
+    def __call__(self, x):
+        return torch.where((x <= self.threshold), 0.,x)
+
+class Log_normalisation:
+    """Log normalisation of intensities"""
+
+    def __init__(self, eps=1e-5):
+        self.epsilon = eps
+
+    def __call__(self, x):
+        return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon)
+
+class Random_shift_rt():
+    pass
+
+
+def default_loader(path):
+    return Image.open(path).convert('RGB')
+
+def npy_loader(path):
+    sample = torch.from_numpy(np.load(path))
+    sample = sample.unsqueeze(0)
+    return sample
+
+def remove_aer_ana(l):
+    l = map(lambda x : x.split('_')[0],l)
+    return list(OrderedDict.fromkeys(l))
+
+def make_dataset_custom(
+    directory: Union[str, Path],
+    class_to_idx: Optional[Dict[str, int]] = None,
+    extensions: Optional[Union[str, Tuple[str, ...]]] = IMG_EXTENSIONS,
+    is_valid_file: Optional[Callable[[str], bool]] = None,
+    allow_empty: bool = False,
+) -> List[Tuple[str, str, int]]:
+    """Generates a list of samples of a form (path_to_sample, class).
+
+    See :class:`DatasetFolder` for details.
+
+    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
+    by default.
+    """
+    directory = os.path.expanduser(directory)
+
+    if class_to_idx is None:
+        _, class_to_idx = torchvision.datasets.folder.find_classes(directory)
+    elif not class_to_idx:
+        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
+
+    both_none = extensions is None and is_valid_file is None
+    both_something = extensions is not None and is_valid_file is not None
+    if both_none or both_something:
+        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
+
+    if extensions is not None:
+
+        def is_valid_file(x: str) -> bool:
+            return torchvision.datasets.folder.has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]
+
+    is_valid_file = cast(Callable[[str], bool], is_valid_file)
+
+    instances = []
+    available_classes = set()
+    for target_class in sorted(class_to_idx.keys()):
+        class_index = class_to_idx[target_class]
+        target_dir = os.path.join(directory, target_class)
+        if not os.path.isdir(target_dir):
+            continue
+        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
+            fnames_base = remove_aer_ana(fnames)
+            for fname in sorted(fnames_base):
+                fname_ana = fname+'_ANA.npy'
+                fname_aer = fname + '_AER.npy'
+                path_ana = os.path.join(root, fname_ana)
+                path_aer = os.path.join(root, fname_aer)
+                if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer):
+                    item = path_aer, path_ana, class_index
+                    instances.append(item)
+
+                    if target_class not in available_classes:
+                        available_classes.add(target_class)
+
+    empty_classes = set(class_to_idx.keys()) - available_classes
+    if empty_classes and not allow_empty:
+        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
+        if extensions is not None:
+            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
+        raise FileNotFoundError(msg)
+
+    return instances
+
+
+class ImageFolderDuo(data.Dataset):
+    def __init__(self, root, transform=None, target_transform=None,
+                 flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None):
+        self.root = root
+        self.imlist = flist_reader(root)
+        self.transform = transform
+        self.target_transform = target_transform
+        self.loader = loader
+        self.classes = torchvision.datasets.folder.find_classes(root)[0]
+        self.ref_dir = ref_dir
+
+    def __getitem__(self, index):
+        impathAER, impathANA, target = self.imlist[index]
+        imgAER = self.loader(impathAER)
+        imgANA = self.loader(impathANA)
+        label_ref = np.random.randint(0,len(self.classes)-1)
+        class_ref = self.classes[label_ref]
+        path_ref = self.ref_dir + class_ref +'.npy'
+        if self.transform is not None:
+            imgAER = self.transform(imgAER)
+            imgANA = self.transform(imgANA)
+        if self.target_transform is not None:
+            target = 0 if self.target_transform(target) == label_ref else 1
+        img_ref = self.loader(path_ref)
+        return imgAER, imgANA, img_ref, target
+
+    def __len__(self):
+        return len(self.imlist)
+
+def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0, ref_dir = None):
+    train_transform = transforms.Compose(
+        [transforms.Resize((224, 224)),
+         Threshold_noise(noise_threshold),
+         Log_normalisation(),
+         transforms.Normalize(0.5, 0.5)])
+    print('Default train transform')
+
+    val_transform = transforms.Compose(
+        [transforms.Resize((224, 224)),
+         Threshold_noise(noise_threshold),
+         Log_normalisation(),
+         transforms.Normalize(0.5, 0.5)])
+    print('Default val transform')
+
+    train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform, ref_dir = ref_dir)
+    val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform, ref_dir = ref_dir)
+    generator1 = torch.Generator().manual_seed(42)
+    indices = torch.randperm(len(train_dataset), generator=generator1)
+    val_size = len(train_dataset) // 5
+    train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])
+    val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])
+
+    data_loader_train = data.DataLoader(
+        dataset=train_dataset,
+        batch_size=batch_size,
+        shuffle=shuffle,
+        num_workers=0,
+        collate_fn=None,
+        pin_memory=False,
+    )
+
+    data_loader_test = data.DataLoader(
+        dataset=val_dataset,
+        batch_size=batch_size,
+        shuffle=shuffle,
+        num_workers=0,
+        collate_fn=None,
+        pin_memory=False,
+    )
+
+    return data_loader_train, data_loader_test
+
+def load_data():
+    raise 'Not implemented'
\ No newline at end of file
diff --git a/image_ref/main.py b/image_ref/main.py
index 0eda02c..b659004 100644
--- a/image_ref/main.py
+++ b/image_ref/main.py
@@ -1,4 +1,284 @@
 #TODO REFAIRE UN DATASET https://discuss.pytorch.org/t/upload-a-customize-data-set-for-multi-regression-task/43413?u=ptrblck
 """1er methode load 1 image pour 1 ref
 2eme methode : load 1 image et toutes les refs : ok pour l'instant mais a voir comment est ce que cela scale avec l'augmentation du nb de classes
-3eme methods 2 datasets différents : plus efficace en stockage mais pas facil a maintenir"""
\ No newline at end of file
+3eme methods 2 datasets différents : plus efficace en stockage mais pas facil a maintenir"""
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from config.config import load_args
+from dataset.dataset_ref import load_data, load_data_duo
+import torch
+import torch.nn as nn
+from image_ref.model import Classification_model_contrastive, Classification_model_duo_contrastive
+import torch.optim as optim
+from sklearn.metrics import confusion_matrix
+import seaborn as sn
+import pandas as pd
+
+
+
+def train(model, data_train, optimizer, loss_function, epoch):
+    model.train()
+    losses = 0.
+    acc = 0.
+    for param in model.parameters():
+        param.requires_grad = True
+
+    for im, label in data_train:
+        label = label.long()
+        if torch.cuda.is_available():
+            im, label = im.cuda(), label.cuda()
+        pred_logits = model.forward(im)
+        pred_class = torch.argmax(pred_logits,dim=1)
+        acc += (pred_class==label).sum().item()
+        loss = loss_function(pred_logits,label)
+        losses += loss.item()
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+    losses = losses/len(data_train.dataset)
+    acc = acc/len(data_train.dataset)
+    print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
+    return losses, acc
+
+def test(model, data_test, loss_function, epoch):
+    model.eval()
+    losses = 0.
+    acc = 0.
+    for param in model.parameters():
+        param.requires_grad = False
+
+    for im, label in data_test:
+        label = label.long()
+        if torch.cuda.is_available():
+            im, label = im.cuda(), label.cuda()
+        pred_logits = model.forward(im)
+        pred_class = torch.argmax(pred_logits,dim=1)
+        acc += (pred_class==label).sum().item()
+        loss = loss_function(pred_logits,label)
+        losses += loss.item()
+    losses = losses/len(data_test.dataset)
+    acc = acc/len(data_test.dataset)
+    print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
+    return losses,acc
+
+def run(args):
+    #load data
+    data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size)
+    #load model
+    model = Classification_model_contrastive(model = args.model, n_class=2,
+                                             ref_dir = args.dataset_ref_dir)
+    #load weights
+    if args.pretrain_path is not None :
+        load_model(model,args.pretrain_path)
+    #move parameters to GPU
+    if torch.cuda.is_available():
+        model = model.cuda()
+    #init accumulator
+    best_acc = 0
+    train_acc=[]
+    train_loss=[]
+    val_acc=[]
+    val_loss=[]
+    #init training
+    loss_function = nn.CrossEntropyLoss()
+    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
+    #traing
+    for e in range(args.epoches):
+        loss, acc = train(model,data_train,optimizer,loss_function,e)
+        train_loss.append(loss)
+        train_acc.append(acc)
+        if e%args.eval_inter==0 :
+            loss, acc = test(model,data_test,loss_function,e)
+            val_loss.append(loss)
+            val_acc.append(acc)
+            if acc > best_acc :
+                save_model(model,args.save_path)
+                best_acc = acc
+    #plot and save training figs
+    plt.plot(train_acc)
+    plt.plot(val_acc)
+    plt.plot(train_acc)
+    plt.plot(train_acc)
+    plt.ylim(0, 1.05)
+    plt.show()
+    plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
+
+    #load and evaluated best model
+    load_model(model, args.save_path)
+    make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
+
+
+def make_prediction(model, data, f_name):
+    y_pred = []
+    y_true = []
+
+    # iterate over test data
+    for im, label in data:
+        label = label.long()
+        if torch.cuda.is_available():
+            im = im.cuda()
+        output = model(im)
+
+        output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
+        y_pred.extend(output)
+
+        label = label.data.cpu().numpy()
+        y_true.extend(label)  # Save Truth
+    # constant for classes
+
+    classes = data.dataset.dataset.classes
+
+    # Build confusion matrix
+    cf_matrix = confusion_matrix(y_true, y_pred)
+    df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
+                         columns=[i for i in classes])
+    plt.figure(figsize=(12, 7))
+    sn.heatmap(df_cm, annot=cf_matrix)
+    plt.savefig(f_name)
+
+
+def train_duo(model, data_train, optimizer, loss_function, epoch):
+    model.train()
+    losses = 0.
+    acc = 0.
+    for param in model.parameters():
+        param.requires_grad = True
+
+    for imaer,imana, label in data_train:
+        label = label.long()
+        if torch.cuda.is_available():
+            imaer = imaer.cuda()
+            imana = imana.cuda()
+            label = label.cuda()
+        pred_logits = model.forward(imaer,imana)
+        pred_class = torch.argmax(pred_logits,dim=1)
+        acc += (pred_class==label).sum().item()
+        loss = loss_function(pred_logits,label)
+        losses += loss.item()
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+    losses = losses/len(data_train.dataset)
+    acc = acc/len(data_train.dataset)
+    print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
+    return losses, acc
+
+def test_duo(model, data_test, loss_function, epoch):
+    model.eval()
+    losses = 0.
+    acc = 0.
+    for param in model.parameters():
+        param.requires_grad = False
+
+    for imaer,imana, label in data_test:
+        label = label.long()
+        if torch.cuda.is_available():
+            imaer = imaer.cuda()
+            imana = imana.cuda()
+            label = label.cuda()
+        pred_logits = model.forward(imaer,imana)
+        pred_class = torch.argmax(pred_logits,dim=1)
+        acc += (pred_class==label).sum().item()
+        loss = loss_function(pred_logits,label)
+        losses += loss.item()
+    losses = losses/len(data_test.dataset)
+    acc = acc/len(data_test.dataset)
+    print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc))
+    return losses,acc
+
+def run_duo(args):
+    #load data
+    data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size, )
+    #load model
+    model = Classification_model_duo_contrastive(model = args.model, n_class=len(data_train.dataset.dataset.classes))
+    model.double()
+    #load weight
+    if args.pretrain_path is not None :
+        load_model(model,args.pretrain_path)
+    #move parameters to GPU
+    if torch.cuda.is_available():
+        model = model.cuda()
+
+    #init accumulators
+    best_acc = 0
+    train_acc=[]
+    train_loss=[]
+    val_acc=[]
+    val_loss=[]
+    #init training
+    loss_function = nn.CrossEntropyLoss()
+    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
+    #train model
+    for e in range(args.epoches):
+        loss, acc = train_duo(model,data_train,optimizer,loss_function,e)
+        train_loss.append(loss)
+        train_acc.append(acc)
+        if e%args.eval_inter==0 :
+            loss, acc = test_duo(model,data_test,loss_function,e)
+            val_loss.append(loss)
+            val_acc.append(acc)
+            if acc > best_acc :
+                save_model(model,args.save_path)
+                best_acc = acc
+    # plot and save training figs
+    plt.plot(train_acc)
+    plt.plot(val_acc)
+    plt.plot(train_acc)
+    plt.plot(train_acc)
+    plt.ylim(0, 1.05)
+    plt.show()
+
+    plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
+    #load and evaluate best model
+    load_model(model, args.save_path)
+    make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type))
+
+
+def make_prediction_duo(model, data, f_name):
+    y_pred = []
+    y_true = []
+    # iterate over test data
+    for imaer,imana, label in data:
+        label = label.long()
+        if torch.cuda.is_available():
+            imaer = imaer.cuda()
+            imana = imana.cuda()
+            label = label.cuda()
+        output = model(imaer,imana)
+
+        output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
+        y_pred.extend(output)
+
+        label = label.data.cpu().numpy()
+        y_true.extend(label)  # Save Truth
+    # constant for classes
+
+    classes = data.dataset.dataset.classes
+    # Build confusion matrix
+    print(len(y_true),len(y_pred))
+    cf_matrix = confusion_matrix(y_true, y_pred)
+    df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes],
+                         columns=[i for i in classes])
+    print('Saving Confusion Matrix')
+    plt.figure(figsize=(14, 9))
+    sn.heatmap(df_cm, annot=cf_matrix)
+    plt.savefig(f_name)
+
+
+def save_model(model, path):
+    print('Model saved')
+    torch.save(model.state_dict(), path)
+
+def load_model(model, path):
+    model.load_state_dict(torch.load(path, weights_only=True))
+
+
+
+if __name__ == '__main__':
+    args = load_args()
+    if args.model_type=='duo':
+        run_duo(args)
+    else :
+        run(args)
\ No newline at end of file
diff --git a/image_ref/model.py b/image_ref/model.py
new file mode 100644
index 0000000..0374d1e
--- /dev/null
+++ b/image_ref/model.py
@@ -0,0 +1,294 @@
+import torch
+import torch.nn as nn
+import torchvision
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None):
+        super(BasicBlock, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        if groups != 1 or base_width != 64:
+            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = norm_layer(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = norm_layer(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+    # This variant is also known as ResNet V1.5 and improves accuracy according to
+    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None):
+        super(Bottleneck, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        width = int(planes * (base_width / 64.)) * groups
+        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv1x1(inplanes, width)
+        self.bn1 = norm_layer(width)
+        self.conv2 = conv3x3(width, width, stride, groups, dilation)
+        self.bn2 = norm_layer(width)
+        self.conv3 = conv1x1(width, planes * self.expansion)
+        self.bn3 = norm_layer(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class ResNet(nn.Module):
+
+    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
+                 norm_layer=None, in_channels=3):
+        super(ResNet, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        self._norm_layer = norm_layer
+
+        self.inplanes = 64
+        self.dilation = 1
+        if replace_stride_with_dilation is None:
+            # each element in the tuple indicates if we should replace
+            # the 2x2 stride with a dilated convolution instead
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError("replace_stride_with_dilation should be None "
+                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+        self.groups = groups
+        self.base_width = width_per_group
+        self.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = norm_layer(self.inplanes)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+                                       dilate=replace_stride_with_dilation[0])
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+                                       dilate=replace_stride_with_dilation[1])
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+                                       dilate=replace_stride_with_dilation[2])
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+                elif isinstance(m, BasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                norm_layer(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+                            self.base_width, previous_dilation, norm_layer))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes, groups=self.groups,
+                                base_width=self.base_width, dilation=self.dilation,
+                                norm_layer=norm_layer))
+
+        return nn.Sequential(*layers)
+
+    def _forward_impl(self, x):
+        # See note [TorchScript super()]
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+        x = self.fc(x)
+
+        return x
+
+    def forward(self, x):
+        return self._forward_impl(x)
+
+
+def _resnet(block, layers, **kwargs):
+    model = ResNet(block, layers, **kwargs)
+
+    return model
+
+
+def resnet18(num_classes=1000,**kwargs):
+    r"""ResNet-18 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+    """
+    return _resnet(BasicBlock, [2, 2, 2, 2],num_classes=num_classes,
+                   **kwargs)
+
+
+
+def resnet34(num_classes=1000, **kwargs):
+    r"""ResNet-34 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+    """
+    return _resnet( BasicBlock, [3, 4, 6, 3],num_classes=num_classes,
+                   **kwargs)
+
+
+
+def resnet50(num_classes=1000,**kwargs):
+    r"""ResNet-50 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+    """
+    return _resnet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes,
+                   **kwargs)
+
+
+
+def resnet101(num_classes=1000,**kwargs):
+    r"""ResNet-101 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+    """
+    return _resnet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes,
+                   **kwargs)
+
+
+
+def resnet152(num_classes=1000,**kwargs):
+    r"""ResNet-152 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+    """
+    return _resnet(Bottleneck, [3, 8, 36, 3],num_classes=num_classes,
+                   **kwargs)
+
+
+class Classification_model_contrastive(nn.Module):
+
+    def __init__(self, model, n_class, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.n_class = n_class
+        if model =='ResNet18':
+            self.im_encoder = resnet18(num_classes=self.n_class, in_channels=2)
+
+
+    def forward(self, input, ref):
+        input = torch.concat(input,ref,dim=2)
+        return self.im_encoder(input)
+
+class Classification_model_duo_contrastive(nn.Module):
+
+    def __init__(self, model, n_class, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.n_class = n_class
+        if model =='ResNet18':
+            self.im_encoder = resnet18(num_classes=self.n_class, in_channels=2)
+
+        self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class)
+
+
+    def forward(self, input_aer, input_ana, input_ref):
+        input_ana = torch.concat(input_ana, input_ref, dim=2)
+        input_aer = torch.concat(input_aer, input_ref, dim=2)
+        out_aer =  self.im_encoder(input_aer)
+        out_ana = self.im_encoder(input_ana)
+        out = torch.concat([out_aer,out_ana],dim=1)
+        return self.predictor(out)
\ No newline at end of file
diff --git a/image_ref/utils.py b/image_ref/utils.py
index 900b6e2..8182a39 100644
--- a/image_ref/utils.py
+++ b/image_ref/utils.py
@@ -241,4 +241,5 @@ if __name__ == '__main__':
             ms1_start_mz=350, bin_mz=1, max_cycle=663, min_rt=min_rt, max_rt=max_rt)
         plt.clf()
         mpimg.imsave(spe+'.png', im)
+        np.save(spe+'.npy', im)
 
diff --git a/models/model.py b/models/model.py
index b59076c..f0a3d83 100644
--- a/models/model.py
+++ b/models/model.py
@@ -272,18 +272,7 @@ class Classification_model(nn.Module):
     def forward(self, input):
         return self.im_encoder(input)
 
-class Classification_model_contrastive(nn.Module):
 
-    def __init__(self, model, n_class, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.n_class = n_class
-        if model =='ResNet18':
-            self.im_encoder = resnet18(num_classes=self.n_class, in_channels=2)
-
-
-    def forward(self, input, ref):
-        input = torch.concat(input,ref,dim=2)
-        return self.im_encoder(input)
 
 class Classification_model_duo(nn.Module):
 
@@ -296,7 +285,7 @@ class Classification_model_duo(nn.Module):
         self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class)
 
 
-    def forward(self, input_aer, input_ana, input_ref):
+    def forward(self, input_aer, input_ana):
         out_aer =  self.im_encoder(input_aer)
         out_ana = self.im_encoder(input_ana)
         out = torch.concat([out_aer,out_ana],dim=1)
diff --git a/requirements.txt b/requirements.txt
index 340cebc..b528c6c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,4 +6,7 @@ openpyxl
 torch~=2.6.0
 torchvision~=0.21.0
 pillow~=11.1.0
-seaborn~=0.13.2
\ No newline at end of file
+seaborn~=0.13.2
+scikit-learn~=1.6.1
+fastapy~=1.0.5
+pyarrow~=19.0.1
\ No newline at end of file
-- 
GitLab