refactoring

This commit is contained in:
frankknoll
2023-03-16 14:23:23 +01:00
parent 7dfc6f373e
commit f1ad511850
3 changed files with 93 additions and 36 deletions

View File

@@ -23,6 +23,61 @@
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.argv = sys.argv[:1]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def isInColab():\n",
" try:\n",
" import colab\n",
" return True\n",
" except:\n",
" return False"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"inColab = isInColab()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if inColab:\n",
" branch = 'read-captcha'\n",
" !git clone https://github.com/KnollFrank/HowBadIsMyBatch.git\n",
" !cd HowBadIsMyBatch; git checkout $branch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if inColab:\n",
" sys.path.insert(0, '/content/HowBadIsMyBatch/src')"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -46,6 +101,36 @@
"from captcha.CaptchaShape import CaptchaShape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"class GoogleDriveManager:\n",
" \n",
" _googleDriveFolder = Path('/content/gdrive')\n",
" _baseFolder = _googleDriveFolder / 'MyDrive/CAPTCHA/models/'\n",
"\n",
" @staticmethod\n",
" def mount():\n",
" from google.colab import drive\n",
" drive.mount(str(GoogleDriveManager._googleDriveFolder))\n",
"\n",
" @staticmethod\n",
" def uploadFolderToGoogleDrive(folder):\n",
" !zip -r {folder}.zip {folder}/\n",
" !cp {folder}.zip {GoogleDriveManager._baseFolder}\n",
"\n",
" @staticmethod\n",
" def downloadFolderFromGoogleDrive(folder):\n",
" !cp {GoogleDriveManager._baseFolder}/{folder}.zip .\n",
" !rm -rf {folder}\n",
" !unzip {folder}.zip\n"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -173,17 +258,6 @@
"## Preparation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NZrKXF6P3MGY"
},
"outputs": [],
"source": [
"inColab = 'google.colab' in str(get_ipython())"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -196,19 +270,6 @@
" GoogleDriveManager.mount()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "S_4hl4S4BmZK"
},
"outputs": [],
"source": [
"if inColab:\n",
" !cp {GoogleDriveManager._baseFolder}/captchas.zip .\n",
" !unzip captchas.zip"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -217,7 +278,7 @@
},
"outputs": [],
"source": [
"modelDAO = ModelDAO(inColab)\n",
"modelDAO = ModelDAO()\n",
"charNumConverter = CharNumConverter(CaptchaGenerator.characters)\n",
"predictionsDecoder = PredictionsDecoder(CaptchaGenerator.captchaLength, charNumConverter.num_to_char)\n",
"captchaShape = CaptchaShape()\n",
@@ -334,7 +395,9 @@
},
"outputs": [],
"source": [
"modelDAO.saveModel(model)"
"modelDAO.saveModel(model)\n",
"if inColab:\n",
" GoogleDriveManager.uploadFolderToGoogleDrive(model.name)"
]
},
{
@@ -525,7 +588,9 @@
"metadata": {},
"outputs": [],
"source": [
"modelDAO.saveModel(model)"
"modelDAO.saveModel(model)\n",
"if inColab:\n",
" GoogleDriveManager.uploadFolderToGoogleDrive(model.name)"
]
}
],

View File

@@ -28,4 +28,4 @@ class CaptchaReader:
return PredictionsDecoder(CaptchaGenerator.captchaLength, CharNumConverter(CaptchaGenerator.characters).num_to_char).decode_batch_predictions(preds)
def _createPredictionModel(self):
return ModelFactory.createPredictionModel(ModelDAO(inColab=False).loadModel(self.modelFilepath))
return ModelFactory.createPredictionModel(ModelDAO().loadModel(self.modelFilepath))

View File

@@ -1,20 +1,12 @@
from tensorflow import keras
from captcha.GoogleDriveManager import GoogleDriveManager
import shutil
class ModelDAO:
def __init__(self, inColab):
self.inColab = inColab
def saveModel(self, model):
shutil.rmtree(model.name, ignore_errors = True)
model.save(model.name)
if self.inColab:
GoogleDriveManager.uploadFolderToGoogleDrive(model.name)
def loadModel(self, modelFilepath):
if self.inColab:
GoogleDriveManager.downloadFolderFromGoogleDrive(modelFilepath)
return keras.models.load_model(modelFilepath)