diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py index 2f5e323f85b..64562d5faea 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py @@ -24,6 +24,28 @@ class ClustersFactoryTest(unittest.TestCase): [0] ])) + def test_createPreferenceMatrix2(self): + # Given + points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)] + lines = [Line.from_points([0, 0], [1, 0]), Line.from_points([0, 0], [1, 1])] + consensusThreshold = 0.001 + + # When + preferenceMatrix = ClustersFactory.createPreferenceMatrix(points, lines, consensusThreshold) + + # Then + np.testing.assert_array_equal( + preferenceMatrix, + np.array( + [ + [1, 0], + [1, 0], + [1, 0], + [0, 1], + [0, 1], + [0, 1] + ])) + def test_createClusters(self): # Given preferenceMatrix = np.array( @@ -42,3 +64,26 @@ class ClustersFactoryTest(unittest.TestCase): [ [1, 0] ])) + + def test_createClusters2(self): + # Given + preferenceMatrix = np.array( + [ + [1, 1], + [1, 0], + [1, 0], + [0, 1], + [0, 1] + ]) + + # When + _, clusters = ClustersFactory.createClusters(preferenceMatrix) + + # Then + np.testing.assert_array_equal( + clusters, + np.array( + [ + [2, 1, 0], + [4, 3] + ]))