diff --git a/config/config.py b/config/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8559a3585aaf49ed93e586a755309a4291052c7e
--- /dev/null
+++ b/config/config.py
@@ -0,0 +1,19 @@
+import argparse
+
+
+def load_args():
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('--epochs', type=int, default=100)
+    parser.add_argument('--save_inter', type=int, default=50)
+    parser.add_argument('--eval_inter', type=int, default=1)
+    parser.add_argument('--lr', type=float, default=0.001)
+    parser.add_argument('--batch_size', type=int, default=2048)
+    parser.add_argument('--model', type=str, default='prosit_transformer')
+    parser.add_argument('--wandb', type=str, default=None)
+    parser.add_argument('--dataset_dir', type=str, default='data/processed_data/npy_image/data_training')
+    parser.add_argument('--output', type=str, default='output/out.csv')
+    parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction)
+    args = parser.parse_args()
+
+    return args
diff --git a/dataset/dataset.py b/dataset/dataset.py
index b0d19939d869aa1d0aad126671689fa552a3f119..2f710c69587c0ba5eea41091101af7d867365ae6 100644
--- a/dataset/dataset.py
+++ b/dataset/dataset.py
@@ -1,29 +1,57 @@
 import torch
 import torchvision
+from torch.utils.data import DataLoader, random_split
+import torchvision.transforms as transforms
 
-from torch.utils.data import DataLoader
-import torchvision.transforms.functional as TF
-import random
-
-root = '../data/processed_data'
-dataset = torchvision.datasets.ImageFolder(root, transform=None)
-data_loader = DataLoader(
-    dataset,
-    batch_size=1,
-    shuffle=False,
-    num_workers=0,
-    collate_fn=None,
-    pin_memory=False,
- )
 
 class Threshold_noise:
-    """Rotate by one of the given angles."""
+    """Remove intensities under given threshold"""
 
-    def __init__(self, threshold):
+    def __init__(self, threshold=100.):
         self.threshold = threshold
 
     def __call__(self, x):
-        angle = random.choice(self.angles)
-        return torch.max(x,0)
+        return torch.max(x-self.threshold,0)
 
-rotation_transform = Threshold_noise(threshold=100)
\ No newline at end of file
+class Log_normalisation:
+    """Log normalisation of intensities"""
+
+    def __init__(self, eps=1e-5):
+        self.epsilon = eps
+
+    def __call__(self, x):
+        return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon)
+
+class Random_shift_rt():
+    pass
+
+
+def load_data(base_dir, batch_size, shuffle=True, transform=None):
+    if transform is not None :
+        transform = transforms.Compose(
+            [transforms.ToTensor(),
+             transforms.Resize((224,224)),
+             Log_normalisation(),
+             transforms.Normalize(0.5, 0.5)])
+    dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=transform)
+    generator1 = torch.Generator().manual_seed(42)
+    data_train, data_test = random_split(dataset, [0.8, 0.2], generator=generator1)
+    data_loader_train = DataLoader(
+        dataset=data_train,
+        batch_size=batch_size,
+        shuffle=shuffle,
+        num_workers=0,
+        collate_fn=None,
+        pin_memory=False,
+    )
+
+    data_loader_test = DataLoader(
+        dataset=data_test,
+        batch_size=batch_size,
+        shuffle=shuffle,
+        num_workers=0,
+        collate_fn=None,
+        pin_memory=False,
+    )
+
+    return data_loader_train, data_loader_test
\ No newline at end of file
diff --git a/image_processing/build_dataset.py b/image_processing/build_dataset.py
index 2a127e0841ff7fef09c60301134ff6cb753da677..6f7c4f71cbd15ed21771b20a99dfcf8d8bde6ad6 100644
--- a/image_processing/build_dataset.py
+++ b/image_processing/build_dataset.py
@@ -72,6 +72,6 @@ def create_dataset():
             np.save(directory_path + "/" + name + '_' + analyse + '.npy', mat)
 
 
-
+#TODO : train val test split
 if __name__ =='__main__' :
-    label = create_antibio_dataset()
\ No newline at end of file
+    create_dataset()
\ No newline at end of file
diff --git a/main.py b/main.py
index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..9c300061d4c3bea00ed1b7734324b9ad446b2d9a 100644
--- a/main.py
+++ b/main.py
@@ -0,0 +1,79 @@
+from config.config import load_args
+from dataset.dataset import load_data
+import torch
+import torch.nn as nn
+from models.model import Classification_model
+import torch.optim as optim
+
+
+
+
+def train(model, data_train, optimizer, loss_function):
+    model.train()
+    losses = 0.
+    acc = 0.
+    for param in model.parameters():
+        param.requires_grad = True
+
+    for im, label in data_train:
+        label = label.float()
+        if torch.cuda.is_available():
+            im, label = im.cuda(), label.cuda()
+        pred_logits = model.forward(im)
+        pred_class = torch.argmax(pred_logits,dim=1)
+        acc += pred_class==label.float().sum()
+        loss = loss_function(pred_logits,label)
+        losses += loss.item()
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+    losses = losses/data_train.size()
+    acc = acc/data_train.size()
+    print('loss : ',losses,' acc : ',acc)
+
+def test(model, data_train, loss_function):
+    model.test()
+    losses = 0.
+    acc = 0.
+    for param in model.parameters():
+        param.requires_grad = False
+
+    for im, label in data_train:
+        label = label.float()
+        if torch.cuda.is_available():
+            im, label = im.cuda(), label.cuda()
+        pred_logits = model.forward(im)
+        pred_class = torch.argmax(pred_logits,dim=1)
+        acc += pred_class==label.float().sum()
+        loss = loss_function(pred_logits,label)
+        losses += loss.item()
+    losses = losses/data_train.size()
+    acc = acc/data_train.size()
+    print('loss : ',losses,' acc : ',acc)
+
+def run(args):
+    model = Classification_model
+    if args.pretrain_path is not None :
+        load_model(model,args.pretrain_path)
+    data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size)
+    loss_function = nn.CrossEntropyLoss()
+    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
+
+    for e in range(args.epoch):
+        train(model,data_train,optimizer,loss_function)
+
+        if e%args.eval_inter==0 :
+            test(model,data_test,loss_function)
+    save_model(model,args.save_path)
+
+def save_model(model, path):
+    torch.save(model.state_dict(), path)
+
+def load_model(model, path):
+    model.load_state_dict(torch.load(path, weights_only=True))
+
+
+
+if __name__ == '__main__':
+    args = load_args()
+    run(args)
\ No newline at end of file
diff --git a/models/model.py b/models/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..78f0771ec5a21d3217ea969c6ec0b8fd533b4102
--- /dev/null
+++ b/models/model.py
@@ -0,0 +1,271 @@
+import torch
+import torch.nn as nn
+import torchvision
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None):
+        super(BasicBlock, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        if groups != 1 or base_width != 64:
+            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = norm_layer(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = norm_layer(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+    # This variant is also known as ResNet V1.5 and improves accuracy according to
+    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
+                 base_width=64, dilation=1, norm_layer=None):
+        super(Bottleneck, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        width = int(planes * (base_width / 64.)) * groups
+        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv1x1(inplanes, width)
+        self.bn1 = norm_layer(width)
+        self.conv2 = conv3x3(width, width, stride, groups, dilation)
+        self.bn2 = norm_layer(width)
+        self.conv3 = conv1x1(width, planes * self.expansion)
+        self.bn3 = norm_layer(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class ResNet(nn.Module):
+
+    def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
+                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
+                 norm_layer=None):
+        super(ResNet, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        self._norm_layer = norm_layer
+
+        self.inplanes = 64
+        self.dilation = 1
+        if replace_stride_with_dilation is None:
+            # each element in the tuple indicates if we should replace
+            # the 2x2 stride with a dilated convolution instead
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError("replace_stride_with_dilation should be None "
+                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+        self.groups = groups
+        self.base_width = width_per_group
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = norm_layer(self.inplanes)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+                                       dilate=replace_stride_with_dilation[0])
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+                                       dilate=replace_stride_with_dilation[1])
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+                                       dilate=replace_stride_with_dilation[2])
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+                elif isinstance(m, BasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                norm_layer(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+                            self.base_width, previous_dilation, norm_layer))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes, groups=self.groups,
+                                base_width=self.base_width, dilation=self.dilation,
+                                norm_layer=norm_layer))
+
+        return nn.Sequential(*layers)
+
+    def _forward_impl(self, x):
+        # See note [TorchScript super()]
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+        x = self.fc(x)
+
+        return x
+
+    def forward(self, x):
+        return self._forward_impl(x)
+
+
+def _resnet(block, layers, **kwargs):
+    model = ResNet(block, layers, **kwargs)
+
+    return model
+
+
+def resnet18(num_classes=1000,**kwargs):
+    r"""ResNet-18 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+    """
+    return _resnet(BasicBlock, [2, 2, 2, 2],num_classes=num_classes,
+                   **kwargs)
+
+
+
+def resnet34(,num_classes=1000, **kwargs):
+    r"""ResNet-34 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+    """
+    return _resnet( BasicBlock, [3, 4, 6, 3],num_classes=num_classes,
+                   **kwargs)
+
+
+
+def resnet50(,num_classes=1000,**kwargs):
+    r"""ResNet-50 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+    """
+    return _resnet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes,
+                   **kwargs)
+
+
+
+def resnet101(num_classes=1000,**kwargs):
+    r"""ResNet-101 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+    """
+    return _resnet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes,
+                   **kwargs)
+
+
+
+def resnet152(num_classes=1000,**kwargs):
+    r"""ResNet-152 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
+
+    Args:
+    """
+    return _resnet(Bottleneck, [3, 8, 36, 3],num_classes=num_classes,
+                   **kwargs)
+
+class Classification_model(nn.Module):
+
+    def __init__(self,n_class):
+        self.n_class = n_class
+        self.im_encoder = resnet18(num_classes=self.n_class)
+
+
+    def forward(self, input):
+        return self.im_encoder(input)
\ No newline at end of file