refining MultiLineFitterTest
This commit is contained in:
@@ -2,8 +2,15 @@ import numpy as np
|
|||||||
from skspatial.objects import Line
|
from skspatial.objects import Line
|
||||||
|
|
||||||
# implementation of "Robust Multiple Structures Estimation with J-linkage"
|
# implementation of "Robust Multiple Structures Estimation with J-linkage"
|
||||||
class ClustersFactory:
|
class MultiLineFitter:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fitLines(points, lines, consensusThreshold):
|
||||||
|
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||||
|
_, preferenceMatrix4Clusters = MultiLineFitter.createClusters(preferenceMatrix)
|
||||||
|
lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix4Clusters)
|
||||||
|
return [lines[lineIndex] for lineIndex in lineIndexes]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def createClusters(preferenceMatrix):
|
def createClusters(preferenceMatrix):
|
||||||
keepClustering = True
|
keepClustering = True
|
||||||
@@ -18,7 +25,7 @@ class ClustersFactory:
|
|||||||
preferenceSetA = preferenceMatrix[clusterIndexA]
|
preferenceSetA = preferenceMatrix[clusterIndexA]
|
||||||
for clusterIndexB in range(clusterIndexA):
|
for clusterIndexB in range(clusterIndexA):
|
||||||
preferenceSetB = preferenceMatrix[clusterIndexB]
|
preferenceSetB = preferenceMatrix[clusterIndexB]
|
||||||
distance = ClustersFactory._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
distance = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
||||||
if distance > maxDistance:
|
if distance > maxDistance:
|
||||||
keepClustering = True
|
keepClustering = True
|
||||||
maxDistance = distance
|
maxDistance = distance
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from skspatial.objects import Line
|
from skspatial.objects import Line
|
||||||
from SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory
|
from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter
|
||||||
|
|
||||||
|
|
||||||
class ClustersFactoryTest(unittest.TestCase):
|
class MultiLineFitterTest(unittest.TestCase):
|
||||||
|
|
||||||
def test_createPreferenceMatrix(self):
|
def test_createPreferenceMatrix(self):
|
||||||
# Given
|
# Given
|
||||||
@@ -13,7 +13,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
|||||||
consensusThreshold = 4.0
|
consensusThreshold = 4.0
|
||||||
|
|
||||||
# When
|
# When
|
||||||
preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
|
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
@@ -31,7 +31,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
|||||||
consensusThreshold = 0.001
|
consensusThreshold = 0.001
|
||||||
|
|
||||||
# When
|
# When
|
||||||
preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
|
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
@@ -55,7 +55,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# When
|
# When
|
||||||
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
|
clusters, _ = MultiLineFitter.createClusters(preferenceMatrix)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
@@ -77,7 +77,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# When
|
# When
|
||||||
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
|
clusters, _ = MultiLineFitter.createClusters(preferenceMatrix)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
@@ -97,7 +97,19 @@ class ClustersFactoryTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# When
|
# When
|
||||||
lineIndexes = ClustersFactory._getLineIndexes(preferenceMatrix)
|
lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(lineIndexes, [2, 1])
|
np.testing.assert_array_equal(lineIndexes, [2, 1])
|
||||||
|
|
||||||
|
def test_fitLines(self):
|
||||||
|
# Given
|
||||||
|
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
|
||||||
|
line1 = Line.from_points([0, 0], [1, 0])
|
||||||
|
line2 = Line.from_points([0, 0], [1, 1])
|
||||||
|
|
||||||
|
# When
|
||||||
|
fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2], consensusThreshold = 0.001)
|
||||||
|
|
||||||
|
# Then
|
||||||
|
np.testing.assert_array_equal(fittedLines, [line1, line2])
|
||||||
Reference in New Issue
Block a user