refactoring

This commit is contained in:
frankknoll
2022-02-07 21:29:03 +01:00
parent c3b459d5f7
commit be4641d3ff

View File

@@ -258,19 +258,19 @@
"metadata": {},
"outputs": [],
"source": [
"class DoseAnalysis:\n",
"class DoseTableFactory:\n",
" \n",
" @staticmethod\n",
" def getDoseTable(dataFrame):\n",
" def createDoseTable(dataFrame):\n",
" dataFrame = DataFrameFilter().filterByCovid19(dataFrame)\n",
" return DoseAnalysis._getDoseTable(\n",
" return DoseTableFactory._createDoseTable(\n",
" dataFrame.groupby(\n",
" dataFrame['VAX_DOSE_SERIES'].rename('Dose')))\n",
"\n",
" @staticmethod\n",
" def getDoseByMonthTable(dataFrame):\n",
" def createDoseByMonthTable(dataFrame):\n",
" dataFrame = DataFrameFilter().filterByCovid19(dataFrame)\n",
" return DoseAnalysis._getDoseTable(\n",
" return DoseTableFactory._createDoseTable(\n",
" dataFrame.groupby(\n",
" [\n",
" dataFrame['RECVDATE'].dt.year.rename('Year'),\n",
@@ -279,7 +279,7 @@
" ]))\n",
"\n",
" @staticmethod\n",
" def _getDoseTable(dataFrame):\n",
" def _createDoseTable(dataFrame):\n",
" doseTable = SummationTableFactory.createSummationTable(\n",
" dataFrame,\n",
" columnNameMappingsDict = {\n",
@@ -302,25 +302,25 @@
"source": [
"import pycountry\n",
"\n",
"class InternationalLotAnalysis:\n",
"class InternationalLotTableFactory:\n",
" \n",
" @staticmethod\n",
" def getInternationalLotTable(dataFrame):\n",
" def createInternationalLotTable(dataFrame):\n",
" dataFrame = DataFrameFilter().filterByCovid19(dataFrame)\n",
" internationalLotTable = InternationalLotAnalysis._getInternationalLotTable(dataFrame)\n",
" internationalLotTable = InternationalLotTableFactory._createInternationalLotTable(dataFrame)\n",
" return internationalLotTable.sort_values(by = 'Severe reports (%)', ascending = False)\n",
"\n",
" @staticmethod\n",
" def _getInternationalLotTable(dataFrame):\n",
" def _createInternationalLotTable(dataFrame):\n",
" countryColumnName = 'Country'\n",
" InternationalLotAnalysis._addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n",
" return DoseAnalysis._getDoseTable(dataFrame.groupby(dataFrame[countryColumnName]))\n",
" InternationalLotTableFactory._addCountryColumn(dataFrame, countryColumnName = countryColumnName)\n",
" return DoseTableFactory._createDoseTable(dataFrame.groupby(dataFrame[countryColumnName]))\n",
"\n",
" @staticmethod\n",
" def _addCountryColumn(dataFrame, countryColumnName):\n",
" dataFrame[countryColumnName] = dataFrame.apply(\n",
" lambda row:\n",
" InternationalLotAnalysis._getCountryNameOfSplttypeOrDefault(\n",
" InternationalLotTableFactory._getCountryNameOfSplttypeOrDefault(\n",
" splttype = row['SPLTTYPE'],\n",
" default = 'Unknown Country'),\n",
" axis = 'columns')\n",
@@ -698,9 +698,9 @@
"source": [
"from pandas.testing import assert_frame_equal\n",
"\n",
"class DoseAnalysisTest(unittest.TestCase):\n",
"class DoseTableFactoryTest(unittest.TestCase):\n",
"\n",
" def test_getDoseTable(self):\n",
" def test_createDoseTable(self):\n",
" # Given\n",
" dataFrame = TestHelper.createDataFrame(\n",
" columns = ['DIED', 'L_THREAT', 'DISABLE', 'VAX_TYPE', 'VAX_MANU', 'VAX_LOT', 'VAX_DOSE_SERIES', 'HOSPITAL', 'ER_VISIT'],\n",
@@ -714,7 +714,7 @@
" dtypes = {'VAX_DOSE_SERIES': \"string\"})\n",
" \n",
" # When\n",
" doseTable = DoseAnalysis.getDoseTable(dataFrame)\n",
" doseTable = DoseTableFactory.createDoseTable(dataFrame)\n",
"\n",
" # Then\n",
" assert_frame_equal(\n",
@@ -729,7 +729,7 @@
" },\n",
" index = pd.Index(['1', '2'], dtype = \"string\", name = 'Dose')))\n",
" \n",
" def test_getDoseByMonthTable(self):\n",
" def test_createDoseByMonthTable(self):\n",
" # Given\n",
" parseDate = lambda dateStr: pd.to_datetime(dateStr, format = \"%m/%d/%Y\")\n",
" dataFrame = TestHelper.createDataFrame(\n",
@@ -744,7 +744,7 @@
" dtypes = {'VAX_DOSE_SERIES': \"string\"})\n",
" \n",
" # When\n",
" doseByMonthTable = DoseAnalysis.getDoseByMonthTable(dataFrame)\n",
" doseByMonthTable = DoseTableFactory.createDoseByMonthTable(dataFrame)\n",
"\n",
" # Then\n",
" assert_frame_equal(\n",
@@ -775,9 +775,9 @@
"source": [
"from pandas.testing import assert_frame_equal\n",
"\n",
"class InternationalLotAnalysisTest(unittest.TestCase):\n",
"class InternationalLotTableFactoryTest(unittest.TestCase):\n",
"\n",
" def test_getInternationalLotTable(self):\n",
" def test_createInternationalLotTable(self):\n",
" # Given\n",
" dataFrame = TestHelper.createDataFrame(\n",
" columns = ['DIED', 'L_THREAT', 'DISABLE', 'VAX_TYPE', 'VAX_MANU', 'VAX_LOT', 'VAX_DOSE_SERIES', 'SPLTTYPE', 'HOSPITAL', 'ER_VISIT'],\n",
@@ -794,7 +794,7 @@
" \"0816\"])\n",
" \n",
" # When\n",
" internationalLotTable = InternationalLotAnalysis.getInternationalLotTable(dataFrame)\n",
" internationalLotTable = InternationalLotTableFactory.createInternationalLotTable(dataFrame)\n",
"\n",
" # Then\n",
" assert_frame_equal(\n",
@@ -947,17 +947,17 @@
"source": [
"# https://www.howbadismybatch.com/firstsecond.html\n",
"\n",
"def getDoseTable():\n",
"def createDoseTable():\n",
" vaersDescrs = VaersDescrReader(dataDir = \"VAERS\").readAllVaersDescrs()\n",
" dataFrame = VaersDescr2DataFrameConverter.createDataFrameFromDescrs(vaersDescrs)\n",
" DataFrameNormalizer.normalize(dataFrame)\n",
" return DoseAnalysis.getDoseTable(dataFrame)\n",
" return DoseTableFactory.createDoseTable(dataFrame)\n",
"\n",
"def getDoseByMonthTable():\n",
"def createDoseByMonthTable():\n",
" vaersDescrs = VaersDescrReader(dataDir = \"VAERS\").readAllVaersDescrs()\n",
" dataFrame = VaersDescr2DataFrameConverter.createDataFrameFromDescrs(vaersDescrs)\n",
" DataFrameNormalizer.normalize(dataFrame)\n",
" return DoseAnalysis.getDoseByMonthTable(dataFrame)"
" return DoseTableFactory.createDoseByMonthTable(dataFrame)"
]
},
{
@@ -967,7 +967,7 @@
"metadata": {},
"outputs": [],
"source": [
"getDoseTable()"
"createDoseTable()"
]
},
{
@@ -977,7 +977,7 @@
"metadata": {},
"outputs": [],
"source": [
"doseByMonthTable = getDoseByMonthTable()\n",
"doseByMonthTable = createDoseByMonthTable()\n",
"doseByMonthTable.to_excel('results/doseByMonthTable.xlsx')\n",
"doseByMonthTable"
]
@@ -999,11 +999,11 @@
"source": [
"# https://www.howbadismybatch.com/international.html\n",
"\n",
"def getInternationalLotTable():\n",
"def createInternationalLotTable():\n",
" vaersDescr = VaersDescrReader(dataDir = 'VAERS').readNonDomesticVaersDescr()\n",
" dataFrame = VaersDescr2DataFrameConverter.createDataFrameFromDescr(vaersDescr)\n",
" DataFrameNormalizer.normalize(dataFrame)\n",
" return InternationalLotAnalysis.getInternationalLotTable(dataFrame)\n"
" return InternationalLotTableFactory.createInternationalLotTable(dataFrame)\n"
]
},
{
@@ -1013,7 +1013,7 @@
"metadata": {},
"outputs": [],
"source": [
"internationalLotTable = getInternationalLotTable()"
"internationalLotTable = createInternationalLotTable()"
]
},
{