refactoring

This commit is contained in:
frankknoll
2023-03-15 17:14:24 +01:00
parent 6330b7b724
commit a9e7bf4833
7 changed files with 68 additions and 82 deletions

View File

@@ -3,9 +3,8 @@ import tensorflow as tf
class DatasetFactory:
def __init__(self, img_height, img_width, char_to_num, batch_size):
self.img_height = img_height
self.img_width = img_width
def __init__(self, captchaShape, char_to_num, batch_size):
self.captchaShape = captchaShape
self.char_to_num = char_to_num
self.batch_size = batch_size
@@ -18,7 +17,7 @@ class DatasetFactory:
def _encode_single_sample(self, img_path, label):
img = tf.io.read_file(img_path)
img = tf.io.decode_jpeg(img, channels=3)
img = tf.image.resize(img, [self.img_height, self.img_width])
img = tf.image.resize(img, [self.captchaShape.height, self.captchaShape.width])
# Map the characters in label to numbers
label = self.char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
# Return a dict as our model is expecting two inputs