using only ascending lines for fitting
This commit is contained in:
@@ -651,6 +651,13 @@
|
||||
" htmlFile = os.path.normpath(webAppBaseDir + '/index.html'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Multi Line Fitting"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -669,8 +676,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"symptomX = 'HIV test' # 'Immunosuppression'\n",
|
||||
"symptomY = 'Immunoglobulin therapy' # 'Infection' # 'Immunoglobulin therapy'"
|
||||
"symptomX = 'Immunosuppression' # HIV test' # 'Immunosuppression'\n",
|
||||
"symptomY = 'Pneumonia' # 'Infection' # 'Immunoglobulin therapy'"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -702,7 +709,7 @@
|
||||
"source": [
|
||||
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
|
||||
"\n",
|
||||
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)"
|
||||
"clusters, lines = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -738,8 +745,10 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
|
||||
@@ -1,17 +1,32 @@
|
||||
from skspatial.objects import Line
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs
|
||||
|
||||
|
||||
class LinesFactory:
|
||||
|
||||
@staticmethod
|
||||
def createLines(points):
|
||||
lines = [Line.from_points(pointA, pointB) for (pointA, pointB) in LinesFactory._getPairs(points)]
|
||||
return LinesFactory._getUniqueLines(lines)
|
||||
return LinesFactory._getUniqueLines(list(LinesFactory._generateAllLines(points)))
|
||||
|
||||
@staticmethod
|
||||
def _getPairs(points):
|
||||
return ((points[i], points[j]) for (i, j) in getPairs(len(points)))
|
||||
def createAscendingLines(points):
|
||||
return LinesFactory._getUniqueLines(list(LinesFactory._generateAllAscendingLines(points)))
|
||||
|
||||
@staticmethod
|
||||
def _generateAllAscendingLines(points):
|
||||
return (line for line in LinesFactory._generateAllLines(points) if LinesFactory._isAscending(line.direction))
|
||||
|
||||
@staticmethod
|
||||
def _generateAllLines(points):
|
||||
return (Line.from_points(pointA, pointB) for (pointA, pointB) in LinesFactory._generatePairs(points))
|
||||
|
||||
@staticmethod
|
||||
def _isAscending(direction):
|
||||
return (direction[0] >= 0 and direction[1] >= 0) or (direction[0] <= 0 and direction[1] <= 0)
|
||||
|
||||
@staticmethod
|
||||
def _generatePairs(points):
|
||||
return ((points[i], points[j]) for (i, j) in generatePairs(len(points)))
|
||||
|
||||
@staticmethod
|
||||
def _getUniqueLines(lines):
|
||||
|
||||
@@ -28,3 +28,15 @@ class LinesFactoryTest(unittest.TestCase):
|
||||
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
|
||||
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
|
||||
self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1])))
|
||||
|
||||
def test_createAscendingLines(self):
|
||||
# Given
|
||||
points = [(0, 0), (1, 0), (0, 1)]
|
||||
|
||||
# When
|
||||
lines = LinesFactory.createAscendingLines(points)
|
||||
|
||||
# Then
|
||||
self.assertEqual(len(lines), 2)
|
||||
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
|
||||
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import CharacteristicFunctions
|
||||
|
||||
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
|
||||
@@ -10,6 +10,10 @@ class MultiLineFitter:
|
||||
def fitPointsByLines(points, consensusThreshold):
|
||||
return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold)
|
||||
|
||||
@staticmethod
|
||||
def fitPointsByAscendingLines(points, consensusThreshold):
|
||||
return MultiLineFitter.fitLines(points, LinesFactory.createAscendingLines(points), consensusThreshold)
|
||||
|
||||
@staticmethod
|
||||
def fitLines(points, lines, consensusThreshold):
|
||||
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||
@@ -48,7 +52,7 @@ class MultiLineFitter:
|
||||
bestClusterIndexCombination = None
|
||||
keepClustering = False
|
||||
numClusters = preferenceMatrix.shape[0]
|
||||
for (clusterIndexA, clusterIndexB) in getPairs(numClusters):
|
||||
for (clusterIndexA, clusterIndexB) in generatePairs(numClusters):
|
||||
preferenceSetA = preferenceMatrix[clusterIndexA]
|
||||
preferenceSetB = preferenceMatrix[clusterIndexB]
|
||||
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
def getPairs(n):
|
||||
def generatePairs(n):
|
||||
for i in range(n):
|
||||
for j in range(i):
|
||||
yield (i, j)
|
||||
|
||||
Reference in New Issue
Block a user