diff --git a/.github/workflows/conventional-commits-lint.js b/.github/workflows/conventional-commits-lint.js index a3815e5e7..96d6c9828 100644 --- a/.github/workflows/conventional-commits-lint.js +++ b/.github/workflows/conventional-commits-lint.js @@ -46,7 +46,12 @@ let failed = false; validate.forEach((payload) => { if (payload.title) { - const { groups } = payload.title.match(TITLE_PATTERN); + const match = payload.title.match(TITLE_PATTERN); + if (!match) { + return + } + + const { groups } = match if (groups) { if (groups.breaking) { diff --git a/internal/api/admin.go b/internal/api/admin.go index 89f7af975..f7acf3b45 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -50,7 +50,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, userID, err := uuid.FromString(chi.URLParam(r, "user_id")) if err != nil { - return nil, badRequestError("user_id must be an UUID") + return nil, notFoundError(ErrorCodeValidationFailed, "user_id must be an UUID") } observability.LogEntrySetField(r, "user_id", userID) @@ -58,7 +58,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, u, err := models.FindUserByID(db, userID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("User not found") + return nil, notFoundError(ErrorCodeUserNotFound, "User not found") } return nil, internalServerError("Database error loading user").WithInternalError(err) } @@ -69,7 +69,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) { factorID, err := uuid.FromString(chi.URLParam(r, "factor_id")) if err != nil { - return nil, badRequestError("factor_id must be an UUID") + return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID") } observability.LogEntrySetField(r, "factor_id", factorID) @@ -77,7 +77,7 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex f, err := models.FindFactorByFactorID(a.db, factorID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("Factor not found") + return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found") } return nil, internalServerError("Database error loading factor").WithInternalError(err) } @@ -101,12 +101,12 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { pageParams, err := paginate(r) if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err) } sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}}) if err != nil { - return badRequestError("Bad Sort Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err) } filter := r.URL.Query().Get("filter") @@ -160,7 +160,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError("invalid format for ban duration: %v", err) + return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) } } if terr := user.Ban(a.db, duration); terr != nil { @@ -308,7 +308,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { } if params.Email == "" && params.Phone == "" { - return unprocessableEntityError("Cannot create a user without either an email or phone") + return badRequestError(ErrorCodeValidationFailed, "Cannot create a user without either an email or phone") } var providers []string @@ -320,7 +320,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil { return internalServerError("Database error checking email").WithInternalError(err) } else if user != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } providers = append(providers, "email") } @@ -333,7 +333,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { return internalServerError("Database error checking phone").WithInternalError(err) } else if exists { - return unprocessableEntityError("Phone number already registered by another user") + return unprocessableEntityError(ErrorCodePhoneExists, "Phone number already registered by another user") } providers = append(providers, "phone") } @@ -429,7 +429,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError("invalid format for ban duration: %v", err) + return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) } } if terr := user.Ban(a.db, duration); terr != nil { @@ -460,11 +460,11 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { params := &adminUserDeleteParams{} body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if len(body) > 0 { if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read params: %v", err) } } else { params.ShouldSoftDelete = false @@ -559,6 +559,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro user := getUser(ctx) adminUser := getAdminUser(ctx) params := &adminUserUpdateFactorParams{} + if err := retrieveRequestParams(r, params); err != nil { return err } @@ -571,7 +572,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro } if params.FactorType != "" { if params.FactorType != models.TOTP { - return badRequestError("Factor Type not valid") + return badRequestError(ErrorCodeValidationFailed, "Factor Type not valid") } if terr := factor.UpdateFactorType(tx, params.FactorType); terr != nil { return terr diff --git a/internal/api/anonymous.go b/internal/api/anonymous.go index 5316525a4..4024d5947 100644 --- a/internal/api/anonymous.go +++ b/internal/api/anonymous.go @@ -15,7 +15,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error { aud := a.requestAud(ctx, r) if config.DisableSignup { - return forbiddenError("Signups not allowed for this instance") + return unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{} diff --git a/internal/api/api.go b/internal/api/api.go index 73d810fa2..edac716a6 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -155,7 +155,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati } if params.Email == "" && params.Phone == "" { if !api.config.External.AnonymousUsers.Enabled { - return unprocessableEntityError("Anonymous sign-ins are disabled") + return unprocessableEntityError(ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled") } if _, err := api.limitHandler(limiter)(w, r); err != nil { return err diff --git a/internal/api/apiversions.go b/internal/api/apiversions.go new file mode 100644 index 000000000..b5394a5fc --- /dev/null +++ b/internal/api/apiversions.go @@ -0,0 +1,35 @@ +package api + +import ( + "time" +) + +const APIVersionHeaderName = "X-Supabase-Api-Version" + +type APIVersion = time.Time + +var ( + APIVersionInitial = time.Time{} + APIVersion20240101 = time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC) +) + +func DetermineClosestAPIVersion(date string) (APIVersion, error) { + if date == "" { + return APIVersionInitial, nil + } + + parsed, err := time.ParseInLocation("2006-01-02", date, time.UTC) + if err != nil { + return APIVersionInitial, err + } + + if parsed.Compare(APIVersion20240101) >= 0 { + return APIVersion20240101, nil + } + + return APIVersionInitial, nil +} + +func FormatAPIVersion(apiVersion APIVersion) string { + return apiVersion.Format("2006-01-02") +} diff --git a/internal/api/apiversions_test.go b/internal/api/apiversions_test.go new file mode 100644 index 000000000..0a9622132 --- /dev/null +++ b/internal/api/apiversions_test.go @@ -0,0 +1,29 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDetermineClosestAPIVersion(t *testing.T) { + version, err := DetermineClosestAPIVersion("") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("Not a date") + require.Error(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2023-12-31") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2024-01-01") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) + + version, err = DetermineClosestAPIVersion("2024-01-02") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) +} diff --git a/internal/api/audit.go b/internal/api/audit.go index 2cb99c6e7..351a7d2cd 100644 --- a/internal/api/audit.go +++ b/internal/api/audit.go @@ -20,7 +20,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { // aud := a.requestAud(ctx, r) pageParams, err := paginate(r) if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err) } var col []string @@ -31,7 +31,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { qparts := strings.SplitN(q, ":", 2) col, exists = filterColumnMap[qparts[0]] if !exists || len(qparts) < 2 { - return badRequestError("Invalid query scope: %s", q) + return badRequestError(ErrorCodeValidationFailed, "Invalid query scope: %s", q) } qval = qparts[1] } diff --git a/internal/api/auth.go b/internal/api/auth.go index dbd0278bd..3e69d8c6c 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -39,7 +39,7 @@ func (a *API) requireNotAnonymous(w http.ResponseWriter, r *http.Request) (conte ctx := r.Context() claims := getClaims(ctx) if claims.IsAnonymous { - return nil, forbiddenError("Anonymous user not allowed to perform these actions") + return nil, forbiddenError(ErrorCodeNoAuthorization, "Anonymous user not allowed to perform these actions") } return ctx, nil } @@ -49,7 +49,7 @@ func (a *API) requireAdmin(ctx context.Context, r *http.Request) (context.Contex claims := getClaims(ctx) if claims == nil { fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token") - return nil, unauthorizedError("Invalid token") + return nil, forbiddenError(ErrorCodeBadJWT, "Invalid token") } adminRoles := a.config.JWT.AdminRoles @@ -60,14 +60,14 @@ func (a *API) requireAdmin(ctx context.Context, r *http.Request) (context.Contex } fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'") - return nil, unauthorizedError("User not allowed") + return nil, forbiddenError(ErrorCodeNotAdmin, "User not allowed") } func (a *API) extractBearerToken(r *http.Request) (string, error) { authHeader := r.Header.Get("Authorization") matches := bearerRegexp.FindStringSubmatch(authHeader) if len(matches) != 2 { - return "", unauthorizedError("This endpoint requires a Bearer token") + return "", httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "This endpoint requires a Bearer token") } return matches[1], nil @@ -82,7 +82,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e return []byte(config.JWT.Secret), nil }) if err != nil { - return nil, unauthorizedError("invalid JWT: unable to parse or verify signature, %v", err) + return nil, forbiddenError(ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err) } return withToken(ctx, token), nil @@ -93,23 +93,23 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro claims := getClaims(ctx) if claims == nil { - return ctx, unauthorizedError("invalid token: missing claims") + return ctx, forbiddenError(ErrorCodeBadJWT, "invalid token: missing claims") } if claims.Subject == "" { - return nil, unauthorizedError("invalid claim: missing sub claim") + return nil, forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim") } var user *models.User if claims.Subject != "" { userId, err := uuid.FromString(claims.Subject) if err != nil { - return ctx, badRequestError("invalid claim: sub claim must be a UUID").WithInternalError(err) + return ctx, badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err) } user, err = models.FindUserByID(db, userId) if err != nil { if models.IsNotFoundError(err) { - return ctx, notFoundError(err.Error()) + return ctx, forbiddenError(ErrorCodeUserNotFound, "User from sub claim in JWT does not exist") } return ctx, err } @@ -120,11 +120,11 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() { sessionId, err := uuid.FromString(claims.SessionId) if err != nil { - return ctx, err + return ctx, forbiddenError(ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err) } session, err = models.FindSessionByID(db, sessionId, false) if err != nil && !models.IsNotFoundError(err) { - return ctx, err + return ctx, forbiddenError(ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist") } ctx = withSession(ctx, session) } diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index 6c95a0bd9..f404e1cb7 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -96,7 +96,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { }, Role: "authenticated", }, - ExpectedError: unauthorizedError("invalid claim: missing sub claim"), + ExpectedError: forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim"), ExpectedUser: nil, }, { @@ -118,7 +118,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { }, Role: "authenticated", }, - ExpectedError: badRequestError("invalid claim: sub claim must be a UUID"), + ExpectedError: badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"), ExpectedUser: nil, }, { diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go new file mode 100644 index 000000000..45dec0dd7 --- /dev/null +++ b/internal/api/errorcodes.go @@ -0,0 +1,77 @@ +package api + +type ErrorCode = string + +const ( + // ErrorCodeUnknown should not be used directly, it only indicates a failure in the error handling system in such a way that an error code was not assigned properly. + ErrorCodeUnknown ErrorCode = "unknown" + + // ErrorCodeUnexpectedFailure signals an unexpected failure such as a 500 Internal Server Error. + ErrorCodeUnexpectedFailure ErrorCode = "unexpected_failure" + + ErrorCodeValidationFailed ErrorCode = "validation_failed" + ErrorCodeBadJSON ErrorCode = "bad_json" + ErrorCodeEmailExists ErrorCode = "email_exists" + ErrorCodePhoneExists ErrorCode = "phone_exists" + ErrorCodeBadJWT ErrorCode = "bad_jwt" + ErrorCodeNotAdmin ErrorCode = "not_admin" + ErrorCodeNoAuthorization ErrorCode = "no_authorization" + ErrorCodeUserNotFound ErrorCode = "user_not_found" + ErrorCodeSessionNotFound ErrorCode = "session_not_found" + ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found" + ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired" + ErrorCodeSignupDisabled ErrorCode = "signup_disabled" + ErrorCodeUserBanned ErrorCode = "user_banned" + ErrorCodeProviderEmailNeedsVerification ErrorCode = "provider_email_needs_verification" + ErrorCodeInviteNotFound ErrorCode = "invite_not_found" + ErrorCodeBadOAuthState ErrorCode = "bad_oauth_state" + ErrorCodeBadOAuthCallback ErrorCode = "bad_oauth_callback" + ErrorCodeOAuthProviderNotSupported ErrorCode = "oauth_provider_not_supported" + ErrorCodeUnexpectedAudience ErrorCode = "unexpected_audience" + ErrorCodeSingleIdentityNotDeletable ErrorCode = "single_identity_not_deletable" + ErrorCodeEmailConflictIdentityNotDeletable ErrorCode = "email_conflict_identity_not_deletable" + ErrorCodeIdentityAlreadyExists ErrorCode = "identity_already_exists" + ErrorCodeEmailProviderDisabled ErrorCode = "email_provider_disabled" + ErrorCodePhoneProviderDisabled ErrorCode = "phone_provider_disabled" + ErrorCodeTooManyEnrolledMFAFactors ErrorCode = "too_many_enrolled_mfa_factors" + ErrorCodeMFAFactorNameConflict ErrorCode = "mfa_factor_name_conflict" + ErrorCodeMFAFactorNotFound ErrorCode = "mfa_factor_not_found" + ErrorCodeMFAIPAddressMismatch ErrorCode = "mfa_ip_address_mismatch" + ErrorCodeMFAChallengeExpired ErrorCode = "mfa_challenge_expired" + ErrorCodeMFAVerificationFailed ErrorCode = "mfa_verification_failed" + ErrorCodeMFAVerificationRejected ErrorCode = "mfa_verification_rejected" + ErrorCodeInsufficientAAL ErrorCode = "insufficient_aal" + ErrorCodeCaptchaFailed ErrorCode = "captcha_failed" + ErrorCodeSAMLProviderDisabled ErrorCode = "saml_provider_disabled" + ErrorCodeManualLinkingDisabled ErrorCode = "manual_linking_disabled" + ErrorCodeSMSSendFailed ErrorCode = "sms_send_failed" + ErrorCodeEmailNotConfirmed ErrorCode = "email_not_confirmed" + ErrorCodePhoneNotConfirmed ErrorCode = "phone_not_confirmed" + ErrorCodeReauthNonceMissing ErrorCode = "reauth_nonce_missing" + ErrorCodeSAMLRelayStateNotFound ErrorCode = "saml_relay_state_not_found" + ErrorCodeSAMLRelayStateExpired ErrorCode = "saml_relay_state_expired" + ErrorCodeSAMLIdPNotFound ErrorCode = "saml_idp_not_found" + ErrorCodeSAMLAssertionNoUserID ErrorCode = "saml_assertion_no_user_id" + ErrorCodeSAMLAssertionNoEmail ErrorCode = "saml_assertion_no_email" + ErrorCodeUserAlreadyExists ErrorCode = "user_already_exists" + ErrorCodeSSOProviderNotFound ErrorCode = "sso_provider_not_found" + ErrorCodeSAMLMetadataFetchFailed ErrorCode = "saml_metadata_fetch_failed" + ErrorCodeSAMLIdPAlreadyExists ErrorCode = "saml_idp_already_exists" + ErrorCodeSSODomainAlreadyExists ErrorCode = "sso_domain_already_exists" + ErrorCodeSAMLEntityIDMismatch ErrorCode = "saml_entity_id_mismatch" + ErrorCodeConflict ErrorCode = "conflict" + ErrorCodeProviderDisabled ErrorCode = "provider_disabled" + ErrorCodeUserSSOManaged ErrorCode = "user_sso_managed" + ErrorCodeReauthenticationNeeded ErrorCode = "reauthentication_needed" + ErrorCodeSamePassword ErrorCode = "same_password" + ErrorCodeReauthenticationNotValid ErrorCode = "reauthentication_not_valid" + ErrorCodeOTPExpired ErrorCode = "otp_expired" + ErrorCodeOTPDisabled ErrorCode = "otp_disabled" + ErrorCodeIdentityNotFound ErrorCode = "identity_not_found" + ErrorCodeWeakPassword ErrorCode = "weak_password" + ErrorCodeOverRequestRateLimit ErrorCode = "over_request_rate_limit" + ErrorCodeOverEmailSendRateLimit ErrorCode = "over_email_send_rate_limit" + ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit" + ErrorBadCodeVerifier ErrorCode = "bad_code_verifier" + ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled" +) diff --git a/internal/api/errors.go b/internal/api/errors.go index 56f404e3c..cc6ba877b 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -8,7 +8,6 @@ import ( "runtime/debug" "github.com/pkg/errors" - "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/utilities" ) @@ -65,65 +64,43 @@ func (e *OAuthError) Cause() error { return e } -func invalidSignupError(config *conf.GlobalConfiguration) *HTTPError { - var msg string - if config.External.Email.Enabled && config.External.Phone.Enabled { - msg = "To signup, please provide your email or phone number" - } else if config.External.Email.Enabled { - msg = "To signup, please provide your email" - } else if config.External.Phone.Enabled { - msg = "To signup, please provide your phone number" - } else { - // 3rd party OAuth signups - msg = "To signup, please provide required fields" - } - return unprocessableEntityError(msg) -} - func oauthError(err string, description string) *OAuthError { return &OAuthError{Err: err, Description: description} } -func badRequestError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusBadRequest, fmtString, args...) +func badRequestError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusBadRequest, errorCode, fmtString, args...) } func internalServerError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusInternalServerError, fmtString, args...) -} - -func notFoundError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusNotFound, fmtString, args...) + return httpError(http.StatusInternalServerError, ErrorCodeUnexpectedFailure, fmtString, args...) } -func expiredTokenError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnauthorized, fmtString, args...) +func notFoundError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusNotFound, errorCode, fmtString, args...) } -func unauthorizedError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnauthorized, fmtString, args...) +func forbiddenError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusForbidden, errorCode, fmtString, args...) } -func forbiddenError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusForbidden, fmtString, args...) +func unprocessableEntityError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusUnprocessableEntity, errorCode, fmtString, args...) } -func unprocessableEntityError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusUnprocessableEntity, fmtString, args...) -} - -func tooManyRequestsError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusTooManyRequests, fmtString, args...) +func tooManyRequestsError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusTooManyRequests, errorCode, fmtString, args...) } func conflictError(fmtString string, args ...interface{}) *HTTPError { - return httpError(http.StatusConflict, fmtString, args...) + return httpError(http.StatusConflict, ErrorCodeConflict, fmtString, args...) } // HTTPError is an error with a message and an HTTP status code. type HTTPError struct { - Code int `json:"code"` - Message string `json:"msg"` + HTTPStatus int `json:"code"` // do not rename the JSON tags! + ErrorCode string `json:"error_code,omitempty"` // do not rename the JSON tags! + Message string `json:"msg"` // do not rename the JSON tags! InternalError error `json:"-"` InternalMessage string `json:"-"` ErrorID string `json:"error_id,omitempty"` @@ -133,7 +110,7 @@ func (e *HTTPError) Error() string { if e.InternalMessage != "" { return e.InternalMessage } - return fmt.Sprintf("%d: %s", e.Code, e.Message) + return fmt.Sprintf("%d: %s", e.HTTPStatus, e.Message) } func (e *HTTPError) Is(target error) bool { @@ -160,50 +137,12 @@ func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) * return e } -func httpError(code int, fmtString string, args ...interface{}) *HTTPError { +func httpError(httpStatus int, errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { return &HTTPError{ - Code: code, - Message: fmt.Sprintf(fmtString, args...), - } -} - -// OTPError is a custom error struct for phone auth errors -type OTPError struct { - Err string `json:"error"` - Description string `json:"error_description,omitempty"` - InternalError error `json:"-"` - InternalMessage string `json:"-"` -} - -func (e *OTPError) Error() string { - if e.InternalMessage != "" { - return e.InternalMessage - } - return fmt.Sprintf("%s: %s", e.Err, e.Description) -} - -// WithInternalError adds internal error information to the error -func (e *OTPError) WithInternalError(err error) *OTPError { - e.InternalError = err - return e -} - -// WithInternalMessage adds internal message information to the error -func (e *OTPError) WithInternalMessage(fmtString string, args ...interface{}) *OTPError { - e.InternalMessage = fmt.Sprintf(fmtString, args...) - return e -} - -// Cause returns the root cause error -func (e *OTPError) Cause() error { - if e.InternalError != nil { - return e.InternalError + HTTPStatus: httpStatus, + ErrorCode: errorCode, + Message: fmt.Sprintf(fmtString, args...), } - return e -} - -func otpError(err string, description string) *OTPError { - return &OTPError{Err: err, Description: description} } // Recoverer is a middleware that recovers from panics, logs the panic (and a @@ -222,10 +161,10 @@ func recoverer(w http.ResponseWriter, r *http.Request) (context.Context, error) } se := &HTTPError{ - Code: http.StatusInternalServerError, - Message: http.StatusText(http.StatusInternalServerError), + HTTPStatus: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), } - handleError(se, w, r) + HandleResponseError(se, w, r) } }() @@ -237,28 +176,61 @@ type ErrorCause interface { Cause() error } -func handleError(err error, w http.ResponseWriter, r *http.Request) { +type HTTPErrorResponse20240101 struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` +} + +func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { log := observability.GetLogEntry(r) errorID := getRequestID(r.Context()) + + apiVersion, averr := DetermineClosestAPIVersion(r.Header.Get(APIVersionHeaderName)) + if averr != nil { + log.WithError(averr).Warn("Invalid version passed to " + APIVersionHeaderName + " header, defaulting to initial version") + } else if apiVersion != APIVersionInitial { + // Echo back the determined API version from the request + w.Header().Set(APIVersionHeaderName, FormatAPIVersion(apiVersion)) + } + switch e := err.(type) { case *WeakPasswordError: - var output struct { - HTTPError - Payload struct { - Reasons []string `json:"reasons,omitempty"` - } `json:"weak_password,omitempty"` - } + if apiVersion.Compare(APIVersion20240101) >= 0 { + var output struct { + HTTPErrorResponse20240101 + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } - output.Code = http.StatusUnprocessableEntity - output.Message = e.Message - output.Payload.Reasons = e.Reasons + output.Code = ErrorCodeWeakPassword + output.Message = e.Message + output.Payload.Reasons = e.Reasons - if jsonErr := sendJSON(w, output.Code, output); jsonErr != nil { - handleError(jsonErr, w, r) + if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + + } else { + var output struct { + HTTPError + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } + + output.HTTPStatus = http.StatusUnprocessableEntity + output.ErrorCode = ErrorCodeWeakPassword + output.Message = e.Message + output.Payload.Reasons = e.Reasons + + if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } } case *HTTPError: - if e.Code >= http.StatusInternalServerError { + if e.HTTPStatus >= http.StatusInternalServerError { e.ErrorID = errorID // this will get us the stack trace too log.WithError(e.Cause()).Error(e.Error()) @@ -266,35 +238,76 @@ func handleError(err error, w http.ResponseWriter, r *http.Request) { log.WithError(e.Cause()).Info(e.Error()) } - // Provide better error messages for certain user-triggered Postgres errors. - if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { - if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil { - handleError(jsonErr, w, r) + if apiVersion.Compare(APIVersion20240101) >= 0 { + resp := HTTPErrorResponse20240101{ + Code: e.ErrorCode, + Message: e.Message, + } + + if resp.Code == "" { + if e.HTTPStatus == http.StatusInternalServerError { + resp.Code = ErrorCodeUnexpectedFailure + } else { + resp.Code = ErrorCodeUnknown + } + } + + if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + } else { + if e.ErrorCode == "" { + if e.HTTPStatus == http.StatusInternalServerError { + e.ErrorCode = ErrorCodeUnexpectedFailure + } else { + e.ErrorCode = ErrorCodeUnknown + } } - return - } - if jsonErr := sendJSON(w, e.Code, e); jsonErr != nil { - handleError(jsonErr, w, r) + // Provide better error messages for certain user-triggered Postgres errors. + if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { + if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + return + } + + if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } } + case *OAuthError: log.WithError(e.Cause()).Info(e.Error()) if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { - handleError(jsonErr, w, r) - } - case *OTPError: - log.WithError(e.Cause()).Info(e.Error()) - if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { - handleError(jsonErr, w, r) + HandleResponseError(jsonErr, w, r) } + case ErrorCause: - handleError(e.Cause(), w, r) + HandleResponseError(e.Cause(), w, r) + default: log.WithError(e).Errorf("Unhandled server error: %s", e.Error()) - // hide real error details from response to prevent info leaks - w.WriteHeader(http.StatusInternalServerError) - if _, writeErr := w.Write([]byte(`{"code":500,"msg":"Internal server error","error_id":"` + errorID + `"}`)); writeErr != nil { - log.WithError(writeErr).Error("Error writing generic error message") + + if apiVersion.Compare(APIVersion20240101) >= 0 { + resp := HTTPErrorResponse20240101{ + Code: ErrorCodeUnexpectedFailure, + Message: "Unexpected failure, please check server logs for more information", + } + + if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + } else { + httpError := HTTPError{ + HTTPStatus: http.StatusInternalServerError, + ErrorCode: ErrorCodeUnexpectedFailure, + Message: "Unexpected failure, please check server logs for more information", + } + + if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } } } } diff --git a/internal/api/errors_test.go b/internal/api/errors_test.go new file mode 100644 index 000000000..fc6135205 --- /dev/null +++ b/internal/api/errors_test.go @@ -0,0 +1,64 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHandleResponseErrorWithHTTPError(t *testing.T) { + examples := []struct { + HTTPError *HTTPError + APIVersion string + ExpectedBody string + }{ + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "", + ExpectedBody: "{\"code\":400,\"error_code\":\"" + ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", + }, + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2023-12-31", + ExpectedBody: "{\"code\":400,\"error_code\":\"" + ErrorCodeBadJSON + "\",\"msg\":\"Unable to parse JSON\"}", + }, + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeBadJSON + "\",\"message\":\"Unable to parse JSON\"}", + }, + { + HTTPError: &HTTPError{ + HTTPStatus: http.StatusBadRequest, + Message: "Uncoded failure", + }, + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeUnknown + "\",\"message\":\"Uncoded failure\"}", + }, + { + HTTPError: &HTTPError{ + HTTPStatus: http.StatusInternalServerError, + Message: "Unexpected failure", + }, + APIVersion: "2024-01-01", + ExpectedBody: "{\"code\":\"" + ErrorCodeUnexpectedFailure + "\",\"message\":\"Unexpected failure\"}", + }, + } + + for _, example := range examples { + rec := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com", nil) + require.NoError(t, err) + + if example.APIVersion != "" { + req.Header.Set(APIVersionHeaderName, example.APIVersion) + } + + HandleResponseError(example.HTTPError, rec, req) + + require.Equal(t, example.HTTPError.HTTPStatus, rec.Code) + require.Equal(t, example.ExpectedBody, rec.Body.String()) + } +} diff --git a/internal/api/external.go b/internal/api/external.go index 177d20059..8fa27f4ae 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -56,7 +56,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ p, err := a.Provider(ctx, providerType, scopes) if err != nil { - return "", badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return "", badRequestError(ErrorCodeValidationFailed, "Unsupported provider: %+v", err).WithInternalError(err) } inviteToken := query.Get("invite_token") @@ -64,7 +64,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ _, userErr := models.FindUserByConfirmationToken(db, inviteToken) if userErr != nil { if models.IsNotFoundError(userErr) { - return "", notFoundError(userErr.Error()) + return "", notFoundError(ErrorCodeUserNotFound, "User identified by token not found") } return "", internalServerError("Database error finding user").WithInternalError(userErr) } @@ -127,6 +127,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ } authURL := p.AuthCodeURL(tokenString, authUrlParams...) + return authURL, nil } @@ -196,9 +197,12 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re // if there's a non-empty FlowStateID we perform PKCE Flow if flowStateID := getFlowStateID(ctx); flowStateID != "" { flowState, err = models.FindFlowStateByID(a.db, flowStateID) - if err != nil { - return err + if models.IsNotFoundError(err) { + return unprocessableEntityError(ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err) + } else if err != nil { + return internalServerError("Failed to find flow state").WithInternalError(err) } + } var user *models.User @@ -300,7 +304,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. case models.CreateAccount: if config.DisableSignup { - return nil, forbiddenError("Signups not allowed for this instance") + return nil, unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{ @@ -347,14 +351,14 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } case models.MultipleAccounts: - return nil, internalServerError(fmt.Sprintf("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain)) + return nil, internalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain) default: - return nil, internalServerError(fmt.Sprintf("Unknown automatic linking decision: %v", decision.Decision)) + return nil, internalServerError("Unknown automatic linking decision: %v", decision.Decision) } if user.IsBanned() { - return nil, unauthorizedError("User is unauthorized") + return nil, forbiddenError(ErrorCodeUserBanned, "User is banned") } if !user.IsConfirmed() { @@ -383,7 +387,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. externalURL := getExternalHost(ctx) if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { - return nil, tooManyRequestsError("For security purposes, you can only request this once every minute") + return nil, tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "For security purposes, you can only request this once every minute") } return nil, internalServerError("Error sending confirmation mail").WithInternalError(terr) } @@ -391,9 +395,9 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } if !config.Mailer.AllowUnverifiedEmailSignIns { if emailConfirmationSent { - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) + return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) } - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) + return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) } } } else { @@ -411,7 +415,7 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p user, err := models.FindUserByConfirmationToken(tx, inviteToken) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()) + return nil, notFoundError(ErrorCodeInviteNotFound, "Invite not found") } return nil, internalServerError("Database error finding user").WithInternalError(err) } @@ -427,7 +431,7 @@ func (a *API) processInvite(r *http.Request, tx *storage.Connection, userData *p } if emailData == nil { - return nil, badRequestError("Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) + return nil, badRequestError(ErrorCodeValidationFailed, "Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) } var identityData map[string]interface{} @@ -480,8 +484,11 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont _, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) { return []byte(config.JWT.Secret), nil }) - if err != nil || claims.Provider == "" { - return nil, badRequestError("OAuth state is invalid: %v", err) + if err != nil { + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) + } + if claims.Provider == "" { + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)") } if claims.InviteToken != "" { ctx = withInviteToken(ctx, claims.InviteToken) @@ -495,12 +502,12 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont if claims.LinkingTargetID != "" { linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) if err != nil { - return nil, badRequestError("invalid target user id") + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)") } u, err := models.FindUserByID(a.db, linkingTargetUserID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("Linking target user not found") + return nil, unprocessableEntityError(ErrorCodeUserNotFound, "Linking target user not found") } return nil, internalServerError("Database error loading user").WithInternalError(err) } @@ -591,12 +598,18 @@ func (a *API) redirectErrors(handler apiHandler, w http.ResponseWriter, r *http. func getErrorQueryString(err error, errorID string, log logrus.FieldLogger, q url.Values) *url.Values { switch e := err.(type) { case *HTTPError: - if str, ok := oauthErrorMap[e.Code]; ok { + if e.ErrorCode == ErrorCodeSignupDisabled { + q.Set("error", "access_denied") + } else if e.ErrorCode == ErrorCodeUserBanned { + q.Set("error", "access_denied") + } else if e.ErrorCode == ErrorCodeProviderEmailNeedsVerification { + q.Set("error", "access_denied") + } else if str, ok := oauthErrorMap[e.HTTPStatus]; ok { q.Set("error", str) } else { q.Set("error", "server_error") } - if e.Code >= http.StatusInternalServerError { + if e.HTTPStatus >= http.StatusInternalServerError { e.ErrorID = errorID // this will get us the stack trace too log.WithError(e.Cause()).Error(e.Error()) @@ -604,7 +617,7 @@ func getErrorQueryString(err error, errorID string, log logrus.FieldLogger, q ur log.WithError(e.Cause()).Info(e.Error()) } q.Set("error_description", e.Message) - q.Set("error_code", strconv.Itoa(e.Code)) + q.Set("error_code", strconv.Itoa(e.HTTPStatus)) case *OAuthError: q.Set("error", e.Err) q.Set("error_description", e.Description) diff --git a/internal/api/external_figma_test.go b/internal/api/external_figma_test.go index 56d2f478d..bd7a8c29c 100644 --- a/internal/api/external_figma_test.go +++ b/internal/api/external_figma_test.go @@ -260,5 +260,5 @@ func (ts *ExternalTestSuite) TestSignupExternalFigmaErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "figma", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_fly_test.go b/internal/api/external_fly_test.go index c469f2900..3c33c53e2 100644 --- a/internal/api/external_fly_test.go +++ b/internal/api/external_fly_test.go @@ -260,5 +260,5 @@ func (ts *ExternalTestSuite) TestSignupExternalFlyErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "fly", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_github_test.go b/internal/api/external_github_test.go index b3ad58440..f6f4334d7 100644 --- a/internal/api/external_github_test.go +++ b/internal/api/external_github_test.go @@ -276,7 +276,7 @@ func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenVerifiedFalse() { u := performAuthorization(ts, "github", code, "") - assertAuthorizationFailure(ts, u, "Unverified email with github. A confirmation email has been sent to your github email", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "Unverified email with github. A confirmation email has been sent to your github email", "access_denied", "") } func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenUserBanned() { @@ -296,5 +296,5 @@ func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "github", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_kakao_test.go b/internal/api/external_kakao_test.go index 7882e1dce..cd2bd2b29 100644 --- a/internal/api/external_kakao_test.go +++ b/internal/api/external_kakao_test.go @@ -214,7 +214,7 @@ func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenVerifiedFalse() { u := performAuthorization(ts, "kakao", code, "") - assertAuthorizationFailure(ts, u, "Unverified email with kakao. A confirmation email has been sent to your kakao email", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "Unverified email with kakao. A confirmation email has been sent to your kakao email", "access_denied", "") } func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenUserBanned() { @@ -234,5 +234,5 @@ func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "kakao", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "access_denied", "") } diff --git a/internal/api/external_oauth.go b/internal/api/external_oauth.go index 5352299ac..6c0972ea8 100644 --- a/internal/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -2,6 +2,7 @@ package api import ( "context" + "fmt" "net/http" "net/url" @@ -30,7 +31,7 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con } if state == "" { - return nil, badRequestError("OAuth state parameter missing") + return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing") } ctx := r.Context() @@ -60,12 +61,12 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s oauthCode := rq.Get("code") if oauthCode == "" { - return nil, badRequestError("Authorization code missing") + return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing") } oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } log := observability.GetLogEntry(r) @@ -107,7 +108,7 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s func (a *API) oAuth1Callback(ctx context.Context, providerType string) (*OAuthProviderData, error) { oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } oauthToken := getRequestToken(ctx) oauthVerifier := getOAuthVerifier(ctx) @@ -145,6 +146,6 @@ func (a *API) OAuthProvider(ctx context.Context, name string) (provider.OAuthPro case provider.OAuthProvider: return p, nil default: - return nil, badRequestError("Provider can not be used for OAuth") + return nil, fmt.Errorf("Provider %v cannot be used for OAuth", name) } } diff --git a/internal/api/helpers.go b/internal/api/helpers.go index ea4102f2e..d771dca40 100644 --- a/internal/api/helpers.go +++ b/internal/api/helpers.go @@ -109,7 +109,7 @@ func retrieveRequestParams[A RequestParams](r *http.Request, params *A) error { return internalServerError("Could not read body into byte slice").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read request body: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not parse request body as JSON: %v", err) } return nil } diff --git a/internal/api/helpers_test.go b/internal/api/helpers_test.go index ec5812e09..15f9ce4d6 100644 --- a/internal/api/helpers_test.go +++ b/internal/api/helpers_test.go @@ -16,12 +16,12 @@ func TestIsValidCodeChallenge(t *testing.T) { { challenge: "invalid", isValid: false, - expectedError: badRequestError("code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), + expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), }, { challenge: "codechallengecontainsinvalidcharacterslike@$^&*", isValid: false, - expectedError: badRequestError("code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), + expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), }, { challenge: "validchallengevalidchallengevalidchallengevalidchallenge", @@ -56,12 +56,12 @@ func TestIsValidPKCEParmas(t *testing.T) { { challengeMethod: "test", challenge: "", - expected: badRequestError(InvalidPKCEParamsErrorMessage), + expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), }, { challengeMethod: "", challenge: "test", - expected: badRequestError(InvalidPKCEParamsErrorMessage), + expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), }, } diff --git a/internal/api/hooks.go b/internal/api/hooks.go index f3a9e11f1..5368339d8 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -80,8 +80,8 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -106,8 +106,8 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -132,8 +132,8 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -146,8 +146,8 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out } httpError := &HTTPError{ - Code: httpCode, - Message: err.Error(), + HTTPStatus: httpCode, + Message: err.Error(), } return httpError diff --git a/internal/api/identity.go b/internal/api/identity.go index f47708555..858810f70 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -2,7 +2,6 @@ package api import ( "context" - "fmt" "net/http" "github.com/fatih/structs" @@ -20,22 +19,22 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { claims := getClaims(ctx) if claims == nil { - return badRequestError("Could not read claims") - } - - aud := a.requestAud(ctx, r) - if aud != claims.Audience { - return badRequestError("Token audience doesn't match request audience") + return internalServerError("Could not read claims") } identityID, err := uuid.FromString(chi.URLParam(r, "identity_id")) if err != nil { - return badRequestError("identity_id must be an UUID") + return notFoundError(ErrorCodeValidationFailed, "identity_id must be an UUID") + } + + aud := a.requestAud(ctx, r) + if aud != claims.Audience { + return forbiddenError(ErrorCodeUnexpectedAudience, "Token audience doesn't match request audience") } user := getUser(ctx) if len(user.Identities) <= 1 { - return badRequestError("User must have at least 1 identity after unlinking") + return unprocessableEntityError(ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking") } var identityToBeDeleted *models.Identity for i := range user.Identities { @@ -46,7 +45,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { } } if identityToBeDeleted == nil { - return badRequestError("Identity doesn't exist") + return unprocessableEntityError(ErrorCodeIdentityNotFound, "Identity doesn't exist") } err = a.db.Transaction(func(tx *storage.Connection) error { @@ -73,7 +72,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { default: if terr := user.UpdateUserEmailFromIdentities(tx); terr != nil { if models.IsUniqueConstraintViolatedError(terr) { - return forbiddenError("Unable to unlink identity due to email conflict").WithInternalError(terr) + return unprocessableEntityError(ErrorCodeEmailConflictIdentityNotDeletable, "Unable to unlink identity due to email conflict").WithInternalError(terr) } return internalServerError("Database error updating user email").WithInternalError(terr) } @@ -117,9 +116,9 @@ func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *stora } if identity != nil { if identity.UserID == targetUser.ID { - return nil, badRequestError("Identity is already linked") + return nil, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked") } - return nil, badRequestError("Identity is already linked to another user") + return nil, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked to another user") } if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil { return nil, terr @@ -128,7 +127,7 @@ func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *stora if targetUser.GetEmail() == "" { if terr := targetUser.UpdateUserEmailFromIdentities(tx); terr != nil { if models.IsUniqueConstraintViolatedError(terr) { - return nil, badRequestError(DuplicateEmailMsg) + return nil, badRequestError(ErrorCodeEmailExists, DuplicateEmailMsg) } return nil, terr } @@ -138,10 +137,10 @@ func (a *API) linkIdentityToUser(r *http.Request, ctx context.Context, tx *stora externalURL := getExternalHost(ctx) if terr := sendConfirmation(tx, targetUser, mailer, a.config.SMTP.MaxFrequency, referrer, externalURL, a.config.Mailer.OtpLength, models.ImplicitFlow); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { - return nil, tooManyRequestsError("For security purposes, you can only request this once every minute") + return nil, tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "For security purposes, you can only request this once every minute") } } - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) + return nil, storage.NewCommitWithError(unprocessableEntityError(ErrorCodeEmailNotConfirmed, "Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType)) } if terr := targetUser.Confirm(tx); terr != nil { return nil, terr diff --git a/internal/api/identity_test.go b/internal/api/identity_test.go index 7f70af416..2b193cd21 100644 --- a/internal/api/identity_test.go +++ b/internal/api/identity_test.go @@ -101,7 +101,7 @@ func (ts *IdentityTestSuite) TestLinkIdentityToUser() { }, } u, err = ts.API.linkIdentityToUser(r, ctx, ts.API.db, testExistingUserData, "email") - require.ErrorIs(ts.T(), err, badRequestError("Identity is already linked")) + require.ErrorIs(ts.T(), err, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked")) require.Nil(ts.T(), u) } @@ -122,13 +122,13 @@ func (ts *IdentityTestSuite) TestUnlinkIdentityError() { desc: "User must have at least 1 identity after unlinking", user: userWithOneIdentity, identityId: userWithOneIdentity.Identities[0].ID, - expectedError: badRequestError("User must have at least 1 identity after unlinking"), + expectedError: unprocessableEntityError(ErrorCodeSingleIdentityNotDeletable, "User must have at least 1 identity after unlinking"), }, { desc: "Identity doesn't exist", user: userWithTwoIdentities, identityId: uuid.Must(uuid.NewV4()), - expectedError: badRequestError("Identity doesn't exist"), + expectedError: unprocessableEntityError(ErrorCodeIdentityNotFound, "Identity doesn't exist"), }, } @@ -141,7 +141,7 @@ func (ts *IdentityTestSuite) TestUnlinkIdentityError() { w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), c.expectedError.Code, w.Code) + require.Equal(ts.T(), c.expectedError.HTTPStatus, w.Code) var data HTTPError require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) diff --git a/internal/api/invite.go b/internal/api/invite.go index 45d94878c..2e912b79c 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -42,7 +42,7 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { err = db.Transaction(func(tx *storage.Connection) error { if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } else { signupParams := SignupParams{ diff --git a/internal/api/invite_test.go b/internal/api/invite_test.go index 466682028..c525e8747 100644 --- a/internal/api/invite_test.go +++ b/internal/api/invite_test.go @@ -162,7 +162,7 @@ func (ts *InviteTestSuite) TestInvite_WithoutAccess() { w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) + assert.Equal(ts.T(), http.StatusUnauthorized, w.Code) // 401 OK because the invite request above has no Authorization header } func (ts *InviteTestSuite) TestVerifyInvite() { diff --git a/internal/api/logout.go b/internal/api/logout.go index ad95b22a4..cd1394eda 100644 --- a/internal/api/logout.go +++ b/internal/api/logout.go @@ -36,7 +36,7 @@ func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { scope = LogoutOthers default: - return badRequestError(fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) + return badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) } } diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index ddd3dba37..c0aaded7a 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -24,7 +25,7 @@ type MagicLinkParams struct { func (p *MagicLinkParams) Validate() error { if p.Email == "" { - return unprocessableEntityError("Password recovery requires an email") + return unprocessableEntityError(ErrorCodeValidationFailed, "Password recovery requires an email") } var err error p.Email, err = validateEmail(p.Email) @@ -44,14 +45,14 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } params := &MagicLinkParams{} jsonDecoder := json.NewDecoder(r.Body) err := jsonDecoder.Decode(params) if err != nil { - return badRequestError("Could not read verification params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read verification params: %v", err).WithInternalError(err) } if err := params.Validate(); err != nil { @@ -82,7 +83,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { // Sign them up with temporary password. password, err := password.Generate(64, 10, 1, false, true) if err != nil { - internalServerError("error creating user").WithInternalError(err) + return internalServerError("error creating user").WithInternalError(err) } signUpParams := &SignupParams{ @@ -94,7 +95,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) } r.Body = io.NopCloser(strings.NewReader(string(newBodyContent))) r.ContentLength = int64(len(string(newBodyContent))) @@ -113,7 +115,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { } metadata, err := json.Marshal(newBodyContent) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) } r.Body = io.NopCloser(bytes.NewReader(metadata)) return a.MagicLink(w, r) @@ -143,7 +146,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Error sending magic link").WithInternalError(err) } diff --git a/internal/api/mail.go b/internal/api/mail.go index 0ab561ab7..448f5a038 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -66,14 +66,17 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) if err != nil { if models.IsNotFoundError(err) { - if params.Type == magicLinkVerification { + switch params.Type { + case magicLinkVerification: params.Type = signupVerification params.Password, err = password.Generate(64, 10, 1, false, true) if err != nil { - return internalServerError("error creating user").WithInternalError(err) + // password generation must always succeed + panic(err) } - } else if params.Type == recoveryVerification || params.Type == "email_change_current" || params.Type == "email_change_new" { - return notFoundError(err.Error()) + + default: + return notFoundError(ErrorCodeUserNotFound, "User with this email not found") } } else { return internalServerError("Database error finding user").WithInternalError(err) @@ -84,7 +87,8 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { now := time.Now() otp, err := crypto.GenerateOtp(config.Mailer.OtpLength) if err != nil { - return err + // OTP generation must always succeed + panic(err) } hashedToken := crypto.GenerateTokenHash(params.Email, otp) @@ -118,11 +122,14 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } user.RecoveryToken = hashedToken user.RecoverySentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + terr = tx.UpdateOnly(user, "recovery_token", "recovery_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for recovery") + } case inviteVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } else { signupParams := &SignupParams{ @@ -162,11 +169,14 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { user.ConfirmationToken = hashedToken user.ConfirmationSentAt = &now user.InvitedAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for invite") + } case signupVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } if err := user.UpdateUserMetaData(tx, params.Data); err != nil { return internalServerError("Database error updating user").WithInternalError(err) @@ -191,19 +201,22 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } user.ConfirmationToken = hashedToken user.ConfirmationSentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for confirmation") + } case "email_change_current", "email_change_new": if !config.Mailer.SecureEmailChangeEnabled && params.Type == "email_change_current" { - return unprocessableEntityError("Enable secure email change to generate link for current email") + return badRequestError(ErrorCodeValidationFailed, "Enable secure email change to generate link for current email") } params.NewEmail, terr = validateEmail(params.NewEmail) if terr != nil { - return unprocessableEntityError("The new email address provided is invalid") + return terr } if duplicateUser, terr := models.IsDuplicatedEmail(tx, params.NewEmail, user.Aud, user); terr != nil { return internalServerError("Database error checking email").WithInternalError(terr) } else if duplicateUser != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } now := time.Now() user.EmailChangeSentAt = &now @@ -214,9 +227,12 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } else if params.Type == "email_change_new" { user.EmailChangeTokenNew = crypto.GenerateTokenHash(params.NewEmail, otp) } - terr = errors.Wrap(tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status"), "Database error updating user for email change") + terr = tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for email change") + } default: - return badRequestError("Invalid email action link type requested: %v", params.Type) + return badRequestError(ErrorCodeValidationFailed, "Invalid email action link type requested: %v", params.Type) } if terr != nil { @@ -255,7 +271,8 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail oldToken := u.ConfirmationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.ConfirmationToken = addFlowPrefixToToken(token, flowType) @@ -265,7 +282,12 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail return errors.Wrap(err, "Error sending confirmation email") } u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for confirmation") + } + + return nil } func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, externalURL *url.URL, otpLength int) error { @@ -273,7 +295,8 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re oldToken := u.ConfirmationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.ConfirmationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) now := time.Now() @@ -283,7 +306,12 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re } u.InvitedAt = &now u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for invite") + } + + return nil } func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { @@ -295,7 +323,8 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile oldToken := u.RecoveryToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) @@ -305,7 +334,12 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile return errors.Wrap(err, "Error sending recovery email") } u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for recovery") + } + + return nil } func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, otpLength int) error { @@ -317,7 +351,8 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma oldToken := u.ReauthenticationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) now := time.Now() @@ -326,7 +361,12 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma return errors.Wrap(err, "Error sending reauthentication email") } u.ReauthenticationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"), "Database error updating user for reauthentication") + err = tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for reauthentication") + } + + return nil } func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { @@ -339,7 +379,8 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile oldToken := u.RecoveryToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) @@ -350,7 +391,12 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile return errors.Wrap(err, "Error sending magic link email") } u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for recovery") + } + + return nil } // sendEmailChange sends out an email change token to the new email. @@ -361,7 +407,8 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu } otpNew, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.EmailChange = email token := crypto.GenerateTokenHash(u.EmailChange, otpNew) @@ -371,7 +418,8 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu if config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { otpCurrent, err = crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } currentToken := crypto.GenerateTokenHash(u.GetEmail(), otpCurrent) u.EmailChangeTokenCurrent = addFlowPrefixToToken(currentToken, flowType) @@ -384,22 +432,28 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu } u.EmailChangeSentAt = &now - return errors.Wrap(tx.UpdateOnly( + err = tx.UpdateOnly( u, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status", - ), "Database error updating user for email change") + ) + + if err != nil { + return errors.Wrap(err, "Database error updating user for email change") + } + + return nil } func validateEmail(email string) (string, error) { if email == "" { - return "", unprocessableEntityError("An email address is required") + return "", badRequestError(ErrorCodeValidationFailed, "An email address is required") } if err := checkmail.ValidateFormat(email); err != nil { - return "", unprocessableEntityError("Unable to validate email address: " + err.Error()) + return "", badRequestError(ErrorCodeValidationFailed, "Unable to validate email address: "+err.Error()) } return strings.ToLower(email), nil } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 6bdb3e596..3919cb781 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -73,11 +73,11 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { return err } - issuer := "" if params.FactorType != models.TOTP { - return badRequestError("factor_type needs to be totp") + return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be totp") } + issuer := "" if params.Issuer == "" { u, err := url.ParseRequestURI(config.SiteURL) if err != nil { @@ -103,15 +103,15 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if factorCount >= int(config.MFA.MaxEnrolledFactors) { - return forbiddenError("Enrolled factors exceed allowed limit, unenroll to continue") + return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") } if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return forbiddenError("Maximum number of verified factors reached, unenroll to continue") + return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") } if numVerifiedFactors > 0 && !session.IsAAL2() { - return forbiddenError("AAL2 required to enroll a new factor") + return forbiddenError(ErrorCodeInsufficientAAL, "AAL2 required to enroll a new factor") } key, err := totp.Generate(totp.GenerateOpts{ @@ -138,7 +138,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { if terr := tx.Create(factor); terr != nil { pgErr := utilities.NewPostgresError(terr) if pgErr.IsUniqueConstraintViolated() { - return badRequestError(fmt.Sprintf("a factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) + return unprocessableEntityError(ErrorCodeMFAFactorNameConflict, fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) } return terr @@ -216,20 +216,20 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { challenge, err := models.FindChallengeByID(a.db, params.ChallengeID) if err != nil && models.IsNotFoundError(err) { - return notFoundError(err.Error()) + return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found") } else if err != nil { return internalServerError("Database error finding Challenge").WithInternalError(err) } if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { - return badRequestError("Challenge and verify IP addresses mismatch") + return unprocessableEntityError(ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch") } if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { if err := a.db.Destroy(challenge); err != nil { return internalServerError("Database error deleting challenge").WithInternalError(err) } - return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID) + return unprocessableEntityError(ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) } valid := totp.Validate(params.Code, factor.Secret) @@ -257,11 +257,11 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { output.Message = hooks.DefaultMFAHookRejectionMessage } - return forbiddenError(output.Message) + return forbiddenError(ErrorCodeMFAVerificationRejected, output.Message) } } if !valid { - return badRequestError("Invalid TOTP code entered") + return unprocessableEntityError(ErrorCodeMFAVerificationFailed, "Invalid TOTP code entered") } var token *AccessTokenResponse @@ -322,7 +322,7 @@ func (a *API) UnenrollFactor(w http.ResponseWriter, r *http.Request) error { } if factor.IsVerified() && !session.IsAAL2() { - return badRequestError("AAL2 required to unenroll verified factor") + return unprocessableEntityError(ErrorCodeInsufficientAAL, "AAL2 required to unenroll verified factor") } if !factor.IsOwnedBy(user) { return internalServerError(InvalidFactorOwnerErrorMessage) diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index bb3c91968..39ec9f2cc 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -168,16 +168,15 @@ func (ts *MFATestSuite) TestDuplicateEnrollsReturnExpectedMessage() { issuer := "https://issuer.com" token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) _ = performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, http.StatusOK) - response := performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, http.StatusBadRequest) + response := performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, http.StatusUnprocessableEntity) var errorResponse HTTPError err := json.NewDecoder(response.Body).Decode(&errorResponse) require.NoError(ts.T(), err) // Convert the response body to a string and check for the expected error message - expectedErrorMessage := fmt.Sprintf("a factor with the friendly name %q for this user likely already exists", friendlyName) + expectedErrorMessage := fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", friendlyName) require.Contains(ts.T(), errorResponse.Message, expectedErrorMessage) - } func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { @@ -226,13 +225,13 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { desc: "Invalid: Valid code and expired challenge", validChallenge: false, validCode: true, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Invalid: Invalid code and valid challenge ", validChallenge: true, validCode: false, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Valid /verify request", @@ -309,7 +308,7 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() { { desc: "Verified Factor: AAL1", isAAL2: false, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Verified Factor: AAL2, Success", diff --git a/internal/api/middleware.go b/internal/api/middleware.go index ab72e32c0..6a6d68a25 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -66,7 +66,7 @@ func (a *API) limitHandler(lmt *limiter.Limiter) middlewareHandler { } else { err := tollbooth.LimitByKeys(lmt, []string{key}) if err != nil { - return c, httpError(http.StatusTooManyRequests, "Rate limit exceeded") + return c, tooManyRequestsError(ErrorCodeOverRequestRateLimit, "Request rate limit reached") } } } @@ -101,7 +101,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { } if err := retrieveRequestParams(req, &requestBody); err != nil { - return c, badRequestError("Error invalid request body").WithInternalError(err) + return c, err } if shouldRateLimitEmail { @@ -112,7 +112,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { 1, attribute.String("path", req.URL.Path), ) - return c, httpError(http.StatusTooManyRequests, "Email rate limit exceeded") + return c, tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "Email rate limit exceeded") } } } @@ -120,7 +120,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { if shouldRateLimitPhone { if requestBody.Phone != "" { if err := tollbooth.LimitByKeys(phoneLimiter, []string{"phone_functions"}); err != nil { - return c, httpError(http.StatusTooManyRequests, "Sms rate limit exceeded") + return c, tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, "SMS rate limit exceeded") } } } @@ -151,7 +151,7 @@ func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (co config := a.config if !config.External.Email.Enabled { - return nil, badRequestError("Email logins are disabled") + return nil, badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } return ctx, nil @@ -178,8 +178,7 @@ func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.C } if !verificationResult.Success { - return nil, badRequestError("captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) - + return nil, badRequestError(ErrorCodeCaptchaFailed, "captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) } return ctx, nil @@ -223,7 +222,7 @@ func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (con func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.SAML.Enabled { - return nil, notFoundError("SAML 2.0 is disabled") + return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled") } return ctx, nil } @@ -231,7 +230,7 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.Security.ManualLinkingEnabled { - return nil, notFoundError("Manual linking is disabled") + return nil, notFoundError(ErrorCodeManualLinkingDisabled, "Manual linking is disabled") } return ctx, nil } diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index c532a50ef..e591121f8 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -176,7 +176,7 @@ func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() { w := httptest.NewRecorder() _, err := ts.API.verifyCaptcha(w, req) - require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).Code) + require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).HTTPStatus) require.Equal(ts.T(), c.expectedMsg, err.(*HTTPError).Message) }) } @@ -201,8 +201,8 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() { }, }, { - desc: "Sms rate limit exceeded", - expectedErrorMsg: "429: Sms rate limit exceeded", + desc: "SMS rate limit exceeded", + expectedErrorMsg: "429: SMS rate limit exceeded", requestBody: map[string]interface{}{ "phone": "+1233456789", }, @@ -269,7 +269,7 @@ func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() { { desc: "SAML not enabled", isEnabled: false, - expectedErr: notFoundError("SAML 2.0 is disabled"), + expectedErr: notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled"), }, { desc: "SAML enabled", diff --git a/internal/api/otp.go b/internal/api/otp.go index 0e437faa1..99b7bae32 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -34,10 +34,10 @@ type SmsParams struct { func (p *OtpParams) Validate() error { if p.Email != "" && p.Phone != "" { - return badRequestError("Only an email address or phone number should be provided") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided") } if p.Email != "" && p.Channel != "" { - return badRequestError("Channel should only be specified with Phone OTP") + return badRequestError(ErrorCodeValidationFailed, "Channel should only be specified with Phone OTP") } if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { return err @@ -47,7 +47,7 @@ func (p *OtpParams) Validate() error { func (p *SmsParams) Validate(smsProvider string) error { if p.Phone != "" && !sms_provider.IsValidMessageChannel(p.Channel, smsProvider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } var err error @@ -80,7 +80,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { } if ok, err := a.shouldCreateUser(r, params); !ok { - return badRequestError("Signups not allowed for otp") + return unprocessableEntityError(ErrorCodeOTPDisabled, "Signups not allowed for otp") } else if err != nil { return err } @@ -91,7 +91,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { return a.SmsOtp(w, r) } - return otpError("unsupported_otp_type", "") + return badRequestError(ErrorCodeValidationFailed, "One of email or phone must be set") } type SmsOtpResponse struct { @@ -105,7 +105,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Phone.Enabled { - return badRequestError("Unsupported phone provider") + return badRequestError(ErrorCodePhoneProviderDisabled, "Unsupported phone provider") } var err error @@ -141,7 +141,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { // Sign them up with temporary password. password, err := password.Generate(64, 10, 1, false, true) if err != nil { - internalServerError("error creating user").WithInternalError(err) + return internalServerError("error creating user").WithInternalError(err) } signUpParams := &SignupParams{ @@ -152,7 +152,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must be marshallable + panic(err) } r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) @@ -170,7 +171,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must be marshallable + panic(err) } r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) return a.SmsOtp(w, r) @@ -191,11 +193,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Unable to get SMS provider").WithInternalError(err) } mID, serr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) if serr != nil { - return badRequestError("Error sending sms OTP: %v", serr) + return badRequestError(ErrorCodeSMSSendFailed, "Error sending sms OTP: %v", serr).WithInternalError(serr) } messageID = mID return nil diff --git a/internal/api/otp_test.go b/internal/api/otp_test.go index be3b18114..c72fbc361 100644 --- a/internal/api/otp_test.go +++ b/internal/api/otp_test.go @@ -80,8 +80,9 @@ func (ts *OtpTestSuite) TestOtpPKCE() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "PKCE flow requires code_challenge_method and code_challenge", + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "PKCE flow requires code_challenge_method and code_challenge", }, }, }, @@ -98,8 +99,9 @@ func (ts *OtpTestSuite) TestOtpPKCE() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "PKCE flow requires code_challenge_method and code_challenge", + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "PKCE flow requires code_challenge_method and code_challenge", }, }, }, @@ -115,10 +117,10 @@ func (ts *OtpTestSuite) TestOtpPKCE() { code int response map[string]interface{} }{ - http.StatusBadRequest, + http.StatusInternalServerError, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Error sending sms:", + "code": float64(http.StatusInternalServerError), + "msg": "Unable to get SMS provider", }, }, }, @@ -182,8 +184,9 @@ func (ts *OtpTestSuite) TestOtp() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Only an email address or phone number should be provided", + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": "Only an email address or phone number should be provided", }, }, }, @@ -200,8 +203,9 @@ func (ts *OtpTestSuite) TestOtp() { }{ http.StatusBadRequest, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": InvalidChannelError, + "code": float64(http.StatusBadRequest), + "error_code": ErrorCodeValidationFailed, + "msg": InvalidChannelError, }, }, }, @@ -244,15 +248,16 @@ func (ts *OtpTestSuite) TestNoSignupsForOtp() { ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusBadRequest, w.Code) + require.Equal(ts.T(), http.StatusUnprocessableEntity, w.Code) data := make(map[string]interface{}) require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) // response should be empty assert.Equal(ts.T(), data, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Signups not allowed for otp", + "code": float64(http.StatusUnprocessableEntity), + "error_code": ErrorCodeOTPDisabled, + "msg": "Signups not allowed for otp", }) } diff --git a/internal/api/phone.go b/internal/api/phone.go index cea94a944..f85caa6fd 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -24,7 +24,7 @@ const ( func validatePhone(phone string) (string, error) { phone = formatPhoneNumber(phone) if isValid := validateE164Format(phone); !isValid { - return "", unprocessableEntityError("Invalid phone number format (E.164 required)") + return "", badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") } return phone, nil } diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go index 3c543634d..09810e288 100644 --- a/internal/api/phone_test.go +++ b/internal/api/phone_test.go @@ -177,8 +177,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "password": "testpassword", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending confirmation sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -190,8 +190,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "phone": "123456789", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -203,8 +203,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "phone": "111111111", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -214,8 +214,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { header: "", body: nil, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, } @@ -244,7 +244,12 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { require.Equal(ts.T(), c.expected["code"], w.Code) body := w.Body.String() - require.True(ts.T(), strings.Contains(body, c.expected["message"].(string))) + require.True(ts.T(), + strings.Contains(body, "Unable to get SMS provider") || + strings.Contains(body, "Error finding SMS provider") || + strings.Contains(body, "Failed to get SMS provider"), + "unexpected body message %q", body, + ) }) } } diff --git a/internal/api/pkce.go b/internal/api/pkce.go index 56ac1acb9..5ac75668d 100644 --- a/internal/api/pkce.go +++ b/internal/api/pkce.go @@ -21,9 +21,9 @@ func isValidCodeChallenge(codeChallenge string) (bool, error) { // See RFC 7636 Section 4.2: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 switch codeChallengeLength := len(codeChallenge); { case codeChallengeLength < MinCodeChallengeLength, codeChallengeLength > MaxCodeChallengeLength: - return false, badRequestError("code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) + return false, badRequestError(ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) case !codeChallengePattern.MatchString(codeChallenge): - return false, badRequestError("code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") + return false, badRequestError(ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") default: return true, nil } @@ -41,7 +41,7 @@ func addFlowPrefixToToken(token string, flowType models.FlowType) string { func issueAuthCode(tx *storage.Connection, user *models.User, authenticationMethod models.AuthenticationMethod) (string, error) { flowState, err := models.FindFlowStateByUserID(tx, user.ID.String(), authenticationMethod) if err != nil && models.IsNotFoundError(err) { - return "", badRequestError("No valid flow state found for user.") + return "", unprocessableEntityError(ErrorCodeFlowStateNotFound, "No valid flow state found for user.") } else if err != nil { return "", err } @@ -63,7 +63,7 @@ func isImplicitFlow(flowType models.FlowType) bool { func validatePKCEParams(codeChallengeMethod, codeChallenge string) error { switch true { case (codeChallenge == "") != (codeChallengeMethod == ""): - return badRequestError(InvalidPKCEParamsErrorMessage) + return badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage) case codeChallenge != "": if valid, err := isValidCodeChallenge(codeChallenge); !valid { return err diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go index b62a51fc0..84b080070 100644 --- a/internal/api/reauthenticate.go +++ b/internal/api/reauthenticate.go @@ -23,16 +23,16 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { email, phone := user.GetEmail(), user.GetPhone() if email == "" && phone == "" { - return unprocessableEntityError("Reauthentication requires the user to have an email or a phone number") + return badRequestError(ErrorCodeValidationFailed, "Reauthentication requires the user to have an email or a phone number") } if email != "" { if !user.IsConfirmed() { - return badRequestError("Please verify your email first.") + return unprocessableEntityError(ErrorCodeEmailNotConfirmed, "Please verify your email first.") } } else if phone != "" { if !user.IsPhoneConfirmed() { - return badRequestError("Please verify your phone first.") + return unprocessableEntityError(ErrorCodePhoneNotConfirmed, "Please verify your phone first.") } } @@ -47,7 +47,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { } else if phone != "" { smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Failed to get SMS provider").WithInternalError(terr) } mID, err := a.sendPhoneConfirmation(tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) if err != nil { @@ -60,7 +60,12 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + reason := ErrorCodeOverEmailSendRateLimit + if phone != "" { + reason = ErrorCodeOverSMSSendRateLimit + } + + return tooManyRequestsError(reason, "For security purposes, you can only request this once every 60 seconds") } return err } @@ -77,7 +82,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { // verifyReauthentication checks if the nonce provided is valid func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, config *conf.GlobalConfiguration, user *models.User) error { if user.ReauthenticationToken == "" || user.ReauthenticationSentAt == nil { - return badRequestError(InvalidNonceMessage) + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, InvalidNonceMessage) } var isValid bool if user.GetEmail() != "" { @@ -87,7 +92,7 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi if config.Sms.IsTwilioVerifyProvider() { smsProvider, _ := sms_provider.GetSmsProvider(*config) if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(string(user.Phone), nonce); err != nil { - return expiredTokenError("Token has expired or is invalid").WithInternalError(err) + return forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } return nil } else { @@ -95,10 +100,10 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi isValid = isOtpValid(tokenHash, user.ReauthenticationToken, user.ReauthenticationSentAt, config.Sms.OtpExp) } } else { - return unprocessableEntityError("Reauthentication requires an email or a phone number") + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, "Reauthentication requires an email or a phone number") } if !isValid { - return badRequestError(InvalidNonceMessage) + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, InvalidNonceMessage) } if err := user.ConfirmReauthentication(tx); err != nil { return internalServerError("Error during reauthentication").WithInternalError(err) diff --git a/internal/api/recover.go b/internal/api/recover.go index 77e3c068d..a3201852d 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -18,7 +18,7 @@ type RecoverParams struct { func (p *RecoverParams) Validate() error { if p.Email == "" { - return unprocessableEntityError("Password recovery requires an email") + return badRequestError(ErrorCodeValidationFailed, "Password recovery requires an email") } var err error if p.Email, err = validateEmail(p.Email); err != nil { @@ -73,7 +73,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Unable to process request").WithInternalError(err) } diff --git a/internal/api/resend.go b/internal/api/resend.go index cb8c4da24..fdad38c43 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -25,22 +25,22 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er break default: // type does not match one of the above - return badRequestError("Missing one of these types: signup, email_change, sms, phone_change") + return badRequestError(ErrorCodeValidationFailed, "Missing one of these types: signup, email_change, sms, phone_change") } if p.Email == "" && p.Type == signupVerification { - return badRequestError("Type provided requires an email address") + return badRequestError(ErrorCodeValidationFailed, "Type provided requires an email address") } if p.Phone == "" && p.Type == smsVerification { - return badRequestError("Type provided requires a phone number") + return badRequestError(ErrorCodeValidationFailed, "Type provided requires a phone number") } var err error if p.Email != "" && p.Phone != "" { - return badRequestError("Only an email address or phone number should be provided.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided.") } else if p.Email != "" { if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } p.Email, err = validateEmail(p.Email) if err != nil { @@ -48,7 +48,7 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er } } else if p.Phone != "" { if !config.External.Phone.Enabled { - return badRequestError("Phone logins are disabled") + return badRequestError(ErrorCodePhoneProviderDisabled, "Phone logins are disabled") } p.Phone, err = validatePhone(p.Phone) if err != nil { @@ -56,7 +56,7 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er } } else { // both email and phone are empty - return badRequestError("Missing email address or phone number") + return badRequestError(ErrorCodeValidationFailed, "Missing email address or phone number") } return nil } @@ -156,8 +156,13 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { + reason := ErrorCodeOverEmailSendRateLimit + if params.Type == smsVerification || params.Type == phoneChangeVerification { + reason = ErrorCodeOverSMSSendRateLimit + } + until := time.Until(user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency)) / time.Second - return tooManyRequestsError("For security purposes, you can only request this once every %d seconds.", until) + return tooManyRequestsError(reason, "For security purposes, you can only request this once every %d seconds.", until) } return internalServerError("Unable to process request").WithInternalError(err) } diff --git a/internal/api/router.go b/internal/api/router.go index c2f06ae2e..70b41f22d 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -63,7 +63,7 @@ func handler(fn apiHandler) http.HandlerFunc { func (h apiHandler) serve(w http.ResponseWriter, r *http.Request) { if err := h(w, r); err != nil { - handleError(err, w, r) + HandleResponseError(err, w, r) } } @@ -78,7 +78,7 @@ func (m middlewareHandler) handler(next http.Handler) http.Handler { func (m middlewareHandler) serve(next http.Handler, w http.ResponseWriter, r *http.Request) { ctx, err := m(w, r) if err != nil { - handleError(err, w, r) + HandleResponseError(err, w, r) return } if ctx != nil { diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index a3932249a..d82117748 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -67,7 +67,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { relayState, err := models.FindSAMLRelayStateByID(db, relayStateUUID) if models.IsNotFoundError(err) { - return badRequestError("SAML RelayState does not exist, try logging in again?") + return notFoundError(ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") } else if err != nil { return err } @@ -77,7 +77,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { return internalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) } - return badRequestError("SAML RelayState has expired. Try loggin in again?") + return unprocessableEntityError(ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try logging in again?") } // TODO: add abuse detection to bind the RelayState UUID with a @@ -107,23 +107,23 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { // SAML Artifact responses are possible only when // RelayState can be used to identify the Identity // Provider. - return badRequestError("SAML Artifact response can only be used with SP initiated flow") + return badRequestError(ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") } samlResponse := r.FormValue("SAMLResponse") if samlResponse == "" { - return badRequestError("SAMLResponse is missing") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is missing") } responseXML, err := base64.StdEncoding.DecodeString(samlResponse) if err != nil { - return badRequestError("SAMLResponse is not a valid Base64 string") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") } var peekResponse saml.Response err = xml.Unmarshal(responseXML, &peekResponse) if err != nil { - return badRequestError("SAMLResponse is not a valid XML SAML assertion") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) } initiatedBy = "idp" @@ -131,12 +131,12 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { redirectTo = relayStateValue } else { // RelayState can't be identified, so SAML flow can't continue - return badRequestError("SAML RelayState is not a valid UUID or URL") + return badRequestError(ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") } ssoProvider, err := models.FindSAMLProviderByEntityID(db, entityId) if models.IsNotFoundError(err) { - return badRequestError("A SAML connection has not been established with this Identity Provider") + return notFoundError(ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") } else if err != nil { return err } @@ -176,10 +176,10 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { spAssertion, err := serviceProvider.ParseResponse(r, requestIds) if err != nil { if ire, ok := err.(*saml.InvalidResponseError); ok { - return badRequestError("SAML Assertion is not valid").WithInternalError(ire.PrivateErr) + return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(ire.PrivateErr) } - return badRequestError("SAML Assertion is not valid").WithInternalError(err) + return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) } assertion := SAMLAssertion{ @@ -188,7 +188,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { userID := assertion.UserID() if userID == "" { - return badRequestError("SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") + return badRequestError(ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") } claims := assertion.Process(ssoProvider.SAMLProvider.AttributeMapping) @@ -200,7 +200,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { } if email == "" { - return badRequestError("SAML Assertion does not contain an email address") + return badRequestError(ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") } else { claims["email"] = email } diff --git a/internal/api/signup.go b/internal/api/signup.go index 3d7a19000..5c7e588b8 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -34,21 +34,21 @@ func (a *API) validateSignupParams(ctx context.Context, p *SignupParams) error { config := a.config if p.Password == "" { - return unprocessableEntityError("Signup requires a valid password") + return badRequestError(ErrorCodeValidationFailed, "Signup requires a valid password") } if err := a.checkPasswordStrength(ctx, p.Password); err != nil { return err } if p.Email != "" && p.Phone != "" { - return unprocessableEntityError("Only an email address or phone number should be provided on signup.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on signup.") } if p.Provider == "phone" && !sms_provider.IsValidMessageChannel(p.Channel, config.Sms.Provider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } // PKCE not needed as phone signups already return access token in body if p.Phone != "" && p.CodeChallenge != "" { - return badRequestError("PKCE not supported for phone signups") + return badRequestError(ErrorCodeValidationFailed, "PKCE not supported for phone signups") } if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { return err @@ -114,7 +114,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) if config.DisableSignup { - return forbiddenError("Signups not allowed for this instance") + return unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{} @@ -141,7 +141,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { switch params.Provider { case "email": if !config.External.Email.Enabled { - return badRequestError("Email signups are disabled") + return badRequestError(ErrorCodeEmailProviderDisabled, "Email signups are disabled") } params.Email, err = validateEmail(params.Email) if err != nil { @@ -150,7 +150,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { user, err = models.IsDuplicatedEmail(db, params.Email, params.Aud, nil) case "phone": if !config.External.Phone.Enabled { - return badRequestError("Phone signups are disabled") + return badRequestError(ErrorCodePhoneProviderDisabled, "Phone signups are disabled") } params.Phone, err = validatePhone(params.Phone) if err != nil { @@ -158,7 +158,18 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } user, err = models.FindUserByPhoneAndAudience(db, params.Phone, params.Aud) default: - return invalidSignupError(config) + msg := "" + if config.External.Email.Enabled && config.External.Phone.Enabled { + msg = "Sign up only available with email or phone provider" + } else if config.External.Email.Enabled { + msg = "Sign up only available with email provider" + } else if config.External.Phone.Enabled { + msg = "Sign up only available with phone provider" + } else { + msg = "Sign up with this provider not possible" + } + + return badRequestError(ErrorCodeValidationFailed, msg) } if err != nil && !models.IsNotFoundError(err) { @@ -241,7 +252,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { if errors.Is(terr, MaxFrequencyLimitError) { now := time.Now() left := user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency).Sub(now) / time.Second - return tooManyRequestsError(fmt.Sprintf("For security purposes, you can only request this after %d seconds.", left)) + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, fmt.Sprintf("For security purposes, you can only request this after %d seconds.", left)) } return internalServerError("Error sending confirmation mail").WithInternalError(terr) } @@ -265,10 +276,10 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending confirmation sms: %v", terr) + return internalServerError("Unable to get SMS provider").WithInternalError(terr) } if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { - return badRequestError("Error sending confirmation sms: %v", terr) + return unprocessableEntityError(ErrorCodeSMSSendFailed, "Error sending confirmation sms: %v", terr).WithInternalError(terr) } } } @@ -277,10 +288,14 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { }) if err != nil { - if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every minute") + reason := ErrorCodeOverEmailSendRateLimit + if params.Provider == "phone" { + reason = ErrorCodeOverSMSSendRateLimit } - if errors.Is(err, UserExistsError) { + + if errors.Is(err, MaxFrequencyLimitError) { + return tooManyRequestsError(reason, "For security purposes, you can only request this once every minute") + } else if errors.Is(err, UserExistsError) { err = db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(r, tx, user, models.UserRepeatedSignUpAction, "", map[string]interface{}{ "provider": params.Provider, @@ -293,7 +308,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return err } if config.Mailer.Autoconfirm || config.Sms.Autoconfirm { - return badRequestError("User already registered") + return unprocessableEntityError(ErrorCodeUserAlreadyExists, "User already registered") } sanitizedUser, err := sanitizeUser(user, params) if err != nil { diff --git a/internal/api/sso.go b/internal/api/sso.go index 08ca4c616..e50d1a369 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -27,9 +27,9 @@ func (p *SingleSignOnParams) validate() (bool, error) { hasDomain := p.Domain != "" if hasProviderID && hasDomain { - return hasProviderID, badRequestError("Only one of provider_id or domain supported") + return hasProviderID, badRequestError(ErrorCodeValidationFailed, "Only one of provider_id or domain supported") } else if !hasProviderID && !hasDomain { - return hasProviderID, badRequestError("A provider_id or domain needs to be provided") + return hasProviderID, badRequestError(ErrorCodeValidationFailed, "A provider_id or domain needs to be provided") } return hasProviderID, nil @@ -73,14 +73,14 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { if hasProviderID { ssoProvider, err = models.FindSSOProviderByID(db, params.ProviderID) if models.IsNotFoundError(err) { - return notFoundError("No such SSO provider") + return notFoundError(ErrorCodeSSOProviderNotFound, "No such SSO provider") } else if err != nil { return internalServerError("Unable to find SSO provider by ID").WithInternalError(err) } } else { ssoProvider, err = models.FindSSOProviderByDomain(db, params.Domain) if models.IsNotFoundError(err) { - return notFoundError("No SSO provider assigned for this domain") + return notFoundError(ErrorCodeSSOProviderNotFound, "No SSO provider assigned for this domain") } else if err != nil { return internalServerError("Unable to find SSO provider by domain").WithInternalError(err) } diff --git a/internal/api/sso_test.go b/internal/api/sso_test.go index 7da8b6eb2..5fc46b2d0 100644 --- a/internal/api/sso_test.go +++ b/internal/api/sso_test.go @@ -277,7 +277,7 @@ func (ts *SSOTestSuite) TestAdminCreateSSOProvider() { }, }, { - StatusCode: http.StatusBadRequest, + StatusCode: http.StatusUnprocessableEntity, Request: map[string]interface{}{ "type": "saml", "metadata_xml": validSAMLIDPMetadata("https://accounts.google.com/o/saml2?idpid=EXAMPLE-DUPLICATE"), diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go index 0f966780e..6e52fc8ff 100644 --- a/internal/api/ssoadmin.go +++ b/internal/api/ssoadmin.go @@ -28,14 +28,14 @@ func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.C idpID, err := uuid.FromString(idpParam) if err != nil { // idpParam is not UUIDv4 - return nil, notFoundError("SSO Identity Provider not found") + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") } // idpParam is a UUIDv4 provider, err := models.FindSSOProviderByID(db, idpID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("SSO Identity Provider not found") + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") } else { return nil, internalServerError("Database error finding SSO Identity Provider").WithInternalError(err) } @@ -78,19 +78,19 @@ type CreateSSOProviderParams struct { func (p *CreateSSOProviderParams) validate(forUpdate bool) error { if !forUpdate && p.Type != "saml" { - return badRequestError("Only 'saml' supported for SSO provider type") + return badRequestError(ErrorCodeValidationFailed, "Only 'saml' supported for SSO provider type") } else if p.MetadataURL != "" && p.MetadataXML != "" { - return badRequestError("Only one of metadata_xml or metadata_url needs to be set") + return badRequestError(ErrorCodeValidationFailed, "Only one of metadata_xml or metadata_url needs to be set") } else if !forUpdate && p.MetadataURL == "" && p.MetadataXML == "" { - return badRequestError("Either metadata_xml or metadata_url must be set") + return badRequestError(ErrorCodeValidationFailed, "Either metadata_xml or metadata_url must be set") } else if p.MetadataURL != "" { metadataURL, err := url.ParseRequestURI(p.MetadataURL) if err != nil { - return badRequestError("metadata_url is not a valid URL") + return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a valid URL") } if metadataURL.Scheme != "https" { - return badRequestError("metadata_url is not a HTTPS URL") + return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a HTTPS URL") } } @@ -126,7 +126,7 @@ func (p *CreateSSOProviderParams) metadata(ctx context.Context) ([]byte, *saml.E func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { if !utf8.Valid(rawMetadata) { - return nil, badRequestError("SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") } metadata, err := samlsp.ParseMetadata(rawMetadata) @@ -135,15 +135,15 @@ func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { } if metadata.EntityID == "" { - return nil, badRequestError("SAML Metadata does not contain an EntityID") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain an EntityID") } if len(metadata.IDPSSODescriptors) < 1 { - return nil, badRequestError("SAML Metadata does not contain any IDPSSODescriptor") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain any IDPSSODescriptor") } if len(metadata.IDPSSODescriptors) > 1 { - return nil, badRequestError("SAML Metadata contains multiple IDPSSODescriptors") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata contains multiple IDPSSODescriptors") } return metadata, nil @@ -152,7 +152,7 @@ func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return nil, badRequestError("Unable to create a request to metadata_url").WithInternalError(err) + return nil, internalServerError("Unable to create a request to metadata_url").WithInternalError(err) } req = req.WithContext(ctx) @@ -167,7 +167,7 @@ func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { defer utilities.SafeClose(resp.Body) if resp.StatusCode != http.StatusOK { - return nil, badRequestError("HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) + return nil, badRequestError(ErrorCodeSAMLMetadataFetchFailed, "HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) } data, err := io.ReadAll(resp.Body) @@ -202,7 +202,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er return err } if existingProvider != nil { - return badRequestError("SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) + return unprocessableEntityError(ErrorCodeSAMLIdPAlreadyExists, "SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) } provider := &models.SSOProvider{ @@ -225,7 +225,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er return err } if existingProvider != nil { - return badRequestError("SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) + return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) } provider.SSODomains = append(provider.SSODomains, models.SSODomain{ @@ -280,7 +280,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er } if provider.SAMLProvider.EntityID != metadata.EntityID { - return badRequestError("SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) + return badRequestError(ErrorCodeSAMLEntityIDMismatch, "SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) } if params.MetadataURL != "" { @@ -309,7 +309,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er if existingProvider.ID == provider.ID { keepDomains[domain] = true } else { - return badRequestError("SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) + return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) } } else { modified = true @@ -359,7 +359,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er return tx.Eager().Load(provider) }); err != nil { - return unprocessableEntityError("Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) + return unprocessableEntityError(ErrorCodeConflict, "Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) } } diff --git a/internal/api/token.go b/internal/api/token.go index 2f6f9e3b2..df0292711 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -108,7 +108,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri config := a.config if params.Email != "" && params.Phone != "" { - return unprocessableEntityError("Only an email address or phone number should be provided on login.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on login.") } var user *models.User var grantParams models.GrantParams @@ -120,13 +120,13 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri if params.Email != "" { provider = "email" if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) } else if params.Phone != "" { provider = "phone" if !config.External.Phone.Enabled { - return badRequestError("Phone logins are disabled") + return unprocessableEntityError(ErrorCodePhoneProviderDisabled, "Phone logins are disabled") } params.Phone = formatPhoneNumber(params.Phone) user, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud) @@ -178,7 +178,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri return err } } - return forbiddenError(output.Message) + return oauthError("invalid_grant", InvalidLoginMessage) } } if !isValidPassword { @@ -230,24 +230,23 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) grantParams.FillGrantParams(r) params := &PKCEGrantParams{} - if err := retrieveRequestParams(r, params); err != nil { return err } if params.AuthCode == "" || params.CodeVerifier == "" { - return badRequestError("invalid request: both auth code and code verifier should be non-empty") + return badRequestError(ErrorCodeValidationFailed, "invalid request: both auth code and code verifier should be non-empty") } flowState, err := models.FindFlowStateByAuthCode(db, params.AuthCode) // Sanity check in case user ID was not set properly if models.IsNotFoundError(err) || flowState.UserID == nil { - return forbiddenError("invalid flow state, no valid flow state found") + return notFoundError(ErrorCodeFlowStateNotFound, "invalid flow state, no valid flow state found") } else if err != nil { return err } if flowState.IsExpired(a.config.External.FlowStateExpiryDuration) { - return forbiddenError("invalid flow state, flow state has expired") + return unprocessableEntityError(ErrorCodeFlowStateExpired, "invalid flow state, flow state has expired") } user, err := models.FindUserByID(db, *flowState.UserID) @@ -255,7 +254,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) return err } if err := flowState.VerifyPKCE(params.CodeVerifier); err != nil { - return forbiddenError(err.Error()) + return badRequestError(ErrorBadCodeVerifier, err.Error()) } var token *AccessTokenResponse @@ -427,6 +426,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, if err != nil { return nil, internalServerError("Cannot read SessionId claim as UUID").WithInternalError(err) } + err = tx.Transaction(func(tx *storage.Connection) error { if terr := models.AddClaimToSession(tx, sessionId, authenticationMethod); terr != nil { return terr @@ -459,7 +459,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, return err } - tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &sessionId, models.TOTPSignIn) + tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &session.ID, models.TOTPSignIn) if terr != nil { httpErr, ok := terr.(*HTTPError) if ok { diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index 58e022afd..0574c3bb8 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -54,7 +54,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa if issuer == "" || !provider.IsAzureIssuer(issuer) { detectedIssuer, err := provider.DetectAzureIDTokenIssuer(ctx, p.IdToken) if err != nil { - return nil, nil, "", nil, badRequestError("Unable to detect issuer in ID token for Azure provider").WithInternalError(err) + return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err) } issuer = detectedIssuer } @@ -95,12 +95,12 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa } if !allowed { - return nil, nil, "", nil, badRequestError(fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) + return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) } } if cfg != nil && !cfg.Enabled { - return nil, nil, "", nil, badRequestError(fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) + return nil, nil, "", nil, badRequestError(ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) } oidcProvider, err := oidc.NewProvider(ctx, issuer) diff --git a/internal/api/token_test.go b/internal/api/token_test.go index b12a79a8e..53a492b11 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -343,7 +343,7 @@ func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusForbidden, w.Code) + assert.Equal(ts.T(), http.StatusNotFound, w.Code) }) } } @@ -618,7 +618,7 @@ func (ts *TokenTestSuite) TestPasswordVerificationHook() { begin return jsonb_build_object('decision', 'reject', 'message', 'You shall not pass!'); end; $$ language plpgsql;`, - expectedCode: http.StatusForbidden, + expectedCode: http.StatusBadRequest, }, } for _, c := range cases { diff --git a/internal/api/user.go b/internal/api/user.go index 723521c17..9fe0dcef8 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -45,7 +45,7 @@ func (a *API) validateUserUpdateParams(ctx context.Context, p *UserUpdateParams) p.Channel = sms_provider.SMSProvider } if !sms_provider.IsValidMessageChannel(p.Channel, config.Sms.Provider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } } @@ -63,12 +63,12 @@ func (a *API) UserGet(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() claims := getClaims(ctx) if claims == nil { - return badRequestError("Could not read claims") + return internalServerError("Could not read claims") } aud := a.requestAud(ctx, r) if aud != claims.Audience { - return badRequestError("Token audience doesn't match request audience") + return badRequestError(ErrorCodeValidationFailed, "Token audience doesn't match request audience") } user := getUser(ctx) @@ -96,7 +96,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if params.AppData != nil && !isAdmin(user, config) { if !isAdmin(user, config) { - return unauthorizedError("Updating app_metadata requires admin privileges") + return forbiddenError(ErrorCodeNotAdmin, "Updating app_metadata requires admin privileges") } } @@ -104,7 +104,8 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { updatingForbiddenFields := false updatingForbiddenFields = updatingForbiddenFields || (params.Password != nil && *params.Password != "") if updatingForbiddenFields { - return unprocessableEntityError("Updating password of an anonymous user is not possible") + // CHECK + return unprocessableEntityError(ErrorCodeUnknown, "Updating password of an anonymous user is not possible") } } @@ -117,7 +118,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { updatingForbiddenFields = updatingForbiddenFields || (params.Nonce != "") if updatingForbiddenFields { - return unprocessableEntityError("Updating email, phone, password of a SSO account only possible via SSO") + return unprocessableEntityError(ErrorCodeUserSSOManaged, "Updating email, phone, password of a SSO account only possible via SSO") } } @@ -125,7 +126,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if duplicateUser, err := models.IsDuplicatedEmail(db, params.Email, aud, user); err != nil { return internalServerError("Database error checking email").WithInternalError(err) } else if duplicateUser != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } @@ -133,7 +134,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { return internalServerError("Database error checking phone").WithInternalError(err) } else if exists { - return unprocessableEntityError(DuplicatePhoneMsg) + return unprocessableEntityError(ErrorCodePhoneExists, DuplicatePhoneMsg) } } @@ -143,7 +144,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { // we require reauthentication if the user hasn't signed in recently in the current session if session == nil || now.After(session.CreatedAt.Add(24*time.Hour)) { if len(params.Nonce) == 0 { - return badRequestError("Password update requires reauthentication") + return badRequestError(ErrorCodeReauthenticationNeeded, "Password update requires reauthentication") } if err := a.verifyReauthentication(params.Nonce, db, config, user); err != nil { return err @@ -154,7 +155,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { password := *params.Password if password != "" { if user.EncryptedPassword != "" && user.Authenticate(ctx, password) { - return unprocessableEntityError("New password should be different from the old password.") + return unprocessableEntityError(ErrorCodeSamePassword, "New password should be different from the old password.") } } @@ -206,7 +207,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { externalURL := getExternalHost(ctx) if terr = a.sendEmailChange(tx, config, user, mailer, params.Email, referrer, externalURL, config.Mailer.OtpLength, flowType); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRateLimit, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Error sending change email").WithInternalError(terr) } @@ -224,7 +225,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } else { smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Error finding SMS provider").WithInternalError(terr) } if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { return internalServerError("Error sending phone change otp").WithInternalError(terr) diff --git a/internal/api/user_test.go b/internal/api/user_test.go index 2136f29c7..18ed6ec85 100644 --- a/internal/api/user_test.go +++ b/internal/api/user_test.go @@ -281,7 +281,7 @@ func (ts *UserTestSuite) TestUserUpdatePassword() { nonce: "123456", requireReauthentication: true, sessionId: nil, - expected: expected{code: http.StatusBadRequest, isAuthenticated: false}, + expected: expected{code: http.StatusUnprocessableEntity, isAuthenticated: false}, }, { desc: "Valid password length", diff --git a/internal/api/verify.go b/internal/api/verify.go index 4af21e023..d6d5dc541 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -53,18 +53,18 @@ type VerifyParams struct { func (p *VerifyParams) Validate(r *http.Request) error { var err error if p.Type == "" { - return badRequestError("Verify requires a verification type") + return badRequestError(ErrorCodeValidationFailed, "Verify requires a verification type") } switch r.Method { case http.MethodGet: if p.Token == "" { - return badRequestError("Verify requires a token or a token hash") + return badRequestError(ErrorCodeValidationFailed, "Verify requires a token or a token hash") } // TODO: deprecate the token query param from GET /verify and use token_hash instead (breaking change) p.TokenHash = p.Token case http.MethodPost: if (p.Token == "" && p.TokenHash == "") || (p.Token != "" && p.TokenHash != "") { - return badRequestError("Verify requires either a token or a token hash") + return badRequestError(ErrorCodeValidationFailed, "Verify requires either a token or a token hash") } if p.Token != "" { if isPhoneOtpVerification(p) { @@ -76,15 +76,15 @@ func (p *VerifyParams) Validate(r *http.Request) error { } else if isEmailOtpVerification(p) { p.Email, err = validateEmail(p.Email) if err != nil { - return unprocessableEntityError("Invalid email format").WithInternalError(err) + return unprocessableEntityError(ErrorCodeValidationFailed, "Invalid email format").WithInternalError(err) } p.TokenHash = crypto.GenerateTokenHash(p.Email, p.Token) } else { - return badRequestError("Only an email address or phone number should be provided on verify") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify") } } else if p.TokenHash != "" { if p.Email != "" || p.Phone != "" || p.RedirectTo != "" { - return badRequestError("Only the token_hash and type should be provided") + return badRequestError(ErrorCodeValidationFailed, "Only the token_hash and type should be provided") } } default: @@ -114,7 +114,8 @@ func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { } return a.verifyPost(w, r, params) default: - return unprocessableEntityError("Only GET and POST methods are supported.") + // this should have been handled by Chi + panic("Only GET and POST methods allowed") } } @@ -165,7 +166,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa return nil } default: - return unprocessableEntityError("Unsupported verification type") + return badRequestError(ErrorCodeValidationFailed, "Unsupported verification type") } if terr != nil { @@ -193,7 +194,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa } } else if isPKCEFlow(flowType) { if authCode, terr = issueAuthCode(tx, user, authenticationMethod); terr != nil { - return badRequestError("No associated flow state found. %s", terr) + return badRequestError(ErrorCodeFlowStateNotFound, "No associated flow state found. %s", terr) } } return nil @@ -266,7 +267,7 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyP case smsVerification, phoneChangeVerification: user, terr = a.smsVerify(r, tx, user, params) default: - return unprocessableEntityError("Unsupported verification type") + return badRequestError(ErrorCodeValidationFailed, "Unsupported verification type") } if terr != nil { @@ -310,7 +311,8 @@ func (a *API) signupVerify(r *http.Request, ctx context.Context, conn *storage.C // to present the user with a password set form password, err := password.Generate(64, 10, 0, false, true) if err != nil { - return nil, err + // password generation must succeed + panic(err) } if err := user.SetPassword(ctx, password); err != nil { @@ -433,14 +435,14 @@ func (a *API) prepErrorRedirectURL(err *HTTPError, r *http.Request, rurl string, errorID := getRequestID(r.Context()) err.ErrorID = errorID log.WithError(err.Cause()).Info(err.Error()) - if str, ok := oauthErrorMap[err.Code]; ok { + if str, ok := oauthErrorMap[err.HTTPStatus]; ok { hq.Set("error", str) q.Set("error", str) } - hq.Set("error_code", strconv.Itoa(err.Code)) + hq.Set("error_code", strconv.Itoa(err.HTTPStatus)) hq.Set("error_description", err.Message) - q.Set("error_code", strconv.Itoa(err.Code)) + q.Set("error_code", strconv.Itoa(err.HTTPStatus)) q.Set("error_description", err.Message) if flowType == models.PKCEFlow { // Additionally, may override existing error query param if set to PKCE. @@ -563,18 +565,18 @@ func (a *API) verifyTokenHash(conn *storage.Connection, params *VerifyParams) (* case emailChangeVerification: user, err = models.FindUserByEmailChangeToken(conn, params.TokenHash) default: - return nil, badRequestError("Invalid email verification type") + return nil, badRequestError(ErrorCodeValidationFailed, "Invalid email verification type") } if err != nil { if models.IsNotFoundError(err) { - return nil, expiredTokenError("Email link is invalid or has expired").WithInternalError(err) + return nil, forbiddenError(ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalError(err) } return nil, internalServerError("Database error finding user from email link").WithInternalError(err) } if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalMessage("user is banned") + return nil, forbiddenError(ErrorCodeUserBanned, "User is banned") } var isExpired bool @@ -596,7 +598,7 @@ func (a *API) verifyTokenHash(conn *storage.Connection, params *VerifyParams) (* } if isExpired { - return nil, expiredTokenError("Email link is invalid or has expired").WithInternalMessage("email link has expired") + return nil, forbiddenError(ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalMessage("email link has expired") } return user, nil @@ -625,13 +627,13 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, if err != nil { if models.IsNotFoundError(err) { - return nil, expiredTokenError("Token has expired or is invalid").WithInternalError(err) + return nil, forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } return nil, internalServerError("Database error finding user").WithInternalError(err) } if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalMessage("user is banned") + return nil, forbiddenError(ErrorCodeUserBanned, "User is banned") } var isValid bool @@ -672,7 +674,7 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, } } if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(phone, params.Token); err != nil { - return nil, expiredTokenError("Token has expired or is invalid").WithInternalError(err) + return nil, forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } return user, nil } @@ -680,7 +682,7 @@ func (a *API) verifyUserAndToken(conn *storage.Connection, params *VerifyParams, } if !isValid { - return nil, expiredTokenError("Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") + return nil, forbiddenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") } return user, nil } diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index c782446fb..ac6e79b71 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -280,9 +280,9 @@ func (ts *VerifyTestSuite) TestExpiredConfirmationToken() { f, err := url.ParseQuery(rurl.Fragment) require.NoError(ts.T(), err) - assert.Equal(ts.T(), "401", f.Get("error_code")) + assert.Equal(ts.T(), "403", f.Get("error_code")) assert.Equal(ts.T(), "Email link is invalid or has expired", f.Get("error_description")) - assert.Equal(ts.T(), "unauthorized_client", f.Get("error")) + assert.Equal(ts.T(), "access_denied", f.Get("error")) } func (ts *VerifyTestSuite) TestInvalidOtp() { @@ -302,7 +302,7 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { } expectedResponse := ResponseBody{ - Code: http.StatusUnauthorized, + Code: http.StatusForbidden, Msg: "Token has expired or is invalid", } @@ -313,7 +313,7 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { expected ResponseBody }{ { - desc: "Expired Sms OTP", + desc: "Expired SMS OTP", sentTime: time.Now().Add(-48 * time.Hour), body: map[string]interface{}{ "type": smsVerification, @@ -323,7 +323,7 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { expected: expectedResponse, }, { - desc: "Invalid Sms OTP", + desc: "Invalid SMS OTP", sentTime: time.Now(), body: map[string]interface{}{ "type": smsVerification, @@ -760,7 +760,7 @@ func (ts *VerifyTestSuite) TestVerifyBannedUser() { f, err := url.ParseQuery(rurl.Fragment) require.NoError(ts.T(), err) - assert.Equal(ts.T(), "401", f.Get("error_code")) + assert.Equal(ts.T(), "403", f.Get("error_code")) }) } } @@ -973,7 +973,7 @@ func (ts *VerifyTestSuite) TestSecureEmailChangeWithTokenHash() { "type": emailChangeVerification, "token_hash": currentEmailChangeToken, }, - expectedStatus: http.StatusUnauthorized, + expectedStatus: http.StatusForbidden, }, } for _, c := range cases { @@ -1102,7 +1102,7 @@ func (ts *VerifyTestSuite) TestPrepErrorRedirectURL() { for _, c := range cases { ts.Run(c.desc, func() { req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) - rurl, err := ts.API.prepErrorRedirectURL(badRequestError(DefaultError), req, c.rurl, c.flowType) + rurl, err := ts.API.prepErrorRedirectURL(badRequestError(ErrorCodeValidationFailed, DefaultError), req, c.rurl, c.flowType) require.NoError(ts.T(), err) require.Equal(ts.T(), c.expected, rurl) }) @@ -1152,7 +1152,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { Token: "some-token", }, method: http.MethodPost, - expected: badRequestError("Only an email address or phone number should be provided on verify"), + expected: badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify"), }, { desc: "Cannot send both TokenHash and Token", @@ -1162,7 +1162,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { TokenHash: "some-token-hash", }, method: http.MethodPost, - expected: badRequestError("Verify requires either a token or a token hash"), + expected: badRequestError(ErrorCodeValidationFailed, "Verify requires either a token or a token hash"), }, { desc: "No verification type specified", @@ -1171,7 +1171,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { Email: "email@example.com", }, method: http.MethodPost, - expected: badRequestError("Verify requires a verification type"), + expected: badRequestError(ErrorCodeValidationFailed, "Verify requires a verification type"), }, }