From 91ea116b5ac8df8382d550fb84d6733690ca11e1 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 22 Oct 2024 13:54:37 +0200
Subject: [PATCH] datasets

---
 main_custom.py       | 4 ++--
 prosit_data_merge.py | 8 ++++++--
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/main_custom.py b/main_custom.py
index fec2d98..a0f6489 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -200,8 +200,8 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test,
 
     else :
         for e in range(1, epochs + 1):
-            # train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
-            #       wandb=wandb)
+            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
+                  wandb=wandb)
             if e % eval_inter == 0:
                 eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
                      wandb=wandb)
diff --git a/prosit_data_merge.py b/prosit_data_merge.py
index 444200c..1211691 100644
--- a/prosit_data_merge.py
+++ b/prosit_data_merge.py
@@ -148,7 +148,11 @@ def alphabetical_to_numerical(seq):
 
 df = pd.read_pickle('database/data_prosit_merged_holdout.pkl')
 
-df = df.head(100)
+print(len(df))
+print(df.head())
+df = df[df['Retention time']!=[0,0,0,0,0,0]]
+print(len(df))
+print(df.head())
 
-df.to_csv('database/data_head.csv')
+df.to_pickle('database/data_prosit_merged_holdout_2.pkl')
 
-- 
GitLab