refactoring
This commit is contained in:
@@ -1,18 +1,26 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from skspatial.objects import Line
|
from skspatial.objects import Line
|
||||||
|
|
||||||
# implementation of "Robust Multiple Structures Estimation with J-linkage"
|
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
|
||||||
class MultiLineFitter:
|
class MultiLineFitter:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fitLines(points, lines, consensusThreshold):
|
def fitLines(points, lines, consensusThreshold):
|
||||||
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||||
_, preferenceMatrix4Clusters = MultiLineFitter.createClusters(preferenceMatrix)
|
_, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix)
|
||||||
lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix4Clusters)
|
lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix4Clusters)
|
||||||
return [lines[lineIndex] for lineIndex in lineIndexes]
|
return [lines[lineIndex] for lineIndex in lineIndexes]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def createClusters(preferenceMatrix):
|
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
||||||
|
preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int)
|
||||||
|
for pointIndex, point in enumerate(points):
|
||||||
|
for lineIndex, line in enumerate(lines):
|
||||||
|
preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0
|
||||||
|
return preferenceMatrix
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _createClusters(preferenceMatrix):
|
||||||
keepClustering = True
|
keepClustering = True
|
||||||
numClusters = preferenceMatrix.shape[0]
|
numClusters = preferenceMatrix.shape[0]
|
||||||
clusters = [[i] for i in range(numClusters)]
|
clusters = [[i] for i in range(numClusters)]
|
||||||
@@ -40,14 +48,6 @@ class MultiLineFitter:
|
|||||||
|
|
||||||
return clusters, preferenceMatrix
|
return clusters, preferenceMatrix
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
|
||||||
preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int)
|
|
||||||
for pointIndex, point in enumerate(points):
|
|
||||||
for lineIndex, line in enumerate(lines):
|
|
||||||
preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0
|
|
||||||
return preferenceMatrix
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _intersectionOverUnion(setA, setB):
|
def _intersectionOverUnion(setA, setB):
|
||||||
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class MultiLineFitterTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# When
|
# When
|
||||||
clusters, _ = MultiLineFitter.createClusters(preferenceMatrix)
|
clusters, _ = MultiLineFitter._createClusters(preferenceMatrix)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
@@ -77,7 +77,7 @@ class MultiLineFitterTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# When
|
# When
|
||||||
clusters, _ = MultiLineFitter.createClusters(preferenceMatrix)
|
clusters, _ = MultiLineFitter._createClusters(preferenceMatrix)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
@@ -107,9 +107,13 @@ class MultiLineFitterTest(unittest.TestCase):
|
|||||||
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
|
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
|
||||||
line1 = Line.from_points([0, 0], [1, 0])
|
line1 = Line.from_points([0, 0], [1, 0])
|
||||||
line2 = Line.from_points([0, 0], [1, 1])
|
line2 = Line.from_points([0, 0], [1, 1])
|
||||||
|
line3 = Line.from_points([0, 0], [0, 1])
|
||||||
|
|
||||||
# When
|
# When
|
||||||
fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2], consensusThreshold = 0.001)
|
fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2, line3], consensusThreshold = 0.001)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(fittedLines, [line1, line2])
|
np.testing.assert_array_equal(fittedLines, [line1, line2])
|
||||||
|
|
||||||
|
#FK-TODO: erzeuge LinesFactory.createLines(points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)])
|
||||||
|
# Diese Funktion soll alle Linien erzeugen, die jeweils zwei verschiedene Punkte aus points verbinden.
|
||||||
Reference in New Issue
Block a user