{
"cells": [
{
"cell_type": "markdown",
"id": "a1bea4ba",
"metadata": {},
"source": [
"# GPR using `mini-gpr`\n",
"\n",
"[Gaussian process regression](https://en.wikipedia.org/wiki/Gaussian_process) (GPR) modelling is a powerful and flexible tool for regression problems, particularly in cases when we want to quantify how un/certain a model is, and when we are interested in modelling the noise in some data, as well as the underlying mean value.\n",
"\n",
"Let's use the `mini-gpr` package to explore what GPR can do. We'll start by generating a toy, 1D dataset:"
]
},
{
"cell_type": "markdown",
"id": "6d86a743",
"metadata": {},
"source": [
"## A toy system"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "97d959ea",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from mini_gpr.tutorials import sample_toy_1d_system\n",
"\n",
"x_train, y_train = sample_toy_1d_system()\n",
"plt.plot(x_train, y_train, \"ok\", alpha=0.5, ms=4);"
]
},
{
"cell_type": "markdown",
"id": "477408ea",
"metadata": {},
"source": [
"## Model definition\n",
"\n",
"GPR models are defined by specifying a [kernel](https://en.wikipedia.org/wiki/Kernel_(machine_learning)): this is a function that quantifies how similar two data locations are. \n",
"\n",
"GPR models use this information to determine how closely correlated the values of the function are at different locations.\n",
"\n",
"Below we define a (very) simple \"Radial basis function\" kernel:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f7496332",
"metadata": {},
"outputs": [],
"source": [
"from mini_gpr.kernels import RBF\n",
"from mini_gpr.models import GPR\n",
"\n",
"model = GPR(kernel=RBF(), noise=0.3)"
]
},
{
"cell_type": "markdown",
"id": "dad7d93d",
"metadata": {},
"source": [
"There exist infinite functions of a 1D input. The kernel (and its hyperparameters) of a GPR model induces a [prior distribution](https://en.wikipedia.org/wiki/Prior_probability) over these functions. We can use the `sample_prior` method to draw samples from this distribution:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "380eac26",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"xx = np.linspace(0, 10, 250)\n",
"plt.plot(xx, model.sample_prior(xx, n_samples=3, rng=42));"
]
},
{
"cell_type": "markdown",
"id": "11ffdaed",
"metadata": {},
"source": [
"Fitting the model to data favours functions that match the data well over those that do not. \n",
"\n",
"We can use the `fit` method to fit the model to the data, and then sample from the posterior distribution to see some of these:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d07c1f1f",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model.fit(x_train, y_train)\n",
"\n",
"plt.plot(x_train, y_train, \"ok\", alpha=0.5, ms=4)\n",
"plt.plot(xx, model.sample_posterior(xx, n_samples=3, rng=42));"
]
},
{
"cell_type": "markdown",
"id": "b9c50eb4",
"metadata": {},
"source": [
"Finally, we can use the `predict` method to generate predictions at new locations: this generates a mean $f(x)$ value as a sum over all possible functions as weighted by their posterior probability given the data."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "046d7801",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"yy = model.predict(xx)\n",
"plt.plot(x_train, y_train, \"ok\", alpha=0.5, ms=4)\n",
"plt.plot(xx, yy, \"k-\", lw=2);"
]
},
{
"cell_type": "markdown",
"id": "6c5170e0",
"metadata": {},
"source": [
"Beyond a simple posterior mean, GPR models also provide two measures of uncertainty:\n",
"\n",
"1. the \"latent\" (or epistemic) uncertainty, which quantifies the model's uncertainty of the function's *mean value* at a given input;\n",
"2. the \"predictive\" (or aleatoric) uncertainty, which quantifies the model's uncertainty of the function's *value* at a given input.\n",
"\n",
"Given a noise function with standard deviation $\\sigma(x)$, and sufficient training data, the latent uncertainty should tend to zero, while the predictive uncertainty should tend to $\\sigma(x)$.\n",
"\n",
"We can use the `latent_uncertainty` and `predictive_uncertainty` methods to compute these uncertainties for a given GPR model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "911e7971",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"latent_std = model.latent_uncertainty(xx)\n",
"yy_std = model.predictive_uncertainty(xx)\n",
"\n",
"_, (left, right) = plt.subplots(1, 2, figsize=(6, 3), sharey=True)\n",
"\n",
"for ax in (left, right):\n",
" ax.plot(x_train, y_train, \"ok\", alpha=0.5, ms=4)\n",
" ax.plot(xx, yy, \"k-\", lw=2)\n",
"\n",
"left.fill_between(xx, yy - latent_std, yy + latent_std, color=\"cornflowerblue\")\n",
"left.set_title(\"Latent uncertainty\", fontsize=10)\n",
"\n",
"right.fill_between(xx, yy - yy_std, yy + yy_std, color=\"pink\")\n",
"right.set_title(\"Predictive uncertainty\", fontsize=10);"
]
},
{
"cell_type": "markdown",
"id": "a8b9ab13",
"metadata": {},
"source": [
"In subsequent notebooks, we will default to showing the model's predictive uncertainty: a useful visualisation of this quantitiy for 1D systems is exposed by the `mini_gpr.viz.show_model_predictions` function:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "85142b69",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from mini_gpr.viz import show_model_predictions\n",
"\n",
"show_model_predictions(model, x_train, y_train)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}