Skip to content
Snippets Groups Projects
Commit 99f0dc51 authored by Schneider Leo's avatar Schneider Leo
Browse files

main intensity

parent 7addac19
No related branches found
No related tags found
No related merge requests found
import os
import wandb as wdb
import torch.nn as nn
import torch.optim as optim
import torch
from dataloader import load_data, load_split_intensity, Intentsity_Dataset, load_intensity_from_files
from model import RT_pred_model, Intensity_pred_model_multi_head, RT_pred_model_self_attention
from config import load_args
from loss import masked_cos_sim, masked_pearson_correlation_distance
def train(model, data_train, epoch, optimizer, criterion, cuda=False):
losses = 0.
distance = 0.
for data1, data2, data3, target in data_train:
if torch.cuda.is_available():
data1, data2, data3, target = data1.cuda(), data2.cuda(), data3.cuda(), target.cuda()
pred_rt = model.forward(data1, data2, data3)
target.float()
loss = criterion(pred_rt, target)
dist = torch.mean(torch.abs(pred_rt - target))
distance += dist.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()
# wdb.log({"train loss": losses / len(data_train), "train mean distance": distance / len(data_train)})
print('epoch : ', epoch, ',train losses : ', losses / len(data_train), " ,mean distance : ",
distance / len(data_train))
def eval(model, data_test, epoch, criterion=masked_cos_sim, cuda=False):
losses = 0.
distance = 0.
for data, target in data_test:
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
pred_rt = model(data)
loss = criterion(pred_rt, target)
losses += loss.item()
dist = torch.mean(torch.abs(pred_rt - target))
distance += dist.item()
# wdb.log({"eval loss": losses / len(data_test), "eval mean distance": distance / len(data_test)})
print('epoch : ', epoch, ',eval losses : ', losses / len(data_test), " ,eval mean distance: :",
distance / len(data_test))
def save(model, optimizer, epoch, checkpoint_name):
print('\nModel Saving...')
model_state_dict = model.state_dict()
os.makedirs('checkpoints', exist_ok=True)
torch.save({
'model_state_dict': model_state_dict,
'global_epoch': epoch,
'optimizer_state_dict': optimizer.state_dict(),
}, os.path.join('checkpoints', checkpoint_name))
def run(epochs, eval_inter, save_inter, model, data_train, data_test, optimizer, criterion=masked_cos_sim,
cuda=False):
for e in range(1, epochs + 1):
train(model, data_train, e, optimizer, criterion, cuda=cuda)
if e % eval_inter == 0:
eval(model, data_test, e, cuda=cuda)
if e % save_inter == 0:
save(model, optimizer, epochs, 'model_self_attention_' + str(e) + '.pt')
def main(args):
os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
wdb.init(project="RT prediction", dir='./wandb_run')
print(torch.cuda.is_available())
sources_train = ('data/intensity/sequence_train.npy',
'data/intensity/intensity_train.npy',
'data/intensity/collision_energy_train.npy',
'data/intensity/precursor_charge_train.npy')
sources_test = ('data/intensity/sequence_test.npy',
'data/intensity/intensity_test.npy',
'data/intensity/collision_energy_test.npy',
'data/intensity/precursor_charge_test.npy')
data_train = load_intensity_from_files(sources_train[0], sources_train[1], sources_train[2], sources_train[3], args.batch_size)
data_test = load_intensity_from_files(sources_test[0], sources_test[1], sources_test[2], sources_test[3], args.batch_size)
print('\nData loaded')
model = Intensity_pred_model_multi_head()
if torch.cuda.is_available():
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print('\nModel initialised')
run(args.epochs, args.eval_inter, args.save_inter, model, data_train, data_test, optimizer=optimizer, cuda=True)
wdb.finish()
if __name__ == "__main__":
args = load_args()
print(args)
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment