Source code for menpo.transform.thinplatesplines

import numpy as np
from .base import Transform, Alignment, Invertible
from .rbf import R2LogR2RBF


# Note we inherit from Alignment first to get it's n_dims behavior
[docs]class ThinPlateSplines(Alignment, Transform, Invertible): r""" The thin plate splines (TPS) alignment between 2D source and target landmarks. `kernel` can be used to specify an alternative kernel function. If `None` is supplied, the :class:`menpo.basis.rbf.R2LogR2` kernel will be used. Parameters ---------- source : (N, 2) ndarray The source points to apply the tps from target : (N, 2) ndarray The target points to apply the tps to kernel : :class:`menpo.basis.rbf.BasisFunction`, optional The kernel to apply. Default: :class:`menpo.basis.rbf.R2LogR2` Raises ------ ValueError TPS is only with on 2-dimensional data """ def __init__(self, source, target, kernel=None): Alignment.__init__(self, source, target) if self.n_dims != 2: raise ValueError('TPS can only be used on 2D data.') if kernel is None: kernel = R2LogR2RBF(source.points) self.kernel = kernel # k[i, j] is the rbf weighting between source i and j # (of course, k is thus symmetrical and it's diagonal nil) self.k = self.kernel.apply(self.source.points) # p is a homogeneous version of the source points self.p = np.concatenate( [np.ones([self.n_points, 1]), self.source.points], axis=1) o = np.zeros([3, 3]) top_l = np.concatenate([self.k, self.p], axis=1) bot_l = np.concatenate([self.p.T, o], axis=1) self.l = np.concatenate([top_l, bot_l], axis=0) self.v, self.y, self.coefficients = None, None, None self._build_coefficients() def _build_coefficients(self): self.v = self.target.points.T.copy() self.y = np.hstack([self.v, np.zeros([2, 3])]) self.coefficients = np.linalg.solve(self.l, self.y.T) def _sync_state_from_target(self): # now the target is updated, we only have to rebuild the # coefficients. self._build_coefficients() def _apply(self, points, **kwargs): """ Performs a TPS transform on the given points. Parameters ---------- points : (N, D) ndarray The points to transform. Returns -------- f : (N, D) ndarray The transformed points """ if points.shape[1] != self.n_dims: raise ValueError('TPS can only be applied to 2D data.') x = points[..., 0][:, None] y = points[..., 1][:, None] # calculate the affine coefficients of the warp # (C = Constant component, then X, Y respectively) c_affine_c = self.coefficients[-3] c_affine_x = self.coefficients[-2] c_affine_y = self.coefficients[-1] # the affine warp component f_affine = c_affine_c + c_affine_x * x + c_affine_y * y # calculate a distance matrix (for L2 Norm) between every source # and the target kernel_dist = self.kernel.apply(points) # grab the affine free components of the warp c_affine_free = self.coefficients[:-3] # build the affine free warp component f_affine_free = kernel_dist.dot(c_affine_free) return f_affine + f_affine_free @property def has_true_inverse(self): return False def pseudoinverse(self): return ThinPlateSplines(self.target, self.source, kernel=self.kernel)