using only ascending lines for fitting
This commit is contained in:
@@ -651,6 +651,13 @@
|
|||||||
" htmlFile = os.path.normpath(webAppBaseDir + '/index.html'))"
|
" htmlFile = os.path.normpath(webAppBaseDir + '/index.html'))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Multi Line Fitting"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
@@ -669,8 +676,8 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"symptomX = 'HIV test' # 'Immunosuppression'\n",
|
"symptomX = 'Immunosuppression' # HIV test' # 'Immunosuppression'\n",
|
||||||
"symptomY = 'Immunoglobulin therapy' # 'Infection' # 'Immunoglobulin therapy'"
|
"symptomY = 'Pneumonia' # 'Infection' # 'Immunoglobulin therapy'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -702,7 +709,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
|
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
|
||||||
"\n",
|
"\n",
|
||||||
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)"
|
"clusters, lines = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -738,8 +745,10 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,17 +1,32 @@
|
|||||||
from skspatial.objects import Line
|
from skspatial.objects import Line
|
||||||
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs
|
from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs
|
||||||
|
|
||||||
|
|
||||||
class LinesFactory:
|
class LinesFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def createLines(points):
|
def createLines(points):
|
||||||
lines = [Line.from_points(pointA, pointB) for (pointA, pointB) in LinesFactory._getPairs(points)]
|
return LinesFactory._getUniqueLines(list(LinesFactory._generateAllLines(points)))
|
||||||
return LinesFactory._getUniqueLines(lines)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _getPairs(points):
|
def createAscendingLines(points):
|
||||||
return ((points[i], points[j]) for (i, j) in getPairs(len(points)))
|
return LinesFactory._getUniqueLines(list(LinesFactory._generateAllAscendingLines(points)))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generateAllAscendingLines(points):
|
||||||
|
return (line for line in LinesFactory._generateAllLines(points) if LinesFactory._isAscending(line.direction))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generateAllLines(points):
|
||||||
|
return (Line.from_points(pointA, pointB) for (pointA, pointB) in LinesFactory._generatePairs(points))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _isAscending(direction):
|
||||||
|
return (direction[0] >= 0 and direction[1] >= 0) or (direction[0] <= 0 and direction[1] <= 0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generatePairs(points):
|
||||||
|
return ((points[i], points[j]) for (i, j) in generatePairs(len(points)))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _getUniqueLines(lines):
|
def _getUniqueLines(lines):
|
||||||
|
|||||||
@@ -28,3 +28,15 @@ class LinesFactoryTest(unittest.TestCase):
|
|||||||
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
|
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
|
||||||
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
|
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
|
||||||
self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1])))
|
self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1])))
|
||||||
|
|
||||||
|
def test_createAscendingLines(self):
|
||||||
|
# Given
|
||||||
|
points = [(0, 0), (1, 0), (0, 1)]
|
||||||
|
|
||||||
|
# When
|
||||||
|
lines = LinesFactory.createAscendingLines(points)
|
||||||
|
|
||||||
|
# Then
|
||||||
|
self.assertEqual(len(lines), 2)
|
||||||
|
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
|
||||||
|
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory
|
from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory
|
||||||
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs
|
from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs
|
||||||
from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import CharacteristicFunctions
|
from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import CharacteristicFunctions
|
||||||
|
|
||||||
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
|
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
|
||||||
@@ -10,6 +10,10 @@ class MultiLineFitter:
|
|||||||
def fitPointsByLines(points, consensusThreshold):
|
def fitPointsByLines(points, consensusThreshold):
|
||||||
return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold)
|
return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fitPointsByAscendingLines(points, consensusThreshold):
|
||||||
|
return MultiLineFitter.fitLines(points, LinesFactory.createAscendingLines(points), consensusThreshold)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fitLines(points, lines, consensusThreshold):
|
def fitLines(points, lines, consensusThreshold):
|
||||||
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||||
@@ -48,7 +52,7 @@ class MultiLineFitter:
|
|||||||
bestClusterIndexCombination = None
|
bestClusterIndexCombination = None
|
||||||
keepClustering = False
|
keepClustering = False
|
||||||
numClusters = preferenceMatrix.shape[0]
|
numClusters = preferenceMatrix.shape[0]
|
||||||
for (clusterIndexA, clusterIndexB) in getPairs(numClusters):
|
for (clusterIndexA, clusterIndexB) in generatePairs(numClusters):
|
||||||
preferenceSetA = preferenceMatrix[clusterIndexA]
|
preferenceSetA = preferenceMatrix[clusterIndexA]
|
||||||
preferenceSetB = preferenceMatrix[clusterIndexB]
|
preferenceSetB = preferenceMatrix[clusterIndexB]
|
||||||
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
def getPairs(n):
|
def generatePairs(n):
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
for j in range(i):
|
for j in range(i):
|
||||||
yield (i, j)
|
yield (i, j)
|
||||||
|
|||||||
Reference in New Issue
Block a user