MultiLineFitterTest

This commit is contained in:
frankknoll
2023-11-17 18:16:09 +01:00
parent c8849301fe
commit 2d43c31f95
3 changed files with 251 additions and 9 deletions

View File

@@ -12,9 +12,10 @@ class MultiLineFitter:
@staticmethod
def fitLines(points, lines, consensusThreshold):
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
_, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix)
lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix4Clusters)
return np.array(lines)[lineIndexes]
clusters, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix)
return (
MultiLineFitter._getClusterPoints(points, clusters),
MultiLineFitter._getLines(lines, preferenceMatrix4Clusters))
@staticmethod
def _createPreferenceMatrix(points, lines, consensusThreshold):
@@ -56,6 +57,15 @@ class MultiLineFitter:
union = np.count_nonzero(np.logical_or(setA, setB))
return 1. * intersection / union
@staticmethod
def _getLines(lines, preferenceMatrix):
return np.array(lines)[MultiLineFitter._getLineIndexes(preferenceMatrix)]
@staticmethod
def _getLineIndexes(preferenceMatrix):
return [list(lines).index(1) for lines in preferenceMatrix]
@staticmethod
def _getClusterPoints(points, clusters):
sortedClusters = [sorted(cluster) for cluster in clusters]
return [list(np.array(points)[cluster]) for cluster in sortedClusters]