diff --git a/src/HowBadIsMyBatch.ipynb b/src/HowBadIsMyBatch.ipynb index 29e3fc3637f..dcd5bb9f928 100644 --- a/src/HowBadIsMyBatch.ipynb +++ b/src/HowBadIsMyBatch.ipynb @@ -665,6 +665,7 @@ "outputs": [], "source": [ "from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n", + "from SymptomsCausedByVaccines.MultiLineFitting.SymptomCombinationsProvider import SymptomCombinationsProvider\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from skspatial.objects import Line\n" @@ -676,8 +677,8 @@ "metadata": {}, "outputs": [], "source": [ - "symptomX = 'Immunosuppression' # HIV test' # 'Immunosuppression'\n", - "symptomY = 'Pneumonia' # 'Infection' # 'Immunoglobulin therapy'" + "symptomX = 'Abdominal abscess' # HIV test' # 'Immunosuppression'\n", + "symptomY = 'Abdominal discomfort' # 'Infection' # 'Immunoglobulin therapy'" ] }, { @@ -691,6 +692,27 @@ "df" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "symptomCombinations = SymptomCombinationsProvider.generateSymptomCombinations(\n", + " prrByLotAndSymptom[prrByLotAndSymptom.columns[:500]],\n", + " dataFramePredicate = lambda df: 30 <= len(df) <= 35)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for symptomCombination in symptomCombinations:\n", + " print(list(symptomCombination.columns))" + ] + }, { "cell_type": "code", "execution_count": null, @@ -745,7 +767,7 @@ "metadata": {}, "outputs": [], "source": [ - "draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 3)" + "draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 2)" ] }, { @@ -754,7 +776,7 @@ "metadata": {}, "outputs": [], "source": [ - "clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)" + "clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.1)" ] }, { @@ -763,7 +785,7 @@ "metadata": {}, "outputs": [], "source": [ - "draw(points, clusters, lines, symptomX, symptomY, minClusterSize = 3)" + "draw(points, clusters, lines, symptomX, symptomY, minClusterSize = 5)" ] }, { diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/SymptomCombinationsProvider.py b/src/SymptomsCausedByVaccines/MultiLineFitting/SymptomCombinationsProvider.py new file mode 100644 index 00000000000..92c00b46b76 --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/SymptomCombinationsProvider.py @@ -0,0 +1,22 @@ +from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs + +class SymptomCombinationsProvider: + + @staticmethod + def generateSymptomCombinations(prrByLotAndSymptom, dataFramePredicate): + symptomPairs = SymptomCombinationsProvider._generatePairs(prrByLotAndSymptom.columns) + symptomCombinations = (SymptomCombinationsProvider._generateSymptomCombination(prrByLotAndSymptom, symptomX, symptomY) for (symptomY, symptomX) in symptomPairs) + return SymptomCombinationsProvider._filter(symptomCombinations, dataFramePredicate) + + @staticmethod + def _generatePairs(symptoms): + return ((symptoms[i], symptoms[j]) for (i, j) in generatePairs(len(symptoms))) + + @staticmethod + def _generateSymptomCombination(prrByLotAndSymptom, symptomX, symptomY): + df = prrByLotAndSymptom[[symptomX, symptomY]] + return df[(df[symptomX] != 0) & (df[symptomY] != 0)] + + @staticmethod + def _filter(dataFrames, dataFramePredicate): + return (dataFrame for dataFrame in dataFrames if dataFramePredicate(dataFrame)) diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/SymptomCombinationsProviderTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/SymptomCombinationsProviderTest.py new file mode 100644 index 00000000000..fe291d0f40b --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/SymptomCombinationsProviderTest.py @@ -0,0 +1,113 @@ +import unittest +from pandas.testing import assert_frame_equal +from TestHelper import TestHelper +import pandas as pd +from SymptomsCausedByVaccines.MultiLineFitting.SymptomCombinationsProvider import SymptomCombinationsProvider + +class SymptomCombinationsProviderTest(unittest.TestCase): + + def test_generateSymptomCombinations(self): + # Given + prrByLotAndSymptom = TestHelper.createDataFrame( + columns = ['SymptomA', 'SymptomB', 'SymptomC', 'SymptomD'], + data = [ [0.6, 1.5, 1.2, 0.0]], + index = pd.Index( + name = 'VAX_LOT', + data = [ + 'LOT-1' + ])) + + # When + symptomCombinations = list( + SymptomCombinationsProvider.generateSymptomCombinations( + prrByLotAndSymptom, + dataFramePredicate = lambda df: len(df) >= 1)) + + # Then + self.assertEqual(len(symptomCombinations), 3) + assert_frame_equal( + symptomCombinations[0], + TestHelper.createDataFrame( + columns = ['SymptomA', 'SymptomB'], + data = [ [0.6, 1.5]], + index = pd.Index( + name = 'VAX_LOT', + data = [ + 'LOT-1' + ]))) + assert_frame_equal( + symptomCombinations[1], + TestHelper.createDataFrame( + columns = ['SymptomA', 'SymptomC'], + data = [ [0.6, 1.2]], + index = pd.Index( + name = 'VAX_LOT', + data = [ + 'LOT-1' + ]))) + assert_frame_equal( + symptomCombinations[2], + TestHelper.createDataFrame( + columns = ['SymptomB', 'SymptomC'], + data = [ [1.5, 1.2]], + index = pd.Index( + name = 'VAX_LOT', + data = [ + 'LOT-1' + ]))) + + def test_generateSymptomCombinations_minSizeOfDataFrame_2(self): + # Given + prrByLotAndSymptom = TestHelper.createDataFrame( + columns = ['SymptomA', 'SymptomB'], + data = [ [0.6, 1.5], + [1.6, 2.5]], + index = pd.Index( + name = 'VAX_LOT', + data = [ + 'LOT-1', + 'LOT-2' + ])) + + # When + symptomCombinations = list( + SymptomCombinationsProvider.generateSymptomCombinations( + prrByLotAndSymptom, + dataFramePredicate = lambda df: len(df) >= 2)) + + # Then + self.assertEqual(len(symptomCombinations), 1) + assert_frame_equal( + symptomCombinations[0], + TestHelper.createDataFrame( + columns = ['SymptomA', 'SymptomB'], + data = [ [0.6, 1.5], + [1.6, 2.5]], + index = pd.Index( + name = 'VAX_LOT', + data = [ + 'LOT-1', + 'LOT-2' + ]))) + + def test_generateSymptomCombinations_minSizeOfDataFrame_3(self): + # Given + prrByLotAndSymptom = TestHelper.createDataFrame( + columns = ['SymptomA', 'SymptomB'], + data = [ [0.6, 1.5], + [1.6, 2.5]], + index = pd.Index( + name = 'VAX_LOT', + data = [ + 'LOT-1', + 'LOT-2' + ])) + + # When + symptomCombinations = list( + SymptomCombinationsProvider.generateSymptomCombinations( + prrByLotAndSymptom, + dataFramePredicate = lambda df: len(df) >= 3)) + + # Then + self.assertEqual(len(symptomCombinations), 0)