refining MultiLineFitterTest
This commit is contained in:
@@ -669,8 +669,8 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"symptomX = 'Immunosuppression'\n",
|
"symptomX = 'HIV test' # 'Immunosuppression'\n",
|
||||||
"symptomY = 'Infection' # 'Immunoglobulin therapy'"
|
"symptomY = 'Immunoglobulin therapy' # 'Infection' # 'Immunoglobulin therapy'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -723,7 +723,7 @@
|
|||||||
" s = 100,\n",
|
" s = 100,\n",
|
||||||
" label = \"Dots\")\n",
|
" label = \"Dots\")\n",
|
||||||
"for cluster, line in zip(clusters, lines):\n",
|
"for cluster, line in zip(clusters, lines):\n",
|
||||||
" if(len(cluster) > 2):\n",
|
" if len(cluster) >= 3:\n",
|
||||||
" coords = line.transform_points(cluster)\n",
|
" coords = line.transform_points(cluster)\n",
|
||||||
" magnitude = line.direction.norm()\n",
|
" magnitude = line.direction.norm()\n",
|
||||||
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",
|
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",
|
||||||
|
|||||||
@@ -0,0 +1,11 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class CharacteristicFunctions:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply(characteristicFunction, elements):
|
||||||
|
return np.array(elements)[CharacteristicFunctions._getIndexes(characteristicFunction)]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _getIndexes(characteristicFunction):
|
||||||
|
return [index for (index, value) in enumerate(characteristicFunction) if value == 1]
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory
|
from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory
|
||||||
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs
|
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs
|
||||||
|
from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import CharacteristicFunctions
|
||||||
|
|
||||||
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
|
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
|
||||||
class MultiLineFitter:
|
class MultiLineFitter:
|
||||||
@@ -12,10 +13,22 @@ class MultiLineFitter:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def fitLines(points, lines, consensusThreshold):
|
def fitLines(points, lines, consensusThreshold):
|
||||||
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||||
clusters, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix)
|
_, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix)
|
||||||
|
fittedLines = MultiLineFitter._getLines(lines, preferenceMatrix4Clusters)
|
||||||
return (
|
return (
|
||||||
MultiLineFitter._getClusterPoints(points, clusters),
|
MultiLineFitter._getFittedPointsList(points, fittedLines, consensusThreshold),
|
||||||
MultiLineFitter._getLines(lines, preferenceMatrix4Clusters))
|
fittedLines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _getFittedPointsList(points, lines, consensusThreshold):
|
||||||
|
return MultiLineFitter._getPointsList(
|
||||||
|
points,
|
||||||
|
MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _getPointsList(points, preferenceMatrix):
|
||||||
|
characteristicFunctionsOfConsensusSets = np.transpose(preferenceMatrix)
|
||||||
|
return [CharacteristicFunctions.apply(characteristicFunctionOfConsensusSet, points) for characteristicFunctionOfConsensusSet in characteristicFunctionsOfConsensusSets]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
||||||
|
|||||||
@@ -111,10 +111,10 @@ class MultiLineFitterTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_fitLines(self):
|
def test_fitLines(self):
|
||||||
# Given
|
# Given
|
||||||
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
|
points = [(0, 0), (1, 0), (2, 0), (1, 1), (2, 2)]
|
||||||
line1 = Line.from_points([0, 0], [1, 0])
|
line1 = Line.from_points([0, 0], [1, 0])
|
||||||
line2 = Line.from_points([0, 0], [1, 1])
|
line2 = Line.from_points([0, 0], [1, 1])
|
||||||
line3 = Line.from_points([0, 0], [0, 1])
|
line3 = Line.from_points([-10, 0], [-10, 1])
|
||||||
|
|
||||||
# When
|
# When
|
||||||
clusters, fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2, line3], consensusThreshold = 0.001)
|
clusters, fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2, line3], consensusThreshold = 0.001)
|
||||||
@@ -129,13 +129,13 @@ class MultiLineFitterTest(unittest.TestCase):
|
|||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
clusters,
|
clusters,
|
||||||
[
|
[
|
||||||
[(1, 0), (2, 0), (3, 0)],
|
[(0, 0), (1, 0), (2, 0)],
|
||||||
[(1, 1), (2, 2), (3, 3)]
|
[(0, 0), (1, 1), (2, 2)]
|
||||||
])
|
])
|
||||||
|
|
||||||
def test_fitPointsByLines(self):
|
def test_fitPointsByLines(self):
|
||||||
# Given
|
# Given
|
||||||
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
|
points = [(0, 0), (1, 0), (2, 0), (1, 1), (2, 2)]
|
||||||
|
|
||||||
# When
|
# When
|
||||||
clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)
|
clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)
|
||||||
@@ -147,6 +147,6 @@ class MultiLineFitterTest(unittest.TestCase):
|
|||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
clusters,
|
clusters,
|
||||||
[
|
[
|
||||||
[(1, 0), (2, 0), (3, 0)],
|
[(0, 0), (1, 0), (2, 0)],
|
||||||
[(1, 1), (2, 2), (3, 3)]
|
[(0, 0), (1, 1), (2, 2)]
|
||||||
])
|
])
|
||||||
|
|||||||
Reference in New Issue
Block a user