From 18cf8cde83e5694c7eb85d56c1fe488f81afc34b Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Wed, 17 Jan 2024 08:42:15 +0100 Subject: [PATCH] feat: add error codes --- internal/api/admin.go | 36 ++-- internal/api/apiversions.go | 31 ++++ internal/api/apiversions_test.go | 29 ++++ internal/api/audit.go | 4 +- internal/api/auth.go | 20 +-- internal/api/auth_test.go | 4 +- internal/api/errorcodes.go | 76 +++++++++ internal/api/errors.go | 242 ++++++++++++++------------- internal/api/errors_test.go | 64 +++++++ internal/api/external.go | 63 ++++--- internal/api/external_figma_test.go | 2 +- internal/api/external_fly_test.go | 2 +- internal/api/external_github_test.go | 4 +- internal/api/external_kakao_test.go | 4 +- internal/api/external_oauth.go | 11 +- internal/api/helpers_test.go | 8 +- internal/api/hook_test.go | 2 +- internal/api/hooks.go | 20 +-- internal/api/identity.go | 24 +-- internal/api/identity_test.go | 2 +- internal/api/invite.go | 6 +- internal/api/invite_test.go | 2 +- internal/api/logout.go | 2 +- internal/api/magic_link.go | 17 +- internal/api/mail.go | 123 ++++++++++---- internal/api/mfa.go | 24 +-- internal/api/mfa_test.go | 13 +- internal/api/middleware.go | 17 +- internal/api/middleware_test.go | 8 +- internal/api/otp.go | 30 ++-- internal/api/otp_test.go | 33 ++-- internal/api/phone.go | 2 +- internal/api/phone_test.go | 23 ++- internal/api/pkce.go | 8 +- internal/api/reauthenticate.go | 23 ++- internal/api/recover.go | 8 +- internal/api/resend.go | 25 +-- internal/api/router.go | 4 +- internal/api/samlacs.go | 24 +-- internal/api/signup.go | 49 ++++-- internal/api/sso.go | 15 +- internal/api/sso_test.go | 2 +- internal/api/ssoadmin.go | 40 ++--- internal/api/token.go | 38 ++--- internal/api/token_oidc.go | 10 +- internal/api/token_refresh.go | 4 +- internal/api/token_test.go | 7 +- internal/api/user.go | 26 +-- internal/api/user_test.go | 2 +- internal/api/verify.go | 50 +++--- internal/api/verify_test.go | 22 +-- internal/models/flow_state.go | 4 +- 52 files changed, 806 insertions(+), 503 deletions(-) create mode 100644 internal/api/apiversions.go create mode 100644 internal/api/apiversions_test.go create mode 100644 internal/api/errorcodes.go create mode 100644 internal/api/errors_test.go diff --git a/internal/api/admin.go b/internal/api/admin.go index d6fd17dac..8983eac35 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -50,7 +50,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, userID, err := uuid.FromString(chi.URLParam(r, "user_id")) if err != nil { - return nil, badRequestError("user_id must be an UUID") + return nil, notFoundError(ErrorCodeValidationFailed, "user_id must be an UUID") } observability.LogEntrySetField(r, "user_id", userID) @@ -58,7 +58,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, u, err := models.FindUserByID(db, userID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("User not found") + return nil, notFoundError(ErrorCodeUserNotFound, "User not found") } return nil, internalServerError("Database error loading user").WithInternalError(err) } @@ -69,7 +69,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) { factorID, err := uuid.FromString(chi.URLParam(r, "factor_id")) if err != nil { - return nil, badRequestError("factor_id must be an UUID") + return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID") } observability.LogEntrySetField(r, "factor_id", factorID) @@ -77,7 +77,7 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex f, err := models.FindFactorByFactorID(a.db, factorID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("Factor not found") + return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found") } return nil, internalServerError("Database error loading factor").WithInternalError(err) } @@ -89,11 +89,11 @@ func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) { body, err := getBodyBytes(r) if err != nil { - return nil, badRequestError("Could not read body").WithInternalError(err) + return nil, internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, ¶ms); err != nil { - return nil, badRequestError("Could not decode admin user params: %v", err) + return nil, badRequestError(ErrorCodeBadJSON, "Could not decode admin user params").WithInternalError(err) } return ¶ms, nil @@ -107,12 +107,12 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { pageParams, err := paginate(r) if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err) } sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}}) if err != nil { - return badRequestError("Bad Sort Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err) } filter := r.URL.Query().Get("filter") @@ -166,7 +166,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError("invalid format for ban duration: %v", err) + return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) } } if terr := user.Ban(a.db, duration); terr != nil { @@ -314,7 +314,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { } if params.Email == "" && params.Phone == "" { - return unprocessableEntityError("Cannot create a user without either an email or phone") + return badRequestError(ErrorCodeValidationFailed, "Cannot create a user without either an email or phone") } var providers []string @@ -326,7 +326,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil { return internalServerError("Database error checking email").WithInternalError(err) } else if user != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } providers = append(providers, "email") } @@ -339,7 +339,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { return internalServerError("Database error checking phone").WithInternalError(err) } else if exists { - return unprocessableEntityError("Phone number already registered by another user") + return unprocessableEntityError(ErrorCodePhoneExists, "Phone number already registered by another user") } providers = append(providers, "phone") } @@ -435,7 +435,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError("invalid format for ban duration: %v", err) + return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) } } if terr := user.Ban(a.db, duration); terr != nil { @@ -466,11 +466,11 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { params := &adminUserDeleteParams{} body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if len(body) > 0 { if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read params: %v", err) } } else { params.ShouldSoftDelete = false @@ -567,11 +567,11 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro params := &adminUserUpdateFactorParams{} body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read factor update params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read factor update params: %v", err).WithInternalError(err) } err = a.db.Transaction(func(tx *storage.Connection) error { @@ -582,7 +582,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro } if params.FactorType != "" { if params.FactorType != models.TOTP { - return badRequestError("Factor Type not valid") + return badRequestError(ErrorCodeValidationFailed, "Factor Type not valid") } if terr := factor.UpdateFactorType(tx, params.FactorType); terr != nil { return terr diff --git a/internal/api/apiversions.go b/internal/api/apiversions.go new file mode 100644 index 000000000..6e1c68ca7 --- /dev/null +++ b/internal/api/apiversions.go @@ -0,0 +1,31 @@ +package api + +import ( + "time" +) + +const APIVersionHeaderName = "X-Supabase-Api-Version" + +type APIVersion = time.Time + +var ( + APIVersionInitial = time.Time{} + APIVersion20240101 = time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC) +) + +func DetermineClosestAPIVersion(date string) (APIVersion, error) { + if date == "" { + return APIVersionInitial, nil + } + + parsed, err := time.ParseInLocation("2006-01-02", date, time.UTC) + if err != nil { + return APIVersionInitial, err + } + + if parsed.Compare(APIVersion20240101) >= 0 { + return APIVersion20240101, nil + } + + return APIVersionInitial, nil +} diff --git a/internal/api/apiversions_test.go b/internal/api/apiversions_test.go new file mode 100644 index 000000000..0a9622132 --- /dev/null +++ b/internal/api/apiversions_test.go @@ -0,0 +1,29 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDetermineClosestAPIVersion(t *testing.T) { + version, err := DetermineClosestAPIVersion("") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("Not a date") + require.Error(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2023-12-31") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2024-01-01") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) + + version, err = DetermineClosestAPIVersion("2024-01-02") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) +} diff --git a/internal/api/audit.go b/internal/api/audit.go index 2cb99c6e7..351a7d2cd 100644 --- a/internal/api/audit.go +++ b/internal/api/audit.go @@ -20,7 +20,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { // aud := a.requestAud(ctx, r) pageParams, err := paginate(r) if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err) } var col []string @@ -31,7 +31,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { qparts := strings.SplitN(q, ":", 2) col, exists = filterColumnMap[qparts[0]] if !exists || len(qparts) < 2 { - return badRequestError("Invalid query scope: %s", q) + return badRequestError(ErrorCodeValidationFailed, "Invalid query scope: %s", q) } qval = qparts[1] } diff --git a/internal/api/auth.go b/internal/api/auth.go index c1f43d511..49d04d16c 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -40,7 +40,7 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R claims := getClaims(ctx) if claims == nil { fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token") - return nil, unauthorizedError("Invalid token") + return nil, forbiddenError(ErrorCodeBadJWT, "Invalid token") } adminRoles := a.config.JWT.AdminRoles @@ -51,14 +51,14 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R } fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'") - return nil, unauthorizedError("User not allowed") + return nil, forbiddenError(ErrorCodeNotAdmin, "User not allowed") } func (a *API) extractBearerToken(r *http.Request) (string, error) { authHeader := r.Header.Get("Authorization") matches := bearerRegexp.FindStringSubmatch(authHeader) if len(matches) != 2 { - return "", unauthorizedError("This endpoint requires a Bearer token") + return "", httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "This endpoint requires a Bearer token") } return matches[1], nil @@ -73,7 +73,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e return []byte(config.JWT.Secret), nil }) if err != nil { - return nil, unauthorizedError("invalid JWT: unable to parse or verify signature, %v", err) + return nil, forbiddenError(ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err) } return withToken(ctx, token), nil @@ -84,23 +84,23 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro claims := getClaims(ctx) if claims == nil { - return ctx, unauthorizedError("invalid token: missing claims") + return ctx, forbiddenError(ErrorCodeBadJWT, "invalid token: missing claims") } if claims.Subject == "" { - return nil, unauthorizedError("invalid claim: missing sub claim") + return nil, forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim") } var user *models.User if claims.Subject != "" { userId, err := uuid.FromString(claims.Subject) if err != nil { - return ctx, badRequestError("invalid claim: sub claim must be a UUID").WithInternalError(err) + return ctx, badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err) } user, err = models.FindUserByID(db, userId) if err != nil { if models.IsNotFoundError(err) { - return ctx, notFoundError(err.Error()) + return ctx, forbiddenError(ErrorCodeUserNotFound, "User from sub claim in JWT does not exist") } return ctx, err } @@ -111,11 +111,11 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() { sessionId, err := uuid.FromString(claims.SessionId) if err != nil { - return ctx, err + return ctx, forbiddenError(ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err) } session, err = models.FindSessionByID(db, sessionId, false) if err != nil && !models.IsNotFoundError(err) { - return ctx, err + return ctx, forbiddenError(ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist") } ctx = withSession(ctx, session) } diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index 59ad0613c..2f18d7ce9 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -97,7 +97,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { }, Role: "authenticated", }, - ExpectedError: unauthorizedError("invalid claim: missing sub claim"), + ExpectedError: forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim"), ExpectedUser: nil, }, { @@ -119,7 +119,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { }, Role: "authenticated", }, - ExpectedError: badRequestError("invalid claim: sub claim must be a UUID"), + ExpectedError: badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"), ExpectedUser: nil, }, { diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go new file mode 100644 index 000000000..368851598 --- /dev/null +++ b/internal/api/errorcodes.go @@ -0,0 +1,76 @@ +package api + +type ErrorCode = string + +const ( + // ErrorCodeUnknown should not be used directly, it only indicates a failure in the error handling system in such a way that an error code was not assigned properly. + ErrorCodeUnknown ErrorCode = "unknown" + + // ErrorCodeUnexpectedFailure signals an unexpected failure such as a 500 Internal Server Error. + ErrorCodeUnexpectedFailure ErrorCode = "unexpected_failure" + + ErrorCodeValidationFailed ErrorCode = "validation_failed" + ErrorCodeBadJSON ErrorCode = "bad_json" + ErrorCodeEmailExists ErrorCode = "email_exists" + ErrorCodePhoneExists ErrorCode = "phone_exists" + ErrorCodeBadJWT ErrorCode = "bad_jwt" + ErrorCodeNotAdmin ErrorCode = "not_admin" + ErrorCodeNoAuthorization ErrorCode = "no_authorization" + ErrorCodeUserNotFound ErrorCode = "user_not_found" + ErrorCodeSessionNotFound ErrorCode = "session_not_found" + ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found" + ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired" + ErrorCodeSignupDisabled ErrorCode = "signup_disabled" + ErrorCodeUserBanned ErrorCode = "user_banned" + ErrorCodeOverEmailSendRate ErrorCode = "over_email_send_rate" + ErrorCodeOverSMSSendRate ErrorCode = "over_sms_send_rate" + ErrorCodeProviderEmailNeedsVerification ErrorCode = "provider_email_needs_verification" + ErrorCodeInviteNotFound ErrorCode = "invite_not_found" + ErrorCodeBadOAuthState ErrorCode = "bad_oauth_state" + ErrorCodeBadOAuthCallback ErrorCode = "bad_oauth_callback" + ErrorCodeOAuthProviderNotSupported ErrorCode = "oauth_provider_not_supported" + ErrorCodeUnexpectedAudience ErrorCode = "unexpected_audience" + ErrorCodeLastIdentityNotDeletable ErrorCode = "last_identity_not_deletable" + ErrorCodeEmailConflictIdentityNotDeletable ErrorCode = "email_conflict_identity_not_deletable" + ErrorCodeIdentityAlreadyExists ErrorCode = "identity_already_exists" + ErrorCodeEmailProviderDisabled ErrorCode = "email_provider_disabled" + ErrorCodePhoneProviderDisabled ErrorCode = "phone_provider_disabled" + ErrorCodeTooManyEnrolledMFAFactors ErrorCode = "too_many_enrolled_mfa_factors" + ErrorCodeMFAFactorNameConflict ErrorCode = "mfa_factor_name_conflict" + ErrorCodeMFAFactorNotFound ErrorCode = "mfa_factor_not_found" + ErrorCodeMFAIPAddressMismatch ErrorCode = "mfa_ip_address_mismatch" + ErrorCodeMFAChallengeExpired ErrorCode = "mfa_challenge_expired" + ErrorCodeMFAVerificationFailed ErrorCode = "mfa_verification_failed" + ErrorCodeMFAVerificationRejected ErrorCode = "mfa_verification_rejected" + ErrorCodeInsufficientAAL ErrorCode = "insufficient_aal" + ErrorCodeCaptchaFailed ErrorCode = "captcha_failed" + ErrorCodeSAMLProviderDisabled ErrorCode = "saml_provider_disabled" + ErrorCodeManualLinkingDisabled ErrorCode = "manual_linking_disabled" + ErrorCodeSMSSendFailed ErrorCode = "sms_send_failed" + ErrorCodeEmailNotConfirmed ErrorCode = "email_not_confirmed" + ErrorCodePhoneNotConfirmed ErrorCode = "phone_not_confirmed" + ErrorCodeReauthNonceMissing ErrorCode = "reauth_nonce_missing" + ErrorCodeSAMLRelayStateNotFound ErrorCode = "saml_relay_state_not_found" + ErrorCodeSAMLRelayStateExpired ErrorCode = "saml_relay_state_expired" + ErrorCodeSAMLIdPNotFound ErrorCode = "saml_idp_not_found" + ErrorCodeSAMLAssertionNoUserID ErrorCode = "saml_assertion_no_user_id" + ErrorCodeSAMLAssertionNoEmail ErrorCode = "saml_assertion_no_email" + ErrorCodeUserAlreadyExists ErrorCode = "user_already_exists" + ErrorCodeSSOProviderNotFound ErrorCode = "sso_provider_not_found" + ErrorCodeSAMLMetadataFetchFailed ErrorCode = "saml_metadata_fetch_failed" + ErrorCodeSAMLIdPAlreadyExists ErrorCode = "saml_idp_already_exists" + ErrorCodeSSODomainAlreadyExists ErrorCode = "sso_domain_already_exists" + ErrorCodeSAMLEntityIDMismatch ErrorCode = "saml_entity_id_mismatch" + ErrorCodeConflict ErrorCode = "conflict" + ErrorCodeProviderDisabled ErrorCode = "provider_disabled" + ErrorCodeUserSSOManaged ErrorCode = "user_sso_managed" + ErrorCodeReauthenticationNeeded ErrorCode = "reauthentication_needed" + ErrorCodeSamePassword ErrorCode = "same_password" + ErrorCodeReauthenticationNotValid ErrorCode = "reauthentication_not_valid" + ErrorCodeOTPExpired ErrorCode = "otp_expired" + ErrorCodeOTPDisabled ErrorCode = "otp_disabled" + ErrorCodeIdentityNotFound ErrorCode = "identity_not_found" + ErrorCodeWeakPassword ErrorCode = "weak_password" + ErrorCodeOverRequestRate ErrorCode = "over_request_rate" + ErrorBadCodeVerifier ErrorCode = "bad_code_verifier" +) diff --git a/internal/api/errors.go b/internal/api/errors.go index 56f404e3c..41e17c1fb 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -8,7 +8,6 @@ import ( "runtime/debug" "github.com/pkg/errors" - "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/utilities" ) @@ -65,65 +64,43 @@ func (e *OAuthError) Cause() error { return e } -func invalidSignupError(config *conf.GlobalConfiguration) *HTTPError { - var msg string - if config.External.Email.Enabled && config.External.Phone.Enabled { - msg = "To signup, please provide your email or phone number" - } else if config.External.Email.Enabled { - msg = "To signup, please provide your email" - } else if config.External.Phone.Enabled { - msg = "To signup, please provide your phone number" - } else { - // 3rd party OAuth signups - msg = "To signup, please provide required fields" - } - return unprocessableEntityError(msg) -} - func oauthError(err string, description string) *OAuthError { return &OAuthError{Err: err, Description: description} } -func badRequestError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusBadRequest, fmtString, args...) +func badRequestError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusBadRequest, errorCode, fmtString, args...) } func internalServerError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusInternalServerError, fmtString, args...) -} - -func notFoundError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusNotFound, fmtString, args...) + return httpError(http.StatusInternalServerError, ErrorCodeUnexpectedFailure, fmtString, args...) } -func expiredTokenError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnauthorized, fmtString, args...) +func notFoundError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusNotFound, errorCode, fmtString, args...) } -func unauthorizedError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnauthorized, fmtString, args...) +func forbiddenError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusForbidden, errorCode, fmtString, args...) } -func forbiddenError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusForbidden, fmtString, args...) +func unprocessableEntityError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusUnprocessableEntity, errorCode, fmtString, args...) } -func unprocessableEntityError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnprocessableEntity, fmtString, args...) -} - -func tooManyRequestsError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusTooManyRequests, fmtString, args...) +func tooManyRequestsError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusTooManyRequests, errorCode, fmtString, args...) } func conflictError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusConflict, fmtString, args...) + return httpError(http.StatusConflict, ErrorCodeConflict, fmtString, args...) } // HTTPError is an error with a message and an HTTP status code. type HTTPError struct { - Code int `json:"code"` - Message string `json:"msg"` + HTTPStatus int `json:"code"` // do not rename the JSON tags! + ErrorCode string `json:"error_code,omitempty"` // do not rename the JSON tags! + Message string `json:"msg"` // do not rename the JSON tags! InternalError error `json:"-"` InternalMessage string `json:"-"` ErrorID string `json:"error_id,omitempty"` @@ -133,7 +110,7 @@ func (e *HTTPError) Error() string { if e.InternalMessage != "" { return e.InternalMessage } - return fmt.Sprintf("%d: %s", e.Code, e.Message) + return fmt.Sprintf("%d: %s", e.HTTPStatus, e.Message) } func (e *HTTPError) Is(target error) bool { @@ -160,50 +137,12 @@ func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) * return e } -func httpError(code int, fmtString string, args ...interface{}) *HTTPError { +func httpError(httpStatus int, errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { return &HTTPError{ - Code: code, - Message: fmt.Sprintf(fmtString, args...), - } -} - -// OTPError is a custom error struct for phone auth errors -type OTPError struct { - Err string `json:"error"` - Description string `json:"error_description,omitempty"` - InternalError error `json:"-"` - InternalMessage string `json:"-"` -} - -func (e *OTPError) Error() string { - if e.InternalMessage != "" { - return e.InternalMessage - } - return fmt.Sprintf("%s: %s", e.Err, e.Description) -} - -// WithInternalError adds internal error information to the error -func (e *OTPError) WithInternalError(err error) *OTPError { - e.InternalError = err - return e -} - -// WithInternalMessage adds internal message information to the error -func (e *OTPError) WithInternalMessage(fmtString string, args ...interface{}) *OTPError { - e.InternalMessage = fmt.Sprintf(fmtString, args...) - return e -} - -// Cause returns the root cause error -func (e *OTPError) Cause() error { - if e.InternalError != nil { - return e.InternalError + HTTPStatus: httpStatus, + ErrorCode: errorCode, + Message: fmt.Sprintf(fmtString, args...), } - return e -} - -func otpError(err string, description string) *OTPError { - return &OTPError{Err: err, Description: description} } // Recoverer is a middleware that recovers from panics, logs the panic (and a @@ -222,10 +161,10 @@ func recoverer(w http.ResponseWriter, r *http.Request) (context.Context, error) } se := &HTTPError{ - Code: http.StatusInternalServerError, - Message: http.StatusText(http.StatusInternalServerError), + HTTPStatus: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), } - handleError(se, w, r) + HandleResponseError(se, w, r) } }() @@ -237,28 +176,58 @@ type ErrorCause interface { Cause() error } -func handleError(err error, w http.ResponseWriter, r *http.Request) { +type HTTPErrorResponse20240101 struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` +} + +func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { log := observability.GetLogEntry(r) errorID := getRequestID(r.Context()) + + apiVersion, averr := DetermineClosestAPIVersion(r.Header.Get(APIVersionHeaderName)) + if averr != nil { + log.WithError(averr).Warn("Invalid version passed to " + APIVersionHeaderName + " header, defaulting to initial version") + } + switch e := err.(type) { case *WeakPasswordError: - var output struct { - HTTPError - Payload struct { - Reasons []string `json:"reasons,omitempty"` - } `json:"weak_password,omitempty"` - } + if apiVersion.Compare(APIVersion20240101) >= 0 { + var output struct { + HTTPErrorResponse20240101 + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } - output.Code = http.StatusUnprocessableEntity - output.Message = e.Message - output.Payload.Reasons = e.Reasons + output.Code = ErrorCodeWeakPassword + output.Message = e.Message + output.Payload.Reasons = e.Reasons - if jsonErr := sendJSON(w, output.Code, output); jsonErr != nil { - handleError(jsonErr, w, r) + if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + + } else { + var output struct { + HTTPError + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } + + output.HTTPStatus = http.StatusUnprocessableEntity + output.ErrorCode = ErrorCodeWeakPassword + output.Message = e.Message + output.Payload.Reasons = e.Reasons + + if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } } case *HTTPError: - if e.Code >= http.StatusInternalServerError { + if e.HTTPStatus >= http.StatusInternalServerError { e.ErrorID = errorID // this will get us the stack trace too log.WithError(e.Cause()).Error(e.Error()) @@ -266,35 +235,76 @@ func handleError(err error, w http.ResponseWriter, r *http.Request) { log.WithError(e.Cause()).Info(e.Error()) } - // Provide better error messages for certain user-triggered Postgres errors. - if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { - if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil { - handleError(jsonErr, w, r) + if apiVersion.Compare(APIVersion20240101) >= 0 { + resp := HTTPErrorResponse20240101{ + Code: e.ErrorCode, + Message: e.Message, + } + + if resp.Code == "" { + if e.HTTPStatus == http.StatusInternalServerError { + resp.Code = ErrorCodeUnexpectedFailure + } else { + resp.Code = ErrorCodeUnknown + } + } + + if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + } else { + if e.ErrorCode == "" { + if e.HTTPStatus == http.StatusInternalServerError { + e.ErrorCode = ErrorCodeUnexpectedFailure + } else { + e.ErrorCode = ErrorCodeUnknown + } } - return - } - if jsonErr := sendJSON(w, e.Code, e); jsonErr != nil { - handleError(jsonErr, w, r) + // Provide better error messages for certain user-triggered Postgres errors. + if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { + if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + return + } + + if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } } + case *OAuthError: log.WithError(e.Cause()).Info(e.Error()) if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { - handleError(jsonErr, w, r) - } - case *OTPError: - log.WithError(e.Cause()).Info(e.Error()) - if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { - handleError(jsonErr, w, r) + HandleResponseError(jsonErr, w, r) } + case ErrorCause: - handleError(e.Cause(), w, r) + HandleResponseError(e.Cause(), w, r) + default: log.WithError(e).Errorf("Unhandled server error: %s", e.Error()) - // hide real error details from response to prevent info leaks - w.WriteHeader(http.StatusInternalServerError) - if _, writeErr := w.Write([]byte(`{"code":500,"msg":"Internal server error","error_id":"` + errorID + `"}`)); writeErr != nil { - log.WithError(writeErr).Error("Error writing generic error message") + + if apiVersion.Compare(APIVersion20240101) >= 0 { + resp := HTTPErrorResponse20240101{ + Code: ErrorCodeUnexpectedFailure, + Message: "Unexpected failure, please check server logs for more information", + } + + if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + } else { + httpError := HTTPError{ + HTTPStatus: http.StatusInternalServerError, + ErrorCode: ErrorCodeUnexpectedFailure, + Message: "Unexpected failure, please check server logs for more information", + } + + if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } } } } diff --git a/internal/api/errors_test.go b/internal/api/errors_test.go new file mode 100644 index 000000000..fc6135205 --- /dev/null +++ b/internal/api/errors_test.go @@ -0,0 +1,64 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHandleResponseErrorWithHTTPError(t *testing.T) { + examples := []struct { + HTTPError *HTTPError + APIVersion string + ExpectedBody string + }{ + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "", + ExpectedBody: "{\"code\":400,\"error_code\":\"" + ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", + }, + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2023-12-31", + ExpectedBody: "{\"code\":400,\"error_code\":\"" + ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", + }, + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeBadJSON + "\",\"message\":\"Unable to parse JSON\"}", + }, + { + HTTPError: &HTTPError{ + HTTPStatus: http.StatusBadRequest, + Message: "Uncoded failure", + }, + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeUnknown + "\",\"message\":\"Uncoded failure\"}", + }, + { + HTTPError: &HTTPError{ + HTTPStatus: http.StatusInternalServerError, + Message: "Unexpected failure", + }, + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeUnexpectedFailure + "\",\"message\":\"Unexpected failure\"}", + }, + } + + for _, example := range examples { + rec := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com", nil) + require.NoError(t, err) + + if example.APIVersion != "" { + req.Header.Set(APIVersionHeaderName, example.APIVersion) + } + + HandleResponseError(example.HTTPError, rec, req) + + require.Equal(t, example.HTTPError.HTTPStatus, rec.Code) + require.Equal(t, example.ExpectedBody, rec.Body.String()) + } +} diff --git a/internal/api/external.go b/internal/api/external.go index a20e3ba59..46a8bd9ff 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -56,7 +56,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ p, err := a.Provider(ctx, providerType, scopes) if err != nil { - return "", badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return "", badRequestError(ErrorCodeValidationFailed, "Unsupported provider: %+v", err).WithInternalError(err) } inviteToken := query.Get("invite_token") @@ -64,7 +64,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ _, userErr := models.FindUserByConfirmationToken(db, inviteToken) if userErr != nil { if models.IsNotFoundError(userErr) { - return "", notFoundError(userErr.Error()) + return "", notFoundError(ErrorCodeUserNotFound, "User identified by token not found") } return "", internalServerError("Database error finding user").WithInternalError(userErr) } @@ -82,14 +82,12 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ if flowType == models.PKCEFlow { codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod) if err != nil { - return "", err - } - flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth) - if err != nil { - return "", err + return "", badRequestError(ErrorCodeValidationFailed, "Code challenge not valid").WithInternalError(err) } + flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth) + if err := a.db.Create(flowState); err != nil { - return "", err + return "", internalServerError("Failed to create flow state").WithInternalError(err) } flowStateID = flowState.ID.String() } @@ -134,6 +132,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ } authURL := p.AuthCodeURL(tokenString, authUrlParams...) + return authURL, nil } @@ -203,9 +202,12 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re // if there's a non-empty FlowStateID we perform PKCE Flow if flowStateID := getFlowStateID(ctx); flowStateID != "" { flowState, err = models.FindFlowStateByID(a.db, flowStateID) - if err != nil { - return err + if models.IsNotFoundError(err) { + return unprocessableEntityError(ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err) + } else if err != nil { + return internalServerError("Failed to find flow state").WithInternalError(err) } + } var user *models.User @@ -304,7 +306,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. case models.CreateAccount: if config.DisableSignup { - return nil, forbiddenError("Signups not allowed for this instance") + return nil, unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{ @@ -351,14 +353,14 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } case models.MultipleAccounts: - return nil, internalServerError(fmt.Sprintf("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain)) + return nil, internalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain) default: - return nil, internalServerError(fmt.Sprintf("Unknown automatic linking decision: %v", decision.Decision)) + return nil, internalServerError("Unknown automatic linking decision: %v", decision.Decision) } if user.IsBanned() { - return nil, unauthorizedError("User is unauthorized") + return nil, unprocessableEntityError(ErrorCodeUserBanned, "User is banned") } if !user.IsConfirmed() { @@ -391,7 +393,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. externalURL := getExternalHost(ctx) if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { - return nil, tooManyRequestsError("For security purposes, you can only request this once every minute") + return nil, tooManyRequestsError(ErrorCodeOverEmailSendRate, "For security purposes, you can only request this once every minute") } return nil, internalServerError("Error sending confirmation mail").WithInternalError(terr) } @@ -399,9 +401,9 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } if !config.Mailer.AllowUnverifiedEmailSignIns { if emailConfirmationSent { - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) + return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) } - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) + return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) } } } else { @@ -423,7 +425,7 @@ func (a *API) processInvite(r *http.Request, ctx context.Context, tx *storage.Co user, err := models.FindUserByConfirmationToken(tx, inviteToken) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()) + return nil, notFoundError(ErrorCodeInviteNotFound, "Invite not found") } return nil, internalServerError("Database error finding user").WithInternalError(err) } @@ -439,7 +441,7 @@ func (a *API) processInvite(r *http.Request, ctx context.Context, tx *storage.Co } if emailData == nil { - return nil, badRequestError("Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) + return nil, badRequestError(ErrorCodeValidationFailed, "Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) } var identityData map[string]interface{} @@ -495,8 +497,11 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont _, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) { return []byte(config.JWT.Secret), nil }) - if err != nil || claims.Provider == "" { - return nil, badRequestError("OAuth state is invalid: %v", err) + if err != nil { + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) + } + if claims.Provider == "" { + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)") } if claims.InviteToken != "" { ctx = withInviteToken(ctx, claims.InviteToken) @@ -510,12 +515,12 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont if claims.LinkingTargetID != "" { linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) if err != nil { - return nil, badRequestError("invalid target user id") + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)") } u, err := models.FindUserByID(a.db, linkingTargetUserID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("Linking target user not found") + return nil, unprocessableEntityError(ErrorCodeUserNotFound, "Linking target user not found") } return nil, internalServerError("Database error loading user").WithInternalError(err) } @@ -606,12 +611,18 @@ func (a *API) redirectErrors(handler apiHandler, w http.ResponseWriter, r *http. func getErrorQueryString(err error, errorID string, log logrus.FieldLogger, q url.Values) *url.Values { switch e := err.(type) { case *HTTPError: - if str, ok := oauthErrorMap[e.Code]; ok { + if e.ErrorCode == ErrorCodeSignupDisabled { + q.Set("error", "access_denied") + } else if e.ErrorCode == ErrorCodeUserBanned { + q.Set("error", "access_denied") + } else if e.ErrorCode == ErrorCodeProviderEmailNeedsVerification { + q.Set("error", "access_denied") + } else if str, ok := oauthErrorMap[e.HTTPStatus]; ok { q.Set("error", str) } else { q.Set("error", "server_error") } - if e.Code >= http.StatusInternalServerError { + if e.HTTPStatus >= http.StatusInternalServerError { e.ErrorID = errorID // this will get us the stack trace too log.WithError(e.Cause()).Error(e.Error()) @@ -619,7 +630,7 @@ func getErrorQueryString(err error, errorID string, log logrus.FieldLogger, q ur log.WithError(e.Cause()).Info(e.Error()) } q.Set("error_description", e.Message) - q.Set("error_code", strconv.Itoa(e.Code)) + q.Set("error_code", strconv.Itoa(e.HTTPStatus)) case *OAuthError: q.Set("error", e.Err) q.Set("error_description", e.Description) diff --git a/internal/api/external_figma_test.go b/internal/api/external_figma_test.go index 56d2f478d..bd7a8c29c 100644 --- a/internal/api/external_figma_test.go +++ b/internal/api/external_figma_test.go @@ -260,5 +260,5 @@ func (ts *ExternalTestSuite) TestSignupExternalFigmaErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "figma", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_fly_test.go b/internal/api/external_fly_test.go index c469f2900..3c33c53e2 100644 --- a/internal/api/external_fly_test.go +++ b/internal/api/external_fly_test.go @@ -260,5 +260,5 @@ func (ts *ExternalTestSuite) TestSignupExternalFlyErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "fly", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_github_test.go b/internal/api/external_github_test.go index b3ad58440..f6f4334d7 100644 --- a/internal/api/external_github_test.go +++ b/internal/api/external_github_test.go @@ -276,7 +276,7 @@ func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenVerifiedFalse() { u := performAuthorization(ts, "github", code, "") - assertAuthorizationFailure(ts, u, "Unverified email with github. A confirmation email has been sent to your github email", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "Unverified email with github. A confirmation email has been sent to your github email", "access_denied", "") } func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenUserBanned() { @@ -296,5 +296,5 @@ func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "github", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_kakao_test.go b/internal/api/external_kakao_test.go index 7882e1dce..cd2bd2b29 100644 --- a/internal/api/external_kakao_test.go +++ b/internal/api/external_kakao_test.go @@ -214,7 +214,7 @@ func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenVerifiedFalse() { u := performAuthorization(ts, "kakao", code, "") - assertAuthorizationFailure(ts, u, "Unverified email with kakao. A confirmation email has been sent to your kakao email", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "Unverified email with kakao. A confirmation email has been sent to your kakao email", "access_denied", "") } func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenUserBanned() { @@ -234,5 +234,5 @@ func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "kakao", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_oauth.go b/internal/api/external_oauth.go index e37b7d682..63aaf895b 100644 --- a/internal/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -2,6 +2,7 @@ package api import ( "context" + "fmt" "net/http" "net/url" @@ -30,7 +31,7 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con } if state == "" { - return nil, badRequestError("OAuth state parameter missing") + return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing") } ctx := r.Context() @@ -60,12 +61,12 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s oauthCode := rq.Get("code") if oauthCode == "" { - return nil, badRequestError("Authorization code missing") + return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing") } oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } log := observability.GetLogEntry(r) @@ -107,7 +108,7 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s func (a *API) oAuth1Callback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) { oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } oauthToken := getRequestToken(ctx) oauthVerifier := getOAuthVerifier(ctx) @@ -145,6 +146,6 @@ func (a *API) OAuthProvider(ctx context.Context, name string) (provider.OAuthPro case provider.OAuthProvider: return p, nil default: - return nil, badRequestError("Provider can not be used for OAuth") + return nil, fmt.Errorf("Provider %v cannot be used for OAuth", name) } } diff --git a/internal/api/helpers_test.go b/internal/api/helpers_test.go index ec5812e09..15f9ce4d6 100644 --- a/internal/api/helpers_test.go +++ b/internal/api/helpers_test.go @@ -16,12 +16,12 @@ func TestIsValidCodeChallenge(t *testing.T) { { challenge: "invalid", isValid: false, - expectedError: badRequestError("code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), + expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), }, { challenge: "codechallengecontainsinvalidcharacterslike@$^&*", isValid: false, - expectedError: badRequestError("code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), + expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), }, { challenge: "validchallengevalidchallengevalidchallengevalidchallenge", @@ -56,12 +56,12 @@ func TestIsValidPKCEParmas(t *testing.T) { { challengeMethod: "test", challenge: "", - expected: badRequestError(InvalidPKCEParamsErrorMessage), + expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), }, { challengeMethod: "", challenge: "test", - expected: badRequestError(InvalidPKCEParamsErrorMessage), + expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), }, } diff --git a/internal/api/hook_test.go b/internal/api/hook_test.go index a5d573db0..c9e7a703a 100644 --- a/internal/api/hook_test.go +++ b/internal/api/hook_test.go @@ -158,7 +158,7 @@ func TestHookTimeout(t *testing.T) { require.Error(t, err) herr, ok := err.(*HTTPError) require.True(t, ok) - assert.Equal(t, http.StatusGatewayTimeout, herr.Code) + assert.Equal(t, http.StatusGatewayTimeout, herr.HTTPStatus) svr.Close() assert.Equal(t, 3, callCount) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 2fd31983e..8895ef930 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -104,7 +104,7 @@ func (w *Webhook) trigger() (io.ReadCloser, error) { // timed out - try again? if i == w.Retries-1 { closeBody(rsp) - return nil, httpError(http.StatusGatewayTimeout, "Failed to perform webhook in time frame (%v seconds)", timeout.Seconds()) + return nil, httpError(http.StatusGatewayTimeout, ErrorCodeUnexpectedFailure, "Failed to perform webhook in time frame (%v seconds)", timeout.Seconds()) } hooklog.Info("Request timed out") continue @@ -135,7 +135,7 @@ func (w *Webhook) trigger() (io.ReadCloser, error) { } hooklog.Infof("Failed to process webhook for %s after %d attempts", w.URL, w.Retries) - return nil, unprocessableEntityError("Failed to handle signup webhook") + return nil, internalServerError("Failed to handle signup webhook") } func (w *Webhook) generateSignature() (string, error) { @@ -347,8 +347,8 @@ func (a *API) invokeHook(ctx context.Context, input, output any) error { } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -373,8 +373,8 @@ func (a *API) invokeHook(ctx context.Context, input, output any) error { } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -399,8 +399,8 @@ func (a *API) invokeHook(ctx context.Context, input, output any) error { } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -413,8 +413,8 @@ func (a *API) invokeHook(ctx context.Context, input, output any) error { } httpError := &HTTPError{ - Code: httpCode, - Message: err.Error(), + HTTPStatus: httpCode, + Message: err.Error(), } return httpError diff --git a/internal/api/identity.go b/internal/api/identity.go index 14f2c167d..4f1a6cc51 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -17,22 +17,22 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { claims := getClaims(ctx) if claims == nil { - return badRequestError("Could not read claims") - } - - aud := a.requestAud(ctx, r) - if aud != claims.Audience { - return badRequestError("Token audience doesn't match request audience") + return internalServerError("Could not read claims") } identityID, err := uuid.FromString(chi.URLParam(r, "identity_id")) if err != nil { - return badRequestError("identity_id must be an UUID") + return notFoundError(ErrorCodeValidationFailed, "identity_id must be an UUID") + } + + aud := a.requestAud(ctx, r) + if aud != claims.Audience { + return forbiddenError(ErrorCodeUnexpectedAudience, "Token audience doesn't match request audience") } user := getUser(ctx) if len(user.Identities) <= 1 { - return badRequestError("User must have at least 1 identity after unlinking") + return unprocessableEntityError(ErrorCodeLastIdentityNotDeletable, "User must have at least 1 identity after unlinking") } var identityToBeDeleted *models.Identity for i := range user.Identities { @@ -43,7 +43,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { } } if identityToBeDeleted == nil { - return badRequestError("Identity doesn't exist") + return notFoundError(ErrorCodeIdentityNotFound, "Identity doesn't exist") } err = a.db.Transaction(func(tx *storage.Connection) error { @@ -59,7 +59,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { } if terr := user.UpdateUserEmail(tx); terr != nil { if models.IsUniqueConstraintViolatedError(terr) { - return forbiddenError("Unable to unlink identity due to email conflict").WithInternalError(terr) + return unprocessableEntityError(ErrorCodeEmailConflictIdentityNotDeletable, "Unable to unlink identity due to email conflict").WithInternalError(terr) } return internalServerError("Database error updating user email").WithInternalError(terr) } @@ -102,9 +102,9 @@ func (a *API) linkIdentityToUser(ctx context.Context, tx *storage.Connection, us } if identity != nil { if identity.UserID == targetUser.ID { - return nil, badRequestError("Identity is already linked") + return nil, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked") } - return nil, badRequestError("Identity is already linked to another user") + return nil, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked to another user") } if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil { return nil, terr diff --git a/internal/api/identity_test.go b/internal/api/identity_test.go index 8238bf6c2..f6bca67df 100644 --- a/internal/api/identity_test.go +++ b/internal/api/identity_test.go @@ -72,6 +72,6 @@ func (ts *IdentityTestSuite) TestLinkIdentityToUser() { }, } u, err = ts.API.linkIdentityToUser(ctx, ts.API.db, testExistingUserData, "email") - require.ErrorIs(ts.T(), err, badRequestError("Identity is already linked")) + require.ErrorIs(ts.T(), err, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked")) require.Nil(ts.T(), u) } diff --git a/internal/api/invite.go b/internal/api/invite.go index 2a0aeb51d..7d4ea593e 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -27,11 +27,11 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read Invite params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read Invite params: %v", err).WithInternalError(err) } params.Email, err = validateEmail(params.Email) @@ -48,7 +48,7 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { err = db.Transaction(func(tx *storage.Connection) error { if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } else { signupParams := SignupParams{ diff --git a/internal/api/invite_test.go b/internal/api/invite_test.go index 466682028..c525e8747 100644 --- a/internal/api/invite_test.go +++ b/internal/api/invite_test.go @@ -162,7 +162,7 @@ func (ts *InviteTestSuite) TestInvite_WithoutAccess() { w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) + assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) // 401 OK because the invite request above has no Authorization header } func (ts *InviteTestSuite) TestVerifyInvite() { diff --git a/internal/api/logout.go b/internal/api/logout.go index ad95b22a4..cd1394eda 100644 --- a/internal/api/logout.go +++ b/internal/api/logout.go @@ -36,7 +36,7 @@ func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { scope = LogoutOthers default: - return badRequestError(fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) + return badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) } } diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index e1b12caaf..307f9f5da 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -24,7 +25,7 @@ type MagicLinkParams struct { func (p *MagicLinkParams) Validate() error { if p.Email == "" { - return unprocessableEntityError("Password recovery requires an email") + return unprocessableEntityError(ErrorCodeValidationFailed, "Password recovery requires an email") } var err error p.Email, err = validateEmail(p.Email) @@ -44,14 +45,14 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } params := &MagicLinkParams{} jsonDecoder := json.NewDecoder(r.Body) err := jsonDecoder.Decode(params) if err != nil { - return badRequestError("Could not read verification params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read verification params: %v", err).WithInternalError(err) } if err := params.Validate(); err != nil { @@ -82,7 +83,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { // Sign them up with temporary password. password, err := password.Generate(64, 10, 1, false, true) if err != nil { - internalServerError("error creating user").WithInternalError(err) + return internalServerError("error creating user").WithInternalError(err) } signUpParams := &SignupParams{ @@ -94,7 +95,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) } r.Body = io.NopCloser(strings.NewReader(string(newBodyContent))) r.ContentLength = int64(len(string(newBodyContent))) @@ -113,7 +115,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { } metadata, err := json.Marshal(newBodyContent) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) } r.Body = io.NopCloser(bytes.NewReader(metadata)) return a.MagicLink(w, r) @@ -148,7 +151,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRate, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Error sending magic link").WithInternalError(err) } diff --git a/internal/api/mail.go b/internal/api/mail.go index 6d4bd0817..4910a83df 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -52,11 +52,11 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not parse JSON: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not parse JSON: %v", err).WithInternalError(err) } params.Email, err = validateEmail(params.Email) @@ -72,14 +72,17 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) if err != nil { if models.IsNotFoundError(err) { - if params.Type == magicLinkVerification { + switch params.Type { + case magicLinkVerification: params.Type = signupVerification params.Password, err = password.Generate(64, 10, 1, false, true) if err != nil { - return internalServerError("error creating user").WithInternalError(err) + // password generation must always succeed + panic(err) } - } else if params.Type == recoveryVerification || params.Type == "email_change_current" || params.Type == "email_change_new" { - return notFoundError(err.Error()) + + default: + return notFoundError(ErrorCodeUserNotFound, "User with this email not found") } } else { return internalServerError("Database error finding user").WithInternalError(err) @@ -90,7 +93,8 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { now := time.Now() otp, err := crypto.GenerateOtp(config.Mailer.OtpLength) if err != nil { - return err + // OTP generation must always succeed + panic(err) } hashedToken := crypto.GenerateTokenHash(params.Email, otp) @@ -124,11 +128,14 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } user.RecoveryToken = hashedToken user.RecoverySentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + terr = tx.UpdateOnly(user, "recovery_token", "recovery_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for recovery") + } case inviteVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } else { signupParams := &SignupParams{ @@ -168,11 +175,14 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { user.ConfirmationToken = hashedToken user.ConfirmationSentAt = &now user.InvitedAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for invite") + } case signupVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } if err := user.UpdateUserMetaData(tx, params.Data); err != nil { return internalServerError("Database error updating user").WithInternalError(err) @@ -197,19 +207,22 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } user.ConfirmationToken = hashedToken user.ConfirmationSentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for confirmation") + } case "email_change_current", "email_change_new": if !config.Mailer.SecureEmailChangeEnabled && params.Type == "email_change_current" { - return unprocessableEntityError("Enable secure email change to generate link for current email") + return badRequestError(ErrorCodeValidationFailed, "Enable secure email change to generate link for current email") } params.NewEmail, terr = validateEmail(params.NewEmail) if terr != nil { - return unprocessableEntityError("The new email address provided is invalid") + return terr } if duplicateUser, terr := models.IsDuplicatedEmail(tx, params.NewEmail, user.Aud, user); terr != nil { return internalServerError("Database error checking email").WithInternalError(terr) } else if duplicateUser != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } now := time.Now() user.EmailChangeSentAt = &now @@ -220,9 +233,12 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } else if params.Type == "email_change_new" { user.EmailChangeTokenNew = crypto.GenerateTokenHash(params.NewEmail, otp) } - terr = errors.Wrap(tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status"), "Database error updating user for email change") + terr = tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for email change") + } default: - return badRequestError("Invalid email action link type requested: %v", params.Type) + return badRequestError(ErrorCodeValidationFailed, "Invalid email action link type requested: %v", params.Type) } if terr != nil { @@ -261,7 +277,8 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail oldToken := u.ConfirmationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.ConfirmationToken = addFlowPrefixToToken(token, flowType) @@ -271,7 +288,12 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail return errors.Wrap(err, "Error sending confirmation email") } u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for confirmation") + } + + return nil } func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, externalURL *url.URL, otpLength int) error { @@ -279,7 +301,8 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re oldToken := u.ConfirmationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.ConfirmationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) now := time.Now() @@ -289,7 +312,12 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re } u.InvitedAt = &now u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for invite") + } + + return nil } func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { @@ -301,7 +329,8 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile oldToken := u.RecoveryToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) @@ -311,7 +340,12 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile return errors.Wrap(err, "Error sending recovery email") } u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for recovery") + } + + return nil } func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, otpLength int) error { @@ -323,19 +357,22 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma oldToken := u.ReauthenticationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) - if err != nil { - return err - } now := time.Now() if err := mailer.ReauthenticateMail(u, otp); err != nil { u.ReauthenticationToken = oldToken return errors.Wrap(err, "Error sending reauthentication email") } u.ReauthenticationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"), "Database error updating user for reauthentication") + err = tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for reauthentication") + } + + return nil } func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { @@ -348,7 +385,8 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile oldToken := u.RecoveryToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) @@ -359,7 +397,12 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile return errors.Wrap(err, "Error sending magic link email") } u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for recovery") + } + + return nil } // sendEmailChange sends out an email change token to the new email. @@ -370,7 +413,8 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu } otpNew, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.EmailChange = email token := crypto.GenerateTokenHash(u.EmailChange, otpNew) @@ -380,7 +424,8 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu if config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { otpCurrent, err = crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } currentToken := crypto.GenerateTokenHash(u.GetEmail(), otpCurrent) u.EmailChangeTokenCurrent = addFlowPrefixToToken(currentToken, flowType) @@ -396,22 +441,28 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu } u.EmailChangeSentAt = &now - return errors.Wrap(tx.UpdateOnly( + err = tx.UpdateOnly( u, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status", - ), "Database error updating user for email change") + ) + + if err != nil { + return errors.Wrap(err, "Database error updating user for email change") + } + + return nil } func validateEmail(email string) (string, error) { if email == "" { - return "", unprocessableEntityError("An email address is required") + return "", badRequestError(ErrorCodeValidationFailed, "An email address is required") } if err := checkmail.ValidateFormat(email); err != nil { - return "", unprocessableEntityError("Unable to validate email address: " + err.Error()) + return "", badRequestError(ErrorCodeValidationFailed, "Unable to validate email address: "+err.Error()) } return strings.ToLower(email), nil } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 434a7117c..6c0e00052 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -72,11 +72,11 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "invalid body: unable to parse JSON").WithInternalError(err) } if params.FactorType != models.TOTP { - return badRequestError("factor_type needs to be totp") + return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be totp") } if params.Issuer == "" { @@ -96,7 +96,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if len(factors) >= int(config.MFA.MaxEnrolledFactors) { - return forbiddenError("Enrolled factors exceed allowed limit, unenroll to continue") + return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Enrolled factors exceed allowed limit, unenroll to continue") } numVerifiedFactors := 0 @@ -107,7 +107,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return forbiddenError("Maximum number of enrolled factors reached, unenroll to continue") + return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of enrolled factors reached, unenroll to continue") } key, err := totp.Generate(totp.GenerateOpts{ @@ -133,7 +133,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { if terr := tx.Create(factor); terr != nil { pgErr := utilities.NewPostgresError(terr) if pgErr.IsUniqueConstraintViolated() { - return badRequestError(fmt.Sprintf("a factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) + return unprocessableEntityError(ErrorCodeMFAFactorNameConflict, fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) } return terr @@ -209,7 +209,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "invalid body: unable to parse JSON").WithInternalError(err) } if !factor.IsOwnedBy(user) { @@ -219,13 +219,13 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { challenge, err := models.FindChallengeByChallengeID(a.db, params.ChallengeID) if err != nil { if models.IsNotFoundError(err) { - return notFoundError(err.Error()) + return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found") } return internalServerError("Database error finding Challenge").WithInternalError(err) } if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { - return badRequestError("Challenge and verify IP addresses mismatch") + return unprocessableEntityError(ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch") } if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { @@ -239,7 +239,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID) + return unprocessableEntityError(ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) } valid := totp.Validate(params.Code, factor.Secret) @@ -267,11 +267,11 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { output.Message = hooks.DefaultMFAHookRejectionMessage } - return forbiddenError(output.Message) + return forbiddenError(ErrorCodeMFAVerificationRejected, output.Message) } } if !valid { - return badRequestError("Invalid TOTP code entered") + return unprocessableEntityError(ErrorCodeMFAVerificationFailed, "Invalid TOTP code entered") } var token *AccessTokenResponse @@ -332,7 +332,7 @@ func (a *API) UnenrollFactor(w http.ResponseWriter, r *http.Request) error { } if factor.IsVerified() && !session.IsAAL2() { - return badRequestError("AAL2 required to unenroll verified factor") + return unprocessableEntityError(ErrorCodeInsufficientAAL, "AAL2 required to unenroll verified factor") } if !factor.IsOwnedBy(user) { return internalServerError(InvalidFactorOwnerErrorMessage) diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 5fe7f0bc6..8d30be7d1 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -5,13 +5,14 @@ import ( "context" "encoding/json" "fmt" - "github.com/gofrs/uuid" "net/http" "net/http/httptest" "strings" "testing" "time" + "github.com/gofrs/uuid" + "github.com/pquerna/otp" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" @@ -163,14 +164,14 @@ func (ts *MFATestSuite) TestDuplicateEnrollsReturnExpectedMessage() { token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, ts.TestUser, nil, models.TOTPSignIn) require.NoError(ts.T(), err) _ = performEnrollFlow(ts, token, friendlyName, models.TOTP, "https://issuer.com", http.StatusOK) - response := performEnrollFlow(ts, token, friendlyName, models.TOTP, "https://issuer.com", http.StatusBadRequest) + response := performEnrollFlow(ts, token, friendlyName, models.TOTP, "https://issuer.com", http.StatusUnprocessableEntity) var errorResponse HTTPError err = json.NewDecoder(response.Body).Decode(&errorResponse) require.NoError(ts.T(), err) // Convert the response body to a string and check for the expected error message - expectedErrorMessage := fmt.Sprintf("a factor with the friendly name %q for this user likely already exists", friendlyName) + expectedErrorMessage := fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", friendlyName) require.Contains(ts.T(), errorResponse.Message, expectedErrorMessage) } @@ -193,13 +194,13 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { desc: "Invalid: Valid code and expired challenge", validChallenge: false, validCode: true, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Invalid: Invalid code and valid challenge ", validChallenge: true, validCode: false, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Valid /verify request", @@ -284,7 +285,7 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() { { desc: "Verified Factor: AAL1", isAAL2: false, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Verified Factor: AAL2, Success", diff --git a/internal/api/middleware.go b/internal/api/middleware.go index c0f3d857a..a620c2ecb 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -66,7 +66,7 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { } else { err := tollbooth.LimitByKeys(lmt, []string{key}) if err != nil { - return c, httpError(http.StatusTooManyRequests, "Rate limit exceeded") + return c, tooManyRequestsError(ErrorCodeOverRequestRate, "Request rate limit reached") } } } @@ -106,7 +106,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { } if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { - return c, badRequestError("Error invalid request body").WithInternalError(err) + return c, badRequestError(ErrorCodeBadJSON, "Error invalid request body").WithInternalError(err) } if shouldRateLimitEmail { @@ -117,7 +117,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { 1, attribute.String("path", req.URL.Path), ) - return c, httpError(http.StatusTooManyRequests, "Email rate limit exceeded") + return c, tooManyRequestsError(ErrorCodeOverEmailSendRate, "Email rate limit exceeded") } } } @@ -125,7 +125,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { if shouldRateLimitPhone { if requestBody.Phone != "" { if err := tollbooth.LimitByKeys(phoneLimiter, []string{"phone_functions"}); err != nil { - return c, httpError(http.StatusTooManyRequests, "Sms rate limit exceeded") + return c, tooManyRequestsError(ErrorCodeOverSMSSendRate, "SMS rate limit exceeded") } } } @@ -156,7 +156,7 @@ func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (co config := a.config if !config.External.Email.Enabled { - return nil, badRequestError("Email logins are disabled") + return nil, badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } return ctx, nil @@ -183,8 +183,7 @@ func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.C } if !verificationResult.Success { - return nil, badRequestError("captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) - + return nil, badRequestError(ErrorCodeCaptchaFailed, "captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) } return ctx, nil @@ -228,7 +227,7 @@ func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (con func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.SAML.Enabled { - return nil, notFoundError("SAML 2.0 is disabled") + return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled") } return ctx, nil } @@ -236,7 +235,7 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.Security.ManualLinkingEnabled { - return nil, notFoundError("Manual linking is disabled") + return nil, notFoundError(ErrorCodeManualLinkingDisabled, "Manual linking is disabled") } return ctx, nil } diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index c532a50ef..e591121f8 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -176,7 +176,7 @@ func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() { w := httptest.NewRecorder() _, err := ts.API.verifyCaptcha(w, req) - require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).Code) + require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).HTTPStatus) require.Equal(ts.T(), c.expectedMsg, err.(*HTTPError).Message) }) } @@ -201,8 +201,8 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() { }, }, { - desc: "Sms rate limit exceeded", - expectedErrorMsg: "429: Sms rate limit exceeded", + desc: "SMS rate limit exceeded", + expectedErrorMsg: "429: SMS rate limit exceeded", requestBody: map[string]interface{}{ "phone": "+1233456789", }, @@ -269,7 +269,7 @@ func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() { { desc: "SAML not enabled", isEnabled: false, - expectedErr: notFoundError("SAML 2.0 is disabled"), + expectedErr: notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled"), }, { desc: "SAML enabled", diff --git a/internal/api/otp.go b/internal/api/otp.go index 0700f0970..016a56159 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -34,10 +34,10 @@ type SmsParams struct { func (p *OtpParams) Validate() error { if p.Email != "" && p.Phone != "" { - return badRequestError("Only an email address or phone number should be provided") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided") } if p.Email != "" && p.Channel != "" { - return badRequestError("Channel should only be specified with Phone OTP") + return badRequestError(ErrorCodeValidationFailed, "Channel should only be specified with Phone OTP") } if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { return err @@ -47,7 +47,7 @@ func (p *OtpParams) Validate() error { func (p *SmsParams) Validate(smsProvider string) error { if p.Phone != "" && !sms_provider.IsValidMessageChannel(p.Channel, smsProvider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } var err error @@ -74,7 +74,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { } if err = json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read verification params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read verification params: %v", err) } if err := params.Validate(); err != nil { @@ -85,7 +85,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { } if ok, err := a.shouldCreateUser(r, params); !ok { - return badRequestError("Signups not allowed for otp") + return unprocessableEntityError(ErrorCodeOTPDisabled, "Signups not allowed for otp") } else if err != nil { return err } @@ -96,7 +96,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { return a.SmsOtp(w, r) } - return otpError("unsupported_otp_type", "") + return badRequestError(ErrorCodeValidationFailed, "One of email or phone must be set") } type SmsOtpResponse struct { @@ -110,7 +110,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Phone.Enabled { - return badRequestError("Unsupported phone provider") + return badRequestError(ErrorCodePhoneProviderDisabled, "Unsupported phone provider") } var err error @@ -118,11 +118,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read sms otp params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read sms otp params: %v", err) } // For backwards compatibility, we default to SMS if params Channel is not specified if params.Phone != "" && params.Channel == "" { @@ -151,7 +151,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { // Sign them up with temporary password. password, err := password.Generate(64, 10, 1, false, true) if err != nil { - internalServerError("error creating user").WithInternalError(err) + return internalServerError("error creating user").WithInternalError(err) } signUpParams := &SignupParams{ @@ -162,7 +162,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must be marshallable + panic(err) } r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) @@ -180,7 +181,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must be marshallable + panic(err) } r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) return a.SmsOtp(w, r) @@ -201,11 +203,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Unable to get SMS provider").WithInternalError(err) } mID, serr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) if serr != nil { - return badRequestError("Error sending sms OTP: %v", serr) + return badRequestError(ErrorCodeSMSSendFailed, "Error sending sms OTP: %v", serr).WithInternalError(serr) } messageID = mID return nil diff --git a/internal/api/otp_test.go b/internal/api/otp_test.go index be3b18114..c72fbc361 100644 --- a/internal/api/otp_test.go +++ b/internal/api/otp_test.go @@ -80,8 +80,9 @@ func (ts *OtpTestSuite) TestOtpPKCE() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "PKCE flow requires code_challenge_method and code_challenge", + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "PKCE flow requires code_challenge_method and code_challenge", }, }, }, @@ -98,8 +99,9 @@ func (ts *OtpTestSuite) TestOtpPKCE() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "PKCE flow requires code_challenge_method and code_challenge", + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "PKCE flow requires code_challenge_method and code_challenge", }, }, }, @@ -115,10 +117,10 @@ func (ts *OtpTestSuite) TestOtpPKCE() { code int response map[string]interface{} }{ - http.StatusBadRequest, + http.StatusInternalServerError, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Error sending sms:", + "code": float64(http.StatusInternalServerError), + "msg": "Unable to get SMS provider", }, }, }, @@ -182,8 +184,9 @@ func (ts *OtpTestSuite) TestOtp() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Only an email address or phone number should be provided", + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "Only an email address or phone number should be provided", }, }, }, @@ -200,8 +203,9 @@ func (ts *OtpTestSuite) TestOtp() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": InvalidChannelError, + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": InvalidChannelError, }, }, }, @@ -244,15 +248,16 @@ func (ts *OtpTestSuite) TestNoSignupsForOtp() { ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusBadRequest, w.Code) + require.Equal(ts.T(), http.StatusUnprocessableEntity, w.Code) data := make(map[string]interface{}) require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) // response should be empty assert.Equal(ts.T(), data, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Signups not allowed for otp", + "code": float64(http.StatusUnprocessableEntity), + "error_code": ErrorCodeOTPDisabled, + "msg": "Signups not allowed for otp", }) } diff --git a/internal/api/phone.go b/internal/api/phone.go index 5d6e9bda3..fcd28b51c 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -25,7 +25,7 @@ const ( func validatePhone(phone string) (string, error) { phone = formatPhoneNumber(phone) if isValid := validateE164Format(phone); !isValid { - return "", unprocessableEntityError("Invalid phone number format (E.164 required)") + return "", badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") } return phone, nil } diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go index 911cb2c71..24cfbc06e 100644 --- a/internal/api/phone_test.go +++ b/internal/api/phone_test.go @@ -178,8 +178,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "password": "testpassword", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending confirmation sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -191,8 +191,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "phone": "123456789", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -204,8 +204,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "phone": "111111111", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -215,8 +215,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { header: "", body: nil, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, } @@ -245,7 +245,12 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { require.Equal(ts.T(), c.expected["code"], w.Code) body := w.Body.String() - require.True(ts.T(), strings.Contains(body, c.expected["message"].(string))) + require.True(ts.T(), + strings.Contains(body, "Unable to get SMS provider") || + strings.Contains(body, "Error finding SMS provider") || + strings.Contains(body, "Failed to get SMS provider"), + "unexpected body message %q", body, + ) }) } } diff --git a/internal/api/pkce.go b/internal/api/pkce.go index a186aa464..bd2809e18 100644 --- a/internal/api/pkce.go +++ b/internal/api/pkce.go @@ -21,9 +21,9 @@ func isValidCodeChallenge(codeChallenge string) (bool, error) { // See RFC 7636 Section 4.2: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 switch codeChallengeLength := len(codeChallenge); { case codeChallengeLength < MinCodeChallengeLength, codeChallengeLength > MaxCodeChallengeLength: - return false, badRequestError("code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) + return false, badRequestError(ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) case !codeChallengePattern.MatchString(codeChallenge): - return false, badRequestError("code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") + return false, badRequestError(ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") default: return true, nil } @@ -41,7 +41,7 @@ func addFlowPrefixToToken(token string, flowType models.FlowType) string { func issueAuthCode(tx *storage.Connection, user *models.User, expiryDuration time.Duration, authenticationMethod models.AuthenticationMethod) (string, error) { flowState, err := models.FindFlowStateByUserID(tx, user.ID.String(), authenticationMethod) if err != nil && models.IsNotFoundError(err) { - return "", badRequestError("No valid flow state found for user.") + return "", unprocessableEntityError(ErrorCodeFlowStateNotFound, "No valid flow state found for user.") } else if err != nil { return "", err } @@ -59,7 +59,7 @@ func isImplicitFlow(flowType models.FlowType) bool { func validatePKCEParams(codeChallengeMethod, codeChallenge string) error { switch true { case (codeChallenge == "") != (codeChallengeMethod == ""): - return badRequestError(InvalidPKCEParamsErrorMessage) + return badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage) case codeChallenge != "": if valid, err := isValidCodeChallenge(codeChallenge); !valid { return err diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go index e29f2f7e3..6d4de9abe 100644 --- a/internal/api/reauthenticate.go +++ b/internal/api/reauthenticate.go @@ -23,16 +23,16 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { email, phone := user.GetEmail(), user.GetPhone() if email == "" && phone == "" { - return unprocessableEntityError("Reauthentication requires the user to have an email or a phone number") + return badRequestError(ErrorCodeValidationFailed, "Reauthentication requires the user to have an email or a phone number") } if email != "" { if !user.IsConfirmed() { - return badRequestError("Please verify your email first.") + return unprocessableEntityError(ErrorCodeEmailNotConfirmed, "Please verify your email first.") } } else if phone != "" { if !user.IsPhoneConfirmed() { - return badRequestError("Please verify your phone first.") + return unprocessableEntityError(ErrorCodePhoneNotConfirmed, "Please verify your phone first.") } } @@ -47,7 +47,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { } else if phone != "" { smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Failed to get SMS provider").WithInternalError(terr) } mID, err := a.sendPhoneConfirmation(ctx, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) if err != nil { @@ -60,7 +60,12 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + reason := ErrorCodeOverEmailSendRate + if phone != "" { + reason = ErrorCodeOverSMSSendRate + } + + return tooManyRequestsError(reason, "For security purposes, you can only request this once every 60 seconds") } return err } @@ -77,7 +82,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { // verifyReauthentication checks if the nonce provided is valid func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, config *conf.GlobalConfiguration, user *models.User) error { if user.ReauthenticationToken == "" || user.ReauthenticationSentAt == nil { - return badRequestError(InvalidNonceMessage) + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, InvalidNonceMessage) } var isValid bool if user.GetEmail() != "" { @@ -87,7 +92,7 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi if config.Sms.IsTwilioVerifyProvider() { smsProvider, _ := sms_provider.GetSmsProvider(*config) if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(string(user.Phone), nonce); err != nil { - return expiredTokenError("Token has expired or is invalid").WithInternalError(err) + return forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } return nil } else { @@ -95,10 +100,10 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi isValid = isOtpValid(tokenHash, user.ReauthenticationToken, user.ReauthenticationSentAt, config.Sms.OtpExp) } } else { - return unprocessableEntityError("Reauthentication requires an email or a phone number") + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, "Reauthentication requires an email or a phone number") } if !isValid { - return badRequestError(InvalidNonceMessage) + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, InvalidNonceMessage) } if err := user.ConfirmReauthentication(tx); err != nil { return internalServerError("Error during reauthentication").WithInternalError(err) diff --git a/internal/api/recover.go b/internal/api/recover.go index 9a5757565..195939a17 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -19,7 +19,7 @@ type RecoverParams struct { func (p *RecoverParams) Validate() error { if p.Email == "" { - return unprocessableEntityError("Password recovery requires an email") + return badRequestError(ErrorCodeValidationFailed, "Password recovery requires an email") } var err error if p.Email, err = validateEmail(p.Email); err != nil { @@ -40,11 +40,11 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read verification params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read verification params: %v", err) } flowType := getFlowFromChallenge(params.CodeChallenge) @@ -83,7 +83,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRate, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Unable to process request").WithInternalError(err) } diff --git a/internal/api/resend.go b/internal/api/resend.go index a2fb4a52b..dfbaa1cc5 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -26,22 +26,22 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er break default: // type does not match one of the above - return badRequestError("Missing one of these types: signup, email_change, sms, phone_change") + return badRequestError(ErrorCodeValidationFailed, "Missing one of these types: signup, email_change, sms, phone_change") } if p.Email == "" && p.Type == signupVerification { - return badRequestError("Type provided requires an email address") + return badRequestError(ErrorCodeValidationFailed, "Type provided requires an email address") } if p.Phone == "" && p.Type == smsVerification { - return badRequestError("Type provided requires a phone number") + return badRequestError(ErrorCodeValidationFailed, "Type provided requires a phone number") } var err error if p.Email != "" && p.Phone != "" { - return badRequestError("Only an email address or phone number should be provided.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided.") } else if p.Email != "" { if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } p.Email, err = validateEmail(p.Email) if err != nil { @@ -49,7 +49,7 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er } } else if p.Phone != "" { if !config.External.Phone.Enabled { - return badRequestError("Phone logins are disabled") + return badRequestError(ErrorCodePhoneProviderDisabled, "Phone logins are disabled") } p.Phone, err = validatePhone(p.Phone) if err != nil { @@ -57,7 +57,7 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er } } else { // both email and phone are empty - return badRequestError("Missing email address or phone number") + return badRequestError(ErrorCodeValidationFailed, "Missing email address or phone number") } return nil } @@ -71,11 +71,11 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read params: %v", err) } if err := params.Validate(config); err != nil { @@ -162,8 +162,13 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { + reason := ErrorCodeOverEmailSendRate + if params.Type == smsVerification || params.Type == phoneChangeVerification { + reason = ErrorCodeOverSMSSendRate + } + until := time.Until(user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency)) / time.Second - return tooManyRequestsError("For security purposes, you can only request this once every %d seconds.", until) + return tooManyRequestsError(reason, "For security purposes, you can only request this once every %d seconds.", until) } return internalServerError("Unable to process request").WithInternalError(err) } diff --git a/internal/api/router.go b/internal/api/router.go index c2f06ae2e..70b41f22d 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -63,7 +63,7 @@ func handler(fn apiHandler) http.HandlerFunc { func (h apiHandler) serve(w http.ResponseWriter, r *http.Request) { if err := h(w, r); err != nil { - handleError(err, w, r) + HandleResponseError(err, w, r) } } @@ -78,7 +78,7 @@ func (m middlewareHandler) handler(next http.Handler) http.Handler { func (m middlewareHandler) serve(next http.Handler, w http.ResponseWriter, r *http.Request) { ctx, err := m(w, r) if err != nil { - handleError(err, w, r) + HandleResponseError(err, w, r) return } if ctx != nil { diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index a3932249a..149266533 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -67,7 +67,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { relayState, err := models.FindSAMLRelayStateByID(db, relayStateUUID) if models.IsNotFoundError(err) { - return badRequestError("SAML RelayState does not exist, try logging in again?") + return notFoundError(ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") } else if err != nil { return err } @@ -77,7 +77,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { return internalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) } - return badRequestError("SAML RelayState has expired. Try loggin in again?") + return unprocessableEntityError(ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try loggin in again?") } // TODO: add abuse detection to bind the RelayState UUID with a @@ -107,23 +107,23 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { // SAML Artifact responses are possible only when // RelayState can be used to identify the Identity // Provider. - return badRequestError("SAML Artifact response can only be used with SP initiated flow") + return badRequestError(ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") } samlResponse := r.FormValue("SAMLResponse") if samlResponse == "" { - return badRequestError("SAMLResponse is missing") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is missing") } responseXML, err := base64.StdEncoding.DecodeString(samlResponse) if err != nil { - return badRequestError("SAMLResponse is not a valid Base64 string") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") } var peekResponse saml.Response err = xml.Unmarshal(responseXML, &peekResponse) if err != nil { - return badRequestError("SAMLResponse is not a valid XML SAML assertion") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) } initiatedBy = "idp" @@ -131,12 +131,12 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { redirectTo = relayStateValue } else { // RelayState can't be identified, so SAML flow can't continue - return badRequestError("SAML RelayState is not a valid UUID or URL") + return badRequestError(ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") } ssoProvider, err := models.FindSAMLProviderByEntityID(db, entityId) if models.IsNotFoundError(err) { - return badRequestError("A SAML connection has not been established with this Identity Provider") + return notFoundError(ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") } else if err != nil { return err } @@ -176,10 +176,10 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { spAssertion, err := serviceProvider.ParseResponse(r, requestIds) if err != nil { if ire, ok := err.(*saml.InvalidResponseError); ok { - return badRequestError("SAML Assertion is not valid").WithInternalError(ire.PrivateErr) + return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(ire.PrivateErr) } - return badRequestError("SAML Assertion is not valid").WithInternalError(err) + return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) } assertion := SAMLAssertion{ @@ -188,7 +188,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { userID := assertion.UserID() if userID == "" { - return badRequestError("SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") + return badRequestError(ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") } claims := assertion.Process(ssoProvider.SAMLProvider.AttributeMapping) @@ -200,7 +200,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { } if email == "" { - return badRequestError("SAML Assertion does not contain an email address") + return badRequestError(ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") } else { claims["email"] = email } diff --git a/internal/api/signup.go b/internal/api/signup.go index 8b84ffd30..fc180b4ba 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -35,21 +35,21 @@ func (a *API) validateSignupParams(ctx context.Context, p *SignupParams) error { config := a.config if p.Password == "" { - return unprocessableEntityError("Signup requires a valid password") + return badRequestError(ErrorCodeValidationFailed, "Signup requires a valid password") } if err := a.checkPasswordStrength(ctx, p.Password); err != nil { return err } if p.Email != "" && p.Phone != "" { - return unprocessableEntityError("Only an email address or phone number should be provided on signup.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on signup.") } if p.Provider == "phone" && !sms_provider.IsValidMessageChannel(p.Channel, config.Sms.Provider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } // PKCE not needed as phone signups already return access token in body if p.Phone != "" && p.CodeChallenge != "" { - return badRequestError("PKCE not supported for phone signups") + return badRequestError(ErrorCodeValidationFailed, "PKCE not supported for phone signups") } if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { return err @@ -113,18 +113,18 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) if config.DisableSignup { - return forbiddenError("Signups not allowed for this instance") + return unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{} body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read Signup params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read Signup params: %v", err).WithInternalError(err) } params.ConfigureDefaults() @@ -152,7 +152,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { switch params.Provider { case "email": if !config.External.Email.Enabled { - return badRequestError("Email signups are disabled") + return badRequestError(ErrorCodeEmailProviderDisabled, "Email signups are disabled") } params.Email, err = validateEmail(params.Email) if err != nil { @@ -161,7 +161,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { user, err = models.IsDuplicatedEmail(db, params.Email, params.Aud, nil) case "phone": if !config.External.Phone.Enabled { - return badRequestError("Phone signups are disabled") + return badRequestError(ErrorCodePhoneProviderDisabled, "Phone signups are disabled") } params.Phone, err = validatePhone(params.Phone) if err != nil { @@ -169,7 +169,18 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } user, err = models.FindUserByPhoneAndAudience(db, params.Phone, params.Aud) default: - return invalidSignupError(config) + msg := "" + if config.External.Email.Enabled && config.External.Phone.Enabled { + msg = "Sign up only available with email or phone provider" + } else if config.External.Email.Enabled { + msg = "Sign up only available with email provider" + } else if config.External.Phone.Enabled { + msg = "Sign up only available with phone provider" + } else { + msg = "Sign up with this provider not possible" + } + + return badRequestError(ErrorCodeValidationFailed, msg) } if err != nil && !models.IsNotFoundError(err) { @@ -240,7 +251,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { if errors.Is(terr, MaxFrequencyLimitError) { now := time.Now() left := user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency).Sub(now) / time.Second - return tooManyRequestsError(fmt.Sprintf("For security purposes, you can only request this after %d seconds.", left)) + return tooManyRequestsError(ErrorCodeOverEmailSendRate, fmt.Sprintf("For security purposes, you can only request this after %d seconds.", left)) } return internalServerError("Error sending confirmation mail").WithInternalError(terr) } @@ -267,10 +278,10 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending confirmation sms: %v", terr) + return internalServerError("Unable to get SMS provider").WithInternalError(terr) } if _, terr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { - return badRequestError("Error sending confirmation sms: %v", terr) + return unprocessableEntityError(ErrorCodeSMSSendFailed, "Error sending confirmation sms: %v", terr).WithInternalError(terr) } } } @@ -279,10 +290,14 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { }) if err != nil { - if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every minute") + reason := ErrorCodeOverEmailSendRate + if params.Provider == "phone" { + reason = ErrorCodeOverSMSSendRate } - if errors.Is(err, UserExistsError) { + + if errors.Is(err, MaxFrequencyLimitError) { + return tooManyRequestsError(reason, "For security purposes, you can only request this once every minute") + } else if errors.Is(err, UserExistsError) { err = db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(r, tx, user, models.UserRepeatedSignUpAction, "", map[string]interface{}{ "provider": params.Provider, @@ -295,7 +310,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return err } if config.Mailer.Autoconfirm || config.Sms.Autoconfirm { - return badRequestError("User already registered") + return unprocessableEntityError(ErrorCodeUserAlreadyExists, "User already registered") } sanitizedUser, err := sanitizeUser(user, params) if err != nil { diff --git a/internal/api/sso.go b/internal/api/sso.go index d93ff82dc..07e97c6c8 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -28,9 +28,9 @@ func (p *SingleSignOnParams) validate() (bool, error) { hasDomain := p.Domain != "" if hasProviderID && hasDomain { - return hasProviderID, badRequestError("Only one of provider_id or domain supported") + return hasProviderID, badRequestError(ErrorCodeValidationFailed, "Only one of provider_id or domain supported") } else if !hasProviderID && !hasDomain { - return hasProviderID, badRequestError("A provider_id or domain needs to be provided") + return hasProviderID, badRequestError(ErrorCodeValidationFailed, "A provider_id or domain needs to be provided") } return hasProviderID, nil @@ -49,7 +49,7 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { var params SingleSignOnParams if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse request body as JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "Unable to parse request body as JSON").WithInternalError(err) } hasProviderID := false @@ -71,10 +71,7 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - flowState, err := models.NewFlowState(models.SSOSAML.String(), codeChallenge, codeChallengeMethodType, models.SSOSAML) - if err != nil { - return err - } + flowState := models.NewFlowState(models.SSOSAML.String(), codeChallenge, codeChallengeMethodType, models.SSOSAML) if err := a.db.Create(flowState); err != nil { return err } @@ -86,14 +83,14 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { if hasProviderID { ssoProvider, err = models.FindSSOProviderByID(db, params.ProviderID) if models.IsNotFoundError(err) { - return notFoundError("No such SSO provider") + return notFoundError(ErrorCodeSSOProviderNotFound, "No such SSO provider") } else if err != nil { return internalServerError("Unable to find SSO provider by ID").WithInternalError(err) } } else { ssoProvider, err = models.FindSSOProviderByDomain(db, params.Domain) if models.IsNotFoundError(err) { - return notFoundError("No SSO provider assigned for this domain") + return notFoundError(ErrorCodeSSOProviderNotFound, "No SSO provider assigned for this domain") } else if err != nil { return internalServerError("Unable to find SSO provider by domain").WithInternalError(err) } diff --git a/internal/api/sso_test.go b/internal/api/sso_test.go index 7da8b6eb2..5fc46b2d0 100644 --- a/internal/api/sso_test.go +++ b/internal/api/sso_test.go @@ -277,7 +277,7 @@ func (ts *SSOTestSuite) TestAdminCreateSSOProvider() { }, }, { - StatusCode: http.StatusBadRequest, + StatusCode: http.StatusUnprocessableEntity, Request: map[string]interface{}{ "type": "saml", "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-DUPLICATE"), diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go index 4fdecc0f8..4f2d38a44 100644 --- a/internal/api/ssoadmin.go +++ b/internal/api/ssoadmin.go @@ -29,14 +29,14 @@ func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.C idpID, err := uuid.FromString(idpParam) if err != nil { // idpParam is not UUIDv4 - return nil, notFoundError("SSO Identity Provider not found") + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") } // idpParam is a UUIDv4 provider, err := models.FindSSOProviderByID(db, idpID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("SSO Identity Provider not found") + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") } else { return nil, internalServerError("Database error finding SSO Identity Provider").WithInternalError(err) } @@ -79,19 +79,19 @@ type CreateSSOProviderParams struct { func (p *CreateSSOProviderParams) validate(forUpdate bool) error { if !forUpdate && p.Type != "saml" { - return badRequestError("Only 'saml' supported for SSO provider type") + return badRequestError(ErrorCodeValidationFailed, "Only 'saml' supported for SSO provider type") } else if p.MetadataURL != "" && p.MetadataXML != "" { - return badRequestError("Only one of metadata_xml or metadata_url needs to be set") + return badRequestError(ErrorCodeValidationFailed, "Only one of metadata_xml or metadata_url needs to be set") } else if !forUpdate && p.MetadataURL == "" && p.MetadataXML == "" { - return badRequestError("Either metadata_xml or metadata_url must be set") + return badRequestError(ErrorCodeValidationFailed, "Either metadata_xml or metadata_url must be set") } else if p.MetadataURL != "" { metadataURL, err := url.ParseRequestURI(p.MetadataURL) if err != nil { - return badRequestError("metadata_url is not a valid URL") + return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a valid URL") } if metadataURL.Scheme != "https" { - return badRequestError("metadata_url is not a HTTPS URL") + return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a HTTPS URL") } } @@ -127,7 +127,7 @@ func (p *CreateSSOProviderParams) metadata(ctx context.Context) ([]byte, *saml.E func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { if !utf8.Valid(rawMetadata) { - return nil, badRequestError("SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") } metadata, err := samlsp.ParseMetadata(rawMetadata) @@ -136,15 +136,15 @@ func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { } if metadata.EntityID == "" { - return nil, badRequestError("SAML Metadata does not contain an EntityID") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain an EntityID") } if len(metadata.IDPSSODescriptors) < 1 { - return nil, badRequestError("SAML Metadata does not contain any IDPSSODescriptor") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain any IDPSSODescriptor") } if len(metadata.IDPSSODescriptors) > 1 { - return nil, badRequestError("SAML Metadata contains multiple IDPSSODescriptors") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata contains multiple IDPSSODescriptors") } return metadata, nil @@ -153,7 +153,7 @@ func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return nil, badRequestError("Unable to create a request to metadata_url").WithInternalError(err) + return nil, internalServerError("Unable to create a request to metadata_url").WithInternalError(err) } req = req.WithContext(ctx) @@ -168,7 +168,7 @@ func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { defer utilities.SafeClose(resp.Body) if resp.StatusCode != http.StatusOK { - return nil, badRequestError("HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) + return nil, badRequestError(ErrorCodeSAMLMetadataFetchFailed, "HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) } data, err := io.ReadAll(resp.Body) @@ -191,7 +191,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er var params CreateSSOProviderParams if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "Unable to parse JSON").WithInternalError(err) } if err := params.validate(false /* <- forUpdate */); err != nil { @@ -208,7 +208,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er return err } if existingProvider != nil { - return badRequestError("SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) + return unprocessableEntityError(ErrorCodeSAMLIdPAlreadyExists, "SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) } provider := &models.SSOProvider{ @@ -231,7 +231,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er return err } if existingProvider != nil { - return badRequestError("SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) + return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) } provider.SSODomains = append(provider.SSODomains, models.SSODomain{ @@ -271,7 +271,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er var params CreateSSOProviderParams if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "Unable to parse JSON").WithInternalError(err) } if err := params.validate(true /* <- forUpdate */); err != nil { @@ -291,7 +291,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er } if provider.SAMLProvider.EntityID != metadata.EntityID { - return badRequestError("SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) + return badRequestError(ErrorCodeSAMLEntityIDMismatch, "SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) } if params.MetadataURL != "" { @@ -320,7 +320,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er if existingProvider.ID == provider.ID { keepDomains[domain] = true } else { - return badRequestError("SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) + return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) } } else { modified = true @@ -370,7 +370,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er return tx.Eager().Load(provider) }); err != nil { - return unprocessableEntityError("Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) + return unprocessableEntityError(ErrorCodeConflict, "Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) } } diff --git a/internal/api/token.go b/internal/api/token.go index 75d2e5087..18233b047 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -103,18 +103,18 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read password grant params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read password grant params: %v", err) } aud := a.requestAud(ctx, r) config := a.config if params.Email != "" && params.Phone != "" { - return unprocessableEntityError("Only an email address or phone number should be provided on login.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on login.") } var user *models.User var grantParams models.GrantParams @@ -125,13 +125,13 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri if params.Email != "" { provider = "email" if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) } else if params.Phone != "" { provider = "phone" if !config.External.Phone.Enabled { - return badRequestError("Phone logins are disabled") + return unprocessableEntityError(ErrorCodePhoneProviderDisabled, "Phone logins are disabled") } params.Phone = formatPhoneNumber(params.Phone) user, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud) @@ -183,7 +183,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri return err } } - return forbiddenError(output.Message) + return oauthError("invalid_grant", InvalidLoginMessage) } } if !isValidPassword { @@ -244,22 +244,22 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) } if err = json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "invalid body: unable to parse JSON").WithInternalError(err) } if params.AuthCode == "" || params.CodeVerifier == "" { - return badRequestError("invalid request: both auth code and code verifier should be non-empty") + return badRequestError(ErrorCodeValidationFailed, "invalid request: both auth code and code verifier should be non-empty") } flowState, err := models.FindFlowStateByAuthCode(db, params.AuthCode) // Sanity check in case user ID was not set properly if models.IsNotFoundError(err) || flowState.UserID == nil { - return forbiddenError("invalid flow state, no valid flow state found") + return notFoundError(ErrorCodeFlowStateNotFound, "invalid flow state, no valid flow state found") } else if err != nil { return err } if flowState.IsExpired(a.config.External.FlowStateExpiryDuration) { - return forbiddenError("invalid flow state, flow state has expired") + return unprocessableEntityError(ErrorCodeFlowStateExpired, "invalid flow state, flow state has expired") } user, err := models.FindUserByID(db, *flowState.UserID) @@ -267,7 +267,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) return err } if err := flowState.VerifyPKCE(params.CodeVerifier); err != nil { - return forbiddenError(err.Error()) + return badRequestError(ErrorBadCodeVerifier, err.Error()) } var token *AccessTokenResponse @@ -433,17 +433,9 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, var tokenString string var expiresAt int64 var refreshToken *models.RefreshToken - currentClaims := getClaims(ctx) - sessionId, err := uuid.FromString(currentClaims.SessionId) - if err != nil { - return nil, internalServerError("Cannot read SessionId claim as UUID").WithInternalError(err) - } - err = tx.Transaction(func(tx *storage.Connection) error { - if terr := models.AddClaimToSession(tx, sessionId, authenticationMethod); terr != nil { - return terr - } - session, terr := models.FindSessionByID(tx, sessionId, false) - if terr != nil { + session := getSession(ctx) + err := tx.Transaction(func(tx *storage.Connection) error { + if terr := models.AddClaimToSession(tx, session.ID, authenticationMethod); terr != nil { return terr } currentToken, terr := models.FindTokenBySessionID(tx, &session.ID) @@ -467,7 +459,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, return err } - tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &sessionId, models.TOTPSignIn) + tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &session.ID, models.TOTPSignIn) if terr != nil { httpErr, ok := terr.(*HTTPError) if ok { diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index c380856c3..7589f9476 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -55,7 +55,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa if issuer == "" || !provider.IsAzureIssuer(issuer) { detectedIssuer, err := provider.DetectAzureIDTokenIssuer(ctx, p.IdToken) if err != nil { - return nil, nil, "", nil, badRequestError("Unable to detect issuer in ID token for Azure provider").WithInternalError(err) + return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err) } issuer = detectedIssuer } @@ -90,12 +90,12 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa } if !allowed { - return nil, nil, "", nil, badRequestError(fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) + return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) } } if cfg != nil && !cfg.Enabled { - return nil, nil, "", nil, badRequestError(fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) + return nil, nil, "", nil, badRequestError(ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) } oidcProvider, err := oidc.NewProvider(ctx, issuer) @@ -117,11 +117,11 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read id token grant params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read id token grant params: %v", err) } if params.IdToken == "" { diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index 1cd665346..a610f7f16 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -29,11 +29,11 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read refresh token grant params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read refresh token grant params: %v", err) } if params.RefreshToken == "" { diff --git a/internal/api/token_test.go b/internal/api/token_test.go index 6a8acf6d3..04bfbdd4d 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -306,8 +306,7 @@ func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() { invalidVerifier := codeVerifier + "123" codeChallenge := sha256.Sum256([]byte(codeVerifier)) challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:]) - flowState, err := models.NewFlowState("github", challenge, models.SHA256, models.OAuth) - require.NoError(ts.T(), err) + flowState := models.NewFlowState("github", challenge, models.SHA256, models.OAuth) flowState.AuthCode = authCode require.NoError(ts.T(), ts.API.db.Create(flowState)) cases := []struct { @@ -344,7 +343,7 @@ func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusForbidden, w.Code) + assert.Equal(ts.T(), http.StatusNotFound, w.Code) }) } } @@ -619,7 +618,7 @@ func (ts *TokenTestSuite) TestPasswordVerificationHook() { begin return json_build_object('decision', 'reject'); end; $$ language plpgsql;`, - expectedCode: http.StatusForbidden, + expectedCode: http.StatusBadRequest, }, } for _, c := range cases { diff --git a/internal/api/user.go b/internal/api/user.go index 83d067a26..b9d28611f 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -46,7 +46,7 @@ func (a *API) validateUserUpdateParams(ctx context.Context, p *UserUpdateParams) p.Channel = sms_provider.SMSProvider } if !sms_provider.IsValidMessageChannel(p.Channel, config.Sms.Provider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } } @@ -64,12 +64,12 @@ func (a *API) UserGet(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() claims := getClaims(ctx) if claims == nil { - return badRequestError("Could not read claims") + return internalServerError("Could not read claims") } aud := a.requestAud(ctx, r) if aud != claims.Audience { - return badRequestError("Token audience doesn't match request audience") + return badRequestError(ErrorCodeValidationFailed, "Token audience doesn't match request audience") } user := getUser(ctx) @@ -87,11 +87,11 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read User Update params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read User Update params: %v", err) } user := getUser(ctx) @@ -103,7 +103,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if params.AppData != nil && !isAdmin(user, config) { if !isAdmin(user, config) { - return unauthorizedError("Updating app_metadata requires admin privileges") + return forbiddenError(ErrorCodeNotAdmin, "Updating app_metadata requires admin privileges") } } @@ -116,7 +116,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { updatingForbiddenFields = updatingForbiddenFields || (params.Nonce != "") if updatingForbiddenFields { - return unprocessableEntityError("Updating email, phone, password of a SSO account only possible via SSO") + return unprocessableEntityError(ErrorCodeUserSSOManaged, "Updating email, phone, password of a SSO account only possible via SSO") } } @@ -124,7 +124,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if duplicateUser, err := models.IsDuplicatedEmail(db, params.Email, aud, user); err != nil { return internalServerError("Database error checking email").WithInternalError(err) } else if duplicateUser != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } @@ -132,7 +132,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { return internalServerError("Database error checking phone").WithInternalError(err) } else if exists { - return unprocessableEntityError(DuplicatePhoneMsg) + return unprocessableEntityError(ErrorCodePhoneExists, DuplicatePhoneMsg) } } @@ -142,7 +142,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { // we require reauthentication if the user hasn't signed in recently in the current session if session == nil || now.After(session.CreatedAt.Add(24*time.Hour)) { if len(params.Nonce) == 0 { - return badRequestError("Password update requires reauthentication") + return badRequestError(ErrorCodeReauthenticationNeeded, "Password update requires reauthentication") } if err := a.verifyReauthentication(params.Nonce, db, config, user); err != nil { return err @@ -153,7 +153,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { password := *params.Password if password != "" { if user.EncryptedPassword != "" && user.Authenticate(ctx, password) { - return unprocessableEntityError("New password should be different from the old password.") + return unprocessableEntityError(ErrorCodeSamePassword, "New password should be different from the old password.") } } @@ -207,7 +207,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { externalURL := getExternalHost(ctx) if terr = a.sendEmailChange(tx, config, user, mailer, params.Email, referrer, externalURL, config.Mailer.OtpLength, flowType); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRate, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Error sending change email").WithInternalError(terr) } @@ -224,7 +224,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } else { smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Error finding SMS provider").WithInternalError(terr) } if _, terr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { return internalServerError("Error sending phone change otp").WithInternalError(terr) diff --git a/internal/api/user_test.go b/internal/api/user_test.go index c449cb02c..2496f15f4 100644 --- a/internal/api/user_test.go +++ b/internal/api/user_test.go @@ -274,7 +274,7 @@ func (ts *UserTestSuite) TestUserUpdatePassword() { nonce: "123456", requireReauthentication: true, sessionId: nil, - expected: expected{code: http.StatusBadRequest, isAuthenticated: false}, + expected: expected{code: http.StatusUnprocessableEntity, isAuthenticated: false}, }, { desc: "Valid password length", diff --git a/internal/api/verify.go b/internal/api/verify.go index ecacc1e9a..8f7c64b8f 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -54,18 +54,18 @@ type VerifyParams struct { func (p *VerifyParams) Validate(r *http.Request) error { var err error if p.Type == "" { - return badRequestError("Verify requires a verification type") + return badRequestError(ErrorCodeValidationFailed, "Verify requires a verification type") } switch r.Method { case http.MethodGet: if p.Token == "" { - return badRequestError("Verify requires a token or a token hash") + return badRequestError(ErrorCodeValidationFailed, "Verify requires a token or a token hash") } // TODO: deprecate the token query param from GET /verify and use token_hash instead (breaking change) p.TokenHash = p.Token case http.MethodPost: if (p.Token == "" && p.TokenHash == "") || (p.Token != "" && p.TokenHash != "") { - return badRequestError("Verify requires either a token or a token hash") + return badRequestError(ErrorCodeValidationFailed, "Verify requires either a token or a token hash") } if p.Token != "" { if isPhoneOtpVerification(p) { @@ -77,15 +77,15 @@ func (p *VerifyParams) Validate(r *http.Request) error { } else if isEmailOtpVerification(p) { p.Email, err = validateEmail(p.Email) if err != nil { - return unprocessableEntityError("Invalid email format").WithInternalError(err) + return unprocessableEntityError(ErrorCodeValidationFailed, "Invalid email format").WithInternalError(err) } p.TokenHash = crypto.GenerateTokenHash(p.Email, p.Token) } else { - return badRequestError("Only an email address or phone number should be provided on verify") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify") } } else if p.TokenHash != "" { if p.Email != "" || p.Phone != "" || p.RedirectTo != "" { - return badRequestError("Only the token_hash and type should be provided") + return badRequestError(ErrorCodeValidationFailed, "Only the token_hash and type should be provided") } } default: @@ -109,17 +109,18 @@ func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { case http.MethodPost: body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not parse verification params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not parse verification params: %v", err) } if err := params.Validate(r); err != nil { return err } return a.verifyPost(w, r, params) default: - return unprocessableEntityError("Only GET and POST methods are supported.") + // this should have been handled by Chi + panic("Only GET and POST methods allowed") } } @@ -171,7 +172,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa return nil } default: - return unprocessableEntityError("Unsupported verification type") + return badRequestError(ErrorCodeValidationFailed, "Unsupported verification type") } if terr != nil { @@ -189,7 +190,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa } } else if isPKCEFlow(flowType) { if authCode, terr = issueAuthCode(tx, user, a.config.External.FlowStateExpiryDuration, authenticationMethod); terr != nil { - return badRequestError("No associated flow state found. %s", terr) + return badRequestError(ErrorCodeFlowStateNotFound, "No associated flow state found. %s", terr) } } return nil @@ -262,7 +263,7 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyP case smsVerification, phoneChangeVerification: user, terr = a.smsVerify(r, ctx, tx, user, params) default: - return unprocessableEntityError("Unsupported verification type") + return badRequestError(ErrorCodeValidationFailed, "Unsupported verification type") } if terr != nil { @@ -298,7 +299,8 @@ func (a *API) signupVerify(r *http.Request, ctx context.Context, conn *storage.C // to present the user with a password set form password, err := password.Generate(64, 10, 0, false, true) if err != nil { - return nil, err + // password generation must succeed + panic(err) } if err := user.SetPassword(ctx, password); err != nil { @@ -437,14 +439,14 @@ func (a *API) prepErrorRedirectURL(err *HTTPError, w http.ResponseWriter, r *htt errorID := getRequestID(r.Context()) err.ErrorID = errorID log.WithError(err.Cause()).Info(err.Error()) - if str, ok := oauthErrorMap[err.Code]; ok { + if str, ok := oauthErrorMap[err.HTTPStatus]; ok { hq.Set("error", str) q.Set("error", str) } - hq.Set("error_code", strconv.Itoa(err.Code)) + hq.Set("error_code", strconv.Itoa(err.HTTPStatus)) hq.Set("error_description", err.Message) - q.Set("error_code", strconv.Itoa(err.Code)) + q.Set("error_code", strconv.Itoa(err.HTTPStatus)) q.Set("error_description", err.Message) if flowType == models.PKCEFlow { // Additionally, may override existing error query param if set to PKCE. @@ -565,18 +567,18 @@ func (a *API) verifyTokenHash(ctx context.Context, conn *storage.Connection, par case emailChangeVerification: user, err = models.FindUserByEmailChangeToken(conn, params.TokenHash) default: - return nil, badRequestError("Invalid email verification type") + return nil, badRequestError(ErrorCodeValidationFailed, "Invalid email verification type") } if err != nil { if models.IsNotFoundError(err) { - return nil, expiredTokenError("Email link is invalid or has expired").WithInternalError(err) + return nil, forbiddenError(ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalError(err) } return nil, internalServerError("Database error finding user from email link").WithInternalError(err) } if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalMessage("user is banned") + return nil, forbiddenError(ErrorCodeUserBanned, "User is banned") } var isExpired bool @@ -598,7 +600,7 @@ func (a *API) verifyTokenHash(ctx context.Context, conn *storage.Connection, par } if isExpired { - return nil, expiredTokenError("Email link is invalid or has expired").WithInternalMessage("email link has expired") + return nil, forbiddenError(ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalMessage("email link has expired") } return user, nil @@ -627,13 +629,13 @@ func (a *API) verifyUserAndToken(ctx context.Context, conn *storage.Connection, if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()).WithInternalError(err) + return nil, notFoundError(ErrorCodeUserNotFound, err.Error()).WithInternalError(err) } return nil, internalServerError("Database error finding user").WithInternalError(err) } if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalMessage("user is banned") + return nil, forbiddenError(ErrorCodeUserBanned, "User is banned") } var isValid bool @@ -674,7 +676,7 @@ func (a *API) verifyUserAndToken(ctx context.Context, conn *storage.Connection, } } if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(phone, params.Token); err != nil { - return nil, expiredTokenError("Token has expired or is invalid").WithInternalError(err) + return nil, forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } return user, nil } @@ -682,7 +684,7 @@ func (a *API) verifyUserAndToken(ctx context.Context, conn *storage.Connection, } if !isValid { - return nil, expiredTokenError("Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") + return nil, forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") } return user, nil } diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index 1cdb43ba9..a10efd604 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -280,9 +280,9 @@ func (ts *VerifyTestSuite) TestExpiredConfirmationToken() { f, err := url.ParseQuery(rurl.Fragment) require.NoError(ts.T(), err) - assert.Equal(ts.T(), "401", f.Get("error_code")) + assert.Equal(ts.T(), "403", f.Get("error_code")) assert.Equal(ts.T(), "Email link is invalid or has expired", f.Get("error_description")) - assert.Equal(ts.T(), "unauthorized_client", f.Get("error")) + assert.Equal(ts.T(), "access_denied", f.Get("error")) } func (ts *VerifyTestSuite) TestInvalidOtp() { @@ -302,7 +302,7 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { } expectedResponse := ResponseBody{ - Code: http.StatusUnauthorized, + Code: http.StatusForbidden, Msg: "Token has expired or is invalid", } @@ -313,7 +313,7 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { expected ResponseBody }{ { - desc: "Expired Sms OTP", + desc: "Expired SMS OTP", sentTime: time.Now().Add(-48 * time.Hour), body: map[string]interface{}{ "type": smsVerification, @@ -323,7 +323,7 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { expected: expectedResponse, }, { - desc: "Invalid Sms OTP", + desc: "Invalid SMS OTP", sentTime: time.Now(), body: map[string]interface{}{ "type": smsVerification, @@ -760,7 +760,7 @@ func (ts *VerifyTestSuite) TestVerifyBannedUser() { f, err := url.ParseQuery(rurl.Fragment) require.NoError(ts.T(), err) - assert.Equal(ts.T(), "401", f.Get("error_code")) + assert.Equal(ts.T(), "403", f.Get("error_code")) }) } } @@ -973,7 +973,7 @@ func (ts *VerifyTestSuite) TestSecureEmailChangeWithTokenHash() { "type": emailChangeVerification, "token_hash": currentEmailChangeToken, }, - expectedStatus: http.StatusUnauthorized, + expectedStatus: http.StatusForbidden, }, } for _, c := range cases { @@ -1103,7 +1103,7 @@ func (ts *VerifyTestSuite) TestPrepErrorRedirectURL() { ts.Run(c.desc, func() { w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) - rurl, err := ts.API.prepErrorRedirectURL(badRequestError(DefaultError), w, req, c.rurl, c.flowType) + rurl, err := ts.API.prepErrorRedirectURL(badRequestError(ErrorCodeValidationFailed, DefaultError), w, req, c.rurl, c.flowType) require.NoError(ts.T(), err) require.Equal(ts.T(), c.expected, rurl) }) @@ -1153,7 +1153,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { Token: "some-token", }, method: http.MethodPost, - expected: badRequestError("Only an email address or phone number should be provided on verify"), + expected: badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify"), }, { desc: "Cannot send both TokenHash and Token", @@ -1163,7 +1163,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { TokenHash: "some-token-hash", }, method: http.MethodPost, - expected: badRequestError("Verify requires either a token or a token hash"), + expected: badRequestError(ErrorCodeValidationFailed, "Verify requires either a token or a token hash"), }, { desc: "No verification type specified", @@ -1172,7 +1172,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { Email: "email@example.com", }, method: http.MethodPost, - expected: badRequestError("Verify requires a verification type"), + expected: badRequestError(ErrorCodeValidationFailed, "Verify requires a verification type"), }, } diff --git a/internal/models/flow_state.go b/internal/models/flow_state.go index 6aced0b59..d18a5bd32 100644 --- a/internal/models/flow_state.go +++ b/internal/models/flow_state.go @@ -81,7 +81,7 @@ func (FlowState) TableName() string { return tableName } -func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod) (*FlowState, error) { +func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod) *FlowState { id := uuid.Must(uuid.NewV4()) authCode := uuid.Must(uuid.NewV4()) flowState := &FlowState{ @@ -92,7 +92,7 @@ func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeCh AuthCode: authCode.String(), AuthenticationMethod: authenticationMethod.String(), } - return flowState, nil + return flowState } func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) error {