refactoring

This commit is contained in:
frankknoll
2023-11-16 15:47:36 +01:00
parent 2caae0e198
commit c2f900504a
2 changed files with 14 additions and 14 deletions

View File

@@ -1,17 +1,9 @@
import numpy as np import numpy as np
from skspatial.objects import Line from skspatial.objects import Line
# implementation of "Robust Multiple Structures Estimation with J-linkage"
class ClustersFactory: class ClustersFactory:
@staticmethod
def createPreferenceMatrix(points, lines, consensusThreshold):
preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int)
for pointIndex, point in enumerate(points):
for lineIndex, line in enumerate(lines):
preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0
return preferenceMatrix
@staticmethod @staticmethod
def createClusters(preferenceMatrix): def createClusters(preferenceMatrix):
keepClustering = True keepClustering = True
@@ -19,7 +11,7 @@ class ClustersFactory:
clusters = [[i] for i in range(numClusters)] clusters = [[i] for i in range(numClusters)]
while keepClustering: while keepClustering:
maxDistance = 0 maxDistance = 0
bestClusterIndexCombo = None bestClusterIndexCombination = None
keepClustering = False keepClustering = False
numClusters = preferenceMatrix.shape[0] numClusters = preferenceMatrix.shape[0]
for clusterIndexA in range(numClusters): for clusterIndexA in range(numClusters):
@@ -30,10 +22,10 @@ class ClustersFactory:
if distance > maxDistance: if distance > maxDistance:
keepClustering = True keepClustering = True
maxDistance = distance maxDistance = distance
bestClusterIndexCombo = (clusterIndexA, clusterIndexB) bestClusterIndexCombination = (clusterIndexA, clusterIndexB)
if keepClustering: if keepClustering:
(clusterIndexA, clusterIndexB) = bestClusterIndexCombo (clusterIndexA, clusterIndexB) = bestClusterIndexCombination
clusters[clusterIndexA] += clusters[clusterIndexB] clusters[clusterIndexA] += clusters[clusterIndexB]
clusters.pop(clusterIndexB) clusters.pop(clusterIndexB)
preferenceMatrix[clusterIndexA] = np.logical_and(preferenceMatrix[clusterIndexA], preferenceMatrix[clusterIndexB]) preferenceMatrix[clusterIndexA] = np.logical_and(preferenceMatrix[clusterIndexA], preferenceMatrix[clusterIndexB])
@@ -41,6 +33,14 @@ class ClustersFactory:
return clusters return clusters
@staticmethod
def _createPreferenceMatrix(points, lines, consensusThreshold):
preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int)
for pointIndex, point in enumerate(points):
for lineIndex, line in enumerate(lines):
preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0
return preferenceMatrix
@staticmethod @staticmethod
def _intersectionOverUnion(setA, setB): def _intersectionOverUnion(setA, setB):
intersection = np.count_nonzero(np.logical_and(setA, setB)) intersection = np.count_nonzero(np.logical_and(setA, setB))

View File

@@ -13,7 +13,7 @@ class ClustersFactoryTest(unittest.TestCase):
consensusThreshold = 4.0 consensusThreshold = 4.0
# When # When
preferenceMatrix = ClustersFactory.createPreferenceMatrix(points, lines, consensusThreshold) preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
# Then # Then
np.testing.assert_array_equal( np.testing.assert_array_equal(
@@ -31,7 +31,7 @@ class ClustersFactoryTest(unittest.TestCase):
consensusThreshold = 0.001 consensusThreshold = 0.001
# When # When
preferenceMatrix = ClustersFactory.createPreferenceMatrix(points, lines, consensusThreshold) preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
# Then # Then
np.testing.assert_array_equal( np.testing.assert_array_equal(