golang中的errgroup

0.1、索引

https://waterflow.link/articles/

专注于为中小企业提供做网站、成都做网站服务,电脑端+手机端+微信端的三站合一,更高效的管理,为中小企业陕西免费做网站提供优质的服务。我们立足成都,凝聚了一批互联网行业人才,有力地推动了数千家企业的稳健成长,帮助中小企业通过网站建设实现规模扩充和转变。

1、串行执行

假如我们需要查询一个课件列表,其中有课件的信息,还有课件创建者的信息,和课件的缩略图信息。但是此时我们已经对服务做了拆分,假设有课件服务用户服务还有文件服务

我们通常的做法是,当我们查询课件列表时,我们首先调用课件服务,比如查询10条课件记录,然后获取到课件的创建人ID,课件的缩略图ID;再通过这些创建人ID去用户服务查询用户信息,通过缩略图ID去文件服务查询文件信息;然后再写到这10条课件记录中返回给前端。

像下面这样:

package main

import (
	"fmt"
	"time"
)

type Courseware struct {
	Id         int64
	Name       string
	Code       string
	CreateId   int64
	CreateName string
	CoverId   int64
	CoverPath string
}

type User struct {
	Id   int64
	Name string
}

type File struct {
	Id   int64
	Path string
}

var coursewares []Courseware
var users map[int64]User
var files map[int64]File
var err error

func main() {
	// 查询课件
	coursewares, err = CoursewareList()
	if err != nil {
		fmt.Println("获取课件错误")
		return
	}

	// 获取用户ID、文件ID
	userIds := make([]int64, 0)
	fileIds := make([]int64, 0)
	for _, courseware := range coursewares {
		userIds = append(userIds, courseware.CreateId)
		fileIds = append(fileIds, courseware.CoverId)
	}

	// 批量获取用户信息
	users, err = UserMap(userIds)
	if err != nil {
		fmt.Println("获取用户错误")
		return
	}

	// 批量获取文件信息
	files, err = FileMap(fileIds)
	if err != nil {
		fmt.Println("获取文件错误")
		return
	}

	// 填充
	for i, courseware := range coursewares {
		if user, ok := users[courseware.CreateId]; ok {
			coursewares[i].CreateName = user.Name
		}

		if file, ok := files[courseware.CoverId]; ok {
			coursewares[i].CoverPath = file.Path
		}
	}
	fmt.Println(coursewares)
}

func UserMap(ids []int64) (map[int64]User, error) {
	time.Sleep(3 * time.Second) // 模拟数据库请求
	return map[int64]User{
		1: {Id: 1, Name: "liu"},
		2: {Id: 2, Name: "kang"},
	}, nil
}

func FileMap(ids []int64) (map[int64]File, error) {
	time.Sleep(3 * time.Second) // 模拟数据库请求
	return map[int64]File{
		1: {Id: 1, Path: "/a/b/c.jpg"},
		2: {Id: 2, Path: "/a/b/c/d.jpg"},
	}, nil
}

func CoursewareList() ([]Courseware, error) {
	time.Sleep(3 * time.Second)
	return []Courseware{
		{Id: 1, Name: "课件1", Code: "CW1", CreateId: 1, CreateName: "", CoverId: 1, CoverPath: ""},
		{Id: 2, Name: "课件2", Code: "CW2", CreateId: 2, CreateName: "", CoverId: 2, CoverPath: ""},
	}, nil
}

2、并发执行

但我们获取课件之后,填充用户信息和文件信息是可以并行执行的,我们可以修改获取用户和文件的代码,把他们放到协程里面,这样就可以并行执行了:

...

	// 此处放到协程里
	go func() {
		// 批量获取用户信息
		users, err = UserMap(userIds)
		if err != nil {
			fmt.Println("获取用户错误")
			return
		}
	}()

	// 此处放到协程里
	go func() {
		// 批量获取文件信息
		files, err = FileMap(fileIds)
		if err != nil {
			fmt.Println("获取文件错误")
			return
		}
	}()

	...

但是当你执行的时候你会发现这样是有问题的,因为下面的填充数据的代码有可能会在这两个协程执行完成之前去执行。也就是说最终的数据有可能没有填充用户信息和文件信息。那怎么办呢?这是我们就可以使用golang的waitgroup了,主要作用就是协程的编排。

我们可以等2个协程都执行完成再去走下面的填充逻辑

我们继续修改代码成下面的样子

...

// 初始化一个sync.WaitGroup
var wg sync.WaitGroup

func main() {
	// 查询课件
	...
	// 获取用户ID、文件ID
	...

	// 此处放到协程里
	wg.Add(1) // 计数器+1
	go func() {
		defer wg.Done() // 计数器-1
		// 批量获取用户信息
		users, err = UserMap(userIds)
		if err != nil {
			fmt.Println("获取用户错误")
			return
		}
	}()

	// 此处放到协程里
	wg.Add(1) // 计数器+1
	go func() {
		defer wg.Done() // 计数器-1
		// 批量获取文件信息
		files, err = FileMap(fileIds)
		if err != nil {
			fmt.Println("获取文件错误")
			return
		}
	}()

  // 阻塞等待计数器小于等于0
	wg.Wait()

	// 填充
	for i, courseware := range coursewares {
		if user, ok := users[courseware.CreateId]; ok {
			coursewares[i].CreateName = user.Name
		}

		if file, ok := files[courseware.CoverId]; ok {
			coursewares[i].CoverPath = file.Path
		}
	}
	fmt.Println(coursewares)
}

...

我们初始化一个sync.WaitGroup,调用wg.Add(1)给计数器加一,调用wg.Done()计数器减一,wg.Wait()阻塞等待直到计数器小于等于0,结束阻塞,继续往下执行。

3、errgroup

但是我们现在又有这样的需求,我们希望如果获取用户或者获取文件有任何一方报错了,直接抛错,不再组装数据。

我们可以像下面这样写

...

var goErr error
var wg sync.WaitGroup

...

func main() {
	...

	// 此处放到协程里
	wg.Add(1)
	go func() {
		defer wg.Done()
		// 批量获取用户信息
		users, err = UserMap(userIds)
		if err != nil {
			goErr = err
			fmt.Println("获取用户错误:", err)
			return
		}
	}()

	// 此处放到协程里
	wg.Add(1)
	go func() {
		defer wg.Done()
		// 批量获取文件信息
		files, err = FileMap(fileIds)
		if err != nil {
			goErr = err
			fmt.Println("获取文件错误:", err)
			return
		}
	}()

	wg.Wait()

	if goErr != nil {
		fmt.Println("goroutine err:", err)
		return
	}

	...
}

...

把错误放在goErr中,结束阻塞后判断协程调用是否抛错。

那golang里面有没有类似这样的实现呢?答案是有的,那就是errgroup。其实和我们上面的方法差不多,但是errgroup包做了一层结构体的封装,也不需要在每个协程里面判断error传给errGo了。

下面是errgroup的实现

package main

import (
	"errors"
	"fmt"
	"golang.org/x/sync/errgroup"
	"time"
)

type Courseware struct {
	Id         int64
	Name       string
	Code       string
	CreateId   int64
	CreateName string
	CoverId   int64
	CoverPath string
}

type User struct {
	Id   int64
	Name string
}

type File struct {
	Id   int64
	Path string
}

var coursewares []Courseware
var users map[int64]User
var files map[int64]File
var err error
// 定义一个errgroup
var eg errgroup.Group

func main() {
	// 查询课件
	coursewares, err = CoursewareList()
	if err != nil {
		fmt.Println("获取课件错误:", err)
		return
	}

	// 获取用户ID、文件ID
	userIds := make([]int64, 0)
	fileIds := make([]int64, 0)
	for _, courseware := range coursewares {
		userIds = append(userIds, courseware.CreateId)
		fileIds = append(fileIds, courseware.CoverId)
	}


	// 此处放到协程里
	eg.Go(func() error {
		// 批量获取用户信息
		users, err = UserMap(userIds)
		if err != nil {
			fmt.Println("获取用户错误:", err)
			return err
		}
		return nil
	})

	// 此处放到协程里
	eg.Go(func() error {
		// 批量获取文件信息
		files, err = FileMap(fileIds)
		if err != nil {
			fmt.Println("获取文件错误:", err)
			return err
		}
		return nil
	})

  // 判断group中是否有报错
	if goErr := eg.Wait(); goErr != nil {
		fmt.Println("goroutine err:", err)
		return
	}

	// 填充
	for i, courseware := range coursewares {
		if user, ok := users[courseware.CreateId]; ok {
			coursewares[i].CreateName = user.Name
		}

		if file, ok := files[courseware.CoverId]; ok {
			coursewares[i].CoverPath = file.Path
		}
	}
	fmt.Println(coursewares)
}

func UserMap(ids []int64) (map[int64]User, error) {
	time.Sleep(3 * time.Second)
	return map[int64]User{
		1: {Id: 1, Name: "liu"},
		2: {Id: 2, Name: "kang"},
	}, errors.New("sql err")
}

func FileMap(ids []int64) (map[int64]File, error) {
	time.Sleep(3 * time.Second)
	return map[int64]File{
		1: {Id: 1, Path: "/a/b/c.jpg"},
		2: {Id: 2, Path: "/a/b/c/d.jpg"},
	}, nil
}

func CoursewareList() ([]Courseware, error) {
	time.Sleep(3 * time.Second)
	return []Courseware{
		{Id: 1, Name: "课件1", Code: "CW1", CreateId: 1, CreateName: "", CoverId: 1, CoverPath: ""},
		{Id: 2, Name: "课件2", Code: "CW2", CreateId: 2, CreateName: "", CoverId: 2, CoverPath: ""},
	}, nil
}

当然,errgroup中也有针对上下文的errgroup.WithContext函数,如果我们想控制请求接口的时间,用这个是最合适不过的。如果请求超时会返回一个关闭上下文的报错,像下面这样

package main

import (
	"context"
	"fmt"
	"golang.org/x/sync/errgroup"
	"time"
)

type Courseware struct {
	Id         int64
	Name       string
	Code       string
	CreateId   int64
	CreateName string
	CoverId    int64
	CoverPath  string
}

type User struct {
	Id   int64
	Name string
}

type File struct {
	Id   int64
	Path string
}

var coursewares []Courseware
var users map[int64]User
var files map[int64]File
var err error

func main() {
	// 查询课件
	...

	// 获取用户ID、文件ID
	...

  // 定义一个带超时时间的上下文,1秒钟超时
	ctx, cancelFunc := context.WithTimeout(context.Background(), 1*time.Second)
	defer cancelFunc()
  // 定义一个带上下文的errgroup,使用上面带有超时时间的上下文
	eg, ctx := errgroup.WithContext(ctx)
	// 此处放到协程里
	eg.Go(func() error {
		// 批量获取用户信息
		users, err = UserMap(ctx, userIds)
		if err != nil {
			fmt.Println("获取用户错误:", err)
			return err
		}
		return nil
	})

	// 此处放到协程里
	eg.Go(func() error {
		// 批量获取文件信息
		files, err = FileMap(ctx, fileIds)
		if err != nil {
			fmt.Println("获取文件错误:", err)
			return err
		}
		return nil
	})

	if goErr := eg.Wait(); goErr != nil {
		fmt.Println("goroutine err:", err)
		return
	}

	// 填充
	for i, courseware := range coursewares {
		if user, ok := users[courseware.CreateId]; ok {
			coursewares[i].CreateName = user.Name
		}

		if file, ok := files[courseware.CoverId]; ok {
			coursewares[i].CoverPath = file.Path
		}
	}
	fmt.Println(coursewares)
}

func UserMap(ctx context.Context, ids []int64) (map[int64]User, error) {
	result := make(chan map[int64]User)
	go func() {
		time.Sleep(2 * time.Second) // 假装请求超过1秒钟
		result <- map[int64]User{
			1: {Id: 1, Name: "liu"},
			2: {Id: 2, Name: "kang"},
		}
	}()

	select {
	case <-ctx.Done(): // 如果上下文结束直接返回错误信息
		return nil, ctx.Err()
	case res := <-result: // 返回正确结果
		return res, nil
	}
}

func FileMap(ctx context.Context, ids []int64) (map[int64]File, error) {
	return map[int64]File{
		1: {Id: 1, Path: "/a/b/c.jpg"},
		2: {Id: 2, Path: "/a/b/c/d.jpg"},
	}, nil
}

func CoursewareList() ([]Courseware, error) {
	time.Sleep(3 * time.Second)
	return []Courseware{
		{Id: 1, Name: "课件1", Code: "CW1", CreateId: 1, CreateName: "", CoverId: 1, CoverPath: ""},
		{Id: 2, Name: "课件2", Code: "CW2", CreateId: 2, CreateName: "", CoverId: 2, CoverPath: ""},
	}, nil
}

执行上面的代码:

go run waitgroup.go
获取用户错误: context deadline exceeded
goroutine err: context deadline exceeded

文章标题:golang中的errgroup
标题链接:http://csdahua.cn/article/dsoidhg.html
扫二维码与项目经理沟通

我们在微信上24小时期待你的声音

解答本文疑问/技术咨询/运营咨询/技术建议/互联网交流