diff --git a/client.go b/client.go index a7e0e46..571f1cb 100644 --- a/client.go +++ b/client.go @@ -42,8 +42,11 @@ type Client struct { Transactions *TransactionsService } -// NewRequest prepares new http request. -func (c *Client) newRequest(method string, urlStr string, body interface{}) (*http.Request, error) { +func (c *Client) newGetRequest(ctx context.Context, urlStr string) (*http.Request, error) { + return c.newRequest(ctx, http.MethodGet, urlStr, nil) +} + +func (c *Client) newRequest(ctx context.Context, method string, urlStr string, body interface{}) (*http.Request, error) { var buf io.ReadWriter if body != nil { buf = new(bytes.Buffer) @@ -51,24 +54,12 @@ func (c *Client) newRequest(method string, urlStr string, body interface{}) (*ht return nil, err } } - return http.NewRequest(method, urlStr, buf) + return http.NewRequestWithContext(ctx, method, urlStr, buf) } -// Get prepares new GET http request. -func (c *Client) get(urlStr string) (*http.Request, error) { - return c.newRequest(http.MethodGet, urlStr, nil) -} - -// Do performs http request and returns response. -func (c *Client) do(ctx context.Context, req *http.Request) (*http.Response, error) { - req = req.WithContext(ctx) +func (c *Client) do(req *http.Request) (*http.Response, error) { resp, err := c.client.Do(req) if err != nil { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } return nil, err } @@ -119,9 +110,8 @@ func (c *Client) checkResponse(r *http.Response) error { // 400 invalid date format in url // 200 ok - switch { // try to handle validation error - case r.StatusCode == http.StatusInternalServerError && strings.Contains(r.Header.Get("Content-Type"), "text/xml"): + if r.StatusCode == http.StatusInternalServerError && strings.Contains(r.Header.Get("Content-Type"), "text/xml") { var errResp xmlErrorResponse if err := xml.NewDecoder(r.Body).Decode(&errResp); err == nil { resp.Message = errResp.Result.Message @@ -137,7 +127,7 @@ func SanitizeURL(token string, u *url.URL) *url.URL { return u } - redacted := strings.Replace(u.String(), token, "REDACTED", -1) + redacted := strings.ReplaceAll(u.String(), token, "REDACTED") redactedURL, _ := url.Parse(redacted) return redactedURL } diff --git a/transactions.go b/transactions.go index 2adedf6..7fbea32 100644 --- a/transactions.go +++ b/transactions.go @@ -71,12 +71,12 @@ type TransactionsService struct { // ByPeriod returns transactions in date period. func (s *TransactionsService) ByPeriod(ctx context.Context, opts ByPeriodOptions) (*TransactionsResponse, error) { urlStr := s.client.buildURL("ib_api/rest/periods", fmtDate(opts.DateFrom), fmtDate(opts.DateTo), "transactions.xml") - req, err := s.client.get(urlStr) + req, err := s.client.newGetRequest(ctx, urlStr) if err != nil { return nil, err } - resp, err := s.client.do(ctx, req) + resp, err := s.client.do(req) if err != nil { return nil, err } @@ -96,12 +96,12 @@ type ExportOptions struct { func (s *TransactionsService) Export(ctx context.Context, opts ExportOptions, w io.Writer) error { exportFmt := fmt.Sprintf("transactions.%v", opts.Format) urlStr := s.client.buildURL("ib_api/rest/periods", fmtDate(opts.DateFrom), fmtDate(opts.DateTo), exportFmt) - req, err := s.client.get(urlStr) + req, err := s.client.newGetRequest(ctx, urlStr) if err != nil { return err } - resp, err := s.client.do(ctx, req) + resp, err := s.client.do(req) if err != nil { return err } @@ -120,12 +120,12 @@ type GetStatementOptions struct { // GetStatement returns statement by its year/id. func (s *TransactionsService) GetStatement(ctx context.Context, opts GetStatementOptions) (*TransactionsResponse, error) { urlStr := s.client.buildURL("ib_api/rest/by-id", strconv.Itoa(opts.Year), strconv.Itoa(opts.ID), "transactions.xml") - req, err := s.client.get(urlStr) + req, err := s.client.newGetRequest(ctx, urlStr) if err != nil { return nil, err } - resp, err := s.client.do(ctx, req) + resp, err := s.client.do(req) if err != nil { return nil, err } @@ -144,12 +144,12 @@ type ExportStatementOptions struct { func (s *TransactionsService) ExportStatement(ctx context.Context, opts ExportStatementOptions, w io.Writer) error { exportFmt := fmt.Sprintf("transactions.%v", opts.Format) urlStr := s.client.buildURL("ib_api/rest/by-id", strconv.Itoa(opts.Year), strconv.Itoa(opts.ID), exportFmt) - req, err := s.client.get(urlStr) + req, err := s.client.newGetRequest(ctx, urlStr) if err != nil { return err } - resp, err := s.client.do(ctx, req) + resp, err := s.client.do(req) if err != nil { return err } @@ -162,12 +162,12 @@ func (s *TransactionsService) ExportStatement(ctx context.Context, opts ExportSt // SinceLastDownload returns transactions since last download. func (s *TransactionsService) SinceLastDownload(ctx context.Context) (*TransactionsResponse, error) { urlStr := s.client.buildURL("ib_api/rest/last", "transactions.xml") - req, err := s.client.get(urlStr) + req, err := s.client.newGetRequest(ctx, urlStr) if err != nil { return nil, err } - resp, err := s.client.do(ctx, req) + resp, err := s.client.do(req) if err != nil { return nil, err } @@ -184,12 +184,12 @@ type SetLastDownloadIDOptions struct { // SetLastDownloadID sets the last downloaded id of statement. func (s *TransactionsService) SetLastDownloadID(ctx context.Context, opts SetLastDownloadIDOptions) error { urlStr := s.client.buildURL("ib_api/rest/set-last-id", strconv.Itoa(opts.ID), "") - req, err := s.client.get(urlStr) + req, err := s.client.newGetRequest(ctx, urlStr) if err != nil { return err } - resp, err := s.client.do(ctx, req) + resp, err := s.client.do(req) if err != nil { return err } @@ -204,12 +204,12 @@ type SetLastDownloadDateOptions struct { // SetLastDownloadDate sets the last download date of statement. func (s *TransactionsService) SetLastDownloadDate(ctx context.Context, opts SetLastDownloadDateOptions) error { urlStr := s.client.buildURL("ib_api/rest/set-last-date", fmtDate(opts.Date), "") - req, err := s.client.get(urlStr) + req, err := s.client.newGetRequest(ctx, urlStr) if err != nil { return err } - resp, err := s.client.do(ctx, req) + resp, err := s.client.do(req) if err != nil { return err } diff --git a/transactions_test.go b/transactions_test.go index 402eff7..b22d276 100644 --- a/transactions_test.go +++ b/transactions_test.go @@ -218,6 +218,8 @@ func TestGetStatement(t *testing.T) { } func assertEqualTransactionsResp(t *testing.T, want *TransactionsResponse, resp *TransactionsResponse) { + t.Helper() + require.Equal(t, want.Info.AccountID, resp.Info.AccountID) require.Equal(t, want.Info.BankID, resp.Info.BankID) require.Equal(t, want.Info.Currency, resp.Info.Currency) @@ -239,6 +241,8 @@ func assertEqualTransactionsResp(t *testing.T, want *TransactionsResponse, resp } func assertEqualTransaction(t *testing.T, want Transaction, got Transaction) { + t.Helper() + require.Equal(t, want.ID, got.ID) require.Equal(t, want.Amount, got.Amount) require.Equal(t, want.Currency, got.Currency)