package nsqd
import (
"bufio"
"compress/flate"
"crypto/tls"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/mreiferson/go-snappystream"
"github.com/nsqio/nsq/internal/auth"
)
const defaultBufferSize = 16 * 1024
const (
stateInit = iota
stateDisconnected
stateConnected
stateSubscribed
stateClosing
)
type identifyDataV2 struct {
ShortID string `json:"short_id"` // TODO: deprecated, remove in 1.0
LongID string `json:"long_id"` // TODO: deprecated, remove in 1.0
ClientID string `json:"client_id"`
Hostname string `json:"hostname"`
HeartbeatInterval int `json:"heartbeat_interval"`
OutputBufferSize int `json:"output_buffer_size"`
OutputBufferTimeout int `json:"output_buffer_timeout"`
FeatureNegotiation bool `json:"feature_negotiation"`
TLSv1 bool `json:"tls_v1"`
Deflate bool `json:"deflate"`
DeflateLevel int `json:"deflate_level"`
Snappy bool `json:"snappy"`
SampleRate int32 `json:"sample_rate"`
UserAgent string `json:"user_agent"`
MsgTimeout int `json:"msg_timeout"`
}
type identifyEvent struct {
OutputBufferTimeout time.Duration
HeartbeatInterval time.Duration
SampleRate int32
MsgTimeout time.Duration
}
type clientV2 struct {
// 64bit atomic vars need to be first for proper alignment on 32bit platforms
ReadyCount int64
InFlightCount int64
MessageCount uint64
FinishCount uint64
RequeueCount uint64
writeLock sync.RWMutex
metaLock sync.RWMutex
ID int64
ctx *context
UserAgent string
// original connection
net.Conn
// connections based on negotiated features
tlsConn *tls.Conn
flateWriter *flate.Writer
// reading/writing interfaces
Reader *bufio.Reader
Writer *bufio.Writer
OutputBufferSize int
OutputBufferTimeout time.Duration
HeartbeatInterval time.Duration
MsgTimeout time.Duration
State int32
ConnectTime time.Time
Channel *Channel
ReadyStateChan chan int
ExitChan chan int
ClientID string
Hostname string
SampleRate int32
IdentifyEventChan chan identifyEvent
SubEventChan chan *Channel
TLS int32
Snappy int32
Deflate int32
// re-usable buffer for reading the 4-byte lengths off the wire
lenBuf [4]byte
lenSlice []byte
AuthSecret string
AuthState *auth.State
}
func newClientV2(id int64, conn net.Conn, ctx *context) *clientV2 {
var identifier string
if conn != nil {
identifier, _, _ = net.SplitHostPort(conn.RemoteAddr().String())
}
c := &clientV2{
ID: id,
ctx: ctx,
Conn: conn,
Reader: bufio.NewReaderSize(conn, defaultBufferSize),
Writer: bufio.NewWriterSize(conn, defaultBufferSize),
OutputBufferSize: defaultBufferSize,
OutputBufferTimeout: 250 * time.Millisecond,
MsgTimeout: ctx.nsqd.getOpts().MsgTimeout,
// ReadyStateChan has a buffer of 1 to guarantee that in the event
// there is a race the state update is not lost
ReadyStateChan: make(chan int, 1),
ExitChan: make(chan int),
ConnectTime: time.Now(),
State: stateInit,
ClientID: identifier,
Hostname: identifier,
SubEventChan: make(chan *Channel, 1),
IdentifyEventChan: make(chan identifyEvent, 1),
// heartbeats are client configurable but default to 30s
HeartbeatInterval: ctx.nsqd.getOpts().ClientTimeout / 2,
}
c.lenSlice = c.lenBuf[:]
return c
}
func (c *clientV2) String() string {
return c.RemoteAddr().String()
}
func (c *clientV2) Identify(data identifyDataV2) error {
c.ctx.nsqd.logf("[%s] IDENTIFY: %+v", c, data)
// TODO: for backwards compatibility, remove in 1.0
hostname := data.Hostname
if hostname == "" {
hostname = data.LongID
}
// TODO: for backwards compatibility, remove in 1.0
clientID := data.ClientID
if clientID == "" {
clientID = data.ShortID
}
c.metaLock.Lock()
c.ClientID = clientID
c.Hostname = hostname
c.UserAgent = data.UserAgent
c.metaLock.Unlock()
err := c.SetHeartbeatInterval(data.HeartbeatInterval)
if err != nil {
return err
}
err = c.SetOutputBufferSize(data.OutputBufferSize)
if err != nil {
return err
}
err = c.SetOutputBufferTimeout(data.OutputBufferTimeout)
if err != nil {
return err
}
err = c.SetSampleRate(data.SampleRate)
if err != nil {
return err
}
err = c.SetMsgTimeout(data.MsgTimeout)
if err != nil {
return err
}
ie := identifyEvent{
OutputBufferTimeout: c.OutputBufferTimeout,
HeartbeatInterval: c.HeartbeatInterval,
SampleRate: c.SampleRate,
MsgTimeout: c.MsgTimeout,
}
// update the client's message pump
select {
case c.IdentifyEventChan <- ie:
default:
}
return nil
}
func (c *clientV2) Stats() ClientStats {
c.metaLock.RLock()
// TODO: deprecated, remove in 1.0
name := c.ClientID
clientID := c.ClientID
hostname := c.Hostname
userAgent := c.UserAgent
var identity string
var identityURL string
if c.AuthState != nil {
identity = c.AuthState.Identity
identityURL = c.AuthState.IdentityURL
}
c.metaLock.RUnlock()
stats := ClientStats{
// TODO: deprecated, remove in 1.0
Name: name,
Version: "V2",
RemoteAddress: c.RemoteAddr().String(),
ClientID: clientID,
Hostname: hostname,
UserAgent: userAgent,
State: atomic.LoadInt32(&c.State),
ReadyCount: atomic.LoadInt64(&c.ReadyCount),
InFlightCount: atomic.LoadInt64(&c.InFlightCount),
MessageCount: atomic.LoadUint64(&c.MessageCount),
FinishCount: atomic.LoadUint64(&c.FinishCount),
RequeueCount: atomic.LoadUint64(&c.RequeueCount),
ConnectTime: c.ConnectTime.Unix(),
SampleRate: atomic.LoadInt32(&c.SampleRate),
TLS: atomic.LoadInt32(&c.TLS) == 1,
Deflate: atomic.LoadInt32(&c.Deflate) == 1,
Snappy: atomic.LoadInt32(&c.Snappy) == 1,
Authed: c.HasAuthorizations(),
AuthIdentity: identity,
AuthIdentityURL: identityURL,
}
if stats.TLS {
p := prettyConnectionState{c.tlsConn.ConnectionState()}
stats.CipherSuite = p.GetCipherSuite()
stats.TLSVersion = p.GetVersion()
stats.TLSNegotiatedProtocol = p.NegotiatedProtocol
stats.TLSNegotiatedProtocolIsMutual = p.NegotiatedProtocolIsMutual
}
return stats
}
// struct to convert from integers to the human readable strings
type prettyConnectionState struct {
tls.ConnectionState
}
func (p *prettyConnectionState) GetCipherSuite() string {
switch p.CipherSuite {
case tls.TLS_RSA_WITH_RC4_128_SHA:
return "TLS_RSA_WITH_RC4_128_SHA"
case tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA:
return "TLS_RSA_WITH_3DES_EDE_CBC_SHA"
case tls.TLS_RSA_WITH_AES_128_CBC_SHA:
return "TLS_RSA_WITH_AES_128_CBC_SHA"
case tls.TLS_RSA_WITH_AES_256_CBC_SHA:
return "TLS_RSA_WITH_AES_256_CBC_SHA"
case tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA:
return "TLS_ECDHE_ECDSA_WITH_RC4_128_SHA"
case tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA:
return "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA"
case tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA:
return "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"
case tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA:
return "TLS_ECDHE_RSA_WITH_RC4_128_SHA"
case tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA:
return "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA"
case tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA:
return "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA"
case tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
return "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"
case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
case tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
}
return fmt.Sprintf("Unknown %d", p.CipherSuite)
}
func (p *prettyConnectionState) GetVersion() string {
switch p.Version {
case tls.VersionSSL30:
return "SSL30"
case tls.VersionTLS10:
return "TLS1.0"
case tls.VersionTLS11:
return "TLS1.1"
case tls.VersionTLS12:
return "TLS1.2"
default:
return fmt.Sprintf("Unknown %d", p.Version)
}
}
func (c *clientV2) IsReadyForMessages() bool {
if c.Channel.IsPaused() {
return false
}
readyCount := atomic.LoadInt64(&c.ReadyCount)
inFlightCount := atomic.LoadInt64(&c.InFlightCount)
if c.ctx.nsqd.getOpts().Verbose {
c.ctx.nsqd.logf("[%s] state rdy: %4d inflt: %4d",
c, readyCount, inFlightCount)
}
if inFlightCount >= readyCount || readyCount <= 0 {
return false
}
return true
}
func (c *clientV2) SetReadyCount(count int64) {
atomic.StoreInt64(&c.ReadyCount, count)
c.tryUpdateReadyState()
}
func (c *clientV2) tryUpdateReadyState() {
// you can always *try* to write to ReadyStateChan because in the cases
// where you cannot the message pump loop would have iterated anyway.
// the atomic integer operations guarantee correctness of the value.
select {
case c.ReadyStateChan <- 1:
default:
}
}
func (c *clientV2) FinishedMessage() {
atomic.AddUint64(&c.FinishCount, 1)
atomic.AddInt64(&c.InFlightCount, -1)
c.tryUpdateReadyState()
}
func (c *clientV2) Empty() {
atomic.StoreInt64(&c.InFlightCount, 0)
c.tryUpdateReadyState()
}
func (c *clientV2) SendingMessage() {
atomic.AddInt64(&c.InFlightCount, 1)
atomic.AddUint64(&c.MessageCount, 1)
}
func (c *clientV2) TimedOutMessage() {
atomic.AddInt64(&c.InFlightCount, -1)
c.tryUpdateReadyState()
}
func (c *clientV2) RequeuedMessage() {
atomic.AddUint64(&c.RequeueCount, 1)
atomic.AddInt64(&c.InFlightCount, -1)
c.tryUpdateReadyState()
}
func (c *clientV2) StartClose() {
// Force the client into ready 0
c.SetReadyCount(0)
// mark this client as closing
atomic.StoreInt32(&c.State, stateClosing)
}
func (c *clientV2) Pause() {
c.tryUpdateReadyState()
}
func (c *clientV2) UnPause() {
c.tryUpdateReadyState()
}
func (c *clientV2) SetHeartbeatInterval(desiredInterval int) error {
c.writeLock.Lock()
defer c.writeLock.Unlock()
switch {
case desiredInterval == -1:
c.HeartbeatInterval = 0
case desiredInterval == 0:
// do nothing (use default)
case desiredInterval >= 1000 &&
desiredInterval <= int(c.ctx.nsqd.getOpts().MaxHeartbeatInterval/time.Millisecond):
c.HeartbeatInterval = time.Duration(desiredInterval) * time.Millisecond
default:
return fmt.Errorf("heartbeat interval (%d) is invalid", desiredInterval)
}
return nil
}
func (c *clientV2) SetOutputBufferSize(desiredSize int) error {
var size int
switch {
case desiredSize == -1:
// effectively no buffer (every write will go directly to the wrapped net.Conn)
size = 1
case desiredSize == 0:
// do nothing (use default)
case desiredSize >= 64 && desiredSize <= int(c.ctx.nsqd.getOpts().MaxOutputBufferSize):
size = desiredSize
default:
return fmt.Errorf("output buffer size (%d) is invalid", desiredSize)
}
if size > 0 {
c.writeLock.Lock()
defer c.writeLock.Unlock()
c.OutputBufferSize = size
err := c.Writer.Flush()
if err != nil {
return err
}
c.Writer = bufio.NewWriterSize(c.Conn, size)
}
return nil
}
func (c *clientV2) SetOutputBufferTimeout(desiredTimeout int) error {
c.writeLock.Lock()
defer c.writeLock.Unlock()
switch {
case desiredTimeout == -1:
c.OutputBufferTimeout = 0
case desiredTimeout == 0:
// do nothing (use default)
case desiredTimeout >= 1 &&
desiredTimeout <= int(c.ctx.nsqd.getOpts().MaxOutputBufferTimeout/time.Millisecond):
c.OutputBufferTimeout = time.Duration(desiredTimeout) * time.Millisecond
default:
return fmt.Errorf("output buffer timeout (%d) is invalid", desiredTimeout)
}
return nil
}
func (c *clientV2) SetSampleRate(sampleRate int32) error {
if sampleRate < 0 || sampleRate > 99 {
return fmt.Errorf("sample rate (%d) is invalid", sampleRate)
}
atomic.StoreInt32(&c.SampleRate, sampleRate)
return nil
}
func (c *clientV2) SetMsgTimeout(msgTimeout int) error {
c.writeLock.Lock()
defer c.writeLock.Unlock()
switch {
case msgTimeout == 0:
// do nothing (use default)
case msgTimeout >= 1000 &&
msgTimeout <= int(c.ctx.nsqd.getOpts().MaxMsgTimeout/time.Millisecond):
c.MsgTimeout = time.Duration(msgTimeout) * time.Millisecond
default:
return fmt.Errorf("msg timeout (%d) is invalid", msgTimeout)
}
return nil
}
func (c *clientV2) UpgradeTLS() error {
c.writeLock.Lock()
defer c.writeLock.Unlock()
tlsConn := tls.Server(c.Conn, c.ctx.nsqd.tlsConfig)
tlsConn.SetDeadline(time.Now().Add(5 * time.Second))
err := tlsConn.Handshake()
if err != nil {
return err
}
c.tlsConn = tlsConn
c.Reader = bufio.NewReaderSize(c.tlsConn, defaultBufferSize)
c.Writer = bufio.NewWriterSize(c.tlsConn, c.OutputBufferSize)
atomic.StoreInt32(&c.TLS, 1)
return nil
}
func (c *clientV2) UpgradeDeflate(level int) error {
c.writeLock.Lock()
defer c.writeLock.Unlock()
conn := c.Conn
if c.tlsConn != nil {
conn = c.tlsConn
}
c.Reader = bufio.NewReaderSize(flate.NewReader(conn), defaultBufferSize)
fw, _ := flate.NewWriter(conn, level)
c.flateWriter = fw
c.Writer = bufio.NewWriterSize(fw, c.OutputBufferSize)
atomic.StoreInt32(&c.Deflate, 1)
return nil
}
func (c *clientV2) UpgradeSnappy() error {
c.writeLock.Lock()
defer c.writeLock.Unlock()
conn := c.Conn
if c.tlsConn != nil {
conn = c.tlsConn
}
c.Reader = bufio.NewReaderSize(snappystream.NewReader(conn, snappystream.SkipVerifyChecksum), defaultBufferSize)
c.Writer = bufio.NewWriterSize(snappystream.NewWriter(conn), c.OutputBufferSize)
atomic.StoreInt32(&c.Snappy, 1)
return nil
}
func (c *clientV2) Flush() error {
var zeroTime time.Time
if c.HeartbeatInterval > 0 {
c.SetWriteDeadline(time.Now().Add(c.HeartbeatInterval))
} else {
c.SetWriteDeadline(zeroTime)
}
err := c.Writer.Flush()
if err != nil {
return err
}
if c.flateWriter != nil {
return c.flateWriter.Flush()
}
return nil
}
func (c *clientV2) QueryAuthd() error {
remoteIP, _, err := net.SplitHostPort(c.String())
if err != nil {
return err
}
tls := atomic.LoadInt32(&c.TLS) == 1
tlsEnabled := "false"
if tls {
tlsEnabled = "true"
}
authState, err := auth.QueryAnyAuthd(c.ctx.nsqd.getOpts().AuthHTTPAddresses,
remoteIP, tlsEnabled, c.AuthSecret, c.ctx.nsqd.getOpts().HTTPClientConnectTimeout,
c.ctx.nsqd.getOpts().HTTPClientRequestTimeout)
if err != nil {
return err
}
c.AuthState = authState
return nil
}
func (c *clientV2) Auth(secret string) error {
c.AuthSecret = secret
return c.QueryAuthd()
}
func (c *clientV2) IsAuthorized(topic, channel string) (bool, error) {
if c.AuthState == nil {
return false, nil
}
if c.AuthState.IsExpired() {
err := c.QueryAuthd()
if err != nil {
return false, err
}
}
if c.AuthState.IsAllowed(topic, channel) {
return true, nil
}
return false, nil
}
func (c *clientV2) HasAuthorizations() bool {
if c.AuthState != nil {
return len(c.AuthState.Authorizations) != 0
}
return false
}