refining ClustersFactoryTest

This commit is contained in:
frankknoll
2023-11-16 17:36:07 +01:00
parent c2f900504a
commit 8231453ae2
2 changed files with 22 additions and 4 deletions

View File

@@ -31,7 +31,7 @@ class ClustersFactory:
preferenceMatrix[clusterIndexA] = np.logical_and(preferenceMatrix[clusterIndexA], preferenceMatrix[clusterIndexB])
preferenceMatrix = np.delete(preferenceMatrix, clusterIndexB, axis = 0)
return clusters
return clusters, preferenceMatrix
@staticmethod
def _createPreferenceMatrix(points, lines, consensusThreshold):
@@ -46,3 +46,7 @@ class ClustersFactory:
intersection = np.count_nonzero(np.logical_and(setA, setB))
union = np.count_nonzero(np.logical_or(setA, setB))
return 1. * intersection / union
@staticmethod
def _getLineIndexes(preferenceMatrix):
return [list(lines).index(1) for lines in preferenceMatrix]

View File

@@ -1,7 +1,7 @@
import unittest
import numpy as np
from skspatial.objects import Line
from src.SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory
from SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory
class ClustersFactoryTest(unittest.TestCase):
@@ -55,7 +55,7 @@ class ClustersFactoryTest(unittest.TestCase):
])
# When
clusters = ClustersFactory.createClusters(preferenceMatrix)
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
# Then
np.testing.assert_array_equal(
@@ -77,7 +77,7 @@ class ClustersFactoryTest(unittest.TestCase):
])
# When
clusters = ClustersFactory.createClusters(preferenceMatrix)
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
# Then
np.testing.assert_array_equal(
@@ -87,3 +87,17 @@ class ClustersFactoryTest(unittest.TestCase):
[2, 1, 0],
[4, 3]
]))
def test_getLineIndexes(self):
# Given
preferenceMatrix = np.array(
[
[0, 0, 1],
[0, 1, 1]
])
# When
lineIndexes = ClustersFactory._getLineIndexes(preferenceMatrix)
# Then
np.testing.assert_array_equal(lineIndexes, [2, 1])