refactoring
This commit is contained in:
@@ -664,10 +664,10 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# https://scikit-learn.org/stable/auto_examples/linear_model/plot_ransac.html\n",
|
||||
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
|
||||
"import numpy as np\n",
|
||||
"from matplotlib import pyplot as plt\n",
|
||||
"from sklearn import linear_model\n"
|
||||
"from skspatial.objects import Line\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -707,9 +707,27 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
|
||||
"def draw(points, clusters, lines, symptomX, symptomY, minClusterSize):\n",
|
||||
" _, ax = plt.subplots()\n",
|
||||
" plt.scatter(_getXs(points), _getYs(points), color = \"blue\", marker = \".\", s = 100)\n",
|
||||
" for cluster, line in zip(clusters, lines):\n",
|
||||
" if len(cluster) >= minClusterSize:\n",
|
||||
" _drawLine(line, cluster, ax)\n",
|
||||
" plt.scatter(_getXs(cluster), _getYs(cluster), marker = \".\", s = 100)\n",
|
||||
" plt.xlabel(symptomX)\n",
|
||||
" plt.ylabel(symptomY)\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"clusters, lines = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)"
|
||||
"def _drawLine(line, cluster, ax):\n",
|
||||
" coords = line.transform_points(cluster)\n",
|
||||
" magnitude = line.direction.norm()\n",
|
||||
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",
|
||||
"\n",
|
||||
"def _getXs(xys):\n",
|
||||
" return [x for (x, _) in xys]\n",
|
||||
"\n",
|
||||
"def _getYs(xys):\n",
|
||||
" return [y for (_, y) in xys]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -718,30 +736,34 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from skspatial.objects import Line\n",
|
||||
"\n",
|
||||
"_, ax = plt.subplots()\n",
|
||||
"plt.scatter(\n",
|
||||
" [x for (x, _) in points],\n",
|
||||
" [y for (_, y) in points],\n",
|
||||
" color = \"blue\",\n",
|
||||
" marker = \".\",\n",
|
||||
" s = 100,\n",
|
||||
" label = \"Dots\")\n",
|
||||
"for cluster, line in zip(clusters, lines):\n",
|
||||
" if len(cluster) >= 3:\n",
|
||||
" coords = line.transform_points(cluster)\n",
|
||||
" magnitude = line.direction.norm()\n",
|
||||
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",
|
||||
" plt.scatter(\n",
|
||||
" [x for (x, _) in cluster],\n",
|
||||
" [y for (_, y) in cluster],\n",
|
||||
" marker = \".\",\n",
|
||||
" s = 100)\n",
|
||||
"plt.xlabel(symptomX)\n",
|
||||
"plt.ylabel(symptomY)\n",
|
||||
"plt.show()"
|
||||
"clustersAscending, linesAscending = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"draw(points, clusters, lines, symptomX, symptomY, minClusterSize = 3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user