From d7e52bdf49af058efe7410906ddfdb9f2f6d3e7c Mon Sep 17 00:00:00 2001 From: frankknoll Date: Tue, 7 Feb 2023 11:17:47 +0100 Subject: [PATCH] refining CountryColumnAdderTest --- src/CountryColumnAdder.py | 29 +++++++++----- src/CountryColumnAdderTest.py | 71 ++++++++++++++++++++++++++++++----- src/VaersReader.py | 2 +- 3 files changed, 82 insertions(+), 20 deletions(-) diff --git a/src/CountryColumnAdder.py b/src/CountryColumnAdder.py index c650ecf1662..d7e8f76dcfb 100644 --- a/src/CountryColumnAdder.py +++ b/src/CountryColumnAdder.py @@ -1,24 +1,35 @@ import pycountry +import pandas as pd class CountryColumnAdder: - @staticmethod - def addCountryColumn(dataFrame): - dataFrame['COUNTRY'] = CountryColumnAdder._splttype2Country(dataFrame['SPLTTYPE']) - return dataFrame + def __init__(self, dataFrame_SPLTTYPE_By_VAERS_ID): + self.dataFrame_COUNTRY_By_VAERS_ID = self._create_dataFrame_COUNTRY_By_VAERS_ID(dataFrame_SPLTTYPE_By_VAERS_ID) + + def addCountryColumn(self, dataFrame): + return pd.merge( + dataFrame, + self.dataFrame_COUNTRY_By_VAERS_ID, + how = 'left', + left_index = True, + right_index = True) - @staticmethod - def _splttype2Country(splttypeSeries): + def _create_dataFrame_COUNTRY_By_VAERS_ID(self, dataFrame_SPLTTYPE_By_VAERS_ID): + dataFrame_COUNTRY_By_VAERS_ID = dataFrame_SPLTTYPE_By_VAERS_ID[['SPLTTYPE']].copy() + dataFrame_COUNTRY_By_VAERS_ID['COUNTRY'] = self._splttype2Country(dataFrame_COUNTRY_By_VAERS_ID['SPLTTYPE']) + dataFrame_COUNTRY_By_VAERS_ID = dataFrame_COUNTRY_By_VAERS_ID.drop(columns = ['SPLTTYPE']) + return dataFrame_COUNTRY_By_VAERS_ID + + def _splttype2Country(self, splttypeSeries): return (splttypeSeries .apply( lambda splttype: - CountryColumnAdder._getCountryNameOfSplttypeOrDefault( + self._getCountryNameOfSplttypeOrDefault( splttype = splttype, default = 'Unknown Country')) .astype("string")) - @staticmethod - def _getCountryNameOfSplttypeOrDefault(splttype, default): + def _getCountryNameOfSplttypeOrDefault(self, splttype, default): if not isinstance(splttype, str): return default diff --git a/src/CountryColumnAdderTest.py b/src/CountryColumnAdderTest.py index 38f23955539..0f18e4c6a30 100644 --- a/src/CountryColumnAdderTest.py +++ b/src/CountryColumnAdderTest.py @@ -14,13 +14,16 @@ class CountryColumnAdderTest(unittest.TestCase): data = [ ['GBPFIZER INC2020486806'], ['FRMODERNATX, INC.MOD20224'], ['dummy']], - index = [ - "4711", - "0815", - "123"]) - + index = pd.Index( + name = 'VAERS_ID', + data = [ + "4711", + "0815", + "123"])) + countryColumnAdder = CountryColumnAdder(dataFrame) + # When - dataFrameWithCountryColumn = CountryColumnAdder.addCountryColumn(dataFrame) + dataFrameWithCountryColumn = countryColumnAdder.addCountryColumn(dataFrame) # Then assert_frame_equal( @@ -30,8 +33,56 @@ class CountryColumnAdderTest(unittest.TestCase): data = [ ['GBPFIZER INC2020486806', 'United Kingdom'], ['FRMODERNATX, INC.MOD20224', 'France'], ['dummy', 'Unknown Country']], - index = [ - "4711", - "0815", - "123"], + index = pd.Index( + name = 'VAERS_ID', + data = [ + "4711", + "0815", + "123"]), + dtypes = {'COUNTRY': 'string'})) + + + def test_addCountryColumn2(self): + # Given + countryColumnAdder = CountryColumnAdder( + TestHelper.createDataFrame( + columns = ['SPLTTYPE'], + data = [ ['GBPFIZER INC2020486806'], + ['FRMODERNATX, INC.MOD20224'], + ['dummy']], + index = pd.Index( + name = 'VAERS_ID', + data = [ + 2547744, + 2547730, + 2540815]))) + dataFrame = TestHelper.createDataFrame( + columns = ['VAX_LOT'], + data = [ ['1808982'], + ['EW0175'], + ['EW0176']], + index = pd.Index( + name = 'VAERS_ID', + data = [ + 2547730, + 2547730, + 2547744])) + + # When + dataFrameWithCountryColumn = countryColumnAdder.addCountryColumn(dataFrame) + + # Then + assert_frame_equal( + dataFrameWithCountryColumn, + TestHelper.createDataFrame( + columns = ['VAX_LOT', 'COUNTRY'], + data = [ ['1808982', 'France'], + ['EW0175', 'France'], + ['EW0176', 'United Kingdom']], + index = pd.Index( + name = 'VAERS_ID', + data = [ + 2547730, + 2547730, + 2547744]), dtypes = {'COUNTRY': 'string'})) diff --git a/src/VaersReader.py b/src/VaersReader.py index 8db1d380c75..94e71bcd4ed 100644 --- a/src/VaersReader.py +++ b/src/VaersReader.py @@ -16,7 +16,7 @@ def getVaersForYears(years): def getNonDomesticVaers(): return _getVaers( [_getVaersDescrReader().readNonDomesticVaersDescr()], - CountryColumnAdder.addCountryColumn) + addCountryColumn = lambda dataFrame: CountryColumnAdder(dataFrame).addCountryColumn(dataFrame)) def _getVaersDescrReader(): return VaersDescrReader(dataDir = "VAERS")