Skip to content
Snippets Groups Projects
Commit 3e525a66 authored by Schneider Leo's avatar Schneider Leo
Browse files

datasets

parent 1de23a4a
No related branches found
No related tags found
No related merge requests found
...@@ -154,15 +154,15 @@ def load_data(path_train, path_val, path_test, batch_size, length, pad=False, co ...@@ -154,15 +154,15 @@ def load_data(path_train, path_val, path_test, batch_size, length, pad=False, co
print('Loading data') print('Loading data')
data_train = pd.read_pickle(path_train) data_train = pd.read_pickle(path_train)
data_val = pd.read_pickle(path_val) data_val = pd.read_pickle(path_val)
data_test = pd.read_pickle(path_test) # data_test = pd.read_pickle(path_test)
train = Common_Dataset(data_train, length, pad, convert, vocab) train = Common_Dataset(data_train, length, pad, convert, vocab)
test = Common_Dataset(data_val, length, pad, convert, vocab) val = Common_Dataset(data_val, length, pad, convert, vocab)
val = Common_Dataset(data_test, length, pad, convert, vocab) # val = Common_Dataset(data_test, length, pad, convert, vocab)
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=True) # test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)
return train_loader, val_loader, test_loader return train_loader, val_loader
if __name__ =='__main__' : if __name__ =='__main__' :
irt_train = np.load('data/intensity/irt_train.npy') irt_train = np.load('data/intensity/irt_train.npy')
......
...@@ -108,13 +108,13 @@ class RT_Dataset(Dataset): ...@@ -108,13 +108,13 @@ class RT_Dataset(Dataset):
def load_data(batch_size, data_sources, n_train=None, n_test=None, length=30): def load_data(batch_size, data_sources, n_train=None, n_test=None, length=30):
print('Loading data') print('Loading data')
train = RT_Dataset(n_train, data_sources[0], 'train', length) train = RT_Dataset(n_train, data_sources[0], 'train', length)
test = RT_Dataset(n_test, data_sources[1], 'test', length) # test = RT_Dataset(n_test, data_sources[1], 'test', length)
val = RT_Dataset(n_test, data_sources[2], 'test', length) val = RT_Dataset(n_test, data_sources[2], 'test', length)
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=True) # test_loader = DataLoader(test, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=True)
return train_loader, val_loader, test_loader return train_loader, val_loader
class H5ToStorage(): class H5ToStorage():
......
...@@ -220,28 +220,28 @@ def main(args): ...@@ -220,28 +220,28 @@ def main(args):
print('Cuda : ', torch.cuda.is_available()) print('Cuda : ', torch.cuda.is_available())
if args.forward == 'both': if args.forward == 'both':
data_train, data_val, data_test = common_dataset.load_data(path_train=args.dataset_train, data_train, data_val, _ = common_dataset.load_data(path_train=args.dataset_train,
path_val=args.dataset_val, path_val=args.dataset_val,
path_test=args.dataset_test, path_test=args.dataset_test,
batch_size=args.batch_size, length=args.seq_length, pad = False, convert=False, vocab='unmod') batch_size=args.batch_size, length=args.seq_length, pad = False, convert=False, vocab='unmod')
elif args.forward == 'rt': elif args.forward == 'rt':
data_train, data_val, data_test = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test], data_train, data_val = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test],
batch_size=args.batch_size, length=args.seq_length) batch_size=args.batch_size, length=args.seq_length)
elif args.forward == 'transfer': elif args.forward == 'transfer':
data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'], data_train, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'],
batch_size=args.batch_size, length=args.seq_length) batch_size=args.batch_size, length=args.seq_length)
_, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val, _, data_val = common_dataset.load_data(path_train=args.dataset_val,
path_val=args.dataset_val, path_val=args.dataset_val,
path_test=args.dataset_test, path_test=args.dataset_test,
batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod') batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
elif args.forward == 'reverse': elif args.forward == 'reverse':
_, data_val, data_test = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test], _, data_val = dataloader.load_data(data_sources=['database/data_train.csv',args.dataset_val,args.dataset_test],
batch_size=args.batch_size, length=args.seq_length) batch_size=args.batch_size, length=args.seq_length)
data_train, _, _ = common_dataset.load_data(path_train=args.dataset_train, data_train, _ = common_dataset.load_data(path_train=args.dataset_train,
path_val=args.dataset_train, path_val=args.dataset_train,
path_test=args.dataset_train, path_test=args.dataset_train,
batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod') batch_size=args.batch_size, length=args.seq_length, pad = True, convert=True, vocab='unmod')
...@@ -261,7 +261,7 @@ def main(args): ...@@ -261,7 +261,7 @@ def main(args):
print('\nModel initialised') print('\nModel initialised')
run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train, run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train,
data_val=data_val, data_test=data_test, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(), data_val=data_val, data_test=data_val, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(),
criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle, criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle,
wandb=args.wandb, forward=args.forward, output=args.output, file=args.file) wandb=args.wandb, forward=args.forward, output=args.output, file=args.file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment