From c6b789878d54263481b98e872621bc2bc78c76b7 Mon Sep 17 00:00:00 2001
From: lschneider <leo.schneider@univ-lyon1.fr>
Date: Mon, 29 Jan 2024 12:38:00 +0100
Subject: [PATCH] fpdf_local

---
 .gitignore       |  1 +
 prosit_RT_ori.py | 47 -----------------------------------------------
 requirements.txt |  5 +++--
 3 files changed, 4 insertions(+), 49 deletions(-)
 delete mode 100644 prosit_RT_ori.py

diff --git a/.gitignore b/.gitignore
index a583428..1f8b308 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,3 +5,4 @@
 /test.py
 /database/
 
+/wandb_run/
diff --git a/prosit_RT_ori.py b/prosit_RT_ori.py
deleted file mode 100644
index 63f207c..0000000
--- a/prosit_RT_ori.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import os
-import numpy as np
-import pandas as pd
-
-from dlomix import constants, data, eval, layers, models, pipelines, reports, utils
-from dlomix.models import RetentionTimePredictor
-from dlomix.data import RetentionTimeDataset
-from dlomix.eval import TimeDeltaMetric
-from dlomix.reports import RetentionTimeReport
-
-batch_size = 1024
-
-TRAIN_DATAPATH = '/database/data.csv'
-
-data = pd.read_csv('/database/data.csv')
-data_train = data[data.state == 'train' or data.state == 'validation']
-data_holdout = data[data.state == 'holdout']
-
-data_train.to_csv('/database/data_train.csv')
-data_holdout.to_csv('/database/data_holdout.csv')
-
-rtdata = RetentionTimeDataset(data_source='/database/data_train.csv',
-                              seq_length=30, batch_size=batch_size, val_ratio=0.2, test=False)
-
-model = RetentionTimePredictor(seq_length=30)
-
-model.compile(optimizer='adam',
-              loss='mse',
-              metrics=['mean_absolute_error', TimeDeltaMetric()])
-
-history = model.fit(rtdata.train_data,
-                    validation_data=rtdata.val_data,
-                    epochs=20)
-
-test_rtdata = RetentionTimeDataset(data_source='/database/data_holdout.csv',
-                              seq_length=30, batch_size=1024, test=True)
-
-predictions = model.predict(test_rtdata.test_data)
-
-# we use ravel from numpy to flatten the array (since it comes out as an array of arrays)
-predictions = predictions.ravel()
-
-test_targets = test_rtdata.get_split_targets(split="test")
-
-report = RetentionTimeReport(output_path="./output", history=history)
-report.calculate_r2(test_targets, predictions)
-report.generate_report(test_targets, predictions)
diff --git a/requirements.txt b/requirements.txt
index 2b26b72..cb93452 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,6 @@
 torch~=2.1.2
 h5py~=3.10.0
-pandas~=2.1.4
+pandas~=2.2.0
 numpy~=1.26.2
-matplotlib~=3.8.2
\ No newline at end of file
+matplotlib~=3.8.2
+wandb~=0.16.2
\ No newline at end of file
-- 
GitLab