From 4c7da48bf9e2e7cf2e509ab888d49275254f0ae1 Mon Sep 17 00:00:00 2001 From: frankknoll Date: Thu, 16 Nov 2023 12:01:56 +0100 Subject: [PATCH] starting PreferenceMatrixFactoryTest --- .gitignore | 1 + environment.yml | 3 ++- .../PreferenceMatrixFactory.py | 14 ++++++++++ .../PreferenceMatrixFactoryTest.py | 26 +++++++++++++++++++ .../MultiLineFitting/__init__.py | 0 5 files changed, 43 insertions(+), 1 deletion(-) create mode 100644 src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactory.py create mode 100644 src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactoryTest.py create mode 100644 src/SymptomsCausedByVaccines/MultiLineFitting/__init__.py diff --git a/.gitignore b/.gitignore index 0a12bc1c9c5..1cbb1611fda 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ google-chrome-stable_current_amd64* src/captcha/__pycache__ src/GoogleAnalytics/__pycache__ src/SymptomsCausedByVaccines/__pycache__ +src/SymptomsCausedByVaccines/MultiLineFitting/__pycache__ diff --git a/environment.yml b/environment.yml index 746cb38373f..46c247952bb 100644 --- a/environment.yml +++ b/environment.yml @@ -1,13 +1,14 @@ name: howbadismybatch-venv channels: - defaults - # - conda-forge + - conda-forge dependencies: - python=3.9 - ipykernel - numpy - pandas - scikit-learn + - scikit-spatial - urllib3 - requests - bs4 diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactory.py b/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactory.py new file mode 100644 index 00000000000..7bf93eb1ab1 --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactory.py @@ -0,0 +1,14 @@ +import numpy as np +from skspatial.objects import Line + +class PreferenceMatrixFactory: + + @staticmethod + def createPreferenceMatrix(points, lines, consensusThreshold): + preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int) + for pointIndex, point in enumerate(points): + for lineIndex, line in enumerate(lines): + preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0 + + return preferenceMatrix + diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactoryTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactoryTest.py new file mode 100644 index 00000000000..1df52a5dccc --- /dev/null +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/PreferenceMatrixFactoryTest.py @@ -0,0 +1,26 @@ +import unittest +import numpy as np +from numpy.testing import assert_array_equal +from skspatial.objects import Line +from SymptomsCausedByVaccines.MultiLineFitting.PreferenceMatrixFactory import PreferenceMatrixFactory + + +class PreferenceMatrixFactoryTest(unittest.TestCase): + + def test_createPreferenceMatrix(self): + # Given + points = [(1, 3), (10, 20)] + lines = [Line.from_points([0, 0], [100, 0])] + consensusThreshold = 4.0 + + # When + preferenceMatrix = PreferenceMatrixFactory.createPreferenceMatrix(points, lines, consensusThreshold) + + # Then + assert_array_equal( + preferenceMatrix, + np.array( + [ + [1], + [0] + ])) diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/__init__.py b/src/SymptomsCausedByVaccines/MultiLineFitting/__init__.py new file mode 100644 index 00000000000..e69de29bb2d