-
Devashish Lohani authoredc9bd72b5
Train.py 12.11 KiB
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
import numpy as np
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torch.nn.init as init
import torch.utils.data as data
import torch.utils.data.dataset as dataset
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
import torchvision.utils as v_utils
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
import cv2
import math
from collections import OrderedDict
import copy
import time
import data.utils as data_utils
import utils
from models import *
import models.loss as loss
from torchsummary import summary
import argparse
#from torch.utils.tensorboard import SummaryWriter
from tensorboardX import SummaryWriter
def main():
torch.backends.cudnn.benchmark = True
print("--------------PyTorch VERSION:", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("..............device", device)
parser = argparse.ArgumentParser(description="MemoryNormality")
parser.add_argument('--gpus', nargs='+', type=str, help='gpus')
parser.add_argument('--batch_size', type=int, default=16, help='batch size for training')
parser.add_argument('--val_batch_size', type=int, default=1, help='batch size for validation')
parser.add_argument('--epochs', type=int, default=150, help='number of epochs for training')
parser.add_argument('--val_epoch', type=int, default=1, help='evaluate the model every %d epoch')
parser.add_argument('--h', type=int, default=256, help='height of input images')
parser.add_argument('--w', type=int, default=256, help='width of input images')
parser.add_argument('--c', type=int, default=1, help='channel of input images')
parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate')
parser.add_argument('--t_length', type=int, default=16, help='length of the frame sequences')
parser.add_argument('--ModelName', help='AE/MemAE', type=str, default='AE')
parser.add_argument('--ModelSetting', help='Conv3D/Conv3DSpar',type=str, default='Conv3D') # give the layer details later
parser.add_argument('--MemDim', help='Memory Dimention', type=int, default=2000)
parser.add_argument('--EntropyLossWeight', help='EntropyLossWeight', type=float, default=0.0002)
parser.add_argument('--ShrinkThres', help='ShrinkThres', type=float, default=0.0025)
parser.add_argument('--Suffix', help='Suffix', type=str, default='Non')
parser.add_argument('--num_workers', type=int, default=16, help='number of workers for the train loader')
parser.add_argument('--num_workers_test', type=int, default=1, help='number of workers for the test loader')
parser.add_argument('--dataset_type', type=str, default='i_LIDS', help='type of dataset: UCSDped2, avenue, Shanghai')
parser.add_argument('--dataset_path', type=str, default='./dataset/', help='directory of data')
parser.add_argument('--exp_dir', type=str, default='log', help='directory of log')
parser.add_argument('--version', type=int, default=0, help='experiment version')
args = parser.parse_args()
torch.manual_seed(2020)
torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
def arrange_image(im_input):
im_input = np.transpose(im_input, (0, 2, 1, 3, 4))
b, t, ch, h, w = im_input.shape
im_input = np.reshape(im_input, [b * t, ch, h, w])
return im_input
train_folder, test_folder = data_utils.give_data_folder(args.dataset_type, args.dataset_path)
print("The training path", train_folder)
print("The testing path", test_folder)
frame_trans = data_utils.give_frame_trans(args.dataset_type, [args.h, args.w])
train_dataset = data_utils.DataLoader(train_folder, frame_trans, time_step=args.t_length - 1, num_pred=1)
test_dataset = data_utils.DataLoader(test_folder, frame_trans, time_step=args.t_length - 1, num_pred=1)
train_batch = data.DataLoader(train_dataset, batch_size = args.batch_size,
shuffle=True, num_workers=args.num_workers, drop_last=True, pin_memory=True)
test_batch = data.DataLoader(test_dataset, batch_size = args.val_batch_size,
shuffle=False, num_workers=args.num_workers_test, drop_last=True, pin_memory=True)
print("Training data shape", len(train_batch))
print("Validation data shape", len(test_batch))
# Model setting
if (args.ModelName == 'AE'):
model = AutoEncoderCov3D(args.c)
elif(args.ModelName=='MemAE'):
model = AutoEncoderCov3DMem(args.c, args.MemDim, shrink_thres=args.ShrinkThres)
elif (args.ModelName == 'AE_conv_stride_jrnl'):
model = AECov3Dstrdjrnl(args.c) #AECov3Dstrdjrnld
elif (args.ModelName == 'AE_conv_stride_jrnld'):
model = AECov3Dstrdjrnld(args.c) #AECov3Dstrdjrnld
elif (args.ModelName == 'AE_conv_jrnld'):
model = AECov3Djrnld(args.c) #AECov3Dstrdjrnld
else:
model = []
print('Wrong Name.')
model.apply(utils.weights_init)
model = model.to(device)
# model = nn.DataParallel(model)
summary(model, (1, args.t_length, args.w, args.h))
for name, p in model.named_parameters():
if not p.requires_grad:
print("---------NO GRADIENT-----", name)
parameter_list = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(parameter_list, lr = args.lr, eps=1e-7, weight_decay=0.0)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.5) # version 2
#scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max =args.epochs)
# Report the training process
log_dir = os.path.join(args.exp_dir, args.dataset_type, 'lr_%.5f_entropyloss_%.5f_version_%d' % (
args.lr, args.EntropyLossWeight, args.version))
if not os.path.exists(log_dir):
os.makedirs(log_dir)
# orig_stdout = sys.stdout
# f = open(os.path.join(log_dir, 'log.txt'),'w')
# sys.stdout= f
for arg in vars(args):
print(arg, getattr(args, arg))
train_writer = SummaryWriter(log_dir=log_dir)
# # warmup
# model.train()
# with torch.no_grad():
# for batch_idx, frame in enumerate(train_batch):
# frame = frame.reshape([args.batch_size, args.t_length, args.c, args.h, args.w])
# frame = frame.permute(0, 2, 1, 3, 4)
# frame = frame.to(device)
# model_output = model(frame)
# Training
best_train_loss = 10000
best_val_loss = 10000
for epoch in range(args.epochs):
model.train()
tr_re_loss, tr_mem_loss, tr_tot = 0.0, 0.0, 0.0
ts_tot = 0.0
progress_bar = tqdm(train_batch)
for batch_idx, frame in enumerate(progress_bar):
progress_bar.update()
frame = frame.reshape([args.batch_size, args.t_length, args.c, args.h, args.w])
frame = frame.permute(0, 2, 1, 3, 4)
frame = frame.to(device)
optimizer.zero_grad()
if (args.ModelName == 'MemAE'):
model_output = model(frame)
recons, attr = model_output['output'], model_output['att']
re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
mem_loss = loss.get_memory_loss(attr)
tot_loss = re_loss + mem_loss * args.EntropyLossWeight
tr_re_loss += re_loss.data.item()
tr_mem_loss += mem_loss.data.item()
tr_tot += tot_loss.data.item()
tot_loss.backward()
optimizer.step()
elif (args.ModelName == 'AE' or args.ModelName == 'AE_conv_stride_jrnl'or args.ModelName == 'AE_conv_stride_jrnld' or args.ModelName == 'AE_conv_jrnld'):
recons = model(frame)
re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
mem_loss = 0
tot_loss = re_loss + mem_loss * args.EntropyLossWeight
tr_re_loss += re_loss.data.item()
tr_mem_loss += 0 # mem_loss.data.item()
tr_tot += tot_loss.data.item()
tot_loss.backward()
optimizer.step()
train_writer.add_scalar("model/train-recons-loss", tr_re_loss / len(train_batch), epoch)
train_writer.add_scalar("model/train-memory-sparse", tr_mem_loss / len(train_batch), epoch)
train_writer.add_scalar("model/train-total-loss", tr_tot / len(train_batch), epoch)
scheduler.step()
current_lr = optimizer.param_groups[0]['lr']
train_writer.add_scalar('learning_rate', current_lr, epoch)
if epoch % args.val_epoch == 0:
model.eval()
re_loss_val, mem_loss_val = 0.0, 0.0
for batch_idx, frame in enumerate(test_batch):
frame = frame.reshape([args.val_batch_size, args.t_length, args.c, args.h, args.w])
frame = frame.permute(0, 2, 1, 3, 4)
frame = frame.to(device)
if (args.ModelName == 'MemAE'):
model_output = model(frame)
recons, attr = model_output['output'], model_output['att']
re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
mem_loss = loss.get_memory_loss(attr)
re_loss_val += re_loss.data.item()
mem_loss_val += mem_loss.data.item()
elif (args.ModelName == 'AE' or args.ModelName == 'AE_conv_stride_jrnl'or args.ModelName == 'AE_conv_stride_jrnld' or args.ModelName == 'AE_conv_jrnld'):
recons = model(frame)
re_loss = loss.get_reconstruction_loss(frame, recons, mean=0.5, std=0.5)
mem_loss = 0
re_loss_val += re_loss.data.item()
mem_loss_val += 0 # mem_loss.data.item()
tot_loss_val = re_loss_val + mem_loss_val * args.EntropyLossWeight
#print(tot_loss_val)
ts_tot += tot_loss_val
if ((batch_idx == 0) or (batch_idx == 10) or (batch_idx == len(test_batch) - 1) or (
batch_idx == int(len(test_batch) / 2)) or (batch_idx == int(len(test_batch) / 4))):
_input_npy = frame.detach().cpu().numpy()
_input_npy = _input_npy * 0.5 + 0.5
_recons_npy = recons.detach().cpu().numpy()
_recons_npy = _recons_npy * 0.5 + 0.5 # [batch_size, ch, time, imh, imw]
train_writer.add_images("image/input_image", arrange_image(_input_npy), epoch)
train_writer.add_images("image/reconstruction", arrange_image(_recons_npy), epoch)
train_writer.add_scalar("model/val-recons-loss", re_loss_val / len(test_batch), epoch)
train_writer.add_scalar("model/val-memory-sparse", mem_loss_val / len(test_batch), epoch)
print("epoch %d" % epoch, "total loss training %.4f validation %.4f" % (tr_tot/len(train_batch), ts_tot/len(test_batch)),
"recons loss training %.4f validation %.4f" % (tr_re_loss/len(train_batch), re_loss_val/len(test_batch)),
"memory sparsity training %.4f validation %.4f" % (tr_mem_loss/len(train_batch), mem_loss_val/len(test_batch)))
if (epoch % 10 == 0 or epoch == args.epochs - 1) and args.dataset_type =='i_LIDS':
torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()}, log_dir + "/model-{:04d}.pt".format(epoch))
if ((tr_tot / len(train_batch) < best_train_loss) and (ts_tot / len(test_batch) <= best_val_loss)):
print("Best model is at epoch ", epoch)
torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()}, log_dir + "/best_model.pt")
best_train_loss = tr_tot / len(train_batch)
best_val_loss = ts_tot / len(test_batch)
# sys.stdout = orig_stdout
# f.close()
if __name__ == '__main__':
main()