博客 / 詳情

返回

【Triton 教程】層標準化

Triton 是一種用於並行編程的語言和編譯器。它旨在提供一個基於 Python 的編程環境,以高效編寫自定義 DNN 計算內核,並能夠在現代 GPU 硬件上以最大吞吐量運行。

更多 Triton 中文文檔可訪問 →https://triton.hyper.ai/

在本教程中,你將編寫一個比 PyTorch 實現運行更快的高性能層標準化 (layer normalization) 內核。

在此過程中,你將瞭解:

  • 在 Triton 中實現反向傳播 (backward pass)。
  • 在 Triton 中實現並行歸約 (parallel reduction)。

動機​

層標準化 (LayerNorm) 算子最先在 BA2016 中提出,旨在提高序列模型(例如 Transformers)或小 batchsize 神經網絡的性能。它以向量 x 作為輸入,並生成與輸入 shape 相同的向量 y 作為輸出。 標準化是通過減去均值併除以 x 的標準差來實現的。 標準化後,會應用帶有權重 w 和偏置 b 的可學習線性變換。

首先讓我們看看前向傳播的實現。

import torch
import triton
import triton.language as tl


try:
    # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it
    # should not be added to extras_require in setup.py.
    # 這是 https://github.com/NVIDIA/apex,不是 PyPi 的 apex,
    # 所以不應該加進 setup.py 的額外依賴中
    import apex
    HAS_APEX = True
except ModuleNotFoundError:
    HAS_APEX = False




@triton.jit
def _layer_norm_fwd_fused(
    X,  # pointer to the input 輸入指針
    Y,  # pointer to the output 輸出指針
    W,  # pointer to the weights 權重指針
    B,  # pointer to the biases 偏差指針
    Mean,  # pointer to the mean 均值指針
    Rstd,  # pointer to the 1/std 1/std 指針
    stride,  # how much to increase the pointer when moving by 1 row 指針移動一行應該增加多少
    N,  # number of columns in X X 的列數
    eps,  # epsilon to avoid division by zero 用於避免除以 0 的 epsilon
    BLOCK_SIZE: tl.constexpr,
):
    # Map the program id to the row of X and Y it should compute.
    # 映射程序 id 到對應計算的 X 和 Y 的行
    row = tl.program_id(0)
    Y += row * stride
    X += row * stride
    # Compute mean
    # 計算均值
    mean = 0
    _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
        _mean += a
    mean = tl.sum(_mean, axis=0) / N
    # Compute variance
    # 計算方差
    _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
        x = tl.where(cols < N, x - mean, 0.)
        _var += x * x
    var = tl.sum(_var, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)
    # Write mean / rstd
    # 寫入 mean / rstd
    tl.store(Mean + row, mean)
    tl.store(Rstd + row, rstd)
    # Normalize and apply linear transformation
    # 歸一化並應用線性變換
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        mask = cols < N
        w = tl.load(W + cols, mask=mask)
        b = tl.load(B + cols, mask=mask)
        x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
        x_hat = (x - mean) * rstd
        y = x_hat * w + b
        # Write output
        tl.store(Y + cols, y, mask=mask)

反向傳播

層標準化算子的反向傳播比前向傳播要複雜一些。

由於在同一批次中的所有行使用相同的權重 w 和偏差 b,它們的梯度需要累加。為了高效地執行此步驟,我們使用並行歸約策略:每個內核實例將某些行的部分 ∇w 和 ∇b 累積到 GROUP_SIZE_M 個獨立緩衝區之一中。這些緩衝區保存在 L2 緩存中,然後通過另一個函數進一步歸約以計算實際的∇w 和 ∇b。

設輸入行數 M=4 和 GROUP_SIZE_M=2,以下是 ∇w 的並行歸約策略圖示(為簡潔起見,省略 ∇b):

在這裏插入圖片描述

在第一階段,同色的 X 行共享同一個緩衝區,因此使用 lock 以確保一次只有一個內核實例寫入緩衝區。在第二階段,這些緩衝區會進一步歸約以計算最終的 ∇w 和 ∇b。在以下實現中,第一階段由函數 _layer_norm_bwd_dx_fused 實現,第二階段由函數 _layer_norm_bwd_dwdb 實現。

@triton.jit
def _layer_norm_bwd_dx_fused(DX,  # pointer to the input gradient 輸入梯度指針
                             DY,  # pointer to the output gradient 輸出梯度指針
                             DW,  # pointer to the partial sum of weights gradient 權重和梯度指針
                             DB,  # pointer to the partial sum of biases gradient 偏差梯度部分和指針
                             X,  # pointer to the input 輸入指針
                             W,  # pointer to the weights 權重指針
                             Mean,  # pointer to the mean 均值指針
                             Rstd,  # pointer to the 1/std 1/std 指針
                             Lock,  # pointer to the lock 鎖指針
                             stride,  # how much to increase the pointer when moving by 1 row 指針移動一行應該增加多少
                             N,  # number of columns in X X 的列數
                             GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
    # Map the program id to the elements of X, DX, and DY it should compute.
    # 映射程序 id 到對應計算的 X, DX, DY
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK_SIZE_N)
    mask = cols < N
    X += row * stride
    DY += row * stride
    DX += row * stride
    # Offset locks and weights/biases gradient pointer for parallel reduction
    # 偏移鎖和權重/偏差梯度指針以並行歸約
    lock_id = row % GROUP_SIZE_M
    Lock += lock_id
    Count = Lock + GROUP_SIZE_M
    DW = DW + lock_id * N + cols
    DB = DB + lock_id * N + cols
    # Load data to SRAM
    # 讀取數據到 SRAM
    x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
    dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
    w = tl.load(W + cols, mask=mask).to(tl.float32)
    mean = tl.load(Mean + row)
    rstd = tl.load(Rstd + row)
    # Compute dx
    # 計算 ds
    xhat = (x - mean) * rstd
    wdy = w * dy
    xhat = tl.where(mask, xhat, 0.)
    wdy = tl.where(mask, wdy, 0.)
    c1 = tl.sum(xhat * wdy, axis=0) / N
    c2 = tl.sum(wdy, axis=0) / N
    dx = (wdy - (xhat * c1 + c2)) * rstd
    # Write dx
    # 寫入 dx
    tl.store(DX + cols, dx, mask=mask)
    # Accumulate partial sums for dw/db
    # 累加 dw 和 db 的部分和
    partial_dw = (dy * xhat).to(w.dtype)
    partial_db = (dy).to(w.dtype)
    while tl.atomic_cas(Lock, 0, 1) == 1:
        pass
    count = tl.load(Count)
    # First store doesn't accumulate
    # 第一個儲存不累加
    if count == 0:
        tl.atomic_xchg(Count, 1)
    else:
        partial_dw += tl.load(DW, mask=mask)
        partial_db += tl.load(DB, mask=mask)
    tl.store(DW, partial_dw, mask=mask)
    tl.store(DB, partial_db, mask=mask)
    # Release the lock
    # 釋放鎖
    tl.atomic_xchg(Lock, 0)




@triton.jit
def _layer_norm_bwd_dwdb(DW,  # pointer to the partial sum of weights gradient 權重部分和指針
                         DB,  # pointer to the partial sum of biases gradient 偏差梯度部分和指針
                         FINAL_DW,  # pointer to the weights gradient 權重梯度指針
                         FINAL_DB,  # pointer to the biases gradient 偏差梯度指針
                         M,  # GROUP_SIZE_M
                         N,  # number of columns 列數
                         BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
    # Map the program id to the elements of DW and DB it should compute.
    # 映射程序 id 到對應計算的 DW 和 DB
    pid = tl.program_id(0)
    cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    # Iterate through the rows of DW and DB to sum the partial sums.
    #迭代通過 DW 和 DB 的行,對部分和進行求和。
    for i in range(0, M, BLOCK_SIZE_M):
        rows = i + tl.arange(0, BLOCK_SIZE_M)
        mask = (rows[:, None] < M) & (cols[None, :] < N)
        offs = rows[:, None] * N + cols[None, :]
        dw += tl.load(DW + offs, mask=mask, other=0.)
        db += tl.load(DB + offs, mask=mask, other=0.)
    # Write the final sum to the output.
    # 將最終結果寫入輸出
    sum_dw = tl.sum(dw, axis=0)
    sum_db = tl.sum(db, axis=0)
    tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
    tl.store(FINAL_DB + cols, sum_db, mask=cols < N)

基準測試

現在我們可以比較 Triton 內核與 PyTorch 的性能了。這裏以每個特徵少於 64KB 的輸入為例進行講解。具體來説,可以設置 mode: 'backward' 來進行後向傳播的基準測試。

class LayerNorm(torch.autograd.Function):


    @staticmethod
    def forward(ctx, x, normalized_shape, weight, bias, eps):
        # allocate output
        # 分配輸出
        y = torch.empty_like(x)
        # reshape input data into 2D tensor
        # 將輸入數據的形狀改為 2D 張量
        x_arg = x.reshape(-1, x.shape[-1])
        M, N = x_arg.shape
        mean = torch.empty((M, ), dtype=torch.float32, device=x.device)
        rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)
        # Less than 64KB per feature: enqueue fused kernel
        # 少於 64KB 每個特徵:入隊融合內核
        MAX_FUSED_SIZE = 65536 // x.element_size()
        BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
        if N > BLOCK_SIZE:
            raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
        # heuristics for number of warps
        # 對 warp 數量的啓發算法
        num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
        # enqueue kernel
        # 入隊內核
        _layer_norm_fwd_fused[(M, )](  #
            x_arg, y, weight, bias, mean, rstd,  #
            x_arg.stride(0), N, eps,  #
            BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
        ctx.save_for_backward(x, weight, bias, mean, rstd)
        ctx.BLOCK_SIZE = BLOCK_SIZE
        ctx.num_warps = num_warps
        ctx.eps = eps
        return y


    @staticmethod
    def backward(ctx, dy):
        x, w, b, m, v = ctx.saved_tensors
        # heuristics for amount of parallel reduction stream for DW/DB
        # 計算對 DW/DB 並行規約流數量的啓發算法
        N = w.shape[0]
        GROUP_SIZE_M = 64
        if N <= 8192: GROUP_SIZE_M = 96
        if N <= 4096: GROUP_SIZE_M = 128
        if N <= 1024: GROUP_SIZE_M = 256
        # allocate output
        # 分配輸出
        locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device)
        _dw = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device)
        _db = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device)
        dw = torch.empty((N, ), dtype=w.dtype, device=w.device)
        db = torch.empty((N, ), dtype=w.dtype, device=w.device)
        dx = torch.empty_like(dy)
        # enqueue kernel using forward pass heuristics
        # 使用前向傳播啓發算法入隊內核
        # also compute partial sums for DW and DB
        # 同樣用於計算 DW 和 DB 的部分和
        x_arg = x.reshape(-1, x.shape[-1])
        M, N = x_arg.shape
        _layer_norm_bwd_dx_fused[(M, )](  #
            dx, dy, _dw, _db, x, w, m, v, locks,  #
            x_arg.stride(0), N,  #
            BLOCK_SIZE_N=ctx.BLOCK_SIZE,  #
            GROUP_SIZE_M=GROUP_SIZE_M,  #
            num_warps=ctx.num_warps)
        grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]
        # accumulate partial sums in separate kernel
        # 在單獨的內核中累加部分和
        _layer_norm_bwd_dwdb[grid](
            _dw, _db, dw, db, min(GROUP_SIZE_M, M), N,  #
            BLOCK_SIZE_M=32,  #
            BLOCK_SIZE_N=128, num_ctas=1)
        return dx, None, dw, db, None




layer_norm = LayerNorm.apply




def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'):
    # create data
    # 創建數據
    x_shape = (M, N)
    w_shape = (x_shape[-1], )
    weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
    bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
    x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device)
    dy = .1 * torch.randn_like(x)
    x.requires_grad_(True)
    # forward pass
    # 前向傳播
    y_tri = layer_norm(x, w_shape, weight, bias, eps)
    y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
    # backward pass (triton)
    # 反向傳播 (triton)
    y_tri.backward(dy, retain_graph=True)
    dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
    x.grad, weight.grad, bias.grad = None, None, None
    # backward pass (torch)
    # 反向傳播 (torch)
    y_ref.backward(dy, retain_graph=True)
    dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
    # 比較
    assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
    assert torch.allclose(dx_tri, dx_ref, atol=1e-2, rtol=0)
    assert torch.allclose(db_tri, db_ref, atol=1e-2, rtol=0)
    assert torch.allclose(dw_tri, dw_ref, atol=1e-2, rtol=0)




@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],
        x_vals=[512 * i for i in range(2, 32)],
        line_arg='provider',
        line_vals=['triton', 'torch'] + (['apex'] if HAS_APEX else []),
        line_names=['Triton', 'Torch'] + (['Apex'] if HAS_APEX else []),
        styles=[('blue', '-'), ('green', '-'), ('orange', '-')],
        ylabel='GB/s',
        plot_name='layer-norm-backward',
        args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'},
    ))
def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'):
    # create data
    # 創建數據
    x_shape = (M, N)
    w_shape = (x_shape[-1], )
    weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
    bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
    x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device)
    dy = .1 * torch.randn_like(x)
    x.requires_grad_(True)
    quantiles = [0.5, 0.2, 0.8]


    def y_fwd():


        if provider == "triton":
            return layer_norm(x, w_shape, weight, bias, eps)  # noqa: F811, E704


        if provider == "torch":
            return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps)  # noqa: F811, E704


        if provider == "apex":
            apex_layer_norm = (apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype))
            return apex_layer_norm(x)  # noqa: F811, E704


    # forward pass
    # 前向傳播
    if mode == 'forward':
        gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6
        ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500)
    # backward pass
    # 反向傳播
    if mode == 'backward':
        y = y_fwd()
        gbps = lambda ms: 3 * x.numel() * x.element_size() / ms * 1e-6  # noqa: F811, E704
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles,
                                                     grad_to_none=[x], rep=500)
    return gbps(ms), gbps(max_ms), gbps(min_ms)




test_layer_norm(1151, 8192, torch.float16)
bench_layer_norm.run(save_path='.', print_data=True)

在這裏插入圖片描述

Out:

layer-norm-backward:

在這裏插入圖片描述
在這裏插入圖片描述

參考文獻

[BA2016] Jimmy Lei Ba and Jamie Ryan Kiros and Geoffrey E. Hinton, “Layer Normalization”, Arxiv 2016

Download Jupyter notebook: 05-layer-norm.ipynb

Download Python source code: 05-layer-norm.py

Download zipped: 05-layer-norm.zip

user avatar qutianhang 頭像 fedl 頭像 kevinwan 頭像 chongdongdeludeng 頭像 xialiwei 頭像 danxiaodechahu_dmjjjs 頭像
6 位用戶收藏了這個故事!

發佈 評論

Some HTML is okay.