diff --git a/classes.py b/classes.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2d74f6a65ecbc0dc9e7763de85888190257d17 --- /dev/null +++ b/classes.py @@ -0,0 +1,52 @@ +import random + + +class modelSet: + + def __init__(self, modelName, modelClass='cat'): + self.name = modelName + self.modelClass = modelClass + + +class modelWrapper: # allows user to use single-net and double-net model interchangeably + + def __init__(self, models, altLabels): + if type(models) is dict: + if len(models.keys()) == 2: + self.classModel = models['classModel'] + self.vecModel = models['vecModel'] + self.combined = False + else: + raise Exception("Probably shouldn't be here ever..") + else: + self.combModel = models + self.combined = True + + self.altLabels = altLabels + + def genPredict(self, input): + if self.combined: + pred = self.combModel.predict(input) + return pred[0], pred[1] + else: + return self.vecModel.predict(input), self.classModel.predict(input) + + +class modelDictVal: + + def __init__(self, structure, generator, losses, outVectors, outClasses, epochs=3, lr=0.01, metrics=None, + outVecName=None, outClassName=None, altLabels=True, augmentation=True): + if metrics is None: + metrics = ['accuracy'] + self.structure = structure + self.generator = generator + self.losses = losses + self.outVectors = outVectors + self.outClasses = outClasses + self.epochs = epochs + self.metrics = metrics + self.lr = lr + self.outVecName = outVecName + self.outClassName = outClassName + self.altLabels = altLabels + self.augmentation = augmentation