diff --git a/.gitignore b/.gitignore index a583428504c4889f199f611952e86b12a08406ef..1f8b3087abce686b440c12c0d2ba6860fe047a28 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ /test.py /database/ +/wandb_run/ diff --git a/prosit_RT_ori.py b/prosit_RT_ori.py deleted file mode 100644 index 63f207ca02ae5758bc71be464c58dac18de1dba1..0000000000000000000000000000000000000000 --- a/prosit_RT_ori.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -import numpy as np -import pandas as pd - -from dlomix import constants, data, eval, layers, models, pipelines, reports, utils -from dlomix.models import RetentionTimePredictor -from dlomix.data import RetentionTimeDataset -from dlomix.eval import TimeDeltaMetric -from dlomix.reports import RetentionTimeReport - -batch_size = 1024 - -TRAIN_DATAPATH = '/database/data.csv' - -data = pd.read_csv('/database/data.csv') -data_train = data[data.state == 'train' or data.state == 'validation'] -data_holdout = data[data.state == 'holdout'] - -data_train.to_csv('/database/data_train.csv') -data_holdout.to_csv('/database/data_holdout.csv') - -rtdata = RetentionTimeDataset(data_source='/database/data_train.csv', - seq_length=30, batch_size=batch_size, val_ratio=0.2, test=False) - -model = RetentionTimePredictor(seq_length=30) - -model.compile(optimizer='adam', - loss='mse', - metrics=['mean_absolute_error', TimeDeltaMetric()]) - -history = model.fit(rtdata.train_data, - validation_data=rtdata.val_data, - epochs=20) - -test_rtdata = RetentionTimeDataset(data_source='/database/data_holdout.csv', - seq_length=30, batch_size=1024, test=True) - -predictions = model.predict(test_rtdata.test_data) - -# we use ravel from numpy to flatten the array (since it comes out as an array of arrays) -predictions = predictions.ravel() - -test_targets = test_rtdata.get_split_targets(split="test") - -report = RetentionTimeReport(output_path="./output", history=history) -report.calculate_r2(test_targets, predictions) -report.generate_report(test_targets, predictions) diff --git a/requirements.txt b/requirements.txt index 2b26b72c70f1a09cdb85653106ecd0123af782d5..cb9345290e9b9ff90993376325c65f4436755e15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ torch~=2.1.2 h5py~=3.10.0 -pandas~=2.1.4 +pandas~=2.2.0 numpy~=1.26.2 -matplotlib~=3.8.2 \ No newline at end of file +matplotlib~=3.8.2 +wandb~=0.16.2 \ No newline at end of file