refactoring

This commit is contained in:
frankknoll
2023-08-23 21:28:40 +02:00
parent af19dc710f
commit 3e1c59a697
3 changed files with 27 additions and 23 deletions

View File

@@ -0,0 +1,19 @@
from scipy.spatial import distance
class JensenShannonDistance2BarChartDescriptionColumnAdder:
@staticmethod
def addJensenShannonDistance2BarChartDescriptionColumn(barChartDescriptionTable):
barChartDescriptionTable['BAR_CHART_DESCRIPTION'] = (
barChartDescriptionTable.apply(
lambda barChartDescription:
{
**barChartDescription['BAR_CHART_DESCRIPTION'],
'Jensen-Shannon distance': distance.jensenshannon(
barChartDescription['BAR_CHART_DESCRIPTION']['Adverse Reaction Reports guessed'],
barChartDescription['BAR_CHART_DESCRIPTION']['Adverse Reaction Reports known'],
base=2.0)
},
axis='columns'))
return barChartDescriptionTable

View File

@@ -2,12 +2,12 @@ import unittest
import pandas as pd import pandas as pd
from pandas.testing import assert_frame_equal from pandas.testing import assert_frame_equal
from TestHelper import TestHelper from TestHelper import TestHelper
from JensenShannonDistanceColumnAdder import JensenShannonDistanceColumnAdder from JensenShannonDistance2BarChartDescriptionColumnAdder import JensenShannonDistance2BarChartDescriptionColumnAdder
from scipy.spatial import distance from scipy.spatial import distance
class JensenShannonDistanceColumnAdderTest(unittest.TestCase): class JensenShannonDistance2BarChartDescriptionColumnAdderTest(unittest.TestCase):
def test_addJensenShannonDistanceColumn(self): def test_addJensenShannonDistance2BarChartDescriptionColumn(self):
# Given # Given
barChartDescriptionTable = TestHelper.createDataFrame( barChartDescriptionTable = TestHelper.createDataFrame(
columns = ['BAR_CHART_DESCRIPTION'], columns = ['BAR_CHART_DESCRIPTION'],
@@ -27,21 +27,21 @@ class JensenShannonDistanceColumnAdderTest(unittest.TestCase):
name = 'VAX_LOT')) name = 'VAX_LOT'))
# When # When
barChartDescriptionTableWithJensenShannonDistanceColumn = JensenShannonDistanceColumnAdder.addJensenShannonDistanceColumn(barChartDescriptionTable) barChartDescriptionTableWithJensenShannonDistanceColumn = JensenShannonDistance2BarChartDescriptionColumnAdder.addJensenShannonDistance2BarChartDescriptionColumn(barChartDescriptionTable)
# Then # Then
assert_frame_equal( assert_frame_equal(
barChartDescriptionTableWithJensenShannonDistanceColumn, barChartDescriptionTableWithJensenShannonDistanceColumn,
TestHelper.createDataFrame( TestHelper.createDataFrame(
columns = ['BAR_CHART_DESCRIPTION', 'JENSEN_SHANNON_DISTANCE'], columns = ['BAR_CHART_DESCRIPTION'],
data = [ data = [
[ [
{ {
'countries': ['Germany', 'Hungary'], 'countries': ['Germany', 'Hungary'],
'Adverse Reaction Reports guessed': [10, 15], 'Adverse Reaction Reports guessed': [10, 15],
'Adverse Reaction Reports known': [20, 30] 'Adverse Reaction Reports known': [20, 30],
}, 'Jensen-Shannon distance': distance.jensenshannon([10, 15], [20, 30], base = 2.0)
distance.jensenshannon([10, 15], [20, 30], base = 2.0) }
] ]
], ],
index = pd.Index( index = pd.Index(

View File

@@ -1,15 +0,0 @@
from scipy.spatial import distance
class JensenShannonDistanceColumnAdder:
@staticmethod
def addJensenShannonDistanceColumn(barChartDescriptionTable):
barChartDescriptionTable['JENSEN_SHANNON_DISTANCE'] = (
barChartDescriptionTable.apply(
lambda barChartDescription: distance.jensenshannon(
barChartDescription['BAR_CHART_DESCRIPTION']['Adverse Reaction Reports guessed'],
barChartDescription['BAR_CHART_DESCRIPTION']['Adverse Reaction Reports known'],
base=2.0),
axis = 'columns'))
return barChartDescriptionTable