From 2b5ffe1c6571d67864b0d0f0c36046182387fe68 Mon Sep 17 00:00:00 2001 From: Gduret <guillaume.duret@ec-lyon.fr> Date: Fri, 3 Mar 2023 00:27:10 +0100 Subject: [PATCH] add path Fruits --- data.py | 32 ++++++++++++++++++++ models.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/data.py b/data.py index ff4a828..2406026 100644 --- a/data.py +++ b/data.py @@ -64,6 +64,15 @@ def getMasterList(basePath): # returns list with image, mask, and label filenam return [[a, b, c] for a, b, c in zip(imageList, maskList, labelList)] +def getMasterList_Fruits(basePath): # returns list with image, mask, and label filenames + imageList = sorted(os.listdir(basePath + '/RGB_resized/')) + maskList = sorted(os.listdir(basePath + '/Instance_Mask_resized/')) + labelList = sorted(os.listdir(basePath + '/FPS_resized/')) + if len(imageList) != len(maskList) or len(imageList) != len(labelList): + raise Exception("image, mask, and label list lengths do not match.") + + return [[a, b, c] for a, b, c in zip(imageList, maskList, labelList)] + def classTrainingGenerator(model, batchSize, masterList=None, height=480, width=640, augmentation=True, **unused): # take input image, resize and store as rgb, create mask training data basePath = os.path.dirname(os.path.realpath(__file__)) + '/LINEMOD/' + model @@ -284,6 +293,29 @@ def getDataSplit(genNew=False, split=.8, modelClass='cat'): splitDict = pickle.load(f) return splitDict["trainData"], splitDict["validData"] +def getDataSplit_Fruits(genNew=False, split=.8, modelClass='cat'): + # access training data, get jpeg, mask, label filenames split into training / validation sets + if genNew: # create split + basePathTraining = f'/home/gduret/Documents/guimod/Generated_Worlds_Training/{modelClass}' # os.path.dirname(os.path.realpath(__file__)) + '/LINEMOD/' + modelClass + basePathEvaluating = f'/home/gduret/Documents/guimod/Generated_Worlds_Evaluating/{modelClass}' + masterList_Training = getMasterList_Fruits(basePathTraining) + masterList_Evalution = getMasterList_Fruits(basePathEvaluating) + random.shuffle(masterList_Training) + random.shuffle(masterList_Evalution) + + #splitPoint = round(len(masterList) * split) + + splitDict = {"trainData": masterList_Training, "validData": masterList_Evalution} + + with open("{0}_trainSplit".format(modelClass), 'wb') as f: + pickle.dump(splitDict, f) + + else: # load saved split + with open("{0}_trainSplit".format(modelClass), 'rb') as f: + splitDict = pickle.load(f) + return splitDict["trainData"], splitDict["validData"] + + def genAltLabels(p3dOld, p3dNew, matrix=np.array([[572.4114, 0., 325.2611], [0., 573.57043, 242.04899], [0., 0., 1.]]), method=cv2.SOLVEPNP_ITERATIVE, modelClass='cat', height=480, width=640, diff --git a/models.py b/models.py index 6c9e9bf..4934cbd 100644 --- a/models.py +++ b/models.py @@ -391,6 +391,93 @@ def trainModel(modelStruct, modelGen, modelClass='cat', batchSize=2, optimizer=t return model + +def trainModel_Fruits(modelStruct, modelGen, modelClass='cat', batchSize=2, optimizer=tf.keras.optimizers.Adam, + learning_rate=0.01, losses=None, metrics=None, saveModel=True, modelName='stvNet_weights', + epochs=1, loss_weights=None, outVectors=False, outClasses=False, dataSplit=True, altLabels=True, + augmentation=True): # train and save model weights + if metrics is None: + metrics = ['accuracy'] + if not (outVectors or outClasses): + print("At least one of outVectors or outClasses must be set to True.") + return + model = modelStruct(outVectors=outVectors, outClasses=outClasses, modelName=modelName) + model.summary() + model.compile(optimizer=optimizer(learning_rate=learning_rate), loss=losses, metrics=metrics, + loss_weights=loss_weights) + + trainData, validData = None, None + if dataSplit: # if using datasplit, otherwise all available data is used + trainData, validData = data.getDataSplit_Fruits(modelClass=modelClass) + + logger = tf.keras.callbacks.CSVLogger("models/history/" + modelName + "_" + modelClass + "_history.csv", + append=True) + # evalLogger = tf.keras.callbacks.CSVLogger("models/history/" + modelName + "_" + modelClass + "_eval_history.csv", append = True) + + history, valHistory = [], [] + + if type(losses) is dict: + outKeys = list(losses.keys()) + if len(outKeys) == 2: # combined output + for i in range(epochs): + print("Epoch {0} of {1}".format(i + 1, epochs)) + hist = model.fit(modelGen(modelClass, batchSize, masterList=trainData, out0=outKeys[0], out1=outKeys[1], + altLabels=altLabels, augmentation=augmentation), + steps_per_epoch=math.ceil(len(trainData) / batchSize), max_queue_size=2, + callbacks=[logger]) + history.append(hist.history) + if dataSplit: + print("Validation:") + valHist = model.evaluate( + modelGen(modelClass, batchSize, masterList=validData, out0=outKeys[0], out1=outKeys[1], + altLabels=altLabels, augmentation=False), steps=math.ceil(len(validData) / batchSize), + max_queue_size=2) + valHistory.append(valHist) + else: + raise Exception("Probably shouldn't be here ever..") + else: + for i in range(epochs): + print("Epoch {0} of {1}".format(i + 1, epochs)) + hist = model.fit( + modelGen(modelClass, batchSize, masterList=trainData, altLabels=altLabels, augmentation=augmentation), + steps_per_epoch=math.ceil(len(trainData) / batchSize), max_queue_size=2, callbacks=[logger]) + history.append(hist.history) + if dataSplit: + print("Validation:") + valHist = model.evaluate( + modelGen(modelClass, batchSize, masterList=validData, altLabels=altLabels, augmentation=False), + steps=math.ceil(len(validData) / batchSize), max_queue_size=2) + valHistory.append(valHist) + + historyLog = {"struct": modelStruct.__name__, + "class": modelClass, + "optimizer": optimizer, + "lr": learning_rate, + "losses": losses, + "name": modelName, + "epochs": epochs, + "history": history, + "evalHistory": valHistory, + "timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"), + } + + if saveModel: + model.save_weights(os.path.dirname(os.path.realpath(__file__)) + '/models/' + modelName + '_' + modelClass) + model.save(os.path.dirname(os.path.realpath(__file__)) + '/models/' + modelName + '_' + modelClass) + if not os.path.exists("models/history/" + modelName + '_trainHistory'): + with open("models/history/" + modelName + '_' + modelClass + '_trainHistory', + 'wb') as f: # create model history + pickle.dump([], f) + with open("models/history/" + modelName + '_' + modelClass + '_trainHistory', 'rb') as f: # loading old history + histories = pickle.load(f) + histories.append(historyLog) + with open("models/history/" + modelName + '_' + modelClass + '_trainHistory', + 'wb') as f: # saving the history of the model + pickle.dump(histories, f) + + return model + + def trainModels(modelSets, shutDown=False): for modelSet in modelSets: print("Training {0}".format(modelSet.name)) -- GitLab