diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactory.py b/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactory.py index d524ba443ec..ddb633a20ef 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactory.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactory.py @@ -31,7 +31,7 @@ class ClustersFactory: preferenceMatrix[clusterIndexA] = np.logical_and(preferenceMatrix[clusterIndexA], preferenceMatrix[clusterIndexB]) preferenceMatrix = np.delete(preferenceMatrix, clusterIndexB, axis = 0) - return clusters + return clusters, preferenceMatrix @staticmethod def _createPreferenceMatrix(points, lines, consensusThreshold): @@ -46,3 +46,7 @@ class ClustersFactory: intersection = np.count_nonzero(np.logical_and(setA, setB)) union = np.count_nonzero(np.logical_or(setA, setB)) return 1. * intersection / union + + @staticmethod + def _getLineIndexes(preferenceMatrix): + return [list(lines).index(1) for lines in preferenceMatrix] diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py index 51d2cc7cf8d..d6780902c78 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/ClustersFactoryTest.py @@ -1,7 +1,7 @@ import unittest import numpy as np from skspatial.objects import Line -from src.SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory +from SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory class ClustersFactoryTest(unittest.TestCase): @@ -55,7 +55,7 @@ class ClustersFactoryTest(unittest.TestCase): ]) # When - clusters = ClustersFactory.createClusters(preferenceMatrix) + clusters, _ = ClustersFactory.createClusters(preferenceMatrix) # Then np.testing.assert_array_equal( @@ -77,7 +77,7 @@ class ClustersFactoryTest(unittest.TestCase): ]) # When - clusters = ClustersFactory.createClusters(preferenceMatrix) + clusters, _ = ClustersFactory.createClusters(preferenceMatrix) # Then np.testing.assert_array_equal( @@ -87,3 +87,17 @@ class ClustersFactoryTest(unittest.TestCase): [2, 1, 0], [4, 3] ])) + + def test_getLineIndexes(self): + # Given + preferenceMatrix = np.array( + [ + [0, 0, 1], + [0, 1, 1] + ]) + + # When + lineIndexes = ClustersFactory._getLineIndexes(preferenceMatrix) + + # Then + np.testing.assert_array_equal(lineIndexes, [2, 1])