Skip to content

Commit

Permalink
Fix lint and test errors, simplify interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
evankanderson committed Jan 30, 2025
1 parent cde854a commit e10fca5
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 53 deletions.
5 changes: 4 additions & 1 deletion internal/auth/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ type idContextKeyType struct{}

var idContextKey idContextKeyType

// WithIdentityContext stores the identity in the context.
func WithIdentityContext(ctx context.Context, identity *Identity) context.Context {
return context.WithValue(ctx, idContextKey, identity)
}

// IdentityFromContext retrieves the caller's identity from the context.
// This may return `nil` or an empty Identity if the user is not authenticated.
func IdentityFromContext(ctx context.Context) *Identity {
id, ok := ctx.Value(idContextKey).(*Identity)
if !ok {
return nil
}
return id
}
}
24 changes: 15 additions & 9 deletions internal/auth/jwt/dynamic/dynamic_fetch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,40 +52,40 @@ func TestValidator_ParseAndValidate(t *testing.T) {
require.NoError(t, err)

keySet := jwk.NewSet()
keySet.AddKey(pubKey)
require.NoError(t, keySet.AddKey(pubKey))
keySetJSON, err := json.Marshal(keySet)
require.NoError(t, err)

mux := http.NewServeMux()
mux.HandleFunc("/certs", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/certs", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write(keySetJSON)
_, _ = w.Write(keySetJSON)
})
server := httptest.NewServer(mux)
t.Cleanup(server.Close)

// We need to add this to the mux after server start, because it includes the server.URL
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(fmt.Sprintf(`{
_, _ = w.Write([]byte(fmt.Sprintf(`{
"issuer":"%[1]s",
"jwks_uri":"%[1]s/certs",
"scopes_supported":["openid","email","profile"],
"claims_supported":["sub","email","iss","aud","iat","exp"]
}`, server.URL)))
})
mux.HandleFunc("/other/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/other/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(fmt.Sprintf(`{
_, _ = w.Write([]byte(fmt.Sprintf(`{
"issuer":"%[1]s/other",
"jwks_uri":"%[1]s/certs",
"scopes_supported":["openid","email","profile"],
"claims_supported":["sub","email","iss","aud","iat","exp"]
}`, server.URL)))
})
mux.HandleFunc("/elsewhere/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/elsewhere/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(fmt.Sprintf(`{
_, _ = w.Write([]byte(fmt.Sprintf(`{
"issuer":"%[1]s/elsewhere",
"jwks_uri":"%[1]s/non-existent",
"scopes_supported":["openid","email","profile"],
Expand All @@ -100,6 +100,7 @@ func TestValidator_ParseAndValidate(t *testing.T) {
}{{
name: "valid token",
getToken: func(t *testing.T) (string, openid.Token) {
t.Helper()
token, err := openid.NewBuilder().
Issuer(server.URL).
Subject("test").
Expand All @@ -115,6 +116,7 @@ func TestValidator_ParseAndValidate(t *testing.T) {
}, {
name: "valid token, other issuer",
getToken: func(t *testing.T) (string, openid.Token) {
t.Helper()
token, err := openid.NewBuilder().
Issuer(server.URL + "/other").
Subject("test").
Expand All @@ -130,12 +132,14 @@ func TestValidator_ParseAndValidate(t *testing.T) {
}, {
name: "invalid signature",
getToken: func(_ *testing.T) (string, openid.Token) {
t.Helper()
return "invalid", nil
},
wantErr: `failed to split compact JWT: invalid number of segments`,
}, {
name: "expired jwt",
getToken: func(_ *testing.T) (string, openid.Token) {
t.Helper()
token, err := openid.NewBuilder().
Issuer(server.URL + "/elsewhere").
Subject("test").
Expand All @@ -152,6 +156,7 @@ func TestValidator_ParseAndValidate(t *testing.T) {
}, {
name: "bad well-known URL",
getToken: func(t *testing.T) (string, openid.Token) {
t.Helper()
token, err := openid.NewBuilder().
Issuer(server.URL + "/elsewhere").
Subject("test").
Expand All @@ -168,6 +173,7 @@ func TestValidator_ParseAndValidate(t *testing.T) {
}, {
name: "bad issuer",
getToken: func(t *testing.T) (string, openid.Token) {
t.Helper()
token, err := openid.NewBuilder().
Issuer(server.URL + "/nothing").
Subject("test").
Expand Down
11 changes: 5 additions & 6 deletions internal/authz/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ var (
// It is used to provide a common interface for the client and a way to
// refresh authentication to the authz provider when needed.
type ClientWrapper struct {
cfg *srvconfig.AuthzConfig
cli *fgaclient.OpenFgaClient
resolver auth.Resolver
l *zerolog.Logger
cfg *srvconfig.AuthzConfig
cli *fgaclient.OpenFgaClient
l *zerolog.Logger
}

var _ Client = &ClientWrapper{}
Expand All @@ -53,8 +52,8 @@ func NewAuthzClient(cfg *srvconfig.AuthzConfig, l *zerolog.Logger) (Client, erro
}

cliWrap := &ClientWrapper{
cfg: cfg,
l: l,
cfg: cfg,
l: l,
}

if err := cliWrap.initAuthzClient(); err != nil {
Expand Down
4 changes: 2 additions & 2 deletions internal/controlplane/handlers_authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ func (s *Server) AssignRole(ctx context.Context, req *minder.AssignRoleRequest)
if sub == "" && inviteeEmail != "" {
if flags.Bool(ctx, s.featureFlags, flags.UserManagement) {
invitation, err := db.WithTransaction(s.store, func(qtx db.ExtendQuerier) (*minder.Invitation, error) {
return s.invites.CreateInvite(ctx, qtx, s.idClient, s.evt, s.cfg.Email, targetProject, authzRole, inviteeEmail)
return s.invites.CreateInvite(ctx, qtx, s.evt, s.cfg.Email, targetProject, authzRole, inviteeEmail)
})
if err != nil {
return nil, err
Expand Down Expand Up @@ -433,7 +433,7 @@ func (s *Server) UpdateRole(ctx context.Context, req *minder.UpdateRoleRequest)
if sub == "" && inviteeEmail != "" {
if flags.Bool(ctx, s.featureFlags, flags.UserManagement) {
updatedInvitation, err := db.WithTransaction(s.store, func(qtx db.ExtendQuerier) (*minder.Invitation, error) {
return s.invites.UpdateInvite(ctx, qtx, s.idClient, s.evt, s.cfg.Email, targetProject, authzRole, inviteeEmail)
return s.invites.UpdateInvite(ctx, qtx, s.evt, s.cfg.Email, targetProject, authzRole, inviteeEmail)
})
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions internal/controlplane/handlers_authz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ func TestUpdateRole(t *testing.T) {

mockInviteService := mockinvites.NewMockInviteService(ctrl)
if tc.expectedInvitation {
mockInviteService.EXPECT().UpdateInvite(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
mockInviteService.EXPECT().UpdateInvite(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Any(), authzRole, tc.inviteeEmail).Return(&minder.Invitation{}, nil)
}
mockRoleService := mockroles.NewMockRoleService(ctrl)
Expand Down Expand Up @@ -736,7 +736,7 @@ func TestAssignRole(t *testing.T) {

mockInviteService := mockinvites.NewMockInviteService(ctrl)
if tc.inviteByEmail {
mockInviteService.EXPECT().CreateInvite(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
mockInviteService.EXPECT().CreateInvite(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(),
gomock.Any(), authzRole, tc.inviteeEmail).Return(&minder.Invitation{
Role: authzRole.String(),
Project: projectIdString,
Expand Down
16 changes: 8 additions & 8 deletions internal/invites/mock/service.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 6 additions & 5 deletions internal/invites/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ import (
// InviteService encapsulates the methods to manage user invites to a project
type InviteService interface {
// CreateInvite creates a new user invite
CreateInvite(ctx context.Context, qtx db.Querier, idClient auth.Resolver, eventsPub interfaces.Publisher,
CreateInvite(ctx context.Context, qtx db.Querier, eventsPub interfaces.Publisher,
emailConfig serverconfig.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string,
) (*minder.Invitation, error)

// UpdateInvite updates the invite status
UpdateInvite(ctx context.Context, qtx db.Querier, idClient auth.Resolver, eventsPub interfaces.Publisher,
UpdateInvite(ctx context.Context, qtx db.Querier, eventsPub interfaces.Publisher,
emailConfig serverconfig.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string,
) (*minder.Invitation, error)

Expand All @@ -55,7 +55,7 @@ func NewInviteService() InviteService {
return &inviteService{}
}

func (_ *inviteService) UpdateInvite(ctx context.Context, qtx db.Querier, idClient auth.Resolver, eventsPub interfaces.Publisher,
func (_ *inviteService) UpdateInvite(ctx context.Context, qtx db.Querier, eventsPub interfaces.Publisher,
emailConfig serverconfig.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string,
) (*minder.Invitation, error) {
var userInvite db.UserInvite
Expand Down Expand Up @@ -225,11 +225,12 @@ func (_ *inviteService) RemoveInvite(ctx context.Context, qtx db.Querier, idClie
}, nil
}

func (_ *inviteService) CreateInvite(ctx context.Context, qtx db.Querier, idClient auth.Resolver, eventsPub interfaces.Publisher,
func (_ *inviteService) CreateInvite(ctx context.Context, qtx db.Querier, eventsPub interfaces.Publisher,
emailConfig serverconfig.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string,
) (*minder.Invitation, error) {
identity := auth.IdentityFromContext(ctx)
if identity.Provider.String() != "" {
// Slight hack -- only the null/default provider has String == UserID
if identity == nil || identity.String() != identity.UserID {
return nil, util.UserVisibleError(codes.PermissionDenied, "only human users can create invites")
}
// Get the sponsor's user information (current user)
Expand Down
29 changes: 12 additions & 17 deletions internal/invites/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,9 @@ func TestCreateInvite(t *testing.T) {
assert.NoError(t, user.Set("sub", userEmail))

ctx := context.Background()
ctx = authjwt.WithAuthTokenContext(ctx, user)

idClient := mockauth.NewMockResolver(ctrl)
idClient.EXPECT().Resolve(ctx, userSubject).Return(&auth.Identity{
ctx = auth.WithIdentityContext(ctx, &auth.Identity{
UserID: userSubject,
}, nil).AnyTimes()
})

publisher := mockevents.NewMockPublisher(ctrl)
if scenario.publisherSetup != nil {
Expand All @@ -93,7 +90,7 @@ func TestCreateInvite(t *testing.T) {
}

service := NewInviteService()
invite, err := service.CreateInvite(ctx, scenario.dBSetup(ctrl), idClient, publisher, emailConfig, projectId, userRole, userEmail)
invite, err := service.CreateInvite(ctx, scenario.dBSetup(ctrl), publisher, emailConfig, projectId, userRole, userEmail)

if scenario.expectedError != "" {
require.ErrorContains(t, err, scenario.expectedError)
Expand Down Expand Up @@ -122,23 +119,23 @@ func TestUpdateInvite(t *testing.T) {
{
name: "error when no existing invites",
dBSetup: dbf.NewDBMock(
withGetUserBySubject(validUser),
// withGetUserBySubject(validUser),
withExistingInvites(noInvites),
),
expectedError: "no invitations found for this email and project",
},
{
name: "error when multiple existing invites",
dBSetup: dbf.NewDBMock(
withGetUserBySubject(validUser),
// withGetUserBySubject(validUser),
withExistingInvites(multipleInvites),
),
expectedError: "multiple invitations found for this email and project",
},
{
name: "no message sent when role is the same",
dBSetup: dbf.NewDBMock(
withGetUserBySubject(validUser),
// withGetUserBySubject(validUser),
withExistingInvites(singleInviteWithSameRole),
withInviteRoleUpdate(userInvite, nil),
withProject(),
Expand All @@ -147,7 +144,7 @@ func TestUpdateInvite(t *testing.T) {
{
name: "invite updated and message sent successfully",
dBSetup: dbf.NewDBMock(
withGetUserBySubject(validUser),
// withGetUserBySubject(validUser),
withExistingInvites(singleInviteWithDifferentRole),
withInviteRoleUpdate(userInvite, nil),
withProject(),
Expand All @@ -174,12 +171,10 @@ func TestUpdateInvite(t *testing.T) {
assert.NoError(t, user.Set("sub", userEmail))

ctx := context.Background()
ctx = authjwt.WithAuthTokenContext(ctx, user)

idClient := mockauth.NewMockResolver(ctrl)
idClient.EXPECT().Resolve(ctx, userSubject).Return(&auth.Identity{
// ctx = authjwt.WithAuthTokenContext(ctx, user)
ctx = auth.WithIdentityContext(ctx, &auth.Identity{
UserID: userSubject,
}, nil).AnyTimes()
})

publisher := mockevents.NewMockPublisher(ctrl)
if scenario.publisherSetup != nil {
Expand All @@ -191,7 +186,7 @@ func TestUpdateInvite(t *testing.T) {
}

service := NewInviteService()
invite, err := service.UpdateInvite(ctx, scenario.dBSetup(ctrl), idClient, publisher, emailConfig, projectId, userRole, userEmail)
invite, err := service.UpdateInvite(ctx, scenario.dBSetup(ctrl), publisher, emailConfig, projectId, userRole, userEmail)

if scenario.expectedError != "" {
require.ErrorContains(t, err, scenario.expectedError)
Expand Down Expand Up @@ -280,7 +275,7 @@ func TestRemoveInvite(t *testing.T) {
var (
projectId = uuid.New()
userEmail = "[email protected]"
userSubject = "subject"
userSubject = uuid.New().String()
userRole = authz.RoleAdmin
inviteCode = "code"
baseUrl = "https://minder.example.com"
Expand Down
6 changes: 3 additions & 3 deletions internal/roles/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestCreateRoleAssignment(t *testing.T) {

idClient := mockauth.NewMockResolver(ctrl)
idClient.EXPECT().Resolve(ctx, subject).Return(&auth.Identity{
UserID: subject,
UserID: subject,
Provider: &keycloak.KeyCloak{},
}, nil)

Expand Down Expand Up @@ -153,7 +153,7 @@ func TestUpdateRoleAssignment(t *testing.T) {

idClient := mockauth.NewMockResolver(ctrl)
idClient.EXPECT().Resolve(ctx, subject).Return(&auth.Identity{
UserID: subject,
UserID: subject,
Provider: &keycloak.KeyCloak{},
}, nil)

Expand Down Expand Up @@ -232,7 +232,7 @@ func TestRemoveRole(t *testing.T) {

idClient := mockauth.NewMockResolver(ctrl)
idClient.EXPECT().Resolve(ctx, subject).Return(&auth.Identity{
UserID: subject,
UserID: subject,
Provider: &keycloak.KeyCloak{},
}, nil)

Expand Down

0 comments on commit e10fca5

Please sign in to comment.