diff --git a/config_common.py b/config_common.py index 527364198df4c98c9164b256493d228642b3cc9b..09dea8f600b833a05a64f4dd0faad4fc2a6b6027 100644 --- a/config_common.py +++ b/config_common.py @@ -26,6 +26,7 @@ def load_args(): parser.add_argument('--output', type=str, default='output/out.csv') parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction) parser.add_argument('--activation', type=str,default='relu') + parser.add_argument('--file', action=argparse.BooleanOptionalAction) args = parser.parse_args() return args diff --git a/main_custom.py b/main_custom.py index e54c69efd5349982013a2e757f48dd27348e8a4a..b0866f7bdb1c1051dbf188a9999d50fb23ef3f2f 100644 --- a/main_custom.py +++ b/main_custom.py @@ -175,8 +175,7 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt, - criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv'): - + criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv', file=False): if forward =='transfer' : for e in range(1, epochs + 1): train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt', @@ -197,7 +196,7 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, wandb=wandb) if e % save_inter == 0: save(model, 'model_common_' + str(e) + '.pt') - save_pred(model, data_val, forward, output) + save_pred(model, data_val, forward, output, file=file) def main(args): @@ -244,7 +243,7 @@ def main(args): run(epochs=args.epochs, eval_inter=args.eval_inter, save_inter=args.save_inter, model=model, data_train=data_train, data_val=data_val, data_test=data_test, optimizer=optimizer, criterion_rt=torch.nn.MSELoss(), criterion_intensity=masked_cos_sim, metric_rt=distance, metric_intensity=masked_spectral_angle, - wandb=args.wandb, forward=args.forward, output=args.output) + wandb=args.wandb, forward=args.forward, output=args.output, file=args.file) if args.wandb is not None: wdb.finish()