Skip to content
Snippets Groups Projects
Commit 2b5ffe1c authored by Guillaume Duret's avatar Guillaume Duret
Browse files

add path Fruits

parent 487d028a
No related branches found
No related tags found
No related merge requests found
......@@ -64,6 +64,15 @@ def getMasterList(basePath): # returns list with image, mask, and label filenam
return [[a, b, c] for a, b, c in zip(imageList, maskList, labelList)]
def getMasterList_Fruits(basePath): # returns list with image, mask, and label filenames
imageList = sorted(os.listdir(basePath + '/RGB_resized/'))
maskList = sorted(os.listdir(basePath + '/Instance_Mask_resized/'))
labelList = sorted(os.listdir(basePath + '/FPS_resized/'))
if len(imageList) != len(maskList) or len(imageList) != len(labelList):
raise Exception("image, mask, and label list lengths do not match.")
return [[a, b, c] for a, b, c in zip(imageList, maskList, labelList)]
def classTrainingGenerator(model, batchSize, masterList=None, height=480, width=640, augmentation=True,
**unused): # take input image, resize and store as rgb, create mask training data
basePath = os.path.dirname(os.path.realpath(__file__)) + '/LINEMOD/' + model
......@@ -284,6 +293,29 @@ def getDataSplit(genNew=False, split=.8, modelClass='cat'):
splitDict = pickle.load(f)
return splitDict["trainData"], splitDict["validData"]
def getDataSplit_Fruits(genNew=False, split=.8, modelClass='cat'):
# access training data, get jpeg, mask, label filenames split into training / validation sets
if genNew: # create split
basePathTraining = f'/home/gduret/Documents/guimod/Generated_Worlds_Training/{modelClass}' # os.path.dirname(os.path.realpath(__file__)) + '/LINEMOD/' + modelClass
basePathEvaluating = f'/home/gduret/Documents/guimod/Generated_Worlds_Evaluating/{modelClass}'
masterList_Training = getMasterList_Fruits(basePathTraining)
masterList_Evalution = getMasterList_Fruits(basePathEvaluating)
random.shuffle(masterList_Training)
random.shuffle(masterList_Evalution)
#splitPoint = round(len(masterList) * split)
splitDict = {"trainData": masterList_Training, "validData": masterList_Evalution}
with open("{0}_trainSplit".format(modelClass), 'wb') as f:
pickle.dump(splitDict, f)
else: # load saved split
with open("{0}_trainSplit".format(modelClass), 'rb') as f:
splitDict = pickle.load(f)
return splitDict["trainData"], splitDict["validData"]
def genAltLabels(p3dOld, p3dNew, matrix=np.array([[572.4114, 0., 325.2611], [0., 573.57043, 242.04899], [0., 0., 1.]]),
method=cv2.SOLVEPNP_ITERATIVE, modelClass='cat', height=480, width=640,
......
......@@ -391,6 +391,93 @@ def trainModel(modelStruct, modelGen, modelClass='cat', batchSize=2, optimizer=t
return model
def trainModel_Fruits(modelStruct, modelGen, modelClass='cat', batchSize=2, optimizer=tf.keras.optimizers.Adam,
learning_rate=0.01, losses=None, metrics=None, saveModel=True, modelName='stvNet_weights',
epochs=1, loss_weights=None, outVectors=False, outClasses=False, dataSplit=True, altLabels=True,
augmentation=True): # train and save model weights
if metrics is None:
metrics = ['accuracy']
if not (outVectors or outClasses):
print("At least one of outVectors or outClasses must be set to True.")
return
model = modelStruct(outVectors=outVectors, outClasses=outClasses, modelName=modelName)
model.summary()
model.compile(optimizer=optimizer(learning_rate=learning_rate), loss=losses, metrics=metrics,
loss_weights=loss_weights)
trainData, validData = None, None
if dataSplit: # if using datasplit, otherwise all available data is used
trainData, validData = data.getDataSplit_Fruits(modelClass=modelClass)
logger = tf.keras.callbacks.CSVLogger("models/history/" + modelName + "_" + modelClass + "_history.csv",
append=True)
# evalLogger = tf.keras.callbacks.CSVLogger("models/history/" + modelName + "_" + modelClass + "_eval_history.csv", append = True)
history, valHistory = [], []
if type(losses) is dict:
outKeys = list(losses.keys())
if len(outKeys) == 2: # combined output
for i in range(epochs):
print("Epoch {0} of {1}".format(i + 1, epochs))
hist = model.fit(modelGen(modelClass, batchSize, masterList=trainData, out0=outKeys[0], out1=outKeys[1],
altLabels=altLabels, augmentation=augmentation),
steps_per_epoch=math.ceil(len(trainData) / batchSize), max_queue_size=2,
callbacks=[logger])
history.append(hist.history)
if dataSplit:
print("Validation:")
valHist = model.evaluate(
modelGen(modelClass, batchSize, masterList=validData, out0=outKeys[0], out1=outKeys[1],
altLabels=altLabels, augmentation=False), steps=math.ceil(len(validData) / batchSize),
max_queue_size=2)
valHistory.append(valHist)
else:
raise Exception("Probably shouldn't be here ever..")
else:
for i in range(epochs):
print("Epoch {0} of {1}".format(i + 1, epochs))
hist = model.fit(
modelGen(modelClass, batchSize, masterList=trainData, altLabels=altLabels, augmentation=augmentation),
steps_per_epoch=math.ceil(len(trainData) / batchSize), max_queue_size=2, callbacks=[logger])
history.append(hist.history)
if dataSplit:
print("Validation:")
valHist = model.evaluate(
modelGen(modelClass, batchSize, masterList=validData, altLabels=altLabels, augmentation=False),
steps=math.ceil(len(validData) / batchSize), max_queue_size=2)
valHistory.append(valHist)
historyLog = {"struct": modelStruct.__name__,
"class": modelClass,
"optimizer": optimizer,
"lr": learning_rate,
"losses": losses,
"name": modelName,
"epochs": epochs,
"history": history,
"evalHistory": valHistory,
"timestamp": datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
}
if saveModel:
model.save_weights(os.path.dirname(os.path.realpath(__file__)) + '/models/' + modelName + '_' + modelClass)
model.save(os.path.dirname(os.path.realpath(__file__)) + '/models/' + modelName + '_' + modelClass)
if not os.path.exists("models/history/" + modelName + '_trainHistory'):
with open("models/history/" + modelName + '_' + modelClass + '_trainHistory',
'wb') as f: # create model history
pickle.dump([], f)
with open("models/history/" + modelName + '_' + modelClass + '_trainHistory', 'rb') as f: # loading old history
histories = pickle.load(f)
histories.append(historyLog)
with open("models/history/" + modelName + '_' + modelClass + '_trainHistory',
'wb') as f: # saving the history of the model
pickle.dump(histories, f)
return model
def trainModels(modelSets, shutDown=False):
for modelSet in modelSets:
print("Training {0}".format(modelSet.name))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment