refactoring
This commit is contained in:
@@ -29,15 +29,14 @@ class MultiLineFitter:
|
||||
bestClusterIndexCombination = None
|
||||
keepClustering = False
|
||||
numClusters = preferenceMatrix.shape[0]
|
||||
for clusterIndexA in range(numClusters):
|
||||
for (clusterIndexA, clusterIndexB) in MultiLineFitter._getPairs(numClusters):
|
||||
preferenceSetA = preferenceMatrix[clusterIndexA]
|
||||
for clusterIndexB in range(clusterIndexA):
|
||||
preferenceSetB = preferenceMatrix[clusterIndexB]
|
||||
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
||||
if similarity > maxSimilarity:
|
||||
keepClustering = True
|
||||
maxSimilarity = similarity
|
||||
bestClusterIndexCombination = (clusterIndexA, clusterIndexB)
|
||||
preferenceSetB = preferenceMatrix[clusterIndexB]
|
||||
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
||||
if similarity > maxSimilarity:
|
||||
keepClustering = True
|
||||
maxSimilarity = similarity
|
||||
bestClusterIndexCombination = (clusterIndexA, clusterIndexB)
|
||||
|
||||
if keepClustering:
|
||||
(clusterIndexA, clusterIndexB) = bestClusterIndexCombination
|
||||
@@ -48,6 +47,12 @@ class MultiLineFitter:
|
||||
|
||||
return clusters, preferenceMatrix
|
||||
|
||||
@staticmethod
|
||||
def _getPairs(n):
|
||||
for i in range(n):
|
||||
for j in range(i):
|
||||
yield (i, j)
|
||||
|
||||
@staticmethod
|
||||
def _intersectionOverUnion(setA, setB):
|
||||
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
||||
|
||||
Reference in New Issue
Block a user