refactoring

This commit is contained in:
frankknoll
2023-11-16 15:27:16 +01:00
parent 457b2c6dd7
commit 44e383734c
2 changed files with 28 additions and 36 deletions

View File

@@ -14,42 +14,34 @@ class ClustersFactory:
@staticmethod @staticmethod
def createClusters(preferenceMatrix): def createClusters(preferenceMatrix):
keep_clustering = True keepClustering = True
cluster_step = 0 numClusters = preferenceMatrix.shape[0]
clusters = [[i] for i in range(numClusters)]
num_clusters = preferenceMatrix.shape[0] while keepClustering:
clusters = [[i] for i in range(num_clusters)] maxDistance = 0
bestCombo = None
while keep_clustering: keepClustering = False
smallest_distance = 0 numClusters = preferenceMatrix.shape[0]
best_combo = None for i in range(numClusters):
keep_clustering = False set_a = preferenceMatrix[i]
num_clusters = preferenceMatrix.shape[0]
for i in range(num_clusters):
for j in range(i): for j in range(i):
set_a = preferenceMatrix[i]
set_b = preferenceMatrix[j] set_b = preferenceMatrix[j]
intersection = np.count_nonzero(np.logical_and(set_a, set_b)) distance = ClustersFactory._intersectionOverUnion(set_a, set_b);
union = np.count_nonzero(np.logical_or(set_a, set_b)) if distance > maxDistance:
distance = 1.*intersection/np.maximum(union, 1e-8) keepClustering = True
maxDistance = distance
bestCombo = (i, j)
if distance > smallest_distance: if keepClustering:
keep_clustering = True clusters[bestCombo[0]] += clusters[bestCombo[1]]
smallest_distance = distance clusters.pop(bestCombo[1])
best_combo = (i,j) preferenceMatrix[bestCombo[0]] = np.logical_and(preferenceMatrix[bestCombo[0]], preferenceMatrix[bestCombo[1]])
preferenceMatrix = np.delete(preferenceMatrix, bestCombo[1], axis = 0)
if keep_clustering: return clusters
clusters[best_combo[0]] += clusters[best_combo[1]]
clusters.pop(best_combo[1])
set_a = preferenceMatrix[best_combo[0]]
set_b = preferenceMatrix[best_combo[1]]
merged_set = np.logical_and(set_a, set_b)
preferenceMatrix[best_combo[0]] = merged_set
preferenceMatrix = np.delete(preferenceMatrix, best_combo[1], axis=0)
cluster_step += 1
print("clustering finished after %d steps" % cluster_step) @staticmethod
def _intersectionOverUnion(set_a, set_b):
return preferenceMatrix, clusters intersection = np.count_nonzero(np.logical_and(set_a, set_b))
union = np.count_nonzero(np.logical_or(set_a, set_b))
return 1. * intersection / union

View File

@@ -55,7 +55,7 @@ class ClustersFactoryTest(unittest.TestCase):
]) ])
# When # When
_, clusters = ClustersFactory.createClusters(preferenceMatrix) clusters = ClustersFactory.createClusters(preferenceMatrix)
# Then # Then
np.testing.assert_array_equal( np.testing.assert_array_equal(
@@ -77,7 +77,7 @@ class ClustersFactoryTest(unittest.TestCase):
]) ])
# When # When
_, clusters = ClustersFactory.createClusters(preferenceMatrix) clusters = ClustersFactory.createClusters(preferenceMatrix)
# Then # Then
np.testing.assert_array_equal( np.testing.assert_array_equal(