From c6b789878d54263481b98e872621bc2bc78c76b7 Mon Sep 17 00:00:00 2001 From: lschneider <leo.schneider@univ-lyon1.fr> Date: Mon, 29 Jan 2024 12:38:00 +0100 Subject: [PATCH] fpdf_local --- .gitignore | 1 + prosit_RT_ori.py | 47 ----------------------------------------------- requirements.txt | 5 +++-- 3 files changed, 4 insertions(+), 49 deletions(-) delete mode 100644 prosit_RT_ori.py diff --git a/.gitignore b/.gitignore index a583428..1f8b308 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ /test.py /database/ +/wandb_run/ diff --git a/prosit_RT_ori.py b/prosit_RT_ori.py deleted file mode 100644 index 63f207c..0000000 --- a/prosit_RT_ori.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -import numpy as np -import pandas as pd - -from dlomix import constants, data, eval, layers, models, pipelines, reports, utils -from dlomix.models import RetentionTimePredictor -from dlomix.data import RetentionTimeDataset -from dlomix.eval import TimeDeltaMetric -from dlomix.reports import RetentionTimeReport - -batch_size = 1024 - -TRAIN_DATAPATH = '/database/data.csv' - -data = pd.read_csv('/database/data.csv') -data_train = data[data.state == 'train' or data.state == 'validation'] -data_holdout = data[data.state == 'holdout'] - -data_train.to_csv('/database/data_train.csv') -data_holdout.to_csv('/database/data_holdout.csv') - -rtdata = RetentionTimeDataset(data_source='/database/data_train.csv', - seq_length=30, batch_size=batch_size, val_ratio=0.2, test=False) - -model = RetentionTimePredictor(seq_length=30) - -model.compile(optimizer='adam', - loss='mse', - metrics=['mean_absolute_error', TimeDeltaMetric()]) - -history = model.fit(rtdata.train_data, - validation_data=rtdata.val_data, - epochs=20) - -test_rtdata = RetentionTimeDataset(data_source='/database/data_holdout.csv', - seq_length=30, batch_size=1024, test=True) - -predictions = model.predict(test_rtdata.test_data) - -# we use ravel from numpy to flatten the array (since it comes out as an array of arrays) -predictions = predictions.ravel() - -test_targets = test_rtdata.get_split_targets(split="test") - -report = RetentionTimeReport(output_path="./output", history=history) -report.calculate_r2(test_targets, predictions) -report.generate_report(test_targets, predictions) diff --git a/requirements.txt b/requirements.txt index 2b26b72..cb93452 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ torch~=2.1.2 h5py~=3.10.0 -pandas~=2.1.4 +pandas~=2.2.0 numpy~=1.26.2 -matplotlib~=3.8.2 \ No newline at end of file +matplotlib~=3.8.2 +wandb~=0.16.2 \ No newline at end of file -- GitLab