diff --git a/common_dataset.py b/common_dataset.py
index a03ce84b898faec8c654bd272787c25e23751dd8..f3b6ccc501acaee84b06f3bd00a0d6f1aaecc002 100644
--- a/common_dataset.py
+++ b/common_dataset.py
@@ -154,15 +154,15 @@ def load_data(path_train, path_val, path_test, batch_size, length, pad=False, co
     print('Loading data')
     data_train = pd.read_pickle(path_train)
     data_val = pd.read_pickle(path_val)
-    data_test = pd.read_pickle(path_test)
+    # data_test = pd.read_pickle(path_test)
     train = Common_Dataset(data_train, length, pad, convert, vocab)
-    test = Common_Dataset(data_val, length, pad, convert, vocab)
-    val = Common_Dataset(data_test, length, pad, convert, vocab)
+    val = Common_Dataset(data_val, length, pad, convert, vocab)
+    # val = Common_Dataset(data_test, length, pad, convert, vocab)
     train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
-    test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)
+    # test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)
     val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)
 
-    return train_loader, val_loader, test_loader
+    return train_loader, val_loader
 
 if __name__ =='__main__' :
     irt_train = np.load('data/intensity/irt_train.npy')
diff --git a/dataloader.py b/dataloader.py
index 9bfeb25558f8d8c1baaaa59236bdaa3cca5c517d..c1368ce76f3bb91225f5731128a8ec99089d170f 100644
--- a/dataloader.py
+++ b/dataloader.py
@@ -108,13 +108,13 @@ class RT_Dataset(Dataset):
 def load_data(batch_size, data_sources, n_train=None, n_test=None, length=30):
     print('Loading data')
     train = RT_Dataset(n_train, data_sources[0], 'train', length)
-    test = RT_Dataset(n_test, data_sources[1], 'test', length)
+    # test = RT_Dataset(n_test, data_sources[1], 'test', length)
     val = RT_Dataset(n_test, data_sources[2], 'test', length)
     train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
-    test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)
+    # test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)
     val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)
 
-    return train_loader, val_loader, test_loader
+    return train_loader, val_loader
 
 
 class H5ToStorage():
diff --git a/main_custom.py b/main_custom.py
index 7c7421b3fef9cc87dd40b42b8d0b12bc1af7673f..b7bba3e2cdd3f510761820dabc871ef18dbc9223 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -220,28 +220,28 @@ def main(args):
     print('Cuda : ', torch.cuda.is_available())
 
     if args.forward == 'both':
-        data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train,
+        data_train, data_val, _ = common_dataset.load_data(path_train=args.dataset_train,
                                                                    path_val=args.dataset_val,
                                                                    path_test=args.dataset_test,
                                                                    batch_size=args.batch_size, length=args.seq_length, pad = False, convert=False, vocab='unmod')
     elif args.forward == 'rt':
-        data_train, data_val, data_test = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test],
+        data_train, data_val = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test],
                                                                batch_size=args.batch_size, length=args.seq_length)
 
     elif args.forward == 'transfer':
-        data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'],
+        data_train, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'],
                                                                batch_size=args.batch_size, length=args.seq_length)
 
-        _, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val,
+        _, data_val = common_dataset.load_data(path_train=args.dataset_val,
                                                                    path_val=args.dataset_val,
                                                                    path_test=args.dataset_test,
                                                                    batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
 
     elif args.forward == 'reverse':
-        _, data_val, data_test = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test],
+        _, data_val = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test],
                                                                batch_size=args.batch_size, length=args.seq_length)
 
-        data_train, _, _ = common_dataset.load_data(path_train=args.dataset_train,
+        data_train, _ = common_dataset.load_data(path_train=args.dataset_train,
                                                                    path_val=args.dataset_train,
                                                                    path_test=args.dataset_train,
                                                                    batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
@@ -261,7 +261,7 @@ def main(args):
     print('\nModel initialised')
 
     run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train,
-        data_val=data_val, data_test=data_test, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(),
+        data_val=data_val, data_test=data_val, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(),
         criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle,
         wandb=args.wandb, forward=args.forward, output=args.output, file=args.file)