refactoring

This commit is contained in:
frankknoll
2023-11-17 08:16:01 +01:00
parent 62ce64308c
commit 2b81552d7a

View File

@@ -29,9 +29,8 @@ class MultiLineFitter:
bestClusterIndexCombination = None
keepClustering = False
numClusters = preferenceMatrix.shape[0]
for clusterIndexA in range(numClusters):
for (clusterIndexA, clusterIndexB) in MultiLineFitter._getPairs(numClusters):
preferenceSetA = preferenceMatrix[clusterIndexA]
for clusterIndexB in range(clusterIndexA):
preferenceSetB = preferenceMatrix[clusterIndexB]
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
if similarity > maxSimilarity:
@@ -48,6 +47,12 @@ class MultiLineFitter:
return clusters, preferenceMatrix
@staticmethod
def _getPairs(n):
for i in range(n):
for j in range(i):
yield (i, j)
@staticmethod
def _intersectionOverUnion(setA, setB):
intersection = np.count_nonzero(np.logical_and(setA, setB))