diff --git a/src/HowBadIsMyBatch.ipynb b/src/HowBadIsMyBatch.ipynb index 5679f2f797d..971b9b18fda 100644 --- a/src/HowBadIsMyBatch.ipynb +++ b/src/HowBadIsMyBatch.ipynb @@ -412,13 +412,34 @@ " \n", " def __init__(self, dataFrame : pd.DataFrame):\n", " self.dataFrame = DataFrameFilter().filterByCovid19(dataFrame)\n", - " self.countryBatchCodeTable = None\n", + " self.batchCodeTableByCountryFactory = BatchCodeTableByCountryFactory(dataFrame)\n", "\n", " def createInternationalLotTable(self):\n", " internationalLotTable = self._createInternationalLotTable()\n", " return internationalLotTable.sort_values(by = 'Severe reports', ascending = False)\n", "\n", - " # FK-TODO: move this and dependent methods to another class\n", + " def createBatchCodeTableByCountry(self, country):\n", + " return self.batchCodeTableByCountryFactory.createBatchCodeTableByCountry(country)\n", + "\n", + " def _createInternationalLotTable(self):\n", + " countryColumnName = 'Country'\n", + " dataFrame = CountryColumnAdder.addCountryColumn(self.dataFrame, countryColumnName = countryColumnName)\n", + " return SummationTableFactory.createSummationTableHavingSevereReportsColumn(dataFrame.groupby(dataFrame[countryColumnName]))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71456a79", + "metadata": {}, + "outputs": [], + "source": [ + "class BatchCodeTableByCountryFactory:\n", + "\n", + " def __init__(self, dataFrame : pd.DataFrame):\n", + " self.dataFrame = DataFrameFilter().filterByCovid19(dataFrame)\n", + " self.countryBatchCodeTable = None\n", + "\n", " def createBatchCodeTableByCountry(self, country):\n", " batchCodeTable = self._createBatchCodeTableByCountry(country)\n", " batchCodeTable = CompanyColumnAdder.addCompanyColumn(batchCodeTable, CompanyColumnAdder.createCompanyByBatchCodeTable(self.dataFrame))\n", @@ -434,11 +455,6 @@ " ]]\n", " return batchCodeTable.sort_values(by = 'Severe reports', ascending = False)\n", "\n", - " def _createInternationalLotTable(self):\n", - " countryColumnName = 'Country'\n", - " dataFrame = CountryColumnAdder.addCountryColumn(self.dataFrame, countryColumnName = countryColumnName)\n", - " return SummationTableFactory.createSummationTableHavingSevereReportsColumn(dataFrame.groupby(dataFrame[countryColumnName]))\n", - "\n", " def _createBatchCodeTableByCountry(self, country):\n", " if self.countryBatchCodeTable is None:\n", " self.countryBatchCodeTable = self._getCountryBatchCodeTable()\n", @@ -458,7 +474,9 @@ " return countryBatchCodeTable.loc[country] if country in countryBatchCodeTable.index else self._getEmptyBatchCodeTable(countryBatchCodeTable)\n", " \n", " def _getEmptyBatchCodeTable(self, countryBatchCodeTable):\n", - " return countryBatchCodeTable[0:0].droplevel(0)\n" + " return countryBatchCodeTable[0:0].droplevel(0)\n", + "\n", + " " ] }, {