diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactory.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py similarity index 82% rename from src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactory.py rename to src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py index ddb633a20ef..78a9877b6fc 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactory.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py @@ -2,8 +2,15 @@ import numpy as np from skspatial.objects import Line # implementation of "Robust Multiple Structures Estimation with J-linkage" -class ClustersFactory: +class MultiLineFitter: + @staticmethod + def fitLines(points, lines, consensusThreshold): + preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold) + _, preferenceMatrix4Clusters = MultiLineFitter.createClusters(preferenceMatrix) + lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix4Clusters) + return [lines[lineIndex] for lineIndex in lineIndexes] + @staticmethod def createClusters(preferenceMatrix): keepClustering = True @@ -18,7 +25,7 @@ class ClustersFactory: preferenceSetA = preferenceMatrix[clusterIndexA] for clusterIndexB in range(clusterIndexA): preferenceSetB = preferenceMatrix[clusterIndexB] - distance = ClustersFactory._intersectionOverUnion(preferenceSetA, preferenceSetB); + distance = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB); if distance > maxDistance: keepClustering = True maxDistance = distance diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py similarity index 72% rename from src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py rename to src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py index d6780902c78..c8704a7065b 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py @@ -1,10 +1,10 @@ import unittest import numpy as np from skspatial.objects import Line -from SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory +from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter -class ClustersFactoryTest(unittest.TestCase): +class MultiLineFitterTest(unittest.TestCase): def test_createPreferenceMatrix(self): # Given @@ -13,7 +13,7 @@ class ClustersFactoryTest(unittest.TestCase): consensusThreshold = 4.0 # When - preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold) + preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold) # Then np.testing.assert_array_equal( @@ -31,7 +31,7 @@ class ClustersFactoryTest(unittest.TestCase): consensusThreshold = 0.001 # When - preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold) + preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold) # Then np.testing.assert_array_equal( @@ -55,7 +55,7 @@ class ClustersFactoryTest(unittest.TestCase): ]) # When - clusters, _ = ClustersFactory.createClusters(preferenceMatrix) + clusters, _ = MultiLineFitter.createClusters(preferenceMatrix) # Then np.testing.assert_array_equal( @@ -77,7 +77,7 @@ class ClustersFactoryTest(unittest.TestCase): ]) # When - clusters, _ = ClustersFactory.createClusters(preferenceMatrix) + clusters, _ = MultiLineFitter.createClusters(preferenceMatrix) # Then np.testing.assert_array_equal( @@ -97,7 +97,19 @@ class ClustersFactoryTest(unittest.TestCase): ]) # When - lineIndexes = ClustersFactory._getLineIndexes(preferenceMatrix) + lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix) # Then np.testing.assert_array_equal(lineIndexes, [2, 1]) + + def test_fitLines(self): + # Given + points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)] + line1 = Line.from_points([0, 0], [1, 0]) + line2 = Line.from_points([0, 0], [1, 1]) + + # When + fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2], consensusThreshold = 0.001) + + # Then + np.testing.assert_array_equal(fittedLines, [line1, line2])