diff --git a/cmd/agent/run.go b/cmd/agent/run.go index 5f3baca1..54a9269a 100644 --- a/cmd/agent/run.go +++ b/cmd/agent/run.go @@ -43,7 +43,7 @@ func run(ctx context.Context) error { remoteLogger := logrus.New() remoteLogger.SetLevel(logrus.Level(cfg.Log.Level)) remoteLogger.SetFormatter(&textFormatter) - log := remoteLogger.WithField("version", ctx.Value("agentVersion").(*config.AgentVersion).Version) + log := remoteLogger.WithField("version", config.VersionInfo.Version) if podName := os.Getenv("SELF_POD_NAME"); podName != "" { log = log.WithField("component_pod_name", podName) } @@ -112,7 +112,7 @@ func runAgentMode(ctx context.Context, castaiclient castai.Client, log *logrus.E ctx, ctxCancel := context.WithCancel(ctx) defer ctxCancel() - agentVersion := ctx.Value("agentVersion").(*config.AgentVersion) + agentVersion := config.VersionInfo // buffer will allow for all senders to push, even though we will only read first error and cancel context after it; // all errors from exitCh are logged diff --git a/cmd/dump/run.go b/cmd/dump/run.go index c4678a99..e485f7ef 100644 --- a/cmd/dump/run.go +++ b/cmd/dump/run.go @@ -22,7 +22,7 @@ func run(ctx context.Context) error { logger := logrus.New() logger.SetLevel(logrus.Level(cfg.Log.Level)) - log := logger.WithField("version", ctx.Value("agentVersion").(*config.AgentVersion).Version) + log := logger.WithField("version", config.VersionInfo.Version) log.Infof("starting dump of cluster snapshot") diff --git a/cmd/monitor/run.go b/cmd/monitor/run.go index defa12d3..4579875a 100644 --- a/cmd/monitor/run.go +++ b/cmd/monitor/run.go @@ -25,7 +25,7 @@ func run(ctx context.Context) error { remoteLogger := logrus.New() remoteLogger.SetLevel(logrus.Level(cfg.Log.Level)) - log := remoteLogger.WithField("version", ctx.Value("agentVersion").(*config.AgentVersion).Version) + log := remoteLogger.WithField("version", config.VersionInfo.Version) localLog := logrus.New() localLog.SetLevel(logrus.DebugLevel) diff --git a/internal/castai/castai.go b/internal/castai/castai.go index a56dc7b6..13da9354 100644 --- a/internal/castai/castai.go +++ b/internal/castai/castai.go @@ -27,13 +27,13 @@ import ( const ( defaultRetryCount = 3 - defaultTimeout = 10 * time.Second - sendDeltaReadTimeout = 2 * time.Minute - totalSendDeltaTimeout = 5 * time.Minute headerAPIKey = "X-Api-Key" headerContinuityToken = "Continuity-Token" headerContentType = "Content-Type" headerContentEncoding = "Content-Encoding" + headerUserAgent = "User-Agent" + + respHeaderRequestID = "X-Castai-Request-Id" ) var ( @@ -73,13 +73,18 @@ func NewDefaultRestyClient() (*resty.Client, error) { } restyClient := resty.NewWithClient(&http.Client{ - Timeout: defaultTimeout, + Timeout: cfg.Timeout, Transport: clientTransport, }) restyClient.SetBaseURL(cfg.URL) restyClient.SetRetryCount(defaultRetryCount) restyClient.Header.Set(headerAPIKey, cfg.Key) + restyClient.Header.Set(headerUserAgent, fmt.Sprintf("castai-agent/%s", config.VersionInfo.Version)) + if host := cfg.HostHeaderOverride; host != "" { + restyClient.Header.Set("Host", host) + } + addUA(restyClient.Header) return restyClient, nil } @@ -93,7 +98,7 @@ func NewDefaultDeltaHTTPClient() (*http.Client, error) { } return &http.Client{ - Timeout: sendDeltaReadTimeout, + Timeout: config.Get().API.DeltaReadTimeout, Transport: clientTransport, }, nil } @@ -194,7 +199,7 @@ func (c *client) SendDelta(ctx context.Context, clusterID string, delta *Delta) } }() - ctx, cancel := context.WithTimeout(ctx, totalSendDeltaTimeout) + ctx, cancel := context.WithTimeout(ctx, cfg.TotalSendDeltaTimeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodPost, uri.String(), pipeReader) @@ -206,6 +211,11 @@ func (c *client) SendDelta(ctx context.Context, clusterID string, delta *Delta) req.Header.Set(headerContentEncoding, "gzip") req.Header.Set(headerAPIKey, cfg.Key) req.Header.Set(headerContinuityToken, c.continuityToken) + addUA(req.Header) + + if host := cfg.HostHeaderOverride; host != "" { + req.Header.Set("Host", host) + } var resp *http.Response @@ -253,7 +263,8 @@ func (c *client) SendDelta(ctx context.Context, clusterID string, delta *Delta) if strings.Contains(buf.String(), ErrInvalidContinuityToken.Error()) { return ErrInvalidContinuityToken } - return fmt.Errorf("delta request error status_code=%d body=%s", resp.StatusCode, buf.String()) + reqID := resp.Header.Get(respHeaderRequestID) + return fmt.Errorf("delta request error request_id=%q status_code=%d body=%s", reqID, resp.StatusCode, buf.String()) } c.continuityToken = resp.Header.Get(headerContinuityToken) log.Infof("delta upload finished") @@ -312,3 +323,11 @@ func (c *client) ExchangeAgentTelemetry(ctx context.Context, clusterID string, r return body, nil } + +func addUA(header http.Header) { + version := "unknown" + if vi := config.VersionInfo; vi != nil { + version = vi.Version + } + header.Set(headerUserAgent, fmt.Sprintf("castai-agent/%s", version)) +} diff --git a/internal/config/config.go b/internal/config/config.go index d3b159fc..1ec8133a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -62,8 +62,12 @@ type Pod struct { } type API struct { - Key string `mapstructure:"key"` - URL string `mapstructure:"url"` + Key string `mapstructure:"key"` + URL string `mapstructure:"url"` + HostHeaderOverride string `mapstructure:"host_header_override"` + Timeout time.Duration `mapstructure:"timeout"` + DeltaReadTimeout time.Duration `mapstructure:"delta_read_timeout"` + TotalSendDeltaTimeout time.Duration `mapstructure:"total_send_delta_timeout"` } type EKS struct { @@ -144,6 +148,10 @@ func Get() Config { return *cfg } + viper.SetDefault("api.timeout", 10*time.Second) + viper.SetDefault("api.delta_read_timeout", 2*time.Minute) + viper.SetDefault("api.total_send_delta_timeout", 5*time.Minute) + viper.SetDefault("controller.interval", 15*time.Second) viper.SetDefault("controller.prep_timeout", 10*time.Minute) viper.SetDefault("controller.memory_pressure_interval", 3*time.Second) diff --git a/internal/config/version.go b/internal/config/version.go index e8d2e8cf..460ab596 100644 --- a/internal/config/version.go +++ b/internal/config/version.go @@ -2,6 +2,8 @@ package config import "fmt" +var VersionInfo *AgentVersion + type AgentVersion struct { GitCommit, GitRef, Version string } diff --git a/internal/services/controller/controller.go b/internal/services/controller/controller.go index af0c22b8..0f4939a0 100644 --- a/internal/services/controller/controller.go +++ b/internal/services/controller/controller.go @@ -151,9 +151,12 @@ func CollectSingleSnapshot(ctx context.Context, defer queue.ShutDown() - agentVersion := ctx.Value("agentVersion").(*config.AgentVersion) + agentVersion := "unknown" + if vi := config.VersionInfo; vi != nil { + agentVersion = vi.Version + } - d := delta.New(log, clusterID, v.Full(), agentVersion.Version) + d := delta.New(log, clusterID, v.Full(), agentVersion) go func() { for { i, _ := queue.Get() @@ -649,7 +652,11 @@ func throttleLog(ctx context.Context, log logrus.FieldLogger, objType string, wa } else { log.Infof("Informer cache for %v synced after %v", objType, time.Since(waitStartedAt)) } - time.Sleep(window) + select { + case <-time.After(window): + case <-ctx.Done(): + return + } } } }() diff --git a/main.go b/main.go index e1c167ef..0b4d7a79 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "context" _ "net/http/pprof" "github.com/KimMachineGun/automemlimit/memlimit" @@ -30,12 +29,11 @@ var ( ) func main() { - ctx := signals.SetupSignalHandler() - ctx = context.WithValue(ctx, "agentVersion", &config.AgentVersion{ + config.VersionInfo = &config.AgentVersion{ GitCommit: GitCommit, GitRef: GitRef, Version: Version, - }) + } cmd.Execute(ctx) }