diff --git a/src/HowBadIsMyBatch.ipynb b/src/HowBadIsMyBatch.ipynb index 046e6bdff68..d550c87d619 100644 --- a/src/HowBadIsMyBatch.ipynb +++ b/src/HowBadIsMyBatch.ipynb @@ -651,6 +651,214 @@ " htmlFile = os.path.normpath(webAppBaseDir + '/index.html'))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# https://scikit-learn.org/stable/auto_examples/linear_model/plot_ransac.html\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "from sklearn import linear_model\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "symptomX = 'Immunosuppression'\n", + "symptomY = 'Infection' # 'Immunoglobulin therapy'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = prrByLotAndSymptom[[symptomX, symptomY]]\n", + "df = df[(df[symptomX] != 0) & (df[symptomY] != 0)]\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "points = [(x, y) for [x, y] in df.values]\n", + "points" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n", + "\n", + "clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clusters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "from skspatial.objects import Line\n", + "\n", + "_, ax = plt.subplots()\n", + "for line in lines:\n", + " line.plot_2d(ax, label = \"line\")\n", + "for cluster in clusters:\n", + " plt.scatter(\n", + " [x for (x, _) in cluster],\n", + " [y for (_, y) in cluster],\n", + " marker = \".\",\n", + " s = 100,\n", + " label = \"Cluster\")\n", + "# plt.scatter(\n", + "# [x for (x, _) in points],\n", + "# [y for (_, y) in points],\n", + "# color = \"blue\",\n", + "# marker = \".\",\n", + "# s = 100,\n", + "# label = \"Dots\")\n", + "plt.xlabel(symptomX)\n", + "plt.ylabel(symptomY)\n", + "plt.legend(loc=\"lower right\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Fit line using all data\n", + "lr = linear_model.LinearRegression()\n", + "lr.fit(X, y)\n", + "\n", + "# Robustly fit linear model with RANSAC algorithm\n", + "ransac = linear_model.RANSACRegressor(random_state = 0)\n", + "ransac.fit(X, y)\n", + "inlier_mask = ransac.inlier_mask_\n", + "outlier_mask = np.logical_not(inlier_mask)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X2 = X[outlier_mask]\n", + "y2 = y[outlier_mask]\n", + "ransac2 = linear_model.RANSACRegressor(random_state = 0)\n", + "ransac2.fit(X2, y2)\n", + "inlier_mask2 = ransac2.inlier_mask_\n", + "outlier_mask2 = np.logical_not(inlier_mask2)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X3 = X2[outlier_mask2]\n", + "y3 = y2[outlier_mask2]\n", + "ransac3 = linear_model.RANSACRegressor(random_state = 0)\n", + "ransac3.fit(X3, y3)\n", + "inlier_mask3 = ransac3.inlier_mask_\n", + "outlier_mask3 = np.logical_not(inlier_mask3)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X_ransac_list = [(X, ransac), (X2, ransac2), (X3, ransac3)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib.pyplot import figure\n", + "figure(figsize=(8, 6), dpi=80)\n", + "\n", + "def plotRANSACResult(X, y, X_ransac_list):\n", + " figure(figsize=(8, 6), dpi=80)\n", + " plt.scatter(X, y, color=\"yellowgreen\", marker=\".\", label=\"Dots\")\n", + " for (X, ransac) in X_ransac_list:\n", + " line_X = np.arange(X.min(), X.max())[:, np.newaxis]\n", + " line_y_ransac = ransac.predict(line_X)\n", + " plt.plot(\n", + " line_X,\n", + " line_y_ransac,\n", + " color = \"cornflowerblue\",\n", + " linewidth = 2,\n", + " label = \"RANSAC regressor\")\n", + " plt.legend(loc=\"lower right\")\n", + " plt.xlabel(symptomX)\n", + " plt.ylabel(symptomY)\n", + " plt.show()\n", + "\n", + "# Predict data of estimated models\n", + "line_X = np.arange(X.min(), X.max())[:, np.newaxis]\n", + "line_y = lr.predict(line_X)\n", + "line_y_ransac = ransac.predict(line_X)\n", + "\n", + "# Compare estimated coefficients\n", + "print(\"Estimated coefficients (true, linear regression, RANSAC):\")\n", + "print(lr.coef_, ransac.estimator_.coef_)\n", + "\n", + "lw = 2\n", + "plt.scatter(X, y, color=\"blue\", marker=\".\", label=\"Inliers\")\n", + "#plt.plot(line_X, line_y, color=\"navy\", linewidth=lw, label=\"Linear regressor\")\n", + "#plt.plot(\n", + "# line_X,\n", + "# line_y_ransac,\n", + "# color=\"cornflowerblue\",\n", + "# linewidth=lw,\n", + "# label=\"RANSAC regressor\")\n", + "plt.legend(loc=\"lower right\")\n", + "plt.xlabel(symptomX)\n", + "plt.ylabel(symptomY)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotRANSACResult(X, y, X_ransac_list)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -659,9 +867,9 @@ ], "metadata": { "kernelspec": { - "display_name": "howbadismybatch-venv", + "display_name": "howbadismybatch-venv-kernel", "language": "python", - "name": "python3" + "name": "howbadismybatch-venv-kernel" }, "language_info": { "codemirror_mode": { diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py index a370fc4672b..5e76b668030 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitter.py @@ -12,9 +12,10 @@ class MultiLineFitter: @staticmethod def fitLines(points, lines, consensusThreshold): preferenceMatrix = MultiLineFitter._createPreferenceMatrix(points, lines, consensusThreshold) - _, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix) - lineIndexes = MultiLineFitter._getLineIndexes(preferenceMatrix4Clusters) - return np.array(lines)[lineIndexes] + clusters, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix) + return ( + MultiLineFitter._getClusterPoints(points, clusters), + MultiLineFitter._getLines(lines, preferenceMatrix4Clusters)) @staticmethod def _createPreferenceMatrix(points, lines, consensusThreshold): @@ -56,6 +57,15 @@ class MultiLineFitter: union = np.count_nonzero(np.logical_or(setA, setB)) return 1. * intersection / union + @staticmethod + def _getLines(lines, preferenceMatrix): + return np.array(lines)[MultiLineFitter._getLineIndexes(preferenceMatrix)] + @staticmethod def _getLineIndexes(preferenceMatrix): return [list(lines).index(1) for lines in preferenceMatrix] + + @staticmethod + def _getClusterPoints(points, clusters): + sortedClusters = [sorted(cluster) for cluster in clusters] + return [list(np.array(points)[cluster]) for cluster in sortedClusters] diff --git a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py index 081a920a3b0..15dda8d0a98 100644 --- a/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py +++ b/src/SymptomsCausedByVaccines/MultiLineFitting/MultiLineFitterTest.py @@ -77,7 +77,7 @@ class MultiLineFitterTest(unittest.TestCase): ]) # When - clusters, _ = MultiLineFitter._createClusters(preferenceMatrix) + clusters, preferenceMatrix4Clusters = MultiLineFitter._createClusters(preferenceMatrix) # Then np.testing.assert_array_equal( @@ -87,6 +87,13 @@ class MultiLineFitterTest(unittest.TestCase): [2, 1, 0], [4, 3] ])) + np.testing.assert_array_equal( + preferenceMatrix4Clusters, + np.array( + [ + [1, 0], + [0, 1] + ])) def test_getLineIndexes(self): # Given @@ -110,19 +117,36 @@ class MultiLineFitterTest(unittest.TestCase): line3 = Line.from_points([0, 0], [0, 1]) # When - fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2, line3], consensusThreshold = 0.001) + clusters, fittedLines = MultiLineFitter.fitLines(points, lines = [line1, line2, line3], consensusThreshold = 0.001) # Then - np.testing.assert_array_equal(fittedLines, [line1, line2]) + np.testing.assert_array_equal( + fittedLines, + [ + line1, + line2 + ]) + np.testing.assert_array_equal( + clusters, + [ + [(1, 0), (2, 0), (3, 0)], + [(1, 1), (2, 2), (3, 3)] + ]) def test_fitPointsByLines(self): # Given points = [(1, 0), (2, 0), (3, 0), (1, 1), (2, 2), (3, 3)] # When - lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001) + clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001) # Then self.assertEqual(len(lines), 2) self.assertTrue(lines[0].is_close(Line.from_points([0, 0], [1, 0]))) self.assertTrue(lines[1].is_close(Line.from_points([0, 0], [1, 1]))) + np.testing.assert_array_equal( + clusters, + [ + [(1, 0), (2, 0), (3, 0)], + [(1, 1), (2, 2), (3, 3)] + ])