diff --git a/src/TableByBatchcodeFilter.py b/src/TableByBatchcodeFilter.py index 37ae2511d33..a340c2f3953 100644 --- a/src/TableByBatchcodeFilter.py +++ b/src/TableByBatchcodeFilter.py @@ -1,9 +1,18 @@ +from functools import reduce + + class TableByBatchcodeFilter: @staticmethod def filterTableByBatchcode(batchcode, table): + batchcodeColumns = table.index.names table = table.reset_index() - filteredTable = table[ - (table['VAX_LOT1'] == batchcode) | - (table['VAX_LOT2'] == batchcode)] - return filteredTable.set_index(['VAX_LOT1', 'VAX_LOT2']) + filteredTable = table[TableByBatchcodeFilter._existsBatchcodeInAnyBatchcodeColumn(table, batchcodeColumns, batchcode)] + return filteredTable.set_index(batchcodeColumns) + + @staticmethod + def _existsBatchcodeInAnyBatchcodeColumn(table, batchcodeColumns, batchcode): + return reduce( + lambda accum, batchcodeColumn: accum | (table[batchcodeColumn] == batchcode), + batchcodeColumns, + [False] * len(table.index)) diff --git a/src/TableByBatchcodeFilterTest.py b/src/TableByBatchcodeFilterTest.py index 0df8a3d51cd..b601454daea 100644 --- a/src/TableByBatchcodeFilterTest.py +++ b/src/TableByBatchcodeFilterTest.py @@ -6,7 +6,7 @@ import pandas as pd class TableByBatchcodeFilterTest(unittest.TestCase): - def test_convertHistogramTable2JsonTable(self): + def test_convertHistogramTable2JsonTable_2_VAX_LOT_columns(self): # Given batchcode = '1808982' symptomHistogramByBatchcodeTable = TestHelper.createDataFrame( @@ -35,3 +35,31 @@ class TableByBatchcodeFilterTest(unittest.TestCase): tuples = [[batchcode, 'EW0175'], ['015M20A', batchcode]]))) + def test_convertHistogramTable2JsonTable_3_VAX_LOT_columns(self): + # Given + batchcode = '1808983' + symptomHistogramByBatchcodeTable = TestHelper.createDataFrame( + columns = ['SYMPTOM_COUNT_BY_VAX_LOT'], + data = [ ['{"Blood pressure orthostatic abnormal":5,"Chest discomfort":1}'], + ['{"Chest discomfort":2}'], + ['{"Chills":5}']], + index = pd.MultiIndex.from_tuples( + names = ['VAX_LOT1', 'VAX_LOT2', 'VAX_LOT3'], + tuples = [[batchcode, 'EW0175', None], + ['015M20A', None, batchcode], + ['015M20A', 'EW0175', 'dummy2']])) + + # When + filteredTable = TableByBatchcodeFilter.filterTableByBatchcode(batchcode, symptomHistogramByBatchcodeTable) + + # Then + assert_frame_equal( + filteredTable, + TestHelper.createDataFrame( + columns = ['SYMPTOM_COUNT_BY_VAX_LOT'], + data = [ ['{"Blood pressure orthostatic abnormal":5,"Chest discomfort":1}'], + ['{"Chest discomfort":2}']], + index = pd.MultiIndex.from_tuples( + names = ['VAX_LOT1', 'VAX_LOT2', 'VAX_LOT3'], + tuples = [[batchcode, 'EW0175', None], + ['015M20A', None, batchcode]])))