diff --git a/internal/api/admin.go b/internal/api/admin.go index d6fd17dac7..8983eac359 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -50,7 +50,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, userID, err := uuid.FromString(chi.URLParam(r, "user_id")) if err != nil { - return nil, badRequestError("user_id must be an UUID") + return nil, notFoundError(ErrorCodeValidationFailed, "user_id must be an UUID") } observability.LogEntrySetField(r, "user_id", userID) @@ -58,7 +58,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, u, err := models.FindUserByID(db, userID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("User not found") + return nil, notFoundError(ErrorCodeUserNotFound, "User not found") } return nil, internalServerError("Database error loading user").WithInternalError(err) } @@ -69,7 +69,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) { factorID, err := uuid.FromString(chi.URLParam(r, "factor_id")) if err != nil { - return nil, badRequestError("factor_id must be an UUID") + return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID") } observability.LogEntrySetField(r, "factor_id", factorID) @@ -77,7 +77,7 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex f, err := models.FindFactorByFactorID(a.db, factorID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("Factor not found") + return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found") } return nil, internalServerError("Database error loading factor").WithInternalError(err) } @@ -89,11 +89,11 @@ func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) { body, err := getBodyBytes(r) if err != nil { - return nil, badRequestError("Could not read body").WithInternalError(err) + return nil, internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, ¶ms); err != nil { - return nil, badRequestError("Could not decode admin user params: %v", err) + return nil, badRequestError(ErrorCodeBadJSON, "Could not decode admin user params").WithInternalError(err) } return ¶ms, nil @@ -107,12 +107,12 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { pageParams, err := paginate(r) if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err) } sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}}) if err != nil { - return badRequestError("Bad Sort Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err) } filter := r.URL.Query().Get("filter") @@ -166,7 +166,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError("invalid format for ban duration: %v", err) + return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) } } if terr := user.Ban(a.db, duration); terr != nil { @@ -314,7 +314,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { } if params.Email == "" && params.Phone == "" { - return unprocessableEntityError("Cannot create a user without either an email or phone") + return badRequestError(ErrorCodeValidationFailed, "Cannot create a user without either an email or phone") } var providers []string @@ -326,7 +326,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil { return internalServerError("Database error checking email").WithInternalError(err) } else if user != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } providers = append(providers, "email") } @@ -339,7 +339,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { return internalServerError("Database error checking phone").WithInternalError(err) } else if exists { - return unprocessableEntityError("Phone number already registered by another user") + return unprocessableEntityError(ErrorCodePhoneExists, "Phone number already registered by another user") } providers = append(providers, "phone") } @@ -435,7 +435,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError("invalid format for ban duration: %v", err) + return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err) } } if terr := user.Ban(a.db, duration); terr != nil { @@ -466,11 +466,11 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { params := &adminUserDeleteParams{} body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if len(body) > 0 { if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read params: %v", err) } } else { params.ShouldSoftDelete = false @@ -567,11 +567,11 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro params := &adminUserUpdateFactorParams{} body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read factor update params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read factor update params: %v", err).WithInternalError(err) } err = a.db.Transaction(func(tx *storage.Connection) error { @@ -582,7 +582,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro } if params.FactorType != "" { if params.FactorType != models.TOTP { - return badRequestError("Factor Type not valid") + return badRequestError(ErrorCodeValidationFailed, "Factor Type not valid") } if terr := factor.UpdateFactorType(tx, params.FactorType); terr != nil { return terr diff --git a/internal/api/apiversions.go b/internal/api/apiversions.go new file mode 100644 index 0000000000..6e1c68ca72 --- /dev/null +++ b/internal/api/apiversions.go @@ -0,0 +1,31 @@ +package api + +import ( + "time" +) + +const APIVersionHeaderName = "X-Supabase-Api-Version" + +type APIVersion = time.Time + +var ( + APIVersionInitial = time.Time{} + APIVersion20240101 = time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC) +) + +func DetermineClosestAPIVersion(date string) (APIVersion, error) { + if date == "" { + return APIVersionInitial, nil + } + + parsed, err := time.ParseInLocation("2006-01-02", date, time.UTC) + if err != nil { + return APIVersionInitial, err + } + + if parsed.Compare(APIVersion20240101) >= 0 { + return APIVersion20240101, nil + } + + return APIVersionInitial, nil +} diff --git a/internal/api/apiversions_test.go b/internal/api/apiversions_test.go new file mode 100644 index 0000000000..0a96221328 --- /dev/null +++ b/internal/api/apiversions_test.go @@ -0,0 +1,29 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDetermineClosestAPIVersion(t *testing.T) { + version, err := DetermineClosestAPIVersion("") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("Not a date") + require.Error(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2023-12-31") + require.NoError(t, err) + require.Equal(t, APIVersionInitial, version) + + version, err = DetermineClosestAPIVersion("2024-01-01") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) + + version, err = DetermineClosestAPIVersion("2024-01-02") + require.NoError(t, err) + require.Equal(t, APIVersion20240101, version) +} diff --git a/internal/api/audit.go b/internal/api/audit.go index 2cb99c6e75..351a7d2cdf 100644 --- a/internal/api/audit.go +++ b/internal/api/audit.go @@ -20,7 +20,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { // aud := a.requestAud(ctx, r) pageParams, err := paginate(r) if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) + return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err) } var col []string @@ -31,7 +31,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { qparts := strings.SplitN(q, ":", 2) col, exists = filterColumnMap[qparts[0]] if !exists || len(qparts) < 2 { - return badRequestError("Invalid query scope: %s", q) + return badRequestError(ErrorCodeValidationFailed, "Invalid query scope: %s", q) } qval = qparts[1] } diff --git a/internal/api/auth.go b/internal/api/auth.go index c1f43d5114..507d534da4 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -40,7 +40,7 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R claims := getClaims(ctx) if claims == nil { fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token") - return nil, unauthorizedError("Invalid token") + return nil, unauthorizedError(ErrorCodeBadJWT, "Invalid token") } adminRoles := a.config.JWT.AdminRoles @@ -51,14 +51,14 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R } fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'") - return nil, unauthorizedError("User not allowed") + return nil, unauthorizedError(ErrorCodeNotAdmin, "User not allowed") } func (a *API) extractBearerToken(r *http.Request) (string, error) { authHeader := r.Header.Get("Authorization") matches := bearerRegexp.FindStringSubmatch(authHeader) if len(matches) != 2 { - return "", unauthorizedError("This endpoint requires a Bearer token") + return "", unauthorizedError(ErrorCodeNoAuthorization, "This endpoint requires a Bearer token") } return matches[1], nil @@ -73,7 +73,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e return []byte(config.JWT.Secret), nil }) if err != nil { - return nil, unauthorizedError("invalid JWT: unable to parse or verify signature, %v", err) + return nil, unauthorizedError(ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err) } return withToken(ctx, token), nil @@ -84,23 +84,23 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro claims := getClaims(ctx) if claims == nil { - return ctx, unauthorizedError("invalid token: missing claims") + return ctx, unauthorizedError(ErrorCodeBadJWT, "invalid token: missing claims") } if claims.Subject == "" { - return nil, unauthorizedError("invalid claim: missing sub claim") + return nil, unauthorizedError(ErrorCodeBadJWT, "invalid claim: missing sub claim") } var user *models.User if claims.Subject != "" { userId, err := uuid.FromString(claims.Subject) if err != nil { - return ctx, badRequestError("invalid claim: sub claim must be a UUID").WithInternalError(err) + return ctx, badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err) } user, err = models.FindUserByID(db, userId) if err != nil { if models.IsNotFoundError(err) { - return ctx, notFoundError(err.Error()) + return ctx, unauthorizedError(ErrorCodeUserNotFound, "User from sub claim in JWT does not exist") } return ctx, err } @@ -111,11 +111,11 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() { sessionId, err := uuid.FromString(claims.SessionId) if err != nil { - return ctx, err + return ctx, badRequestError(ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err) } session, err = models.FindSessionByID(db, sessionId, false) if err != nil && !models.IsNotFoundError(err) { - return ctx, err + return ctx, unauthorizedError(ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist") } ctx = withSession(ctx, session) } diff --git a/internal/api/auth_test.go b/internal/api/auth_test.go index 59ad0613c3..6adc05b0b1 100644 --- a/internal/api/auth_test.go +++ b/internal/api/auth_test.go @@ -97,7 +97,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { }, Role: "authenticated", }, - ExpectedError: unauthorizedError("invalid claim: missing sub claim"), + ExpectedError: unauthorizedError(ErrorCodeBadJWT, "invalid claim: missing sub claim"), ExpectedUser: nil, }, { @@ -119,7 +119,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() { }, Role: "authenticated", }, - ExpectedError: badRequestError("invalid claim: sub claim must be a UUID"), + ExpectedError: badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"), ExpectedUser: nil, }, { diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go new file mode 100644 index 0000000000..a64eec64f7 --- /dev/null +++ b/internal/api/errorcodes.go @@ -0,0 +1,68 @@ +package api + +type ErrorCode = string + +const ( + ErrorCodeValidationFailed ErrorCode = "validation_failed" + ErrorCodeBadJSON ErrorCode = "bad_json" + ErrorCodeEmailExists ErrorCode = "email_exists" + ErrorCodePhoneExists ErrorCode = "phone_exists" + ErrorCodeBadJWT ErrorCode = "bad_jwt" + ErrorCodeNotAdmin ErrorCode = "not_admin" + ErrorCodeNoAuthorization ErrorCode = "no_authorization" + ErrorCodeUserNotFound ErrorCode = "user_not_found" + ErrorCodeSessionNotFound ErrorCode = "session_not_found" + ErrorCodeFlowStateNotFound ErrorCode = "flow_state_not_found" + ErrorCodeFlowStateExpired ErrorCode = "flow_state_expired" + ErrorCodeSignupDisabled ErrorCode = "signup_disabled" + ErrorCodeUserBanned ErrorCode = "user_banned" + ErrorCodeOverEmailSendRate ErrorCode = "over_email_send_rate" + ErrorCodeOverSMSSendRate ErrorCode = "over_sms_send_rate" + ErrorCodeProviderEmailNeedsVerification ErrorCode = "provider_email_needs_verification" + ErrorCodeInviteNotFound ErrorCode = "invite_not_found" + ErrorCodeBadOAuthState ErrorCode = "bad_oauth_state" + ErrorCodeBadOAuthCallback ErrorCode = "bad_oauth_callback" + ErrorCodeOAuthProviderNotSupported ErrorCode = "oauth_provider_not_supported" + ErrorCodeUnexpectedAudience ErrorCode = "unexpected_audience" + ErrorCodeLastIdentityNotDeletable ErrorCode = "last_identity_not_deletable" + ErrorCodeEmailConflictIdentityNotDeletable ErrorCode = "email_conflict_identity_not_deletable" + ErrorCodeIdentityAlreadyExists ErrorCode = "identity_already_exists" + ErrorCodeEmailProviderDisabled ErrorCode = "email_provider_disabled" + ErrorCodePhoneProviderDisabled ErrorCode = "phone_provider_disabled" + ErrorCodeTooManyEnrolledMFAFactors ErrorCode = "too_many_enrolled_mfa_factors" + ErrorCodeMFAFactorNameConflict ErrorCode = "mfa_factor_name_conflict" + ErrorCodeMFAFactorNotFound ErrorCode = "mfa_factor_not_found" + ErrorCodeMFAIPAddressMismatch ErrorCode = "mfa_ip_address_mismatch" + ErrorCodeMFAChallengeExpired ErrorCode = "mfa_challenge_expired" + ErrorCodeMFAVerificationFailed ErrorCode = "mfa_verification_failed" + ErrorCodeMFAVerificationRejected ErrorCode = "mfa_verification_rejected" + ErrorCodeInsufficientAAL ErrorCode = "insufficient_aal" + ErrorCodeCaptchaFailed ErrorCode = "captcha_failed" + ErrorCodeSAMLProviderDisabled ErrorCode = "saml_provider_disabled" + ErrorCodeManualLinkingDisabled ErrorCode = "manual_linking_disabled" + ErrorCodeSMSSendFailed ErrorCode = "sms_send_failed" + ErrorCodeEmailNotConfirmed ErrorCode = "email_not_confirmed" + ErrorCodePhoneNotConfirmed ErrorCode = "phone_not_confirmed" + ErrorCodeReauthNonceMissing ErrorCode = "reauth_nonce_missing" + ErrorCodeSAMLRelayStateNotFound ErrorCode = "saml_relay_state_not_found" + ErrorCodeSAMLRelayStateExpired ErrorCode = "saml_relay_state_expired" + ErrorCodeSAMLIdPNotFound ErrorCode = "saml_idp_not_found" + ErrorCodeSAMLAssertionNoUserID ErrorCode = "saml_assertion_no_user_id" + ErrorCodeSAMLAssertionNoEmail ErrorCode = "saml_assertion_no_email" + ErrorCodeUserAlreadyExists ErrorCode = "user_already_exists" + ErrorCodeSSOProviderNotFound ErrorCode = "sso_provider_not_found" + ErrorCodeSAMLMetadataFetchFailed ErrorCode = "saml_metadata_fetch_failed" + ErrorCodeSAMLIdPAlreadyExists ErrorCode = "saml_idp_already_exists" + ErrorCodeSSODomainAlreadyExists ErrorCode = "sso_domain_already_exists" + ErrorCodeSAMLEntityIDMismatch ErrorCode = "saml_entity_id_mismatch" + ErrorCodeConflict ErrorCode = "conflict" + ErrorCodeProviderDisabled ErrorCode = "provider_disabled" + ErrorCodeUserSSOManaged ErrorCode = "user_sso_managed" + ErrorCodeReauthenticationNeeded ErrorCode = "reauthentication_needed" + ErrorCodeSamePassword ErrorCode = "same_password" + ErrorCodeReauthenticationNotValid ErrorCode = "reauthentication_not_valid" + ErrorCodeOTPExpired ErrorCode = "otp_expired" + ErrorCodeOTPDisabled ErrorCode = "otp_disabled" + ErrorCodeIdentityNotFound ErrorCode = "identity_not_found" + ErrorCodeWeakPassword ErrorCode = "weak_password" +) diff --git a/internal/api/errors.go b/internal/api/errors.go index 56f404e3c9..8fb9fe343f 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -8,7 +8,6 @@ import ( "runtime/debug" "github.com/pkg/errors" - "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/utilities" ) @@ -65,26 +64,11 @@ func (e *OAuthError) Cause() error { return e } -func invalidSignupError(config *conf.GlobalConfiguration) *HTTPError { - var msg string - if config.External.Email.Enabled && config.External.Phone.Enabled { - msg = "To signup, please provide your email or phone number" - } else if config.External.Email.Enabled { - msg = "To signup, please provide your email" - } else if config.External.Phone.Enabled { - msg = "To signup, please provide your phone number" - } else { - // 3rd party OAuth signups - msg = "To signup, please provide required fields" - } - return unprocessableEntityError(msg) -} - func oauthError(err string, description string) *OAuthError { return &OAuthError{Err: err, Description: description} } -func badRequestError(fmtString string, args ...interface{}) *HTTPError { +func badRequestError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusBadRequest, fmtString, args...) } @@ -92,27 +76,27 @@ func internalServerError(fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusInternalServerError, fmtString, args...) } -func notFoundError(fmtString string, args ...interface{}) *HTTPError { +func notFoundError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusNotFound, fmtString, args...) } -func expiredTokenError(fmtString string, args ...interface{}) *HTTPError { +func expiredTokenError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusUnauthorized, fmtString, args...) } -func unauthorizedError(fmtString string, args ...interface{}) *HTTPError { +func unauthorizedError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusUnauthorized, fmtString, args...) } -func forbiddenError(fmtString string, args ...interface{}) *HTTPError { +func forbiddenError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusForbidden, fmtString, args...) } -func unprocessableEntityError(fmtString string, args ...interface{}) *HTTPError { +func unprocessableEntityError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusUnprocessableEntity, fmtString, args...) } -func tooManyRequestsError(fmtString string, args ...interface{}) *HTTPError { +func tooManyRequestsError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusTooManyRequests, fmtString, args...) } @@ -122,8 +106,9 @@ func conflictError(fmtString string, args ...interface{}) *HTTPError { // HTTPError is an error with a message and an HTTP status code. type HTTPError struct { - Code int `json:"code"` - Message string `json:"msg"` + HTTPStatus int `json:"code"` // do not rename the JSON tags! + ErrorCode string `json:"error_code"` // do not rename the JSON tags! + Message string `json:"msg"` // do not rename the JSON tags! InternalError error `json:"-"` InternalMessage string `json:"-"` ErrorID string `json:"error_id,omitempty"` @@ -133,7 +118,7 @@ func (e *HTTPError) Error() string { if e.InternalMessage != "" { return e.InternalMessage } - return fmt.Sprintf("%d: %s", e.Code, e.Message) + return fmt.Sprintf("%d: %s", e.HTTPStatus, e.Message) } func (e *HTTPError) Is(target error) bool { @@ -160,50 +145,11 @@ func (e *HTTPError) WithInternalMessage(fmtString string, args ...interface{}) * return e } -func httpError(code int, fmtString string, args ...interface{}) *HTTPError { +func httpError(httpStatus int, fmtString string, args ...interface{}) *HTTPError { return &HTTPError{ - Code: code, - Message: fmt.Sprintf(fmtString, args...), - } -} - -// OTPError is a custom error struct for phone auth errors -type OTPError struct { - Err string `json:"error"` - Description string `json:"error_description,omitempty"` - InternalError error `json:"-"` - InternalMessage string `json:"-"` -} - -func (e *OTPError) Error() string { - if e.InternalMessage != "" { - return e.InternalMessage - } - return fmt.Sprintf("%s: %s", e.Err, e.Description) -} - -// WithInternalError adds internal error information to the error -func (e *OTPError) WithInternalError(err error) *OTPError { - e.InternalError = err - return e -} - -// WithInternalMessage adds internal message information to the error -func (e *OTPError) WithInternalMessage(fmtString string, args ...interface{}) *OTPError { - e.InternalMessage = fmt.Sprintf(fmtString, args...) - return e -} - -// Cause returns the root cause error -func (e *OTPError) Cause() error { - if e.InternalError != nil { - return e.InternalError + HTTPStatus: httpStatus, + Message: fmt.Sprintf(fmtString, args...), } - return e -} - -func otpError(err string, description string) *OTPError { - return &OTPError{Err: err, Description: description} } // Recoverer is a middleware that recovers from panics, logs the panic (and a @@ -222,10 +168,10 @@ func recoverer(w http.ResponseWriter, r *http.Request) (context.Context, error) } se := &HTTPError{ - Code: http.StatusInternalServerError, - Message: http.StatusText(http.StatusInternalServerError), + HTTPStatus: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), } - handleError(se, w, r) + HandleResponseError(se, w, r) } }() @@ -237,28 +183,57 @@ type ErrorCause interface { Cause() error } -func handleError(err error, w http.ResponseWriter, r *http.Request) { +type HTTPErrorResponse20240101 struct { + Code ErrorCode `json:"code"` + Message string `json:"message"` +} + +func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { log := observability.GetLogEntry(r) errorID := getRequestID(r.Context()) + + apiVersion, averr := DetermineClosestAPIVersion(r.Header.Get(APIVersionHeaderName)) + if averr != nil { + log.WithError(averr).Warn("Invalid version passed to " + APIVersionHeaderName + " header, defaulting to initial version") + } + switch e := err.(type) { case *WeakPasswordError: - var output struct { - HTTPError - Payload struct { - Reasons []string `json:"reasons,omitempty"` - } `json:"weak_password,omitempty"` - } + if apiVersion.Compare(APIVersion20240101) >= 0 { + var output struct { + HTTPErrorResponse20240101 + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } - output.Code = http.StatusUnprocessableEntity - output.Message = e.Message - output.Payload.Reasons = e.Reasons + output.Code = ErrorCodeWeakPassword + output.Message = e.Message + output.Payload.Reasons = e.Reasons - if jsonErr := sendJSON(w, output.Code, output); jsonErr != nil { - handleError(jsonErr, w, r) + if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + + } else { + var output struct { + HTTPError + Payload struct { + Reasons []string `json:"reasons,omitempty"` + } `json:"weak_password,omitempty"` + } + + output.HTTPStatus = http.StatusUnprocessableEntity + output.Message = e.Message + output.Payload.Reasons = e.Reasons + + if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } } case *HTTPError: - if e.Code >= http.StatusInternalServerError { + if e.HTTPStatus >= http.StatusInternalServerError { e.ErrorID = errorID // this will get us the stack trace too log.WithError(e.Cause()).Error(e.Error()) @@ -266,29 +241,36 @@ func handleError(err error, w http.ResponseWriter, r *http.Request) { log.WithError(e.Cause()).Info(e.Error()) } - // Provide better error messages for certain user-triggered Postgres errors. - if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { - if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil { - handleError(jsonErr, w, r) + if apiVersion.Compare(APIVersion20240101) >= 0 { + resp := HTTPErrorResponse20240101{ + Code: e.ErrorCode, + Message: e.Message, + } + + if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + } else { + // Provide better error messages for certain user-triggered Postgres errors. + if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { + if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } + return } - return - } - if jsonErr := sendJSON(w, e.Code, e); jsonErr != nil { - handleError(jsonErr, w, r) + if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil { + HandleResponseError(jsonErr, w, r) + } } + case *OAuthError: log.WithError(e.Cause()).Info(e.Error()) if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { - handleError(jsonErr, w, r) - } - case *OTPError: - log.WithError(e.Cause()).Info(e.Error()) - if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { - handleError(jsonErr, w, r) + HandleResponseError(jsonErr, w, r) } case ErrorCause: - handleError(e.Cause(), w, r) + HandleResponseError(e.Cause(), w, r) default: log.WithError(e).Errorf("Unhandled server error: %s", e.Error()) // hide real error details from response to prevent info leaks diff --git a/internal/api/errors_test.go b/internal/api/errors_test.go new file mode 100644 index 0000000000..c35909e899 --- /dev/null +++ b/internal/api/errors_test.go @@ -0,0 +1,45 @@ +package api + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHandleResponseErrorWithHTTPError(t *testing.T) { + examples := []struct { + HTTPError *HTTPError + APIVersion string + ExpectedBody string + }{ + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "", + }, + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2023-12-31", + }, + { + HTTPError: badRequestError(ErrorCodeBadJSON, "Unable to parse JSON"), + APIVersion: "2024-01-01", + }, + } + + for _, example := range examples { + rec := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodPost, "http://example.com", nil) + require.NoError(t, err) + + if example.APIVersion != "" { + req.Header.Set(APIVersionHeaderName, example.APIVersion) + } + + HandleResponseError(example.HTTPError, rec, req) + + require.Equal(t, rec.Code, example.HTTPError.HTTPStatus) + require.Equal(t, rec.Body.String(), "") + } +} diff --git a/internal/api/external.go b/internal/api/external.go index a20e3ba597..68a9370f21 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -56,7 +56,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ p, err := a.Provider(ctx, providerType, scopes) if err != nil { - return "", badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return "", badRequestError(ErrorCodeValidationFailed, "Unsupported provider: %+v", err).WithInternalError(err) } inviteToken := query.Get("invite_token") @@ -64,7 +64,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ _, userErr := models.FindUserByConfirmationToken(db, inviteToken) if userErr != nil { if models.IsNotFoundError(userErr) { - return "", notFoundError(userErr.Error()) + return "", notFoundError(ErrorCodeUserNotFound, "User identified by token not found") } return "", internalServerError("Database error finding user").WithInternalError(userErr) } @@ -82,14 +82,12 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ if flowType == models.PKCEFlow { codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod) if err != nil { - return "", err - } - flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth) - if err != nil { - return "", err + return "", badRequestError(ErrorCodeValidationFailed, "Code challenge not valid").WithInternalError(err) } + flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth) + if err := a.db.Create(flowState); err != nil { - return "", err + return "", internalServerError("Failed to create flow state").WithInternalError(err) } flowStateID = flowState.ID.String() } @@ -134,6 +132,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ } authURL := p.AuthCodeURL(tokenString, authUrlParams...) + return authURL, nil } @@ -203,9 +202,12 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re // if there's a non-empty FlowStateID we perform PKCE Flow if flowStateID := getFlowStateID(ctx); flowStateID != "" { flowState, err = models.FindFlowStateByID(a.db, flowStateID) - if err != nil { - return err + if models.IsNotFoundError(err) { + return notFoundError(ErrorCodeFlowStateNotFound, "Flow state not found").WithInternalError(err) + } else if err != nil { + return internalServerError("Failed to find flow state").WithInternalError(err) } + } var user *models.User @@ -304,7 +306,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. case models.CreateAccount: if config.DisableSignup { - return nil, forbiddenError("Signups not allowed for this instance") + return nil, unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{ @@ -351,14 +353,14 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } case models.MultipleAccounts: - return nil, internalServerError(fmt.Sprintf("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain)) + return nil, internalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain) default: - return nil, internalServerError(fmt.Sprintf("Unknown automatic linking decision: %v", decision.Decision)) + return nil, internalServerError("Unknown automatic linking decision: %v", decision.Decision) } if user.IsBanned() { - return nil, unauthorizedError("User is unauthorized") + return nil, unauthorizedError(ErrorCodeUserBanned, "User is banned") } if !user.IsConfirmed() { @@ -391,7 +393,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. externalURL := getExternalHost(ctx) if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { - return nil, tooManyRequestsError("For security purposes, you can only request this once every minute") + return nil, tooManyRequestsError(ErrorCodeOverEmailSendRate, "For security purposes, you can only request this once every minute") } return nil, internalServerError("Error sending confirmation mail").WithInternalError(terr) } @@ -399,9 +401,9 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } if !config.Mailer.AllowUnverifiedEmailSignIns { if emailConfirmationSent { - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) + return nil, storage.NewCommitWithError(unauthorizedError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) } - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) + return nil, storage.NewCommitWithError(unauthorizedError(ErrorCodeProviderEmailNeedsVerification, fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) } } } else { @@ -423,7 +425,7 @@ func (a *API) processInvite(r *http.Request, ctx context.Context, tx *storage.Co user, err := models.FindUserByConfirmationToken(tx, inviteToken) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()) + return nil, notFoundError(ErrorCodeInviteNotFound, "Invite not found") } return nil, internalServerError("Database error finding user").WithInternalError(err) } @@ -439,7 +441,7 @@ func (a *API) processInvite(r *http.Request, ctx context.Context, tx *storage.Co } if emailData == nil { - return nil, badRequestError("Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) + return nil, badRequestError(ErrorCodeValidationFailed, "Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) } var identityData map[string]interface{} @@ -495,8 +497,11 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont _, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) { return []byte(config.JWT.Secret), nil }) - if err != nil || claims.Provider == "" { - return nil, badRequestError("OAuth state is invalid: %v", err) + if err != nil { + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err) + } + if claims.Provider == "" { + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (missing provider)") } if claims.InviteToken != "" { ctx = withInviteToken(ctx, claims.InviteToken) @@ -510,12 +515,12 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont if claims.LinkingTargetID != "" { linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) if err != nil { - return nil, badRequestError("invalid target user id") + return nil, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state (linking_target_id must be UUID)") } u, err := models.FindUserByID(a.db, linkingTargetUserID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("Linking target user not found") + return nil, notFoundError(ErrorCodeUserNotFound, "Linking target user not found") } return nil, internalServerError("Database error loading user").WithInternalError(err) } @@ -606,12 +611,12 @@ func (a *API) redirectErrors(handler apiHandler, w http.ResponseWriter, r *http. func getErrorQueryString(err error, errorID string, log logrus.FieldLogger, q url.Values) *url.Values { switch e := err.(type) { case *HTTPError: - if str, ok := oauthErrorMap[e.Code]; ok { + if str, ok := oauthErrorMap[e.HTTPStatus]; ok { q.Set("error", str) } else { q.Set("error", "server_error") } - if e.Code >= http.StatusInternalServerError { + if e.HTTPStatus >= http.StatusInternalServerError { e.ErrorID = errorID // this will get us the stack trace too log.WithError(e.Cause()).Error(e.Error()) @@ -619,7 +624,7 @@ func getErrorQueryString(err error, errorID string, log logrus.FieldLogger, q ur log.WithError(e.Cause()).Info(e.Error()) } q.Set("error_description", e.Message) - q.Set("error_code", strconv.Itoa(e.Code)) + q.Set("error_code", strconv.Itoa(e.HTTPStatus)) case *OAuthError: q.Set("error", e.Err) q.Set("error_description", e.Description) diff --git a/internal/api/external_figma_test.go b/internal/api/external_figma_test.go index 56d2f478d1..758a4b8f26 100644 --- a/internal/api/external_figma_test.go +++ b/internal/api/external_figma_test.go @@ -260,5 +260,5 @@ func (ts *ExternalTestSuite) TestSignupExternalFigmaErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "figma", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "unauthorized_client", "") } diff --git a/internal/api/external_fly_test.go b/internal/api/external_fly_test.go index c469f29004..ac82e60109 100644 --- a/internal/api/external_fly_test.go +++ b/internal/api/external_fly_test.go @@ -260,5 +260,5 @@ func (ts *ExternalTestSuite) TestSignupExternalFlyErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "fly", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "unauthorized_client", "") } diff --git a/internal/api/external_github_test.go b/internal/api/external_github_test.go index b3ad584407..3006aafc83 100644 --- a/internal/api/external_github_test.go +++ b/internal/api/external_github_test.go @@ -296,5 +296,5 @@ func (ts *ExternalTestSuite) TestSignupExternalGitHubErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "github", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "unauthorized_client", "") } diff --git a/internal/api/external_kakao_test.go b/internal/api/external_kakao_test.go index 7882e1dce3..1ea92b66a0 100644 --- a/internal/api/external_kakao_test.go +++ b/internal/api/external_kakao_test.go @@ -234,5 +234,5 @@ func (ts *ExternalTestSuite) TestSignupExternalKakaoErrorWhenUserBanned() { require.NoError(ts.T(), ts.API.db.UpdateOnly(user, "banned_until")) u = performAuthorization(ts, "kakao", code, "") - assertAuthorizationFailure(ts, u, "User is unauthorized", "unauthorized_client", "") + assertAuthorizationFailure(ts, u, "User is banned", "unauthorized_client", "") } diff --git a/internal/api/external_oauth.go b/internal/api/external_oauth.go index e37b7d6820..63aaf895bc 100644 --- a/internal/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -2,6 +2,7 @@ package api import ( "context" + "fmt" "net/http" "net/url" @@ -30,7 +31,7 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con } if state == "" { - return nil, badRequestError("OAuth state parameter missing") + return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth state parameter missing") } ctx := r.Context() @@ -60,12 +61,12 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s oauthCode := rq.Get("code") if oauthCode == "" { - return nil, badRequestError("Authorization code missing") + return nil, badRequestError(ErrorCodeBadOAuthCallback, "OAuth callback with missing authorization code missing") } oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } log := observability.GetLogEntry(r) @@ -107,7 +108,7 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s func (a *API) oAuth1Callback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) { oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return nil, badRequestError(ErrorCodeOAuthProviderNotSupported, "Unsupported provider: %+v", err).WithInternalError(err) } oauthToken := getRequestToken(ctx) oauthVerifier := getOAuthVerifier(ctx) @@ -145,6 +146,6 @@ func (a *API) OAuthProvider(ctx context.Context, name string) (provider.OAuthPro case provider.OAuthProvider: return p, nil default: - return nil, badRequestError("Provider can not be used for OAuth") + return nil, fmt.Errorf("Provider %v cannot be used for OAuth", name) } } diff --git a/internal/api/helpers_test.go b/internal/api/helpers_test.go index ec5812e093..15f9ce4d67 100644 --- a/internal/api/helpers_test.go +++ b/internal/api/helpers_test.go @@ -16,12 +16,12 @@ func TestIsValidCodeChallenge(t *testing.T) { { challenge: "invalid", isValid: false, - expectedError: badRequestError("code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), + expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength), }, { challenge: "codechallengecontainsinvalidcharacterslike@$^&*", isValid: false, - expectedError: badRequestError("code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), + expectedError: badRequestError(ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes"), }, { challenge: "validchallengevalidchallengevalidchallengevalidchallenge", @@ -56,12 +56,12 @@ func TestIsValidPKCEParmas(t *testing.T) { { challengeMethod: "test", challenge: "", - expected: badRequestError(InvalidPKCEParamsErrorMessage), + expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), }, { challengeMethod: "", challenge: "test", - expected: badRequestError(InvalidPKCEParamsErrorMessage), + expected: badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage), }, } diff --git a/internal/api/hook_test.go b/internal/api/hook_test.go index a5d573db09..c9e7a703a3 100644 --- a/internal/api/hook_test.go +++ b/internal/api/hook_test.go @@ -158,7 +158,7 @@ func TestHookTimeout(t *testing.T) { require.Error(t, err) herr, ok := err.(*HTTPError) require.True(t, ok) - assert.Equal(t, http.StatusGatewayTimeout, herr.Code) + assert.Equal(t, http.StatusGatewayTimeout, herr.HTTPStatus) svr.Close() assert.Equal(t, 3, callCount) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 2fd31983e2..d988b4f143 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -135,7 +135,7 @@ func (w *Webhook) trigger() (io.ReadCloser, error) { } hooklog.Infof("Failed to process webhook for %s after %d attempts", w.URL, w.Retries) - return nil, unprocessableEntityError("Failed to handle signup webhook") + return nil, internalServerError("Failed to handle signup webhook") } func (w *Webhook) generateSignature() (string, error) { @@ -347,8 +347,8 @@ func (a *API) invokeHook(ctx context.Context, input, output any) error { } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -373,8 +373,8 @@ func (a *API) invokeHook(ctx context.Context, input, output any) error { } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -399,8 +399,8 @@ func (a *API) invokeHook(ctx context.Context, input, output any) error { } httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, + HTTPStatus: httpCode, + Message: hookOutput.HookError.Message, } return httpError.WithInternalError(&hookOutput.HookError) @@ -413,8 +413,8 @@ func (a *API) invokeHook(ctx context.Context, input, output any) error { } httpError := &HTTPError{ - Code: httpCode, - Message: err.Error(), + HTTPStatus: httpCode, + Message: err.Error(), } return httpError diff --git a/internal/api/identity.go b/internal/api/identity.go index 14f2c167d9..465d4ca797 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -17,22 +17,22 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { claims := getClaims(ctx) if claims == nil { - return badRequestError("Could not read claims") + return internalServerError("Could not read claims") } aud := a.requestAud(ctx, r) if aud != claims.Audience { - return badRequestError("Token audience doesn't match request audience") + return unauthorizedError(ErrorCodeUnexpectedAudience, "Token audience doesn't match request audience") } identityID, err := uuid.FromString(chi.URLParam(r, "identity_id")) if err != nil { - return badRequestError("identity_id must be an UUID") + return badRequestError(ErrorCodeValidationFailed, "identity_id must be an UUID") } user := getUser(ctx) if len(user.Identities) <= 1 { - return badRequestError("User must have at least 1 identity after unlinking") + return unprocessableEntityError(ErrorCodeLastIdentityNotDeletable, "User must have at least 1 identity after unlinking") } var identityToBeDeleted *models.Identity for i := range user.Identities { @@ -43,7 +43,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { } } if identityToBeDeleted == nil { - return badRequestError("Identity doesn't exist") + return notFoundError(ErrorCodeIdentityNotFound, "Identity doesn't exist") } err = a.db.Transaction(func(tx *storage.Connection) error { @@ -59,7 +59,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { } if terr := user.UpdateUserEmail(tx); terr != nil { if models.IsUniqueConstraintViolatedError(terr) { - return forbiddenError("Unable to unlink identity due to email conflict").WithInternalError(terr) + return unprocessableEntityError(ErrorCodeEmailConflictIdentityNotDeletable, "Unable to unlink identity due to email conflict").WithInternalError(terr) } return internalServerError("Database error updating user email").WithInternalError(terr) } @@ -102,9 +102,9 @@ func (a *API) linkIdentityToUser(ctx context.Context, tx *storage.Connection, us } if identity != nil { if identity.UserID == targetUser.ID { - return nil, badRequestError("Identity is already linked") + return nil, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked") } - return nil, badRequestError("Identity is already linked to another user") + return nil, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked to another user") } if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil { return nil, terr diff --git a/internal/api/identity_test.go b/internal/api/identity_test.go index 8238bf6c20..f6bca67dfd 100644 --- a/internal/api/identity_test.go +++ b/internal/api/identity_test.go @@ -72,6 +72,6 @@ func (ts *IdentityTestSuite) TestLinkIdentityToUser() { }, } u, err = ts.API.linkIdentityToUser(ctx, ts.API.db, testExistingUserData, "email") - require.ErrorIs(ts.T(), err, badRequestError("Identity is already linked")) + require.ErrorIs(ts.T(), err, unprocessableEntityError(ErrorCodeIdentityAlreadyExists, "Identity is already linked")) require.Nil(ts.T(), u) } diff --git a/internal/api/invite.go b/internal/api/invite.go index 2a0aeb51d9..7d4ea593e6 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -27,11 +27,11 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read Invite params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read Invite params: %v", err).WithInternalError(err) } params.Email, err = validateEmail(params.Email) @@ -48,7 +48,7 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { err = db.Transaction(func(tx *storage.Connection) error { if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } else { signupParams := SignupParams{ diff --git a/internal/api/logout.go b/internal/api/logout.go index ad95b22a49..cd1394edaa 100644 --- a/internal/api/logout.go +++ b/internal/api/logout.go @@ -36,7 +36,7 @@ func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { scope = LogoutOthers default: - return badRequestError(fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) + return badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) } } diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index e1b12caafd..307f9f5da5 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -24,7 +25,7 @@ type MagicLinkParams struct { func (p *MagicLinkParams) Validate() error { if p.Email == "" { - return unprocessableEntityError("Password recovery requires an email") + return unprocessableEntityError(ErrorCodeValidationFailed, "Password recovery requires an email") } var err error p.Email, err = validateEmail(p.Email) @@ -44,14 +45,14 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } params := &MagicLinkParams{} jsonDecoder := json.NewDecoder(r.Body) err := jsonDecoder.Decode(params) if err != nil { - return badRequestError("Could not read verification params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read verification params: %v", err).WithInternalError(err) } if err := params.Validate(); err != nil { @@ -82,7 +83,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { // Sign them up with temporary password. password, err := password.Generate(64, 10, 1, false, true) if err != nil { - internalServerError("error creating user").WithInternalError(err) + return internalServerError("error creating user").WithInternalError(err) } signUpParams := &SignupParams{ @@ -94,7 +95,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) } r.Body = io.NopCloser(strings.NewReader(string(newBodyContent))) r.ContentLength = int64(len(string(newBodyContent))) @@ -113,7 +115,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { } metadata, err := json.Marshal(newBodyContent) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must always be marshallable + panic(fmt.Errorf("failed to marshal SignupParams: %w", err)) } r.Body = io.NopCloser(bytes.NewReader(metadata)) return a.MagicLink(w, r) @@ -148,7 +151,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRate, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Error sending magic link").WithInternalError(err) } diff --git a/internal/api/mail.go b/internal/api/mail.go index 6d4bd0817c..4910a83df0 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -52,11 +52,11 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not parse JSON: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not parse JSON: %v", err).WithInternalError(err) } params.Email, err = validateEmail(params.Email) @@ -72,14 +72,17 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) if err != nil { if models.IsNotFoundError(err) { - if params.Type == magicLinkVerification { + switch params.Type { + case magicLinkVerification: params.Type = signupVerification params.Password, err = password.Generate(64, 10, 1, false, true) if err != nil { - return internalServerError("error creating user").WithInternalError(err) + // password generation must always succeed + panic(err) } - } else if params.Type == recoveryVerification || params.Type == "email_change_current" || params.Type == "email_change_new" { - return notFoundError(err.Error()) + + default: + return notFoundError(ErrorCodeUserNotFound, "User with this email not found") } } else { return internalServerError("Database error finding user").WithInternalError(err) @@ -90,7 +93,8 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { now := time.Now() otp, err := crypto.GenerateOtp(config.Mailer.OtpLength) if err != nil { - return err + // OTP generation must always succeed + panic(err) } hashedToken := crypto.GenerateTokenHash(params.Email, otp) @@ -124,11 +128,14 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } user.RecoveryToken = hashedToken user.RecoverySentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + terr = tx.UpdateOnly(user, "recovery_token", "recovery_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for recovery") + } case inviteVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } else { signupParams := &SignupParams{ @@ -168,11 +175,14 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { user.ConfirmationToken = hashedToken user.ConfirmationSentAt = &now user.InvitedAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for invite") + } case signupVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } if err := user.UpdateUserMetaData(tx, params.Data); err != nil { return internalServerError("Database error updating user").WithInternalError(err) @@ -197,19 +207,22 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } user.ConfirmationToken = hashedToken user.ConfirmationSentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for confirmation") + } case "email_change_current", "email_change_new": if !config.Mailer.SecureEmailChangeEnabled && params.Type == "email_change_current" { - return unprocessableEntityError("Enable secure email change to generate link for current email") + return badRequestError(ErrorCodeValidationFailed, "Enable secure email change to generate link for current email") } params.NewEmail, terr = validateEmail(params.NewEmail) if terr != nil { - return unprocessableEntityError("The new email address provided is invalid") + return terr } if duplicateUser, terr := models.IsDuplicatedEmail(tx, params.NewEmail, user.Aud, user); terr != nil { return internalServerError("Database error checking email").WithInternalError(terr) } else if duplicateUser != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } now := time.Now() user.EmailChangeSentAt = &now @@ -220,9 +233,12 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } else if params.Type == "email_change_new" { user.EmailChangeTokenNew = crypto.GenerateTokenHash(params.NewEmail, otp) } - terr = errors.Wrap(tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status"), "Database error updating user for email change") + terr = tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for email change") + } default: - return badRequestError("Invalid email action link type requested: %v", params.Type) + return badRequestError(ErrorCodeValidationFailed, "Invalid email action link type requested: %v", params.Type) } if terr != nil { @@ -261,7 +277,8 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail oldToken := u.ConfirmationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.ConfirmationToken = addFlowPrefixToToken(token, flowType) @@ -271,7 +288,12 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail return errors.Wrap(err, "Error sending confirmation email") } u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for confirmation") + } + + return nil } func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, externalURL *url.URL, otpLength int) error { @@ -279,7 +301,8 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re oldToken := u.ConfirmationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.ConfirmationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) now := time.Now() @@ -289,7 +312,12 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re } u.InvitedAt = &now u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for invite") + } + + return nil } func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { @@ -301,7 +329,8 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile oldToken := u.RecoveryToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) @@ -311,7 +340,12 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile return errors.Wrap(err, "Error sending recovery email") } u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for recovery") + } + + return nil } func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, otpLength int) error { @@ -323,19 +357,22 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma oldToken := u.ReauthenticationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) - if err != nil { - return err - } now := time.Now() if err := mailer.ReauthenticateMail(u, otp); err != nil { u.ReauthenticationToken = oldToken return errors.Wrap(err, "Error sending reauthentication email") } u.ReauthenticationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"), "Database error updating user for reauthentication") + err = tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for reauthentication") + } + + return nil } func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { @@ -348,7 +385,8 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile oldToken := u.RecoveryToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) @@ -359,7 +397,12 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile return errors.Wrap(err, "Error sending magic link email") } u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for recovery") + } + + return nil } // sendEmailChange sends out an email change token to the new email. @@ -370,7 +413,8 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu } otpNew, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.EmailChange = email token := crypto.GenerateTokenHash(u.EmailChange, otpNew) @@ -380,7 +424,8 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu if config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { otpCurrent, err = crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } currentToken := crypto.GenerateTokenHash(u.GetEmail(), otpCurrent) u.EmailChangeTokenCurrent = addFlowPrefixToToken(currentToken, flowType) @@ -396,22 +441,28 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu } u.EmailChangeSentAt = &now - return errors.Wrap(tx.UpdateOnly( + err = tx.UpdateOnly( u, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status", - ), "Database error updating user for email change") + ) + + if err != nil { + return errors.Wrap(err, "Database error updating user for email change") + } + + return nil } func validateEmail(email string) (string, error) { if email == "" { - return "", unprocessableEntityError("An email address is required") + return "", badRequestError(ErrorCodeValidationFailed, "An email address is required") } if err := checkmail.ValidateFormat(email); err != nil { - return "", unprocessableEntityError("Unable to validate email address: " + err.Error()) + return "", badRequestError(ErrorCodeValidationFailed, "Unable to validate email address: "+err.Error()) } return strings.ToLower(email), nil } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 434a7117c7..6c0e00052e 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -72,11 +72,11 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "invalid body: unable to parse JSON").WithInternalError(err) } if params.FactorType != models.TOTP { - return badRequestError("factor_type needs to be totp") + return badRequestError(ErrorCodeValidationFailed, "factor_type needs to be totp") } if params.Issuer == "" { @@ -96,7 +96,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if len(factors) >= int(config.MFA.MaxEnrolledFactors) { - return forbiddenError("Enrolled factors exceed allowed limit, unenroll to continue") + return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Enrolled factors exceed allowed limit, unenroll to continue") } numVerifiedFactors := 0 @@ -107,7 +107,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return forbiddenError("Maximum number of enrolled factors reached, unenroll to continue") + return forbiddenError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of enrolled factors reached, unenroll to continue") } key, err := totp.Generate(totp.GenerateOpts{ @@ -133,7 +133,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { if terr := tx.Create(factor); terr != nil { pgErr := utilities.NewPostgresError(terr) if pgErr.IsUniqueConstraintViolated() { - return badRequestError(fmt.Sprintf("a factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) + return unprocessableEntityError(ErrorCodeMFAFactorNameConflict, fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) } return terr @@ -209,7 +209,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "invalid body: unable to parse JSON").WithInternalError(err) } if !factor.IsOwnedBy(user) { @@ -219,13 +219,13 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { challenge, err := models.FindChallengeByChallengeID(a.db, params.ChallengeID) if err != nil { if models.IsNotFoundError(err) { - return notFoundError(err.Error()) + return notFoundError(ErrorCodeMFAFactorNotFound, "MFA factor with the provided challenge ID not found") } return internalServerError("Database error finding Challenge").WithInternalError(err) } if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { - return badRequestError("Challenge and verify IP addresses mismatch") + return unprocessableEntityError(ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch") } if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { @@ -239,7 +239,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID) + return unprocessableEntityError(ErrorCodeMFAChallengeExpired, "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) } valid := totp.Validate(params.Code, factor.Secret) @@ -267,11 +267,11 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { output.Message = hooks.DefaultMFAHookRejectionMessage } - return forbiddenError(output.Message) + return forbiddenError(ErrorCodeMFAVerificationRejected, output.Message) } } if !valid { - return badRequestError("Invalid TOTP code entered") + return unprocessableEntityError(ErrorCodeMFAVerificationFailed, "Invalid TOTP code entered") } var token *AccessTokenResponse @@ -332,7 +332,7 @@ func (a *API) UnenrollFactor(w http.ResponseWriter, r *http.Request) error { } if factor.IsVerified() && !session.IsAAL2() { - return badRequestError("AAL2 required to unenroll verified factor") + return unprocessableEntityError(ErrorCodeInsufficientAAL, "AAL2 required to unenroll verified factor") } if !factor.IsOwnedBy(user) { return internalServerError(InvalidFactorOwnerErrorMessage) diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 5fe7f0bc60..8d30be7d1d 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -5,13 +5,14 @@ import ( "context" "encoding/json" "fmt" - "github.com/gofrs/uuid" "net/http" "net/http/httptest" "strings" "testing" "time" + "github.com/gofrs/uuid" + "github.com/pquerna/otp" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" @@ -163,14 +164,14 @@ func (ts *MFATestSuite) TestDuplicateEnrollsReturnExpectedMessage() { token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, ts.TestUser, nil, models.TOTPSignIn) require.NoError(ts.T(), err) _ = performEnrollFlow(ts, token, friendlyName, models.TOTP, "https://issuer.com", http.StatusOK) - response := performEnrollFlow(ts, token, friendlyName, models.TOTP, "https://issuer.com", http.StatusBadRequest) + response := performEnrollFlow(ts, token, friendlyName, models.TOTP, "https://issuer.com", http.StatusUnprocessableEntity) var errorResponse HTTPError err = json.NewDecoder(response.Body).Decode(&errorResponse) require.NoError(ts.T(), err) // Convert the response body to a string and check for the expected error message - expectedErrorMessage := fmt.Sprintf("a factor with the friendly name %q for this user likely already exists", friendlyName) + expectedErrorMessage := fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", friendlyName) require.Contains(ts.T(), errorResponse.Message, expectedErrorMessage) } @@ -193,13 +194,13 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { desc: "Invalid: Valid code and expired challenge", validChallenge: false, validCode: true, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Invalid: Invalid code and valid challenge ", validChallenge: true, validCode: false, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Valid /verify request", @@ -284,7 +285,7 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() { { desc: "Verified Factor: AAL1", isAAL2: false, - expectedHTTPCode: http.StatusBadRequest, + expectedHTTPCode: http.StatusUnprocessableEntity, }, { desc: "Verified Factor: AAL2, Success", diff --git a/internal/api/middleware.go b/internal/api/middleware.go index c0f3d857a1..75eedace01 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -106,7 +106,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { } if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { - return c, badRequestError("Error invalid request body").WithInternalError(err) + return c, badRequestError(ErrorCodeBadJSON, "Error invalid request body").WithInternalError(err) } if shouldRateLimitEmail { @@ -117,7 +117,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { 1, attribute.String("path", req.URL.Path), ) - return c, httpError(http.StatusTooManyRequests, "Email rate limit exceeded") + return c, tooManyRequestsError(ErrorCodeOverEmailSendRate, "Email rate limit exceeded") } } } @@ -125,7 +125,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { if shouldRateLimitPhone { if requestBody.Phone != "" { if err := tollbooth.LimitByKeys(phoneLimiter, []string{"phone_functions"}); err != nil { - return c, httpError(http.StatusTooManyRequests, "Sms rate limit exceeded") + return c, tooManyRequestsError(ErrorCodeOverSMSSendRate, "SMS rate limit exceeded") } } } @@ -156,7 +156,7 @@ func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (co config := a.config if !config.External.Email.Enabled { - return nil, badRequestError("Email logins are disabled") + return nil, badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } return ctx, nil @@ -183,8 +183,7 @@ func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.C } if !verificationResult.Success { - return nil, badRequestError("captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) - + return nil, badRequestError(ErrorCodeCaptchaFailed, "captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) } return ctx, nil @@ -228,7 +227,7 @@ func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (con func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.SAML.Enabled { - return nil, notFoundError("SAML 2.0 is disabled") + return nil, notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled") } return ctx, nil } @@ -236,7 +235,7 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.Security.ManualLinkingEnabled { - return nil, notFoundError("Manual linking is disabled") + return nil, notFoundError(ErrorCodeManualLinkingDisabled, "Manual linking is disabled") } return ctx, nil } diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index c532a50ef6..e591121f86 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -176,7 +176,7 @@ func (ts *MiddlewareTestSuite) TestVerifyCaptchaInvalid() { w := httptest.NewRecorder() _, err := ts.API.verifyCaptcha(w, req) - require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).Code) + require.Equal(ts.T(), c.expectedCode, err.(*HTTPError).HTTPStatus) require.Equal(ts.T(), c.expectedMsg, err.(*HTTPError).Message) }) } @@ -201,8 +201,8 @@ func (ts *MiddlewareTestSuite) TestLimitEmailOrPhoneSentHandler() { }, }, { - desc: "Sms rate limit exceeded", - expectedErrorMsg: "429: Sms rate limit exceeded", + desc: "SMS rate limit exceeded", + expectedErrorMsg: "429: SMS rate limit exceeded", requestBody: map[string]interface{}{ "phone": "+1233456789", }, @@ -269,7 +269,7 @@ func (ts *MiddlewareTestSuite) TestRequireSAMLEnabled() { { desc: "SAML not enabled", isEnabled: false, - expectedErr: notFoundError("SAML 2.0 is disabled"), + expectedErr: notFoundError(ErrorCodeSAMLProviderDisabled, "SAML 2.0 is disabled"), }, { desc: "SAML enabled", diff --git a/internal/api/otp.go b/internal/api/otp.go index 0700f09705..016a56159d 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -34,10 +34,10 @@ type SmsParams struct { func (p *OtpParams) Validate() error { if p.Email != "" && p.Phone != "" { - return badRequestError("Only an email address or phone number should be provided") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided") } if p.Email != "" && p.Channel != "" { - return badRequestError("Channel should only be specified with Phone OTP") + return badRequestError(ErrorCodeValidationFailed, "Channel should only be specified with Phone OTP") } if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { return err @@ -47,7 +47,7 @@ func (p *OtpParams) Validate() error { func (p *SmsParams) Validate(smsProvider string) error { if p.Phone != "" && !sms_provider.IsValidMessageChannel(p.Channel, smsProvider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } var err error @@ -74,7 +74,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { } if err = json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read verification params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read verification params: %v", err) } if err := params.Validate(); err != nil { @@ -85,7 +85,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { } if ok, err := a.shouldCreateUser(r, params); !ok { - return badRequestError("Signups not allowed for otp") + return unprocessableEntityError(ErrorCodeOTPDisabled, "Signups not allowed for otp") } else if err != nil { return err } @@ -96,7 +96,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { return a.SmsOtp(w, r) } - return otpError("unsupported_otp_type", "") + return badRequestError(ErrorCodeValidationFailed, "One of email or phone must be set") } type SmsOtpResponse struct { @@ -110,7 +110,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Phone.Enabled { - return badRequestError("Unsupported phone provider") + return badRequestError(ErrorCodePhoneProviderDisabled, "Unsupported phone provider") } var err error @@ -118,11 +118,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read sms otp params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read sms otp params: %v", err) } // For backwards compatibility, we default to SMS if params Channel is not specified if params.Phone != "" && params.Channel == "" { @@ -151,7 +151,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { // Sign them up with temporary password. password, err := password.Generate(64, 10, 1, false, true) if err != nil { - internalServerError("error creating user").WithInternalError(err) + return internalServerError("error creating user").WithInternalError(err) } signUpParams := &SignupParams{ @@ -162,7 +162,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must be marshallable + panic(err) } r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) @@ -180,7 +181,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must be marshallable + panic(err) } r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) return a.SmsOtp(w, r) @@ -201,11 +203,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Unable to get SMS provider").WithInternalError(err) } mID, serr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) if serr != nil { - return badRequestError("Error sending sms OTP: %v", serr) + return badRequestError(ErrorCodeSMSSendFailed, "Error sending sms OTP: %v", serr).WithInternalError(serr) } messageID = mID return nil diff --git a/internal/api/otp_test.go b/internal/api/otp_test.go index be3b181145..a007a232ae 100644 --- a/internal/api/otp_test.go +++ b/internal/api/otp_test.go @@ -115,10 +115,10 @@ func (ts *OtpTestSuite) TestOtpPKCE() { code int response map[string]interface{} }{ - http.StatusBadRequest, + http.StatusInternalServerError, map[string]interface{}{ - "code": float64(http.StatusBadRequest), - "msg": "Error sending sms:", + "code": float64(http.StatusInternalServerError), + "msg": "Unable to get SMS provider", }, }, }, diff --git a/internal/api/phone.go b/internal/api/phone.go index 5d6e9bda30..fcd28b51c4 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -25,7 +25,7 @@ const ( func validatePhone(phone string) (string, error) { phone = formatPhoneNumber(phone) if isValid := validateE164Format(phone); !isValid { - return "", unprocessableEntityError("Invalid phone number format (E.164 required)") + return "", badRequestError(ErrorCodeValidationFailed, "Invalid phone number format (E.164 required)") } return phone, nil } diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go index 911cb2c710..24cfbc06e3 100644 --- a/internal/api/phone_test.go +++ b/internal/api/phone_test.go @@ -178,8 +178,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "password": "testpassword", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending confirmation sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -191,8 +191,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "phone": "123456789", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -204,8 +204,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { "phone": "111111111", }, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, { @@ -215,8 +215,8 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { header: "", body: nil, expected: map[string]interface{}{ - "code": http.StatusBadRequest, - "message": "Error sending sms:", + "code": http.StatusInternalServerError, + "message": "Unable to get SMS provider", }, }, } @@ -245,7 +245,12 @@ func (ts *PhoneTestSuite) TestMissingSmsProviderConfig() { require.Equal(ts.T(), c.expected["code"], w.Code) body := w.Body.String() - require.True(ts.T(), strings.Contains(body, c.expected["message"].(string))) + require.True(ts.T(), + strings.Contains(body, "Unable to get SMS provider") || + strings.Contains(body, "Error finding SMS provider") || + strings.Contains(body, "Failed to get SMS provider"), + "unexpected body message %q", body, + ) }) } } diff --git a/internal/api/pkce.go b/internal/api/pkce.go index a186aa4649..bd2809e185 100644 --- a/internal/api/pkce.go +++ b/internal/api/pkce.go @@ -21,9 +21,9 @@ func isValidCodeChallenge(codeChallenge string) (bool, error) { // See RFC 7636 Section 4.2: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 switch codeChallengeLength := len(codeChallenge); { case codeChallengeLength < MinCodeChallengeLength, codeChallengeLength > MaxCodeChallengeLength: - return false, badRequestError("code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) + return false, badRequestError(ErrorCodeValidationFailed, "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) case !codeChallengePattern.MatchString(codeChallenge): - return false, badRequestError("code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") + return false, badRequestError(ErrorCodeValidationFailed, "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") default: return true, nil } @@ -41,7 +41,7 @@ func addFlowPrefixToToken(token string, flowType models.FlowType) string { func issueAuthCode(tx *storage.Connection, user *models.User, expiryDuration time.Duration, authenticationMethod models.AuthenticationMethod) (string, error) { flowState, err := models.FindFlowStateByUserID(tx, user.ID.String(), authenticationMethod) if err != nil && models.IsNotFoundError(err) { - return "", badRequestError("No valid flow state found for user.") + return "", unprocessableEntityError(ErrorCodeFlowStateNotFound, "No valid flow state found for user.") } else if err != nil { return "", err } @@ -59,7 +59,7 @@ func isImplicitFlow(flowType models.FlowType) bool { func validatePKCEParams(codeChallengeMethod, codeChallenge string) error { switch true { case (codeChallenge == "") != (codeChallengeMethod == ""): - return badRequestError(InvalidPKCEParamsErrorMessage) + return badRequestError(ErrorCodeValidationFailed, InvalidPKCEParamsErrorMessage) case codeChallenge != "": if valid, err := isValidCodeChallenge(codeChallenge); !valid { return err diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go index e29f2f7e32..8f8df9af72 100644 --- a/internal/api/reauthenticate.go +++ b/internal/api/reauthenticate.go @@ -23,16 +23,16 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { email, phone := user.GetEmail(), user.GetPhone() if email == "" && phone == "" { - return unprocessableEntityError("Reauthentication requires the user to have an email or a phone number") + return badRequestError(ErrorCodeValidationFailed, "Reauthentication requires the user to have an email or a phone number") } if email != "" { if !user.IsConfirmed() { - return badRequestError("Please verify your email first.") + return unprocessableEntityError(ErrorCodeEmailNotConfirmed, "Please verify your email first.") } } else if phone != "" { if !user.IsPhoneConfirmed() { - return badRequestError("Please verify your phone first.") + return unprocessableEntityError(ErrorCodePhoneNotConfirmed, "Please verify your phone first.") } } @@ -47,7 +47,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { } else if phone != "" { smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Failed to get SMS provider").WithInternalError(terr) } mID, err := a.sendPhoneConfirmation(ctx, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) if err != nil { @@ -60,7 +60,12 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + reason := ErrorCodeOverEmailSendRate + if phone != "" { + reason = ErrorCodeOverSMSSendRate + } + + return tooManyRequestsError(reason, "For security purposes, you can only request this once every 60 seconds") } return err } @@ -77,7 +82,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { // verifyReauthentication checks if the nonce provided is valid func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, config *conf.GlobalConfiguration, user *models.User) error { if user.ReauthenticationToken == "" || user.ReauthenticationSentAt == nil { - return badRequestError(InvalidNonceMessage) + return unprocessableEntityError(ErrorCodeReauthenticationNotValid, InvalidNonceMessage) } var isValid bool if user.GetEmail() != "" { @@ -87,7 +92,7 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi if config.Sms.IsTwilioVerifyProvider() { smsProvider, _ := sms_provider.GetSmsProvider(*config) if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(string(user.Phone), nonce); err != nil { - return expiredTokenError("Token has expired or is invalid").WithInternalError(err) + return expiredTokenError("CHECK", "Token has expired or is invalid").WithInternalError(err) } return nil } else { @@ -95,10 +100,10 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi isValid = isOtpValid(tokenHash, user.ReauthenticationToken, user.ReauthenticationSentAt, config.Sms.OtpExp) } } else { - return unprocessableEntityError("Reauthentication requires an email or a phone number") + return unprocessableEntityError("CHECK", "Reauthentication requires an email or a phone number") } if !isValid { - return badRequestError(InvalidNonceMessage) + return badRequestError("CHECK", InvalidNonceMessage) } if err := user.ConfirmReauthentication(tx); err != nil { return internalServerError("Error during reauthentication").WithInternalError(err) diff --git a/internal/api/recover.go b/internal/api/recover.go index 9a57575650..195939a17a 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -19,7 +19,7 @@ type RecoverParams struct { func (p *RecoverParams) Validate() error { if p.Email == "" { - return unprocessableEntityError("Password recovery requires an email") + return badRequestError(ErrorCodeValidationFailed, "Password recovery requires an email") } var err error if p.Email, err = validateEmail(p.Email); err != nil { @@ -40,11 +40,11 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read verification params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read verification params: %v", err) } flowType := getFlowFromChallenge(params.CodeChallenge) @@ -83,7 +83,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRate, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Unable to process request").WithInternalError(err) } diff --git a/internal/api/resend.go b/internal/api/resend.go index a2fb4a52be..dfbaa1cc5b 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -26,22 +26,22 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er break default: // type does not match one of the above - return badRequestError("Missing one of these types: signup, email_change, sms, phone_change") + return badRequestError(ErrorCodeValidationFailed, "Missing one of these types: signup, email_change, sms, phone_change") } if p.Email == "" && p.Type == signupVerification { - return badRequestError("Type provided requires an email address") + return badRequestError(ErrorCodeValidationFailed, "Type provided requires an email address") } if p.Phone == "" && p.Type == smsVerification { - return badRequestError("Type provided requires a phone number") + return badRequestError(ErrorCodeValidationFailed, "Type provided requires a phone number") } var err error if p.Email != "" && p.Phone != "" { - return badRequestError("Only an email address or phone number should be provided.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided.") } else if p.Email != "" { if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return badRequestError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } p.Email, err = validateEmail(p.Email) if err != nil { @@ -49,7 +49,7 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er } } else if p.Phone != "" { if !config.External.Phone.Enabled { - return badRequestError("Phone logins are disabled") + return badRequestError(ErrorCodePhoneProviderDisabled, "Phone logins are disabled") } p.Phone, err = validatePhone(p.Phone) if err != nil { @@ -57,7 +57,7 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er } } else { // both email and phone are empty - return badRequestError("Missing email address or phone number") + return badRequestError(ErrorCodeValidationFailed, "Missing email address or phone number") } return nil } @@ -71,11 +71,11 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read params: %v", err) } if err := params.Validate(config); err != nil { @@ -162,8 +162,13 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { + reason := ErrorCodeOverEmailSendRate + if params.Type == smsVerification || params.Type == phoneChangeVerification { + reason = ErrorCodeOverSMSSendRate + } + until := time.Until(user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency)) / time.Second - return tooManyRequestsError("For security purposes, you can only request this once every %d seconds.", until) + return tooManyRequestsError(reason, "For security purposes, you can only request this once every %d seconds.", until) } return internalServerError("Unable to process request").WithInternalError(err) } diff --git a/internal/api/router.go b/internal/api/router.go index c2f06ae2e6..70b41f22db 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -63,7 +63,7 @@ func handler(fn apiHandler) http.HandlerFunc { func (h apiHandler) serve(w http.ResponseWriter, r *http.Request) { if err := h(w, r); err != nil { - handleError(err, w, r) + HandleResponseError(err, w, r) } } @@ -78,7 +78,7 @@ func (m middlewareHandler) handler(next http.Handler) http.Handler { func (m middlewareHandler) serve(next http.Handler, w http.ResponseWriter, r *http.Request) { ctx, err := m(w, r) if err != nil { - handleError(err, w, r) + HandleResponseError(err, w, r) return } if ctx != nil { diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index a3932249a3..1492665339 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -67,7 +67,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { relayState, err := models.FindSAMLRelayStateByID(db, relayStateUUID) if models.IsNotFoundError(err) { - return badRequestError("SAML RelayState does not exist, try logging in again?") + return notFoundError(ErrorCodeSAMLRelayStateNotFound, "SAML RelayState does not exist, try logging in again?") } else if err != nil { return err } @@ -77,7 +77,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { return internalServerError("SAML RelayState has expired and destroying it failed. Try logging in again?").WithInternalError(err) } - return badRequestError("SAML RelayState has expired. Try loggin in again?") + return unprocessableEntityError(ErrorCodeSAMLRelayStateExpired, "SAML RelayState has expired. Try loggin in again?") } // TODO: add abuse detection to bind the RelayState UUID with a @@ -107,23 +107,23 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { // SAML Artifact responses are possible only when // RelayState can be used to identify the Identity // Provider. - return badRequestError("SAML Artifact response can only be used with SP initiated flow") + return badRequestError(ErrorCodeValidationFailed, "SAML Artifact response can only be used with SP initiated flow") } samlResponse := r.FormValue("SAMLResponse") if samlResponse == "" { - return badRequestError("SAMLResponse is missing") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is missing") } responseXML, err := base64.StdEncoding.DecodeString(samlResponse) if err != nil { - return badRequestError("SAMLResponse is not a valid Base64 string") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid Base64 string") } var peekResponse saml.Response err = xml.Unmarshal(responseXML, &peekResponse) if err != nil { - return badRequestError("SAMLResponse is not a valid XML SAML assertion") + return badRequestError(ErrorCodeValidationFailed, "SAMLResponse is not a valid XML SAML assertion").WithInternalError(err) } initiatedBy = "idp" @@ -131,12 +131,12 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { redirectTo = relayStateValue } else { // RelayState can't be identified, so SAML flow can't continue - return badRequestError("SAML RelayState is not a valid UUID or URL") + return badRequestError(ErrorCodeValidationFailed, "SAML RelayState is not a valid UUID or URL") } ssoProvider, err := models.FindSAMLProviderByEntityID(db, entityId) if models.IsNotFoundError(err) { - return badRequestError("A SAML connection has not been established with this Identity Provider") + return notFoundError(ErrorCodeSAMLIdPNotFound, "A SAML connection has not been established with this Identity Provider") } else if err != nil { return err } @@ -176,10 +176,10 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { spAssertion, err := serviceProvider.ParseResponse(r, requestIds) if err != nil { if ire, ok := err.(*saml.InvalidResponseError); ok { - return badRequestError("SAML Assertion is not valid").WithInternalError(ire.PrivateErr) + return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(ire.PrivateErr) } - return badRequestError("SAML Assertion is not valid").WithInternalError(err) + return badRequestError(ErrorCodeValidationFailed, "SAML Assertion is not valid").WithInternalError(err) } assertion := SAMLAssertion{ @@ -188,7 +188,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { userID := assertion.UserID() if userID == "" { - return badRequestError("SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") + return badRequestError(ErrorCodeSAMLAssertionNoUserID, "SAML Assertion did not contain a persistent Subject Identifier attribute or Subject NameID uniquely identifying this user") } claims := assertion.Process(ssoProvider.SAMLProvider.AttributeMapping) @@ -200,7 +200,7 @@ func (a *API) SAMLACS(w http.ResponseWriter, r *http.Request) error { } if email == "" { - return badRequestError("SAML Assertion does not contain an email address") + return badRequestError(ErrorCodeSAMLAssertionNoEmail, "SAML Assertion does not contain an email address") } else { claims["email"] = email } diff --git a/internal/api/signup.go b/internal/api/signup.go index 8b84ffd30d..fc180b4ba0 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -35,21 +35,21 @@ func (a *API) validateSignupParams(ctx context.Context, p *SignupParams) error { config := a.config if p.Password == "" { - return unprocessableEntityError("Signup requires a valid password") + return badRequestError(ErrorCodeValidationFailed, "Signup requires a valid password") } if err := a.checkPasswordStrength(ctx, p.Password); err != nil { return err } if p.Email != "" && p.Phone != "" { - return unprocessableEntityError("Only an email address or phone number should be provided on signup.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on signup.") } if p.Provider == "phone" && !sms_provider.IsValidMessageChannel(p.Channel, config.Sms.Provider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } // PKCE not needed as phone signups already return access token in body if p.Phone != "" && p.CodeChallenge != "" { - return badRequestError("PKCE not supported for phone signups") + return badRequestError(ErrorCodeValidationFailed, "PKCE not supported for phone signups") } if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { return err @@ -113,18 +113,18 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { db := a.db.WithContext(ctx) if config.DisableSignup { - return forbiddenError("Signups not allowed for this instance") + return unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance") } params := &SignupParams{} body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read Signup params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read Signup params: %v", err).WithInternalError(err) } params.ConfigureDefaults() @@ -152,7 +152,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { switch params.Provider { case "email": if !config.External.Email.Enabled { - return badRequestError("Email signups are disabled") + return badRequestError(ErrorCodeEmailProviderDisabled, "Email signups are disabled") } params.Email, err = validateEmail(params.Email) if err != nil { @@ -161,7 +161,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { user, err = models.IsDuplicatedEmail(db, params.Email, params.Aud, nil) case "phone": if !config.External.Phone.Enabled { - return badRequestError("Phone signups are disabled") + return badRequestError(ErrorCodePhoneProviderDisabled, "Phone signups are disabled") } params.Phone, err = validatePhone(params.Phone) if err != nil { @@ -169,7 +169,18 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } user, err = models.FindUserByPhoneAndAudience(db, params.Phone, params.Aud) default: - return invalidSignupError(config) + msg := "" + if config.External.Email.Enabled && config.External.Phone.Enabled { + msg = "Sign up only available with email or phone provider" + } else if config.External.Email.Enabled { + msg = "Sign up only available with email provider" + } else if config.External.Phone.Enabled { + msg = "Sign up only available with phone provider" + } else { + msg = "Sign up with this provider not possible" + } + + return badRequestError(ErrorCodeValidationFailed, msg) } if err != nil && !models.IsNotFoundError(err) { @@ -240,7 +251,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { if errors.Is(terr, MaxFrequencyLimitError) { now := time.Now() left := user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency).Sub(now) / time.Second - return tooManyRequestsError(fmt.Sprintf("For security purposes, you can only request this after %d seconds.", left)) + return tooManyRequestsError(ErrorCodeOverEmailSendRate, fmt.Sprintf("For security purposes, you can only request this after %d seconds.", left)) } return internalServerError("Error sending confirmation mail").WithInternalError(terr) } @@ -267,10 +278,10 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { } smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending confirmation sms: %v", terr) + return internalServerError("Unable to get SMS provider").WithInternalError(terr) } if _, terr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { - return badRequestError("Error sending confirmation sms: %v", terr) + return unprocessableEntityError(ErrorCodeSMSSendFailed, "Error sending confirmation sms: %v", terr).WithInternalError(terr) } } } @@ -279,10 +290,14 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { }) if err != nil { - if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every minute") + reason := ErrorCodeOverEmailSendRate + if params.Provider == "phone" { + reason = ErrorCodeOverSMSSendRate } - if errors.Is(err, UserExistsError) { + + if errors.Is(err, MaxFrequencyLimitError) { + return tooManyRequestsError(reason, "For security purposes, you can only request this once every minute") + } else if errors.Is(err, UserExistsError) { err = db.Transaction(func(tx *storage.Connection) error { if terr := models.NewAuditLogEntry(r, tx, user, models.UserRepeatedSignUpAction, "", map[string]interface{}{ "provider": params.Provider, @@ -295,7 +310,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { return err } if config.Mailer.Autoconfirm || config.Sms.Autoconfirm { - return badRequestError("User already registered") + return unprocessableEntityError(ErrorCodeUserAlreadyExists, "User already registered") } sanitizedUser, err := sanitizeUser(user, params) if err != nil { diff --git a/internal/api/sso.go b/internal/api/sso.go index d93ff82dc0..07e97c6c8b 100644 --- a/internal/api/sso.go +++ b/internal/api/sso.go @@ -28,9 +28,9 @@ func (p *SingleSignOnParams) validate() (bool, error) { hasDomain := p.Domain != "" if hasProviderID && hasDomain { - return hasProviderID, badRequestError("Only one of provider_id or domain supported") + return hasProviderID, badRequestError(ErrorCodeValidationFailed, "Only one of provider_id or domain supported") } else if !hasProviderID && !hasDomain { - return hasProviderID, badRequestError("A provider_id or domain needs to be provided") + return hasProviderID, badRequestError(ErrorCodeValidationFailed, "A provider_id or domain needs to be provided") } return hasProviderID, nil @@ -49,7 +49,7 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { var params SingleSignOnParams if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse request body as JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "Unable to parse request body as JSON").WithInternalError(err) } hasProviderID := false @@ -71,10 +71,7 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - flowState, err := models.NewFlowState(models.SSOSAML.String(), codeChallenge, codeChallengeMethodType, models.SSOSAML) - if err != nil { - return err - } + flowState := models.NewFlowState(models.SSOSAML.String(), codeChallenge, codeChallengeMethodType, models.SSOSAML) if err := a.db.Create(flowState); err != nil { return err } @@ -86,14 +83,14 @@ func (a *API) SingleSignOn(w http.ResponseWriter, r *http.Request) error { if hasProviderID { ssoProvider, err = models.FindSSOProviderByID(db, params.ProviderID) if models.IsNotFoundError(err) { - return notFoundError("No such SSO provider") + return notFoundError(ErrorCodeSSOProviderNotFound, "No such SSO provider") } else if err != nil { return internalServerError("Unable to find SSO provider by ID").WithInternalError(err) } } else { ssoProvider, err = models.FindSSOProviderByDomain(db, params.Domain) if models.IsNotFoundError(err) { - return notFoundError("No SSO provider assigned for this domain") + return notFoundError(ErrorCodeSSOProviderNotFound, "No SSO provider assigned for this domain") } else if err != nil { return internalServerError("Unable to find SSO provider by domain").WithInternalError(err) } diff --git a/internal/api/ssoadmin.go b/internal/api/ssoadmin.go index 4fdecc0f81..4f2d38a44f 100644 --- a/internal/api/ssoadmin.go +++ b/internal/api/ssoadmin.go @@ -29,14 +29,14 @@ func (a *API) loadSSOProvider(w http.ResponseWriter, r *http.Request) (context.C idpID, err := uuid.FromString(idpParam) if err != nil { // idpParam is not UUIDv4 - return nil, notFoundError("SSO Identity Provider not found") + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") } // idpParam is a UUIDv4 provider, err := models.FindSSOProviderByID(db, idpID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("SSO Identity Provider not found") + return nil, notFoundError(ErrorCodeSSOProviderNotFound, "SSO Identity Provider not found") } else { return nil, internalServerError("Database error finding SSO Identity Provider").WithInternalError(err) } @@ -79,19 +79,19 @@ type CreateSSOProviderParams struct { func (p *CreateSSOProviderParams) validate(forUpdate bool) error { if !forUpdate && p.Type != "saml" { - return badRequestError("Only 'saml' supported for SSO provider type") + return badRequestError(ErrorCodeValidationFailed, "Only 'saml' supported for SSO provider type") } else if p.MetadataURL != "" && p.MetadataXML != "" { - return badRequestError("Only one of metadata_xml or metadata_url needs to be set") + return badRequestError(ErrorCodeValidationFailed, "Only one of metadata_xml or metadata_url needs to be set") } else if !forUpdate && p.MetadataURL == "" && p.MetadataXML == "" { - return badRequestError("Either metadata_xml or metadata_url must be set") + return badRequestError(ErrorCodeValidationFailed, "Either metadata_xml or metadata_url must be set") } else if p.MetadataURL != "" { metadataURL, err := url.ParseRequestURI(p.MetadataURL) if err != nil { - return badRequestError("metadata_url is not a valid URL") + return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a valid URL") } if metadataURL.Scheme != "https" { - return badRequestError("metadata_url is not a HTTPS URL") + return badRequestError(ErrorCodeValidationFailed, "metadata_url is not a HTTPS URL") } } @@ -127,7 +127,7 @@ func (p *CreateSSOProviderParams) metadata(ctx context.Context) ([]byte, *saml.E func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { if !utf8.Valid(rawMetadata) { - return nil, badRequestError("SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata XML contains invalid UTF-8 characters, which are not supported at this time") } metadata, err := samlsp.ParseMetadata(rawMetadata) @@ -136,15 +136,15 @@ func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { } if metadata.EntityID == "" { - return nil, badRequestError("SAML Metadata does not contain an EntityID") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain an EntityID") } if len(metadata.IDPSSODescriptors) < 1 { - return nil, badRequestError("SAML Metadata does not contain any IDPSSODescriptor") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata does not contain any IDPSSODescriptor") } if len(metadata.IDPSSODescriptors) > 1 { - return nil, badRequestError("SAML Metadata contains multiple IDPSSODescriptors") + return nil, badRequestError(ErrorCodeValidationFailed, "SAML Metadata contains multiple IDPSSODescriptors") } return metadata, nil @@ -153,7 +153,7 @@ func parseSAMLMetadata(rawMetadata []byte) (*saml.EntityDescriptor, error) { func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { - return nil, badRequestError("Unable to create a request to metadata_url").WithInternalError(err) + return nil, internalServerError("Unable to create a request to metadata_url").WithInternalError(err) } req = req.WithContext(ctx) @@ -168,7 +168,7 @@ func fetchSAMLMetadata(ctx context.Context, url string) ([]byte, error) { defer utilities.SafeClose(resp.Body) if resp.StatusCode != http.StatusOK { - return nil, badRequestError("HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) + return nil, badRequestError(ErrorCodeSAMLMetadataFetchFailed, "HTTP %v error fetching SAML Metadata from URL '%s'", resp.StatusCode, url) } data, err := io.ReadAll(resp.Body) @@ -191,7 +191,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er var params CreateSSOProviderParams if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "Unable to parse JSON").WithInternalError(err) } if err := params.validate(false /* <- forUpdate */); err != nil { @@ -208,7 +208,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er return err } if existingProvider != nil { - return badRequestError("SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) + return unprocessableEntityError(ErrorCodeSAMLIdPAlreadyExists, "SAML Identity Provider with this EntityID (%s) already exists", metadata.EntityID) } provider := &models.SSOProvider{ @@ -231,7 +231,7 @@ func (a *API) adminSSOProvidersCreate(w http.ResponseWriter, r *http.Request) er return err } if existingProvider != nil { - return badRequestError("SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) + return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO Domain '%s' is already assigned to an SSO identity provider (%s)", domain, existingProvider.ID.String()) } provider.SSODomains = append(provider.SSODomains, models.SSODomain{ @@ -271,7 +271,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er var params CreateSSOProviderParams if err := json.Unmarshal(body, ¶ms); err != nil { - return badRequestError("Unable to parse JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "Unable to parse JSON").WithInternalError(err) } if err := params.validate(true /* <- forUpdate */); err != nil { @@ -291,7 +291,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er } if provider.SAMLProvider.EntityID != metadata.EntityID { - return badRequestError("SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) + return badRequestError(ErrorCodeSAMLEntityIDMismatch, "SAML Metadata can be updated only if the EntityID matches for the provider; expected '%s' but got '%s'", provider.SAMLProvider.EntityID, metadata.EntityID) } if params.MetadataURL != "" { @@ -320,7 +320,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er if existingProvider.ID == provider.ID { keepDomains[domain] = true } else { - return badRequestError("SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) + return badRequestError(ErrorCodeSSODomainAlreadyExists, "SSO domain '%s' already assigned to another provider (%s)", domain, existingProvider.ID.String()) } } else { modified = true @@ -370,7 +370,7 @@ func (a *API) adminSSOProvidersUpdate(w http.ResponseWriter, r *http.Request) er return tx.Eager().Load(provider) }); err != nil { - return unprocessableEntityError("Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) + return unprocessableEntityError(ErrorCodeConflict, "Updating SSO provider failed, likely due to a conflict. Try again?").WithInternalError(err) } } diff --git a/internal/api/token.go b/internal/api/token.go index 75d2e50878..8c0c9ca48d 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -103,18 +103,18 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read password grant params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read password grant params: %v", err) } aud := a.requestAud(ctx, r) config := a.config if params.Email != "" && params.Phone != "" { - return unprocessableEntityError("Only an email address or phone number should be provided on login.") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on login.") } var user *models.User var grantParams models.GrantParams @@ -125,13 +125,13 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri if params.Email != "" { provider = "email" if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return unprocessableEntityError(ErrorCodeEmailProviderDisabled, "Email logins are disabled") } user, err = models.FindUserByEmailAndAudience(db, params.Email, aud) } else if params.Phone != "" { provider = "phone" if !config.External.Phone.Enabled { - return badRequestError("Phone logins are disabled") + return unprocessableEntityError(ErrorCodePhoneProviderDisabled, "Phone logins are disabled") } params.Phone = formatPhoneNumber(params.Phone) user, err = models.FindUserByPhoneAndAudience(db, params.Phone, aud) @@ -183,7 +183,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri return err } } - return forbiddenError(output.Message) + return oauthError("invalid_grant", InvalidLoginMessage) } } if !isValidPassword { @@ -244,22 +244,22 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) } if err = json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + return badRequestError(ErrorCodeBadJSON, "invalid body: unable to parse JSON").WithInternalError(err) } if params.AuthCode == "" || params.CodeVerifier == "" { - return badRequestError("invalid request: both auth code and code verifier should be non-empty") + return badRequestError(ErrorCodeValidationFailed, "invalid request: both auth code and code verifier should be non-empty") } flowState, err := models.FindFlowStateByAuthCode(db, params.AuthCode) // Sanity check in case user ID was not set properly if models.IsNotFoundError(err) || flowState.UserID == nil { - return forbiddenError("invalid flow state, no valid flow state found") + return notFoundError(ErrorCodeFlowStateNotFound, "invalid flow state, no valid flow state found") } else if err != nil { return err } if flowState.IsExpired(a.config.External.FlowStateExpiryDuration) { - return forbiddenError("invalid flow state, flow state has expired") + return unprocessableEntityError(ErrorCodeFlowStateExpired, "invalid flow state, flow state has expired") } user, err := models.FindUserByID(db, *flowState.UserID) @@ -267,7 +267,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) return err } if err := flowState.VerifyPKCE(params.CodeVerifier); err != nil { - return forbiddenError(err.Error()) + return forbiddenError("CHECK", err.Error()) } var token *AccessTokenResponse @@ -433,17 +433,9 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, var tokenString string var expiresAt int64 var refreshToken *models.RefreshToken - currentClaims := getClaims(ctx) - sessionId, err := uuid.FromString(currentClaims.SessionId) - if err != nil { - return nil, internalServerError("Cannot read SessionId claim as UUID").WithInternalError(err) - } - err = tx.Transaction(func(tx *storage.Connection) error { - if terr := models.AddClaimToSession(tx, sessionId, authenticationMethod); terr != nil { - return terr - } - session, terr := models.FindSessionByID(tx, sessionId, false) - if terr != nil { + session := getSession(ctx) + err := tx.Transaction(func(tx *storage.Connection) error { + if terr := models.AddClaimToSession(tx, session.ID, authenticationMethod); terr != nil { return terr } currentToken, terr := models.FindTokenBySessionID(tx, &session.ID) @@ -467,7 +459,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, return err } - tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &sessionId, models.TOTPSignIn) + tokenString, expiresAt, terr = a.generateAccessToken(ctx, tx, user, &session.ID, models.TOTPSignIn) if terr != nil { httpErr, ok := terr.(*HTTPError) if ok { diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index c380856c33..7589f94762 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -55,7 +55,7 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa if issuer == "" || !provider.IsAzureIssuer(issuer) { detectedIssuer, err := provider.DetectAzureIDTokenIssuer(ctx, p.IdToken) if err != nil { - return nil, nil, "", nil, badRequestError("Unable to detect issuer in ID token for Azure provider").WithInternalError(err) + return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, "Unable to detect issuer in ID token for Azure provider").WithInternalError(err) } issuer = detectedIssuer } @@ -90,12 +90,12 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa } if !allowed { - return nil, nil, "", nil, badRequestError(fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) + return nil, nil, "", nil, badRequestError(ErrorCodeValidationFailed, fmt.Sprintf("Custom OIDC provider %q not allowed", p.Provider)) } } if cfg != nil && !cfg.Enabled { - return nil, nil, "", nil, badRequestError(fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) + return nil, nil, "", nil, badRequestError(ErrorCodeProviderDisabled, fmt.Sprintf("Provider (issuer %q) is not enabled", issuer)) } oidcProvider, err := oidc.NewProvider(ctx, issuer) @@ -117,11 +117,11 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read id token grant params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read id token grant params: %v", err) } if params.IdToken == "" { diff --git a/internal/api/token_refresh.go b/internal/api/token_refresh.go index 1cd665346f..a610f7f160 100644 --- a/internal/api/token_refresh.go +++ b/internal/api/token_refresh.go @@ -29,11 +29,11 @@ func (a *API) RefreshTokenGrant(ctx context.Context, w http.ResponseWriter, r *h body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read refresh token grant params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read refresh token grant params: %v", err) } if params.RefreshToken == "" { diff --git a/internal/api/token_test.go b/internal/api/token_test.go index 6a8acf6d3b..04bfbdd4dc 100644 --- a/internal/api/token_test.go +++ b/internal/api/token_test.go @@ -306,8 +306,7 @@ func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() { invalidVerifier := codeVerifier + "123" codeChallenge := sha256.Sum256([]byte(codeVerifier)) challenge := base64.RawURLEncoding.EncodeToString(codeChallenge[:]) - flowState, err := models.NewFlowState("github", challenge, models.SHA256, models.OAuth) - require.NoError(ts.T(), err) + flowState := models.NewFlowState("github", challenge, models.SHA256, models.OAuth) flowState.AuthCode = authCode require.NoError(ts.T(), ts.API.db.Create(flowState)) cases := []struct { @@ -344,7 +343,7 @@ func (ts *TokenTestSuite) TestTokenPKCEGrantFailure() { req.Header.Set("Content-Type", "application/json") w := httptest.NewRecorder() ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusForbidden, w.Code) + assert.Equal(ts.T(), http.StatusNotFound, w.Code) }) } } @@ -619,7 +618,7 @@ func (ts *TokenTestSuite) TestPasswordVerificationHook() { begin return json_build_object('decision', 'reject'); end; $$ language plpgsql;`, - expectedCode: http.StatusForbidden, + expectedCode: http.StatusBadRequest, }, } for _, c := range cases { diff --git a/internal/api/user.go b/internal/api/user.go index 02fd099b21..dfe2c9d0d9 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -48,7 +48,7 @@ func (a *API) validateUserUpdateParams(ctx context.Context, p *UserUpdateParams) p.Channel = sms_provider.SMSProvider } if !sms_provider.IsValidMessageChannel(p.Channel, config.Sms.Provider) { - return badRequestError(InvalidChannelError) + return badRequestError(ErrorCodeValidationFailed, InvalidChannelError) } } @@ -66,12 +66,12 @@ func (a *API) UserGet(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() claims := getClaims(ctx) if claims == nil { - return badRequestError("Could not read claims") + return badRequestError("CHECK", "Could not read claims") } aud := a.requestAud(ctx, r) if aud != claims.Audience { - return badRequestError("Token audience doesn't match request audience") + return badRequestError(ErrorCodeValidationFailed, "Token audience doesn't match request audience") } user := getUser(ctx) @@ -89,11 +89,11 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read User Update params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not read User Update params: %v", err) } user := getUser(ctx) @@ -105,7 +105,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if params.AppData != nil && !isAdmin(user, config) { if !isAdmin(user, config) { - return unauthorizedError("Updating app_metadata requires admin privileges") + return unauthorizedError(ErrorCodeNotAdmin, "Updating app_metadata requires admin privileges") } } @@ -118,7 +118,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { updatingForbiddenFields = updatingForbiddenFields || (params.Nonce != "") if updatingForbiddenFields { - return unprocessableEntityError("Updating email, phone, password of a SSO account only possible via SSO") + return unprocessableEntityError(ErrorCodeUserSSOManaged, "Updating email, phone, password of a SSO account only possible via SSO") } } @@ -126,7 +126,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if duplicateUser, err := models.IsDuplicatedEmail(db, params.Email, aud, user); err != nil { return internalServerError("Database error checking email").WithInternalError(err) } else if duplicateUser != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg) } } @@ -134,7 +134,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { return internalServerError("Database error checking phone").WithInternalError(err) } else if exists { - return unprocessableEntityError(DuplicatePhoneMsg) + return unprocessableEntityError(ErrorCodePhoneExists, DuplicatePhoneMsg) } } @@ -144,7 +144,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { // we require reauthentication if the user hasn't signed in recently in the current session if session == nil || now.After(session.CreatedAt.Add(24*time.Hour)) { if len(params.Nonce) == 0 { - return badRequestError("Password update requires reauthentication") + return badRequestError(ErrorCodeReauthenticationNeeded, "Password update requires reauthentication") } if err := a.verifyReauthentication(params.Nonce, db, config, user); err != nil { return err @@ -155,7 +155,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { password := *params.Password if password != "" { if user.EncryptedPassword != "" && user.Authenticate(ctx, password) { - return unprocessableEntityError("New password should be different from the old password.") + return unprocessableEntityError(ErrorCodeSamePassword, "New password should be different from the old password.") } } @@ -231,7 +231,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { externalURL := getExternalHost(ctx) if terr = a.sendEmailChange(tx, config, user, mailer, params.Email, referrer, externalURL, config.Mailer.OtpLength, flowType); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError(ErrorCodeOverEmailSendRate, "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Error sending change email").WithInternalError(terr) } @@ -264,7 +264,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } else { smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Error finding SMS provider").WithInternalError(terr) } if _, terr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { return internalServerError("Error sending phone change otp").WithInternalError(terr) diff --git a/internal/api/user_test.go b/internal/api/user_test.go index c449cb02c9..2496f15f41 100644 --- a/internal/api/user_test.go +++ b/internal/api/user_test.go @@ -274,7 +274,7 @@ func (ts *UserTestSuite) TestUserUpdatePassword() { nonce: "123456", requireReauthentication: true, sessionId: nil, - expected: expected{code: http.StatusBadRequest, isAuthenticated: false}, + expected: expected{code: http.StatusUnprocessableEntity, isAuthenticated: false}, }, { desc: "Valid password length", diff --git a/internal/api/verify.go b/internal/api/verify.go index c338064097..e326fb1171 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -52,18 +52,18 @@ type VerifyParams struct { func (p *VerifyParams) Validate(r *http.Request) error { var err error if p.Type == "" { - return badRequestError("Verify requires a verification type") + return badRequestError(ErrorCodeValidationFailed, "Verify requires a verification type") } switch r.Method { case http.MethodGet: if p.Token == "" { - return badRequestError("Verify requires a token or a token hash") + return badRequestError(ErrorCodeValidationFailed, "Verify requires a token or a token hash") } // TODO: deprecate the token query param from GET /verify and use token_hash instead (breaking change) p.TokenHash = p.Token case http.MethodPost: if (p.Token == "" && p.TokenHash == "") || (p.Token != "" && p.TokenHash != "") { - return badRequestError("Verify requires either a token or a token hash") + return badRequestError(ErrorCodeValidationFailed, "Verify requires either a token or a token hash") } if p.Token != "" { if isPhoneOtpVerification(p) { @@ -75,15 +75,15 @@ func (p *VerifyParams) Validate(r *http.Request) error { } else if isEmailOtpVerification(p) { p.Email, err = validateEmail(p.Email) if err != nil { - return unprocessableEntityError("Invalid email format").WithInternalError(err) + return unprocessableEntityError(ErrorCodeValidationFailed, "Invalid email format").WithInternalError(err) } p.TokenHash = crypto.GenerateTokenHash(p.Email, p.Token) } else { - return badRequestError("Only an email address or phone number should be provided on verify") + return badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify") } } else if p.TokenHash != "" { if p.Email != "" || p.Phone != "" || p.RedirectTo != "" { - return badRequestError("Only the token_hash and type should be provided") + return badRequestError(ErrorCodeValidationFailed, "Only the token_hash and type should be provided") } } default: @@ -107,17 +107,18 @@ func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { case http.MethodPost: body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not parse verification params: %v", err) + return badRequestError(ErrorCodeBadJSON, "Could not parse verification params: %v", err) } if err := params.Validate(r); err != nil { return err } return a.verifyPost(w, r, params) default: - return unprocessableEntityError("Only GET and POST methods are supported.") + // this should have been handled by Chi + panic("Only GET and POST methods allowed") } } @@ -169,7 +170,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa return nil } default: - return unprocessableEntityError("Unsupported verification type") + return badRequestError(ErrorCodeValidationFailed, "Unsupported verification type") } if terr != nil { @@ -187,7 +188,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa } } else if isPKCEFlow(flowType) { if authCode, terr = issueAuthCode(tx, user, a.config.External.FlowStateExpiryDuration, authenticationMethod); terr != nil { - return badRequestError("No associated flow state found. %s", terr) + return badRequestError(ErrorCodeFlowStateNotFound, "No associated flow state found. %s", terr) } } return nil @@ -260,7 +261,7 @@ func (a *API) verifyPost(w http.ResponseWriter, r *http.Request, params *VerifyP case smsVerification, phoneChangeVerification: user, terr = a.smsVerify(r, ctx, tx, user, params.Type) default: - return unprocessableEntityError("Unsupported verification type") + return badRequestError(ErrorCodeValidationFailed, "Unsupported verification type") } if terr != nil { @@ -296,7 +297,8 @@ func (a *API) signupVerify(r *http.Request, ctx context.Context, conn *storage.C // to present the user with a password set form password, err := password.Generate(64, 10, 0, false, true) if err != nil { - return nil, err + // password generation must succeed + panic(err) } if err := user.SetPassword(ctx, password); err != nil { @@ -410,14 +412,14 @@ func (a *API) prepErrorRedirectURL(err *HTTPError, w http.ResponseWriter, r *htt errorID := getRequestID(r.Context()) err.ErrorID = errorID log.WithError(err.Cause()).Info(err.Error()) - if str, ok := oauthErrorMap[err.Code]; ok { + if str, ok := oauthErrorMap[err.HTTPStatus]; ok { hq.Set("error", str) q.Set("error", str) } - hq.Set("error_code", strconv.Itoa(err.Code)) + hq.Set("error_code", strconv.Itoa(err.HTTPStatus)) hq.Set("error_description", err.Message) - q.Set("error_code", strconv.Itoa(err.Code)) + q.Set("error_code", strconv.Itoa(err.HTTPStatus)) q.Set("error_description", err.Message) if flowType == models.PKCEFlow { // Additionally, may override existing error query param if set to PKCE. @@ -517,18 +519,18 @@ func (a *API) verifyTokenHash(ctx context.Context, conn *storage.Connection, par case emailChangeVerification: user, err = models.FindUserByEmailChangeToken(conn, params.TokenHash) default: - return nil, badRequestError("Invalid email verification type") + return nil, badRequestError(ErrorCodeValidationFailed, "Invalid email verification type") } if err != nil { if models.IsNotFoundError(err) { - return nil, expiredTokenError("Email link is invalid or has expired").WithInternalError(err) + return nil, expiredTokenError(ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalError(err) } return nil, internalServerError("Database error finding user from email link").WithInternalError(err) } if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalMessage("user is banned") + return nil, unauthorizedError(ErrorCodeUserBanned, "User is banned").WithInternalMessage("user is banned") } var isExpired bool @@ -550,7 +552,7 @@ func (a *API) verifyTokenHash(ctx context.Context, conn *storage.Connection, par } if isExpired { - return nil, expiredTokenError("Email link is invalid or has expired").WithInternalMessage("email link has expired") + return nil, expiredTokenError(ErrorCodeOTPExpired, "Email link is invalid or has expired").WithInternalMessage("email link has expired") } return user, nil @@ -579,13 +581,13 @@ func (a *API) verifyUserAndToken(ctx context.Context, conn *storage.Connection, if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()).WithInternalError(err) + return nil, notFoundError(ErrorCodeUserNotFound, err.Error()).WithInternalError(err) } return nil, internalServerError("Database error finding user").WithInternalError(err) } if user.IsBanned() { - return nil, unauthorizedError("Error confirming user").WithInternalMessage("user is banned") + return nil, unauthorizedError(ErrorCodeUserBanned, "User is banned") } var isValid bool @@ -626,7 +628,7 @@ func (a *API) verifyUserAndToken(ctx context.Context, conn *storage.Connection, } } if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(phone, params.Token); err != nil { - return nil, expiredTokenError("Token has expired or is invalid").WithInternalError(err) + return nil, expiredTokenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalError(err) } return user, nil } @@ -634,7 +636,7 @@ func (a *API) verifyUserAndToken(ctx context.Context, conn *storage.Connection, } if !isValid { - return nil, expiredTokenError("Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") + return nil, expiredTokenError(ErrorCodeOTPExpired, "Token has expired or is invalid").WithInternalMessage("token has expired or is invalid") } return user, nil } diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index 1cdb43ba9e..1f04255e5d 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -313,7 +313,7 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { expected ResponseBody }{ { - desc: "Expired Sms OTP", + desc: "Expired SMS OTP", sentTime: time.Now().Add(-48 * time.Hour), body: map[string]interface{}{ "type": smsVerification, @@ -323,7 +323,7 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { expected: expectedResponse, }, { - desc: "Invalid Sms OTP", + desc: "Invalid SMS OTP", sentTime: time.Now(), body: map[string]interface{}{ "type": smsVerification, @@ -1103,7 +1103,7 @@ func (ts *VerifyTestSuite) TestPrepErrorRedirectURL() { ts.Run(c.desc, func() { w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) - rurl, err := ts.API.prepErrorRedirectURL(badRequestError(DefaultError), w, req, c.rurl, c.flowType) + rurl, err := ts.API.prepErrorRedirectURL(badRequestError(ErrorCodeValidationFailed, DefaultError), w, req, c.rurl, c.flowType) require.NoError(ts.T(), err) require.Equal(ts.T(), c.expected, rurl) }) @@ -1153,7 +1153,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { Token: "some-token", }, method: http.MethodPost, - expected: badRequestError("Only an email address or phone number should be provided on verify"), + expected: badRequestError(ErrorCodeValidationFailed, "Only an email address or phone number should be provided on verify"), }, { desc: "Cannot send both TokenHash and Token", @@ -1163,7 +1163,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { TokenHash: "some-token-hash", }, method: http.MethodPost, - expected: badRequestError("Verify requires either a token or a token hash"), + expected: badRequestError(ErrorCodeValidationFailed, "Verify requires either a token or a token hash"), }, { desc: "No verification type specified", @@ -1172,7 +1172,7 @@ func (ts *VerifyTestSuite) TestVerifyValidateParams() { Email: "email@example.com", }, method: http.MethodPost, - expected: badRequestError("Verify requires a verification type"), + expected: badRequestError(ErrorCodeValidationFailed, "Verify requires a verification type"), }, } diff --git a/internal/models/flow_state.go b/internal/models/flow_state.go index 6aced0b59e..d18a5bd32b 100644 --- a/internal/models/flow_state.go +++ b/internal/models/flow_state.go @@ -81,7 +81,7 @@ func (FlowState) TableName() string { return tableName } -func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod) (*FlowState, error) { +func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod) *FlowState { id := uuid.Must(uuid.NewV4()) authCode := uuid.Must(uuid.NewV4()) flowState := &FlowState{ @@ -92,7 +92,7 @@ func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeCh AuthCode: authCode.String(), AuthenticationMethod: authenticationMethod.String(), } - return flowState, nil + return flowState } func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) error {