diff --git a/src/HistogramTable2DictTableConverter.py b/src/HistogramTable2DictTableConverter.py index d4752e3a86a..017b04c82fb 100644 --- a/src/HistogramTable2DictTableConverter.py +++ b/src/HistogramTable2DictTableConverter.py @@ -11,5 +11,5 @@ class HistogramTable2DictTableConverter: @staticmethod def _histogram_to_json(histogram_with_vax_lots, vax_lot_columns): - histogram = histogram_with_vax_lots.reset_index(level = vax_lot_columns, drop=True) + histogram = histogram_with_vax_lots.reset_index(level = vax_lot_columns, drop = True) return histogram.to_dict() diff --git a/src/SymptomHistogramByBatchcodeTableFactory.py b/src/SymptomHistogramByBatchcodeTableFactory.py index b485c81ea46..4f701179b88 100644 --- a/src/SymptomHistogramByBatchcodeTableFactory.py +++ b/src/SymptomHistogramByBatchcodeTableFactory.py @@ -2,9 +2,9 @@ class SymptomHistogramByBatchcodeTableFactory: @staticmethod def createSymptomHistogramByBatchcodeTable(symptomByBatchcodeTable): - return ( - symptomByBatchcodeTable - .groupby(symptomByBatchcodeTable.index.names) - ['SYMPTOM'].value_counts() - .to_frame(name = 'SYMPTOM_COUNT_BY_VAX_LOT') - ) + return (symptomByBatchcodeTable + .groupby(symptomByBatchcodeTable.index.names + ['COUNTRY']) + ['SYMPTOM'].value_counts() + .to_frame(name = 'SYMPTOM_COUNT_BY_VAX_LOT') + .reset_index(level = 'COUNTRY') + [['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY']]) diff --git a/src/SymptomHistogramByBatchcodeTableFactoryTest.py b/src/SymptomHistogramByBatchcodeTableFactoryTest.py index 8d518a69bcc..62590611e73 100644 --- a/src/SymptomHistogramByBatchcodeTableFactoryTest.py +++ b/src/SymptomHistogramByBatchcodeTableFactoryTest.py @@ -9,10 +9,10 @@ class SymptomHistogramByBatchcodeTableFactoryTest(unittest.TestCase): def test_createSymptomHistogramByBatchcodeTable(self): # Given symptomByBatchcodeTable = TestHelper.createDataFrame( - columns = ['SYMPTOM'], - data = [ ['Blood pressure orthostatic abnormal'], - ['Blood pressure orthostatic abnormal'], - ['Blood pressure orthostatic abnormal']], + columns = ['SYMPTOM', 'COUNTRY'], + data = [ ['Blood pressure orthostatic abnormal', 'Germany'], + ['Blood pressure orthostatic abnormal', 'Germany'], + ['Blood pressure orthostatic abnormal', 'Germany']], index = pd.Index( name = 'VAX_LOT1', data = ['EW0175', @@ -26,9 +26,9 @@ class SymptomHistogramByBatchcodeTableFactoryTest(unittest.TestCase): assert_frame_equal( symptomHistogramByBatchcodeTable, TestHelper.createDataFrame( - columns = ['SYMPTOM_COUNT_BY_VAX_LOT'], - data = [ [1], - [2]], + columns = ['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY'], + data = [ [1, 'Germany'], + [2, 'Germany']], index = pd.MultiIndex.from_tuples( names = ['VAX_LOT1', 'SYMPTOM'], tuples = [['1808982', 'Blood pressure orthostatic abnormal'], @@ -37,13 +37,14 @@ class SymptomHistogramByBatchcodeTableFactoryTest(unittest.TestCase): def test_createSymptomHistogramByBatchcodeTable_two_VAX_LOTs_Index(self): # Given symptomByBatchcodeTable = TestHelper.createDataFrame( - columns = ['SYMPTOM'], - data = [ ['Blood pressure orthostatic abnormal'], - ['Blood pressure orthostatic abnormal'], - ['Headache']], + columns = ['SYMPTOM', 'COUNTRY'], + data = [ ['Blood pressure orthostatic abnormal', 'Germany'], + ['Blood pressure orthostatic abnormal', 'Germany'], + ['Blood pressure orthostatic abnormal', 'Russian Federation'], + ['Headache', 'Germany']], index = pd.MultiIndex.from_tuples( names = ['VAX_LOT1', 'VAX_LOT2'], - tuples = [['1808982', 'EW0175']] * 3)) + tuples = [['1808982', 'EW0175']] * 4)) # When symptomHistogramByBatchcodeTable = SymptomHistogramByBatchcodeTableFactory.createSymptomHistogramByBatchcodeTable(symptomByBatchcodeTable) @@ -52,10 +53,12 @@ class SymptomHistogramByBatchcodeTableFactoryTest(unittest.TestCase): assert_frame_equal( symptomHistogramByBatchcodeTable, TestHelper.createDataFrame( - columns = ['SYMPTOM_COUNT_BY_VAX_LOT'], - data = [ [2], - [1]], + columns = ['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY'], + data = [ [2, 'Germany'], + [1, 'Germany'], + [1, 'Russian Federation']], index = pd.MultiIndex.from_tuples( names = ['VAX_LOT1', 'VAX_LOT2', 'SYMPTOM'], tuples = [['1808982', 'EW0175', 'Blood pressure orthostatic abnormal'], - ['1808982', 'EW0175', 'Headache']]))) + ['1808982', 'EW0175', 'Headache'], + ['1808982', 'EW0175', 'Blood pressure orthostatic abnormal']])))