A note on the Fokker-Planck-Kolmogorov equation
This is a test note for checking math and code rendering on the site. It is not meant to be a polished article yet.
One useful way to look at a stochastic differential equation is not to follow individual sample paths, but to follow the density of a population of particles evolving under the same dynamics.
Suppose \(X_t \in \mathbb{R}^d\) follows the Ito process
\[ dX_t = b(X_t,t)\,dt + \sigma(X_t,t)\,dW_t, \]where \(b\) is the drift, \(\sigma\) is the diffusion coefficient, and \(W_t\) is Brownian motion. If \(\rho_t(x)\) is the probability density of \(X_t\), then the Fokker-Planck-Kolmogorov equation is
\[ \frac{\partial \rho_t(x)}{\partial t} = - \nabla \cdot \left( b(x,t)\rho_t(x) \right) + \frac{1}{2} \sum_{i,j} \frac{\partial^2}{\partial x_i \partial x_j} \left( a_{ij}(x,t)\rho_t(x) \right), \qquad a = \sigma \sigma^\top . \]In this form, the equation says that probability moves because of two effects: directed transport from the drift field, and spreading from diffusion. This is one reason it keeps reappearing near flow matching, score-based models, stochastic control, and physical simulation.
In code, one small way to write the right-hand side is to treat \(\rho(x,t)\), \(b(x,t)\), and \(a(x,t)\) as callable functions and let automatic differentiation handle the derivatives.
import jax
import jax.numpy as jnp
def fpk_rhs(rho, drift, diffusion, x, t):
"""Right-hand side of the Fokker-Planck-Kolmogorov equation."""
def flux(y):
return drift(y, t) * rho(y, t)
def diffusion_term(y):
a = diffusion(y, t)
return a * rho(y, t)
transport = -jnp.trace(jax.jacfwd(flux)(x))
hessian = jax.hessian(diffusion_term)(x)
spreading = 0.5 * jnp.einsum("ijij->", hessian)
return transport + spreading
grad_fpk_rhs = jax.grad(
lambda x, t: fpk_rhs(rho, drift, diffusion, x, t)
)
gradient_at_x = grad_fpk_rhs(x, t)