From e605442f451186b868f06925ffe26e68d3c88a52 Mon Sep 17 00:00:00 2001 From: Arthur EICHELBERGER Date: Tue, 13 Dec 2022 18:25:51 +0100 Subject: [PATCH] feat(assert): safe type assertion (#39) --- add_on.go | 30 ++++++++++++++++++++++++------ billable_metric.go | 25 ++++++++++++++++++++----- coupon.go | 35 ++++++++++++++++++++++++++++------- credit_note.go | 30 ++++++++++++++++++++++++------ customer.go | 20 ++++++++++++++++---- error.go | 12 +++++++++++- group.go | 5 ++++- invoice.go | 20 ++++++++++++++++---- lago.go | 42 ++++++++++++++++++++++++++++++++++++------ organization.go | 5 ++++- plan.go | 25 ++++++++++++++++++++----- subscription.go | 12 +++++++++--- wallet.go | 25 ++++++++++++++++++++----- wallet_transaction.go | 2 +- 14 files changed, 233 insertions(+), 55 deletions(-) diff --git a/add_on.go b/add_on.go index 32441ec..3790201 100644 --- a/add_on.go +++ b/add_on.go @@ -95,7 +95,10 @@ func (adr *AddOnRequest) Get(ctx context.Context, addOnCode string) (*AddOn, *Er return nil, err } - addOnResult := result.(*AddOnResult) + addOnResult, ok := result.(*AddOnResult) + if !ok { + return nil, &ErrorTypeAssert + } return addOnResult.AddOn, nil } @@ -122,7 +125,10 @@ func (adr *AddOnRequest) GetList(ctx context.Context, addOnListInput *AddOnListI return nil, clientErr } - addOnResult := result.(*AddOnResult) + addOnResult, ok := result.(*AddOnResult) + if !ok { + return nil, &ErrorTypeAssert + } return addOnResult, nil } @@ -143,7 +149,10 @@ func (adr *AddOnRequest) Create(ctx context.Context, addOnInput *AddOnInput) (*A return nil, err } - addOnResult := result.(*AddOnResult) + addOnResult, ok := result.(*AddOnResult) + if !ok { + return nil, &ErrorTypeAssert + } return addOnResult.AddOn, nil } @@ -165,7 +174,10 @@ func (adr *AddOnRequest) Update(ctx context.Context, addOnInput *AddOnInput) (*A return nil, err } - addOnResult := result.(*AddOnResult) + addOnResult, ok := result.(*AddOnResult) + if !ok { + return nil, &ErrorTypeAssert + } return addOnResult.AddOn, nil } @@ -183,7 +195,10 @@ func (adr *AddOnRequest) Delete(ctx context.Context, addOnCode string) (*AddOn, return nil, err } - addOnResult := result.(*AddOnResult) + addOnResult, ok := result.(*AddOnResult) + if !ok { + return nil, &ErrorTypeAssert + } return addOnResult.AddOn, nil } @@ -204,7 +219,10 @@ func (adr *AddOnRequest) ApplyToCustomer(ctx context.Context, applyAddOnInput *A return nil, err } - appliedAddOnResult := result.(*AppliedAddOnResult) + appliedAddOnResult, ok := result.(*AppliedAddOnResult) + if !ok { + return nil, &ErrorTypeAssert + } return appliedAddOnResult.AppliedAddOn, nil } diff --git a/billable_metric.go b/billable_metric.go index 767a46f..687e5d8 100644 --- a/billable_metric.go +++ b/billable_metric.go @@ -68,7 +68,10 @@ func (bmr *BillableMetricRequest) Get(ctx context.Context, billableMetricCode st return nil, err } - billableMetricResult := result.(*BillableMetricResult) + billableMetricResult, ok := result.(*BillableMetricResult) + if !ok { + return nil, &ErrorTypeAssert + } return billableMetricResult.BillableMetric, nil } @@ -95,7 +98,10 @@ func (bmr *BillableMetricRequest) GetList(ctx context.Context, billableMetricLis return nil, clientErr } - billableMetricResult := result.(*BillableMetricResult) + billableMetricResult, ok := result.(*BillableMetricResult) + if !ok { + return nil, &ErrorTypeAssert + } return billableMetricResult, nil } @@ -112,7 +118,10 @@ func (bmr *BillableMetricRequest) Create(ctx context.Context, billableMetricInpu return nil, err } - billableMetricResult := result.(*BillableMetricResult) + billableMetricResult, ok := result.(*BillableMetricResult) + if !ok { + return nil, &ErrorTypeAssert + } return billableMetricResult.BillableMetric, nil } @@ -130,7 +139,10 @@ func (bmr *BillableMetricRequest) Update(ctx context.Context, billableMetricInpu return nil, err } - billableMetricResult := result.(*BillableMetricResult) + billableMetricResult, ok := result.(*BillableMetricResult) + if !ok { + return nil, &ErrorTypeAssert + } return billableMetricResult.BillableMetric, nil } @@ -147,7 +159,10 @@ func (bmr *BillableMetricRequest) Delete(ctx context.Context, billableMetricCode return nil, err } - billableMetricResult := result.(*BillableMetricResult) + billableMetricResult, ok := result.(*BillableMetricResult) + if !ok { + return nil, &ErrorTypeAssert + } return billableMetricResult.BillableMetric, nil } diff --git a/coupon.go b/coupon.go index 0a66e45..c9fcbf4 100644 --- a/coupon.go +++ b/coupon.go @@ -162,7 +162,10 @@ func (cr *CouponRequest) Get(ctx context.Context, couponCode string) (*Coupon, * return nil, err } - couponResult := result.(*CouponResult) + couponResult, ok := result.(*CouponResult) + if !ok { + return nil, &ErrorTypeAssert + } return couponResult.Coupon, nil } @@ -189,7 +192,10 @@ func (cr *CouponRequest) GetList(ctx context.Context, couponListInput *CouponLis return nil, clientErr } - couponResult := result.(*CouponResult) + couponResult, ok := result.(*CouponResult) + if !ok { + return nil, &ErrorTypeAssert + } return couponResult, nil } @@ -210,7 +216,10 @@ func (cr *CouponRequest) Create(ctx context.Context, couponInput *CouponInput) ( return nil, err } - couponResult := result.(*CouponResult) + couponResult, ok := result.(*CouponResult) + if !ok { + return nil, &ErrorTypeAssert + } return couponResult.Coupon, nil } @@ -232,7 +241,10 @@ func (cr *CouponRequest) Update(ctx context.Context, couponInput *CouponInput) ( return nil, err } - couponResult := result.(*CouponResult) + couponResult, ok := result.(*CouponResult) + if !ok { + return nil, &ErrorTypeAssert + } return couponResult.Coupon, nil } @@ -249,7 +261,10 @@ func (cr *CouponRequest) Delete(ctx context.Context, couponCode string) (*Coupon return nil, err } - couponResult := result.(*CouponResult) + couponResult, ok := result.(*CouponResult) + if !ok { + return nil, &ErrorTypeAssert + } return couponResult.Coupon, nil } @@ -276,7 +291,10 @@ func (cr *AppliedCouponRequest) GetList(ctx context.Context, appliedCouponListIn return nil, clientErr } - appliedCouponResult := result.(*AppliedCouponResult) + appliedCouponResult, ok := result.(*AppliedCouponResult) + if !ok { + return nil, &ErrorTypeAssert + } return appliedCouponResult, nil } @@ -297,7 +315,10 @@ func (cr *CouponRequest) ApplyToCustomer(ctx context.Context, applyCouponInput * return nil, err } - appliedCouponResult := result.(*AppliedCouponResult) + appliedCouponResult, ok := result.(*AppliedCouponResult) + if !ok { + return nil, &ErrorTypeAssert + } return appliedCouponResult.AppliedCoupon, nil } diff --git a/credit_note.go b/credit_note.go index 416ae24..11ba637 100644 --- a/credit_note.go +++ b/credit_note.go @@ -132,7 +132,10 @@ func (cr *CreditNoteRequest) Get(ctx context.Context, creditNoteID uuid.UUID) (* return nil, err } - creditNoteResult := result.(*CreditNoteResult) + creditNoteResult, ok := result.(*CreditNoteResult) + if !ok { + return nil, &ErrorTypeAssert + } return creditNoteResult.CreditNote, nil } @@ -150,7 +153,10 @@ func (cr *CreditNoteRequest) Download(ctx context.Context, creditNoteID string) } if result != nil { - creditNoteResult := result.(*CreditNoteResult) + creditNoteResult, ok := result.(*CreditNoteResult) + if !ok { + return nil, &ErrorTypeAssert + } return creditNoteResult.CreditNote, nil } @@ -180,7 +186,10 @@ func (cr *CreditNoteRequest) GetList(ctx context.Context, creditNoteListInput *C return nil, clientErr } - creditNoteResult := result.(*CreditNoteResult) + creditNoteResult, ok := result.(*CreditNoteResult) + if !ok { + return nil, &ErrorTypeAssert + } return creditNoteResult, nil } @@ -201,7 +210,10 @@ func (cr *CreditNoteRequest) Create(ctx context.Context, creditNoteInput *Credit return nil, err } - creditNoteResult := result.(*CreditNoteResult) + creditNoteResult, ok := result.(*CreditNoteResult) + if !ok { + return nil, &ErrorTypeAssert + } return creditNoteResult.CreditNote, nil } @@ -223,7 +235,10 @@ func (cr *CreditNoteRequest) Update(ctx context.Context, creditNoteUpdateInput * return nil, err } - creditNoteResult := result.(*CreditNoteResult) + creditNoteResult, ok := result.(*CreditNoteResult) + if !ok { + return nil, &ErrorTypeAssert + } return creditNoteResult.CreditNote, nil } @@ -241,7 +256,10 @@ func (cr *CreditNoteRequest) Void(ctx context.Context, creditNoteID string) (*Cr return nil, err } - creditNoteResult := result.(*CreditNoteResult) + creditNoteResult, ok := result.(*CreditNoteResult) + if !ok { + return nil, &ErrorTypeAssert + } return creditNoteResult.CreditNote, nil } diff --git a/customer.go b/customer.go index b2b2f8c..b5b4f4f 100644 --- a/customer.go +++ b/customer.go @@ -147,7 +147,10 @@ func (cr *CustomerRequest) Create(ctx context.Context, customerInput *CustomerIn return nil, err } - customerResult := result.(*CustomerResult) + customerResult, ok := result.(*CustomerResult) + if !ok { + return nil, &ErrorTypeAssert + } return customerResult.Customer, nil } @@ -182,7 +185,10 @@ func (cr *CustomerRequest) CurrentUsage(ctx context.Context, externalCustomerID return nil, clientErr } - currentUsageResult := result.(*CustomerUsageResult) + currentUsageResult, ok := result.(*CustomerUsageResult) + if !ok { + return nil, &ErrorTypeAssert + } return currentUsageResult.CustomerUsage, nil } @@ -199,7 +205,10 @@ func (cr *CustomerRequest) Get(ctx context.Context, externalCustomerID string) ( return nil, err } - customerResult := result.(*CustomerResult) + customerResult, ok := result.(*CustomerResult) + if !ok { + return nil, &ErrorTypeAssert + } return customerResult.Customer, nil } @@ -226,7 +235,10 @@ func (cr *CustomerRequest) GetList(ctx context.Context, customerListInput *Custo return nil, clientErr } - customerResult := result.(*CustomerResult) + customerResult, ok := result.(*CustomerResult) + if !ok { + return nil, &ErrorTypeAssert + } return customerResult, nil } diff --git a/error.go b/error.go index 6c278b0..961ed16 100644 --- a/error.go +++ b/error.go @@ -1,13 +1,23 @@ package lago +import ( + "errors" + "net/http" +) + type ErrorCode string const ( ErrorCodeAlreadyExist ErrorCode = "value_already_exist" ErrorCodeInvalidValue - ErrorTypeAssert ErrorCode = "error_type_assert" ) +var ErrorTypeAssert = Error{ + Err: errors.New("type assertion failed"), + HTTPStatusCode: http.StatusUnprocessableEntity, + Msg: "type assertion failed", +} + type ErrorDetail struct { ErrorCode []ErrorCode `json:"code,omitempty"` } diff --git a/group.go b/group.go index ec79339..81f71bb 100644 --- a/group.go +++ b/group.go @@ -58,7 +58,10 @@ func (cr *GroupRequest) GetList(ctx context.Context, groupListInput *GroupListIn return nil, clientErr } - groupResult := result.(*GroupResult) + groupResult, ok := result.(*GroupResult) + if !ok { + return nil, &ErrorTypeAssert + } return groupResult, nil } diff --git a/invoice.go b/invoice.go index a2cb9f1..70504e7 100644 --- a/invoice.go +++ b/invoice.go @@ -129,7 +129,10 @@ func (ir *InvoiceRequest) Get(ctx context.Context, invoiceID string) (*Invoice, return nil, err } - invoiceResult := result.(*InvoiceResult) + invoiceResult, ok := result.(*InvoiceResult) + if !ok { + return nil, &ErrorTypeAssert + } return invoiceResult.Invoice, nil } @@ -156,7 +159,10 @@ func (ir *InvoiceRequest) GetList(ctx context.Context, invoiceListInput *Invoice return nil, clientErr } - invoiceResult := result.(*InvoiceResult) + invoiceResult, ok := result.(*InvoiceResult) + if !ok { + return nil, &ErrorTypeAssert + } return invoiceResult, nil } @@ -178,7 +184,10 @@ func (ir *InvoiceRequest) Update(ctx context.Context, invoiceInput *InvoiceInput return nil, err } - invoiceResult := result.(*InvoiceResult) + invoiceResult, ok := result.(*InvoiceResult) + if !ok { + return nil, &ErrorTypeAssert + } return invoiceResult.Invoice, nil } @@ -196,7 +205,10 @@ func (ir *InvoiceRequest) Download(ctx context.Context, invoiceID string) (*Invo } if result != nil { - invoiceResult := result.(*InvoiceResult) + invoiceResult, ok := result.(*InvoiceResult) + if !ok { + return nil, &ErrorTypeAssert + } return invoiceResult.Invoice, nil } diff --git a/lago.go b/lago.go index 23c18fe..91dac5f 100644 --- a/lago.go +++ b/lago.go @@ -79,7 +79,12 @@ func (c *Client) Get(ctx context.Context, cr *ClientRequest) (interface{}, *Erro } if resp.IsError() { - return nil, resp.Error().(*Error) + err, ok := resp.Error().(*Error) + if !ok { + return nil, &ErrorTypeAssert + } + + return nil, err } return resp.Result(), nil @@ -102,7 +107,12 @@ func (c *Client) Post(ctx context.Context, cr *ClientRequest) (interface{}, *Err } if resp.IsError() { - return nil, resp.Error().(*Error) + err, ok := resp.Error().(*Error) + if !ok { + return nil, &ErrorTypeAssert + } + + return nil, err } return resp.Result(), nil @@ -124,7 +134,12 @@ func (c *Client) PostWithoutResult(ctx context.Context, cr *ClientRequest) *Erro } if resp.IsError() { - return resp.Error().(*Error) + err, ok := resp.Error().(*Error) + if !ok { + return &ErrorTypeAssert + } + + return err } return nil @@ -145,7 +160,12 @@ func (c *Client) PostWithoutBody(ctx context.Context, cr *ClientRequest) (interf } if resp.IsError() { - return nil, resp.Error().(*Error) + err, ok := resp.Error().(*Error) + if !ok { + return nil, &ErrorTypeAssert + } + + return nil, err } return resp.Result(), nil @@ -168,7 +188,12 @@ func (c *Client) Put(ctx context.Context, cr *ClientRequest) (interface{}, *Erro } if resp.IsError() { - return nil, resp.Error().(*Error) + err, ok := resp.Error().(*Error) + if !ok { + return nil, &ErrorTypeAssert + } + + return nil, err } return resp.Result(), nil @@ -191,7 +216,12 @@ func (c *Client) Delete(ctx context.Context, cr *ClientRequest) (interface{}, *E } if resp.IsError() { - return nil, resp.Error().(*Error) + err, ok := resp.Error().(*Error) + if !ok { + return nil, &ErrorTypeAssert + } + + return nil, err } return resp.Result(), nil diff --git a/organization.go b/organization.go index 644e03c..1803cc7 100644 --- a/organization.go +++ b/organization.go @@ -82,7 +82,10 @@ func (or *OrganizationRequest) Update(ctx context.Context, organizationInput *Or return nil, err } - organizationResult := result.(*OrganizationResult) + organizationResult, ok := result.(*OrganizationResult) + if !ok { + return nil, &ErrorTypeAssert + } return organizationResult.Organization, nil } diff --git a/plan.go b/plan.go index 86311a8..4ab4c94 100644 --- a/plan.go +++ b/plan.go @@ -88,7 +88,10 @@ func (pr *PlanRequest) Get(ctx context.Context, planCode string) (*Plan, *Error) return nil, err } - planResult := result.(*PlanResult) + planResult, ok := result.(*PlanResult) + if !ok { + return nil, &ErrorTypeAssert + } return planResult.Plan, nil } @@ -115,7 +118,10 @@ func (pr *PlanRequest) GetList(ctx context.Context, planListInput *PlanListInput return nil, clientErr } - planResult := result.(*PlanResult) + planResult, ok := result.(*PlanResult) + if !ok { + return nil, &ErrorTypeAssert + } return planResult, nil } @@ -136,7 +142,10 @@ func (pr *PlanRequest) Create(ctx context.Context, planInput *PlanInput) (*Plan, return nil, err } - planResult := result.(*PlanResult) + planResult, ok := result.(*PlanResult) + if !ok { + return nil, &ErrorTypeAssert + } return planResult.Plan, nil } @@ -158,7 +167,10 @@ func (pr *PlanRequest) Update(ctx context.Context, planInput *PlanInput) (*Plan, return nil, err } - planResult := result.(*PlanResult) + planResult, ok := result.(*PlanResult) + if !ok { + return nil, &ErrorTypeAssert + } return planResult.Plan, nil } @@ -176,7 +188,10 @@ func (pr *PlanRequest) Delete(ctx context.Context, planCode string) (*Plan, *Err return nil, err } - planResult := result.(*PlanResult) + planResult, ok := result.(*PlanResult) + if !ok { + return nil, &ErrorTypeAssert + } return planResult.Plan, nil } diff --git a/subscription.go b/subscription.go index a0e820c..8ec91c2 100644 --- a/subscription.go +++ b/subscription.go @@ -100,7 +100,10 @@ func (sr *SubscriptionRequest) Create(ctx context.Context, subscriptionInput *Su return nil, err } - subscriptionResult := result.(*SubscriptionResult) + subscriptionResult, ok := result.(*SubscriptionResult) + if !ok { + return nil, &ErrorTypeAssert + } return subscriptionResult.Subscription, nil } @@ -118,7 +121,10 @@ func (sr *SubscriptionRequest) Terminate(ctx context.Context, externalID string) return nil, err } - subscriptionResult := result.(*SubscriptionResult) + subscriptionResult, ok := result.(*SubscriptionResult) + if !ok { + return nil, &ErrorTypeAssert + } return subscriptionResult.Subscription, nil } @@ -147,7 +153,7 @@ func (sr *SubscriptionRequest) GetList(ctx context.Context, subscriptionListInpu subscriptionResult, ok := result.(*SubscriptionResult) if !ok { - return nil, &Error{Err: ErrorTypeAssert} + return nil, &ErrorTypeAssert } return subscriptionResult, nil diff --git a/wallet.go b/wallet.go index c88d307..1770292 100644 --- a/wallet.go +++ b/wallet.go @@ -81,7 +81,10 @@ func (bmr *WalletRequest) Get(ctx context.Context, walletId string) (*Wallet, *E return nil, err } - walletResult := result.(*WalletResult) + walletResult, ok := result.(*WalletResult) + if !ok { + return nil, &ErrorTypeAssert + } return walletResult.Wallet, nil } @@ -108,7 +111,10 @@ func (bmr *WalletRequest) GetList(ctx context.Context, walletListInput *WalletLi return nil, clientErr } - walletResult := result.(*WalletResult) + walletResult, ok := result.(*WalletResult) + if !ok { + return nil, &ErrorTypeAssert + } return walletResult, nil } @@ -125,7 +131,10 @@ func (bmr *WalletRequest) Create(ctx context.Context, walletInput *WalletInput) return nil, err } - walletResult := result.(*WalletResult) + walletResult, ok := result.(*WalletResult) + if !ok { + return nil, &ErrorTypeAssert + } return walletResult.Wallet, nil } @@ -143,7 +152,10 @@ func (bmr *WalletRequest) Update(ctx context.Context, walletInput *WalletInput, return nil, err } - walletResult := result.(*WalletResult) + walletResult, ok := result.(*WalletResult) + if !ok { + return nil, &ErrorTypeAssert + } return walletResult.Wallet, nil } @@ -160,7 +172,10 @@ func (bmr *WalletRequest) Delete(ctx context.Context, walletId string) (*Wallet, return nil, err } - walletResult := result.(*WalletResult) + walletResult, ok := result.(*WalletResult) + if !ok { + return nil, &ErrorTypeAssert + } return walletResult.Wallet, nil } diff --git a/wallet_transaction.go b/wallet_transaction.go index 582cb27..88806cd 100644 --- a/wallet_transaction.go +++ b/wallet_transaction.go @@ -70,7 +70,7 @@ func (bmr *WalletTransactionRequest) Create(ctx context.Context, walletTransacti walletTransactionResult, ok := result.(*WalletTransactionResult) if !ok { - return nil, err + return nil, &ErrorTypeAssert } return walletTransactionResult, nil