diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py new file mode 100644 index 00000000000..0887ae6d034 --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py @@ -0,0 +1,30 @@ +from skspatial.objects import Line +from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter + + +class LinesFactory: + + @staticmethod + def createLines(points): + lines = [Line.from_points(pointA, pointB) for (pointA, pointB) in LinesFactory._getPairs(points)] + return LinesFactory._getUniqueLines(lines) + + @staticmethod + def _getPairs(points): + return ((points[i], points[j]) for (i, j) in MultiLineFitter._getPairs(len(points))) + + @staticmethod + def _getUniqueLines(lines): + uniqueLines = [] + for i in range(len(lines)): + line = lines[i] + if not LinesFactory._isLineCloseToAnyOtherLine(line, lines[i + 1:]): + uniqueLines.append(line) + return uniqueLines + + @staticmethod + def _isLineCloseToAnyOtherLine(line, otherLines): + for otherLine in otherLines: + if line.is_close(otherLine): + return True + return False diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py new file mode 100644 index 00000000000..d53ec86e419 --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py @@ -0,0 +1,30 @@ +import unittest +from skspatial.objects import Line +from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory + + +class LinesFactoryTest(unittest.TestCase): + + def test_createLines(self): + # Given + points = [(1, 0), (2, 0), (3, 0)] + + # When + lines = LinesFactory.createLines(points) + + # Then + self.assertEqual(len(lines), 1) + self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) + + def test_createLines2(self): + # Given + points = [(0, 0), (1, 0), (0, 1)] + + # When + lines = LinesFactory.createLines(points) + + # Then + self.assertEqual(len(lines), 3) + self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) + self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1]))) + self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1]))) diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py index 0e77c83cb34..3c5d0c5df06 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py @@ -1,5 +1,4 @@ import numpy as np -from skspatial.objects import Line # implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage class MultiLineFitter: