Skip to content
Snippets Groups Projects
Commit 279fbb46 authored by Devashish Lohani's avatar Devashish Lohani
Browse files

Upload New File

parent b45ba41c
No related branches found
No related tags found
No related merge requests found
from __future__ import absolute_import, print_function
from torchsummary import summary
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
from torch import nn
class AECov3Dstrdjrnl(nn.Module):
def __init__(self, chnum_in):
super(AECov3Dstrdjrnl, self).__init__()
self.chnum_in = chnum_in # input channel number is 1;
feature_num = 8 # inc: 12 , dec: 8, inc: 16
feature_num_2 = 16 # inc: 8 , dec: 16, inc: 8
self.encoder = nn.Sequential(
nn.Conv3d(self.chnum_in, feature_num_2, (5,3,3), stride=(2, 2, 2), padding=(2, 1, 1)),
nn.ReLU(inplace=True),
nn.Dropout(p=0.25),
nn.Conv3d(feature_num_2, feature_num, (5,3,3), stride=(2,2,2), padding=(2,1,1)),
nn.ReLU(inplace=True)
)
self.decoder = nn.Sequential(
nn.ConvTranspose3d(feature_num, feature_num, (5,3,3), stride=(2,2,2), padding=(2,1,1), output_padding=(1,1,1)),
nn.ReLU(inplace=True),
nn.ConvTranspose3d(feature_num, feature_num_2, (5, 3, 3), stride=(2, 2, 2), padding=(2, 1, 1), output_padding=(1, 1, 1)),
nn.ReLU(inplace=True),
nn.ConvTranspose3d(feature_num_2, self.chnum_in, (5,3,3), stride=(1,1,1), padding=(2,1,1), output_padding=(0,0,0)),
nn.Tanh()
)
'''
self.encoder = nn.Sequential(
nn.Conv3d(self.chnum_in, feature_num_2, (5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1)),
nn.ReLU(inplace=True),
nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)),
nn.Dropout(p=0.25, inplace=True),
nn.Conv3d(feature_num_2, feature_num, (5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1)),
nn.ReLU(inplace=True),
nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2))
)
self.decoder = nn.Sequential(
nn.Conv3d(feature_num, feature_num, (5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1)),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=(2, 2, 2)),
nn.Conv3d(feature_num, feature_num_2, (5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1)),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=(2, 2, 2)),
nn.Conv3d(feature_num_2, self.chnum_in, (5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1)),
nn.Tanh()
)'''
def forward(self, x):
f = self.encoder(x)
out = self.decoder(f)
# out = out[:, :, 5:20, 5:20]
return out
if __name__ == '__main__':
device = torch.device("cuda" if True else "cpu")
chnum_in_ = 1
model = AECov3Dstrdjrnl(chnum_in_)
model.to(device)
print(summary(model, (1, 8, 692, 560)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment