diff --git a/prosit_RT_ori.py b/prosit_RT_ori.py new file mode 100644 index 0000000000000000000000000000000000000000..b07b71d29e5b834d0ddb6b5f10065c05ca9fc254 --- /dev/null +++ b/prosit_RT_ori.py @@ -0,0 +1,58 @@ +import os +import wandb as wdb +import numpy as np +import pandas as pd +import dlomix +from dlomix import constants, data, eval, layers, models, pipelines, reports, utils +from dlomix.models import RetentionTimePredictor +from dlomix.data import RetentionTimeDataset +from dlomix.eval import TimeDeltaMetric +from dlomix.reports import RetentionTimeReport + +batch_size = 1024 +os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd' +os.environ["WANDB_MODE"] = "offline" + +config = { + "model": "RT prediction GRU/selfAtt+ GRU", + "batch_size":batch_size, +} + +wdb.init(project="RT prediction") + +TRAIN_DATAPATH = '/database/data.csv' + +data = pd.read_csv('/database/data.csv') +data_train = data[data.state == 'train' or data.state == 'validation'] +data_holdout = data[data.state == 'holdout'] + +data_train.to_csv('/database/data_train.csv') +data_holdout.to_csv('/database/data_holdout.csv') + +rtdata = RetentionTimeDataset(data_source='/database/data_train.csv', + seq_length=30, batch_size=batch_size, val_ratio=0.2, test=False) + +model = RetentionTimePredictor(seq_length=30) + +model.compile(optimizer='adam', + loss='mse', + metrics=['mean_absolute_error', TimeDeltaMetric()]) + +history = model.fit(rtdata.train_data, + validation_data=rtdata.val_data, + epochs=20) + +test_rtdata = RetentionTimeDataset(data_source='/database/data_holdout.csv', + seq_length=30, batch_size=1024, test=True) + +predictions = model.predict(test_rtdata.test_data) + +# we use ravel from numpy to flatten the array (since it comes out as an array of arrays) +predictions = predictions.ravel() + +test_targets = test_rtdata.get_split_targets(split="test") + +report = RetentionTimeReport(output_path="./output", history=history) +report.calculate_r2(test_targets, predictions) + +wdb.finish() \ No newline at end of file