refactoring

This commit is contained in:
frankknoll
2022-02-08 13:15:25 +01:00
parent 1f1bdd9293
commit b620452eaa

View File

@@ -300,12 +300,40 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "09e6b511", "id": "c40bd0f0",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import pycountry\n", "import pycountry\n",
"\n", "\n",
"class CountryColumnAdder:\n",
" \n",
" @staticmethod\n",
" def addCountryColumn(dataFrame, countryColumnName):\n",
" dataFrame[countryColumnName] = dataFrame.apply(\n",
" lambda row:\n",
" CountryColumnAdder._getCountryNameOfSplttypeOrDefault(\n",
" splttype = row['SPLTTYPE'],\n",
" default = 'Unknown Country'),\n",
" axis = 'columns')\n",
" return dataFrame.astype({countryColumnName: \"string\"})\n",
"\n",
" @staticmethod\n",
" def _getCountryNameOfSplttypeOrDefault(splttype, default):\n",
" if not isinstance(splttype, str):\n",
" return default\n",
" \n",
" country = pycountry.countries.get(alpha_2 = splttype[:2])\n",
" return country.name if country is not None else default"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "09e6b511",
"metadata": {},
"outputs": [],
"source": [
"class InternationalLotTableFactory:\n", "class InternationalLotTableFactory:\n",
" \n", " \n",
" @staticmethod\n", " @staticmethod\n",
@@ -323,33 +351,15 @@
" @staticmethod\n", " @staticmethod\n",
" def _createInternationalLotTable(dataFrame):\n", " def _createInternationalLotTable(dataFrame):\n",
" countryColumnName = 'Country'\n", " countryColumnName = 'Country'\n",
" dataFrame = InternationalLotTableFactory._addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n", " dataFrame = CountryColumnAdder.addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n",
" return DoseTableFactory._createDoseTable(dataFrame.groupby(dataFrame[countryColumnName]))\n", " return DoseTableFactory._createDoseTable(dataFrame.groupby(dataFrame[countryColumnName]))\n",
"\n", "\n",
" @staticmethod\n", " @staticmethod\n",
" def _createBatchCodeTableByCountry(dataFrame : pd.DataFrame, country):\n", " def _createBatchCodeTableByCountry(dataFrame : pd.DataFrame, country):\n",
" countryColumnName = 'Country'\n", " countryColumnName = 'Country'\n",
" dataFrame = InternationalLotTableFactory._addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n", " dataFrame = CountryColumnAdder.addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n",
" dataFrame = DataFrameFilter().filterByCountry(dataFrame, country = country, countryColumnName = countryColumnName)\n", " dataFrame = DataFrameFilter().filterByCountry(dataFrame, country = country, countryColumnName = countryColumnName)\n",
" return DoseTableFactory._createDoseTable(dataFrame.groupby('VAX_LOT'))\n", " return DoseTableFactory._createDoseTable(dataFrame.groupby('VAX_LOT'))\n"
"\n",
" @staticmethod\n",
" def _addCountryColumn(dataFrame, countryColumnName):\n",
" dataFrame[countryColumnName] = dataFrame.apply(\n",
" lambda row:\n",
" InternationalLotTableFactory._getCountryNameOfSplttypeOrDefault(\n",
" splttype = row['SPLTTYPE'],\n",
" default = 'Unknown Country'),\n",
" axis = 'columns')\n",
" return dataFrame.astype({countryColumnName: \"string\"})\n",
"\n",
" @staticmethod\n",
" def _getCountryNameOfSplttypeOrDefault(splttype, default):\n",
" if not isinstance(splttype, str):\n",
" return default\n",
" \n",
" country = pycountry.countries.get(alpha_2 = splttype[:2])\n",
" return country.name if country is not None else default\n"
] ]
}, },
{ {