From d40116ba6f70601bb88a7b8f0e4c1f9458e13f3a Mon Sep 17 00:00:00 2001 From: frankknoll Date: Sat, 18 Nov 2023 11:05:35 +0100 Subject: [PATCH] using only ascending lines for fitting --- src/HowBadIsMyBatch.ipynb | 17 ++++++++++--- .../MultiLineFitting/LinesFactory.py | 25 +++++++++++++++---- .../MultiLineFitting/LinesFactoryTest.py | 12 +++++++++ .../MultiLineFitting/MultiLineFitter.py | 8 ++++-- .../MultiLineFitting/Utils.py | 2 +- 5 files changed, 52 insertions(+), 12 deletions(-) diff --git a/src/HowBadIsMyBatch.ipynb b/src/HowBadIsMyBatch.ipynb index 62e22a52129..c8e0233cad9 100644 --- a/src/HowBadIsMyBatch.ipynb +++ b/src/HowBadIsMyBatch.ipynb @@ -651,6 +651,13 @@ " htmlFile = os.path.normpath(webAppBaseDir + '/index.html'))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Multi Line Fitting" + ] + }, { "cell_type": "code", "execution_count": null, @@ -669,8 +676,8 @@ "metadata": {}, "outputs": [], "source": [ - "symptomX = 'HIV test' # 'Immunosuppression'\n", - "symptomY = 'Immunoglobulin therapy' # 'Infection' # 'Immunoglobulin therapy'" + "symptomX = 'Immunosuppression' # HIV test' # 'Immunosuppression'\n", + "symptomY = 'Pneumonia' # 'Infection' # 'Immunoglobulin therapy'" ] }, { @@ -702,7 +709,7 @@ "source": [ "from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n", "\n", - "clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)" + "clusters, lines = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)" ] }, { @@ -738,8 +745,10 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py index 47df7a77386..995bfb169f9 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py @@ -1,17 +1,32 @@ from skspatial.objects import Line -from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs +from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs class LinesFactory: @staticmethod def createLines(points): - lines = [Line.from_points(pointA, pointB) for (pointA, pointB) in LinesFactory._getPairs(points)] - return LinesFactory._getUniqueLines(lines) + return LinesFactory._getUniqueLines(list(LinesFactory._generateAllLines(points))) @staticmethod - def _getPairs(points): - return ((points[i], points[j]) for (i, j) in getPairs(len(points))) + def createAscendingLines(points): + return LinesFactory._getUniqueLines(list(LinesFactory._generateAllAscendingLines(points))) + + @staticmethod + def _generateAllAscendingLines(points): + return (line for line in LinesFactory._generateAllLines(points) if LinesFactory._isAscending(line.direction)) + + @staticmethod + def _generateAllLines(points): + return (Line.from_points(pointA, pointB) for (pointA, pointB) in LinesFactory._generatePairs(points)) + + @staticmethod + def _isAscending(direction): + return (direction[0] >= 0 and direction[1] >= 0) or (direction[0] <= 0 and direction[1] <= 0) + + @staticmethod + def _generatePairs(points): + return ((points[i], points[j]) for (i, j) in generatePairs(len(points))) @staticmethod def _getUniqueLines(lines): diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py index d53ec86e419..3d30c4694ab 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py @@ -28,3 +28,15 @@ class LinesFactoryTest(unittest.TestCase): self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1]))) self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1]))) + + def test_createAscendingLines(self): + # Given + points = [(0, 0), (1, 0), (0, 1)] + + # When + lines = LinesFactory.createAscendingLines(points) + + # Then + self.assertEqual(len(lines), 2) + self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) + self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1]))) diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py index 14bf0fba376..a92a3b1f812 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py @@ -1,6 +1,6 @@ import numpy as np from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory -from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs +from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import CharacteristicFunctions # implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage @@ -10,6 +10,10 @@ class MultiLineFitter: def fitPointsByLines(points, consensusThreshold): return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold) + @staticmethod + def fitPointsByAscendingLines(points, consensusThreshold): + return MultiLineFitter.fitLines(points, LinesFactory.createAscendingLines(points), consensusThreshold) + @staticmethod def fitLines(points, lines, consensusThreshold): preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold) @@ -48,7 +52,7 @@ class MultiLineFitter: bestClusterIndexCombination = None keepClustering = False numClusters = preferenceMatrix.shape[0] - for (clusterIndexA, clusterIndexB) in getPairs(numClusters): + for (clusterIndexA, clusterIndexB) in generatePairs(numClusters): preferenceSetA = preferenceMatrix[clusterIndexA] preferenceSetB = preferenceMatrix[clusterIndexB] similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB); diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py b/src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py index 3444571f5f9..7afd114e81d 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py @@ -1,4 +1,4 @@ -def getPairs(n): +def generatePairs(n): for i in range(n): for j in range(i): yield (i, j)