diff --git a/res_viz.py b/res_viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..4117e65f78ec6457f4ceb652cb357446cd3e9303
--- /dev/null
+++ b/res_viz.py
@@ -0,0 +1,76 @@
+import random
+import matplotlib
+import numpy as np
+import matplotlib.pyplot as plt
+from dlomix.data import RetentionTimeDataset
+epoch = 100
+number = 8
+BATCH_SIZE=64
+
+test_rtdata = RetentionTimeDataset(data_source='database/data_holdout.csv',
+                                       seq_length=30, batch_size=BATCH_SIZE, test=True)
+
+target = test_rtdata.get_split_targets(split="test")
+index = random.sample(range(target.size), number)
+target_plot = target[index]
+pred_1=[]
+pred_2=[]
+error_1=[]
+error_2=[]
+order_1=[]
+order_2=[]
+pred_plot_1=[]
+pred_plot_2=[]
+error_plot_1=[]
+error_plot_2=[]
+order_plot_1=[]
+order_plot_2=[]
+
+
+for i in range(epoch):
+    data_1 = np.load('results/pred_prosit_ori/mem_pred_'+str(i)+'.npy')
+    data_1_plot =data_1[index]
+    data_2 = np.load('results/pred_prosit_ori_2/mem_pred_' + str(i) + '.npy')
+    data_2_plot = data_2[index]
+    pred_1.append(data_1)
+    pred_2.append(data_2)
+    error_1.append(data_1-target)
+    error_2.append(data_2-target)
+    order_1.append(np.argsort(data_1-target))
+    order_2.append(np.argsort(data_2-target))
+    pred_plot_1.append(data_1_plot)
+    pred_plot_2.append(data_2_plot)
+    error_plot_1.append(data_1_plot-target_plot)
+    error_plot_2.append(data_2_plot-target_plot)
+    order_plot_1.append(np.argsort(data_1_plot-target_plot))
+    order_plot_2.append(np.argsort(data_2_plot-target_plot))
+
+pred_1=np.array(pred_1)
+pred_2=np.array(pred_2)
+error_1=np.array(error_1)
+error_2=np.array(error_2)
+order_1=np.array(order_1)
+order_2=np.array(order_2)
+pred_plot_1=np.array(pred_plot_1)
+pred_plot_2=np.array(pred_plot_2)
+error_plot_1=np.array(error_plot_1)
+error_plot_2=np.array(error_plot_2)
+order_plot_1=np.array(order_plot_1)
+order_plot_2=np.array(order_plot_2)
+
+def viz_error_1():
+    for i in range(number):
+        plt.plot(error_plot_1[:,i])
+    plt.show()
+
+def viz_error_2():
+    for i in range(number):
+        plt.plot(error_plot_2[:,i])
+    plt.show()
+
+def compare_error():
+    colors = list(matplotlib.colors.TABLEAU_COLORS)
+    for i in range(number):
+        plt.plot(error_plot_1[:,i], c = colors[i], ls =  'dashed')
+        plt.plot(error_plot_2[:, i], c = colors[i])
+    plt.show()
\ No newline at end of file