{ "cells": [ { "cell_type": "markdown", "id": "aad8e99a", "metadata": {}, "source": [ "# Low-rank approximations\n", "\n", "Assume we have a training set composed of $N$ data points.\n", "Fitting a (full-rank) GPR model requires:\n", "- $\\mathcal{O}(N^2)$ memory (for storing the $N\\times N$ kernel matrix)\n", "- $\\mathcal{O}(N^3)$ time (dominated by inverting above matrix)\n", "\n", "Further, assume we want to make predictions at $T$ test points. Making predictions requires:\n", "- $\\mathcal{O}(NT)$ time and memory to compute the predictive mean (dominated by calculating the $N\\times T$ matrix of kernel evaluations)\n", "- $\\mathcal{O}(N^2T)$ time to compute the predictive (diagonal) variance\n", "\n", "This means that full-rank GPR models become prohibitively expensive for large datasets.\n", "\n", "A popular approach to reducing this cost is to approximate the $N\\times N$ kernel matrix taken over the training data with a lower-rank matrix. To do this, a set of $M$ \"sparse\" or \"inducing\" points are chosen at representative locations in the training data. This slightly reduces the accuracy of the model, but can dramatically reduce the cost of fitting and inference when $M \\ll N$:\n", "\n", "- $\\mathcal{O}(NM^2 + M^3)$ memory to store the $M\\times M$ kernel matrix over the inducing points\n", "- $\\mathcal{O}(N^2M + M^3)$ time to fit the model\n", "- $\\mathcal{O}(TM)$ time to make mean predictions\n", "- $\\mathcal{O}(TM^2)$ space to make predictive variance predictions\n", "\n", "`mini-gpr` implements the [Subset of Regressors](https://uk.mathworks.com/help/stats/subset-of-regressors-approximation-for-gpr-models.html) (`SoR`) approach: this model conforms to the same interface as the full-rank `GPR` model, and so is very easy to use as a drop-in replacement.\n", "\n", "Below, we show examples of optimising both GPR and SoR models for a synthetic 1D system:" ] }, { "cell_type": "code", "execution_count": 1, "id": "5ae1498d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'sigma': 2.2437003072015034, 'scale': 1.4595497350774491}\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-09-13T14:56:51.761632\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.3, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from mini_gpr.kernels import RBF\n", "from mini_gpr.models import GPR\n", "from mini_gpr.opt import maximise_log_likelihood, optimise_model\n", "from mini_gpr.tutorials import sample_toy_1d_system\n", "from mini_gpr.viz import show_model_predictions\n", "\n", "x_train, y_train = sample_toy_1d_system()\n", "full_rank_model = optimise_model(\n", " GPR(RBF(), noise=0.1),\n", " maximise_log_likelihood,\n", " x_train,\n", " y_train,\n", " optimise_noise=True,\n", ")\n", "show_model_predictions(full_rank_model, x_train, y_train)\n", "print(full_rank_model.kernel.params)" ] }, { "cell_type": "code", "execution_count": 2, "id": "fbf06b3a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'sigma': 2.5377083633879733, 'scale': 2.2140375820386287}\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-09-13T14:56:51.837835\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.10.3, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "\n", "from mini_gpr.models import SoR\n", "\n", "model = SoR(\n", " kernel=RBF(),\n", " # use 5 inducing points, equally along the range of the training data\n", " sparse_points=np.linspace(1, 9, num=5),\n", " noise=0.1,\n", ")\n", "low_rank_model = optimise_model(\n", " model,\n", " maximise_log_likelihood,\n", " x_train,\n", " y_train,\n", " optimise_noise=True,\n", ")\n", "show_model_predictions(low_rank_model, x_train, y_train)\n", "print(low_rank_model.kernel.params)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 }