From 2b5ffe1c6571d67864b0d0f0c36046182387fe68 Mon Sep 17 00:00:00 2001
From: Gduret <guillaume.duret@ec-lyon.fr>
Date: Fri, 3 Mar 2023 00:27:10 +0100
Subject: [PATCH] add path Fruits

---
 data.py   | 32 ++++++++++++++++++++
 models.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 119 insertions(+)

diff --git a/data.py b/data.py
index ff4a828..2406026 100644
--- a/data.py
+++ b/data.py
@@ -64,6 +64,15 @@ def getMasterList(basePath):  # returns list with image, mask, and label filenam
     return [[a, b, c] for a, b, c in zip(imageList, maskList, labelList)]
 
 
+def getMasterList_Fruits(basePath):  # returns list with image, mask, and label filenames
+    imageList = sorted(os.listdir(basePath + '/RGB_resized/'))
+    maskList = sorted(os.listdir(basePath + '/Instance_Mask_resized/'))
+    labelList = sorted(os.listdir(basePath + '/FPS_resized/'))
+    if len(imageList) != len(maskList) or len(imageList) != len(labelList):
+        raise Exception("image, mask, and label list lengths do not match.")
+
+    return [[a, b, c] for a, b, c in zip(imageList, maskList, labelList)]
+
 def classTrainingGenerator(model, batchSize, masterList=None, height=480, width=640, augmentation=True,
                            **unused):  # take input image, resize and store as rgb, create mask training data
     basePath = os.path.dirname(os.path.realpath(__file__)) + '/LINEMOD/' + model
@@ -284,6 +293,29 @@ def getDataSplit(genNew=False, split=.8, modelClass='cat'):
             splitDict = pickle.load(f)
     return splitDict["trainData"], splitDict["validData"]
 
+def getDataSplit_Fruits(genNew=False, split=.8, modelClass='cat'):
+    # access training data, get jpeg, mask, label filenames split into training / validation sets
+    if genNew:  # create split
+        basePathTraining = f'/home/gduret/Documents/guimod/Generated_Worlds_Training/{modelClass}'  # os.path.dirname(os.path.realpath(__file__)) + '/LINEMOD/' + modelClass
+        basePathEvaluating = f'/home/gduret/Documents/guimod/Generated_Worlds_Evaluating/{modelClass}'
+        masterList_Training = getMasterList_Fruits(basePathTraining)
+        masterList_Evalution = getMasterList_Fruits(basePathEvaluating)
+        random.shuffle(masterList_Training)
+        random.shuffle(masterList_Evalution)
+
+        #splitPoint = round(len(masterList) * split)
+
+        splitDict = {"trainData": masterList_Training, "validData": masterList_Evalution}
+
+        with open("{0}_trainSplit".format(modelClass), 'wb') as f:
+            pickle.dump(splitDict, f)
+
+    else:  # load saved split
+        with open("{0}_trainSplit".format(modelClass), 'rb') as f:
+            splitDict = pickle.load(f)
+    return splitDict["trainData"], splitDict["validData"]
+
+
 
 def genAltLabels(p3dOld, p3dNew, matrix=np.array([[572.4114, 0., 325.2611], [0., 573.57043, 242.04899], [0., 0., 1.]]),
                  method=cv2.SOLVEPNP_ITERATIVE, modelClass='cat', height=480, width=640,
diff --git a/models.py b/models.py
index 6c9e9bf..4934cbd 100644
--- a/models.py
+++ b/models.py
@@ -391,6 +391,93 @@ def trainModel(modelStruct, modelGen, modelClass='cat', batchSize=2, optimizer=t
     return model
 
 
+
+def trainModel_Fruits(modelStruct, modelGen, modelClass='cat', batchSize=2, optimizer=tf.keras.optimizers.Adam,
+               learning_rate=0.01, losses=None, metrics=None, saveModel=True, modelName='stvNet_weights',
+               epochs=1, loss_weights=None, outVectors=False, outClasses=False, dataSplit=True, altLabels=True,
+               augmentation=True):  # train and save model weights
+    if metrics is None:
+        metrics = ['accuracy']
+    if not (outVectors or outClasses):
+        print("At least one of outVectors or outClasses must be set to True.")
+        return
+    model = modelStruct(outVectors=outVectors, outClasses=outClasses, modelName=modelName)
+    model.summary()
+    model.compile(optimizer=optimizer(learning_rate=learning_rate), loss=losses, metrics=metrics,
+                  loss_weights=loss_weights)
+
+    trainData, validData = None, None
+    if dataSplit:  # if using datasplit, otherwise all available data is used
+        trainData, validData = data.getDataSplit_Fruits(modelClass=modelClass)
+
+    logger = tf.keras.callbacks.CSVLogger("models/history/" + modelName + "_" + modelClass + "_history.csv",
+                                          append=True)
+    # evalLogger = tf.keras.callbacks.CSVLogger("models/history/" + modelName + "_" + modelClass + "_eval_history.csv", append = True)
+
+    history, valHistory = [], []
+
+    if type(losses) is dict:
+        outKeys = list(losses.keys())
+        if len(outKeys) == 2:  # combined output
+            for i in range(epochs):
+                print("Epoch {0} of {1}".format(i + 1, epochs))
+                hist = model.fit(modelGen(modelClass, batchSize, masterList=trainData, out0=outKeys[0], out1=outKeys[1],
+                                          altLabels=altLabels, augmentation=augmentation),
+                                 steps_per_epoch=math.ceil(len(trainData) / batchSize), max_queue_size=2,
+                                 callbacks=[logger])
+                history.append(hist.history)
+                if dataSplit:
+                    print("Validation:")
+                    valHist = model.evaluate(
+                        modelGen(modelClass, batchSize, masterList=validData, out0=outKeys[0], out1=outKeys[1],
+                                 altLabels=altLabels, augmentation=False), steps=math.ceil(len(validData) / batchSize),
+                        max_queue_size=2)
+                    valHistory.append(valHist)
+        else:
+            raise Exception("Probably shouldn't be here ever..")
+    else:
+        for i in range(epochs):
+            print("Epoch {0} of {1}".format(i + 1, epochs))
+            hist = model.fit(
+                modelGen(modelClass, batchSize, masterList=trainData, altLabels=altLabels, augmentation=augmentation),
+                steps_per_epoch=math.ceil(len(trainData) / batchSize), max_queue_size=2, callbacks=[logger])
+            history.append(hist.history)
+            if dataSplit:
+                print("Validation:")
+                valHist = model.evaluate(
+                    modelGen(modelClass, batchSize, masterList=validData, altLabels=altLabels, augmentation=False),
+                    steps=math.ceil(len(validData) / batchSize), max_queue_size=2)
+                valHistory.append(valHist)
+
+    historyLog = {"struct": modelStruct.__name__,
+                  "class": modelClass,
+                  "optimizer": optimizer,
+                  "lr": learning_rate,
+                  "losses": losses,
+                  "name": modelName,
+                  "epochs": epochs,
+                  "history": history,
+                  "evalHistory": valHistory,
+                  "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
+                  }
+
+    if saveModel:
+        model.save_weights(os.path.dirname(os.path.realpath(__file__)) + '/models/' + modelName + '_' + modelClass)
+        model.save(os.path.dirname(os.path.realpath(__file__)) + '/models/' + modelName + '_' + modelClass)
+        if not os.path.exists("models/history/" + modelName + '_trainHistory'):
+            with open("models/history/" + modelName + '_' + modelClass + '_trainHistory',
+                      'wb') as f:  # create model history
+                pickle.dump([], f)
+        with open("models/history/" + modelName + '_' + modelClass + '_trainHistory', 'rb') as f:  # loading old history
+            histories = pickle.load(f)
+        histories.append(historyLog)
+        with open("models/history/" + modelName + '_' + modelClass + '_trainHistory',
+                  'wb') as f:  # saving the history of the model
+            pickle.dump(histories, f)
+
+    return model
+
+
 def trainModels(modelSets, shutDown=False):
     for modelSet in modelSets:
         print("Training {0}".format(modelSet.name))
-- 
GitLab