refactoring
This commit is contained in:
@@ -1,17 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from skspatial.objects import Line
|
from skspatial.objects import Line
|
||||||
|
|
||||||
|
# implementation of "Robust Multiple Structures Estimation with J-linkage"
|
||||||
class ClustersFactory:
|
class ClustersFactory:
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def createPreferenceMatrix(points, lines, consensusThreshold):
|
|
||||||
preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int)
|
|
||||||
for pointIndex, point in enumerate(points):
|
|
||||||
for lineIndex, line in enumerate(lines):
|
|
||||||
preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0
|
|
||||||
|
|
||||||
return preferenceMatrix
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def createClusters(preferenceMatrix):
|
def createClusters(preferenceMatrix):
|
||||||
keepClustering = True
|
keepClustering = True
|
||||||
@@ -19,7 +11,7 @@ class ClustersFactory:
|
|||||||
clusters = [[i] for i in range(numClusters)]
|
clusters = [[i] for i in range(numClusters)]
|
||||||
while keepClustering:
|
while keepClustering:
|
||||||
maxDistance = 0
|
maxDistance = 0
|
||||||
bestClusterIndexCombo = None
|
bestClusterIndexCombination = None
|
||||||
keepClustering = False
|
keepClustering = False
|
||||||
numClusters = preferenceMatrix.shape[0]
|
numClusters = preferenceMatrix.shape[0]
|
||||||
for clusterIndexA in range(numClusters):
|
for clusterIndexA in range(numClusters):
|
||||||
@@ -30,10 +22,10 @@ class ClustersFactory:
|
|||||||
if distance > maxDistance:
|
if distance > maxDistance:
|
||||||
keepClustering = True
|
keepClustering = True
|
||||||
maxDistance = distance
|
maxDistance = distance
|
||||||
bestClusterIndexCombo = (clusterIndexA, clusterIndexB)
|
bestClusterIndexCombination = (clusterIndexA, clusterIndexB)
|
||||||
|
|
||||||
if keepClustering:
|
if keepClustering:
|
||||||
(clusterIndexA, clusterIndexB) = bestClusterIndexCombo
|
(clusterIndexA, clusterIndexB) = bestClusterIndexCombination
|
||||||
clusters[clusterIndexA] += clusters[clusterIndexB]
|
clusters[clusterIndexA] += clusters[clusterIndexB]
|
||||||
clusters.pop(clusterIndexB)
|
clusters.pop(clusterIndexB)
|
||||||
preferenceMatrix[clusterIndexA] = np.logical_and(preferenceMatrix[clusterIndexA], preferenceMatrix[clusterIndexB])
|
preferenceMatrix[clusterIndexA] = np.logical_and(preferenceMatrix[clusterIndexA], preferenceMatrix[clusterIndexB])
|
||||||
@@ -41,6 +33,14 @@ class ClustersFactory:
|
|||||||
|
|
||||||
return clusters
|
return clusters
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
||||||
|
preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int)
|
||||||
|
for pointIndex, point in enumerate(points):
|
||||||
|
for lineIndex, line in enumerate(lines):
|
||||||
|
preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0
|
||||||
|
return preferenceMatrix
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _intersectionOverUnion(setA, setB):
|
def _intersectionOverUnion(setA, setB):
|
||||||
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
|||||||
consensusThreshold = 4.0
|
consensusThreshold = 4.0
|
||||||
|
|
||||||
# When
|
# When
|
||||||
preferenceMatrix = ClustersFactory.createPreferenceMatrix(points, lines, consensusThreshold)
|
preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
@@ -31,7 +31,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
|||||||
consensusThreshold = 0.001
|
consensusThreshold = 0.001
|
||||||
|
|
||||||
# When
|
# When
|
||||||
preferenceMatrix = ClustersFactory.createPreferenceMatrix(points, lines, consensusThreshold)
|
preferenceMatrix = ClustersFactory._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
|
|||||||
Reference in New Issue
Block a user