Source code for menpo.transform.rbf

import numpy as np
from scipy.spatial.distance import cdist
from .base import Transform


class RadialBasisFunction(Transform):
    r"""
    Radial Basis Functions are a class of transform that is used by
    :map:`ThinPlateSplines`. They have to be able to take their own radial
    derivative for :map:`ThinPlateSplines` to be able to take its own total
    derivative.

    Parameters
    ----------
    c : ``(n_centres, n_dims)`` `ndarray`
        The set of centers that make the basis. Usually represents a set of
        source landmarks.
    """
    def __init__(self, c):
        self.c = c

    @property
    def n_centres(self):
        r"""
        The number of centres.

        :type: `int`
        """
        return self.c.shape[0]

    @property
    def n_dims(self):
        r"""
        The RBF can only be applied on points with the same dimensionality as
        the centres.

        :type: `int`
        """
        return self.c.shape[1]

    @property
    def n_dims_output(self):
        r"""
        The result of the transform has a dimension (weight) for every centre.

        :type: `int`
        """
        return self.n_centres


[docs]class R2LogR2RBF(RadialBasisFunction): r""" The :math:`r^2 \log{r^2}` basis function. The derivative of this function is :math:`2 r (\log{r^2} + 1)`. .. note:: :math:`r = \lVert x - c \rVert` Parameters ---------- c : ``(n_centres, n_dims)`` `ndarray` The set of centers that make the basis. Usually represents a set of source landmarks. """ def __init__(self, c): super(R2LogR2RBF, self).__init__(c) def _apply(self, x, **kwargs): """ Apply the basis function. .. note:: :math:`r^2 \log{r^2} === r^2 2 \log{r}` Parameters ---------- x : ``(n_points, n_dims)`` `ndarray` Set of points to apply the basis to. Returns ------- u : ``(n_points, n_centres)`` `ndarray` The basis function applied to each distance, :math:`\lVert x - c \rVert`. """ euclidean_distance = cdist(x, self.c) mask = euclidean_distance == 0 with np.errstate(divide='ignore', invalid='ignore'): u = (euclidean_distance ** 2 * (2 * np.log(euclidean_distance))) # reset singularities to 0 u[mask] = 0 return u
[docs]class R2LogRRBF(RadialBasisFunction): r""" Calculates the :math:`r^2 \log{r}` basis function. The derivative of this function is :math:`r (1 + 2 \log{r})`. .. note:: :math:`r = \lVert x - c \rVert` Parameters ---------- c : ``(n_centres, n_dims)`` `ndarray` The set of centers that make the basis. Usually represents a set of source landmarks. """ def __init__(self, c): super(R2LogRRBF, self).__init__(c) def _apply(self, points, **kwargs): """ Apply the basis function :math:`r^2 \log{r}`. Parameters ---------- points : ``(n_points, n_dims)`` `ndarray` Set of points to apply the basis to. Returns ------- u : ``(n_points, n_centres)`` `ndarray` The basis function applied to each distance, :math:`\lVert points - c \rVert`. """ euclidean_distance = cdist(points, self.c) mask = euclidean_distance == 0 with np.errstate(divide='ignore', invalid='ignore'): u = euclidean_distance ** 2 * np.log(euclidean_distance) # reset singularities to 0 u[mask] = 0 return u