refining MultiLineFitterTest

This commit is contained in:
frankknoll
2023-11-17 10:28:17 +01:00
parent 1941337066
commit a14c1cd217
4 changed files with 24 additions and 11 deletions

View File

@@ -1,8 +1,14 @@
import numpy as np
from SymptomsCausedByVaccines.MultiLineFitting.LinesFactory import LinesFactory
from SymptomsCausedByVaccines.MultiLineFitting.Utils import getPairs
# implementation of "Robust Multiple Structures Estimation with J-linkage" adapted from https://github.com/fkluger/vp-linkage
class MultiLineFitter:
@staticmethod
def fitPointsByLines(points, consensusThreshold):
return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold)
@staticmethod
def fitLines(points, lines, consensusThreshold):
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
@@ -28,7 +34,7 @@ class MultiLineFitter:
bestClusterIndexCombination = None
keepClustering = False
numClusters = preferenceMatrix.shape[0]
for (clusterIndexA, clusterIndexB) in MultiLineFitter._getPairs(numClusters):
for (clusterIndexA, clusterIndexB) in getPairs(numClusters):
preferenceSetA = preferenceMatrix[clusterIndexA]
preferenceSetB = preferenceMatrix[clusterIndexB]
similarity = MultiLineFitter._intersectionOverUnion(preferenceSetA, preferenceSetB);
@@ -46,12 +52,6 @@ class MultiLineFitter:
return clusters, preferenceMatrix
@staticmethod
def _getPairs(n):
for i in range(n):
for j in range(i):
yield (i, j)
@staticmethod
def _intersectionOverUnion(setA, setB):
intersection = np.count_nonzero(np.logical_and(setA, setB))