refactoring

This commit is contained in:
frankknoll
2023-11-17 08:16:01 +01:00
parent 62ce64308c
commit 2b81552d7a

View File

@@ -29,15 +29,14 @@ class MultiLineFitter:
bestClusterIndexCombination = None bestClusterIndexCombination = None
keepClustering = False keepClustering = False
numClusters = preferenceMatrix.shape[0] numClusters = preferenceMatrix.shape[0]
for clusterIndexA in range(numClusters): for (clusterIndexA, clusterIndexB) in MultiLineFitter._getPairs(numClusters):
preferenceSetA = preferenceMatrix[clusterIndexA] preferenceSetA = preferenceMatrix[clusterIndexA]
for clusterIndexB in range(clusterIndexA): preferenceSetB = preferenceMatrix[clusterIndexB]
preferenceSetB = preferenceMatrix[clusterIndexB] similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB); if similarity > maxSimilarity:
if similarity > maxSimilarity: keepClustering = True
keepClustering = True maxSimilarity = similarity
maxSimilarity = similarity bestClusterIndexCombination = (clusterIndexA, clusterIndexB)
bestClusterIndexCombination = (clusterIndexA, clusterIndexB)
if keepClustering: if keepClustering:
(clusterIndexA, clusterIndexB) = bestClusterIndexCombination (clusterIndexA, clusterIndexB) = bestClusterIndexCombination
@@ -48,6 +47,12 @@ class MultiLineFitter:
return clusters, preferenceMatrix return clusters, preferenceMatrix
@staticmethod
def _getPairs(n):
for i in range(n):
for j in range(i):
yield (i, j)
@staticmethod @staticmethod
def _intersectionOverUnion(setA, setB): def _intersectionOverUnion(setA, setB):
intersection = np.count_nonzero(np.logical_and(setA, setB)) intersection = np.count_nonzero(np.logical_and(setA, setB))