refining LinesFactoryTest
This commit is contained in:
@@ -666,9 +666,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
|
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
|
||||||
"from SymptomsCausedByVaccines.MultiLineFitting.SymptomCombinationsProvider import SymptomCombinationsProvider\n",
|
"from SymptomsCausedByVaccines.MultiLineFitting.SymptomCombinationsProvider import SymptomCombinationsProvider\n",
|
||||||
"import numpy as np\n",
|
"from matplotlib import pyplot as plt\n"
|
||||||
"from matplotlib import pyplot as plt\n",
|
|
||||||
"from skspatial.objects import Line\n"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -677,8 +675,8 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"symptomX = 'Abdominal abscess' # HIV test' # 'Immunosuppression'\n",
|
"# symptomX = 'Abdominal discomfort' # HIV test' # 'Immunosuppression'\n",
|
||||||
"symptomY = 'Abdominal discomfort' # 'Infection' # 'Immunoglobulin therapy'"
|
"# symptomY = 'Abdominal distension' # 'Infection' # 'Immunoglobulin therapy'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -687,9 +685,20 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"df = prrByLotAndSymptom[[symptomX, symptomY]]\n",
|
"# df = prrByLotAndSymptom[[symptomX, symptomY]]\n",
|
||||||
"df = df[(df[symptomX] != 0) & (df[symptomY] != 0)]\n",
|
"# df = df[(df[symptomX] != 0) & (df[symptomY] != 0)]\n",
|
||||||
"df"
|
"# df"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# retain only those columns of prrByLotAndSymptom that have more than 400 PRRs != 0\n",
|
||||||
|
"# prrByLotAndSymptom2 = prrByLotAndSymptom.loc[:, (prrByLotAndSymptom != 0).sum() >= 400]\n",
|
||||||
|
"# prrByLotAndSymptom2"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -699,8 +708,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"symptomCombinations = SymptomCombinationsProvider.generateSymptomCombinations(\n",
|
"symptomCombinations = SymptomCombinationsProvider.generateSymptomCombinations(\n",
|
||||||
" prrByLotAndSymptom[prrByLotAndSymptom.columns[:500]],\n",
|
" prrByLotAndSymptom,\n",
|
||||||
" dataFramePredicate = lambda df: 30 <= len(df) <= 35)"
|
" dataFramePredicate = lambda df: 40 <= len(df) <= 50)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -709,8 +718,19 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"for symptomCombination in symptomCombinations:\n",
|
"from SymptomsCausedByVaccines.MultiLineFitting.Utils import take\n",
|
||||||
" print(list(symptomCombination.columns))"
|
"\n",
|
||||||
|
"df = take(symptomCombinations, 1)[0]\n",
|
||||||
|
"df"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"symptomX, symptomY = df.columns"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -758,7 +778,10 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"clustersAscending, linesAscending = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)"
|
"clustersAscending, linesAscending = MultiLineFitter.fitPointsByAscendingLines(\n",
|
||||||
|
" points,\n",
|
||||||
|
" consensusThreshold = 0.01,\n",
|
||||||
|
" maxNumLines = None)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -767,7 +790,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 2)"
|
"draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 5)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -776,7 +799,10 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.1)"
|
"clusters, lines = MultiLineFitter.fitPointsByLines(\n",
|
||||||
|
" points,\n",
|
||||||
|
" consensusThreshold = 0.01,\n",
|
||||||
|
" maxNumLines = None)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1,16 +1,22 @@
|
|||||||
from skspatial.objects import Line
|
from skspatial.objects import Line
|
||||||
from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs
|
from SymptomsCausedByVaccines.MultiLineFitting.Utils import generatePairs, take
|
||||||
|
|
||||||
|
|
||||||
class LinesFactory:
|
class LinesFactory:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def createLines(points):
|
def createLines(points, maxNumLines = None):
|
||||||
return LinesFactory._getUniqueLines(list(LinesFactory._generateAllLines(points)))
|
return LinesFactory._getUniqueLines(
|
||||||
|
take(
|
||||||
|
LinesFactory._generateAllLines(points),
|
||||||
|
maxNumLines))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def createAscendingLines(points):
|
def createAscendingLines(points, maxNumLines = None):
|
||||||
return LinesFactory._getUniqueLines(list(LinesFactory._generateAllAscendingLines(points)))
|
return LinesFactory._getUniqueLines(
|
||||||
|
take(
|
||||||
|
LinesFactory._generateAllAscendingLines(points),
|
||||||
|
maxNumLines))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generateAllAscendingLines(points):
|
def _generateAllAscendingLines(points):
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class LinesFactoryTest(unittest.TestCase):
|
|||||||
self.assertEqual(len(lines), 1)
|
self.assertEqual(len(lines), 1)
|
||||||
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
|
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
|
||||||
|
|
||||||
|
|
||||||
def test_createLines2(self):
|
def test_createLines2(self):
|
||||||
# Given
|
# Given
|
||||||
points = [(0, 0), (1, 0), (0, 1)]
|
points = [(0, 0), (1, 0), (0, 1)]
|
||||||
@@ -29,6 +30,20 @@ class LinesFactoryTest(unittest.TestCase):
|
|||||||
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
|
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
|
||||||
self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1])))
|
self.assertTrue(lines[2].is_close(Line(point = [0, 1], direction = [1, -1])))
|
||||||
|
|
||||||
|
|
||||||
|
def test_createLines_maxNumLines(self):
|
||||||
|
# Given
|
||||||
|
points = [(0, 0), (1, 0), (0, 1)]
|
||||||
|
|
||||||
|
# When
|
||||||
|
lines = LinesFactory.createLines(points, maxNumLines = 2)
|
||||||
|
|
||||||
|
# Then
|
||||||
|
self.assertEqual(len(lines), 2)
|
||||||
|
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
|
||||||
|
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
|
||||||
|
|
||||||
|
|
||||||
def test_createAscendingLines(self):
|
def test_createAscendingLines(self):
|
||||||
# Given
|
# Given
|
||||||
points = [(0, 0), (1, 0), (0, 1)]
|
points = [(0, 0), (1, 0), (0, 1)]
|
||||||
@@ -40,3 +55,14 @@ class LinesFactoryTest(unittest.TestCase):
|
|||||||
self.assertEqual(len(lines), 2)
|
self.assertEqual(len(lines), 2)
|
||||||
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
|
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
|
||||||
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
|
self.assertTrue(lines[1].is_close(Line(point = [0, 0], direction = [0, 1])))
|
||||||
|
|
||||||
|
def test_createAscendingLines_maxNumLines(self):
|
||||||
|
# Given
|
||||||
|
points = [(0, 0), (1, 0), (0, 1)]
|
||||||
|
|
||||||
|
# When
|
||||||
|
lines = LinesFactory.createAscendingLines(points, maxNumLines = 1)
|
||||||
|
|
||||||
|
# Then
|
||||||
|
self.assertEqual(len(lines), 1)
|
||||||
|
self.assertTrue(lines[0].is_close(Line(point = [0, 0], direction = [1, 0])))
|
||||||
|
|||||||
@@ -7,12 +7,18 @@ from SymptomsCausedByVaccines.MultiLineFitting.CharacteristicFunctions import Ch
|
|||||||
class MultiLineFitter:
|
class MultiLineFitter:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fitPointsByLines(points, consensusThreshold):
|
def fitPointsByLines(points, consensusThreshold, maxNumLines = None):
|
||||||
return MultiLineFitter.fitLines(points, LinesFactory.createLines(points), consensusThreshold)
|
return MultiLineFitter.fitLines(
|
||||||
|
points,
|
||||||
|
LinesFactory.createLines(points, maxNumLines),
|
||||||
|
consensusThreshold)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fitPointsByAscendingLines(points, consensusThreshold):
|
def fitPointsByAscendingLines(points, consensusThreshold, maxNumLines = None):
|
||||||
return MultiLineFitter.fitLines(points, LinesFactory.createAscendingLines(points), consensusThreshold)
|
return MultiLineFitter.fitLines(
|
||||||
|
points,
|
||||||
|
LinesFactory.createAscendingLines(points, maxNumLines),
|
||||||
|
consensusThreshold)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def fitLines(points, lines, consensusThreshold):
|
def fitLines(points, lines, consensusThreshold):
|
||||||
@@ -72,7 +78,7 @@ class MultiLineFitter:
|
|||||||
def _intersectionOverUnion(setA, setB):
|
def _intersectionOverUnion(setA, setB):
|
||||||
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
intersection = np.count_nonzero(np.logical_and(setA, setB))
|
||||||
union = np.count_nonzero(np.logical_or(setA, setB))
|
union = np.count_nonzero(np.logical_or(setA, setB))
|
||||||
return 1. * intersection / union
|
return 1. * intersection / union if intersection > 0.0 else 0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _getLines(lines, preferenceMatrix):
|
def _getLines(lines, preferenceMatrix):
|
||||||
@@ -80,7 +86,15 @@ class MultiLineFitter:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _getLineIndexes(preferenceMatrix):
|
def _getLineIndexes(preferenceMatrix):
|
||||||
return [list(lines).index(1) for lines in preferenceMatrix]
|
lineIndexes = (MultiLineFitter._index(lines, 1) for lines in preferenceMatrix)
|
||||||
|
return [lineIndex for lineIndex in lineIndexes if lineIndex is not None]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _index(xs, x):
|
||||||
|
try:
|
||||||
|
return list(xs).index(x)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _getClusterPoints(points, clusters):
|
def _getClusterPoints(points, clusters):
|
||||||
|
|||||||
@@ -1,4 +1,9 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
def generatePairs(n):
|
def generatePairs(n):
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
for j in range(i):
|
for j in range(i):
|
||||||
yield (i, j)
|
yield (i, j)
|
||||||
|
|
||||||
|
def take(iterable, numElements):
|
||||||
|
return list(itertools.islice(iterable, numElements)) if numElements is not None else list(iterable)
|
||||||
|
|||||||
Reference in New Issue
Block a user