From 1941337066b8f9539229bc3e0cce44ed29f9f646 Mon Sep 17 00:00:00 2001 From: frankknoll Date: Fri, 17 Nov 2023 10:10:18 +0100 Subject: [PATCH] starting LinesFactoryTest --- .../MultiLineFitting/LinesFactory.py | 30 +++++++++++++++++++ .../MultiLineFitting/LinesFactoryTest.py | 30 +++++++++++++++++++ .../MultiLineFitting/MultiLineFitter.py | 1 - 3 files changed, 60 insertions(+), 1 deletion(-) create mode 100644 src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py create mode 100644 src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py new file mode 100644 index 00000000000..0887ae6d034 --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactory.py @@ -0,0 +1,30 @@ +from skspatial.objects import Line +from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter + + +class LinesFactory: + + @staticmethod + def createLines(points): + lines = [Line.from_points(pointA, pointB) for (pointA, pointB) in LinesFactory._getPairs(points)] + return LinesFactory._getUniqueLines(lines) + + @staticmethod + def _getPairs(points): + return ((points[i], points[j]) for (i, j) in MultiLineFitter._getPairs(len(points))) + + @staticmethod + def _getUniqueLines(lines): + uniqueLines = [] + for i in range(len(lines)): + line = lines[i] + if not LinesFactory._isLineCloseToAnyOtherLine(line, lines[i + 1:]): + uniqueLines.append(line) + return uniqueLines + + @staticmethod + def _isLineCloseToAnyOtherLine(line, otherLines): + for otherLine in otherLines: + if line.is_close(otherLine): + return True + return False diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py new file mode 100644 index 00000000000..d53ec86e419 --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/LinesFactoryTest.py @@ -0,0 +1,30 @@ +import unittest +from skspatial.objects import Line +from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory + + +class LinesFactoryTest(unittest.TestCase): + + def test_createLines(self): + # Given + points = [(1, 0), (2, 0), (3, 0)] + + # When + lines = LinesFactory.createLines(points) + + # Then + self.assertEqual(len(lines), 1) + self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) + + def test_createLines2(self): + # Given + points = [(0, 0), (1, 0), (0, 1)] + + # When + lines = LinesFactory.createLines(points) + + # Then + self.assertEqual(len(lines), 3) + self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) + self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1]))) + self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1]))) diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py index 0e77c83cb34..3c5d0c5df06 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py @@ -1,5 +1,4 @@ import numpy as np -from skspatial.objects import Line # implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage class MultiLineFitter: