diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactory.py b/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactory.py new file mode 100644 index 00000000000..c6d392a919f --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactory.py @@ -0,0 +1,55 @@ +import numpy as np +from skspatial.objects import Line + +class ClustersFactory: + + @staticmethod + def createPreferenceMatrix(points, lines, consensusThreshold): + preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int) + for pointIndex, point in enumerate(points): + for lineIndex, line in enumerate(lines): + preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0 + + return preferenceMatrix + + @staticmethod + def createClusters(preferenceMatrix): + keep_clustering = True + cluster_step = 0 + + num_clusters = preferenceMatrix.shape[0] + clusters = [[i] for i in range(num_clusters)] + + while keep_clustering: + smallest_distance = 0 + best_combo = None + keep_clustering = False + + num_clusters = preferenceMatrix.shape[0] + + for i in range(num_clusters): + for j in range(i): + set_a = preferenceMatrix[i] + set_b = preferenceMatrix[j] + intersection = np.count_nonzero(np.logical_and(set_a, set_b)) + union = np.count_nonzero(np.logical_or(set_a, set_b)) + distance = 1.*intersection/np.maximum(union, 1e-8) + + if distance > smallest_distance: + keep_clustering = True + smallest_distance = distance + best_combo = (i,j) + + if keep_clustering: + clusters[best_combo[0]] += clusters[best_combo[1]] + clusters.pop(best_combo[1]) + set_a = preferenceMatrix[best_combo[0]] + set_b = preferenceMatrix[best_combo[1]] + merged_set = np.logical_and(set_a, set_b) + preferenceMatrix[best_combo[0]] = merged_set + preferenceMatrix = np.delete(preferenceMatrix, best_combo[1], axis=0) + cluster_step += 1 + + print("clustering finished after %d steps" % cluster_step) + + return preferenceMatrix, clusters diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py new file mode 100644 index 00000000000..2f5e323f85b --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py @@ -0,0 +1,44 @@ +import unittest +import numpy as np +from skspatial.objects import Line +from src.SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory + + +class ClustersFactoryTest(unittest.TestCase): + + def test_createPreferenceMatrix(self): + # Given + points = [(1, 3), (10, 20)] + lines = [Line.from_points([0, 0], [100, 0])] + consensusThreshold = 4.0 + + # When + preferenceMatrix = ClustersFactory.createPreferenceMatrix(points, lines, consensusThreshold) + + # Then + np.testing.assert_array_equal( + preferenceMatrix, + np.array( + [ + [1], + [0] + ])) + + def test_createClusters(self): + # Given + preferenceMatrix = np.array( + [ + [1], + [1] + ]) + + # When + _, clusters = ClustersFactory.createClusters(preferenceMatrix) + + # Then + np.testing.assert_array_equal( + clusters, + np.array( + [ + [1, 0] + ])) diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactory.py b/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactory.py deleted file mode 100644 index 7bf93eb1ab1..00000000000 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactory.py +++ /dev/null @@ -1,14 +0,0 @@ -import numpy as np -from skspatial.objects import Line - -class PreferenceMatrixFactory: - - @staticmethod - def createPreferenceMatrix(points, lines, consensusThreshold): - preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int) - for pointIndex, point in enumerate(points): - for lineIndex, line in enumerate(lines): - preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0 - - return preferenceMatrix - diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactoryTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactoryTest.py deleted file mode 100644 index 1df52a5dccc..00000000000 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactoryTest.py +++ /dev/null @@ -1,26 +0,0 @@ -import unittest -import numpy as np -from numpy.testing import assert_array_equal -from skspatial.objects import Line -from SymptomsCausedByVaccines.MultiLineFitting.PreferenceMatrixFactory import PreferenceMatrixFactory - - -class PreferenceMatrixFactoryTest(unittest.TestCase): - - def test_createPreferenceMatrix(self): - # Given - points = [(1, 3), (10, 20)] - lines = [Line.from_points([0, 0], [100, 0])] - consensusThreshold = 4.0 - - # When - preferenceMatrix = PreferenceMatrixFactory.createPreferenceMatrix(points, lines, consensusThreshold) - - # Then - assert_array_equal( - preferenceMatrix, - np.array( - [ - [1], - [0] - ]))