refining ClustersFactoryTest
This commit is contained in:
@@ -31,7 +31,7 @@ class ClustersFactory:
|
|||||||
preferenceMatrix[clusterIndexA] = np.logical_and(preferenceMatrix[clusterIndexA], preferenceMatrix[clusterIndexB])
|
preferenceMatrix[clusterIndexA] = np.logical_and(preferenceMatrix[clusterIndexA], preferenceMatrix[clusterIndexB])
|
||||||
preferenceMatrix = np.delete(preferenceMatrix, clusterIndexB, axis = 0)
|
preferenceMatrix = np.delete(preferenceMatrix, clusterIndexB, axis = 0)
|
||||||
|
|
||||||
return clusters
|
return clusters, preferenceMatrix
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
||||||
@@ -46,3 +46,7 @@ class ClustersFactory:
|
|||||||
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
||||||
union = np.count_nonzero(np.logical_or(setA, setB))
|
union = np.count_nonzero(np.logical_or(setA, setB))
|
||||||
return 1. * intersection / union
|
return 1. * intersection / union
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _getLineIndexes(preferenceMatrix):
|
||||||
|
return [list(lines).index(1) for lines in preferenceMatrix]
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from skspatial.objects import Line
|
from skspatial.objects import Line
|
||||||
from src.SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory
|
from SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory
|
||||||
|
|
||||||
|
|
||||||
class ClustersFactoryTest(unittest.TestCase):
|
class ClustersFactoryTest(unittest.TestCase):
|
||||||
@@ -55,7 +55,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# When
|
# When
|
||||||
clusters = ClustersFactory.createClusters(preferenceMatrix)
|
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
@@ -77,7 +77,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# When
|
# When
|
||||||
clusters = ClustersFactory.createClusters(preferenceMatrix)
|
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
@@ -87,3 +87,17 @@ class ClustersFactoryTest(unittest.TestCase):
|
|||||||
[2, 1, 0],
|
[2, 1, 0],
|
||||||
[4, 3]
|
[4, 3]
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
def test_getLineIndexes(self):
|
||||||
|
# Given
|
||||||
|
preferenceMatrix = np.array(
|
||||||
|
[
|
||||||
|
[0, 0, 1],
|
||||||
|
[0, 1, 1]
|
||||||
|
])
|
||||||
|
|
||||||
|
# When
|
||||||
|
lineIndexes = ClustersFactory._getLineIndexes(preferenceMatrix)
|
||||||
|
|
||||||
|
# Then
|
||||||
|
np.testing.assert_array_equal(lineIndexes, [2, 1])
|
||||||
|
|||||||
Reference in New Issue
Block a user