This is an automated email from the ASF dual-hosted git repository.

charlie pushed a commit to branch density-notebooks
in repository https://gitbox.apache.org/repos/asf/datasketches-python.git

commit 6e7b37f0e483d4392e0047f91a58bd8ac45b813c
Author: Charlie Dickens <[email protected]>
AuthorDate: Fri Nov 3 15:53:43 2023 +0000

    Adding density sketch examples
---
 .../1-introduction-kde-coreset.ipynb               | 1630 ++++++++++++++++++++
 .../2-density-sketch-visualization.ipynb           |  543 +++++++
 jupyter/density_sketch/kernel_density.py           |   61 +
 jupyter/density_sketch/naive_bayes_classifier.py   |  135 ++
 4 files changed, 2369 insertions(+)

diff --git a/jupyter/density_sketch/1-introduction-kde-coreset.ipynb 
b/jupyter/density_sketch/1-introduction-kde-coreset.ipynb
new file mode 100644
index 0000000..cb7f5f5
--- /dev/null
+++ b/jupyter/density_sketch/1-introduction-kde-coreset.ipynb
@@ -0,0 +1,1630 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "eced4a42",
+   "metadata": {},
+   "source": [
+    "# Introduction to Kernel Density Estimation (KDE)\n",
+    "\n",
+    "\n",
+    "**Objectives:** \n",
+    "- To understand exactly what a kernel density estimate is 
approximating.\n",
+    "- To understand the parameters that affecct the performance of a KDE.\n",
+    "\n",
+    "\n",
+    "## 1. Problem Setup.\n",
+    "We need the following notation:\n",
+    "- $f$ is an underlying probability distribution function (pdf) for a 
probability distribution $\\mathcal{D}$.\n",
+    "- Samples $x_i$ are drawn independently and identically from the 
distribution $\\mathcal{D}$.\n",
+    "\n",
+    "The high-level aim of KDE is to use a smoothing function, known as the 
_kernel function_, $K$, to estimate the function $f$.\n",
+    "\n",
+    "The _kernel density estimator_ is the function $\\hat{f}$.  It is 
evaluated at a test (or query) point $x^*$ and satisfies the following 
relationship:\n",
+    "\\begin{align}\n",
+    "f(x^*) \\star K(x^*) &= \\int_{\\text{domain}(f)} K(x^* - x) f(x) dx 
\\\\\n",
+    "&= \\mathbf{E}_{f(x)} \\left[ K(x^* - x) \\right] \\\\ \n",
+    "&\\approx \\frac1n \\sum_{i=1}^n K(x^* - x_i) \\\\ \n",
+    "&= \\hat{f}(x^*).\n",
+    "\\end{align}\n",
+    "\n",
+    "In words, the KDE at $x^*$ is approximately equal to the convolution of 
the density at $x^*$ with the kernel at $x^*$.\n",
+    "\n",
+    "Notably, this _does not_ mean that the KDE $\\hat{f}$ is necessarily a 
good estimator for the probability density function, despite that being the 
motivation for the estimation problem.\n",
+    "\n",
+    "## 2. Initial Comparison\n",
+    "\n",
+    "We will investigate this estimation of $\\hat{f}(x^*) \\approx f(x^*) 
\\star K(x^*)$."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "1d1be136",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "from scipy import stats\n",
+    "import matplotlib.pyplot as plt\n",
+    "from sklearn.metrics.pairwise import rbf_kernel\n",
+    "import pandas as pd\n",
+    "import scipy.integrate as integrate\n",
+    "from scipy.spatial.distance import cdist\n",
+    "%matplotlib inline"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 41,
+   "id": "a27e06b4-e5a6-4f29-931a-c1108f470368",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import warnings\n",
+    "warnings.filterwarnings('ignore')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "213d42b3",
+   "metadata": {},
+   "source": [
+    "First, we will consider the case when the distribution of study is the 
standard normal distribution $N(0,1)$.  We also obtain a _sample_ of points 
from $N(0,1)$."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "id": "58a54fab",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "n_sample = 10000\n",
+    "Z = stats.norm(0, 1)\n",
+    "X_train = Z.rvs(size=n_sample)[:, np.newaxis]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "082a089f",
+   "metadata": {},
+   "source": [
+    "The distribution is plotted below along with the samples.  "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "id": "1991a6d6",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.legend.Legend at 0x13fe2a690>"
+      ]
+     },
+     "execution_count": 3,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAhgAAAFzCAYAAAB8X3AUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAABSZUlEQVR4nO3deXxTVd4/8E+WJumW0r0UCgVkEVnK1oqCglZ2QUFFxkcW10cBl+qMwKPwqONUHEQc9CcOo6LjKDygIDNCWSqbimylgOxbW6B7S5M2bZM2ub8/Lk1704UuaW+TfN6vV0buyfZNpm0+OefccxSCIAggIiIiciKl3AUQERGR+2HAICIiIqdjwCAiIiKnY8AgIiIip2PAICIiIqdjwCAiIiKnY8AgIiIip2PAICIiIqdTy11AW7PZbMjMzIS/vz8UCoXc5RAREbkMQRBQXFyMyMhIKJUN91F4XMDIzMxE
 [...]
+      "text/plain": [
+       "<Figure size 600x400 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "fig, ax = plt.subplots(figsize=(6, 4))\n",
+    "_x = np.linspace(-5, 5)\n",
+    "ax.plot(_x, Z.pdf(_x), label=r\"$N(0,1)$\", color='red', lw=3.)\n",
+    "ax.scatter(X_train, \n",
+    "           0.01*np.random.uniform(size=X_train.shape[0])-0.05, \n",
+    "           marker='.',\n",
+    "           s=1.,\n",
+    "          alpha=0.5, label=\"Samples: $x_i$\")\n",
+    "ax.set_ylabel(\"Density\")\n",
+    "ax.set_xlabel(r\"$x$\")\n",
+    "ax.legend()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "795ac4cd",
+   "metadata": {},
+   "source": [
+    "Next, we define a kernel function for single variable inputs which takes 
a free parameter $\\gamma > 0$ that we may choose.\n",
+    "\n",
+    "\\begin{align}\n",
+    "K(u) = \\exp(-\\gamma u^2).\n",
+    "\\end{align}\n",
+    "\n",
+    "For multivariable inputs we use\n",
+    "\\begin{align}\n",
+    "K(u) = \\exp(-\\gamma ||u||^2).\n",
+    "\\end{align}\n",
+    "\n",
+    "If we choose the constant $\\gamma = 1 / 2 h^2$ for some $h$, then $K$ 
represents a Gaussian kernel provided that \n",
+    "we return the following rescaled answer for the kernel density 
estimate:\n",
+    "\n",
+    "\\begin{align}\n",
+    "\\hat{f}(x^*) = \\frac{1}{h \\sqrt{2\\pi} }\\frac1n \\sum_{i=1}^n K(x^* - 
x_i).\n",
+    "\\end{align}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "28043f3f",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def kernel_density(Xtrain, Xtest, bandwidth=1.):\n",
+    "    \"\"\"\n",
+    "    Returns the kernel density estimate between Xtrain and Xtest.\n",
+    "    returns:\n",
+    "        (1/n)*(1/bandwidth*sqrt(2pi))^d*sum_{i=1}^n K( (x* - x_i) / 
bandwidth )\n",
+    "    The bandwidth in scipy is in the numerator so we use 1./bandwidth\n",
+    "    The mean function picks up the 1/n factor.\n",
+    "    \"\"\"\n",
+    "    for x in [Xtrain, Xtest]:\n",
+    "        if x.ndim == 1:\n",
+    "            x.reshape(-1, 1)\n",
+    "    g = (1./bandwidth)**2\n",
+    "    K = np.exp(-cdist(Xtrain, Xtest, metric='sqeuclidean')*g/2)\n",
+    "    K *= 1./(bandwidth*np.sqrt(2*np.pi))**Xtrain.shape[1]\n",
+    "    return np.mean(K, axis=0)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6d4be102",
+   "metadata": {},
+   "source": [
+    "Recall that we want to compare $\\hat{f}$ to the convolution below at a 
test point $x^*$. \n",
+    "Assumign that $h=1$ we have:\n",
+    "\n",
+    "\\begin{align}\n",
+    "f(x^*) \\star K(x^*) &= \\frac{1}{\\sqrt{2\\pi}}\\int_{\\text{domain}(f)} 
K(x^* - x) f(x) dx \\\\ \n",
+    "&= \\frac{1}{\\sqrt{2\\pi}} \\int_{\\text{domain}(f)} \\exp\\left(-\\|x^* 
- x\\|^2\\right) f(x) dx\n",
+    "\\end{align}\n",
+    "\n",
+    "The convolution can be estimated through the following code."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "id": "a45dd5c1",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Target: 0.3969525474770118\n"
+     ]
+    }
+   ],
+   "source": [
+    "x_test = 0.1\n",
+    "x_test_arr = np.array([x_test])[:, np.newaxis] # needed for some of the 
vectorised functions\n",
+    "pdf_x_test = Z.pdf(x_test)\n",
+    "print(f\"Target: {pdf_x_test}\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "id": "7caaf141",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def convolution_at_q(x, q, h):\n",
+    "    return Z.pdf(x)*(np.exp(-0.5*np.linalg.norm((1./h)*(x - q))**2) / 
(h*np.sqrt(2*np.pi)))\n",
+    "\n",
+    "# This function evaluates the convolution (f \\star K) at x*\n",
+    "conv_estimate = integrate.quad(convolution_at_q, \n",
+    "                                -np.inf, np.inf,\n",
+    "                                args=(x_test_arr, 1.))[0]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5525515b",
+   "metadata": {},
+   "source": [
+    "Let's do a quick check on the two quantities that we are comparing."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "330bb223",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Convolution: 0.281\n",
+      "KDE        : 0.280\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(f\"Convolution: {conv_estimate:.3f}\")\n",
+    "print(f\"KDE        : {kernel_density(X_train, x_test_arr)[0]:.3f}\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "52e23e0b",
+   "metadata": {},
+   "source": [
+    "They are comparable, which is good news.  Now we will begin to experiment 
with this and check that the KDE method approaches the KDE value.\n",
+    "\n",
+    "**Experimental Setup.**\n",
+    "We fix a number of trials: `num_trials` and choose a test point 
`query`.\n",
+    "We generate an independent sample/training set for every trial and obtain 
the two: one for convolution, one for using the KDE.\n",
+    "\n",
+    "We will then plot the two error curves.\n",
+    "\n",
+    "If this experiment is too slow then try varying the number of trials or 
using \n",
+    "```\n",
+    "nn = np.linspace(100, 1000000, dtype=np.int64)\n",
+    "```"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "id": "99b87e87",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "num_trials = 25\n",
+    "query = np.array([1.])[:, np.newaxis]\n",
+    "nn = np.logspace(2, 7, endpoint=False, dtype=np.int64, num=20) \n",
+    "\n",
+    "kernel_estimates = {i:np.zeros((len(nn),), dtype=float) for i in 
range(num_trials)}\n",
+    "\n",
+    "conv_estimates = {i:np.zeros((len(nn),), dtype=float) for i in 
range(num_trials)}\n",
+    "\n",
+    "for i,n in enumerate(nn):\n",
+    "    for t in range(num_trials):\n",
+    "        Xtrain = stats.norm(0, 1).rvs(size=n)[:, np.newaxis]\n",
+    "        estimate = kernel_density(Xtrain, query)[0]\n",
+    "        kernel_estimates[t][i] = estimate\n",
+    "        conv_estimates[t][i] = integrate.quad(convolution_at_q, -100., 
100., args=(query,1.))[0]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "21a08001",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "kernel_df = pd.DataFrame.from_dict(kernel_estimates)\n",
+    "conv_df = pd.DataFrame.from_dict(conv_estimates)\n",
+    "error_df = kernel_df - conv_df\n",
+    "abs_error_df = error_df.abs()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "1ca5edb9",
+   "metadata": {},
+   "source": [
+    "First, we plot the error $\\hat{f}(x^*) - f(x^*) \\star K(x^*)$ versus 
the sample size, showing that as the sample size increases, the error 
decreases.  We plot the median and a $90\\%$ confidence interval."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "id": "ed624e78",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.legend.Legend at 0x17a910ad0>"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABEQAAAGHCAYAAAC5wYeeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAACk00lEQVR4nOzdeXxTVfrH8U+S7vsGZYcKqIjsIiCCLIKKiiiL6LihMzIirowK+nNwxhkRd1xQUEcHHRRFUXBgZFMHZKcsiiOgFAqyt3Tf0uT+/igJSZu2aZuQFr7v1ysvbu6959znlrRwn57zHJNhGAYiIiIiIiIiImcRc6ADEBERERERERE53ZQQEREREREREZGzjhIiIiIiIiIiInLWUUJERERERERERM46SoiIiIiIiIiIyFlHCREREREREREROesoISIiIiIiIiIiZx0lRERERERERETkrKOEiIiIiIiIiIic
 [...]
+      "text/plain": [
+       "<Figure size 1200x400 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "fig, ax = plt.subplots(figsize=(12,4))\n",
+    "\n",
+    "ax.plot(nn, error_df.quantile(q=0.95, axis=1), label='Q95')\n",
+    "ax.plot(nn, error_df.quantile(q=0.5, axis=1), label=\"Q50\")\n",
+    "ax.plot(nn, error_df.quantile(q=0.05, axis=1), label='Q05')\n",
+    "ax.set_ylabel(r\"$\\hat{f}(x^*) - f(x^*) \\star K(x^*)$\")\n",
+    "ax.set_xlabel(r\"Sample size\")\n",
+    "ax.set_ylim(-0.001, 0.001)\n",
+    "ax.grid()\n",
+    "\n",
+    "ax.tick_params(axis='both', which='major', labelsize=16)\n",
+    "ax.xaxis.get_label().set_fontsize(16)\n",
+    "ax.yaxis.get_label().set_fontsize(16)\n",
+    "ax.legend(prop={'size': 16})"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "a049da34",
+   "metadata": {},
+   "source": [
+    "Once $n$ is moderately large, we see that the median concetrates about 
$0$ and we appear to over and under estimate the quantity with roughly equal 
proportion.  This suggests that the two quantities are close and 
$\\hat{f}(x^*)$ is an unbiased estimation of  $f(x^*) \\star K(x^*)$.  More 
trials would smooth out the curves.\n",
+    "\n",
+    "However, we also want to understand the rate of convergence so we measure 
the absolute error.  The absolute error only tells us about the magnitude of 
the error, not the sign."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "id": "e0b271ac",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.legend.Legend at 0x17aa3ab90>"
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABBoAAAGECAYAAACYiEPLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAADfQ0lEQVR4nOzdd3hb5fXA8a+W5b333tnLGWTvBAi7rJRCIS2lLWGPQgJpS2nZowECbX8UKFD2KoEQsheQvZf3ivfelmVZvz+uLVux7NiObdnO+TzPfSTde3Xvq+RGkY7Oe47KbDabEUIIIYQQQgghhOgFansPQAghhBBCCCGEEEOHBBqEEEIIIYQQQgjRayTQIIQQQgghhBBCiF4jgQYhhBBCCCGEEEL0Ggk0CCGEEEIIIYQQotdIoEEIIYQQQgghhBC9RgINQgghhBBCCCGE6DUSaBBCCCGEEEIIIUSv0dp7AKJn
 [...]
+      "text/plain": [
+       "<Figure size 1200x400 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "fig, ax = plt.subplots(figsize=(12,4))\n",
+    "\n",
+    "ax.plot(nn, abs_error_df.quantile(q=0.975,axis=1))\n",
+    "ax.plot(nn, abs_error_df.quantile(q=0.5, axis=1))\n",
+    "ax.plot(nn, abs_error_df.quantile(q=0.025,axis=1))\n",
+    "ax.plot(nn, (1/16)*(nn)**(-0.5),  linestyle='--', color='black', 
linewidth=3.0, label=r\"$\\frac{1}{16}n^{-1/2}$\")\n",
+    "ax.set_ylabel(r\"$|\\hat{f}(x^*) - f(x^*) \\star K(x^*)|$\")\n",
+    "ax.set_xlabel(r\"Sample size\")\n",
+    "ax.set_yscale('log')\n",
+    "ax.set_xscale('log')\n",
+    "ax.legend()\n",
+    "ax.grid()\n",
+    "\n",
+    "ax.tick_params(axis='both', which='major', labelsize=16)\n",
+    "ax.xaxis.get_label().set_fontsize(16)\n",
+    "ax.yaxis.get_label().set_fontsize(16)\n",
+    "ax.legend(prop={'size': 16})"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "f018a895",
+   "metadata": {},
+   "source": [
+    "This plots show that we get convergence of the KDE to the convolution at 
a rate of about $1/\\sqrt{n}$.  \n",
+    "\n",
+    "### 2.1 Illustrating the KDE\n",
+    "\n",
+    "Now that we have a handle on _what_ the KDE is approximating, let's 
illustrate the returned curve.\n",
+    "Again, we sample points from the normal distribution but now we also plot 
the KDE over a uniform test grid on $[-3,3]$."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "id": "97df7763",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Text(0.5, 0, 'x')"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAkAAAAGwCAYAAABB4NqyAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAADJU0lEQVR4nOzdd1xUV/r48c+dAkMdOkMvFhAVsRJb1NhjYjS9bIymbepuwm/TEzXZZE0zm7L5xo1ZE03RNDWmmSixxIZGJXYUBFF6h6HMDDP398foyIAoIDCg5/163ZfMmXPvPNcQeTj3OedIsizLCIIgCIIgXEYUjg5AEARBEAShs4kESBAEQRCEy45IgARBEARBuOyIBEgQBEEQhMuOSIAEQRAEQbjsiARIEARBEITLjkiABEEQBEG47KgcHUBXZLFYyM3NxcPDA0mSHB2OIAiCIAgtIMsyVVVVBAcHo1Ccf4xH
 [...]
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "x_test = np.linspace(-3., 3.)[:, np.newaxis]\n",
+    "nn = np.logspace(2, 4, endpoint=False, dtype=np.int64, num=5) \n",
+    "\n",
+    "kernel_estimates = {i:np.zeros((len(nn),), dtype=float) for i in 
range(num_trials)}\n",
+    "\n",
+    "conv_estimates = {i:np.zeros((len(nn),), dtype=float) for i in 
range(num_trials)}\n",
+    "\n",
+    "fig, ax = plt.subplots()\n",
+    "for i,n in enumerate(nn):\n",
+    "    Xtrain = stats.norm(0, 1).rvs(size=n)[:, np.newaxis]\n",
+    "    estimate = kernel_density(Xtrain, x_test)\n",
+    "    ax.plot(x_test, estimate, alpha=0.5, color=\"C\"+str(i), 
label=f\"n_train: {n}\")\n",
+    "ax.plot(x_test, stats.norm(0,1).pdf(x_test), lw=3.0, color='red', 
label=\"PDF: N(0,1)\")    \n",
+    "ax.legend()\n",
+    "ax.set_ylabel(\"Density\")\n",
+    "ax.set_xlabel(\"x\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2c55bd95",
+   "metadata": {},
+   "source": [
+    "The shape looks correct, as as expected from the preceding plots, we look 
to be converging to some curve similar to the distribution...*but didn't we 
want to estimate the underlying distribution function* itself?\n",
+    "Generally speaking, we only claim to estimate the density function over a 
sample, not the true underlying distribution.\n",
+    "However, if certain conditions are met then, as outlined in the next 
section, we also estimate the underlying distribution.\n",
+    "\n",
+    "\n",
+    "## 3. Bandwidth Selection\n",
+    "\n",
+    "One detail that we have thus far omitted is the `bandwidth` parameter, 
$h$. \n",
+    "This parameter $h$ controls the width of the kernels applied at every 
point. \n",
+    "In general, it might be difficult to choose a priori the correct 
bandwidth.\n",
+    "If the bandwidth is chosen too small, then the returned curve will be too 
spiky, and if the bandwidth is too large, then the curve will appear too 
flat.\n",
+    "\n",
+    "Suppose that we sample $(x_1, x_2, x_3) = (-2, 0.05, 0.15) $ from a 
standard normal distribution and fit a KDE model.\n",
+    "When we query a test point $x^*$, we are taking a sum over distances from 
$x^*$ to the points $x_1, x_2, x_3$.\n",
+    "Ignoring the constants, the KDE at $x^*$ is $\\hat{f}(x^*) = K(x^* - x_1) 
+ K(x^* - x_2) + K(x^* - x_3)$.\n",
+    "When a query point $x^*$ comes at test time, its distance is compared to 
all $x_i$ and (an amount proportional to) $\\exp(-\\frac{1}{2h^2} |x_i - 
x^*|^2)$ is contributed to $\\hat{f}(x^*)$.  The contribution should be 
''large'' when the distance is small and ''small'' when the distance is 
large.\n",
+    "\n",
+    "In the plot below, the query point $x^*$ is close to the two points on 
the right hand side so picks up a non-trivial contribution from each.  On the 
other hand, it is far from the point on the left hand side, so the the 
exponential function makes this contribute almost zero to the overall sum."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 38,
+   "id": "621d01c9",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.legend.Legend at 0x30ed57050>"
+      ]
+     },
+     "execution_count": 38,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAp8AAAGsCAYAAACb7syWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAACKqUlEQVR4nO3dd3hc1bXw4d+Zqt67LFuyLXdcwcZUAwZTE5JQEhKK84UklATiNCABh+QCISGEhOZcLmDSaYGEGExxYqqxwca9F1myrN771O+PozOSbEk+M9KZut7nOY/Go9kzS7I8Xtp7r7UVr9frRQghhBBCiCAwhToAIYQQQggROyT5FEIIIYQQQSPJpxBCCCGECBpJPoUQQgghRNBI8imEEEIIIYJGkk8hhBBCCBE0knwKIYQQQoigsYQ6AD08Hg9Hjx4lOTkZRVFCHY4QQgghhDiG1+ulra2NgoICTKah5zcj
 [...]
+      "text/plain": [
+       "<Figure size 800x500 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "x_vals = [-2, 0.05, 0.15]\n",
+    "query = -0.1\n",
+    "fig, ax = plt.subplots(figsize=(8,5))\n",
+    "ax.scatter(x_vals, np.zeros_like(x_vals), label=r\"$(x_1, x_2, 
x_3)$\")\n",
+    "\n",
+    "h = 0.25\n",
+    "window = np.linspace(-1., 1.) \n",
+    "for i, x in enumerate(x_vals):\n",
+    "    ax.plot(x+window, np.exp(-0.5*(window/h)**2))\n",
+    "ax.vlines(query, 0, 1., color='red', linestyle=\"--\")\n",
+    "ax.plot(query, 0, label=r\"Query: $x^*$\", marker= 'x', 
markersize=10.)\n",
+    "ax.legend()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "1a9808a6",
+   "metadata": {},
+   "source": [
+    "### Small $h$\n",
+    "\n",
+    "When $h$ is too small, the kernels are extremely narrow (think of a 
Gaussian distribution with small variance).  Then any test point can easily lie 
in the tails of many of the kernel functions about each of the datapoints.  
This means that the contribution is small from almost all of the sample data.  
However, when a test point lies close to a sample point, it gets essentially 
the entirety of its contribution to $\\hat{f}$ from that datapoint, so we see 
large spikes in the returned KDE."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "id": "6daf7fdf",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.legend.Legend at 0x17f859650>"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAApIAAAGxCAYAAADRWFZjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAACoDUlEQVR4nOz9eXxb1Z0//r+udsuWJe/7GjuJHcdOnISsZCGEBAIJS6AsMy3T6XTa0gLTzpQyj++UMrQN7fxop+20lGn7gU5LyhL2kIQA2UhwyL7Hifd9lS3Jtqz13t8fihTJkqx7ZW2W38/Hww/w9T33HjvSuW+d5X0YjuM4EEIIIYQQIpAo2hUghBBCCCHTEwWShBBCCCEkKBRIEkIIIYSQoFAgSQghhBBCgkKBJCGEEEIICQoFkoQQQgghJCgUSBJCCCGEkKBQIEkIIYQQQoJCgSQhhBBCCAkKBZKEEEIIISQo
 [...]
+      "text/plain": [
+       "<Figure size 800x500 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "x_vals = [-2, -0.15, -0.02, 0.05, 0.2, 2.5]\n",
+    "\n",
+    "fig, ax = plt.subplots(figsize=(8,5))\n",
+    "ax.scatter(x_vals, np.zeros_like(x_vals), label='Samples')\n",
+    "\n",
+    "h = 0.25\n",
+    "window = np.linspace(-1., 1.) \n",
+    "for i, x in enumerate(x_vals):\n",
+    "    if i == 0:\n",
+    "        ax.plot(x+window, np.exp(-0.5*(window/h)**2)/h, alpha=0.75, 
linestyle=\":\", color='grey', label=\"Kernels\")\n",
+    "    else:\n",
+    "        ax.plot(x+window, np.exp(-0.5*(window/h)**2)/h, alpha=0.75, 
linestyle=\":\", color='grey')\n",
+    "    \n",
+    "test_points = np.linspace(-3, 3, 25)[:, np.newaxis]\n",
+    "kde = kernel_density(np.array(x_vals)[:, np.newaxis], test_points, 
0.01)\n",
+    "ax.plot(test_points, kde, label=\"KDE\")\n",
+    "ax.scatter(test_points, -0.5*np.ones_like(test_points), label=\"Test 
points\")\n",
+    "ax.legend()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6b817079",
+   "metadata": {},
+   "source": [
+    "### Large $h$\n",
+    "\n",
+    "On the other hand, when $h$ is too large, almost all of the points 
contribute to the kernel sum for $\\hat{f}$, so this flattens out the curve as 
it is more difficult to distinguish between groups of points with large and 
small distances."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "c058ff01",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.legend.Legend at 0x17f911ed0>"
+      ]
+     },
+     "execution_count": 15,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAArMAAAHMCAYAAADGeyCSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAABuGklEQVR4nO3deXxU5d0+/muWzEwmy0z2ZLKHJCQhkgAhYd8EglYURVmqVailT1VsfbBV6a+KVltcqvJ1xepTsa0WKu6KKCIgS9jCvgRIQsg62TOTZDKZ9fdHzDFDQgRMMnOS6/16zQtycubkM5mTM9e5z33fR+J0Op0gIiIiIhIhqbsLICIiIiK6WgyzRERERCRaDLNEREREJFoMs0REREQkWgyzRERERCRaDLNEREREJFoMs0REREQkWgyzRERERCRacncXQOLkdDphs9lgt9vdXQoREZFHkslkkMvlkEgk7i5l
 [...]
+      "text/plain": [
+       "<Figure size 800x500 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "x_vals = [-2, -0.15, -0.02, 0.05, 0.2, 2.5]\n",
+    "\n",
+    "fig, ax = plt.subplots(figsize=(8,5))\n",
+    "ax.scatter(x_vals, np.zeros_like(x_vals), label='Samples')\n",
+    "\n",
+    "h = 5.\n",
+    "window = np.linspace(-2., 2.) \n",
+    "for i, x in enumerate(x_vals):\n",
+    "    if i == 0:\n",
+    "        ax.plot(x+window, np.exp(-0.5*(window/h)**2)/h, alpha=0.75, 
linestyle=\":\", color='grey', label=\"Kernels\")\n",
+    "    else:\n",
+    "        ax.plot(x+window, np.exp(-0.5*(window/h)**2)/h, alpha=0.75, 
linestyle=\":\", color='grey')\n",
+    "    \n",
+    "test_points = np.linspace(-3, 3, 25)[:, np.newaxis]\n",
+    "kde = kernel_density(np.array(x_vals)[:, np.newaxis], test_points, 
3.0)\n",
+    "ax.plot(test_points, kde, label=\"KDE\")\n",
+    "ax.scatter(test_points, -0.05*np.ones_like(test_points), label=\"Test 
points\")\n",
+    "ax.legend(loc='upper center', ncol=4, bbox_to_anchor=(0.5, 1.1),)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "f0aa9475",
+   "metadata": {},
+   "source": [
+    "However, all is not lost as we can vary the bandwidth parameter to get 
closer to the target distribution."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 16,
+   "id": "886b5b3e",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.legend.Legend at 0x17a94e750>"
+      ]
+     },
+     "execution_count": 16,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAC8tElEQVR4nOzdeXxU9b34/9eZfcu+J4RshIQtBMIiKovs1t66tNblWpVavde2t/VLe+3lZ4vW9l5atba2euutrQpqldatdUMwsiqENeyELSGB7Hsms8+c3x8DA5GwBEMmCe/n43EeMud8zmfeB2LmPZ9VUVVVRQghhBCiH9OEOwAhhBBCiAuRhEUIIYQQ/Z4kLEIIIYTo9yRhEUIIIUS/JwmLEEIIIfo9SViEEEII0e9JwiKEEEKIfk8SFiGEEEL0e7pwB9AbAoEA1dXVREREoChKuMMRQgghxEVQVZWOjg5SU1PR
 [...]
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "hh = [0.1, 0.2, 0.4, 0.8]\n",
+    "n_train = 10000\n",
+    "x_test = np.linspace(-3., 3.)[:, np.newaxis]\n",
+    "true_density = stats.norm(0,1).pdf(x_test)\n",
+    "Xtrain = stats.norm(0, 1).rvs(size=n_train)[:, np.newaxis]\n",
+    "all_estimates = {str(h): np.zeros((len(x_test),)) for h in hh}\n",
+    "\n",
+    "fig, ax = plt.subplots()\n",
+    "ax.plot(x_test, true_density, lw=3., color='red', label=\"True 
density\")\n",
+    "for i, h in enumerate(hh):\n",
+    "    estimate = kernel_density(Xtrain, x_test, h)\n",
+    "    ax.plot(x_test, estimate, alpha=0.5, color=\"C\"+str(i), 
label=f\"h:{h}\")\n",
+    "ax.legend()\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "842054e0",
+   "metadata": {},
+   "source": [
+    "It turns out that if we define the bandwidth as a function of $n$, say $h 
= h(n)$, then provided the sample \n",
+    "size $n$ grows more quickly than $h(n)$ decays, the kernel density 
estimate will converge\n",
+    "pointwise to the density function evaluated at that point.\n",
+    "A stronger form of  convergence (known as uniform convergence) can be 
obtained if $n$ grows faster than the square of $h(n)$.\n",
+    "If the number of samples $n$ in the training set grows more \n",
+    "For example, if we guarantee that the band\n",
+    "$\\lim_{n \\rightarrow \\infty}n h_n = \\infty$, then uniform convergence 
to the distribution is obtained.\n",
+    "\n",
+    "<!-- It looks as if the density estimate is converging to some consistent 
level of error.  This is the unresolvable model error.\n",
+    "Some further sources can be found at:\n",
+    "- 
https://www.projectrhea.org/rhea/images/4/4c/Parzen_window_method_and_classification.pdf\n";,
+    "- https://www.ehu.eus/ccwintco/uploads/8/89/Borja-Parzen-windows.pdf 
-->\n",
+    "\n",
+    "We will now show that allowing the bandwidth to decay with $n$ causes the 
KDE to converge to the true distribution.\n",
+    "We use a uniform grid over $[-2, 2]$ as the test set."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "id": "1d224351",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "[<matplotlib.lines.Line2D at 0x17fad48d0>]"
+      ]
+     },
+     "execution_count": 17,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9WYxl2X3fe37Xnvc++4xxYsyxBhZHcRAlarDV9r1WXxlo+Mo9AIJhW4Zg68GAnwg/mIAhQ3rRgwFBhiFAhtqCDbgBCw247cb1hWxctgxTlihKpCiyisUacozxxJn3PK7VDyczq5KZWVVZA7OG9QECVRl5hh1xIjJ+sdZ//f9CKaXQNE3TNE17QownfQGapmmapn206TCiaZqmadoTpcOIpmmapmlPlA4jmqZpmqY9UTqMaJqmaZr2ROkwommapmnaE6XDiKZpmqZpT5QOI5qmaZqmPVHWk76At0JK
 [...]
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "num_trials = 10\n",
+    "query = np.linspace(-2., 2., num=20)[:, np.newaxis]\n",
+    "true_dist = stats.norm(0,1).pdf(query)\n",
+    "nn = np.logspace(2, 7, endpoint=False, dtype=np.int64, num=20) \n",
+    "\n",
+    "kernel_errors = {i:np.zeros((len(nn),), dtype=float) for i in 
range(num_trials)}\n",
+    "#conv_estimates = {i:np.zeros((len(nn),), dtype=float) for i in 
range(num_trials)}\n",
+    "\n",
+    "fig, ax = plt.subplots()\n",
+    "for i,n in enumerate(nn):\n",
+    "    for t in range(num_trials):\n",
+    "        Xtrain = stats.norm(0, 1).rvs(size=n)[:, np.newaxis]\n",
+    "        estimate = kernel_density(Xtrain, query, 2./n**0.5)\n",
+    "        error = np.linalg.norm(estimate.flatten() - 
true_dist.flatten())\n",
+    "        kernel_errors[t][i] = error\n",
+    "        ax.plot(query, estimate, alpha=0.2)\n",
+    "        \n",
+    "ax.plot(query, true_dist, lw=3., color='red')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "id": "c0c088c1",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.legend.Legend at 0x17f979490>"
+      ]
+     },
+     "execution_count": 18,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABBoAAAGECAYAAACYiEPLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAADo0UlEQVR4nOzdeXhU5fXA8e9MMtn3PWRPSAIBAkkm7GGHCggqilh3a9VWqGvr3qpFbSvaX1HRim1xQQXBDRCUfSdkJ2xZyJ4A2fd1kszvjyEDAwkEyM75PM99Jrn3vnfei9fJ3HPPe16FVqvVIoQQQgghhBBCCNEFlL3dASGEEEIIIYQQQgwcEmgQQgghhBBCCCFEl5FAgxBCCCGEEEIIIbqMBBqEEEIIIYQQQgjRZSTQIIQQQgghhBBCiC4jgQYhhBBCCCGEEEJ0GQk0CCGEEEIIIYQQostIoEEIIYQQQgghhBBd
 [...]
+      "text/plain": [
+       "<Figure size 1200x400 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "dist_df = pd.DataFrame.from_dict(kernel_errors)\n",
+    "\n",
+    "fig, ax = plt.subplots(figsize=(12,4))\n",
+    "\n",
+    "ax.plot(nn, dist_df.quantile(q=0.95,axis=1), label='Q95')\n",
+    "ax.plot(nn, dist_df.quantile(q=0.5, axis=1), label=\"Q50\")\n",
+    "ax.plot(nn, dist_df.quantile(q=0.05,axis=1), label='Q05')\n",
+    "ax.plot(nn, (nn)**(-0.2), linestyle='--', color='black', linewidth=3.0, 
label=r\"$n^{-1/2}$\")\n",
+    "#ax.plot(nn, (1/16)*(nn)**(-0.5),  linestyle='--', color='black', 
linewidth=3.0, label=r\"$\\frac{1}{16}n^{-1/2}$\")\n",
+    "ax.set_ylabel(r\"$|\\hat{f}(x^*) - f(x^*)|$\")\n",
+    "ax.set_xlabel(r\"Sample size\")\n",
+    "ax.set_yscale('log')\n",
+    "ax.set_xscale('log')\n",
+    "ax.legend(title=r\"$h_n = 2n^{-0.2}$\")\n",
+    "ax.grid()\n",
+    "\n",
+    "ax.tick_params(axis='both', which='major', labelsize=16)\n",
+    "ax.xaxis.get_label().set_fontsize(16)\n",
+    "ax.yaxis.get_label().set_fontsize(16)\n",
+    "ax.legend(prop={'size': 16})"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "97866c6a",
+   "metadata": {},
+   "source": [
+    "Setting the bandwidth parameter $h_n \\approx n^{-0.5}$ means that the 
number of samples grows faster than the rate at which $h_n$ shrinks.  This 
setting shows that convergence to the true distribution is achieved over the 
entire input domain.  "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5b4cd191",
+   "metadata": {},
+   "source": [
+    "## 4. Using a Kernel Density Estimator"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "23a98b66",
+   "metadata": {},
+   "source": [
+    "### 4.1 Distribution Visualisation\n",
+    "\n",
+    "A canonical use-case is to use the KDE for brief simple visualization 
tasks.  Here, we might be interested in simple univariate or bivariate 
distributions.  For example, below we have a simple Gaussian Mixture Model (a 
modification of the 
[scikit-learn](https://scikit-learn.org/stable/auto_examples/neighbors/plot_kde_1d.html)
 tutorial)."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "id": "15e265da",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "N = 10000\n",
+    "np.random.seed(1)\n",
+    "\n",
+    "lw= 3.\n",
+    "left_mass_param = 0.3\n",
+    "right_mass_param = 1. - left_mass_param\n",
+    "X_train = np.concatenate(\n",
+    "    (np.random.normal(0, 1, int(left_mass_param * N)), \n",
+    "     np.random.normal(5, 1, int(right_mass_param * N)))\n",
+    ")[:, np.newaxis]\n",
+    "\n",
+    "X_plot = np.linspace(-5, 10, 1000)[:, np.newaxis] # Linearly spaced 
points on the interval.\n",
+    "\n",
+    "################ True density function ################\n",
+    "true_dens = left_mass_param * stats.norm(0, 1).pdf(X_plot[:, 0]) + 
right_mass_param * stats.norm(5, 1).pdf(X_plot[:, 0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 20,
+   "id": "dbc04423",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def make_basic_plot():\n",
+    "    fig, ax = plt.subplots(figsize=(16,8))\n",
+    "    ax.fill(X_plot[:, 0], true_dens, fc=\"black\", alpha=0.2, 
label=\"input distribution\")\n",
+    "    lw = 2\n",
+    "\n",
+    "    ax.legend(loc=\"upper left\", title=\"N={0} points\".format(N))\n",
+    "    ax.plot(X_train[:, 0], -0.005 - 0.01 * 
np.random.random(X_train.shape[0]), \"+k\", alpha=0.1, label='samples')\n",
+    "\n",
+    "    ax.set_xlim(-4, 9)\n",
+    "    return fig, ax"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "id": "97c19eb9",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.legend.Legend at 0x17fcca810>"
+      ]
+     },
+     "execution_count": 21,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAABRQAAAKTCAYAAABo9IQGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3xV9f3H8ffNHmQACYQRCHvJRgFxIKLIco+qde9Zi6i1toraVqq49yhDRVx14KIqihNwYADZIwMyCcldGXf//uBH4GSRfTJez8cjj3o+95xzPwlpkvu+32EJBAIBAQAAAAAAAEAtBJndAAAAAAAAAIDWg0ARAAAAAAAAQK0RKAIAAAAAAACoNQJFAAAAAAAAALVGoAgAAAAAAACg1ggUAQAAAAAAANQagSIAAAAAAACAWgsxu4HG4Pf7lZ2drZiYGFksFrPbAQAAAAAAAFqVQCAgh8Oh7t27Kyio
 [...]
+      "text/plain": [
+       "<Figure size 1600x800 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "################ KDE functions ################\n",
+    "colors = [\"navy\"]\n",
+    "kernels = [\"gaussian\"]\n",
+    "fig, ax = make_basic_plot()\n",
+    "for color, kernel in zip(colors, kernels):\n",
+    "    kde = kernel_density(X_train, X_plot, 0.5) \n",
+    "    ax.plot(\n",
+    "        X_plot[:, 0],\n",
+    "        kde,\n",
+    "        color=color,\n",
+    "        lw=lw,\n",
+    "        linestyle=\"-\",\n",
+    "        label=\"kernel = '{0}'\".format(kernel),\n",
+    "    )\n",
+    "ax.legend()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "505777e8",
+   "metadata": {},
+   "source": [
+    "We can also plot bivariate examples such as a 2d Gaussian distribution.  
The next cell generates a bivariate Gaussian distribution."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 40,
+   "id": "c7dc0f22",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "x = np.linspace(-3, 3)\n",
+    "y = np.linspace(-3, 3)\n",
+    "\n",
+    "X_test, Y_test = np.meshgrid(x, y)\n",
+    "X_plot = np.c_[X_test.ravel(), Y_test.ravel()]\n",
+    "\n",
+    "# Generate input distribution\n",
+    "means = [0, 0]\n",
+    "covariances = np.array([[1., 0.4],\n",
+    "               [0.8, 1.]])\n",
+    "mvn = stats.multivariate_normal(mean=means, cov=covariances)\n",
+    "X_train = mvn.rvs(size=10000)\n",
+    "dist = mvn.pdf(X_plot)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "id": "372e55ba",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Text(0.5, 1.0, 'Samples from Bivariate Normal')"
+      ]
+     },
+     "execution_count": 23,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAiIAAAGzCAYAAAASZnxRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAB6PklEQVR4nO3deXgUZdY+/rsTks4C2QjQAVnCpoawCLKJw2YQFAV1dARlBPSLyuKCjiLOKDLMDCL+Bn0FETecdxTcFVSMA4IiGMRXiBAjCjFBJiRAFjohIQvp+v0Rq+mlqruquqq7Ork/18V1kU5315NOp+vU85znHIsgCAKIiIiIQiAi1AMgIiKi1ouBCBEREYUMAxEiIiIKGQYiREREFDIMRIiIiChkGIgQERFRyDAQISIiopBhIEJEREQhw0CEiIiIQoaBCJFGFosFTzzxRFCPeeLECdx4441o3749LBYLnnnm
 [...]
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "plt.scatter(X_train[:,0], X_train[:,1])\n",
+    "plt.title(\"Samples from Bivariate Normal\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 24,
+   "id": "5505c290",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      
"/var/folders/ts/39plpy691lg3xd9rvzndrmt40000gq/T/ipykernel_20006/3788610756.py:29:
 UserWarning: set_ticklabels() should only be used with a fixed number of 
ticks, i.e. after set_ticks() or using a FixedLocator.\n",
+      "  ax_both.set_yticklabels(ylabels)\n",
+      
"/var/folders/ts/39plpy691lg3xd9rvzndrmt40000gq/T/ipykernel_20006/3788610756.py:33:
 UserWarning: set_ticklabels() should only be used with a fixed number of 
ticks, i.e. after set_ticks() or using a FixedLocator.\n",
+      "  ax_both.set_xticklabels(ylabels)\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "Text(0.5, 0.98, 'True Distribution')"
+      ]
+     },
+     "execution_count": 24,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAp4AAALjCAYAAAC7ygWMAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAB2s0lEQVR4nO3dd3hUZd6H8e9Meg8hCS2hl9C7KCgERQTFipV1BXtfXV0Vyyoqroq9d9HVtfeuSFVRkS6d0FKAUEIK6cmc949AXpCWMjPPmZn7c125LkmZ+SVq5uY55zzHYVmWJQAAAMDDnKYHAAAAQGAgPAEAAOAVhCcAAAC8gvAEAACAVxCeAAAA8ArCEwAAAF5BeAIAAMArCE8AAAB4BeEJAAAAryA8AT/mcDjq/Zaenm567DqbNWvWAfOHhIQoISFBnTt31tlnn60nn3xS27ZtO+RjbNy4UQ6HQ23btvXe4Iex
 [...]
+      "text/plain": [
+       "<Figure size 800x800 with 3 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "fig = plt.figure()\n",
+    "fig.set_figheight(8)\n",
+    "fig.set_figwidth(8)\n",
+    "\n",
+    "ax_component1 = plt.subplot2grid(shape=(3, 3), loc=(0, 0), colspan=2)\n",
+    "ax_component2 = plt.subplot2grid(shape=(3, 3), loc=(1, 2), rowspan=2)\n",
+    "ax_both = plt.subplot2grid((3, 3), (1, 0), rowspan=2, colspan=2)\n",
+    " \n",
+    "# plotting both Gaussian components\n",
+    "ax_component1.plot(x, stats.multivariate_normal.pdf(x, mean=means[0], 
cov=covariances[0][0]))\n",
+    "ax_component2.plot( stats.multivariate_normal.pdf(x, mean=means[1], 
cov=covariances[1][1]), x)\n",
+    "\n",
+    "\n",
+    "ax_both.contourf(X_test, Y_test, dist.reshape(X_test.shape), 
cmap=\"Blues\")\n",
+    "#ax_both.contour(X_train[:,0], X_train[:,1], dist)\n",
+    "#ax_both.scatter(X_train[:,0], X_train[:,1],alpha=0.1, marker='.')\n",
+    "\n",
+    " \n",
+    "plt.subplots_adjust(wspace=0, hspace=0)\n",
+    "ax_both.axis('equal')\n",
+    "\n",
+    "for a in [ax_component1, ax_component2]:\n",
+    "    a.get_xaxis().set_visible(False)\n",
+    "    a.get_yaxis().set_visible(False)\n",
+    "\n",
+    "plt.gcf().canvas.draw()\n",
+    "ylabels = ax_both.get_yticklabels()\n",
+    "ylabels[0] = ylabels[-1] = \"\"\n",
+    "ax_both.set_yticklabels(ylabels)\n",
+    "#ax_both.scatter(X_test, Y_test, alpha=0.2)\n",
+    "xlabels = ax_both.get_xticklabels()\n",
+    "xlabels[0] = xlabels[-1] = \"\"\n",
+    "ax_both.set_xticklabels(ylabels)\n",
+    "fig.suptitle('True Distribution', fontsize=16)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "3eb3dd9d",
+   "metadata": {},
+   "source": [
+    "Now, let's plot the KDE evaluated over the same test grid."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "id": "60d88ad8",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      
"/var/folders/ts/39plpy691lg3xd9rvzndrmt40000gq/T/ipykernel_20006/3462036152.py:28:
 UserWarning: set_ticklabels() should only be used with a fixed number of 
ticks, i.e. after set_ticks() or using a FixedLocator.\n",
+      "  ax_both.set_yticklabels(ylabels)\n",
+      
"/var/folders/ts/39plpy691lg3xd9rvzndrmt40000gq/T/ipykernel_20006/3462036152.py:31:
 UserWarning: set_ticklabels() should only be used with a fixed number of 
ticks, i.e. after set_ticks() or using a FixedLocator.\n",
+      "  ax_both.set_xticklabels(ylabels)\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "Text(0.5, 0.98, 'Kernel Density Estimate')"
+      ]
+     },
+     "execution_count": 25,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAgMAAAI1CAYAAABL+PnnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAABlIUlEQVR4nO3dd3gU1f4G8Hc3ZdMTEtI7JCTU0JGeKNIUpIqCCqiICOgVC0Uv5f5ULFjBXsCrYgMRUHrvcOk1IQECKYQkQCqpu+f3R8hKIEDK7s7Mzvt5nn0e2J3d+c5usvPmnDPnaIQQAkRERKRaWqkLICIiImkxDBAREakcwwAREZHKMQwQERGpHMMAERGRyjEMEBERqRzDABERkcoxDBAREakcwwAREZHKMQxQtcLCwqDRaLBo0aJqH8/KykL79u2h0WjQvHlzpKWlWbZAM0pOToZGo0FYWFitnhcbGwuNRlPl
 [...]
+      "text/plain": [
+       "<Figure size 600x600 with 3 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "fig = plt.figure()\n",
+    "fig.set_figheight(6)\n",
+    "fig.set_figwidth(6)\n",
+    "\n",
+    "ax_component1 = plt.subplot2grid(shape=(3, 3), loc=(0, 0), colspan=2)\n",
+    "ax_component2 = plt.subplot2grid(shape=(3, 3), loc=(1, 2), rowspan=2)\n",
+    "ax_both = plt.subplot2grid((3, 3), (1, 0), rowspan=2, colspan=2)\n",
+    " \n",
+    "# plotting both Gaussian components\n",
+    "ax_component1.plot(x, stats.multivariate_normal.pdf(x, mean=means[0], 
cov=covariances[0][0]))\n",
+    "ax_component2.plot( stats.multivariate_normal.pdf(x, mean=means[1], 
cov=covariances[1][1]), x)\n",
+    "\n",
+    "\n",
+    "# Fit and plot the model\n",
+    "kde =  kernel_density(X_train, X_plot, 0.5).reshape(X_test.shape)\n",
+    "ax_both.contourf(X_test, Y_test, kde, cmap=\"Blues\")\n",
+    "\n",
+    "plt.subplots_adjust(wspace=0, hspace=0)\n",
+    "ax_both.axis('equal')\n",
+    "\n",
+    "for a in [ax_component1, ax_component2]:\n",
+    "    a.get_xaxis().set_visible(False)\n",
+    "    a.get_yaxis().set_visible(False)\n",
+    "\n",
+    "#plt.gcf().canvas.draw()\n",
+    "ylabels = ax_both.get_yticklabels()\n",
+    "ylabels[0] = ylabels[-1] = \"\"\n",
+    "ax_both.set_yticklabels(ylabels)\n",
+    "xlabels = ax_both.get_xticklabels()\n",
+    "xlabels[0] = xlabels[-1] = \"\"\n",
+    "ax_both.set_xticklabels(ylabels)\n",
+    "fig.suptitle('Kernel Density Estimate', fontsize=16)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "947aa055",
+   "metadata": {},
+   "source": [
+    "Plot the two side by side."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 26,
+   "id": "279b5838",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Text(0.5, 1.0, 'KDE')"
+      ]
+     },
+     "execution_count": 26,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAiIAAAGzCAYAAAASZnxRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1XUlEQVR4nO3de3RV5Z3/8U+IkhAugUCAWsIdL5Riu6ggFBEW96k6dFk6zowKVC0z5bJcuMYRZ2rUaRe1sFqWShXHFjpjmXGkWjt4g3Ep1lurokN1BA2CchEIIEmMEiTZvz/4nZDLOSfnsp/9PHvv92uts5ZJzjn7yWl5ziff73fvU+B5nicAAAALOtleAAAAiC+CCAAAsIYgAgAArCGIAAAAawgiAADAGoIIAACwhiACAACsIYgAAABrCCIAAMAagggyNnjwYM2fP9/4cfbs2aOCggKtX7+++Xvz589Xt27djB87
 [...]
+      "text/plain": [
+       "<Figure size 640x480 with 2 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "fig, ax = plt.subplots(ncols=2)\n",
+    "ax[0].contourf(X_test, Y_test, dist.reshape(X_test.shape), 
cmap=\"Blues\")\n",
+    "ax[1].contourf(X_test, Y_test, kde, cmap=\"Blues\")\n",
+    "ax[0].set_title(\"True Distribution\")\n",
+    "ax[1].set_title(\"KDE\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "518cf82c",
+   "metadata": {},
+   "source": [
+    "In practice, we may never have the true distribution for comparison.  
Instead, the user should vary the bandwidth and compare against some other held 
out data from their dataset to choose the best bandwidth.\n",
+    "An example of this approach on real data can be seen on the classical 
_iris_ dataset.\n",
+    "We choose the first two features which are `sepal_length` and 
`sepal_width`, respectively."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "id": "4e746800",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from sklearn.datasets import load_iris"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 28,
+   "id": "4bc85d34",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "iris = load_iris()\n",
+    "X_iris = iris.data[:, :2]\n",
+    "y_iris = iris.target"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 29,
+   "id": "66ec81f8",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Text(0, 0.5, 'Sepal width')"
+      ]
+     },
+     "execution_count": 29,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAjgAAAGzCAYAAAAi6m1wAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAACgMElEQVR4nOzdd1hU19YG8Hc6I2XoKIoNEUVEsIK9gAoWMBpLVIzRWGIvMWrKTWLUm+RLLCRejQY1QRN7SWLsShQBETV2jb0rIkXaMGV9f6AjE2aUOkNZv+fhubl7zrDXGSmLM2fvV0BEBMYYY4yxSkRo7gIYY4wxxkobNziMMcYYq3S4wWGMMcZYpcMNDmOMMcYqHW5wGGOMMVbpcIPDGGOMsUqHGxzGGGOMVTrc4DDGGGOs0uEGhzHGGGOVDjc4jDHGGKt0xOYu4IX//ve/mDNnDqZMmYLFixcbPGbNmjUYOXKk
 [...]
+      "text/plain": [
+       "<Figure size 640x480 with 1 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# Plot the training points\n",
+    "plt.scatter(X_iris[:, 0], X_iris[:, 1], c=y_iris, cmap=plt.cm.Set1, 
edgecolor=\"k\")\n",
+    "plt.xlabel(\"Sepal length\")\n",
+    "plt.ylabel(\"Sepal width\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "d747686d",
+   "metadata": {},
+   "source": [
+    "The plotted contours represent the kernel density estimates.  The 
estimates in one dimension are also plotted."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 30,
+   "id": "25fd9e42",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      
"/var/folders/ts/39plpy691lg3xd9rvzndrmt40000gq/T/ipykernel_20006/975648639.py:44:
 UserWarning: set_ticklabels() should only be used with a fixed number of 
ticks, i.e. after set_ticks() or using a FixedLocator.\n",
+      "  ax_both.set_yticklabels(ylabels)\n",
+      
"/var/folders/ts/39plpy691lg3xd9rvzndrmt40000gq/T/ipykernel_20006/975648639.py:47:
 UserWarning: set_ticklabels() should only be used with a fixed number of 
ticks, i.e. after set_ticks() or using a FixedLocator.\n",
+      "  ax_both.set_xticklabels(ylabels)\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "Text(0.5, 0.98, 'Kernel Density Estimate')"
+      ]
+     },
+     "execution_count": 30,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "<Figure size 640x480 with 0 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAqYAAAL3CAYAAABYqYfNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3hUZd7G8e9Meg8hCQkkhJ7QEnqviiCCKEVQUJrYsK2uvpbVtaxrb+vaRYoFCyoCCggoIL13Qi+ppJLeM+f9I5I1AkpJcibJ/bmuubg458yZ30wymXue8xSLYRgGIiIiIiIms5pdgIiIiIgIKJiKiIiIiJ1QMBURERERu6BgKiIiIiJ2QcFUREREROyCgqmIiIiI2AUFUxERERGxCwqmIiIiImIXFExFRERExC4omIpUsSZNmmCxWJg9e/Y596ekpNClSxcsFgtt27YlPj6+egusQidOnMBisdCk
 [...]
+      "text/plain": [
+       "<Figure size 800x800 with 3 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "plt.clf()\n",
+    "fig = plt.figure()\n",
+    "fig.set_figheight(8)\n",
+    "fig.set_figwidth(8)\n",
+    "\n",
+    "ax_component1 = plt.subplot2grid(shape=(3, 3), loc=(0, 0), colspan=2)\n",
+    "ax_component2 = plt.subplot2grid(shape=(3, 3), loc=(1, 2), rowspan=2)\n",
+    "ax_both = plt.subplot2grid((3, 3), (1, 0), rowspan=2, colspan=2)\n",
+    " \n",
+    "# first_kde\n",
+    "first_x = np.linspace(X_iris[:,0].min(), X_iris[:,0].max(), 
num=len(X_iris))[:, np.newaxis]\n",
+    "first_kde = kernel_density(X_iris[:,0][:, np.newaxis], first_x) \n",
+    "ax_component1.plot(first_x, first_kde)\n",
+    "\n",
+    "# Second kde\n",
+    "second_x = np.linspace(X_iris[:,1].min(), X_iris[:,1].max(), 
num=len(X_iris))[:, np.newaxis]\n",
+    "second_kde = kernel_density(X_iris[:,1][:, np.newaxis], second_x) \n",
+    "ax_component2.plot(second_kde,second_x)\n",
+    "\n",
+    "\n",
+    "# Fit and plot the model\n",
+    "buffer = 0.5\n",
+    "x = np.linspace(X_iris[:,0].min()-buffer, X_iris[:,0].max()+buffer, 
100)\n",
+    "y = np.linspace(X_iris[:,1].min()-buffer, X_iris[:,1].max()+buffer, 
100)\n",
+    "X_test, Y_test = np.meshgrid(x, y)\n",
+    "X_plot = np.c_[X_test.ravel(), Y_test.ravel()]\n",
+    "ests = kernel_density(X_iris, X_plot, 0.5).reshape(X_test.shape)\n",
+    "ax_both.contour(X_test, Y_test, ests)\n",
+    "ax_both.scatter(X_iris[:, 0], X_iris[:, 1], c=y_iris, cmap=plt.cm.Set1, 
edgecolor=\"k\")\n",
+    "ax_both.set_xlabel(\"Sepal length\")\n",
+    "ax_both.set_ylabel(\"Sepal width\")\n",
+    "\n",
+    "# automatically adjust padding horizontally as well as vertically.\n",
+    "plt.subplots_adjust(wspace=0, hspace=0)\n",
+    "ax_both.axis('equal')\n",
+    "\n",
+    "for a in [ax_component1, ax_component2]:\n",
+    "    a.get_xaxis().set_visible(False)\n",
+    "    a.get_yaxis().set_visible(False)\n",
+    "\n",
+    "plt.gcf().canvas.draw()\n",
+    "ylabels = ax_both.get_yticklabels()\n",
+    "ylabels[0] = ylabels[-1] = \"\"\n",
+    "ax_both.set_yticklabels(ylabels)\n",
+    "xlabels = ax_both.get_xticklabels()\n",
+    "xlabels[0] = xlabels[-1] = \"\"\n",
+    "ax_both.set_xticklabels(ylabels)\n",
+    "fig.suptitle('Kernel Density Estimate', fontsize=16)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "f6a73ab2",
+   "metadata": {},
+   "source": [
+    "## 5. The Shape of Things to Come: Coresets\n",
+    "\n",
+    "One problem with the KDE method, is that the ``test`` time is high.\n",
+    "For every query point the user provides, the query must evaluate 
$\\frac1n \\sum_{i=1}^n K(x - x^*)$.\n",
+    "This sum is a sum over $n$ length $d$ vectors so costs $O(nd)$ time  to 
return the estimate.  In addition, $O(n)$ space is required for any query since 
the entire dataset is needed.  Clearly, for a large dataset, this is not a 
feasible solution so we need a more scalable method.  \n",
+    "\n",
+    "The method that we have developed is called a _coreset_ - a specially 
chosen weighted subset of the data.  The coreset should be a small 
representation of the entire dataset for the kernel query.\n",
+    "\n",
+    "The idea is tha a query on test point $x^*$\n",
+    "\\begin{align}\n",
+    "q^* = \\frac1n \\sum_{i=1}^n K(x^* - x_i)\n",
+    "\\end{align}\n",
+    "should be well approximated by evaluating the kernel on the weight-point 
pairs $(w_i, v_i)$:\n",
+    "\\begin{align}\n",
+    "\\hat{q} = \\frac1m \\sum_{i=1}^m K(x^* - w_i v_i)\n",
+    "\\end{align}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 31,
+   "id": "d595e479",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from datasketches import density_sketch"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e177768a",
+   "metadata": {},
+   "source": [
+    "We will build a density estimator from the coreset so define a wrapper 
class."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 32,
+   "id": "d0247f18",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class CoresetKDE:\n",
+    "    def __init__(self, k_, dim_, bandwidth_=1):\n",
+    "        self.k = k_\n",
+    "        self.dim = dim_\n",
+    "        self.bandwidth = bandwidth_\n",
+    "        self.normalisation_factor = np.sqrt(2.)*self.bandwidth\n",
+    "        self.estimation_scale_factor = 
1./(np.sqrt(2*np.pi)*self.bandwidth)\n",
+    "        self.sketch = density_sketch(self.k,  self.dim)\n",
+    "    \n",
+    "    def fit(self, X_):\n",
+    "        \"\"\"Fits the coreset to the data\"\"\"\n",
+    "        for x_ in X_:\n",
+    "            self.sketch.update(x_ / self.normalisation_factor)\n",
+    "            \n",
+    "    def predict(self, X):\n",
+    "        \"\"\"Returns density estimates over array X\"\"\"\n",
+    "        predictions = np.zeros((len(X)), dtype=X.dtype)\n",
+    "        for i, x in enumerate(X):\n",
+    "            predictions[i] = 
self.estimation_scale_factor*self.sketch.get_estimate( x / 
self.normalisation_factor)\n",
+    "        return predictions\n",
+    "    \n",
+    "    def get_coreset(self):\n",
+    "        \"\"\"Returns the weighted coreset\"\"\"\n",
+    "        samples_weights = [pw for pw in self.sketch]\n",
+    "        samples = np.zeros((self.sketch.get_num_retained(), self.dim))\n",
+    "        weights = np.zeros((self.sketch.get_num_retained(),))\n",
+    "        for i, pw in enumerate(self.sketch):\n",
+    "            samples[i] = pw[0]\n",
+    "            weights[i] = pw[1]\n",
+    "        return samples*self.normalisation_factor, weights"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5f32708f",
+   "metadata": {},
+   "source": [
+    "We plot the original dataset in white.  The coreset is smaller than the 
original input size and also has weights associated to each point.  The weights 
are indicated by the colorbar.  Only the coloured points are retained by the 
coreset."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 33,
+   "id": "5ba2fa78",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<matplotlib.legend.Legend at 0x30e78d050>"
+      ]
+     },
+     "execution_count": 33,
+     "metadata": {},
+     "output_type": "execute_result"
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAf4AAAGiCAYAAAAGI6SpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAABqoUlEQVR4nO3deVxU5f4H8M+ZAQYEBndERSHFNSwtSzBcbpK5Xci05Wdp273dMJcyTK17tUUx0SwrbLtpXTNvKmKR5lVT08TcDdxSc0EFXNhRtjnP7w9jamQ7M3OY9fN+vc4fHJ5zzvOcmcPDebavJIQQICIiIregsXcGiIiIyHZY8RMREbkRVvxERERuhBU/ERGRG2HFT0RE5EZY8RMREbkRVvxERERuhBU/ERGRG2HFT0RE5EZY8RMREbkRVvxEREROwGAw4J///CdCQ0Ph4+ODDh064I033oC5K+97NFD+iIiI
 [...]
+      "text/plain": [
+       "<Figure size 640x480 with 2 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "bandwidth = 0.5\n",
+    "c = CoresetKDE(10, 2, bandwidth)\n",
+    "c.fit(X_iris)\n",
+    "coreset_points, coreset_weights = c.get_coreset()\n",
+    "\n",
+    "fig, ax = plt.subplots()\n",
+    "ax.scatter(X_iris[:, 0], X_iris[:, 1], c=\"white\", edgecolor=\"k\")\n",
+    "sc = ax.scatter(coreset_points[:, 0], coreset_points[:, 1], \n",
+    "                c=coreset_weights,cmap=\"coolwarm\", edgecolor=None, 
label=f\"Coreset size:{c.sketch.get_num_retained()}\")\n",
+    "plt.colorbar(sc)\n",
+    "ax.legend()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "70dbe9d3",
+   "metadata": {},
+   "source": [
+    "We can visualise the two estimates as follows.  First we defined a test 
grid over the input domain.  Then we will build a coreset and fit the KDE over 
the coreset."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 34,
+   "id": "5b08682e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "buffer = 0.5\n",
+    "x = np.linspace(X_iris[:,0].min()-buffer, X_iris[:,0].max()+buffer, 
100)\n",
+    "y = np.linspace(X_iris[:,1].min()-buffer, X_iris[:,1].max()+buffer, 
100)\n",
+    "X_test, Y_test = np.meshgrid(x, y)\n",
+    "X_plot = np.c_[X_test.ravel(), Y_test.ravel()]\n",
+    "\n",
+    "\n",
+    "ests = kernel_density(X_iris, X_plot, bandwidth)\n",
+    "ests = ests.reshape(X_test.shape)\n",
+    "coreset_ests = c.predict(X_plot).reshape(X_test.shape)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
+   "id": "b5eefd9b",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      
"/var/folders/ts/39plpy691lg3xd9rvzndrmt40000gq/T/ipykernel_20006/3989765740.py:4:
 UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will 
be ignored\n",
+      "  ax_kde.scatter(X_iris[:, 0], X_iris[:, 1], c=\"white\", 
cmap=plt.cm.Set1, edgecolor=\"k\")\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAA8UAAAGJCAYAAACjN5FSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddXQUV/vA8e9u3N1diSIJGry4OxXeQvVHS52+FepGqVBXqm8N2uJOgOIQSJAEQtzd3bPy+yNlSwqUJE12I/dzzh5OZ2dn7qY7c+e58lyJUqlUIgiCIAiCIAiCIAh9kFTTBRAEQRAEQRAEQRAETRFBsSAIgiAIgiAIgtBniaBYEARBEARBEARB6LNEUCwIgiAIgiAIgiD0WSIoFgRBEARBEARBEPosERQLgiAIgiAIgiAIfZYIigVBEARBEARBEIQ+SwTFgiAIgiAIgiAIQp8lgmJBEARBEARBEASh
 [...]
+      "text/plain": [
+       "<Figure size 1200x400 with 3 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "fig, ax = plt.subplots(figsize=(12,4),ncols=2)\n",
+    "ax_kde, ax_coreset = ax\n",
+    "ax_kde.contour(X_test, Y_test, ests)\n",
+    "ax_kde.scatter(X_iris[:, 0], X_iris[:, 1], c=\"white\", cmap=plt.cm.Set1, 
edgecolor=\"k\")  \n",
+    "ax_coreset.contour(X_test, Y_test, coreset_ests)\n",
+    "ax_coreset.scatter(X_iris[:, 0], X_iris[:, 1], c=\"white\", 
edgecolor=\"k\")\n",
+    "sc = ax_coreset.scatter(coreset_points[:, 0], coreset_points[:, 1], \n",
+    "                c=coreset_weights, cmap=\"coolwarm\", edgecolor=None, 
label=f\"Coreset size:{c.sketch.get_num_retained()}\")\n",
+    "plt.colorbar(sc)\n",
+    "ax_kde.set_title(\"KDE\")\n",
+    "ax_coreset.set_title(\"Coreset\")\n",
+    "ax_kde.set_xlabel(\"Sepal length\")\n",
+    "ax_kde.set_ylabel(\"Sepal width\")\n",
+    "plt.subplots_adjust(wspace=0, hspace=0)\n",
+    "\n",
+    "ax_coreset.get_yaxis().set_visible(False)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "30bc6ee9",
+   "metadata": {},
+   "source": [
+    "### Conclusion\n",
+    "\n",
+    "**Objective 1**: To understand exactly what function a kernel density 
estimate is approximating.\n",
+    "\n",
+    "We have seen that the KDE is a function evaluated over a finite sample 
from a distribution.  \n",
+    "In practice this distribution is likely unknown so the KDE may not have 
the same scale as the distribution but it should resemble the shape of the 
distribution.\n",
+    "\n",
+    "\n",
+    "**Objective 2**: To understand the parameters that affecct the 
performance of a KDE.\n",
+    "\n",
+    "We have seen that the KDE converges quickly to the convolution function 
when the sample size increases.  \n",
+    "When we know the distribution from which the data is sampled, we can also 
test how far the convolution form is from the true distribution and we saw that 
decreasing the bandwidth while increasing the sample size allows the KDE to 
approach the true distribution.\n",
+    "In practice, we will likely not know the distribution from which a fixed 
size sample is drawn so the bandwidth parameter must be evaluated to choose the 
best option."
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/jupyter/density_sketch/2-density-sketch-visualization.ipynb 
b/jupyter/density_sketch/2-density-sketch-visualization.ipynb
new file mode 100644
index 0000000..8e4145d
--- /dev/null
+++ b/jupyter/density_sketch/2-density-sketch-visualization.ipynb
@@ -0,0 +1,543 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "badd26a1",
+   "metadata": {},
+   "source": [
+    "# Visualising Species Distribution\n",
+    "\n",
+    "This is an example of how coresets can be used to generate the species 
distribution visualisation from 
[scikit-learn](https://scikit-learn.org/stable/auto_examples/neighbors/plot_species_kde.html#sphx-glr-auto-examples-neighbors-plot-species-kde-py).\n",
+    "Much of this code has been lifted from the relevant sklearn page and then 
adapted for use with a coreset."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "0624fa23",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import numpy as np\n",
+    "import matplotlib.pyplot as plt\n",
+    "from sklearn.datasets import fetch_species_distributions\n",
+    "from sklearn.metrics.pairwise import rbf_kernel\n",
+    "from sklearn.neighbors import KernelDensity\n",
+    "from datasketches import density_sketch\n",
+    "%matplotlib inline\n",
+    "# if basemap is available, we'll use it.\n",
+    "# otherwise, we'll improvise later...\n",
+    "try:\n",
+    "    from mpl_toolkits.basemap import Basemap\n",
+    "\n",
+    "    basemap = True\n",
+    "except ImportError:\n",
+    "    basemap = False"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "66a143b3-dca2-435e-aa7f-0f0bca3bfa25",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "id": "ff796ec0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def construct_grids(batch):\n",
+    "    \"\"\"Construct the map grid from the batch object\n",
+    "\n",
+    "    Parameters\n",
+    "    ----------\n",
+    "    batch : Batch object\n",
+    "        The object returned by :func:`fetch_species_distributions`\n",
+    "\n",
+    "    Returns\n",
+    "    -------\n",
+    "    (xgrid, ygrid) : 1-D arrays\n",
+    "        The grid corresponding to the values in batch.coverages\n",
+    "    \"\"\"\n",
+    "    # x,y coordinates for corner cells\n",
+    "    xmin = batch.x_left_lower_corner + batch.grid_size\n",
+    "    xmax = xmin + (batch.Nx * batch.grid_size)\n",
+    "    ymin = batch.y_left_lower_corner + batch.grid_size\n",
+    "    ymax = ymin + (batch.Ny * batch.grid_size)\n",
+    "\n",
+    "    # x coordinates of the grid cells\n",
+    "    xgrid = np.arange(xmin, xmax, batch.grid_size)\n",
+    "    # y coordinates of the grid cells\n",
+    "    ygrid = np.arange(ymin, ymax, batch.grid_size)\n",
+    "\n",
+    "    return (xgrid, ygrid)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "id": "730603b1",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Get matrices/arrays of species IDs and locations\n",
+    "# if fetch_species_distributions() call does not work you may need \n",
+    "# import ssl\n",
+    "# ssl._create_default_https_context = ssl._create_unverified_context\n",
+    "\n",
+    "data = fetch_species_distributions()\n",
+    "species_names = [\"Bradypus Variegatus\", \"Microryzomys Minutus\"]\n",
+    "\n",
+    "Xtrain = np.vstack([data[\"train\"][\"dd lat\"], data[\"train\"][\"dd 
long\"]]).T\n",
+    "ytrain = np.array(\n",
+    "    [d.decode(\"ascii\").startswith(\"micro\") for d in 
data[\"train\"][\"species\"]],\n",
+    "    dtype=\"int\",\n",
+    ")\n",
+    "Xtrain *= np.pi / 180.0  # Convert lat/long to radians"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "id": "df7329d2",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Set up the data grid for the contour plot\n",
+    "xgrid, ygrid = construct_grids(data)\n",
+    "X, Y = np.meshgrid(xgrid[::5], ygrid[::5][::-1])\n",
+    "land_reference = data.coverages[6][::5, ::5]\n",
+    "land_mask = (land_reference > -9999).ravel()\n",
+    "\n",
+    "xy = np.vstack([Y.ravel(), X.ravel()]).T\n",
+    "xy = xy[land_mask]\n",
+    "xy *= np.pi / 180.0"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "id": "16348ec1",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      " - computing KDE in spherical coordinates\n",
+      " - plot coastlines from coverage\n",
+      " - computing KDE in spherical coordinates\n",
+      " - plot coastlines from coverage\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAlQAAAGbCAYAAAACzg7VAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAADO7ElEQVR4nOydd1gTWRfG30AIvYsooqgoKmBXLKiAoCj2rmvva6/fquu6a++9rF3Xjr2jWCiKZe1l7YqKHaR3CLnfHzFDJj0hIYD39zzzaGbu3LmZkJN3zj3nXA4hhIBCoVAoFAqFojEG+h4AhUKhUCgUSnGHCioKhUKhUCiUAkIFFYVCoVAoFEoBoYKKQqFQKBQKpYBQQUWhUCgUCoVSQKigolAoFAqFQikgVFBRKBQKhUKhFBAqqCgUCoVCoVAKCBVUFAqFQqFQKAWECqoixLt378DhcPDPP//oeyjFhtmzZ4PD
 [...]
+      "text/plain": [
+       "<Figure size 640x480 with 2 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# Plot map of South America with distributions of each species\n",
+    "fig = plt.figure()\n",
+    "fig.subplots_adjust(left=0.05, right=0.95, wspace=0.05)\n",
+    "\n",
+    "for i in range(2):\n",
+    "    plt.subplot(1, 2, i + 1)\n",
+    "\n",
+    "    # construct a kernel density estimate of the distribution\n",
+    "    print(\" - computing KDE in spherical coordinates\")\n",
+    "    kde = KernelDensity(\n",
+    "        bandwidth=0.04, metric=\"haversine\", kernel=\"gaussian\", 
algorithm=\"ball_tree\"\n",
+    "    )\n",
+    "    kde.fit(Xtrain[ytrain == i])\n",
+    "\n",
+    "    # evaluate only on the land: -9999 indicates ocean\n",
+    "    Z = np.full(land_mask.shape[0], -9999, dtype=\"int\")\n",
+    "    Z[land_mask] = np.exp(kde.score_samples(xy))\n",
+    "    Z = Z.reshape(X.shape)\n",
+    "\n",
+    "    # plot contours of the density\n",
+    "    levels = np.linspace(0, Z.max(), 25)\n",
+    "    plt.contourf(X, Y, Z, levels=levels, cmap=plt.cm.Reds)\n",
+    "\n",
+    "    if basemap:\n",
+    "        print(\" - plot coastlines using basemap\")\n",
+    "        m = Basemap(\n",
+    "            projection=\"cyl\",\n",
+    "            llcrnrlat=Y.min(),\n",
+    "            urcrnrlat=Y.max(),\n",
+    "            llcrnrlon=X.min(),\n",
+    "            urcrnrlon=X.max(),\n",
+    "            resolution=\"c\",\n",
+    "        )\n",
+    "        m.drawcoastlines()\n",
+    "        m.drawcountries()\n",
+    "    else:\n",
+    "        print(\" - plot coastlines from coverage\")\n",
+    "        plt.contour(\n",
+    "            X, Y, land_reference, levels=[-9998], colors=\"k\", 
linestyles=\"solid\"\n",
+    "        )\n",
+    "        plt.xticks([])\n",
+    "        plt.yticks([])\n",
+    "\n",
+    "    plt.title(species_names[i])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "7b84aedb",
+   "metadata": {},
+   "source": [
+    "This method uses the Haversine distance.  We only have implemented the 
Euclidean distance so will use Euclidean distance for comparisons."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "id": "4442f401",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      " - computing KDE in spherical coordinates\n",
+      " - plot coastlines from coverage\n",
+      " - computing KDE in spherical coordinates\n",
+      " - plot coastlines from coverage\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAlQAAAGbCAYAAAACzg7VAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAADOCElEQVR4nOydd1gTWRfG3xAIoTcRRRSFtQFW7KiAYO9d197XXndd3bV31967a8eGvWChWdcu+9kVewPpPYHc74/sjOmFJATw/p5nHs3MnTs3Q3LyzrnnnsMhhBBQKBQKhUKhUPKNibEHQKFQKBQKhVLUoYKKQqFQKBQKRUeooKJQKBQKhULRESqoKBQKhUKhUHSECioKhUKhUCgUHaGCikKhUCgUCkVHqKCiUCgUCoVC0REqqCgUCoVCoVB0hAoqCoVCoVAoFB2hgqoQ8ebNG3A4HPz999/GHkqRYfbs2eBwOMYe
 [...]
+      "text/plain": [
+       "<Figure size 640x480 with 2 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# Plot map of South America with distributions of each species\n",
+    "fig = plt.figure()\n",
+    "fig.subplots_adjust(left=0.05, right=0.95, wspace=0.05)\n",
+    "\n",
+    "for i in range(2):\n",
+    "    plt.subplot(1, 2, i + 1)\n",
+    "\n",
+    "    # construct a kernel density estimate of the distribution\n",
+    "    print(\" - computing KDE in spherical coordinates\")\n",
+    "    kde = KernelDensity(\n",
+    "        bandwidth=0.04, metric=\"euclidean\", kernel=\"gaussian\", 
algorithm=\"ball_tree\"\n",
+    "    )\n",
+    "    kde.fit(Xtrain[ytrain == i])\n",
+    "\n",
+    "    # evaluate only on the land: -9999 indicates ocean\n",
+    "    Z = np.full(land_mask.shape[0], -9999, dtype=\"int\")\n",
+    "    Z[land_mask] = np.exp(kde.score_samples(xy))\n",
+    "    Z = Z.reshape(X.shape)\n",
+    "\n",
+    "    # plot contours of the density\n",
+    "    levels = np.linspace(0, Z.max(), 25)\n",
+    "    plt.contourf(X, Y, Z, levels=levels, cmap=plt.cm.Blues)\n",
+    "\n",
+    "    if basemap:\n",
+    "        print(\" - plot coastlines using basemap\")\n",
+    "        m = Basemap(\n",
+    "            projection=\"cyl\",\n",
+    "            llcrnrlat=Y.min(),\n",
+    "            urcrnrlat=Y.max(),\n",
+    "            llcrnrlon=X.min(),\n",
+    "            urcrnrlon=X.max(),\n",
+    "            resolution=\"c\",\n",
+    "        )\n",
+    "        m.drawcoastlines()\n",
+    "        m.drawcountries()\n",
+    "    else:\n",
+    "        print(\" - plot coastlines from coverage\")\n",
+    "        plt.contour(\n",
+    "            X, Y, land_reference, levels=[-9998], colors=\"k\", 
linestyles=\"solid\"\n",
+    "        )\n",
+    "        plt.xticks([])\n",
+    "        plt.yticks([])\n",
+    "\n",
+    "    plt.title(species_names[i])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "109b8027",
+   "metadata": {},
+   "source": [
+    "Now we will define our own kernel function so that we can check the 
density sketch properly.  The scikit-learn implementation uses a kernel defined 
as:\n",
+    "\\begin{align}\n",
+    "\\hat{f}(x^*) = \\frac{1}{2 \\pi^{d/2} h^d} \\cdot \\frac1n \\sum_{i=1}^n 
K\\left( -\\frac{||x^* - x_i ||^2}{2 h^2} \\right).\n",
+    "\\end{align}\n",
+    "\n",
+    "However, they actually use a tree-based implementation so that points far 
away from the test points are grouped by leaf to give a common contribution to 
the sum which reduces the total number of evaluations in the summation.\n",
+    "Unlike the tree-based implementation, this coreset implementation 
evaluates the full sum over the retained points."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "id": "4cb9e222",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def kernel_density(Xtrain, Xtest, bandwidth=1.):\n",
+    "    \"\"\"\n",
+    "    Returns the kernel density estimate between Xtrain and Xtest.\n",
+    "    returns:\n",
+    "        (1/n)*(1/bandwidth*sqrt(2pi))\\sum_{i=1}^n K( (x* - x_i) / 
bandwidth )\n",
+    "    The bandwidth in scipy is in the numerator so we use 1./bandwidth\n",
+    "    The mean function picks up the 1/n factor.\n",
+    "    \"\"\"\n",
+    "    for x in [Xtrain, Xtest]:\n",
+    "        if x.ndim == 1:\n",
+    "            x.reshape(-1, 1)\n",
+    "    assert Xtrain.shape[1] == Xtest.shape[1] \n",
+    "    d = Xtest.shape[1]\n",
+    "    g = (1./bandwidth)**2\n",
+    "    K = rbf_kernel(Xtrain, Xtest, gamma=0.5*g)\n",
+    "    K /= (bandwidth*np.sqrt(2.*np.pi))**d\n",
+    "    return np.mean(K, axis=0)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "9b6dc25a",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      " - computing KDE in spherical coordinates\n",
+      " - plot coastlines from coverage\n",
+      " - computing KDE in spherical coordinates\n",
+      " - plot coastlines from coverage\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAlQAAAGbCAYAAAACzg7VAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAADOCElEQVR4nOydd1gTWRfG3xAIoTcRRRSFtQFW7KiAYO9d197XXndd3bV31967a8eGvWChWdcu+9kVewPpPYHc74/sjOmFJATw/p5nHs3MnTs3Q3LyzrnnnsMhhBBQKBQKhUKhUPKNibEHQKFQKBQKhVLUoYKKQqFQKBQKRUeooKJQKBQKhULRESqoKBQKhUKhUHSECioKhUKhUCgUHaGCikKhUCgUCkVHqKCiUCgUCoVC0REqqCgUCoVCoVB0hAoqCoVCoVAoFB2hgqoQ8ebNG3A4HPz999/GHkqRYfbs2eBwOMYe
 [...]
+      "text/plain": [
+       "<Figure size 640x480 with 2 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# Plot map of South America with distributions of each species\n",
+    "fig = plt.figure()\n",
+    "fig.subplots_adjust(left=0.05, right=0.95, wspace=0.05)\n",
+    "\n",
+    "for i in range(2):\n",
+    "    plt.subplot(1, 2, i + 1)\n",
+    "\n",
+    "    # construct a kernel density estimate of the distribution\n",
+    "    print(\" - computing KDE in spherical coordinates\")\n",
+    "    kde = KernelDensity(\n",
+    "        bandwidth=0.04, metric=\"euclidean\", kernel=\"gaussian\", 
algorithm=\"ball_tree\"\n",
+    "    )\n",
+    "    kde.fit(Xtrain[ytrain == i])\n",
+    "    \n",
+    "    my_kde = kernel_density(Xtrain[ytrain == i], xy, bandwidth=0.04)\n",
+    "\n",
+    "    # evaluate only on the land: -9999 indicates ocean\n",
+    "    Z = np.full(land_mask.shape[0], -9999, dtype=\"int\")\n",
+    "    Z[land_mask] = my_kde \n",
+    "    Z = Z.reshape(X.shape)\n",
+    "\n",
+    "    # plot contours of the density\n",
+    "    levels = np.linspace(0, Z.max(), 25)\n",
+    "    plt.contourf(X, Y, Z, levels=levels, cmap=plt.cm.Blues)\n",
+    "\n",
+    "    if basemap:\n",
+    "        print(\" - plot coastlines using basemap\")\n",
+    "        m = Basemap(\n",
+    "            projection=\"cyl\",\n",
+    "            llcrnrlat=Y.min(),\n",
+    "            urcrnrlat=Y.max(),\n",
+    "            llcrnrlon=X.min(),\n",
+    "            urcrnrlon=X.max(),\n",
+    "            resolution=\"c\",\n",
+    "        )\n",
+    "        m.drawcoastlines()\n",
+    "        m.drawcountries()\n",
+    "    else:\n",
+    "        print(\" - plot coastlines from coverage\")\n",
+    "        plt.contour(\n",
+    "            X, Y, land_reference, levels=[-9998], colors=\"k\", 
linestyles=\"solid\"\n",
+    "        )\n",
+    "        plt.xticks([])\n",
+    "        plt.yticks([])\n",
+    "\n",
+    "    plt.title(species_names[i])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "766d889f",
+   "metadata": {},
+   "source": [
+    "This looks almost identical to the sklearn implementation.  Now let's use 
a coreset and compare the two."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "id": "13f60986",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class CoresetKDE:\n",
+    "    def __init__(self, k_, dim_, bandwidth_=1):\n",
+    "        self.k = k_\n",
+    "        self.dim = dim_\n",
+    "        self.bandwidth = bandwidth_\n",
+    "        self.normalisation_factor = np.sqrt(2.)*self.bandwidth\n",
+    "        self.estimation_scale_factor = 
1./(self.bandwidth*np.sqrt(2.*np.pi))**self.dim\n",
+    "        self.sketch = density_sketch(self.k,  self.dim)\n",
+    "    \n",
+    "    def fit(self, X_):\n",
+    "        \"\"\"Fits the coreset to the data\"\"\"\n",
+    "        for x_ in X_:\n",
+    "            self.sketch.update(x_ / self.normalisation_factor)\n",
+    "            \n",
+    "    def predict(self, X):\n",
+    "        \"\"\"Returns density estimates over array X\"\"\"\n",
+    "        predictions = np.zeros((len(X)), dtype=X.dtype)\n",
+    "        for i, x in enumerate(X):\n",
+    "            predictions[i] = 
self.estimation_scale_factor*self.sketch.get_estimate( x / 
self.normalisation_factor)\n",
+    "        return predictions\n",
+    "    \n",
+    "    def get_coreset(self):\n",
+    "        \"\"\"Returns the weighted coreset\"\"\"\n",
+    "        samples_weights = [pw for pw in self.sketch]\n",
+    "        samples = np.zeros((self.sketch.get_num_retained(), self.dim))\n",
+    "        weights = np.zeros((self.sketch.get_num_retained(),))\n",
+    "        for i, pw in enumerate(self.sketch):\n",
+    "            samples[i] = pw[0]\n",
+    "            weights[i] = pw[1]\n",
+    "        return samples*self.normalisation_factor, weights"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "id": "8de799c7",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Coreset size: 81 of 926\n",
+      " - plot coastlines from coverage\n",
+      "Coreset size: 90 of 698\n",
+      " - plot coastlines from coverage\n"
+     ]
+    },
+    {
+     "data": {
+      "image/png": 
"iVBORw0KGgoAAAANSUhEUgAAAlQAAAGbCAYAAAACzg7VAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAADQIUlEQVR4nOydd1gTWRfG3xAIoYOIKKKgrFjAhooFCwgWsHdde1+7q7tr2XXX3tbeu2vvvSIKgljWLn52RcUO0nvL/f7IZkyZ9IQEvL/nmUczc+fOzYScvHPuuedwCCEEFAqFQqFQKBSNMTH0ACgUCoVCoVCKO1RQUSgUCoVCoWgJFVQUCoVCoVAoWkIFFYVCoVAoFIqWUEFFoVAoFAqFoiVUUFEoFAqFQqFoCRVUFAqFQqFQKFpCBRWFQqFQKBSKllBBRaFQKBQKhaIlVFAZEW/evAGHw8E///xj6KEUG2bOnAkO
 [...]
+      "text/plain": [
+       "<Figure size 640x480 with 2 Axes>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "# Plot map of South America with distributions of each species\n",
+    "fig = plt.figure()\n",
+    "fig.subplots_adjust(left=0.05, right=0.95, wspace=0.05)\n",
+    "\n",
+    "for i in range(2):\n",
+    "    plt.subplot(1, 2, i + 1)\n",
+    "\n",
+    "    # construct a kernel density estimate of the distribution using the 
coreset.\n",
+    "    c = CoresetKDE(16, 2, bandwidth_=0.04)\n",
+    "    c.fit(Xtrain[ytrain==i])\n",
+    "    c_ests = c.predict(xy)\n",
+    "    print(f\"Coreset size: {c.sketch.get_num_retained()} of 
{c.sketch.get_n()}\")\n",
+    "    #coreset_points, coreset_weights = c.get_coreset()\n",
+    "    # evaluate only on the land: -9999 indicates ocean\n",
+    "    Z = np.full(land_mask.shape[0], -9999, dtype=\"int\")\n",
+    "    Z[land_mask] = c_ests \n",
+    "    Z = Z.reshape(X.shape)\n",
+    "\n",
+    "    # plot contours of the density\n",
+    "    levels = np.linspace(0, Z.max(), 25)\n",
+    "    plt.contourf(X, Y, Z, levels=levels, cmap=plt.cm.Blues)\n",
+    "\n",
+    "    if basemap:\n",
+    "        print(\" - plot coastlines using basemap\")\n",
+    "        m = Basemap(\n",
+    "            projection=\"cyl\",\n",
+    "            llcrnrlat=Y.min(),\n",
+    "            urcrnrlat=Y.max(),\n",
+    "            llcrnrlon=X.min(),\n",
+    "            urcrnrlon=X.max(),\n",
+    "            resolution=\"c\",\n",
+    "        )\n",
+    "        m.drawcoastlines()\n",
+    "        m.drawcountries()\n",
+    "    else:\n",
+    "        print(\" - plot coastlines from coverage\")\n",
+    "        plt.contour(\n",
+    "            X, Y, land_reference, levels=[-9998], colors=\"k\", 
linestyles=\"solid\"\n",
+    "        )\n",
+    "        plt.xticks([])\n",
+    "        plt.yticks([])\n",
+    "\n",
+    "    plt.title(species_names[i])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "a49eec84",
+   "metadata": {},
+   "source": [
+    "The coreset is very similar to the brute force approach but is evaluated 
over many fewer points."
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.4"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/jupyter/density_sketch/kernel_density.py 
b/jupyter/density_sketch/kernel_density.py
new file mode 100644
index 0000000..214f45d
--- /dev/null
+++ b/jupyter/density_sketch/kernel_density.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+from scipy.spatial.distance import cdist
+from scipy.stats import multivariate_normal
+from sklearn.neighbors import KernelDensity
+from sklearn.datasets import make_blobs
+
+def kernel_density(Xtrain, Xtest, bandwidth=1.):
+    """
+    Returns the kernel density estimate between Xtrain and Xtest.
+    returns:
+        (1/n)*(1/bandwidth*sqrt(2pi))^d*sum_{i=1}^n K( (x* - x_i) / bandwidth )
+    The bandwidth in scipy is in the numerator so we use 1./bandwidth
+    The mean function picks up the 1/n factor.
+    """
+    for x in [Xtrain, Xtest]:
+        if x.ndim == 1:
+            x.reshape(-1, 1)
+    g = (1./bandwidth)**2
+    K = np.exp(-cdist(Xtrain, Xtest, metric='sqeuclidean')*g/2)
+    K *= 1./(bandwidth*np.sqrt(2*np.pi))**Xtrain.shape[1]
+    return np.mean(K, axis=0)
+
+def test_kernel_density():
+    # Generate random data
+    np.random.seed(0)
+    X, _ = make_blobs(n_samples=1000, centers=3, n_features=2, random_state=0)
+
+    # Calculate true densities using scipy's multivariate_normal
+    true_densities = multivariate_normal.pdf(X, mean=X.mean(axis=0), 
cov=X.var(axis=0))
+
+    # Calculate estimated densities using your function
+    estimated_densities = kernel_density(X, X, bandwidth=1.)
+
+    # Check that the estimated densities are close to the true densities
+    assert np.allclose(estimated_densities, true_densities, atol=0.05)
+
+    # Check that your function gives the same results as sklearn's 
KernelDensity
+    kde = KernelDensity(bandwidth=1., kernel='gaussian').fit(X)
+    estimated_densities_sklearn = np.exp(kde.score_samples(X))
+    assert np.allclose(estimated_densities, estimated_densities_sklearn, 
atol=0.05)
+
+
+if __name__ == "__main__":
+    test_kernel_density()
diff --git a/jupyter/density_sketch/naive_bayes_classifier.py 
b/jupyter/density_sketch/naive_bayes_classifier.py
new file mode 100644
index 0000000..0af10e7
--- /dev/null
+++ b/jupyter/density_sketch/naive_bayes_classifier.py
@@ -0,0 +1,135 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+from sklearn.naive_bayes import GaussianNB
+from sklearn.metrics import accuracy_score
+from scipy.stats import norm
+from sklearn.neighbors import KernelDensity
+import matplotlib.pyplot as plt
+from kernel_density import kernel_density
+from datasketches import density_sketch
+
+class NaiveBayes:
+    def __init__(self, kernel='gaussian', bandwidth=1.0, coreset=False, 
coreset_k=None, coreset_dim=None):
+        self.kernel = kernel
+        self.bandwidth = bandwidth
+        self.class_priors = None
+        self.class_kdes = None
+        self.using_coreset = coreset
+        if self.using_coreset:
+            self.class_coresets = {}
+            self.k = coreset_k
+            self.d = coreset_dim
+
+    def fit(self, X, y):
+        self.X = X
+        self.y = y
+        self.class_priors = {}
+        self.class_kdes = {}
+
+        for c in np.unique(y):
+            X_c = X[y == c]
+            self.class_priors[c] = len(X_c) / len(X)
+            self.class_kdes[c] = []
+
+            if self.using_coreset:
+                self.class_coresets[c] = []
+                coreset = density_sketch(self.k, 1)
+                for i in range(X.shape[1]):
+                    # perform 1d density estimation over every feature 
+                    for sample in X_c[:, i]:
+                        coreset.update([sample / self.bandwidth] )
+                    self.class_coresets[c].append(coreset)
+
+    def predict_proba(self, X):
+        posteriors = []
+        for i in range(X.shape[0]):
+            likelihoods = []
+            for c in self.class_priors:
+                likelihood = self.class_priors[c]
+                if self.using_coreset:
+                    kde = 1.
+                    for j in range(X.shape[1]):
+                        kde *= self.class_coresets[c][j].get_estimate([X[i, j] 
/ self.bandwidth])
+                    #print(kde.shape)
+                else:
+                    kde = kernel_density(self.X[self.y == c], X[i, 
:].reshape(1,-1), self.bandwidth)[0]
+                likelihoods.append(likelihood*kde)
+            posterior = likelihoods / np.sum(likelihoods)
+            posteriors.append(posterior)
+        return np.array(posteriors)
+
+    def predict(self, X):
+        posteriors = self.predict_proba(X)
+        return np.argmax(posteriors, axis=1)
+
+
+def main():
+    np.random.seed(42)
+
+    # Generate some random data
+    n_sample = 200
+    X = np.concatenate([np.random.normal(0, 1, size=(n_sample, 2)), 
np.random.normal(2, 1, size=(n_sample, 2))], axis=0)
+    y = np.concatenate([np.zeros(n_sample), np.ones(n_sample)], axis=0)
+
+    # Split the data into training and testing sets
+    indices = np.random.permutation(X.shape[0])
+    train_indices, test_indices = indices[:int(0.8 * X.shape[0])], 
indices[int(0.8 * X.shape[0]):]
+    X_train, y_train = X[train_indices], y[train_indices]
+    X_test, y_test = X[test_indices], y[test_indices]
+
+    # Train and test the Naive Bayes classifier
+    nb = NaiveBayes()
+    nb.fit(X_train, y_train)
+    y_pred = nb.predict(X_test)
+    print(f"Custom Naive Bayes Accuracy: {accuracy_score(y_test, y_pred)}")
+
+    # Train and test the Naive Bayes classifier using a coreset
+    nbc = NaiveBayes(coreset=True, coreset_k=8, coreset_dim=X.shape[1])
+    nbc.fit(X_train, y_train)
+    y_pred = nbc.predict(X_test)
+    print(f"Coreset Naive Bayes Accuracy: {accuracy_score(y_test, y_pred)}")
+
+    # Train and test the Sklearn Naive Bayes classifier
+    gnb = GaussianNB()
+    gnb.fit(X_train, y_train)
+    y_pred = gnb.predict(X_test)
+    print(f"Sklearn Naive Bayes Accuracy: {accuracy_score(y_test, y_pred)}")
+
+    fig, ax = plt.subplots()
+    fig.suptitle("Coreset Naive Bayes")
+    buffer = 0.5
+    x = np.linspace(X_train[:, 0].min() - buffer, X_train[:, 0].max() + 
buffer, 200)
+    y = np.linspace(X_train[:, 1].min() - buffer, X_train[:, 1].max() + 
buffer, 200)
+    X_plot, Y_plot = np.meshgrid(x, y)
+    X_plot_2d = np.c_[X_plot.ravel(), Y_plot.ravel()]
+    plot_preds = nbc.predict(X_plot_2d)
+    plot_preds.reshape(X_plot.shape)
+    ax.scatter(X_plot, Y_plot, c=plot_preds, alpha=0.01)#0875)
+
+    # plot the markers on top of the color background
+    ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, marker='d', alpha=1., 
edgecolor="black", label='Training data')
+    ax.scatter(X_test[:, 0], X_test[:, 1], c="white", edgecolor="black", 
marker='o', label='Test data')
+    x_star = np.array([5, -2])[np.newaxis, :]
+    print("Predicting: ", x_star)
+    print("Predicted class: ", nbc.predict(x_star))
+    ax.legend()
+    plt.show()
+
+if __name__ == "__main__":
+    main()
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to