基础

使用

WaitGroup 可以理解为 Wait-Goroutine-Group,等待一组 goroutine 结束,如果某个 goroutine 需要等待其他 goroutine 全部完成,那么使用 WaitGroup 可以轻松搞定

A WaitGroup waits for a collection of goroutines to finish. The main goroutine calls Add to set the number of goroutines to wait for. Then each of the goroutines runs and calls Done when finished. At the same time, Wait can be used to block until all goroutines have finished.

must not be copied after first use

https://play.golang.org/p/Bjt5YKF2cN7

package main

import (
	"fmt"
	"sync"
)

func main() {

	var wg sync.WaitGroup

	wg.Add(1)

	go func() {
		fmt.Println("ok")
		wg.Done()
	}()

	wg.Wait()
}

源码

race 的代码直接省略

WaitGroup 结构定义如下,

// A WaitGroup must not be copied after first use.
type WaitGroup struct {
	noCopy noCopy

	// 64-bit value: high 32 bits are counter, low 32 bits are waiter count.
	// 64-bit atomic operations require 64-bit alignment, but 32-bit
	// compilers do not ensure it. So we allocate 12 bytes and then use
	// the aligned 8 bytes in them as state, and the other 4 as storage
	// for the sema.
	state1 [3]uint32
}
  • noCopy 是为了保证结构体首次使用后不会被复制,会用于 go vet 检测,https://golang.org/issues/8005#issuecomment-190753527
  • state1 是个长度为3的数组,其中包含了 state 和 semaphore 信号量,而 state 实际上是两个计数器:
    • counter: 当前还未执行结束的goroutine计数器
    • waiter count: 等待goroutine-group结束的goroutine数量,即有多少个等候者
    • semaphore: 信号量
  • state1 对于 64 位系统,高 32 位为 counter,低 32 位为 waiter 计数,64 位系统的原子操作需要 64 位的对齐,但是 32 位系统不能保证。所以,分配了 12 byte 对齐的 8 个 byte 作为状态,然后用剩下的 4 byte 作为 semaphore 信号量的存储
  • state1 的存储都是用 3*32 位(12 byte)存储,但是不同系统存储的格式不一样。

state1 还涉及到不同位数的操作系统原子对齐问题,所以可以简化为下面的结构。

最多支持协程的数量是由 semaphore 信号量决定的,而 semaphore 是 4 byte,所以最多可以支持 232(即 24*8)个协程。

WaitGroup 主要由三个方法 Add()、Done() 和 Wait() ,但是 Done() 本质就是 Add(-1),所以主要是 Add() 和 Wait() 方法,而这两个方法都涉及了 wg.state() 方法

state

// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
		// 32 位系统
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}
  • state() 是从 wg.state1 中获取 state 和 semaphore 的,根据指针地址是否能被 8 整除来判断是否内存对齐,判断是 64 位系统还是 32 位系统。因为在不同位数的操作系统内 state 和 semaphore 存储的结构是不一样的。
  • unsafe.Pointer 是不能进行计算的,所以需要转变成 uintptr 进行取余。
                    32 位 系 统
+---------------+-------------------+-----------------+
|    counter    |      waiter       |    semaphore    |
+---------------+-------------------+-----------------+

                    64 位 系 统
+---------------+-------------------+-----------------+
|    semaphore  |     counter       |      waiter     |
+---------------+-------------------+-----------------+

Add

Add() 功能是添加计数(可以是负数)到 WaitGroup 的计数中

  • 主要功能:
    • 把 delta 累加到 counter
    • 当 counter 为负值时 panic
    • 为 0 时根据 waiter 数值释放等量的 goroutine
  • 注意事项:
    • 如果添加正数则必须要在 wait 之前
    • 但是可以在任意时刻添加负数或者 0
    • 所以最好在创建 goroutine 或者其他需要等待的事件之前 Add
    • 如果想重用 WaitGroup 则一定要在 Wait 完成后再执行 Add 方法
func (wg *WaitGroup) Add(delta int) {

	// 获得 state 和 semaphore 的地址
	statep, semap := wg.state()

	// 把 delta 累加到 counter
	state := atomic.AddUint64(statep, uint64(delta)<<32)

	v := int32(state >> 32) // 获得 counter
	w := uint32(state)      // 获得 waiter

	// state 之前是 64 位,然后强转为 32 位,所以会把高 32 舍弃,正好可以直接获得低 32 位
	// 可以参考 https://play.golang.org/p/HJSJmT10kVx

	// 增加 delta 后,如果 counter 是负值则 panic
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}

	// 判断是否并发调用
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}

	// 此时 counter >= 0
	// 如果 counter 大于 0 说明是添加成功,不需要释放信号
	// waiter 等于 0 则说明没有等待者则直接释放退出
	if v > 0 || w == 0 {
		return
	}

	// 此刻,counter 肯定等于 0,waiter 一定大于 0 (waiter 内部维护一定大于等于 0)
	if *statep != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}

	// 把 counter 置为 0,再释放与 waiter 数量相等的信号量
	*statep = 0
	for ; w != 0; w-- {
		runtime_Semrelease(semap, false) // 释放 semap 信号量
	}
}

Wait

Wait()方法也做了两件事,一是累加waiter, 二是阻塞等待信号量

func (wg *WaitGroup) Wait() {
	// 和 Add 一样获得 state 和 semaphore 的地址
	statep, semap := wg.state()

	for {
		// 获得 state 的具体地址
		state := atomic.LoadUint64(statep)

		v := int32(state >> 32) // 获得 counter
		w := uint32(state)      // 获得 waiter
		if v == 0 {
			// 如果 counter 为 0说明不需要阻塞,直接退出
			return
		}

		// 使用CAS(比较交换算法)累加 waiter,累加可能会失败,失败后通过for loop下次重试
		// CAS算法保证有多个 goroutine 同时执行Wait()时也能正确累加waiter
		if atomic.CompareAndSwapUint64(statep, state, state+1) {

			runtime_Semacquire(semap) // 累加信号,相当于订阅

			if *statep != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			return
		}
	}
}

参考资料