diff --git a/src/BatchCodeTableFactory.py b/src/BatchCodeTableFactory.py index ad6e54a0619..7510dc8c835 100644 --- a/src/BatchCodeTableFactory.py +++ b/src/BatchCodeTableFactory.py @@ -6,14 +6,7 @@ class BatchCodeTableFactory: def __init__(self, dataFrame: pd.DataFrame): self.dataFrame = dataFrame - self.companyColumnAdder = CompanyColumnAdder(dataFrame) - self.countryBatchCodeTable = SummationTableFactory.createSummationTable( - dataFrame.groupby( - [ - dataFrame['COUNTRY'], - dataFrame['VAX_LOT'] - ])) - + def createGlobalBatchCodeTable(self): return self._postProcess(SummationTableFactory.createSummationTable(self.dataFrame.groupby('VAX_LOT'))) @@ -21,7 +14,7 @@ class BatchCodeTableFactory: return self._postProcess(self._getBatchCodeTableByCountry(country)) def _postProcess(self, batchCodeTable): - batchCodeTable = self.companyColumnAdder.addCompanyColumn(batchCodeTable) + batchCodeTable = CompanyColumnAdder(self.dataFrame).addCompanyColumn(batchCodeTable) batchCodeTable = batchCodeTable[ [ 'Adverse Reaction Reports', @@ -35,10 +28,19 @@ class BatchCodeTableFactory: return batchCodeTable.sort_values(by = 'Severe reports', ascending = False) def _getBatchCodeTableByCountry(self, country): - if country in self.countryBatchCodeTable.index: - return self.countryBatchCodeTable.loc[country] + countryBatchCodeTable = self._getCountryBatchCodeTable() + if country in countryBatchCodeTable.index: + return countryBatchCodeTable.loc[country] else: - return self._getEmptyBatchCodeTable() + return self._getEmptyBatchCodeTable(countryBatchCodeTable) - def _getEmptyBatchCodeTable(self): - return self.countryBatchCodeTable[0:0].droplevel(0) + def _getCountryBatchCodeTable(self): + return SummationTableFactory.createSummationTable( + self.dataFrame.groupby( + [ + self.dataFrame['COUNTRY'], + self.dataFrame['VAX_LOT'] + ])) + + def _getEmptyBatchCodeTable(self, countryBatchCodeTable): + return countryBatchCodeTable[0:0].droplevel(0)