From 0d7a97661244462c48e12582c5cabdcc212465aa Mon Sep 17 00:00:00 2001 From: frankknoll Date: Fri, 4 Feb 2022 01:25:07 +0100 Subject: [PATCH] refactoring --- HowBadIsMyBatch.ipynb | 111 ++++++++++++++++++++++-------------------- 1 file changed, 59 insertions(+), 52 deletions(-) diff --git a/HowBadIsMyBatch.ipynb b/HowBadIsMyBatch.ipynb index d5f4b236ed6..8a0ab1f2352 100644 --- a/HowBadIsMyBatch.ipynb +++ b/HowBadIsMyBatch.ipynb @@ -40,7 +40,6 @@ " 'VAERSDATA':\n", " self._read_csv(\n", " folder + year + \"VAERSDATA.csv\",\n", - " # FK-TODO: use Column enum\n", " ['VAERS_ID', 'DIED', 'L_THREAT', 'DISABLE', 'HOSPITAL', 'ER_VISIT']),\n", " 'VAERSVAX':\n", " self._read_csv(\n", @@ -129,55 +128,56 @@ " \n", " def __init__(self, dataFrame : pd.DataFrame):\n", " self.dataFrame = dataFrame \n", + " self._convertColumnOfDataFrameToNumeric('DIED')\n", + " self._convertColumnOfDataFrameToNumeric('L_THREAT')\n", + " self._convertColumnOfDataFrameToNumeric('DISABLE')\n", + " self._convertColumnOfDataFrameToNumeric('HOSPITAL')\n", + " self._convertColumnOfDataFrameToNumeric('ER_VISIT')\n", "\n", " def createBatchCodeTable(self):\n", - " return self._asDataFrame(\n", + " batchCodeTable = self.dataFrame.groupby('VAX_LOT').agg(\n", " {\n", - " 'ADRs': self._getADRs(),\n", - " 'DEATHS': self._getDEATHS(),\n", - " 'DISABILITIES': self._getDISABILITIES(),\n", - " 'LIFE THREATENING ILLNESSES': self._getLIFE_THREATENING_ILLNESSES()\n", + " 'DIED': ['sum', 'size'],\n", + " 'L_THREAT': 'sum',\n", + " 'DISABLE': 'sum'\n", " })\n", + " self._flattenColumns(batchCodeTable)\n", + " batchCodeTable = batchCodeTable.rename(\n", + " columns =\n", + " {\n", + " \"DIED_size\": \"ADRs\",\n", + " \"DIED_sum\": \"DEATHS\",\n", + " \"L_THREAT_sum\": \"LIFE THREATENING ILLNESSES\",\n", + " \"DISABLE_sum\": \"DISABILITIES\"\n", + " })[['ADRs', 'DEATHS', 'DISABILITIES', 'LIFE THREATENING ILLNESSES']]\n", + " return batchCodeTable.sort_values(by = 'ADRs', ascending = False)\n", "\n", " # create table from https://www.howbadismybatch.com/combined.html\n", " def createSevereEffectsBatchCodeTable(self):\n", + " batchCodeTable = self.dataFrame.groupby('VAX_LOT').agg(\n", + " {\n", + " 'DIED': ['sum', 'size'],\n", + " 'L_THREAT': 'sum',\n", + " 'DISABLE': 'sum',\n", + " 'HOSPITAL': 'sum',\n", + " 'ER_VISIT': 'sum'\n", + " })\n", + " self._flattenColumns(batchCodeTable)\n", + " batchCodeTable = batchCodeTable.rename(\n", + " columns =\n", + " {\n", + " \"DIED_size\": \"ADRs\",\n", + " \"DIED_sum\": \"DEATHS\",\n", + " \"L_THREAT_sum\": \"LIFE THREATENING ILLNESSES\",\n", + " \"DISABLE_sum\": \"DISABILITIES\",\n", + " 'HOSPITAL_sum': 'HOSPITALISATIONS',\n", + " 'ER_VISIT_sum': 'EMERGENCY ROOM OR DOCTOR VISITS'\n", + " })[['ADRs', 'DEATHS', 'DISABILITIES', 'LIFE THREATENING ILLNESSES', 'HOSPITALISATIONS', 'EMERGENCY ROOM OR DOCTOR VISITS']]\n", + " batchCodeTable = batchCodeTable.sort_values(by = 'ADRs', ascending = False)\n", " return self._addCompanyColumn(\n", - " self._asDataFrame(\n", - " {\n", - " 'ADRs': self._getADRs(),\n", - " 'DEATHS': self._getDEATHS(),\n", - " 'DISABILITIES': self._getDISABILITIES(),\n", - " 'LIFE THREATENING ILLNESSES': self._getLIFE_THREATENING_ILLNESSES(),\n", - " 'HOSPITALISATIONS': self._getHOSPITALISATIONS(),\n", - " 'EMERGENCY ROOM OR DOCTOR VISITS': self._getER_VISITs()\n", - " }),\n", + " batchCodeTable,\n", " self._createCompanyByBatchCodeTable())\n", "\n", - " def _getADRs(self):\n", - " return self.dataFrame['VAX_LOT'].value_counts()\n", - "\n", - " def _getDEATHS(self):\n", - " return self._countValues('DIED')\n", - "\n", - " def _getDISABILITIES(self):\n", - " return self._countValues('DISABLE')\n", - "\n", - " def _getLIFE_THREATENING_ILLNESSES(self):\n", - " return self._countValues('L_THREAT')\n", - "\n", - " def _getHOSPITALISATIONS(self):\n", - " return self._countValues('HOSPITAL')\n", - "\n", - " def _getER_VISITs(self):\n", - " return self._countValues('ER_VISIT')\n", - "\n", - " def _countValues(self, column):\n", - " return self.dataFrame[self.dataFrame[column] == 'Y']['VAX_LOT'].value_counts()\n", - "\n", - " def _asDataFrame(self, dict):\n", - " dataFrame = pd.concat(dict, axis = 'columns')\n", - " dataFrame.index.name = 'VAX_LOT'\n", - " return dataFrame.replace(to_replace = np.nan, value = 0)\n", "\n", " def _addCompanyColumn(self, batchCodeTable, companyByBatchCodeTable):\n", " return pd.merge(\n", @@ -196,6 +196,13 @@ " manufacturerByBatchCodeTable = manufacturerByBatchCodeTable.drop_duplicates(subset = ['VAX_LOT'])\n", " return manufacturerByBatchCodeTable.set_index('VAX_LOT')\n", "\n", + " def _convertColumnOfDataFrameToNumeric(self, column):\n", + " self.dataFrame[column] = np.where(self.dataFrame[column] == 'Y', 1, 0)\n", + "\n", + " def _flattenColumns(self, batchCodeTable):\n", + " batchCodeTable.columns = [\"_\".join(a) for a in batchCodeTable.columns.to_flat_index()]\n", + "\n", + "\n", "class BatchCodeTableFactory:\n", "\n", " @staticmethod\n", @@ -468,14 +475,14 @@ " batchCodeTableExpected = pd.DataFrame(\n", " data = {\n", " 'ADRs': [1, 1],\n", - " 'DEATHS': [1, 0],\n", - " 'DISABILITIES': [0, 1],\n", - " 'LIFE THREATENING ILLNESSES': [1, 0],\n", - " 'HOSPITALISATIONS': [1, 0],\n", + " 'DEATHS': [0, 1],\n", + " 'DISABILITIES': [1, 0],\n", + " 'LIFE THREATENING ILLNESSES': [0, 1],\n", + " 'HOSPITALISATIONS': [0, 1],\n", " 'EMERGENCY ROOM OR DOCTOR VISITS': [1, 1],\n", - " 'COMPANY': ['MODERNA', 'PFIZER\\BIONTECH']\n", + " 'COMPANY': ['PFIZER\\BIONTECH', 'MODERNA']\n", " },\n", - " index = pd.Index(['037K20A', '025L20A'], name='VAX_LOT'))\n", + " index = pd.Index(['025L20A', '037K20A'], name = 'VAX_LOT'))\n", " assert_frame_equal(batchCodeTable, batchCodeTableExpected, check_dtype = False)\n", "\n", " def test_createBatchCodeTable2(self):\n", @@ -483,9 +490,9 @@ " [\n", " {\n", " 'VAERSDATA': self.createDataFrame(\n", - " columns = ['DIED', 'L_THREAT', 'DISABLE'],\n", - " data = [ ['Y', np.NaN, np.NaN],\n", - " [np.NaN, np.NaN, 'Y']],\n", + " columns = ['DIED', 'L_THREAT', 'DISABLE', 'HOSPITAL', 'ER_VISIT'],\n", + " data = [ ['Y', np.NaN, np.NaN, np.NaN, np.NaN],\n", + " [np.NaN, np.NaN, 'Y', np.NaN, np.NaN]],\n", " index = [\n", " \"0916600\",\n", " \"0916601\"]),\n", @@ -500,9 +507,9 @@ " },\n", " {\n", " 'VAERSDATA': self.createDataFrame(\n", - " columns = ['DIED', 'L_THREAT', 'DISABLE'],\n", - " data = [ [np.NaN, np.NaN, np.NaN],\n", - " [np.NaN, np.NaN, 'Y']],\n", + " columns = ['DIED', 'L_THREAT', 'DISABLE', 'HOSPITAL', 'ER_VISIT'],\n", + " data = [ [np.NaN, np.NaN, np.NaN, np.NaN, np.NaN],\n", + " [np.NaN, np.NaN, 'Y', np.NaN, np.NaN]],\n", " index = [\n", " \"1996873\",\n", " \"1996874\"]),\n",