动态

详情 返回 返回

Go 併發控制:sync.WaitGroup 詳解 - 动态 详情

首發地址:https://mp.weixin.qq.com/s/-FtDLcHW39vgvqSMUVM-yw

前段時間我在《Go 併發控制:errgroup 詳解》一文中講解了 errgroup 的用法和源碼,通過源碼我們知道 errgroup 內部是使用 sync.WaitGroup 實現的,那麼本文就更進一步,來探索下 sync.WaitGroup 源碼是如何實現的。

使用示例

sync.WaitGroup 可以用來阻塞等待一組併發任務(goroutine)的完成,使用示例如下:

package main

import (
    "fmt"
    "net/http"
    "sync"
)

func main() {
    var urls = []string{
        "http://www.golang.org/",
        "http://www.google.com/",
        "http://www.somestupidname.com/", // 這是一個錯誤的 URL,會導致任務失敗
    }

    var wg sync.WaitGroup
    errs := make([]error, len(urls)) // 使用 slice 收集錯誤

    for i, url := range urls {
        wg.Add(1)
        go func() {
            defer wg.Done()
            resp, err := http.Get(url)
            if err != nil {
                errs[i] = fmt.Errorf("failed to fetch %s: %v", url, err)
                return
            }
            defer resp.Body.Close()
            fmt.Printf("fetch url %s status %s\n", url, resp.Status)
        }()
    }

    wg.Wait()

    // 處理所有錯誤
    for i, err := range errs {
        if err != nil {
            fmt.Printf("fetch url %s error: %s\n", urls[i], err)
        }
    }
}

示例中,我們使用 sync.WaitGroup 來啓動 3 個 goroutine 併發訪問 3 個不同的 URL,並在成功時打印響應狀態碼,或失敗時記錄錯誤信息。

執行示例代碼,得到如下輸出:

$ go run waitgroup/main.go
fetch url http://www.google.com/ status 200 OK
fetch url http://www.golang.org/ status 200 OK
fetch url http://www.somestupidname.com/ error: failed to fetch http://www.somestupidname.com/: Get "http://www.somestupidname.com/": dial tcp: lookup www.somestupidname.com: no such host

我們得到了兩個成功的響應,並記錄了一條錯誤信息。

根據示例,我們可以抽象出 sync.WaitGroup 最典型的慣用法:

var wg sync.WaitGroup

for ... {
    wg.Add(1)

    go func() {
        defer wg.Done()
        // do something
    }()
}

wg.Wait()

sync.WaitGroup 零值可用,它會在內部維護一個計數器,wg.Add(1)會將 sync.WaitGroup 計數器的值加 1,表示增加一個 goroutine 計數;wg.Done() 則將計數器的值減 1,表示一個 goroutine 任務已經完成;wg.Wait() 會阻塞調用者所在的 goroutine,直到計數器的值為 0。

源碼解讀

本文以 Go 1.23.0 版本源碼為基礎進行講解。

WaitGroup 結構體

首先 sync.WaitGroup 定義如下:

https://github.com/golang/go/blob/go1.23.0/src/sync/waitgroup.go
// WaitGroup 結構體
type WaitGroup struct {
    noCopy noCopy // 避免複製

    state atomic.Uint64 // 高 32 位是計數器(counter)的值,低 32 位是等待者(waiter)的數量
    sema  uint32        // 信號量,用於 阻塞/喚醒 waiter
}

sync.WaitGroup 是一個結構體,所以這也是其零值可用的原因。

這個結構體包含 3 個字段:

  • noCopy 字段的類型也叫 noCopy,這個字段用於標識 sync.WaitGroup 結構體不可被複制,vet 工具能夠識別它。這個字段的具體細節我們暫且不必深究,它不是 sync.WaitGroup 的核心功能,在文章最後再來解釋它。
  • state 字段是一個原子類型 atomic.Uint64,所以對 state 字段的修改能夠保證原子性。它比較有意思,sync.WaitGroup 結構體使用這一個字段來表示兩個“變量”值,高 32 位是計數器(counter)的值,低 32 位是等待者(waiter)的數量。我們調用 wg.Add(1)wg.Done() 時操作的就是計數器 counter;調用 wg.Wait() 時等待者 waiter 數量就會加 1。
  • sema 是一個信號量,用於阻塞/喚醒 waiter,即調用 wg.Wait() 時的阻塞和喚醒都依賴這個信號量。

sync.WaitGroup 結構體只有 3 個方法:AddDoneWait

Done 方法

我們先來看 Done 方法的源碼實現:

// Done 將計數器(counter)值減 1
func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

調用 Done 方法可以讓 counter 值減 1。可以發現,調用 wg.Done() 方法實際上等價於調用 wg.Add(-1),所以我們的重點還是要關注 Add 方法。

Add 方法

Add 方法源碼實現如下:

// Add 為計數器(counter)的值增加 delta(delta 可能為負數)
func (wg *WaitGroup) Add(delta int) {
    state := wg.state.Add(uint64(delta) << 32) // delta 左移 32 位後與 state 相加,即為 counter 值加上 delta
    v := int32(state >> 32)                    // state 右移 32 位得到 counter 的值
    w := uint32(state)                         // state 轉成 uint32 拿到低 32 位的值,得到 waiter 的值

    // 如果 counter 值為負數,直接 panic
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    
    // 併發調用 Wait 和 Add 會觸發 panic
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }

    // 條件成立説明 counter 值加上 delta 操作成功,返回
    if v > 0 || w == 0 {
        return
    }

    // 如果 counter 值為 0,並且還有被阻塞的 waiter,程序繼續向下執行

    // 併發調用 Wait 和 Add 會觸發 panic
    if wg.state.Load() != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }

    // 目前 counter 值已經為 0,這裏重置 waiter 數量為 0
    wg.state.Store(0)
    for ; w != 0; w-- { // 喚醒所有 waiter
        runtime_Semrelease(&wg.sema, false, 0)
    }
}

NOTE:

為了方便你理解,我將 Add 方法源碼中使用 race 包做競態檢查部分的代碼去掉了,並適當的增加了空行。

Add 方法接收一個 int 類型的 delta 值,在方法第一行,先將這個值轉換成 uint64 類型,然後通過移位操作,將其左移 32 位後與 wg.state 值相加,得到新的 state。我在前文説過,wg.state 字段的高 32 位表示 counter 值,所以這行代碼的作用就是為 counter 值加上 delta。當 delta 值為正數,counter 值增加,當 delta 值為負數,counter 值減少。

接着,將 state 值右移 32 位,得到高 32 位的 counterv;使用 uint32(state) 操作將 uint64 類型強轉成 uint32,捨棄高 32 位,得到低 32 位的 waiterw。注意,這裏拿到的 vw 是與 delta 計算後的最新值。

接下來會做兩個校驗,先對 counter 進行判斷,如果 v 的值為負數,會觸發 panic,所以我們在使用時要小心 counter 不能為負;然後又對併發調用 WaitAdd 方法的場景做了校驗,如果併發調用二者,同樣會觸發 panic

關於判斷是否併發調用 WaitAdd 方法的場景,我在詳細解釋下:

  • w != 0 表明有等待者 waiter 存在,即已經有 goroutine 調用了 wg.Wait() 方法,正在阻塞等待,還未返回。
  • delta > 0 表明這次調用 Add 方法是要增加計數器 counter 的值,這也説明肯定不是通過調用 wg.Done() 方法觸發的。
  • v == int32(delta) 表明在調用 Add 方法之前,counter 的值為 0。因為 v 是計算後的 counter 值,它等於 delta,就説明在計算之前 counter 的值為 0。

如果這三個條件同時滿足,即 w != 0 && delta > 0 && v == int32(delta)true,就説明我們在調用 wg.Wait() 方法以後,還未等到喚醒它,就馬上又調用了 wg.Add(delta) 方法,此時就會觸發 panic。所以,我們在使用時要記住,一定要在調用 wg.Wait() 之前調用 wg.Add(delta)

做完了校驗以後,就到了 Add 方法的第一個出口,如果 v > 0 説明我們正常調用了 Add 方法或 Done 方法,計數器 counter 此時還未清零,那麼無需喚醒 wg.Wait() 的阻塞等待,直接返回即可;或者如果 w == 0 説明當前沒有正在等待的 waiter,即還未調用 wg.Wait(),那麼也可以直接返回。

那麼現在,Add 方法還能繼續往下執行的條件是:v == 0 && w > 0,即 counter 值為 0,並且還有被阻塞的 waiter

既然計數器 counter 值已經為 0,那麼就可以喚醒所有被阻塞的 wg.Wait() 調用了,這也是接下來的程序邏輯。

不過,這裏會再次對併發調用 WaitAdd 方法的場景進行校驗。如果此時從 wg.state 字段獲取到的最新值與變量 state 值不一致,即 wg.state.Load() != statetrue,則會觸發 panic。所以,當 counter 值變為 0,程序即將喚醒被阻塞的 waiter 之前這一小段時間,不要併發的調用 wg.Add(delta) 來改變計數器的值。

最後,通過 wg.state.Store(0)waiter 的值置為 0(因為此時 counter 值已經是 0 了,所以這個操作的目的是將 waiter 值置 0),並使用 runtime_Semrelease 來喚醒所有被阻塞的 waiter

NOTE:

關於 runtime_Semrelease 以及下文將要介紹的 runtime_Semacquire 方法則不必深究,這是 Go 語言底層 runtime 為我們實現的用於喚醒或阻塞當前 goroutine 的函數。

至此,Add 方法就分析完成了。

我們現在可以總結下 Add 方法的作用:Add 為計數器 counter 的值增加 deltadelta 可能為負數),如果計算結果 counter 為負數,則觸發 panic;如果 counter 為正數,則正常返回;如果 counter 為 0,則喚醒所有被阻塞的 waiter

所以 Add 方法主要用來管理計數器 counter,並在 counter 為 0 時,喚醒 waiter

Wait 方法

現在,我們再來看下 sync.WaitGroup 結構體最後一個方法 Wait 的源碼實現:

// Wait 阻塞調用者當前的 goroutine(waiter),直到計數器(counter)值為 0
func (wg *WaitGroup) Wait() {
    for { // 開啓無限循環保證 CAS 操作成功
        state := wg.state.Load()
        v := int32(state >> 32) // 拿到 counter 值
        // w := uint32(state)   // 拿到 waiter 值

        if v == 0 { // 如果 counter 值已經為 0,直接返回
            return
        }

        // 使用 CAS 操作增加 waiter 的數量
        if wg.state.CompareAndSwap(state, state+1) {
            runtime_Semacquire(&wg.sema) // 阻塞當前 waiter 所在的 goroutine,等待被喚醒

            // 併發調用 Wait 和 Add 會觸發 panic
            if wg.state.Load() != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }

            return // 如果 state 值為 0,説明 waiter 所等待的任務全部完成,成功返回
        }
    }
}

NOTE:

為了方便你理解,我同樣將 Wait 方法源碼中使用 race 包做競態檢查部分的代碼去掉了,並適當的增加了空行。

首先,這裏開啓了一個無限 for 循環,這是為了重試下面的 CAS 操作,保證其執行成功。

Wait 方法同樣使用移位操作拿到到高 32 位的 counterv 和低 32 位的 waiterw。因為 w 只會在競態檢查的代碼中被用到,所以被我手動註釋掉了。

接下來判斷計數器 counter 的值是否為 0,如果 v 已經為 0,那麼無需阻塞 waiter,直接返回即可。否則,需要對 waiter 的值進行加 1 操作,這裏使用 CAS 操作(即 Compare And Swap)來完成。

所謂的 CAS 操作,就是先 Compare 再 Swap。當我們調用 wg.state.CompareAndSwap(state, state+1) 時,CompareAndSwap 方法會先判斷 wg.state 值是否等於傳進來的第一個參數 state,如果相等,則將其替換為第二個參數 state+1 的值,並返回 true;如果 wg.state 值與 state 不相等,則不會修改 wg.state,並返回 false。這樣,就保證了對 wg.state 的修改是原子性的。

在併發場景中,CAS 操作可能失敗,返回 false,所以需要結合最外層的 for 無限循環,來保證 CAS 操作成功。

一旦 CAS 操作成功,即 waiter 的數量加 1,就會使用 runtime_Semacquire 來阻塞當前 waiter 所在的 goroutine,等待被喚醒。而喚醒時機,就是在 Add 方法的最後對 runtime_Semrelease(&wg.sema, false, 0) 的調用。

waiter 被喚醒後,會對併發調用 WaitAdd 方法的場景進行校驗。如果 wg.state.Load() != 0true,則會觸發 panic。因為 Add 方法在調用 runtime_Semrelease 喚醒所有 waiter 之前,已經通過 wg.state.Store(0)waiter 的值置為 0 了,所以在不出現併發調用的情況下,wg.state.Load() 的值必然為 0。

而如果沒有出現併發調用 WaitAdd 方法,則説明 waiter 所等待的任務全部完成,正常返回即可。

至此,sync.WaitGroup 結構體最後一個方法Wait 就分析完成了。

根據源碼,我們能夠分析出:Wait 方法主要用來管理 waiter,它會阻塞所有 waiter,並等待被 Add 喚醒。

noCopy 結構體

現在 sync.WaitGroup 結構體的核心功能就全部講解完成了,是時候介紹下 noCopy 了。

noCopy 實際上也是一個結構體,其定義如下:

// noCopy may be added to structs which must not be copied
// after the first use.
//
// See https://golang.org/issues/8005#issuecomment-190753527
// for details.
//
// Note that it must not be embedded, due to the Lock and Unlock methods.
type noCopy struct{}

// Lock is a no-op used by -copylocks checker from `go vet`.
func (*noCopy) Lock()   {}
func (*noCopy) Unlock() {}

noCopy 非常簡單,就是一個空結構體實現了 Locker 接口。noCopy 結構體的唯一功能,就是用於輔助 vet 工具檢查用的,vet 工具遇到它就會知道,這個結構體是不能被複制的,僅此而已。

此外,我在《Go 中空結構體慣用法,我幫你總結全了!》一文中 #標識符 小節介紹瞭如何將它用在我們自定義的結構體中,感興趣的讀者可以點擊鏈接跳轉過去學習。

好了,sync.WaitGroup 源碼解析就講解到這裏。

總結

你一定要記住 sync.WaitGroup 的慣用法,首先,它無需初始化,零值可用;其次,它會在內部維護一個計數器 counter,通過 wg.Add(delta)wg.Done() 來操作計數器的值;它還會維護一個等待者數量 waiter,調用 wg.Wait() 會阻塞 waiter 所在的 goroutine;當計數器 counter 的值為 0,所有 waiter 都會被喚醒。

還要注意不要併發調用 AddWait 方法,也不要讓計數器 counter 的值為負數,不然會觸發 panic

雖然 sync.WaitGroup 的源碼很少,可卻因為裏面使用了移位操作和一些邊界條件的檢查,使其不太容易理解。為此,我專門畫了一副 sync.WaitGroup 三大方法的執行流程圖,來助你分析 sync.WaitGroup 各個方法的執行流程和關聯關係。

流程圖如下:

image.png

Done 方法沒什麼好解釋的,等價於 Add(-1)

Add 方法在第一次出現 return 之前的代碼(即 檢查 counter 大於 0,或 waiter 等於 0),其實可以看作是增加計數器的功能,即 delta 值大於 0 的情況;而接下來的代碼,則可以看作是調用 Done 方法,減少計數器的功能,即 delta 值小於 0 的情況。當計數器的值為 0,就會喚醒所有 waiter

Wait 方法則用來管理 waiter,並阻塞 waiter,等待被 Add 方法喚醒。

如果你對上面的源碼分析理解還覺得有點不夠透徹,可以對照這幅圖,多梳理幾遍。看懂了這幅圖,那麼你就完全掌握了 sync.WaitGroup

本文示例源碼我都放在了 GitHub 中,歡迎點擊查看。

希望此文能對你有所啓發。

聯繫我

  • 公眾號:Go編程世界
  • 微信:jianghushinian
  • 郵箱:jianghushinian007@outlook.com
  • 博客:https://jianghushinian.cn
  • GitHub:https://github.com/jianghushinian

Add a new 评论

Some HTML is okay.