From 00da43dd5d91e9d2397663d2c10f27e90467dc9c Mon Sep 17 00:00:00 2001 From: frankknoll Date: Fri, 20 Jan 2023 01:07:17 +0100 Subject: [PATCH] refining SymptomsByBatchcodesTableFactoryTest --- src/SymptomsByBatchcodesTableFactory.py | 12 ++++++++++-- src/SymptomsByBatchcodesTableFactoryTest.py | 7 ++++--- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/SymptomsByBatchcodesTableFactory.py b/src/SymptomsByBatchcodesTableFactory.py index 657a80942bd..43df6049b36 100644 --- a/src/SymptomsByBatchcodesTableFactory.py +++ b/src/SymptomsByBatchcodesTableFactory.py @@ -5,12 +5,20 @@ class SymptomsByBatchcodesTableFactory: @staticmethod def createSymptomsByBatchcodesTable(VAERSVAX, VAERSSYMPTOMS): - index_columns = ['VAX_LOT1', 'VAX_LOT2'] + index_columns = SymptomsByBatchcodesTableFactory._getIndexColumns(VAERSVAX) return pd.merge( SymptomsByBatchcodesTableFactory._get_VAERSVAX_WITH_VAX_LOTS(VAERSVAX, index_columns), SymptomsByBatchcodesTableFactory._getSymptomsTable(VAERSSYMPTOMS), on = 'VAERS_ID').set_index(index_columns)[['SYMPTOMS']] + @staticmethod + def _getIndexColumns(VAERSVAX): + return [f"VAX_LOT{num}" for num in range(1, SymptomsByBatchcodesTableFactory._getMaxNumShots(VAERSVAX) + 1)] + + @staticmethod + def _getMaxNumShots(VAERSVAX): + return VAERSVAX.index.value_counts().iloc[0] + @staticmethod def _get_VAERSVAX_WITH_VAX_LOTS(VAERSVAX, index_columns): return pd.concat( @@ -21,7 +29,7 @@ class SymptomsByBatchcodesTableFactory: def _getVaxLotsTable(VAERSVAX, index_columns): VAX_LOT_LIST_Table = VAERSVAX.groupby("VAERS_ID").agg(VAX_LOT_LIST = pd.NamedAgg(column = 'VAX_LOT', aggfunc = list)) return pd.DataFrame( - [fill(VAX_LOTS, 2, str(np.nan)) for VAX_LOTS in VAX_LOT_LIST_Table['VAX_LOT_LIST'].tolist()], + [fill(VAX_LOTS, len(index_columns), str(np.nan)) for VAX_LOTS in VAX_LOT_LIST_Table['VAX_LOT_LIST'].tolist()], columns = index_columns, index = VAX_LOT_LIST_Table.index) diff --git a/src/SymptomsByBatchcodesTableFactoryTest.py b/src/SymptomsByBatchcodesTableFactoryTest.py index c18528e4766..1b0219be5f4 100644 --- a/src/SymptomsByBatchcodesTableFactoryTest.py +++ b/src/SymptomsByBatchcodesTableFactoryTest.py @@ -92,9 +92,10 @@ class SymptomsByBatchcodesTableFactoryTest(unittest.TestCase): columns = ['SYMPTOMS'], data = [ ['Blood pressure orthostatic abnormal'], ['Blood pressure orthostatic abnormal']], - index = pd.MultiIndex.from_tuples( - names = ['VAX_LOT1', 'VAX_LOT2'], - tuples = [['EW0175', str(np.nan)]] * 2)), + index = pd.Index( + name = 'VAX_LOT1', + data = ['EW0175', + 'EW0175'])), check_dtype = False) def test_createSymptomsByBatchcodesTable_two_patients_distinct_symptoms(self):