diff --git a/Train.py b/Train.py new file mode 100644 index 0000000000000000000000000000000000000000..7dba228921df62203922016621d56343398bc86d --- /dev/null +++ b/Train.py @@ -0,0 +1,249 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"]="1" +import numpy as np +import sys +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +import torch.nn.init as init +import torch.utils.data as data +import torch.utils.data.dataset as dataset +import torchvision.datasets as dset +import torchvision.transforms as transforms +from torch.autograd import Variable +import torchvision.utils as v_utils +import matplotlib.pyplot as plt +from tqdm.autonotebook import tqdm + +import cv2 +import math +from collections import OrderedDict +import copy +import time + +import data.utils as data_utils + +import utils +from models import * +import models.loss as loss +from torchsummary import summary +import argparse +#from torch.utils.tensorboard import SummaryWriter +from tensorboardX import SummaryWriter + +def main(): + torch.backends.cudnn.benchmark = True + print("--------------PyTorch VERSION:", torch.__version__) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print("..............device", device) + + parser = argparse.ArgumentParser(description="MemoryNormality") + parser.add_argument('--gpus', nargs='+', type=str, help='gpus') + parser.add_argument('--batch_size', type=int, default=16, help='batch size for training') + parser.add_argument('--val_batch_size', type=int, default=1, help='batch size for validation') + parser.add_argument('--epochs', type=int, default=150, help='number of epochs for training') + parser.add_argument('--val_epoch', type=int, default=1, help='evaluate the model every %d epoch') + parser.add_argument('--h', type=int, default=256, help='height of input images') + parser.add_argument('--w', type=int, default=256, help='width of input images') + parser.add_argument('--c', type=int, default=1, help='channel of input images') + parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate') + parser.add_argument('--t_length', type=int, default=16, help='length of the frame sequences') + parser.add_argument('--ModelName', help='AE/MemAE', type=str, default='AE') + parser.add_argument('--ModelSetting', help='Conv3D/Conv3DSpar',type=str, default='Conv3D') # give the layer details later + parser.add_argument('--MemDim', help='Memory Dimention', type=int, default=2000) + parser.add_argument('--EntropyLossWeight', help='EntropyLossWeight', type=float, default=0.0002) + parser.add_argument('--ShrinkThres', help='ShrinkThres', type=float, default=0.0025) + parser.add_argument('--Suffix', help='Suffix', type=str, default='Non') + parser.add_argument('--num_workers', type=int, default=16, help='number of workers for the train loader') + parser.add_argument('--num_workers_test', type=int, default=1, help='number of workers for the test loader') + parser.add_argument('--dataset_type', type=str, default='i_LIDS', help='type of dataset: UCSDped2, avenue, Shanghai') + parser.add_argument('--dataset_path', type=str, default='./dataset/', help='directory of data') + parser.add_argument('--exp_dir', type=str, default='log', help='directory of log') + parser.add_argument('--version', type=int, default=0, help='experiment version') + + args = parser.parse_args() + + torch.manual_seed(2020) + + torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance + + def arrange_image(im_input): + im_input = np.transpose(im_input, (0, 2, 1, 3, 4)) + b, t, ch, h, w = im_input.shape + im_input = np.reshape(im_input, [b * t, ch, h, w]) + return im_input + + train_folder, test_folder = data_utils.give_data_folder(args.dataset_type, args.dataset_path) + + print("The training path", train_folder) + print("The testing path", test_folder) + + frame_trans = data_utils.give_frame_trans(args.dataset_type, [args.h, args.w]) + + train_dataset = data_utils.DataLoader(train_folder, frame_trans, time_step=args.t_length - 1, num_pred=1) + test_dataset = data_utils.DataLoader(test_folder, frame_trans, time_step=args.t_length - 1, num_pred=1) + + + train_batch = data.DataLoader(train_dataset, batch_size = args.batch_size, + shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True) + test_batch = data.DataLoader(test_dataset, batch_size = args.val_batch_size, + shuffle=False, num_workers=args.num_workers_test, drop_last=True, pin_memory=True) + + print("Training data shape", len(train_batch)) + print("Validation data shape", len(test_batch)) + + # Model setting + + if (args.ModelName == 'AE'): + model = AutoEncoderCov3D(args.c) + elif(args.ModelName=='MemAE'): + model = AutoEncoderCov3DMem(args.c, args.MemDim, shrink_thres=args.ShrinkThres) + elif (args.ModelName == 'AE_conv_stride_jrnl'): + model = AECov3Dstrdjrnl(args.c) #AECov3Dstrdjrnld + elif (args.ModelName == 'AE_conv_stride_jrnld'): + model = AECov3Dstrdjrnld(args.c) #AECov3Dstrdjrnld + elif (args.ModelName == 'AE_conv_jrnld'): + model = AECov3Djrnld(args.c) #AECov3Dstrdjrnld + else: + model = [] + print('Wrong Name.') + + model.apply(utils.weights_init) + model = model.to(device) + # model = nn.DataParallel(model) + summary(model, (1, args.t_length, args.w, args.h)) + + for name, p in model.named_parameters(): + if not p.requires_grad: + print("---------NO GRADIENT-----", name) + + parameter_list = [p for p in model.parameters() if p.requires_grad] + optimizer = torch.optim.Adam(parameter_list, lr = args.lr, eps=1e-7, weight_decay=0.0) + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.5) # version 2 + + #scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max =args.epochs) + + # Report the training process + log_dir = os.path.join(args.exp_dir, args.dataset_type, 'lr_%.5f_entropyloss_%.5f_version_%d' % ( + args.lr, args.EntropyLossWeight, args.version)) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + # orig_stdout = sys.stdout + # f = open(os.path.join(log_dir, 'log.txt'),'w') + # sys.stdout= f + + for arg in vars(args): + print(arg, getattr(args, arg)) + + train_writer = SummaryWriter(log_dir=log_dir) + + # # warmup + # model.train() + # with torch.no_grad(): + # for batch_idx, frame in enumerate(train_batch): + # frame = frame.reshape([args.batch_size, args.t_length, args.c, args.h, args.w]) + # frame = frame.permute(0, 2, 1, 3, 4) + # frame = frame.to(device) + # model_output = model(frame) + + # Training + best_train_loss = 10000 + best_val_loss = 10000 + for epoch in range(args.epochs): + model.train() + tr_re_loss, tr_mem_loss, tr_tot = 0.0, 0.0, 0.0 + ts_tot = 0.0 + progress_bar = tqdm(train_batch) + + for batch_idx, frame in enumerate(progress_bar): + + progress_bar.update() + frame = frame.reshape([args.batch_size, args.t_length, args.c, args.h, args.w]) + frame = frame.permute(0, 2, 1, 3, 4) + frame = frame.to(device) + optimizer.zero_grad() + + if (args.ModelName == 'MemAE'): + model_output = model(frame) + recons, attr = model_output['output'], model_output['att'] + re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5) + mem_loss = loss.get_memory_loss(attr) + tot_loss = re_loss + mem_loss * args.EntropyLossWeight + tr_re_loss += re_loss.data.item() + tr_mem_loss += mem_loss.data.item() + tr_tot += tot_loss.data.item() + tot_loss.backward() + optimizer.step() + elif (args.ModelName == 'AE' or args.ModelName == 'AE_conv_stride_jrnl'or args.ModelName == 'AE_conv_stride_jrnld' or args.ModelName == 'AE_conv_jrnld'): + recons = model(frame) + re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5) + mem_loss = 0 + tot_loss = re_loss + mem_loss * args.EntropyLossWeight + tr_re_loss += re_loss.data.item() + tr_mem_loss += 0 # mem_loss.data.item() + tr_tot += tot_loss.data.item() + tot_loss.backward() + optimizer.step() + + train_writer.add_scalar("model/train-recons-loss", tr_re_loss / len(train_batch), epoch) + train_writer.add_scalar("model/train-memory-sparse", tr_mem_loss / len(train_batch), epoch) + train_writer.add_scalar("model/train-total-loss", tr_tot / len(train_batch), epoch) + scheduler.step() + + current_lr = optimizer.param_groups[0]['lr'] + train_writer.add_scalar('learning_rate', current_lr, epoch) + + if epoch % args.val_epoch == 0: + model.eval() + re_loss_val, mem_loss_val = 0.0, 0.0 + for batch_idx, frame in enumerate(test_batch): + frame = frame.reshape([args.val_batch_size, args.t_length, args.c, args.h, args.w]) + frame = frame.permute(0, 2, 1, 3, 4) + frame = frame.to(device) + if (args.ModelName == 'MemAE'): + model_output = model(frame) + recons, attr = model_output['output'], model_output['att'] + re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5) + mem_loss = loss.get_memory_loss(attr) + re_loss_val += re_loss.data.item() + mem_loss_val += mem_loss.data.item() + elif (args.ModelName == 'AE' or args.ModelName == 'AE_conv_stride_jrnl'or args.ModelName == 'AE_conv_stride_jrnld' or args.ModelName == 'AE_conv_jrnld'): + recons = model(frame) + re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5) + mem_loss = 0 + re_loss_val += re_loss.data.item() + mem_loss_val += 0 # mem_loss.data.item() + tot_loss_val = re_loss_val + mem_loss_val * args.EntropyLossWeight + #print(tot_loss_val) + ts_tot += tot_loss_val + + if ((batch_idx == 0) or (batch_idx == 10) or (batch_idx == len(test_batch) - 1) or ( + batch_idx == int(len(test_batch) / 2)) or (batch_idx == int(len(test_batch) / 4))): + _input_npy = frame.detach().cpu().numpy() + _input_npy = _input_npy * 0.5 + 0.5 + _recons_npy = recons.detach().cpu().numpy() + _recons_npy = _recons_npy * 0.5 + 0.5 # [batch_size, ch, time, imh, imw] + train_writer.add_images("image/input_image", arrange_image(_input_npy), epoch) + train_writer.add_images("image/reconstruction", arrange_image(_recons_npy), epoch) + train_writer.add_scalar("model/val-recons-loss", re_loss_val / len(test_batch), epoch) + train_writer.add_scalar("model/val-memory-sparse", mem_loss_val / len(test_batch), epoch) + print("epoch %d" % epoch, "total loss training %.4f validation %.4f" % (tr_tot/len(train_batch), ts_tot/len(test_batch)), + "recons loss training %.4f validation %.4f" % (tr_re_loss/len(train_batch), re_loss_val/len(test_batch)), + "memory sparsity training %.4f validation %.4f" % (tr_mem_loss/len(train_batch), mem_loss_val/len(test_batch))) + if (epoch % 10 == 0 or epoch == args.epochs - 1) and args.dataset_type =='i_LIDS': + torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict()}, log_dir + "/model-{:04d}.pt".format(epoch)) + if ((tr_tot / len(train_batch) < best_train_loss) and (ts_tot / len(test_batch) <= best_val_loss)): + print("Best model is at epoch ", epoch) + torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict()}, log_dir + "/best_model.pt") + best_train_loss = tr_tot / len(train_batch) + best_val_loss = ts_tot / len(test_batch) + + # sys.stdout = orig_stdout + # f.close() + +if __name__ == '__main__': + main()