From ee524ef03653a69df05a582ddc7ebddd925163bc Mon Sep 17 00:00:00 2001 From: frankknoll Date: Fri, 17 Nov 2023 18:42:38 +0100 Subject: [PATCH] plotting fitted lines --- src/HowBadIsMyBatch.ipynb | 156 +++++--------------------------------- 1 file changed, 17 insertions(+), 139 deletions(-) diff --git a/src/HowBadIsMyBatch.ipynb b/src/HowBadIsMyBatch.ipynb index d550c87d619..30b8b89c495 100644 --- a/src/HowBadIsMyBatch.ipynb +++ b/src/HowBadIsMyBatch.ipynb @@ -705,15 +705,6 @@ "clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "clusters" - ] - }, { "cell_type": "code", "execution_count": null, @@ -724,141 +715,28 @@ "from skspatial.objects import Line\n", "\n", "_, ax = plt.subplots()\n", - "for line in lines:\n", - " line.plot_2d(ax, label = \"line\")\n", - "for cluster in clusters:\n", - " plt.scatter(\n", - " [x for (x, _) in cluster],\n", - " [y for (_, y) in cluster],\n", - " marker = \".\",\n", - " s = 100,\n", - " label = \"Cluster\")\n", - "# plt.scatter(\n", - "# [x for (x, _) in points],\n", - "# [y for (_, y) in points],\n", - "# color = \"blue\",\n", - "# marker = \".\",\n", - "# s = 100,\n", - "# label = \"Dots\")\n", - "plt.xlabel(symptomX)\n", - "plt.ylabel(symptomY)\n", - "plt.legend(loc=\"lower right\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Fit line using all data\n", - "lr = linear_model.LinearRegression()\n", - "lr.fit(X, y)\n", - "\n", - "# Robustly fit linear model with RANSAC algorithm\n", - "ransac = linear_model.RANSACRegressor(random_state = 0)\n", - "ransac.fit(X, y)\n", - "inlier_mask = ransac.inlier_mask_\n", - "outlier_mask = np.logical_not(inlier_mask)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "X2 = X[outlier_mask]\n", - "y2 = y[outlier_mask]\n", - "ransac2 = linear_model.RANSACRegressor(random_state = 0)\n", - "ransac2.fit(X2, y2)\n", - "inlier_mask2 = ransac2.inlier_mask_\n", - "outlier_mask2 = np.logical_not(inlier_mask2)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "X3 = X2[outlier_mask2]\n", - "y3 = y2[outlier_mask2]\n", - "ransac3 = linear_model.RANSACRegressor(random_state = 0)\n", - "ransac3.fit(X3, y3)\n", - "inlier_mask3 = ransac3.inlier_mask_\n", - "outlier_mask3 = np.logical_not(inlier_mask3)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "X_ransac_list = [(X, ransac), (X2, ransac2), (X3, ransac3)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from matplotlib.pyplot import figure\n", - "figure(figsize=(8, 6), dpi=80)\n", - "\n", - "def plotRANSACResult(X, y, X_ransac_list):\n", - " figure(figsize=(8, 6), dpi=80)\n", - " plt.scatter(X, y, color=\"yellowgreen\", marker=\".\", label=\"Dots\")\n", - " for (X, ransac) in X_ransac_list:\n", - " line_X = np.arange(X.min(), X.max())[:, np.newaxis]\n", - " line_y_ransac = ransac.predict(line_X)\n", - " plt.plot(\n", - " line_X,\n", - " line_y_ransac,\n", - " color = \"cornflowerblue\",\n", - " linewidth = 2,\n", - " label = \"RANSAC regressor\")\n", - " plt.legend(loc=\"lower right\")\n", - " plt.xlabel(symptomX)\n", - " plt.ylabel(symptomY)\n", - " plt.show()\n", - "\n", - "# Predict data of estimated models\n", - "line_X = np.arange(X.min(), X.max())[:, np.newaxis]\n", - "line_y = lr.predict(line_X)\n", - "line_y_ransac = ransac.predict(line_X)\n", - "\n", - "# Compare estimated coefficients\n", - "print(\"Estimated coefficients (true, linear regression, RANSAC):\")\n", - "print(lr.coef_, ransac.estimator_.coef_)\n", - "\n", - "lw = 2\n", - "plt.scatter(X, y, color=\"blue\", marker=\".\", label=\"Inliers\")\n", - "#plt.plot(line_X, line_y, color=\"navy\", linewidth=lw, label=\"Linear regressor\")\n", - "#plt.plot(\n", - "# line_X,\n", - "# line_y_ransac,\n", - "# color=\"cornflowerblue\",\n", - "# linewidth=lw,\n", - "# label=\"RANSAC regressor\")\n", - "plt.legend(loc=\"lower right\")\n", + "plt.scatter(\n", + " [x for (x, _) in points],\n", + " [y for (_, y) in points],\n", + " color = \"blue\",\n", + " marker = \".\",\n", + " s = 100,\n", + " label = \"Dots\")\n", + "for cluster, line in zip(clusters, lines):\n", + " if(len(cluster) > 2):\n", + " coords = line.transform_points(cluster)\n", + " magnitude = line.direction.norm()\n", + " line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n", + " plt.scatter(\n", + " [x for (x, _) in cluster],\n", + " [y for (_, y) in cluster],\n", + " marker = \".\",\n", + " s = 100)\n", "plt.xlabel(symptomX)\n", "plt.ylabel(symptomY)\n", "plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plotRANSACResult(X, y, X_ransac_list)" - ] - }, { "cell_type": "markdown", "metadata": {},