From a14c1cd2179607713c055070f43b5c2a7f05e525 Mon Sep 17 00:00:00 2001 From: frankknoll Date: Fri, 17 Nov 2023 10:28:17 +0100 Subject: [PATCH] refining MultiLineFitterTest --- .../MultiLineFitting/LinesFactory.py | 4 ++-- .../MultiLineFitting/MultiLineFitter.py | 14 +++++++------- .../MultiLineFitting/MultiLineFitterTest.py | 13 +++++++++++-- .../MultiLineFitting/Utils.py | 4 ++++ 4 files changed, 24 insertions(+), 11 deletions(-) create mode 100644 src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py index 0887ae6d034..47df7a77386 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py @@ -1,5 +1,5 @@ from skspatial.objects import Line -from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter +from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs class LinesFactory: @@ -11,7 +11,7 @@ class LinesFactory: @staticmethod def _getPairs(points): - return ((points[i], points[j]) for (i, j) in MultiLineFitter._getPairs(len(points))) + return ((points[i], points[j]) for (i, j) in getPairs(len(points))) @staticmethod def _getUniqueLines(lines): diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py index 3c5d0c5df06..b4e7e0cc22a 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py @@ -1,8 +1,14 @@ import numpy as np +from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory +from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs # implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage class MultiLineFitter: + @staticmethod + def fitPointsByLines(points, consensusThreshold): + return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold) + @staticmethod def fitLines(points, lines, consensusThreshold): preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold) @@ -28,7 +34,7 @@ class MultiLineFitter: bestClusterIndexCombination = None keepClustering = False numClusters = preferenceMatrix.shape[0] - for (clusterIndexA, clusterIndexB) in MultiLineFitter._getPairs(numClusters): + for (clusterIndexA, clusterIndexB) in getPairs(numClusters): preferenceSetA = preferenceMatrix[clusterIndexA] preferenceSetB = preferenceMatrix[clusterIndexB] similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB); @@ -46,12 +52,6 @@ class MultiLineFitter: return clusters, preferenceMatrix - @staticmethod - def _getPairs(n): - for i in range(n): - for j in range(i): - yield (i, j) - @staticmethod def _intersectionOverUnion(setA, setB): intersection = np.count_nonzero(np.logical_and(setA, setB)) diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py index 9d8704c304b..081a920a3b0 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py @@ -115,5 +115,14 @@ class MultiLineFitterTest(unittest.TestCase): # Then np.testing.assert_array_equal(fittedLines, [line1, line2]) -#FK-TODO: erzeuge LinesFactory.createLines(points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]) -# Diese Funktion soll alle Linien erzeugen, die jeweils zwei verschiedene Punkte aus points verbinden. \ No newline at end of file + def test_fitPointsByLines(self): + # Given + points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)] + + # When + lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001) + + # Then + self.assertEqual(len(lines), 2) + self.assertTrue(lines[0].is_close(Line.from_points([0, 0], [1, 0]))) + self.assertTrue(lines[1].is_close(Line.from_points([0, 0], [1, 1]))) diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py b/src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py new file mode 100644 index 00000000000..3444571f5f9 --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py @@ -0,0 +1,4 @@ +def getPairs(n): + for i in range(n): + for j in range(i): + yield (i, j)