diff --git a/Train.py b/Train.py
new file mode 100644
index 0000000000000000000000000000000000000000..7dba228921df62203922016621d56343398bc86d
--- /dev/null
+++ b/Train.py
@@ -0,0 +1,249 @@
+import os
+os.environ["CUDA_VISIBLE_DEVICES"]="1"
+import numpy as np
+import sys
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchvision
+import torch.nn.init as init
+import torch.utils.data as data
+import torch.utils.data.dataset as dataset
+import torchvision.datasets as dset
+import torchvision.transforms as transforms
+from torch.autograd import Variable
+import torchvision.utils as v_utils
+import matplotlib.pyplot as plt
+from tqdm.autonotebook import tqdm
+
+import cv2
+import math
+from collections import OrderedDict
+import copy
+import time
+
+import data.utils as data_utils
+
+import utils
+from models import *
+import models.loss as loss
+from torchsummary import summary
+import argparse
+#from torch.utils.tensorboard import SummaryWriter
+from tensorboardX import SummaryWriter
+
+def main():
+    torch.backends.cudnn.benchmark = True
+    print("--------------PyTorch VERSION:", torch.__version__)
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    print("..............device", device)
+
+    parser = argparse.ArgumentParser(description="MemoryNormality")
+    parser.add_argument('--gpus', nargs='+', type=str, help='gpus')
+    parser.add_argument('--batch_size', type=int, default=16, help='batch size for training')
+    parser.add_argument('--val_batch_size', type=int, default=1, help='batch size for validation')
+    parser.add_argument('--epochs', type=int, default=150, help='number of epochs for training')
+    parser.add_argument('--val_epoch', type=int, default=1, help='evaluate the model every %d epoch')
+    parser.add_argument('--h', type=int, default=256, help='height of input images')
+    parser.add_argument('--w', type=int, default=256, help='width of input images')
+    parser.add_argument('--c', type=int, default=1, help='channel of input images')
+    parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate')
+    parser.add_argument('--t_length', type=int, default=16, help='length of the frame sequences')
+    parser.add_argument('--ModelName', help='AE/MemAE', type=str, default='AE')
+    parser.add_argument('--ModelSetting', help='Conv3D/Conv3DSpar',type=str, default='Conv3D')  # give the layer details later
+    parser.add_argument('--MemDim', help='Memory Dimention', type=int, default=2000)
+    parser.add_argument('--EntropyLossWeight', help='EntropyLossWeight', type=float, default=0.0002)
+    parser.add_argument('--ShrinkThres', help='ShrinkThres', type=float, default=0.0025)
+    parser.add_argument('--Suffix', help='Suffix', type=str, default='Non')
+    parser.add_argument('--num_workers', type=int, default=16, help='number of workers for the train loader')
+    parser.add_argument('--num_workers_test', type=int, default=1, help='number of workers for the test loader')
+    parser.add_argument('--dataset_type', type=str, default='i_LIDS', help='type of dataset: UCSDped2, avenue, Shanghai')
+    parser.add_argument('--dataset_path', type=str, default='./dataset/', help='directory of data')
+    parser.add_argument('--exp_dir', type=str, default='log', help='directory of log')
+    parser.add_argument('--version', type=int, default=0, help='experiment version')
+
+    args = parser.parse_args()
+
+    torch.manual_seed(2020)
+
+    torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
+
+    def arrange_image(im_input):
+        im_input = np.transpose(im_input, (0, 2, 1, 3, 4))
+        b, t, ch, h, w = im_input.shape
+        im_input = np.reshape(im_input, [b * t, ch, h, w])
+        return im_input
+
+    train_folder, test_folder = data_utils.give_data_folder(args.dataset_type, args.dataset_path)
+
+    print("The training path", train_folder)
+    print("The testing path", test_folder)
+
+    frame_trans = data_utils.give_frame_trans(args.dataset_type, [args.h, args.w])
+
+    train_dataset = data_utils.DataLoader(train_folder, frame_trans, time_step=args.t_length - 1, num_pred=1)
+    test_dataset = data_utils.DataLoader(test_folder, frame_trans, time_step=args.t_length - 1, num_pred=1)
+
+
+    train_batch = data.DataLoader(train_dataset, batch_size = args.batch_size,
+                                  shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True)
+    test_batch = data.DataLoader(test_dataset, batch_size = args.val_batch_size,
+                                 shuffle=False, num_workers=args.num_workers_test, drop_last=True, pin_memory=True)
+
+    print("Training data shape", len(train_batch))
+    print("Validation data shape", len(test_batch))
+
+    # Model setting
+
+    if (args.ModelName == 'AE'):
+        model = AutoEncoderCov3D(args.c)
+    elif(args.ModelName=='MemAE'):
+        model = AutoEncoderCov3DMem(args.c, args.MemDim, shrink_thres=args.ShrinkThres)
+    elif (args.ModelName == 'AE_conv_stride_jrnl'):
+        model = AECov3Dstrdjrnl(args.c) #AECov3Dstrdjrnld
+    elif (args.ModelName == 'AE_conv_stride_jrnld'):
+        model = AECov3Dstrdjrnld(args.c) #AECov3Dstrdjrnld
+    elif (args.ModelName == 'AE_conv_jrnld'):
+        model = AECov3Djrnld(args.c) #AECov3Dstrdjrnld
+    else:
+        model = []
+        print('Wrong Name.')
+
+    model.apply(utils.weights_init)
+    model = model.to(device)
+    # model = nn.DataParallel(model)
+    summary(model, (1, args.t_length, args.w, args.h))
+
+    for name, p in model.named_parameters():
+        if not p.requires_grad:
+            print("---------NO GRADIENT-----", name)
+
+    parameter_list = [p for p in model.parameters() if p.requires_grad]
+    optimizer = torch.optim.Adam(parameter_list, lr = args.lr, eps=1e-7, weight_decay=0.0)
+    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.5)  # version 2
+
+    #scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max =args.epochs)
+
+    # Report the training process
+    log_dir = os.path.join(args.exp_dir, args.dataset_type, 'lr_%.5f_entropyloss_%.5f_version_%d' % (
+        args.lr, args.EntropyLossWeight, args.version))
+    if not os.path.exists(log_dir):
+        os.makedirs(log_dir)
+    # orig_stdout = sys.stdout
+    # f = open(os.path.join(log_dir, 'log.txt'),'w')
+    # sys.stdout= f
+
+    for arg in vars(args):
+        print(arg, getattr(args, arg))
+
+    train_writer = SummaryWriter(log_dir=log_dir)
+
+    # # warmup
+    # model.train()
+    # with torch.no_grad():
+    #     for batch_idx, frame in enumerate(train_batch):
+    #         frame = frame.reshape([args.batch_size, args.t_length, args.c, args.h, args.w])
+    #         frame = frame.permute(0, 2, 1, 3, 4)
+    #         frame = frame.to(device)
+    #         model_output = model(frame)
+
+    # Training
+    best_train_loss = 10000 
+    best_val_loss = 10000
+    for epoch in range(args.epochs):
+        model.train()
+        tr_re_loss, tr_mem_loss, tr_tot = 0.0, 0.0, 0.0
+        ts_tot = 0.0
+        progress_bar = tqdm(train_batch)
+
+        for batch_idx, frame in enumerate(progress_bar):
+
+            progress_bar.update()
+            frame = frame.reshape([args.batch_size, args.t_length, args.c, args.h, args.w])
+            frame = frame.permute(0, 2, 1, 3, 4)
+            frame = frame.to(device)
+            optimizer.zero_grad()
+
+            if (args.ModelName == 'MemAE'):
+                model_output = model(frame)
+                recons, attr = model_output['output'], model_output['att']
+                re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
+                mem_loss = loss.get_memory_loss(attr)
+                tot_loss = re_loss + mem_loss * args.EntropyLossWeight
+                tr_re_loss += re_loss.data.item()
+                tr_mem_loss += mem_loss.data.item()
+                tr_tot += tot_loss.data.item()
+                tot_loss.backward()
+                optimizer.step()
+            elif (args.ModelName == 'AE' or args.ModelName == 'AE_conv_stride_jrnl'or args.ModelName == 'AE_conv_stride_jrnld' or args.ModelName == 'AE_conv_jrnld'):
+                recons = model(frame)
+                re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
+                mem_loss = 0
+                tot_loss = re_loss + mem_loss * args.EntropyLossWeight
+                tr_re_loss += re_loss.data.item()
+                tr_mem_loss += 0  # mem_loss.data.item()
+                tr_tot += tot_loss.data.item()
+                tot_loss.backward()
+                optimizer.step()
+
+        train_writer.add_scalar("model/train-recons-loss", tr_re_loss / len(train_batch), epoch)
+        train_writer.add_scalar("model/train-memory-sparse", tr_mem_loss / len(train_batch), epoch)
+        train_writer.add_scalar("model/train-total-loss", tr_tot / len(train_batch), epoch)
+        scheduler.step()
+
+        current_lr = optimizer.param_groups[0]['lr']
+        train_writer.add_scalar('learning_rate', current_lr, epoch)
+
+        if epoch % args.val_epoch == 0:
+            model.eval()
+            re_loss_val, mem_loss_val = 0.0, 0.0
+            for batch_idx, frame in enumerate(test_batch):
+                frame = frame.reshape([args.val_batch_size, args.t_length, args.c, args.h, args.w])
+                frame = frame.permute(0, 2, 1, 3, 4)
+                frame = frame.to(device)
+                if (args.ModelName == 'MemAE'):
+                    model_output = model(frame)
+                    recons, attr = model_output['output'], model_output['att']
+                    re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
+                    mem_loss = loss.get_memory_loss(attr)
+                    re_loss_val += re_loss.data.item()
+                    mem_loss_val += mem_loss.data.item()
+                elif (args.ModelName == 'AE' or args.ModelName == 'AE_conv_stride_jrnl'or args.ModelName == 'AE_conv_stride_jrnld' or args.ModelName == 'AE_conv_jrnld'):
+                    recons = model(frame)
+                    re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
+                    mem_loss = 0
+                    re_loss_val += re_loss.data.item()
+                    mem_loss_val += 0  # mem_loss.data.item()
+                    tot_loss_val = re_loss_val + mem_loss_val * args.EntropyLossWeight
+                    #print(tot_loss_val)
+                    ts_tot += tot_loss_val
+
+                if ((batch_idx == 0) or (batch_idx == 10) or (batch_idx == len(test_batch) - 1) or (
+                        batch_idx == int(len(test_batch) / 2)) or (batch_idx == int(len(test_batch) / 4))):
+                    _input_npy = frame.detach().cpu().numpy()
+                    _input_npy = _input_npy * 0.5 + 0.5
+                    _recons_npy = recons.detach().cpu().numpy()
+                    _recons_npy = _recons_npy * 0.5 + 0.5  # [batch_size, ch, time, imh, imw]
+                    train_writer.add_images("image/input_image", arrange_image(_input_npy), epoch)
+                    train_writer.add_images("image/reconstruction", arrange_image(_recons_npy), epoch)
+            train_writer.add_scalar("model/val-recons-loss", re_loss_val / len(test_batch), epoch)
+            train_writer.add_scalar("model/val-memory-sparse", mem_loss_val / len(test_batch), epoch)
+            print("epoch %d" % epoch, "total loss training %.4f validation %.4f" % (tr_tot/len(train_batch), ts_tot/len(test_batch)),
+                  "recons loss training %.4f validation %.4f" % (tr_re_loss/len(train_batch), re_loss_val/len(test_batch)),
+                  "memory sparsity training %.4f validation %.4f" % (tr_mem_loss/len(train_batch), mem_loss_val/len(test_batch)))
+            if (epoch % 10 == 0 or epoch == args.epochs - 1) and args.dataset_type =='i_LIDS':
+                torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
+                            'optimizer_state_dict': optimizer.state_dict()}, log_dir + "/model-{:04d}.pt".format(epoch))
+            if ((tr_tot / len(train_batch) < best_train_loss) and (ts_tot / len(test_batch) <= best_val_loss)):
+                print("Best model is at epoch ", epoch)
+                torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
+                            'optimizer_state_dict': optimizer.state_dict()}, log_dir + "/best_model.pt")
+                best_train_loss = tr_tot / len(train_batch)
+                best_val_loss = ts_tot / len(test_batch)
+
+    # sys.stdout = orig_stdout
+    # f.close()
+
+if __name__ == '__main__':
+    main()