zoukankan      html  css  js  c++  java
  • protocol_v2.go

    package nsqd

    import (
        "bytes"
        "encoding/binary"
        "encoding/json"
        "errors"
        "fmt"
        "io"
        "math"
        "math/rand"
        "net"
        "sync/atomic"
        "time"
        "unsafe"

        "github.com/nsqio/nsq/internal/protocol"
        "github.com/nsqio/nsq/internal/version"
    )

    const maxTimeout = time.Hour

    const (
        frameTypeResponse int32 = 0
        frameTypeError    int32 = 1
        frameTypeMessage  int32 = 2
    )

    var separatorBytes = []byte(" ")
    var heartbeatBytes = []byte("_heartbeat_")
    var okBytes = []byte("OK")

    type protocolV2 struct {
        ctx *context
    }

    func (p *protocolV2) IOLoop(conn net.Conn) error {
        var err error
        var line []byte
        var zeroTime time.Time

        clientID := atomic.AddInt64(&p.ctx.nsqd.clientIDSequence, 1)
        client := newClientV2(clientID, conn, p.ctx)

        // synchronize the startup of messagePump in order
        // to guarantee that it gets a chance to initialize
        // goroutine local state derived from client attributes
        // and avoid a potential race with IDENTIFY (where a client
        // could have changed or disabled said attributes)
        messagePumpStartedChan := make(chan bool)
        go p.messagePump(client, messagePumpStartedChan)
        <-messagePumpStartedChan

        for {
            if client.HeartbeatInterval > 0 {
                client.SetReadDeadline(time.Now().Add(client.HeartbeatInterval * 2))
            } else {
                client.SetReadDeadline(zeroTime)
            }

            // ReadSlice does not allocate new space for the data each request
            // ie. the returned slice is only valid until the next call to it
            line, err = client.Reader.ReadSlice('
    ')
            if err != nil {
                if err == io.EOF {
                    err = nil
                } else {
                    err = fmt.Errorf("failed to read command - %s", err)
                }
                break
            }

            // trim the '
    '
            line = line[:len(line)-1]
            // optionally trim the '
    '
            if len(line) > 0 && line[len(line)-1] == '
    ' {
                line = line[:len(line)-1]
            }
            params := bytes.Split(line, separatorBytes)

            if p.ctx.nsqd.getOpts().Verbose {
                p.ctx.nsqd.logf("PROTOCOL(V2): [%s] %s", client, params)
            }

            var response []byte
            response, err = p.Exec(client, params)
            if err != nil {
                ctx := ""
                if parentErr := err.(protocol.ChildErr).Parent(); parentErr != nil {
                    ctx = " - " + parentErr.Error()
                }
                p.ctx.nsqd.logf("ERROR: [%s] - %s%s", client, err, ctx)

                sendErr := p.Send(client, frameTypeError, []byte(err.Error()))
                if sendErr != nil {
                    p.ctx.nsqd.logf("ERROR: [%s] - %s%s", client, sendErr, ctx)
                    break
                }

                // errors of type FatalClientErr should forceably close the connection
                if _, ok := err.(*protocol.FatalClientErr); ok {
                    break
                }
                continue
            }

            if response != nil {
                err = p.Send(client, frameTypeResponse, response)
                if err != nil {
                    err = fmt.Errorf("failed to send response - %s", err)
                    break
                }
            }
        }

        p.ctx.nsqd.logf("PROTOCOL(V2): [%s] exiting ioloop", client)
        conn.Close()
        close(client.ExitChan)
        if client.Channel != nil {
            client.Channel.RemoveClient(client.ID)
        }

        return err
    }

    func (p *protocolV2) SendMessage(client *clientV2, msg *Message, buf *bytes.Buffer) error {
        if p.ctx.nsqd.getOpts().Verbose {
            p.ctx.nsqd.logf("PROTOCOL(V2): writing msg(%s) to client(%s) - %s",
                msg.ID, client, msg.Body)
        }

        buf.Reset()
        _, err := msg.WriteTo(buf)
        if err != nil {
            return err
        }

        err = p.Send(client, frameTypeMessage, buf.Bytes())
        if err != nil {
            return err
        }

        return nil
    }

    func (p *protocolV2) Send(client *clientV2, frameType int32, data []byte) error {
        client.writeLock.Lock()

        var zeroTime time.Time
        if client.HeartbeatInterval > 0 {
            client.SetWriteDeadline(time.Now().Add(client.HeartbeatInterval))
        } else {
            client.SetWriteDeadline(zeroTime)
        }

        _, err := protocol.SendFramedResponse(client.Writer, frameType, data)
        if err != nil {
            client.writeLock.Unlock()
            return err
        }

        if frameType != frameTypeMessage {
            err = client.Flush()
        }

        client.writeLock.Unlock()

        return err
    }

    func (p *protocolV2) Exec(client *clientV2, params [][]byte) ([]byte, error) {
        if bytes.Equal(params[0], []byte("IDENTIFY")) {
            return p.IDENTIFY(client, params)
        }
        err := enforceTLSPolicy(client, p, params[0])
        if err != nil {
            return nil, err
        }
        switch {
        case bytes.Equal(params[0], []byte("FIN")):
            return p.FIN(client, params)
        case bytes.Equal(params[0], []byte("RDY")):
            return p.RDY(client, params)
        case bytes.Equal(params[0], []byte("REQ")):
            return p.REQ(client, params)
        case bytes.Equal(params[0], []byte("PUB")):
            return p.PUB(client, params)
        case bytes.Equal(params[0], []byte("MPUB")):
            return p.MPUB(client, params)
        case bytes.Equal(params[0], []byte("DPUB")):
            return p.DPUB(client, params)
        case bytes.Equal(params[0], []byte("NOP")):
            return p.NOP(client, params)
        case bytes.Equal(params[0], []byte("TOUCH")):
            return p.TOUCH(client, params)
        case bytes.Equal(params[0], []byte("SUB")):
            return p.SUB(client, params)
        case bytes.Equal(params[0], []byte("CLS")):
            return p.CLS(client, params)
        case bytes.Equal(params[0], []byte("AUTH")):
            return p.AUTH(client, params)
        }
        return nil, protocol.NewFatalClientErr(nil, "E_INVALID", fmt.Sprintf("invalid command %s", params[0]))
    }

    func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) {
        var err error
        var buf bytes.Buffer
        var memoryMsgChan chan *Message
        var backendMsgChan chan []byte
        var subChannel *Channel
        // NOTE: `flusherChan` is used to bound message latency for
        // the pathological case of a channel on a low volume topic
        // with >1 clients having >1 RDY counts
        var flusherChan <-chan time.Time
        var sampleRate int32

        subEventChan := client.SubEventChan
        identifyEventChan := client.IdentifyEventChan
        outputBufferTicker := time.NewTicker(client.OutputBufferTimeout)
        heartbeatTicker := time.NewTicker(client.HeartbeatInterval)
        heartbeatChan := heartbeatTicker.C
        msgTimeout := client.MsgTimeout

        // v2 opportunistically buffers data to clients to reduce write system calls
        // we force flush in two cases:
        //    1. when the client is not ready to receive messages
        //    2. we're buffered and the channel has nothing left to send us
        //       (ie. we would block in this loop anyway)
        //
        flushed := true

        // signal to the goroutine that started the messagePump
        // that we've started up
        close(startedChan)

        for {
            if subChannel == nil || !client.IsReadyForMessages() {
                // the client is not ready to receive messages...
                memoryMsgChan = nil
                backendMsgChan = nil
                flusherChan = nil
                // force flush
                client.writeLock.Lock()
                err = client.Flush()
                client.writeLock.Unlock()
                if err != nil {
                    goto exit
                }
                flushed = true
            } else if flushed {
                // last iteration we flushed...
                // do not select on the flusher ticker channel
                memoryMsgChan = subChannel.memoryMsgChan
                backendMsgChan = subChannel.backend.ReadChan()
                flusherChan = nil
            } else {
                // we're buffered (if there isn't any more data we should flush)...
                // select on the flusher ticker channel, too
                memoryMsgChan = subChannel.memoryMsgChan
                backendMsgChan = subChannel.backend.ReadChan()
                flusherChan = outputBufferTicker.C
            }

            select {
            case <-flusherChan:
                // if this case wins, we're either starved
                // or we won the race between other channels...
                // in either case, force flush
                client.writeLock.Lock()
                err = client.Flush()
                client.writeLock.Unlock()
                if err != nil {
                    goto exit
                }
                flushed = true
            case <-client.ReadyStateChan:
            case subChannel = <-subEventChan:
                // you can't SUB anymore
                subEventChan = nil
            case identifyData := <-identifyEventChan:
                // you can't IDENTIFY anymore
                identifyEventChan = nil

                outputBufferTicker.Stop()
                if identifyData.OutputBufferTimeout > 0 {
                    outputBufferTicker = time.NewTicker(identifyData.OutputBufferTimeout)
                }

                heartbeatTicker.Stop()
                heartbeatChan = nil
                if identifyData.HeartbeatInterval > 0 {
                    heartbeatTicker = time.NewTicker(identifyData.HeartbeatInterval)
                    heartbeatChan = heartbeatTicker.C
                }

                if identifyData.SampleRate > 0 {
                    sampleRate = identifyData.SampleRate
                }

                msgTimeout = identifyData.MsgTimeout
            case <-heartbeatChan:
                err = p.Send(client, frameTypeResponse, heartbeatBytes)
                if err != nil {
                    goto exit
                }
            case b := <-backendMsgChan:
                if sampleRate > 0 && rand.Int31n(100) > sampleRate {
                    continue
                }

                msg, err := decodeMessage(b)
                if err != nil {
                    p.ctx.nsqd.logf("ERROR: failed to decode message - %s", err)
                    continue
                }
                msg.Attempts++

                subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
                client.SendingMessage()
                err = p.SendMessage(client, msg, &buf)
                if err != nil {
                    goto exit
                }
                flushed = false
            case msg := <-memoryMsgChan:
                if sampleRate > 0 && rand.Int31n(100) > sampleRate {
                    continue
                }
                msg.Attempts++

                subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
                client.SendingMessage()
                err = p.SendMessage(client, msg, &buf)
                if err != nil {
                    goto exit
                }
                flushed = false
            case <-client.ExitChan:
                goto exit
            }
        }

    exit:
        p.ctx.nsqd.logf("PROTOCOL(V2): [%s] exiting messagePump", client)
        heartbeatTicker.Stop()
        outputBufferTicker.Stop()
        if err != nil {
            p.ctx.nsqd.logf("PROTOCOL(V2): [%s] messagePump error - %s", client, err)
        }
    }

    func (p *protocolV2) IDENTIFY(client *clientV2, params [][]byte) ([]byte, error) {
        var err error

        if atomic.LoadInt32(&client.State) != stateInit {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot IDENTIFY in current state")
        }

        bodyLen, err := readLen(client.Reader, client.lenSlice)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body size")
        }

        if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxBodySize {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
                fmt.Sprintf("IDENTIFY body too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxBodySize))
        }

        if bodyLen <= 0 {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
                fmt.Sprintf("IDENTIFY invalid body size %d", bodyLen))
        }

        body := make([]byte, bodyLen)
        _, err = io.ReadFull(client.Reader, body)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body")
        }

        // body is a json structure with producer information
        var identifyData identifyDataV2
        err = json.Unmarshal(body, &identifyData)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to decode JSON body")
        }

        if p.ctx.nsqd.getOpts().Verbose {
            p.ctx.nsqd.logf("PROTOCOL(V2): [%s] %+v", client, identifyData)
        }

        err = client.Identify(identifyData)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY "+err.Error())
        }

        // bail out early if we're not negotiating features
        if !identifyData.FeatureNegotiation {
            return okBytes, nil
        }

        tlsv1 := p.ctx.nsqd.tlsConfig != nil && identifyData.TLSv1
        deflate := p.ctx.nsqd.getOpts().DeflateEnabled && identifyData.Deflate
        deflateLevel := 0
        if deflate {
            if identifyData.DeflateLevel <= 0 {
                deflateLevel = 6
            }
            deflateLevel = int(math.Min(float64(deflateLevel), float64(p.ctx.nsqd.getOpts().MaxDeflateLevel)))
        }
        snappy := p.ctx.nsqd.getOpts().SnappyEnabled && identifyData.Snappy

        if deflate && snappy {
            return nil, protocol.NewFatalClientErr(nil, "E_IDENTIFY_FAILED", "cannot enable both deflate and snappy compression")
        }

        resp, err := json.Marshal(struct {
            MaxRdyCount         int64  `json:"max_rdy_count"`
            Version             string `json:"version"`
            MaxMsgTimeout       int64  `json:"max_msg_timeout"`
            MsgTimeout          int64  `json:"msg_timeout"`
            TLSv1               bool   `json:"tls_v1"`
            Deflate             bool   `json:"deflate"`
            DeflateLevel        int    `json:"deflate_level"`
            MaxDeflateLevel     int    `json:"max_deflate_level"`
            Snappy              bool   `json:"snappy"`
            SampleRate          int32  `json:"sample_rate"`
            AuthRequired        bool   `json:"auth_required"`
            OutputBufferSize    int    `json:"output_buffer_size"`
            OutputBufferTimeout int64  `json:"output_buffer_timeout"`
        }{
            MaxRdyCount:         p.ctx.nsqd.getOpts().MaxRdyCount,
            Version:             version.Binary,
            MaxMsgTimeout:       int64(p.ctx.nsqd.getOpts().MaxMsgTimeout / time.Millisecond),
            MsgTimeout:          int64(client.MsgTimeout / time.Millisecond),
            TLSv1:               tlsv1,
            Deflate:             deflate,
            DeflateLevel:        deflateLevel,
            MaxDeflateLevel:     p.ctx.nsqd.getOpts().MaxDeflateLevel,
            Snappy:              snappy,
            SampleRate:          client.SampleRate,
            AuthRequired:        p.ctx.nsqd.IsAuthEnabled(),
            OutputBufferSize:    client.OutputBufferSize,
            OutputBufferTimeout: int64(client.OutputBufferTimeout / time.Millisecond),
        })
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
        }

        err = p.Send(client, frameTypeResponse, resp)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
        }

        if tlsv1 {
            p.ctx.nsqd.logf("PROTOCOL(V2): [%s] upgrading connection to TLS", client)
            err = client.UpgradeTLS()
            if err != nil {
                return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
            }

            err = p.Send(client, frameTypeResponse, okBytes)
            if err != nil {
                return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
            }
        }

        if snappy {
            p.ctx.nsqd.logf("PROTOCOL(V2): [%s] upgrading connection to snappy", client)
            err = client.UpgradeSnappy()
            if err != nil {
                return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
            }

            err = p.Send(client, frameTypeResponse, okBytes)
            if err != nil {
                return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
            }
        }

        if deflate {
            p.ctx.nsqd.logf("PROTOCOL(V2): [%s] upgrading connection to deflate", client)
            err = client.UpgradeDeflate(deflateLevel)
            if err != nil {
                return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
            }

            err = p.Send(client, frameTypeResponse, okBytes)
            if err != nil {
                return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
            }
        }

        return nil, nil
    }

    func (p *protocolV2) AUTH(client *clientV2, params [][]byte) ([]byte, error) {
        if atomic.LoadInt32(&client.State) != stateInit {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot AUTH in current state")
        }

        if len(params) != 1 {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH invalid number of parameters")
        }

        bodyLen, err := readLen(client.Reader, client.lenSlice)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body size")
        }

        if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxBodySize {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
                fmt.Sprintf("AUTH body too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxBodySize))
        }

        if bodyLen <= 0 {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
                fmt.Sprintf("AUTH invalid body size %d", bodyLen))
        }

        body := make([]byte, bodyLen)
        _, err = io.ReadFull(client.Reader, body)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body")
        }

        if client.HasAuthorizations() {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH Already set")
        }

        if !client.ctx.nsqd.IsAuthEnabled() {
            return nil, protocol.NewFatalClientErr(err, "E_AUTH_DISABLED", "AUTH Disabled")
        }

        if err := client.Auth(string(body)); err != nil {
            // we don't want to leak errors contacting the auth server to untrusted clients
            p.ctx.nsqd.logf("PROTOCOL(V2): [%s] Auth Failed %s", client, err)
            return nil, protocol.NewFatalClientErr(err, "E_AUTH_FAILED", "AUTH failed")
        }

        if !client.HasAuthorizations() {
            return nil, protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED", "AUTH No authorizations found")
        }

        resp, err := json.Marshal(struct {
            Identity        string `json:"identity"`
            IdentityURL     string `json:"identity_url"`
            PermissionCount int    `json:"permission_count"`
        }{
            Identity:        client.AuthState.Identity,
            IdentityURL:     client.AuthState.IdentityURL,
            PermissionCount: len(client.AuthState.Authorizations),
        })
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error())
        }

        err = p.Send(client, frameTypeResponse, resp)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error())
        }

        return nil, nil

    }

    func (p *protocolV2) CheckAuth(client *clientV2, cmd, topicName, channelName string) error {
        // if auth is enabled, the client must have authorized already
        // compare topic/channel against cached authorization data (refetching if expired)
        if client.ctx.nsqd.IsAuthEnabled() {
            if !client.HasAuthorizations() {
                return protocol.NewFatalClientErr(nil, "E_AUTH_FIRST",
                    fmt.Sprintf("AUTH required before %s", cmd))
            }
            ok, err := client.IsAuthorized(topicName, channelName)
            if err != nil {
                // we don't want to leak errors contacting the auth server to untrusted clients
                p.ctx.nsqd.logf("PROTOCOL(V2): [%s] Auth Failed %s", client, err)
                return protocol.NewFatalClientErr(nil, "E_AUTH_FAILED", "AUTH failed")
            }
            if !ok {
                return protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED",
                    fmt.Sprintf("AUTH failed for %s on %q %q", cmd, topicName, channelName))
            }
        }
        return nil
    }

    func (p *protocolV2) SUB(client *clientV2, params [][]byte) ([]byte, error) {
        if atomic.LoadInt32(&client.State) != stateInit {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB in current state")
        }

        if client.HeartbeatInterval <= 0 {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB with heartbeats disabled")
        }

        if len(params) < 3 {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "SUB insufficient number of parameters")
        }

        topicName := string(params[1])
        if !protocol.IsValidTopicName(topicName) {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
                fmt.Sprintf("SUB topic name %q is not valid", topicName))
        }

        channelName := string(params[2])
        if !protocol.IsValidChannelName(channelName) {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_CHANNEL",
                fmt.Sprintf("SUB channel name %q is not valid", channelName))
        }

        if err := p.CheckAuth(client, "SUB", topicName, channelName); err != nil {
            return nil, err
        }

        topic := p.ctx.nsqd.GetTopic(topicName)
        channel := topic.GetChannel(channelName)
        channel.AddClient(client.ID, client)

        atomic.StoreInt32(&client.State, stateSubscribed)
        client.Channel = channel
        // update message pump
        client.SubEventChan <- channel

        return okBytes, nil
    }

    func (p *protocolV2) RDY(client *clientV2, params [][]byte) ([]byte, error) {
        state := atomic.LoadInt32(&client.State)

        if state == stateClosing {
            // just ignore ready changes on a closing channel
            p.ctx.nsqd.logf(
                "PROTOCOL(V2): [%s] ignoring RDY after CLS in state ClientStateV2Closing",
                client)
            return nil, nil
        }

        if state != stateSubscribed {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot RDY in current state")
        }

        count := int64(1)
        if len(params) > 1 {
            b10, err := protocol.ByteToBase10(params[1])
            if err != nil {
                return nil, protocol.NewFatalClientErr(err, "E_INVALID",
                    fmt.Sprintf("RDY could not parse count %s", params[1]))
            }
            count = int64(b10)
        }

        if count < 0 || count > p.ctx.nsqd.getOpts().MaxRdyCount {
            // this needs to be a fatal error otherwise clients would have
            // inconsistent state
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
                fmt.Sprintf("RDY count %d out of range 0-%d", count, p.ctx.nsqd.getOpts().MaxRdyCount))
        }

        client.SetReadyCount(count)

        return nil, nil
    }

    func (p *protocolV2) FIN(client *clientV2, params [][]byte) ([]byte, error) {
        state := atomic.LoadInt32(&client.State)
        if state != stateSubscribed && state != stateClosing {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot FIN in current state")
        }

        if len(params) < 2 {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "FIN insufficient number of params")
        }

        id, err := getMessageID(params[1])
        if err != nil {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
        }

        err = client.Channel.FinishMessage(client.ID, *id)
        if err != nil {
            return nil, protocol.NewClientErr(err, "E_FIN_FAILED",
                fmt.Sprintf("FIN %s failed %s", *id, err.Error()))
        }

        client.FinishedMessage()

        return nil, nil
    }

    func (p *protocolV2) REQ(client *clientV2, params [][]byte) ([]byte, error) {
        state := atomic.LoadInt32(&client.State)
        if state != stateSubscribed && state != stateClosing {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot REQ in current state")
        }

        if len(params) < 3 {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "REQ insufficient number of params")
        }

        id, err := getMessageID(params[1])
        if err != nil {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
        }

        timeoutMs, err := protocol.ByteToBase10(params[2])
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_INVALID",
                fmt.Sprintf("REQ could not parse timeout %s", params[2]))
        }
        timeoutDuration := time.Duration(timeoutMs) * time.Millisecond

        if timeoutDuration < 0 || timeoutDuration > p.ctx.nsqd.getOpts().MaxReqTimeout {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
                fmt.Sprintf("REQ timeout %d out of range 0-%d", timeoutDuration, p.ctx.nsqd.getOpts().MaxReqTimeout))
        }

        err = client.Channel.RequeueMessage(client.ID, *id, timeoutDuration)
        if err != nil {
            return nil, protocol.NewClientErr(err, "E_REQ_FAILED",
                fmt.Sprintf("REQ %s failed %s", *id, err.Error()))
        }

        client.RequeuedMessage()

        return nil, nil
    }

    func (p *protocolV2) CLS(client *clientV2, params [][]byte) ([]byte, error) {
        if atomic.LoadInt32(&client.State) != stateSubscribed {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot CLS in current state")
        }

        client.StartClose()

        return []byte("CLOSE_WAIT"), nil
    }

    func (p *protocolV2) NOP(client *clientV2, params [][]byte) ([]byte, error) {
        return nil, nil
    }

    func (p *protocolV2) PUB(client *clientV2, params [][]byte) ([]byte, error) {
        var err error

        if len(params) < 2 {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "PUB insufficient number of parameters")
        }

        topicName := string(params[1])
        if !protocol.IsValidTopicName(topicName) {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
                fmt.Sprintf("PUB topic name %q is not valid", topicName))
        }

        bodyLen, err := readLen(client.Reader, client.lenSlice)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body size")
        }

        if bodyLen <= 0 {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
                fmt.Sprintf("PUB invalid message body size %d", bodyLen))
        }

        if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxMsgSize {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
                fmt.Sprintf("PUB message too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxMsgSize))
        }

        messageBody := make([]byte, bodyLen)
        _, err = io.ReadFull(client.Reader, messageBody)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body")
        }

        if err := p.CheckAuth(client, "PUB", topicName, ""); err != nil {
            return nil, err
        }

        topic := p.ctx.nsqd.GetTopic(topicName)
        msg := NewMessage(<-p.ctx.nsqd.idChan, messageBody)
        err = topic.PutMessage(msg)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_PUB_FAILED", "PUB failed "+err.Error())
        }

        return okBytes, nil
    }

    func (p *protocolV2) MPUB(client *clientV2, params [][]byte) ([]byte, error) {
        var err error

        if len(params) < 2 {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "MPUB insufficient number of parameters")
        }

        topicName := string(params[1])
        if !protocol.IsValidTopicName(topicName) {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
                fmt.Sprintf("E_BAD_TOPIC MPUB topic name %q is not valid", topicName))
        }

        bodyLen, err := readLen(client.Reader, client.lenSlice)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read body size")
        }

        if bodyLen <= 0 {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
                fmt.Sprintf("MPUB invalid body size %d", bodyLen))
        }

        if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxBodySize {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
                fmt.Sprintf("MPUB body too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxBodySize))
        }

        messages, err := readMPUB(client.Reader, client.lenSlice, p.ctx.nsqd.idChan,
            p.ctx.nsqd.getOpts().MaxMsgSize)
        if err != nil {
            return nil, err
        }

        if err := p.CheckAuth(client, "MPUB", topicName, ""); err != nil {
            return nil, err
        }

        topic := p.ctx.nsqd.GetTopic(topicName)

        // if we've made it this far we've validated all the input,
        // the only possible error is that the topic is exiting during
        // this next call (and no messages will be queued in that case)
        err = topic.PutMessages(messages)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_MPUB_FAILED", "MPUB failed "+err.Error())
        }

        return okBytes, nil
    }

    func (p *protocolV2) DPUB(client *clientV2, params [][]byte) ([]byte, error) {
        var err error

        if len(params) < 3 {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "DPUB insufficient number of parameters")
        }

        topicName := string(params[1])
        if !protocol.IsValidTopicName(topicName) {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
                fmt.Sprintf("DPUB topic name %q is not valid", topicName))
        }

        timeoutMs, err := protocol.ByteToBase10(params[2])
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_INVALID",
                fmt.Sprintf("DPUB could not parse timeout %s", params[2]))
        }
        timeoutDuration := time.Duration(timeoutMs) * time.Millisecond

        if timeoutDuration < 0 || timeoutDuration > p.ctx.nsqd.getOpts().MaxReqTimeout {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
                fmt.Sprintf("DPUB timeout %d out of range 0-%d",
                    timeoutMs, p.ctx.nsqd.getOpts().MaxReqTimeout/time.Millisecond))
        }

        bodyLen, err := readLen(client.Reader, client.lenSlice)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body size")
        }

        if bodyLen <= 0 {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
                fmt.Sprintf("DPUB invalid message body size %d", bodyLen))
        }

        if int64(bodyLen) > p.ctx.nsqd.getOpts().MaxMsgSize {
            return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
                fmt.Sprintf("DPUB message too big %d > %d", bodyLen, p.ctx.nsqd.getOpts().MaxMsgSize))
        }

        messageBody := make([]byte, bodyLen)
        _, err = io.ReadFull(client.Reader, messageBody)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body")
        }

        if err := p.CheckAuth(client, "DPUB", topicName, ""); err != nil {
            return nil, err
        }

        topic := p.ctx.nsqd.GetTopic(topicName)
        msg := NewMessage(<-p.ctx.nsqd.idChan, messageBody)
        msg.deferred = timeoutDuration
        err = topic.PutMessage(msg)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_DPUB_FAILED", "DPUB failed "+err.Error())
        }

        return okBytes, nil
    }

    func (p *protocolV2) TOUCH(client *clientV2, params [][]byte) ([]byte, error) {
        state := atomic.LoadInt32(&client.State)
        if state != stateSubscribed && state != stateClosing {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot TOUCH in current state")
        }

        if len(params) < 2 {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "TOUCH insufficient number of params")
        }

        id, err := getMessageID(params[1])
        if err != nil {
            return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
        }

        client.writeLock.RLock()
        msgTimeout := client.MsgTimeout
        client.writeLock.RUnlock()
        err = client.Channel.TouchMessage(client.ID, *id, msgTimeout)
        if err != nil {
            return nil, protocol.NewClientErr(err, "E_TOUCH_FAILED",
                fmt.Sprintf("TOUCH %s failed %s", *id, err.Error()))
        }

        return nil, nil
    }

    func readMPUB(r io.Reader, tmp []byte, idChan chan MessageID, maxMessageSize int64) ([]*Message, error) {
        numMessages, err := readLen(r, tmp)
        if err != nil {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read message count")
        }

        if numMessages <= 0 {
            return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY",
                fmt.Sprintf("MPUB invalid message count %d", numMessages))
        }

        messages := make([]*Message, 0, numMessages)
        for i := int32(0); i < numMessages; i++ {
            messageSize, err := readLen(r, tmp)
            if err != nil {
                return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE",
                    fmt.Sprintf("MPUB failed to read message(%d) body size", i))
            }

            if messageSize <= 0 {
                return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
                    fmt.Sprintf("MPUB invalid message(%d) body size %d", i, messageSize))
            }

            if int64(messageSize) > maxMessageSize {
                return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
                    fmt.Sprintf("MPUB message too big %d > %d", messageSize, maxMessageSize))
            }

            msgBody := make([]byte, messageSize)
            _, err = io.ReadFull(r, msgBody)
            if err != nil {
                return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "MPUB failed to read message body")
            }

            messages = append(messages, NewMessage(<-idChan, msgBody))
        }

        return messages, nil
    }

    // validate and cast the bytes on the wire to a message ID
    func getMessageID(p []byte) (*MessageID, error) {
        if len(p) != MsgIDLength {
            return nil, errors.New("Invalid Message ID")
        }
        return (*MessageID)(unsafe.Pointer(&p[0])), nil
    }

    func readLen(r io.Reader, tmp []byte) (int32, error) {
        _, err := io.ReadFull(r, tmp)
        if err != nil {
            return 0, err
        }
        return int32(binary.BigEndian.Uint32(tmp)), nil
    }

    func enforceTLSPolicy(client *clientV2, p *protocolV2, command []byte) error {
        if p.ctx.nsqd.getOpts().TLSRequired != TLSNotRequired && atomic.LoadInt32(&client.TLS) != 1 {
            return protocol.NewFatalClientErr(nil, "E_INVALID",
                fmt.Sprintf("cannot %s in current state (TLS required)", command))
        }
        return nil
    }

  • 相关阅读:
    [linux] shell脚本编程-ubuntu创建vsftpd服务
    [linux] C语言Linux系统编程-做成守护进程
    [编程] C语言Linux系统编程-等待终止的子进程(僵死进程)
    [Linux]C语言Linux系统编程创建进程
    [linux] C语言Linux系统编程进程基本概念
    [编程] C语言枚举类型(Enum)
    [编程] C语言结构体指针作为函数参数
    [编程] C语言的二级指针
    [编程] C语言的结构体
    [编程] C语言循环结构计算π的值
  • 原文地址:https://www.cnblogs.com/zhangboyu/p/7457356.html
Copyright © 2011-2022 走看看