diff --git a/src/JensenShannonDistanceColumnAdder.py b/src/JensenShannonDistanceColumnAdder.py new file mode 100644 index 00000000000..7fc9d6c5bb3 --- /dev/null +++ b/src/JensenShannonDistanceColumnAdder.py @@ -0,0 +1,15 @@ +from scipy.spatial import distance + + +class JensenShannonDistanceColumnAdder: + + @staticmethod + def addJensenShannonDistanceColumn(barChartDescriptionTable): + barChartDescriptionTable['JENSEN_SHANNON_DISTANCE'] = ( + barChartDescriptionTable.apply( + lambda barChartDescription: distance.jensenshannon( + barChartDescription['BAR_CHART_DESCRIPTION']['Adverse Reaction Reports guessed'], + barChartDescription['BAR_CHART_DESCRIPTION']['Adverse Reaction Reports known'], + base=2.0), + axis = 'columns')) + return barChartDescriptionTable diff --git a/src/JensenShannonDistanceColumnAdderTest.py b/src/JensenShannonDistanceColumnAdderTest.py new file mode 100644 index 00000000000..b9f3793eb8a --- /dev/null +++ b/src/JensenShannonDistanceColumnAdderTest.py @@ -0,0 +1,52 @@ +import unittest +import pandas as pd +from pandas.testing import assert_frame_equal +from TestHelper import TestHelper +from JensenShannonDistanceColumnAdder import JensenShannonDistanceColumnAdder +from scipy.spatial import distance + +class JensenShannonDistanceColumnAdderTest(unittest.TestCase): + + def test_addJensenShannonDistanceColumn(self): + # Given + barChartDescriptionTable = TestHelper.createDataFrame( + columns = ['BAR_CHART_DESCRIPTION'], + data = [ + [ + { + 'countries': ['Germany', 'Hungary'], + 'Adverse Reaction Reports guessed': [10, 15], + 'Adverse Reaction Reports known': [20, 30] + } + ] + ], + index = pd.Index( + [ + '!D0181', + ], + name = 'VAX_LOT')) + + # When + barChartDescriptionTableWithJensenShannonDistanceColumn = JensenShannonDistanceColumnAdder.addJensenShannonDistanceColumn(barChartDescriptionTable) + + # Then + assert_frame_equal( + barChartDescriptionTableWithJensenShannonDistanceColumn, + TestHelper.createDataFrame( + columns = ['BAR_CHART_DESCRIPTION', 'JENSEN_SHANNON_DISTANCE'], + data = [ + [ + { + 'countries': ['Germany', 'Hungary'], + 'Adverse Reaction Reports guessed': [10, 15], + 'Adverse Reaction Reports known': [20, 30] + }, + distance.jensenshannon([10, 15], [20, 30], base = 2.0) + ] + ], + index = pd.Index( + [ + '!D0181', + ], + name = 'VAX_LOT')), + check_dtype = True)