refactoring
This commit is contained in:
@@ -664,10 +664,10 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# https://scikit-learn.org/stable/auto_examples/linear_model/plot_ransac.html\n",
|
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"from matplotlib import pyplot as plt\n",
|
"from matplotlib import pyplot as plt\n",
|
||||||
"from sklearn import linear_model\n"
|
"from skspatial.objects import Line\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -707,9 +707,27 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from SymptomsCausedByVaccines.MultiLineFitting.MultiLineFitter import MultiLineFitter\n",
|
"def draw(points, clusters, lines, symptomX, symptomY, minClusterSize):\n",
|
||||||
|
" _, ax = plt.subplots()\n",
|
||||||
|
" plt.scatter(_getXs(points), _getYs(points), color = \"blue\", marker = \".\", s = 100)\n",
|
||||||
|
" for cluster, line in zip(clusters, lines):\n",
|
||||||
|
" if len(cluster) >= minClusterSize:\n",
|
||||||
|
" _drawLine(line, cluster, ax)\n",
|
||||||
|
" plt.scatter(_getXs(cluster), _getYs(cluster), marker = \".\", s = 100)\n",
|
||||||
|
" plt.xlabel(symptomX)\n",
|
||||||
|
" plt.ylabel(symptomY)\n",
|
||||||
|
" plt.show()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"clusters, lines = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)"
|
"def _drawLine(line, cluster, ax):\n",
|
||||||
|
" coords = line.transform_points(cluster)\n",
|
||||||
|
" magnitude = line.direction.norm()\n",
|
||||||
|
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",
|
||||||
|
"\n",
|
||||||
|
"def _getXs(xys):\n",
|
||||||
|
" return [x for (x, _) in xys]\n",
|
||||||
|
"\n",
|
||||||
|
"def _getYs(xys):\n",
|
||||||
|
" return [y for (_, y) in xys]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -718,30 +736,34 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import matplotlib.pyplot as plt\n",
|
"clustersAscending, linesAscending = MultiLineFitter.fitPointsByAscendingLines(points, consensusThreshold = 0.001)"
|
||||||
"from skspatial.objects import Line\n",
|
]
|
||||||
"\n",
|
},
|
||||||
"_, ax = plt.subplots()\n",
|
{
|
||||||
"plt.scatter(\n",
|
"cell_type": "code",
|
||||||
" [x for (x, _) in points],\n",
|
"execution_count": null,
|
||||||
" [y for (_, y) in points],\n",
|
"metadata": {},
|
||||||
" color = \"blue\",\n",
|
"outputs": [],
|
||||||
" marker = \".\",\n",
|
"source": [
|
||||||
" s = 100,\n",
|
"draw(points, clustersAscending, linesAscending, symptomX, symptomY, minClusterSize = 3)"
|
||||||
" label = \"Dots\")\n",
|
]
|
||||||
"for cluster, line in zip(clusters, lines):\n",
|
},
|
||||||
" if len(cluster) >= 3:\n",
|
{
|
||||||
" coords = line.transform_points(cluster)\n",
|
"cell_type": "code",
|
||||||
" magnitude = line.direction.norm()\n",
|
"execution_count": null,
|
||||||
" line.plot_2d(ax, t_1 = min(coords) / magnitude, t_2 = max(coords) / magnitude)\n",
|
"metadata": {},
|
||||||
" plt.scatter(\n",
|
"outputs": [],
|
||||||
" [x for (x, _) in cluster],\n",
|
"source": [
|
||||||
" [y for (_, y) in cluster],\n",
|
"clusters, lines = MultiLineFitter.fitPointsByLines(points, consensusThreshold = 0.001)"
|
||||||
" marker = \".\",\n",
|
]
|
||||||
" s = 100)\n",
|
},
|
||||||
"plt.xlabel(symptomX)\n",
|
{
|
||||||
"plt.ylabel(symptomY)\n",
|
"cell_type": "code",
|
||||||
"plt.show()"
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"draw(points, clusters, lines, symptomX, symptomY, minClusterSize = 3)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user