diff --git a/.idea/misc.xml b/.idea/misc.xml index 1b5f6f736536803396a9f042bfdc1981e5153ac8..791742aac19004e98ffc7e23b7f561586967d922 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ <component name="Black"> <option name="sdkName" value="Python 3.9 (LC-MS-RT-prediction)" /> </component> - <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" /> + <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" /> </project> \ No newline at end of file diff --git a/main.py b/main.py index bff9cd6e14c75e1d8f1ae76bed652686956a71f9..97cecf39b7b11529db2af847707808a58a76f273 100644 --- a/main.py +++ b/main.py @@ -38,6 +38,8 @@ from model import (RT_pred_model_self_attention, Intensity_pred_model_multi_head # return score, delta_95 + + def train_rt(model, data_train, epoch, optimizer, criterion, metric, wandb=None): losses = 0. dist_acc = 0. diff --git a/scheduler.py b/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..941e84b43090fe273b31deeeae2317030b8f6b29 --- /dev/null +++ b/scheduler.py @@ -0,0 +1,20 @@ +import numpy as np +from torch import optim + + +class CosineWarmupScheduler(optim.lr_scheduler.LRScheduler): + + def __init__(self, optimizer, warmup, max_iters): + self.warmup = warmup + self.max_num_iters = max_iters + super().__init__(optimizer) + + def get_lr(self): + lr_factor = self.get_lr_factor(epoch=self.last_epoch) + return [base_lr * lr_factor for base_lr in self.base_lrs] + + def get_lr_factor(self, epoch): + lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters)) + if epoch <= self.warmup: + lr_factor *= epoch * 1.0 / self.warmup + return lr_factor \ No newline at end of file