package zk

import (
	"fmt"
	"log"
	"testing"
	"time"
)

// localhostLookupHost is a test replacement for net.LookupHost that
// always returns 127.0.0.1
func localhostLookupHost(host string) ([]string, error) {
	return []string{"127.0.0.1"}, nil
}

// TestDNSHostProviderCreate is just like TestCreate, but with an
// overridden HostProvider that ignores the provided hostname.
func TestDNSHostProviderCreate(t *testing.T) {
	ts, err := StartTestCluster(1, nil, logWriter{t: t, p: "[ZKERR] "})
	if err != nil {
		t.Fatal(err)
	}
	defer ts.Stop()

	port := ts.Servers[0].Port
	server := fmt.Sprintf("foo.example.com:%d", port)
	hostProvider := &DNSHostProvider{lookupHost: localhostLookupHost}
	zk, _, err := Connect([]string{server}, time.Second*15, WithHostProvider(hostProvider))
	if err != nil {
		t.Fatalf("Connect returned error: %+v", err)
	}
	defer zk.Close()

	path := "/gozk-test"

	if err := zk.Delete(path, -1); err != nil && err != ErrNoNode {
		t.Fatalf("Delete returned error: %+v", err)
	}
	if p, err := zk.Create(path, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil {
		t.Fatalf("Create returned error: %+v", err)
	} else if p != path {
		t.Fatalf("Create returned different path '%s' != '%s'", p, path)
	}
	if data, stat, err := zk.Get(path); err != nil {
		t.Fatalf("Get returned error: %+v", err)
	} else if stat == nil {
		t.Fatal("Get returned nil stat")
	} else if len(data) < 4 {
		t.Fatal("Get returned wrong size data")
	}
}

// localHostPortsFacade wraps a HostProvider, remapping the
// address/port combinations it returns to "localhost:$PORT" where
// $PORT is chosen from the provided ports.
type localHostPortsFacade struct {
	inner    HostProvider      // The wrapped HostProvider
	ports    []int             // The provided list of ports
	nextPort int               // The next port to use
	mapped   map[string]string // Already-mapped address/port combinations
}

func newLocalHostPortsFacade(inner HostProvider, ports []int) *localHostPortsFacade {
	return &localHostPortsFacade{
		inner:  inner,
		ports:  ports,
		mapped: make(map[string]string),
	}
}

func (lhpf *localHostPortsFacade) Len() int                    { return lhpf.inner.Len() }
func (lhpf *localHostPortsFacade) Connected()                  { lhpf.inner.Connected() }
func (lhpf *localHostPortsFacade) Init(servers []string) error { return lhpf.inner.Init(servers) }
func (lhpf *localHostPortsFacade) Next() (string, bool) {
	server, retryStart := lhpf.inner.Next()

	// If we've already set up a mapping for that server, just return it.
	if localMapping := lhpf.mapped[server]; localMapping != "" {
		return localMapping, retryStart
	}

	if lhpf.nextPort == len(lhpf.ports) {
		log.Fatalf("localHostPortsFacade out of ports to assign to %q; current config: %q", server, lhpf.mapped)
	}

	localMapping := fmt.Sprintf("localhost:%d", lhpf.ports[lhpf.nextPort])
	lhpf.mapped[server] = localMapping
	lhpf.nextPort++
	return localMapping, retryStart
}

var _ HostProvider = &localHostPortsFacade{}

// TestDNSHostProviderReconnect tests that the zk.Conn correctly
// reconnects when the Zookeeper instance it's connected to
// restarts. It wraps the DNSHostProvider in a lightweight facade that
// remaps addresses to localhost:$PORT combinations corresponding to
// the test ZooKeeper instances.
func TestDNSHostProviderReconnect(t *testing.T) {
	ts, err := StartTestCluster(3, nil, logWriter{t: t, p: "[ZKERR] "})
	if err != nil {
		t.Fatal(err)
	}
	defer ts.Stop()

	innerHp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) {
		return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil
	}}
	ports := make([]int, 0, len(ts.Servers))
	for _, server := range ts.Servers {
		ports = append(ports, server.Port)
	}
	hp := newLocalHostPortsFacade(innerHp, ports)

	zk, _, err := Connect([]string{"foo.example.com:12345"}, time.Second, WithHostProvider(hp))
	if err != nil {
		t.Fatalf("Connect returned error: %+v", err)
	}
	defer zk.Close()

	path := "/gozk-test"

	// Initial operation to force connection.
	if err := zk.Delete(path, -1); err != nil && err != ErrNoNode {
		t.Fatalf("Delete returned error: %+v", err)
	}

	// Figure out which server we're connected to.
	currentServer := zk.Server()
	t.Logf("Connected to %q. Finding test server index…", currentServer)
	serverIndex := -1
	for i, server := range ts.Servers {
		server := fmt.Sprintf("localhost:%d", server.Port)
		t.Logf("…trying %q", server)
		if currentServer == server {
			serverIndex = i
			t.Logf("…found at index %d", i)
			break
		}
	}
	if serverIndex == -1 {
		t.Fatalf("Cannot determine test server index.")
	}

	// Restart the connected server.
	ts.Servers[serverIndex].Srv.Stop()
	ts.Servers[serverIndex].Srv.Start()

	// Continue with the basic TestCreate tests.
	if p, err := zk.Create(path, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil {
		t.Fatalf("Create returned error: %+v", err)
	} else if p != path {
		t.Fatalf("Create returned different path '%s' != '%s'", p, path)
	}
	if data, stat, err := zk.Get(path); err != nil {
		t.Fatalf("Get returned error: %+v", err)
	} else if stat == nil {
		t.Fatal("Get returned nil stat")
	} else if len(data) < 4 {
		t.Fatal("Get returned wrong size data")
	}

	if zk.Server() == currentServer {
		t.Errorf("Still connected to %q after restart.", currentServer)
	}
}

// TestDNSHostProviderRetryStart tests the `retryStart` functionality
// of DNSHostProvider.
// It's also probably the clearest visual explanation of exactly how
// it works.
func TestDNSHostProviderRetryStart(t *testing.T) {
	t.Parallel()

	hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) {
		return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil
	}}

	if err := hp.Init([]string{"foo.example.com:12345"}); err != nil {
		t.Fatal(err)
	}

	testdata := []struct {
		retryStartWant bool
		callConnected  bool
	}{
		// Repeated failures.
		{false, false},
		{false, false},
		{false, false},
		{true, false},
		{false, false},
		{false, false},
		{true, true},

		// One success offsets things.
		{false, false},
		{false, true},
		{false, true},

		// Repeated successes.
		{false, true},
		{false, true},
		{false, true},
		{false, true},
		{false, true},

		// And some more failures.
		{false, false},
		{false, false},
		{true, false}, // Looped back to last known good server: all alternates failed.
		{false, false},
	}

	for i, td := range testdata {
		_, retryStartGot := hp.Next()
		if retryStartGot != td.retryStartWant {
			t.Errorf("%d: retryStart=%v; want %v", i, retryStartGot, td.retryStartWant)
		}
		if td.callConnected {
			hp.Connected()
		}
	}
}