refining MultiLineFitterTest

This commit is contained in:
frankknoll
2023-11-17 21:13:27 +01:00
parent ee524ef036
commit 36492ae88b
4 changed files with 37 additions and 13 deletions

View File

@@ -669,8 +669,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"symptomX = 'Immunosuppression'\n", "symptomX = 'HIV test' # 'Immunosuppression'\n",
"symptomY = 'Infection' # 'Immunoglobulin therapy'" "symptomY = 'Immunoglobulin therapy' # 'Infection' # 'Immunoglobulin therapy'"
] ]
}, },
{ {
@@ -723,7 +723,7 @@
" s = 100,\n", " s = 100,\n",
" label = \"Dots\")\n", " label = \"Dots\")\n",
"for cluster, line in zip(clusters, lines):\n", "for cluster, line in zip(clusters, lines):\n",
" if(len(cluster) > 2):\n", " if len(cluster) >= 3:\n",
" coords = line.transform_points(cluster)\n", " coords = line.transform_points(cluster)\n",
" magnitude = line.direction.norm()\n", " magnitude = line.direction.norm()\n",
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n", " line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",

View File

@@ -0,0 +1,11 @@
import numpy as np
class CharacteristicFunctions:
@staticmethod
def apply(characteristicFunction, elements):
return np.array(elements)[CharacteristicFunctions._getIndexes(characteristicFunction)]
@staticmethod
def _getIndexes(characteristicFunction):
return [index for (index, value) in enumerate(characteristicFunction) if value == 1]

View File

@@ -1,6 +1,7 @@
import numpy as np import numpy as np
from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs
from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import CharacteristicFunctions
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage # implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
class MultiLineFitter: class MultiLineFitter:
@@ -12,10 +13,22 @@ class MultiLineFitter:
@staticmethod @staticmethod
def fitLines(points, lines, consensusThreshold): def fitLines(points, lines, consensusThreshold):
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold) preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
clusters, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix) _, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix)
fittedLines = MultiLineFitter._getLines(lines, preferenceMatrix4Clusters)
return ( return (
MultiLineFitter._getClusterPoints(points, clusters), MultiLineFitter._getFittedPointsList(points, fittedLines, consensusThreshold),
MultiLineFitter._getLines(lines, preferenceMatrix4Clusters)) fittedLines)
@staticmethod
def _getFittedPointsList(points, lines, consensusThreshold):
return MultiLineFitter._getPointsList(
points,
MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold))
@staticmethod
def _getPointsList(points, preferenceMatrix):
characteristicFunctionsOfConsensusSets = np.transpose(preferenceMatrix)
return [CharacteristicFunctions.apply(characteristicFunctionOfConsensusSet, points) for characteristicFunctionOfConsensusSet in characteristicFunctionsOfConsensusSets]
@staticmethod @staticmethod
def _createPreferenceMatrix(points, lines, consensusThreshold): def _createPreferenceMatrix(points, lines, consensusThreshold):

View File

@@ -111,10 +111,10 @@ class MultiLineFitterTest(unittest.TestCase):
def test_fitLines(self): def test_fitLines(self):
# Given # Given
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)] points = [(0, 0), (1, 0), (2, 0), (1, 1), (2, 2)]
line1 = Line.from_points([0, 0], [1, 0]) line1 = Line.from_points([0, 0], [1, 0])
line2 = Line.from_points([0, 0], [1, 1]) line2 = Line.from_points([0, 0], [1, 1])
line3 = Line.from_points([0, 0], [0, 1]) line3 = Line.from_points([-10, 0], [-10, 1])
# When # When
clusters, fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2, line3], consensusThreshold = 0.001) clusters, fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2, line3], consensusThreshold = 0.001)
@@ -129,13 +129,13 @@ class MultiLineFitterTest(unittest.TestCase):
np.testing.assert_array_equal( np.testing.assert_array_equal(
clusters, clusters,
[ [
[(1, 0), (2, 0), (3, 0)], [(0, 0), (1, 0), (2, 0)],
[(1, 1), (2, 2), (3, 3)] [(0, 0), (1, 1), (2, 2)]
]) ])
def test_fitPointsByLines(self): def test_fitPointsByLines(self):
# Given # Given
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)] points = [(0, 0), (1, 0), (2, 0), (1, 1), (2, 2)]
# When # When
clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001) clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)
@@ -147,6 +147,6 @@ class MultiLineFitterTest(unittest.TestCase):
np.testing.assert_array_equal( np.testing.assert_array_equal(
clusters, clusters,
[ [
[(1, 0), (2, 0), (3, 0)], [(0, 0), (1, 0), (2, 0)],
[(1, 1), (2, 2), (3, 3)] [(0, 0), (1, 1), (2, 2)]
]) ])