From 6a9416eafc474ccca72a6f91d273af8801f231b6 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 24 Sep 2024 14:25:58 +0200 Subject: [PATCH] forward = transfer --- main_custom.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/main_custom.py b/main_custom.py index 78b8bb0..fb68fa6 100644 --- a/main_custom.py +++ b/main_custom.py @@ -175,15 +175,28 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv'): - for e in range(1, epochs + 1): - train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, - wandb=wandb) - if e % eval_inter == 0: - eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, - wandb=wandb) - if e % save_inter == 0: - save(model, 'model_common_' + str(e) + '.pt') - save_pred(model, data_val, forward, output) + + if forward =='transfer' : + for e in range(1, epochs + 1): + train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt', + wandb=wandb) + if e % eval_inter == 0: + eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'both', + wandb=wandb) + if e % save_inter == 0: + save(model, 'model_common_' + str(e) + '.pt') + save_pred(model, data_val, 'rt', output) + + else : + for e in range(1, epochs + 1): + train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, + wandb=wandb) + if e % eval_inter == 0: + eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward, + wandb=wandb) + if e % save_inter == 0: + save(model, 'model_common_' + str(e) + '.pt') + save_pred(model, data_val, forward, output) def main(args): @@ -202,7 +215,11 @@ def main(args): path_test=args.dataset_test, batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='iapuc') elif args.forward == 'rt': - data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'], + data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test], + batch_size=args.batch_size, length=25) + + elif args.forward == 'transfer': + data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.cvs','database/data_holdout.cvs'], batch_size=args.batch_size, length=25) _, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val, -- GitLab