package auth
import (
"errors"
"fmt"
"log"
"net/url"
"regexp"
"time"
"github.com/nsqio/nsq/internal/http_api"
)
type Authorization struct { Topic string `json:"topic"`
Channels []string `json:"channels"`
Permissions []string `json:"permissions"`
}
type State struct { TTL int `json:"ttl"`
Authorizations []Authorization `json:"authorizations"`
Identity string `json:"identity"`
IdentityURL string `json:"identity_url"`
Expires time.Time
}
func (a *Authorization) HasPermission(permission string) bool { for _, p := range a.Permissions { if permission == p { return true
}
}
return false
}
func (a *Authorization) IsAllowed(topic, channel string) bool { if channel != "" { if !a.HasPermission("subscribe") { return false
}
} else { if !a.HasPermission("publish") { return false
}
}
topicRegex := regexp.MustCompile(a.Topic)
if !topicRegex.MatchString(topic) { return false
}
for _, c := range a.Channels { channelRegex := regexp.MustCompile(c)
if channelRegex.MatchString(channel) { return true
}
}
return false
}
func (a *State) IsAllowed(topic, channel string) bool { for _, aa := range a.Authorizations { if aa.IsAllowed(topic, channel) { return true
}
}
return false
}
func (a *State) IsExpired() bool { if a.Expires.Before(time.Now()) { return true
}
return false
}
func QueryAnyAuthd(authd []string, remoteIP, tlsEnabled, authSecret string,
connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) { for _, a := range authd { authState, err := QueryAuthd(a, remoteIP, tlsEnabled, authSecret, connectTimeout, requestTimeout)
if err != nil { log.Printf("Error: failed auth against %s %s", a, err) continue
}
return authState, nil
}
return nil, errors.New("Unable to access auth server")}
func QueryAuthd(authd, remoteIP, tlsEnabled, authSecret string,
connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) { v := url.Values{} v.Set("remote_ip", remoteIP) v.Set("tls", tlsEnabled) v.Set("secret", authSecret)
endpoint := fmt.Sprintf("http://%s/auth?%s", authd, v.Encode())
var authState State
client := http_api.NewClient(nil, connectTimeout, requestTimeout)
if err := client.GETV1(endpoint, &authState); err != nil { return nil, err
}
// validation on response
for _, auth := range authState.Authorizations { for _, p := range auth.Permissions { switch p { case "subscribe", "publish":
default:
return nil, fmt.Errorf("unknown permission %s", p) }
}
if _, err := regexp.Compile(auth.Topic); err != nil { return nil, fmt.Errorf("unable to compile topic %q %s", auth.Topic, err) }
for _, channel := range auth.Channels { if _, err := regexp.Compile(channel); err != nil { return nil, fmt.Errorf("unable to compile channel %q %s", channel, err) }
}
}
if authState.TTL <= 0 { return nil, fmt.Errorf("invalid TTL %d (must be >0)", authState.TTL) }
authState.Expires = time.Now().Add(time.Duration(authState.TTL) * time.Second)
return &authState, nil
}