refactoring
This commit is contained in:
@@ -29,15 +29,14 @@ class MultiLineFitter:
|
|||||||
bestClusterIndexCombination = None
|
bestClusterIndexCombination = None
|
||||||
keepClustering = False
|
keepClustering = False
|
||||||
numClusters = preferenceMatrix.shape[0]
|
numClusters = preferenceMatrix.shape[0]
|
||||||
for clusterIndexA in range(numClusters):
|
for (clusterIndexA, clusterIndexB) in MultiLineFitter._getPairs(numClusters):
|
||||||
preferenceSetA = preferenceMatrix[clusterIndexA]
|
preferenceSetA = preferenceMatrix[clusterIndexA]
|
||||||
for clusterIndexB in range(clusterIndexA):
|
preferenceSetB = preferenceMatrix[clusterIndexB]
|
||||||
preferenceSetB = preferenceMatrix[clusterIndexB]
|
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
||||||
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
if similarity > maxSimilarity:
|
||||||
if similarity > maxSimilarity:
|
keepClustering = True
|
||||||
keepClustering = True
|
maxSimilarity = similarity
|
||||||
maxSimilarity = similarity
|
bestClusterIndexCombination = (clusterIndexA, clusterIndexB)
|
||||||
bestClusterIndexCombination = (clusterIndexA, clusterIndexB)
|
|
||||||
|
|
||||||
if keepClustering:
|
if keepClustering:
|
||||||
(clusterIndexA, clusterIndexB) = bestClusterIndexCombination
|
(clusterIndexA, clusterIndexB) = bestClusterIndexCombination
|
||||||
@@ -48,6 +47,12 @@ class MultiLineFitter:
|
|||||||
|
|
||||||
return clusters, preferenceMatrix
|
return clusters, preferenceMatrix
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _getPairs(n):
|
||||||
|
for i in range(n):
|
||||||
|
for j in range(i):
|
||||||
|
yield (i, j)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _intersectionOverUnion(setA, setB):
|
def _intersectionOverUnion(setA, setB):
|
||||||
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
||||||
|
|||||||
Reference in New Issue
Block a user