refining MultiLineFitterTest
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
from skspatial.objects import Line
|
from skspatial.objects import Line
|
||||||
from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter
|
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs
|
||||||
|
|
||||||
|
|
||||||
class LinesFactory:
|
class LinesFactory:
|
||||||
@@ -11,7 +11,7 @@ class LinesFactory:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _getPairs(points):
|
def _getPairs(points):
|
||||||
return ((points[i], points[j]) for (i, j) in MultiLineFitter._getPairs(len(points)))
|
return ((points[i], points[j]) for (i, j) in getPairs(len(points)))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _getUniqueLines(lines):
|
def _getUniqueLines(lines):
|
||||||
|
|||||||
@@ -1,8 +1,14 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory
|
||||||
|
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs
|
||||||
|
|
||||||
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
|
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
|
||||||
class MultiLineFitter:
|
class MultiLineFitter:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fitPointsByLines(points, consensusThreshold):
|
||||||
|
return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fitLines(points, lines, consensusThreshold):
|
def fitLines(points, lines, consensusThreshold):
|
||||||
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||||
@@ -28,7 +34,7 @@ class MultiLineFitter:
|
|||||||
bestClusterIndexCombination = None
|
bestClusterIndexCombination = None
|
||||||
keepClustering = False
|
keepClustering = False
|
||||||
numClusters = preferenceMatrix.shape[0]
|
numClusters = preferenceMatrix.shape[0]
|
||||||
for (clusterIndexA, clusterIndexB) in MultiLineFitter._getPairs(numClusters):
|
for (clusterIndexA, clusterIndexB) in getPairs(numClusters):
|
||||||
preferenceSetA = preferenceMatrix[clusterIndexA]
|
preferenceSetA = preferenceMatrix[clusterIndexA]
|
||||||
preferenceSetB = preferenceMatrix[clusterIndexB]
|
preferenceSetB = preferenceMatrix[clusterIndexB]
|
||||||
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
|
||||||
@@ -46,12 +52,6 @@ class MultiLineFitter:
|
|||||||
|
|
||||||
return clusters, preferenceMatrix
|
return clusters, preferenceMatrix
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _getPairs(n):
|
|
||||||
for i in range(n):
|
|
||||||
for j in range(i):
|
|
||||||
yield (i, j)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _intersectionOverUnion(setA, setB):
|
def _intersectionOverUnion(setA, setB):
|
||||||
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
||||||
|
|||||||
@@ -115,5 +115,14 @@ class MultiLineFitterTest(unittest.TestCase):
|
|||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(fittedLines, [line1, line2])
|
np.testing.assert_array_equal(fittedLines, [line1, line2])
|
||||||
|
|
||||||
#FK-TODO: erzeuge LinesFactory.createLines(points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)])
|
def test_fitPointsByLines(self):
|
||||||
# Diese Funktion soll alle Linien erzeugen, die jeweils zwei verschiedene Punkte aus points verbinden.
|
# Given
|
||||||
|
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
|
||||||
|
|
||||||
|
# When
|
||||||
|
lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)
|
||||||
|
|
||||||
|
# Then
|
||||||
|
self.assertEqual(len(lines), 2)
|
||||||
|
self.assertTrue(lines[0].is_close(Line.from_points([0, 0], [1, 0])))
|
||||||
|
self.assertTrue(lines[1].is_close(Line.from_points([0, 0], [1, 1])))
|
||||||
|
|||||||
4
src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py
Normal file
4
src/SymptomsCausedByVaccines/MultiLineFitting/Utils.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
def getPairs(n):
|
||||||
|
for i in range(n):
|
||||||
|
for j in range(i):
|
||||||
|
yield (i, j)
|
||||||
Reference in New Issue
Block a user