refactoring

This commit is contained in:
frankknoll
2023-11-18 11:27:45 +01:00
parent d40116ba6f
commit b773e34f3d

View File

@@ -664,10 +664,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# https://scikit-learn.org/stable/auto_examples/linear_model/plot_ransac.html\n", "from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
"import numpy as np\n", "import numpy as np\n",
"from matplotlib import pyplot as plt\n", "from matplotlib import pyplot as plt\n",
"from sklearn import linear_model\n" "from skspatial.objects import Line\n"
] ]
}, },
{ {
@@ -707,9 +707,27 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n", "def draw(points, clusters, lines, symptomX, symptomY, minClusterSize):\n",
" _, ax = plt.subplots()\n",
" plt.scatter(_getXs(points), _getYs(points), color = \"blue\", marker = \".\", s = 100)\n",
" for cluster, line in zip(clusters, lines):\n",
" if len(cluster) >= minClusterSize:\n",
" _drawLine(line, cluster, ax)\n",
" plt.scatter(_getXs(cluster), _getYs(cluster), marker = \".\", s = 100)\n",
" plt.xlabel(symptomX)\n",
" plt.ylabel(symptomY)\n",
" plt.show()\n",
"\n", "\n",
"clusters, lines = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)" "def _drawLine(line, cluster, ax):\n",
" coords = line.transform_points(cluster)\n",
" magnitude = line.direction.norm()\n",
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",
"\n",
"def _getXs(xys):\n",
" return [x for (x, _) in xys]\n",
"\n",
"def _getYs(xys):\n",
" return [y for (_, y) in xys]"
] ]
}, },
{ {
@@ -718,30 +736,34 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import matplotlib.pyplot as plt\n", "clustersAscending, linesAscending = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)"
"from skspatial.objects import Line\n", ]
"\n", },
"_, ax = plt.subplots()\n", {
"plt.scatter(\n", "cell_type": "code",
" [x for (x, _) in points],\n", "execution_count": null,
" [y for (_, y) in points],\n", "metadata": {},
" color = \"blue\",\n", "outputs": [],
" marker = \".\",\n", "source": [
" s = 100,\n", "draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 3)"
" label = \"Dots\")\n", ]
"for cluster, line in zip(clusters, lines):\n", },
" if len(cluster) >= 3:\n", {
" coords = line.transform_points(cluster)\n", "cell_type": "code",
" magnitude = line.direction.norm()\n", "execution_count": null,
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n", "metadata": {},
" plt.scatter(\n", "outputs": [],
" [x for (x, _) in cluster],\n", "source": [
" [y for (_, y) in cluster],\n", "clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)"
" marker = \".\",\n", ]
" s = 100)\n", },
"plt.xlabel(symptomX)\n", {
"plt.ylabel(symptomY)\n", "cell_type": "code",
"plt.show()" "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"draw(points, clusters, lines, symptomX, symptomY, minClusterSize = 3)"
] ]
}, },
{ {