Skip to content

Attentive Kernel

sgptools.kernels.attentive.Attentive

Bases: Kernel

Attentive kernel (non-stationary kernel).

This kernel uses a Multi-Layer Perceptron (MLP) to compute attention representations that weight a mixture of RBF components, producing a locally adaptive, non-stationary covariance function.

Implementation based on Weizhe-Chen/attentive_kernels.

Refer to the following paper for more details
  • AK: Attentive Kernel for Information Gathering [Chen et al., 2022]

Attributes:

Name Type Description
_free_amplitude Variable

Trainable scalar amplitude parameter applied to the final covariance.

lengthscales Variable

1D tensor of fixed lengthscale values for the RBF mixture components.

num_lengthscales int

Number of RBF mixture components.

nn NN

Neural network that maps input points to latent attention representations.

Source code in sgptools/kernels/attentive.py
class Attentive(gpflow.kernels.Kernel):
    """Attentive kernel (non-stationary kernel).

    This kernel uses a Multi-Layer Perceptron (MLP) to compute attention
    representations that weight a mixture of RBF components, producing a
    locally adaptive, non-stationary covariance function.

    Implementation based on [Weizhe-Chen/attentive_kernels](https://github.com/Weizhe-Chen/attentive_kernels).

    Refer to the following paper for more details:
        - AK: Attentive Kernel for Information Gathering [Chen et al., 2022]

    Attributes:
        _free_amplitude (tf.Variable):
            Trainable scalar amplitude parameter applied to the final covariance.
        lengthscales (tf.Variable):
            1D tensor of fixed lengthscale values for the RBF mixture components.
        num_lengthscales (int):
            Number of RBF mixture components.
        nn (NN):
            Neural network that maps input points to latent attention
            representations.
    """

    def __init__(
        self,
        lengthscales: Union[List[float], np.ndarray] = None,
        hidden_sizes: List[int] = None,
        amplitude: float = 1.0,
        num_dim: int = 2,
    ):
        """Initialize an Attentive kernel.

        Args:
            lengthscales (List[float] | np.ndarray | None):
                Positive lengthscale values used for the fixed RBF mixture
                components. These are treated as non-trainable parameters.
                If None, a default grid ``np.linspace(0.01, 2.0, 10)`` is used.
            hidden_sizes (List[int] | None):
                Hidden-layer widths of the MLP. The length of this list
                determines the number of hidden layers. If None, defaults to
                ``[10, 10]``.
            amplitude (float):
                Initial value for the trainable scalar amplitude parameter used
                to rescale the final covariance.
            num_dim (int):
                Dimensionality of each input data point (e.g. 2 for 2D inputs).

        Returns:
            None

        Usage:
            Basic usage with fixed lengthscales for 2D data::

                ```python
                import gpflow
                import numpy as np
                from sgptools.kernels.attentive import Attentive

                # Example: 10 fixed lengthscales ranging from 0.01 to 2.0
                l_scales = np.linspace(0.01, 2.0, 10).astype(np.float32)

                # Initialize Attentive kernel for 2D data
                kernel = Attentive(
                    lengthscales=l_scales,
                    hidden_sizes=[10, 10],
                    amplitude=1.0,
                    num_dim=2,
                )

                # Use this kernel in a GPflow model:
                # model = gpflow.models.GPR(
                #     data=(X_train, Y_train),
                #     kernel=kernel,
                #     noise_variance=0.1,
                # )
                # optimize_model(model)
                ```
        """
        super().__init__()
        if lengthscales is None:
            lengthscales = np.linspace(0.01, 2.0, 10)

        if hidden_sizes is None:
            hidden_sizes = [10, 10]
        else:
            hidden_sizes = list(hidden_sizes)

        with self.name_scope:
            self.num_lengthscales = len(lengthscales)
            self._free_amplitude = tf.Variable(
                amplitude,
                shape=[],
                trainable=True,
                dtype=default_float(),
            )

            # Lengthscales are fixed, not optimized.
            self.lengthscales = tf.Variable(
                tf.cast(lengthscales, default_float()),
                shape=[self.num_lengthscales],
                trainable=False,
                dtype=default_float(),
            )

            self.nn = NN(
                [num_dim] + hidden_sizes + [self.num_lengthscales],
                output_activation_fn="softplus",
            )

    @tf.autograph.experimental.do_not_convert
    def get_representations(self, X: tf.Tensor) -> tf.Tensor:
        """Compute normalized latent attention representations.

        Args:
            X (tf.Tensor):
                Tensor of shape (N, D). Input data points.

        Returns:
            tf.Tensor:
                Tensor of shape (N, num_lengthscales) containing unit-norm
                latent representation vectors used for generating attention
                weights.
        """
        Z = self.nn(X)
        representations = Z / tf.norm(Z, axis=1, keepdims=True)
        return representations

    @tf.autograph.experimental.do_not_convert
    def K(self, X: tf.Tensor, X2: Optional[tf.Tensor] = None) -> tf.Tensor:
        """Compute full covariance matrix between X and X2.

        The covariance is a weighted sum of RBF mixture components modulated
        by attention representations in the learned latent space.

        Args:
            X (tf.Tensor):
                Tensor of shape (N1, D). First set of input points.
            X2 (tf.Tensor | None):
                Tensor of shape (N2, D). Optional second set of input points.
                If None, `X` is used for both arguments.

        Returns:
            tf.Tensor:
                Tensor of shape (N1, N2) containing the covariance matrix
                K(X, X2).
        """
        repre1 = self.get_representations(X)
        if X2 is None:
            repre2 = repre1
            X2_internal = X
        else:
            X2_internal = X2
            repre2 = self.get_representations(X2_internal)
        dist = cdist(X, X2_internal)

        def get_mixture_component(i: tf.Tensor) -> tf.Tensor:
            """Compute a single mixture RBF component.

            Args:
                i (tf.Tensor):
                    Scalar integer tensor representing a lengthscale index.

            Returns:
                tf.Tensor:
                    Tensor of shape (N1, N2) containing the i-th mixture
                    kernel component.
            """
            attention_lengthscales = tf.tensordot(
                repre1[:, i], repre2[:, i], axes=0
            )
            return rbf(dist, self.lengthscales[i]) * attention_lengthscales

        cov_mat_per_ls = tf.map_fn(
            fn=get_mixture_component,
            elems=tf.range(self.num_lengthscales, dtype=tf.int64),
            fn_output_signature=dist.dtype,
        )

        cov_mat_summed = tf.reduce_sum(cov_mat_per_ls, axis=0)
        attention_inputs = tf.matmul(repre1, repre2, transpose_b=True)

        return self._free_amplitude * attention_inputs * cov_mat_summed

    @tf.autograph.experimental.do_not_convert
    def K_diag(self, X: tf.Tensor) -> tf.Tensor:
        """Compute the diagonal of K(X, X).

        Args:
            X (tf.Tensor):
                Tensor of shape (N, D). Input points.

        Returns:
            tf.Tensor:
                Tensor of shape (N,) containing the diagonal of the covariance
                matrix (constant when representations are unit norm).
        """
        return self._free_amplitude * tf.ones((X.shape[0],), dtype=X.dtype)

    def get_lengthscales(self, X: np.ndarray) -> np.ndarray:
        """Compute non-stationary effective lengthscales.

        Args:
            X (np.ndarray):
                Array of shape (N, D). Input points at which to estimate
                effective lengthscales.

        Returns:
            np.ndarray:
                Array of shape (N,) containing effective spatially varying
                lengthscale values at the given input locations.
        """
        lengthscales = self.lengthscales.numpy()
        preds = np.zeros(len(X))

        repre = self.get_representations(X)
        for i in range(len(lengthscales)):
            attention = tf.tensordot(
                repre[:, i], tf.transpose(repre[:, i]), axes=0
            )
            preds += np.diag(attention) * lengthscales[i]
        return preds

K(X, X2=None)

Compute full covariance matrix between X and X2.

The covariance is a weighted sum of RBF mixture components modulated by attention representations in the learned latent space.

Parameters:

Name Type Description Default
X Tensor

Tensor of shape (N1, D). First set of input points.

required
X2 Tensor | None

Tensor of shape (N2, D). Optional second set of input points. If None, X is used for both arguments.

None

Returns:

Type Description
Tensor

tf.Tensor: Tensor of shape (N1, N2) containing the covariance matrix K(X, X2).

Source code in sgptools/kernels/attentive.py
@tf.autograph.experimental.do_not_convert
def K(self, X: tf.Tensor, X2: Optional[tf.Tensor] = None) -> tf.Tensor:
    """Compute full covariance matrix between X and X2.

    The covariance is a weighted sum of RBF mixture components modulated
    by attention representations in the learned latent space.

    Args:
        X (tf.Tensor):
            Tensor of shape (N1, D). First set of input points.
        X2 (tf.Tensor | None):
            Tensor of shape (N2, D). Optional second set of input points.
            If None, `X` is used for both arguments.

    Returns:
        tf.Tensor:
            Tensor of shape (N1, N2) containing the covariance matrix
            K(X, X2).
    """
    repre1 = self.get_representations(X)
    if X2 is None:
        repre2 = repre1
        X2_internal = X
    else:
        X2_internal = X2
        repre2 = self.get_representations(X2_internal)
    dist = cdist(X, X2_internal)

    def get_mixture_component(i: tf.Tensor) -> tf.Tensor:
        """Compute a single mixture RBF component.

        Args:
            i (tf.Tensor):
                Scalar integer tensor representing a lengthscale index.

        Returns:
            tf.Tensor:
                Tensor of shape (N1, N2) containing the i-th mixture
                kernel component.
        """
        attention_lengthscales = tf.tensordot(
            repre1[:, i], repre2[:, i], axes=0
        )
        return rbf(dist, self.lengthscales[i]) * attention_lengthscales

    cov_mat_per_ls = tf.map_fn(
        fn=get_mixture_component,
        elems=tf.range(self.num_lengthscales, dtype=tf.int64),
        fn_output_signature=dist.dtype,
    )

    cov_mat_summed = tf.reduce_sum(cov_mat_per_ls, axis=0)
    attention_inputs = tf.matmul(repre1, repre2, transpose_b=True)

    return self._free_amplitude * attention_inputs * cov_mat_summed

K_diag(X)

Compute the diagonal of K(X, X).

Parameters:

Name Type Description Default
X Tensor

Tensor of shape (N, D). Input points.

required

Returns:

Type Description
Tensor

tf.Tensor: Tensor of shape (N,) containing the diagonal of the covariance matrix (constant when representations are unit norm).

Source code in sgptools/kernels/attentive.py
@tf.autograph.experimental.do_not_convert
def K_diag(self, X: tf.Tensor) -> tf.Tensor:
    """Compute the diagonal of K(X, X).

    Args:
        X (tf.Tensor):
            Tensor of shape (N, D). Input points.

    Returns:
        tf.Tensor:
            Tensor of shape (N,) containing the diagonal of the covariance
            matrix (constant when representations are unit norm).
    """
    return self._free_amplitude * tf.ones((X.shape[0],), dtype=X.dtype)

__init__(lengthscales=None, hidden_sizes=None, amplitude=1.0, num_dim=2)

Initialize an Attentive kernel.

Parameters:

Name Type Description Default
lengthscales List[float] | ndarray | None

Positive lengthscale values used for the fixed RBF mixture components. These are treated as non-trainable parameters. If None, a default grid np.linspace(0.01, 2.0, 10) is used.

None
hidden_sizes List[int] | None

Hidden-layer widths of the MLP. The length of this list determines the number of hidden layers. If None, defaults to [10, 10].

None
amplitude float

Initial value for the trainable scalar amplitude parameter used to rescale the final covariance.

1.0
num_dim int

Dimensionality of each input data point (e.g. 2 for 2D inputs).

2

Returns:

Type Description

None

Usage

Basic usage with fixed lengthscales for 2D data::

```python
import gpflow
import numpy as np
from sgptools.kernels.attentive import Attentive

# Example: 10 fixed lengthscales ranging from 0.01 to 2.0
l_scales = np.linspace(0.01, 2.0, 10).astype(np.float32)

# Initialize Attentive kernel for 2D data
kernel = Attentive(
    lengthscales=l_scales,
    hidden_sizes=[10, 10],
    amplitude=1.0,
    num_dim=2,
)

# Use this kernel in a GPflow model:
# model = gpflow.models.GPR(
#     data=(X_train, Y_train),
#     kernel=kernel,
#     noise_variance=0.1,
# )
# optimize_model(model)
```
Source code in sgptools/kernels/attentive.py
def __init__(
    self,
    lengthscales: Union[List[float], np.ndarray] = None,
    hidden_sizes: List[int] = None,
    amplitude: float = 1.0,
    num_dim: int = 2,
):
    """Initialize an Attentive kernel.

    Args:
        lengthscales (List[float] | np.ndarray | None):
            Positive lengthscale values used for the fixed RBF mixture
            components. These are treated as non-trainable parameters.
            If None, a default grid ``np.linspace(0.01, 2.0, 10)`` is used.
        hidden_sizes (List[int] | None):
            Hidden-layer widths of the MLP. The length of this list
            determines the number of hidden layers. If None, defaults to
            ``[10, 10]``.
        amplitude (float):
            Initial value for the trainable scalar amplitude parameter used
            to rescale the final covariance.
        num_dim (int):
            Dimensionality of each input data point (e.g. 2 for 2D inputs).

    Returns:
        None

    Usage:
        Basic usage with fixed lengthscales for 2D data::

            ```python
            import gpflow
            import numpy as np
            from sgptools.kernels.attentive import Attentive

            # Example: 10 fixed lengthscales ranging from 0.01 to 2.0
            l_scales = np.linspace(0.01, 2.0, 10).astype(np.float32)

            # Initialize Attentive kernel for 2D data
            kernel = Attentive(
                lengthscales=l_scales,
                hidden_sizes=[10, 10],
                amplitude=1.0,
                num_dim=2,
            )

            # Use this kernel in a GPflow model:
            # model = gpflow.models.GPR(
            #     data=(X_train, Y_train),
            #     kernel=kernel,
            #     noise_variance=0.1,
            # )
            # optimize_model(model)
            ```
    """
    super().__init__()
    if lengthscales is None:
        lengthscales = np.linspace(0.01, 2.0, 10)

    if hidden_sizes is None:
        hidden_sizes = [10, 10]
    else:
        hidden_sizes = list(hidden_sizes)

    with self.name_scope:
        self.num_lengthscales = len(lengthscales)
        self._free_amplitude = tf.Variable(
            amplitude,
            shape=[],
            trainable=True,
            dtype=default_float(),
        )

        # Lengthscales are fixed, not optimized.
        self.lengthscales = tf.Variable(
            tf.cast(lengthscales, default_float()),
            shape=[self.num_lengthscales],
            trainable=False,
            dtype=default_float(),
        )

        self.nn = NN(
            [num_dim] + hidden_sizes + [self.num_lengthscales],
            output_activation_fn="softplus",
        )

get_lengthscales(X)

Compute non-stationary effective lengthscales.

Parameters:

Name Type Description Default
X ndarray

Array of shape (N, D). Input points at which to estimate effective lengthscales.

required

Returns:

Type Description
ndarray

np.ndarray: Array of shape (N,) containing effective spatially varying lengthscale values at the given input locations.

Source code in sgptools/kernels/attentive.py
def get_lengthscales(self, X: np.ndarray) -> np.ndarray:
    """Compute non-stationary effective lengthscales.

    Args:
        X (np.ndarray):
            Array of shape (N, D). Input points at which to estimate
            effective lengthscales.

    Returns:
        np.ndarray:
            Array of shape (N,) containing effective spatially varying
            lengthscale values at the given input locations.
    """
    lengthscales = self.lengthscales.numpy()
    preds = np.zeros(len(X))

    repre = self.get_representations(X)
    for i in range(len(lengthscales)):
        attention = tf.tensordot(
            repre[:, i], tf.transpose(repre[:, i]), axes=0
        )
        preds += np.diag(attention) * lengthscales[i]
    return preds

get_representations(X)

Compute normalized latent attention representations.

Parameters:

Name Type Description Default
X Tensor

Tensor of shape (N, D). Input data points.

required

Returns:

Type Description
Tensor

tf.Tensor: Tensor of shape (N, num_lengthscales) containing unit-norm latent representation vectors used for generating attention weights.

Source code in sgptools/kernels/attentive.py
@tf.autograph.experimental.do_not_convert
def get_representations(self, X: tf.Tensor) -> tf.Tensor:
    """Compute normalized latent attention representations.

    Args:
        X (tf.Tensor):
            Tensor of shape (N, D). Input data points.

    Returns:
        tf.Tensor:
            Tensor of shape (N, num_lengthscales) containing unit-norm
            latent representation vectors used for generating attention
            weights.
    """
    Z = self.nn(X)
    representations = Z / tf.norm(Z, axis=1, keepdims=True)
    return representations