Skip to content
Snippets Groups Projects
seq_exp.py 15.36 KiB
import multiprocessing
import os
import time
import numpy as np
import pandas as pd
import csv
import matplotlib.pyplot as plt
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
from keras.models import load_model
# import tensorflow.data
import h5py
import glob
from sklearn.metrics import average_precision_score
import matplotlib

matplotlib.use('Agg')
import sys

sys.path.insert(0, './animation')
# from plot_video_animation_3D import *

from img_exp import ImgExp
from h5py_init import init_videos, flip_windowed_arr
# from models import *
from util import generate_vid_keys, generate_test_vid_names, agg_window, get_output, gather_auc_avg_per_tol, \
    animate_fall_detect_Spresent, join_mean_std, truncate , plot_RE_hist, animate_fall_detect_Spresent_o

from data_management import init_windowed_arr, create_windowed_arr, generate_training_windows, make_windows_h5_file, get_windowed_data

root_drive = '.'


class SeqExp(ImgExp):
    '''
        A autoencoder experiment based on sequence of images
        Inherits get_thresh, save exp, and variable initialization

        Attributes:
            int win_len: window length, or the number of contigous frames forming a sample 
        '''

    def __init__(self, model=None, model_name=None, misc_save_info=None,
                 batch_size=32, model_type=None, callbacks_list=None,
                 pre_load=None, initial_epoch=0, epochs=1, dset='Thermal',
                 win_len=8, hor_flip=False, img_width=64, img_height=64):

        ImgExp.__init__(self, model=model, img_width=img_width, \
                        img_height=img_height, model_name=model_name, \
                        batch_size=batch_size, model_type=model_type, \
                        pre_load=pre_load, initial_epoch=initial_epoch, \
                        epochs=epochs, hor_flip=hor_flip, dset=dset)

        self.win_len = win_len

    def set_train_data(self, raw=False, mmap_mode=None):  # TODO init windows from h5py if no npData found

        '''
        loads or initializes windowed train data, and sets self.train_data accordingly
        '''

        to_load = root_drive + '/npData/{}/ADL_data-proc-win_{}.npy'.format(self.dset, self.win_len)

        if os.path.isfile(to_load):
            print('npData found, loading..')
            self.train_data = np.load(to_load, mmap_mode=mmap_mode)
        else:
            print('npData not found, initializing..')

            self.train_data = init_windowed_arr(dset=self.dset, ADL_only=True,
                                                win_len=self.win_len,
                                                img_width=self.img_width,
                                                img_height=self.img_height)


        if self.hor_flip == True:
            to_load_flip = './npData/hor_flip-by_window/{}'. \
                format(os.path.basename(to_load))
            data_flip = self.init_flipped_by_win(to_load_flip)

            self.train_data = np.concatenate((self.train_data, data_flip), axis=0)

    def get_h5_file(self, raw=False, mmap_mode=None):  # TODO init windows from h5py if no npData found
        '''
            Creates or overwrites h5py file
            '''

        master_path = root_drive + '/H5Data/Data_set-{}-imgdim{}x{}.h5'.format(self.dset, self.img_width,
                                                                               self.img_height)
        if not os.path.isfile(master_path):
            print('initializing h5py..')
            init_videos(img_width=self.img_width, img_height=self.img_height, raw=raw, dset=self.dset)

        return master_path

    def train(self, sample_weight=None):

        """
                trains a sequential autoencoder on windowed data(sequences of
                contiguous frames are reconstucted)
                """

        model_name = self.model_name
        base = './Checkpoints/{}'.format(self.dset)
        base_logs = './logs/{}'.format(self.dset)

        if not os.path.isdir(base):
            os.mkdir(base)
        if not os.path.isdir(base_logs):
            os.mkdir(base_logs)

        checkpointer = ModelCheckpoint(filepath=base + '/' + model_name + '-' + \
                                                '{epoch:03d}-{loss:.3f}.hdf5', period=100, verbose=1)
        timestamp = time.time()
        print('./Checkpoints/' + model_name + '-' + '.{epoch:03d}-{loss:.3f}.hdf5')

        csv_logger = CSVLogger(base_logs + '/' + model_name + '-' + 'training-' + str(timestamp) + '.log')
        callbacks_list = [csv_logger, checkpointer]
        print(callbacks_list)
        print(csv_logger)

        self.model.fit(self.train_data, self.train_data, epochs = self.epochs,
                       batch_size = self.batch_size, verbose = 2,
                       callbacks = callbacks_list, sample_weight = sample_weight)
        self.save_exp()

    def init_flipped_by_win(self, to_load_flip):

        if os.path.isfile(to_load_flip):
            data_flip = np.load(to_load_flip)
            data_flip = data_flip.reshape(len(data_flip), self.train_data.shape[1],
                                          self.train_data.shape[2],
                                          self.train_data.shape[3], 1)

            return data_flip

        else:
            print('creating flipped by window data..')
            data_flip = flip_windowed_arr(self.train_data)

            return data_flip

    def get_MSE(self, test_data, agg_type='r_sigma'):
        '''
            MSE for sequential data (video). Uses data chunking with memap for SDU-Filled.
            Assumes windowed
            
            Params:
                ndarray test_data: data used to test model (reconstrcut). Of 
                    shape (samples, window length, img_width, img_height)
                agg_type: how to aggregate windowed scores

            Returns:
                ndarray: Mean squared error between test_data windows and reconstructed windows,
                aggregated.
                This gives (samples,) shape
    
            '''

        img_width, img_height, win_len, model, stride = self.img_width, self.img_height, \
                                                        self.win_len, self.model, 1

        if test_data.shape[1] != win_len:  # Not windowed
            test_data = test_data.reshape(len(test_data), img_width, img_height, 1)
            test_data = create_windowed_arr(test_data, stride, win_len)

       # Do prediction on test windows with learned model
        recons_seq = model.predict(test_data)  # (samples-win_len+1, win_len, wd,ht,1)
        recons_seq_or = recons_seq

        test_data = test_data.reshape(len(test_data), win_len, img_height * img_width)
        recons_seq = recons_seq.reshape(len(recons_seq), win_len, img_height * img_width)

        error = test_data - recons_seq
        # Calculate Reconstruction Error(MSE) for all windows
        RE = np.mean(np.power(error, 2), axis=2)  # (samples-win_len+1,win_len)

        RE = agg_window(RE, agg_type)

        return RE, recons_seq_or, error

    def get_MSE_all_agg(self, test_data):

        """
            Gets MSE for all aggregate types 'r_sigma', 'r_mu', 'in_std', 'in_mean'.

            Params:
                ndarray test_data: data used to test model (reconstruct).
                shape (samples(windows), window length, img_width, img_height, 1)

            Returns:
                dictionary with keys 'r_sigma', 'r_mu', 'in_std', 'in_mean', and values
                ndarrays of shape (samples,)
            """

        img_width, img_height, win_len, model = self.img_width, self.img_height, self.win_len, \
                                                self.model

        # Do prediction on test windows_animation with learned model
        recons_seq = model.predict(test_data)  # (samples(or frames)-win_len+1, win_len, wd, ht, 1)
        recons_seq_or = recons_seq

        # reshape test windows & reconstructed windows
        test_data = test_data.reshape(len(test_data), win_len, img_height * img_width)
        recons_seq = recons_seq.reshape(len(recons_seq), win_len, img_height * img_width)

        error = test_data - recons_seq

        # Calculate Reconstruction Error(MSE) for all windows
        RE = np.mean(np.power(error, 2), axis=2)  # (samples-win_len+1, win_len)

        RE_dict = {}
        agg_type_list = ['r_sigma', 'r_mu', 'in_std', 'in_mean', 'r']

        # Get various per frame RE score
        for agg_type in agg_type_list:
            RE_dict[agg_type] = agg_window(RE, agg_type)

        return RE_dict, recons_seq_or, error

    def test(self, eval_type = 'per_video', RE_type ='r', animate=False, indicative_threshold_animation=False):

        '''
            Gets AUC ROC/PR for all videos, using various (20) scoring schemes.
            Save scores to './AEComparisons/all_scores/self.dset/self.model_name.csv'
            Assumes self.model has been initialized
            '''

        dset, to_load, img_width, img_height = self.dset, self.pre_load, self.img_width, self.img_height
        stride = 1
        win_len = self.win_len
        model_name = os.path.basename(to_load).split('.')[0]
        print(model_name)

        labels_total_l = []
        vid_index = 0
        preds_total = []

        # h5 file path!
        path = root_drive + '/H5Data/Data_set-{}-imgdim{}x{}.h5'.format(dset, img_width, img_height)

        # Create h5 if does not already exist
        if not os.path.isfile(path):
            print('initializing h5py..')
            init_videos(img_width=img_width, img_height=img_height, raw=False, dset=dset)

        hf = h5py.File(path, 'r')
        data_dict = hf['{}/Processed/Split_by_video'.format(dset)] # Get into Split by Videos for Fall Folders Data
        vid_dir_keys = generate_test_vid_names(data_dict, dset)

        num_vids = len(vid_dir_keys)
        print('num_vids', num_vids)

        base = './AEComparisons/scores/{}/'.format(self.dset)

        if not os.path.isdir(base):
            # create dir if it does not exist!
            os.makedirs(base)

        save_path = base+'{}_{}.csv'.format(model_name + "-" + RE_type, eval_type)

        f = open(save_path, 'w')
        with f :
            fnames = ['AUROC'+':'+RE_type+'('+eval_type+')', 'AUPR'+':'+ RE_type+'('+eval_type+')'] # AUROC:r(all_vids), AUROC:r(per_vid)
            writer = csv.DictWriter(f, fieldnames=fnames)
            writer.writeheader()

            if(eval_type=='per_video'):
                ROC_mat = np.ones((num_vids, 1))
                PR_mat = np.ones((num_vids, 1))

            for Fall_name in vid_dir_keys:

                if(Fall_name=='Intru65'):
                    exit(8)
                print(Fall_name)
                start_time = time.time()
                # Get frames and labels of a Fallx folder
                vid_total = data_dict[Fall_name]['Data'][:]
                labels_total = data_dict[Fall_name]['Labels'][:]
                test_labels = labels_total
                test_data = vid_total.reshape(len(vid_total), img_width, img_height, 1)
                # Create array of frame windows!
                test_data_windowed = create_windowed_arr(test_data, stride, win_len)

                RE_pred, recons_seq, error_seq = self.get_MSE(test_data_windowed, agg_type=RE_type)

                if(eval_type=='all_videos'):
                    labels_total_l.extend(labels_total)
                    preds_total.extend(RE_pred)

                elif(eval_type=='per_video'):
                    plot = False
                    # Get AUROC, Confusion matrix, Geometric mean score & AUPR scores
                    auc_roc, conf_mat, g_mean, auc_pr, roc_thres, pr_thres = get_output(
                        labels=test_labels, \
                        predictions=RE_pred, get_thres=True,
                        data_option=RE_type,
                        to_plot=plot, dset=dset,
                        model_name=model_name,
                        dir_name=Fall_name)
                    ROC_mat[vid_index, 0] = auc_roc
                    PR_mat[vid_index, 0] = auc_pr

                    writer.writerow({fnames[0]: truncate(auc_roc, 3), fnames[1]: truncate(auc_pr, 3)})

                    print("AUROC", auc_roc)
                    print("AUPR", auc_pr)

                    print("Without Animation Time %.2f s or %.2f mins" % (time.time() - start_time, (time.time() - start_time) / 60))

                    if animate == True:
                        ani_dir = './Animation/{}'.format(dset)
                        ani_dir = ani_dir + '/{}'.format(model_name)
                        if not os.path.isdir(ani_dir):
                            os.makedirs(ani_dir)
                        ani_dir = ani_dir + '/{}'.format(Fall_name)
                        if not os.path.isdir(ani_dir):
                            os.makedirs(ani_dir)
                        print('saving animation to {}'.format(ani_dir))
                        tr = 0 # by default no indicative threshold
                        if(indicative_threshold_animation== True):
                            tr = roc_thres
                        animate_fall_detect_Spresent(testfall=test_data, \
                                                     recons=recons_seq[:, int(np.floor(win_len / 2)), :],
                                                     error_map_seq=error_seq[:, int(np.floor(win_len / 2)), :],
                                                     scores=RE_pred,
                                                     to_save=ani_dir + '/{}.mp4'.format(RE_type),
                                                     win_len=8,
                                                     legend_label= RE_type,
                                                     threshold= tr
                                                     )

                        print("With Animation Time %.2f s or %.2f mins" % (
                            time.time() - start_time, (time.time() - start_time) / 60))

                else:
                    print("eval_type wrong...")
                    print("possible options: all_videos or per_video")
                    print("exiting...")
                    exit()
                vid_index += 1  # next video

            if(eval_type=='all_videos'):
                labels_total_l = np.asarray(labels_total_l)
                preds_total = np.asarray(preds_total)
                plot= False
                auc_roc, conf_mat, g_mean, auc_pr, roc_thres, pr_thres = get_output(
                    labels=labels_total_l, \
                    predictions=preds_total, get_thres=True,
                    data_option=RE_type,
                    to_plot=plot, dset=dset,
                    model_name=model_name,
                    dir_name=Fall_name)
                writer.writerow({fnames[0]: truncate(auc_roc, 2), fnames[1]: truncate(auc_pr, 2)})
                f.close()

            else:
                print('ROC_mat.shape', ROC_mat.shape)
                AUROC_avg = np.mean(ROC_mat, axis=0)
                AUROC_std = np.std(ROC_mat, axis=0)
                AUROC_avg_std = join_mean_std(AUROC_avg, AUROC_std)

                # Same for PR values
                AUPR_avg = np.mean(PR_mat, axis=0)
                AUPR_std = np.std(PR_mat, axis=0)
                AUPR_avg_std = join_mean_std(AUPR_avg, AUPR_std)

                writer.writerow({fnames[0]: 'Average (Std)'
                                 })

                writer.writerow({fnames[0]: AUROC_avg_std[0],
                                 fnames[1]: AUPR_avg_std[0],
                                 })

                f.close()