博客 / 詳情

返回

自動微分-初步教學

省流總結

數學分析部分

鏈式法則的基本規則:

$$ y = f(u, v), 則 \; \frac{\partial y}{\partial x} = \frac{\partial f}{\partial u} \cdot \frac{\partial u}{\partial x} + \frac{\partial f}{\partial v} \cdot \frac{\partial v}{\partial x} $$

用鏈式法則我們可以有這個公式。我們定義P性質\( P(n,g) \),當n=root,g=1的時候,就是我們想要的自動微分性質
$$\begin{aligned} P(n,\,g)\;\stackrel{\text{def}}{=}\;\text{調用 }n.\text{backward}(g)\text{ 能夠保證} \\ \quad\text{對所有滿足 }req(m)\text{ 且 }m\text{ 位於 }n\text{ 的子圖中的節點,} \\ \text{最終} \quad grad(m)\;\text{ 增加 }\;g\; \cdot \;\frac{\partial val(n)}{\partial val(m)}\text{}\end{aligned}$$
當我們確保代碼中每一個算子(add/mul等等)滿足下面的公式時,則P性質對於全圖成立(可由數學歸納法證明)
$$ \Bigl(\,\bigwedge_{i=1}^{n} P\!\bigl(x_i,\; g' \cdot \frac{\partial\,val(f)}{\partial\,val(x_i)}\bigr) \Bigr), \; \Delta grad(x_i)+=g’ \;\Longrightarrow\; P\!\bigl(f(x_1,\dots,x_n),\; g'\bigr) $$
也就是説,我們在反向傳播函數裏,向子節點傳播 \( \displaystyle g’ \cdot \frac{\partial\,val(f)}{\partial\,val(x_i)} \),便可保證梯度回傳的正確性。而對於backward到的當前節點,其將會將所傳入的梯度累加進已有梯度中(初始為0)

代碼實現部分

反向傳播函數:

    def backward(self, incoming_grad: Number = 1):
        """
        incoming_grad: 父節點傳下來的梯度 (∂L/∂self)
        """
        if not self.require_grad:
            return

        # 累加梯度
        self.grad = (self.grad or 0) + incoming_grad

        # 葉子節點結束
        if self.operator is None or not self.children:
            return

        # Operator 只給局部梯度,其餘遞歸由 Node 處理
        local_grads = self.operator.get_grad(self.children)  # list 與 children 對齊
        for child, local_g in zip(self.children, local_grads):
            child.backward(incoming_grad * local_g)

總結啓發部分

對於梯度不合法的處理:

  1. Autograd 層 應提供:
    • 定義域檢查 → nan/inf 報警或安全替代
    • 穩定公式、裁剪、loss scaling
    • 確定性次梯度與自定義梯度接口
  2. 使用者 應通過正規化、ε-修正、梯度裁剪、合適激活/優化器、監控與混合精度等手段,讓“數學上合法”的梯度在 數值實現 裏依舊穩定可靠。

更新日誌

2025-12-11 完成初稿

前置知識

  1. 微積分:鏈式求導法則
  2. 數據結構與算法:樹結構與搜索

    自動微分器教學

    數學知識

    自動微分器的核心在於鏈式法則,快速回顧:
    文字定義:如果一個函數是另一個函數的複合,那麼其導數可以通過內外函數的導數相乘來計算。

$$ \text{如果 } y = f(g(x)) \text{,則 } \frac{dy}{dx} = f'(g(x)) \cdot g'(x) $$

換一種更加常用的表述就是:

$$ y = f(u, v), 則 \; \frac{\partial y}{\partial x} = \frac{\partial f}{\partial u} \cdot \frac{\partial u}{\partial x} + \frac{\partial f}{\partial v} \cdot \frac{\partial v}{\partial x} $$

樹形解析式

基本數學公式可以被拆解為一棵樹(而且往往可以用二叉樹來表示),我們這裏使用多叉樹表示,其表示形式為

class Node:
    def __init__(self, value, children: list["Node"] = None, operator: AbstractOperator = None):
        self.value = value
        self.children = children
        self.operator = operator

在各種編程語言(包括python)解析算術式時,其也會將表達式拆解為一顆類似的樹。該樹的結構特點在於

  1. 葉子節點為“常量值”,沒有操作符。例如 x = 2 * 3,2,3在解析樹中就會在葉子節點
  2. 非葉子節點為“中間值”,包含操作符,其value由左右兩個子節點bottom-up得到
    因此,我們可以使用深度優先算法遞歸的求解每個節點的node值,並最終求解得到根節點的value。其具體過程略,在此放出代碼
# 策略模式
class Node:
    def __init__(self, value, children: list["Node"] = None, operator: AbstractOperator = None):
        self.value = value
        self.children = children
        self.operator = operator
        
    def calculate():
        self.value = operator.calculate(children)
        return self.value
        
# 例子:對於Add Operator
class AddOperator(AbstractOperator):  
    def calculate(self, children: list["Node"]):
        x0, x1 = children
        return x0.calculate() + x1.calculate()
  

用遞歸實現鏈式求導

在樹上實現鏈式求導遞歸的關鍵在於將鏈式法使用遞歸語言進行描述,我們考慮:

$$ y = f(u, v), 則 \; \frac{\partial y}{\partial x} = \frac{\partial f}{\partial u} \cdot \frac{\partial u}{\partial x} + \frac{\partial f}{\partial v} \cdot \frac{\partial v}{\partial x} $$

在進行鏈式求導的過程中,如上式所示,我們可以先將y(x)視為y(u,v),以u為例,先針對f求u的偏導,即求解 \( \displaystyle g_u=\frac{\partial f}{\partial u} \),求解u對x的偏導,並將g值傳遞給下一個求導項,即 \( \displaystyle g_{ux} = g_u * \frac{\partial u}{\partial x} \),同理我們得到 \( \displaystyle g_{vx} \) ,並將二者進行加和,即得到我們所需要求解的值 \( \displaystyle \frac{\partial y}{\partial x} = g_{ux} + g_{vx} \)。

graph TD
  y["y = f(u,v)"]
  u["u(x)"]
  v["v(x)"]
  x["x"]

  y --> u
  y --> v
  u --> x
  v --> x

上面的語言便描述了完整的backward遞歸過程,我們使用一個簡單的例子與節點圖以更清晰的描述這期間發生的事情。在dfs遍歷節點樹時:

  1. 首先,對於節點\( y = f(u,v) \),我們求解本節點(u,v)相較於父節點 y 的偏導數值。即\( \displaystyle \frac{\partial y}{\partial u},\frac{\partial y}{\partial v} \)
  2. 隨後將偏導數值傳遞給子節點。對於節點$u(x)$,我們計算\( \displaystyle \frac{\partial u}{\partial x} \),並累乘上前一步傳遞下來的梯度值 \( \displaystyle \frac{\partial y}{\partial u} \) ,得到梯度 \( \displaystyle \frac{\partial y}{\partial u} \cdot \frac{\partial u}{\partial x} \) ,並將梯度值傳遞給子節點。$v(x)$同理。
  3. 對於子節點$x$,其同時被傳遞了兩個不同方向上的梯度,即 \( g_{ux},g_{vx} \) 。其所需要做的事情就是將這兩個方向上的梯度進行加和,得到最終的根節點對自身的偏導 \( \displaystyle \frac{\partial y}{\partial x} = g_{ux} + g_{vx} \)

翻譯成遞歸代碼語言,我們定義可遞歸函數backward(bw_grad),其行為為

  1. 終止條件:當前節點為根節點(代碼實現是額外提供給用户一個可選項:require_grad == false,選擇性關閉部分節點的梯度計算(如常量節點的梯度就是沒有意義的))
  2. 遞歸循環:

    1. 將bw_grad累加到sef.grad中(self.grad初始化為0)
    2. 依據該節點算子 \( parent=f(left,right) \) ,調用 \( \displaystyle left.backward(bw\_grad*\frac{\partial parent}{\partial left}) \)

上述流程寫成代碼就是這個樣子:

class Node:
    def __init__(self, value, children: list["Node"] = None, operator: AbstractOperator = None, require_grad=True):
        self.value = value
        self.children = children
        self.operator = operator
        self.grad = None
        self.require_grad = require_grad

    def backward(self, backward_grad=None):
        if not self.require_grad:
            return
        # 1. 計算當前這次傳入的梯度(初始化)
        if backward_grad is None:  # 根結點梯度為1
            backward_grad = 1

        # 2. 累加到自己的梯度(STEP3)
        if self.grad is None:
            self.grad = 0
        self.grad += backward_grad

        # 3. 繼續傳播 *本次* backward_grad
        if self.operator is not None:
            self.operator.backward(self.children, backward_grad)
            
class AddOperator(AbstractOperator):
    def backward(self, children: list["Node"], grad: Number): # (STEP2)
        for e in AU.getN(children, 2): # getN就是檢查參數是倆個,這裏簡單實現保證加法tree是二叉樹了
            e: Node
            e.backward(grad) # y = u + v偏導數u,v都是1,所以直接grad*1

因此我們發現核心在於backward部分求解的梯度會因為$y=f(u,v)$,的運算子$f$不同而不同,下面給出乘法的例子

數學分析:

$$ 對於 \; y = u \cdot v, 有\; \frac{\partial y}{\partial x} = v \cdot \frac{\partial u}{\partial x} + u \cdot \frac{\partial v}{\partial x} $$

寫成代碼就是這樣(記得乘上父節點傳下來的梯度 grad)

class MulOperator(AbstractOperator):

    def calculate(self, children: list["Node"]) -> "Node":
        x0, x1 = AU.getN(children, 2)
        return Node(value=x0.value * x1.value, children=children, operator=self)

    def backward(self, children: list["Node"], grad: Number):
        x0, x1 = AU.getN(children, 2)
        x0: Node
        x1: Node
        x0.backward(x1.value * grad) # 乘法運算符的偏導數,詳細見公式
        x1.backward(x0.value * grad)

同理,我們也可以構造出減法、除法、冪函數、指數、對數、三角函數節點,從而覆蓋深度學習框架中的絕大多數運算符
【同理實現特殊的函數們,比如softmax,relu,gelu,sqrt,abs等等)

值得注意的是,對於指數和對數。我們有換底公式
$$f_1(u,v) = u(x)^{v(x)} = e^{u(x)ln(v(x))}$$

$$ f_2(u,v) = log_{u(x)}{v(x)} = \frac{ln(u(x))}{ln(v(x))} $$

因此我們只需要實現 \( f_3(u) = e^{u(x)},f_4(u)= ln(u(x)) \)算子即可,可以避免較為複雜的求導
(但其實硬要求導實現 \( f_1,f_2 \) 也是可以的。注意需要考慮邊界條件,在出現違背有效取值範圍時梯度應該被置為nan)

真實優化器中出現inf和nan的原因往往是因為梯度爆炸/過小造成

完整代碼示例

一個更合適更加容易理解的代碼版本:

from __future__ import annotations
from abc import ABC, abstractmethod
from numbers import Number
from enum import Enum
from typing import List


# ========== Operator 定義 ==========
class AbstractOperator(ABC):
    """每種算子只需要關心
       1. 前向值  forward(children)
       2. 對各 child 的局部梯度 get_grad(children) -> list[Number]
    """

    @abstractmethod
    def forward(self, children: List["Node"]) -> Number:
        pass

    @abstractmethod
    def get_grad(self, children: List["Node"]) -> List[Number]:
        pass


class AddOperator(AbstractOperator):
    def forward(self, children: List["Node"]) -> Number:
        a, b = children
        return a.value + b.value

    def get_grad(self, children: List["Node"]) -> List[Number]:
        # ∂(a+b)/∂a = 1,  ∂(a+b)/∂b = 1
        return [1, 1]


class MulOperator(AbstractOperator):
    def forward(self, children: List["Node"]) -> Number:
        a, b = children
        return a.value * b.value

    def get_grad(self, children: List["Node"]) -> List[Number]:
        a, b = children
        # ∂(a·b)/∂a = b,  ∂(a·b)/∂b = a
        return [b.value, a.value]


class OperatorTypes(Enum):
    ADD = AddOperator()
    MUL = MulOperator()


# ========== Node 定義 ==========
class Node:
    def __init__(
        self,
        value: Number | None = None,
        children: List["Node"] | None = None,
        operator: AbstractOperator | None = None,
        require_grad: bool = True,
    ):
        self._value = value
        self.children = children
        self.operator = operator
        self.grad: Number | None = None
        self.require_grad = require_grad

    # ---------- 前向 ----------
    @property
    def value(self) -> Number:
        # lazy 計算:只有在需要時、且自身沒有 value,才通過 forward 計算
        if self._value is None and self.operator is not None:
            self.calculate()
        return self._value

    def calculate(self):
        """調用 operator.forward 並更新自身 value"""
        if self.operator is None:
            return self._value
        # 先確保子節點都已計算
        for c in self.children:
            c.calculate()
        self._value = self.operator.forward(self.children)
        return self._value

    def update(self):
        """從葉子到根整棵樹執行一次前向"""
        for c in (self.children or []):
            c.update()
        self.calculate()

    # ---------- 反向 ----------
    def backward(self, incoming_grad: Number = 1):
        """
        incoming_grad: 父節點傳下來的梯度 (∂L/∂self)
        """
        if not self.require_grad:
            return

        # 累加梯度
        self.grad = (self.grad or 0) + incoming_grad

        # 葉子節點結束
        if self.operator is None or not self.children:
            return

        # Operator 只給局部梯度,其餘遞歸由 Node 處理
        local_grads = self.operator.get_grad(self.children)  # list 與 children 對齊
        for child, local_g in zip(self.children, local_grads):
            child.backward(incoming_grad * local_g)

    # ---------- 可視化 ----------
    def get_tree(self, depth: int = 0) -> str:
        pad = "  " * depth
        if self.operator is None:
            return f"{pad}{self.value}\n"
        name = self.operator.__class__.__name__.replace("Operator", "")
        s = f"{pad}{name}({self.value})\n"
        for c in self.children:
            s += c.get_tree(depth + 1)
        return s

    # ---------- 清零梯度 ----------
    def zero_grad(self):
        self.grad = None
        for c in self.children or []:
            c.zero_grad()

    # 工廠:由 children + operator 構造中間節點
    @classmethod
    def op(cls, children: List["Node"], op_type: OperatorTypes) -> "Node":
        return cls(value=None, children=children, operator=op_type.value)


# ========== 輕量包裝,支持 +、*、backward ==========
class Tensor:
    def __init__(self, node: Node):
        self.node = node

    @property
    def value(self):
        return self.node.value

    @property
    def grad(self):
        return self.node.grad

    # ----- 運算符重載 -----
    def __add__(self, other: "Tensor"):
        return Tensor(Node.op([self.node, other.node], OperatorTypes.ADD))

    def __mul__(self, other: "Tensor"):
        return Tensor(Node.op([self.node, other.node], OperatorTypes.MUL))

    # ----- autograd -----
    def backward(self):
        self.node.backward()

    def zero_grad(self):
        self.node.zero_grad()

    def tree(self):
        print(self.node.get_tree())


class Var(Tensor):
    def __init__(self, value: Number):
        super().__init__(Node(value=value))


class Const(Tensor):
    def __init__(self, value: Number):
        super().__init__(Node(value=value, require_grad=False))


# ========== Demo ==========
if __name__ == "__main__":
    c = Const(2)
    z = Var(3)
    x = Var(4)
    k = Var(5)
    g = Var(6)

    y = c * z * x + k * x + g     # 複合表達式
    y.backward()

    print("=== y = c*z*x + k*x + g ===")
    print("x.grad:", x.grad)   # (c*z + k)
    print("z.grad:", z.grad)   # (c*x)
    print("k.grad:", k.grad)   # (x)
    print("g.grad:", g.grad)   # 1
    y.tree()

    # 斷言
    assert x.grad == (c * z + k).value
    assert z.grad == (c * x).value
    assert k.grad == x.value
    assert g.grad == 1
    print("Forward / Backward ✔")

    # 第二個例子:y2 = x * x
    y.zero_grad()
    y2 = x * x
    y2.backward()
    print("\n=== y2 = x*x ===")
    print("x.grad:", x.grad)   # 2*x
    assert x.grad == (Const(2) * x).value
    print("Forward / Backward ✔")

最開始寫的代碼Node和Operator耦合了,僅供參考


from abc import ABC, abstractmethod
from enum import Enum
from numbers import Number

from utils.ArgsUtils import AU


class AbstractOperator(ABC):
    @abstractmethod
    def calculate(self, children: list["Node"]) -> "Node":
        pass

    @abstractmethod
    def backward(self, children: list["Node"], grad: Number):
        pass


class AddOperator(AbstractOperator):
    def calculate(self, children: list["Node"]) -> "Node":
        x0, x1 = AU.getN(children, 2)
        return Node(value=x0.value + x1.value, children=children, operator=self)

    def backward(self, children: list["Node"], grad: Number):
        for e in AU.getN(children, 2):
            e: Node
            e.backward(grad)


class MulOperator(AbstractOperator):

    def calculate(self, children: list["Node"]) -> "Node":
        x0, x1 = AU.getN(children, 2)
        return Node(value=x0.value * x1.value, children=children, operator=self)

    def backward(self, children: list["Node"], grad: Number):
        x0, x1 = AU.getN(children, 2)
        x0: Node
        x1: Node
        x0.backward(x1.value * grad)
        x1.backward(x0.value * grad)


class OperatorTypes(Enum):
    ADD = AddOperator()
    MUL = MulOperator()


class Node:
    def __init__(self, value, children: list["Node"] = None, operator: AbstractOperator = None, require_grad=True):
        self.value = value
        self.children = children
        self.operator = operator
        self.grad = None
        self.require_grad = require_grad

    @classmethod
    def from_child(cls, nodes: list["Node"], operator_type: OperatorTypes) -> "Node":
        return operator_type.value.calculate(nodes)

    def backward(self, backward_grad=None):
        if not self.require_grad:
            return
        # 1. 計算當前這次傳入的梯度
        if backward_grad is None:  # 根結點
            backward_grad = 1

        # 2. 累加到自己的梯度
        if self.grad is None:
            self.grad = 0
        self.grad += backward_grad

        # 3. 繼續傳播 *本次* backward_grad
        if self.operator is not None:
            self.operator.backward(self.children, backward_grad)

    def zero_grad(self):
        self.grad = None
        if self.children is None:
            return
        for c in self.children:
            c.zero_grad()


class NodeWrapper:
    def __init__(self, node):
        self.node = node

    @property
    def value(self):
        return self.node.value

    @property
    def grad(self):
        return self.node.grad

    def __add__(self, other: "NodeWrapper"):
        return NodeWrapper(Node.from_child([self.node, other.node], OperatorTypes.ADD))

    def __mul__(self, other: "NodeWrapper"):
        return NodeWrapper(Node.from_child([self.node, other.node], OperatorTypes.MUL))

    def backward(self):
        self.node.backward()

    def zero_grad(self):
        self.node.zero_grad()


class Var(NodeWrapper):
    def __init__(self, value):
        super().__init__(Node(value))


class Const(NodeWrapper):

    def __init__(self, value):
        super().__init__(Node(value, require_grad=False))


if __name__ == "__main__":
    c = Const(2)
    z = Var(3)
    x = Var(4)
    k = Var(5)
    g = Var(6)
    y = c * z * x + k * x + g

    y.backward()
    print(f"y:  x.grad:{x.grad};z.grad:{z.grad};k.grad:{k.grad};g.grad:{g.grad}")
    assert x.grad == (c * z + k).value
    assert z.grad == (c * x).value
    assert g.grad == 1
    assert k.grad == x.value
    print("y assertion passed!")
    y.zero_grad()

    y2 = x*x
    y2.backward()
    print(f"y2:  x.grad:{x.grad}")
    assert x.grad == (Const(2) * x).value
    print("y2 assertion passed!")

遞歸正確性證明

對於上述樹遞歸的正確性,我們可以通過上面的實例較好的理解。下面,我們使用數學歸納法對算法的正確性進行嚴格證明
我們只需證明性質P(歸納假設)即可

  1. 記 $val(n)$ 為 n.value,$grad(n)$ 為調用 backward 結束後的 n.grad
  2. 記 $req(n)$ 為謂詞:節點 n.require_grad == True
  3. 記 $child(n)$ 為 n.children(若為空則為 $\varnothing$)。
  4. 關鍵性質(待證):
    $$\begin{aligned} P(n,\,g)\;\stackrel{\text{def}}{=}\;\text{調用 }n.\text{backward}(g)\text{ 能夠保證} \\ \quad\text{對所有滿足 }req(m)\text{ 且 }m\text{ 位於 }n\text{ 的子圖中的節點,} \\ \text{最終} \quad grad(m)\;\text{ 增加 }\;g\; \cdot \;\frac{\partial val(n)}{\partial val(m)}\text{}\end{aligned}$$
    當 $g = 1$ 且 $n = root\_node$ 時,就得到我們想要的結論。

很容易證明我們的樹擁有這個性質:
基礎步驟:全局只有葉子節點時

  • req(n) == False,則空洞成立
  • req(n) == True,則葉子節點將執行 grad(n) += g而沒有進一步遞歸,因此顯然成立 ($\displaystyle \frac{\partial val(n)}{\partial val(n)} = 1$)

遞歸演繹:(以mul節點為例)
對於mul節點:兩個節點用MUL父節點結合後,父節點仍滿足這一性質,即有

$$ P(x_1, g’ \cdot val(x_2)) \; \land \; P(x_2, g’ \cdot val(x_1)) \Longrightarrow P(mul(x_1, x_2), g’) $$

證明,由乘法求導有:
$$ \frac{\partial val(mul)}{\partial val(x_1)} = val(x_2), \quad \frac{\partial val(mul)}{\partial val(x_2)} = val(x_1) $$

若 x1 與 x2 不同節點,則顯然我們發現在前面兩個條件下, \( P(n,g) \) 的性質的對 $P(mul(x_1, x_2), g’)$ 保持了。對於樹上的所有節點:

  • mul(x1,x2)節點被初始化為 \( g’ \),則是正確的
  • 對於x1,x2節點,其grad的增加為 \( g’ \),所以對這兩個節點成立
  • 而對於x1,x2的所有子節點,由遞歸演繹可知,這兩個節點滿足$P$ 性質(“n-1"層成立)。因此我們可以寫出:對於x1子樹的某個節點\( m \),其梯度\( grad(m) \)將增加:

$$ \Delta grad(m) = g’ \cdot val(x_2) \cdot \frac{\partial val(x_1)}{\partial val(m)} = g’ \cdot \frac{\partial val(mul)}{\partial val(x_1)} \cdot \frac{\partial val(x_1)}{\partial val(m)} = g’ \cdot \frac{\partial val(mul)}{\partial val(m)} $$

x2子樹同理,因此我們證明了所有的mul樹的子節點都將增加相應的梯度,則性質對mul節點成立

總結
同樣基於上述方法,我們將可以同樣證明add等其他任意節點的正確性,只要其正確的在回傳時回傳了 backward_grad * 當前節點相對於父節點的grad。也就是有一個通用公式
$$ \Bigl(\,\bigwedge_{i=1}^{n} P\!\bigl(x_i,\; g' \cdot \frac{\partial\,val(f)}{\partial\,val(x_i)}\bigr) \Bigr) \;\Longrightarrow\; P\!\bigl(f(x_1,\dots,x_n),\; g'\bigr) $$

常見疑問澄清

關於控制流生成圖

  1. 如果深度學習的函數式子中有分支語句怎麼辦?
    回答:由於計算圖是運行時生成的,假設分支語句分出A,B兩枝,在生成圖的過程中,一定只有A枝或者B枝會被執行,因此生成得到的計算圖將會是確定的

    關於梯度不合法的情況

    在真實數學求導中,梯度可能會不存在(輸入不在定義域內,nan)、或者數學意義上區域無窮inf,因此此類情況需要進行特殊判斷。

數學邏輯上,如果正常發生了前向,那麼反向梯度多數情況也是合法的。但存在如下反例:

  1. 導數本身可能不存在,例如: \( y=∣x∣ \),幾乎所有主流框架在$x=0$處會給一個“次梯度”:0。因此如果有任何動力需求,需要覆蓋框架的實現
  2. 導數可存在但無限大:例如 \( y = sqrt(x) \) 在 \( x=0 \) 處區域有 \( \frac{\partial y}{\partial x} = \frac{1}{2\sqrt{x}} \to \infty \),雖然其在此處的定義域是合法的

我們同時需要注意,計算機中常用編碼下所能存儲的數字是有界且離散的,因此存在:

  • 浮點數運算過程中的非精確問題,因此往往使用isclose等類似方法,當數字與臨界/極限差值已經非常小時,便直接判定不合法
  • 溢出,則判定為inf
    因此,一些模型在進行反向傳播的過程中,會出現”數學上理應保證合法,但實際上爆出inf/nan的情況“

    實用建議(@GPT):對於Autograd開發者
  • Autograd 層 應提供:
    • 定義域檢查 → nan/inf 報警或安全替代
    • 穩定公式、裁剪、loss scaling
    • 確定性次梯度與自定義梯度接口
關注點 建議做法 典型實現/説明
1. 定義域檢測 ‐ 在 前向 明確檢查輸入是否落在數學定義域。
‐ 對於非法輸入直接產生 nan/inf 並拋 Error(易於定位);或返回安全替代(見下)。
TF:tf.debugging.assert_all_finite
PyTorch:torch.finfo, torch.isnan, torch.isinf; torch.autograd.detect_anomaly()
2. 數值穩定公式 提供 穩定算子logsumexpsoftplussafe_log1p 等,把極端值映射回可表達範圍。 內部自動減去 max, 加 ε,或採用泰勒展開。
3. 不可導點選擇次梯度 對分段函數或絕對值,在不可導點固定返回 0 或其他確定值,保持可重複性。 sign(0)=0abs'(0)=0relu'(0)=0;可自定義。
4. ∞ 導數→飽和 / 截斷 對導數 (\ g\ >g_{\max}) 設閾值;或在前向加 ε 避免 0 除。 例:sqrt(x) 內部做 max(x, ε)rsqrt 內部 clip。
5. nan/inf 傳播與診斷 保持 IEEE-754 傳播(出錯即沿圖擴散),並提供異常探針;遇到非法梯度可觸發梯度 NaN skip / zero PyTorch GradScaler, AMP 自動跳過 NaN step。
6. 數據類型與動態縮放 在 FP16/bfloat16 訓練中內置 Loss Scaling,防止梯度下溢;在累加時轉 FP32。 Apex/AMP:動態 scale;XLA mixed precision。
7. 提供 梯度裁剪 原語 clip_by_norm, clip_by_value, clip_grad_norm_, clip_grad_value_ 優化器調用前後自動裁剪。
8. 允許 用户自定義梯度 暴露 custom_gradient / autograd.Function 接口,對特殊邊界寫出精確或近似梯度。 解決諸如 (\sqrt{\text{ReLU}}) 等 0·∞ 問題。
實用建議(@GPT):對於模型訓練者
  1. 使用者 通過正規化、ε-修正、梯度裁剪、合適激活/優化器、監控與混合精度等手段,才能讓“數學上合法”的梯度在 數值實現 裏依舊穩定可靠。
# 建議 關鍵點 / 示例
1 輸入歸一化 訓練前做 z-score / min-max,避免過大過小激活。
2 選擇數值穩定的激活/損失 Softplus ≈ ReLU+ε;用 logsumexp 替代 log(softmax);用 silu, gelu 代替裸 exp
3 顯式加 ε / clamp log(x+ε), sqrt(clamp(x,min=ε));ε 取 1e-6~1e-12 依精度而定。
4 梯度裁剪 常用:L2‐norm 裁剪 clip_grad_norm_(model.parameters(), 1.0);或每元素裁剪。
5 合理初始化 He/Kaiming, Xavier, or LSUV;避免早期梯度爆炸/消失。
6 優化器超參穩健化 學習率 warm-up、餘弦衰減;Adam/Adagrad 自適應;謹慎設置 β₁, β₂,避免動量累積過大。
7 Mixed Precision + Loss Scaling 提高吞吐的同時自動放大梯度,防止 underflow;若 step 出現 NaN,自動回滾並增減 scale。
8 監控 & 斷點恢復 每 N step 記錄 grad_norm, param_norm, loss; 出現異常立刻 checkpoint & debug。
9 自定義安全運算 如需 sqrt(relu(x)),實現 safe_sqrt()sqrt(relu(x)+ε),或手寫 backward 返回掩碼。
10 靜態/動態圖審計 使用 torch.fx, tf.function 的 graph-walker 或 ptxas/nvdisasm 看是否有隱含除 0、exp 溢出。

拓展閲讀

[待續]

user avatar
0 位用戶收藏了這個故事!

發佈 評論

Some HTML is okay.