diff --git a/main_ray_tune.py b/main_ray_tune.py index d939a1b85b99ec23e4f231470c349925bd6629b9..8215f01284d48b70023dadfdf9dc3023719d111c 100644 --- a/main_ray_tune.py +++ b/main_ray_tune.py @@ -57,12 +57,12 @@ def train_model(config, args): optimizer.load_state_dict(optimizer_state) if args.forward == 'both': - data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train, + data_train, data_val= common_dataset.load_data(path_train=args.dataset_train, path_val=args.dataset_val, path_test=args.dataset_test, batch_size=int(config["batch_size"]), length=25) else: - data_train, data_val, data_test = dataloader.load_data(data_source=args.dataset_train, + data_train, data_val = dataloader.load_data(data_source=args.dataset_train, batch_size=int(config["batch_size"]), length=25) for epoch in range(100): # loop over the dataset multiple times diff --git a/prosit_rt_ori.py b/prosit_rt_ori.py new file mode 100644 index 0000000000000000000000000000000000000000..950c7c5f7bd6e3ca8889f5de097b26c35b9784c9 --- /dev/null +++ b/prosit_rt_ori.py @@ -0,0 +1,77 @@ +import os +import sys + +import pandas as pd +import tensorflow as tf + +from dlomix.data import RetentionTimeDataset +from dlomix.eval import TimeDeltaMetric +from dlomix.models import PrositRetentionTimePredictor +from dlomix.reports import RetentionTimeReport + +# sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + +# consider the use-case for starting from a saved model + +model = PrositRetentionTimePredictor(seq_length=30) + +optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001) + + +d = RetentionTimeDataset( + data_format="csv", + data_source='database/data_train.csv', + test_data_source='database/data_holdout.csv', + sequence_column="sequence", + label_column="irt", + max_seq_len=30, + batch_size=512, + val_ratio=0.2, +) + +print(d) + +test_targets = d["test"]["irt"] +test_sequences = d["test"]["sequence"] + +model.compile( + optimizer=optimizer, loss="mse", metrics=["mean_absolute_error", TimeDeltaMetric()] +) + +weights_file = ".weights.h5" +checkpoint = tf.keras.callbacks.ModelCheckpoint( + weights_file, save_best_only=True, save_weights_only=True +) +decay = tf.keras.callbacks.ReduceLROnPlateau( + monitor="val_loss", factor=0.1, patience=10, verbose=1, min_lr=0 +) +early_stop = tf.keras.callbacks.EarlyStopping(patience=20) +callbacks = [checkpoint, early_stop, decay] + + +history = model.fit( + d.tensor_train_data, + epochs=25, + validation_data=d.tensor_val_data, + callbacks=callbacks, +) + +predictions = model.predict(test_sequences) +predictions = predictions.ravel() + +print(test_sequences[:5]) +print(test_targets[:5]) +print(predictions[:5]) + + +report = RetentionTimeReport(output_path="./output", history=history) + +print("R2: ", report.calculate_r2(test_targets, predictions)) + +pd.DataFrame( + { + "sequence": d["test"]["_parsed_sequence"], + "irt": test_targets, + "predicted_irt": predictions, + } +).to_csv("./predictions_prosit_fullrun.csv", index=False) \ No newline at end of file