adding SymptomCombinationsProviderTest

This commit is contained in:
frankknoll
2023-11-18 13:21:27 +01:00
parent b773e34f3d
commit 7081d9014b
3 changed files with 162 additions and 5 deletions

View File

@@ -665,6 +665,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n", "from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
"from SymptomsCausedByVaccines.MultiLineFitting.SymptomCombinationsProvider import SymptomCombinationsProvider\n",
"import numpy as np\n", "import numpy as np\n",
"from matplotlib import pyplot as plt\n", "from matplotlib import pyplot as plt\n",
"from skspatial.objects import Line\n" "from skspatial.objects import Line\n"
@@ -676,8 +677,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"symptomX = 'Immunosuppression' # HIV test' # 'Immunosuppression'\n", "symptomX = 'Abdominal abscess' # HIV test' # 'Immunosuppression'\n",
"symptomY = 'Pneumonia' # 'Infection' # 'Immunoglobulin therapy'" "symptomY = 'Abdominal discomfort' # 'Infection' # 'Immunoglobulin therapy'"
] ]
}, },
{ {
@@ -691,6 +692,27 @@
"df" "df"
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"symptomCombinations = SymptomCombinationsProvider.generateSymptomCombinations(\n",
" prrByLotAndSymptom[prrByLotAndSymptom.columns[:500]],\n",
" dataFramePredicate = lambda df: 30 <= len(df) <= 35)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for symptomCombination in symptomCombinations:\n",
" print(list(symptomCombination.columns))"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@@ -745,7 +767,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 3)" "draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 2)"
] ]
}, },
{ {
@@ -754,7 +776,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)" "clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.1)"
] ]
}, },
{ {
@@ -763,7 +785,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"draw(points, clusters, lines, symptomX, symptomY, minClusterSize = 3)" "draw(points, clusters, lines, symptomX, symptomY, minClusterSize = 5)"
] ]
}, },
{ {

View File

@@ -0,0 +1,22 @@
from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs
class SymptomCombinationsProvider:
@staticmethod
def generateSymptomCombinations(prrByLotAndSymptom, dataFramePredicate):
symptomPairs = SymptomCombinationsProvider._generatePairs(prrByLotAndSymptom.columns)
symptomCombinations = (SymptomCombinationsProvider._generateSymptomCombination(prrByLotAndSymptom, symptomX, symptomY) for (symptomY, symptomX) in symptomPairs)
return SymptomCombinationsProvider._filter(symptomCombinations, dataFramePredicate)
@staticmethod
def _generatePairs(symptoms):
return ((symptoms[i], symptoms[j]) for (i, j) in generatePairs(len(symptoms)))
@staticmethod
def _generateSymptomCombination(prrByLotAndSymptom, symptomX, symptomY):
df = prrByLotAndSymptom[[symptomX, symptomY]]
return df[(df[symptomX] != 0) & (df[symptomY] != 0)]
@staticmethod
def _filter(dataFrames, dataFramePredicate):
return (dataFrame for dataFrame in dataFrames if dataFramePredicate(dataFrame))

View File

@@ -0,0 +1,113 @@
import unittest
from pandas.testing import assert_frame_equal
from TestHelper import TestHelper
import pandas as pd
from SymptomsCausedByVaccines.MultiLineFitting.SymptomCombinationsProvider import SymptomCombinationsProvider
class SymptomCombinationsProviderTest(unittest.TestCase):
def test_generateSymptomCombinations(self):
# Given
prrByLotAndSymptom = TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomB', 'SymptomC', 'SymptomD'],
data = [ [0.6, 1.5, 1.2, 0.0]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1'
]))
# When
symptomCombinations = list(
SymptomCombinationsProvider.generateSymptomCombinations(
prrByLotAndSymptom,
dataFramePredicate = lambda df: len(df) >= 1))
# Then
self.assertEqual(len(symptomCombinations), 3)
assert_frame_equal(
symptomCombinations[0],
TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomB'],
data = [ [0.6, 1.5]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1'
])))
assert_frame_equal(
symptomCombinations[1],
TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomC'],
data = [ [0.6, 1.2]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1'
])))
assert_frame_equal(
symptomCombinations[2],
TestHelper.createDataFrame(
columns = ['SymptomB', 'SymptomC'],
data = [ [1.5, 1.2]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1'
])))
def test_generateSymptomCombinations_minSizeOfDataFrame_2(self):
# Given
prrByLotAndSymptom = TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomB'],
data = [ [0.6, 1.5],
[1.6, 2.5]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1',
'LOT-2'
]))
# When
symptomCombinations = list(
SymptomCombinationsProvider.generateSymptomCombinations(
prrByLotAndSymptom,
dataFramePredicate = lambda df: len(df) >= 2))
# Then
self.assertEqual(len(symptomCombinations), 1)
assert_frame_equal(
symptomCombinations[0],
TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomB'],
data = [ [0.6, 1.5],
[1.6, 2.5]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1',
'LOT-2'
])))
def test_generateSymptomCombinations_minSizeOfDataFrame_3(self):
# Given
prrByLotAndSymptom = TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomB'],
data = [ [0.6, 1.5],
[1.6, 2.5]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1',
'LOT-2'
]))
# When
symptomCombinations = list(
SymptomCombinationsProvider.generateSymptomCombinations(
prrByLotAndSymptom,
dataFramePredicate = lambda df: len(df) >= 3))
# Then
self.assertEqual(len(symptomCombinations), 0)