diff --git a/config_common.py b/config_common.py index 7bf9ab098f7c5f93fb7060c025066f5dc675c1a3..527364198df4c98c9164b256493d228642b3cc9b 100644 --- a/config_common.py +++ b/config_common.py @@ -23,6 +23,7 @@ def load_args(): parser.add_argument('--dataset_train', type=str, default='database/data_DIA_ISA_55_train.pkl') parser.add_argument('--dataset_val', type=str, default='database/data_DIA_ISA_55_test.pkl') parser.add_argument('--dataset_test', type=str, default='database/data_DIA_ISA_55_test.pkl') + parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction) parser.add_argument('--activation', type=str,default='relu') args = parser.parse_args() diff --git a/main_custom.py b/main_custom.py index 1f1d51e3b9487ee450321292b650762da4baf28f..fcc907494f8c2419cfd47376f0c5a9f4fac3a94c 100644 --- a/main_custom.py +++ b/main_custom.py @@ -174,7 +174,7 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt, - criterion_intensity, metric_rt, metric_intensity, forward, wandb=None): + criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv'): for e in range(1, epochs + 1): train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=wandb) @@ -183,7 +183,7 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, wandb=wandb) if e % save_inter == 0: save(model, 'model_common_' + str(e) + '.pt') - save_pred(model, data_val, forward, 'output/out.csv') + save_pred(model, data_val, forward, output) def main(args): @@ -221,7 +221,7 @@ def main(args): run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train, data_val=data_val, data_test=data_test, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(), criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle, - wandb=args.wandb, forward=args.forward) + wandb=args.wandb, forward=args.forward, output=args.output) if args.wandb is not None: wdb.finish()