{
"cells": [
{
"cell_type": "markdown",
"id": "ea1fe951",
"metadata": {},
"source": [
"# Choosing a kernel\n",
"\n",
"Different kernel functions induce different _priors_ over the space of all functions that a GPR model can represent. \n",
"\n",
"Choosing the correct kernel (and its hyper-parameters) is therefore a key part of fitting a GPR model to data!\n",
"For an excellent discourse on this subject, see [the kernel cookbook](https://www.cs.toronto.edu/~duvenaud/cookbook/).\n",
"\n",
"Below, we show samples from functions drawn from a number of different kernels available in `mini-gpr`.\n",
"\n",
"## RBF kernel\n",
"\n",
"Perhaps the most commonly used kernel, the RBF kernel (or squared-exponential kernel) limits the functions that can be represented by a GPR model to those that are smooth, infinitely differentiable, and that have finite second moments:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "329cc523",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"from mini_gpr import kernels\n",
"from mini_gpr.viz import sample_kernel\n",
"\n",
"sample_kernel(kernels.RBF(sigma=1, scale=1), c=\"black\", n_samples=3)\n",
"sample_kernel(kernels.RBF(sigma=1, scale=0.2), c=\"red\", n_samples=3)\n",
"sample_kernel(kernels.RBF(sigma=0.4, scale=1), c=\"blue\", n_samples=3)"
]
},
{
"cell_type": "markdown",
"id": "0cb7bc5c",
"metadata": {},
"source": [
"As you can see from above, the `scale` parameter controls the standard deviation (in the vertical direction) of the prior, while the `sigma` parameter controls the lengthscale (in the horizontal direction) of the prior.\n",
"\n",
"## Linear kernel\n",
"\n",
"The linear kernel, defined by the `m` and `scale` parameters limits the functions that can be represented by a GPR model to those that are linear and that:\n",
"- pass through `(m, 0)`\n",
"- pass through the `x=0` according to $y \\sim \\mathcal{N}(0, \\text{scale}^2)$"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b46b3fb8",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# feel free to change these values\n",
"scale = 0.2\n",
"m = 2\n",
"\n",
"sample_kernel(\n",
" kernels.Linear(m=m, scale=scale), n_samples=300, c=\"k\", alpha=0.1, seed=42\n",
")\n",
"plt.axvline(m, color=\"crimson\", ls=\"--\", lw=1)\n",
"plt.axline((m, 0), slope=scale, color=\"crimson\", lw=1)\n",
"plt.axline((m, 0), slope=-scale, color=\"crimson\", lw=1)\n",
"plt.xlim(0, 6);"
]
},
{
"cell_type": "markdown",
"id": "9e1a6724",
"metadata": {},
"source": [
"## Periodic kernel\n",
"\n",
"The periodic kernel limits the functions that can be represented by a GPR model to those that are perfectly periodic. The degree of oscillation within each period is controlled by the `sigma` parameter:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e92adc14",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sample_kernel(kernels.Periodic(period=3, sigma=2, scale=3), c=\"black\", n_samples=1)\n",
"sample_kernel(kernels.Periodic(period=3, sigma=0.3, scale=1), c=\"red\", n_samples=1)"
]
},
{
"cell_type": "markdown",
"id": "3c8e6637",
"metadata": {},
"source": [
"## Combined kernels\n",
"\n",
"Priors with more structure can be constructed by combining multiple kernels.\n",
"\n",
"One way to combine kernels is to add their output values together: in `mini-gpr`, this is done using the `+` operator:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ae970e45",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sample_kernel(\n",
" kernels.Periodic(period=3, sigma=0.3) + kernels.Linear(m=2, scale=3)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "d3e039b8",
"metadata": {},
"source": [
"Another way to combine kernels is to multiply their output values together: in `mini-gpr`, this is done using the `*` operator:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7c5ee625",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sample_kernel(\n",
" kernels.Periodic(period=3, sigma=0.3) * kernels.Linear(m=2, scale=3)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "05586af7",
"metadata": {},
"source": [
"The `**` exponentiation operator can be used to raise a kernel to a power:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e5197950",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sample_kernel(kernels.Linear(m=2, scale=3) ** 2)"
]
},
{
"cell_type": "markdown",
"id": "738b06c7",
"metadata": {},
"source": [
"## Higher dimensions\n",
"\n",
"Of course, in most practical applications, the input space is multi-dimensional.\n",
"\n",
"All of the kernels defined in `mini-gpr` support multi-dimensional input spaces.\n",
"\n",
"As an example, here are some samples from a 2D input space with an\n",
"\n",
"### RBF kernel"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7d4373ad",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from mini_gpr.viz import sample_2d_kernel\n",
"\n",
"sample_2d_kernel(kernels.RBF(sigma=[1, 0.3]))"
]
},
{
"cell_type": "markdown",
"id": "61424cb2",
"metadata": {},
"source": [
"### Combined DotProduct + RBF kernel"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "15d0e198",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"combined_kernel = kernels.DotProduct() + kernels.RBF(sigma=0.15, scale=0.2)\n",
"sample_2d_kernel(combined_kernel, cmap=\"RdBu\", n=3**2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}