From 85d9b6479ac01fe49fdcc8b7d4e4d113a3bf4171 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Mon, 31 Mar 2025 13:18:21 +0200 Subject: [PATCH] add : model contrastive and dataloader --- config/config.py | 21 +++ dataset/dataset_ref.py | 186 ++++++++++++++++++++++++++ image_ref/main.py | 282 ++++++++++++++++++++++++++++++++++++++- image_ref/model.py | 294 +++++++++++++++++++++++++++++++++++++++++ image_ref/utils.py | 1 + models/model.py | 13 +- requirements.txt | 5 +- 7 files changed, 788 insertions(+), 14 deletions(-) create mode 100644 dataset/dataset_ref.py create mode 100644 image_ref/model.py diff --git a/config/config.py b/config/config.py index fca846d..4975373 100644 --- a/config/config.py +++ b/config/config.py @@ -19,3 +19,24 @@ def load_args(): args = parser.parse_args() return args + +def load_args_contrastive(): + parser = argparse.ArgumentParser() + + parser.add_argument('--epoches', type=int, default=3) + parser.add_argument('--save_inter', type=int, default=50) + parser.add_argument('--eval_inter', type=int, default=1) + parser.add_argument('--noise_threshold', type=int, default=0) + parser.add_argument('--lr', type=float, default=0.001) + parser.add_argument('--batch_size', type=int, default=64) + parser.add_argument('--model', type=str, default='ResNet18') + parser.add_argument('--model_type', type=str, default='duo') + parser.add_argument('--dataset_dir', type=str, default='../data/processed_data/npy_image/data_training') + parser.add_argument('--dataset_ref_dir', type=str, default='image_ref/img_ref') + parser.add_argument('--output', type=str, default='../output/out_contrastive.csv') + parser.add_argument('--save_path', type=str, default='../output/best_model_constrastive.pt') + parser.add_argument('--pretrain_path', type=str, default=None) + args = parser.parse_args() + + return args + diff --git a/dataset/dataset_ref.py b/dataset/dataset_ref.py new file mode 100644 index 0000000..b77f9c7 --- /dev/null +++ b/dataset/dataset_ref.py @@ -0,0 +1,186 @@ +import numpy as np +import torch +import torchvision +import torchvision.transforms as transforms +import torch.utils.data as data +from PIL import Image +import os +import os.path +from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from pathlib import Path +from collections import OrderedDict +from sklearn.model_selection import train_test_split +IMG_EXTENSIONS = ".npy" + +class Threshold_noise: + """Remove intensities under given threshold""" + + def __init__(self, threshold=100.): + self.threshold = threshold + + def __call__(self, x): + return torch.where((x <= self.threshold), 0.,x) + +class Log_normalisation: + """Log normalisation of intensities""" + + def __init__(self, eps=1e-5): + self.epsilon = eps + + def __call__(self, x): + return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon) + +class Random_shift_rt(): + pass + + +def default_loader(path): + return Image.open(path).convert('RGB') + +def npy_loader(path): + sample = torch.from_numpy(np.load(path)) + sample = sample.unsqueeze(0) + return sample + +def remove_aer_ana(l): + l = map(lambda x : x.split('_')[0],l) + return list(OrderedDict.fromkeys(l)) + +def make_dataset_custom( + directory: Union[str, Path], + class_to_idx: Optional[Dict[str, int]] = None, + extensions: Optional[Union[str, Tuple[str, ...]]] = IMG_EXTENSIONS, + is_valid_file: Optional[Callable[[str], bool]] = None, + allow_empty: bool = False, +) -> List[Tuple[str, str, int]]: + """Generates a list of samples of a form (path_to_sample, class). + + See :class:`DatasetFolder` for details. + + Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function + by default. + """ + directory = os.path.expanduser(directory) + + if class_to_idx is None: + _, class_to_idx = torchvision.datasets.folder.find_classes(directory) + elif not class_to_idx: + raise ValueError("'class_to_index' must have at least one entry to collect any samples.") + + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + + if extensions is not None: + + def is_valid_file(x: str) -> bool: + return torchvision.datasets.folder.has_file_allowed_extension(x, extensions) # type: ignore[arg-type] + + is_valid_file = cast(Callable[[str], bool], is_valid_file) + + instances = [] + available_classes = set() + for target_class in sorted(class_to_idx.keys()): + class_index = class_to_idx[target_class] + target_dir = os.path.join(directory, target_class) + if not os.path.isdir(target_dir): + continue + for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): + fnames_base = remove_aer_ana(fnames) + for fname in sorted(fnames_base): + fname_ana = fname+'_ANA.npy' + fname_aer = fname + '_AER.npy' + path_ana = os.path.join(root, fname_ana) + path_aer = os.path.join(root, fname_aer) + if is_valid_file(path_ana) and is_valid_file(path_aer) and os.path.isfile(path_ana) and os.path.isfile(path_aer): + item = path_aer, path_ana, class_index + instances.append(item) + + if target_class not in available_classes: + available_classes.add(target_class) + + empty_classes = set(class_to_idx.keys()) - available_classes + if empty_classes and not allow_empty: + msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " + if extensions is not None: + msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" + raise FileNotFoundError(msg) + + return instances + + +class ImageFolderDuo(data.Dataset): + def __init__(self, root, transform=None, target_transform=None, + flist_reader=make_dataset_custom, loader=npy_loader, ref_dir = None): + self.root = root + self.imlist = flist_reader(root) + self.transform = transform + self.target_transform = target_transform + self.loader = loader + self.classes = torchvision.datasets.folder.find_classes(root)[0] + self.ref_dir = ref_dir + + def __getitem__(self, index): + impathAER, impathANA, target = self.imlist[index] + imgAER = self.loader(impathAER) + imgANA = self.loader(impathANA) + label_ref = np.random.randint(0,len(self.classes)-1) + class_ref = self.classes[label_ref] + path_ref = self.ref_dir + class_ref +'.npy' + if self.transform is not None: + imgAER = self.transform(imgAER) + imgANA = self.transform(imgANA) + if self.target_transform is not None: + target = 0 if self.target_transform(target) == label_ref else 1 + img_ref = self.loader(path_ref) + return imgAER, imgANA, img_ref, target + + def __len__(self): + return len(self.imlist) + +def load_data_duo(base_dir, batch_size, shuffle=True, noise_threshold=0, ref_dir = None): + train_transform = transforms.Compose( + [transforms.Resize((224, 224)), + Threshold_noise(noise_threshold), + Log_normalisation(), + transforms.Normalize(0.5, 0.5)]) + print('Default train transform') + + val_transform = transforms.Compose( + [transforms.Resize((224, 224)), + Threshold_noise(noise_threshold), + Log_normalisation(), + transforms.Normalize(0.5, 0.5)]) + print('Default val transform') + + train_dataset = ImageFolderDuo(root=base_dir, transform=train_transform, ref_dir = ref_dir) + val_dataset = ImageFolderDuo(root=base_dir, transform=val_transform, ref_dir = ref_dir) + generator1 = torch.Generator().manual_seed(42) + indices = torch.randperm(len(train_dataset), generator=generator1) + val_size = len(train_dataset) // 5 + train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size]) + val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:]) + + data_loader_train = data.DataLoader( + dataset=train_dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + data_loader_test = data.DataLoader( + dataset=val_dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + return data_loader_train, data_loader_test + +def load_data(): + raise 'Not implemented' \ No newline at end of file diff --git a/image_ref/main.py b/image_ref/main.py index 0eda02c..b659004 100644 --- a/image_ref/main.py +++ b/image_ref/main.py @@ -1,4 +1,284 @@ #TODO REFAIRE UN DATASET https://discuss.pytorch.org/t/upload-a-customize-data-set-for-multi-regression-task/43413?u=ptrblck """1er methode load 1 image pour 1 ref 2eme methode : load 1 image et toutes les refs : ok pour l'instant mais a voir comment est ce que cela scale avec l'augmentation du nb de classes -3eme methods 2 datasets différents : plus efficace en stockage mais pas facil a maintenir""" \ No newline at end of file +3eme methods 2 datasets différents : plus efficace en stockage mais pas facil a maintenir""" + +import matplotlib.pyplot as plt +import numpy as np + +from config.config import load_args +from dataset.dataset_ref import load_data, load_data_duo +import torch +import torch.nn as nn +from image_ref.model import Classification_model_contrastive, Classification_model_duo_contrastive +import torch.optim as optim +from sklearn.metrics import confusion_matrix +import seaborn as sn +import pandas as pd + + + +def train(model, data_train, optimizer, loss_function, epoch): + model.train() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = True + + for im, label in data_train: + label = label.long() + if torch.cuda.is_available(): + im, label = im.cuda(), label.cuda() + pred_logits = model.forward(im) + pred_class = torch.argmax(pred_logits,dim=1) + acc += (pred_class==label).sum().item() + loss = loss_function(pred_logits,label) + losses += loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses = losses/len(data_train.dataset) + acc = acc/len(data_train.dataset) + print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) + return losses, acc + +def test(model, data_test, loss_function, epoch): + model.eval() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = False + + for im, label in data_test: + label = label.long() + if torch.cuda.is_available(): + im, label = im.cuda(), label.cuda() + pred_logits = model.forward(im) + pred_class = torch.argmax(pred_logits,dim=1) + acc += (pred_class==label).sum().item() + loss = loss_function(pred_logits,label) + losses += loss.item() + losses = losses/len(data_test.dataset) + acc = acc/len(data_test.dataset) + print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) + return losses,acc + +def run(args): + #load data + data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size) + #load model + model = Classification_model_contrastive(model = args.model, n_class=2, + ref_dir = args.dataset_ref_dir) + #load weights + if args.pretrain_path is not None : + load_model(model,args.pretrain_path) + #move parameters to GPU + if torch.cuda.is_available(): + model = model.cuda() + #init accumulator + best_acc = 0 + train_acc=[] + train_loss=[] + val_acc=[] + val_loss=[] + #init training + loss_function = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + #traing + for e in range(args.epoches): + loss, acc = train(model,data_train,optimizer,loss_function,e) + train_loss.append(loss) + train_acc.append(acc) + if e%args.eval_inter==0 : + loss, acc = test(model,data_test,loss_function,e) + val_loss.append(loss) + val_acc.append(acc) + if acc > best_acc : + save_model(model,args.save_path) + best_acc = acc + #plot and save training figs + plt.plot(train_acc) + plt.plot(val_acc) + plt.plot(train_acc) + plt.plot(train_acc) + plt.ylim(0, 1.05) + plt.show() + plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) + + #load and evaluated best model + load_model(model, args.save_path) + make_prediction(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) + + +def make_prediction(model, data, f_name): + y_pred = [] + y_true = [] + + # iterate over test data + for im, label in data: + label = label.long() + if torch.cuda.is_available(): + im = im.cuda() + output = model(im) + + output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() + y_pred.extend(output) + + label = label.data.cpu().numpy() + y_true.extend(label) # Save Truth + # constant for classes + + classes = data.dataset.dataset.classes + + # Build confusion matrix + cf_matrix = confusion_matrix(y_true, y_pred) + df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], + columns=[i for i in classes]) + plt.figure(figsize=(12, 7)) + sn.heatmap(df_cm, annot=cf_matrix) + plt.savefig(f_name) + + +def train_duo(model, data_train, optimizer, loss_function, epoch): + model.train() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = True + + for imaer,imana, label in data_train: + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + label = label.cuda() + pred_logits = model.forward(imaer,imana) + pred_class = torch.argmax(pred_logits,dim=1) + acc += (pred_class==label).sum().item() + loss = loss_function(pred_logits,label) + losses += loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses = losses/len(data_train.dataset) + acc = acc/len(data_train.dataset) + print('Train epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) + return losses, acc + +def test_duo(model, data_test, loss_function, epoch): + model.eval() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = False + + for imaer,imana, label in data_test: + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + label = label.cuda() + pred_logits = model.forward(imaer,imana) + pred_class = torch.argmax(pred_logits,dim=1) + acc += (pred_class==label).sum().item() + loss = loss_function(pred_logits,label) + losses += loss.item() + losses = losses/len(data_test.dataset) + acc = acc/len(data_test.dataset) + print('Test epoch {}, loss : {:.3f} acc : {:.3f}'.format(epoch,losses,acc)) + return losses,acc + +def run_duo(args): + #load data + data_train, data_test = load_data_duo(base_dir=args.dataset_dir, batch_size=args.batch_size, ) + #load model + model = Classification_model_duo_contrastive(model = args.model, n_class=len(data_train.dataset.dataset.classes)) + model.double() + #load weight + if args.pretrain_path is not None : + load_model(model,args.pretrain_path) + #move parameters to GPU + if torch.cuda.is_available(): + model = model.cuda() + + #init accumulators + best_acc = 0 + train_acc=[] + train_loss=[] + val_acc=[] + val_loss=[] + #init training + loss_function = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + #train model + for e in range(args.epoches): + loss, acc = train_duo(model,data_train,optimizer,loss_function,e) + train_loss.append(loss) + train_acc.append(acc) + if e%args.eval_inter==0 : + loss, acc = test_duo(model,data_test,loss_function,e) + val_loss.append(loss) + val_acc.append(acc) + if acc > best_acc : + save_model(model,args.save_path) + best_acc = acc + # plot and save training figs + plt.plot(train_acc) + plt.plot(val_acc) + plt.plot(train_acc) + plt.plot(train_acc) + plt.ylim(0, 1.05) + plt.show() + + plt.savefig('output/training_plot_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) + #load and evaluate best model + load_model(model, args.save_path) + make_prediction_duo(model,data_test, 'output/confusion_matrix_noise_{}_lr_{}_model_{}_{}.png'.format(args.noise_threshold,args.lr,args.model,args.model_type)) + + +def make_prediction_duo(model, data, f_name): + y_pred = [] + y_true = [] + # iterate over test data + for imaer,imana, label in data: + label = label.long() + if torch.cuda.is_available(): + imaer = imaer.cuda() + imana = imana.cuda() + label = label.cuda() + output = model(imaer,imana) + + output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy() + y_pred.extend(output) + + label = label.data.cpu().numpy() + y_true.extend(label) # Save Truth + # constant for classes + + classes = data.dataset.dataset.classes + # Build confusion matrix + print(len(y_true),len(y_pred)) + cf_matrix = confusion_matrix(y_true, y_pred) + df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index=[i for i in classes], + columns=[i for i in classes]) + print('Saving Confusion Matrix') + plt.figure(figsize=(14, 9)) + sn.heatmap(df_cm, annot=cf_matrix) + plt.savefig(f_name) + + +def save_model(model, path): + print('Model saved') + torch.save(model.state_dict(), path) + +def load_model(model, path): + model.load_state_dict(torch.load(path, weights_only=True)) + + + +if __name__ == '__main__': + args = load_args() + if args.model_type=='duo': + run_duo(args) + else : + run(args) \ No newline at end of file diff --git a/image_ref/model.py b/image_ref/model.py new file mode 100644 index 0000000..0374d1e --- /dev/null +++ b/image_ref/model.py @@ -0,0 +1,294 @@ +import torch +import torch.nn as nn +import torchvision + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, in_channels=3): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x): + return self._forward_impl(x) + + +def _resnet(block, layers, **kwargs): + model = ResNet(block, layers, **kwargs) + + return model + + +def resnet18(num_classes=1000,**kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(BasicBlock, [2, 2, 2, 2],num_classes=num_classes, + **kwargs) + + + +def resnet34(num_classes=1000, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet( BasicBlock, [3, 4, 6, 3],num_classes=num_classes, + **kwargs) + + + +def resnet50(num_classes=1000,**kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes, + **kwargs) + + + +def resnet101(num_classes=1000,**kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes, + **kwargs) + + + +def resnet152(num_classes=1000,**kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(Bottleneck, [3, 8, 36, 3],num_classes=num_classes, + **kwargs) + + +class Classification_model_contrastive(nn.Module): + + def __init__(self, model, n_class, *args, **kwargs): + super().__init__(*args, **kwargs) + self.n_class = n_class + if model =='ResNet18': + self.im_encoder = resnet18(num_classes=self.n_class, in_channels=2) + + + def forward(self, input, ref): + input = torch.concat(input,ref,dim=2) + return self.im_encoder(input) + +class Classification_model_duo_contrastive(nn.Module): + + def __init__(self, model, n_class, *args, **kwargs): + super().__init__(*args, **kwargs) + self.n_class = n_class + if model =='ResNet18': + self.im_encoder = resnet18(num_classes=self.n_class, in_channels=2) + + self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class) + + + def forward(self, input_aer, input_ana, input_ref): + input_ana = torch.concat(input_ana, input_ref, dim=2) + input_aer = torch.concat(input_aer, input_ref, dim=2) + out_aer = self.im_encoder(input_aer) + out_ana = self.im_encoder(input_ana) + out = torch.concat([out_aer,out_ana],dim=1) + return self.predictor(out) \ No newline at end of file diff --git a/image_ref/utils.py b/image_ref/utils.py index 900b6e2..8182a39 100644 --- a/image_ref/utils.py +++ b/image_ref/utils.py @@ -241,4 +241,5 @@ if __name__ == '__main__': ms1_start_mz=350, bin_mz=1, max_cycle=663, min_rt=min_rt, max_rt=max_rt) plt.clf() mpimg.imsave(spe+'.png', im) + np.save(spe+'.npy', im) diff --git a/models/model.py b/models/model.py index b59076c..f0a3d83 100644 --- a/models/model.py +++ b/models/model.py @@ -272,18 +272,7 @@ class Classification_model(nn.Module): def forward(self, input): return self.im_encoder(input) -class Classification_model_contrastive(nn.Module): - def __init__(self, model, n_class, *args, **kwargs): - super().__init__(*args, **kwargs) - self.n_class = n_class - if model =='ResNet18': - self.im_encoder = resnet18(num_classes=self.n_class, in_channels=2) - - - def forward(self, input, ref): - input = torch.concat(input,ref,dim=2) - return self.im_encoder(input) class Classification_model_duo(nn.Module): @@ -296,7 +285,7 @@ class Classification_model_duo(nn.Module): self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class) - def forward(self, input_aer, input_ana, input_ref): + def forward(self, input_aer, input_ana): out_aer = self.im_encoder(input_aer) out_ana = self.im_encoder(input_ana) out = torch.concat([out_aer,out_ana],dim=1) diff --git a/requirements.txt b/requirements.txt index 340cebc..b528c6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,7 @@ openpyxl torch~=2.6.0 torchvision~=0.21.0 pillow~=11.1.0 -seaborn~=0.13.2 \ No newline at end of file +seaborn~=0.13.2 +scikit-learn~=1.6.1 +fastapy~=1.0.5 +pyarrow~=19.0.1 \ No newline at end of file -- GitLab