diff --git a/src/HowBadIsMyBatch.ipynb b/src/HowBadIsMyBatch.ipynb index afb474d4c9d..5679f2f797d 100644 --- a/src/HowBadIsMyBatch.ipynb +++ b/src/HowBadIsMyBatch.ipynb @@ -410,17 +410,18 @@ "source": [ "class InternationalLotTableFactory:\n", " \n", - " @staticmethod\n", - " def createInternationalLotTable(dataFrame):\n", - " dataFrame = DataFrameFilter().filterByCovid19(dataFrame)\n", - " internationalLotTable = InternationalLotTableFactory._createInternationalLotTable(dataFrame)\n", + " def __init__(self, dataFrame : pd.DataFrame):\n", + " self.dataFrame = DataFrameFilter().filterByCovid19(dataFrame)\n", + " self.countryBatchCodeTable = None\n", + "\n", + " def createInternationalLotTable(self):\n", + " internationalLotTable = self._createInternationalLotTable()\n", " return internationalLotTable.sort_values(by = 'Severe reports', ascending = False)\n", "\n", - " @staticmethod\n", - " def createBatchCodeTableByCountry(dataFrame : pd.DataFrame, country):\n", - " dataFrame = DataFrameFilter().filterByCovid19(dataFrame)\n", - " batchCodeTable = InternationalLotTableFactory._createBatchCodeTableByCountry(dataFrame, country)\n", - " batchCodeTable = CompanyColumnAdder.addCompanyColumn(batchCodeTable, CompanyColumnAdder.createCompanyByBatchCodeTable(dataFrame))\n", + " # FK-TODO: move this and dependent methods to another class\n", + " def createBatchCodeTableByCountry(self, country):\n", + " batchCodeTable = self._createBatchCodeTableByCountry(country)\n", + " batchCodeTable = CompanyColumnAdder.addCompanyColumn(batchCodeTable, CompanyColumnAdder.createCompanyByBatchCodeTable(self.dataFrame))\n", " batchCodeTable = batchCodeTable[\n", " [\n", " 'Adverse Reaction Reports',\n", @@ -433,18 +434,31 @@ " ]]\n", " return batchCodeTable.sort_values(by = 'Severe reports', ascending = False)\n", "\n", - " @staticmethod\n", - " def _createInternationalLotTable(dataFrame):\n", + " def _createInternationalLotTable(self):\n", " countryColumnName = 'Country'\n", - " dataFrame = CountryColumnAdder.addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n", + " dataFrame = CountryColumnAdder.addCountryColumn(self.dataFrame, countryColumnName = countryColumnName)\n", " return SummationTableFactory.createSummationTableHavingSevereReportsColumn(dataFrame.groupby(dataFrame[countryColumnName]))\n", "\n", - " @staticmethod\n", - " def _createBatchCodeTableByCountry(dataFrame : pd.DataFrame, country):\n", + " def _createBatchCodeTableByCountry(self, country):\n", + " if self.countryBatchCodeTable is None:\n", + " self.countryBatchCodeTable = self._getCountryBatchCodeTable()\n", + " return self._getCountry(self.countryBatchCodeTable, country)\n", + "\n", + " def _getCountryBatchCodeTable(self):\n", " countryColumnName = 'Country'\n", - " dataFrame = CountryColumnAdder.addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n", - " dataFrame = DataFrameFilter().filterByCountry(dataFrame, country = country, countryColumnName = countryColumnName)\n", - " return SummationTableFactory.createSummationTableHavingSevereReportsColumn(dataFrame.groupby('VAX_LOT'))\n" + " dataFrame = CountryColumnAdder.addCountryColumn(self.dataFrame, countryColumnName = countryColumnName)\n", + " return SummationTableFactory.createSummationTableHavingSevereReportsColumn(\n", + " dataFrame.groupby(\n", + " [\n", + " dataFrame[countryColumnName],\n", + " dataFrame['VAX_LOT']\n", + " ]))\n", + "\n", + " def _getCountry(self, countryBatchCodeTable, country):\n", + " return countryBatchCodeTable.loc[country] if country in countryBatchCodeTable.index else self._getEmptyBatchCodeTable(countryBatchCodeTable)\n", + " \n", + " def _getEmptyBatchCodeTable(self, countryBatchCodeTable):\n", + " return countryBatchCodeTable[0:0].droplevel(0)\n" ] }, { @@ -1109,9 +1123,10 @@ " \"0815\",\n", " \"0816\"])\n", " dataFrame = SevereColumnAdder.addSevereColumn(dataFrame)\n", + " internationalLotTableFactory = InternationalLotTableFactory(dataFrame)\n", " \n", " # When\n", - " internationalLotTable = InternationalLotTableFactory.createInternationalLotTable(dataFrame)\n", + " internationalLotTable = internationalLotTableFactory.createInternationalLotTable()\n", "\n", " # Then\n", " assert_frame_equal(\n", @@ -1133,20 +1148,21 @@ " def test_createBatchCodeTableByCountry(self):\n", " # Given\n", " dataFrame = TestHelper.createDataFrame(\n", - " columns = ['DIED', 'L_THREAT', 'DISABLE', 'VAX_TYPE', 'VAX_MANU', 'VAX_LOT', 'VAX_DOSE_SERIES', 'SPLTTYPE', 'HOSPITAL', 'ER_VISIT'],\n", - " data = [ [1, 0, 0, 'COVID19', 'MODERNA', '016M20A', '2', 'GBPFIZER INC2020486806', 0, 0],\n", - " [0, 0, 0, 'COVID19', 'MODERNA', '030L20A', '1', 'FRMODERNATX, INC.MOD20224', 0, 0],\n", - " [1, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0],\n", - " [0, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0]],\n", + " columns = ['DIED', 'L_THREAT', 'DISABLE', 'VAX_TYPE', 'VAX_MANU', 'VAX_LOT', 'VAX_DOSE_SERIES', 'SPLTTYPE', 'HOSPITAL', 'ER_VISIT'],\n", + " data = [ [1, 0, 0, 'COVID19', 'PFIZER\\BIONTECH', '016M20A', '2', 'GBPFIZER INC2020486806', 0, 0],\n", + " [0, 0, 0, 'COVID19', 'MODERNA', '030L20A', '1', 'FRMODERNATX, INC.MOD20224', 0, 0],\n", + " [1, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0],\n", + " [0, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0]],\n", " index = [\n", " \"1048786\",\n", " \"1048786\",\n", " \"4711\",\n", " \"0815\"])\n", " dataFrame = SevereColumnAdder.addSevereColumn(dataFrame)\n", + " internationalLotTableFactory = InternationalLotTableFactory(dataFrame)\n", " \n", " # When\n", - " batchCodeTable = InternationalLotTableFactory.createBatchCodeTableByCountry(dataFrame, 'France')\n", + " batchCodeTable = internationalLotTableFactory.createBatchCodeTableByCountry('France')\n", "\n", " # Then\n", " assert_frame_equal(\n", @@ -1161,6 +1177,34 @@ " '030L20A'\n", " ],\n", " name = 'VAX_LOT')),\n", + " check_dtype = False)\n", + "\n", + " def test_createBatchCodeTableByNonExistingCountry(self):\n", + " # Given\n", + " dataFrame = TestHelper.createDataFrame(\n", + " columns = ['DIED', 'L_THREAT', 'DISABLE', 'VAX_TYPE', 'VAX_MANU', 'VAX_LOT', 'VAX_DOSE_SERIES', 'SPLTTYPE', 'HOSPITAL', 'ER_VISIT'],\n", + " data = [ [1, 0, 0, 'COVID19', 'PFIZER\\BIONTECH', '016M20A', '2', 'GBPFIZER INC2020486806', 0, 0],\n", + " [0, 0, 0, 'COVID19', 'MODERNA', '030L20A', '1', 'FRMODERNATX, INC.MOD20224', 0, 0],\n", + " [1, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0],\n", + " [0, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0]],\n", + " index = [\n", + " \"1048786\",\n", + " \"1048786\",\n", + " \"4711\",\n", + " \"0815\"])\n", + " dataFrame = SevereColumnAdder.addSevereColumn(dataFrame)\n", + " internationalLotTableFactory = InternationalLotTableFactory(dataFrame)\n", + " \n", + " # When\n", + " batchCodeTable = internationalLotTableFactory.createBatchCodeTableByCountry('non existing country')\n", + "\n", + " # Then\n", + " assert_frame_equal(\n", + " batchCodeTable,\n", + " TestHelper.createDataFrame(\n", + " columns = ['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Severe reports', 'Lethality'],\n", + " data = [ ],\n", + " index = pd.Index([], name = 'VAX_LOT')),\n", " check_dtype = False)\n" ] }, @@ -1366,7 +1410,7 @@ "metadata": {}, "outputs": [], "source": [ - "internationalLotTable = InternationalLotTableFactory.createInternationalLotTable(nonDomesticVaers)" + "internationalLotTable = InternationalLotTableFactory(nonDomesticVaers).createInternationalLotTable()" ] }, { @@ -1388,8 +1432,8 @@ "metadata": {}, "outputs": [], "source": [ - "def createAndSaveAndDisplayBatchCodeTableByCountry(nonDomesticVaers, country, minADRsForLethality = None):\n", - " batchCodeTable = InternationalLotTableFactory.createBatchCodeTableByCountry(nonDomesticVaers, country)\n", + "def createAndSaveAndDisplayBatchCodeTableByCountry(internationalLotTableFactory, country, minADRsForLethality = None):\n", + " batchCodeTable = internationalLotTableFactory.createBatchCodeTableByCountry(country)\n", " batchCodeTable.index.set_names(\"Batch\", inplace = True)\n", " if minADRsForLethality is not None:\n", " batchCodeTable.loc[batchCodeTable['Adverse Reaction Reports'] < minADRsForLethality, ['Severe reports', 'Lethality']] = [np.nan, np.nan]\n", @@ -1397,8 +1441,9 @@ " display(country + \":\", batchCodeTable)\n", "\n", "def createAndSaveAndDisplayBatchCodeTablesByCountry(nonDomesticVaers, countries, minADRsForLethality = None):\n", + " internationalLotTableFactory = InternationalLotTableFactory(nonDomesticVaers)\n", " for country in countries:\n", - " createAndSaveAndDisplayBatchCodeTableByCountry(nonDomesticVaers, country, minADRsForLethality)" + " createAndSaveAndDisplayBatchCodeTableByCountry(internationalLotTableFactory, country, minADRsForLethality)" ] }, {