From 237621c43a053c9d54e5eb3ac747ddfdec8b5578 Mon Sep 17 00:00:00 2001 From: Michael Chernov <4ernovm@gmail.com> Date: Fri, 1 Dec 2023 17:44:47 +0200 Subject: [PATCH] Extract get with timeout into separate function --- internal/client/deb_client.go | 8 ++++++-- internal/client/deb_client_test.go | 4 ++-- internal/client/request.go | 22 ++++++++++++++++----- internal/client/testdata/deb_client_mock.go | 4 ++-- internal/file/finder.go | 6 +++--- internal/report/license/report.go | 4 ++-- internal/report/vulnerability/report.go | 2 +- internal/upload/batch.go | 2 +- 8 files changed, 34 insertions(+), 18 deletions(-) diff --git a/internal/client/deb_client.go b/internal/client/deb_client.go index ec09c5d7..85855ee0 100644 --- a/internal/client/deb_client.go +++ b/internal/client/deb_client.go @@ -13,7 +13,7 @@ type IDebClient interface { // Post makes a POST request to one of Debricked's API endpoints Post(uri string, contentType string, body *bytes.Buffer, timeout int) (*http.Response, error) // Get makes a GET request to one of Debricked's API endpoints - Get(uri string, format string) (*http.Response, error) + Get(uri string, format string, timeout int) (*http.Response, error) SetAccessToken(accessToken *string) ConfigureClientSettings(retry bool, timeout int) } @@ -51,7 +51,11 @@ func (debClient *DebClient) Post(uri string, contentType string, body *bytes.Buf return post(uri, debClient, contentType, body, true) } -func (debClient *DebClient) Get(uri string, format string) (*http.Response, error) { +func (debClient *DebClient) Get(uri string, format string, timeout int) (*http.Response, error) { + if timeout > 0 { + return getWithTimeout(uri, debClient, false, format, timeout) + } + return get(uri, debClient, debClient.retry, format) } diff --git a/internal/client/deb_client_test.go b/internal/client/deb_client_test.go index 95e120b5..070892cf 100644 --- a/internal/client/deb_client_test.go +++ b/internal/client/deb_client_test.go @@ -84,7 +84,7 @@ func TestClientUnauthorized(t *testing.T) { }) client = NewDebClient(&tkn, clientMock) - res, err := client.Get("/api/1.0/open/user-profile/is-admin", "application/json") + res, err := client.Get("/api/1.0/open/user-profile/is-admin", "application/json", 0) if err == nil { t.Error("failed to assert client error") defer res.Body.Close() @@ -104,7 +104,7 @@ func TestGet(t *testing.T) { }) client = NewDebClient(&tkn, clientMock) - res, err := client.Get("/api/1.0/open/user-profile/is-admin", "application/json") + res, err := client.Get("/api/1.0/open/user-profile/is-admin", "application/json", 0) if err != nil { t.Fatal("failed to assert that no client error occurred. Error:", err) } diff --git a/internal/client/request.go b/internal/client/request.go index a64d8d92..af4a5c47 100644 --- a/internal/client/request.go +++ b/internal/client/request.go @@ -23,13 +23,25 @@ func get(uri string, debClient *DebClient, retry bool, format string) (*http.Res return nil, err } - if debClient.timeout > 0 { - timeoutDuration := time.Duration(debClient.timeout) * time.Second - ctx, cancel := context.WithTimeout(request.Context(), timeoutDuration) - defer cancel() - request = request.WithContext(ctx) + res, _ := debClient.httpClient.Do(request) + req := func() (*http.Response, error) { + return get(uri, debClient, false, format) } + return interpret(res, req, debClient, retry) +} + +func getWithTimeout(uri string, debClient *DebClient, retry bool, format string, timeout int) (*http.Response, error) { + request, err := newRequest("GET", *debClient.host+uri, debClient.jwtToken, format, nil) + if err != nil { + return nil, err + } + + timeoutDuration := time.Duration(debClient.timeout) * time.Second + ctx, cancel := context.WithTimeout(request.Context(), timeoutDuration) + defer cancel() + request = request.WithContext(ctx) + res, _ := debClient.httpClient.Do(request) req := func() (*http.Response, error) { return get(uri, debClient, false, format) diff --git a/internal/client/testdata/deb_client_mock.go b/internal/client/testdata/deb_client_mock.go index c53a28f0..ef9255bd 100644 --- a/internal/client/testdata/deb_client_mock.go +++ b/internal/client/testdata/deb_client_mock.go @@ -32,14 +32,14 @@ func NewDebClientMock() *DebClientMock { } } -func (mock *DebClientMock) Get(uri string, format string) (*http.Response, error) { +func (mock *DebClientMock) Get(uri string, format string, timeout int) (*http.Response, error) { response, err := mock.popResponse(mock.RemoveQueryParamsFromUri(uri)) if response != nil || !mock.serviceUp { return response, err } - return mock.realDebClient.Get(uri, format) + return mock.realDebClient.Get(uri, format, timeout) } func (mock *DebClientMock) Post(uri string, format string, body *bytes.Buffer, timeout int) (*http.Response, error) { diff --git a/internal/file/finder.go b/internal/file/finder.go index 7e415776..cc3756af 100644 --- a/internal/file/finder.go +++ b/internal/file/finder.go @@ -104,10 +104,10 @@ func (finder *Finder) GetSupportedFormats() ([]*CompiledFormat, error) { } func (finder *Finder) GetSupportedFormatsJson() ([]byte, error) { - finder.debClient.ConfigureClientSettings(false, 3) - defer finder.debClient.ConfigureClientSettings(true, 15) + //finder.debClient.ConfigureClientSettings(false, 3) + //defer finder.debClient.ConfigureClientSettings(true, 15) - res, err := finder.debClient.Get(SupportedFormatsUri, "application/json") + res, err := finder.debClient.Get(SupportedFormatsUri, "application/json", 3) if err != nil || res.StatusCode != http.StatusOK { fmt.Printf("%s Unable to get supported formats from the server. Using cached data instead.\n", color.YellowString("⚠️")) diff --git a/internal/report/license/report.go b/internal/report/license/report.go index d0f0feaa..737e6aeb 100644 --- a/internal/report/license/report.go +++ b/internal/report/license/report.go @@ -37,7 +37,7 @@ func (r Reporter) Order(args report.IOrderArgs) error { } uri := fmt.Sprintf("/api/1.0/open/licenses/get-licenses?order=asc&sortColumn=name&generateExcel=1&commitId=%d&email=%s", commitId, orderArgs.Email) - res, err := r.DebClient.Get(uri, "application/json") + res, err := r.DebClient.Get(uri, "application/json", 0) if err != nil { return err } @@ -62,7 +62,7 @@ type commit struct { func (r Reporter) getCommitId(hash string) (int, error) { uri := fmt.Sprintf("/api/1.0/open/releases/by/name?name=%s", hash) - res, err := r.DebClient.Get(uri, "application/json") + res, err := r.DebClient.Get(uri, "application/json", 0) if err != nil { return 0, err } diff --git a/internal/report/vulnerability/report.go b/internal/report/vulnerability/report.go index 7e208d45..66bc661c 100644 --- a/internal/report/vulnerability/report.go +++ b/internal/report/vulnerability/report.go @@ -29,7 +29,7 @@ func (r Reporter) Order(args report.IOrderArgs) error { } uri := fmt.Sprintf("/api/1.0/open/repositories/get-repositories?order=asc&generateExcel=1&email=%s", orderArgs.Email) - res, err := r.DebClient.Get(uri, "application/json") + res, err := r.DebClient.Get(uri, "application/json", 0) if err != nil { return err } diff --git a/internal/upload/batch.go b/internal/upload/batch.go index 2663e47b..1d17b5f7 100644 --- a/internal/upload/batch.go +++ b/internal/upload/batch.go @@ -200,7 +200,7 @@ func (uploadBatch *uploadBatch) wait() (*UploadResult, error) { var resultStatus *UploadResult uri := fmt.Sprintf("/api/1.0/open/ci/upload/status?ciUploadId=%s", strconv.Itoa(uploadBatch.ciUploadId)) for !bar.IsFinished() { - res, err := (*uploadBatch.client).Get(uri, "application/json") + res, err := (*uploadBatch.client).Get(uri, "application/json", 0) if err != nil { return nil, err }