refactoring
This commit is contained in:
@@ -7,21 +7,19 @@ from captcha.DatasetFactory import DatasetFactory
|
||||
import numpy as np
|
||||
from tensorflow import keras
|
||||
|
||||
# FK-TODO: DRY with captcha.ipynb
|
||||
img_width = 241
|
||||
img_height = 62
|
||||
|
||||
class CaptchaReader:
|
||||
|
||||
def __init__(self, modelFilepath):
|
||||
def __init__(self, modelFilepath, captchaShape):
|
||||
self.modelFilepath = modelFilepath
|
||||
self.captchaShape = captchaShape
|
||||
|
||||
def getTextInCaptchaImage(self, captchaImageFile):
|
||||
# FK-TODO: refactor
|
||||
modelDAO = ModelDAO(inColab = False)
|
||||
model = modelDAO.loadModel(self.modelFilepath)
|
||||
prediction_model = ModelFactory.createPredictionModel(model)
|
||||
charNumConverter = CharNumConverter(CaptchaGenerator.characters)
|
||||
datasetFactory = DatasetFactory(img_height, img_width, charNumConverter.char_to_num, batch_size = 64)
|
||||
datasetFactory = DatasetFactory(self.captchaShape,charNumConverter.char_to_num, batch_size = 64)
|
||||
batchImages = self._asSingleSampleBatch(datasetFactory._encode_single_sample(captchaImageFile, 'dummy')['image'])
|
||||
preds = prediction_model.predict(batchImages)
|
||||
predictionsDecoder = PredictionsDecoder(CaptchaGenerator.captchaLength, charNumConverter.num_to_char)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
from captcha.CaptchaReader import CaptchaReader
|
||||
from captcha.CaptchaShape import CaptchaShape
|
||||
import os
|
||||
|
||||
class CaptchaReaderTest(unittest.TestCase):
|
||||
@@ -10,7 +11,9 @@ class CaptchaReaderTest(unittest.TestCase):
|
||||
def test_getTextInCaptchaImage(self):
|
||||
# Given
|
||||
textInCaptchaImage = '1Ad47a'
|
||||
captchaReader = CaptchaReader(modelFilepath = f'{self.working_directory}/MobileNetV3Small')
|
||||
captchaReader = CaptchaReader(
|
||||
modelFilepath = f'{self.working_directory}/MobileNetV3Small',
|
||||
captchaShape = CaptchaShape())
|
||||
|
||||
# When
|
||||
textInCaptchaImageActual = captchaReader.getTextInCaptchaImage(f'{self.working_directory}/captchas/VAERS/{textInCaptchaImage}.jpeg')
|
||||
|
||||
5
src/captcha/CaptchaShape.py
Normal file
5
src/captcha/CaptchaShape.py
Normal file
@@ -0,0 +1,5 @@
|
||||
class CaptchaShape:
|
||||
|
||||
def __init__(self):
|
||||
self.width = 241
|
||||
self.height = 62
|
||||
@@ -3,9 +3,8 @@ import tensorflow as tf
|
||||
|
||||
class DatasetFactory:
|
||||
|
||||
def __init__(self, img_height, img_width, char_to_num, batch_size):
|
||||
self.img_height = img_height
|
||||
self.img_width = img_width
|
||||
def __init__(self, captchaShape, char_to_num, batch_size):
|
||||
self.captchaShape = captchaShape
|
||||
self.char_to_num = char_to_num
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -18,7 +17,7 @@ class DatasetFactory:
|
||||
def _encode_single_sample(self, img_path, label):
|
||||
img = tf.io.read_file(img_path)
|
||||
img = tf.io.decode_jpeg(img, channels=3)
|
||||
img = tf.image.resize(img, [self.img_height, self.img_width])
|
||||
img = tf.image.resize(img, [self.captchaShape.height, self.captchaShape.width])
|
||||
# Map the characters in label to numbers
|
||||
label = self.char_to_num(tf.strings.unicode_split(label, input_encoding="UTF-8"))
|
||||
# Return a dict as our model is expecting two inputs
|
||||
|
||||
@@ -9,9 +9,8 @@ class ModelFactory:
|
||||
predictionModelInputLayerName = "image"
|
||||
predictionModelOutputLayerName = "dense2"
|
||||
|
||||
def __init__(self, img_height, img_width, char_to_num):
|
||||
self.img_height = img_height
|
||||
self.img_width = img_width
|
||||
def __init__(self, captchaShape, char_to_num):
|
||||
self.captchaShape = captchaShape
|
||||
self.char_to_num = char_to_num
|
||||
|
||||
# see https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet/ResNet101
|
||||
@@ -52,9 +51,9 @@ class ModelFactory:
|
||||
def _createModel(self, baseModelFactory, preprocess_input, name):
|
||||
# Inputs to the model
|
||||
input_image = layers.Input(
|
||||
shape=(self.img_height, self.img_width, 3),
|
||||
name=ModelFactory.predictionModelInputLayerName,
|
||||
dtype="float32")
|
||||
shape = (self.captchaShape.height, self.captchaShape.width, 3),
|
||||
name = ModelFactory.predictionModelInputLayerName,
|
||||
dtype = "float32")
|
||||
labels = layers.Input(name="label", shape=(None,), dtype="float32")
|
||||
|
||||
image = preprocess_input(input_image)
|
||||
|
||||
@@ -1,737 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "UNKC5YSEIS_d"
|
||||
},
|
||||
"source": [
|
||||
"# Captchas\n",
|
||||
"\n",
|
||||
"**see:** https://keras.io/examples/vision/captcha_ocr/<br>\n",
|
||||
"**original:** https://colab.research.google.com/drive/1Olw2KMHfPlnGaYuzffl2zb6D1etlBGZf?usp=sharing<br>\n",
|
||||
"**View Github version in Colab:** <a href=\"https://colab.research.google.com/github/KnollFrank/2captcha-worker-assistant-server/blob/master/captcha_ocr_trainAndSaveModel_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a><br>\n",
|
||||
"**paper:** Simple and Easy: Transfer Learning-Based Attacks to Text CAPTCHA<br>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "wRUsVuIiIS_s"
|
||||
},
|
||||
"source": [
|
||||
"## Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"id": "zZSwQragIS_v"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-03-15 10:46:02.303787: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA\n",
|
||||
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
||||
"/home/frankknoll/.local/lib/python3.9/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.5\n",
|
||||
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"import tensorflow as tf\n",
|
||||
"from tensorflow import keras\n",
|
||||
"from tensorflow.keras import layers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"id": "QB8QZJPg3MGI"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from GoogleDriveManager import GoogleDriveManager"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"id": "C3bxU1US2blM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from CaptchaGenerator import CaptchaGenerator"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"id": "0DZfMrbe3MGN"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def getImagesAndLabels(dataDir):\n",
|
||||
" fileSuffix = \".jpeg\"\n",
|
||||
" images = sorted(list(map(str, list(dataDir.glob(\"*\" + fileSuffix)))))\n",
|
||||
" labels = [image.split(os.path.sep)[-1].split(fileSuffix)[0] for image in images]\n",
|
||||
" return images, labels\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"id": "sNJjugG83MGO"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from CharNumConverter import CharNumConverter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"id": "qxs04OTR3MGP"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class DataSplitter:\n",
|
||||
"\n",
|
||||
" def __init__(self, x, y):\n",
|
||||
" (self.x_train, self.y_train), (x_valid_test, y_valid_test) = DataSplitter._splitData(np.array(x), np.array(y), train_size=0.7)\n",
|
||||
" (self.x_valid, self.y_valid), (self.x_test, self.y_test) = DataSplitter._splitData(x_valid_test, y_valid_test, train_size=0.5)\n",
|
||||
"\n",
|
||||
" def getTrain(self):\n",
|
||||
" return (self.x_train, self.y_train)\n",
|
||||
"\n",
|
||||
" def getValid(self):\n",
|
||||
" return (self.x_valid, self.y_valid)\n",
|
||||
"\n",
|
||||
" def getTest(self):\n",
|
||||
" return (self.x_test, self.y_test)\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def _splitData(x, y, train_size=0.9, shuffle=True):\n",
|
||||
" size = len(x)\n",
|
||||
" indices = np.arange(size)\n",
|
||||
" if shuffle:\n",
|
||||
" np.random.shuffle(indices)\n",
|
||||
" train_samples = int(size * train_size)\n",
|
||||
" x_train, y_train = x[indices[:train_samples]], y[indices[:train_samples]]\n",
|
||||
" x_test, y_test = x[indices[train_samples:]], y[indices[train_samples:]]\n",
|
||||
" return (x_train, y_train), (x_test, y_test)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"id": "dAAACymS3MGR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from DatasetFactory import DatasetFactory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"id": "kdL9_t03Mf3t"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def getTrainValidationTestDatasets(dataDir, datasetFactory):\n",
|
||||
" images, labels = getImagesAndLabels(dataDir)\n",
|
||||
" print(\"Number of images found:\", len(images))\n",
|
||||
" print(\"Characters:\", CaptchaGenerator.characters)\n",
|
||||
"\n",
|
||||
" dataSplitter = DataSplitter(images, labels)\n",
|
||||
" \n",
|
||||
" return (\n",
|
||||
" datasetFactory.createDataset(*dataSplitter.getTrain()),\n",
|
||||
" datasetFactory.createDataset(*dataSplitter.getValid()),\n",
|
||||
" datasetFactory.createDataset(*dataSplitter.getTest())\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"id": "FqVSEuZp3MGT"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import math\n",
|
||||
"\n",
|
||||
"def displayImagesInGrid(numGridCols, images, titles, titleColors):\n",
|
||||
" assert len(images) == len(titles) == len(titleColors)\n",
|
||||
" images = [image.numpy().astype(np.uint8) for image in images]\n",
|
||||
" numGridRows = math.ceil(len(images) / numGridCols)\n",
|
||||
" _, axs = plt.subplots(numGridRows, numGridCols, figsize=(15, 5))\n",
|
||||
" for row in range(numGridRows):\n",
|
||||
" for col in range(numGridCols):\n",
|
||||
" ax = axs[row, col]\n",
|
||||
" ax.axis(\"off\")\n",
|
||||
" i = row * numGridCols + col\n",
|
||||
" if(i < len(images)):\n",
|
||||
" ax.imshow(images[i])\n",
|
||||
" ax.set_title(titles[i], color=titleColors[i])\n",
|
||||
" plt.show()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"id": "apkeCHhP3MGU"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def display16Predictions(model, dataset, predictionsDecoder):\n",
|
||||
" for batch in dataset.take(1):\n",
|
||||
" numPredictions2Display = 16\n",
|
||||
" batch_images = batch[\"image\"][:numPredictions2Display]\n",
|
||||
" batch_labels = batch[\"label\"][:numPredictions2Display]\n",
|
||||
"\n",
|
||||
" preds = model.predict(batch_images)\n",
|
||||
" pred_texts = predictionsDecoder.decode_batch_predictions(preds)\n",
|
||||
" orig_texts = predictionsDecoder.asStrings(batch_labels)\n",
|
||||
"\n",
|
||||
" displayImagesInGrid(\n",
|
||||
" 4,\n",
|
||||
" batch_images,\n",
|
||||
" [f\"Prediction/Truth: {pred_text}/{orig_text}\" for (pred_text, orig_text) in zip(pred_texts, orig_texts)],\n",
|
||||
" ['green' if pred_text == orig_text else 'red' for (pred_text, orig_text) in zip(pred_texts, orig_texts)])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"id": "st13jAjL3MGV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ModelFactory import ModelFactory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def printLayers(model):\n",
|
||||
" for i, layer in enumerate(model.layers):\n",
|
||||
" print(i, layer.name)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"id": "B7GZlk2_3MGX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from PredictionsDecoder import PredictionsDecoder"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"id": "8Oa7avYt3MGX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from ModelDAO import ModelDAO"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"id": "S3X_SslH3MGY"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# FK-TODO: entferne die getAccuracy()-Methode. Implementiere stattdessen https://stackoverflow.com/questions/37657260/how-to-implement-custom-metric-in-keras oder https://keras.io/api/metrics/#custom-metrics\n",
|
||||
"def getAccuracy(dataset, prediction_model, ctc_decode):\n",
|
||||
" accuracy = tf.keras.metrics.Accuracy()\n",
|
||||
"\n",
|
||||
" for batch in dataset:\n",
|
||||
" accuracy.update_state(batch[\"label\"], ctc_decode(prediction_model.predict(batch[\"image\"], verbose=0)))\n",
|
||||
"\n",
|
||||
" return accuracy.result().numpy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "94755hrNMf3w"
|
||||
},
|
||||
"source": [
|
||||
"## Preparation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"id": "NZrKXF6P3MGY"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inColab = 'google.colab' in str(get_ipython())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"id": "7EsmTaF03MGZ"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if inColab:\n",
|
||||
" GoogleDriveManager.mount()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"id": "S_4hl4S4BmZK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if inColab:\n",
|
||||
" !cp {GoogleDriveManager._baseFolder}/captchas.zip .\n",
|
||||
" !unzip captchas.zip"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"id": "WmUghcQaMf3y"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"modelDAO = ModelDAO(inColab)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"id": "cpxO7yGAMf3z"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-03-15 10:41:54.085280: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 FMA\n",
|
||||
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
||||
"2023-03-15 10:41:54.089954: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"charNumConverter = CharNumConverter(CaptchaGenerator.characters)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"id": "tVb5nDFTMf3z"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"predictionsDecoder = PredictionsDecoder(CaptchaGenerator.captchaLength, charNumConverter.num_to_char)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"id": "t1wzlHQ-Mf3z"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"(img_width, img_height) = (241, 62)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"id": "s35OUslsMf30"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"datasetFactory = DatasetFactory(img_height, img_width, charNumConverter.char_to_num, batch_size = 64)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "lsLuSi7h3MGZ"
|
||||
},
|
||||
"source": [
|
||||
"## Create And Train Base Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"id": "oRcemcbG3MGa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if inColab:\n",
|
||||
" !sudo apt install ttf-mscorefonts-installer\n",
|
||||
" !sudo fc-cache -f\n",
|
||||
" !fc-match Arial"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"id": "P7myCt7e2h6A"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# \"We generate 200,000 images for base model pre-training\"\n",
|
||||
"captchaGenerator = CaptchaGenerator(\n",
|
||||
" numCaptchas = 50, # 50, # 200000,\n",
|
||||
" dataDir = Path(\"captchas/generated/VAERS/\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"id": "j9apYsyI3MGb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"captchaGenerator.createAndSaveCaptchas()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "AgN4skCkMf31"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset, validation_dataset, test_dataset = getTrainValidationTestDatasets(captchaGenerator.dataDir, datasetFactory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "RcgWHXVSNsa7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for batch in train_dataset.take(1):\n",
|
||||
" numImages2Display = 16\n",
|
||||
" images = batch[\"image\"][:numImages2Display]\n",
|
||||
" labels = batch[\"label\"][:numImages2Display]\n",
|
||||
" displayImagesInGrid(4, images, predictionsDecoder.asStrings(labels), ['black'] * len(labels))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "V8ELN-qJ3MGe"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"modelFactory = ModelFactory(img_height, img_width, charNumConverter.char_to_num)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "zDoFYKM2hdEW"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = modelFactory.createMobileNetV3Small()\n",
|
||||
"model.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ltXYrpjIITAb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# \"the success rates became stable after the base-model training epochs exceeded 20\"\n",
|
||||
"history = model.fit(\n",
|
||||
" train_dataset,\n",
|
||||
" validation_data=validation_dataset,\n",
|
||||
" epochs=20)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fPG-Yl1SJfF7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"modelDAO.saveModel(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "NnNHMtIGITAe"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prediction_model = ModelFactory.createPredictionModel(model)\n",
|
||||
"prediction_model.summary()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "YW651ztD8sKI"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"display16Predictions(prediction_model, test_dataset, predictionsDecoder)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "V5gqMBIwBmZU"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"getAccuracy(test_dataset, prediction_model, predictionsDecoder.ctc_decode)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "UYxiYTH9BmZU"
|
||||
},
|
||||
"source": [
|
||||
"## Transfer learning"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "WV8IS4KrBmZU"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# \"we collected 1,500 real CAPTCHAs from the websites. Note that only 500 of them are used for fine-tuning, and another 1,000 are applied to calculate the test accuracy\"\n",
|
||||
"# FK-TODO: lade das pre-trainierte model und trainiere es mit 500 real-world-Daten aus dem Ordner captchas/VAERS/, die restlichen 540 (es sollten nach obigem Zitat aber 1,000 sein) sind dann die Test-Daten.\n",
|
||||
"# see https://keras.io/guides/transfer_learning/\n",
|
||||
"# see https://www.tensorflow.org/tutorials/images/transfer_learning\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"modelName, numTrainableLayers = 'MobileNetV3Small', 104\n",
|
||||
"# modelName, numTrainableLayers = 'ResNet101', 348"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "D7ogEQmB3MGj"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = modelDAO.loadModel(modelName)\n",
|
||||
"model.summary(show_trainable=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "gbPigogKNFrD"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# printLayers(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "59quw8o3Mf34"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.trainable = True\n",
|
||||
"for layer in model.layers[:numTrainableLayers]:\n",
|
||||
" layer.trainable = False"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "acGczax3Mf34"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.summary(show_trainable=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "q7_MjUO0BmZV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset, validation_dataset, test_dataset = getTrainValidationTestDatasets(Path(\"captchas/VAERS/\"), datasetFactory)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "dZsCpibkBmZX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# \"The model is optimized by a stochastic gradient descent (SGD) strategy with an initial learning rate of 0.004, weight decay of 0.00004 and momentum of 0.9.\"\n",
|
||||
"from tensorflow.keras.optimizers import SGD\n",
|
||||
"# model.compile(optimizer=SGD(learning_rate=0.0001, momentum=0.9))\n",
|
||||
"model.compile(optimizer='adam')\n",
|
||||
"\n",
|
||||
"# \"Therefore, in our experiments, we chose 1 epoch for the fine-tuning stage.\"\n",
|
||||
"history = model.fit(\n",
|
||||
" train_dataset,\n",
|
||||
" validation_data=validation_dataset,\n",
|
||||
" epochs=20)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "TRbJigbH3MGl"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prediction_model = ModelFactory.createPredictionModel(model)\n",
|
||||
"prediction_model.summary()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rPszfhJ4BmZX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"getAccuracy(test_dataset, prediction_model, predictionsDecoder.ctc_decode)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"id": "hfmRY1qC7aVV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"display16Predictions(prediction_model, test_dataset, predictionsDecoder)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"modelDAO.saveModel(model)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "captcha.ipynb",
|
||||
"private_outputs": true,
|
||||
"provenance": []
|
||||
},
|
||||
"gpuClass": "standard",
|
||||
"kernelspec": {
|
||||
"display_name": "howbadismybatch-venv-kernel",
|
||||
"language": "python",
|
||||
"name": "howbadismybatch-venv-kernel"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
Reference in New Issue
Block a user