From 8dc0b495ca1970fe6abda3ba4ec9e7f3c49abd68 Mon Sep 17 00:00:00 2001
From: maali <mahmoud-ahmed.ali@liris.cnrs.fr>
Date: Mon, 27 Mar 2023 16:39:33 +0200
Subject: [PATCH] Add getAllValDataFruits function

---
 data.py | 20 ++++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/data.py b/data.py
index e23af56..ae0dbd1 100644
--- a/data.py
+++ b/data.py
@@ -54,6 +54,26 @@ def getAllValData(modelClass='cat'):  # retrieves random image and label set fro
     return image_ls, labels_ls, mask_ls
 
 
+def getAllValDataFruits(modelClass='cat'):  # retrieves random image and label set from specified dataset
+    trainData, validData = getDataSplit_Fruits(modelClass=modelClass)
+    basePath = os.path.dirname(
+        os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + modelClass
+    image_ls = []
+    labels_ls = []
+    mask_ls = []
+    choice_ls = []
+    for choice in validData:
+        with open(basePath + '/FPS_resized/' + choice[2]) as f:
+            labels = f.readline().split(' ')[1:19]
+            labels_ls.append(labels)
+        image = filePathToArray(basePath + '/RGB_resized/' + choice[0])
+        mask = filePathToArray(basePath + '/Instance_Mask_resized/' + choice[1])
+        image_ls.append(image)
+        mask_ls.append(mask)
+        choice_ls.append(choice[0])
+    return image_ls, labels_ls, mask_ls, choice_ls
+
+
 def getMasterList(basePath):  # returns list with image, mask, and label filenames
     imageList = sorted(os.listdir(basePath + '/rgb/'))
     maskList = sorted(os.listdir(basePath + '/mask/'))
-- 
GitLab