refining ClustersFactoryTest
This commit is contained in:
@@ -31,7 +31,7 @@ class ClustersFactory:
|
||||
preferenceMatrix[clusterIndexA] = np.logical_and(preferenceMatrix[clusterIndexA], preferenceMatrix[clusterIndexB])
|
||||
preferenceMatrix = np.delete(preferenceMatrix, clusterIndexB, axis = 0)
|
||||
|
||||
return clusters
|
||||
return clusters, preferenceMatrix
|
||||
|
||||
@staticmethod
|
||||
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
||||
@@ -46,3 +46,7 @@ class ClustersFactory:
|
||||
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
||||
union = np.count_nonzero(np.logical_or(setA, setB))
|
||||
return 1. * intersection / union
|
||||
|
||||
@staticmethod
|
||||
def _getLineIndexes(preferenceMatrix):
|
||||
return [list(lines).index(1) for lines in preferenceMatrix]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from skspatial.objects import Line
|
||||
from src.SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory
|
||||
from SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory
|
||||
|
||||
|
||||
class ClustersFactoryTest(unittest.TestCase):
|
||||
@@ -55,7 +55,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
||||
])
|
||||
|
||||
# When
|
||||
clusters = ClustersFactory.createClusters(preferenceMatrix)
|
||||
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
|
||||
|
||||
# Then
|
||||
np.testing.assert_array_equal(
|
||||
@@ -77,7 +77,7 @@ class ClustersFactoryTest(unittest.TestCase):
|
||||
])
|
||||
|
||||
# When
|
||||
clusters = ClustersFactory.createClusters(preferenceMatrix)
|
||||
clusters, _ = ClustersFactory.createClusters(preferenceMatrix)
|
||||
|
||||
# Then
|
||||
np.testing.assert_array_equal(
|
||||
@@ -87,3 +87,17 @@ class ClustersFactoryTest(unittest.TestCase):
|
||||
[2, 1, 0],
|
||||
[4, 3]
|
||||
]))
|
||||
|
||||
def test_getLineIndexes(self):
|
||||
# Given
|
||||
preferenceMatrix = np.array(
|
||||
[
|
||||
[0, 0, 1],
|
||||
[0, 1, 1]
|
||||
])
|
||||
|
||||
# When
|
||||
lineIndexes = ClustersFactory._getLineIndexes(preferenceMatrix)
|
||||
|
||||
# Then
|
||||
np.testing.assert_array_equal(lineIndexes, [2, 1])
|
||||
|
||||
Reference in New Issue
Block a user