博客 / 詳情

返回

JAX性能優化實戰:7個變換讓TPU/GPU吃滿算力

JAX跑得快的技巧其實很簡單:通過組合變換讓XLA能看到大塊連續的計算,比如説批處理、融合、分片,讓每一步在單設備或多設備同步時都像一個乾淨的kernel。

我們今天就來總結7個能夠提高運行速度的JAX變換組合

1、 jit 優先,形狀穩定

jit

對函數做一次追蹤後XLA負責融合算子,形狀穩定、無副作用時,Python處理的開銷就被分攤掉,可以提高運行速度。

形狀創建和靜態參數要麼挪到step外部,要麼顯式標記為static。

donate_argnums

能讓JAX複用緩衝區,省掉不必要的內存拷貝。step之間保持dtype和shape一致,trace結果才能被緩存下來。

 import jax, jax.numpy as jnp  

@jax.jit(donate_argnums=(0,))  
def sgd_step(params, batch, lr):  
    x, y = batch  
    def loss_fn(p):  
        preds = model_apply(p, x)  # pure function  
        return jnp.mean((preds - y) ** 2)  
    grads = jax.grad(loss_fn)(params)  
     return jax.tree_map(lambda p, g: p - lr * g, params, grads)

每個(shape, dtype, static-arg)組合只追蹤一次。頻繁retrace多半是輸入shape在變,或者Python邏輯泄漏進了計算圖。

2、vmap替換Python循環

vmap

在leading axis上做向量化,XLA直接把batch融進kernel。for循環沒了設備launch就少了,內存訪問也更連續。

 # per-example loss  
 def example_loss(params, x, y):  
     pred = model_apply(params, x)  
     return jnp.mean((pred - y) ** 2)  
   
 # batch it without writing loops  
 batched_loss = jax.vmap(example_loss, in_axes=(None, 0, 0))  # params broadcasted

嵌套

vmap

可以搞2D batch,比如time × batch,只要別超HBM容量。

vmap

適合做內層微批處理,比如ensemble或MC sampling這類場景,外層維度留給分片。

3、長循環的融合利器Scan

RNN、展開解碼、迭代求解器,這些場景用

scan

比Python循環快。

scan

只編譯一次循環體跑在XLA的while-loop裏,Python開銷基本為0,融合和內存複用也更激進。

 from jax import lax  

def rnn_cell(carry, x):  
    h = carry  
    h = jnp.tanh(W_hh @ h + W_xh @ x + b)  
    y = W_hy @ h  
    return h, y  # (carry, output)  

def rnn_forward(h0, xs):  
    hT, ys = lax.scan(rnn_cell, h0, xs)  # xs: [T, B, D]  
     return hT, ys

循環狀態用

carry

傳遞,body保持小而純淨,要注意保持形狀不要變,比如:序列模型、diffusion step循環、定點迭代、beam解碼(形狀穩定時)都適用。

4、remat可以用計算換內存

批次大了TPU/GPU的FLOP利用率往往更高。

remat

(也叫checkpoint)會丟掉部分中間激活,反向時重算這樣峯值顯存下來batch就能開的更大。

 from jax import remat  

def block(params, x):  
    x = jax.nn.gelu(x @ params['w1'])  
    x = x @ params['w2']  
    return x  

fast_block = remat(block)  # checkpointed  

@jax.jit  
def forward(params, x):  
    for _ in range(6):  
        x = x + fast_block(params, x)  
     return x

只包最重的子塊就行,比如attention加MLP那幾層。同時配合

vmap

或分片,全局batch能再往上拉。不過需要一些額外FLOPs,但如果換來1.3到2倍的batch increase,wall-clock往往更短。

5、pmap單機多卡數據並行

pmap

把函數複製到單主機的多個設備上(8卡工作站、單節點8核TPU),梯度可以自動all-reduce,並且每設備只編譯一次。

 from jax import pmap, lax  

@pmap(axis_name='d')  
def train_step(params, batch, lr):  
    x, y = batch  # each device sees [local_B, ...]  
    def loss_fn(p):  
        pred = model_apply(p, x)  
        loss = jnp.mean((pred - y) ** 2)  
        return loss  
    loss, grads = jax.value_and_grad(loss_fn)(params)  
    loss = lax.pmean(loss, axis_name='d')  
    grads = lax.pmean(grads, axis_name='d')  
    params = jax.tree_map(lambda p, g: p - lr * g, params, grads)  
     return params, loss

batch在leading axis分片,

lax.pmean

聚合loss和grads。單機場景下

pmap

簡單可靠。跨主機擴展或者想做張量級細粒度分片可以成換

pjit

6、pjit+ 命名分片:SPMD並行

pjit

編譯出單一SPMD程序可以跨設備跨主機運行。用mesh和

PartitionSpec

描述數組怎麼切,JAX處理collective通信,這樣數據並行、張量並行、混合並行都能做。

 import jax  
from jax.sharding import Mesh, PartitionSpec as P  
import numpy as np  

devices = np.array(jax.devices()).reshape(2, 4)  # 2 × 4 mesh (dp × mp)  
mesh = Mesh(devices, ('dp', 'mp'))  

@jax.jit  # jit is optional when using pjit; shown when composing  
def model_apply_sharded(params, x):  
    return model_apply(params, x)  

from jax.experimental.pjit import pjit  

with mesh:  
    in_shard  = (P('mp',), P('dp',))  # example; tailor to your shapes  
    out_shard = P('dp',)              # e.g., shard batch across dp  
    step = pjit(model_apply_sharded,  
                in_shardings=(P('mp',), P('dp',)),  
                out_shardings=out_shard)  
     y = step(params_sharded, x_sharded)

一般都是batch軸走

dp

,大矩陣維度(hidden size、heads)走

mp

。分片數需要跟設備拓撲對齊,跨主機流量才少。

7、value_and_grad的正確堆疊方式

規範寫法是

jit(value_and_grad(loss, has_aux=True))

,外面可以再套一層

pmap

pjit

。這樣forward只跑一遍metrics留在aux裏帶出來。

 def loss_with_aux(params, batch):  
    x, y = batch  
    pred = model_apply(params, x)  
    loss = jnp.mean((pred - y) ** 2)  
    aux  = {'mse': loss, 'mean_pred': jnp.mean(pred)}  
    return loss, aux  

@jax.jit  
def train_step(params, opt_state, batch, lr):  
    (loss, aux), grads = jax.value_and_grad(loss_with_aux, has_aux=True)(params, batch)  
    updates, opt_state = optimizer_update(grads, opt_state, params, lr)  
    params = optax_apply(updates, params)  
     return params, opt_state, loss, aux
value_and_grad

jit

裏面,JAX會把forward和backward一起stage。返回

(loss, aux)

日誌指標不用再跑一遍forward。

這套組合很靈活:

vmap

做微批次,

scan

跑時序循環,外面套

pmap

pjit

donate_argnums

標上buffer。

總結

變長序列pad加mask,shape穩定是前提條件。traced代碼裏不要添加Python隨機性,比如PRNG key要在外面split好。矩陣乘用

bfloat16

,這樣數值穩定性也夠用,吞吐量在TPU/GPU上表現的也很好。性能profile要重點看warm-up之後的tokens/sec或samples/sec。日誌只看標量aux metrics就行,每step把大數組傳回host是性能殺手。

JAX的性能不是黑盒:

jit
  • shape可以穩定打底,
vmap

做batch,

scan

融合循環,

remat

回收顯存,

pmap

pjit

做擴展,

value_and_grad(..., has_aux=True)

讓每一步只跑一次forward一次backward。

https://avoid.overfit.cn/post/84e4e28e3ca8473488a0e9248d1ec51b

作者:Nexumo

user avatar
0 位用戶收藏了這個故事!

發佈 評論

Some HTML is okay.