refining MultiLineFitterTest

This commit is contained in:
frankknoll
2023-11-16 17:50:03 +01:00
parent 8231453ae2
commit fe7c2b1c88
2 changed files with 28 additions and 9 deletions

View File

@@ -2,8 +2,15 @@ import numpy as np
from skspatial.objects import Line
# implementation of "Robust Multiple Structures Estimation with J-linkage"
class ClustersFactory:
class MultiLineFitter:
@staticmethod
def fitLines(points, lines, consensusThreshold):
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
_, preferenceMatrix4Clusters = MultiLineFitter.createClusters(preferenceMatrix)
lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix4Clusters)
return [lines[lineIndex] for lineIndex in lineIndexes]
@staticmethod
def createClusters(preferenceMatrix):
keepClustering = True
@@ -18,7 +25,7 @@ class ClustersFactory:
preferenceSetA = preferenceMatrix[clusterIndexA]
for clusterIndexB in range(clusterIndexA):
preferenceSetB = preferenceMatrix[clusterIndexB]
distance = ClustersFactory._intersectionOverUnion(preferenceSetA, preferenceSetB);
distance = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
if distance > maxDistance:
keepClustering = True
maxDistance = distance

View File

@@ -1,10 +1,10 @@
import unittest
import numpy as np
from skspatial.objects import Line
from SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory
from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter
class ClustersFactoryTest(unittest.TestCase):
class MultiLineFitterTest(unittest.TestCase):
def test_createPreferenceMatrix(self):
# Given
@@ -13,7 +13,7 @@ class ClustersFactoryTest(unittest.TestCase):
consensusThreshold = 4.0
# When
preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
# Then
np.testing.assert_array_equal(
@@ -31,7 +31,7 @@ class ClustersFactoryTest(unittest.TestCase):
consensusThreshold = 0.001
# When
preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
# Then
np.testing.assert_array_equal(
@@ -55,7 +55,7 @@ class ClustersFactoryTest(unittest.TestCase):
])
# When
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
clusters, _ = MultiLineFitter.createClusters(preferenceMatrix)
# Then
np.testing.assert_array_equal(
@@ -77,7 +77,7 @@ class ClustersFactoryTest(unittest.TestCase):
])
# When
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
clusters, _ = MultiLineFitter.createClusters(preferenceMatrix)
# Then
np.testing.assert_array_equal(
@@ -97,7 +97,19 @@ class ClustersFactoryTest(unittest.TestCase):
])
# When
lineIndexes = ClustersFactory._getLineIndexes(preferenceMatrix)
lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix)
# Then
np.testing.assert_array_equal(lineIndexes, [2, 1])
def test_fitLines(self):
# Given
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
line1 = Line.from_points([0, 0], [1, 0])
line2 = Line.from_points([0, 0], [1, 1])
# When
fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2], consensusThreshold = 0.001)
# Then
np.testing.assert_array_equal(fittedLines, [line1, line2])