diff --git a/HowBadIsMyBatch.ipynb b/HowBadIsMyBatch.ipynb index 8423bf44a55..6ba38b4c618 100644 --- a/HowBadIsMyBatch.ipynb +++ b/HowBadIsMyBatch.ipynb @@ -300,12 +300,40 @@ { "cell_type": "code", "execution_count": null, - "id": "09e6b511", + "id": "c40bd0f0", "metadata": {}, "outputs": [], "source": [ "import pycountry\n", "\n", + "class CountryColumnAdder:\n", + " \n", + " @staticmethod\n", + " def addCountryColumn(dataFrame, countryColumnName):\n", + " dataFrame[countryColumnName] = dataFrame.apply(\n", + " lambda row:\n", + " CountryColumnAdder._getCountryNameOfSplttypeOrDefault(\n", + " splttype = row['SPLTTYPE'],\n", + " default = 'Unknown Country'),\n", + " axis = 'columns')\n", + " return dataFrame.astype({countryColumnName: \"string\"})\n", + "\n", + " @staticmethod\n", + " def _getCountryNameOfSplttypeOrDefault(splttype, default):\n", + " if not isinstance(splttype, str):\n", + " return default\n", + " \n", + " country = pycountry.countries.get(alpha_2 = splttype[:2])\n", + " return country.name if country is not None else default" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09e6b511", + "metadata": {}, + "outputs": [], + "source": [ "class InternationalLotTableFactory:\n", " \n", " @staticmethod\n", @@ -323,33 +351,15 @@ " @staticmethod\n", " def _createInternationalLotTable(dataFrame):\n", " countryColumnName = 'Country'\n", - " dataFrame = InternationalLotTableFactory._addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n", + " dataFrame = CountryColumnAdder.addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n", " return DoseTableFactory._createDoseTable(dataFrame.groupby(dataFrame[countryColumnName]))\n", "\n", " @staticmethod\n", " def _createBatchCodeTableByCountry(dataFrame : pd.DataFrame, country):\n", " countryColumnName = 'Country'\n", - " dataFrame = InternationalLotTableFactory._addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n", + " dataFrame = CountryColumnAdder.addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n", " dataFrame = DataFrameFilter().filterByCountry(dataFrame, country = country, countryColumnName = countryColumnName)\n", - " return DoseTableFactory._createDoseTable(dataFrame.groupby('VAX_LOT'))\n", - "\n", - " @staticmethod\n", - " def _addCountryColumn(dataFrame, countryColumnName):\n", - " dataFrame[countryColumnName] = dataFrame.apply(\n", - " lambda row:\n", - " InternationalLotTableFactory._getCountryNameOfSplttypeOrDefault(\n", - " splttype = row['SPLTTYPE'],\n", - " default = 'Unknown Country'),\n", - " axis = 'columns')\n", - " return dataFrame.astype({countryColumnName: \"string\"})\n", - "\n", - " @staticmethod\n", - " def _getCountryNameOfSplttypeOrDefault(splttype, default):\n", - " if not isinstance(splttype, str):\n", - " return default\n", - " \n", - " country = pycountry.countries.get(alpha_2 = splttype[:2])\n", - " return country.name if country is not None else default\n" + " return DoseTableFactory._createDoseTable(dataFrame.groupby('VAX_LOT'))\n" ] }, {