From 6a9416eafc474ccca72a6f91d273af8801f231b6 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 24 Sep 2024 14:25:58 +0200
Subject: [PATCH] forward = transfer

---
 main_custom.py | 37 +++++++++++++++++++++++++++----------
 1 file changed, 27 insertions(+), 10 deletions(-)

diff --git a/main_custom.py b/main_custom.py
index 78b8bb0..fb68fa6 100644
--- a/main_custom.py
+++ b/main_custom.py
@@ -175,15 +175,28 @@ def eval(model, data_val, epoch, criterion_rt, criterion_intensity, metric_rt, m
 
 def run(epochs, eval_inter, save_inter, model, data_train, data_val, data_test, optimizer, criterion_rt,
         criterion_intensity, metric_rt, metric_intensity, forward, wandb=None, output='output/out.csv'):
-    for e in range(1, epochs + 1):
-        train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
-              wandb=wandb)
-        if e % eval_inter == 0:
-            eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
-                 wandb=wandb)
-        if e % save_inter == 0:
-            save(model, 'model_common_' + str(e) + '.pt')
-    save_pred(model, data_val, forward, output)
+
+    if forward =='transfer' :
+        for e in range(1, epochs + 1):
+            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'rt',
+                  wandb=wandb)
+            if e % eval_inter == 0:
+                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, 'both',
+                     wandb=wandb)
+            if e % save_inter == 0:
+                save(model, 'model_common_' + str(e) + '.pt')
+        save_pred(model, data_val, 'rt', output)
+
+    else :
+        for e in range(1, epochs + 1):
+            train(model, data_train, e, optimizer, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
+                  wandb=wandb)
+            if e % eval_inter == 0:
+                eval(model, data_val, e, criterion_rt, criterion_intensity, metric_rt, metric_intensity, forward,
+                     wandb=wandb)
+            if e % save_inter == 0:
+                save(model, 'model_common_' + str(e) + '.pt')
+        save_pred(model, data_val, forward, output)
 
 
 def main(args):
@@ -202,7 +215,11 @@ def main(args):
                                                                    path_test=args.dataset_test,
                                                                    batch_size=args.batch_size, length=25, pad = True, convert=True, vocab='iapuc')
     elif args.forward == 'rt':
-        data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.csv','database/data_holdout.csv'],
+        data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,args.dataset_val,args.dataset_test],
+                                                               batch_size=args.batch_size, length=25)
+
+    elif args.forward == 'transfer':
+        data_train, _, _ = dataloader.load_data(data_sources=[args.dataset_train,'database/data_holdout.cvs','database/data_holdout.cvs'],
                                                                batch_size=args.batch_size, length=25)
 
         _, data_val, data_test = common_dataset.load_data(path_train=args.dataset_val,
-- 
GitLab