Skip to content
Snippets Groups Projects
Commit d2f1bdc6 authored by Schneider Leo's avatar Schneider Leo
Browse files

datasets

parent 99827cc8
No related branches found
No related tags found
No related merge requests found
...@@ -57,12 +57,12 @@ def train_model(config, args): ...@@ -57,12 +57,12 @@ def train_model(config, args):
optimizer.load_state_dict(optimizer_state) optimizer.load_state_dict(optimizer_state)
if args.forward == 'both': if args.forward == 'both':
data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train, data_train, data_val= common_dataset.load_data(path_train=args.dataset_train,
path_val=args.dataset_val, path_val=args.dataset_val,
path_test=args.dataset_test, path_test=args.dataset_test,
batch_size=int(config["batch_size"]), length=25) batch_size=int(config["batch_size"]), length=25)
else: else:
data_train, data_val, data_test = dataloader.load_data(data_source=args.dataset_train, data_train, data_val = dataloader.load_data(data_source=args.dataset_train,
batch_size=int(config["batch_size"]), length=25) batch_size=int(config["batch_size"]), length=25)
for epoch in range(100): # loop over the dataset multiple times for epoch in range(100): # loop over the dataset multiple times
......
import os
import sys
import pandas as pd
import tensorflow as tf
from dlomix.data import RetentionTimeDataset
from dlomix.eval import TimeDeltaMetric
from dlomix.models import PrositRetentionTimePredictor
from dlomix.reports import RetentionTimeReport
# sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
# consider the use-case for starting from a saved model
model = PrositRetentionTimePredictor(seq_length=30)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
d = RetentionTimeDataset(
data_format="csv",
data_source='database/data_train.csv',
test_data_source='database/data_holdout.csv',
sequence_column="sequence",
label_column="irt",
max_seq_len=30,
batch_size=512,
val_ratio=0.2,
)
print(d)
test_targets = d["test"]["irt"]
test_sequences = d["test"]["sequence"]
model.compile(
optimizer=optimizer, loss="mse", metrics=["mean_absolute_error", TimeDeltaMetric()]
)
weights_file = ".weights.h5"
checkpoint = tf.keras.callbacks.ModelCheckpoint(
weights_file, save_best_only=True, save_weights_only=True
)
decay = tf.keras.callbacks.ReduceLROnPlateau(
monitor="val_loss", factor=0.1, patience=10, verbose=1, min_lr=0
)
early_stop = tf.keras.callbacks.EarlyStopping(patience=20)
callbacks = [checkpoint, early_stop, decay]
history = model.fit(
d.tensor_train_data,
epochs=25,
validation_data=d.tensor_val_data,
callbacks=callbacks,
)
predictions = model.predict(test_sequences)
predictions = predictions.ravel()
print(test_sequences[:5])
print(test_targets[:5])
print(predictions[:5])
report = RetentionTimeReport(output_path="./output", history=history)
print("R2: ", report.calculate_r2(test_targets, predictions))
pd.DataFrame(
{
"sequence": d["test"]["_parsed_sequence"],
"irt": test_targets,
"predicted_irt": predictions,
}
).to_csv("./predictions_prosit_fullrun.csv", index=False)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment