diff --git a/.gitignore b/.gitignore
index 1f8b3087abce686b440c12c0d2ba6860fe047a28..749ec53c95489a50a42060d36817030ff03cb25a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,7 +2,5 @@
 /fig/
 /venv/
 /dataset/
-/test.py
 /database/
-
 /wandb_run/
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..59f12795546e16b50f8a4ca49a1035e993619daf
--- /dev/null
+++ b/test.py
@@ -0,0 +1,132 @@
+import os
+import wandb as wdb
+import torch.nn as nn
+import torch.optim as optim
+import torch
+from dataloader import load_data, load_split_intensity, Intentsity_Dataset, load_intensity_from_files
+from model import RT_pred_model, Intensity_pred_model_multi_head, RT_pred_model_self_attention
+from config import load_args
+
+
+def train(model, data_train, epoch, optimizer, criterion, cuda=False):
+    losses = 0.
+    distance = 0.
+    for data1, data2, data3, target in data_train:
+        if torch.cuda.is_available():
+            data1, data2, data3, target = data1.cuda(), data2.cuda(), data3.cuda(), target.cuda()
+        pred_rt = model.forward(data1, data2, data3)
+        target.float()
+        loss = criterion(pred_rt, target)
+        dist = torch.mean(torch.abs(pred_rt - target))
+        distance += dist.item()
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+        losses += loss.item()
+
+    # wdb.log({"train loss": losses / len(data_train), "train mean distance": distance / len(data_train)})
+
+    print('epoch : ', epoch, ',train losses : ', losses / len(data_train), " ,mean distance : ",
+          distance / len(data_train))
+
+
+def eval(model, data_test, epoch, criterion=nn.MSELoss(reduction='mean'), cuda=False):
+    losses = 0.
+    distance = 0.
+    for data, target in data_test:
+        if torch.cuda.is_available():
+            data, target = data.cuda(), target.cuda()
+        pred_rt = model(data)
+        loss = criterion(pred_rt, target)
+        losses += loss.item()
+        dist = torch.mean(torch.abs(pred_rt - target))
+        distance += dist.item()
+    # wdb.log({"eval loss": losses / len(data_test), "eval mean distance": distance / len(data_test)})
+    print('epoch : ', epoch, ',eval losses : ', losses / len(data_test), " ,eval mean distance: :",
+          distance / len(data_test))
+
+
+def save(model, optimizer, epoch, checkpoint_name):
+    print('\nModel Saving...')
+    model_state_dict = model.state_dict()
+    os.makedirs('checkpoints', exist_ok=True)
+    torch.save({
+        'model_state_dict': model_state_dict,
+        'global_epoch': epoch,
+        'optimizer_state_dict': optimizer.state_dict(),
+    }, os.path.join('checkpoints', checkpoint_name))
+
+
+def run(epochs, eval_inter, save_inter, model, data_train, data_test, optimizer, criterion=nn.MSELoss(reduction='mean'),
+        cuda=False):
+    for e in range(1, epochs + 1):
+        train(model, data_train, e, optimizer, criterion, cuda=cuda)
+        if e % eval_inter == 0:
+            eval(model, data_test, e, cuda=cuda)
+        if e % save_inter == 0:
+            save(model, optimizer, epochs, 'model_self_attention_' + str(e) + '.pt')
+
+
+def main(args):
+    os.environ["WANDB_API_KEY"] = 'b4a27ac6b6145e1a5d0ee7f9e2e8c20bd101dccd'
+    os.environ["WANDB_MODE"] = "offline"
+
+    config = {
+        "model": "RT prediction GRU/selfAtt+ GRU",
+        "learning_rate": args.lr,
+        "batch_size": args.batch_size,
+    }
+
+    # wdb.init(project="RT prediction", dir='wandb_run')
+
+    print('Cuda : ', torch.cuda.is_available())
+    data_train, data_test = load_data(args.batch_size, args.n_train, args.n_test, data_source='database/data.csv')
+    print('\nData loaded')
+    model = RT_pred_model_self_attention()
+    if torch.cuda.is_available():
+        model = model.cuda()
+    optimizer = optim.Adam(model.parameters(), lr=args.lr)
+    print('\nModel initialised')
+    run(args.epochs, args.eval_inter, args.save_inter, model, data_train, data_test, optimizer=optimizer, cuda=True)
+    # wdb.finish()
+
+
+def main2(args):
+    print(torch.cuda.is_available())
+    # sources_train = ('data/intensity/sequence_train.npy',
+    #                  'data/intensity/intensity_train.npy',
+    #                  'data/intensity/collision_energy_train.npy',
+    #                  'data/intensity/precursor_charge_train.npy')
+    #
+    # sources_test = ('data/intensity/sequence_test.npy',
+    #                 'data/intensity/intensity_test.npy',
+    #                 'data/intensity/collision_energy_test.npy',
+    #                 'data/intensity/precursor_charge_test.npy')
+    #
+    # data_train = load_intensity_from_files(sources_train[0], sources_train[1], sources_train[2], sources_train[3])
+    # data_test = load_intensity_from_files(sources_test[0], sources_test[1], sources_test[2], sources_test[3])
+
+    sources = ('data/intensity/sequence_header.npy',
+               'data/intensity/intensity_header.npy',
+               'data/intensity/collision_energy_header.npy',
+               'data/intensity/precursor_charge_header.npy')
+
+    data_train, data_test, data_validation = load_split_intensity(sources, (0.5, 0.25, 0.25))
+    train = Intentsity_Dataset(data_train)
+    test = Intentsity_Dataset(data_test)
+
+
+
+    print('\nData loaded')
+    model = Intensity_pred_model_multi_head()
+    if torch.cuda.is_available():
+        model = model.cuda()
+    optimizer = optim.Adam(model.parameters(), lr=0.001)
+    print('\nModel initialised')
+    run(args.epochs, args.eval_inter, args.save_inter, model, train, test, optimizer=optimizer, cuda=True)
+
+
+if __name__ == "__main__":
+    args = load_args()
+    print(args)
+    main2(args)