diff --git a/common_dataset.py b/common_dataset.py index 19f35c97b153cdbd3a426cda491d13c1026e6d99..bab5944e439f64aedf8194928f505b2c2a0a4c10 100644 --- a/common_dataset.py +++ b/common_dataset.py @@ -137,9 +137,8 @@ class Common_Dataset(Dataset): rt = self.data['Retention time'][index] intensity = self.data['Spectra'][index] charge = self.data['Charge'][index] - file = self.data['file'][index] - - if self.file_mode : + if self.file_mode: + file = self.data['file'][index] return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity), torch.tensor(file) else : return torch.tensor(seq), torch.tensor(charge), torch.tensor(rt).float(), torch.tensor(intensity) @@ -164,3 +163,19 @@ def load_data(path_train, path_val, path_test, batch_size, length, pad=False, co val_loader = DataLoader(val, batch_size=batch_size, shuffle=True) return train_loader, val_loader, test_loader + + +irt_train = np.load('data/intensity/collision_irt_train.npy') +seq_train = np.load('data/intensity/sequence_train.npy') +charge_train = np.load('data/intensity/precursor_charge_train.npy') +spectra_train = np.load('data/intensity/intensity_train.npy') + +# irt_holdout = np.load('data/intensity/collision_irt_holdout.npy') +# seq_holdout = np.load('data/intensity/sequence_holdout.npy') +# charge_holdout = np.load('data/intensity/precursor_charge_holdout.npy') +# spectra_holdout = np.load('data/intensity/intensity_holdout.npy') + +dataset_train = pd.DataFrame({'Sequence':list(seq_train), 'Retention time':list(irt_train), 'Charge':list(charge_train), 'Spectra' : list(spectra_train)},index=list(range(6787933))) +dataset_train.to_pickle('database/data_prosit_merged_train.pkl') +# dataset_test = pd.DataFrame({'Sequence':list(seq_holdout), 'Retention time':list(irt_holdout), 'Charge':list(charge_holdout), 'Spectra' : list(spectra_holdout)},index=list(range(754215))) +# dataset_train.to_pickle('database/data_prosit_merged_holdout.pkl') \ No newline at end of file diff --git a/main.py b/main.py index 97cecf39b7b11529db2af847707808a58a76f273..23c1a3d2b6562e57a7babb22488bb5f96c69cca3 100644 --- a/main.py +++ b/main.py @@ -340,10 +340,10 @@ def main_int(args): 'data/intensity/collision_energy_train.npy', 'data/intensity/precursor_charge_train.npy') - sources_test = ('data/intensity/sequence_test.npy', - 'data/intensity/intensity_test.npy', - 'data/intensity/collision_energy_test.npy', - 'data/intensity/precursor_charge_test.npy') + sources_test = ('data/intensity/sequence_holdout.npy', + 'data/intensity/intensity_holdout.npy', + 'data/intensity/collision_energy_holdout.npy', + 'data/intensity/precursor_charge_holdout.npy') data_train = load_intensity_from_files(sources_train[0], sources_train[1], sources_train[2], sources_train[3], args.batch_size)