refactoring

This commit is contained in:
frankknoll
2023-11-18 11:27:45 +01:00
parent d40116ba6f
commit b773e34f3d

View File

@@ -664,10 +664,10 @@
"metadata": {},
"outputs": [],
"source": [
"# https://scikit-learn.org/stable/auto_examples/linear_model/plot_ransac.html\n",
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"from sklearn import linear_model\n"
"from skspatial.objects import Line\n"
]
},
{
@@ -707,9 +707,27 @@
"metadata": {},
"outputs": [],
"source": [
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
"def draw(points, clusters, lines, symptomX, symptomY, minClusterSize):\n",
" _, ax = plt.subplots()\n",
" plt.scatter(_getXs(points), _getYs(points), color = \"blue\", marker = \".\", s = 100)\n",
" for cluster, line in zip(clusters, lines):\n",
" if len(cluster) >= minClusterSize:\n",
" _drawLine(line, cluster, ax)\n",
" plt.scatter(_getXs(cluster), _getYs(cluster), marker = \".\", s = 100)\n",
" plt.xlabel(symptomX)\n",
" plt.ylabel(symptomY)\n",
" plt.show()\n",
"\n",
"clusters, lines = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)"
"def _drawLine(line, cluster, ax):\n",
" coords = line.transform_points(cluster)\n",
" magnitude = line.direction.norm()\n",
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",
"\n",
"def _getXs(xys):\n",
" return [x for (x, _) in xys]\n",
"\n",
"def _getYs(xys):\n",
" return [y for (_, y) in xys]"
]
},
{
@@ -718,30 +736,34 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"from skspatial.objects import Line\n",
"\n",
"_, ax = plt.subplots()\n",
"plt.scatter(\n",
" [x for (x, _) in points],\n",
" [y for (_, y) in points],\n",
" color = \"blue\",\n",
" marker = \".\",\n",
" s = 100,\n",
" label = \"Dots\")\n",
"for cluster, line in zip(clusters, lines):\n",
" if len(cluster) >= 3:\n",
" coords = line.transform_points(cluster)\n",
" magnitude = line.direction.norm()\n",
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",
" plt.scatter(\n",
" [x for (x, _) in cluster],\n",
" [y for (_, y) in cluster],\n",
" marker = \".\",\n",
" s = 100)\n",
"plt.xlabel(symptomX)\n",
"plt.ylabel(symptomY)\n",
"plt.show()"
"clustersAscending, linesAscending = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"draw(points, clusters, lines, symptomX, symptomY, minClusterSize = 3)"
]
},
{