refactoring

This commit is contained in:
frankknoll
2023-02-16 16:11:48 +01:00
parent 54ddc3e47f
commit e7463962c1
5 changed files with 18 additions and 36 deletions

View File

@@ -2,12 +2,6 @@ from SymptomHistogramByBatchcodeTableFactory import SymptomHistogramByBatchcodeT
from HistogramTable2DictTableConverter import HistogramTable2DictTableConverter from HistogramTable2DictTableConverter import HistogramTable2DictTableConverter
def createGlobalHistograms(symptomByBatchcodeTable):
symptomHistogramByBatchcodeTable = SymptomHistogramByBatchcodeTableFactory.createGlobalSymptomHistogramByBatchcodeTable(symptomByBatchcodeTable)
dictByBatchcodeTable = HistogramTable2DictTableConverter.convertGlobalHistogramTable2DictTable(symptomHistogramByBatchcodeTable)
return dictByBatchcodeTable
def createHistograms(symptomByBatchcodeTable): def createHistograms(symptomByBatchcodeTable):
symptomHistogramByBatchcodeTable = SymptomHistogramByBatchcodeTableFactory.createSymptomHistogramByBatchcodeTable(symptomByBatchcodeTable) symptomHistogramByBatchcodeTable = SymptomHistogramByBatchcodeTableFactory.createSymptomHistogramByBatchcodeTable(symptomByBatchcodeTable)
dictByBatchcodeTable = HistogramTable2DictTableConverter.convertHistogramTable2DictTable(symptomHistogramByBatchcodeTable) dictByBatchcodeTable = HistogramTable2DictTableConverter.convertHistogramTable2DictTable(symptomHistogramByBatchcodeTable)

View File

@@ -9,13 +9,6 @@ class HistogramTable2DictTableConverter:
.reset_index(level = 'COUNTRY') .reset_index(level = 'COUNTRY')
[['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY']]) [['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY']])
@staticmethod
def convertGlobalHistogramTable2DictTable(globalSymptomHistogramByBatchcodeTable):
vax_lot_columns = globalSymptomHistogramByBatchcodeTable.index.names.difference(['SYMPTOM'])
return (globalSymptomHistogramByBatchcodeTable
.groupby(vax_lot_columns)
.agg(lambda histogram_with_vax_lots: HistogramTable2DictTableConverter._histogram_to_json(histogram_with_vax_lots, vax_lot_columns)))
@staticmethod @staticmethod
def _histogram_to_json(histogram_with_vax_lots, vax_lot_columns): def _histogram_to_json(histogram_with_vax_lots, vax_lot_columns):
histogram = histogram_with_vax_lots.reset_index(level = vax_lot_columns, drop = True) histogram = histogram_with_vax_lots.reset_index(level = vax_lot_columns, drop = True)

View File

@@ -93,10 +93,10 @@ class HistogramTable2DictTableConverterTest(unittest.TestCase):
def test_convertGlobalHistogramTable2DictTable(self): def test_convertGlobalHistogramTable2DictTable(self):
# Given # Given
globalHistogramTable = TestHelper.createDataFrame( globalHistogramTable = TestHelper.createDataFrame(
columns = ['SYMPTOM_COUNT_BY_VAX_LOT'], columns = ['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY'],
data = [ [5], data = [ [5, 'Global'],
[1], [1, 'Global'],
[2]], [2, 'Global']],
index = pd.MultiIndex.from_tuples( index = pd.MultiIndex.from_tuples(
names = ['VAX_LOT1', 'SYMPTOM'], names = ['VAX_LOT1', 'SYMPTOM'],
tuples = [['1808982', 'Blood pressure orthostatic abnormal'], tuples = [['1808982', 'Blood pressure orthostatic abnormal'],
@@ -104,23 +104,25 @@ class HistogramTable2DictTableConverterTest(unittest.TestCase):
['EW0175', 'Chest discomfort']])) ['EW0175', 'Chest discomfort']]))
# When # When
dictTable = HistogramTable2DictTableConverter.convertGlobalHistogramTable2DictTable(globalHistogramTable) dictTable = HistogramTable2DictTableConverter.convertHistogramTable2DictTable(globalHistogramTable)
# Then # Then
assert_frame_equal( assert_frame_equal(
dictTable, dictTable,
TestHelper.createDataFrame( TestHelper.createDataFrame(
columns = ['SYMPTOM_COUNT_BY_VAX_LOT'], columns = ['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY'],
data = [ [ data = [ [
{ {
"Blood pressure orthostatic abnormal": 5, "Blood pressure orthostatic abnormal": 5,
"Chest discomfort": 1 "Chest discomfort": 1
} },
'Global'
], ],
[ [
{ {
"Chest discomfort": 2 "Chest discomfort": 2
} },
'Global'
]], ]],
index = pd.Index( index = pd.Index(
name = 'VAX_LOT1', name = 'VAX_LOT1',

View File

@@ -8,10 +8,3 @@ class SymptomHistogramByBatchcodeTableFactory:
.to_frame(name = 'SYMPTOM_COUNT_BY_VAX_LOT') .to_frame(name = 'SYMPTOM_COUNT_BY_VAX_LOT')
.reset_index(level = 'COUNTRY') .reset_index(level = 'COUNTRY')
[['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY']]) [['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY']])
@staticmethod
def createGlobalSymptomHistogramByBatchcodeTable(symptomByBatchcodeTable):
return (symptomByBatchcodeTable
.groupby(symptomByBatchcodeTable.index.names)
['SYMPTOM'].value_counts()
.to_frame(name = 'SYMPTOM_COUNT_BY_VAX_LOT'))

View File

@@ -67,24 +67,24 @@ class SymptomHistogramByBatchcodeTableFactoryTest(unittest.TestCase):
# Given # Given
symptomByBatchcodeTable = TestHelper.createDataFrame( symptomByBatchcodeTable = TestHelper.createDataFrame(
columns = ['SYMPTOM', 'COUNTRY'], columns = ['SYMPTOM', 'COUNTRY'],
data = [ ['Blood pressure orthostatic abnormal', 'Germany'], data = [ ['Blood pressure orthostatic abnormal', 'Global'],
['Blood pressure orthostatic abnormal', 'Germany'], ['Blood pressure orthostatic abnormal', 'Global'],
['Blood pressure orthostatic abnormal', 'Russian Federation'], ['Blood pressure orthostatic abnormal', 'Global'],
['Headache', 'Germany']], ['Headache', 'Global']],
index = pd.MultiIndex.from_tuples( index = pd.MultiIndex.from_tuples(
names = ['VAX_LOT1', 'VAX_LOT2'], names = ['VAX_LOT1', 'VAX_LOT2'],
tuples = [['1808982', 'EW0175']] * 4)) tuples = [['1808982', 'EW0175']] * 4))
# When # When
globalSymptomHistogramByBatchcodeTable = SymptomHistogramByBatchcodeTableFactory.createGlobalSymptomHistogramByBatchcodeTable(symptomByBatchcodeTable) globalSymptomHistogramByBatchcodeTable = SymptomHistogramByBatchcodeTableFactory.createSymptomHistogramByBatchcodeTable(symptomByBatchcodeTable)
# Then # Then
assert_frame_equal( assert_frame_equal(
globalSymptomHistogramByBatchcodeTable, globalSymptomHistogramByBatchcodeTable,
TestHelper.createDataFrame( TestHelper.createDataFrame(
columns = ['SYMPTOM_COUNT_BY_VAX_LOT'], columns = ['SYMPTOM_COUNT_BY_VAX_LOT', 'COUNTRY'],
data = [ [3], data = [ [3, 'Global'],
[1]], [1, 'Global']],
index = pd.MultiIndex.from_tuples( index = pd.MultiIndex.from_tuples(
names = ['VAX_LOT1', 'VAX_LOT2', 'SYMPTOM'], names = ['VAX_LOT1', 'VAX_LOT2', 'SYMPTOM'],
tuples = [['1808982', 'EW0175', 'Blood pressure orthostatic abnormal'], tuples = [['1808982', 'EW0175', 'Blood pressure orthostatic abnormal'],