Skip to content
Snippets Groups Projects
Commit e275fe64 authored by Guillaume Duret's avatar Guillaume Duret
Browse files

make path flexible for pvnet

parent 7e8ff8dc
No related branches found
No related tags found
No related merge requests found
......@@ -54,10 +54,10 @@ def getAllValData(modelClass='cat'): # retrieves random image and label set fro
return image_ls, labels_ls, mask_ls
def getAllValDataFruits(modelClass='cat'): # retrieves random image and label set from specified dataset
trainData, validData = getDataSplit_Fruits(modelClass=modelClass)
basePath = os.path.dirname(
os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + modelClass
def getAllValDataFruits(base_path, training_folder, evaluation_folder, modelClass='cat'): # retrieves random image and label set from specified dataset
trainData, validData = getDataSplit_Fruits(base_path, training_folder, evaluation_folder, modelClass=modelClass)
basePath = f'{base_path}/{evaluation_folder}/{modelClass}' #os.path.dirname(
#os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + modelClass
image_ls = []
labels_ls = []
mask_ls = []
......@@ -94,12 +94,15 @@ def getMasterList_Fruits(basePath): # returns list with image, mask, and label
return [[a, b, c] for a, b, c in zip(imageList, maskList, labelList)]
def classTrainingGenerator(model, batchSize, masterList=None, height=480, width=640, augmentation=True, val=False,
def classTrainingGenerator(base_path, training_folder, evaluation_folder, model, batchSize, masterList=None, height=480, width=640, augmentation=True, val=False,
**unused): # take input image, resize and store as rgb, create mask training data
basePath = os.path.dirname(os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Training/' + model
#basePath = os.path.dirname(os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Training/' + model
basePath = f'{base_path}/{training_folder}/{model}'
if val:
basePath = os.path.dirname(
os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + model
basePath = f'{base_path}/{evaluation_folder}/{model}' # os.path.dirname(
# os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + model
print("basePath : ", basePath)
if masterList is None:
masterList = getMasterList_Fruits(basePath)
......@@ -137,14 +140,15 @@ def classTrainingGenerator(model, batchSize, masterList=None, height=480, width=
yield np.array(xBatch), np.array(yClassBatch)
def coordsTrainingGenerator(model, batchSize, masterList=None, val=False, height=480, width=640, augmentation=True,
def coordsTrainingGenerator(base_path, training_folder, evaluation_folder, model, batchSize, masterList=None, val=False, height=480, width=640, augmentation=True,
altLabels=True): # takes input image and generates unit vector training data
print(f"-------- {batchSize}----------")
basePath = os.path.dirname(os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Training/' + model
#basePath = os.path.dirname(os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Training/' + model
basePath = f'{base_path}/{training_folder}/{model}'
if val:
basePath = os.path.dirname(
os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + model
basePath = f'{base_path}/{evaluation_folder}/{model}' # os.path.dirname(
#os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + model
if masterList == None:
masterList = getMasterList_Fruits(basePath)
......@@ -187,13 +191,14 @@ def coordsTrainingGenerator(model, batchSize, masterList=None, val=False, height
yield np.array(xBatch), np.array(yCoordBatch)
def combinedTrainingGenerator(model, batchSize, masterList=None, val=False, height=480, width=640, out0='activation_9',
def combinedTrainingGenerator(base_path, training_folder, evaluation_folder, model, batchSize, masterList=None, val=False, height=480, width=640, out0='activation_9',
out1='activation_10', augmentation=True,
altLabels=True): # take input image, resize and store as rgb, create training data
basePath = os.path.dirname(os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Training/' + model
#basePath = os.path.dirname(os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Training/' + model
basePath = f'{base_path}/{training_folder}/{model}'
if val:
basePath = os.path.dirname(
os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + model
basePath = f'{base_path}/{evaluation_folder}/{modelClass}' #os.path.dirname(
#os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + model
if masterList is None:
masterList = getMasterList_Fruits(basePath)
......@@ -327,12 +332,14 @@ def getDataSplit(genNew=False, split=.8, modelClass='cat'):
return splitDict["trainData"], splitDict["validData"]
def getDataSplit_Fruits(genNew=True, split=.8, modelClass='cat'):
def getDataSplit_Fruits(base_path, training_folder, evaluation_folder, genNew=True, split=.8, modelClass='cat'):
# access training data, get jpeg, mask, label filenames split into training / validation sets
if genNew: # create split
basePathTraining = os.path.dirname(os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Training/' + modelClass
basePathEvaluating = os.path.dirname(os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + modelClass
#basePathTraining = os.path.dirname(os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Training/' + modelClass
basePathTraining = f'{base_path}/{training_folder}/{modelClass}'
#basePathEvaluating = os.path.dirname(os.path.realpath(__file__)) + '/Generated_Worlds_/Generated_Worlds_Evaluating/' + modelClass
basePathEvaluating = f'{base_path}/{evaluation_folder}/{modelClass}'
masterList_Training = getMasterList_Fruits(basePathTraining)
masterList_Evalution = getMasterList_Fruits(basePathEvaluating)
random.shuffle(masterList_Training)
......
......@@ -307,7 +307,7 @@ def uNet(inputShape=(480, 640, 3), outVectors=True, outClasses=True,
return tf.keras.Model(inputs=[xIn], outputs=outputs, name=modelName)
def trainModel_Fruits(modelStruct, modelGen, modelClass='cat', batchSize=2, optimizer=tf.keras.optimizers.Adam,
def trainModel_Fruits(path_base, training_folder, evaluation_folder, modelStruct, modelGen, modelClass='cat', batchSize=2, optimizer=tf.keras.optimizers.Adam,
learning_rate=0.01, losses=None, metrics=None, saveModel=True, modelName='stvNet_weights',
epochs=1, loss_weights=None, outVectors=False, outClasses=False, dataSplit=True, altLabels=True,
augmentation=True): # train and save model weights
......@@ -323,7 +323,7 @@ def trainModel_Fruits(modelStruct, modelGen, modelClass='cat', batchSize=2, opti
trainData, validData = None, None
if dataSplit: # if using datasplit, otherwise all available data is used
trainData, validData = data.getDataSplit_Fruits(modelClass=modelClass)
trainData, validData = data.getDataSplit_Fruits(path_base, training_folder, evaluation_folder, modelClass=modelClass)
logger = tf.keras.callbacks.CSVLogger("models/history/" + modelName + "_" + modelClass + "_history.csv",
append=True)
......@@ -356,13 +356,13 @@ def trainModel_Fruits(modelStruct, modelGen, modelClass='cat', batchSize=2, opti
for i in range(epochs):
print("Epoch {0} of {1}".format(i + 1, epochs))
hist = model.fit(
modelGen(modelClass, batchSize, masterList=trainData, altLabels=altLabels, augmentation=augmentation),
modelGen(path_base, training_folder, evaluation_folder, modelClass, batchSize, masterList=trainData, altLabels=altLabels, augmentation=augmentation),
steps_per_epoch=math.ceil(len(trainData)), max_queue_size=2, callbacks=[logger])
history.append(hist.history)
if dataSplit:
print("########## Validation: ############2")
valHist = model.evaluate(
modelGen(modelClass, batchSize, val=True, masterList=validData, altLabels=altLabels,
modelGen(path_base, training_folder, evaluation_folder, modelClass, batchSize, val=True, masterList=validData, altLabels=altLabels,
augmentation=False),
steps=math.ceil(len(validData)), max_queue_size=2)
valHistory.append(valHist)
......@@ -412,7 +412,7 @@ def trainModel(modelStruct, modelGen, modelClass='cat', batchSize=2, optimizer=t
trainData, validData = None, None
if dataSplit: # if using datasplit, otherwise all available data is used
trainData, validData = data.getDataSplit_Fruits(modelClass=modelClass)
trainData, validData = data.getDataSplit(modelClass=modelClass)
logger = tf.keras.callbacks.CSVLogger("models/history/" + modelName + "_" + modelClass + "_history.csv",
append=True)
......@@ -482,11 +482,11 @@ def trainModel(modelStruct, modelGen, modelClass='cat', batchSize=2, optimizer=t
return model
def trainModels(modelSets, shutDown=False):
def trainModels(path_base, training_folder, evaluation_folder, modelSets, shutDown=False):
for modelSet in modelSets:
print("Training {0}".format(modelSet.name))
model = modelsDict[modelSet.name]
trainModel_Fruits(model.structure, model.generator, modelClass=modelSet.modelClass, epochs=model.epochs,
trainModel_Fruits(path_base, training_folder, evaluation_folder, model.structure, model.generator, modelClass=modelSet.modelClass, epochs=model.epochs,
losses=model.losses, modelName=modelSet.name, outClasses=model.outClasses,
outVectors=model.outVectors, learning_rate=model.lr, metrics=model.metrics,
altLabels=model.altLabels, augmentation=model.augmentation)
......@@ -686,12 +686,18 @@ if __name__ == "__main__":
ap.add_argument("-cls_name", "--class_name", type=str,
default='kiwi1',
help="[kiwi1, pear2, banana1, orange, peach1]")
ap.add_argument("--path_base", type=str, required=True)
ap.add_argument("--training_folder", type=str, required=True)
ap.add_argument("--evaluation_folder", type=str, required=True)
args = vars(ap.parse_args())
class_name = args["class_name"]
path_base = args["path_base"]
training_folder = args["training_folder"]
evaluation_folder = args["evaluation_folder"]
modelSets = [modelSet('uNet_classes', class_name), modelSet('stvNet_new_coords', class_name)]
trainModels(modelSets)
trainModels(path_base, training_folder, evaluation_folder, modelSets)
evaluateModels(modelSets)
loadHistories(modelSets)
plotHistories(modelSets)
......@@ -23,10 +23,10 @@ set -x
cat_target=$1
echo $cat_target
path_base=$2
training_folder=$3
evaluation_folder=$4
conda activate tf
python models.py -cls_name $cat_target
python models.py -cls_name $cat_target --path_base $path_base --training_folder $training_folder --evaluation_folder $evaluation_folder
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment