diff --git a/src/BatchCodeTableFactory.py b/src/BatchCodeTableFactory.py index 1e3f2c2966b..8813fed2e00 100644 --- a/src/BatchCodeTableFactory.py +++ b/src/BatchCodeTableFactory.py @@ -14,14 +14,16 @@ class BatchCodeTableFactory: dataFrame['VAX_LOT'] ])) - def createGlobalBatchCodeTable(self): - return self._postProcess(SummationTableFactory.createSummationTable(self.dataFrame.groupby('VAX_LOT'))) + def createGlobalBatchCodeTable(self, countriesAsList = False): + return self._postProcess(SummationTableFactory.createSummationTable(self.dataFrame.groupby('VAX_LOT')), countriesAsList) - def createBatchCodeTableByCountry(self, country): - return self._postProcess(self._getBatchCodeTableByCountry(country)) + def createBatchCodeTableByCountry(self, country, countriesAsList = False): + return self._postProcess(self._getBatchCodeTableByCountry(country), countriesAsList) - def _postProcess(self, batchCodeTable): + def _postProcess(self, batchCodeTable, countriesAsList): batchCodeTable = self.companyColumnAdder.addCompanyColumn(batchCodeTable) + if not countriesAsList: + batchCodeTable['Countries'] = batchCodeTable['Countries'].apply(', '.join) batchCodeTable = batchCodeTable[ [ 'Adverse Reaction Reports', diff --git a/src/BatchCodeTableFactoryTest.py b/src/BatchCodeTableFactoryTest.py index 5762b48205c..4a9a6a7ef53 100644 --- a/src/BatchCodeTableFactoryTest.py +++ b/src/BatchCodeTableFactoryTest.py @@ -39,7 +39,41 @@ class BatchCodeTableFactoryTest(unittest.TestCase): '030L20A' ], name = 'VAX_LOT')), - check_dtype = False) + check_dtype = True) + + def test_createBatchCodeTableByCountry_countriesAsList(self): + # Given + dataFrame = TestHelper.createDataFrame( + columns = ['DIED', 'L_THREAT', 'DISABLE', 'VAX_TYPE', 'VAX_MANU', 'VAX_LOT', 'VAX_DOSE_SERIES', 'SPLTTYPE', 'HOSPITAL', 'ER_VISIT', 'COUNTRY'], + data = [ [1, 0, 0, 'COVID19', 'PFIZER\BIONTECH', '016M20A', '2', 'GBPFIZER INC2020486806', 0, 0, 'United Kingdom'], + [0, 0, 0, 'COVID19', 'MODERNA', '030L20A', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'France'], + [1, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'France'], + [0, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'France']], + index = [ + "1048786", + "1048786", + "4711", + "0815"]) + dataFrame = SevereColumnAdder.addSevereColumn(dataFrame) + batchCodeTableFactory = BatchCodeTableFactory(dataFrame) + + # When + batchCodeTable = batchCodeTableFactory.createBatchCodeTableByCountry('France', countriesAsList = True) + + # Then + assert_frame_equal( + batchCodeTable[['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality']], + TestHelper.createDataFrame( + columns = ['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality'], + data = [ [2, 1, 2, 2, 'MODERNA', ['France'], 2/2 * 100, 1/2 * 100], + [1, 0, 0, 0, 'MODERNA', ['France'], 0/1 * 100, 0/1 * 100]], + index = pd.Index( + [ + '030L20B', + '030L20A' + ], + name = 'VAX_LOT')), + check_dtype = True) def test_createGlobalBatchCodeTable(self): # Given @@ -64,10 +98,10 @@ class BatchCodeTableFactoryTest(unittest.TestCase): assert_frame_equal( batchCodeTable[['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality']], TestHelper.createDataFrame( - columns = ['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality'], - data = [ [1, 1, 0, 0, 'PFIZER\BIONTECH', 'United Kingdom', 1/1 * 100, 1/1 * 100], - [2, 1, 2, 2, 'MODERNA', 'France, United Kingdom', 2/2 * 100, 1/2 * 100], - [1, 0, 0, 0, 'MODERNA', 'France', 0/1 * 100, 0/1 * 100]], + columns = ['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality'], + data = [ [1, 1, 0, 0, 'PFIZER\BIONTECH', 'United Kingdom', 1/1 * 100, 1/1 * 100], + [2, 1, 2, 2, 'MODERNA', 'France, United Kingdom', 2/2 * 100, 1/2 * 100], + [1, 0, 0, 0, 'MODERNA', 'France', 0/1 * 100, 0/1 * 100]], index = pd.Index( [ '016M20A', @@ -75,7 +109,43 @@ class BatchCodeTableFactoryTest(unittest.TestCase): '030L20A' ], name = 'VAX_LOT')), - check_dtype = False) + check_dtype = True) + + def test_createGlobalBatchCodeTable_countriesAsList(self): + # Given + dataFrame = TestHelper.createDataFrame( + columns = ['DIED', 'L_THREAT', 'DISABLE', 'VAX_TYPE', 'VAX_MANU', 'VAX_LOT', 'VAX_DOSE_SERIES', 'SPLTTYPE', 'HOSPITAL', 'ER_VISIT', 'COUNTRY'], + data = [ [1, 0, 0, 'COVID19', 'PFIZER\BIONTECH', '016M20A', '2', 'GBPFIZER INC2020486806', 0, 0, 'United Kingdom'], + [0, 0, 0, 'COVID19', 'MODERNA', '030L20A', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'France'], + [1, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'France'], + [0, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'United Kingdom']], + index = [ + "1048786", + "1048786", + "4711", + "0815"]) + dataFrame = SevereColumnAdder.addSevereColumn(dataFrame) + batchCodeTableFactory = BatchCodeTableFactory(dataFrame) + + # When + batchCodeTable = batchCodeTableFactory.createGlobalBatchCodeTable(countriesAsList = True) + + # Then + assert_frame_equal( + batchCodeTable[['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality']], + TestHelper.createDataFrame( + columns = ['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality'], + data = [ [1, 1, 0, 0, 'PFIZER\BIONTECH', ['United Kingdom'], 1/1 * 100, 1/1 * 100], + [2, 1, 2, 2, 'MODERNA', ['France', 'United Kingdom'], 2/2 * 100, 1/2 * 100], + [1, 0, 0, 0, 'MODERNA', ['France'], 0/1 * 100, 0/1 * 100]], + index = pd.Index( + [ + '016M20A', + '030L20B', + '030L20A' + ], + name = 'VAX_LOT')), + check_dtype = True) def test_createBatchCodeTableByNonExistingCountry(self): # Given diff --git a/src/SummationTableFactory.py b/src/SummationTableFactory.py index 727b86b3bb6..9963ead0f10 100644 --- a/src/SummationTableFactory.py +++ b/src/SummationTableFactory.py @@ -11,7 +11,7 @@ class SummationTableFactory: 'Life Threatening Illnesses': pd.NamedAgg(column = 'L_THREAT', aggfunc = 'sum'), 'Disabilities': pd.NamedAgg(column = 'DISABLE', aggfunc = 'sum'), 'Severities': pd.NamedAgg(column = 'SEVERE', aggfunc = 'sum'), - 'Countries': pd.NamedAgg(column = 'COUNTRY', aggfunc = SummationTableFactory.countries2str) + 'Countries': pd.NamedAgg(column = 'COUNTRY', aggfunc = SummationTableFactory.sortCountries) }) summationTable['Severe reports'] = summationTable['Severities'] / summationTable['Adverse Reaction Reports'] * 100 summationTable['Lethality'] = summationTable['Deaths'] / summationTable['Adverse Reaction Reports'] * 100 @@ -27,5 +27,5 @@ class SummationTableFactory: ]] @staticmethod - def countries2str(countries): - return ', '.join(sorted(set(countries))) \ No newline at end of file + def sortCountries(countries): + return sorted(set(countries)) \ No newline at end of file