Skip to content
Snippets Groups Projects
Commit 4c99210f authored by Schneider Leo's avatar Schneider Leo
Browse files

code cleaning

parent 61d770f6
No related branches found
No related tags found
No related merge requests found
......@@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="Python 3.9 (LC-MS-RT-prediction)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
......@@ -38,6 +38,8 @@ from model import (RT_pred_model_self_attention, Intensity_pred_model_multi_head
# return score, delta_95
def train_rt(model, data_train, epoch, optimizer, criterion, metric, wandb=None):
losses = 0.
dist_acc = 0.
......
import numpy as np
from torch import optim
class CosineWarmupScheduler(optim.lr_scheduler.LRScheduler):
def __init__(self, optimizer, warmup, max_iters):
self.warmup = warmup
self.max_num_iters = max_iters
super().__init__(optimizer)
def get_lr(self):
lr_factor = self.get_lr_factor(epoch=self.last_epoch)
return [base_lr * lr_factor for base_lr in self.base_lrs]
def get_lr_factor(self, epoch):
lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters))
if epoch <= self.warmup:
lr_factor *= epoch * 1.0 / self.warmup
return lr_factor
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment