Skip to content
Snippets Groups Projects
Train.py 12.11 KiB
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import numpy as np
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch.nn.init as init
import torch.utils.data as data
import torch.utils.data.dataset as dataset
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.utils as v_utils
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm

import cv2
import math
from collections import OrderedDict
import copy
import time

import data.utils as data_utils

import utils
from models import *
import models.loss as loss
from torchsummary import summary
import argparse
#from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter

def main():
    torch.backends.cudnn.benchmark = True
    print("--------------PyTorch VERSION:", torch.__version__)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("..............device", device)

    parser = argparse.ArgumentParser(description="MemoryNormality")
    parser.add_argument('--gpus', nargs='+', type=str, help='gpus')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size for training')
    parser.add_argument('--val_batch_size', type=int, default=1, help='batch size for validation')
    parser.add_argument('--epochs', type=int, default=150, help='number of epochs for training')
    parser.add_argument('--val_epoch', type=int, default=1, help='evaluate the model every %d epoch')
    parser.add_argument('--h', type=int, default=256, help='height of input images')
    parser.add_argument('--w', type=int, default=256, help='width of input images')
    parser.add_argument('--c', type=int, default=1, help='channel of input images')
    parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate')
    parser.add_argument('--t_length', type=int, default=16, help='length of the frame sequences')
    parser.add_argument('--ModelName', help='AE/MemAE', type=str, default='AE')
    parser.add_argument('--ModelSetting', help='Conv3D/Conv3DSpar',type=str, default='Conv3D')  # give the layer details later
    parser.add_argument('--MemDim', help='Memory Dimention', type=int, default=2000)
    parser.add_argument('--EntropyLossWeight', help='EntropyLossWeight', type=float, default=0.0002)
    parser.add_argument('--ShrinkThres', help='ShrinkThres', type=float, default=0.0025)
    parser.add_argument('--Suffix', help='Suffix', type=str, default='Non')
    parser.add_argument('--num_workers', type=int, default=16, help='number of workers for the train loader')
    parser.add_argument('--num_workers_test', type=int, default=1, help='number of workers for the test loader')
    parser.add_argument('--dataset_type', type=str, default='i_LIDS', help='type of dataset: UCSDped2, avenue, Shanghai')
    parser.add_argument('--dataset_path', type=str, default='./dataset/', help='directory of data')
    parser.add_argument('--exp_dir', type=str, default='log', help='directory of log')
    parser.add_argument('--version', type=int, default=0, help='experiment version')

    args = parser.parse_args()

    torch.manual_seed(2020)

    torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance

    def arrange_image(im_input):
        im_input = np.transpose(im_input, (0, 2, 1, 3, 4))
        b, t, ch, h, w = im_input.shape
        im_input = np.reshape(im_input, [b * t, ch, h, w])
        return im_input

    train_folder, test_folder = data_utils.give_data_folder(args.dataset_type, args.dataset_path)

    print("The training path", train_folder)
    print("The testing path", test_folder)

    frame_trans = data_utils.give_frame_trans(args.dataset_type, [args.h, args.w])

    train_dataset = data_utils.DataLoader(train_folder, frame_trans, time_step=args.t_length - 1, num_pred=1)
    test_dataset = data_utils.DataLoader(test_folder, frame_trans, time_step=args.t_length - 1, num_pred=1)


    train_batch = data.DataLoader(train_dataset, batch_size = args.batch_size,
                                  shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True)
    test_batch = data.DataLoader(test_dataset, batch_size = args.val_batch_size,
                                 shuffle=False, num_workers=args.num_workers_test, drop_last=True, pin_memory=True)

    print("Training data shape", len(train_batch))
    print("Validation data shape", len(test_batch))

    # Model setting

    if (args.ModelName == 'AE'):
        model = AutoEncoderCov3D(args.c)
    elif(args.ModelName=='MemAE'):
        model = AutoEncoderCov3DMem(args.c, args.MemDim, shrink_thres=args.ShrinkThres)
    elif (args.ModelName == 'AE_conv_stride_jrnl'):
        model = AECov3Dstrdjrnl(args.c) #AECov3Dstrdjrnld
    elif (args.ModelName == 'AE_conv_stride_jrnld'):
        model = AECov3Dstrdjrnld(args.c) #AECov3Dstrdjrnld
    elif (args.ModelName == 'AE_conv_jrnld'):
        model = AECov3Djrnld(args.c) #AECov3Dstrdjrnld
    else:
        model = []
        print('Wrong Name.')

    model.apply(utils.weights_init)
    model = model.to(device)
    # model = nn.DataParallel(model)
    summary(model, (1, args.t_length, args.w, args.h))

    for name, p in model.named_parameters():
        if not p.requires_grad:
            print("---------NO GRADIENT-----", name)

    parameter_list = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(parameter_list, lr = args.lr, eps=1e-7, weight_decay=0.0)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.5)  # version 2

    #scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max =args.epochs)

    # Report the training process
    log_dir = os.path.join(args.exp_dir, args.dataset_type, 'lr_%.5f_entropyloss_%.5f_version_%d' % (
        args.lr, args.EntropyLossWeight, args.version))
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    # orig_stdout = sys.stdout
    # f = open(os.path.join(log_dir, 'log.txt'),'w')
    # sys.stdout= f

    for arg in vars(args):
        print(arg, getattr(args, arg))

    train_writer = SummaryWriter(log_dir=log_dir)

    # # warmup
    # model.train()
    # with torch.no_grad():
    #     for batch_idx, frame in enumerate(train_batch):
    #         frame = frame.reshape([args.batch_size, args.t_length, args.c, args.h, args.w])
    #         frame = frame.permute(0, 2, 1, 3, 4)
    #         frame = frame.to(device)
    #         model_output = model(frame)

    # Training
    best_train_loss = 10000 
    best_val_loss = 10000
    for epoch in range(args.epochs):
        model.train()
        tr_re_loss, tr_mem_loss, tr_tot = 0.0, 0.0, 0.0
        ts_tot = 0.0
        progress_bar = tqdm(train_batch)

        for batch_idx, frame in enumerate(progress_bar):

            progress_bar.update()
            frame = frame.reshape([args.batch_size, args.t_length, args.c, args.h, args.w])
            frame = frame.permute(0, 2, 1, 3, 4)
            frame = frame.to(device)
            optimizer.zero_grad()

            if (args.ModelName == 'MemAE'):
                model_output = model(frame)
                recons, attr = model_output['output'], model_output['att']
                re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
                mem_loss = loss.get_memory_loss(attr)
                tot_loss = re_loss + mem_loss * args.EntropyLossWeight
                tr_re_loss += re_loss.data.item()
                tr_mem_loss += mem_loss.data.item()
                tr_tot += tot_loss.data.item()
                tot_loss.backward()
                optimizer.step()
            elif (args.ModelName == 'AE' or args.ModelName == 'AE_conv_stride_jrnl'or args.ModelName == 'AE_conv_stride_jrnld' or args.ModelName == 'AE_conv_jrnld'):
                recons = model(frame)
                re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
                mem_loss = 0
                tot_loss = re_loss + mem_loss * args.EntropyLossWeight
                tr_re_loss += re_loss.data.item()
                tr_mem_loss += 0  # mem_loss.data.item()
                tr_tot += tot_loss.data.item()
                tot_loss.backward()
                optimizer.step()

        train_writer.add_scalar("model/train-recons-loss", tr_re_loss / len(train_batch), epoch)
        train_writer.add_scalar("model/train-memory-sparse", tr_mem_loss / len(train_batch), epoch)
        train_writer.add_scalar("model/train-total-loss", tr_tot / len(train_batch), epoch)
        scheduler.step()

        current_lr = optimizer.param_groups[0]['lr']
        train_writer.add_scalar('learning_rate', current_lr, epoch)

        if epoch % args.val_epoch == 0:
            model.eval()
            re_loss_val, mem_loss_val = 0.0, 0.0
            for batch_idx, frame in enumerate(test_batch):
                frame = frame.reshape([args.val_batch_size, args.t_length, args.c, args.h, args.w])
                frame = frame.permute(0, 2, 1, 3, 4)
                frame = frame.to(device)
                if (args.ModelName == 'MemAE'):
                    model_output = model(frame)
                    recons, attr = model_output['output'], model_output['att']
                    re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
                    mem_loss = loss.get_memory_loss(attr)
                    re_loss_val += re_loss.data.item()
                    mem_loss_val += mem_loss.data.item()
                elif (args.ModelName == 'AE' or args.ModelName == 'AE_conv_stride_jrnl'or args.ModelName == 'AE_conv_stride_jrnld' or args.ModelName == 'AE_conv_jrnld'):
                    recons = model(frame)
                    re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
                    mem_loss = 0
                    re_loss_val += re_loss.data.item()
                    mem_loss_val += 0  # mem_loss.data.item()
                    tot_loss_val = re_loss_val + mem_loss_val * args.EntropyLossWeight
                    #print(tot_loss_val)
                    ts_tot += tot_loss_val

                if ((batch_idx == 0) or (batch_idx == 10) or (batch_idx == len(test_batch) - 1) or (
                        batch_idx == int(len(test_batch) / 2)) or (batch_idx == int(len(test_batch) / 4))):
                    _input_npy = frame.detach().cpu().numpy()
                    _input_npy = _input_npy * 0.5 + 0.5
                    _recons_npy = recons.detach().cpu().numpy()
                    _recons_npy = _recons_npy * 0.5 + 0.5  # [batch_size, ch, time, imh, imw]
                    train_writer.add_images("image/input_image", arrange_image(_input_npy), epoch)
                    train_writer.add_images("image/reconstruction", arrange_image(_recons_npy), epoch)
            train_writer.add_scalar("model/val-recons-loss", re_loss_val / len(test_batch), epoch)
            train_writer.add_scalar("model/val-memory-sparse", mem_loss_val / len(test_batch), epoch)
            print("epoch %d" % epoch, "total loss training %.4f validation %.4f" % (tr_tot/len(train_batch), ts_tot/len(test_batch)),
                  "recons loss training %.4f validation %.4f" % (tr_re_loss/len(train_batch), re_loss_val/len(test_batch)),
                  "memory sparsity training %.4f validation %.4f" % (tr_mem_loss/len(train_batch), mem_loss_val/len(test_batch)))
            if (epoch % 10 == 0 or epoch == args.epochs - 1) and args.dataset_type =='i_LIDS':
                torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict()}, log_dir + "/model-{:04d}.pt".format(epoch))
            if ((tr_tot / len(train_batch) < best_train_loss) and (ts_tot / len(test_batch) <= best_val_loss)):
                print("Best model is at epoch ", epoch)
                torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict()}, log_dir + "/best_model.pt")
                best_train_loss = tr_tot / len(train_batch)
                best_val_loss = ts_tot / len(test_batch)

    # sys.stdout = orig_stdout
    # f.close()

if __name__ == '__main__':
    main()