From 4c99210f1db67f382aa36646c8ace8279eb20eb7 Mon Sep 17 00:00:00 2001 From: Schneider Leo <leo.schneider@etu.ec-lyon.fr> Date: Tue, 13 Feb 2024 11:38:56 +0100 Subject: [PATCH] code cleaning --- .idea/misc.xml | 2 +- main.py | 2 ++ scheduler.py | 20 ++++++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 scheduler.py diff --git a/.idea/misc.xml b/.idea/misc.xml index 1b5f6f7..791742a 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ <component name="Black"> <option name="sdkName" value="Python 3.9 (LC-MS-RT-prediction)" /> </component> - <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" /> + <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" /> </project> \ No newline at end of file diff --git a/main.py b/main.py index bff9cd6..97cecf3 100644 --- a/main.py +++ b/main.py @@ -38,6 +38,8 @@ from model import (RT_pred_model_self_attention, Intensity_pred_model_multi_head # return score, delta_95 + + def train_rt(model, data_train, epoch, optimizer, criterion, metric, wandb=None): losses = 0. dist_acc = 0. diff --git a/scheduler.py b/scheduler.py new file mode 100644 index 0000000..941e84b --- /dev/null +++ b/scheduler.py @@ -0,0 +1,20 @@ +import numpy as np +from torch import optim + + +class CosineWarmupScheduler(optim.lr_scheduler.LRScheduler): + + def __init__(self, optimizer, warmup, max_iters): + self.warmup = warmup + self.max_num_iters = max_iters + super().__init__(optimizer) + + def get_lr(self): + lr_factor = self.get_lr_factor(epoch=self.last_epoch) + return [base_lr * lr_factor for base_lr in self.base_lrs] + + def get_lr_factor(self, epoch): + lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters)) + if epoch <= self.warmup: + lr_factor *= epoch * 1.0 / self.warmup + return lr_factor \ No newline at end of file -- GitLab