diff --git a/image_ref/analyse_diann_digestion.py b/image_ref/analyse_diann_digestion.py
index 94f8a4c9987161c66a8ff9087b9dbaa9a8d42df6..920a5861329c90f43355b443529a127ce4a542e3 100644
--- a/image_ref/analyse_diann_digestion.py
+++ b/image_ref/analyse_diann_digestion.py
@@ -11,13 +11,13 @@ def load_lib(path):
 
     return table
 
+if __name__ =='__main__':
+    df1 = load_lib('fasta/steigerwaltii variants/uniparc_proteome_UP000033376_2025_03_14.predicted.parquet')
+    df2 = load_lib('fasta/steigerwaltii variants/uniparc_proteome_UP000033499_2025_03_14.predicted.parquet')
 
-df1 = load_lib('fasta/steigerwaltii variants/uniparc_proteome_UP000033376_2025_03_14.predicted.parquet')
-df2 = load_lib('fasta/steigerwaltii variants/uniparc_proteome_UP000033499_2025_03_14.predicted.parquet')
+    set1 = set(df1['Stripped.Sequence'].to_list())
+    set2 = set(df2['Stripped.Sequence'].to_list())
 
-set1 = set(df1['Stripped.Sequence'].to_list())
-set2 = set(df2['Stripped.Sequence'].to_list())
-
-venn2((set1, set2), ('Group1', 'Group2'))
-plt.show()
-plt.savefig('fasta_similarity_diann.png')
\ No newline at end of file
+    venn2((set1, set2), ('Group1', 'Group2'))
+    plt.show()
+    plt.savefig('fasta_similarity_diann.png')
\ No newline at end of file
diff --git a/image_ref/main.py b/image_ref/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eda02c6aa42d5ac7503dfd78a558780575c5b4f
--- /dev/null
+++ b/image_ref/main.py
@@ -0,0 +1,4 @@
+#TODO REFAIRE UN DATASET https://discuss.pytorch.org/t/upload-a-customize-data-set-for-multi-regression-task/43413?u=ptrblck
+"""1er methode load 1 image pour 1 ref
+2eme methode : load 1 image et toutes les refs : ok pour l'instant mais a voir comment est ce que cela scale avec l'augmentation du nb de classes
+3eme methods 2 datasets différents : plus efficace en stockage mais pas facil a maintenir"""
\ No newline at end of file
diff --git a/image_ref/utils.py b/image_ref/utils.py
index 727f5e17b0a939bb0ed6dd10983acb63ee628d37..900b6e2f0a7be0cb1c8879c701dcd452105d0b0e 100644
--- a/image_ref/utils.py
+++ b/image_ref/utils.py
@@ -202,7 +202,7 @@ def build_ref_image(path_fasta, possible_charge, ms1_end_mz, ms1_start_mz, bin_m
     return im
 
 
-def build_ref_image_from_diann(path_parqet, ms1_end_mz, ms1_start_mz, bin_mz, max_cycle, rt_pred):
+def build_ref_image_from_diann(path_parqet, ms1_end_mz, ms1_start_mz, bin_mz, max_cycle, min_rt=None, max_rt=None):
 
 
     df = load_lib(path_parqet)
@@ -212,8 +212,10 @@ def build_ref_image_from_diann(path_parqet, ms1_end_mz, ms1_start_mz, bin_mz, ma
     total_ms1_mz = ms1_end_mz - ms1_start_mz
     n_bin_ms1 = int(total_ms1_mz // bin_mz)
     im = np.zeros([max_cycle, n_bin_ms1])
-    max_rt = np.max(df_unique['RT'])
-    min_rt = np.min(df_unique['RT'])
+    if max_rt is None:
+        max_rt = np.max(df_unique['RT'])
+    if min_rt is None:
+        min_rt = np.min(df_unique['RT'])
     total_rt = max_rt - min_rt +1e-3
     for row in df_unique.iterrows() :
         if 900 > int(((row[1]['Precursor.Mz']-ms1_start_mz)/total_ms1_mz)*n_bin_ms1) >= 0:
@@ -230,8 +232,13 @@ if __name__ == '__main__':
     # mpimg.imsave('test_img.png', im)
 
     df = build_database_ref_peptide()
+    df_full = load_lib('fasta/full proteom/steigerwaltii variants/uniparc_proteome_UP000033376_2025_03_14.predicted.parquet')
+    min_rt = df_full['RT'].min()
+    max_rt = df_full['RT'].max()
     for spe in ['Proteus mirabilis','Klebsiella pneumoniae','Klebsiella oxytoca','Enterobacter hormaechei','Citrobacter freundii']:
-        df_spe = df[df['Specie']==spe]
-        with open(spe+'.fasta','w') as f:
-            for r in df_spe.iterrows():
-                f.write(r[1]['Sequence'])
+        im = build_ref_image_from_diann(
+            'fasta/optimal peptide set/'+spe+'.parquet', ms1_end_mz=1250,
+            ms1_start_mz=350, bin_mz=1, max_cycle=663, min_rt=min_rt, max_rt=max_rt)
+        plt.clf()
+        mpimg.imsave(spe+'.png', im)
+
diff --git a/models/model.py b/models/model.py
index cea9caa2a300f4832e708b618bdca4fff20aca10..8b03f7eea1f1edbde7df6d59025363b9137457ea 100644
--- a/models/model.py
+++ b/models/model.py
@@ -106,7 +106,7 @@ class ResNet(nn.Module):
 
     def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
                  groups=1, width_per_group=64, replace_stride_with_dilation=None,
-                 norm_layer=None):
+                 norm_layer=None, in_channels=3):
         super(ResNet, self).__init__()
         if norm_layer is None:
             norm_layer = nn.BatchNorm2d
@@ -123,7 +123,7 @@ class ResNet(nn.Module):
                              "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
         self.groups = groups
         self.base_width = width_per_group
-        self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=7, stride=2, padding=3,
+        self.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=7, stride=2, padding=3,
                                bias=False)
         self.bn1 = norm_layer(self.inplanes)
         self.relu = nn.ReLU(inplace=True)
@@ -266,19 +266,32 @@ class Classification_model(nn.Module):
         super().__init__(*args, **kwargs)
         self.n_class = n_class
         if model =='ResNet18':
-            self.im_encoder = resnet18(num_classes=self.n_class)
+            self.im_encoder = resnet18(num_classes=self.n_class, in_channels=1)
 
 
     def forward(self, input):
         return self.im_encoder(input)
 
+class Classification_model_contrastive(nn.Module):
+
+    def __init__(self, model, n_class, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.n_class = n_class
+        if model =='ResNet18':
+            self.im_encoder = resnet18(num_classes=self.n_class, in_channels=2)
+
+
+    def forward(self, input, ref):
+        input = torch.concat(input,ref,dim=2)
+        return self.im_encoder(input)
+
 class Classification_model_duo(nn.Module):
 
     def __init__(self, model, n_class, *args, **kwargs):
         super().__init__(*args, **kwargs)
         self.n_class = n_class
         if model =='ResNet18':
-            self.im_encoder = resnet18(num_classes=self.n_class)
+            self.im_encoder = resnet18(num_classes=self.n_class, in_channels=1)
 
         self.predictor = nn.Linear(in_features=self.n_class*2,out_features=self.n_class)