Skip to content
Snippets Groups Projects
Commit 6e4a0502 authored by Devashish Lohani's avatar Devashish Lohani
Browse files

Upload New File

parent b178ec76
No related branches found
No related tags found
No related merge requests found
import os
#os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
from torch import nn
import sys
import numpy as np
import time
from torchvision import transforms
from torch.utils.data import DataLoader
import utils
from models import *
import data.utils as data_utils
import argparse
from tqdm import tqdm
import utils.eval as eval_utils
from torchsummary import summary
from natsort import natsorted
from collections import OrderedDict
from utils import tensor2numpy, window_zoned_mse, frame_zoned_mse
#import antialiased_cnns
# from network import *
# from network.models import SaliencyNetwork
def main():
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# C = 10 # example feature channel size
# blurpool = antialiased_cnns.BlurPool3D(C, stride=2, stride_tuple=(1,2,2)) # BlurPool layer; use to downsample a feature map
# ex_tens = torch.Tensor(1, C, 8, 128, 128)
# print(blurpool(ex_tens).shape) # 1xCx64x64 tensor
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(description="Memorizing_Normality")
parser.add_argument('--h', type=int, default=256, help='height of input images')
parser.add_argument('--w', type=int, default=256, help='width of input images')
parser.add_argument('--t_length', type=int, default=16, help='length of the frame sequences')
parser.add_argument('--test_batch_size', type=int, default=1, help='batch size for testing')
parser.add_argument('--num_workers_test', type=int, default=1, help='number of workers for the test loader')
parser.add_argument('--ModelName', help='AE/MemAE/AE_max_unpool', type=str, default='AE')
parser.add_argument('--ImgChnNum', help='image channel', type=int, default=1)
parser.add_argument('--dataset_type', type=str, default="UCSDped2")
parser.add_argument("--dataset_path", type=str, default='./dataset/')
parser.add_argument('--MemDim', help='Memory Dimention', type=int, default=2000)
parser.add_argument("--version", type=int, default=1)
parser.add_argument("--ckpt_step", type=int, default=39)
parser.add_argument("--EntropyLossWeight", type=float, default=0.0002)
parser.add_argument('--ShrinkThres', help='ShrinkThres', type=float, default=0.0025)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--exp_dir", type=str, default='/ckpt/')
parser.add_argument("--used_dataparallel", type=bool, default=False)
parser.add_argument('--backbone_name', type=str, default='resnet152_csn_ir', help='directory of data')
parser.add_argument('--backbone_pre', type=str, default='pre_trained/irCSN_152_ft_kinetics_from_ig65m_f126851907.pth', help='directory of data')
parser.add_argument('--backbone_freeze', type=str, default='True', help='directory of data')
parser.add_argument('--model_freeze', type=str, default='True', help='directory of data')
args = parser.parse_args()
ModelName = args.ModelName
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cuda")
print("..............device", device)
height, width = args.h, args.w
chnum_in_ = args.ImgChnNum
num_frame = args.t_length
batch_size = args.test_batch_size
test_gt_dir = args.dataset_path + args.dataset_type # TODO i added
if "ckpt/" in args.exp_dir:
model_dir = "ckpt/%s/" % args.dataset_type
else:
model_dir = args.exp_dir + '%s/lr_%.5f_entropyloss_%.5f_version_%d/' % (args.dataset_type,
args.lr,
args.EntropyLossWeight, args.version)
# orig_stdout = sys.stdout
# f = open(os.path.join(model_dir, 'output_%s_%d.txt' % ("original_1.00", args.ckpt_step)),'w')
# sys.stdout= f
if args.dataset_type == "Avenue":
data_dir = args.dataset_path + "Avenue/frames/testing/"
elif "UCSD" in args.dataset_type:
data_dir = args.dataset_path + "%s/Test_jpg/" % args.dataset_type
elif args.dataset_type == "i_LIDS" or args.dataset_type == "Shanghai":
data_dir = args.dataset_path + args.dataset_type + "/testing/frames/"
data_dir = "/media/dev_liris/DATA/data_share/Datasets/i_LIDS/testing/frames/" #only for i_LIDS
data_dir = [data_dir + vf for vf in natsorted(os.listdir(data_dir))]
else:
print("The dataset is not available..........")
pass
if args.dataset_type == "Shanghai":
test_gt_dir = test_gt_dir + "/testing/test_frame_mask/"
test_gt_dir = [test_gt_dir + vf for vf in natsorted(os.listdir(test_gt_dir))]
ckpt_dir = model_dir + "model-{:04d}.pt".format(args.ckpt_step)
print("Using Checkpoint ",ckpt_dir)
#ckpt_dir = "C:/Users/Administrator/Documents/3dc-seg/ckp/bmvc_final.pth"
gt_file = "ckpt/%s_gt.npy" % (args.dataset_type) #TODO i did
# gt_file = args.dataset_path + "%s/gt_label.npy" % args.dataset_type
re_file_path = model_dir + "recons_error_original_1.0_%d.npy" % args.ckpt_step
if not os.path.isfile(gt_file):
print("Attention!! No gt file exists..")
if (args.dataset_type=='Shanghai'):
print("Creating GT npy for {} dataset".format(args.dataset_type))
all_test_labels = []
for test_gt_file in test_gt_dir:
test_file_label = np.load(test_gt_file, allow_pickle=True)
all_test_labels.append(test_file_label)
all_test_labels = np.array(all_test_labels)
np.save(gt_file, all_test_labels)
# if os.path.isfile(gt_file) and os.path.isfile(re_file_path):
# print("Evaluating: using gt and pred files...")
# recons_error = np.load(re_file_path)
# eval_utils.eval_video2(gt_file, recons_error, args.dataset_type)
# exit()
if(chnum_in_==1):
norm_mean = [0.5]
norm_std = [0.5]
elif(chnum_in_==3):
norm_mean = (0.5, 0.5, 0.5)
norm_std = (0.5, 0.5, 0.5)
frame_trans = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize([height, width]),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
if args.dataset_type == "i_LIDS":
frame_trans = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Equalize(),
transforms.Resize([height, width]),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)
])
unorm_trans = utils.UnNormalize(mean=norm_mean, std=norm_std)
print("------Data folder", data_dir)
print("------Model folder", model_dir)
print("------Restored ckpt", ckpt_dir)
data_loader = data_utils.DataLoader(data_dir, frame_trans, time_step=num_frame-1, num_pred=1)
video_data_loader = DataLoader(data_loader, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers_test)
chnum_in_ = 1
mem_dim_in = args.MemDim
sparse_shrink_thres = args.ShrinkThres
if (args.ModelName == 'AE'):
model = AutoEncoderCov3D(chnum_in_)
# elif (args.ModelName == 'AECov3Dref'):
# model = SaliencyNetwork(chnum_in_, args.t_length, args.backbone_name, args.backbone_pre, args.backbone_freeze, args.model_freeze)
# elif (opt.ModelName == 'AE_max_unpool'):
# model = AECov3DMaxUnpool(chnum_in_)
# elif (opt.ModelName == 'AE_avg_unpool'):
# model = AECov3DAvgUnpool(chnum_in_)
elif (args.ModelName == 'AECov3DAvgMem'):
model = AECov3DAvgMem(chnum_in_, mem_dim_in, shrink_thres=sparse_shrink_thres)
elif (args.ModelName == 'AEhab'):
model = AECov3DHab(chnum_in_)
elif (args.ModelName == 'AECov3DBlur'):
model = AECov3DBlur(chnum_in_)
elif (args.ModelName == 'AE_conv_stride_jrnl'):
model = AECov3Dstrdjrnl(chnum_in_)
elif (args.ModelName == 'AE_conv_stride_jrnld'):
model = AECov3Dstrdjrnld(chnum_in_)
elif (args.ModelName == 'AE_conv_jrnl'):
model = AECov3Djrnl(chnum_in_)
elif (args.ModelName == 'AE_conv_jrnld'):
model = AECov3Djrnld(chnum_in_)
elif(args.ModelName=='AECov3DMaxUnMem'):
model = AECov3DMaxUnpoolMem(chnum_in_, mem_dim_in, shrink_thres=sparse_shrink_thres)
elif(args.ModelName == 'MemSC'):
model = AECov3DMemSC(chnum_in_, mem_dim_in, shrink_thres=sparse_shrink_thres)
elif(args.ModelName=='MemAE'):
model = AutoEncoderCov3DMem(chnum_in_, mem_dim_in, shrink_thres=sparse_shrink_thres)
else:
model = []
print('Wrong Name.')
#model = nn.DataParallel(model) #for i_LIDS
model_para = torch.load(ckpt_dir)
# if(args.used_dataparallel):
# ##==If data parallelisation, i.e., multi-gpus were used for training!!!====
# new_state_dict = OrderedDict()
# for k, v in model_para.items():
# name = k[7:] # remove `module.`
# new_state_dict[name] = v
# model_para = new_state_dict
# del new_state_dict
# )#, strict= False) #strict: to ignore if keys are missing
if args.dataset_type == "i_LIDS":
model.load_state_dict(model_para)#['model_state_dict'])#['model_state_dict'])#, strict= False) #strict: to ignore if keys are missing
else:
model.load_state_dict(model_para['model_state_dict'])
#model = nn.DataParallel(model)
model.requires_grad_(False)
model.to(device)
model.eval()
summary(model, (1, args.t_length, args.w, args.h))
img_crop_size = 0
#recon_error_list = [None] * len(video_data_loader)
recon_error_list = []
psnr_error_list = []
recon_error_list_l = [] #for last frame only instead of window
psnr_error_list_l = []
progress_bar = tqdm(video_data_loader)
for batch_idx, frames in enumerate(progress_bar):
progress_bar.update()
#frames = frames.reshape([batch_size, num_frame, chnum_in_, height, width])
frames = frames.reshape([frames.shape[0], num_frame, chnum_in_, height, width])
frames = frames.permute(0, 2, 1, 3, 4)
frames = frames.to(device)
if (ModelName == 'AE' or args.ModelName == 'AE_conv_stride_jrnl'or args.ModelName == 'AECov3DBlur'
or args.ModelName == 'AECov3Dref' or args.ModelName == 'AE_conv_stride_jrnld'
or args.ModelName == 'AE_conv_jrnld' or args.ModelName == 'AE_conv_jrnl'):
recon_frames = model(frames)
###### calculate reconstruction error (MSE)
unnormed_i = unorm_trans(frames.data)
unnormed_r = unorm_trans(recon_frames.data)
num_frames_batch = unnormed_i.shape[0]
#re_frames = []
unnormed_i = tensor2numpy(unnormed_i)
unnormed_r = tensor2numpy(unnormed_r)
for i in range(num_frames_batch):
input_np = unnormed_i[i, :, :, :, :]
input_np = np.transpose(input_np , (1, 0, 2, 3))
recon_np = unnormed_r[i, :, :, :, :]
recon_np = np.transpose(recon_np, (1, 0, 2, 3))
# input_np = utils.vframes2imgs(unnormed_i, step=1, batch_idx=i)
# recon_np = utils.vframes2imgs(unnormed_r, step=1, batch_idx=i)
r = utils.crop_image(recon_np, img_crop_size) - utils.crop_image(input_np, img_crop_size)
#r_l = utils.crop_image(recon_np[-1,:], img_crop_size) - utils.crop_image(input_np[-1,:], img_crop_size)
recon_error = np.mean(r ** 2) # **0.5 #TODO I did for i_LIDS
#recon_error_l = np.mean(r_l ** 2) # **0.5 #TODO I did for i_LIDS
if args.dataset_type == "i_LIDS":
#pts = np.array([[6, 357], [314, 169], [416, 42], [686, 42], [686, 556], [6, 556], [6, 357]], dtype=np.int32) # view 1
pts = np.array([[0, 0], [0, 560], [692, 560], [692, 400], [300, 0], [0, 0]], dtype=np.int32) #view 2
recon_error = window_zoned_mse(r, pts)
recon_error_l = frame_zoned_mse(r, pts)
#re_frames.append(recon_error)
recon_error_list.append(recon_error)
recon_error_list_l.append(recon_error_l)
#max_p = max(utils.crop_image(recon_np, img_crop_size))
#psnr_error = utils.psnr(recon_error, recon_np.max())
psnr_error = utils.psnr(recon_error)
psnr_error_list.append(psnr_error)
psnr_error_list_l.append(utils.psnr(recon_error_l))
elif (ModelName == 'MemAE' or ModelName=='MemSC' or ModelName=='AECov3DMaxUnMem'
or args.ModelName == 'AECov3DAvgMem' or args.ModelName=='AEhab' ):
recon_res = model(frames)
recon_frames = recon_res['output']
recon_np = utils.vframes2imgs(unorm_trans(recon_frames.data), step=1, batch_idx=0)
input_np = utils.vframes2imgs(unorm_trans(frames.data), step=1, batch_idx=0)
r = utils.crop_image(recon_np, img_crop_size) - utils.crop_image(input_np, img_crop_size)
sp_error_map = sum(r ** 2)**0.5
recon_error = np.mean(sp_error_map.flatten())
recon_error_list.append(recon_error)
else:
recon_error = -1
print('Wrong ModelName.')
#recon_error_list.extend(re_frames)
# recon_error_list.append(recon_error)
# recon_error_list[batch_idx] = recon_error
# recon_error_list = [v for j in recon_error_list for v in j]
print("The length of the reconstruction error is ", len(recon_error_list))
print("The length of the testing images is", len(data_loader))
print("............start to checking the anomaly detection auc score...................")
print("............use ckpt dir at step %d" % args.ckpt_step)
#eval_utils.eval_video2(gt_file, recon_error_list, args.dataset_type)
# sys.stdout = orig_stdout
# f.close()
save_path = model_dir + "recons_error_original_1.0_%d" % args.ckpt_step
np.save(save_path, recon_error_list)
save_path = model_dir + "psnr_error_original_1.0_%d" % args.ckpt_step
np.save(save_path, psnr_error_list)
save_path = model_dir + "recons_error_l_%d" % args.ckpt_step
np.save(save_path, recon_error_list_l)
save_path = model_dir + "psnr_error_l_%d" % args.ckpt_step
np.save(save_path, psnr_error_list_l)
print('done')
if __name__ == '__main__':
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment