diff --git a/internal/config/config.go b/internal/config/config.go index b87ab8e4..46a49aa4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -72,9 +72,11 @@ type Static struct { } type Controller struct { - Interval time.Duration `mapstructure:"interval"` - PrepTimeout time.Duration `mapstructure:"prep_timeout"` - InitialSleepDuration time.Duration `mapstructure:"initial_sleep_duration"` + Interval time.Duration `mapstructure:"interval"` + PrepTimeout time.Duration `mapstructure:"prep_timeout"` + InitialSleepDuration time.Duration `mapstructure:"initial_sleep_duration"` + HealthySnapshotIntervalLimit time.Duration `mapstructure:"healthy_snapshot_interval_limit"` + InitializationTimeoutExtension time.Duration `mapstructure:"initialization_timeout_extension"` } var cfg *Config @@ -95,6 +97,8 @@ func Get() Config { viper.SetDefault("controller.interval", 15*time.Second) viper.SetDefault("controller.prep_timeout", 10*time.Minute) viper.SetDefault("controller.initial_sleep_duration", 30*time.Second) + viper.SetDefault("controller.healthy_snapshot_interval_limit", 10*time.Minute) + viper.SetDefault("controller.initialization_timeout_extension", 5*time.Minute) viper.AutomaticEnv() viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) diff --git a/internal/services/controller/controller.go b/internal/services/controller/controller.go index 16d7b017..5258015c 100644 --- a/internal/services/controller/controller.go +++ b/internal/services/controller/controller.go @@ -47,7 +47,8 @@ type Controller struct { delta *delta deltaMu sync.Mutex - agentVersion *config.AgentVersion + agentVersion *config.AgentVersion + healthzProvider *HealthzProvider } func New( @@ -59,7 +60,10 @@ func New( cfg *config.Controller, v version.Interface, agentVersion *config.AgentVersion, + healthzProvider *HealthzProvider, ) *Controller { + healthzProvider.Initializing() + typeInformerMap := map[reflect.Type]cache.SharedInformer{ reflect.TypeOf(&corev1.Node{}): f.Core().V1().Nodes().Informer(), reflect.TypeOf(&corev1.Pod{}): f.Core().V1().Pods().Informer(), @@ -86,15 +90,16 @@ func New( } c := &Controller{ - log: log, - clusterID: clusterID, - castaiclient: castaiclient, - provider: provider, - cfg: cfg, - delta: newDelta(log, clusterID, v.Full()), - queue: workqueue.NewNamed("castai-agent"), - informers: typeInformerMap, - agentVersion: agentVersion, + log: log, + clusterID: clusterID, + castaiclient: castaiclient, + provider: provider, + cfg: cfg, + delta: newDelta(log, clusterID, v.Full()), + queue: workqueue.NewNamed("castai-agent"), + informers: typeInformerMap, + agentVersion: agentVersion, + healthzProvider: healthzProvider, } c.registerEventHandlers() @@ -228,7 +233,7 @@ func removeSensitiveEnvVars(obj interface{}) { } } -func (c *Controller) Run(ctx context.Context) { +func (c *Controller) Run(ctx context.Context) error { defer c.queue.ShutDown() ctx, cancel := context.WithCancel(ctx) @@ -243,7 +248,7 @@ func (c *Controller) Run(ctx context.Context) { c.log.Info("waiting for informers cache to sync") if !cache.WaitForCacheSync(ctx.Done(), syncs...) { c.log.Error("failed to sync") - return + return fmt.Errorf("failed to wait for cache sync") } c.log.Infof("informers cache synced after %v", time.Since(waitStartedAt)) @@ -282,6 +287,8 @@ func (c *Controller) Run(ctx context.Context) { c.log.Infof("sleeping for %s before starting to send cluster deltas", c.cfg.InitialSleepDuration) time.Sleep(c.cfg.InitialSleepDuration) + c.healthzProvider.Initialized() + c.log.Infof("sending cluster deltas every %s", c.cfg.Interval) wait.Until(func() { c.send(ctx) @@ -294,6 +301,8 @@ func (c *Controller) Run(ctx context.Context) { }() c.pollQueueUntilShutdown() + + return nil } // collectInitialSnapshot is used to add a time buffer to collect the initial snapshot which is larger than periodic @@ -389,6 +398,8 @@ func (c *Controller) send(ctx context.Context) { return } + c.healthzProvider.SnapshotSent() + c.delta.clear() } diff --git a/internal/services/controller/controller_exclude_race_test.go b/internal/services/controller/controller_exclude_race_test.go index 11a6e676..d9ca19dd 100644 --- a/internal/services/controller/controller_exclude_race_test.go +++ b/internal/services/controller/controller_exclude_race_test.go @@ -121,24 +121,17 @@ func TestController_ShouldKeepDeltaAfterDelete(t *testing.T) { log := logrus.New() log.SetLevel(logrus.DebugLevel) - ctrl := New( - log, - f, - castaiclient, - provider, - clusterID.String(), - &config.Controller{ - Interval: 2 * time.Second, - PrepTimeout: 2 * time.Second, - InitialSleepDuration: 10 * time.Millisecond, - }, - version, - agentVersion, - ) + ctrl := New(log, f, castaiclient, provider, clusterID.String(), &config.Controller{ + Interval: 2 * time.Second, + PrepTimeout: 2 * time.Second, + InitialSleepDuration: 10 * time.Millisecond, + }, version, agentVersion, NewHealthzProvider(defaultHealthzCfg)) f.Start(ctx.Done()) - go ctrl.Run(ctx) + go func() { + require.NoError(t, ctrl.Run(ctx)) + }() wait.Until(func() { if atomic.LoadInt64(&invocations) >= 3 { diff --git a/internal/services/controller/controller_test.go b/internal/services/controller/controller_test.go index 9df5d36c..71929565 100644 --- a/internal/services/controller/controller_test.go +++ b/internal/services/controller/controller_test.go @@ -32,6 +32,14 @@ import ( "castai-agent/pkg/labels" ) +var defaultHealthzCfg = config.Config{Controller: &config.Controller{ + Interval: 15 * time.Second, + PrepTimeout: 10 * time.Minute, + InitialSleepDuration: 30 * time.Second, + HealthySnapshotIntervalLimit: 10 * time.Minute, + InitializationTimeoutExtension: 5 * time.Minute, +}} + func TestMain(m *testing.M) { goleak.VerifyTestMain( m, @@ -101,23 +109,16 @@ func TestController_HappyPath(t *testing.T) { f := informers.NewSharedInformerFactory(clientset, 0) log := logrus.New() log.SetLevel(logrus.DebugLevel) - ctrl := New( - log, - f, - castaiclient, - provider, - clusterID.String(), - &config.Controller{ - Interval: 15 * time.Second, - PrepTimeout: 2 * time.Second, - InitialSleepDuration: 10 * time.Millisecond, - }, - version, - agentVersion, - ) + ctrl := New(log, f, castaiclient, provider, clusterID.String(), &config.Controller{ + Interval: 15 * time.Second, + PrepTimeout: 2 * time.Second, + InitialSleepDuration: 10 * time.Millisecond, + }, version, agentVersion, NewHealthzProvider(defaultHealthzCfg)) f.Start(ctx.Done()) - go ctrl.Run(ctx) + go func() { + require.NoError(t, ctrl.Run(ctx)) + }() wait.Until(func() { if atomic.LoadInt64(&invocations) >= 1 { diff --git a/internal/services/controller/healthz.go b/internal/services/controller/healthz.go new file mode 100644 index 00000000..1fbb6dfc --- /dev/null +++ b/internal/services/controller/healthz.go @@ -0,0 +1,67 @@ +package controller + +import ( + "fmt" + "net/http" + "time" + + "castai-agent/internal/config" +) + +func NewHealthzProvider(cfg config.Config) *HealthzProvider { + return &HealthzProvider{ + cfg: cfg, + initHardTimeout: cfg.Controller.PrepTimeout + cfg.Controller.InitialSleepDuration + cfg.Controller.InitializationTimeoutExtension, + } +} + +type HealthzProvider struct { + cfg config.Config + initHardTimeout time.Duration + + initializeStartedAt *time.Time + lastHealthyActionAt *time.Time +} + +func (h *HealthzProvider) Check(_ *http.Request) error { + if h.lastHealthyActionAt != nil { + if time.Since(*h.lastHealthyActionAt) > h.cfg.Controller.HealthySnapshotIntervalLimit { + return fmt.Errorf("time since initialization or last snapshot sent is over the considered healthy limit of %s", h.cfg.Controller.HealthySnapshotIntervalLimit) + } + return nil + } + + if h.initializeStartedAt != nil { + if time.Since(*h.initializeStartedAt) > h.initHardTimeout { + return fmt.Errorf("controller initialization is taking longer than the hard timeout of %s", h.initHardTimeout) + } + return nil + } + + return fmt.Errorf("healthz not initialized") +} + +func (h *HealthzProvider) Initializing() { + if h.initializeStartedAt == nil { + h.initializeStartedAt = nowPtr() + h.lastHealthyActionAt = nil + } +} + +func (h *HealthzProvider) Initialized() { + h.healthyAction() +} + +func (h *HealthzProvider) SnapshotSent() { + h.healthyAction() +} + +func (h *HealthzProvider) healthyAction() { + h.initializeStartedAt = nil + h.lastHealthyActionAt = nowPtr() +} + +func nowPtr() *time.Time { + now := time.Now() + return &now +} diff --git a/internal/services/controller/healthz_test.go b/internal/services/controller/healthz_test.go new file mode 100644 index 00000000..6ab5454c --- /dev/null +++ b/internal/services/controller/healthz_test.go @@ -0,0 +1,74 @@ +package controller + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + + "castai-agent/internal/config" +) + +func TestNewHealthzProvider(t *testing.T) { + t.Run("unhealthy statuses", func(t *testing.T) { + cfg := config.Config{Controller: &config.Controller{ + Interval: 15 * time.Second, + PrepTimeout: time.Millisecond, + InitialSleepDuration: time.Millisecond, + InitializationTimeoutExtension: time.Millisecond, + HealthySnapshotIntervalLimit: time.Millisecond, + }} + + h := NewHealthzProvider(cfg) + + t.Run("should return not initialized error", func(t *testing.T) { + require.Error(t, h.Check(nil)) + }) + + t.Run("should return initialize timeout error", func(t *testing.T) { + h.Initializing() + + time.Sleep(5 * time.Millisecond) + + require.Error(t, h.Check(nil)) + }) + + t.Run("should return snapshot timeout error", func(t *testing.T) { + h.healthyAction() + + time.Sleep(5 * time.Millisecond) + + require.Error(t, h.Check(nil)) + }) + }) + + t.Run("healthy statuses", func(t *testing.T) { + cfg := config.Config{Controller: &config.Controller{ + Interval: 15 * time.Second, + PrepTimeout: 10 * time.Minute, + InitialSleepDuration: 30 * time.Second, + InitializationTimeoutExtension: 5 * time.Minute, + HealthySnapshotIntervalLimit: 10 * time.Minute, + }} + + h := NewHealthzProvider(cfg) + + t.Run("should return no error when still initializing", func(t *testing.T) { + h.Initializing() + + require.NoError(t, h.Check(nil)) + }) + + t.Run("should return no error when timeout after initialization has not yet passed", func(t *testing.T) { + h.Initialized() + + require.NoError(t, h.Check(nil)) + }) + + t.Run("should return no error when time since last snapshot has not been long", func(t *testing.T) { + h.SnapshotSent() + + require.NoError(t, h.Check(nil)) + }) + }) +} diff --git a/internal/services/controller/worker.go b/internal/services/controller/worker.go new file mode 100644 index 00000000..73e0de11 --- /dev/null +++ b/internal/services/controller/worker.go @@ -0,0 +1,111 @@ +package controller + +import ( + "context" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/sirupsen/logrus" + "k8s.io/client-go/informers" + "k8s.io/client-go/kubernetes" + + "castai-agent/internal/castai" + "castai-agent/internal/config" + "castai-agent/internal/services/providers/types" + "castai-agent/internal/services/version" +) + +type Worker struct { + Fn func(ctx context.Context) error + + stop context.CancelFunc + exitCh chan struct{} +} + +func (w *Worker) Start(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + w.exitCh = make(chan struct{}) + defer close(w.exitCh) + + w.stop = cancel + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if err := w.Fn(ctx); err != nil { + return fmt.Errorf("running controller function: %w", err) + } + } +} + +func (w *Worker) Stop(log logrus.FieldLogger) { + if w.stop == nil { + return + } + + w.stop() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + select { + case <-w.exitCh: + case <-ctx.Done(): + log.Errorf("waiting for controller to exit: %v", ctx.Err()) + } +} + +func RunController( + log logrus.FieldLogger, + clientset kubernetes.Interface, + castaiclient castai.Client, + provider types.Provider, + clusterID string, + cfg config.Config, + agentVersion *config.AgentVersion, + healthzProvider *HealthzProvider, +) func(ctx context.Context) error { + return func(ctx context.Context) error { + log = log.WithField("controller_id", uuid.New().String()) + + defer func() { + if err := recover(); err != nil { + log.Errorf("panic: runtime error: %v", err) + } + }() + + ctrlCtx, cancelCtrlCtx := context.WithCancel(ctx) + defer cancelCtrlCtx() + + v, err := version.Get(log, clientset) + if err != nil { + return fmt.Errorf("getting kubernetes version: %w", err) + } + + log = log.WithField("k8s_version", v.Full()) + + f := informers.NewSharedInformerFactory(clientset, 0) + ctrl := New( + log, + f, + castaiclient, + provider, + clusterID, + cfg.Controller, + v, + agentVersion, + healthzProvider, + ) + f.Start(ctrlCtx.Done()) + + // Run the controller. This is a blocking call. + return ctrl.Run(ctrlCtx) + } +} diff --git a/main.go b/main.go index d8b82a0a..9595cd0f 100644 --- a/main.go +++ b/main.go @@ -9,11 +9,11 @@ import ( _ "net/http/pprof" "time" + "sigs.k8s.io/controller-runtime/pkg/healthz" + castailog "castai-agent/pkg/log" "github.com/sirupsen/logrus" - "k8s.io/apimachinery/pkg/util/wait" - "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" @@ -23,7 +23,6 @@ import ( "castai-agent/internal/config" "castai-agent/internal/services/controller" "castai-agent/internal/services/providers" - "castai-agent/internal/services/version" ) // These should be set via `go build` during a release @@ -51,9 +50,14 @@ func main() { } log.Fatalf("agent failed: %v", err) } + + log.Info("agent shutdown") } func run(ctx context.Context, castaiclient castai.Client, logger *logrus.Logger, cfg config.Config) (reterr error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + fields := logrus.Fields{} defer func() { @@ -118,44 +122,55 @@ func run(ctx context.Context, castaiclient castai.Client, logger *logrus.Logger, fields["cluster_id"] = clusterID log = log.WithFields(fields) + exitCh := make(chan error) + if cfg.PprofPort != 0 { + addr := fmt.Sprintf(":%d", cfg.PprofPort) + pprofSrv := &http.Server{Addr: addr, Handler: http.DefaultServeMux} + defer func() { + if err := pprofSrv.Close(); err != nil { + log.Errorf("closing pprof server: %v", err) + } + }() + go func() { - addr := fmt.Sprintf(":%d", cfg.PprofPort) log.Infof("starting pprof server on %s", addr) - if err := http.ListenAndServe(addr, http.DefaultServeMux); err != nil { - log.Errorf("failed to start pprof http server: %v", err) - } + exitCh <- fmt.Errorf("pprof server: %w", pprofSrv.ListenAndServe()) }() } - wait.Until(func() { - ctrlCtx, cancelCtrlCtx := context.WithCancel(ctx) - defer cancelCtrlCtx() + ctrlHealthz := controller.NewHealthzProvider(cfg) - v, err := version.Get(log, clientset) - if err != nil { - log.Fatalf("failed getting kubernetes version: %v", err) + healthzSrv := &http.Server{Addr: ":9876", Handler: &healthz.Handler{Checks: map[string]healthz.Checker{ + "server": healthz.Ping, + "controller": ctrlHealthz.Check, + }}} + defer func() { + if err := healthzSrv.Close(); err != nil { + log.Errorf("closing healthz server: %v", err) } + }() + + go func() { + exitCh <- fmt.Errorf("healthz server: %w", healthzSrv.ListenAndServe()) + }() + + w := &controller.Worker{ + Fn: controller.RunController(log, clientset, castaiclient, provider, clusterID, cfg, agentVersion, ctrlHealthz), + } + defer w.Stop(log) - fields["k8s_version"] = v.Full() - log = log.WithFields(fields) - - f := informers.NewSharedInformerFactory(clientset, 0) - ctrl := controller.New( - log, - f, - castaiclient, - provider, - clusterID, - cfg.Controller, - v, - agentVersion, - ) - f.Start(ctrlCtx.Done()) - ctrl.Run(ctrlCtx) - }, 0, ctx.Done()) - - return nil + go func() { + exitCh <- fmt.Errorf("controller loop: %w", w.Start(ctx)) + }() + + select { + case err := <-exitCh: + cancel() + return err + case <-ctx.Done(): + return nil + } } func kubeConfigFromEnv() (*rest.Config, error) {