Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: refactor mfa tests #1322

Merged
merged 5 commits into from
Nov 28, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 88 additions & 74 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/gofrs/uuid"
"net/http"
"net/http/httptest"
"strings"
Expand All @@ -25,11 +26,12 @@ import (

type MFATestSuite struct {
suite.Suite
API *API
Config *conf.GlobalConfiguration
TestDomain string
TestEmail string
TestOTPKey *otp.Key
API *API
Config *conf.GlobalConfiguration
TestDomain string
TestEmail string
TestOTPKey *otp.Key
TestPassword string
}

func TestMFA(t *testing.T) {
Expand All @@ -53,7 +55,7 @@ func (ts *MFATestSuite) SetupTest() {
f, err := models.NewFactor(u, "test_factor", models.TOTP, models.FactorStateUnverified, "secretkey")
require.NoError(ts.T(), err, "Error creating test factor model")
require.NoError(ts.T(), ts.API.db.Create(f), "Error saving new test factor")
// Create corresponding sessoin
// Create corresponding session
s, err := models.NewSession()
require.NoError(ts.T(), err, "Error creating test session")
s.UserID = u.ID
Expand All @@ -67,6 +69,7 @@ func (ts *MFATestSuite) SetupTest() {
testDomain := strings.Split(testEmail, "@")[1]
ts.TestDomain = testDomain
ts.TestEmail = testEmail
ts.TestPassword = "password"

key, err := totp.Generate(totp.GenerateOpts{
Issuer: ts.TestDomain,
Expand All @@ -80,6 +83,12 @@ func (ts *MFATestSuite) SetupTest() {
func (ts *MFATestSuite) TestEnrollFactor() {
testFriendlyName := "bob"
alternativeFriendlyName := "john"
user, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
ts.Require().NoError(err)

token, _, err := generateAccessToken(ts.API.db, user, nil, &ts.Config.JWT)

require.NoError(ts.T(), err)
var cases = []struct {
desc string
friendlyName string
Expand Down Expand Up @@ -119,20 +128,8 @@ func (ts *MFATestSuite) TestEnrollFactor() {
}
for _, c := range cases {
ts.Run(c.desc, func() {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"friendly_name": c.friendlyName, "factor_type": c.factorType, "issuer": c.issuer}))
J0 marked this conversation as resolved.
Show resolved Hide resolved
user, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
ts.Require().NoError(err)

token, _, err := generateAccessToken(ts.API.db, user, nil, &ts.Config.JWT)
require.NoError(ts.T(), err)

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/factors", &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), c.expectedCode, w.Code)
w := enroll(ts, token, c.friendlyName, c.factorType, c.issuer, c.expectedCode)

factors, err := models.FindFactorsByUser(ts.API.db, user)
ts.Require().NoError(err)
Expand Down Expand Up @@ -176,6 +173,8 @@ func (ts *MFATestSuite) TestChallengeFactor() {
}

func (ts *MFATestSuite) TestMFAVerifyFactor() {
user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.TestEmail, ts.Config.JWT.Aud)
J0 marked this conversation as resolved.
Show resolved Hide resolved
ts.Require().NoError(err)
cases := []struct {
desc string
validChallenge bool
Expand Down Expand Up @@ -204,8 +203,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
for _, v := range cases {
ts.Run(v.desc, func() {
// Authenticate users and set secret
user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.TestEmail, ts.Config.JWT.Aud)
ts.Require().NoError(err)

var buffer bytes.Buffer
r, err := models.GrantAuthenticatedUser(ts.API.db, user, models.GrantParams{})
require.NoError(ts.T(), err)
Expand Down Expand Up @@ -271,7 +269,18 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
}
}

func (ts *MFATestSuite) setupUserAndSession() (*models.User, *models.Session) {
user, err := models.FindUserByEmailAndAudience(ts.API.db, ts.TestEmail, ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
session, err := models.FindSessionByUserID(ts.API.db, user.ID)
require.NoError(ts.T(), err)
return user, session
}

func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {

user, session := ts.setupUserAndSession()
J0 marked this conversation as resolved.
Show resolved Hide resolved

cases := []struct {
desc string
isAAL2 bool
Expand All @@ -289,25 +298,20 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
},
}
for _, v := range cases {

ts.Run(v.desc, func() {
// Create User
u, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
s, err := models.FindSessionByUserID(ts.API.db, u.ID)
require.NoError(ts.T(), err)
if v.isAAL2 {
s.UpdateAssociatedAAL(ts.API.db, models.AAL2.String())
session.UpdateAssociatedAAL(ts.API.db, models.AAL2.String())
}
var secondarySession *models.Session

// Create Session to test behaviour which downgrades other sessions
factors, err := models.FindFactorsByUser(ts.API.db, u)
factors, err := models.FindFactorsByUser(ts.API.db, user)
require.NoError(ts.T(), err, "error finding factors")
f := factors[0]
secondarySession, err = models.NewSession()
require.NoError(ts.T(), err, "Error creating test session")
secondarySession.UserID = u.ID
secondarySession.UserID = user.ID
secondarySession.FactorID = &f.ID
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

Expand All @@ -319,7 +323,7 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {

var buffer bytes.Buffer

token, _, err := generateAccessToken(ts.API.db, u, &s.ID, &ts.Config.JWT)
token, _, err := generateAccessToken(ts.API.db, user, &session.ID, &ts.Config.JWT)
require.NoError(ts.T(), err)

w := httptest.NewRecorder()
Expand All @@ -342,17 +346,15 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
}

func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {
u, err := models.FindUserByEmailAndAudience(ts.API.db, "[email protected]", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
s, err := models.FindSessionByUserID(ts.API.db, u.ID)
require.NoError(ts.T(), err)
user, session := ts.setupUserAndSession()

var secondarySession *models.Session
factors, err := models.FindFactorsByUser(ts.API.db, u)
factors, err := models.FindFactorsByUser(ts.API.db, user)
require.NoError(ts.T(), err, "error finding factors")
f := factors[0]
secondarySession, err = models.NewSession()
require.NoError(ts.T(), err, "Error creating test session")
secondarySession.UserID = u.ID
secondarySession.UserID = user.ID
secondarySession.FactorID = &f.ID
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

Expand All @@ -361,7 +363,7 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {

var buffer bytes.Buffer

token, _, err := generateAccessToken(ts.API.db, u, &s.ID, &ts.Config.JWT)
token, _, err := generateAccessToken(ts.API.db, user, &session.ID, &ts.Config.JWT)
require.NoError(ts.T(), err)
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"factor_id": f.ID,
Expand All @@ -374,22 +376,24 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {
require.Equal(ts.T(), http.StatusOK, w.Code)
_, err = models.FindFactorByFactorID(ts.API.db, f.ID)
require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error())
session, _ := models.FindSessionByID(ts.API.db, secondarySession.ID, false)
session, _ = models.FindSessionByID(ts.API.db, secondarySession.ID, false)
require.Equal(ts.T(), models.AAL1.String(), session.GetAAL())
require.Nil(ts.T(), session.FactorID)

}

// Integration Tests
func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {
email := "[email protected]"
password := "test123"
token := signUpAndVerify(ts, email, password)
resp := signUpAndVerify(ts, ts.TestEmail, ts.TestPassword)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

ts.Config.Security.RefreshTokenRotationEnabled = true
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": token.RefreshToken,
"refresh_token": accessTokenResp.RefreshToken,
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=refresh_token", &buffer)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
Expand All @@ -408,14 +412,15 @@ func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {

// Performing MFA Verification followed by a sign in should return an AAL1 session and an AAL2 session
func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() {
email := "[email protected]"
password := "test123"
token := signUpAndVerify(ts, email, password)
resp := signUpAndVerify(ts, ts.TestEmail, ts.TestPassword)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

ts.Config.Security.RefreshTokenRotationEnabled = true
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": email,
"password": password,
"email": ts.TestEmail,
"password": ts.TestPassword,
}))
req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
req.Header.Set("Content-Type", "application/json")
Expand All @@ -430,7 +435,7 @@ func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() {
ctx, err = ts.API.maybeLoadUserOrSession(ctx)
require.NoError(ts.T(), err)
require.Equal(ts.T(), models.AAL1.String(), getSession(ctx).GetAAL())
session, err := models.FindSessionByUserID(ts.API.db, token.User.ID)
session, err := models.FindSessionByUserID(ts.API.db, accessTokenResp.User.ID)
require.NoError(ts.T(), err)
require.True(ts.T(), session.IsAAL2())
}
Expand All @@ -455,43 +460,30 @@ func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenRes
return data
}

func signUpAndVerify(ts *MFATestSuite, email, password string) (verifyResp *AccessTokenResponse) {
func signUpAndVerify(ts *MFATestSuite, email, password string) *httptest.ResponseRecorder {
J0 marked this conversation as resolved.
Show resolved Hide resolved

signUpResp := signUp(ts, email, password)
verifyResp = enrollAndVerify(ts, signUpResp.User, signUpResp.Token)
resp := enrollAndVerify(ts, signUpResp.User, signUpResp.Token)

return verifyResp
return resp

}

func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) (verifyResp *AccessTokenResponse) {
func enroll(ts *MFATestSuite, token, friendlyName, factorType, issuer string, expectedCode int) *httptest.ResponseRecorder {
J0 marked this conversation as resolved.
Show resolved Hide resolved
var buffer bytes.Buffer
w := httptest.NewRecorder()
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"friendly_name": "john", "factor_type": models.TOTP, "issuer": ts.TestDomain}))
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"friendly_name": friendlyName, "factor_type": factorType, "issuer": issuer}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/factors/", &buffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")

ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
enrollResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp))
factorID := enrollResp.ID
require.Equal(ts.T(), expectedCode, w.Code)
return w

// Challenge
var challengeBuffer bytes.Buffer
x := httptest.NewRecorder()
req = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), &challengeBuffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")
ts.API.handler.ServeHTTP(x, req)
require.Equal(ts.T(), http.StatusOK, x.Code)
challengeResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(x.Body).Decode(&challengeResp))
challengeID := challengeResp.ID

// Verify
}
func verify(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, expectedCode int) *httptest.ResponseRecorder {
J0 marked this conversation as resolved.
Show resolved Hide resolved
var verifyBuffer bytes.Buffer
y := httptest.NewRecorder()

Expand All @@ -511,13 +503,35 @@ func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) (verifyR
"challenge_id": challengeID,
"code": code,
}))
req = httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), &verifyBuffer)
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", factorID), &verifyBuffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")

ts.API.handler.ServeHTTP(y, req)
require.Equal(ts.T(), http.StatusOK, y.Code)
verifyResp = &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(y.Body).Decode(&verifyResp))
return verifyResp
return y
}

func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) *httptest.ResponseRecorder {
w := enroll(ts, token, "", models.TOTP, ts.TestDomain, http.StatusOK)
enrollResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp))
factorID := enrollResp.ID

// Challenge
var challengeBuffer bytes.Buffer
x := httptest.NewRecorder()
J0 marked this conversation as resolved.
Show resolved Hide resolved
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost/factors/%s/challenge", factorID), &challengeBuffer)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Content-Type", "application/json")
ts.API.handler.ServeHTTP(x, req)
require.Equal(ts.T(), http.StatusOK, x.Code)
challengeResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(x.Body).Decode(&challengeResp))
challengeID := challengeResp.ID

// Verify
y := verify(ts, challengeID, factorID, token, http.StatusOK)

return y
}
Loading