diff --git a/README.md b/README.md index 0b77a3e..0cf0196 100644 --- a/README.md +++ b/README.md @@ -14,9 +14,8 @@ library has been tested. Vendoring a working version of this library is highly r * [Leveraging the library](docs/leveraging-the-library.md) ## To Do -* Allow library to use ZK to find the master * Create or import a custom transport that uses https://github.com/jmcvetta/napping to improve efficiency * End to end testing with Vagrant setup ## Contributions -Contributions are very much welcome. Please raise an issue so that the contribution may be discussed before it's made. +Contributions are very much welcome. Please raise an issue so that the contribution may be discussed before it's made. \ No newline at end of file diff --git a/clusters_test.go b/clusters_test.go new file mode 100644 index 0000000..ab79120 --- /dev/null +++ b/clusters_test.go @@ -0,0 +1,36 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package realis + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestLoadClusters(t *testing.T) { + + clusters, err := LoadClusters("examples/clusters.json") + if err != nil { + fmt.Print(err) + } + + assert.Equal(t, clusters[0].Name, "devcluster") + assert.Equal(t, clusters[0].ZK, "192.168.33.7") + assert.Equal(t, clusters[0].SchedZKPath, "/aurora/scheduler") + assert.Equal(t, clusters[0].AuthMechanism, "UNAUTHENTICATED") + assert.Equal(t, clusters[0].AgentRunDir, "latest") + assert.Equal(t, clusters[0].AgentRoot, "/var/lib/mesos") +} diff --git a/examples/client.go b/examples/client.go index 36f41dd..d681886 100644 --- a/examples/client.go +++ b/examples/client.go @@ -28,15 +28,33 @@ func main() { cmd := flag.String("cmd", "", "Job request type to send to Aurora Scheduler") executor := flag.String("executor", "thermos", "Executor to use") url := flag.String("url", "", "URL at which the Aurora Scheduler exists as [url]:[port]") + clustersConfig := flag.String("clusters", "", "Location of the clusters.json file used by aurora.") updateId := flag.String("updateId", "", "Update ID to operate on") username := flag.String("username", "aurora", "Username to use for authorization") password := flag.String("password", "secret", "Password to use for authorization") flag.Parse() + // Attempt to load leader from zookeeper + if *clustersConfig != "" { + clusters, err := realis.LoadClusters(*clustersConfig) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + cluster, _ := clusters["devcluster"] + + *url, err = realis.LeaderFromZK(cluster) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + } + //Create new configuration with default transport layer config, err := realis.NewDefaultConfig(*url) if err != nil { - fmt.Print(err) + fmt.Println(err) os.Exit(1) } @@ -51,7 +69,7 @@ func main() { case "thermos": payload, err := ioutil.ReadFile("examples/thermos_payload.json") if err != nil { - fmt.Print("Error reading json config file: ", err) + fmt.Println("Error reading json config file: ", err) os.Exit(1) } @@ -94,41 +112,41 @@ func main() { fmt.Println("Creating job") response, err := r.CreateJob(job) if err != nil { - fmt.Print(err) + fmt.Println(err) os.Exit(1) } - fmt.Print(response.String()) + fmt.Println(response.String()) break case "kill": fmt.Println("Killing job") response, err := r.KillJob(job.JobKey()) if err != nil { - fmt.Print(err) + fmt.Println(err) os.Exit(1) } - fmt.Print(response.String()) + fmt.Println(response.String()) break case "restart": fmt.Println("Restarting job") response, err := r.RestartJob(job.JobKey()) if err != nil { - fmt.Print(err) + fmt.Println(err) os.Exit(1) } - fmt.Print(response.String()) + fmt.Println(response.String()) break case "flexUp": fmt.Println("Flexing up job") response, err := r.AddInstances(&aurora.InstanceKey{job.JobKey(), 0}, 5) if err != nil { - fmt.Print(err) + fmt.Println(err) os.Exit(1) } - fmt.Print(response.String()) + fmt.Println(response.String()) break case "update": fmt.Println("Updating a job with a new name") @@ -138,19 +156,19 @@ func main() { resposne, err := r.StartJobUpdate(updateJob, "") if err != nil { - fmt.Print(err) + fmt.Println(err) os.Exit(1) } - fmt.Print(resposne.String()) + fmt.Println(resposne.String()) break case "abortUpdate": fmt.Println("Abort update") response, err := r.AbortJobUpdate(job.JobKey(), *updateId, "") if err != nil { - fmt.Print(err) + fmt.Println(err) os.Exit(1) } - fmt.Print(response.String()) + fmt.Println(response.String()) break default: fmt.Println("Only create, kill, restart, flexUp, update, and abortUpdate are supported now") diff --git a/examples/clusters.json b/examples/clusters.json new file mode 100644 index 0000000..287a618 --- /dev/null +++ b/examples/clusters.json @@ -0,0 +1,8 @@ +[{ + "name": "devcluster", + "zk": "192.168.33.7", + "scheduler_zk_path": "/aurora/scheduler", + "auth_mechanism": "UNAUTHENTICATED", + "slave_run_directory": "latest", + "slave_root": "/var/lib/mesos" +}] \ No newline at end of file diff --git a/realis.go b/realis.go index 712bdc0..d7f0803 100644 --- a/realis.go +++ b/realis.go @@ -63,7 +63,7 @@ func NewDefaultConfig(url string) (RealisConfig, error) { jar, err := cookiejar.New(nil) if err != nil { - return RealisConfig{}, errors.Wrap(err, "Error creating Cookie Jar.") + return RealisConfig{}, errors.Wrap(err, "Error creating Cookie Jar") } //Custom client to timeout after 10 seconds to avoid hanging @@ -71,12 +71,12 @@ func NewDefaultConfig(url string) (RealisConfig, error) { thrift.THttpClientOptions{Client: &http.Client{Timeout: time.Second * 10, Jar: jar}}) if err != nil { - return RealisConfig{}, errors.Wrap(err, "Error creating transport.") + return RealisConfig{}, errors.Wrap(err, "Error creating transport") } if err := trans.Open(); err != nil { fmt.Fprintln(os.Stderr) - return RealisConfig{}, errors.Wrapf(err, "Error opening connection to %s.", url) + return RealisConfig{}, errors.Wrapf(err, "Error opening connection to %s", url) } return RealisConfig{transport: trans}, nil @@ -108,7 +108,7 @@ func (r realisClient) getActiveInstanceIds(key *aurora.JobKey) (map[int32]bool, response, err := r.client.GetTasksWithoutConfigs(taskQ) if err != nil { - return nil, errors.Wrap(err, "Error querying Aurora Scheduler") + return nil, errors.Wrap(err, "Error querying Aurora Scheduler for active IDs") } tasks := response.GetResult_().GetScheduleStatusResult_().GetTasks() @@ -130,7 +130,7 @@ func (r realisClient) KillInstance(key *aurora.JobKey, instanceId int32) (*auror response, err := r.client.KillTasks(key, instanceIds) if err != nil { - return nil, errors.Wrap(err, "Error sending Kill command to Aurora Scheduler.") + return nil, errors.Wrap(err, "Error sending Kill command to Aurora Scheduler") } return response, nil @@ -141,19 +141,19 @@ func (r realisClient) KillJob(key *aurora.JobKey) (*aurora.Response, error) { instanceIds, err := r.getActiveInstanceIds(key) if err != nil { - return nil, errors.Wrap(err, "Could not retrieve relevant task instance IDs.") + return nil, errors.Wrap(err, "Could not retrieve relevant task instance IDs") } if len(instanceIds) > 0 { response, err := r.client.KillTasks(key, instanceIds) if err != nil { - return nil, errors.Wrap(err, "Error sending Kill command to Aurora Scheduler.") + return nil, errors.Wrap(err, "Error sending Kill command to Aurora Scheduler") } return response, nil } else { - return nil, errors.New("No tasks in the Active state.") + return nil, errors.New("No tasks in the Active state") } } @@ -162,7 +162,7 @@ func (r realisClient) CreateJob(auroraJob *Job) (*aurora.Response, error) { response, err := r.client.CreateJob(auroraJob.jobConfig) if err != nil { - return nil, errors.Wrap(err, "Error sending Create command to Aurora Scheduler.") + return nil, errors.Wrap(err, "Error sending Create command to Aurora Scheduler") } return response, nil @@ -173,29 +173,29 @@ func (r realisClient) RestartJob(key *aurora.JobKey) (*aurora.Response, error) { instanceIds, err := r.getActiveInstanceIds(key) if err != nil { - return nil, errors.Wrap(err, "Could not retrieve relevant task instance IDs.") + return nil, errors.Wrap(err, "Could not retrieve relevant task instance IDs") } if len(instanceIds) > 0 { response, err := r.client.RestartShards(key, instanceIds) if err != nil { - return nil, errors.Wrap(err, "Error sending Restart command to Aurora Scheduler.") + return nil, errors.Wrap(err, "Error sending Restart command to Aurora Scheduler") } return response, nil } else { - return nil, errors.New("No tasks in the Active state.") + return nil, errors.New("No tasks in the Active state") } } -// Update all tasks under a job configuration. Currently there's no support for canary deployments. +// Update all tasks under a job configuration. Currently gorealis doesn't support for canary deployments. func (r realisClient) StartJobUpdate(updateJob *UpdateJob, message string) (*aurora.Response, error) { response, err := r.client.StartJobUpdate(updateJob.req, message) if err != nil { - return nil, errors.Wrap(err, "Error sending StartJobUpdate command to Aurora Scheduler.") + return nil, errors.Wrap(err, "Error sending StartJobUpdate command to Aurora Scheduler") } return response, nil @@ -210,7 +210,7 @@ func (r realisClient) AbortJobUpdate( response, err := r.client.AbortJobUpdate(&aurora.JobUpdateKey{key, updateId}, message) if err != nil { - return nil, errors.Wrap(err, "Error sending AbortJobUpdate command to Aurora Scheduler.") + return nil, errors.Wrap(err, "Error sending AbortJobUpdate command to Aurora Scheduler") } return response, nil @@ -223,7 +223,7 @@ func (r realisClient) AddInstances(instKey *aurora.InstanceKey, count int32) (*a response, err := r.client.AddInstances(instKey, count) if err != nil { - return nil, errors.Wrap(err, "Error sending AddInstances command to Aurora Scheduler.") + return nil, errors.Wrap(err, "Error sending AddInstances command to Aurora Scheduler") } return response, nil diff --git a/vendor/git.apache.org/thrift.git/lib/go/thrift/framed_transport.go b/vendor/git.apache.org/thrift.git/lib/go/thrift/framed_transport.go index d0bae21..c8bb887 100644 --- a/vendor/git.apache.org/thrift.git/lib/go/thrift/framed_transport.go +++ b/vendor/git.apache.org/thrift.git/lib/go/thrift/framed_transport.go @@ -48,7 +48,7 @@ func NewTFramedTransportFactory(factory TTransportFactory) TTransportFactory { } func NewTFramedTransportFactoryMaxLength(factory TTransportFactory, maxLength uint32) TTransportFactory { - return &tFramedTransportFactory{factory: factory, maxLength: maxLength} + return &tFramedTransportFactory{factory: factory, maxLength: maxLength} } func (p *tFramedTransportFactory) GetTransport(base TTransport) TTransport { @@ -164,4 +164,3 @@ func (p *TFramedTransport) readFrameHeader() (uint32, error) { func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) { return uint64(p.frameSize) } - diff --git a/vendor/git.apache.org/thrift.git/lib/go/thrift/iostream_transport.go b/vendor/git.apache.org/thrift.git/lib/go/thrift/iostream_transport.go index 794872f..70bede9 100644 --- a/vendor/git.apache.org/thrift.git/lib/go/thrift/iostream_transport.go +++ b/vendor/git.apache.org/thrift.git/lib/go/thrift/iostream_transport.go @@ -209,6 +209,5 @@ func (p *StreamTransport) WriteString(s string) (n int, err error) { func (p *StreamTransport) RemainingBytes() (num_bytes uint64) { const maxSize = ^uint64(0) - return maxSize // the thruth is, we just don't know unless framed is used + return maxSize // the thruth is, we just don't know unless framed is used } - diff --git a/vendor/git.apache.org/thrift.git/lib/go/thrift/protocol.go b/vendor/git.apache.org/thrift.git/lib/go/thrift/protocol.go index 45fa202..5b77363 100644 --- a/vendor/git.apache.org/thrift.git/lib/go/thrift/protocol.go +++ b/vendor/git.apache.org/thrift.git/lib/go/thrift/protocol.go @@ -88,9 +88,9 @@ func SkipDefaultDepth(prot TProtocol, typeId TType) (err error) { // Skips over the next data element from the provided input TProtocol object. func Skip(self TProtocol, fieldType TType, maxDepth int) (err error) { - - if maxDepth <= 0 { - return NewTProtocolExceptionWithType( DEPTH_LIMIT, errors.New("Depth limit exceeded")) + + if maxDepth <= 0 { + return NewTProtocolExceptionWithType(DEPTH_LIMIT, errors.New("Depth limit exceeded")) } switch fieldType { diff --git a/vendor/git.apache.org/thrift.git/lib/go/thrift/protocol_exception.go b/vendor/git.apache.org/thrift.git/lib/go/thrift/protocol_exception.go index 6e357ee..29ab75d 100644 --- a/vendor/git.apache.org/thrift.git/lib/go/thrift/protocol_exception.go +++ b/vendor/git.apache.org/thrift.git/lib/go/thrift/protocol_exception.go @@ -60,7 +60,7 @@ func NewTProtocolException(err error) TProtocolException { if err == nil { return nil } - if e,ok := err.(TProtocolException); ok { + if e, ok := err.(TProtocolException); ok { return e } if _, ok := err.(base64.CorruptInputError); ok { @@ -75,4 +75,3 @@ func NewTProtocolExceptionWithType(errType int, err error) TProtocolException { } return &tProtocolException{errType, err.Error()} } - diff --git a/vendor/git.apache.org/thrift.git/lib/go/thrift/rich_transport.go b/vendor/git.apache.org/thrift.git/lib/go/thrift/rich_transport.go index 8e296a9..4025beb 100644 --- a/vendor/git.apache.org/thrift.git/lib/go/thrift/rich_transport.go +++ b/vendor/git.apache.org/thrift.git/lib/go/thrift/rich_transport.go @@ -66,4 +66,3 @@ func writeByte(w io.Writer, c byte) error { _, err := w.Write(v[0:1]) return err } - diff --git a/vendor/git.apache.org/thrift.git/lib/go/thrift/simple_server.go b/vendor/git.apache.org/thrift.git/lib/go/thrift/simple_server.go index 6b3811e..bdf4428 100644 --- a/vendor/git.apache.org/thrift.git/lib/go/thrift/simple_server.go +++ b/vendor/git.apache.org/thrift.git/lib/go/thrift/simple_server.go @@ -27,7 +27,7 @@ import ( // Simple, non-concurrent server for testing. type TSimpleServer struct { - quit chan struct{} + quit chan struct{} stopped int64 processorFactory TProcessorFactory diff --git a/vendor/git.apache.org/thrift.git/lib/go/thrift/socket.go b/vendor/git.apache.org/thrift.git/lib/go/thrift/socket.go index 82e28b4..383b1fe 100644 --- a/vendor/git.apache.org/thrift.git/lib/go/thrift/socket.go +++ b/vendor/git.apache.org/thrift.git/lib/go/thrift/socket.go @@ -161,6 +161,5 @@ func (p *TSocket) Interrupt() error { func (p *TSocket) RemainingBytes() (num_bytes uint64) { const maxSize = ^uint64(0) - return maxSize // the thruth is, we just don't know unless framed is used + return maxSize // the thruth is, we just don't know unless framed is used } - diff --git a/vendor/git.apache.org/thrift.git/lib/go/thrift/ssl_server_socket.go b/vendor/git.apache.org/thrift.git/lib/go/thrift/ssl_server_socket.go index 58f859b..0615528 100644 --- a/vendor/git.apache.org/thrift.git/lib/go/thrift/ssl_server_socket.go +++ b/vendor/git.apache.org/thrift.git/lib/go/thrift/ssl_server_socket.go @@ -20,9 +20,9 @@ package thrift import ( + "crypto/tls" "net" "time" - "crypto/tls" ) type TSSLServerSocket struct { diff --git a/vendor/git.apache.org/thrift.git/lib/go/thrift/ssl_socket.go b/vendor/git.apache.org/thrift.git/lib/go/thrift/ssl_socket.go index 04d3850..86a68a3 100644 --- a/vendor/git.apache.org/thrift.git/lib/go/thrift/ssl_socket.go +++ b/vendor/git.apache.org/thrift.git/lib/go/thrift/ssl_socket.go @@ -166,6 +166,5 @@ func (p *TSSLSocket) Interrupt() error { func (p *TSSLSocket) RemainingBytes() (num_bytes uint64) { const maxSize = ^uint64(0) - return maxSize // the thruth is, we just don't know unless framed is used + return maxSize // the thruth is, we just don't know unless framed is used } - diff --git a/vendor/git.apache.org/thrift.git/lib/go/thrift/transport.go b/vendor/git.apache.org/thrift.git/lib/go/thrift/transport.go index 4538996..70a85a8 100644 --- a/vendor/git.apache.org/thrift.git/lib/go/thrift/transport.go +++ b/vendor/git.apache.org/thrift.git/lib/go/thrift/transport.go @@ -34,7 +34,6 @@ type ReadSizeProvider interface { RemainingBytes() (num_bytes uint64) } - // Encapsulates the I/O layer type TTransport interface { io.ReadWriteCloser @@ -52,7 +51,6 @@ type stringWriter interface { WriteString(s string) (n int, err error) } - // This is "enchanced" transport with extra capabilities. You need to use one of these // to construct protocol. // Notably, TSocket does not implement this interface, and it is always a mistake to use @@ -65,4 +63,3 @@ type TRichTransport interface { Flusher ReadSizeProvider } - diff --git a/vendor/github.com/samuel/go-zookeeper/LICENSE b/vendor/github.com/samuel/go-zookeeper/LICENSE new file mode 100644 index 0000000..bc00498 --- /dev/null +++ b/vendor/github.com/samuel/go-zookeeper/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2013, Samuel Stauffer +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. +* Neither the name of the author nor the + names of its contributors may be used to endorse or promote products + derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/samuel/go-zookeeper/zk/conn.go b/vendor/github.com/samuel/go-zookeeper/zk/conn.go new file mode 100644 index 0000000..5ca8e2b --- /dev/null +++ b/vendor/github.com/samuel/go-zookeeper/zk/conn.go @@ -0,0 +1,935 @@ +// Package zk is a native Go client library for the ZooKeeper orchestration service. +package zk + +/* +TODO: +* make sure a ping response comes back in a reasonable time + +Possible watcher events: +* Event{Type: EventNotWatching, State: StateDisconnected, Path: path, Err: err} +*/ + +import ( + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// ErrNoServer indicates that an operation cannot be completed +// because attempts to connect to all servers in the list failed. +var ErrNoServer = errors.New("zk: could not connect to a server") + +// ErrInvalidPath indicates that an operation was being attempted on +// an invalid path. (e.g. empty path) +var ErrInvalidPath = errors.New("zk: invalid path") + +// DefaultLogger uses the stdlib log package for logging. +var DefaultLogger Logger = defaultLogger{} + +const ( + bufferSize = 1536 * 1024 + eventChanSize = 6 + sendChanSize = 16 + protectedPrefix = "_c_" +) + +type watchType int + +const ( + watchTypeData = iota + watchTypeExist = iota + watchTypeChild = iota +) + +type watchPathType struct { + path string + wType watchType +} + +type Dialer func(network, address string, timeout time.Duration) (net.Conn, error) + +// Logger is an interface that can be implemented to provide custom log output. +type Logger interface { + Printf(string, ...interface{}) +} + +// NoOp logger -- http://stackoverflow.com/questions/10571182/go-disable-a-log-logger +type NopLogger struct { + *log.Logger +} + +func (l NopLogger) Printf(string, ...interface{}) { + // noop +} + +type Conn struct { + lastZxid int64 + sessionID int64 + state State // must be 32-bit aligned + xid uint32 + sessionTimeoutMs int32 // session timeout in milliseconds + passwd []byte + + dialer Dialer + hostProvider HostProvider + serverMu sync.Mutex // protects server + server string // remember the address/port of the current server + conn net.Conn + eventChan chan Event + shouldQuit chan struct{} + pingInterval time.Duration + recvTimeout time.Duration + connectTimeout time.Duration + + sendChan chan *request + requests map[int32]*request // Xid -> pending request + requestsLock sync.Mutex + watchers map[watchPathType][]chan Event + watchersLock sync.Mutex + + // Debug (used by unit tests) + reconnectDelay time.Duration + + logger Logger +} + +// connOption represents a connection option. +type connOption func(c *Conn) + +type request struct { + xid int32 + opcode int32 + pkt interface{} + recvStruct interface{} + recvChan chan response + + // Because sending and receiving happen in separate go routines, there's + // a possible race condition when creating watches from outside the read + // loop. We must ensure that a watcher gets added to the list synchronously + // with the response from the server on any request that creates a watch. + // In order to not hard code the watch logic for each opcode in the recv + // loop the caller can use recvFunc to insert some synchronously code + // after a response. + recvFunc func(*request, *responseHeader, error) +} + +type response struct { + zxid int64 + err error +} + +type Event struct { + Type EventType + State State + Path string // For non-session events, the path of the watched node. + Err error + Server string // For connection events +} + +// HostProvider is used to represent a set of hosts a ZooKeeper client should connect to. +// It is an analog of the Java equivalent: +// http://svn.apache.org/viewvc/zookeeper/trunk/src/java/main/org/apache/zookeeper/client/HostProvider.java?view=markup +type HostProvider interface { + // Init is called first, with the servers specified in the connection string. + Init(servers []string) error + // Len returns the number of servers. + Len() int + // Next returns the next server to connect to. retryStart will be true if we've looped through + // all known servers without Connected() being called. + Next() (server string, retryStart bool) + // Notify the HostProvider of a successful connection. + Connected() +} + +// ConnectWithDialer establishes a new connection to a pool of zookeeper servers +// using a custom Dialer. See Connect for further information about session timeout. +// This method is deprecated and provided for compatibility: use the WithDialer option instead. +func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) { + return Connect(servers, sessionTimeout, WithDialer(dialer)) +} + +// Connect establishes a new connection to a pool of zookeeper +// servers. The provided session timeout sets the amount of time for which +// a session is considered valid after losing connection to a server. Within +// the session timeout it's possible to reestablish a connection to a different +// server and keep the same session. This is means any ephemeral nodes and +// watches are maintained. +func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) { + if len(servers) == 0 { + return nil, nil, errors.New("zk: server list must not be empty") + } + + srvs := make([]string, len(servers)) + + for i, addr := range servers { + if strings.Contains(addr, ":") { + srvs[i] = addr + } else { + srvs[i] = addr + ":" + strconv.Itoa(DefaultPort) + } + } + + // Randomize the order of the servers to avoid creating hotspots + stringShuffle(srvs) + + ec := make(chan Event, eventChanSize) + conn := &Conn{ + dialer: net.DialTimeout, + hostProvider: &DNSHostProvider{}, + conn: nil, + state: StateDisconnected, + eventChan: ec, + shouldQuit: make(chan struct{}), + connectTimeout: 1 * time.Second, + sendChan: make(chan *request, sendChanSize), + requests: make(map[int32]*request), + watchers: make(map[watchPathType][]chan Event), + passwd: emptyPassword, + logger: DefaultLogger, + + // Debug + reconnectDelay: 0, + } + + // Set provided options. + for _, option := range options { + option(conn) + } + + if err := conn.hostProvider.Init(srvs); err != nil { + return nil, nil, err + } + + conn.setTimeouts(int32(sessionTimeout / time.Millisecond)) + + go func() { + conn.loop() + conn.flushRequests(ErrClosing) + conn.invalidateWatches(ErrClosing) + close(conn.eventChan) + }() + return conn, ec, nil +} + +// WithDialer returns a connection option specifying a non-default Dialer. +func WithDialer(dialer Dialer) connOption { + return func(c *Conn) { + c.dialer = dialer + } +} + +// WithHostProvider returns a connection option specifying a non-default HostProvider. +func WithHostProvider(hostProvider HostProvider) connOption { + return func(c *Conn) { + c.hostProvider = hostProvider + } +} + +// WithLogger returns a connection option specifying a non-default logger +func WithLogger(logger Logger) connOption { + return func(c *Conn) { + c.logger = logger + } +} + +// WithLogger returns a connection option specifying a non-default logger +func WithoutLogger() connOption { + return func(c *Conn) { + c.logger = NopLogger{log.New(os.Stderr, "", log.LstdFlags)} + } +} + +func (c *Conn) Close() { + close(c.shouldQuit) + + select { + case <-c.queueRequest(opClose, &closeRequest{}, &closeResponse{}, nil): + case <-time.After(time.Second): + } +} + +// State returns the current state of the connection. +func (c *Conn) State() State { + return State(atomic.LoadInt32((*int32)(&c.state))) +} + +// SessionId returns the current session id of the connection. +func (c *Conn) SessionID() int64 { + return atomic.LoadInt64(&c.sessionID) +} + +// SetLogger sets the logger to be used for printing errors. +// Logger is an interface provided by this package. +func (c *Conn) SetLogger(l Logger) { + c.logger = l +} + +func (c *Conn) setTimeouts(sessionTimeoutMs int32) { + c.sessionTimeoutMs = sessionTimeoutMs + sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond + c.recvTimeout = sessionTimeout * 2 / 3 + c.pingInterval = c.recvTimeout / 2 +} + +func (c *Conn) setState(state State) { + atomic.StoreInt32((*int32)(&c.state), int32(state)) + select { + case c.eventChan <- Event{Type: EventSession, State: state, Server: c.Server()}: + default: + // panic("zk: event channel full - it must be monitored and never allowed to be full") + } +} + +func (c *Conn) connect() error { + var retryStart bool + for { + c.serverMu.Lock() + c.server, retryStart = c.hostProvider.Next() + c.serverMu.Unlock() + c.setState(StateConnecting) + if retryStart { + c.flushUnsentRequests(ErrNoServer) + select { + case <-time.After(time.Second): + // pass + case <-c.shouldQuit: + c.setState(StateDisconnected) + c.flushUnsentRequests(ErrClosing) + return ErrClosing + } + } + + zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout) + if err == nil { + c.conn = zkConn + c.setState(StateConnected) + c.logger.Printf("Connected to %s", c.Server()) + return nil + } + + c.logger.Printf("Failed to connect to %s: %+v", c.Server(), err) + } +} + +func (c *Conn) loop() { + for { + if err := c.connect(); err != nil { + // c.Close() was called + return + } + + err := c.authenticate() + switch { + case err == ErrSessionExpired: + c.logger.Printf("Authentication failed: %s", err) + c.invalidateWatches(err) + case err != nil && c.conn != nil: + c.logger.Printf("Authentication failed: %s", err) + c.conn.Close() + case err == nil: + c.logger.Printf("Authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs) + c.hostProvider.Connected() // mark success + closeChan := make(chan struct{}) // channel to tell send loop stop + var wg sync.WaitGroup + + wg.Add(1) + go func() { + err := c.sendLoop(c.conn, closeChan) + c.logger.Printf("Send loop terminated: err=%v", err) + c.conn.Close() // causes recv loop to EOF/exit + wg.Done() + }() + + wg.Add(1) + go func() { + err := c.recvLoop(c.conn) + c.logger.Printf("Recv loop terminated: err=%v", err) + if err == nil { + panic("zk: recvLoop should never return nil error") + } + close(closeChan) // tell send loop to exit + wg.Done() + }() + + c.sendSetWatches() + wg.Wait() + } + + c.setState(StateDisconnected) + + select { + case <-c.shouldQuit: + c.flushRequests(ErrClosing) + return + default: + } + + if err != ErrSessionExpired { + err = ErrConnectionClosed + } + c.flushRequests(err) + + if c.reconnectDelay > 0 { + select { + case <-c.shouldQuit: + return + case <-time.After(c.reconnectDelay): + } + } + } +} + +func (c *Conn) flushUnsentRequests(err error) { + for { + select { + default: + return + case req := <-c.sendChan: + req.recvChan <- response{-1, err} + } + } +} + +// Send error to all pending requests and clear request map +func (c *Conn) flushRequests(err error) { + c.requestsLock.Lock() + for _, req := range c.requests { + req.recvChan <- response{-1, err} + } + c.requests = make(map[int32]*request) + c.requestsLock.Unlock() +} + +// Send error to all watchers and clear watchers map +func (c *Conn) invalidateWatches(err error) { + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + + if len(c.watchers) >= 0 { + for pathType, watchers := range c.watchers { + ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err} + for _, ch := range watchers { + ch <- ev + close(ch) + } + } + c.watchers = make(map[watchPathType][]chan Event) + } +} + +func (c *Conn) sendSetWatches() { + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + + if len(c.watchers) == 0 { + return + } + + req := &setWatchesRequest{ + RelativeZxid: c.lastZxid, + DataWatches: make([]string, 0), + ExistWatches: make([]string, 0), + ChildWatches: make([]string, 0), + } + n := 0 + for pathType, watchers := range c.watchers { + if len(watchers) == 0 { + continue + } + switch pathType.wType { + case watchTypeData: + req.DataWatches = append(req.DataWatches, pathType.path) + case watchTypeExist: + req.ExistWatches = append(req.ExistWatches, pathType.path) + case watchTypeChild: + req.ChildWatches = append(req.ChildWatches, pathType.path) + } + n++ + } + if n == 0 { + return + } + + go func() { + res := &setWatchesResponse{} + _, err := c.request(opSetWatches, req, res, nil) + if err != nil { + c.logger.Printf("Failed to set previous watches: %s", err.Error()) + } + }() +} + +func (c *Conn) authenticate() error { + buf := make([]byte, 256) + + // Encode and send a connect request. + n, err := encodePacket(buf[4:], &connectRequest{ + ProtocolVersion: protocolVersion, + LastZxidSeen: c.lastZxid, + TimeOut: c.sessionTimeoutMs, + SessionID: c.SessionID(), + Passwd: c.passwd, + }) + if err != nil { + return err + } + + binary.BigEndian.PutUint32(buf[:4], uint32(n)) + + c.conn.SetWriteDeadline(time.Now().Add(c.recvTimeout * 10)) + _, err = c.conn.Write(buf[:n+4]) + c.conn.SetWriteDeadline(time.Time{}) + if err != nil { + return err + } + + // Receive and decode a connect response. + c.conn.SetReadDeadline(time.Now().Add(c.recvTimeout * 10)) + _, err = io.ReadFull(c.conn, buf[:4]) + c.conn.SetReadDeadline(time.Time{}) + if err != nil { + return err + } + + blen := int(binary.BigEndian.Uint32(buf[:4])) + if cap(buf) < blen { + buf = make([]byte, blen) + } + + _, err = io.ReadFull(c.conn, buf[:blen]) + if err != nil { + return err + } + + r := connectResponse{} + _, err = decodePacket(buf[:blen], &r) + if err != nil { + return err + } + if r.SessionID == 0 { + atomic.StoreInt64(&c.sessionID, int64(0)) + c.passwd = emptyPassword + c.lastZxid = 0 + c.setState(StateExpired) + return ErrSessionExpired + } + + atomic.StoreInt64(&c.sessionID, r.SessionID) + c.setTimeouts(r.TimeOut) + c.passwd = r.Passwd + c.setState(StateHasSession) + + return nil +} + +func (c *Conn) sendLoop(conn net.Conn, closeChan <-chan struct{}) error { + pingTicker := time.NewTicker(c.pingInterval) + defer pingTicker.Stop() + + buf := make([]byte, bufferSize) + for { + select { + case req := <-c.sendChan: + header := &requestHeader{req.xid, req.opcode} + n, err := encodePacket(buf[4:], header) + if err != nil { + req.recvChan <- response{-1, err} + continue + } + + n2, err := encodePacket(buf[4+n:], req.pkt) + if err != nil { + req.recvChan <- response{-1, err} + continue + } + + n += n2 + + binary.BigEndian.PutUint32(buf[:4], uint32(n)) + + c.requestsLock.Lock() + select { + case <-closeChan: + req.recvChan <- response{-1, ErrConnectionClosed} + c.requestsLock.Unlock() + return ErrConnectionClosed + default: + } + c.requests[req.xid] = req + c.requestsLock.Unlock() + + conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) + _, err = conn.Write(buf[:n+4]) + conn.SetWriteDeadline(time.Time{}) + if err != nil { + req.recvChan <- response{-1, err} + conn.Close() + return err + } + case <-pingTicker.C: + n, err := encodePacket(buf[4:], &requestHeader{Xid: -2, Opcode: opPing}) + if err != nil { + panic("zk: opPing should never fail to serialize") + } + + binary.BigEndian.PutUint32(buf[:4], uint32(n)) + + conn.SetWriteDeadline(time.Now().Add(c.recvTimeout)) + _, err = conn.Write(buf[:n+4]) + conn.SetWriteDeadline(time.Time{}) + if err != nil { + conn.Close() + return err + } + case <-closeChan: + return nil + } + } +} + +func (c *Conn) recvLoop(conn net.Conn) error { + buf := make([]byte, bufferSize) + for { + // package length + conn.SetReadDeadline(time.Now().Add(c.recvTimeout)) + _, err := io.ReadFull(conn, buf[:4]) + if err != nil { + return err + } + + blen := int(binary.BigEndian.Uint32(buf[:4])) + if cap(buf) < blen { + buf = make([]byte, blen) + } + + _, err = io.ReadFull(conn, buf[:blen]) + conn.SetReadDeadline(time.Time{}) + if err != nil { + return err + } + + res := responseHeader{} + _, err = decodePacket(buf[:16], &res) + if err != nil { + return err + } + + if res.Xid == -1 { + res := &watcherEvent{} + _, err := decodePacket(buf[16:blen], res) + if err != nil { + return err + } + ev := Event{ + Type: res.Type, + State: res.State, + Path: res.Path, + Err: nil, + } + select { + case c.eventChan <- ev: + default: + } + wTypes := make([]watchType, 0, 2) + switch res.Type { + case EventNodeCreated: + wTypes = append(wTypes, watchTypeExist) + case EventNodeDeleted, EventNodeDataChanged: + wTypes = append(wTypes, watchTypeExist, watchTypeData, watchTypeChild) + case EventNodeChildrenChanged: + wTypes = append(wTypes, watchTypeChild) + } + c.watchersLock.Lock() + for _, t := range wTypes { + wpt := watchPathType{res.Path, t} + if watchers := c.watchers[wpt]; watchers != nil && len(watchers) > 0 { + for _, ch := range watchers { + ch <- ev + close(ch) + } + delete(c.watchers, wpt) + } + } + c.watchersLock.Unlock() + } else if res.Xid == -2 { + // Ping response. Ignore. + } else if res.Xid < 0 { + c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid) + } else { + if res.Zxid > 0 { + c.lastZxid = res.Zxid + } + + c.requestsLock.Lock() + req, ok := c.requests[res.Xid] + if ok { + delete(c.requests, res.Xid) + } + c.requestsLock.Unlock() + + if !ok { + c.logger.Printf("Response for unknown request with xid %d", res.Xid) + } else { + if res.Err != 0 { + err = res.Err.toError() + } else { + _, err = decodePacket(buf[16:blen], req.recvStruct) + } + if req.recvFunc != nil { + req.recvFunc(req, &res, err) + } + req.recvChan <- response{res.Zxid, err} + if req.opcode == opClose { + return io.EOF + } + } + } + } +} + +func (c *Conn) nextXid() int32 { + return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff) +} + +func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event { + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + + ch := make(chan Event, 1) + wpt := watchPathType{path, watchType} + c.watchers[wpt] = append(c.watchers[wpt], ch) + return ch +} + +func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response { + rq := &request{ + xid: c.nextXid(), + opcode: opcode, + pkt: req, + recvStruct: res, + recvChan: make(chan response, 1), + recvFunc: recvFunc, + } + c.sendChan <- rq + return rq.recvChan +} + +func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { + r := <-c.queueRequest(opcode, req, res, recvFunc) + return r.zxid, r.err +} + +func (c *Conn) AddAuth(scheme string, auth []byte) error { + _, err := c.request(opSetAuth, &setAuthRequest{Type: 0, Scheme: scheme, Auth: auth}, &setAuthResponse{}, nil) + return err +} + +func (c *Conn) Children(path string) ([]string, *Stat, error) { + res := &getChildren2Response{} + _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil) + return res.Children, &res.Stat, err +} + +func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) { + var ech <-chan Event + res := &getChildren2Response{} + _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { + if err == nil { + ech = c.addWatcher(path, watchTypeChild) + } + }) + if err != nil { + return nil, nil, nil, err + } + return res.Children, &res.Stat, ech, err +} + +func (c *Conn) Get(path string) ([]byte, *Stat, error) { + res := &getDataResponse{} + _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil) + return res.Data, &res.Stat, err +} + +// GetW returns the contents of a znode and sets a watch +func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) { + var ech <-chan Event + res := &getDataResponse{} + _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { + if err == nil { + ech = c.addWatcher(path, watchTypeData) + } + }) + if err != nil { + return nil, nil, nil, err + } + return res.Data, &res.Stat, ech, err +} + +func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) { + if path == "" { + return nil, ErrInvalidPath + } + res := &setDataResponse{} + _, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil) + return &res.Stat, err +} + +func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, error) { + res := &createResponse{} + _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil) + return res.Path, err +} + +// CreateProtectedEphemeralSequential fixes a race condition if the server crashes +// after it creates the node. On reconnect the session may still be valid so the +// ephemeral node still exists. Therefore, on reconnect we need to check if a node +// with a GUID generated on create exists. +func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl []ACL) (string, error) { + var guid [16]byte + _, err := io.ReadFull(rand.Reader, guid[:16]) + if err != nil { + return "", err + } + guidStr := fmt.Sprintf("%x", guid) + + parts := strings.Split(path, "/") + parts[len(parts)-1] = fmt.Sprintf("%s%s-%s", protectedPrefix, guidStr, parts[len(parts)-1]) + rootPath := strings.Join(parts[:len(parts)-1], "/") + protectedPath := strings.Join(parts, "/") + + var newPath string + for i := 0; i < 3; i++ { + newPath, err = c.Create(protectedPath, data, FlagEphemeral|FlagSequence, acl) + switch err { + case ErrSessionExpired: + // No need to search for the node since it can't exist. Just try again. + case ErrConnectionClosed: + children, _, err := c.Children(rootPath) + if err != nil { + return "", err + } + for _, p := range children { + parts := strings.Split(p, "/") + if pth := parts[len(parts)-1]; strings.HasPrefix(pth, protectedPrefix) { + if g := pth[len(protectedPrefix) : len(protectedPrefix)+32]; g == guidStr { + return rootPath + "/" + p, nil + } + } + } + case nil: + return newPath, nil + default: + return "", err + } + } + return "", err +} + +func (c *Conn) Delete(path string, version int32) error { + _, err := c.request(opDelete, &DeleteRequest{path, version}, &deleteResponse{}, nil) + return err +} + +func (c *Conn) Exists(path string) (bool, *Stat, error) { + res := &existsResponse{} + _, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil) + exists := true + if err == ErrNoNode { + exists = false + err = nil + } + return exists, &res.Stat, err +} + +func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { + var ech <-chan Event + res := &existsResponse{} + _, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { + if err == nil { + ech = c.addWatcher(path, watchTypeData) + } else if err == ErrNoNode { + ech = c.addWatcher(path, watchTypeExist) + } + }) + exists := true + if err == ErrNoNode { + exists = false + err = nil + } + if err != nil { + return false, nil, nil, err + } + return exists, &res.Stat, ech, err +} + +func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) { + res := &getAclResponse{} + _, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil) + return res.Acl, &res.Stat, err +} +func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) { + res := &setAclResponse{} + _, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil) + return &res.Stat, err +} + +func (c *Conn) Sync(path string) (string, error) { + res := &syncResponse{} + _, err := c.request(opSync, &syncRequest{Path: path}, res, nil) + return res.Path, err +} + +type MultiResponse struct { + Stat *Stat + String string +} + +// Multi executes multiple ZooKeeper operations or none of them. The provided +// ops must be one of *CreateRequest, *DeleteRequest, *SetDataRequest, or +// *CheckVersionRequest. +func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { + req := &multiRequest{ + Ops: make([]multiRequestOp, 0, len(ops)), + DoneHeader: multiHeader{Type: -1, Done: true, Err: -1}, + } + for _, op := range ops { + var opCode int32 + switch op.(type) { + case *CreateRequest: + opCode = opCreate + case *SetDataRequest: + opCode = opSetData + case *DeleteRequest: + opCode = opDelete + case *CheckVersionRequest: + opCode = opCheck + default: + return nil, fmt.Errorf("unknown operation type %T", op) + } + req.Ops = append(req.Ops, multiRequestOp{multiHeader{opCode, false, -1}, op}) + } + res := &multiResponse{} + _, err := c.request(opMulti, req, res, nil) + mr := make([]MultiResponse, len(res.Ops)) + for i, op := range res.Ops { + mr[i] = MultiResponse{Stat: op.Stat, String: op.String} + } + return mr, err +} + +// Server returns the current or last-connected server name. +func (c *Conn) Server() string { + c.serverMu.Lock() + defer c.serverMu.Unlock() + return c.server +} diff --git a/vendor/github.com/samuel/go-zookeeper/zk/constants.go b/vendor/github.com/samuel/go-zookeeper/zk/constants.go new file mode 100644 index 0000000..f9b39b9 --- /dev/null +++ b/vendor/github.com/samuel/go-zookeeper/zk/constants.go @@ -0,0 +1,240 @@ +package zk + +import ( + "errors" +) + +const ( + protocolVersion = 0 + + DefaultPort = 2181 +) + +const ( + opNotify = 0 + opCreate = 1 + opDelete = 2 + opExists = 3 + opGetData = 4 + opSetData = 5 + opGetAcl = 6 + opSetAcl = 7 + opGetChildren = 8 + opSync = 9 + opPing = 11 + opGetChildren2 = 12 + opCheck = 13 + opMulti = 14 + opClose = -11 + opSetAuth = 100 + opSetWatches = 101 + // Not in protocol, used internally + opWatcherEvent = -2 +) + +const ( + EventNodeCreated = EventType(1) + EventNodeDeleted = EventType(2) + EventNodeDataChanged = EventType(3) + EventNodeChildrenChanged = EventType(4) + + EventSession = EventType(-1) + EventNotWatching = EventType(-2) +) + +var ( + eventNames = map[EventType]string{ + EventNodeCreated: "EventNodeCreated", + EventNodeDeleted: "EventNodeDeleted", + EventNodeDataChanged: "EventNodeDataChanged", + EventNodeChildrenChanged: "EventNodeChildrenChanged", + EventSession: "EventSession", + EventNotWatching: "EventNotWatching", + } +) + +const ( + StateUnknown = State(-1) + StateDisconnected = State(0) + StateConnecting = State(1) + StateAuthFailed = State(4) + StateConnectedReadOnly = State(5) + StateSaslAuthenticated = State(6) + StateExpired = State(-112) + // StateAuthFailed = State(-113) + + StateConnected = State(100) + StateHasSession = State(101) +) + +const ( + FlagEphemeral = 1 + FlagSequence = 2 +) + +var ( + stateNames = map[State]string{ + StateUnknown: "StateUnknown", + StateDisconnected: "StateDisconnected", + StateConnectedReadOnly: "StateConnectedReadOnly", + StateSaslAuthenticated: "StateSaslAuthenticated", + StateExpired: "StateExpired", + StateAuthFailed: "StateAuthFailed", + StateConnecting: "StateConnecting", + StateConnected: "StateConnected", + StateHasSession: "StateHasSession", + } +) + +type State int32 + +func (s State) String() string { + if name := stateNames[s]; name != "" { + return name + } + return "Unknown" +} + +type ErrCode int32 + +var ( + ErrConnectionClosed = errors.New("zk: connection closed") + ErrUnknown = errors.New("zk: unknown error") + ErrAPIError = errors.New("zk: api error") + ErrNoNode = errors.New("zk: node does not exist") + ErrNoAuth = errors.New("zk: not authenticated") + ErrBadVersion = errors.New("zk: version conflict") + ErrNoChildrenForEphemerals = errors.New("zk: ephemeral nodes may not have children") + ErrNodeExists = errors.New("zk: node already exists") + ErrNotEmpty = errors.New("zk: node has children") + ErrSessionExpired = errors.New("zk: session has been expired by the server") + ErrInvalidACL = errors.New("zk: invalid ACL specified") + ErrAuthFailed = errors.New("zk: client authentication failed") + ErrClosing = errors.New("zk: zookeeper is closing") + ErrNothing = errors.New("zk: no server responsees to process") + ErrSessionMoved = errors.New("zk: session moved to another server, so operation is ignored") + + // ErrInvalidCallback = errors.New("zk: invalid callback specified") + errCodeToError = map[ErrCode]error{ + 0: nil, + errAPIError: ErrAPIError, + errNoNode: ErrNoNode, + errNoAuth: ErrNoAuth, + errBadVersion: ErrBadVersion, + errNoChildrenForEphemerals: ErrNoChildrenForEphemerals, + errNodeExists: ErrNodeExists, + errNotEmpty: ErrNotEmpty, + errSessionExpired: ErrSessionExpired, + // errInvalidCallback: ErrInvalidCallback, + errInvalidAcl: ErrInvalidACL, + errAuthFailed: ErrAuthFailed, + errClosing: ErrClosing, + errNothing: ErrNothing, + errSessionMoved: ErrSessionMoved, + } +) + +func (e ErrCode) toError() error { + if err, ok := errCodeToError[e]; ok { + return err + } + return ErrUnknown +} + +const ( + errOk = 0 + // System and server-side errors + errSystemError = -1 + errRuntimeInconsistency = -2 + errDataInconsistency = -3 + errConnectionLoss = -4 + errMarshallingError = -5 + errUnimplemented = -6 + errOperationTimeout = -7 + errBadArguments = -8 + errInvalidState = -9 + // API errors + errAPIError = ErrCode(-100) + errNoNode = ErrCode(-101) // * + errNoAuth = ErrCode(-102) + errBadVersion = ErrCode(-103) // * + errNoChildrenForEphemerals = ErrCode(-108) + errNodeExists = ErrCode(-110) // * + errNotEmpty = ErrCode(-111) + errSessionExpired = ErrCode(-112) + errInvalidCallback = ErrCode(-113) + errInvalidAcl = ErrCode(-114) + errAuthFailed = ErrCode(-115) + errClosing = ErrCode(-116) + errNothing = ErrCode(-117) + errSessionMoved = ErrCode(-118) +) + +// Constants for ACL permissions +const ( + PermRead = 1 << iota + PermWrite + PermCreate + PermDelete + PermAdmin + PermAll = 0x1f +) + +var ( + emptyPassword = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + opNames = map[int32]string{ + opNotify: "notify", + opCreate: "create", + opDelete: "delete", + opExists: "exists", + opGetData: "getData", + opSetData: "setData", + opGetAcl: "getACL", + opSetAcl: "setACL", + opGetChildren: "getChildren", + opSync: "sync", + opPing: "ping", + opGetChildren2: "getChildren2", + opCheck: "check", + opMulti: "multi", + opClose: "close", + opSetAuth: "setAuth", + opSetWatches: "setWatches", + + opWatcherEvent: "watcherEvent", + } +) + +type EventType int32 + +func (t EventType) String() string { + if name := eventNames[t]; name != "" { + return name + } + return "Unknown" +} + +// Mode is used to build custom server modes (leader|follower|standalone). +type Mode uint8 + +func (m Mode) String() string { + if name := modeNames[m]; name != "" { + return name + } + return "unknown" +} + +const ( + ModeUnknown Mode = iota + ModeLeader Mode = iota + ModeFollower Mode = iota + ModeStandalone Mode = iota +) + +var ( + modeNames = map[Mode]string{ + ModeLeader: "leader", + ModeFollower: "follower", + ModeStandalone: "standalone", + } +) diff --git a/vendor/github.com/samuel/go-zookeeper/zk/dnshostprovider.go b/vendor/github.com/samuel/go-zookeeper/zk/dnshostprovider.go new file mode 100644 index 0000000..f4bba8d --- /dev/null +++ b/vendor/github.com/samuel/go-zookeeper/zk/dnshostprovider.go @@ -0,0 +1,88 @@ +package zk + +import ( + "fmt" + "net" + "sync" +) + +// DNSHostProvider is the default HostProvider. It currently matches +// the Java StaticHostProvider, resolving hosts from DNS once during +// the call to Init. It could be easily extended to re-query DNS +// periodically or if there is trouble connecting. +type DNSHostProvider struct { + mu sync.Mutex // Protects everything, so we can add asynchronous updates later. + servers []string + curr int + last int + lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing. +} + +// Init is called first, with the servers specified in the connection +// string. It uses DNS to look up addresses for each server, then +// shuffles them all together. +func (hp *DNSHostProvider) Init(servers []string) error { + hp.mu.Lock() + defer hp.mu.Unlock() + + lookupHost := hp.lookupHost + if lookupHost == nil { + lookupHost = net.LookupHost + } + + found := []string{} + for _, server := range servers { + host, port, err := net.SplitHostPort(server) + if err != nil { + return err + } + addrs, err := lookupHost(host) + if err != nil { + return err + } + for _, addr := range addrs { + found = append(found, net.JoinHostPort(addr, port)) + } + } + + if len(found) == 0 { + return fmt.Errorf("No hosts found for addresses %q", servers) + } + + // Randomize the order of the servers to avoid creating hotspots + stringShuffle(found) + + hp.servers = found + hp.curr = -1 + hp.last = -1 + + return nil +} + +// Len returns the number of servers available +func (hp *DNSHostProvider) Len() int { + hp.mu.Lock() + defer hp.mu.Unlock() + return len(hp.servers) +} + +// Next returns the next server to connect to. retryStart will be true +// if we've looped through all known servers without Connected() being +// called. +func (hp *DNSHostProvider) Next() (server string, retryStart bool) { + hp.mu.Lock() + defer hp.mu.Unlock() + hp.curr = (hp.curr + 1) % len(hp.servers) + retryStart = hp.curr == hp.last + if hp.last == -1 { + hp.last = 0 + } + return hp.servers[hp.curr], retryStart +} + +// Connected notifies the HostProvider of a successful connection. +func (hp *DNSHostProvider) Connected() { + hp.mu.Lock() + defer hp.mu.Unlock() + hp.last = hp.curr +} diff --git a/vendor/github.com/samuel/go-zookeeper/zk/flw.go b/vendor/github.com/samuel/go-zookeeper/zk/flw.go new file mode 100644 index 0000000..3e97f96 --- /dev/null +++ b/vendor/github.com/samuel/go-zookeeper/zk/flw.go @@ -0,0 +1,266 @@ +package zk + +import ( + "bufio" + "bytes" + "fmt" + "io/ioutil" + "net" + "regexp" + "strconv" + "strings" + "time" +) + +// FLWSrvr is a FourLetterWord helper function. In particular, this function pulls the srvr output +// from the zookeeper instances and parses the output. A slice of *ServerStats structs are returned +// as well as a boolean value to indicate whether this function processed successfully. +// +// If the boolean value is false there was a problem. If the *ServerStats slice is empty or nil, +// then the error happened before we started to obtain 'srvr' values. Otherwise, one of the +// servers had an issue and the "Error" value in the struct should be inspected to determine +// which server had the issue. +func FLWSrvr(servers []string, timeout time.Duration) ([]*ServerStats, bool) { + // different parts of the regular expression that are required to parse the srvr output + const ( + zrVer = `^Zookeeper version: ([A-Za-z0-9\.\-]+), built on (\d\d/\d\d/\d\d\d\d \d\d:\d\d [A-Za-z0-9:\+\-]+)` + zrLat = `^Latency min/avg/max: (\d+)/(\d+)/(\d+)` + zrNet = `^Received: (\d+).*\n^Sent: (\d+).*\n^Connections: (\d+).*\n^Outstanding: (\d+)` + zrState = `^Zxid: (0x[A-Za-z0-9]+).*\n^Mode: (\w+).*\n^Node count: (\d+)` + ) + + // build the regex from the pieces above + re, err := regexp.Compile(fmt.Sprintf(`(?m:\A%v.*\n%v.*\n%v.*\n%v)`, zrVer, zrLat, zrNet, zrState)) + if err != nil { + return nil, false + } + + imOk := true + servers = FormatServers(servers) + ss := make([]*ServerStats, len(servers)) + + for i := range ss { + response, err := fourLetterWord(servers[i], "srvr", timeout) + + if err != nil { + ss[i] = &ServerStats{Error: err} + imOk = false + continue + } + + matches := re.FindAllStringSubmatch(string(response), -1) + + if matches == nil { + err := fmt.Errorf("unable to parse fields from zookeeper response (no regex matches)") + ss[i] = &ServerStats{Error: err} + imOk = false + continue + } + + match := matches[0][1:] + + // determine current server + var srvrMode Mode + switch match[10] { + case "leader": + srvrMode = ModeLeader + case "follower": + srvrMode = ModeFollower + case "standalone": + srvrMode = ModeStandalone + default: + srvrMode = ModeUnknown + } + + buildTime, err := time.Parse("01/02/2006 15:04 MST", match[1]) + + if err != nil { + ss[i] = &ServerStats{Error: err} + imOk = false + continue + } + + parsedInt, err := strconv.ParseInt(match[9], 0, 64) + + if err != nil { + ss[i] = &ServerStats{Error: err} + imOk = false + continue + } + + // the ZxID value is an int64 with two int32s packed inside + // the high int32 is the epoch (i.e., number of leader elections) + // the low int32 is the counter + epoch := int32(parsedInt >> 32) + counter := int32(parsedInt & 0xFFFFFFFF) + + // within the regex above, these values must be numerical + // so we can avoid useless checking of the error return value + minLatency, _ := strconv.ParseInt(match[2], 0, 64) + avgLatency, _ := strconv.ParseInt(match[3], 0, 64) + maxLatency, _ := strconv.ParseInt(match[4], 0, 64) + recv, _ := strconv.ParseInt(match[5], 0, 64) + sent, _ := strconv.ParseInt(match[6], 0, 64) + cons, _ := strconv.ParseInt(match[7], 0, 64) + outs, _ := strconv.ParseInt(match[8], 0, 64) + ncnt, _ := strconv.ParseInt(match[11], 0, 64) + + ss[i] = &ServerStats{ + Sent: sent, + Received: recv, + NodeCount: ncnt, + MinLatency: minLatency, + AvgLatency: avgLatency, + MaxLatency: maxLatency, + Connections: cons, + Outstanding: outs, + Epoch: epoch, + Counter: counter, + BuildTime: buildTime, + Mode: srvrMode, + Version: match[0], + } + } + + return ss, imOk +} + +// FLWRuok is a FourLetterWord helper function. In particular, this function +// pulls the ruok output from each server. +func FLWRuok(servers []string, timeout time.Duration) []bool { + servers = FormatServers(servers) + oks := make([]bool, len(servers)) + + for i := range oks { + response, err := fourLetterWord(servers[i], "ruok", timeout) + + if err != nil { + continue + } + + if bytes.Equal(response[:4], []byte("imok")) { + oks[i] = true + } + } + return oks +} + +// FLWCons is a FourLetterWord helper function. In particular, this function +// pulls the ruok output from each server. +// +// As with FLWSrvr, the boolean value indicates whether one of the requests had +// an issue. The Clients struct has an Error value that can be checked. +func FLWCons(servers []string, timeout time.Duration) ([]*ServerClients, bool) { + const ( + zrAddr = `^ /((?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?):(?:\d+))\[\d+\]` + zrPac = `\(queued=(\d+),recved=(\d+),sent=(\d+),sid=(0x[A-Za-z0-9]+),lop=(\w+),est=(\d+),to=(\d+),` + zrSesh = `lcxid=(0x[A-Za-z0-9]+),lzxid=(0x[A-Za-z0-9]+),lresp=(\d+),llat=(\d+),minlat=(\d+),avglat=(\d+),maxlat=(\d+)\)` + ) + + re, err := regexp.Compile(fmt.Sprintf("%v%v%v", zrAddr, zrPac, zrSesh)) + if err != nil { + return nil, false + } + + servers = FormatServers(servers) + sc := make([]*ServerClients, len(servers)) + imOk := true + + for i := range sc { + response, err := fourLetterWord(servers[i], "cons", timeout) + + if err != nil { + sc[i] = &ServerClients{Error: err} + imOk = false + continue + } + + scan := bufio.NewScanner(bytes.NewReader(response)) + + var clients []*ServerClient + + for scan.Scan() { + line := scan.Bytes() + + if len(line) == 0 { + continue + } + + m := re.FindAllStringSubmatch(string(line), -1) + + if m == nil { + err := fmt.Errorf("unable to parse fields from zookeeper response (no regex matches)") + sc[i] = &ServerClients{Error: err} + imOk = false + continue + } + + match := m[0][1:] + + queued, _ := strconv.ParseInt(match[1], 0, 64) + recvd, _ := strconv.ParseInt(match[2], 0, 64) + sent, _ := strconv.ParseInt(match[3], 0, 64) + sid, _ := strconv.ParseInt(match[4], 0, 64) + est, _ := strconv.ParseInt(match[6], 0, 64) + timeout, _ := strconv.ParseInt(match[7], 0, 32) + lcxid, _ := parseInt64(match[8]) + lzxid, _ := parseInt64(match[9]) + lresp, _ := strconv.ParseInt(match[10], 0, 64) + llat, _ := strconv.ParseInt(match[11], 0, 32) + minlat, _ := strconv.ParseInt(match[12], 0, 32) + avglat, _ := strconv.ParseInt(match[13], 0, 32) + maxlat, _ := strconv.ParseInt(match[14], 0, 32) + + clients = append(clients, &ServerClient{ + Queued: queued, + Received: recvd, + Sent: sent, + SessionID: sid, + Lcxid: int64(lcxid), + Lzxid: int64(lzxid), + Timeout: int32(timeout), + LastLatency: int32(llat), + MinLatency: int32(minlat), + AvgLatency: int32(avglat), + MaxLatency: int32(maxlat), + Established: time.Unix(est, 0), + LastResponse: time.Unix(lresp, 0), + Addr: match[0], + LastOperation: match[5], + }) + } + + sc[i] = &ServerClients{Clients: clients} + } + + return sc, imOk +} + +// parseInt64 is similar to strconv.ParseInt, but it also handles hex values that represent negative numbers +func parseInt64(s string) (int64, error) { + if strings.HasPrefix(s, "0x") { + i, err := strconv.ParseUint(s, 0, 64) + return int64(i), err + } + return strconv.ParseInt(s, 0, 64) +} + +func fourLetterWord(server, command string, timeout time.Duration) ([]byte, error) { + conn, err := net.DialTimeout("tcp", server, timeout) + if err != nil { + return nil, err + } + + // the zookeeper server should automatically close this socket + // once the command has been processed, but better safe than sorry + defer conn.Close() + + conn.SetWriteDeadline(time.Now().Add(timeout)) + _, err = conn.Write([]byte(command)) + if err != nil { + return nil, err + } + + conn.SetReadDeadline(time.Now().Add(timeout)) + return ioutil.ReadAll(conn) +} diff --git a/vendor/github.com/samuel/go-zookeeper/zk/lock.go b/vendor/github.com/samuel/go-zookeeper/zk/lock.go new file mode 100644 index 0000000..f13a8b0 --- /dev/null +++ b/vendor/github.com/samuel/go-zookeeper/zk/lock.go @@ -0,0 +1,142 @@ +package zk + +import ( + "errors" + "fmt" + "strconv" + "strings" +) + +var ( + // ErrDeadlock is returned by Lock when trying to lock twice without unlocking first + ErrDeadlock = errors.New("zk: trying to acquire a lock twice") + // ErrNotLocked is returned by Unlock when trying to release a lock that has not first be acquired. + ErrNotLocked = errors.New("zk: not locked") +) + +// Lock is a mutual exclusion lock. +type Lock struct { + c *Conn + path string + acl []ACL + lockPath string + seq int +} + +// NewLock creates a new lock instance using the provided connection, path, and acl. +// The path must be a node that is only used by this lock. A lock instances starts +// unlocked until Lock() is called. +func NewLock(c *Conn, path string, acl []ACL) *Lock { + return &Lock{ + c: c, + path: path, + acl: acl, + } +} + +func parseSeq(path string) (int, error) { + parts := strings.Split(path, "-") + return strconv.Atoi(parts[len(parts)-1]) +} + +// Lock attempts to acquire the lock. It will wait to return until the lock +// is acquired or an error occurs. If this instance already has the lock +// then ErrDeadlock is returned. +func (l *Lock) Lock() error { + if l.lockPath != "" { + return ErrDeadlock + } + + prefix := fmt.Sprintf("%s/lock-", l.path) + + path := "" + var err error + for i := 0; i < 3; i++ { + path, err = l.c.CreateProtectedEphemeralSequential(prefix, []byte{}, l.acl) + if err == ErrNoNode { + // Create parent node. + parts := strings.Split(l.path, "/") + pth := "" + for _, p := range parts[1:] { + pth += "/" + p + _, err := l.c.Create(pth, []byte{}, 0, l.acl) + if err != nil && err != ErrNodeExists { + return err + } + } + } else if err == nil { + break + } else { + return err + } + } + if err != nil { + return err + } + + seq, err := parseSeq(path) + if err != nil { + return err + } + + for { + children, _, err := l.c.Children(l.path) + if err != nil { + return err + } + + lowestSeq := seq + prevSeq := 0 + prevSeqPath := "" + for _, p := range children { + s, err := parseSeq(p) + if err != nil { + return err + } + if s < lowestSeq { + lowestSeq = s + } + if s < seq && s > prevSeq { + prevSeq = s + prevSeqPath = p + } + } + + if seq == lowestSeq { + // Acquired the lock + break + } + + // Wait on the node next in line for the lock + _, _, ch, err := l.c.GetW(l.path + "/" + prevSeqPath) + if err != nil && err != ErrNoNode { + return err + } else if err != nil && err == ErrNoNode { + // try again + continue + } + + ev := <-ch + if ev.Err != nil { + return ev.Err + } + } + + l.seq = seq + l.lockPath = path + return nil +} + +// Unlock releases an acquired lock. If the lock is not currently acquired by +// this Lock instance than ErrNotLocked is returned. +func (l *Lock) Unlock() error { + if l.lockPath == "" { + return ErrNotLocked + } + if err := l.c.Delete(l.lockPath, -1); err != nil { + return err + } + l.lockPath = "" + l.seq = 0 + return nil +} diff --git a/vendor/github.com/samuel/go-zookeeper/zk/server_help.go b/vendor/github.com/samuel/go-zookeeper/zk/server_help.go new file mode 100644 index 0000000..618185a --- /dev/null +++ b/vendor/github.com/samuel/go-zookeeper/zk/server_help.go @@ -0,0 +1,190 @@ +package zk + +import ( + "fmt" + "io" + "io/ioutil" + "math/rand" + "os" + "path/filepath" + "strings" + "time" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +type TestServer struct { + Port int + Path string + Srv *Server +} + +type TestCluster struct { + Path string + Servers []TestServer +} + +func StartTestCluster(size int, stdout, stderr io.Writer) (*TestCluster, error) { + tmpPath, err := ioutil.TempDir("", "gozk") + if err != nil { + return nil, err + } + success := false + startPort := int(rand.Int31n(6000) + 10000) + cluster := &TestCluster{Path: tmpPath} + defer func() { + if !success { + cluster.Stop() + } + }() + for serverN := 0; serverN < size; serverN++ { + srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv%d", serverN)) + if err := os.Mkdir(srvPath, 0700); err != nil { + return nil, err + } + port := startPort + serverN*3 + cfg := ServerConfig{ + ClientPort: port, + DataDir: srvPath, + } + for i := 0; i < size; i++ { + cfg.Servers = append(cfg.Servers, ServerConfigServer{ + ID: i + 1, + Host: "127.0.0.1", + PeerPort: startPort + i*3 + 1, + LeaderElectionPort: startPort + i*3 + 2, + }) + } + cfgPath := filepath.Join(srvPath, "zoo.cfg") + fi, err := os.Create(cfgPath) + if err != nil { + return nil, err + } + err = cfg.Marshall(fi) + fi.Close() + if err != nil { + return nil, err + } + + fi, err = os.Create(filepath.Join(srvPath, "myid")) + if err != nil { + return nil, err + } + _, err = fmt.Fprintf(fi, "%d\n", serverN+1) + fi.Close() + if err != nil { + return nil, err + } + + srv := &Server{ + ConfigPath: cfgPath, + Stdout: stdout, + Stderr: stderr, + } + if err := srv.Start(); err != nil { + return nil, err + } + cluster.Servers = append(cluster.Servers, TestServer{ + Path: srvPath, + Port: cfg.ClientPort, + Srv: srv, + }) + } + if err := cluster.waitForStart(10, time.Second); err != nil { + return nil, err + } + success = true + return cluster, nil +} + +func (ts *TestCluster) Connect(idx int) (*Conn, error) { + zk, _, err := Connect([]string{fmt.Sprintf("127.0.0.1:%d", ts.Servers[idx].Port)}, time.Second*15) + return zk, err +} + +func (ts *TestCluster) ConnectAll() (*Conn, <-chan Event, error) { + return ts.ConnectAllTimeout(time.Second * 15) +} + +func (ts *TestCluster) ConnectAllTimeout(sessionTimeout time.Duration) (*Conn, <-chan Event, error) { + hosts := make([]string, len(ts.Servers)) + for i, srv := range ts.Servers { + hosts[i] = fmt.Sprintf("127.0.0.1:%d", srv.Port) + } + zk, ch, err := Connect(hosts, sessionTimeout) + return zk, ch, err +} + +func (ts *TestCluster) Stop() error { + for _, srv := range ts.Servers { + srv.Srv.Stop() + } + defer os.RemoveAll(ts.Path) + return ts.waitForStop(5, time.Second) +} + +// waitForStart blocks until the cluster is up +func (ts *TestCluster) waitForStart(maxRetry int, interval time.Duration) error { + // verify that the servers are up with SRVR + serverAddrs := make([]string, len(ts.Servers)) + for i, s := range ts.Servers { + serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port) + } + + for i := 0; i < maxRetry; i++ { + _, ok := FLWSrvr(serverAddrs, time.Second) + if ok { + return nil + } + time.Sleep(interval) + } + return fmt.Errorf("unable to verify health of servers") +} + +// waitForStop blocks until the cluster is down +func (ts *TestCluster) waitForStop(maxRetry int, interval time.Duration) error { + // verify that the servers are up with RUOK + serverAddrs := make([]string, len(ts.Servers)) + for i, s := range ts.Servers { + serverAddrs[i] = fmt.Sprintf("127.0.0.1:%d", s.Port) + } + + var success bool + for i := 0; i < maxRetry && !success; i++ { + success = true + for _, ok := range FLWRuok(serverAddrs, time.Second) { + if ok { + success = false + } + } + if !success { + time.Sleep(interval) + } + } + if !success { + return fmt.Errorf("unable to verify servers are down") + } + return nil +} + +func (tc *TestCluster) StartServer(server string) { + for _, s := range tc.Servers { + if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) { + s.Srv.Start() + return + } + } + panic(fmt.Sprintf("Unknown server: %s", server)) +} + +func (tc *TestCluster) StopServer(server string) { + for _, s := range tc.Servers { + if strings.HasSuffix(server, fmt.Sprintf(":%d", s.Port)) { + s.Srv.Stop() + return + } + } + panic(fmt.Sprintf("Unknown server: %s", server)) +} diff --git a/vendor/github.com/samuel/go-zookeeper/zk/server_java.go b/vendor/github.com/samuel/go-zookeeper/zk/server_java.go new file mode 100644 index 0000000..e553ec1 --- /dev/null +++ b/vendor/github.com/samuel/go-zookeeper/zk/server_java.go @@ -0,0 +1,136 @@ +package zk + +import ( + "fmt" + "io" + "os" + "os/exec" + "path/filepath" +) + +type ErrMissingServerConfigField string + +func (e ErrMissingServerConfigField) Error() string { + return fmt.Sprintf("zk: missing server config field '%s'", string(e)) +} + +const ( + DefaultServerTickTime = 2000 + DefaultServerInitLimit = 10 + DefaultServerSyncLimit = 5 + DefaultServerAutoPurgeSnapRetainCount = 3 + DefaultPeerPort = 2888 + DefaultLeaderElectionPort = 3888 +) + +type ServerConfigServer struct { + ID int + Host string + PeerPort int + LeaderElectionPort int +} + +type ServerConfig struct { + TickTime int // Number of milliseconds of each tick + InitLimit int // Number of ticks that the initial synchronization phase can take + SyncLimit int // Number of ticks that can pass between sending a request and getting an acknowledgement + DataDir string // Direcrory where the snapshot is stored + ClientPort int // Port at which clients will connect + AutoPurgeSnapRetainCount int // Number of snapshots to retain in dataDir + AutoPurgePurgeInterval int // Purge task internal in hours (0 to disable auto purge) + Servers []ServerConfigServer +} + +func (sc ServerConfig) Marshall(w io.Writer) error { + if sc.DataDir == "" { + return ErrMissingServerConfigField("dataDir") + } + fmt.Fprintf(w, "dataDir=%s\n", sc.DataDir) + if sc.TickTime <= 0 { + sc.TickTime = DefaultServerTickTime + } + fmt.Fprintf(w, "tickTime=%d\n", sc.TickTime) + if sc.InitLimit <= 0 { + sc.InitLimit = DefaultServerInitLimit + } + fmt.Fprintf(w, "initLimit=%d\n", sc.InitLimit) + if sc.SyncLimit <= 0 { + sc.SyncLimit = DefaultServerSyncLimit + } + fmt.Fprintf(w, "syncLimit=%d\n", sc.SyncLimit) + if sc.ClientPort <= 0 { + sc.ClientPort = DefaultPort + } + fmt.Fprintf(w, "clientPort=%d\n", sc.ClientPort) + if sc.AutoPurgePurgeInterval > 0 { + if sc.AutoPurgeSnapRetainCount <= 0 { + sc.AutoPurgeSnapRetainCount = DefaultServerAutoPurgeSnapRetainCount + } + fmt.Fprintf(w, "autopurge.snapRetainCount=%d\n", sc.AutoPurgeSnapRetainCount) + fmt.Fprintf(w, "autopurge.purgeInterval=%d\n", sc.AutoPurgePurgeInterval) + } + if len(sc.Servers) > 0 { + for _, srv := range sc.Servers { + if srv.PeerPort <= 0 { + srv.PeerPort = DefaultPeerPort + } + if srv.LeaderElectionPort <= 0 { + srv.LeaderElectionPort = DefaultLeaderElectionPort + } + fmt.Fprintf(w, "server.%d=%s:%d:%d\n", srv.ID, srv.Host, srv.PeerPort, srv.LeaderElectionPort) + } + } + return nil +} + +var jarSearchPaths = []string{ + "zookeeper-*/contrib/fatjar/zookeeper-*-fatjar.jar", + "../zookeeper-*/contrib/fatjar/zookeeper-*-fatjar.jar", + "/usr/share/java/zookeeper-*.jar", + "/usr/local/zookeeper-*/contrib/fatjar/zookeeper-*-fatjar.jar", + "/usr/local/Cellar/zookeeper/*/libexec/contrib/fatjar/zookeeper-*-fatjar.jar", +} + +func findZookeeperFatJar() string { + var paths []string + zkPath := os.Getenv("ZOOKEEPER_PATH") + if zkPath == "" { + paths = jarSearchPaths + } else { + paths = []string{filepath.Join(zkPath, "contrib/fatjar/zookeeper-*-fatjar.jar")} + } + for _, path := range paths { + matches, _ := filepath.Glob(path) + // TODO: could sort by version and pick latest + if len(matches) > 0 { + return matches[0] + } + } + return "" +} + +type Server struct { + JarPath string + ConfigPath string + Stdout, Stderr io.Writer + + cmd *exec.Cmd +} + +func (srv *Server) Start() error { + if srv.JarPath == "" { + srv.JarPath = findZookeeperFatJar() + if srv.JarPath == "" { + return fmt.Errorf("zk: unable to find server jar") + } + } + srv.cmd = exec.Command("java", "-jar", srv.JarPath, "server", srv.ConfigPath) + srv.cmd.Stdout = srv.Stdout + srv.cmd.Stderr = srv.Stderr + return srv.cmd.Start() +} + +func (srv *Server) Stop() error { + srv.cmd.Process.Signal(os.Kill) + return srv.cmd.Wait() +} diff --git a/vendor/github.com/samuel/go-zookeeper/zk/structs.go b/vendor/github.com/samuel/go-zookeeper/zk/structs.go new file mode 100644 index 0000000..02cd3f3 --- /dev/null +++ b/vendor/github.com/samuel/go-zookeeper/zk/structs.go @@ -0,0 +1,600 @@ +package zk + +import ( + "encoding/binary" + "errors" + "log" + "reflect" + "runtime" + "time" +) + +var ( + ErrUnhandledFieldType = errors.New("zk: unhandled field type") + ErrPtrExpected = errors.New("zk: encode/decode expect a non-nil pointer to struct") + ErrShortBuffer = errors.New("zk: buffer too small") +) + +type defaultLogger struct{} + +func (defaultLogger) Printf(format string, a ...interface{}) { + log.Printf(format, a...) +} + +type ACL struct { + Perms int32 + Scheme string + ID string +} + +type Stat struct { + Czxid int64 // The zxid of the change that caused this znode to be created. + Mzxid int64 // The zxid of the change that last modified this znode. + Ctime int64 // The time in milliseconds from epoch when this znode was created. + Mtime int64 // The time in milliseconds from epoch when this znode was last modified. + Version int32 // The number of changes to the data of this znode. + Cversion int32 // The number of changes to the children of this znode. + Aversion int32 // The number of changes to the ACL of this znode. + EphemeralOwner int64 // The session id of the owner of this znode if the znode is an ephemeral node. If it is not an ephemeral node, it will be zero. + DataLength int32 // The length of the data field of this znode. + NumChildren int32 // The number of children of this znode. + Pzxid int64 // last modified children +} + +// ServerClient is the information for a single Zookeeper client and its session. +// This is used to parse/extract the output fo the `cons` command. +type ServerClient struct { + Queued int64 + Received int64 + Sent int64 + SessionID int64 + Lcxid int64 + Lzxid int64 + Timeout int32 + LastLatency int32 + MinLatency int32 + AvgLatency int32 + MaxLatency int32 + Established time.Time + LastResponse time.Time + Addr string + LastOperation string // maybe? + Error error +} + +// ServerClients is a struct for the FLWCons() function. It's used to provide +// the list of Clients. +// +// This is needed because FLWCons() takes multiple servers. +type ServerClients struct { + Clients []*ServerClient + Error error +} + +// ServerStats is the information pulled from the Zookeeper `stat` command. +type ServerStats struct { + Sent int64 + Received int64 + NodeCount int64 + MinLatency int64 + AvgLatency int64 + MaxLatency int64 + Connections int64 + Outstanding int64 + Epoch int32 + Counter int32 + BuildTime time.Time + Mode Mode + Version string + Error error +} + +type requestHeader struct { + Xid int32 + Opcode int32 +} + +type responseHeader struct { + Xid int32 + Zxid int64 + Err ErrCode +} + +type multiHeader struct { + Type int32 + Done bool + Err ErrCode +} + +type auth struct { + Type int32 + Scheme string + Auth []byte +} + +// Generic request structs + +type pathRequest struct { + Path string +} + +type PathVersionRequest struct { + Path string + Version int32 +} + +type pathWatchRequest struct { + Path string + Watch bool +} + +type pathResponse struct { + Path string +} + +type statResponse struct { + Stat Stat +} + +// + +type CheckVersionRequest PathVersionRequest +type closeRequest struct{} +type closeResponse struct{} + +type connectRequest struct { + ProtocolVersion int32 + LastZxidSeen int64 + TimeOut int32 + SessionID int64 + Passwd []byte +} + +type connectResponse struct { + ProtocolVersion int32 + TimeOut int32 + SessionID int64 + Passwd []byte +} + +type CreateRequest struct { + Path string + Data []byte + Acl []ACL + Flags int32 +} + +type createResponse pathResponse +type DeleteRequest PathVersionRequest +type deleteResponse struct{} + +type errorResponse struct { + Err int32 +} + +type existsRequest pathWatchRequest +type existsResponse statResponse +type getAclRequest pathRequest + +type getAclResponse struct { + Acl []ACL + Stat Stat +} + +type getChildrenRequest pathRequest + +type getChildrenResponse struct { + Children []string +} + +type getChildren2Request pathWatchRequest + +type getChildren2Response struct { + Children []string + Stat Stat +} + +type getDataRequest pathWatchRequest + +type getDataResponse struct { + Data []byte + Stat Stat +} + +type getMaxChildrenRequest pathRequest + +type getMaxChildrenResponse struct { + Max int32 +} + +type getSaslRequest struct { + Token []byte +} + +type pingRequest struct{} +type pingResponse struct{} + +type setAclRequest struct { + Path string + Acl []ACL + Version int32 +} + +type setAclResponse statResponse + +type SetDataRequest struct { + Path string + Data []byte + Version int32 +} + +type setDataResponse statResponse + +type setMaxChildren struct { + Path string + Max int32 +} + +type setSaslRequest struct { + Token string +} + +type setSaslResponse struct { + Token string +} + +type setWatchesRequest struct { + RelativeZxid int64 + DataWatches []string + ExistWatches []string + ChildWatches []string +} + +type setWatchesResponse struct{} + +type syncRequest pathRequest +type syncResponse pathResponse + +type setAuthRequest auth +type setAuthResponse struct{} + +type multiRequestOp struct { + Header multiHeader + Op interface{} +} +type multiRequest struct { + Ops []multiRequestOp + DoneHeader multiHeader +} +type multiResponseOp struct { + Header multiHeader + String string + Stat *Stat +} +type multiResponse struct { + Ops []multiResponseOp + DoneHeader multiHeader +} + +func (r *multiRequest) Encode(buf []byte) (int, error) { + total := 0 + for _, op := range r.Ops { + op.Header.Done = false + n, err := encodePacketValue(buf[total:], reflect.ValueOf(op)) + if err != nil { + return total, err + } + total += n + } + r.DoneHeader.Done = true + n, err := encodePacketValue(buf[total:], reflect.ValueOf(r.DoneHeader)) + if err != nil { + return total, err + } + total += n + + return total, nil +} + +func (r *multiRequest) Decode(buf []byte) (int, error) { + r.Ops = make([]multiRequestOp, 0) + r.DoneHeader = multiHeader{-1, true, -1} + total := 0 + for { + header := &multiHeader{} + n, err := decodePacketValue(buf[total:], reflect.ValueOf(header)) + if err != nil { + return total, err + } + total += n + if header.Done { + r.DoneHeader = *header + break + } + + req := requestStructForOp(header.Type) + if req == nil { + return total, ErrAPIError + } + n, err = decodePacketValue(buf[total:], reflect.ValueOf(req)) + if err != nil { + return total, err + } + total += n + r.Ops = append(r.Ops, multiRequestOp{*header, req}) + } + return total, nil +} + +func (r *multiResponse) Decode(buf []byte) (int, error) { + r.Ops = make([]multiResponseOp, 0) + r.DoneHeader = multiHeader{-1, true, -1} + total := 0 + for { + header := &multiHeader{} + n, err := decodePacketValue(buf[total:], reflect.ValueOf(header)) + if err != nil { + return total, err + } + total += n + if header.Done { + r.DoneHeader = *header + break + } + + res := multiResponseOp{Header: *header} + var w reflect.Value + switch header.Type { + default: + return total, ErrAPIError + case opCreate: + w = reflect.ValueOf(&res.String) + case opSetData: + res.Stat = new(Stat) + w = reflect.ValueOf(res.Stat) + case opCheck, opDelete: + } + if w.IsValid() { + n, err := decodePacketValue(buf[total:], w) + if err != nil { + return total, err + } + total += n + } + r.Ops = append(r.Ops, res) + } + return total, nil +} + +type watcherEvent struct { + Type EventType + State State + Path string +} + +type decoder interface { + Decode(buf []byte) (int, error) +} + +type encoder interface { + Encode(buf []byte) (int, error) +} + +func decodePacket(buf []byte, st interface{}) (n int, err error) { + defer func() { + if r := recover(); r != nil { + if e, ok := r.(runtime.Error); ok && e.Error() == "runtime error: slice bounds out of range" { + err = ErrShortBuffer + } else { + panic(r) + } + } + }() + + v := reflect.ValueOf(st) + if v.Kind() != reflect.Ptr || v.IsNil() { + return 0, ErrPtrExpected + } + return decodePacketValue(buf, v) +} + +func decodePacketValue(buf []byte, v reflect.Value) (int, error) { + rv := v + kind := v.Kind() + if kind == reflect.Ptr { + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + v = v.Elem() + kind = v.Kind() + } + + n := 0 + switch kind { + default: + return n, ErrUnhandledFieldType + case reflect.Struct: + if de, ok := rv.Interface().(decoder); ok { + return de.Decode(buf) + } else if de, ok := v.Interface().(decoder); ok { + return de.Decode(buf) + } else { + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + n2, err := decodePacketValue(buf[n:], field) + n += n2 + if err != nil { + return n, err + } + } + } + case reflect.Bool: + v.SetBool(buf[n] != 0) + n++ + case reflect.Int32: + v.SetInt(int64(binary.BigEndian.Uint32(buf[n : n+4]))) + n += 4 + case reflect.Int64: + v.SetInt(int64(binary.BigEndian.Uint64(buf[n : n+8]))) + n += 8 + case reflect.String: + ln := int(binary.BigEndian.Uint32(buf[n : n+4])) + v.SetString(string(buf[n+4 : n+4+ln])) + n += 4 + ln + case reflect.Slice: + switch v.Type().Elem().Kind() { + default: + count := int(binary.BigEndian.Uint32(buf[n : n+4])) + n += 4 + values := reflect.MakeSlice(v.Type(), count, count) + v.Set(values) + for i := 0; i < count; i++ { + n2, err := decodePacketValue(buf[n:], values.Index(i)) + n += n2 + if err != nil { + return n, err + } + } + case reflect.Uint8: + ln := int(int32(binary.BigEndian.Uint32(buf[n : n+4]))) + if ln < 0 { + n += 4 + v.SetBytes(nil) + } else { + bytes := make([]byte, ln) + copy(bytes, buf[n+4:n+4+ln]) + v.SetBytes(bytes) + n += 4 + ln + } + } + } + return n, nil +} + +func encodePacket(buf []byte, st interface{}) (n int, err error) { + defer func() { + if r := recover(); r != nil { + if e, ok := r.(runtime.Error); ok && e.Error() == "runtime error: slice bounds out of range" { + err = ErrShortBuffer + } else { + panic(r) + } + } + }() + + v := reflect.ValueOf(st) + if v.Kind() != reflect.Ptr || v.IsNil() { + return 0, ErrPtrExpected + } + return encodePacketValue(buf, v) +} + +func encodePacketValue(buf []byte, v reflect.Value) (int, error) { + rv := v + for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { + v = v.Elem() + } + + n := 0 + switch v.Kind() { + default: + return n, ErrUnhandledFieldType + case reflect.Struct: + if en, ok := rv.Interface().(encoder); ok { + return en.Encode(buf) + } else if en, ok := v.Interface().(encoder); ok { + return en.Encode(buf) + } else { + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + n2, err := encodePacketValue(buf[n:], field) + n += n2 + if err != nil { + return n, err + } + } + } + case reflect.Bool: + if v.Bool() { + buf[n] = 1 + } else { + buf[n] = 0 + } + n++ + case reflect.Int32: + binary.BigEndian.PutUint32(buf[n:n+4], uint32(v.Int())) + n += 4 + case reflect.Int64: + binary.BigEndian.PutUint64(buf[n:n+8], uint64(v.Int())) + n += 8 + case reflect.String: + str := v.String() + binary.BigEndian.PutUint32(buf[n:n+4], uint32(len(str))) + copy(buf[n+4:n+4+len(str)], []byte(str)) + n += 4 + len(str) + case reflect.Slice: + switch v.Type().Elem().Kind() { + default: + count := v.Len() + startN := n + n += 4 + for i := 0; i < count; i++ { + n2, err := encodePacketValue(buf[n:], v.Index(i)) + n += n2 + if err != nil { + return n, err + } + } + binary.BigEndian.PutUint32(buf[startN:startN+4], uint32(count)) + case reflect.Uint8: + if v.IsNil() { + binary.BigEndian.PutUint32(buf[n:n+4], uint32(0xffffffff)) + n += 4 + } else { + bytes := v.Bytes() + binary.BigEndian.PutUint32(buf[n:n+4], uint32(len(bytes))) + copy(buf[n+4:n+4+len(bytes)], bytes) + n += 4 + len(bytes) + } + } + } + return n, nil +} + +func requestStructForOp(op int32) interface{} { + switch op { + case opClose: + return &closeRequest{} + case opCreate: + return &CreateRequest{} + case opDelete: + return &DeleteRequest{} + case opExists: + return &existsRequest{} + case opGetAcl: + return &getAclRequest{} + case opGetChildren: + return &getChildrenRequest{} + case opGetChildren2: + return &getChildren2Request{} + case opGetData: + return &getDataRequest{} + case opPing: + return &pingRequest{} + case opSetAcl: + return &setAclRequest{} + case opSetData: + return &SetDataRequest{} + case opSetWatches: + return &setWatchesRequest{} + case opSync: + return &syncRequest{} + case opSetAuth: + return &setAuthRequest{} + case opCheck: + return &CheckVersionRequest{} + case opMulti: + return &multiRequest{} + } + return nil +} diff --git a/vendor/github.com/samuel/go-zookeeper/zk/util.go b/vendor/github.com/samuel/go-zookeeper/zk/util.go new file mode 100644 index 0000000..769bbe8 --- /dev/null +++ b/vendor/github.com/samuel/go-zookeeper/zk/util.go @@ -0,0 +1,54 @@ +package zk + +import ( + "crypto/sha1" + "encoding/base64" + "fmt" + "math/rand" + "strconv" + "strings" +) + +// AuthACL produces an ACL list containing a single ACL which uses the +// provided permissions, with the scheme "auth", and ID "", which is used +// by ZooKeeper to represent any authenticated user. +func AuthACL(perms int32) []ACL { + return []ACL{{perms, "auth", ""}} +} + +// WorldACL produces an ACL list containing a single ACL which uses the +// provided permissions, with the scheme "world", and ID "anyone", which +// is used by ZooKeeper to represent any user at all. +func WorldACL(perms int32) []ACL { + return []ACL{{perms, "world", "anyone"}} +} + +func DigestACL(perms int32, user, password string) []ACL { + userPass := []byte(fmt.Sprintf("%s:%s", user, password)) + h := sha1.New() + if n, err := h.Write(userPass); err != nil || n != len(userPass) { + panic("SHA1 failed") + } + digest := base64.StdEncoding.EncodeToString(h.Sum(nil)) + return []ACL{{perms, "digest", fmt.Sprintf("%s:%s", user, digest)}} +} + +// FormatServers takes a slice of addresses, and makes sure they are in a format +// that resembles :. If the server has no port provided, the +// DefaultPort constant is added to the end. +func FormatServers(servers []string) []string { + for i := range servers { + if !strings.Contains(servers[i], ":") { + servers[i] = servers[i] + ":" + strconv.Itoa(DefaultPort) + } + } + return servers +} + +// stringShuffle performs a Fisher-Yates shuffle on a slice of strings +func stringShuffle(s []string) { + for i := len(s) - 1; i > 0; i-- { + j := rand.Intn(i + 1) + s[i], s[j] = s[j], s[i] + } +} diff --git a/vendor/github.com/stretchr/testify/LICENSE b/vendor/github.com/stretchr/testify/LICENSE new file mode 100644 index 0000000..473b670 --- /dev/null +++ b/vendor/github.com/stretchr/testify/LICENSE @@ -0,0 +1,22 @@ +Copyright (c) 2012 - 2013 Mat Ryer and Tyler Bunnell + +Please consider promoting this project if you find it useful. + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without restriction, +including without limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of the Software, +and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included +in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT +OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE +OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/stretchr/testify/assert/assertion_forward.go b/vendor/github.com/stretchr/testify/assert/assertion_forward.go new file mode 100644 index 0000000..368992b --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/assertion_forward.go @@ -0,0 +1,348 @@ +/* +* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen +* THIS FILE MUST NOT BE EDITED BY HAND + */ + +package assert + +import ( + http "net/http" + url "net/url" + time "time" +) + +// Condition uses a Comparison to assert a complex condition. +func (a *Assertions) Condition(comp Comparison, msgAndArgs ...interface{}) bool { + return Condition(a.t, comp, msgAndArgs...) +} + +// Contains asserts that the specified string, list(array, slice...) or map contains the +// specified substring or element. +// +// a.Contains("Hello World", "World", "But 'Hello World' does contain 'World'") +// a.Contains(["Hello", "World"], "World", "But ["Hello", "World"] does contain 'World'") +// a.Contains({"Hello": "World"}, "Hello", "But {'Hello': 'World'} does contain 'Hello'") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) Contains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool { + return Contains(a.t, s, contains, msgAndArgs...) +} + +// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// a.Empty(obj) +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool { + return Empty(a.t, object, msgAndArgs...) +} + +// Equal asserts that two objects are equal. +// +// a.Equal(123, 123, "123 and 123 should be equal") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) Equal(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + return Equal(a.t, expected, actual, msgAndArgs...) +} + +// EqualError asserts that a function returned an error (i.e. not `nil`) +// and that it is equal to the provided error. +// +// actualObj, err := SomeFunction() +// if assert.Error(t, err, "An error was expected") { +// assert.Equal(t, err, expectedError) +// } +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) EqualError(theError error, errString string, msgAndArgs ...interface{}) bool { + return EqualError(a.t, theError, errString, msgAndArgs...) +} + +// EqualValues asserts that two objects are equal or convertable to the same types +// and equal. +// +// a.EqualValues(uint32(123), int32(123), "123 and 123 should be equal") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + return EqualValues(a.t, expected, actual, msgAndArgs...) +} + +// Error asserts that a function returned an error (i.e. not `nil`). +// +// actualObj, err := SomeFunction() +// if a.Error(err, "An error was expected") { +// assert.Equal(t, err, expectedError) +// } +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool { + return Error(a.t, err, msgAndArgs...) +} + +// Exactly asserts that two objects are equal is value and type. +// +// a.Exactly(int32(123), int64(123), "123 and 123 should NOT be equal") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) Exactly(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + return Exactly(a.t, expected, actual, msgAndArgs...) +} + +// Fail reports a failure through +func (a *Assertions) Fail(failureMessage string, msgAndArgs ...interface{}) bool { + return Fail(a.t, failureMessage, msgAndArgs...) +} + +// FailNow fails test +func (a *Assertions) FailNow(failureMessage string, msgAndArgs ...interface{}) bool { + return FailNow(a.t, failureMessage, msgAndArgs...) +} + +// False asserts that the specified value is false. +// +// a.False(myBool, "myBool should be false") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) False(value bool, msgAndArgs ...interface{}) bool { + return False(a.t, value, msgAndArgs...) +} + +// HTTPBodyContains asserts that a specified handler returns a +// body that contains a string. +// +// a.HTTPBodyContains(myHandler, "www.google.com", nil, "I'm Feeling Lucky") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPBodyContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) bool { + return HTTPBodyContains(a.t, handler, method, url, values, str) +} + +// HTTPBodyNotContains asserts that a specified handler returns a +// body that does not contain a string. +// +// a.HTTPBodyNotContains(myHandler, "www.google.com", nil, "I'm Feeling Lucky") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPBodyNotContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}) bool { + return HTTPBodyNotContains(a.t, handler, method, url, values, str) +} + +// HTTPError asserts that a specified handler returns an error status code. +// +// a.HTTPError(myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPError(handler http.HandlerFunc, method string, url string, values url.Values) bool { + return HTTPError(a.t, handler, method, url, values) +} + +// HTTPRedirect asserts that a specified handler returns a redirect status code. +// +// a.HTTPRedirect(myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPRedirect(handler http.HandlerFunc, method string, url string, values url.Values) bool { + return HTTPRedirect(a.t, handler, method, url, values) +} + +// HTTPSuccess asserts that a specified handler returns a success status code. +// +// a.HTTPSuccess(myHandler, "POST", "http://www.google.com", nil) +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) HTTPSuccess(handler http.HandlerFunc, method string, url string, values url.Values) bool { + return HTTPSuccess(a.t, handler, method, url, values) +} + +// Implements asserts that an object is implemented by the specified interface. +// +// a.Implements((*MyInterface)(nil), new(MyObject), "MyObject") +func (a *Assertions) Implements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { + return Implements(a.t, interfaceObject, object, msgAndArgs...) +} + +// InDelta asserts that the two numerals are within delta of each other. +// +// a.InDelta(math.Pi, (22 / 7.0), 0.01) +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) InDelta(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + return InDelta(a.t, expected, actual, delta, msgAndArgs...) +} + +// InDeltaSlice is the same as InDelta, except it compares two slices. +func (a *Assertions) InDeltaSlice(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + return InDeltaSlice(a.t, expected, actual, delta, msgAndArgs...) +} + +// InEpsilon asserts that expected and actual have a relative error less than epsilon +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) InEpsilon(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { + return InEpsilon(a.t, expected, actual, epsilon, msgAndArgs...) +} + +// InEpsilonSlice is the same as InEpsilon, except it compares two slices. +func (a *Assertions) InEpsilonSlice(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + return InEpsilonSlice(a.t, expected, actual, delta, msgAndArgs...) +} + +// IsType asserts that the specified objects are of the same type. +func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { + return IsType(a.t, expectedType, object, msgAndArgs...) +} + +// JSONEq asserts that two JSON strings are equivalent. +// +// a.JSONEq(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) JSONEq(expected string, actual string, msgAndArgs ...interface{}) bool { + return JSONEq(a.t, expected, actual, msgAndArgs...) +} + +// Len asserts that the specified object has specific length. +// Len also fails if the object has a type that len() not accept. +// +// a.Len(mySlice, 3, "The size of slice is not 3") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) Len(object interface{}, length int, msgAndArgs ...interface{}) bool { + return Len(a.t, object, length, msgAndArgs...) +} + +// Nil asserts that the specified object is nil. +// +// a.Nil(err, "err should be nothing") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) Nil(object interface{}, msgAndArgs ...interface{}) bool { + return Nil(a.t, object, msgAndArgs...) +} + +// NoError asserts that a function returned no error (i.e. `nil`). +// +// actualObj, err := SomeFunction() +// if a.NoError(err) { +// assert.Equal(t, actualObj, expectedObj) +// } +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) NoError(err error, msgAndArgs ...interface{}) bool { + return NoError(a.t, err, msgAndArgs...) +} + +// NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the +// specified substring or element. +// +// a.NotContains("Hello World", "Earth", "But 'Hello World' does NOT contain 'Earth'") +// a.NotContains(["Hello", "World"], "Earth", "But ['Hello', 'World'] does NOT contain 'Earth'") +// a.NotContains({"Hello": "World"}, "Earth", "But {'Hello': 'World'} does NOT contain 'Earth'") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) NotContains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool { + return NotContains(a.t, s, contains, msgAndArgs...) +} + +// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// if a.NotEmpty(obj) { +// assert.Equal(t, "two", obj[1]) +// } +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) bool { + return NotEmpty(a.t, object, msgAndArgs...) +} + +// NotEqual asserts that the specified values are NOT equal. +// +// a.NotEqual(obj1, obj2, "two objects shouldn't be equal") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) NotEqual(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { + return NotEqual(a.t, expected, actual, msgAndArgs...) +} + +// NotNil asserts that the specified object is not nil. +// +// a.NotNil(err, "err should be something") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) NotNil(object interface{}, msgAndArgs ...interface{}) bool { + return NotNil(a.t, object, msgAndArgs...) +} + +// NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. +// +// a.NotPanics(func(){ +// RemainCalm() +// }, "Calling RemainCalm() should NOT panic") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) NotPanics(f PanicTestFunc, msgAndArgs ...interface{}) bool { + return NotPanics(a.t, f, msgAndArgs...) +} + +// NotRegexp asserts that a specified regexp does not match a string. +// +// a.NotRegexp(regexp.MustCompile("starts"), "it's starting") +// a.NotRegexp("^start", "it's not starting") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) NotRegexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { + return NotRegexp(a.t, rx, str, msgAndArgs...) +} + +// NotZero asserts that i is not the zero value for its type and returns the truth. +func (a *Assertions) NotZero(i interface{}, msgAndArgs ...interface{}) bool { + return NotZero(a.t, i, msgAndArgs...) +} + +// Panics asserts that the code inside the specified PanicTestFunc panics. +// +// a.Panics(func(){ +// GoCrazy() +// }, "Calling GoCrazy() should panic") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) Panics(f PanicTestFunc, msgAndArgs ...interface{}) bool { + return Panics(a.t, f, msgAndArgs...) +} + +// Regexp asserts that a specified regexp matches a string. +// +// a.Regexp(regexp.MustCompile("start"), "it's starting") +// a.Regexp("start...$", "it's not starting") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) Regexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { + return Regexp(a.t, rx, str, msgAndArgs...) +} + +// True asserts that the specified value is true. +// +// a.True(myBool, "myBool should be true") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) True(value bool, msgAndArgs ...interface{}) bool { + return True(a.t, value, msgAndArgs...) +} + +// WithinDuration asserts that the two times are within duration delta of each other. +// +// a.WithinDuration(time.Now(), time.Now(), 10*time.Second, "The difference should not be more than 10s") +// +// Returns whether the assertion was successful (true) or not (false). +func (a *Assertions) WithinDuration(expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool { + return WithinDuration(a.t, expected, actual, delta, msgAndArgs...) +} + +// Zero asserts that i is the zero value for its type and returns the truth. +func (a *Assertions) Zero(i interface{}, msgAndArgs ...interface{}) bool { + return Zero(a.t, i, msgAndArgs...) +} diff --git a/vendor/github.com/stretchr/testify/assert/assertion_forward.go.tmpl b/vendor/github.com/stretchr/testify/assert/assertion_forward.go.tmpl new file mode 100644 index 0000000..99f9acf --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/assertion_forward.go.tmpl @@ -0,0 +1,4 @@ +{{.CommentWithoutT "a"}} +func (a *Assertions) {{.DocInfo.Name}}({{.Params}}) bool { + return {{.DocInfo.Name}}(a.t, {{.ForwardedParams}}) +} diff --git a/vendor/github.com/stretchr/testify/assert/assertions.go b/vendor/github.com/stretchr/testify/assert/assertions.go new file mode 100644 index 0000000..348d5f1 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/assertions.go @@ -0,0 +1,1007 @@ +package assert + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "math" + "reflect" + "regexp" + "runtime" + "strings" + "time" + "unicode" + "unicode/utf8" + + "github.com/davecgh/go-spew/spew" + "github.com/pmezard/go-difflib/difflib" +) + +// TestingT is an interface wrapper around *testing.T +type TestingT interface { + Errorf(format string, args ...interface{}) +} + +// Comparison a custom function that returns true on success and false on failure +type Comparison func() (success bool) + +/* + Helper functions +*/ + +// ObjectsAreEqual determines if two objects are considered equal. +// +// This function does no assertion of any kind. +func ObjectsAreEqual(expected, actual interface{}) bool { + + if expected == nil || actual == nil { + return expected == actual + } + + return reflect.DeepEqual(expected, actual) + +} + +// ObjectsAreEqualValues gets whether two objects are equal, or if their +// values are equal. +func ObjectsAreEqualValues(expected, actual interface{}) bool { + if ObjectsAreEqual(expected, actual) { + return true + } + + actualType := reflect.TypeOf(actual) + if actualType == nil { + return false + } + expectedValue := reflect.ValueOf(expected) + if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { + // Attempt comparison after type conversion + return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) + } + + return false +} + +/* CallerInfo is necessary because the assert functions use the testing object +internally, causing it to print the file:line of the assert method, rather than where +the problem actually occured in calling code.*/ + +// CallerInfo returns an array of strings containing the file and line number +// of each stack frame leading from the current test to the assert call that +// failed. +func CallerInfo() []string { + + pc := uintptr(0) + file := "" + line := 0 + ok := false + name := "" + + callers := []string{} + for i := 0; ; i++ { + pc, file, line, ok = runtime.Caller(i) + if !ok { + return nil + } + + // This is a huge edge case, but it will panic if this is the case, see #180 + if file == "" { + break + } + + parts := strings.Split(file, "/") + dir := parts[len(parts)-2] + file = parts[len(parts)-1] + if (dir != "assert" && dir != "mock" && dir != "require") || file == "mock_test.go" { + callers = append(callers, fmt.Sprintf("%s:%d", file, line)) + } + + f := runtime.FuncForPC(pc) + if f == nil { + break + } + name = f.Name() + // Drop the package + segments := strings.Split(name, ".") + name = segments[len(segments)-1] + if isTest(name, "Test") || + isTest(name, "Benchmark") || + isTest(name, "Example") { + break + } + } + + return callers +} + +// Stolen from the `go test` tool. +// isTest tells whether name looks like a test (or benchmark, according to prefix). +// It is a Test (say) if there is a character after Test that is not a lower-case letter. +// We don't want TesticularCancer. +func isTest(name, prefix string) bool { + if !strings.HasPrefix(name, prefix) { + return false + } + if len(name) == len(prefix) { // "Test" is ok + return true + } + rune, _ := utf8.DecodeRuneInString(name[len(prefix):]) + return !unicode.IsLower(rune) +} + +// getWhitespaceString returns a string that is long enough to overwrite the default +// output from the go testing framework. +func getWhitespaceString() string { + + _, file, line, ok := runtime.Caller(1) + if !ok { + return "" + } + parts := strings.Split(file, "/") + file = parts[len(parts)-1] + + return strings.Repeat(" ", len(fmt.Sprintf("%s:%d: ", file, line))) + +} + +func messageFromMsgAndArgs(msgAndArgs ...interface{}) string { + if len(msgAndArgs) == 0 || msgAndArgs == nil { + return "" + } + if len(msgAndArgs) == 1 { + return msgAndArgs[0].(string) + } + if len(msgAndArgs) > 1 { + return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) + } + return "" +} + +// Indents all lines of the message by appending a number of tabs to each line, in an output format compatible with Go's +// test printing (see inner comment for specifics) +func indentMessageLines(message string, tabs int) string { + outBuf := new(bytes.Buffer) + + for i, scanner := 0, bufio.NewScanner(strings.NewReader(message)); scanner.Scan(); i++ { + if i != 0 { + outBuf.WriteRune('\n') + } + for ii := 0; ii < tabs; ii++ { + outBuf.WriteRune('\t') + // Bizarrely, all lines except the first need one fewer tabs prepended, so deliberately advance the counter + // by 1 prematurely. + if ii == 0 && i > 0 { + ii++ + } + } + outBuf.WriteString(scanner.Text()) + } + + return outBuf.String() +} + +type failNower interface { + FailNow() +} + +// FailNow fails test +func FailNow(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool { + Fail(t, failureMessage, msgAndArgs...) + + // We cannot extend TestingT with FailNow() and + // maintain backwards compatibility, so we fallback + // to panicking when FailNow is not available in + // TestingT. + // See issue #263 + + if t, ok := t.(failNower); ok { + t.FailNow() + } else { + panic("test failed and t is missing `FailNow()`") + } + return false +} + +// Fail reports a failure through +func Fail(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool { + + message := messageFromMsgAndArgs(msgAndArgs...) + + errorTrace := strings.Join(CallerInfo(), "\n\r\t\t\t") + if len(message) > 0 { + t.Errorf("\r%s\r\tError Trace:\t%s\n"+ + "\r\tError:%s\n"+ + "\r\tMessages:\t%s\n\r", + getWhitespaceString(), + errorTrace, + indentMessageLines(failureMessage, 2), + message) + } else { + t.Errorf("\r%s\r\tError Trace:\t%s\n"+ + "\r\tError:%s\n\r", + getWhitespaceString(), + errorTrace, + indentMessageLines(failureMessage, 2)) + } + + return false +} + +// Implements asserts that an object is implemented by the specified interface. +// +// assert.Implements(t, (*MyInterface)(nil), new(MyObject), "MyObject") +func Implements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { + + interfaceType := reflect.TypeOf(interfaceObject).Elem() + + if !reflect.TypeOf(object).Implements(interfaceType) { + return Fail(t, fmt.Sprintf("%T must implement %v", object, interfaceType), msgAndArgs...) + } + + return true + +} + +// IsType asserts that the specified objects are of the same type. +func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { + + if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) { + return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...) + } + + return true +} + +// Equal asserts that two objects are equal. +// +// assert.Equal(t, 123, 123, "123 and 123 should be equal") +// +// Returns whether the assertion was successful (true) or not (false). +func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + + if !ObjectsAreEqual(expected, actual) { + diff := diff(expected, actual) + return Fail(t, fmt.Sprintf("Not equal: %#v (expected)\n"+ + " != %#v (actual)%s", expected, actual, diff), msgAndArgs...) + } + + return true + +} + +// EqualValues asserts that two objects are equal or convertable to the same types +// and equal. +// +// assert.EqualValues(t, uint32(123), int32(123), "123 and 123 should be equal") +// +// Returns whether the assertion was successful (true) or not (false). +func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + + if !ObjectsAreEqualValues(expected, actual) { + return Fail(t, fmt.Sprintf("Not equal: %#v (expected)\n"+ + " != %#v (actual)", expected, actual), msgAndArgs...) + } + + return true + +} + +// Exactly asserts that two objects are equal is value and type. +// +// assert.Exactly(t, int32(123), int64(123), "123 and 123 should NOT be equal") +// +// Returns whether the assertion was successful (true) or not (false). +func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + + aType := reflect.TypeOf(expected) + bType := reflect.TypeOf(actual) + + if aType != bType { + return Fail(t, fmt.Sprintf("Types expected to match exactly\n\r\t%v != %v", aType, bType), msgAndArgs...) + } + + return Equal(t, expected, actual, msgAndArgs...) + +} + +// NotNil asserts that the specified object is not nil. +// +// assert.NotNil(t, err, "err should be something") +// +// Returns whether the assertion was successful (true) or not (false). +func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + if !isNil(object) { + return true + } + return Fail(t, "Expected value not to be nil.", msgAndArgs...) +} + +// isNil checks if a specified object is nil or not, without Failing. +func isNil(object interface{}) bool { + if object == nil { + return true + } + + value := reflect.ValueOf(object) + kind := value.Kind() + if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() { + return true + } + + return false +} + +// Nil asserts that the specified object is nil. +// +// assert.Nil(t, err, "err should be nothing") +// +// Returns whether the assertion was successful (true) or not (false). +func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + if isNil(object) { + return true + } + return Fail(t, fmt.Sprintf("Expected nil, but got: %#v", object), msgAndArgs...) +} + +var numericZeros = []interface{}{ + int(0), + int8(0), + int16(0), + int32(0), + int64(0), + uint(0), + uint8(0), + uint16(0), + uint32(0), + uint64(0), + float32(0), + float64(0), +} + +// isEmpty gets whether the specified object is considered empty or not. +func isEmpty(object interface{}) bool { + + if object == nil { + return true + } else if object == "" { + return true + } else if object == false { + return true + } + + for _, v := range numericZeros { + if object == v { + return true + } + } + + objValue := reflect.ValueOf(object) + + switch objValue.Kind() { + case reflect.Map: + fallthrough + case reflect.Slice, reflect.Chan: + { + return (objValue.Len() == 0) + } + case reflect.Struct: + switch object.(type) { + case time.Time: + return object.(time.Time).IsZero() + } + case reflect.Ptr: + { + if objValue.IsNil() { + return true + } + switch object.(type) { + case *time.Time: + return object.(*time.Time).IsZero() + default: + return false + } + } + } + return false +} + +// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// assert.Empty(t, obj) +// +// Returns whether the assertion was successful (true) or not (false). +func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + + pass := isEmpty(object) + if !pass { + Fail(t, fmt.Sprintf("Should be empty, but was %v", object), msgAndArgs...) + } + + return pass + +} + +// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either +// a slice or a channel with len == 0. +// +// if assert.NotEmpty(t, obj) { +// assert.Equal(t, "two", obj[1]) +// } +// +// Returns whether the assertion was successful (true) or not (false). +func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + + pass := !isEmpty(object) + if !pass { + Fail(t, fmt.Sprintf("Should NOT be empty, but was %v", object), msgAndArgs...) + } + + return pass + +} + +// getLen try to get length of object. +// return (false, 0) if impossible. +func getLen(x interface{}) (ok bool, length int) { + v := reflect.ValueOf(x) + defer func() { + if e := recover(); e != nil { + ok = false + } + }() + return true, v.Len() +} + +// Len asserts that the specified object has specific length. +// Len also fails if the object has a type that len() not accept. +// +// assert.Len(t, mySlice, 3, "The size of slice is not 3") +// +// Returns whether the assertion was successful (true) or not (false). +func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) bool { + ok, l := getLen(object) + if !ok { + return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", object), msgAndArgs...) + } + + if l != length { + return Fail(t, fmt.Sprintf("\"%s\" should have %d item(s), but has %d", object, length, l), msgAndArgs...) + } + return true +} + +// True asserts that the specified value is true. +// +// assert.True(t, myBool, "myBool should be true") +// +// Returns whether the assertion was successful (true) or not (false). +func True(t TestingT, value bool, msgAndArgs ...interface{}) bool { + + if value != true { + return Fail(t, "Should be true", msgAndArgs...) + } + + return true + +} + +// False asserts that the specified value is false. +// +// assert.False(t, myBool, "myBool should be false") +// +// Returns whether the assertion was successful (true) or not (false). +func False(t TestingT, value bool, msgAndArgs ...interface{}) bool { + + if value != false { + return Fail(t, "Should be false", msgAndArgs...) + } + + return true + +} + +// NotEqual asserts that the specified values are NOT equal. +// +// assert.NotEqual(t, obj1, obj2, "two objects shouldn't be equal") +// +// Returns whether the assertion was successful (true) or not (false). +func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { + + if ObjectsAreEqual(expected, actual) { + return Fail(t, fmt.Sprintf("Should not be: %#v\n", actual), msgAndArgs...) + } + + return true + +} + +// containsElement try loop over the list check if the list includes the element. +// return (false, false) if impossible. +// return (true, false) if element was not found. +// return (true, true) if element was found. +func includeElement(list interface{}, element interface{}) (ok, found bool) { + + listValue := reflect.ValueOf(list) + elementValue := reflect.ValueOf(element) + defer func() { + if e := recover(); e != nil { + ok = false + found = false + } + }() + + if reflect.TypeOf(list).Kind() == reflect.String { + return true, strings.Contains(listValue.String(), elementValue.String()) + } + + if reflect.TypeOf(list).Kind() == reflect.Map { + mapKeys := listValue.MapKeys() + for i := 0; i < len(mapKeys); i++ { + if ObjectsAreEqual(mapKeys[i].Interface(), element) { + return true, true + } + } + return true, false + } + + for i := 0; i < listValue.Len(); i++ { + if ObjectsAreEqual(listValue.Index(i).Interface(), element) { + return true, true + } + } + return true, false + +} + +// Contains asserts that the specified string, list(array, slice...) or map contains the +// specified substring or element. +// +// assert.Contains(t, "Hello World", "World", "But 'Hello World' does contain 'World'") +// assert.Contains(t, ["Hello", "World"], "World", "But ["Hello", "World"] does contain 'World'") +// assert.Contains(t, {"Hello": "World"}, "Hello", "But {'Hello': 'World'} does contain 'Hello'") +// +// Returns whether the assertion was successful (true) or not (false). +func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { + + ok, found := includeElement(s, contains) + if !ok { + return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...) + } + if !found { + return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", s, contains), msgAndArgs...) + } + + return true + +} + +// NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the +// specified substring or element. +// +// assert.NotContains(t, "Hello World", "Earth", "But 'Hello World' does NOT contain 'Earth'") +// assert.NotContains(t, ["Hello", "World"], "Earth", "But ['Hello', 'World'] does NOT contain 'Earth'") +// assert.NotContains(t, {"Hello": "World"}, "Earth", "But {'Hello': 'World'} does NOT contain 'Earth'") +// +// Returns whether the assertion was successful (true) or not (false). +func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { + + ok, found := includeElement(s, contains) + if !ok { + return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...) + } + if found { + return Fail(t, fmt.Sprintf("\"%s\" should not contain \"%s\"", s, contains), msgAndArgs...) + } + + return true + +} + +// Condition uses a Comparison to assert a complex condition. +func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool { + result := comp() + if !result { + Fail(t, "Condition failed!", msgAndArgs...) + } + return result +} + +// PanicTestFunc defines a func that should be passed to the assert.Panics and assert.NotPanics +// methods, and represents a simple func that takes no arguments, and returns nothing. +type PanicTestFunc func() + +// didPanic returns true if the function passed to it panics. Otherwise, it returns false. +func didPanic(f PanicTestFunc) (bool, interface{}) { + + didPanic := false + var message interface{} + func() { + + defer func() { + if message = recover(); message != nil { + didPanic = true + } + }() + + // call the target function + f() + + }() + + return didPanic, message + +} + +// Panics asserts that the code inside the specified PanicTestFunc panics. +// +// assert.Panics(t, func(){ +// GoCrazy() +// }, "Calling GoCrazy() should panic") +// +// Returns whether the assertion was successful (true) or not (false). +func Panics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { + + if funcDidPanic, panicValue := didPanic(f); !funcDidPanic { + return Fail(t, fmt.Sprintf("func %#v should panic\n\r\tPanic value:\t%v", f, panicValue), msgAndArgs...) + } + + return true +} + +// NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic. +// +// assert.NotPanics(t, func(){ +// RemainCalm() +// }, "Calling RemainCalm() should NOT panic") +// +// Returns whether the assertion was successful (true) or not (false). +func NotPanics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { + + if funcDidPanic, panicValue := didPanic(f); funcDidPanic { + return Fail(t, fmt.Sprintf("func %#v should not panic\n\r\tPanic value:\t%v", f, panicValue), msgAndArgs...) + } + + return true +} + +// WithinDuration asserts that the two times are within duration delta of each other. +// +// assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second, "The difference should not be more than 10s") +// +// Returns whether the assertion was successful (true) or not (false). +func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool { + + dt := expected.Sub(actual) + if dt < -delta || dt > delta { + return Fail(t, fmt.Sprintf("Max difference between %v and %v allowed is %v, but difference was %v", expected, actual, delta, dt), msgAndArgs...) + } + + return true +} + +func toFloat(x interface{}) (float64, bool) { + var xf float64 + xok := true + + switch xn := x.(type) { + case uint8: + xf = float64(xn) + case uint16: + xf = float64(xn) + case uint32: + xf = float64(xn) + case uint64: + xf = float64(xn) + case int: + xf = float64(xn) + case int8: + xf = float64(xn) + case int16: + xf = float64(xn) + case int32: + xf = float64(xn) + case int64: + xf = float64(xn) + case float32: + xf = float64(xn) + case float64: + xf = float64(xn) + default: + xok = false + } + + return xf, xok +} + +// InDelta asserts that the two numerals are within delta of each other. +// +// assert.InDelta(t, math.Pi, (22 / 7.0), 0.01) +// +// Returns whether the assertion was successful (true) or not (false). +func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + + af, aok := toFloat(expected) + bf, bok := toFloat(actual) + + if !aok || !bok { + return Fail(t, fmt.Sprintf("Parameters must be numerical"), msgAndArgs...) + } + + if math.IsNaN(af) { + return Fail(t, fmt.Sprintf("Actual must not be NaN"), msgAndArgs...) + } + + if math.IsNaN(bf) { + return Fail(t, fmt.Sprintf("Expected %v with delta %v, but was NaN", expected, delta), msgAndArgs...) + } + + dt := af - bf + if dt < -delta || dt > delta { + return Fail(t, fmt.Sprintf("Max difference between %v and %v allowed is %v, but difference was %v", expected, actual, delta, dt), msgAndArgs...) + } + + return true +} + +// InDeltaSlice is the same as InDelta, except it compares two slices. +func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { + if expected == nil || actual == nil || + reflect.TypeOf(actual).Kind() != reflect.Slice || + reflect.TypeOf(expected).Kind() != reflect.Slice { + return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...) + } + + actualSlice := reflect.ValueOf(actual) + expectedSlice := reflect.ValueOf(expected) + + for i := 0; i < actualSlice.Len(); i++ { + result := InDelta(t, actualSlice.Index(i).Interface(), expectedSlice.Index(i).Interface(), delta) + if !result { + return result + } + } + + return true +} + +func calcRelativeError(expected, actual interface{}) (float64, error) { + af, aok := toFloat(expected) + if !aok { + return 0, fmt.Errorf("expected value %q cannot be converted to float", expected) + } + if af == 0 { + return 0, fmt.Errorf("expected value must have a value other than zero to calculate the relative error") + } + bf, bok := toFloat(actual) + if !bok { + return 0, fmt.Errorf("expected value %q cannot be converted to float", actual) + } + + return math.Abs(af-bf) / math.Abs(af), nil +} + +// InEpsilon asserts that expected and actual have a relative error less than epsilon +// +// Returns whether the assertion was successful (true) or not (false). +func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { + actualEpsilon, err := calcRelativeError(expected, actual) + if err != nil { + return Fail(t, err.Error(), msgAndArgs...) + } + if actualEpsilon > epsilon { + return Fail(t, fmt.Sprintf("Relative error is too high: %#v (expected)\n"+ + " < %#v (actual)", actualEpsilon, epsilon), msgAndArgs...) + } + + return true +} + +// InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices. +func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { + if expected == nil || actual == nil || + reflect.TypeOf(actual).Kind() != reflect.Slice || + reflect.TypeOf(expected).Kind() != reflect.Slice { + return Fail(t, fmt.Sprintf("Parameters must be slice"), msgAndArgs...) + } + + actualSlice := reflect.ValueOf(actual) + expectedSlice := reflect.ValueOf(expected) + + for i := 0; i < actualSlice.Len(); i++ { + result := InEpsilon(t, actualSlice.Index(i).Interface(), expectedSlice.Index(i).Interface(), epsilon) + if !result { + return result + } + } + + return true +} + +/* + Errors +*/ + +// NoError asserts that a function returned no error (i.e. `nil`). +// +// actualObj, err := SomeFunction() +// if assert.NoError(t, err) { +// assert.Equal(t, actualObj, expectedObj) +// } +// +// Returns whether the assertion was successful (true) or not (false). +func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool { + if err != nil { + return Fail(t, fmt.Sprintf("Received unexpected error %q", err), msgAndArgs...) + } + + return true +} + +// Error asserts that a function returned an error (i.e. not `nil`). +// +// actualObj, err := SomeFunction() +// if assert.Error(t, err, "An error was expected") { +// assert.Equal(t, err, expectedError) +// } +// +// Returns whether the assertion was successful (true) or not (false). +func Error(t TestingT, err error, msgAndArgs ...interface{}) bool { + + message := messageFromMsgAndArgs(msgAndArgs...) + if err == nil { + return Fail(t, "An error is expected but got nil. %s", message) + } + + return true +} + +// EqualError asserts that a function returned an error (i.e. not `nil`) +// and that it is equal to the provided error. +// +// actualObj, err := SomeFunction() +// if assert.Error(t, err, "An error was expected") { +// assert.Equal(t, err, expectedError) +// } +// +// Returns whether the assertion was successful (true) or not (false). +func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) bool { + + message := messageFromMsgAndArgs(msgAndArgs...) + if !NotNil(t, theError, "An error is expected but got nil. %s", message) { + return false + } + s := "An error with value \"%s\" is expected but got \"%s\". %s" + return Equal(t, errString, theError.Error(), + s, errString, theError.Error(), message) +} + +// matchRegexp return true if a specified regexp matches a string. +func matchRegexp(rx interface{}, str interface{}) bool { + + var r *regexp.Regexp + if rr, ok := rx.(*regexp.Regexp); ok { + r = rr + } else { + r = regexp.MustCompile(fmt.Sprint(rx)) + } + + return (r.FindStringIndex(fmt.Sprint(str)) != nil) + +} + +// Regexp asserts that a specified regexp matches a string. +// +// assert.Regexp(t, regexp.MustCompile("start"), "it's starting") +// assert.Regexp(t, "start...$", "it's not starting") +// +// Returns whether the assertion was successful (true) or not (false). +func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { + + match := matchRegexp(rx, str) + + if !match { + Fail(t, fmt.Sprintf("Expect \"%v\" to match \"%v\"", str, rx), msgAndArgs...) + } + + return match +} + +// NotRegexp asserts that a specified regexp does not match a string. +// +// assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting") +// assert.NotRegexp(t, "^start", "it's not starting") +// +// Returns whether the assertion was successful (true) or not (false). +func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { + match := matchRegexp(rx, str) + + if match { + Fail(t, fmt.Sprintf("Expect \"%v\" to NOT match \"%v\"", str, rx), msgAndArgs...) + } + + return !match + +} + +// Zero asserts that i is the zero value for its type and returns the truth. +func Zero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool { + if i != nil && !reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { + return Fail(t, fmt.Sprintf("Should be zero, but was %v", i), msgAndArgs...) + } + return true +} + +// NotZero asserts that i is not the zero value for its type and returns the truth. +func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool { + if i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { + return Fail(t, fmt.Sprintf("Should not be zero, but was %v", i), msgAndArgs...) + } + return true +} + +// JSONEq asserts that two JSON strings are equivalent. +// +// assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) +// +// Returns whether the assertion was successful (true) or not (false). +func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) bool { + var expectedJSONAsInterface, actualJSONAsInterface interface{} + + if err := json.Unmarshal([]byte(expected), &expectedJSONAsInterface); err != nil { + return Fail(t, fmt.Sprintf("Expected value ('%s') is not valid json.\nJSON parsing error: '%s'", expected, err.Error()), msgAndArgs...) + } + + if err := json.Unmarshal([]byte(actual), &actualJSONAsInterface); err != nil { + return Fail(t, fmt.Sprintf("Input ('%s') needs to be valid json.\nJSON parsing error: '%s'", actual, err.Error()), msgAndArgs...) + } + + return Equal(t, expectedJSONAsInterface, actualJSONAsInterface, msgAndArgs...) +} + +func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) { + t := reflect.TypeOf(v) + k := t.Kind() + + if k == reflect.Ptr { + t = t.Elem() + k = t.Kind() + } + return t, k +} + +// diff returns a diff of both values as long as both are of the same type and +// are a struct, map, slice or array. Otherwise it returns an empty string. +func diff(expected interface{}, actual interface{}) string { + if expected == nil || actual == nil { + return "" + } + + et, ek := typeAndKind(expected) + at, _ := typeAndKind(actual) + + if et != at { + return "" + } + + if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array { + return "" + } + + spew.Config.SortKeys = true + e := spew.Sdump(expected) + a := spew.Sdump(actual) + + diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ + A: difflib.SplitLines(e), + B: difflib.SplitLines(a), + FromFile: "Expected", + FromDate: "", + ToFile: "Actual", + ToDate: "", + Context: 1, + }) + + return "\n\nDiff:\n" + diff +} diff --git a/vendor/github.com/stretchr/testify/assert/doc.go b/vendor/github.com/stretchr/testify/assert/doc.go new file mode 100644 index 0000000..c9dccc4 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/doc.go @@ -0,0 +1,45 @@ +// Package assert provides a set of comprehensive testing tools for use with the normal Go testing system. +// +// Example Usage +// +// The following is a complete example using assert in a standard test function: +// import ( +// "testing" +// "github.com/stretchr/testify/assert" +// ) +// +// func TestSomething(t *testing.T) { +// +// var a string = "Hello" +// var b string = "Hello" +// +// assert.Equal(t, a, b, "The two words should be the same.") +// +// } +// +// if you assert many times, use the format below: +// +// import ( +// "testing" +// "github.com/stretchr/testify/assert" +// ) +// +// func TestSomething(t *testing.T) { +// assert := assert.New(t) +// +// var a string = "Hello" +// var b string = "Hello" +// +// assert.Equal(a, b, "The two words should be the same.") +// } +// +// Assertions +// +// Assertions allow you to easily write test code, and are global funcs in the `assert` package. +// All assertion functions take, as the first argument, the `*testing.T` object provided by the +// testing framework. This allows the assertion funcs to write the failings and other details to +// the correct place. +// +// Every assertion function also takes an optional string message as the final argument, +// allowing custom error messages to be appended to the message the assertion method outputs. +package assert diff --git a/vendor/github.com/stretchr/testify/assert/errors.go b/vendor/github.com/stretchr/testify/assert/errors.go new file mode 100644 index 0000000..ac9dc9d --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/errors.go @@ -0,0 +1,10 @@ +package assert + +import ( + "errors" +) + +// AnError is an error instance useful for testing. If the code does not care +// about error specifics, and only needs to return the error for example, this +// error should be used to make the test code more readable. +var AnError = errors.New("assert.AnError general error for testing") diff --git a/vendor/github.com/stretchr/testify/assert/forward_assertions.go b/vendor/github.com/stretchr/testify/assert/forward_assertions.go new file mode 100644 index 0000000..b867e95 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/forward_assertions.go @@ -0,0 +1,16 @@ +package assert + +// Assertions provides assertion methods around the +// TestingT interface. +type Assertions struct { + t TestingT +} + +// New makes a new Assertions object for the specified TestingT. +func New(t TestingT) *Assertions { + return &Assertions{ + t: t, + } +} + +//go:generate go run ../_codegen/main.go -output-package=assert -template=assertion_forward.go.tmpl diff --git a/vendor/github.com/stretchr/testify/assert/http_assertions.go b/vendor/github.com/stretchr/testify/assert/http_assertions.go new file mode 100644 index 0000000..e1b9442 --- /dev/null +++ b/vendor/github.com/stretchr/testify/assert/http_assertions.go @@ -0,0 +1,106 @@ +package assert + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" +) + +// httpCode is a helper that returns HTTP code of the response. It returns -1 +// if building a new request fails. +func httpCode(handler http.HandlerFunc, method, url string, values url.Values) int { + w := httptest.NewRecorder() + req, err := http.NewRequest(method, url+"?"+values.Encode(), nil) + if err != nil { + return -1 + } + handler(w, req) + return w.Code +} + +// HTTPSuccess asserts that a specified handler returns a success status code. +// +// assert.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil) +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, values url.Values) bool { + code := httpCode(handler, method, url, values) + if code == -1 { + return false + } + return code >= http.StatusOK && code <= http.StatusPartialContent +} + +// HTTPRedirect asserts that a specified handler returns a redirect status code. +// +// assert.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, values url.Values) bool { + code := httpCode(handler, method, url, values) + if code == -1 { + return false + } + return code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect +} + +// HTTPError asserts that a specified handler returns an error status code. +// +// assert.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}} +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values url.Values) bool { + code := httpCode(handler, method, url, values) + if code == -1 { + return false + } + return code >= http.StatusBadRequest +} + +// HTTPBody is a helper that returns HTTP body of the response. It returns +// empty string if building a new request fails. +func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string { + w := httptest.NewRecorder() + req, err := http.NewRequest(method, url+"?"+values.Encode(), nil) + if err != nil { + return "" + } + handler(w, req) + return w.Body.String() +} + +// HTTPBodyContains asserts that a specified handler returns a +// body that contains a string. +// +// assert.HTTPBodyContains(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky") +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}) bool { + body := HTTPBody(handler, method, url, values) + + contains := strings.Contains(body, fmt.Sprint(str)) + if !contains { + Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body)) + } + + return contains +} + +// HTTPBodyNotContains asserts that a specified handler returns a +// body that does not contain a string. +// +// assert.HTTPBodyNotContains(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky") +// +// Returns whether the assertion was successful (true) or not (false). +func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}) bool { + body := HTTPBody(handler, method, url, values) + + contains := strings.Contains(body, fmt.Sprint(str)) + if contains { + Fail(t, "Expected response body for %s to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body) + } + + return !contains +} diff --git a/vendor/vendor.json b/vendor/vendor.json index f8edbfe..074e58f 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -13,6 +13,18 @@ "path": "github.com/pkg/errors", "revision": "a2d6902c6d2a2f194eb3fb474981ab7867c81505", "revisionTime": "2016-06-27T22:23:52Z" + }, + { + "checksumSHA1": "dF3fORwN1HTgrlrdmll9K2cOjOg=", + "path": "github.com/samuel/go-zookeeper/zk", + "revision": "e64db453f3512cade908163702045e0f31137843", + "revisionTime": "2016-06-16T02:49:54Z" + }, + { + "checksumSHA1": "iydUphwYqZRq3WhstEdGsbvBAKs=", + "path": "github.com/stretchr/testify/assert", + "revision": "d77da356e56a7428ad25149ca77381849a6a5232", + "revisionTime": "2016-06-15T09:26:46Z" } ], "rootPath": "github.com/rdelval/gorealis" diff --git a/zk.go b/zk.go new file mode 100644 index 0000000..4fd09bb --- /dev/null +++ b/zk.go @@ -0,0 +1,88 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package realis + +import ( + "encoding/json" + "fmt" + "github.com/pkg/errors" + "github.com/samuel/go-zookeeper/zk" + "strconv" + "strings" + "time" +) + +type Endpoint struct { + Host string `json:"host"` + Port int `json:"port"` +} + +type ServiceInstance struct { + Service Endpoint `json:"serviceEndpoint"` + AdditionalEndpoints map[string]Endpoint `json:"additionalEndpoints"` + Status string `json:"status"` +} + +// Loads leader from ZK endpoint. +func LeaderFromZK(cluster Cluster) (string, error) { + + endpoints := strings.Split(cluster.ZK, ",") + //TODO (rdelvalle): When enabling debugging, change logger here + c, _, err := zk.Connect(endpoints, time.Second*10, zk.WithoutLogger()) + defer c.Close() + if err != nil { + return "", errors.Wrap(err, "Failed to connect to Zookeeper at "+cluster.ZK) + } + + children, _, _, err := c.ChildrenW(cluster.SchedZKPath) + if err != nil { + return "", errors.Wrapf(err, "Path %s doesn't exist on Zookeeper ", cluster.SchedZKPath) + } + + serviceInst := new(ServiceInstance) + + for _, child := range children { + + // Only the leader will start with member_ + if strings.HasPrefix(child, "member_") { + + data, _, err := c.Get(cluster.SchedZKPath + "/" + child) + if err != nil { + return "", errors.Wrap(err, "Error fetching contents of leader") + } + + err = json.Unmarshal([]byte(data), serviceInst) + if err != nil { + return "", errors.Wrap(err, "Unable to unmarshall contents of leader") + } + + // Should only be one endpoint + if len(serviceInst.AdditionalEndpoints) > 1 { + fmt.Errorf("Ambiguous end points schemes") + } + + var scheme, host, port string + for k, v := range serviceInst.AdditionalEndpoints { + scheme = k + host = v.Host + port = strconv.Itoa(v.Port) + } + + return scheme + "://" + host + ":" + port, nil + } + } + + return "", errors.New("No leader found") +}