zoukankan      html  css  js  c++  java
  • Golang 实现 Redis(6): 实现 pipeline 模式的 redis 客户端

    本文是使用 golang 实现 redis 系列的第六篇, 将介绍如何实现一个 Pipeline 模式的 Redis 客户端。

    本文的完整代码在Github:Godis/redis/client

    通常 TCP 客户端的通信模式都是阻塞式的: 客户端发送请求 -> 等待服务端响应 -> 发送下一个请求。因为需要等待网络传输数据,完成一次请求循环需要等待较多时间。

    我们能否不等待服务端响应直接发送下一条请求呢?答案是肯定的。

    TCP 作为全双工协议可以同时进行上行和下行通信,不必担心客户端和服务端同时发包会导致冲突。

    p.s. 打电话的时候两个人同时讲话就会冲突听不清,只能轮流讲。这种通信方式称为半双工。广播只能由电台发送到收音机不能反向传输,这种方式称为单工。

    我们为每一个 tcp 连接分配了一个 goroutine 可以保证先收到的请求先先回复。另一个方面,tcp 协议会保证数据流的有序性,同一个 tcp 连接上先发送的请求服务端先接收,先回复的响应客户端先收到。因此我们不必担心混淆响应所对应的请求。

    这种在服务端未响应时客户端继续向服务端发送请求的模式称为 Pipeline 模式。因为减少等待网络传输的时间,Pipeline 模式可以极大的提高吞吐量,减少所需使用的 tcp 链接数。

    pipeline 模式的 redis 客户端需要有两个后台协程程负责 tcp 通信,调用方通过 channel 向后台协程发送指令,并阻塞等待直到收到响应,这是一个典型的异步编程模式。

    我们先来定义 client 的结构:

    type Client struct {
        conn        net.Conn // 与服务端的 tcp 连接
        sendingReqs chan *Request // 等待发送的请求
        waitingReqs chan *Request // 等待服务器响应的请求
        ticker      *time.Ticker // 用于触发心跳包的计时器
        addr        string
    
        ctx        context.Context
        cancelFunc context.CancelFunc
        writing    *sync.WaitGroup // 有请求正在处理不能立即停止,用于实现 graceful shutdown
    }
    
    type Request struct {
        id        uint64 // 请求id
        args      [][]byte // 上行参数
        reply     redis.Reply // 收到的返回值
        heartbeat bool // 标记是否是心跳请求
        waiting   *wait.Wait // 调用协程发送请求后通过 waitgroup 等待请求异步处理完成
        err       error
    }
    

    调用者将请求发送给后台协程,并通过 wait group 等待异步处理完成:

    func (client *Client) Send(args [][]byte) redis.Reply {
        request := &Request{
            args:      args,
            heartbeat: false,
            waiting:   &wait.Wait{},
        }
        request.waiting.Add(1) 
        client.sendingReqs <- request // 将请求发往处理队列
        timeout := request.waiting.WaitWithTimeout(maxWait) // 等待请求处理完成或者超时
        if timeout {
            return reply.MakeErrReply("server time out")
        }
        if request.err != nil {
            return reply.MakeErrReply("request failed: " + err.Error())
        }
        return request.reply
    }
    

    client 的核心部分是后台的读写协程。先从写协程开始:

    // 写协程入口
    func (client *Client) handleWrite() {
    loop:
        for {
            select {
            case req := <-client.sendingReqs: // 从 channel 中取出请求
                client.writing.Add(1) // 未完成请求数+1
                client.doRequest(req) // 发送请求
            case <-client.ctx.Done():
                break loop
            }
        }
    }
    
    // 发送请求
    func (client *Client) doRequest(req *Request) {
        bytes := reply.MakeMultiBulkReply(req.args).ToBytes() // 序列化
        _, err := client.conn.Write(bytes) // 通过 tcp connection 发送
        i := 0
        for err != nil && i < 3 { // 失败重试
            err = client.handleConnectionError(err) 
            if err == nil {
                _, err = client.conn.Write(bytes)
            }
            i++
        }
        if err == nil {
            client.waitingReqs <- req // 将发送成功的请求放入等待响应的队列
        } else {
            // 发送失败
            req.err = err
            req.waiting.Done() // 结束调用者的等待
            client.writing.Done() // 未完成请求数 -1
        }
    }
    

    读协程是我们熟悉的协议解析器模板, 不熟悉的朋友可以到实现 Redis 协议解析器了解更多。

    // 收到服务端的响应
    func (client *Client) finishRequest(reply redis.Reply) {
        request := <-client.waitingReqs // 取出等待响应的 request
        request.reply = reply
        if request.waiting != nil {
            request.waiting.Done() // 结束调用者的等待
        }
        client.writing.Done() // 未完成请求数-1
    }
    
    // 读协程是个 RESP 协议解析器,不熟悉的朋友可以
    func (client *Client) handleRead() error {
        reader := bufio.NewReader(client.conn)
        downloading := false
        expectedArgsCount := 0
        receivedCount := 0
        msgType := byte(0) // first char of msg
        var args [][]byte
        var fixedLen int64 = 0
        var err error
        var msg []byte
        for {
            // read line
            if fixedLen == 0 { // read normal line
                msg, err = reader.ReadBytes('
    ')
                if err != nil {
                    if err == io.EOF || err == io.ErrUnexpectedEOF {
                        logger.Info("connection close")
                    } else {
                        logger.Warn(err)
                    }
    
                    return errors.New("connection closed")
                }
                if len(msg) == 0 || msg[len(msg)-2] != '
    ' {
                    return errors.New("protocol error")
                }
            } else { // read bulk line (binary safe)
                msg = make([]byte, fixedLen+2)
                _, err = io.ReadFull(reader, msg)
                if err != nil {
                    if err == io.EOF || err == io.ErrUnexpectedEOF {
                        return errors.New("connection closed")
                    } else {
                        return err
                    }
                }
                if len(msg) == 0 ||
                    msg[len(msg)-2] != '
    ' ||
                    msg[len(msg)-1] != '
    ' {
                    return errors.New("protocol error")
                }
                fixedLen = 0
            }
    
            // parse line
            if !downloading {
                // receive new response
                if msg[0] == '*' { // multi bulk response
                    // bulk multi msg
                    expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
                    if err != nil {
                        return errors.New("protocol error: " + err.Error())
                    }
                    if expectedLine == 0 {
                        client.finishRequest(&reply.EmptyMultiBulkReply{})
                    } else if expectedLine > 0 {
                        msgType = msg[0]
                        downloading = true
                        expectedArgsCount = int(expectedLine)
                        receivedCount = 0
                        args = make([][]byte, expectedLine)
                    } else {
                        return errors.New("protocol error")
                    }
                } else if msg[0] == '$' { // bulk response
                    fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
                    if err != nil {
                        return err
                    }
                    if fixedLen == -1 { // null bulk
                        client.finishRequest(&reply.NullBulkReply{})
                        fixedLen = 0
                    } else if fixedLen > 0 {
                        msgType = msg[0]
                        downloading = true
                        expectedArgsCount = 1
                        receivedCount = 0
                        args = make([][]byte, 1)
                    } else {
                        return errors.New("protocol error")
                    }
                } else { // single line response
                    str := strings.TrimSuffix(string(msg), "
    ")
                    str = strings.TrimSuffix(str, "
    ")
                    var result redis.Reply
                    switch msg[0] {
                    case '+':
                        result = reply.MakeStatusReply(str[1:])
                    case '-':
                        result = reply.MakeErrReply(str[1:])
                    case ':':
                        val, err := strconv.ParseInt(str[1:], 10, 64)
                        if err != nil {
                            return errors.New("protocol error")
                        }
                        result = reply.MakeIntReply(val)
                    }
                    client.finishRequest(result)
                }
            } else {
                // receive following part of a request
                line := msg[0 : len(msg)-2]
                if line[0] == '$' {
                    fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
                    if err != nil {
                        return err
                    }
                    if fixedLen <= 0 { // null bulk in multi bulks
                        args[receivedCount] = []byte{}
                        receivedCount++
                        fixedLen = 0
                    }
                } else {
                    args[receivedCount] = line
                    receivedCount++
                }
    
                // if sending finished
                if receivedCount == expectedArgsCount {
                    downloading = false // finish downloading progress
    
                    if msgType == '*' {
                        reply := reply.MakeMultiBulkReply(args)
                        client.finishRequest(reply)
                    } else if msgType == '$' {
                        reply := reply.MakeBulkReply(args[0])
                        client.finishRequest(reply)
                    }
    
    
                    // finish reply
                    expectedArgsCount = 0
                    receivedCount = 0
                    args = nil
                    msgType = byte(0)
                }
            }
        }
    }
    

    最后编写 client 的构造器和启动异步协程的代码:

    func MakeClient(addr string) (*Client, error) {
        conn, err := net.Dial("tcp", addr)
        if err != nil {
            return nil, err
        }
        ctx, cancel := context.WithCancel(context.Background())
        return &Client{
            addr:        addr,
            conn:        conn,
            sendingReqs: make(chan *Request, chanSize),
            waitingReqs: make(chan *Request, chanSize),
            ctx:         ctx,
            cancelFunc:  cancel,
            writing:     &sync.WaitGroup{},
        }, nil
    }
    
    func (client *Client) Start() {
        client.ticker = time.NewTicker(10 * time.Second)
        go client.handleWrite()
        go func() {
            err := client.handleRead()
            logger.Warn(err)
        }()
        go client.heartbeat()
    }
    

    关闭 client 的时候记得等待请求完成:

    func (client *Client) Close() {
        // 先阻止新请求进入队列
        close(client.sendingReqs)
    
        // 等待处理中的请求完成
        client.writing.Wait()
    
        // 释放资源
        _ = client.conn.Close() // 关闭与服务端的连接,连接关闭后读协程会退出
        client.cancelFunc() // 使用 context 关闭读协程
        close(client.waitingReqs) // 关闭队列
    }
    

    测试一下:

    func TestClient(t *testing.T) {
        client, err := MakeClient("localhost:6379")
        if err != nil {
            t.Error(err)
        }
        client.Start()
    
        result = client.Send([][]byte{
            []byte("SET"),
            []byte("a"),
            []byte("a"),
        })
        if statusRet, ok := result.(*reply.StatusReply); ok {
            if statusRet.Status != "OK" {
                t.Error("`set` failed, result: " + statusRet.Status)
            }
        }
    
        result = client.Send([][]byte{
            []byte("GET"),
            []byte("a"),
        })
        if bulkRet, ok := result.(*reply.BulkReply); ok {
            if string(bulkRet.Arg) != "a" {
                t.Error("`get` failed, result: " + string(bulkRet.Arg))
            }
        }
    }
    
  • 相关阅读:
    自我介绍 x
    第一次作业 x
    第二次作业 x
    第三次作业 x
    [C#] 用一种更优美的方式来替换掉又多又长的switchcase代码段
    通过设置光标形状实现拖拽控件时跟随一张透明图片的效果
    spring 入门笔记(一)
    PAT IO01. 表格输出(5)
    Maven 安装记
    华为机试 求最大三位数
  • 原文地址:https://www.cnblogs.com/Finley/p/14028402.html
Copyright © 2011-2022 走看看