refining ClustersFactoryTest

This commit is contained in:
frankknoll
2023-11-16 14:46:59 +01:00
parent 4c7da48bf9
commit 487ff3eff0
4 changed files with 99 additions and 40 deletions

View File

@@ -0,0 +1,55 @@
import numpy as np
from skspatial.objects import Line
class ClustersFactory:
@staticmethod
def createPreferenceMatrix(points, lines, consensusThreshold):
preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int)
for pointIndex, point in enumerate(points):
for lineIndex, line in enumerate(lines):
preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0
return preferenceMatrix
@staticmethod
def createClusters(preferenceMatrix):
keep_clustering = True
cluster_step = 0
num_clusters = preferenceMatrix.shape[0]
clusters = [[i] for i in range(num_clusters)]
while keep_clustering:
smallest_distance = 0
best_combo = None
keep_clustering = False
num_clusters = preferenceMatrix.shape[0]
for i in range(num_clusters):
for j in range(i):
set_a = preferenceMatrix[i]
set_b = preferenceMatrix[j]
intersection = np.count_nonzero(np.logical_and(set_a, set_b))
union = np.count_nonzero(np.logical_or(set_a, set_b))
distance = 1.*intersection/np.maximum(union, 1e-8)
if distance > smallest_distance:
keep_clustering = True
smallest_distance = distance
best_combo = (i,j)
if keep_clustering:
clusters[best_combo[0]] += clusters[best_combo[1]]
clusters.pop(best_combo[1])
set_a = preferenceMatrix[best_combo[0]]
set_b = preferenceMatrix[best_combo[1]]
merged_set = np.logical_and(set_a, set_b)
preferenceMatrix[best_combo[0]] = merged_set
preferenceMatrix = np.delete(preferenceMatrix, best_combo[1], axis=0)
cluster_step += 1
print("clustering finished after %d steps" % cluster_step)
return preferenceMatrix, clusters

View File

@@ -0,0 +1,44 @@
import unittest
import numpy as np
from skspatial.objects import Line
from src.SymptomsCausedByVaccines.MultiLineFitting.ClustersFactory import ClustersFactory
class ClustersFactoryTest(unittest.TestCase):
def test_createPreferenceMatrix(self):
# Given
points = [(1, 3), (10, 20)]
lines = [Line.from_points([0, 0], [100, 0])]
consensusThreshold = 4.0
# When
preferenceMatrix = ClustersFactory.createPreferenceMatrix(points, lines, consensusThreshold)
# Then
np.testing.assert_array_equal(
preferenceMatrix,
np.array(
[
[1],
[0]
]))
def test_createClusters(self):
# Given
preferenceMatrix = np.array(
[
[1],
[1]
])
# When
_, clusters = ClustersFactory.createClusters(preferenceMatrix)
# Then
np.testing.assert_array_equal(
clusters,
np.array(
[
[1, 0]
]))

View File

@@ -1,14 +0,0 @@
import numpy as np
from skspatial.objects import Line
class PreferenceMatrixFactory:
@staticmethod
def createPreferenceMatrix(points, lines, consensusThreshold):
preferenceMatrix = np.zeros([len(points), len(lines)], dtype = int)
for pointIndex, point in enumerate(points):
for lineIndex, line in enumerate(lines):
preferenceMatrix[pointIndex, lineIndex] = 1 if line.distance_point(point) <= consensusThreshold else 0
return preferenceMatrix

View File

@@ -1,26 +0,0 @@
import unittest
import numpy as np
from numpy.testing import assert_array_equal
from skspatial.objects import Line
from SymptomsCausedByVaccines.MultiLineFitting.PreferenceMatrixFactory import PreferenceMatrixFactory
class PreferenceMatrixFactoryTest(unittest.TestCase):
def test_createPreferenceMatrix(self):
# Given
points = [(1, 3), (10, 20)]
lines = [Line.from_points([0, 0], [100, 0])]
consensusThreshold = 4.0
# When
preferenceMatrix = PreferenceMatrixFactory.createPreferenceMatrix(points, lines, consensusThreshold)
# Then
assert_array_equal(
preferenceMatrix,
np.array(
[
[1],
[0]
]))