Skip to content
Snippets Groups Projects
Commit 9b4bc38c authored by Mahmoud Ahmed Ali's avatar Mahmoud Ahmed Ali
Browse files

Add classes file

parent b8fa59e3
No related branches found
No related tags found
No related merge requests found
import random
class modelSet:
def __init__(self, modelName, modelClass='cat'):
self.name = modelName
self.modelClass = modelClass
class modelWrapper: # allows user to use single-net and double-net model interchangeably
def __init__(self, models, altLabels):
if type(models) is dict:
if len(models.keys()) == 2:
self.classModel = models['classModel']
self.vecModel = models['vecModel']
self.combined = False
else:
raise Exception("Probably shouldn't be here ever..")
else:
self.combModel = models
self.combined = True
self.altLabels = altLabels
def genPredict(self, input):
if self.combined:
pred = self.combModel.predict(input)
return pred[0], pred[1]
else:
return self.vecModel.predict(input), self.classModel.predict(input)
class modelDictVal:
def __init__(self, structure, generator, losses, outVectors, outClasses, epochs=3, lr=0.01, metrics=None,
outVecName=None, outClassName=None, altLabels=True, augmentation=True):
if metrics is None:
metrics = ['accuracy']
self.structure = structure
self.generator = generator
self.losses = losses
self.outVectors = outVectors
self.outClasses = outClasses
self.epochs = epochs
self.metrics = metrics
self.lr = lr
self.outVecName = outVecName
self.outClassName = outClassName
self.altLabels = altLabels
self.augmentation = augmentation
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment