diff --git a/main_custom.py b/main_custom.py index ff5a2b19719281784814e51c1c8211275eca6ee4..7985488949223dca9e125d87214fe18e006d36ee 100644 --- a/main_custom.py +++ b/main_custom.py @@ -200,7 +200,6 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, def main(args): - print('file', args.file) if args.wandb is not None: os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd' os.environ["WANDB_MODE"] = "offline" @@ -282,7 +281,7 @@ def save_pred(model, data_val, forward, output_path, file = False): pred_rt, pred_int, seqs, charges, true_rt, true_int, file_list = [], [], [], [], [], [], [] if file: data_val.dataset.set_file_mode(True) - for seq, charge, rt, intensity, file in data_val: + for seq, charge, rt, intensity, files in data_val: rt, intensity = rt.float(), intensity.float() if torch.cuda.is_available(): seq, charge, rt, intensity = seq.cuda(), charge.cuda(), rt.cuda(), intensity.cuda() @@ -293,7 +292,7 @@ def save_pred(model, data_val, forward, output_path, file = False): charges.extend(charge.data.cpu().tolist()) true_rt.extend(rt.data.cpu().tolist()) true_int.extend(intensity.data.cpu().tolist()) - file_list.extend([file]) + file_list.extend([files]) else : for seq, charge, rt, intensity in data_val: rt, intensity = rt.float(), intensity.float()