JAX跑得快的技巧其實很簡單:通過組合變換讓XLA能看到大塊連續的計算,比如説批處理、融合、分片,讓每一步在單設備或多設備同步時都像一個乾淨的kernel。
我們今天就來總結7個能夠提高運行速度的JAX變換組合
1、 jit 優先,形狀穩定
jit
對函數做一次追蹤後XLA負責融合算子,形狀穩定、無副作用時,Python處理的開銷就被分攤掉,可以提高運行速度。
形狀創建和靜態參數要麼挪到step外部,要麼顯式標記為static。
donate_argnums
能讓JAX複用緩衝區,省掉不必要的內存拷貝。step之間保持dtype和shape一致,trace結果才能被緩存下來。
import jax, jax.numpy as jnp
@jax.jit(donate_argnums=(0,))
def sgd_step(params, batch, lr):
x, y = batch
def loss_fn(p):
preds = model_apply(p, x) # pure function
return jnp.mean((preds - y) ** 2)
grads = jax.grad(loss_fn)(params)
return jax.tree_map(lambda p, g: p - lr * g, params, grads)
每個(shape, dtype, static-arg)組合只追蹤一次。頻繁retrace多半是輸入shape在變,或者Python邏輯泄漏進了計算圖。
2、vmap替換Python循環
vmap
在leading axis上做向量化,XLA直接把batch融進kernel。for循環沒了設備launch就少了,內存訪問也更連續。
# per-example loss
def example_loss(params, x, y):
pred = model_apply(params, x)
return jnp.mean((pred - y) ** 2)
# batch it without writing loops
batched_loss = jax.vmap(example_loss, in_axes=(None, 0, 0)) # params broadcasted
嵌套
vmap
可以搞2D batch,比如time × batch,只要別超HBM容量。
vmap
適合做內層微批處理,比如ensemble或MC sampling這類場景,外層維度留給分片。
3、長循環的融合利器Scan
RNN、展開解碼、迭代求解器,這些場景用
scan
比Python循環快。
scan
只編譯一次循環體跑在XLA的while-loop裏,Python開銷基本為0,融合和內存複用也更激進。
from jax import lax
def rnn_cell(carry, x):
h = carry
h = jnp.tanh(W_hh @ h + W_xh @ x + b)
y = W_hy @ h
return h, y # (carry, output)
def rnn_forward(h0, xs):
hT, ys = lax.scan(rnn_cell, h0, xs) # xs: [T, B, D]
return hT, ys
循環狀態用
carry
傳遞,body保持小而純淨,要注意保持形狀不要變,比如:序列模型、diffusion step循環、定點迭代、beam解碼(形狀穩定時)都適用。
4、remat可以用計算換內存
批次大了TPU/GPU的FLOP利用率往往更高。
remat
(也叫checkpoint)會丟掉部分中間激活,反向時重算這樣峯值顯存下來batch就能開的更大。
from jax import remat
def block(params, x):
x = jax.nn.gelu(x @ params['w1'])
x = x @ params['w2']
return x
fast_block = remat(block) # checkpointed
@jax.jit
def forward(params, x):
for _ in range(6):
x = x + fast_block(params, x)
return x
只包最重的子塊就行,比如attention加MLP那幾層。同時配合
vmap
或分片,全局batch能再往上拉。不過需要一些額外FLOPs,但如果換來1.3到2倍的batch increase,wall-clock往往更短。
5、pmap單機多卡數據並行
pmap
把函數複製到單主機的多個設備上(8卡工作站、單節點8核TPU),梯度可以自動all-reduce,並且每設備只編譯一次。
from jax import pmap, lax
@pmap(axis_name='d')
def train_step(params, batch, lr):
x, y = batch # each device sees [local_B, ...]
def loss_fn(p):
pred = model_apply(p, x)
loss = jnp.mean((pred - y) ** 2)
return loss
loss, grads = jax.value_and_grad(loss_fn)(params)
loss = lax.pmean(loss, axis_name='d')
grads = lax.pmean(grads, axis_name='d')
params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
return params, loss
batch在leading axis分片,
lax.pmean
聚合loss和grads。單機場景下
pmap
簡單可靠。跨主機擴展或者想做張量級細粒度分片可以成換
pjit
。
6、pjit+ 命名分片:SPMD並行
pjit
編譯出單一SPMD程序可以跨設備跨主機運行。用mesh和
PartitionSpec
描述數組怎麼切,JAX處理collective通信,這樣數據並行、張量並行、混合並行都能做。
import jax
from jax.sharding import Mesh, PartitionSpec as P
import numpy as np
devices = np.array(jax.devices()).reshape(2, 4) # 2 × 4 mesh (dp × mp)
mesh = Mesh(devices, ('dp', 'mp'))
@jax.jit # jit is optional when using pjit; shown when composing
def model_apply_sharded(params, x):
return model_apply(params, x)
from jax.experimental.pjit import pjit
with mesh:
in_shard = (P('mp',), P('dp',)) # example; tailor to your shapes
out_shard = P('dp',) # e.g., shard batch across dp
step = pjit(model_apply_sharded,
in_shardings=(P('mp',), P('dp',)),
out_shardings=out_shard)
y = step(params_sharded, x_sharded)
一般都是batch軸走
dp
,大矩陣維度(hidden size、heads)走
mp
。分片數需要跟設備拓撲對齊,跨主機流量才少。
7、value_and_grad的正確堆疊方式
規範寫法是
jit(value_and_grad(loss, has_aux=True))
,外面可以再套一層
pmap
或
pjit
。這樣forward只跑一遍metrics留在aux裏帶出來。
def loss_with_aux(params, batch):
x, y = batch
pred = model_apply(params, x)
loss = jnp.mean((pred - y) ** 2)
aux = {'mse': loss, 'mean_pred': jnp.mean(pred)}
return loss, aux
@jax.jit
def train_step(params, opt_state, batch, lr):
(loss, aux), grads = jax.value_and_grad(loss_with_aux, has_aux=True)(params, batch)
updates, opt_state = optimizer_update(grads, opt_state, params, lr)
params = optax_apply(updates, params)
return params, opt_state, loss, aux
value_and_grad
放
jit
裏面,JAX會把forward和backward一起stage。返回
(loss, aux)
日誌指標不用再跑一遍forward。
這套組合很靈活:
vmap
做微批次,
scan
跑時序循環,外面套
pmap
或
pjit
,
donate_argnums
標上buffer。
總結
變長序列pad加mask,shape穩定是前提條件。traced代碼裏不要添加Python隨機性,比如PRNG key要在外面split好。矩陣乘用
bfloat16
,這樣數值穩定性也夠用,吞吐量在TPU/GPU上表現的也很好。性能profile要重點看warm-up之後的tokens/sec或samples/sec。日誌只看標量aux metrics就行,每step把大數組傳回host是性能殺手。
JAX的性能不是黑盒:
jit
- shape可以穩定打底,
vmap
做batch,
scan
融合循環,
remat
回收顯存,
pmap
或
pjit
做擴展,
value_and_grad(..., has_aux=True)
讓每一步只跑一次forward一次backward。
https://avoid.overfit.cn/post/84e4e28e3ca8473488a0e9248d1ec51b
作者:Nexumo