diff --git a/models/ae_3dconv_stride_jrnl.py b/models/ae_3dconv_stride_jrnl.py new file mode 100644 index 0000000000000000000000000000000000000000..a374ccc488c4e0cebca2589cd16ce74de4dce8b7 --- /dev/null +++ b/models/ae_3dconv_stride_jrnl.py @@ -0,0 +1,69 @@ +from __future__ import absolute_import, print_function +from torchsummary import summary +import os + +#os.environ["CUDA_VISIBLE_DEVICES"] = "1" +import torch +from torch import nn + + +class AECov3Dstrdjrnl(nn.Module): + def __init__(self, chnum_in): + super(AECov3Dstrdjrnl, self).__init__() + self.chnum_in = chnum_in # input channel number is 1; + feature_num = 8 # inc: 12 , dec: 8, inc: 16 + feature_num_2 = 16 # inc: 8 , dec: 16, inc: 8 + + self.encoder = nn.Sequential( + nn.Conv3d(self.chnum_in, feature_num_2, (5,3,3), stride=(2, 2, 2), padding=(2, 1, 1)), + nn.ReLU(inplace=True), + nn.Dropout(p=0.25), + nn.Conv3d(feature_num_2, feature_num, (5,3,3), stride=(2,2,2), padding=(2,1,1)), + nn.ReLU(inplace=True) + ) + self.decoder = nn.Sequential( + nn.ConvTranspose3d(feature_num, feature_num, (5,3,3), stride=(2,2,2), padding=(2,1,1), output_padding=(1,1,1)), + nn.ReLU(inplace=True), + nn.ConvTranspose3d(feature_num, feature_num_2, (5, 3, 3), stride=(2, 2, 2), padding=(2, 1, 1), output_padding=(1, 1, 1)), + nn.ReLU(inplace=True), + nn.ConvTranspose3d(feature_num_2, self.chnum_in, (5,3,3), stride=(1,1,1), padding=(2,1,1), output_padding=(0,0,0)), + nn.Tanh() + ) + ''' + self.encoder = nn.Sequential( + nn.Conv3d(self.chnum_in, feature_num_2, (5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1)), + nn.ReLU(inplace=True), + nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)), + nn.Dropout(p=0.25, inplace=True), + + nn.Conv3d(feature_num_2, feature_num, (5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1)), + nn.ReLU(inplace=True), + nn.MaxPool3d((2, 2, 2), stride=(2, 2, 2)) + ) + self.decoder = nn.Sequential( + nn.Conv3d(feature_num, feature_num, (5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1)), + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=(2, 2, 2)), + + nn.Conv3d(feature_num, feature_num_2, (5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1)), + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=(2, 2, 2)), + + nn.Conv3d(feature_num_2, self.chnum_in, (5, 3, 3), stride=(1, 1, 1), padding=(2, 1, 1)), + nn.Tanh() + )''' + + def forward(self, x): + f = self.encoder(x) + out = self.decoder(f) + # out = out[:, :, 5:20, 5:20] + return out + + +if __name__ == '__main__': + device = torch.device("cuda" if True else "cpu") + chnum_in_ = 1 + model = AECov3Dstrdjrnl(chnum_in_) + model.to(device) + + print(summary(model, (1, 8, 692, 560)))