MultiLineFitterTest
This commit is contained in:
@@ -651,6 +651,214 @@
|
|||||||
" htmlFile = os.path.normpath(webAppBaseDir + '/index.html'))"
|
" htmlFile = os.path.normpath(webAppBaseDir + '/index.html'))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# https://scikit-learn.org/stable/auto_examples/linear_model/plot_ransac.html\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from matplotlib import pyplot as plt\n",
|
||||||
|
"from sklearn import linear_model\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"symptomX = 'Immunosuppression'\n",
|
||||||
|
"symptomY = 'Infection' # 'Immunoglobulin therapy'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"df = prrByLotAndSymptom[[symptomX, symptomY]]\n",
|
||||||
|
"df = df[(df[symptomX] != 0) & (df[symptomY] != 0)]\n",
|
||||||
|
"df"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"points = [(x, y) for [x, y] in df.values]\n",
|
||||||
|
"points"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
|
||||||
|
"\n",
|
||||||
|
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"clusters"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"from skspatial.objects import Line\n",
|
||||||
|
"\n",
|
||||||
|
"_, ax = plt.subplots()\n",
|
||||||
|
"for line in lines:\n",
|
||||||
|
" line.plot_2d(ax, label = \"line\")\n",
|
||||||
|
"for cluster in clusters:\n",
|
||||||
|
" plt.scatter(\n",
|
||||||
|
" [x for (x, _) in cluster],\n",
|
||||||
|
" [y for (_, y) in cluster],\n",
|
||||||
|
" marker = \".\",\n",
|
||||||
|
" s = 100,\n",
|
||||||
|
" label = \"Cluster\")\n",
|
||||||
|
"# plt.scatter(\n",
|
||||||
|
"# [x for (x, _) in points],\n",
|
||||||
|
"# [y for (_, y) in points],\n",
|
||||||
|
"# color = \"blue\",\n",
|
||||||
|
"# marker = \".\",\n",
|
||||||
|
"# s = 100,\n",
|
||||||
|
"# label = \"Dots\")\n",
|
||||||
|
"plt.xlabel(symptomX)\n",
|
||||||
|
"plt.ylabel(symptomY)\n",
|
||||||
|
"plt.legend(loc=\"lower right\")\n",
|
||||||
|
"plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Fit line using all data\n",
|
||||||
|
"lr = linear_model.LinearRegression()\n",
|
||||||
|
"lr.fit(X, y)\n",
|
||||||
|
"\n",
|
||||||
|
"# Robustly fit linear model with RANSAC algorithm\n",
|
||||||
|
"ransac = linear_model.RANSACRegressor(random_state = 0)\n",
|
||||||
|
"ransac.fit(X, y)\n",
|
||||||
|
"inlier_mask = ransac.inlier_mask_\n",
|
||||||
|
"outlier_mask = np.logical_not(inlier_mask)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X2 = X[outlier_mask]\n",
|
||||||
|
"y2 = y[outlier_mask]\n",
|
||||||
|
"ransac2 = linear_model.RANSACRegressor(random_state = 0)\n",
|
||||||
|
"ransac2.fit(X2, y2)\n",
|
||||||
|
"inlier_mask2 = ransac2.inlier_mask_\n",
|
||||||
|
"outlier_mask2 = np.logical_not(inlier_mask2)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X3 = X2[outlier_mask2]\n",
|
||||||
|
"y3 = y2[outlier_mask2]\n",
|
||||||
|
"ransac3 = linear_model.RANSACRegressor(random_state = 0)\n",
|
||||||
|
"ransac3.fit(X3, y3)\n",
|
||||||
|
"inlier_mask3 = ransac3.inlier_mask_\n",
|
||||||
|
"outlier_mask3 = np.logical_not(inlier_mask3)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X_ransac_list = [(X, ransac), (X2, ransac2), (X3, ransac3)]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from matplotlib.pyplot import figure\n",
|
||||||
|
"figure(figsize=(8, 6), dpi=80)\n",
|
||||||
|
"\n",
|
||||||
|
"def plotRANSACResult(X, y, X_ransac_list):\n",
|
||||||
|
" figure(figsize=(8, 6), dpi=80)\n",
|
||||||
|
" plt.scatter(X, y, color=\"yellowgreen\", marker=\".\", label=\"Dots\")\n",
|
||||||
|
" for (X, ransac) in X_ransac_list:\n",
|
||||||
|
" line_X = np.arange(X.min(), X.max())[:, np.newaxis]\n",
|
||||||
|
" line_y_ransac = ransac.predict(line_X)\n",
|
||||||
|
" plt.plot(\n",
|
||||||
|
" line_X,\n",
|
||||||
|
" line_y_ransac,\n",
|
||||||
|
" color = \"cornflowerblue\",\n",
|
||||||
|
" linewidth = 2,\n",
|
||||||
|
" label = \"RANSAC regressor\")\n",
|
||||||
|
" plt.legend(loc=\"lower right\")\n",
|
||||||
|
" plt.xlabel(symptomX)\n",
|
||||||
|
" plt.ylabel(symptomY)\n",
|
||||||
|
" plt.show()\n",
|
||||||
|
"\n",
|
||||||
|
"# Predict data of estimated models\n",
|
||||||
|
"line_X = np.arange(X.min(), X.max())[:, np.newaxis]\n",
|
||||||
|
"line_y = lr.predict(line_X)\n",
|
||||||
|
"line_y_ransac = ransac.predict(line_X)\n",
|
||||||
|
"\n",
|
||||||
|
"# Compare estimated coefficients\n",
|
||||||
|
"print(\"Estimated coefficients (true, linear regression, RANSAC):\")\n",
|
||||||
|
"print(lr.coef_, ransac.estimator_.coef_)\n",
|
||||||
|
"\n",
|
||||||
|
"lw = 2\n",
|
||||||
|
"plt.scatter(X, y, color=\"blue\", marker=\".\", label=\"Inliers\")\n",
|
||||||
|
"#plt.plot(line_X, line_y, color=\"navy\", linewidth=lw, label=\"Linear regressor\")\n",
|
||||||
|
"#plt.plot(\n",
|
||||||
|
"# line_X,\n",
|
||||||
|
"# line_y_ransac,\n",
|
||||||
|
"# color=\"cornflowerblue\",\n",
|
||||||
|
"# linewidth=lw,\n",
|
||||||
|
"# label=\"RANSAC regressor\")\n",
|
||||||
|
"plt.legend(loc=\"lower right\")\n",
|
||||||
|
"plt.xlabel(symptomX)\n",
|
||||||
|
"plt.ylabel(symptomY)\n",
|
||||||
|
"plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"plotRANSACResult(X, y, X_ransac_list)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -659,9 +867,9 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "howbadismybatch-venv",
|
"display_name": "howbadismybatch-venv-kernel",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "howbadismybatch-venv-kernel"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
|
|||||||
@@ -12,9 +12,10 @@ class MultiLineFitter:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def fitLines(points, lines, consensusThreshold):
|
def fitLines(points, lines, consensusThreshold):
|
||||||
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold)
|
||||||
_, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix)
|
clusters, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix)
|
||||||
lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix4Clusters)
|
return (
|
||||||
return np.array(lines)[lineIndexes]
|
MultiLineFitter._getClusterPoints(points, clusters),
|
||||||
|
MultiLineFitter._getLines(lines, preferenceMatrix4Clusters))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
def _createPreferenceMatrix(points, lines, consensusThreshold):
|
||||||
@@ -56,6 +57,15 @@ class MultiLineFitter:
|
|||||||
union = np.count_nonzero(np.logical_or(setA, setB))
|
union = np.count_nonzero(np.logical_or(setA, setB))
|
||||||
return 1. * intersection / union
|
return 1. * intersection / union
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _getLines(lines, preferenceMatrix):
|
||||||
|
return np.array(lines)[MultiLineFitter._getLineIndexes(preferenceMatrix)]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _getLineIndexes(preferenceMatrix):
|
def _getLineIndexes(preferenceMatrix):
|
||||||
return [list(lines).index(1) for lines in preferenceMatrix]
|
return [list(lines).index(1) for lines in preferenceMatrix]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _getClusterPoints(points, clusters):
|
||||||
|
sortedClusters = [sorted(cluster) for cluster in clusters]
|
||||||
|
return [list(np.array(points)[cluster]) for cluster in sortedClusters]
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ class MultiLineFitterTest(unittest.TestCase):
|
|||||||
])
|
])
|
||||||
|
|
||||||
# When
|
# When
|
||||||
clusters, _ = MultiLineFitter._createClusters(preferenceMatrix)
|
clusters, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(
|
np.testing.assert_array_equal(
|
||||||
@@ -87,6 +87,13 @@ class MultiLineFitterTest(unittest.TestCase):
|
|||||||
[2, 1, 0],
|
[2, 1, 0],
|
||||||
[4, 3]
|
[4, 3]
|
||||||
]))
|
]))
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
preferenceMatrix4Clusters,
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
[1, 0],
|
||||||
|
[0, 1]
|
||||||
|
]))
|
||||||
|
|
||||||
def test_getLineIndexes(self):
|
def test_getLineIndexes(self):
|
||||||
# Given
|
# Given
|
||||||
@@ -110,19 +117,36 @@ class MultiLineFitterTest(unittest.TestCase):
|
|||||||
line3 = Line.from_points([0, 0], [0, 1])
|
line3 = Line.from_points([0, 0], [0, 1])
|
||||||
|
|
||||||
# When
|
# When
|
||||||
fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2, line3], consensusThreshold = 0.001)
|
clusters, fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2, line3], consensusThreshold = 0.001)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
np.testing.assert_array_equal(fittedLines, [line1, line2])
|
np.testing.assert_array_equal(
|
||||||
|
fittedLines,
|
||||||
|
[
|
||||||
|
line1,
|
||||||
|
line2
|
||||||
|
])
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
clusters,
|
||||||
|
[
|
||||||
|
[(1, 0), (2, 0), (3, 0)],
|
||||||
|
[(1, 1), (2, 2), (3, 3)]
|
||||||
|
])
|
||||||
|
|
||||||
def test_fitPointsByLines(self):
|
def test_fitPointsByLines(self):
|
||||||
# Given
|
# Given
|
||||||
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
|
points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)]
|
||||||
|
|
||||||
# When
|
# When
|
||||||
lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)
|
clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)
|
||||||
|
|
||||||
# Then
|
# Then
|
||||||
self.assertEqual(len(lines), 2)
|
self.assertEqual(len(lines), 2)
|
||||||
self.assertTrue(lines[0].is_close(Line.from_points([0, 0], [1, 0])))
|
self.assertTrue(lines[0].is_close(Line.from_points([0, 0], [1, 0])))
|
||||||
self.assertTrue(lines[1].is_close(Line.from_points([0, 0], [1, 1])))
|
self.assertTrue(lines[1].is_close(Line.from_points([0, 0], [1, 1])))
|
||||||
|
np.testing.assert_array_equal(
|
||||||
|
clusters,
|
||||||
|
[
|
||||||
|
[(1, 0), (2, 0), (3, 0)],
|
||||||
|
[(1, 1), (2, 2), (3, 3)]
|
||||||
|
])
|
||||||
|
|||||||
Reference in New Issue
Block a user