refining MultiLineFitterTest
This commit is contained in:
@@ -2,7 +2,14 @@ import numpy as np
|
||||
from skspatial.objects import Line
|
||||
|
||||
# implementation of "Robust Multiple Structures Estimation with J-linkage"
|
||||
class ClustersFactory:
|
||||
class MultiLineFitter:
|
||||
|
||||
@staticmethod
|
||||
def fitLines(points, lines, consensusThreshold):
|
||||
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||
_, preferenceMatrix4Clusters = MultiLineFitter.createClusters(preferenceMatrix)
|
||||
lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix4Clusters)
|
||||
return [lines[lineIndex] for lineIndex in lineIndexes]
|
||||
|
||||
@staticmethod
|
||||
def createClusters(preferenceMatrix):
|
||||
@@ -18,7 +25,7 @@ class ClustersFactory:
|
||||
preferenceSetA = preferenceMatrix[clusterIndexA]
|
||||
for clusterIndexB in range(clusterIndexA):
|
||||
preferenceSetB = preferenceMatrix[clusterIndexB]
|
||||
distance = ClustersFactory._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
||||
distance = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
||||
if distance > maxDistance:
|
||||
keepClustering = True
|
||||
maxDistance = distance
|
||||
@@ -1,10 +1,10 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from skspatial.objects import Line
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter
|
||||
|
||||
|
||||
class ClustersFactoryTest(unittest.TestCase):
|
||||
class MultiLineFitterTest(unittest.TestCase):
|
||||
|
||||
def test_createPreferenceMatrix(self):
|
||||
# Given
|
||||
@@ -13,7 +13,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
||||
consensusThreshold = 4.0
|
||||
|
||||
# When
|
||||
preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||
|
||||
# Then
|
||||
np.testing.assert_array_equal(
|
||||
@@ -31,7 +31,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
||||
consensusThreshold = 0.001
|
||||
|
||||
# When
|
||||
preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||
|
||||
# Then
|
||||
np.testing.assert_array_equal(
|
||||
@@ -55,7 +55,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
||||
])
|
||||
|
||||
# When
|
||||
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
|
||||
clusters, _ = MultiLineFitter.createClusters(preferenceMatrix)
|
||||
|
||||
# Then
|
||||
np.testing.assert_array_equal(
|
||||
@@ -77,7 +77,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
||||
])
|
||||
|
||||
# When
|
||||
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
|
||||
clusters, _ = MultiLineFitter.createClusters(preferenceMatrix)
|
||||
|
||||
# Then
|
||||
np.testing.assert_array_equal(
|
||||
@@ -97,7 +97,19 @@ class ClustersFactoryTest(unittest.TestCase):
|
||||
])
|
||||
|
||||
# When
|
||||
lineIndexes = ClustersFactory._getLineIndexes(preferenceMatrix)
|
||||
lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix)
|
||||
|
||||
# Then
|
||||
np.testing.assert_array_equal(lineIndexes, [2, 1])
|
||||
|
||||
def test_fitLines(self):
|
||||
# Given
|
||||
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
|
||||
line1 = Line.from_points([0, 0], [1, 0])
|
||||
line2 = Line.from_points([0, 0], [1, 1])
|
||||
|
||||
# When
|
||||
fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2], consensusThreshold = 0.001)
|
||||
|
||||
# Then
|
||||
np.testing.assert_array_equal(fittedLines, [line1, line2])
|
||||
Reference in New Issue
Block a user