diff --git a/src/HowBadIsMyBatch.ipynb b/src/HowBadIsMyBatch.ipynb index dcd5bb9f928..7e3cfb75e28 100644 --- a/src/HowBadIsMyBatch.ipynb +++ b/src/HowBadIsMyBatch.ipynb @@ -666,9 +666,7 @@ "source": [ "from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n", "from SymptomsCausedByVaccines.MultiLineFitting.SymptomCombinationsProvider import SymptomCombinationsProvider\n", - "import numpy as np\n", - "from matplotlib import pyplot as plt\n", - "from skspatial.objects import Line\n" + "from matplotlib import pyplot as plt\n" ] }, { @@ -677,8 +675,8 @@ "metadata": {}, "outputs": [], "source": [ - "symptomX = 'Abdominal abscess' # HIV test' # 'Immunosuppression'\n", - "symptomY = 'Abdominal discomfort' # 'Infection' # 'Immunoglobulin therapy'" + "# symptomX = 'Abdominal discomfort' # HIV test' # 'Immunosuppression'\n", + "# symptomY = 'Abdominal distension' # 'Infection' # 'Immunoglobulin therapy'" ] }, { @@ -687,9 +685,20 @@ "metadata": {}, "outputs": [], "source": [ - "df = prrByLotAndSymptom[[symptomX, symptomY]]\n", - "df = df[(df[symptomX] != 0) & (df[symptomY] != 0)]\n", - "df" + "# df = prrByLotAndSymptom[[symptomX, symptomY]]\n", + "# df = df[(df[symptomX] != 0) & (df[symptomY] != 0)]\n", + "# df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# retain only those columns of prrByLotAndSymptom that have more than 400 PRRs != 0\n", + "# prrByLotAndSymptom2 = prrByLotAndSymptom.loc[:, (prrByLotAndSymptom != 0).sum() >= 400]\n", + "# prrByLotAndSymptom2" ] }, { @@ -699,8 +708,8 @@ "outputs": [], "source": [ "symptomCombinations = SymptomCombinationsProvider.generateSymptomCombinations(\n", - " prrByLotAndSymptom[prrByLotAndSymptom.columns[:500]],\n", - " dataFramePredicate = lambda df: 30 <= len(df) <= 35)" + " prrByLotAndSymptom,\n", + " dataFramePredicate = lambda df: 40 <= len(df) <= 50)" ] }, { @@ -709,8 +718,19 @@ "metadata": {}, "outputs": [], "source": [ - "for symptomCombination in symptomCombinations:\n", - " print(list(symptomCombination.columns))" + "from SymptomsCausedByVaccines.MultiLineFitting.Utils import take\n", + "\n", + "df = take(symptomCombinations, 1)[0]\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "symptomX, symptomY = df.columns" ] }, { @@ -758,7 +778,10 @@ "metadata": {}, "outputs": [], "source": [ - "clustersAscending, linesAscending = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)" + "clustersAscending, linesAscending = MultiLineFitter.fitPointsByAscendingLines(\n", + " points,\n", + " consensusThreshold = 0.01,\n", + " maxNumLines = None)" ] }, { @@ -767,7 +790,7 @@ "metadata": {}, "outputs": [], "source": [ - "draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 2)" + "draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 5)" ] }, { @@ -776,7 +799,10 @@ "metadata": {}, "outputs": [], "source": [ - "clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.1)" + "clusters, lines = MultiLineFitter.fitPointsByLines(\n", + " points,\n", + " consensusThreshold = 0.01,\n", + " maxNumLines = None)" ] }, { diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py index 995bfb169f9..bac9a97902e 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py @@ -1,16 +1,22 @@ from skspatial.objects import Line -from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs +from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs, take class LinesFactory: @staticmethod - def createLines(points): - return LinesFactory._getUniqueLines(list(LinesFactory._generateAllLines(points))) + def createLines(points, maxNumLines = None): + return LinesFactory._getUniqueLines( + take( + LinesFactory._generateAllLines(points), + maxNumLines)) @staticmethod - def createAscendingLines(points): - return LinesFactory._getUniqueLines(list(LinesFactory._generateAllAscendingLines(points))) + def createAscendingLines(points, maxNumLines = None): + return LinesFactory._getUniqueLines( + take( + LinesFactory._generateAllAscendingLines(points), + maxNumLines)) @staticmethod def _generateAllAscendingLines(points): diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py index 3d30c4694ab..c8e1f6b6344 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py @@ -16,6 +16,7 @@ class LinesFactoryTest(unittest.TestCase): self.assertEqual(len(lines), 1) self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) + def test_createLines2(self): # Given points = [(0, 0), (1, 0), (0, 1)] @@ -29,6 +30,20 @@ class LinesFactoryTest(unittest.TestCase): self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1]))) self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1]))) + + def test_createLines_maxNumLines(self): + # Given + points = [(0, 0), (1, 0), (0, 1)] + + # When + lines = LinesFactory.createLines(points, maxNumLines = 2) + + # Then + self.assertEqual(len(lines), 2) + self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) + self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1]))) + + def test_createAscendingLines(self): # Given points = [(0, 0), (1, 0), (0, 1)] @@ -40,3 +55,14 @@ class LinesFactoryTest(unittest.TestCase): self.assertEqual(len(lines), 2) self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1]))) + + def test_createAscendingLines_maxNumLines(self): + # Given + points = [(0, 0), (1, 0), (0, 1)] + + # When + lines = LinesFactory.createAscendingLines(points, maxNumLines = 1) + + # Then + self.assertEqual(len(lines), 1) + self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py index a92a3b1f812..b28bfa1a1f4 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py @@ -7,12 +7,18 @@ from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import Ch class MultiLineFitter: @staticmethod - def fitPointsByLines(points, consensusThreshold): - return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold) + def fitPointsByLines(points, consensusThreshold, maxNumLines = None): + return MultiLineFitter.fitLines( + points, + LinesFactory.createLines(points, maxNumLines), + consensusThreshold) @staticmethod - def fitPointsByAscendingLines(points, consensusThreshold): - return MultiLineFitter.fitLines(points, LinesFactory.createAscendingLines(points), consensusThreshold) + def fitPointsByAscendingLines(points, consensusThreshold, maxNumLines = None): + return MultiLineFitter.fitLines( + points, + LinesFactory.createAscendingLines(points, maxNumLines), + consensusThreshold) @staticmethod def fitLines(points, lines, consensusThreshold): @@ -72,7 +78,7 @@ class MultiLineFitter: def _intersectionOverUnion(setA, setB): intersection = np.count_nonzero(np.logical_and(setA, setB)) union = np.count_nonzero(np.logical_or(setA, setB)) - return 1. * intersection / union + return 1. * intersection / union if intersection > 0.0 else 0 @staticmethod def _getLines(lines, preferenceMatrix): @@ -80,7 +86,15 @@ class MultiLineFitter: @staticmethod def _getLineIndexes(preferenceMatrix): - return [list(lines).index(1) for lines in preferenceMatrix] + lineIndexes = (MultiLineFitter._index(lines, 1) for lines in preferenceMatrix) + return [lineIndex for lineIndex in lineIndexes if lineIndex is not None] + + @staticmethod + def _index(xs, x): + try: + return list(xs).index(x) + except ValueError: + return None @staticmethod def _getClusterPoints(points, clusters): diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py b/src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py index 7afd114e81d..3d76dfb4742 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py @@ -1,4 +1,9 @@ +import itertools + def generatePairs(n): for i in range(n): for j in range(i): yield (i, j) + +def take(iterable, numElements): + return list(itertools.islice(iterable, numElements)) if numElements is not None else list(iterable)