refactoring

This commit is contained in:
frankknoll
2023-08-27 15:59:23 +02:00
parent 03c981d927
commit af45d6d4fd
2 changed files with 68 additions and 4 deletions

View File

@@ -18,12 +18,14 @@ class BarChartDescriptionTables:
@staticmethod @staticmethod
def hasMinSizeOfGuessedHistogram(barChartDescription, minSizeOfGuessedHistogram): def hasMinSizeOfGuessedHistogram(barChartDescription, minSizeOfGuessedHistogram):
sizeOfGuessedHistogram = sum(barChartDescription['Adverse Reaction Reports guessed']) sizeOfGuessedHistogram = sum(
barChartDescription['Adverse Reaction Reports guessed'])
return sizeOfGuessedHistogram >= minSizeOfGuessedHistogram return sizeOfGuessedHistogram >= minSizeOfGuessedHistogram
@staticmethod @staticmethod
def hasMinSizeOfKnownHistogram(barChartDescription, minSizeOfKnownHistogram): def hasMinSizeOfKnownHistogram(barChartDescription, minSizeOfKnownHistogram):
sizeOfKnownHistogram = sum(barChartDescription['Adverse Reaction Reports known']) sizeOfKnownHistogram = sum(
barChartDescription['Adverse Reaction Reports known'])
return sizeOfKnownHistogram >= minSizeOfKnownHistogram return sizeOfKnownHistogram >= minSizeOfKnownHistogram
@staticmethod @staticmethod
@@ -34,6 +36,13 @@ class BarChartDescriptionTables:
@staticmethod @staticmethod
def isGuessedGreaterThanKnown(barChartDescription): def isGuessedGreaterThanKnown(barChartDescription):
sizeOfGuessedHistogram = sum(barChartDescription['Adverse Reaction Reports guessed']) sizeOfGuessedHistogram = sum(
sizeOfKnownHistogram = sum(barChartDescription['Adverse Reaction Reports known']) barChartDescription['Adverse Reaction Reports guessed'])
sizeOfKnownHistogram = sum(
barChartDescription['Adverse Reaction Reports known'])
return sizeOfGuessedHistogram >= sizeOfKnownHistogram return sizeOfGuessedHistogram >= sizeOfKnownHistogram
@staticmethod
def containsCountry(barChartDescription, country):
COUNTRIES = [country.upper() for country in barChartDescription['countries']]
return country.upper() in COUNTRIES

View File

@@ -321,3 +321,58 @@ class BarChartDescriptionTablesTest(unittest.TestCase):
], ],
name='VAX_LOT')), name='VAX_LOT')),
check_dtype=True) check_dtype=True)
def test_filterContainsCountry(self):
# Given
barChartDescriptionTable = TestHelper.createDataFrame(
columns=['BAR_CHART_DESCRIPTION'],
data=[
[
{
'countries': ['Germany', 'Hungary'],
'Adverse Reaction Reports guessed': [25, 20],
'Adverse Reaction Reports known': [20, 30],
'Jensen-Shannon distance': 0.4711
}
],
[
{
'countries': ['Germany', 'America'],
'Adverse Reaction Reports guessed': [25, 20],
'Adverse Reaction Reports known': [250, 200],
'Jensen-Shannon distance': 0.815
}
]],
index=pd.Index(
[
'!D0181',
'some batch code'
],
name='VAX_LOT'))
# When
barChartDescriptionTableResult = BarChartDescriptionTables.filter(
barChartDescriptionTable,
lambda barChartDescription: BarChartDescriptionTables.containsCountry(barChartDescription, 'America'))
# Then
assert_frame_equal(
barChartDescriptionTableResult,
TestHelper.createDataFrame(
columns=['BAR_CHART_DESCRIPTION'],
data=[
[
{
'countries': ['Germany', 'America'],
'Adverse Reaction Reports guessed': [25, 20],
'Adverse Reaction Reports known': [250, 200],
'Jensen-Shannon distance': 0.815
}
]
],
index=pd.Index(
[
'some batch code'
],
name='VAX_LOT')),
check_dtype=True)