zoukankan      html  css  js  c++  java
  • nsqd.go

    package nsqd

    import (
        "crypto/tls"
        "crypto/x509"
        "encoding/json"
        "errors"
        "fmt"
        "io/ioutil"
        "math/rand"
        "net"
        "os"
        "path"
        "runtime"
        "strings"
        "sync"
        "sync/atomic"
        "time"

        "github.com/nsqio/nsq/internal/clusterinfo"
        "github.com/nsqio/nsq/internal/dirlock"
        "github.com/nsqio/nsq/internal/http_api"
        "github.com/nsqio/nsq/internal/protocol"
        "github.com/nsqio/nsq/internal/statsd"
        "github.com/nsqio/nsq/internal/util"
        "github.com/nsqio/nsq/internal/version"
    )

    const (
        TLSNotRequired = iota
        TLSRequiredExceptHTTP
        TLSRequired
    )

    type errStore struct {
        err error
    }

    type NSQD struct {
        // 64bit atomic vars need to be first for proper alignment on 32bit platforms
        clientIDSequence int64

        sync.RWMutex

        opts atomic.Value

        dl        *dirlock.DirLock
        isLoading int32
        errValue  atomic.Value
        startTime time.Time

        topicMap map[string]*Topic

        lookupPeers atomic.Value

        tcpListener   net.Listener
        httpListener  net.Listener
        httpsListener net.Listener
        tlsConfig     *tls.Config

        poolSize int

        idChan               chan MessageID
        notifyChan           chan interface{}
        optsNotificationChan chan struct{}
        exitChan             chan int
        waitGroup            util.WaitGroupWrapper

        ci *clusterinfo.ClusterInfo
    }

    func New(opts *Options) *NSQD {
        dataPath := opts.DataPath
        if opts.DataPath == "" {
            cwd, _ := os.Getwd()
            dataPath = cwd
        }

        n := &NSQD{
            startTime:            time.Now(),
            topicMap:             make(map[string]*Topic),
            idChan:               make(chan MessageID, 4096),
            exitChan:             make(chan int),
            notifyChan:           make(chan interface{}),
            optsNotificationChan: make(chan struct{}, 1),
            ci:                   clusterinfo.New(opts.Logger, http_api.NewClient(nil, opts.HTTPClientConnectTimeout, opts.HTTPClientRequestTimeout)),
            dl:                   dirlock.New(dataPath),
        }
        n.swapOpts(opts)
        n.errValue.Store(errStore{})

        err := n.dl.Lock()
        if err != nil {
            n.logf("FATAL: --data-path=%s in use (possibly by another instance of nsqd)", dataPath)
            os.Exit(1)
        }

        if opts.MaxDeflateLevel < 1 || opts.MaxDeflateLevel > 9 {
            n.logf("FATAL: --max-deflate-level must be [1,9]")
            os.Exit(1)
        }

        if opts.ID < 0 || opts.ID >= 1024 {
            n.logf("FATAL: --worker-id must be [0,1024)")
            os.Exit(1)
        }

        if opts.StatsdPrefix != "" {
            var port string
            _, port, err = net.SplitHostPort(opts.HTTPAddress)
            if err != nil {
                n.logf("ERROR: failed to parse HTTP address (%s) - %s", opts.HTTPAddress, err)
                os.Exit(1)
            }
            statsdHostKey := statsd.HostKey(net.JoinHostPort(opts.BroadcastAddress, port))
            prefixWithHost := strings.Replace(opts.StatsdPrefix, "%s", statsdHostKey, -1)
            if prefixWithHost[len(prefixWithHost)-1] != '.' {
                prefixWithHost += "."
            }
            opts.StatsdPrefix = prefixWithHost
        }

        if opts.TLSClientAuthPolicy != "" && opts.TLSRequired == TLSNotRequired {
            opts.TLSRequired = TLSRequired
        }

        tlsConfig, err := buildTLSConfig(opts)
        if err != nil {
            n.logf("FATAL: failed to build TLS config - %s", err)
            os.Exit(1)
        }
        if tlsConfig == nil && opts.TLSRequired != TLSNotRequired {
            n.logf("FATAL: cannot require TLS client connections without TLS key and cert")
            os.Exit(1)
        }
        n.tlsConfig = tlsConfig

        n.logf(version.String("nsqd"))
        n.logf("ID: %d", opts.ID)

        return n
    }

    func (n *NSQD) logf(f string, args ...interface{}) {
        if n.getOpts().Logger == nil {
            return
        }
        n.getOpts().Logger.Output(2, fmt.Sprintf(f, args...))
    }

    func (n *NSQD) getOpts() *Options {
        return n.opts.Load().(*Options)
    }

    func (n *NSQD) swapOpts(opts *Options) {
        n.opts.Store(opts)
    }

    func (n *NSQD) triggerOptsNotification() {
        select {
        case n.optsNotificationChan <- struct{}{}:
        default:
        }
    }

    func (n *NSQD) RealTCPAddr() *net.TCPAddr {
        n.RLock()
        defer n.RUnlock()
        return n.tcpListener.Addr().(*net.TCPAddr)
    }

    func (n *NSQD) RealHTTPAddr() *net.TCPAddr {
        n.RLock()
        defer n.RUnlock()
        return n.httpListener.Addr().(*net.TCPAddr)
    }

    func (n *NSQD) RealHTTPSAddr() *net.TCPAddr {
        n.RLock()
        defer n.RUnlock()
        return n.httpsListener.Addr().(*net.TCPAddr)
    }

    func (n *NSQD) SetHealth(err error) {
        n.errValue.Store(errStore{err: err})
    }

    func (n *NSQD) IsHealthy() bool {
        return n.GetError() == nil
    }

    func (n *NSQD) GetError() error {
        errValue := n.errValue.Load()
        return errValue.(errStore).err
    }

    func (n *NSQD) GetHealth() string {
        err := n.GetError()
        if err != nil {
            return fmt.Sprintf("NOK - %s", err)
        }
        return "OK"
    }

    func (n *NSQD) GetStartTime() time.Time {
        return n.startTime
    }

    func (n *NSQD) Main() {
        var httpListener net.Listener
        var httpsListener net.Listener

        ctx := &context{n}

        tcpListener, err := net.Listen("tcp", n.getOpts().TCPAddress)
        if err != nil {
            n.logf("FATAL: listen (%s) failed - %s", n.getOpts().TCPAddress, err)
            os.Exit(1)
        }
        n.Lock()
        n.tcpListener = tcpListener
        n.Unlock()
        tcpServer := &tcpServer{ctx: ctx}
        n.waitGroup.Wrap(func() {
            protocol.TCPServer(n.tcpListener, tcpServer, n.getOpts().Logger)
        })

        if n.tlsConfig != nil && n.getOpts().HTTPSAddress != "" {
            httpsListener, err = tls.Listen("tcp", n.getOpts().HTTPSAddress, n.tlsConfig)
            if err != nil {
                n.logf("FATAL: listen (%s) failed - %s", n.getOpts().HTTPSAddress, err)
                os.Exit(1)
            }
            n.Lock()
            n.httpsListener = httpsListener
            n.Unlock()
            httpsServer := newHTTPServer(ctx, true, true)
            n.waitGroup.Wrap(func() {
                http_api.Serve(n.httpsListener, httpsServer, "HTTPS", n.getOpts().Logger)
            })
        }
        httpListener, err = net.Listen("tcp", n.getOpts().HTTPAddress)
        if err != nil {
            n.logf("FATAL: listen (%s) failed - %s", n.getOpts().HTTPAddress, err)
            os.Exit(1)
        }
        n.Lock()
        n.httpListener = httpListener
        n.Unlock()
        httpServer := newHTTPServer(ctx, false, n.getOpts().TLSRequired == TLSRequired)
        n.waitGroup.Wrap(func() {
            http_api.Serve(n.httpListener, httpServer, "HTTP", n.getOpts().Logger)
        })

        n.waitGroup.Wrap(func() { n.queueScanLoop() })
        n.waitGroup.Wrap(func() { n.idPump() })
        n.waitGroup.Wrap(func() { n.lookupLoop() })
        if n.getOpts().StatsdAddress != "" {
            n.waitGroup.Wrap(func() { n.statsdLoop() })
        }
    }

    type meta struct {
        Topics []struct {
            Name     string `json:"name"`
            Paused   bool   `json:"paused"`
            Channels []struct {
                Name   string `json:"name"`
                Paused bool   `json:"paused"`
            } `json:"channels"`
        } `json:"topics"`
    }

    func (n *NSQD) LoadMetadata() {
        atomic.StoreInt32(&n.isLoading, 1)
        defer atomic.StoreInt32(&n.isLoading, 0)

        fn := fmt.Sprintf(path.Join(n.getOpts().DataPath, "nsqd.%d.dat"), n.getOpts().ID)
        data, err := ioutil.ReadFile(fn)
        if err != nil {
            if !os.IsNotExist(err) {
                n.logf("ERROR: failed to read channel metadata from %s - %s", fn, err)
            }
            return
        }

        var m meta
        err = json.Unmarshal(data, &m)
        if err != nil {
            n.logf("ERROR: failed to parse metadata - %s", err)
            return
        }

        for _, t := range m.Topics {
            if !protocol.IsValidTopicName(t.Name) {
                n.logf("WARNING: skipping creation of invalid topic %s", t.Name)
                continue
            }
            topic := n.GetTopic(t.Name)
            if t.Paused {
                topic.Pause()
            }

            for _, c := range t.Channels {
                if !protocol.IsValidChannelName(c.Name) {
                    n.logf("WARNING: skipping creation of invalid channel %s", c.Name)
                    continue
                }
                channel := topic.GetChannel(c.Name)
                if c.Paused {
                    channel.Pause()
                }
            }
        }
    }

    func (n *NSQD) PersistMetadata() error {
        // persist metadata about what topics/channels we have
        // so that upon restart we can get back to the same state
        fileName := fmt.Sprintf(path.Join(n.getOpts().DataPath, "nsqd.%d.dat"), n.getOpts().ID)
        n.logf("NSQ: persisting topic/channel metadata to %s", fileName)

        js := make(map[string]interface{})
        topics := []interface{}{}
        for _, topic := range n.topicMap {
            if topic.ephemeral {
                continue
            }
            topicData := make(map[string]interface{})
            topicData["name"] = topic.name
            topicData["paused"] = topic.IsPaused()
            channels := []interface{}{}
            topic.Lock()
            for _, channel := range topic.channelMap {
                channel.Lock()
                if channel.ephemeral {
                    channel.Unlock()
                    continue
                }
                channelData := make(map[string]interface{})
                channelData["name"] = channel.name
                channelData["paused"] = channel.IsPaused()
                channels = append(channels, channelData)
                channel.Unlock()
            }
            topic.Unlock()
            topicData["channels"] = channels
            topics = append(topics, topicData)
        }
        js["version"] = version.Binary
        js["topics"] = topics

        data, err := json.Marshal(&js)
        if err != nil {
            return err
        }

        tmpFileName := fmt.Sprintf("%s.%d.tmp", fileName, rand.Int())
        f, err := os.OpenFile(tmpFileName, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
        if err != nil {
            return err
        }

        _, err = f.Write(data)
        if err != nil {
            f.Close()
            return err
        }
        f.Sync()
        f.Close()

        err = atomicRename(tmpFileName, fileName)
        if err != nil {
            return err
        }

        return nil
    }

    func (n *NSQD) Exit() {
        if n.tcpListener != nil {
            n.tcpListener.Close()
        }

        if n.httpListener != nil {
            n.httpListener.Close()
        }

        if n.httpsListener != nil {
            n.httpsListener.Close()
        }

        n.Lock()
        err := n.PersistMetadata()
        if err != nil {
            n.logf("ERROR: failed to persist metadata - %s", err)
        }
        n.logf("NSQ: closing topics")
        for _, topic := range n.topicMap {
            topic.Close()
        }
        n.Unlock()

        // we want to do this last as it closes the idPump (if closed first it
        // could potentially starve items in process and deadlock)
        close(n.exitChan)
        n.waitGroup.Wait()

        n.dl.Unlock()
    }

    // GetTopic performs a thread safe operation
    // to return a pointer to a Topic object (potentially new)
    func (n *NSQD) GetTopic(topicName string) *Topic {
        // most likely, we already have this topic, so try read lock first.
        n.RLock()
        t, ok := n.topicMap[topicName]
        n.RUnlock()
        if ok {
            return t
        }

        n.Lock()

        t, ok = n.topicMap[topicName]
        if ok {
            n.Unlock()
            return t
        }
        deleteCallback := func(t *Topic) {
            n.DeleteExistingTopic(t.name)
        }
        t = NewTopic(topicName, &context{n}, deleteCallback)
        n.topicMap[topicName] = t

        n.logf("TOPIC(%s): created", t.name)

        // release our global nsqd lock, and switch to a more granular topic lock while we init our
        // channels from lookupd. This blocks concurrent PutMessages to this topic.
        t.Lock()
        n.Unlock()

        // if using lookupd, make a blocking call to get the topics, and immediately create them.
        // this makes sure that any message received is buffered to the right channels
        lookupdHTTPAddrs := n.lookupdHTTPAddrs()
        if len(lookupdHTTPAddrs) > 0 {
            channelNames, _ := n.ci.GetLookupdTopicChannels(t.name, lookupdHTTPAddrs)
            for _, channelName := range channelNames {
                if strings.HasSuffix(channelName, "#ephemeral") {
                    // we don't want to pre-create ephemeral channels
                    // because there isn't a client connected
                    continue
                }
                t.getOrCreateChannel(channelName)
            }
        }

        t.Unlock()

        // NOTE: I would prefer for this to only happen in topic.GetChannel() but we're special
        // casing the code above so that we can control the locks such that it is impossible
        // for a message to be written to a (new) topic while we're looking up channels
        // from lookupd...
        //
        // update messagePump state
        select {
        case t.channelUpdateChan <- 1:
        case <-t.exitChan:
        }
        return t
    }

    // GetExistingTopic gets a topic only if it exists
    func (n *NSQD) GetExistingTopic(topicName string) (*Topic, error) {
        n.RLock()
        defer n.RUnlock()
        topic, ok := n.topicMap[topicName]
        if !ok {
            return nil, errors.New("topic does not exist")
        }
        return topic, nil
    }

    // DeleteExistingTopic removes a topic only if it exists
    func (n *NSQD) DeleteExistingTopic(topicName string) error {
        n.RLock()
        topic, ok := n.topicMap[topicName]
        if !ok {
            n.RUnlock()
            return errors.New("topic does not exist")
        }
        n.RUnlock()

        // delete empties all channels and the topic itself before closing
        // (so that we dont leave any messages around)
        //
        // we do this before removing the topic from map below (with no lock)
        // so that any incoming writes will error and not create a new topic
        // to enforce ordering
        topic.Delete()

        n.Lock()
        delete(n.topicMap, topicName)
        n.Unlock()

        return nil
    }

    func (n *NSQD) idPump() {
        factory := &guidFactory{}
        lastError := time.Unix(0, 0)
        workerID := n.getOpts().ID
        for {
            id, err := factory.NewGUID(workerID)
            if err != nil {
                now := time.Now()
                if now.Sub(lastError) > time.Second {
                    // only print the error once/second
                    n.logf("ERROR: %s", err)
                    lastError = now
                }
                runtime.Gosched()
                continue
            }
            select {
            case n.idChan <- id.Hex():
            case <-n.exitChan:
                goto exit
            }
        }

    exit:
        n.logf("ID: closing")
    }

    func (n *NSQD) Notify(v interface{}) {
        // since the in-memory metadata is incomplete,
        // should not persist metadata while loading it.
        // nsqd will call `PersistMetadata` it after loading
        persist := atomic.LoadInt32(&n.isLoading) == 0
        n.waitGroup.Wrap(func() {
            // by selecting on exitChan we guarantee that
            // we do not block exit, see issue #123
            select {
            case <-n.exitChan:
            case n.notifyChan <- v:
                if !persist {
                    return
                }
                n.Lock()
                err := n.PersistMetadata()
                if err != nil {
                    n.logf("ERROR: failed to persist metadata - %s", err)
                }
                n.Unlock()
            }
        })
    }

    // channels returns a flat slice of all channels in all topics
    func (n *NSQD) channels() []*Channel {
        var channels []*Channel
        n.RLock()
        for _, t := range n.topicMap {
            t.RLock()
            for _, c := range t.channelMap {
                channels = append(channels, c)
            }
            t.RUnlock()
        }
        n.RUnlock()
        return channels
    }

    // resizePool adjusts the size of the pool of queueScanWorker goroutines
    //
    //     1 <= pool <= min(num * 0.25, QueueScanWorkerPoolMax)
    //
    func (n *NSQD) resizePool(num int, workCh chan *Channel, responseCh chan bool, closeCh chan int) {
        idealPoolSize := int(float64(num) * 0.25)
        if idealPoolSize < 1 {
            idealPoolSize = 1
        } else if idealPoolSize > n.getOpts().QueueScanWorkerPoolMax {
            idealPoolSize = n.getOpts().QueueScanWorkerPoolMax
        }
        for {
            if idealPoolSize == n.poolSize {
                break
            } else if idealPoolSize < n.poolSize {
                // contract
                closeCh <- 1
                n.poolSize--
            } else {
                // expand
                n.waitGroup.Wrap(func() {
                    n.queueScanWorker(workCh, responseCh, closeCh)
                })
                n.poolSize++
            }
        }
    }

    // queueScanWorker receives work (in the form of a channel) from queueScanLoop
    // and processes the deferred and in-flight queues
    func (n *NSQD) queueScanWorker(workCh chan *Channel, responseCh chan bool, closeCh chan int) {
        for {
            select {
            case c := <-workCh:
                now := time.Now().UnixNano()
                dirty := false
                if c.processInFlightQueue(now) {
                    dirty = true
                }
                if c.processDeferredQueue(now) {
                    dirty = true
                }
                responseCh <- dirty
            case <-closeCh:
                return
            }
        }
    }

    // queueScanLoop runs in a single goroutine to process in-flight and deferred
    // priority queues. It manages a pool of queueScanWorker (configurable max of
    // QueueScanWorkerPoolMax (default: 4)) that process channels concurrently.
    //
    // It copies Redis's probabilistic expiration algorithm: it wakes up every
    // QueueScanInterval (default: 100ms) to select a random QueueScanSelectionCount
    // (default: 20) channels from a locally cached list (refreshed every
    // QueueScanRefreshInterval (default: 5s)).
    //
    // If either of the queues had work to do the channel is considered "dirty".
    //
    // If QueueScanDirtyPercent (default: 25%) of the selected channels were dirty,
    // the loop continues without sleep.
    func (n *NSQD) queueScanLoop() {
        workCh := make(chan *Channel, n.getOpts().QueueScanSelectionCount)
        responseCh := make(chan bool, n.getOpts().QueueScanSelectionCount)
        closeCh := make(chan int)

        workTicker := time.NewTicker(n.getOpts().QueueScanInterval)
        refreshTicker := time.NewTicker(n.getOpts().QueueScanRefreshInterval)

        channels := n.channels()
        n.resizePool(len(channels), workCh, responseCh, closeCh)

        for {
            select {
            case <-workTicker.C:
                if len(channels) == 0 {
                    continue
                }
            case <-refreshTicker.C:
                channels = n.channels()
                n.resizePool(len(channels), workCh, responseCh, closeCh)
                continue
            case <-n.exitChan:
                goto exit
            }

            num := n.getOpts().QueueScanSelectionCount
            if num > len(channels) {
                num = len(channels)
            }

        loop:
            for _, i := range util.UniqRands(num, len(channels)) {
                workCh <- channels[i]
            }

            numDirty := 0
            for i := 0; i < num; i++ {
                if <-responseCh {
                    numDirty++
                }
            }

            if float64(numDirty)/float64(num) > n.getOpts().QueueScanDirtyPercent {
                goto loop
            }
        }

    exit:
        n.logf("QUEUESCAN: closing")
        close(closeCh)
        workTicker.Stop()
        refreshTicker.Stop()
    }

    func buildTLSConfig(opts *Options) (*tls.Config, error) {
        var tlsConfig *tls.Config

        if opts.TLSCert == "" && opts.TLSKey == "" {
            return nil, nil
        }

        tlsClientAuthPolicy := tls.VerifyClientCertIfGiven

        cert, err := tls.LoadX509KeyPair(opts.TLSCert, opts.TLSKey)
        if err != nil {
            return nil, err
        }
        switch opts.TLSClientAuthPolicy {
        case "require":
            tlsClientAuthPolicy = tls.RequireAnyClientCert
        case "require-verify":
            tlsClientAuthPolicy = tls.RequireAndVerifyClientCert
        default:
            tlsClientAuthPolicy = tls.NoClientCert
        }

        tlsConfig = &tls.Config{
            Certificates: []tls.Certificate{cert},
            ClientAuth:   tlsClientAuthPolicy,
            MinVersion:   opts.TLSMinVersion,
            MaxVersion:   tls.VersionTLS12, // enable TLS_FALLBACK_SCSV prior to Go 1.5: https://go-review.googlesource.com/#/c/1776/
        }

        if opts.TLSRootCAFile != "" {
            tlsCertPool := x509.NewCertPool()
            caCertFile, err := ioutil.ReadFile(opts.TLSRootCAFile)
            if err != nil {
                return nil, err
            }
            if !tlsCertPool.AppendCertsFromPEM(caCertFile) {
                return nil, errors.New("failed to append certificate to pool")
            }
            tlsConfig.ClientCAs = tlsCertPool
        }

        tlsConfig.BuildNameToCertificate()

        return tlsConfig, nil
    }

    func (n *NSQD) IsAuthEnabled() bool {
        return len(n.getOpts().AuthHTTPAddresses) != 0
    }

  • 相关阅读:
    写了一个html5音乐播放器
    跨站脚本攻击之反射型XSS漏洞
    org.hibernate.ObjectNotFoundException异常一则
    addEventListener第二个参数的handleEvent
    IE的fireEvent方法
    DexMongoDB索引优化工具
    javascript 位运算符
    javascript contains方法来判断元素包含关系
    将uglifyjs添加到鼠标右键菜单
    过滤XSS(跨站脚本攻击)的函数和防止svn版本库被浏览
  • 原文地址:https://www.cnblogs.com/zhangboyu/p/7457350.html
Copyright © 2011-2022 走看看