diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py index d2c149d7921..0e77c83cb34 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py @@ -29,15 +29,14 @@ class MultiLineFitter: bestClusterIndexCombination = None keepClustering = False numClusters = preferenceMatrix.shape[0] - for clusterIndexA in range(numClusters): + for (clusterIndexA, clusterIndexB) in MultiLineFitter._getPairs(numClusters): preferenceSetA = preferenceMatrix[clusterIndexA] - for clusterIndexB in range(clusterIndexA): - preferenceSetB = preferenceMatrix[clusterIndexB] - similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB); - if similarity > maxSimilarity: - keepClustering = True - maxSimilarity = similarity - bestClusterIndexCombination = (clusterIndexA, clusterIndexB) + preferenceSetB = preferenceMatrix[clusterIndexB] + similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB); + if similarity > maxSimilarity: + keepClustering = True + maxSimilarity = similarity + bestClusterIndexCombination = (clusterIndexA, clusterIndexB) if keepClustering: (clusterIndexA, clusterIndexB) = bestClusterIndexCombination @@ -48,6 +47,12 @@ class MultiLineFitter: return clusters, preferenceMatrix + @staticmethod + def _getPairs(n): + for i in range(n): + for j in range(i): + yield (i, j) + @staticmethod def _intersectionOverUnion(setA, setB): intersection = np.count_nonzero(np.logical_and(setA, setB))