Skip to content
Snippets Groups Projects
Commit 2fc33702 authored by Léo Schneider's avatar Léo Schneider Committed by Schneider Leo
Browse files

merging main

parent 99f0dc51
No related branches found
No related tags found
No related merge requests found
......@@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="Python 3.9 (LC-MS-RT-prediction)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
import os
import wandb as wdb
import torch.nn as nn
import torch.optim as optim
import torch
from dataloader import load_data, load_split_intensity, Intentsity_Dataset, load_intensity_from_files
from model import RT_pred_model, Intensity_pred_model_multi_head, RT_pred_model_self_attention
from config import load_args
from loss import masked_cos_sim, masked_pearson_correlation_distance
def train(model, data_train, epoch, optimizer, criterion, cuda=False):
losses = 0.
distance = 0.
for data1, data2, data3, target in data_train:
if torch.cuda.is_available():
data1, data2, data3, target = data1.cuda(), data2.cuda(), data3.cuda(), target.cuda()
pred_rt = model.forward(data1, data2, data3)
target.float()
loss = criterion(pred_rt, target)
dist = torch.mean(torch.abs(pred_rt - target))
distance += dist.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
# wdb.log({"train loss": losses / len(data_train), "train mean distance": distance / len(data_train)})
print('epoch : ', epoch, ',train losses : ', losses / len(data_train), " ,mean distance : ",
distance / len(data_train))
def eval(model, data_test, epoch, criterion=masked_cos_sim, cuda=False):
losses = 0.
distance = 0.
for data, target in data_test:
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
pred_rt = model(data)
loss = criterion(pred_rt, target)
losses += loss.item()
dist = torch.mean(torch.abs(pred_rt - target))
distance += dist.item()
# wdb.log({"eval loss": losses / len(data_test), "eval mean distance": distance / len(data_test)})
print('epoch : ', epoch, ',eval losses : ', losses / len(data_test), " ,eval mean distance: :",
distance / len(data_test))
def save(model, optimizer, epoch, checkpoint_name):
print('\nModel Saving...')
model_state_dict = model.state_dict()
os.makedirs('checkpoints', exist_ok=True)
torch.save({
'model_state_dict': model_state_dict,
'global_epoch': epoch,
'optimizer_state_dict': optimizer.state_dict(),
}, os.path.join('checkpoints', checkpoint_name))
def run(epochs, eval_inter, save_inter, model, data_train, data_test, optimizer, criterion=masked_cos_sim,
cuda=False):
for e in range(1, epochs + 1):
train(model, data_train, e, optimizer, criterion, cuda=cuda)
if e % eval_inter == 0:
eval(model, data_test, e, cuda=cuda)
if e % save_inter == 0:
save(model, optimizer, epochs, 'model_self_attention_' + str(e) + '.pt')
def main(args):
os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
wdb.init(project="RT prediction", dir='./wandb_run')
print(torch.cuda.is_available())
sources_train = ('data/intensity/sequence_train.npy',
'data/intensity/intensity_train.npy',
'data/intensity/collision_energy_train.npy',
'data/intensity/precursor_charge_train.npy')
sources_test = ('data/intensity/sequence_test.npy',
'data/intensity/intensity_test.npy',
'data/intensity/collision_energy_test.npy',
'data/intensity/precursor_charge_test.npy')
data_train = load_intensity_from_files(sources_train[0], sources_train[1], sources_train[2], sources_train[3], args.batch_size)
data_test = load_intensity_from_files(sources_test[0], sources_test[1], sources_test[2], sources_test[3], args.batch_size)
print('\nData loaded')
model = Intensity_pred_model_multi_head()
if torch.cuda.is_available():
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print('\nModel initialised')
run(args.epochs, args.eval_inter, args.save_inter, model, data_train, data_test, optimizer=optimizer, cuda=True)
wdb.finish()
if __name__ == "__main__":
args = load_args()
print(args)
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment