diff --git a/config/config.py b/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..8559a3585aaf49ed93e586a755309a4291052c7e --- /dev/null +++ b/config/config.py @@ -0,0 +1,19 @@ +import argparse + + +def load_args(): + parser = argparse.ArgumentParser() + + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--save_inter', type=int, default=50) + parser.add_argument('--eval_inter', type=int, default=1) + parser.add_argument('--lr', type=float, default=0.001) + parser.add_argument('--batch_size', type=int, default=2048) + parser.add_argument('--model', type=str, default='prosit_transformer') + parser.add_argument('--wandb', type=str, default=None) + parser.add_argument('--dataset_dir', type=str, default='data/processed_data/npy_image/data_training') + parser.add_argument('--output', type=str, default='output/out.csv') + parser.add_argument('--norm_first', action=argparse.BooleanOptionalAction) + args = parser.parse_args() + + return args diff --git a/dataset/dataset.py b/dataset/dataset.py index b0d19939d869aa1d0aad126671689fa552a3f119..2f710c69587c0ba5eea41091101af7d867365ae6 100644 --- a/dataset/dataset.py +++ b/dataset/dataset.py @@ -1,29 +1,57 @@ import torch import torchvision +from torch.utils.data import DataLoader, random_split +import torchvision.transforms as transforms -from torch.utils.data import DataLoader -import torchvision.transforms.functional as TF -import random - -root = '../data/processed_data' -dataset = torchvision.datasets.ImageFolder(root, transform=None) -data_loader = DataLoader( - dataset, - batch_size=1, - shuffle=False, - num_workers=0, - collate_fn=None, - pin_memory=False, - ) class Threshold_noise: - """Rotate by one of the given angles.""" + """Remove intensities under given threshold""" - def __init__(self, threshold): + def __init__(self, threshold=100.): self.threshold = threshold def __call__(self, x): - angle = random.choice(self.angles) - return torch.max(x,0) + return torch.max(x-self.threshold,0) -rotation_transform = Threshold_noise(threshold=100) \ No newline at end of file +class Log_normalisation: + """Log normalisation of intensities""" + + def __init__(self, eps=1e-5): + self.epsilon = eps + + def __call__(self, x): + return torch.log(x+1+self.epsilon)/torch.log(torch.max(x)+1+self.epsilon) + +class Random_shift_rt(): + pass + + +def load_data(base_dir, batch_size, shuffle=True, transform=None): + if transform is not None : + transform = transforms.Compose( + [transforms.ToTensor(), + transforms.Resize((224,224)), + Log_normalisation(), + transforms.Normalize(0.5, 0.5)]) + dataset = torchvision.datasets.ImageFolder(root=base_dir, transform=transform) + generator1 = torch.Generator().manual_seed(42) + data_train, data_test = random_split(dataset, [0.8, 0.2], generator=generator1) + data_loader_train = DataLoader( + dataset=data_train, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + data_loader_test = DataLoader( + dataset=data_test, + batch_size=batch_size, + shuffle=shuffle, + num_workers=0, + collate_fn=None, + pin_memory=False, + ) + + return data_loader_train, data_loader_test \ No newline at end of file diff --git a/image_processing/build_dataset.py b/image_processing/build_dataset.py index 2a127e0841ff7fef09c60301134ff6cb753da677..6f7c4f71cbd15ed21771b20a99dfcf8d8bde6ad6 100644 --- a/image_processing/build_dataset.py +++ b/image_processing/build_dataset.py @@ -72,6 +72,6 @@ def create_dataset(): np.save(directory_path + "/" + name + '_' + analyse + '.npy', mat) - +#TODO : train val test split if __name__ =='__main__' : - label = create_antibio_dataset() \ No newline at end of file + create_dataset() \ No newline at end of file diff --git a/main.py b/main.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..9c300061d4c3bea00ed1b7734324b9ad446b2d9a 100644 --- a/main.py +++ b/main.py @@ -0,0 +1,79 @@ +from config.config import load_args +from dataset.dataset import load_data +import torch +import torch.nn as nn +from models.model import Classification_model +import torch.optim as optim + + + + +def train(model, data_train, optimizer, loss_function): + model.train() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = True + + for im, label in data_train: + label = label.float() + if torch.cuda.is_available(): + im, label = im.cuda(), label.cuda() + pred_logits = model.forward(im) + pred_class = torch.argmax(pred_logits,dim=1) + acc += pred_class==label.float().sum() + loss = loss_function(pred_logits,label) + losses += loss.item() + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses = losses/data_train.size() + acc = acc/data_train.size() + print('loss : ',losses,' acc : ',acc) + +def test(model, data_train, loss_function): + model.test() + losses = 0. + acc = 0. + for param in model.parameters(): + param.requires_grad = False + + for im, label in data_train: + label = label.float() + if torch.cuda.is_available(): + im, label = im.cuda(), label.cuda() + pred_logits = model.forward(im) + pred_class = torch.argmax(pred_logits,dim=1) + acc += pred_class==label.float().sum() + loss = loss_function(pred_logits,label) + losses += loss.item() + losses = losses/data_train.size() + acc = acc/data_train.size() + print('loss : ',losses,' acc : ',acc) + +def run(args): + model = Classification_model + if args.pretrain_path is not None : + load_model(model,args.pretrain_path) + data_train, data_test = load_data(base_dir=args.dataset_dir, batch_size=args.batch_size) + loss_function = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + + for e in range(args.epoch): + train(model,data_train,optimizer,loss_function) + + if e%args.eval_inter==0 : + test(model,data_test,loss_function) + save_model(model,args.save_path) + +def save_model(model, path): + torch.save(model.state_dict(), path) + +def load_model(model, path): + model.load_state_dict(torch.load(path, weights_only=True)) + + + +if __name__ == '__main__': + args = load_args() + run(args) \ No newline at end of file diff --git a/models/model.py b/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..78f0771ec5a21d3217ea969c6ec0b8fd533b4102 --- /dev/null +++ b/models/model.py @@ -0,0 +1,271 @@ +import torch +import torch.nn as nn +import torchvision + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x): + return self._forward_impl(x) + + +def _resnet(block, layers, **kwargs): + model = ResNet(block, layers, **kwargs) + + return model + + +def resnet18(num_classes=1000,**kwargs): + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(BasicBlock, [2, 2, 2, 2],num_classes=num_classes, + **kwargs) + + + +def resnet34(,num_classes=1000, **kwargs): + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet( BasicBlock, [3, 4, 6, 3],num_classes=num_classes, + **kwargs) + + + +def resnet50(,num_classes=1000,**kwargs): + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(Bottleneck, [3, 4, 6, 3],num_classes=num_classes, + **kwargs) + + + +def resnet101(num_classes=1000,**kwargs): + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(Bottleneck, [3, 4, 23, 3],num_classes=num_classes, + **kwargs) + + + +def resnet152(num_classes=1000,**kwargs): + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ + + Args: + """ + return _resnet(Bottleneck, [3, 8, 36, 3],num_classes=num_classes, + **kwargs) + +class Classification_model(nn.Module): + + def __init__(self,n_class): + self.n_class = n_class + self.im_encoder = resnet18(num_classes=self.n_class) + + + def forward(self, input): + return self.im_encoder(input) \ No newline at end of file