import os import datetime from sklearn.metrics import r2_score import matplotlib.pyplot as plt import numpy as np import pandas as pd import wandb from dlomix.data import RetentionTimeDataset from dlomix.eval import TimeDeltaMetric from dlomix.models import PrositRetentionTimePredictor, RetentionTimePredictor from dlomix.reports import RetentionTimeReport import tensorflow def save_reg(pred, true, name): coef = np.polyfit(pred, true, 1) poly1d_fn = np.poly1d(coef) r2 = round(r2_score(pred, true), 4) plt.plot(pred, true, 'y,', pred, poly1d_fn(pred), '--k') plt.text(120, 20, 'R² = ' + str(r2), fontsize=12) plt.savefig(name) plt.clf() def track_train(model, epoch, test_rtdata, rtdata): BATCH_SIZE = 256 test_targets = test_rtdata.get_split_targets(split="test") train_target = rtdata.get_split_targets(split="train") loss = tensorflow.keras.losses.MeanSquaredError() metric = TimeDeltaMetric() optimizer = tensorflow.keras.optimizers.Adam() os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd' os.environ["WANDB_MODE"] = "offline" os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run") wandb.init(project="Prosit ori full dataset", dir='./wandb_run', name='prosit ori') for e in range(epoch): for step, (X_batch, y_batch) in enumerate(rtdata.train_data): with tensorflow.GradientTape() as tape: predictions = model(X_batch, training=True) l = loss(predictions, y_batch) grads = tape.gradient(l, model.trainable_weights) optimizer.apply_gradients(zip(grads, model.trainable_weights)) wandb.log({'grads': grads}) predictions = model.predict(test_rtdata.test_data) save_reg(predictions.flatten(), test_targets, 'fig/unstability/reg_epoch_' + str(e)) wandb.finish() def train_step(model, optimizer, x_train, y_train, step): with tensorflow.GradientTape() as tape: predictions = model(x_train, training=True) tape.watch(model.trainable_variables) loss = loss_object(y_train, predictions) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) for weights, grads in zip(model.trainable_weights, grads): tensorflow.summary.histogram( weights.name.replace(':', '_') + '_grads', data=grads, step=step) train_loss(loss) train_accuracy(y_train, predictions) def test_step(model, x_test, y_test): predictions = model(x_test) loss = loss_object(y_test, predictions) test_loss(loss) test_accuracy(y_test, predictions) def main(): BATCH_SIZE = 256 rtdata = RetentionTimeDataset(data_source='database/data_train.csv', seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False) test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv', seq_length=30, batch_size=32, test=True) test_targets = test_rtdata.get_split_targets(split="test") model = PrositRetentionTimePredictor(seq_length=30) model.compile(optimizer='adam', loss='mse', metrics=['mean_absolute_error', TimeDeltaMetric()]) os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd' os.environ["WANDB_MODE"] = "offline" os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run") wandb.init(project="Prosit ori full dataset", dir='./wandb_run', name='prosit ori') history = model.fit(rtdata.train_data, validation_data=rtdata.val_data, epochs=100) wandb.finish() test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv', seq_length=30, batch_size=32, test=True) predictions = model.predict(test_rtdata.test_data) test_targets = test_rtdata.get_split_targets(split="test") report = RetentionTimeReport(output_path="./output", history=history) def main_track(): BATCH_SIZE = 256 rtdata = RetentionTimeDataset(data_source='database/data_train.csv', seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False) test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv', seq_length=30, batch_size=BATCH_SIZE, test=True) test_targets = test_rtdata.get_split_targets(split="test") model = RetentionTimePredictor(seq_length=30) track_train(model, 100, test_rtdata, rtdata) if __name__ == '__main__': # loss_object = tensorflow.keras.losses.MeanSquaredError() # optimizer = tensorflow.keras.optimizers.Adam() # train_loss = tensorflow.keras.metrics.Mean('train_loss', dtype=tensorflow.float32) # train_accuracy = tensorflow.keras.metrics.MeanAbsoluteError('train_accuracy') # test_loss = tensorflow.keras.metrics.Mean('test_loss', dtype=tensorflow.float32) # test_accuracy = tensorflow.keras.metrics.MeanAbsoluteError('test_accuracy') # # current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") # train_log_dir = 'logs/gradient_tape/' + current_time + '/train' # test_log_dir = 'logs/gradient_tape/' + current_time + '/test' # train_summary_writer = tensorflow.summary.create_file_writer(train_log_dir) # test_summary_writer = tensorflow.summary.create_file_writer(test_log_dir) BATCH_SIZE = 256 rtdata = RetentionTimeDataset(data_source='database/data_train.csv', seq_length=30, batch_size=BATCH_SIZE, val_ratio=0.2, test=False) test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv', seq_length=30, batch_size=32, test=True) test_targets = test_rtdata.get_split_targets(split="test") # model = RetentionTimePredictor(seq_length=30) # EPOCHS = 5 for epoch in range(EPOCHS): for (x_train, y_train) in rtdata.train_data: print(x_train) break # train_step(model, optimizer, x_train, y_train, epoch) # with train_summary_writer.as_default(): # tensorflow.summary.scalar('loss', train_loss.result(), step=epoch) # tensorflow.summary.scalar('accuracy', train_accuracy.result(), step=epoch) # # for (x_test, y_test) in test_rtdata.test_data: # test_step(model, x_test, y_test) # with test_summary_writer.as_default(): # tensorflow.summary.scalar('loss', test_loss.result(), step=epoch) # tensorflow.summary.scalar('accuracy', test_accuracy.result(), step=epoch) # # template = 'Epoch {}, Loss: {}, Absolute Error: {}, Test Loss: {}, Test Absolute Error: {}' # print(template.format(epoch + 1, # train_loss.result(), # train_accuracy.result(), # test_loss.result(), # test_accuracy.result())) # # # Reset metrics every epoch # train_loss.reset_states() # test_loss.reset_states() # train_accuracy.reset_states() # test_accuracy.reset_states()