diff --git a/res_viz.py b/res_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..4117e65f78ec6457f4ceb652cb357446cd3e9303 --- /dev/null +++ b/res_viz.py @@ -0,0 +1,76 @@ +import random +import matplotlib +import numpy as np +import matplotlib.pyplot as plt +from dlomix.data import RetentionTimeDataset +epoch = 100 +number = 8 +BATCH_SIZE=64 + +test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv', + seq_length=30, batch_size=BATCH_SIZE, test=True) + +target = test_rtdata.get_split_targets(split="test") +index = random.sample(range(target.size), number) +target_plot = target[index] +pred_1=[] +pred_2=[] +error_1=[] +error_2=[] +order_1=[] +order_2=[] +pred_plot_1=[] +pred_plot_2=[] +error_plot_1=[] +error_plot_2=[] +order_plot_1=[] +order_plot_2=[] + + +for i in range(epoch): + data_1 = np.load('results/pred_prosit_ori/mem_pred_'+str(i)+'.npy') + data_1_plot =data_1[index] + data_2 = np.load('results/pred_prosit_ori_2/mem_pred_' + str(i) + '.npy') + data_2_plot = data_2[index] + pred_1.append(data_1) + pred_2.append(data_2) + error_1.append(data_1-target) + error_2.append(data_2-target) + order_1.append(np.argsort(data_1-target)) + order_2.append(np.argsort(data_2-target)) + pred_plot_1.append(data_1_plot) + pred_plot_2.append(data_2_plot) + error_plot_1.append(data_1_plot-target_plot) + error_plot_2.append(data_2_plot-target_plot) + order_plot_1.append(np.argsort(data_1_plot-target_plot)) + order_plot_2.append(np.argsort(data_2_plot-target_plot)) + +pred_1=np.array(pred_1) +pred_2=np.array(pred_2) +error_1=np.array(error_1) +error_2=np.array(error_2) +order_1=np.array(order_1) +order_2=np.array(order_2) +pred_plot_1=np.array(pred_plot_1) +pred_plot_2=np.array(pred_plot_2) +error_plot_1=np.array(error_plot_1) +error_plot_2=np.array(error_plot_2) +order_plot_1=np.array(order_plot_1) +order_plot_2=np.array(order_plot_2) + +def viz_error_1(): + for i in range(number): + plt.plot(error_plot_1[:,i]) + plt.show() + +def viz_error_2(): + for i in range(number): + plt.plot(error_plot_2[:,i]) + plt.show() + +def compare_error(): + colors = list(matplotlib.colors.TABLEAU_COLORS) + for i in range(number): + plt.plot(error_plot_1[:,i], c = colors[i], ls = 'dashed') + plt.plot(error_plot_2[:, i], c = colors[i]) + plt.show() \ No newline at end of file