diff --git a/internal/cast/cast.go b/internal/cast/cast.go index d40ba76f..e4819217 100644 --- a/internal/cast/cast.go +++ b/internal/cast/cast.go @@ -9,12 +9,23 @@ import ( "fmt" "github.com/go-resty/resty/v2" "github.com/sirupsen/logrus" + "io" + "mime/multipart" + "net/http" + "net/textproto" + "net/url" "time" ) const ( defaultRetryCount = 3 defaultTimeout = 10 * time.Second + headerAPIKey = "X-API-Key" +) + +var ( + hdrContentType = http.CanonicalHeaderKey("Content-Type") + hdrContentDisposition = http.CanonicalHeaderKey("Content-Disposition") ) // Client responsible for communication between the agent and CAST AI API. @@ -42,7 +53,7 @@ func NewDefaultClient() *resty.Client { client.SetHostURL(fmt.Sprintf("https://%s", cfg.URL)) client.SetRetryCount(defaultRetryCount) client.SetTimeout(defaultTimeout) - client.Header.Set("X-API-Key", cfg.Key) + client.Header.Set(headerAPIKey, cfg.Key) return client } @@ -72,30 +83,81 @@ func (c *client) RegisterCluster(ctx context.Context, req *RegisterClusterReques } func (c *client) SendClusterSnapshot(ctx context.Context, snap *Snapshot) error { - payload, err := json.Marshal(snap) + cfg := config.Get().API + + uri, err := url.Parse(fmt.Sprintf("https://%s/v1/agent/snapshot", cfg.URL)) if err != nil { - return fmt.Errorf("marshaling snapshot payload: %w", err) + return fmt.Errorf("invalid url: %w", err) } - buf := bytes.NewBuffer(payload) - resp, err := c.rest.R(). - SetFileReader("payload", "payload.json", buf). - SetResult(&RegisterClusterResponse{}). - SetContext(ctx). - Post("/v1/agent/snapshot") + r, w := io.Pipe() + mw := multipart.NewWriter(w) + + go func() { + defer func() { + if err := w.Close(); err != nil { + c.log.Errorf("closing pipe: %v", err) + } + }() + defer func() { + if err := mw.Close(); err != nil { + c.log.Errorf("closing multipart writer: %w", err) + } + }() + if err := writeSnapshotPart(mw, snap); err != nil { + c.log.Errorf("writing snapshot content: %v", err) + } + }() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, uri.String(), r) if err != nil { - return err + return fmt.Errorf("creating snapshot request: %w", err) } - if resp.IsError() { - return fmt.Errorf("request error status_code=%d body=%s", resp.StatusCode(), resp.Body()) + + req.Header.Set(hdrContentType, mw.FormDataContentType()) + req.Header.Set(headerAPIKey, cfg.Key) + + resp, err := c.rest.GetClient().Do(req) + if err != nil { + return fmt.Errorf("sending snapshot request: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + c.log.Errorf("closing response body: %v", err) + } + }() + + if resp.StatusCode > 399 { + var buf bytes.Buffer + if _, err := buf.ReadFrom(resp.Body); err != nil { + c.log.Errorf("failed reading error response body: %v", err) + } + return fmt.Errorf("snapshot request error status_code=%d body=%s", resp.StatusCode, buf.String()) } c.log.Infof( "snapshot with nodes[%d], pods[%d] sent, response_code=%d", len(snap.NodeList.Items), len(snap.PodList.Items), - resp.StatusCode(), + resp.StatusCode, ) return nil } + +func writeSnapshotPart(mw *multipart.Writer, snap *Snapshot) error { + header := textproto.MIMEHeader{} + header.Set(hdrContentDisposition, `form-data; name="payload"; filename="payload.json"`) + header.Set(hdrContentType, "application/json") + + bw, err := mw.CreatePart(header) + if err != nil { + return fmt.Errorf("creating payload part: %w", err) + } + + if err := json.NewEncoder(bw).Encode(snap); err != nil { + return fmt.Errorf("marshaling snapshot payload: %w", err) + } + + return nil +} diff --git a/internal/cast/cast_test.go b/internal/cast/cast_test.go index c9a89863..a6ab365a 100644 --- a/internal/cast/cast_test.go +++ b/internal/cast/cast_test.go @@ -12,6 +12,7 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "net/http" + "os" "testing" ) @@ -39,6 +40,9 @@ func TestClient_RegisterCluster(t *testing.T) { } func TestClient_SendClusterSnapshot(t *testing.T) { + require.NoError(t, os.Setenv("API_KEY", "api-key")) + require.NoError(t, os.Setenv("API_URL", "localhost")) + rest := resty.New() httpmock.ActivateNonDefault(rest.GetClient()) defer httpmock.Reset() @@ -69,7 +73,7 @@ func TestClient_SendClusterSnapshot(t *testing.T) { }, } - httpmock.RegisterResponder(http.MethodPost, "/v1/agent/snapshot", func(req *http.Request) (*http.Response, error) { + httpmock.RegisterResponder(http.MethodPost, "https://localhost/v1/agent/snapshot", func(req *http.Request) (*http.Response, error) { f, _, err := req.FormFile("payload") require.NoError(t, err) @@ -78,6 +82,8 @@ func TestClient_SendClusterSnapshot(t *testing.T) { require.Equal(t, snapshot, actualRequest) + require.Equal(t, "api-key", req.Header.Get(headerAPIKey)) + return httpmock.NewStringResponse(http.StatusNoContent, "ok"), nil })