refactoring
This commit is contained in:
@@ -9,9 +9,8 @@ class ModelFactory:
|
||||
predictionModelInputLayerName = "image"
|
||||
predictionModelOutputLayerName = "dense2"
|
||||
|
||||
def __init__(self, img_height, img_width, char_to_num):
|
||||
self.img_height = img_height
|
||||
self.img_width = img_width
|
||||
def __init__(self, captchaShape, char_to_num):
|
||||
self.captchaShape = captchaShape
|
||||
self.char_to_num = char_to_num
|
||||
|
||||
# see https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet/ResNet101
|
||||
@@ -52,9 +51,9 @@ class ModelFactory:
|
||||
def _createModel(self, baseModelFactory, preprocess_input, name):
|
||||
# Inputs to the model
|
||||
input_image = layers.Input(
|
||||
shape=(self.img_height, self.img_width, 3),
|
||||
name=ModelFactory.predictionModelInputLayerName,
|
||||
dtype="float32")
|
||||
shape = (self.captchaShape.height, self.captchaShape.width, 3),
|
||||
name = ModelFactory.predictionModelInputLayerName,
|
||||
dtype = "float32")
|
||||
labels = layers.Input(name="label", shape=(None,), dtype="float32")
|
||||
|
||||
image = preprocess_input(input_image)
|
||||
|
||||
Reference in New Issue
Block a user