adding SymptomCombinationsProviderTest

This commit is contained in:
frankknoll
2023-11-18 13:21:27 +01:00
parent b773e34f3d
commit 7081d9014b
3 changed files with 162 additions and 5 deletions

View File

@@ -665,6 +665,7 @@
"outputs": [],
"source": [
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
"from SymptomsCausedByVaccines.MultiLineFitting.SymptomCombinationsProvider import SymptomCombinationsProvider\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"from skspatial.objects import Line\n"
@@ -676,8 +677,8 @@
"metadata": {},
"outputs": [],
"source": [
"symptomX = 'Immunosuppression' # HIV test' # 'Immunosuppression'\n",
"symptomY = 'Pneumonia' # 'Infection' # 'Immunoglobulin therapy'"
"symptomX = 'Abdominal abscess' # HIV test' # 'Immunosuppression'\n",
"symptomY = 'Abdominal discomfort' # 'Infection' # 'Immunoglobulin therapy'"
]
},
{
@@ -691,6 +692,27 @@
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"symptomCombinations = SymptomCombinationsProvider.generateSymptomCombinations(\n",
" prrByLotAndSymptom[prrByLotAndSymptom.columns[:500]],\n",
" dataFramePredicate = lambda df: 30 <= len(df) <= 35)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for symptomCombination in symptomCombinations:\n",
" print(list(symptomCombination.columns))"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -745,7 +767,7 @@
"metadata": {},
"outputs": [],
"source": [
"draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 3)"
"draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 2)"
]
},
{
@@ -754,7 +776,7 @@
"metadata": {},
"outputs": [],
"source": [
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)"
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.1)"
]
},
{
@@ -763,7 +785,7 @@
"metadata": {},
"outputs": [],
"source": [
"draw(points, clusters, lines, symptomX, symptomY, minClusterSize = 3)"
"draw(points, clusters, lines, symptomX, symptomY, minClusterSize = 5)"
]
},
{

View File

@@ -0,0 +1,22 @@
from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs
class SymptomCombinationsProvider:
@staticmethod
def generateSymptomCombinations(prrByLotAndSymptom, dataFramePredicate):
symptomPairs = SymptomCombinationsProvider._generatePairs(prrByLotAndSymptom.columns)
symptomCombinations = (SymptomCombinationsProvider._generateSymptomCombination(prrByLotAndSymptom, symptomX, symptomY) for (symptomY, symptomX) in symptomPairs)
return SymptomCombinationsProvider._filter(symptomCombinations, dataFramePredicate)
@staticmethod
def _generatePairs(symptoms):
return ((symptoms[i], symptoms[j]) for (i, j) in generatePairs(len(symptoms)))
@staticmethod
def _generateSymptomCombination(prrByLotAndSymptom, symptomX, symptomY):
df = prrByLotAndSymptom[[symptomX, symptomY]]
return df[(df[symptomX] != 0) & (df[symptomY] != 0)]
@staticmethod
def _filter(dataFrames, dataFramePredicate):
return (dataFrame for dataFrame in dataFrames if dataFramePredicate(dataFrame))

View File

@@ -0,0 +1,113 @@
import unittest
from pandas.testing import assert_frame_equal
from TestHelper import TestHelper
import pandas as pd
from SymptomsCausedByVaccines.MultiLineFitting.SymptomCombinationsProvider import SymptomCombinationsProvider
class SymptomCombinationsProviderTest(unittest.TestCase):
def test_generateSymptomCombinations(self):
# Given
prrByLotAndSymptom = TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomB', 'SymptomC', 'SymptomD'],
data = [ [0.6, 1.5, 1.2, 0.0]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1'
]))
# When
symptomCombinations = list(
SymptomCombinationsProvider.generateSymptomCombinations(
prrByLotAndSymptom,
dataFramePredicate = lambda df: len(df) >= 1))
# Then
self.assertEqual(len(symptomCombinations), 3)
assert_frame_equal(
symptomCombinations[0],
TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomB'],
data = [ [0.6, 1.5]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1'
])))
assert_frame_equal(
symptomCombinations[1],
TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomC'],
data = [ [0.6, 1.2]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1'
])))
assert_frame_equal(
symptomCombinations[2],
TestHelper.createDataFrame(
columns = ['SymptomB', 'SymptomC'],
data = [ [1.5, 1.2]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1'
])))
def test_generateSymptomCombinations_minSizeOfDataFrame_2(self):
# Given
prrByLotAndSymptom = TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomB'],
data = [ [0.6, 1.5],
[1.6, 2.5]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1',
'LOT-2'
]))
# When
symptomCombinations = list(
SymptomCombinationsProvider.generateSymptomCombinations(
prrByLotAndSymptom,
dataFramePredicate = lambda df: len(df) >= 2))
# Then
self.assertEqual(len(symptomCombinations), 1)
assert_frame_equal(
symptomCombinations[0],
TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomB'],
data = [ [0.6, 1.5],
[1.6, 2.5]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1',
'LOT-2'
])))
def test_generateSymptomCombinations_minSizeOfDataFrame_3(self):
# Given
prrByLotAndSymptom = TestHelper.createDataFrame(
columns = ['SymptomA', 'SymptomB'],
data = [ [0.6, 1.5],
[1.6, 2.5]],
index = pd.Index(
name = 'VAX_LOT',
data = [
'LOT-1',
'LOT-2'
]))
# When
symptomCombinations = list(
SymptomCombinationsProvider.generateSymptomCombinations(
prrByLotAndSymptom,
dataFramePredicate = lambda df: len(df) >= 3))
# Then
self.assertEqual(len(symptomCombinations), 0)