Skip to content
Snippets Groups Projects
Commit 7addac19 authored by Schneider Leo's avatar Schneider Leo
Browse files

intensity model

parent c9c09157
No related branches found
No related tags found
No related merge requests found
......@@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="Python 3.9 (LC-MS-RT-prediction)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (LC-MS-RT-prediction)" project-jdk-type="Python SDK" />
</project>
\ No newline at end of file
......@@ -3,4 +3,5 @@ h5py~=3.10.0
pandas~=2.2.0
numpy~=1.26.2
matplotlib~=3.8.2
wandb~=0.16.2
\ No newline at end of file
wandb~=0.16.2
torchmetrics~=1.3.0.post0
\ No newline at end of file
......@@ -6,6 +6,7 @@ import torch
from dataloader import load_data, load_split_intensity, Intentsity_Dataset, load_intensity_from_files
from model import RT_pred_model, Intensity_pred_model_multi_head, RT_pred_model_self_attention
from config import load_args
from loss import masked_cos_sim, masked_pearson_correlation_distance
def train(model, data_train, epoch, optimizer, criterion, cuda=False):
......@@ -30,7 +31,7 @@ def train(model, data_train, epoch, optimizer, criterion, cuda=False):
distance / len(data_train))
def eval(model, data_test, epoch, criterion=nn.MSELoss(reduction='mean'), cuda=False):
def eval(model, data_test, epoch, criterion=masked_cos_sim, cuda=False):
losses = 0.
distance = 0.
for data, target in data_test:
......@@ -57,7 +58,7 @@ def save(model, optimizer, epoch, checkpoint_name):
}, os.path.join('checkpoints', checkpoint_name))
def run(epochs, eval_inter, save_inter, model, data_train, data_test, optimizer, criterion=nn.MSELoss(reduction='mean'),
def run(epochs, eval_inter, save_inter, model, data_train, data_test, optimizer, criterion=masked_cos_sim,
cuda=False):
for e in range(1, epochs + 1):
train(model, data_train, e, optimizer, criterion, cuda=cuda)
......@@ -68,30 +69,12 @@ def run(epochs, eval_inter, save_inter, model, data_train, data_test, optimizer,
def main(args):
os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
os.environ["WANDB_MODE"] = "offline"
os.environ["WANDB_DIR"] = os.path.abspath("./wandb_run")
config = {
"model": "RT prediction GRU/selfAtt+ GRU",
"learning_rate": args.lr,
"batch_size": args.batch_size,
}
# wdb.init(project="RT prediction", dir='wandb_run')
print('Cuda : ', torch.cuda.is_available())
data_train, data_test = load_data(args.batch_size, args.n_train, args.n_test, data_source='database/data.csv')
print('\nData loaded')
model = RT_pred_model_self_attention()
if torch.cuda.is_available():
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
print('\nModel initialised')
run(args.epochs, args.eval_inter, args.save_inter, model, data_train, data_test, optimizer=optimizer, cuda=True)
# wdb.finish()
def main2(args):
wdb.init(project="RT prediction", dir='./wandb_run')
print(torch.cuda.is_available())
# sources_train = ('data/intensity/sequence_train.npy',
# 'data/intensity/intensity_train.npy',
......@@ -111,9 +94,7 @@ def main2(args):
'data/intensity/collision_energy_header.npy',
'data/intensity/precursor_charge_header.npy')
data_train, data_test, data_validation = load_split_intensity(sources, (0.5, 0.25, 0.25))
train = Intentsity_Dataset(data_train)
test = Intentsity_Dataset(data_test)
data_train, data_test, data_validation = load_split_intensity(sources,32, (0.5, 0.25, 0.25))
......@@ -123,10 +104,11 @@ def main2(args):
model = model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print('\nModel initialised')
run(args.epochs, args.eval_inter, args.save_inter, model, train, test, optimizer=optimizer, cuda=True)
run(args.epochs, args.eval_inter, args.save_inter, model, data_train, data_test, optimizer=optimizer, cuda=True)
wdb.finish()
if __name__ == "__main__":
args = load_args()
print(args)
main2(args)
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment