zoukankan      html  css  js  c++  java
  • Golang 实现 Redis(6): 实现 pipeline 模式的 redis 客户端

    本文是使用 golang 实现 redis 系列的第六篇, 将介绍如何实现一个 Pipeline 模式的 Redis 客户端。

    本文的完整代码在Github:Godis/redis/client

    通常 TCP 客户端的通信模式都是阻塞式的: 客户端发送请求 -> 等待服务端响应 -> 发送下一个请求。因为需要等待网络传输数据,完成一次请求循环需要等待较多时间。

    我们能否不等待服务端响应直接发送下一条请求呢?答案是肯定的。

    TCP 作为全双工协议可以同时进行上行和下行通信,不必担心客户端和服务端同时发包会导致冲突。

    p.s. 打电话的时候两个人同时讲话就会冲突听不清,只能轮流讲。这种通信方式称为半双工。广播只能由电台发送到收音机不能反向传输,这种方式称为单工。

    我们为每一个 tcp 连接分配了一个 goroutine 可以保证先收到的请求先先回复。另一个方面,tcp 协议会保证数据流的有序性,同一个 tcp 连接上先发送的请求服务端先接收,先回复的响应客户端先收到。因此我们不必担心混淆响应所对应的请求。

    这种在服务端未响应时客户端继续向服务端发送请求的模式称为 Pipeline 模式。因为减少等待网络传输的时间,Pipeline 模式可以极大的提高吞吐量,减少所需使用的 tcp 链接数。

    pipeline 模式的 redis 客户端需要有两个后台协程程负责 tcp 通信,调用方通过 channel 向后台协程发送指令,并阻塞等待直到收到响应,这是一个典型的异步编程模式。

    我们先来定义 client 的结构:

    type Client struct {
        conn        net.Conn // 与服务端的 tcp 连接
        sendingReqs chan *Request // 等待发送的请求
        waitingReqs chan *Request // 等待服务器响应的请求
        ticker      *time.Ticker // 用于触发心跳包的计时器
        addr        string
    
        ctx        context.Context
        cancelFunc context.CancelFunc
        writing    *sync.WaitGroup // 有请求正在处理不能立即停止,用于实现 graceful shutdown
    }
    
    type Request struct {
        id        uint64 // 请求id
        args      [][]byte // 上行参数
        reply     redis.Reply // 收到的返回值
        heartbeat bool // 标记是否是心跳请求
        waiting   *wait.Wait // 调用协程发送请求后通过 waitgroup 等待请求异步处理完成
        err       error
    }
    

    调用者将请求发送给后台协程,并通过 wait group 等待异步处理完成:

    func (client *Client) Send(args [][]byte) redis.Reply {
        request := &Request{
            args:      args,
            heartbeat: false,
            waiting:   &wait.Wait{},
        }
        request.waiting.Add(1) 
        client.sendingReqs <- request // 将请求发往处理队列
        timeout := request.waiting.WaitWithTimeout(maxWait) // 等待请求处理完成或者超时
        if timeout {
            return reply.MakeErrReply("server time out")
        }
        if request.err != nil {
            return reply.MakeErrReply("request failed: " + err.Error())
        }
        return request.reply
    }
    

    client 的核心部分是后台的读写协程。先从写协程开始:

    // 写协程入口
    func (client *Client) handleWrite() {
    loop:
        for {
            select {
            case req := <-client.sendingReqs: // 从 channel 中取出请求
                client.writing.Add(1) // 未完成请求数+1
                client.doRequest(req) // 发送请求
            case <-client.ctx.Done():
                break loop
            }
        }
    }
    
    // 发送请求
    func (client *Client) doRequest(req *Request) {
        bytes := reply.MakeMultiBulkReply(req.args).ToBytes() // 序列化
        _, err := client.conn.Write(bytes) // 通过 tcp connection 发送
        i := 0
        for err != nil && i < 3 { // 失败重试
            err = client.handleConnectionError(err) 
            if err == nil {
                _, err = client.conn.Write(bytes)
            }
            i++
        }
        if err == nil {
            client.waitingReqs <- req // 将发送成功的请求放入等待响应的队列
        } else {
            // 发送失败
            req.err = err
            req.waiting.Done() // 结束调用者的等待
            client.writing.Done() // 未完成请求数 -1
        }
    }
    

    读协程是我们熟悉的协议解析器模板, 不熟悉的朋友可以到实现 Redis 协议解析器了解更多。

    // 收到服务端的响应
    func (client *Client) finishRequest(reply redis.Reply) {
        request := <-client.waitingReqs // 取出等待响应的 request
        request.reply = reply
        if request.waiting != nil {
            request.waiting.Done() // 结束调用者的等待
        }
        client.writing.Done() // 未完成请求数-1
    }
    
    // 读协程是个 RESP 协议解析器,不熟悉的朋友可以
    func (client *Client) handleRead() error {
        reader := bufio.NewReader(client.conn)
        downloading := false
        expectedArgsCount := 0
        receivedCount := 0
        msgType := byte(0) // first char of msg
        var args [][]byte
        var fixedLen int64 = 0
        var err error
        var msg []byte
        for {
            // read line
            if fixedLen == 0 { // read normal line
                msg, err = reader.ReadBytes('
    ')
                if err != nil {
                    if err == io.EOF || err == io.ErrUnexpectedEOF {
                        logger.Info("connection close")
                    } else {
                        logger.Warn(err)
                    }
    
                    return errors.New("connection closed")
                }
                if len(msg) == 0 || msg[len(msg)-2] != '
    ' {
                    return errors.New("protocol error")
                }
            } else { // read bulk line (binary safe)
                msg = make([]byte, fixedLen+2)
                _, err = io.ReadFull(reader, msg)
                if err != nil {
                    if err == io.EOF || err == io.ErrUnexpectedEOF {
                        return errors.New("connection closed")
                    } else {
                        return err
                    }
                }
                if len(msg) == 0 ||
                    msg[len(msg)-2] != '
    ' ||
                    msg[len(msg)-1] != '
    ' {
                    return errors.New("protocol error")
                }
                fixedLen = 0
            }
    
            // parse line
            if !downloading {
                // receive new response
                if msg[0] == '*' { // multi bulk response
                    // bulk multi msg
                    expectedLine, err := strconv.ParseUint(string(msg[1:len(msg)-2]), 10, 32)
                    if err != nil {
                        return errors.New("protocol error: " + err.Error())
                    }
                    if expectedLine == 0 {
                        client.finishRequest(&reply.EmptyMultiBulkReply{})
                    } else if expectedLine > 0 {
                        msgType = msg[0]
                        downloading = true
                        expectedArgsCount = int(expectedLine)
                        receivedCount = 0
                        args = make([][]byte, expectedLine)
                    } else {
                        return errors.New("protocol error")
                    }
                } else if msg[0] == '$' { // bulk response
                    fixedLen, err = strconv.ParseInt(string(msg[1:len(msg)-2]), 10, 64)
                    if err != nil {
                        return err
                    }
                    if fixedLen == -1 { // null bulk
                        client.finishRequest(&reply.NullBulkReply{})
                        fixedLen = 0
                    } else if fixedLen > 0 {
                        msgType = msg[0]
                        downloading = true
                        expectedArgsCount = 1
                        receivedCount = 0
                        args = make([][]byte, 1)
                    } else {
                        return errors.New("protocol error")
                    }
                } else { // single line response
                    str := strings.TrimSuffix(string(msg), "
    ")
                    str = strings.TrimSuffix(str, "
    ")
                    var result redis.Reply
                    switch msg[0] {
                    case '+':
                        result = reply.MakeStatusReply(str[1:])
                    case '-':
                        result = reply.MakeErrReply(str[1:])
                    case ':':
                        val, err := strconv.ParseInt(str[1:], 10, 64)
                        if err != nil {
                            return errors.New("protocol error")
                        }
                        result = reply.MakeIntReply(val)
                    }
                    client.finishRequest(result)
                }
            } else {
                // receive following part of a request
                line := msg[0 : len(msg)-2]
                if line[0] == '$' {
                    fixedLen, err = strconv.ParseInt(string(line[1:]), 10, 64)
                    if err != nil {
                        return err
                    }
                    if fixedLen <= 0 { // null bulk in multi bulks
                        args[receivedCount] = []byte{}
                        receivedCount++
                        fixedLen = 0
                    }
                } else {
                    args[receivedCount] = line
                    receivedCount++
                }
    
                // if sending finished
                if receivedCount == expectedArgsCount {
                    downloading = false // finish downloading progress
    
                    if msgType == '*' {
                        reply := reply.MakeMultiBulkReply(args)
                        client.finishRequest(reply)
                    } else if msgType == '$' {
                        reply := reply.MakeBulkReply(args[0])
                        client.finishRequest(reply)
                    }
    
    
                    // finish reply
                    expectedArgsCount = 0
                    receivedCount = 0
                    args = nil
                    msgType = byte(0)
                }
            }
        }
    }
    

    最后编写 client 的构造器和启动异步协程的代码:

    func MakeClient(addr string) (*Client, error) {
        conn, err := net.Dial("tcp", addr)
        if err != nil {
            return nil, err
        }
        ctx, cancel := context.WithCancel(context.Background())
        return &Client{
            addr:        addr,
            conn:        conn,
            sendingReqs: make(chan *Request, chanSize),
            waitingReqs: make(chan *Request, chanSize),
            ctx:         ctx,
            cancelFunc:  cancel,
            writing:     &sync.WaitGroup{},
        }, nil
    }
    
    func (client *Client) Start() {
        client.ticker = time.NewTicker(10 * time.Second)
        go client.handleWrite()
        go func() {
            err := client.handleRead()
            logger.Warn(err)
        }()
        go client.heartbeat()
    }
    

    关闭 client 的时候记得等待请求完成:

    func (client *Client) Close() {
        // 先阻止新请求进入队列
        close(client.sendingReqs)
    
        // 等待处理中的请求完成
        client.writing.Wait()
    
        // 释放资源
        _ = client.conn.Close() // 关闭与服务端的连接,连接关闭后读协程会退出
        client.cancelFunc() // 使用 context 关闭读协程
        close(client.waitingReqs) // 关闭队列
    }
    

    测试一下:

    func TestClient(t *testing.T) {
        client, err := MakeClient("localhost:6379")
        if err != nil {
            t.Error(err)
        }
        client.Start()
    
        result = client.Send([][]byte{
            []byte("SET"),
            []byte("a"),
            []byte("a"),
        })
        if statusRet, ok := result.(*reply.StatusReply); ok {
            if statusRet.Status != "OK" {
                t.Error("`set` failed, result: " + statusRet.Status)
            }
        }
    
        result = client.Send([][]byte{
            []byte("GET"),
            []byte("a"),
        })
        if bulkRet, ok := result.(*reply.BulkReply); ok {
            if string(bulkRet.Arg) != "a" {
                t.Error("`get` failed, result: " + string(bulkRet.Arg))
            }
        }
    }
    
  • 相关阅读:
    LeetCode题解之Flipping an Image
    LeetCode 之Find Minimum in Rotated Sorted Array
    LeetCode题解Transpose Matrix
    LeetCode 题解之Minimum Index Sum of Two Lists
    LeetCode题解之Intersection of Two Linked Lists
    LeetCode 题解之Add Two Numbers II
    LeetCode题解之Add two numbers
    href="#"与href="javascript:void(0)"的区别
    有关ie9 以下不支持placeholder属性以及获得焦点placeholder的移除
    ie7下属性书写不规范造成的easyui 弹窗布局紊乱
  • 原文地址:https://www.cnblogs.com/Finley/p/14028402.html
Copyright © 2011-2022 走看看