diff --git a/src/HowBadIsMyBatch.ipynb b/src/HowBadIsMyBatch.ipynb index c8e0233cad9..29e3fc3637f 100644 --- a/src/HowBadIsMyBatch.ipynb +++ b/src/HowBadIsMyBatch.ipynb @@ -664,10 +664,10 @@ "metadata": {}, "outputs": [], "source": [ - "# https://scikit-learn.org/stable/auto_examples/linear_model/plot_ransac.html\n", + "from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", - "from sklearn import linear_model\n" + "from skspatial.objects import Line\n" ] }, { @@ -707,9 +707,27 @@ "metadata": {}, "outputs": [], "source": [ - "from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n", + "def draw(points, clusters, lines, symptomX, symptomY, minClusterSize):\n", + " _, ax = plt.subplots()\n", + " plt.scatter(_getXs(points), _getYs(points), color = \"blue\", marker = \".\", s = 100)\n", + " for cluster, line in zip(clusters, lines):\n", + " if len(cluster) >= minClusterSize:\n", + " _drawLine(line, cluster, ax)\n", + " plt.scatter(_getXs(cluster), _getYs(cluster), marker = \".\", s = 100)\n", + " plt.xlabel(symptomX)\n", + " plt.ylabel(symptomY)\n", + " plt.show()\n", "\n", - "clusters, lines = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)" + "def _drawLine(line, cluster, ax):\n", + " coords = line.transform_points(cluster)\n", + " magnitude = line.direction.norm()\n", + " line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n", + "\n", + "def _getXs(xys):\n", + " return [x for (x, _) in xys]\n", + "\n", + "def _getYs(xys):\n", + " return [y for (_, y) in xys]" ] }, { @@ -718,30 +736,34 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "from skspatial.objects import Line\n", - "\n", - "_, ax = plt.subplots()\n", - "plt.scatter(\n", - " [x for (x, _) in points],\n", - " [y for (_, y) in points],\n", - " color = \"blue\",\n", - " marker = \".\",\n", - " s = 100,\n", - " label = \"Dots\")\n", - "for cluster, line in zip(clusters, lines):\n", - " if len(cluster) >= 3:\n", - " coords = line.transform_points(cluster)\n", - " magnitude = line.direction.norm()\n", - " line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n", - " plt.scatter(\n", - " [x for (x, _) in cluster],\n", - " [y for (_, y) in cluster],\n", - " marker = \".\",\n", - " s = 100)\n", - "plt.xlabel(symptomX)\n", - "plt.ylabel(symptomY)\n", - "plt.show()" + "clustersAscending, linesAscending = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "draw(points, clusters, lines, symptomX, symptomY, minClusterSize = 3)" ] }, {