refining BatchCodeTableFactoryTest

This commit is contained in:
frankknoll
2023-03-30 16:19:49 +02:00
parent 80b7448aed
commit 9914ac95ef
3 changed files with 86 additions and 14 deletions

View File

@@ -14,14 +14,16 @@ class BatchCodeTableFactory:
dataFrame['VAX_LOT']
]))
def createGlobalBatchCodeTable(self):
return self._postProcess(SummationTableFactory.createSummationTable(self.dataFrame.groupby('VAX_LOT')))
def createGlobalBatchCodeTable(self, countriesAsList = False):
return self._postProcess(SummationTableFactory.createSummationTable(self.dataFrame.groupby('VAX_LOT')), countriesAsList)
def createBatchCodeTableByCountry(self, country):
return self._postProcess(self._getBatchCodeTableByCountry(country))
def createBatchCodeTableByCountry(self, country, countriesAsList = False):
return self._postProcess(self._getBatchCodeTableByCountry(country), countriesAsList)
def _postProcess(self, batchCodeTable):
def _postProcess(self, batchCodeTable, countriesAsList):
batchCodeTable = self.companyColumnAdder.addCompanyColumn(batchCodeTable)
if not countriesAsList:
batchCodeTable['Countries'] = batchCodeTable['Countries'].apply(', '.join)
batchCodeTable = batchCodeTable[
[
'Adverse Reaction Reports',

View File

@@ -39,7 +39,41 @@ class BatchCodeTableFactoryTest(unittest.TestCase):
'030L20A'
],
name = 'VAX_LOT')),
check_dtype = False)
check_dtype = True)
def test_createBatchCodeTableByCountry_countriesAsList(self):
# Given
dataFrame = TestHelper.createDataFrame(
columns = ['DIED', 'L_THREAT', 'DISABLE', 'VAX_TYPE', 'VAX_MANU', 'VAX_LOT', 'VAX_DOSE_SERIES', 'SPLTTYPE', 'HOSPITAL', 'ER_VISIT', 'COUNTRY'],
data = [ [1, 0, 0, 'COVID19', 'PFIZER\BIONTECH', '016M20A', '2', 'GBPFIZER INC2020486806', 0, 0, 'United Kingdom'],
[0, 0, 0, 'COVID19', 'MODERNA', '030L20A', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'France'],
[1, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'France'],
[0, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'France']],
index = [
"1048786",
"1048786",
"4711",
"0815"])
dataFrame = SevereColumnAdder.addSevereColumn(dataFrame)
batchCodeTableFactory = BatchCodeTableFactory(dataFrame)
# When
batchCodeTable = batchCodeTableFactory.createBatchCodeTableByCountry('France', countriesAsList = True)
# Then
assert_frame_equal(
batchCodeTable[['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality']],
TestHelper.createDataFrame(
columns = ['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality'],
data = [ [2, 1, 2, 2, 'MODERNA', ['France'], 2/2 * 100, 1/2 * 100],
[1, 0, 0, 0, 'MODERNA', ['France'], 0/1 * 100, 0/1 * 100]],
index = pd.Index(
[
'030L20B',
'030L20A'
],
name = 'VAX_LOT')),
check_dtype = True)
def test_createGlobalBatchCodeTable(self):
# Given
@@ -64,10 +98,10 @@ class BatchCodeTableFactoryTest(unittest.TestCase):
assert_frame_equal(
batchCodeTable[['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality']],
TestHelper.createDataFrame(
columns = ['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality'],
data = [ [1, 1, 0, 0, 'PFIZER\BIONTECH', 'United Kingdom', 1/1 * 100, 1/1 * 100],
[2, 1, 2, 2, 'MODERNA', 'France, United Kingdom', 2/2 * 100, 1/2 * 100],
[1, 0, 0, 0, 'MODERNA', 'France', 0/1 * 100, 0/1 * 100]],
columns = ['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality'],
data = [ [1, 1, 0, 0, 'PFIZER\BIONTECH', 'United Kingdom', 1/1 * 100, 1/1 * 100],
[2, 1, 2, 2, 'MODERNA', 'France, United Kingdom', 2/2 * 100, 1/2 * 100],
[1, 0, 0, 0, 'MODERNA', 'France', 0/1 * 100, 0/1 * 100]],
index = pd.Index(
[
'016M20A',
@@ -75,7 +109,43 @@ class BatchCodeTableFactoryTest(unittest.TestCase):
'030L20A'
],
name = 'VAX_LOT')),
check_dtype = False)
check_dtype = True)
def test_createGlobalBatchCodeTable_countriesAsList(self):
# Given
dataFrame = TestHelper.createDataFrame(
columns = ['DIED', 'L_THREAT', 'DISABLE', 'VAX_TYPE', 'VAX_MANU', 'VAX_LOT', 'VAX_DOSE_SERIES', 'SPLTTYPE', 'HOSPITAL', 'ER_VISIT', 'COUNTRY'],
data = [ [1, 0, 0, 'COVID19', 'PFIZER\BIONTECH', '016M20A', '2', 'GBPFIZER INC2020486806', 0, 0, 'United Kingdom'],
[0, 0, 0, 'COVID19', 'MODERNA', '030L20A', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'France'],
[1, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'France'],
[0, 1, 1, 'COVID19', 'MODERNA', '030L20B', '1', 'FRMODERNATX, INC.MOD20224', 0, 0, 'United Kingdom']],
index = [
"1048786",
"1048786",
"4711",
"0815"])
dataFrame = SevereColumnAdder.addSevereColumn(dataFrame)
batchCodeTableFactory = BatchCodeTableFactory(dataFrame)
# When
batchCodeTable = batchCodeTableFactory.createGlobalBatchCodeTable(countriesAsList = True)
# Then
assert_frame_equal(
batchCodeTable[['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality']],
TestHelper.createDataFrame(
columns = ['Adverse Reaction Reports', 'Deaths', 'Disabilities', 'Life Threatening Illnesses', 'Company', 'Countries', 'Severe reports', 'Lethality'],
data = [ [1, 1, 0, 0, 'PFIZER\BIONTECH', ['United Kingdom'], 1/1 * 100, 1/1 * 100],
[2, 1, 2, 2, 'MODERNA', ['France', 'United Kingdom'], 2/2 * 100, 1/2 * 100],
[1, 0, 0, 0, 'MODERNA', ['France'], 0/1 * 100, 0/1 * 100]],
index = pd.Index(
[
'016M20A',
'030L20B',
'030L20A'
],
name = 'VAX_LOT')),
check_dtype = True)
def test_createBatchCodeTableByNonExistingCountry(self):
# Given

View File

@@ -11,7 +11,7 @@ class SummationTableFactory:
'Life Threatening Illnesses': pd.NamedAgg(column = 'L_THREAT', aggfunc = 'sum'),
'Disabilities': pd.NamedAgg(column = 'DISABLE', aggfunc = 'sum'),
'Severities': pd.NamedAgg(column = 'SEVERE', aggfunc = 'sum'),
'Countries': pd.NamedAgg(column = 'COUNTRY', aggfunc = SummationTableFactory.countries2str)
'Countries': pd.NamedAgg(column = 'COUNTRY', aggfunc = SummationTableFactory.sortCountries)
})
summationTable['Severe reports'] = summationTable['Severities'] / summationTable['Adverse Reaction Reports'] * 100
summationTable['Lethality'] = summationTable['Deaths'] / summationTable['Adverse Reaction Reports'] * 100
@@ -27,5 +27,5 @@ class SummationTableFactory:
]]
@staticmethod
def countries2str(countries):
return ', '.join(sorted(set(countries)))
def sortCountries(countries):
return sorted(set(countries))