From 36492ae88b14497fd299e26a2ecd98b56245b45f Mon Sep 17 00:00:00 2001 From: frankknoll Date: Fri, 17 Nov 2023 21:13:27 +0100 Subject: [PATCH] refining MultiLineFitterTest --- src/HowBadIsMyBatch.ipynb | 6 +++--- .../CharacteristicFunctions.py | 11 +++++++++++ .../MultiLineFitting/MultiLineFitter.py | 19 ++++++++++++++++--- .../MultiLineFitting/MultiLineFitterTest.py | 14 +++++++------- 4 files changed, 37 insertions(+), 13 deletions(-) create mode 100644 src/SymptomsCausedByVaccines/MultiLineFitting/CharacteristicFunctions.py diff --git a/src/HowBadIsMyBatch.ipynb b/src/HowBadIsMyBatch.ipynb index 30b8b89c495..62e22a52129 100644 --- a/src/HowBadIsMyBatch.ipynb +++ b/src/HowBadIsMyBatch.ipynb @@ -669,8 +669,8 @@ "metadata": {}, "outputs": [], "source": [ - "symptomX = 'Immunosuppression'\n", - "symptomY = 'Infection' # 'Immunoglobulin therapy'" + "symptomX = 'HIV test' # 'Immunosuppression'\n", + "symptomY = 'Immunoglobulin therapy' # 'Infection' # 'Immunoglobulin therapy'" ] }, { @@ -723,7 +723,7 @@ " s = 100,\n", " label = \"Dots\")\n", "for cluster, line in zip(clusters, lines):\n", - " if(len(cluster) > 2):\n", + " if len(cluster) >= 3:\n", " coords = line.transform_points(cluster)\n", " magnitude = line.direction.norm()\n", " line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n", diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/CharacteristicFunctions.py b/src/SymptomsCausedByVaccines/MultiLineFitting/CharacteristicFunctions.py new file mode 100644 index 00000000000..bb454f7f14e --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/CharacteristicFunctions.py @@ -0,0 +1,11 @@ +import numpy as np + +class CharacteristicFunctions: + + @staticmethod + def apply(characteristicFunction, elements): + return np.array(elements)[CharacteristicFunctions._getIndexes(characteristicFunction)] + + @staticmethod + def _getIndexes(characteristicFunction): + return [index for (index, value) in enumerate(characteristicFunction) if value == 1] diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py index 5e76b668030..14bf0fba376 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py @@ -1,6 +1,7 @@ import numpy as np from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs +from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import CharacteristicFunctions # implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage class MultiLineFitter: @@ -12,10 +13,22 @@ class MultiLineFitter: @staticmethod def fitLines(points, lines, consensusThreshold): preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold) - clusters, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix) + _, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix) + fittedLines = MultiLineFitter._getLines(lines, preferenceMatrix4Clusters) return ( - MultiLineFitter._getClusterPoints(points, clusters), - MultiLineFitter._getLines(lines, preferenceMatrix4Clusters)) + MultiLineFitter._getFittedPointsList(points, fittedLines, consensusThreshold), + fittedLines) + + @staticmethod + def _getFittedPointsList(points, lines, consensusThreshold): + return MultiLineFitter._getPointsList( + points, + MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)) + + @staticmethod + def _getPointsList(points, preferenceMatrix): + characteristicFunctionsOfConsensusSets = np.transpose(preferenceMatrix) + return [CharacteristicFunctions.apply(characteristicFunctionOfConsensusSet, points) for characteristicFunctionOfConsensusSet in characteristicFunctionsOfConsensusSets] @staticmethod def _createPreferenceMatrix(points, lines, consensusThreshold): diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py index 15dda8d0a98..0270b9a889e 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py @@ -111,10 +111,10 @@ class MultiLineFitterTest(unittest.TestCase): def test_fitLines(self): # Given - points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)] + points = [(0, 0), (1, 0), (2, 0), (1, 1), (2, 2)] line1 = Line.from_points([0, 0], [1, 0]) line2 = Line.from_points([0, 0], [1, 1]) - line3 = Line.from_points([0, 0], [0, 1]) + line3 = Line.from_points([-10, 0], [-10, 1]) # When clusters, fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2, line3], consensusThreshold = 0.001) @@ -129,13 +129,13 @@ class MultiLineFitterTest(unittest.TestCase): np.testing.assert_array_equal( clusters, [ - [(1, 0), (2, 0), (3, 0)], - [(1, 1), (2, 2), (3, 3)] + [(0, 0), (1, 0), (2, 0)], + [(0, 0), (1, 1), (2, 2)] ]) def test_fitPointsByLines(self): # Given - points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)] + points = [(0, 0), (1, 0), (2, 0), (1, 1), (2, 2)] # When clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001) @@ -147,6 +147,6 @@ class MultiLineFitterTest(unittest.TestCase): np.testing.assert_array_equal( clusters, [ - [(1, 0), (2, 0), (3, 0)], - [(1, 1), (2, 2), (3, 3)] + [(0, 0), (1, 0), (2, 0)], + [(0, 0), (1, 1), (2, 2)] ])