diff --git a/internal/dcs/config.go b/internal/dcs/config.go index 5047e00..b199723 100644 --- a/internal/dcs/config.go +++ b/internal/dcs/config.go @@ -9,27 +9,38 @@ import ( // ZookeeperConfig contains Zookeeper connection info type ZookeeperConfig struct { - Hostname string `config:"hostname" yaml:"hostname"` - SessionTimeout time.Duration `config:"session_timeout" yaml:"session_timeout"` - Namespace string `config:"namespace,required"` - Hosts []string `config:"hosts,required"` - BackoffInterval time.Duration `config:"backoff_interval" yaml:"backoff_interval"` - BackoffRandFactor float64 `config:"backoff_rand_factor" yaml:"backoff_rand_factor"` - BackoffMultiplier float64 `config:"backoff_multiplier" yaml:"backoff_multiplier"` - BackoffMaxInterval time.Duration `config:"backoff_max_interval" yaml:"backoff_max_interval"` - BackoffMaxElapsedTime time.Duration `config:"backoff_max_elapsed_time" yaml:"backoff_max_elapsed_time"` - BackoffMaxRetries uint64 `config:"backoff_max_retries" yaml:"backoff_max_retries"` - Auth bool `config:"auth" yaml:"auth"` - Username string `config:"username" yaml:"username"` - Password string `config:"password" yaml:"password"` - UseSSL bool `config:"use_ssl" yaml:"use_ssl"` - KeyFile string `config:"keyfile" yaml:"keyfile"` - CertFile string `config:"certfile" yaml:"certfile"` - CACert string `config:"ca_cert" yaml:"ca_cert"` - VerifyCerts bool `config:"verify_certs" yaml:"verify_certs"` + Hostname string `config:"hostname" yaml:"hostname"` + SessionTimeout time.Duration `config:"session_timeout" yaml:"session_timeout"` + Namespace string `config:"namespace,required"` + Hosts []string `config:"hosts,required"` + BackoffInterval time.Duration `config:"backoff_interval" yaml:"backoff_interval"` + BackoffRandFactor float64 `config:"backoff_rand_factor" yaml:"backoff_rand_factor"` + BackoffMultiplier float64 `config:"backoff_multiplier" yaml:"backoff_multiplier"` + BackoffMaxInterval time.Duration `config:"backoff_max_interval" yaml:"backoff_max_interval"` + BackoffMaxElapsedTime time.Duration `config:"backoff_max_elapsed_time" yaml:"backoff_max_elapsed_time"` + BackoffMaxRetries uint64 `config:"backoff_max_retries" yaml:"backoff_max_retries"` + RandomHostProvider RandomHostProviderConfig `config:"random_host_provider" yaml:"random_host_provider"` + Auth bool `config:"auth" yaml:"auth"` + Username string `config:"username" yaml:"username"` + Password string `config:"password" yaml:"password"` + UseSSL bool `config:"use_ssl" yaml:"use_ssl"` + KeyFile string `config:"keyfile" yaml:"keyfile"` + CertFile string `config:"certfile" yaml:"certfile"` + CACert string `config:"ca_cert" yaml:"ca_cert"` + VerifyCerts bool `config:"verify_certs" yaml:"verify_certs"` } -// DefaultZookeeperConfig return default Zookeeper connection configuration +type RandomHostProviderConfig struct { + LookupTTL time.Duration `config:"lookup_ttl" yaml:"lookup_ttl"` +} + +func DefaultRandomHostProviderConfig() RandomHostProviderConfig { + return RandomHostProviderConfig{ + LookupTTL: 300 * time.Second, + } +} + +// DefaultZookeeperConfig returns default Zookeeper connection configuration func DefaultZookeeperConfig() (ZookeeperConfig, error) { hostname, err := os.Hostname() if err != nil { @@ -44,6 +55,7 @@ func DefaultZookeeperConfig() (ZookeeperConfig, error) { BackoffMaxInterval: backoff.DefaultMaxInterval, BackoffMaxElapsedTime: backoff.DefaultMaxElapsedTime, BackoffMaxRetries: 10, + RandomHostProvider: DefaultRandomHostProviderConfig(), } return config, nil } diff --git a/internal/dcs/zk.go b/internal/dcs/zk.go index 9139c95..00c9419 100644 --- a/internal/dcs/zk.go +++ b/internal/dcs/zk.go @@ -71,6 +71,9 @@ func NewZookeeper(config *ZookeeperConfig, logger *slog.Logger) (DCS, error) { proxyLogger := logger.With("module", "dcs") var operation func() error + + hostProvider := NewRandomHostProvider(&config.RandomHostProvider, logger) + if config.UseSSL { if config.CACert == "" || config.KeyFile == "" || config.CertFile == "" { return nil, fmt.Errorf("zookeeper ssl not configured, fill ca_cert/key_file/cert_file in config or disable use_ssl flag") @@ -85,12 +88,12 @@ func NewZookeeper(config *ZookeeperConfig, logger *slog.Logger) (DCS, error) { return nil, err } operation = func() error { - conn, ec, err = zk.Connect(config.Hosts, config.SessionTimeout, zk.WithLogger(zkLoggerProxy{proxyLogger}), zk.WithDialer(dialer)) + conn, ec, err = zk.Connect(config.Hosts, config.SessionTimeout, zk.WithLogger(zkLoggerProxy{proxyLogger}), zk.WithDialer(dialer), zk.WithHostProvider(hostProvider)) return err } } else { operation = func() error { - conn, ec, err = zk.Connect(config.Hosts, config.SessionTimeout, zk.WithLogger(zkLoggerProxy{proxyLogger})) + conn, ec, err = zk.Connect(config.Hosts, config.SessionTimeout, zk.WithLogger(zkLoggerProxy{proxyLogger}), zk.WithHostProvider(hostProvider)) return err } } diff --git a/internal/dcs/zk_host_provider.go b/internal/dcs/zk_host_provider.go new file mode 100644 index 0000000..4eb3cb4 --- /dev/null +++ b/internal/dcs/zk_host_provider.go @@ -0,0 +1,122 @@ +package dcs + +import ( + "fmt" + "log/slog" + "math/rand" + "net" + "sync" + "time" +) + +type RandomHostProvider struct { + lock sync.Mutex + servers []string + resolved []string + tried map[string]struct{} + logger *slog.Logger + lastLookup time.Time + lookupTTL time.Duration +} + +func NewRandomHostProvider(config *RandomHostProviderConfig, logger *slog.Logger) *RandomHostProvider { + return &RandomHostProvider{ + lookupTTL: config.LookupTTL, + logger: logger, + tried: make(map[string]struct{}), + } +} + +func (rhp *RandomHostProvider) Init(servers []string) error { + rhp.lock.Lock() + defer rhp.lock.Unlock() + + rhp.servers = servers + + err := rhp.resolveHosts() + + if err != nil { + return fmt.Errorf("failed to init zk host provider %v", err) + } + + return nil +} + +func (rhp *RandomHostProvider) resolveHosts() error { + resolved := []string{} + for _, server := range rhp.servers { + host, port, err := net.SplitHostPort(server) + if err != nil { + return err + } + addrs, err := net.LookupHost(host) + if err != nil { + rhp.logger.Error(fmt.Sprintf("unable to resolve %s", host), "error", err) + } + for _, addr := range addrs { + resolved = append(resolved, net.JoinHostPort(addr, port)) + } + } + + if len(resolved) == 0 { + return fmt.Errorf("no hosts resolved for %q", rhp.servers) + } + + rhp.lastLookup = time.Now() + rhp.resolved = resolved + + rand.Shuffle(len(rhp.resolved), func(i, j int) { rhp.resolved[i], rhp.resolved[j] = rhp.resolved[j], rhp.resolved[i] }) + + return nil +} + +func (rhp *RandomHostProvider) Len() int { + rhp.lock.Lock() + defer rhp.lock.Unlock() + return len(rhp.resolved) +} + +func (rhp *RandomHostProvider) Next() (server string, retryStart bool) { + rhp.lock.Lock() + defer rhp.lock.Unlock() + lastTime := time.Since(rhp.lastLookup) + needRetry := false + if lastTime > rhp.lookupTTL { + err := rhp.resolveHosts() + if err != nil { + rhp.logger.Error("resolve zk hosts failed", "error", err) + } + } + + notTried := []string{} + + for _, addr := range rhp.resolved { + if _, ok := rhp.tried[addr]; !ok { + notTried = append(notTried, addr) + } + } + + var selected string + + if len(notTried) == 0 { + needRetry = true + for k := range rhp.tried { + delete(rhp.tried, k) + } + selected = rhp.resolved[rand.Intn(len(rhp.resolved))] + } else { + selected = notTried[rand.Intn(len(notTried))] + } + + rhp.tried[selected] = struct{}{} + + return selected, needRetry +} + +func (rhp *RandomHostProvider) Connected() { + rhp.lock.Lock() + defer rhp.lock.Unlock() + for k := range rhp.tried { + delete(rhp.tried, k) + } +}