refactoring
This commit is contained in:
@@ -29,9 +29,8 @@ class MultiLineFitter:
|
|||||||
bestClusterIndexCombination = None
|
bestClusterIndexCombination = None
|
||||||
keepClustering = False
|
keepClustering = False
|
||||||
numClusters = preferenceMatrix.shape[0]
|
numClusters = preferenceMatrix.shape[0]
|
||||||
for clusterIndexA in range(numClusters):
|
for (clusterIndexA, clusterIndexB) in MultiLineFitter._getPairs(numClusters):
|
||||||
preferenceSetA = preferenceMatrix[clusterIndexA]
|
preferenceSetA = preferenceMatrix[clusterIndexA]
|
||||||
for clusterIndexB in range(clusterIndexA):
|
|
||||||
preferenceSetB = preferenceMatrix[clusterIndexB]
|
preferenceSetB = preferenceMatrix[clusterIndexB]
|
||||||
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
||||||
if similarity > maxSimilarity:
|
if similarity > maxSimilarity:
|
||||||
@@ -48,6 +47,12 @@ class MultiLineFitter:
|
|||||||
|
|
||||||
return clusters, preferenceMatrix
|
return clusters, preferenceMatrix
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _getPairs(n):
|
||||||
|
for i in range(n):
|
||||||
|
for j in range(i):
|
||||||
|
yield (i, j)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _intersectionOverUnion(setA, setB):
|
def _intersectionOverUnion(setA, setB):
|
||||||
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
||||||
|
|||||||
Reference in New Issue
Block a user