refining MultiLineFitterTest
This commit is contained in:
@@ -669,8 +669,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"symptomX = 'Immunosuppression'\n",
|
||||
"symptomY = 'Infection' # 'Immunoglobulin therapy'"
|
||||
"symptomX = 'HIV test' # 'Immunosuppression'\n",
|
||||
"symptomY = 'Immunoglobulin therapy' # 'Infection' # 'Immunoglobulin therapy'"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -723,7 +723,7 @@
|
||||
" s = 100,\n",
|
||||
" label = \"Dots\")\n",
|
||||
"for cluster, line in zip(clusters, lines):\n",
|
||||
" if(len(cluster) > 2):\n",
|
||||
" if len(cluster) >= 3:\n",
|
||||
" coords = line.transform_points(cluster)\n",
|
||||
" magnitude = line.direction.norm()\n",
|
||||
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
import numpy as np
|
||||
|
||||
class CharacteristicFunctions:
|
||||
|
||||
@staticmethod
|
||||
def apply(characteristicFunction, elements):
|
||||
return np.array(elements)[CharacteristicFunctions._getIndexes(characteristicFunction)]
|
||||
|
||||
@staticmethod
|
||||
def _getIndexes(characteristicFunction):
|
||||
return [index for (index, value) in enumerate(characteristicFunction) if value == 1]
|
||||
@@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import CharacteristicFunctions
|
||||
|
||||
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
|
||||
class MultiLineFitter:
|
||||
@@ -12,10 +13,22 @@ class MultiLineFitter:
|
||||
@staticmethod
|
||||
def fitLines(points, lines, consensusThreshold):
|
||||
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||
clusters, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix)
|
||||
_, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix)
|
||||
fittedLines = MultiLineFitter._getLines(lines, preferenceMatrix4Clusters)
|
||||
return (
|
||||
MultiLineFitter._getClusterPoints(points, clusters),
|
||||
MultiLineFitter._getLines(lines, preferenceMatrix4Clusters))
|
||||
MultiLineFitter._getFittedPointsList(points, fittedLines, consensusThreshold),
|
||||
fittedLines)
|
||||
|
||||
@staticmethod
|
||||
def _getFittedPointsList(points, lines, consensusThreshold):
|
||||
return MultiLineFitter._getPointsList(
|
||||
points,
|
||||
MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold))
|
||||
|
||||
@staticmethod
|
||||
def _getPointsList(points, preferenceMatrix):
|
||||
characteristicFunctionsOfConsensusSets = np.transpose(preferenceMatrix)
|
||||
return [CharacteristicFunctions.apply(characteristicFunctionOfConsensusSet, points) for characteristicFunctionOfConsensusSet in characteristicFunctionsOfConsensusSets]
|
||||
|
||||
@staticmethod
|
||||
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
||||
|
||||
@@ -111,10 +111,10 @@ class MultiLineFitterTest(unittest.TestCase):
|
||||
|
||||
def test_fitLines(self):
|
||||
# Given
|
||||
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
|
||||
points = [(0, 0), (1, 0), (2, 0), (1, 1), (2, 2)]
|
||||
line1 = Line.from_points([0, 0], [1, 0])
|
||||
line2 = Line.from_points([0, 0], [1, 1])
|
||||
line3 = Line.from_points([0, 0], [0, 1])
|
||||
line3 = Line.from_points([-10, 0], [-10, 1])
|
||||
|
||||
# When
|
||||
clusters, fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2, line3], consensusThreshold = 0.001)
|
||||
@@ -129,13 +129,13 @@ class MultiLineFitterTest(unittest.TestCase):
|
||||
np.testing.assert_array_equal(
|
||||
clusters,
|
||||
[
|
||||
[(1, 0), (2, 0), (3, 0)],
|
||||
[(1, 1), (2, 2), (3, 3)]
|
||||
[(0, 0), (1, 0), (2, 0)],
|
||||
[(0, 0), (1, 1), (2, 2)]
|
||||
])
|
||||
|
||||
def test_fitPointsByLines(self):
|
||||
# Given
|
||||
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
|
||||
points = [(0, 0), (1, 0), (2, 0), (1, 1), (2, 2)]
|
||||
|
||||
# When
|
||||
clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)
|
||||
@@ -147,6 +147,6 @@ class MultiLineFitterTest(unittest.TestCase):
|
||||
np.testing.assert_array_equal(
|
||||
clusters,
|
||||
[
|
||||
[(1, 0), (2, 0), (3, 0)],
|
||||
[(1, 1), (2, 2), (3, 3)]
|
||||
[(0, 0), (1, 0), (2, 0)],
|
||||
[(0, 0), (1, 1), (2, 2)]
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user