refactoring
This commit is contained in:
@@ -1,16 +1,8 @@
|
||||
import numpy as np
|
||||
from skspatial.objects import Line
|
||||
|
||||
# implementation of "Robust Multiple Structures Estimation with J-linkage"
|
||||
class ClustersFactory:
|
||||
|
||||
@staticmethod
|
||||
def createPreferenceMatrix(points, lines, consensusThreshold):
|
||||
preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int)
|
||||
for pointIndex, point in enumerate(points):
|
||||
for lineIndex, line in enumerate(lines):
|
||||
preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0
|
||||
|
||||
return preferenceMatrix
|
||||
|
||||
@staticmethod
|
||||
def createClusters(preferenceMatrix):
|
||||
@@ -19,7 +11,7 @@ class ClustersFactory:
|
||||
clusters = [[i] for i in range(numClusters)]
|
||||
while keepClustering:
|
||||
maxDistance = 0
|
||||
bestClusterIndexCombo = None
|
||||
bestClusterIndexCombination = None
|
||||
keepClustering = False
|
||||
numClusters = preferenceMatrix.shape[0]
|
||||
for clusterIndexA in range(numClusters):
|
||||
@@ -30,10 +22,10 @@ class ClustersFactory:
|
||||
if distance > maxDistance:
|
||||
keepClustering = True
|
||||
maxDistance = distance
|
||||
bestClusterIndexCombo = (clusterIndexA, clusterIndexB)
|
||||
bestClusterIndexCombination = (clusterIndexA, clusterIndexB)
|
||||
|
||||
if keepClustering:
|
||||
(clusterIndexA, clusterIndexB) = bestClusterIndexCombo
|
||||
(clusterIndexA, clusterIndexB) = bestClusterIndexCombination
|
||||
clusters[clusterIndexA] += clusters[clusterIndexB]
|
||||
clusters.pop(clusterIndexB)
|
||||
preferenceMatrix[clusterIndexA] = np.logical_and(preferenceMatrix[clusterIndexA], preferenceMatrix[clusterIndexB])
|
||||
@@ -41,6 +33,14 @@ class ClustersFactory:
|
||||
|
||||
return clusters
|
||||
|
||||
@staticmethod
|
||||
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
||||
preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int)
|
||||
for pointIndex, point in enumerate(points):
|
||||
for lineIndex, line in enumerate(lines):
|
||||
preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0
|
||||
return preferenceMatrix
|
||||
|
||||
@staticmethod
|
||||
def _intersectionOverUnion(setA, setB):
|
||||
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
||||
|
||||
@@ -13,7 +13,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
||||
consensusThreshold = 4.0
|
||||
|
||||
# When
|
||||
preferenceMatrix = ClustersFactory.createPreferenceMatrix(points, lines, consensusThreshold)
|
||||
preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||
|
||||
# Then
|
||||
np.testing.assert_array_equal(
|
||||
@@ -31,7 +31,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
||||
consensusThreshold = 0.001
|
||||
|
||||
# When
|
||||
preferenceMatrix = ClustersFactory.createPreferenceMatrix(points, lines, consensusThreshold)
|
||||
preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||
|
||||
# Then
|
||||
np.testing.assert_array_equal(
|
||||
|
||||
Reference in New Issue
Block a user