refining LinesFactoryTest

This commit is contained in:
frankknoll
2023-11-19 16:39:29 +01:00
parent 7081d9014b
commit 95da66087d
5 changed files with 103 additions and 26 deletions

View File

@@ -666,9 +666,7 @@
"source": [ "source": [
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n", "from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
"from SymptomsCausedByVaccines.MultiLineFitting.SymptomCombinationsProvider import SymptomCombinationsProvider\n", "from SymptomsCausedByVaccines.MultiLineFitting.SymptomCombinationsProvider import SymptomCombinationsProvider\n",
"import numpy as np\n", "from matplotlib import pyplot as plt\n"
"from matplotlib import pyplot as plt\n",
"from skspatial.objects import Line\n"
] ]
}, },
{ {
@@ -677,8 +675,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"symptomX = 'Abdominal abscess' # HIV test' # 'Immunosuppression'\n", "# symptomX = 'Abdominal discomfort' # HIV test' # 'Immunosuppression'\n",
"symptomY = 'Abdominal discomfort' # 'Infection' # 'Immunoglobulin therapy'" "# symptomY = 'Abdominal distension' # 'Infection' # 'Immunoglobulin therapy'"
] ]
}, },
{ {
@@ -687,9 +685,20 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"df = prrByLotAndSymptom[[symptomX, symptomY]]\n", "# df = prrByLotAndSymptom[[symptomX, symptomY]]\n",
"df = df[(df[symptomX] != 0) & (df[symptomY] != 0)]\n", "# df = df[(df[symptomX] != 0) & (df[symptomY] != 0)]\n",
"df" "# df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# retain only those columns of prrByLotAndSymptom that have more than 400 PRRs != 0\n",
"# prrByLotAndSymptom2 = prrByLotAndSymptom.loc[:, (prrByLotAndSymptom != 0).sum() >= 400]\n",
"# prrByLotAndSymptom2"
] ]
}, },
{ {
@@ -699,8 +708,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"symptomCombinations = SymptomCombinationsProvider.generateSymptomCombinations(\n", "symptomCombinations = SymptomCombinationsProvider.generateSymptomCombinations(\n",
" prrByLotAndSymptom[prrByLotAndSymptom.columns[:500]],\n", " prrByLotAndSymptom,\n",
" dataFramePredicate = lambda df: 30 <= len(df) <= 35)" " dataFramePredicate = lambda df: 40 <= len(df) <= 50)"
] ]
}, },
{ {
@@ -709,8 +718,19 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"for symptomCombination in symptomCombinations:\n", "from SymptomsCausedByVaccines.MultiLineFitting.Utils import take\n",
" print(list(symptomCombination.columns))" "\n",
"df = take(symptomCombinations, 1)[0]\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"symptomX, symptomY = df.columns"
] ]
}, },
{ {
@@ -758,7 +778,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"clustersAscending, linesAscending = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)" "clustersAscending, linesAscending = MultiLineFitter.fitPointsByAscendingLines(\n",
" points,\n",
" consensusThreshold = 0.01,\n",
" maxNumLines = None)"
] ]
}, },
{ {
@@ -767,7 +790,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 2)" "draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 5)"
] ]
}, },
{ {
@@ -776,7 +799,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.1)" "clusters, lines = MultiLineFitter.fitPointsByLines(\n",
" points,\n",
" consensusThreshold = 0.01,\n",
" maxNumLines = None)"
] ]
}, },
{ {

View File

@@ -1,16 +1,22 @@
from skspatial.objects import Line from skspatial.objects import Line
from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs, take
class LinesFactory: class LinesFactory:
@staticmethod @staticmethod
def createLines(points): def createLines(points, maxNumLines = None):
return LinesFactory._getUniqueLines(list(LinesFactory._generateAllLines(points))) return LinesFactory._getUniqueLines(
take(
LinesFactory._generateAllLines(points),
maxNumLines))
@staticmethod @staticmethod
def createAscendingLines(points): def createAscendingLines(points, maxNumLines = None):
return LinesFactory._getUniqueLines(list(LinesFactory._generateAllAscendingLines(points))) return LinesFactory._getUniqueLines(
take(
LinesFactory._generateAllAscendingLines(points),
maxNumLines))
@staticmethod @staticmethod
def _generateAllAscendingLines(points): def _generateAllAscendingLines(points):

View File

@@ -16,6 +16,7 @@ class LinesFactoryTest(unittest.TestCase):
self.assertEqual(len(lines), 1) self.assertEqual(len(lines), 1)
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
def test_createLines2(self): def test_createLines2(self):
# Given # Given
points = [(0, 0), (1, 0), (0, 1)] points = [(0, 0), (1, 0), (0, 1)]
@@ -29,6 +30,20 @@ class LinesFactoryTest(unittest.TestCase):
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1]))) self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1]))) self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1])))
def test_createLines_maxNumLines(self):
# Given
points = [(0, 0), (1, 0), (0, 1)]
# When
lines = LinesFactory.createLines(points, maxNumLines = 2)
# Then
self.assertEqual(len(lines), 2)
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
def test_createAscendingLines(self): def test_createAscendingLines(self):
# Given # Given
points = [(0, 0), (1, 0), (0, 1)] points = [(0, 0), (1, 0), (0, 1)]
@@ -40,3 +55,14 @@ class LinesFactoryTest(unittest.TestCase):
self.assertEqual(len(lines), 2) self.assertEqual(len(lines), 2)
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0]))) self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1]))) self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
def test_createAscendingLines_maxNumLines(self):
# Given
points = [(0, 0), (1, 0), (0, 1)]
# When
lines = LinesFactory.createAscendingLines(points, maxNumLines = 1)
# Then
self.assertEqual(len(lines), 1)
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))

View File

@@ -7,12 +7,18 @@ from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import Ch
class MultiLineFitter: class MultiLineFitter:
@staticmethod @staticmethod
def fitPointsByLines(points, consensusThreshold): def fitPointsByLines(points, consensusThreshold, maxNumLines = None):
return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold) return MultiLineFitter.fitLines(
points,
LinesFactory.createLines(points, maxNumLines),
consensusThreshold)
@staticmethod @staticmethod
def fitPointsByAscendingLines(points, consensusThreshold): def fitPointsByAscendingLines(points, consensusThreshold, maxNumLines = None):
return MultiLineFitter.fitLines(points, LinesFactory.createAscendingLines(points), consensusThreshold) return MultiLineFitter.fitLines(
points,
LinesFactory.createAscendingLines(points, maxNumLines),
consensusThreshold)
@staticmethod @staticmethod
def fitLines(points, lines, consensusThreshold): def fitLines(points, lines, consensusThreshold):
@@ -72,7 +78,7 @@ class MultiLineFitter:
def _intersectionOverUnion(setA, setB): def _intersectionOverUnion(setA, setB):
intersection = np.count_nonzero(np.logical_and(setA, setB)) intersection = np.count_nonzero(np.logical_and(setA, setB))
union = np.count_nonzero(np.logical_or(setA, setB)) union = np.count_nonzero(np.logical_or(setA, setB))
return 1. * intersection / union return 1. * intersection / union if intersection > 0.0 else 0
@staticmethod @staticmethod
def _getLines(lines, preferenceMatrix): def _getLines(lines, preferenceMatrix):
@@ -80,7 +86,15 @@ class MultiLineFitter:
@staticmethod @staticmethod
def _getLineIndexes(preferenceMatrix): def _getLineIndexes(preferenceMatrix):
return [list(lines).index(1) for lines in preferenceMatrix] lineIndexes = (MultiLineFitter._index(lines, 1) for lines in preferenceMatrix)
return [lineIndex for lineIndex in lineIndexes if lineIndex is not None]
@staticmethod
def _index(xs, x):
try:
return list(xs).index(x)
except ValueError:
return None
@staticmethod @staticmethod
def _getClusterPoints(points, clusters): def _getClusterPoints(points, clusters):

View File

@@ -1,4 +1,9 @@
import itertools
def generatePairs(n): def generatePairs(n):
for i in range(n): for i in range(n):
for j in range(i): for j in range(i):
yield (i, j) yield (i, j)
def take(iterable, numElements):
return list(itertools.islice(iterable, numElements)) if numElements is not None else list(iterable)