From 4c99210f1db67f382aa36646c8ace8279eb20eb7 Mon Sep 17 00:00:00 2001
From: Schneider Leo <leo.schneider@etu.ec-lyon.fr>
Date: Tue, 13 Feb 2024 11:38:56 +0100
Subject: [PATCH] code cleaning

---
 .idea/misc.xml |  2 +-
 main.py        |  2 ++
 scheduler.py   | 20 ++++++++++++++++++++
 3 files changed, 23 insertions(+), 1 deletion(-)
 create mode 100644 scheduler.py

diff --git a/.idea/misc.xml b/.idea/misc.xml
index 1b5f6f7..791742a 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -3,5 +3,5 @@
   <component name="Black">
     <option name="sdkName" value="Python 3.9 (LC-MS-RT-prediction)" />
   </component>
-  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" />
+  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" />
 </project>
\ No newline at end of file
diff --git a/main.py b/main.py
index bff9cd6..97cecf3 100644
--- a/main.py
+++ b/main.py
@@ -38,6 +38,8 @@ from model import (RT_pred_model_self_attention, Intensity_pred_model_multi_head
 #     return score, delta_95
 
 
+
+
 def train_rt(model, data_train, epoch, optimizer, criterion, metric, wandb=None):
     losses = 0.
     dist_acc = 0.
diff --git a/scheduler.py b/scheduler.py
new file mode 100644
index 0000000..941e84b
--- /dev/null
+++ b/scheduler.py
@@ -0,0 +1,20 @@
+import numpy as np
+from torch import optim
+
+
+class CosineWarmupScheduler(optim.lr_scheduler.LRScheduler):
+
+    def __init__(self, optimizer, warmup, max_iters):
+        self.warmup = warmup
+        self.max_num_iters = max_iters
+        super().__init__(optimizer)
+
+    def get_lr(self):
+        lr_factor = self.get_lr_factor(epoch=self.last_epoch)
+        return [base_lr * lr_factor for base_lr in self.base_lrs]
+
+    def get_lr_factor(self, epoch):
+        lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
+        if epoch <= self.warmup:
+            lr_factor *= epoch * 1.0 / self.warmup
+        return lr_factor
\ No newline at end of file
-- 
GitLab