GPU / JAX / CuPy
findiff can transparently operate on JAX and CuPy arrays in addition
to standard NumPy arrays. All derivative operators (Diff, Gradient,
Divergence, Curl, Laplacian) and operator compositions (addition,
multiplication, exponentiation) work out of the box — just pass an array from
any supported backend.
Installation
findiff detects JAX and CuPy at runtime; neither is a hard dependency. Install whichever backend you need:
# JAX — CPU only
pip install jax
# JAX — NVIDIA GPU
pip install jax[cuda12]
# CuPy — NVIDIA GPU
pip install cupy-cuda12x
Basic usage
Pass a JAX (or CuPy) array where you would normally pass a NumPy array:
import jax
import jax.numpy as jnp
from findiff import Diff
jax.config.update("jax_enable_x64", True)
x = jnp.linspace(0, 2 * jnp.pi, 1000)
dx = float(x[1] - x[0])
f = jnp.sin(x)
d_dx = Diff(0, dx)
df_dx = d_dx(f) # returns a JAX array
type(df_dx) # jaxlib.xla_extension.ArrayImpl
The result stays on the same backend as the input — no implicit copies back to NumPy.
Using jax.jit for speed
Without JIT, JAX dispatches each NumPy-like operation individually, which can be slower than plain NumPy due to dispatch overhead. The real speedup comes from JIT compilation, which fuses all the slice and arithmetic operations inside the operator into a single optimized kernel:
d_dx_jit = jax.jit(d_dx)
# First call traces + compiles (slow):
result = d_dx_jit(f)
# Subsequent calls reuse the compiled kernel (fast):
result = d_dx_jit(f)
This works for any operator, including composed operators and vector calculus shortcuts:
from findiff import Laplacian
lap = Laplacian(h=[dx, dy, dz])
lap_jit = jax.jit(lap)
result = lap_jit(f_3d)
Tip
Always call .block_until_ready() when benchmarking JAX, since JAX uses
asynchronous dispatch:
result = lap_jit(f_3d).block_until_ready()
Non-uniform grids
Non-uniform grid coordinates can be passed as JAX or CuPy arrays. They are converted to NumPy internally for coefficient computation (which happens once at operator construction time), while the operator application still runs on the GPU backend:
import numpy as np
x = np.linspace(0, np.pi, 500) # NumPy coords for construction
f = jnp.array(np.sin(x)) # JAX array for application
d_dx = Diff(0, x)
result = d_dx(f) # returns a JAX array
Vector calculus
Gradient, Divergence, Curl, and Laplacian all support
alternative backends:
from findiff import Gradient, Laplacian
grad = Gradient(h=[dx, dy], acc=4)
grad_f = grad(f_2d_jax) # returns a JAX array
lap = Laplacian(h=[dx, dy])
lap_f = lap(f_2d_jax) # returns a JAX array
Operator composition
Composed operators with scalar or array coefficients work as expected:
from findiff import Diff, Identity
d2 = Diff(0, dx) ** 2
L = d2 + Identity() # d²/dx² + 1
L_jit = jax.jit(L)
result = L_jit(f_jax)
What is not supported on GPU
The following features remain NumPy / SciPy only and will raise errors or return NumPy arrays if called with GPU data:
.matrix(shape)— returns ascipy.sparsematrixPDEandBoundaryConditions— usescipy.sparse.linalgsolversTimeDependentPDE/MOLSolution— implicit time steppers use sparse solversCompactSchemeoperators — usescipy.sparse.linalg.spluStencil/StencilSet— construction depends on the matrix path.eigs()/.eigsh()— usescipy.sparse.linalg
Running the benchmarks
The test suite includes benchmarks that compare NumPy and JAX performance.
They are excluded from normal pytest runs and can be executed with:
pytest -m benchmark -v -s --override-ini='addopts='