refactoring

This commit is contained in:
frankknoll
2023-03-15 17:28:11 +01:00
parent a9e7bf4833
commit 5379644a89
2 changed files with 37 additions and 117 deletions

View File

@@ -33,34 +33,17 @@
"source": [
"import os\n",
"import numpy as np\n",
"\n",
"from pathlib import Path\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QB8QZJPg3MGI"
},
"outputs": [],
"source": [
"from captcha.GoogleDriveManager import GoogleDriveManager"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "C3bxU1US2blM"
},
"outputs": [],
"source": [
"from captcha.CaptchaGenerator import CaptchaGenerator"
"from captcha.GoogleDriveManager import GoogleDriveManager\n",
"from captcha.CaptchaGenerator import CaptchaGenerator\n",
"from captcha.CharNumConverter import CharNumConverter\n",
"from captcha.DataSplitter import DataSplitter\n",
"from captcha.DatasetFactory import DatasetFactory\n",
"from captcha.ModelFactory import ModelFactory\n",
"from captcha.PredictionsDecoder import PredictionsDecoder\n",
"from captcha.ModelDAO import ModelDAO\n",
"from captcha.CaptchaShape import CaptchaShape"
]
},
{
@@ -78,63 +61,6 @@
" return images, labels\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sNJjugG83MGO"
},
"outputs": [],
"source": [
"from captcha.CharNumConverter import CharNumConverter"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qxs04OTR3MGP"
},
"outputs": [],
"source": [
"class DataSplitter:\n",
"\n",
" def __init__(self, x, y):\n",
" (self.x_train, self.y_train), (x_valid_test, y_valid_test) = DataSplitter._splitData(np.array(x), np.array(y), train_size=0.7)\n",
" (self.x_valid, self.y_valid), (self.x_test, self.y_test) = DataSplitter._splitData(x_valid_test, y_valid_test, train_size=0.5)\n",
"\n",
" def getTrain(self):\n",
" return (self.x_train, self.y_train)\n",
"\n",
" def getValid(self):\n",
" return (self.x_valid, self.y_valid)\n",
"\n",
" def getTest(self):\n",
" return (self.x_test, self.y_test)\n",
"\n",
" @staticmethod\n",
" def _splitData(x, y, train_size=0.9, shuffle=True):\n",
" size = len(x)\n",
" indices = np.arange(size)\n",
" if shuffle:\n",
" np.random.shuffle(indices)\n",
" train_samples = int(size * train_size)\n",
" x_train, y_train = x[indices[:train_samples]], y[indices[:train_samples]]\n",
" x_test, y_test = x[indices[train_samples:]], y[indices[train_samples:]]\n",
" return (x_train, y_train), (x_test, y_test)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dAAACymS3MGR"
},
"outputs": [],
"source": [
"from captcha.DatasetFactory import DatasetFactory"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -209,17 +135,6 @@
" ['green' if pred_text == orig_text else 'red' for (pred_text, orig_text) in zip(pred_texts, orig_texts)])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "st13jAjL3MGV"
},
"outputs": [],
"source": [
"from captcha.ModelFactory import ModelFactory"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -231,28 +146,6 @@
" print(i, layer.name)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "B7GZlk2_3MGX"
},
"outputs": [],
"source": [
"from captcha.PredictionsDecoder import PredictionsDecoder"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8Oa7avYt3MGX"
},
"outputs": [],
"source": [
"from captcha.ModelDAO import ModelDAO"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -355,7 +248,6 @@
"metadata": {},
"outputs": [],
"source": [
"from captcha.CaptchaShape import CaptchaShape\n",
"captchaShape = CaptchaShape()"
]
},

View File

@@ -0,0 +1,28 @@
import numpy as np
class DataSplitter:
def __init__(self, x, y):
(self.x_train, self.y_train), (x_valid_test, y_valid_test) = DataSplitter._splitData(np.array(x), np.array(y), train_size=0.7)
(self.x_valid, self.y_valid), (self.x_test, self.y_test) = DataSplitter._splitData(x_valid_test, y_valid_test, train_size=0.5)
def getTrain(self):
return (self.x_train, self.y_train)
def getValid(self):
return (self.x_valid, self.y_valid)
def getTest(self):
return (self.x_test, self.y_test)
@staticmethod
def _splitData(x, y, train_size=0.9, shuffle=True):
size = len(x)
indices = np.arange(size)
if shuffle:
np.random.shuffle(indices)
train_samples = int(size * train_size)
x_train, y_train = x[indices[:train_samples]], y[indices[:train_samples]]
x_test, y_test = x[indices[train_samples:]], y[indices[train_samples:]]
return (x_train, y_train), (x_test, y_test)