From e180bbf292edf65350e71438a4de711b0f02839e Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 15 Oct 2024 11:43:48 +0200
Subject: [PATCH] data merge intensity

---
 common_dataset.py | 21 ++++++++++++++++++---
 main.py           |  8 ++++----
 2 files changed, 22 insertions(+), 7 deletions(-)

diff --git a/common_dataset.py b/common_dataset.py
index 19f35c9..bab5944 100644
--- a/common_dataset.py
+++ b/common_dataset.py
@@ -137,9 +137,8 @@ class Common_Dataset(Dataset):
         rt = self.data['Retention time'][index]
         intensity = self.data['Spectra'][index]
         charge = self.data['Charge'][index]
-        file = self.data['file'][index]
-
-        if self.file_mode :
+        if self.file_mode:
+            file = self.data['file'][index]
             return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity),  torch.tensor(file)
         else :
             return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity)
@@ -164,3 +163,19 @@ def load_data(path_train, path_val, path_test, batch_size, length, pad=False, co
     val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)
 
     return train_loader, val_loader, test_loader
+
+
+irt_train = np.load('data/intensity/collision_irt_train.npy')
+seq_train = np.load('data/intensity/sequence_train.npy')
+charge_train = np.load('data/intensity/precursor_charge_train.npy')
+spectra_train = np.load('data/intensity/intensity_train.npy')
+
+# irt_holdout = np.load('data/intensity/collision_irt_holdout.npy')
+# seq_holdout = np.load('data/intensity/sequence_holdout.npy')
+# charge_holdout = np.load('data/intensity/precursor_charge_holdout.npy')
+# spectra_holdout = np.load('data/intensity/intensity_holdout.npy')
+
+dataset_train = pd.DataFrame({'Sequence':list(seq_train), 'Retention time':list(irt_train), 'Charge':list(charge_train), 'Spectra' : list(spectra_train)},index=list(range(6787933)))
+dataset_train.to_pickle('database/data_prosit_merged_train.pkl')
+# dataset_test = pd.DataFrame({'Sequence':list(seq_holdout), 'Retention time':list(irt_holdout), 'Charge':list(charge_holdout), 'Spectra' : list(spectra_holdout)},index=list(range(754215)))
+# dataset_train.to_pickle('database/data_prosit_merged_holdout.pkl')
\ No newline at end of file
diff --git a/main.py b/main.py
index 97cecf3..23c1a3d 100644
--- a/main.py
+++ b/main.py
@@ -340,10 +340,10 @@ def main_int(args):
                      'data/intensity/collision_energy_train.npy',
                      'data/intensity/precursor_charge_train.npy')
 
-    sources_test = ('data/intensity/sequence_test.npy',
-                    'data/intensity/intensity_test.npy',
-                    'data/intensity/collision_energy_test.npy',
-                    'data/intensity/precursor_charge_test.npy')
+    sources_test = ('data/intensity/sequence_holdout.npy',
+                    'data/intensity/intensity_holdout.npy',
+                    'data/intensity/collision_energy_holdout.npy',
+                    'data/intensity/precursor_charge_holdout.npy')
 
     data_train = load_intensity_from_files(sources_train[0], sources_train[1], sources_train[2], sources_train[3],
                                            args.batch_size)
-- 
GitLab