diff --git a/src/HistogramTable2DictTableConverter.py b/src/HistogramTable2DictTableConverter.py index 017b04c82fb..4d1db1d8b1b 100644 --- a/src/HistogramTable2DictTableConverter.py +++ b/src/HistogramTable2DictTableConverter.py @@ -3,11 +3,11 @@ class HistogramTable2DictTableConverter: @staticmethod def convertHistogramTable2DictTable(symptomHistogramByBatchcodeTable): vax_lot_columns = symptomHistogramByBatchcodeTable.index.names.difference(['SYMPTOM']) - return ( - symptomHistogramByBatchcodeTable - .groupby(vax_lot_columns) + return (symptomHistogramByBatchcodeTable + .groupby(vax_lot_columns + ['COUNTRY']) .agg(lambda histogram_with_vax_lots: HistogramTable2DictTableConverter._histogram_to_json(histogram_with_vax_lots, vax_lot_columns)) - ) + .reset_index(level = 'COUNTRY') + [['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY']]) @staticmethod def _histogram_to_json(histogram_with_vax_lots, vax_lot_columns): diff --git a/src/HistogramTable2DictTableConverterTest.py b/src/HistogramTable2DictTableConverterTest.py index f024ee48a5e..819895d63b8 100644 --- a/src/HistogramTable2DictTableConverterTest.py +++ b/src/HistogramTable2DictTableConverterTest.py @@ -9,10 +9,10 @@ class HistogramTable2DictTableConverterTest(unittest.TestCase): def test_convertHistogramTable2DictTable(self): # Given histogramTable = TestHelper.createDataFrame( - columns = ['SYMPTOM_COUNT_BY_VAX_LOT'], - data = [ [5], - [1], - [2]], + columns = ['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY'], + data = [ [5, 'Germany'], + [1, 'Germany'], + [2, 'Russian Federation']], index = pd.MultiIndex.from_tuples( names = ['VAX_LOT1', 'SYMPTOM'], tuples = [['1808982', 'Blood pressure orthostatic abnormal'], @@ -26,17 +26,19 @@ class HistogramTable2DictTableConverterTest(unittest.TestCase): assert_frame_equal( dictTable, TestHelper.createDataFrame( - columns = ['SYMPTOM_COUNT_BY_VAX_LOT'], + columns = ['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY'], data = [ [ { "Blood pressure orthostatic abnormal": 5, "Chest discomfort": 1 - } + }, + 'Germany' ], [ { "Chest discomfort": 2 - } + }, + 'Russian Federation' ]], index = pd.Index( name = 'VAX_LOT1',