Source code for menpo.transform.rbf

import numpy as np
from scipy.spatial.distance import cdist
from .base import Transform


class RadialBasisFunction(Transform):
    r"""
    Radial Basis Functions are a class of transform that is used by
    TPS. They have to be able to take their own radial derivative for TPS to
    be able to take it's own total derivative.

    Parameters
    ----------
    c : (n_centres, n_dims) ndarray
        The set of centers that make the basis. Usually represents a set of
        source landmarks.
    """

    def __init__(self, c):
        self.c = c

    @property
    def n_centres(self):
        return self.c.shape[0]

    @property
    def n_dims(self):
        r"""
        The RBF can only be applied on points with the same dimensionality as
        the centres.
        """
        return self.c.shape[1]

    @property
    def n_dims_output(self):
        r"""
        The result of the transform has a dimension (weight) for every centre
        """
        return self.n_centres


[docs]class R2LogR2RBF(RadialBasisFunction): r""" The :math:`r^2 \log{r^2}` basis function. The derivative of this function is :math:`2 r (\log{r^2} + 1)`. .. note:: :math:`r = \lVert x - c \rVert` Parameters ---------- c : (n_centres, n_dims) ndarray The set of centers that make the basis. Usually represents a set of source landmarks. """ def __init__(self, c): super(R2LogR2RBF, self).__init__(c) def _apply(self, x, **kwargs): """ Apply the basis function. .. note:: :math:`r^2 \log{r^2} === r^2 2 \log{r}` Parameters ---------- x : (n_points, n_dims) ndarray Set of points to apply the basis to. Returns ------- u : (n_points, n_centres) ndarray The basis function applied to each distance, :math:`\lVert x - c \rVert`. """ euclidean_distance = cdist(x, self.c) mask = euclidean_distance == 0 with np.errstate(divide='ignore', invalid='ignore'): u = (euclidean_distance ** 2 * (2 * np.log(euclidean_distance))) # reset singularities to 0 u[mask] = 0 return u
[docs]class R2LogRRBF(RadialBasisFunction): r""" Calculates the :math:`r^2 \log{r}` basis function. The derivative of this function is :math:`r (1 + 2 \log{r})`. .. note:: :math:`r = \lVert x - c \rVert` Parameters ---------- c : (n_centres, n_dims) ndarray The set of centers that make the basis. Usually represents a set of source landmarks. """ def __init__(self, c): super(R2LogRRBF, self).__init__(c) def _apply(self, points, **kwargs): """ Apply the basis function :math:`r^2 \log{r}`. Parameters ---------- points : (n_points, n_dims) ndarray Set of points to apply the basis to. Returns ------- u : (n_points, n_centres) ndarray The basis function applied to each distance, :math:`\lVert points - c \rVert`. """ euclidean_distance = cdist(points, self.c) mask = euclidean_distance == 0 with np.errstate(divide='ignore', invalid='ignore'): u = euclidean_distance ** 2 * np.log(euclidean_distance) # reset singularities to 0 u[mask] = 0 return u