# License: Non-Commercial Use Only
#
# Permission is granted to use, copy, modify, and distribute this software
# for non-commercial purposes only, with attribution to the original author.
# Commercial use requires explicit permission.
#
# This software is provided "as is", without warranty of any kind.
"""Sparse weight construction for transferring fields between meshes."""
from typing import Tuple
from numpy.typing import NDArray
import numpy as np
from scipy import sparse
import trimesh
from .mesh.base import Mesh, RectangularMesh
from .helpers import query_nearest
# ---------------------------------------------------------------------------
# Private weight-computation helpers
# ---------------------------------------------------------------------------
def _rbf_kernel(r: NDArray, kernel: str) -> NDArray:
"""Evaluate a conditionally positive definite RBF kernel.
Args:
r: Distance array (any shape).
kernel: One of ``"thin_plate_spline"``, ``"cubic"``, or
``"linear"``.
Returns:
Kernel values, same shape as *r*.
"""
if kernel == "thin_plate_spline":
with np.errstate(divide="ignore", invalid="ignore"):
return np.where(r > 0, r ** 2 * np.log(r), 0.0)
if kernel == "cubic":
return r ** 3
if kernel == "linear":
return r
raise ValueError(f"Unknown kernel: {kernel!r}")
def _knn_weights(
distances: NDArray,
indices: NDArray,
method: str,
) -> Tuple[NDArray, NDArray, NDArray]:
"""Compute sparse weight triplets from k-NN distances.
Args:
distances: Shape (M, k).
indices: Shape (M, k).
method: ``"nearest"``, ``"idw"``, or ``"gaussian"``.
Returns:
``(row, col, val)`` arrays for sparse matrix construction.
"""
M, k = indices.shape
if method == "nearest":
return (
np.arange(M),
indices[:, 0],
np.ones(M),
)
if method == "idw":
w = 1.0 / (distances + 1e-16)
elif method == "gaussian":
d_min = np.min(distances, axis=1, keepdims=True) + 1e-16
w = np.exp(-0.2 * (distances / d_min) ** 2)
else:
raise ValueError(f"Unknown knn method: {method}")
w /= w.sum(axis=1, keepdims=True)
rows = np.repeat(np.arange(M), k)
cols = indices.ravel()
vals = w.ravel()
return rows, cols, vals
def _barycentric_weights(
sender_mesh: Mesh,
receiver_nodes: NDArray,
) -> Tuple[NDArray, NDArray, NDArray]:
"""Compute barycentric interpolation weights (degree=3).
Projects each receiver node onto the nearest triangle of the sender
mesh and computes barycentric coordinates.
Args:
sender_mesh: The sender mesh (any type -- quads are triangulated).
receiver_nodes: Query points, shape (M, 3).
Returns:
``(row, col, val)`` arrays for sparse matrix construction.
"""
sender_trimesh = sender_mesh.build_trimesh()
closest_pts, _, tri_indices = trimesh.proximity.closest_point(
sender_trimesh, receiver_nodes
)
# Triangle vertex indices (already in masked-node space)
tri_verts = sender_trimesh.faces[tri_indices] # (M, 3)
A = sender_mesh.nodes[tri_verts[:, 0]]
B = sender_mesh.nodes[tri_verts[:, 1]]
C = sender_mesh.nodes[tri_verts[:, 2]]
P = closest_pts
# Barycentric coordinates via Cramer's rule
v0 = B - A
v1 = C - A
v2 = P - A
d00 = np.sum(v0 * v0, axis=1)
d01 = np.sum(v0 * v1, axis=1)
d11 = np.sum(v1 * v1, axis=1)
d20 = np.sum(v2 * v0, axis=1)
d21 = np.sum(v2 * v1, axis=1)
denom = d00 * d11 - d01 * d01 + 1e-30
v = (d11 * d20 - d01 * d21) / denom
w = (d00 * d21 - d01 * d20) / denom
u = 1.0 - v - w
bary = np.column_stack([u, v, w])
bary = np.clip(bary, 0, 1)
bary /= bary.sum(axis=1, keepdims=True)
M = receiver_nodes.shape[0]
rows = np.repeat(np.arange(M), 3)
cols = tri_verts.ravel()
vals = bary.ravel()
return rows, cols, vals
def _bilinear_weights(
sender_mesh: RectangularMesh,
receiver_nodes: NDArray,
) -> Tuple[NDArray, NDArray, NDArray]:
"""Compute bilinear interpolation weights on quad faces (degree=4).
Projects each receiver node onto the nearest triangle, maps back to
the containing quad face, then computes bilinear weights.
Args:
sender_mesh: A rectangular mesh.
receiver_nodes: Query points, shape (M, 3).
Returns:
``(row, col, val)`` arrays for sparse matrix construction.
"""
sender_trimesh = sender_mesh.build_trimesh()
closest_pts, _, tri_indices = trimesh.proximity.closest_point(
sender_trimesh, receiver_nodes
)
# Map triangle → face, then get quad vertex indices
face_indices = sender_mesh.triangle2face_index(tri_indices)
quad_verts = sender_mesh.faces[face_indices] # (M, 4)
A = sender_mesh.nodes[quad_verts[:, 0]] # corner 0
B = sender_mesh.nodes[quad_verts[:, 1]] # corner 1
_ = sender_mesh.nodes[quad_verts[:, 2]] # corner 2 (unused)
D = sender_mesh.nodes[quad_verts[:, 3]] # corner 3
P = closest_pts
# Project into local (u, v) coordinates on the quad.
# For a general quad ABCD, we use least-squares projection:
# P ≈ (1-u)(1-v)*A + u(1-v)*B + uv*C + (1-u)v*D
# Solve via iterative Newton or direct for near-planar quads.
# For spherical meshes the quads are near-planar, so a direct
# approach using the two edge directions works well.
e_u = B - A # u-direction edge
e_v = D - A # v-direction edge
dp = P - A
# Solve [e_u | e_v]^T @ [e_u | e_v] @ [u; v] = [e_u | e_v]^T @ dp
g11 = np.sum(e_u * e_u, axis=1)
g12 = np.sum(e_u * e_v, axis=1)
g22 = np.sum(e_v * e_v, axis=1)
b1 = np.sum(dp * e_u, axis=1)
b2 = np.sum(dp * e_v, axis=1)
det = g11 * g22 - g12 * g12 + 1e-30
u = (g22 * b1 - g12 * b2) / det
v = (g11 * b2 - g12 * b1) / det
u = np.clip(u, 0, 1)
v = np.clip(v, 0, 1)
# Bilinear weights: A=(1-u)(1-v), B=u(1-v), C=uv, D=(1-u)v
w0 = (1 - u) * (1 - v)
w1 = u * (1 - v)
w2 = u * v
w3 = (1 - u) * v
weights = np.column_stack([w0, w1, w2, w3])
weights /= weights.sum(axis=1, keepdims=True)
M = receiver_nodes.shape[0]
rows = np.repeat(np.arange(M), 4)
cols = quad_verts.ravel()
vals = weights.ravel()
return rows, cols, vals
def _poly_basis(X: NDArray, degree: int) -> NDArray:
"""Build polynomial basis matrix up to given degree in 3D.
For degree *d*, the basis contains all monomials
``x^a * y^b * z^c`` with ``a + b + c <= d``, ordered by total
degree then lexicographically.
Args:
X: Coordinate array, shape ``(..., 3)``.
degree: Maximum total polynomial degree (>= 0).
Returns:
Basis matrix, shape ``(..., m)`` where
``m = (d+1)(d+2)(d+3)/6``.
"""
cols = [np.ones(X.shape[:-1] + (1,))]
if degree >= 1:
cols.append(X)
if degree >= 2:
x, y, z = X[..., 0:1], X[..., 1:2], X[..., 2:3]
for d in range(2, degree + 1):
# All monomials of total degree exactly d
for a in range(d, -1, -1):
for b in range(d - a, -1, -1):
c = d - a - b
cols.append(x ** a * y ** b * z ** c)
return np.concatenate(cols, axis=-1)
def _local_rbf_weights(
sender_nodes: NDArray,
receiver_nodes: NDArray,
distances: NDArray,
indices: NDArray,
kernel: str = "thin_plate_spline",
degree: int = 1,
) -> Tuple[NDArray, NDArray, NDArray]:
"""Compute local RBF interpolation weights (vectorized).
For each receiver node, builds a local k×k kernel system using its
k nearest sender neighbors and solves for interpolation weights.
All M systems are solved in a single batched ``np.linalg.solve`` call.
Args:
sender_nodes: Sender node coordinates, shape (N, 3).
receiver_nodes: Receiver node coordinates, shape (M, 3).
distances: Pre-computed distances to k neighbors, shape (M, k).
indices: Indices of k neighbors in sender_nodes, shape (M, k).
kernel: RBF kernel name.
degree: Polynomial augmentation degree (0=constant, 1=linear,
2=quadratic, etc.). ``None`` disables augmentation.
Returns:
``(row, col, val)`` arrays for sparse matrix construction.
"""
M, k = indices.shape
# Gather local neighborhoods: (M, k, 3)
X = sender_nodes[indices]
# Pairwise distances within each neighborhood: (M, k, k)
# Using ||a-b||^2 = ||a||^2 + ||b||^2 - 2*a.b to avoid (M,k,k,3) tensor
XX = np.sum(X ** 2, axis=-1) # (M, k)
dots = np.einsum("mij,mkj->mik", X, X) # (M, k, k)
dist_sq = XX[:, :, None] + XX[:, None, :] - 2 * dots
dist_sq = np.maximum(dist_sq, 0)
Phi = _rbf_kernel(np.sqrt(dist_sq), kernel)
# Evaluation distances: already have (M, k)
phi = _rbf_kernel(distances, kernel)
# Minimal regularization to prevent exact singularity.
# The post-hoc weight check (phase 3 below) handles
# ill-conditioned systems by falling back to degree=0.
Phi[:, range(k), range(k)] += 1e-12
def _build_augmented(Phi_sub, phi_sub, P_sub, p_q_sub):
"""Build and solve augmented RBF system [Phi P; P^T 0]."""
M_s, k_s = phi_sub.shape
m_s = P_sub.shape[2]
n_s = k_s + m_s
A = np.zeros((M_s, n_s, n_s))
A[:, :k_s, :k_s] = Phi_sub
A[:, :k_s, k_s:] = P_sub
A[:, k_s:, :k_s] = P_sub.transpose(0, 2, 1)
rhs = np.zeros((M_s, n_s, 1))
rhs[:, :k_s, 0] = phi_sub
rhs[:, k_s:, 0] = p_q_sub
return np.linalg.solve(A, rhs)[:, :k_s, 0]
if degree is not None and degree >= 0:
# Polynomial augmentation for conditionally positive definite
# kernels (TPS, cubic, linear, etc.).
# Center coordinates around the local mean to improve
# conditioning (on the unit sphere, raw xyz values ~1 but
# local variations ~0.01, causing near-singularity).
center = X.mean(axis=1, keepdims=True) # (M, 1, 3)
X_c = X - center # (M, k, 3)
q_c = receiver_nodes - center[:, 0, :] # (M, 3)
P = _poly_basis(X_c, degree) # (M, k, m)
p_q = _poly_basis(q_c[:, None, :], degree)[:, 0, :] # (M, m)
if degree == 0:
w = _build_augmented(Phi, phi, P, p_q)
else:
# Phase 1: SVD pre-filter for rank-deficient P
sv = np.linalg.svd(P, compute_uv=False)
bad = sv[:, -1] < 1e-10 * sv[:, 0]
attempt = ~bad
# Phase 2: solve for rows passing SVD check
w = np.empty((M, k))
if attempt.any():
w[attempt] = _build_augmented(
Phi[attempt], phi[attempt], P[attempt], p_q[attempt])
# Phase 3: post-solve safety net -- catch remaining
# instabilities not detected by SVD on P alone
if attempt.any():
max_abs = np.max(np.abs(w[attempt]), axis=1)
solve_bad = max_abs > 2.0
if solve_bad.any():
bad[np.where(attempt)[0][solve_bad]] = True
# Phase 4: fallback to degree=0 for all bad rows
if bad.any():
P0 = np.ones((bad.sum(), k, 1))
p_q0 = np.ones((bad.sum(), 1))
w[bad] = _build_augmented(
Phi[bad], phi[bad], P0, p_q0)
else:
w = np.linalg.solve(Phi, phi[:, :, None])[:, :, 0]
rows = np.repeat(np.arange(M), k)
cols = indices.ravel()
vals = w.ravel()
return rows, cols, vals
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _poly_terms(degree: int, dim: int) -> int:
"""Number of polynomial terms up to *degree* in *dim* dimensions.
Equal to ``comb(degree + dim, dim)`` =
``(d+1)(d+2)(d+3)/6`` for dim=3.
"""
from math import comb # pylint: disable=C0415
return comb(degree + dim, dim)
_UNSET = object()
# ---------------------------------------------------------------------------
# MeshTransfer class
# ---------------------------------------------------------------------------
[docs]
class MeshTransfer:
"""Sparse regridder between two spherical meshes.
Builds a sparse weight matrix ``W`` of shape
``(receiver.num_nodes, sender.num_nodes)`` so that ``W @ data``
transfers values from the sender grid to the receiver grid.
Weights are computed lazily on the first :meth:`transform` call,
or explicitly via :meth:`build_weights`.
Args:
sender: Source mesh (any :class:`Mesh` subclass).
receiver: Target mesh.
method: Interpolation method.
- ``"nearest"``: 1-nearest-neighbor.
- ``"idw"``: Inverse-distance weighting.
- ``"gaussian"``: Gaussian distance weighting.
- ``"barycentric"``: Barycentric on triangulated faces.
- ``"bilinear"``: Bilinear on quad faces
(:class:`RectangularMesh` sender only).
- ``"local_rbf"``: Local RBF interpolation (default).
k: Number of neighbors for kNN-based methods.
kernel: RBF kernel name (for ``"local_rbf"``).
degree: Polynomial augmentation degree (for ``"local_rbf"``).
0 = constant, 1 = linear, 2 = quadratic, etc.
max_dist: Maximum neighbor distance. Neighbors beyond this are
pruned. Use ``"auto"`` for the 90th-percentile heuristic.
Example::
regridder = MeshTransfer(ocean_mesh, target_grid,
method="local_rbf", k=16, degree=0)
sst_regridded = regridder.transform(sst_ocean)
Example::
regridder = MeshTransfer(ocean_mesh, target_grid,
method="local_rbf", k=16, degree=0)
sst_regridded = regridder.transform(sst_ocean)
# Or use the @ operator
result = regridder @ data
# Access the sparse matrix directly
W = regridder.weights
"""
def __init__(
self,
sender: Mesh,
receiver: Mesh,
method: str = "local_rbf",
k: int = 16,
kernel: str = "thin_plate_spline",
degree: int = 1,
max_dist: float = None,
) -> None:
self.sender = sender
self.receiver = receiver
self.method = method
self.k = k
self.kernel = kernel
self.degree = degree
self.max_dist = max_dist
self._weights = None
self._nearest_senders = None
self._nearest_distances = None
def __repr__(self) -> str:
n_s = f"{self.sender.num_nodes:,}"
n_r = f"{self.receiver.num_nodes:,}"
s = f"MeshTransfer({n_s} -> {n_r}"
s += f", method='{self.method}', k={self.k}"
if self.method == "local_rbf":
s += f", kernel='{self.kernel}', degree={self.degree}"
if self.max_dist is not None:
s += f", max_dist={self.max_dist}"
built = "built" if self._weights is not None else "not built"
s += f", {built})"
return s
@property
def shape(self) -> Tuple[int, int]:
"""Shape of the weight matrix: ``(num_receiver, num_sender)``."""
return (self.receiver.num_nodes, self.sender.num_nodes)
@property
def weights(self) -> sparse.csr_matrix:
"""The sparse weight matrix. Built lazily on first access."""
if self._weights is None:
self.build_weights()
return self._weights
# ------------------------------------------------------------------
# Core public API
# ------------------------------------------------------------------
[docs]
def build_weights(
self,
method: str = _UNSET,
k: int = _UNSET,
kernel: str = _UNSET,
degree: int = _UNSET,
max_dist: float = _UNSET,
) -> sparse.csr_matrix:
"""Build the sparse interpolation weight matrix.
When called with no arguments, uses the parameters stored on
the instance. Any provided argument updates the stored
configuration and triggers a rebuild.
Returns:
Sparse CSR matrix of shape :attr:`shape`.
"""
if method is not _UNSET:
self.method = method
if k is not _UNSET:
self.k = k
if kernel is not _UNSET:
self.kernel = kernel
if degree is not _UNSET:
self.degree = degree
if max_dist is not _UNSET:
self.max_dist = max_dist
n_recv = self.receiver.num_nodes
n_send = self.sender.num_nodes
if self.method == "barycentric":
rows, cols, vals = _barycentric_weights(
self.sender, self.receiver.nodes
)
elif self.method == "bilinear":
if not isinstance(self.sender, RectangularMesh):
raise ValueError(
"bilinear method requires a RectangularMesh sender"
)
rows, cols, vals = _bilinear_weights(
self.sender, self.receiver.nodes
)
else:
k = self.k
self._ensure_neighbors(k)
dists = self._nearest_distances[:, :k]
idxs = self._nearest_senders[:, :k]
if self.method in ("nearest", "idw", "gaussian"):
rows, cols, vals = _knn_weights(dists, idxs, self.method)
elif self.method == "local_rbf":
deg = self.degree if self.degree is not None else 1
max_d = self.max_dist
if max_d == "auto":
mean_d = dists.mean(axis=1)
max_d = np.percentile(mean_d, 90)
if max_d is not None:
rows, cols, vals = self._build_weights_thresholded(
dists, idxs, self.kernel, deg, max_d,
)
else:
rows, cols, vals = _local_rbf_weights(
self.sender.nodes,
self.receiver.nodes,
dists,
idxs,
kernel=self.kernel,
degree=deg,
)
else:
raise ValueError(f"Unknown method: {self.method}")
W = sparse.csr_matrix((vals, (rows, cols)), shape=(n_recv, n_send))
self._weights = W
return W
[docs]
def __matmul__(self, data: NDArray) -> NDArray:
"""Allow ``regridder @ data`` as shorthand for :meth:`transform`."""
return self.transform(data)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _ensure_neighbors(self, k: int, recompute: bool = False):
"""Ensure at least *k* neighbors have been computed."""
need = (
self._nearest_senders is None
or self._nearest_senders.shape[1] < k
or recompute
)
if need:
self._nearest_distances, self._nearest_senders = query_nearest(
self.sender.nodes,
self.receiver.nodes,
n_neighbors=k,
)
if k == 1:
self._nearest_senders = np.expand_dims(
self._nearest_senders, -1
)
self._nearest_distances = np.expand_dims(
self._nearest_distances, -1
)
def _build_weights_thresholded(
self,
dists: NDArray,
idxs: NDArray,
kernel: str,
degree: int,
max_dist: float,
) -> Tuple[NDArray, NDArray, NDArray]:
"""Build local RBF weights with distance-based neighbor pruning."""
n_valid = np.sum(dists <= max_dist, axis=1)
rows_all, cols_all, vals_all = [], [], []
sender_nodes = self.sender.nodes
receiver_nodes = self.receiver.nodes
for nv in np.unique(n_valid):
if nv == 0:
continue
group = np.where(n_valid == nv)[0]
if nv == 1:
rows_all.append(group)
cols_all.append(idxs[group, 0])
vals_all.append(np.ones(len(group)))
continue
deg = degree
if deg >= 1 and nv < _poly_terms(deg, 3):
deg = 0
r, c, v = _local_rbf_weights(
sender_nodes,
receiver_nodes[group],
dists[group, :nv],
idxs[group, :nv],
kernel=kernel,
degree=deg,
)
r = group[r]
rows_all.append(r)
cols_all.append(c)
vals_all.append(v)
if rows_all:
return (np.concatenate(rows_all),
np.concatenate(cols_all),
np.concatenate(vals_all))
return np.array([], dtype=int), np.array([], dtype=int), np.array([])