diff --git a/common_dataset.py b/common_dataset.py index a03ce84b898faec8c654bd272787c25e23751dd8..f3b6ccc501acaee84b06f3bd00a0d6f1aaecc002 100644 --- a/common_dataset.py +++ b/common_dataset.py @@ -154,15 +154,15 @@ def load_data(path_train, path_val, path_test, batch_size, length, pad=False, co print('Loading data') data_train = pd.read_pickle(path_train) data_val = pd.read_pickle(path_val) - data_test = pd.read_pickle(path_test) + # data_test = pd.read_pickle(path_test) train = Common_Dataset(data_train, length, pad, convert, vocab) - test = Common_Dataset(data_val, length, pad, convert, vocab) - val = Common_Dataset(data_test, length, pad, convert, vocab) + val = Common_Dataset(data_val, length, pad, convert, vocab) + # val = Common_Dataset(data_test, length, pad, convert, vocab) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True) - test_loader = DataLoader(test, batch_size=batch_size, shuffle=True) + # test_loader = DataLoader(test, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=True) - return train_loader, val_loader, test_loader + return train_loader, val_loader if __name__ =='__main__' : irt_train = np.load('data/intensity/irt_train.npy') diff --git a/dataloader.py b/dataloader.py index 9bfeb25558f8d8c1baaaa59236bdaa3cca5c517d..c1368ce76f3bb91225f5731128a8ec99089d170f 100644 --- a/dataloader.py +++ b/dataloader.py @@ -108,13 +108,13 @@ class RT_Dataset(Dataset): def load_data(batch_size, data_sources, n_train=None, n_test=None, length=30): print('Loading data') train = RT_Dataset(n_train, data_sources[0], 'train', length) - test = RT_Dataset(n_test, data_sources[1], 'test', length) + # test = RT_Dataset(n_test, data_sources[1], 'test', length) val = RT_Dataset(n_test, data_sources[2], 'test', length) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True) - test_loader = DataLoader(test, batch_size=batch_size, shuffle=True) + # test_loader = DataLoader(test, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=True) - return train_loader, val_loader, test_loader + return train_loader, val_loader class H5ToStorage(): diff --git a/main_custom.py b/main_custom.py index 7c7421b3fef9cc87dd40b42b8d0b12bc1af7673f..b7bba3e2cdd3f510761820dabc871ef18dbc9223 100644 --- a/main_custom.py +++ b/main_custom.py @@ -220,28 +220,28 @@ def main(args): print('Cuda : ', torch.cuda.is_available()) if args.forward == 'both': - data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train, + data_train, data_val, _ = common_dataset.load_data(path_train=args.dataset_train, path_val=args.dataset_val, path_test=args.dataset_test, batch_size=args.batch_size, length=args.seq_length, pad = False, convert=False, vocab='unmod') elif args.forward == 'rt': - data_train, data_val, data_test = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test], + data_train, data_val = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test], batch_size=args.batch_size, length=args.seq_length) elif args.forward == 'transfer': - data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'], + data_train, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'], batch_size=args.batch_size, length=args.seq_length) - _, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val, + _, data_val = common_dataset.load_data(path_train=args.dataset_val, path_val=args.dataset_val, path_test=args.dataset_test, batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod') elif args.forward == 'reverse': - _, data_val, data_test = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test], + _, data_val = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test], batch_size=args.batch_size, length=args.seq_length) - data_train, _, _ = common_dataset.load_data(path_train=args.dataset_train, + data_train, _ = common_dataset.load_data(path_train=args.dataset_train, path_val=args.dataset_train, path_test=args.dataset_train, batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod') @@ -261,7 +261,7 @@ def main(args): print('\nModel initialised') run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train, - data_val=data_val, data_test=data_test, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(), + data_val=data_val, data_test=data_val, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(), criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle, wandb=args.wandb, forward=args.forward, output=args.output, file=args.file)