A note on the Fokker-Planck-Kolmogorov equation

This is a test note for checking math and code rendering on the site. It is not meant to be a polished article yet.

One useful way to look at a stochastic differential equation is not to follow individual sample paths, but to follow the density of a population of particles evolving under the same dynamics.

Suppose \(X_t \in \mathbb{R}^d\) follows the Ito process

\[ dX_t = b(X_t,t)\,dt + \sigma(X_t,t)\,dW_t, \]

where \(b\) is the drift, \(\sigma\) is the diffusion coefficient, and \(W_t\) is Brownian motion. If \(\rho_t(x)\) is the probability density of \(X_t\), then the Fokker-Planck-Kolmogorov equation is

\[ \frac{\partial \rho_t(x)}{\partial t} = - \nabla \cdot \left( b(x,t)\rho_t(x) \right) + \frac{1}{2} \sum_{i,j} \frac{\partial^2}{\partial x_i \partial x_j} \left( a_{ij}(x,t)\rho_t(x) \right), \qquad a = \sigma \sigma^\top . \]

In this form, the equation says that probability moves because of two effects: directed transport from the drift field, and spreading from diffusion. This is one reason it keeps reappearing near flow matching, score-based models, stochastic control, and physical simulation.

In code, one small way to write the right-hand side is to treat \(\rho(x,t)\), \(b(x,t)\), and \(a(x,t)\) as callable functions and let automatic differentiation handle the derivatives.

import jax
import jax.numpy as jnp


def fpk_rhs(rho, drift, diffusion, x, t):
    """Right-hand side of the Fokker-Planck-Kolmogorov equation."""

    def flux(y):
        return drift(y, t) * rho(y, t)

    def diffusion_term(y):
        a = diffusion(y, t)
        return a * rho(y, t)

    transport = -jnp.trace(jax.jacfwd(flux)(x))
    hessian = jax.hessian(diffusion_term)(x)
    spreading = 0.5 * jnp.einsum("ijij->", hessian)

    return transport + spreading


grad_fpk_rhs = jax.grad(
    lambda x, t: fpk_rhs(rho, drift, diffusion, x, t)
)
gradient_at_x = grad_fpk_rhs(x, t)