plotting fitted lines
This commit is contained in:
@@ -705,15 +705,6 @@
|
||||
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"clusters"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -724,141 +715,28 @@
|
||||
"from skspatial.objects import Line\n",
|
||||
"\n",
|
||||
"_, ax = plt.subplots()\n",
|
||||
"for line in lines:\n",
|
||||
" line.plot_2d(ax, label = \"line\")\n",
|
||||
"for cluster in clusters:\n",
|
||||
" plt.scatter(\n",
|
||||
" [x for (x, _) in cluster],\n",
|
||||
" [y for (_, y) in cluster],\n",
|
||||
" marker = \".\",\n",
|
||||
" s = 100,\n",
|
||||
" label = \"Cluster\")\n",
|
||||
"# plt.scatter(\n",
|
||||
"# [x for (x, _) in points],\n",
|
||||
"# [y for (_, y) in points],\n",
|
||||
"# color = \"blue\",\n",
|
||||
"# marker = \".\",\n",
|
||||
"# s = 100,\n",
|
||||
"# label = \"Dots\")\n",
|
||||
"plt.xlabel(symptomX)\n",
|
||||
"plt.ylabel(symptomY)\n",
|
||||
"plt.legend(loc=\"lower right\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Fit line using all data\n",
|
||||
"lr = linear_model.LinearRegression()\n",
|
||||
"lr.fit(X, y)\n",
|
||||
"\n",
|
||||
"# Robustly fit linear model with RANSAC algorithm\n",
|
||||
"ransac = linear_model.RANSACRegressor(random_state = 0)\n",
|
||||
"ransac.fit(X, y)\n",
|
||||
"inlier_mask = ransac.inlier_mask_\n",
|
||||
"outlier_mask = np.logical_not(inlier_mask)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X2 = X[outlier_mask]\n",
|
||||
"y2 = y[outlier_mask]\n",
|
||||
"ransac2 = linear_model.RANSACRegressor(random_state = 0)\n",
|
||||
"ransac2.fit(X2, y2)\n",
|
||||
"inlier_mask2 = ransac2.inlier_mask_\n",
|
||||
"outlier_mask2 = np.logical_not(inlier_mask2)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X3 = X2[outlier_mask2]\n",
|
||||
"y3 = y2[outlier_mask2]\n",
|
||||
"ransac3 = linear_model.RANSACRegressor(random_state = 0)\n",
|
||||
"ransac3.fit(X3, y3)\n",
|
||||
"inlier_mask3 = ransac3.inlier_mask_\n",
|
||||
"outlier_mask3 = np.logical_not(inlier_mask3)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_ransac_list = [(X, ransac), (X2, ransac2), (X3, ransac3)]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from matplotlib.pyplot import figure\n",
|
||||
"figure(figsize=(8, 6), dpi=80)\n",
|
||||
"\n",
|
||||
"def plotRANSACResult(X, y, X_ransac_list):\n",
|
||||
" figure(figsize=(8, 6), dpi=80)\n",
|
||||
" plt.scatter(X, y, color=\"yellowgreen\", marker=\".\", label=\"Dots\")\n",
|
||||
" for (X, ransac) in X_ransac_list:\n",
|
||||
" line_X = np.arange(X.min(), X.max())[:, np.newaxis]\n",
|
||||
" line_y_ransac = ransac.predict(line_X)\n",
|
||||
" plt.plot(\n",
|
||||
" line_X,\n",
|
||||
" line_y_ransac,\n",
|
||||
" color = \"cornflowerblue\",\n",
|
||||
" linewidth = 2,\n",
|
||||
" label = \"RANSAC regressor\")\n",
|
||||
" plt.legend(loc=\"lower right\")\n",
|
||||
" plt.xlabel(symptomX)\n",
|
||||
" plt.ylabel(symptomY)\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"# Predict data of estimated models\n",
|
||||
"line_X = np.arange(X.min(), X.max())[:, np.newaxis]\n",
|
||||
"line_y = lr.predict(line_X)\n",
|
||||
"line_y_ransac = ransac.predict(line_X)\n",
|
||||
"\n",
|
||||
"# Compare estimated coefficients\n",
|
||||
"print(\"Estimated coefficients (true, linear regression, RANSAC):\")\n",
|
||||
"print(lr.coef_, ransac.estimator_.coef_)\n",
|
||||
"\n",
|
||||
"lw = 2\n",
|
||||
"plt.scatter(X, y, color=\"blue\", marker=\".\", label=\"Inliers\")\n",
|
||||
"#plt.plot(line_X, line_y, color=\"navy\", linewidth=lw, label=\"Linear regressor\")\n",
|
||||
"#plt.plot(\n",
|
||||
"# line_X,\n",
|
||||
"# line_y_ransac,\n",
|
||||
"# color=\"cornflowerblue\",\n",
|
||||
"# linewidth=lw,\n",
|
||||
"# label=\"RANSAC regressor\")\n",
|
||||
"plt.legend(loc=\"lower right\")\n",
|
||||
"plt.scatter(\n",
|
||||
" [x for (x, _) in points],\n",
|
||||
" [y for (_, y) in points],\n",
|
||||
" color = \"blue\",\n",
|
||||
" marker = \".\",\n",
|
||||
" s = 100,\n",
|
||||
" label = \"Dots\")\n",
|
||||
"for cluster, line in zip(clusters, lines):\n",
|
||||
" if(len(cluster) > 2):\n",
|
||||
" coords = line.transform_points(cluster)\n",
|
||||
" magnitude = line.direction.norm()\n",
|
||||
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",
|
||||
" plt.scatter(\n",
|
||||
" [x for (x, _) in cluster],\n",
|
||||
" [y for (_, y) in cluster],\n",
|
||||
" marker = \".\",\n",
|
||||
" s = 100)\n",
|
||||
"plt.xlabel(symptomX)\n",
|
||||
"plt.ylabel(symptomY)\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plotRANSACResult(X, y, X_ransac_list)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
||||
Reference in New Issue
Block a user