using only ascending lines for fitting

This commit is contained in:
frankknoll
2023-11-18 11:05:35 +01:00
parent 36492ae88b
commit d40116ba6f
5 changed files with 52 additions and 12 deletions

View File

@@ -651,6 +651,13 @@
" htmlFile = os.path.normpath(webAppBaseDir + '/index.html'))" " htmlFile = os.path.normpath(webAppBaseDir + '/index.html'))"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multi Line Fitting"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@@ -669,8 +676,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"symptomX = 'HIV test' # 'Immunosuppression'\n", "symptomX = 'Immunosuppression' # HIV test' # 'Immunosuppression'\n",
"symptomY = 'Immunoglobulin therapy' # 'Infection' # 'Immunoglobulin therapy'" "symptomY = 'Pneumonia' # 'Infection' # 'Immunoglobulin therapy'"
] ]
}, },
{ {
@@ -702,7 +709,7 @@
"source": [ "source": [
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n", "from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
"\n", "\n",
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)" "clusters, lines = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)"
] ]
}, },
{ {
@@ -738,8 +745,10 @@
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "code",
"execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [] "source": []
} }
], ],

View File

@@ -1,17 +1,32 @@
from skspatial.objects import Line from skspatial.objects import Line
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs
class LinesFactory: class LinesFactory:
@staticmethod @staticmethod
def createLines(points): def createLines(points):
lines = [Line.from_points(pointA, pointB) for (pointA, pointB) in LinesFactory._getPairs(points)] return LinesFactory._getUniqueLines(list(LinesFactory._generateAllLines(points)))
return LinesFactory._getUniqueLines(lines)
@staticmethod @staticmethod
def _getPairs(points): def createAscendingLines(points):
return ((points[i], points[j]) for (i, j) in getPairs(len(points))) return LinesFactory._getUniqueLines(list(LinesFactory._generateAllAscendingLines(points)))
@staticmethod
def _generateAllAscendingLines(points):
return (line for line in LinesFactory._generateAllLines(points) if LinesFactory._isAscending(line.direction))
@staticmethod
def _generateAllLines(points):
return (Line.from_points(pointA, pointB) for (pointA, pointB) in LinesFactory._generatePairs(points))
@staticmethod
def _isAscending(direction):
return (direction[0] >= 0 and direction[1] >= 0) or (direction[0] <= 0 and direction[1] <= 0)
@staticmethod
def _generatePairs(points):
return ((points[i], points[j]) for (i, j) in generatePairs(len(points)))
@staticmethod @staticmethod
def _getUniqueLines(lines): def _getUniqueLines(lines):

View File

@@ -28,3 +28,15 @@ class LinesFactoryTest(unittest.TestCase):
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1]))) self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1]))) self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1])))
def test_createAscendingLines(self):
# Given
points = [(0, 0), (1, 0), (0, 1)]
# When
lines = LinesFactory.createAscendingLines(points)
# Then
self.assertEqual(len(lines), 2)
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))

View File

@@ -1,6 +1,6 @@
import numpy as np import numpy as np
from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs
from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import CharacteristicFunctions from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import CharacteristicFunctions
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage # implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
@@ -10,6 +10,10 @@ class MultiLineFitter:
def fitPointsByLines(points, consensusThreshold): def fitPointsByLines(points, consensusThreshold):
return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold) return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold)
@staticmethod
def fitPointsByAscendingLines(points, consensusThreshold):
return MultiLineFitter.fitLines(points, LinesFactory.createAscendingLines(points), consensusThreshold)
@staticmethod @staticmethod
def fitLines(points, lines, consensusThreshold): def fitLines(points, lines, consensusThreshold):
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold) preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
@@ -48,7 +52,7 @@ class MultiLineFitter:
bestClusterIndexCombination = None bestClusterIndexCombination = None
keepClustering = False keepClustering = False
numClusters = preferenceMatrix.shape[0] numClusters = preferenceMatrix.shape[0]
for (clusterIndexA, clusterIndexB) in getPairs(numClusters): for (clusterIndexA, clusterIndexB) in generatePairs(numClusters):
preferenceSetA = preferenceMatrix[clusterIndexA] preferenceSetA = preferenceMatrix[clusterIndexA]
preferenceSetB = preferenceMatrix[clusterIndexB] preferenceSetB = preferenceMatrix[clusterIndexB]
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB); similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);

View File

@@ -1,4 +1,4 @@
def getPairs(n): def generatePairs(n):
for i in range(n): for i in range(n):
for j in range(i): for j in range(i):
yield (i, j) yield (i, j)