Skip to content

Latest commit

 

History

History
56 lines (44 loc) · 3.01 KB

README.md

File metadata and controls

56 lines (44 loc) · 3.01 KB

Cholesky decompositions for sparse kernels in Gaussian Processes

This repository is the implementation section of my thesis. The GPR class (GP.py) implements a Gaussian Process Regression class based on a custom banded Cholesky decomposition in JAX. The code is benchmarked on a number of different examples to measure the effectiveness of sparse Cholesky implementations in Gaussian processes. Locally supported covariance functions ("sparse kernels") are used in this context.

Motivation

Using covariance tapering in conjunction with special sparse Cholesky factorisation algorithms, it may be possible to reduce the time complexity of larger GP models. The effect of which is supposed to be investigated on real datasets.

Features

  • Gaussian Process Regression in JAX
    • Banded implementation using pure JAX
    • Sparse implementation using a custom primitive based on Eigen.
    • Optimizing the marginal loglikelihood using Optax
    • Optimizing the marginal loglikelihood using JAXopt
  • Kernels
  • Benchmarks based on real data

Requirements

The project runs on Python 3.11.

To install requirements:

pip install -r requirements.txt

Additionally, the package liesel-sparse is required for the sparse algorithms. Currently only available at liesel-devs/liesel-sparse. A virtual environment is recommended. The required datasets are included in the repository.

Datasets

data contains the following processed datasets and other relevant files:

Code example

def kernel_(s, l, x, y):
    return MaternKernel32(s, l, x, y) * WendlandTapering(3, 8, x, y)

gpr = GPR(X_train, y_train, kernel_, params = jnp.array([37**2, 3]), eps = 0.01)
gpr.fit(X_train, y_train)
mean_pred = gpr.predict(X_test, False)

Results

The thesis shows that the sparse kernels can make GPs more scalable if only the Cholesky decomposition of sparse algorithms is considered. In a general sense, the sparse implementations are slower than a pure JAX version. This is due to the overhead of the algorithms, in particular the conversion between different matrix formats. The compactly supported covariance functions achieve great success in banded covariance matrix and reduce the theoretical computational complexity.

License

Licensed under the MIT License.