refactoring
This commit is contained in:
@@ -33,34 +33,17 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import tensorflow as tf\n",
|
||||
"from tensorflow import keras\n",
|
||||
"from tensorflow.keras import layers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "QB8QZJPg3MGI"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from captcha.GoogleDriveManager import GoogleDriveManager"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "C3bxU1US2blM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from captcha.CaptchaGenerator import CaptchaGenerator"
|
||||
"from captcha.GoogleDriveManager import GoogleDriveManager\n",
|
||||
"from captcha.CaptchaGenerator import CaptchaGenerator\n",
|
||||
"from captcha.CharNumConverter import CharNumConverter\n",
|
||||
"from captcha.DataSplitter import DataSplitter\n",
|
||||
"from captcha.DatasetFactory import DatasetFactory\n",
|
||||
"from captcha.ModelFactory import ModelFactory\n",
|
||||
"from captcha.PredictionsDecoder import PredictionsDecoder\n",
|
||||
"from captcha.ModelDAO import ModelDAO\n",
|
||||
"from captcha.CaptchaShape import CaptchaShape"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -78,63 +61,6 @@
|
||||
" return images, labels\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "sNJjugG83MGO"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from captcha.CharNumConverter import CharNumConverter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "qxs04OTR3MGP"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class DataSplitter:\n",
|
||||
"\n",
|
||||
" def __init__(self, x, y):\n",
|
||||
" (self.x_train, self.y_train), (x_valid_test, y_valid_test) = DataSplitter._splitData(np.array(x), np.array(y), train_size=0.7)\n",
|
||||
" (self.x_valid, self.y_valid), (self.x_test, self.y_test) = DataSplitter._splitData(x_valid_test, y_valid_test, train_size=0.5)\n",
|
||||
"\n",
|
||||
" def getTrain(self):\n",
|
||||
" return (self.x_train, self.y_train)\n",
|
||||
"\n",
|
||||
" def getValid(self):\n",
|
||||
" return (self.x_valid, self.y_valid)\n",
|
||||
"\n",
|
||||
" def getTest(self):\n",
|
||||
" return (self.x_test, self.y_test)\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def _splitData(x, y, train_size=0.9, shuffle=True):\n",
|
||||
" size = len(x)\n",
|
||||
" indices = np.arange(size)\n",
|
||||
" if shuffle:\n",
|
||||
" np.random.shuffle(indices)\n",
|
||||
" train_samples = int(size * train_size)\n",
|
||||
" x_train, y_train = x[indices[:train_samples]], y[indices[:train_samples]]\n",
|
||||
" x_test, y_test = x[indices[train_samples:]], y[indices[train_samples:]]\n",
|
||||
" return (x_train, y_train), (x_test, y_test)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "dAAACymS3MGR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from captcha.DatasetFactory import DatasetFactory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -209,17 +135,6 @@
|
||||
" ['green' if pred_text == orig_text else 'red' for (pred_text, orig_text) in zip(pred_texts, orig_texts)])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "st13jAjL3MGV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from captcha.ModelFactory import ModelFactory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -231,28 +146,6 @@
|
||||
" print(i, layer.name)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "B7GZlk2_3MGX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from captcha.PredictionsDecoder import PredictionsDecoder"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "8Oa7avYt3MGX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from captcha.ModelDAO import ModelDAO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -355,7 +248,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from captcha.CaptchaShape import CaptchaShape\n",
|
||||
"captchaShape = CaptchaShape()"
|
||||
]
|
||||
},
|
||||
|
||||
28
src/captcha/DataSplitter.py
Normal file
28
src/captcha/DataSplitter.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DataSplitter:
|
||||
|
||||
def __init__(self, x, y):
|
||||
(self.x_train, self.y_train), (x_valid_test, y_valid_test) = DataSplitter._splitData(np.array(x), np.array(y), train_size=0.7)
|
||||
(self.x_valid, self.y_valid), (self.x_test, self.y_test) = DataSplitter._splitData(x_valid_test, y_valid_test, train_size=0.5)
|
||||
|
||||
def getTrain(self):
|
||||
return (self.x_train, self.y_train)
|
||||
|
||||
def getValid(self):
|
||||
return (self.x_valid, self.y_valid)
|
||||
|
||||
def getTest(self):
|
||||
return (self.x_test, self.y_test)
|
||||
|
||||
@staticmethod
|
||||
def _splitData(x, y, train_size=0.9, shuffle=True):
|
||||
size = len(x)
|
||||
indices = np.arange(size)
|
||||
if shuffle:
|
||||
np.random.shuffle(indices)
|
||||
train_samples = int(size * train_size)
|
||||
x_train, y_train = x[indices[:train_samples]], y[indices[:train_samples]]
|
||||
x_test, y_test = x[indices[train_samples:]], y[indices[train_samples:]]
|
||||
return (x_train, y_train), (x_test, y_test)
|
||||
Reference in New Issue
Block a user