Reinforce Adjoint Matching (RAM) Explained¶
Recently a new Diffusion RL method Reinforce Adjoint Matching (RAM) immediately caught my attention with their taglines "0 reward gradients, 0 SDE rollouts, 0 reward hacking, 50x fewer training steps". Well, their claims are true, and the final code implementation is deadly simple. But understanding how they got there is not straightforward, especially for readers who know little about stochastic optimal control (SOC). So I built this notebook to explain RAM in an intuitive, self-contained way. By the end you will:
- understand the post-training problem RAM solves,
- see how RAM took shape via adjoint matching and the REINFORCE trick,
- understand the full RAM loss,
- train a tiny 2D model with flow-matching and RAM on CPU,
- and extend RAM to handle multi-reward post-training.
Edit History:
- 05-30-2026: initial version with Claude's help
0. Setup¶
We only need torch, numpy, matplotlib, and tqdm. Install them with
your package manager of choice if any are missing.
from __future__ import annotations
import math
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
torch.manual_seed(0)
np.random.seed(0)
DEVICE = torch.device("cpu")
torch.set_default_dtype(torch.float32)
print("torch:", torch.__version__)
torch: 2.11.0
A small bundle of plotting helpers, defined once so the rest of the notebook can stay focused on the math.
def plot_samples(samples, ax=None, *, title=None, color="C0", s=6, alpha=0.4,
xlim=(-5, 5), ylim=(-5, 5), c=None, cmap=None, label=None):
"""Scatter-plot a tensor of 2D points.
Pass ``c`` (per-point values) and ``cmap`` to color points by a category,
otherwise the single ``color`` is used.
"""
if ax is None:
_, ax = plt.subplots(figsize=(4, 4))
pts = samples.detach().cpu().numpy()
if c is not None:
ax.scatter(pts[:, 0], pts[:, 1], s=s, alpha=alpha, c=c, cmap=cmap, label=label)
else:
ax.scatter(pts[:, 0], pts[:, 1], s=s, alpha=alpha, color=color, label=label)
ax.set_xlim(xlim); ax.set_ylim(ylim); ax.set_aspect("equal")
if title:
ax.set_title(title)
return ax
def plot_density(density_fn, ax=None, *, title=None, n_grid=80,
xlim=(-5, 5), ylim=(-5, 5), cmap="viridis"):
"""Heatmap of `density_fn`, a callable mapping [N, 2] points to [N] scalars."""
if ax is None:
_, ax = plt.subplots(figsize=(4, 4))
gx = torch.linspace(xlim[0], xlim[1], n_grid)
gy = torch.linspace(ylim[0], ylim[1], n_grid)
X, Y = torch.meshgrid(gx, gy, indexing="ij")
pts = torch.stack([X.flatten(), Y.flatten()], dim=-1)
Z = density_fn(pts).detach().cpu().reshape(n_grid, n_grid).numpy()
ax.imshow(Z.T, origin="lower", extent=(*xlim, *ylim), cmap=cmap)
ax.set_aspect("equal")
if title:
ax.set_title(title)
return ax
def kde_density(grid_pts, samples, h=0.15):
"""Gaussian KDE (unnormalized values; we re-normalize when comparing)."""
d2 = ((grid_pts[:, None, :] - samples[None, :, :]) ** 2).sum(-1)
p = torch.exp(-0.5 * d2 / h**2).mean(-1)
return p
1. Two stages of generative learning¶
Training a modern continuous-data generative model — for images, video, proteins, or robot actions — typically happens in two stages.
Stage 1 — Pretraining. Show the model a giant pile of "good" examples and teach it to imitate their distribution. The objective is supervised regression: take a clean datapoint, corrupt it analytically, then ask the model to predict either the corruption or the clean datapoint. This is what makes diffusion and flow-matching scale: the loss is plain MSE, the per-sample target is in closed form, and no sampling-during-training is required. The result is a generator that, given noisy input (a fresh sample from a simple distribution like a standard Gaussian), produces plausible outputs — samples that follow the training distribution and look like they could have come from it.
Stage 2 — Post-training. Pretraining alone cannot tell the model which of its many plausible outputs are desirable. "Desirable" usually comes as a scalar reward we can compute on a finished sample — an aesthetic score, a layout check, an OCR text match, a docking energy, a human-preference label. Post-training nudges the model toward higher reward, without losing the manifold of plausible outputs that pretraining gave us. This is reinforcement learning, applied to a generative model.
Both stages are essential. Pretraining alone gives you a fluent generator with no preferences; reward optimization alone (from scratch) has nothing plausible to start from. RAM is a method for Stage 2 — specifically, for flow-matching models on continuous data. The next section recaps Stage 1 with a small concrete model, then the rest of the notebook is all about Stage 2.
2. Stage 1: flow-matching pretraining on a 2D toy¶
The problem. We are given easy samples from a source distribution — the 2D standard Gaussian $\mathcal{N}(0, I)$, a single blob centered at the origin. We want a generator that turns those easy samples into samples from a target distribution we care about — an 8-mode Gaussian ring: eight tight clusters arranged evenly around a circle of radius 4.
The source is easy because we can draw from it with one line of NumPy. The target has structure: eight separate modes, none of them at the origin, and a generator that averages its outputs would land right in the middle of the ring, missing every mode. The whole point of Stage 1 is to learn a transport from the easy source to the structured target.
def sample_ring(n: int, modes: int = 8, radius: float = 4.0, sigma: float = 0.15,
return_labels: bool = False):
"""Sample n points from an 8-mode Gaussian ring (the target distribution).
Returns a [n, 2] tensor, or (samples, mode_idx) if ``return_labels=True``.
"""
mode_idx = torch.randint(0, modes, (n,))
angles = (2 * math.pi / modes) * mode_idx
centers = torch.stack(
[radius * torch.cos(angles), radius * torch.sin(angles)], dim=-1
)
samples = centers + sigma * torch.randn(n, 2)
if return_labels:
return samples, mode_idx
return samples
def sample_source(n: int) -> torch.Tensor:
"""Sample n points from the source distribution, a 2D standard Gaussian."""
return torch.randn(n, 2)
# Visualize source and target side by side so we know what we're transporting.
# We color the target by mode index so the 8 clusters are visually distinct.
ring_samples, ring_labels = sample_ring(3000, return_labels=True)
fig, axes = plt.subplots(1, 2, figsize=(9, 4.5))
plot_samples(sample_source(3000), axes[0],
title="source: standard Gaussian N(0, I)", color="0.25")
plot_samples(ring_samples, axes[1],
title="target: 8-mode Gaussian ring",
c=ring_labels.numpy(), cmap="tab10")
for ax in axes:
ax.set_xlim(-5, 5); ax.set_ylim(-5, 5)
plt.tight_layout(); plt.show()
Flow matching, in three steps¶
§1 said pretraining is "corrupt a clean datapoint analytically, then ask the model to predict either the corruption or the clean datapoint." Flow matching is one concrete instance of that recipe, particularly well-suited to continuous data. The clever twist is that the "noise" we corrupt with is exactly the source distribution we want to start from at generation time — so undoing the corruption is the transport we are after.
Step 1 — Corrupt analytically. Take a clean target sample $X_0$ (from the ring) and a fresh source sample $\varepsilon \sim \mathcal{N}(0, I)$, then linearly interpolate:
$$ X_t \;=\; (1-t)\, X_0 \;+\; t\,\varepsilon, \quad\text{with}\quad X_0 \sim p_{\text{data}},\; \varepsilon \sim \mathcal{N}(0, I),\; t \sim \mathcal{U}[0, 1]. $$
At $t = 0$ we have the target ($X_0$); at $t = 1$ we have the source ($\varepsilon$). The parameter $t$ is the fraction of source mixed in. Crucially, there is no neural network in this step — the corruption is closed-form, so per-sample targets in the next step will be cheap to compute.
Step 2 — Predict the velocity. Given a noisy $X_t$, in which direction was it pushed? The cleanest answer is the straight-line "velocity" $(\varepsilon - X_0)$ that took us from target to source. The catch is that any specific location $x$ at time $t$ can be reached by many different $(X_0, \varepsilon)$ pairs (a target point pulled hard by one noise vector, or a different target pulled gently by another). So we ask the network $v^\theta(x, t)$ to predict the average velocity across all those pairs — formally, the conditional expectation over $(X_0, \varepsilon)$ given that they happen to land at $X_t = x$:
$$ v(x, t) \;=\; \mathbb{E}_{X_0 \sim p_{\text{data}},\; \varepsilon \sim \mathcal{N}(0, I)} \!\left[\,\varepsilon - X_0 \,\bigm|\, X_t = x\,\right]. $$
Why velocity? Predicting $X_0$ directly is mathematically equivalent (the two targets differ by a known affine transform of $(X_t, t)$ and induce the same sampling ODE), and recent work argues $X_0$-prediction can be preferable for high-dimensional raw data — see Li & He, arXiv:2511.13720. We use the velocity parameterization here simply because it is the standard for flow matching / rectified flow (Lipman et al., arXiv:2210.02747; Liu et al., arXiv:2209.03003) and the one the RAM paper is formulated around.
Step 3 — Minimize the loss. Plain MSE between the model's prediction and the per-sample target:
$$ \mathcal{L}_{\text{FM}}(\theta) \;=\; \mathbb{E}_{X_0 \sim p_{\text{data}},\; \varepsilon \sim \mathcal{N}(0, I),\; t \sim \mathcal{U}[0, 1]} \Bigl[\,\bigl\|\,v^\theta(X_t, t) - (\varepsilon - X_0)\,\bigr\|^2\,\Bigr]. $$
Worth pausing to appreciate this: flow matching reduces generative modeling to three modest ingredients — easy sampling from the source, a closed-form per-sample target, and plain MSE regression. Together they are exactly why flow matching scales cleanly and trains stably at modern model sizes.
Velocity model¶
A small MLP with a sinusoidal time embedding. ~35k parameters — plenty for a 2D toy and fast on CPU.
class VelocityNet(nn.Module):
"""v_theta(x, t) for 2D flow matching.
Inputs:
x: [B, 2] spatial location.
t: [B] time in [0, 1].
Output:
[B, 2] predicted velocity.
"""
def __init__(self, hidden: int = 128, n_freqs: int = 8):
super().__init__()
# Fixed sinusoidal time features: sin/cos of t * 2^k * pi for k=0..n_freqs-1.
self.register_buffer("freqs", 2 ** torch.arange(n_freqs).float() * math.pi)
in_dim = 2 + 2 * n_freqs
self.net = nn.Sequential(
nn.Linear(in_dim, hidden), nn.SiLU(),
nn.Linear(hidden, hidden), nn.SiLU(),
nn.Linear(hidden, hidden), nn.SiLU(),
nn.Linear(hidden, 2),
)
def time_emb(self, t: torch.Tensor) -> torch.Tensor:
ang = t[:, None] * self.freqs[None, :] # [B, n_freqs]
return torch.cat([torch.sin(ang), torch.cos(ang)], dim=-1) # [B, 2*n_freqs]
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return self.net(torch.cat([x, self.time_emb(t)], dim=-1))
def num_params(m: nn.Module) -> int:
return sum(p.numel() for p in m.parameters())
_demo_net = VelocityNet()
print(f"VelocityNet has {num_params(_demo_net):,} parameters.")
VelocityNet has 35,714 parameters.
The training loop¶
Sample a batch of data, sample noise and a random $t$, build $X_t$, regress to $(\varepsilon - X_0)$. One full Stage-1 pretraining recipe in under twenty lines.
def pretrain_flow_matching(
model: VelocityNet,
*,
steps: int = 5000,
batch_size: int = 256,
lr: float = 1e-3,
) -> list[float]:
opt = torch.optim.Adam(model.parameters(), lr=lr)
losses: list[float] = []
pbar = tqdm(range(steps), desc="pretrain", mininterval=2.0)
for s in pbar:
x0 = sample_ring(batch_size)
eps = torch.randn_like(x0)
t = torch.rand(batch_size)
xt = (1 - t[:, None]) * x0 + t[:, None] * eps
target = eps - x0
loss = ((model(xt, t) - target) ** 2).mean()
opt.zero_grad(); loss.backward(); opt.step()
losses.append(loss.item())
if s % 500 == 0:
pbar.set_postfix(loss=f"{loss.item():.3f}")
return losses
v_ref = VelocityNet()
pretrain_losses = pretrain_flow_matching(v_ref, steps=5000)
v_ref.eval()
for p in v_ref.parameters():
p.requires_grad_(False) # freeze it -- this is our reference
pretrain: 100%|██████████| 5000/5000 [00:04<00:00, 1180.43it/s, loss=3.420]
Sampling from the pretrained model¶
We have a trained velocity field $v^\theta(x, t)$. It defines an ODE:
$$ \frac{\mathrm{d} X_t}{\mathrm{d} t} \;=\; v^\theta(X_t, t). $$
Pretraining oriented $v$ so that it points from target toward source as $t$ increases. To generate, then, we draw an easy source sample $\varepsilon \sim \mathcal{N}(0, I)$, place it at $t = 1$, and run the ODE backward to $t = 0$. Each Euler step subtracts the velocity:
@torch.no_grad()
def euler_sample(model: VelocityNet, n: int, n_steps: int = 50) -> torch.Tensor:
"""Generate n samples by integrating v_theta backward from t=1 to t=0.
Starts from n source samples ~ N(0, I) and pushes them toward the target.
"""
x = sample_source(n)
dt = 1.0 / n_steps
for k in range(n_steps):
t = 1.0 - k * dt
x = x - model(x, torch.full((n,), t)) * dt
return x
# Visualize what the pretrained model has learned: target distribution, and source -> learned target.
ref_samples = euler_sample(v_ref, 3000)
src_samples = sample_source(3000)
gt_samples, gt_labels = sample_ring(3000, return_labels=True)
fig, axes = plt.subplots(1, 2, figsize=(10, 4.5))
plot_samples(gt_samples, axes[0], title="p_data (ground truth)",
c=gt_labels.numpy(), cmap="tab10")
plot_samples(src_samples, axes[1], color="navy", label="source (t=1)")
plot_samples(ref_samples, axes[1], color="teal", label="target (t=0)",
title="source N(0, I) -> target")
axes[1].legend(loc="upper right", fontsize=8, frameon=False, markerscale=2)
plt.tight_layout(); plt.show()
The pretrained model successfully transports source samples (a single blob) into the target distribution (eight separate modes). From here on, $v_{\text{ref}}$ is our frozen Stage-1 model; everything that follows is Stage 2.
3. The RL post-training problem¶
Look back at the right panel of the previous section: the pretrained model isn't perfect. Some samples drift between modes, and we have no say in which mode each sample lands on. The job of post-training is to tame the pretrained model into producing the outputs we prefer — and the lever for that is a scalar reward plus the machinery of reinforcement learning.
A reward is a scorecard. For any sample $x$, the reward $r(x)$ returns a real number saying how much we like that sample — higher is better. The grader can be anything we can compute or query: a classifier logit, an OCR match, a docking energy, a human-preference label. It is the only place we tell the model what we want; everything else (architecture, loss, sampler) is just plumbing for following it. Crucially, $r$ doesn't have to be differentiable, and we never need a closed-form expression — a black-box function call is enough.
Formally, we want to bias the pretrained model toward some scalar reward $r(x)$. The canonical objective for RL post-training is KL-regularized reward maximization:
$$ \max_p \;\; \mathbb{E}_{x \sim p}\bigl[r(x)\bigr] \;-\; D_{\mathrm{KL}}\bigl(p \,\big\|\, p_{\text{ref}}\bigr). $$
Here $D_{\mathrm{KL}}(p \,\|\, p_{\text{ref}})$ is the Kullback–Leibler divergence — a standard non-negative measure of how much one distribution differs from another, equal to zero iff $p = p_{\text{ref}}$ and growing as the two pull apart.
Two pieces in tension:
- $\mathbb{E}_{x \sim p}[r(x)]$ — the expected reward under our new model $p$ (which RL calls a policy; we'll use model and policy interchangeably throughout Stage 2); we want this high.
- $D_{\mathrm{KL}}\bigl(p \,\big\|\, p_{\text{ref}}\bigr)$ — how far $p$ has drifted from the pretrained model $p_{\text{ref}}$; we want this small.
So we are tilting toward reward, but not so far that we forget what pretraining taught the model.
The argmax has a clean closed form:
$$ p_{\text{target}}(x) \;\propto\; p_{\text{ref}}(x)\,\exp\!\bigl(r(x)\bigr). $$
So the optimal post-trained model is the pretrained density tilted multiplicatively by $\exp(r)$. The KL anchor keeps us close to the pretrained model in places where $r$ is uninformative; the reward shifts mass to where $r$ is high.
📐 Why is the optimum $p_{\text{ref}}\exp(r)$? Let $Z = \int p_{\text{ref}}(x)\exp(r(x))\,\mathrm{d}x$ and define $q(x) := p_{\text{ref}}(x)\exp(r(x)) / Z$. Then for any density $p$,
$$ \mathbb{E}_p[r] \;-\; D_{\mathrm{KL}}\!\bigl(p\,\big\|\,p_{\text{ref}}\bigr) \;=\; -\,D_{\mathrm{KL}}\!\bigl(p\,\big\|\,q\bigr) \;+\; \log Z. $$
$\log Z$ doesn't depend on $p$, and $D_{\mathrm{KL}}(p\|q) \geq 0$ with equality iff $p = q$, so the unique maximizer is $p^\star = q \propto p_{\text{ref}}\exp(r)$. This Boltzmann/Gibbs identity is standard in maximum-entropy RL and RLHF; see Levine, 2018 for a tutorial.
A reward for our 2D ring. Let's pick a concrete reward $r$: a smooth Gaussian bump centered at $(4, 0)$. With our ring of radius 4, this peaks on the right-most mode and gives appreciable signal to the two neighbors. Other modes get near-zero reward, so the reward is sparse but nonzero. We also fix a tilting strength $\beta = 5$ — the single dial controlling how hard the optimum $p_{\text{target}} \propto p_{\text{ref}}\,\exp(\beta r)$ pulls toward high-reward regions. (We'll sweep $\beta$ in §7.2.)
def reward(x: torch.Tensor, center=(4.0, 0.0), scale: float = 2.5) -> torch.Tensor:
"""Smooth, bounded-in-[0,1] reward favoring points near `center`."""
c = torch.as_tensor(center, dtype=x.dtype, device=x.device)
return torch.exp(-0.5 * ((x - c) ** 2).sum(-1) / scale**2)
BETA = 5.0 # reward scaling we will use throughout (a single dial)
With the reward $r$ in hand, two distributions are now in play — worth keeping straight in your head:
| symbol | meaning | how we touch it in code |
|---|---|---|
| $p_{\text{ref}}$ | What the pretrained model actually produces (close to the ring, but imperfect). This is what RAM regularizes against. | euler_sample(v_ref, n) — draw samples |
| $p_{\text{target}}$ | $\propto p_{\text{ref}} \cdot \exp(\beta\, r)$. The KL-optimal post-trained density we're aiming for. | evaluate the density on a 2D grid for plotting; no direct sampler — building one is what RAM is for |
Let's see $p_{\text{ref}}$, $r$, and $p_{\text{target}}$ side by side. We estimate $p_{\text{ref}}$ from a big pile of pretrained samples using kernel density estimation (KDE) — drop a small Gaussian bump on each sample and sum them up — then multiply by $\exp(\beta r)$ and normalize to get $p_{\text{target}}$.
# A big sample bag from the reference model to build a smooth KDE of p_ref.
ref_big = euler_sample(v_ref, 5000)
def p_ref_density(grid_pts):
p = kde_density(grid_pts, ref_big, h=0.15)
return p / (p.sum() + 1e-12)
def p_target_density(grid_pts, beta=BETA):
p = kde_density(grid_pts, ref_big, h=0.15) * torch.exp(beta * reward(grid_pts))
return p / (p.sum() + 1e-12)
fig, axes = plt.subplots(1, 3, figsize=(14, 4.5))
plot_density(p_ref_density, axes[0], title="p_ref (pretrained)")
plot_density(reward, axes[1], title=f"reward r(x) (we use beta={BETA})")
plot_density(p_target_density, axes[2],
title=fr"p_target $\propto$ p_ref $\cdot$ exp({BETA} r)")
plt.tight_layout(); plt.show()
The middle panel is the reward, smooth and unimodal. The right panel is the KL-optimal post-trained density: it keeps the multi-modal structure of $p_{\text{ref}}$ but heavily favors the modes near the reward peak.
Our problem, restated. Train a velocity field $v^\theta$ whose generated samples look like the right panel — using only the pretrained model $v_{\text{ref}}$, an ODE sampler, and a scalar reward query.
4. Building RAM in three pieces¶
⚠️ Math-heavy ahead — but lighter than the paper. This section walks through the math behind RAM. We motivate each step and keep things informal; the original RAM paper has the rigorous measure-theoretic version. If you only want to use the loss, skim §4 and jump to the boxed loss (RAM) in §4.3.
We have a pretrained velocity field $v_{\text{ref}}$ that generates from $p_{\text{ref}}$. We want a new velocity field whose samples lean toward high reward while staying close to $p_{\text{ref}}$. The cleanest way to think about this is steering: keep $v_{\text{ref}}$ in place and add a correction.
$$ v^\theta(x, t) \;=\; v_{\text{ref}}(x, t) \;+\; \underbrace{(\text{correction})}_{\text{this is what we learn}}. $$
Three questions decompose the problem:
- What's the math of the optimal correction? If we knew the right formula, what would the correction look like at every $(x, t)$? This is answered by adjoint matching from stochastic optimal control (SOC) in §4.1.
- How do we estimate it efficiently? Answered in §4.2 by combining a REINFORCE-style estimator with one structural shortcut from the paper, then computing the Bayes bridge score in closed form.
- How do we turn this into a trainable loss? Answered in §4.3 by plugging the closed-form pieces together.
Adjoint matching + a REINFORCE-style estimator give the algorithm its name: REINFORCE Adjoint Matching, abbreviated RAM.
4.1 Adjoint matching: the math of the optimal correction¶
Think of generation as a tiny agent that starts from noise at $t=1$ and walks toward data at $t=0$, taking small steps prescribed by $v$. At any intermediate state $(x, t)$ — meaning "currently at position $x$, at time $t$" — we can ask:
If we let generation run to completion from here and this moment, what reward should we expect?
That number is the value function:
$$ V(x, t) \;\;=\;\; \mathbb{E}\bigl[\,r(X_0) \,\big|\, X_t = x\,\bigr]. $$
It is high in regions of $(x, t)$-space from which rolling forward tends to land in high-reward places, and low elsewhere.
The spatial gradient of $V$ tells us which direction in $x$-space leads to higher expected reward:
$$ A(x, t) \;\;:=\;\; \nabla_x V(x, t). $$
This vector field is called the adjoint in stochastic optimal control.
📖 Why "adjoint"? In stochastic optimal control there is a forward equation for the state $X_t$ and a paired backward equation for its sensitivity to the future cost — i.e., how a tiny nudge to $X_t$ propagates into a change in the expected terminal reward $r(X_0)$. That backward-propagated sensitivity is the adjoint, and it equals $\nabla_x V(x, t)$.
The point of computing it this way is efficiency: a single backward sweep gives the gradient of a scalar w.r.t. a high-dimensional input $x$ in time independent of $\dim(x)$, far cheaper than perturbing each of $x$'s $d$ coordinates and running generation forward each time. The same idea shows up in autograd: backprop is the discrete adjoint of the forward pass, and the adjoint ODE (Pontryagin's principle, Neural ODEs) is its continuous-time cousin.
The adjoint-matching theorem (Theorem 3.2 of the paper) is the following beautiful self-consistency statement:
A velocity field $v^\theta$ is optimal iff its correction $(v^\theta - v_{\text{ref}})$ is proportional to its own induced adjoint.
"Its own induced adjoint" is the subtle bit. The value function $V$ — and therefore the adjoint $A = \nabla_x V$ — depends on which policy rolls generation forward to compute the expected reward in $V(x, t) = \mathbb{E}[r(X_0)\mid X_t = x]$. Different $v^\theta$'s produce different endpoint distributions for $X_0$, hence different $V$'s and different $A$'s. Writing $A^{v^\theta}$ to make this dependence explicit, the condition
$$ v^\theta - v_{\text{ref}} \;\propto\; A^{v^\theta} \tag{AM} $$
is self-referential: the correction we want depends on the adjoint of the very policy we are trying to find. The natural reading is as a fixed-point iteration:
$$ \underbrace{v^\theta}_{\text{current policy}} \;\xrightarrow{\text{roll out, take }\nabla_x V}\; \underbrace{A^{v^\theta}}_{\text{its adjoint}} \;\xrightarrow{\text{install as new correction}}\; \underbrace{v_{\text{ref}} + A^{v^\theta}}_{\text{updated policy}}. $$
The theorem basically says an optimal solution $v^\theta$ is a fixed point that leaves this iteration unchanged. Done — in principle.
The snag: $\nabla V$ needs $\nabla r$. $V(x, t) = \mathbb{E}[r(X_0) \mid X_t = x]$ is the expected reward at the endpoint $X_0$ reached by rolling the policy forward from $(x, t)$. Differentiating it in $x$ propagates a sensitivity backward through the entire rollout to the endpoint, where we need $\nabla_{X_0} r$. The classical adjoint method (Pontryagin's principle / Neural ODEs) literally integrates a backward ODE seeded with the terminal condition $\nabla r(X_0)$ — and if $r$ is non-differentiable, the recipe has nothing to seed the backward sweep with.
📚 Prior work: Adjoint Matching. The original Adjoint Matching paper (Domingo-Enrich et al., arXiv:2409.08861, 2024) took exactly this route. They prove the adjoint-matching theorem we just stated, compute $A$ via the backward adjoint ODE seeded with $\nabla r(X_0)$, and regress the correction $(v^\theta - v_{\text{ref}})$ onto it. It works beautifully when $r$ is a neural network you can backprop through (CLIP scores, aesthetic predictors), but two issues motivate the next step:
- Many useful rewards are not differentiable — OCR edit distance, object-detection counts, yes/no preference labels, win-rate verifiers.
- The backward sweep is expensive at image scale — an extra integration over the entire generation trajectory at every training step, with variance that grows with rollout length.
RAM keeps the adjoint-matching condition (AM) but swaps the backward sweep for a REINFORCE-style estimator of $A$ that needs only scalar reward queries. That is what §4.2 derives.
4.2 REINFORCE: gradients of expectations without gradients of the integrand¶
Time for a brief detour into the REINFORCE identity (R), the workhorse of policy-gradient RL. It says: to differentiate the expected value of a function under a parametrized distribution, you don't need to differentiate the function — only the log-density of the distribution.
If $X \sim p_\theta$, then
$$ \nabla_\theta \, \mathbb{E}_{X \sim p_\theta}\bigl[\, f(X)\, \bigr] \;=\; \mathbb{E}_{X \sim p_\theta}\!\bigl[\, f(X) \cdot \nabla_\theta \log p_\theta(X)\,\bigr]. \tag{R} $$
The gradient passes from $f$ onto $\log p$ (which we do control by choosing $\theta$). $f$ enters the right-hand side only as a scalar weight — never differentiated. This is why classical RL works on discrete rewards, black-box simulators, and the like.
🧮 A one-line 1D check. Take $p_\theta = \mathcal{N}(\theta, 1)$ and $f(x) = x^2$. Then $\mathbb{E}[f(X)] = 1 + \theta^2$, whose derivative is $2\theta$. The REINFORCE side: $\nabla_\theta \log p_\theta(x) = x - \theta$, so the RHS is $\mathbb{E}[X^2 (X - \theta)] = \theta^3 + 3\theta - \theta(1 + \theta^2) = 2\theta$. ✓ Same answer, never used $\nabla f$.
Applying the same trick to $\nabla_x V$. The exact same identity applies with $\nabla_x$ in place of $\nabla_\theta$. The "$\theta$-like variable" is now $x$; the "distribution" is the backward bridge $p(x_0 \mid x)$ — the conditional law of the endpoint $X_0$ given that the policy passed through state $x$ at time $t$; the "function" is the reward $r(X_0)$. Applying (R):
$$ A(x, t) \;=\; \mathbb{E}\!\bigl[\, r(X_0) \cdot \underbrace{\nabla_x \log p(X_0 \mid x)}_{\text{Bayes bridge score}} \;\big|\; X_t = x \,\bigr]. \tag{A} $$
The factor under the brace is the Bayes bridge score — the spatial gradient of the log-density of the backward bridge, evaluated at the "current state" $x$. It mirrors REINFORCE's $\nabla_\theta \log p_\theta$, with $x$ playing the role of $\theta$ and the bridge density playing the role of $p_\theta$.
Read this as: the adjoint at $(x, t)$ is a reward-weighted Bayes bridge score. Exactly like vanilla REINFORCE — $r$ enters as a scalar weight, never differentiated.
Plugging this REINFORCE-estimated adjoint back into the adjoint-matching condition of §4.1 gives REINFORCE Adjoint Matching (RAM):
Choose $v^\theta$ so that its correction $(v^\theta - v_{\text{ref}})$ equals a reward-weighted Bayes-bridge score, in expectation, under the policy's own endpoint distribution.
We are almost there: we have $v_{\text{ref}}$, we have the reward $r$; the only thing left is a lot of pairs $(X_0, X_t)$ drawn from the current policy $v^\theta$ to approximate the expectation. Getting an endpoint $X_0$ from the policy is unavoidable — that costs one full rollout from noise at $t=1$ down to data at $t=0$ (dozens of network evaluations), and we cannot cache across training steps because $v^\theta$ keeps updating. The matched $X_t$'s do fall out of the same rollout — every intermediate Euler step is one such $X_t$ — but two obstacles stand in the way of using them directly:
- They are not independent. All intermediates from one trajectory share the same endpoint $X_0$ and are correlated with each other; Monte Carlo wants many independent samples per expensive rollout.
- We cannot compute the bridge score on them. Evaluating $\nabla_x \log p(x_0 \mid x_t)$ requires a closed-form expression for the policy's conditional $p(x_t \mid x_0)$, which an arbitrary $v^\theta$ does not give us.
The paper's Theorem 3.1 removes both obstacles at once. It says we can use the same cheap analytic noising as pretraining:
The KL-optimal post-trained model uses the same conditional noising rule as pretraining. Only the endpoint distribution changes from $p_{\text{ref}}$ to $p_{\text{target}}$.
Concretely: in pretraining, training pairs $(X_0, X_t)$ are built by $X_0 \sim p_{\text{data}}$ and $X_t = (1-t)X_0 + t\varepsilon$, so that $X_t \mid X_0$ is Gaussian. The theorem says: at the post-training optimum, the conditional law $X_t \mid X_0$ is exactly the same Gaussian:
$$ X_t \mid X_0 \;\sim\; \mathcal{N}\bigl((1-t)X_0,\, t^2 I\bigr). \tag{T3.1} $$
(T3.1) resolves the two obstacles effectively:
- The bridge score becomes computable. With $p(x_t \mid x_0)$ in closed form, Bayes' rule gives a closed form for $p(x_0 \mid x_t)$ — and hence for the score $\nabla_x \log p(x_0 \mid x_t)$ that appears in (A). We work this out in a moment.
- Independence is restored. Once $X_t$ is a closed-form Gaussian function of $X_0$ and a fresh draw $(t, \varepsilon)$, we can re-noise the same $X_0$ as many times as we want with independent $(t_k, \varepsilon_k)$ pairs. One expensive rollout for $X_0$ yields $K$ uncorrelated training pairs from it — the K-targets trick we'll see in the algorithm.
Now we can happily generate $K$ on-policy training pairs in 3 steps:
- Sample an endpoint $X_0$ from the current model (any ODE sampler).
- Sample $K$ pairs $(\varepsilon_k, t_k)$ with $\varepsilon_k \sim \mathcal{N}(0, I)$ and $t_k \in (0, 1)$.
- Build $K$ noised states $X_{t_k} = (1 - t_k)\,X_0 + t_k\,\varepsilon_k$. ✨ No SDE rollout. ✨
Last piece: the Bayes bridge score in closed form. The score $\nabla_x \log p(x_0 \mid x_t)$ inside (A) has a clean closed form once we plug in the Gaussian (T3.1). Using Bayes' rule and the score–velocity identity, it simplifies (paper Prop. 4.1) to
$$ \nabla_{x_t}\log p(x_0\mid x_t) \;=\; \tfrac{1-t}{t}\,\bigl(v^\theta(x_t, t) - (\varepsilon - x_0)\bigr). \tag{BS} $$
4.3 Assembling the RAM loss¶
We have all the math we need. Time to fold the four pieces together into a single regression loss:
- (AM) — the optimality condition $(v^\theta - v_{\text{ref}}) \propto A^{v^\theta}$,
- (A) — the REINFORCE expression for $A$ in terms of the bridge score,
- (BS) — the closed-form bridge score,
- (T3.1) — the cheap analytic noising that lets us sample $(X_0, X_t)$ jointly.
Three small steps.
Step 1 — turn the adjoint into a reward-weighted bridge score. Plug (BS) into (A). The deterministic prefactor $\tfrac{1-t}{t}$ pulls out of the expectation:
$$ A^{v^\theta}(x_t, t) \;=\; \tfrac{1-t}{t}\; \mathbb{E}\!\bigl[\, r(X_0)\,\bigl(v^\theta(x_t, t) - (\varepsilon - X_0)\bigr) \,\big|\, X_t = x_t \,\bigr]. $$
Step 2 — convert the matching condition to a fixed point for $v^\theta$. Substitute the above into (AM). Following the paper's velocity-space conversion (§B.2) — which absorbs the $\tfrac{1-t}{t}$ prefactor into an implicit loss weighting, fixes a sign from reverse-time generation, and fuses the proportionality constant into a single tunable gain $\beta > 0$ — we get a clean target velocity $v^\theta = T(x_t, t)$
$$ T(x_t, t) \;:=\; v_{\text{ref}}(x_t, t) \;+\; \beta\; \mathbb{E}\!\bigl[\, r(X_0)\,\bigl((\varepsilon - X_0) - v^\theta(x_t, t)\bigr) \,\big|\, X_t = x_t \,\bigr]. $$
Step 3 — regress $v^\theta$ onto $T$ with stop-gradient. Note that $v^\theta$ appears inside $T$. To keep the fixed-point structure clean we freeze the target with stop-gradient and minimize a plain MSE:
$$ \mathcal{L}(\theta) \;=\; \mathbb{E}\bigl\| v^\theta(X_t, t) - \mathrm{sg}\bigl(T(X_t, t)\bigr)\bigr\|^2. $$
🎯 Finally, the RAM loss¶
$$ \boxed{ \mathcal{L}_{\text{RAM}}(\theta) = \mathbb{E}_{X_0,\,\varepsilon,\,t}\!\left[\;\Big\|\,v^\theta(X_t, t) - \mathrm{sg}\!\left(\,v_{\text{ref}}(X_t, t) + r(X_0)\,\bigl((\varepsilon - X_0) - v^\theta(X_t, t)\bigr)\,\right)\Big\|^2\;\right] } \tag{RAM} $$
with $X_0 \sim p^\theta$, $\varepsilon \sim \mathcal{N}(0, I)$,
$X_t = (1-t)X_0 + t\,\varepsilon$.
For clarity the box folds $\beta$ into the reward; we'll dial $\beta$ as a
single hyperparameter in §5 (Implementation). Compare to pretraining
$\bigl\|v^\theta(X_t, t) - (\varepsilon - X_0)\bigr\|^2$: same MSE shape,
but the target is now v_ref + r * (pretrain_target - v_theta) instead
of just pretrain_target.
Reading the target:
- If $r(x_0)$ is large and positive, the target pulls $v^\theta$ toward $(\varepsilon - x_0)$ — i.e. toward the pretraining target for this specific endpoint. That makes this endpoint more likely to be produced again. Good outcome → make it more frequent.
- If $r(x_0) \approx 0$, the target collapses to $v_{\text{ref}}$. Don't move. This is the KL anchor.
- If $r(x_0)$ is negative (only happens once we subtract a baseline, in §5 below), the target pushes $v^\theta$ away from $(\varepsilon - x_0)$. Bad outcome → make this endpoint less frequent.
The reference $v_{\text{ref}}$ stays in the target for the entire training run — that's what prevents reward hacking and keeps generations on the data manifold even when the reward is wrong.
⚠️ Faithfulness caveat. The boxed loss (RAM) is not the exact KL-optimal update. RAM (a) drops a "path-cost correction" term in the optimal-control decomposition, (b) uses a plug-in score that is exact only at initialization and at the true optimum, and (c) replaces an integral over a path of tilted distributions by a single fixed-point evaluation.
Per the paper, RAM coincides with the KL optimum exactly for a Gaussian reference under a linear reward, and agrees with it to first order in the reward more generally. In practice (on images, and as we'll see in 2D), it's a solid approximation. We will not prove any of this; the goal of this notebook is to understand the recipe.
5. Implementation¶
We take the boxed loss (RAM) literally, with two small practical additions worth flagging up front:
- $K$ targets per endpoint. ODE sampling and reward evaluation are
the expensive operations; the loss itself is cheap. We reuse each
$x_0$ for $K$ independent $(t, \varepsilon)$ draws — exposed as
k_targetsin the config. - Group-relative advantage. Like all REINFORCE-style methods, RAM
benefits from subtracting a baseline. We sample $G$ endpoints in a
"group" and use the group mean (same trick as GRPO / Flow-GRPO /
RLHF). Toggled via
use_advantage.
@dataclass
class RAMConfig:
outer_steps: int = 600 # number of optimizer steps
group_size: int = 32 # endpoints per step (G)
k_targets: int = 4 # (t, eps) draws per endpoint (K)
sample_steps: int = 25 # Euler steps for endpoint sampling
lr: float = 1e-4
beta: float = 5.0 # reward scale (single dial)
use_advantage: bool = False # subtract group-mean baseline
scale_advantage: bool = False # also divide by group std
log_every: int = 50
def train_ram(
model: VelocityNet,
model_ref: VelocityNet,
reward_fn,
cfg: RAMConfig,
) -> dict:
"""Train `model` with RAM. `model_ref` is the frozen pretrained reference.
Returns a dict with training-curve arrays.
"""
opt = torch.optim.Adam(model.parameters(), lr=cfg.lr)
mean_reward_hist, loss_hist = [], []
pbar = tqdm(range(cfg.outer_steps), desc=f"RAM (beta={cfg.beta})", mininterval=2.0)
for step in pbar:
# 1) Sample G endpoints on-policy.
with torch.no_grad():
x0 = euler_sample(model, cfg.group_size, n_steps=cfg.sample_steps)
# 2) Score them with the reward.
raw_reward = reward_fn(x0) # [G]; the actual r(x_0)
mean_reward_hist.append(raw_reward.mean().item())
# 3) Build the "signal" that enters the loss.
if cfg.use_advantage:
advantage = raw_reward - raw_reward.mean()
if cfg.scale_advantage:
advantage = advantage / (raw_reward.std(correction=0) + 1e-4)
signal = cfg.beta * advantage
else:
signal = cfg.beta * raw_reward # the raw-reward loss
# 4) Reuse each endpoint for K (t, eps) draws.
x0_rep = x0.repeat_interleave(cfg.k_targets, dim=0) # [G*K, 2]
signal_rep = signal.repeat_interleave(cfg.k_targets, dim=0) # [G*K]
B = x0_rep.shape[0]
eps = torch.randn_like(x0_rep)
t = torch.rand(B)
xt = (1 - t[:, None]) * x0_rep + t[:, None] * eps
# 5) Build the RAM target (stop-gradient on everything inside).
with torch.no_grad():
v_ref_xt = model_ref(xt, t)
v_theta_sg = model(xt, t) # current model, no grad
pretrain_target = eps - x0_rep
target = v_ref_xt + signal_rep[:, None] * (pretrain_target - v_theta_sg)
# 6) MSE between the model's prediction and the target.
v_pred = model(xt, t)
loss = ((v_pred - target) ** 2).mean()
opt.zero_grad(); loss.backward(); opt.step()
loss_hist.append(loss.item())
if step % cfg.log_every == 0:
pbar.set_postfix(loss=f"{loss.item():.3f}",
mean_r=f"{mean_reward_hist[-1]:.3f}")
return {"mean_reward": mean_reward_hist, "loss": loss_hist}
A note on signal. With use_advantage=False we use raw $\beta\,r$ —
this is the version that literally matches the boxed loss (RAM) and whose fixed
point we will compare to the analytic $p_{\text{target}}$. With
use_advantage=True we subtract the per-group mean as a variance-reduction
baseline; this is the standard "GRPO"-style trick that the paper's SD3
script and most modern RLHF setups use. Both reach the same fixed point;
we mostly use raw reward in this notebook so the comparison to
$p_{\text{target}}$ remains apples-to-apples (the standardized variant
also has a long-horizon caveat we will see in §7.5).
6. Watching RAM tilt the ring¶
The headline experiment. We initialize $v^\theta \leftarrow v_{\text{ref}}$ and train with raw-reward RAM at $\beta=5$.
# Initialize the model as a copy of the pretrained reference.
v_theta = VelocityNet()
v_theta.load_state_dict(v_ref.state_dict())
history = train_ram(v_theta, v_ref, reward, RAMConfig(
outer_steps=600,
group_size=32,
k_targets=4,
beta=BETA,
lr=1e-4,
))
RAM (beta=5.0): 100%|██████████| 600/600 [00:01<00:00, 437.55it/s, loss=17.484, mean_r=0.689]
The mean-reward curve should climb from the pretrained baseline (~0.24, since mostly one mode sits inside the reward bump with two neighbors getting modest signal) to roughly $\sim 0.67$. The loss can increase during training — that's expected here because the regression target itself grows with the reward; what matters is that the residual stays bounded.
fig, axes = plt.subplots(1, 2, figsize=(11, 4))
axes[0].plot(history["mean_reward"])
axes[0].set_xlabel("outer step"); axes[0].set_ylabel("mean reward over group")
axes[0].set_title("RAM training reward")
axes[1].plot(history["loss"])
axes[1].set_xlabel("outer step"); axes[1].set_ylabel("RAM loss")
axes[1].set_title("RAM loss (note: grows with reward; not a failure mode)")
axes[1].set_yscale("log")
plt.tight_layout(); plt.show()
The four-panel comparison everyone is waiting for. Pretrained samples, the reward, the analytic target, and what RAM produced.
v_theta.eval()
ram_samples = euler_sample(v_theta, 3000)
ram_reward = reward(ram_samples).mean().item()
ref_reward = reward(ref_samples).mean().item()
print(f"mean reward — pretrained: {ref_reward:.3f} | RAM-trained: {ram_reward:.3f} "
f"(x{ram_reward/ref_reward:.1f})")
def p_ram_density(grid_pts):
p = kde_density(grid_pts, ram_samples, h=0.15)
return p / (p.sum() + 1e-12)
fig, axes = plt.subplots(1, 4, figsize=(18, 4.5))
plot_density(p_ref_density, axes[0], title="p_ref (pretrained)")
plot_density(reward, axes[1], title="reward r(x)")
plot_density(p_target_density, axes[2], title=fr"analytic $p_{{\rm target}} \propto p_{{\rm ref}} \cdot e^{{{BETA} r}}$")
plot_density(p_ram_density, axes[3], title=f"RAM-trained (mean r = {ram_reward:.2f})")
plt.tight_layout(); plt.show()
mean reward — pretrained: 0.244 | RAM-trained: 0.673 (x2.8)
The right-most panel should look very similar to the third panel (the analytic target): mass concentrates around $(4, 0)$, with smaller residual weight on the two adjacent modes. Mean reward jumps from $\sim 0.24$ to $\sim 0.67$ — about a $2.8\times$ improvement on this task.
7. Ablation studies¶
Five short experiments isolating one knob at a time: the $\beta = 0$ anchor sanity check (§7.1), tilting strength $\beta$ (§7.2), reuse $K$ (§7.3), group-relative advantage (§7.4), and a cautionary note on long-horizon advantage scaling (§7.5).
7.1 The reference anchor — $\beta = 0$ should be a no-op¶
If we set the reward to zero (i.e. $\beta = 0$), every $r(x_0) = 0$, the target collapses to $v_{\text{ref}}(x_t, t)$, and the loss is $\|v^\theta - v_{\text{ref}}\|^2$ — which starts at zero (since $v^\theta = v_{\text{ref}}$) and should stay at zero. The samples should not drift.
This is the empirical confirmation that $v_{\text{ref}}$ in the target acts as a hard anchor. (We must use the raw-reward version here, not the advantage version, because subtracting a baseline from zero would still trivially be zero — the test would be tautological.)
v_theta_zero = VelocityNet()
v_theta_zero.load_state_dict(v_ref.state_dict())
hist_zero = train_ram(v_theta_zero, v_ref, reward, RAMConfig(
outer_steps=200, beta=0.0, lr=1e-4, log_every=200,
))
zero_samples = euler_sample(v_theta_zero, 3000)
print(f"max loss observed (should be ~0): {max(hist_zero['loss']):.2e}")
print(f"mean reward of beta=0 RAM samples: {reward(zero_samples).mean():.4f}")
print(f"mean reward of pretrained samples: {reward(ref_samples).mean():.4f}")
fig, axes = plt.subplots(1, 2, figsize=(9, 4))
plot_samples(ref_samples, axes[0], title="pretrained", color="teal")
plot_samples(zero_samples, axes[1], title="beta=0 RAM (should match)", color="C2")
plt.tight_layout(); plt.show()
RAM (beta=0.0): 100%|██████████| 200/200 [00:00<00:00, 435.23it/s, loss=0.000, mean_r=0.242]
max loss observed (should be ~0): 0.00e+00 mean reward of beta=0 RAM samples: 0.2420 mean reward of pretrained samples: 0.2436
The loss is essentially numerical noise around zero; the sample plots are indistinguishable. The anchor works.
7.2 Tilting strength — bigger $\beta$ tilts harder¶
As we scale the reward by $\beta$, the analytic target $p_{\text{ref}} \cdot \exp(\beta\,r)$ peaks more sharply. RAM-trained samples should follow. We sweep $\beta \in \{2, 5, 8\}$.
sweep_results = []
for beta in [2.0, 5.0, 8.0]:
torch.manual_seed(7)
v_sw = VelocityNet()
v_sw.load_state_dict(v_ref.state_dict())
h = train_ram(v_sw, v_ref, reward, RAMConfig(
outer_steps=400, beta=beta, lr=1e-4, log_every=400,
))
samp = euler_sample(v_sw, 3000)
sweep_results.append((beta, samp, reward(samp).mean().item()))
fig, axes = plt.subplots(2, 3, figsize=(13, 8))
for col, (beta, samp, mean_r) in enumerate(sweep_results):
# Top row: analytic target.
def make_target(b=beta):
def f(grid):
p = kde_density(grid, ref_big, h=0.15) * torch.exp(b * reward(grid))
return p / (p.sum() + 1e-12)
return f
plot_density(make_target(beta), axes[0, col],
title=fr"analytic $p_{{\rm target}}$, $\beta = {beta}$")
# Bottom row: KDE density of the RAM-trained samples (same form as top
# row, so the tilting is directly readable from peak sharpness).
def make_ram_density(s=samp):
def f(grid):
p = kde_density(grid, s, h=0.15)
return p / (p.sum() + 1e-12)
return f
plot_density(make_ram_density(samp), axes[1, col],
title=fr"RAM-trained, $\beta = {beta}$, mean $r$ = {mean_r:.2f}")
plt.tight_layout(); plt.show()
RAM (beta=2.0): 100%|██████████| 400/400 [00:00<00:00, 439.91it/s, loss=2.047, mean_r=0.288]
RAM (beta=5.0): 100%|██████████| 400/400 [00:00<00:00, 436.08it/s, loss=12.791, mean_r=0.288]
RAM (beta=8.0): 100%|██████████| 400/400 [00:00<00:00, 438.79it/s, loss=32.745, mean_r=0.288]
Top row: analytic target getting sharper with $\beta$. Bottom row: RAM samples following. With small $\beta$ several modes survive; with $\beta = 8$ the model concentrates essentially all probability mass onto the rewarded mode.
7.3 Reuse — more $(t, \varepsilon)$ draws per endpoint help¶
A central claim of the paper: because endpoint sampling and reward evaluation are the expensive operations, reusing each $x_0$ across $K$ independent $(t, \varepsilon)$ draws gives more independent gradient signal per ODE rollout than methods that build all their training states from a single SDE rollout (which produces correlated states).
We compare $K \in \{1, 4, 16\}$ at fixed $G$ and fixed outer-step count. Per outer step, $K=16$ does 16× more loss evals than $K=1$ — but the same amount of ODE sampling. Larger $K$ should reach higher reward in fewer sampling-bound steps.
k_results = {}
for K in [1, 4, 16]:
torch.manual_seed(11)
v_k = VelocityNet()
v_k.load_state_dict(v_ref.state_dict())
h = train_ram(v_k, v_ref, reward, RAMConfig(
outer_steps=300, group_size=32, k_targets=K, beta=BETA, lr=1e-4,
log_every=300,
))
k_results[K] = h["mean_reward"]
plt.figure(figsize=(8, 4))
for K, mr in k_results.items():
# Smooth a little for legibility.
mr_smooth = np.convolve(mr, np.ones(20)/20, mode="valid")
plt.plot(mr_smooth, label=f"K = {K}")
plt.xlabel("outer step (each step = one ODE rollout)")
plt.ylabel("mean reward over group (smoothed)")
plt.title("Targets per endpoint K: more reuse = more signal per sampling cost")
plt.legend()
plt.tight_layout(); plt.show()
RAM (beta=5.0): 100%|██████████| 300/300 [00:00<00:00, 487.26it/s, loss=11.599, mean_r=0.227]
RAM (beta=5.0): 100%|██████████| 300/300 [00:00<00:00, 435.67it/s, loss=28.244, mean_r=0.227]
RAM (beta=5.0): 100%|██████████| 300/300 [00:01<00:00, 257.27it/s, loss=15.991, mean_r=0.227]
Larger $K$ reaches higher mean reward in the same number of sampling steps. The pattern matches the paper's observation that RAM's conditionally-independent training states are more informative than correlated states from a single SDE rollout. On CPU this micro-experiment runs in a few seconds; at image scale, the difference is the difference between "feasible" and "infeasible".
7.4 Group-relative advantage — variance reduction¶
Subtracting a per-group baseline from the reward (and optionally dividing by the group's standard deviation) is the standard REINFORCE variance-reduction trick, used everywhere from GRPO to RLHF. It doesn't change the boxed loss (RAM)'s fixed point: the baseline is constant within a group, so it averages out in expectation. What it changes is the gradient variance.
We compare three signals at fixed $\beta$ and seed:
signal = beta * r(raw reward — what we've used so far),signal = beta * (r - mean(r))(group-mean baseline),signal = beta * (r - mean(r)) / std(r)(standardized — the SD3 default).
adv_variants = [
("raw $\\beta r$", False, False),
("group-mean centered", True, False),
("group-mean + std-scaled", True, True),
]
adv_results = {}
for tag, use_adv, scale_adv in adv_variants:
torch.manual_seed(21)
v_a = VelocityNet()
v_a.load_state_dict(v_ref.state_dict())
h = train_ram(v_a, v_ref, reward, RAMConfig(
outer_steps=400, group_size=32, k_targets=4,
beta=BETA, lr=1e-4, log_every=400,
use_advantage=use_adv, scale_advantage=scale_adv,
))
adv_results[tag] = h["mean_reward"]
plt.figure(figsize=(8, 4))
for tag, mr in adv_results.items():
mr_smooth = np.convolve(mr, np.ones(20)/20, mode="valid")
plt.plot(mr_smooth, label=tag)
plt.xlabel("outer step")
plt.ylabel("mean reward over group (smoothed)")
plt.title(f"Reward signals at $\\beta = {BETA}$")
plt.legend()
plt.tight_layout(); plt.show()
RAM (beta=5.0): 100%|██████████| 400/400 [00:00<00:00, 441.91it/s, loss=8.657, mean_r=0.206]
RAM (beta=5.0): 100%|██████████| 400/400 [00:00<00:00, 441.90it/s, loss=6.677, mean_r=0.206]
RAM (beta=5.0): 100%|██████████| 400/400 [00:00<00:00, 436.31it/s, loss=64.702, mean_r=0.206]
All three curves climb to comparable mean reward — the fixed point is shared. The baselined variants typically reach it with smoother / slightly faster trajectories, which is exactly the variance-reduction story. At image scale this becomes much more pronounced.
7.5 The other side of the advantage signal — over-training collapse¶
§7.4 introduced the std-scaled advantage $\text{adv} = (r - \bar r)/\sigma_r$ and showed it reaches the same plateau as the raw-reward signal with lower variance. Time to disclose its cost.
§7.1 showed that with $\beta = 0$, $v_{\text{ref}}$ in the per-sample target acts as a hard anchor — parameters don't move at all. With $\beta > 0$ the anchor only applies on samples whose advantage is near zero. Look at the per-sample target again: $$ v_{\text{ref}}(x_t, t) + \beta \cdot \text{adv}(x_0) \cdot \bigl((\varepsilon - x_0) - \mathrm{sg}\bigl(v^\theta(x_t, t)\bigr)\bigr). $$
With the raw-reward signal ($\text{adv} := r$, bounded in $[0, 1]$), the off-anchor term is bounded, the per-sample fixed point $v^\theta_\star = \bigl(v_{\text{ref}} + \beta r\,(\varepsilon - x_0)\bigr) / (1 + \beta r)$ always exists, and training is well-behaved over long horizons. With the std-scaled signal the picture changes once the model concentrates around its target mode: intra-batch reward variance shrinks and the standardized advantage $\text{adv} = (r - \bar r) / \sigma_r$ amplifies what is by then almost pure noise. (We clamp $\sigma_r$ from below, so inflation is bounded but real.) The most damaging effect is on below-mean samples: when $\beta \cdot \text{adv} < -1$, the analogous fixed point $v^\theta_\star = \bigl(v_{\text{ref}} + \beta\,\text{adv}\,(\varepsilon - x_0)\bigr) / (1 + \beta\,\text{adv})$ flips sign in its denominator and effectively explodes — the model chases a target far from $v_{\text{ref}}$, drifts off the data manifold, and the reward collapses.
Let's run the same §6 setup but with std-scaled advantage, pushed all the way to 1000 steps.
# Same setup as §6 (single reward at (4, 0), beta=5, G=32, K=4) but with
# std-scaled advantage and pushed to 1000 outer steps. Snapshot the
# 3000-sample reward at intermediate steps.
torch.manual_seed(3)
v_long = VelocityNet()
v_long.load_state_dict(v_ref.state_dict())
opt_long = torch.optim.Adam(v_long.parameters(), lr=1e-4)
mean_r_hist = []
snap_steps = [100, 200, 300, 400, 500, 700, 900, 1000]
snap_rewards = []
for step in range(1, 1001):
with torch.no_grad():
x0 = euler_sample(v_long, 32)
r = reward(x0)
adv = (r - r.mean()) / r.std(correction=0).clamp_min(1e-3) # the §7.4 variant
x0_rep = x0.repeat_interleave(4, 0)
eps = torch.randn_like(x0_rep)
t = torch.rand(x0_rep.shape[0])
xt = (1 - t[:, None]) * x0_rep + t[:, None] * eps
with torch.no_grad():
v_ref_xt = v_ref(xt, t)
v_sg = v_long(xt, t)
pretrain_target = eps - x0_rep
sig = (BETA * adv).repeat_interleave(4, 0)
target = v_ref_xt + sig[:, None] * (pretrain_target - v_sg)
loss = ((v_long(xt, t) - target) ** 2).mean()
opt_long.zero_grad(); loss.backward(); opt_long.step()
mean_r_hist.append(r.mean().item())
if step in snap_steps:
with torch.no_grad():
snap_rewards.append((step, reward(euler_sample(v_long, 3000)).mean().item()))
print(f"{'step':>5} | reward on 3000 ODE samples")
for s, r in snap_rewards:
print(f"{s:5d} | {r:.3f}")
step | reward on 3000 ODE samples 100 | 0.404 200 | 0.591 300 | 0.712 400 | 0.677 500 | 0.666 700 | 0.469 900 | 0.026 1000 | 0.006
mr = np.asarray(mean_r_hist)
W = 30
mr_smooth = np.convolve(mr, np.ones(W) / W, mode="valid")
fig, ax = plt.subplots(figsize=(8.5, 4.2))
ax.plot(mr, color="C0", alpha=0.18, lw=0.7, label="per-batch reward")
ax.plot(np.arange(W - 1, len(mr)), mr_smooth,
color="C0", lw=2.2, label=f"rolling mean ({W})")
xs, ys = zip(*snap_rewards)
ax.plot(xs, ys, "o", color="C3", markersize=7,
label="reward on 3000 ODE samples")
ax.axvspan(250, 500, color="green", alpha=0.10, label="safe plateau")
ax.axvspan(700, 1000, color="red", alpha=0.10, label="collapsing")
ax.set_xlabel("outer step")
ax.set_ylabel("mean reward")
ax.set_title(r"RAM with std-scaled advantage, 1000 steps: climb $\to$ plateau $\to$ collapse")
ax.legend(loc="lower left", fontsize=9)
ax.set_ylim(-0.05, 1.0)
plt.tight_layout(); plt.show()
With std-scaled advantage the reward climbs steeply through ~300 steps, holds a plateau through ~500, then degrades — first slowly, then catastrophically past ~900. Different seeds shift the exact collapse onset by a hundred steps or so but the shape is the same. (The raw- reward §6 setup keeps climbing well past 1000 steps — it's the standardization that breaks down once the batch becomes uniform.)
The practical takeaway: std-scaled advantage is useful early because it normalizes reward scale across batches and sharpens the relative preference signal among rollouts. But once the batch becomes nearly uniform, the standardization amplifies what is by then almost pure noise and destabilizes training. Two ways to live with that — early-stop while the plateau holds, or stay on raw $\beta r$ and lose nothing on the toy 2D problem.
8. RAM with multiple rewards¶
The single-reward setting is a useful pedagogical model, but real post-training rarely lives there. Aligning an image generator with human preferences typically balances composability (does the prompt's object structure render?), text rendering (OCR), aesthetics, safety, and more — five or more reward models running at once. In this section we extend RAM to handle multiple rewards in a single unified framework.
A common approach is a linear blend $r(x) = \sum_j w_j\, r_j(x)$ with fixed weights $w_j$. It is simple and often a good starting point, but choosing the $w_j$ takes care: they depend on each reward's scale and dynamic range, and sample-level mismatch means many rollouts are informative for one reward dimension but near-noise for others — collapsing them into a single weighted scalar discards that per-reward signal.
A complementary framing: each reward $r_j$ defines its own RAM loss $\mathcal{L}_j(\theta)$, and we want $\theta$ to descend all of them together. That is multi-objective optimization (MOO). The classical recipe is MGDA — Multiple Gradient Descent Algorithm (Désidéri, 2012; in the form widely used in deep multi-task learning, Sener & Koltun, NeurIPS 2018). At each step, MGDA computes the per-task gradients $g_j = \nabla_\theta\,\mathcal{L}_j$ and finds the convex combination
$$ g_\alpha \;=\; \sum_{j=1}^{J} \alpha_j\, g_j, \qquad \alpha \in \Delta_J \;\text{(probability simplex)}, $$
with minimum norm. Stepping in $-g_\alpha$ is a common-descent direction for the current minibatch surrogate losses — every $\mathcal{L}_j$ decreases simultaneously whenever such a direction exists (otherwise we are at a Pareto stationary point). MARBLE (Zhao et al., 2026) brings this idea to diffusion RL post-training. It keeps per-reward advantage estimators, computes per-reward policy gradients, and combines them by solving the same simplex QP — so the balance between rewards is set automatically each step rather than tuned in advance. We implement a simplified version on multi-reward RAM.
8.1 MGDA: per-reward gradients + Frank-Wolfe combiner¶
min_norm_simplex solves the small QP
$\min_{\alpha \in \Delta_J}\,\|\sum_j \alpha_j g_j\|^2$ by Frank-Wolfe:
each iteration picks the simplex vertex $e_{j^\star}$ that most reduces
the squared norm and takes a closed-form line-search step toward it.
For the small $J$ we use here the loop converges in a handful of
iterations.
train_mgda_ram then plugs this into the RAM training loop. Per step:
one rollout shared across all $J$ rewards, one shared $(t, \varepsilon)$
batch and one shared forward pass; we build $J$ per-reward targets and
take $J$ backward passes through the same forward graph. Sharing the
rollout is deliberate — only the reward signal $r_j(X_0)$ differs
between the $J$ gradients, so $\alpha$ sees clean directional
information about the rewards themselves rather than batch sampling
noise. (This is what makes MARBLE's amortized backward possible at
image scale.) We then apply two MARBLE-style refinements: each
per-reward gradient is normalized to unit norm before solving for
$\alpha$ (so $\alpha$ reflects directional conflict between rewards,
not magnitude disparities), and the combined direction is rescaled
by the mean original norm so Adam sees the natural per-step step
size. Finally $\alpha$ is EMA-smoothed before applying.
def min_norm_simplex(grads, n_iters=20):
"""Frank-Wolfe minimizer of ||sum_j alpha_j g_j||^2 over the simplex."""
G = torch.stack(grads) # [J, P]
device, dtype = G.device, G.dtype
M = G @ G.T # [J, J] Gram
n_tasks = G.shape[0]
alpha = torch.full((n_tasks,), 1.0 / n_tasks, device=device, dtype=dtype)
for _ in range(n_iters):
v = M @ alpha
j_min = int(v.argmin())
e = torch.zeros(n_tasks, device=device, dtype=dtype); e[j_min] = 1.0
d = e - alpha
num = -(alpha @ (M @ d))
den = (d @ (M @ d)).clamp(min=1e-12)
gamma = float((num / den).clamp(0.0, 1.0))
alpha = alpha + gamma * d
return alpha
def flat_grad(loss, params, retain_graph):
g = torch.autograd.grad(loss, params, retain_graph=retain_graph)
return torch.cat([gi.flatten() for gi in g])
@dataclass
class MGDAConfig:
outer_steps: int = 500
group_size: int = 32
k_targets: int = 4
sample_steps: int = 25
lr: float = 1e-4
beta: float = 5.0
alpha_ema: float = 0.9 # EMA smoothing on the simplex weights
use_advantage: bool = False # if False, signal = beta * r (raw, like §6)
scale_advantage: bool = False # if True (with use_advantage), std-scale
log_every: int = 50
def train_mgda_ram(model, model_ref, reward_fns, cfg: MGDAConfig) -> dict:
"""Multi-reward RAM via MGDA. `reward_fns` is a list of J reward callables."""
params = list(model.parameters())
opt = torch.optim.Adam(params, lr=cfg.lr)
J = len(reward_fns)
per_reward_hist = [[] for _ in range(J)]
alpha_hist = []
alpha_ema = torch.full((J,), 1.0 / J)
pbar = tqdm(range(cfg.outer_steps), desc="MGDA-RAM", mininterval=2.0)
for step in pbar:
# Shared rollout: one set of G endpoints, scored under each reward.
with torch.no_grad():
x0 = euler_sample(model, cfg.group_size, n_steps=cfg.sample_steps)
per_reward = [r(x0) for r in reward_fns]
for j, r_vals in enumerate(per_reward):
per_reward_hist[j].append(r_vals.mean().item())
# Shared (t, eps) batch and one forward pass through v_theta.
x0_rep = x0.repeat_interleave(cfg.k_targets, dim=0)
eps = torch.randn_like(x0_rep)
t = torch.rand(x0_rep.shape[0])
xt = (1 - t[:, None]) * x0_rep + t[:, None] * eps
with torch.no_grad():
v_ref_xt = model_ref(xt, t)
v_theta_sg = model(xt, t)
pretrain_target = eps - x0_rep
v_pred = model(xt, t) # one forward, graph reused by J backwards
# J per-reward losses → J per-reward gradients.
grads = []
for j, r_vals in enumerate(per_reward):
if cfg.use_advantage:
adv = r_vals - r_vals.mean()
if cfg.scale_advantage:
adv = adv / r_vals.std(correction=0).clamp_min(1e-3)
sig = (cfg.beta * adv).repeat_interleave(cfg.k_targets, dim=0)
else:
sig = (cfg.beta * r_vals).repeat_interleave(cfg.k_targets, dim=0)
target_j = v_ref_xt + sig[:, None] * (pretrain_target - v_theta_sg)
loss_j = ((v_pred - target_j) ** 2).mean()
grads.append(flat_grad(loss_j, params, retain_graph=(j < J - 1)))
# MARBLE recipe: normalize each gradient to unit norm before solving for
# alpha (so alpha reflects directional conflict, not magnitude disparities),
# then rescale the combined direction by the mean original norm to keep
# Adam's natural step size unchanged.
norms = torch.stack([g.norm() for g in grads]).clamp_min(1e-8)
grads_unit = [g / n for g, n in zip(grads, norms)]
# MGDA on the simplex of unit-norm gradients, then EMA smooth.
alpha_inst = min_norm_simplex(grads_unit).detach()
alpha_ema = cfg.alpha_ema * alpha_ema + (1 - cfg.alpha_ema) * alpha_inst
alpha_hist.append(alpha_ema.tolist())
combined_unit = (alpha_ema[:, None] * torch.stack(grads_unit)).sum(0)
combined = combined_unit * norms.mean()
# Assign back as .grad and step.
opt.zero_grad()
offset = 0
for p in params:
n = p.numel()
p.grad = combined[offset:offset + n].view_as(p).clone()
offset += n
opt.step()
if step % cfg.log_every == 0:
pbar.set_postfix(
**{f"r{j}": f"{per_reward_hist[j][-1]:.3f}" for j in range(J)},
alpha=f"[{','.join(f'{a:.2f}' for a in alpha_ema.tolist())}]",
)
return {"per_reward": per_reward_hist, "alpha": alpha_hist}
8.2 Three reward setups: from conflicting to aligned¶
To probe how multi-reward RAM behaves across reward geometries, we set up three two-reward problems by varying the angular half-spread $\phi$ between two identical Gaussian rewards on the data ring ($R = 4$, scale $s = 2.5$):
| setup | $\phi$ | peak separation | cross-reward $\approx$ |
|---|---|---|---|
| conflicting | $90°$ | $180°$ | $0.006$ |
| partial | $30°$ | $60°$ | $0.28$ |
| aligned | $5°$ | $10°$ | $0.96$ |
Cross-reward $= \exp(-(2R\sin\phi)^2 / (2s^2))$ — how much a sample sitting at one peak scores on the other reward. Sweeping it from $\sim 0$ to $\sim 1$ takes us across the full spectrum: from peaks that share essentially no support (conflict) to peaks that nearly coincide (alignment), with a middle case where they meaningfully overlap but still favor different regions.
def make_radial_reward(center, scale=2.5):
c = torch.tensor(center, dtype=torch.float32)
return lambda x: torch.exp(-0.5 * ((x - c) ** 2).sum(-1) / scale**2)
R = 4.0
SETUPS = [
("conflicting", 90),
("partial overlap", 30),
("aligned", 5),
]
def make_setup(phi_deg):
# Return (centers, rewards, cross_reward) for the two-reward setup
# with half-spread phi_deg degrees from mode 0 on the ring.
phi = math.radians(phi_deg)
centers = [(R * math.cos(+phi), R * math.sin(+phi)),
(R * math.cos(-phi), R * math.sin(-phi))]
rewards = [make_radial_reward(c) for c in centers]
cross_r = math.exp(-(2 * R * math.sin(phi))**2 / (2 * 2.5**2))
return centers, rewards, cross_r
# Visualize the three reward geometries side-by-side. Each panel shows
# the SUM of the two rewards so both peaks are visible at once.
fig, axes = plt.subplots(1, 3, figsize=(13, 4.2))
for ax, (setup_name, phi_deg) in zip(axes, SETUPS):
centers, rewards, cross_r = make_setup(phi_deg)
r_sum = lambda x, rs=rewards: rs[0](x) + rs[1](x)
plot_density(r_sum, ax,
title=f"{setup_name} ($\\phi=\\pm{phi_deg}°$, cross-r$\\approx${cross_r:.3f})")
for nm, (cx, cy) in zip(["A", "B"], centers):
ax.plot(cx, cy, marker="*", markersize=14,
color="white", markeredgecolor="black")
ax.text(cx + 0.2, cy + 0.25, nm, color="white", fontsize=11,
bbox=dict(facecolor="black", alpha=0.5, pad=1))
plt.tight_layout(); plt.show()
8.3 Train all 9 models and compare¶
Three setups $\times$ {single-A, single-B, MGDA} = nine training runs. Same recipe across all of them: $\beta = 5$, 1500 outer steps, group size 32, $K = 4$ reuse, raw $\beta r$ signal throughout (matching §6 — bounded and doesn't collapse over long horizons). The nine trainings finish in $\sim 30$ seconds combined.
BETA = 5.0
STEPS = 1500
reward_names = ["A", "B"]
def fresh_copy_of_ref():
m = VelocityNet()
m.load_state_dict(v_ref.state_dict())
return m
# results[setup_name] = dict(centers, rewards, cross_r, models, scores, samples)
results = {}
for setup_name, phi_deg in SETUPS:
centers, rewards, cross_r = make_setup(phi_deg)
models = {}
for nm, rf in zip(reward_names, rewards):
torch.manual_seed(31)
m = fresh_copy_of_ref()
train_ram(m, v_ref, rf, RAMConfig(
outer_steps=STEPS, group_size=32, k_targets=4,
beta=BETA, lr=1e-4, log_every=STEPS,
))
models[f"single-{nm}"] = m
torch.manual_seed(31)
mm = fresh_copy_of_ref()
train_mgda_ram(mm, v_ref, rewards, MGDAConfig(
outer_steps=STEPS, group_size=32, k_targets=4,
beta=BETA, lr=1e-4, alpha_ema=0.9, log_every=STEPS,
))
models["multi-MGDA"] = mm
# Evaluate every model on both rewards using a 3000-sample bag.
scores = np.zeros((len(models), len(rewards)))
samples = {}
for i, (name, m) in enumerate(models.items()):
m.eval()
s = euler_sample(m, 3000)
samples[name] = s
for j, rf in enumerate(rewards):
scores[i, j] = rf(s).mean().item()
results[setup_name] = dict(centers=centers, rewards=rewards, cross_r=cross_r,
models=models, scores=scores, samples=samples)
RAM (beta=5.0): 100%|██████████| 1500/1500 [00:03<00:00, 418.89it/s, loss=9.554, mean_r=0.229]
RAM (beta=5.0): 100%|██████████| 1500/1500 [00:03<00:00, 434.09it/s, loss=21.375, mean_r=0.251]
MGDA-RAM: 100%|██████████| 1500/1500 [00:05<00:00, 290.90it/s, alpha=[0.50,0.50], r0=0.229, r1=0.251]
RAM (beta=5.0): 100%|██████████| 1500/1500 [00:03<00:00, 440.38it/s, loss=13.735, mean_r=0.247]
RAM (beta=5.0): 100%|██████████| 1500/1500 [00:03<00:00, 429.96it/s, loss=26.599, mean_r=0.337]
MGDA-RAM: 100%|██████████| 1500/1500 [00:05<00:00, 295.41it/s, alpha=[0.50,0.50], r0=0.247, r1=0.337]
RAM (beta=5.0): 100%|██████████| 1500/1500 [00:03<00:00, 438.25it/s, loss=27.951, mean_r=0.331]
RAM (beta=5.0): 100%|██████████| 1500/1500 [00:03<00:00, 446.31it/s, loss=30.374, mean_r=0.350]
MGDA-RAM: 100%|██████████| 1500/1500 [00:04<00:00, 304.68it/s, alpha=[0.50,0.50], r0=0.331, r1=0.350]
# Single table: 9 rows × {rA, rB, min, avg} columns.
print(f"{'setup':18s} {'cross_r':>8s} {'model':12s} "
f"{'r_A':>6s} {'r_B':>6s} {'min':>6s} {'avg':>6s}")
print("-" * 70)
for setup_name, _ in SETUPS:
r = results[setup_name]
for i, name in enumerate(r["models"]):
row = r["scores"][i]
print(f" {setup_name:16s} {r['cross_r']:8.3f} {name:12s} "
f"{row[0]:6.3f} {row[1]:6.3f} {row.min():6.3f} {row.mean():6.3f}")
print()
setup cross_r model r_A r_B min avg ---------------------------------------------------------------------- conflicting 0.006 single-A 0.796 0.034 0.034 0.415 conflicting 0.006 single-B 0.036 0.780 0.036 0.408 conflicting 0.006 multi-MGDA 0.358 0.364 0.358 0.361 partial overlap 0.278 single-A 0.704 0.322 0.322 0.513 partial overlap 0.278 single-B 0.314 0.736 0.314 0.525 partial overlap 0.278 multi-MGDA 0.535 0.550 0.535 0.542 aligned 0.962 single-A 0.775 0.765 0.765 0.770 aligned 0.962 single-B 0.769 0.776 0.769 0.773 aligned 0.962 multi-MGDA 0.785 0.783 0.783 0.784
# Per-reward bar chart: 1 row × 3 setups. Shared y-axis so the three
# panels can be compared at a glance.
model_names = list(results[SETUPS[0][0]]["models"].keys())
x = np.arange(len(model_names))
width = 0.35
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5), sharey=True)
for ax, (setup_name, _) in zip(axes, SETUPS):
r = results[setup_name]
sc = r["scores"]
ax.bar(x - width / 2, sc[:, 0], width, label=r"$r_A$", color="C0")
ax.bar(x + width / 2, sc[:, 1], width, label=r"$r_B$", color="C1")
for xi, v in zip(x - width / 2, sc[:, 0]):
ax.text(xi, v + 0.015, f"{v:.2f}", ha="center", fontsize=8)
for xi, v in zip(x + width / 2, sc[:, 1]):
ax.text(xi, v + 0.015, f"{v:.2f}", ha="center", fontsize=8)
ax.set_xticks(x); ax.set_xticklabels(model_names, rotation=10)
ax.set_ylim(0, 0.92)
ax.set_title(f"{setup_name} (cross-r $\\approx$ {r['cross_r']:.3f})")
axes[0].set_ylabel("mean reward over 3000 samples")
axes[0].legend(loc="upper right")
plt.tight_layout(); plt.show()
# 3 × 3 grid of endpoint density heatmaps: rows = setups, cols = models.
fig, axes = plt.subplots(3, 3, figsize=(11.5, 11.5))
for row, (setup_name, _) in enumerate(SETUPS):
r = results[setup_name]
for col, (name, samples) in enumerate(r["samples"].items()):
ax = axes[row, col]
def density(g, s=samples):
p = kde_density(g, s, h=0.15)
return p / (p.sum() + 1e-12)
plot_density(density, ax, title=f"{setup_name} — {name}")
for (cx, cy), nm in zip(r["centers"], ["A", "B"]):
ax.plot(cx, cy, marker="*", markersize=12,
color="white", markeredgecolor="black")
ax.text(cx + 0.2, cy + 0.25, nm, color="white", fontsize=9,
bbox=dict(facecolor="black", alpha=0.5, pad=1))
plt.tight_layout(); plt.show()
8.4 Takeaway¶
Two observations stand out from the 9-run sweep.
Multi-reward RAM works in every setup. MGDA returns a model that improves both rewards from the pretrained baseline ($\sim 0.24$) in all three regimes — to $\sim 0.36$ on the conflicting peaks, $\sim 0.54$ on partial overlap, $\sim 0.78$ on aligned peaks. No setup breaks it.
The value of MGDA over single-reward specialists depends entirely on reward geometry.
- Conflicting ($\phi = 90°$, cross-r $\sim 0$): each specialist hits its own reward at $\sim 0.79$ but scores $\sim 0.03$ on the other one — a strict Pareto trade-off enforced by near-disjoint support. MGDA gives up about half of the per-axis peak ($\sim 0.36$) but covers both — its worst-axis reward is $\sim 10\times$ higher than any specialist's. Which model to pick depends on whether you want one strong axis or balanced coverage.
- Partial overlap ($\phi = 30°$, cross-r $\sim 0.28$): the Pareto trade-off is much milder. Each specialist still wins its own axis ($\sim 0.70$ vs MGDA's $\sim 0.54$), but MGDA's worst axis ($\sim 0.54$) is almost double every specialist's worst axis ($\sim 0.30$), and MGDA's arithmetic mean ($\sim 0.54$) beats every specialist's mean ($\sim 0.52$). MGDA dominates on both balance and average.
- Aligned ($\phi = 5°$, cross-r $\sim 0.96$): the two rewards agree on what a "good sample" looks like. Every model scores within $\sim 0.02$ on every axis. MGDA costs nothing and returns one model that matches the specialists on all dimensions. This is the MARBLE regime, the one image-scale post-training usually lives in.
The 3 $\times$ 3 density heatmaps make the same point visually. In the conflicting row, each specialist parks all of its mass on its own ring mode; MGDA spreads mass to cover both stars. In the partial-overlap row the specialists' distributions already overlap the other reward's basin; MGDA sits between them with broader coverage. In the aligned row all three distributions look essentially identical.
9. Recap¶
Everything we just built fits in two sentences.
- Pretraining a flow-matching model is the world's simplest regression: noise a clean sample analytically, regress against $(\varepsilon - X_0)$.
- RAM keeps it a regression — same noising rule, same loss shape — only with a reward-weighted correction toward better endpoints, anchored on the pretrained velocity.
No reward gradients. No SDE rollouts. No backward adjoint sweeps. Just a regression — exactly like pretraining.
Why RAM is a promising direction for diffusion RL¶
Stepping back from the 2D toy, RAM is worth attention beyond what it does on the ring. It stands out from previous diffusion-RL recipes in five ways, several of which are visible in the code we just wrote:
- On-policy by construction. Every training batch we sampled in
§6–§8 came from a fresh
euler_sample(model)call against the current policy. No replay buffer, no off-policy correction, no separate "teacher" pre-training step. The loop is just: roll the policy, score, regress. - Regression-style loss, not a clipped policy ratio. Our target was an L2 regression against $v_{\text{ref}} + \beta r\,(\varepsilon - X_0)$ (§4.3). There is no $\pi_\theta / \pi_{\text{old}}$ ratio in the gradient, so the ratio-clipping and KL-penalty machinery that PPO/GRPO-style methods (DDPO, FlowGRPO) inherit from language-model RL is not needed here. The pretrained $v^{\text{ref}}$ acts as the KL anchor on its own.
- Black-box rewards. Nowhere in the notebook did we compute
$\nabla r$. The reward function
ris just a callable that returns a scalar — it could equally well be a CLIP score, an OCR edit distance, a VLM judge, or a hand-written rule. Compare this with classic adjoint-matching, which propagates $\nabla r$ through the sampler. - Solver-agnostic, terminal-only rollouts. Each outer step needs exactly one thing from the sampler: the clean endpoint $X_0$. No trajectory storage, no per-step log-prob recording, no ODE-backward through the sampler. Our 2D code used a deterministic Euler ODE for speed; in principle any black-box solver works. FlowGRPO-style methods, by contrast, rely on SDE samplers so that per-step log-probabilities are available for the ratio.
- Unified single ↔ multi-reward. Because the loss is a regression
target, multiplying rewards is just multi-task regression. §8
bolted on MGDA in ~30 lines of code, reused the same
train_ramrollout, and recovered single-reward as a special case ($J = 1$ gives a trivial simplex). MARBLE makes this multi-task treatment first-class at image scale. By contrast, a two-stage setup like DiffusionOPD (arXiv:2605.15055) — train one teacher per task, then distill them into a single student — collapses to one stage here.