From 5379644a89b4ec864069340aeb0386a5dbe593fe Mon Sep 17 00:00:00 2001 From: frankknoll Date: Wed, 15 Mar 2023 17:28:11 +0100 Subject: [PATCH] refactoring --- src/captcha.ipynb | 126 +++--------------------------------- src/captcha/DataSplitter.py | 28 ++++++++ 2 files changed, 37 insertions(+), 117 deletions(-) create mode 100644 src/captcha/DataSplitter.py diff --git a/src/captcha.ipynb b/src/captcha.ipynb index c6ea0f2921a..591152fd8bc 100644 --- a/src/captcha.ipynb +++ b/src/captcha.ipynb @@ -33,34 +33,17 @@ "source": [ "import os\n", "import numpy as np\n", - "\n", "from pathlib import Path\n", - "\n", "import tensorflow as tf\n", - "from tensorflow import keras\n", - "from tensorflow.keras import layers" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "QB8QZJPg3MGI" - }, - "outputs": [], - "source": [ - "from captcha.GoogleDriveManager import GoogleDriveManager" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "C3bxU1US2blM" - }, - "outputs": [], - "source": [ - "from captcha.CaptchaGenerator import CaptchaGenerator" + "from captcha.GoogleDriveManager import GoogleDriveManager\n", + "from captcha.CaptchaGenerator import CaptchaGenerator\n", + "from captcha.CharNumConverter import CharNumConverter\n", + "from captcha.DataSplitter import DataSplitter\n", + "from captcha.DatasetFactory import DatasetFactory\n", + "from captcha.ModelFactory import ModelFactory\n", + "from captcha.PredictionsDecoder import PredictionsDecoder\n", + "from captcha.ModelDAO import ModelDAO\n", + "from captcha.CaptchaShape import CaptchaShape" ] }, { @@ -78,63 +61,6 @@ " return images, labels\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sNJjugG83MGO" - }, - "outputs": [], - "source": [ - "from captcha.CharNumConverter import CharNumConverter" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "qxs04OTR3MGP" - }, - "outputs": [], - "source": [ - "class DataSplitter:\n", - "\n", - " def __init__(self, x, y):\n", - " (self.x_train, self.y_train), (x_valid_test, y_valid_test) = DataSplitter._splitData(np.array(x), np.array(y), train_size=0.7)\n", - " (self.x_valid, self.y_valid), (self.x_test, self.y_test) = DataSplitter._splitData(x_valid_test, y_valid_test, train_size=0.5)\n", - "\n", - " def getTrain(self):\n", - " return (self.x_train, self.y_train)\n", - "\n", - " def getValid(self):\n", - " return (self.x_valid, self.y_valid)\n", - "\n", - " def getTest(self):\n", - " return (self.x_test, self.y_test)\n", - "\n", - " @staticmethod\n", - " def _splitData(x, y, train_size=0.9, shuffle=True):\n", - " size = len(x)\n", - " indices = np.arange(size)\n", - " if shuffle:\n", - " np.random.shuffle(indices)\n", - " train_samples = int(size * train_size)\n", - " x_train, y_train = x[indices[:train_samples]], y[indices[:train_samples]]\n", - " x_test, y_test = x[indices[train_samples:]], y[indices[train_samples:]]\n", - " return (x_train, y_train), (x_test, y_test)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "dAAACymS3MGR" - }, - "outputs": [], - "source": [ - "from captcha.DatasetFactory import DatasetFactory" - ] - }, { "cell_type": "code", "execution_count": null, @@ -209,17 +135,6 @@ " ['green' if pred_text == orig_text else 'red' for (pred_text, orig_text) in zip(pred_texts, orig_texts)])" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "st13jAjL3MGV" - }, - "outputs": [], - "source": [ - "from captcha.ModelFactory import ModelFactory" - ] - }, { "cell_type": "code", "execution_count": null, @@ -231,28 +146,6 @@ " print(i, layer.name)\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "B7GZlk2_3MGX" - }, - "outputs": [], - "source": [ - "from captcha.PredictionsDecoder import PredictionsDecoder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8Oa7avYt3MGX" - }, - "outputs": [], - "source": [ - "from captcha.ModelDAO import ModelDAO" - ] - }, { "cell_type": "code", "execution_count": null, @@ -355,7 +248,6 @@ "metadata": {}, "outputs": [], "source": [ - "from captcha.CaptchaShape import CaptchaShape\n", "captchaShape = CaptchaShape()" ] }, diff --git a/src/captcha/DataSplitter.py b/src/captcha/DataSplitter.py new file mode 100644 index 00000000000..d4f619ea982 --- /dev/null +++ b/src/captcha/DataSplitter.py @@ -0,0 +1,28 @@ +import numpy as np + + +class DataSplitter: + + def __init__(self, x, y): + (self.x_train, self.y_train), (x_valid_test, y_valid_test) = DataSplitter._splitData(np.array(x), np.array(y), train_size=0.7) + (self.x_valid, self.y_valid), (self.x_test, self.y_test) = DataSplitter._splitData(x_valid_test, y_valid_test, train_size=0.5) + + def getTrain(self): + return (self.x_train, self.y_train) + + def getValid(self): + return (self.x_valid, self.y_valid) + + def getTest(self): + return (self.x_test, self.y_test) + + @staticmethod + def _splitData(x, y, train_size=0.9, shuffle=True): + size = len(x) + indices = np.arange(size) + if shuffle: + np.random.shuffle(indices) + train_samples = int(size * train_size) + x_train, y_train = x[indices[:train_samples]], y[indices[:train_samples]] + x_test, y_test = x[indices[train_samples:]], y[indices[train_samples:]] + return (x_train, y_train), (x_test, y_test)