From 8dc0b495ca1970fe6abda3ba4ec9e7f3c49abd68 Mon Sep 17 00:00:00 2001 From: maali <mahmoud-ahmed.ali@liris.cnrs.fr> Date: Mon, 27 Mar 2023 16:39:33 +0200 Subject: [PATCH] Add getAllValDataFruits function --- data.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/data.py b/data.py index e23af56..ae0dbd1 100644 --- a/data.py +++ b/data.py @@ -54,6 +54,26 @@ def getAllValData(modelClass='cat'): # retrieves random image and label set fro return image_ls, labels_ls, mask_ls +def getAllValDataFruits(modelClass='cat'): # retrieves random image and label set from specified dataset + trainData, validData = getDataSplit_Fruits(modelClass=modelClass) + basePath = os.path.dirname( + os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + modelClass + image_ls = [] + labels_ls = [] + mask_ls = [] + choice_ls = [] + for choice in validData: + with open(basePath + '/FPS_resized/' + choice[2]) as f: + labels = f.readline().split(' ')[1:19] + labels_ls.append(labels) + image = filePathToArray(basePath + '/RGB_resized/' + choice[0]) + mask = filePathToArray(basePath + '/Instance_Mask_resized/' + choice[1]) + image_ls.append(image) + mask_ls.append(mask) + choice_ls.append(choice[0]) + return image_ls, labels_ls, mask_ls, choice_ls + + def getMasterList(basePath): # returns list with image, mask, and label filenames imageList = sorted(os.listdir(basePath + '/rgb/')) maskList = sorted(os.listdir(basePath + '/mask/')) -- GitLab