diff --git a/internal/auth/context.go b/internal/auth/context.go index 40ebe5de30..62eb39a26f 100644 --- a/internal/auth/context.go +++ b/internal/auth/context.go @@ -9,14 +9,17 @@ type idContextKeyType struct{} var idContextKey idContextKeyType +// WithIdentityContext stores the identity in the context. func WithIdentityContext(ctx context.Context, identity *Identity) context.Context { return context.WithValue(ctx, idContextKey, identity) } +// IdentityFromContext retrieves the caller's identity from the context. +// This may return `nil` or an empty Identity if the user is not authenticated. func IdentityFromContext(ctx context.Context) *Identity { id, ok := ctx.Value(idContextKey).(*Identity) if !ok { return nil } return id -} \ No newline at end of file +} diff --git a/internal/auth/jwt/dynamic/dynamic_fetch_test.go b/internal/auth/jwt/dynamic/dynamic_fetch_test.go index cc15e0f18c..2c66058bc3 100644 --- a/internal/auth/jwt/dynamic/dynamic_fetch_test.go +++ b/internal/auth/jwt/dynamic/dynamic_fetch_test.go @@ -52,40 +52,40 @@ func TestValidator_ParseAndValidate(t *testing.T) { require.NoError(t, err) keySet := jwk.NewSet() - keySet.AddKey(pubKey) + require.NoError(t, keySet.AddKey(pubKey)) keySetJSON, err := json.Marshal(keySet) require.NoError(t, err) mux := http.NewServeMux() - mux.HandleFunc("/certs", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/certs", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") - w.Write(keySetJSON) + _, _ = w.Write(keySetJSON) }) server := httptest.NewServer(mux) t.Cleanup(server.Close) // We need to add this to the mux after server start, because it includes the server.URL - mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") - w.Write([]byte(fmt.Sprintf(`{ + _, _ = w.Write([]byte(fmt.Sprintf(`{ "issuer":"%[1]s", "jwks_uri":"%[1]s/certs", "scopes_supported":["openid","email","profile"], "claims_supported":["sub","email","iss","aud","iat","exp"] }`, server.URL))) }) - mux.HandleFunc("/other/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/other/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") - w.Write([]byte(fmt.Sprintf(`{ + _, _ = w.Write([]byte(fmt.Sprintf(`{ "issuer":"%[1]s/other", "jwks_uri":"%[1]s/certs", "scopes_supported":["openid","email","profile"], "claims_supported":["sub","email","iss","aud","iat","exp"] }`, server.URL))) }) - mux.HandleFunc("/elsewhere/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/elsewhere/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") - w.Write([]byte(fmt.Sprintf(`{ + _, _ = w.Write([]byte(fmt.Sprintf(`{ "issuer":"%[1]s/elsewhere", "jwks_uri":"%[1]s/non-existent", "scopes_supported":["openid","email","profile"], @@ -100,6 +100,7 @@ func TestValidator_ParseAndValidate(t *testing.T) { }{{ name: "valid token", getToken: func(t *testing.T) (string, openid.Token) { + t.Helper() token, err := openid.NewBuilder(). Issuer(server.URL). Subject("test"). @@ -115,6 +116,7 @@ func TestValidator_ParseAndValidate(t *testing.T) { }, { name: "valid token, other issuer", getToken: func(t *testing.T) (string, openid.Token) { + t.Helper() token, err := openid.NewBuilder(). Issuer(server.URL + "/other"). Subject("test"). @@ -130,12 +132,14 @@ func TestValidator_ParseAndValidate(t *testing.T) { }, { name: "invalid signature", getToken: func(_ *testing.T) (string, openid.Token) { + t.Helper() return "invalid", nil }, wantErr: `failed to split compact JWT: invalid number of segments`, }, { name: "expired jwt", getToken: func(_ *testing.T) (string, openid.Token) { + t.Helper() token, err := openid.NewBuilder(). Issuer(server.URL + "/elsewhere"). Subject("test"). @@ -152,6 +156,7 @@ func TestValidator_ParseAndValidate(t *testing.T) { }, { name: "bad well-known URL", getToken: func(t *testing.T) (string, openid.Token) { + t.Helper() token, err := openid.NewBuilder(). Issuer(server.URL + "/elsewhere"). Subject("test"). @@ -168,6 +173,7 @@ func TestValidator_ParseAndValidate(t *testing.T) { }, { name: "bad issuer", getToken: func(t *testing.T) (string, openid.Token) { + t.Helper() token, err := openid.NewBuilder(). Issuer(server.URL + "/nothing"). Subject("test"). diff --git a/internal/authz/authz.go b/internal/authz/authz.go index 5dc610e5ee..80865f826a 100644 --- a/internal/authz/authz.go +++ b/internal/authz/authz.go @@ -38,10 +38,9 @@ var ( // It is used to provide a common interface for the client and a way to // refresh authentication to the authz provider when needed. type ClientWrapper struct { - cfg *srvconfig.AuthzConfig - cli *fgaclient.OpenFgaClient - resolver auth.Resolver - l *zerolog.Logger + cfg *srvconfig.AuthzConfig + cli *fgaclient.OpenFgaClient + l *zerolog.Logger } var _ Client = &ClientWrapper{} @@ -53,8 +52,8 @@ func NewAuthzClient(cfg *srvconfig.AuthzConfig, l *zerolog.Logger) (Client, erro } cliWrap := &ClientWrapper{ - cfg: cfg, - l: l, + cfg: cfg, + l: l, } if err := cliWrap.initAuthzClient(); err != nil { diff --git a/internal/controlplane/handlers_authz.go b/internal/controlplane/handlers_authz.go index 54ec3c7ce4..d7236cb99a 100644 --- a/internal/controlplane/handlers_authz.go +++ b/internal/controlplane/handlers_authz.go @@ -326,7 +326,7 @@ func (s *Server) AssignRole(ctx context.Context, req *minder.AssignRoleRequest) if sub == "" && inviteeEmail != "" { if flags.Bool(ctx, s.featureFlags, flags.UserManagement) { invitation, err := db.WithTransaction(s.store, func(qtx db.ExtendQuerier) (*minder.Invitation, error) { - return s.invites.CreateInvite(ctx, qtx, s.idClient, s.evt, s.cfg.Email, targetProject, authzRole, inviteeEmail) + return s.invites.CreateInvite(ctx, qtx, s.evt, s.cfg.Email, targetProject, authzRole, inviteeEmail) }) if err != nil { return nil, err @@ -433,7 +433,7 @@ func (s *Server) UpdateRole(ctx context.Context, req *minder.UpdateRoleRequest) if sub == "" && inviteeEmail != "" { if flags.Bool(ctx, s.featureFlags, flags.UserManagement) { updatedInvitation, err := db.WithTransaction(s.store, func(qtx db.ExtendQuerier) (*minder.Invitation, error) { - return s.invites.UpdateInvite(ctx, qtx, s.idClient, s.evt, s.cfg.Email, targetProject, authzRole, inviteeEmail) + return s.invites.UpdateInvite(ctx, qtx, s.evt, s.cfg.Email, targetProject, authzRole, inviteeEmail) }) if err != nil { return nil, err diff --git a/internal/controlplane/handlers_authz_test.go b/internal/controlplane/handlers_authz_test.go index bc91687f81..4b08899d92 100644 --- a/internal/controlplane/handlers_authz_test.go +++ b/internal/controlplane/handlers_authz_test.go @@ -601,7 +601,7 @@ func TestUpdateRole(t *testing.T) { mockInviteService := mockinvites.NewMockInviteService(ctrl) if tc.expectedInvitation { - mockInviteService.EXPECT().UpdateInvite(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + mockInviteService.EXPECT().UpdateInvite(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), authzRole, tc.inviteeEmail).Return(&minder.Invitation{}, nil) } mockRoleService := mockroles.NewMockRoleService(ctrl) @@ -736,7 +736,7 @@ func TestAssignRole(t *testing.T) { mockInviteService := mockinvites.NewMockInviteService(ctrl) if tc.inviteByEmail { - mockInviteService.EXPECT().CreateInvite(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), + mockInviteService.EXPECT().CreateInvite(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), authzRole, tc.inviteeEmail).Return(&minder.Invitation{ Role: authzRole.String(), Project: projectIdString, diff --git a/internal/invites/mock/service.go b/internal/invites/mock/service.go index 420d8c1e6d..b8c12dbd9c 100644 --- a/internal/invites/mock/service.go +++ b/internal/invites/mock/service.go @@ -48,18 +48,18 @@ func (m *MockInviteService) EXPECT() *MockInviteServiceMockRecorder { } // CreateInvite mocks base method. -func (m *MockInviteService) CreateInvite(ctx context.Context, qtx db.Querier, idClient auth.Resolver, eventsPub interfaces.Publisher, emailConfig server.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string) (*v1.Invitation, error) { +func (m *MockInviteService) CreateInvite(ctx context.Context, qtx db.Querier, eventsPub interfaces.Publisher, emailConfig server.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string) (*v1.Invitation, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateInvite", ctx, qtx, idClient, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail) + ret := m.ctrl.Call(m, "CreateInvite", ctx, qtx, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail) ret0, _ := ret[0].(*v1.Invitation) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateInvite indicates an expected call of CreateInvite. -func (mr *MockInviteServiceMockRecorder) CreateInvite(ctx, qtx, idClient, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail any) *gomock.Call { +func (mr *MockInviteServiceMockRecorder) CreateInvite(ctx, qtx, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateInvite", reflect.TypeOf((*MockInviteService)(nil).CreateInvite), ctx, qtx, idClient, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateInvite", reflect.TypeOf((*MockInviteService)(nil).CreateInvite), ctx, qtx, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail) } // RemoveInvite mocks base method. @@ -78,16 +78,16 @@ func (mr *MockInviteServiceMockRecorder) RemoveInvite(ctx, qtx, idClient, target } // UpdateInvite mocks base method. -func (m *MockInviteService) UpdateInvite(ctx context.Context, qtx db.Querier, idClient auth.Resolver, eventsPub interfaces.Publisher, emailConfig server.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string) (*v1.Invitation, error) { +func (m *MockInviteService) UpdateInvite(ctx context.Context, qtx db.Querier, eventsPub interfaces.Publisher, emailConfig server.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string) (*v1.Invitation, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateInvite", ctx, qtx, idClient, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail) + ret := m.ctrl.Call(m, "UpdateInvite", ctx, qtx, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail) ret0, _ := ret[0].(*v1.Invitation) ret1, _ := ret[1].(error) return ret0, ret1 } // UpdateInvite indicates an expected call of UpdateInvite. -func (mr *MockInviteServiceMockRecorder) UpdateInvite(ctx, qtx, idClient, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail any) *gomock.Call { +func (mr *MockInviteServiceMockRecorder) UpdateInvite(ctx, qtx, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateInvite", reflect.TypeOf((*MockInviteService)(nil).UpdateInvite), ctx, qtx, idClient, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateInvite", reflect.TypeOf((*MockInviteService)(nil).UpdateInvite), ctx, qtx, eventsPub, emailConfig, targetProject, authzRole, inviteeEmail) } diff --git a/internal/invites/service.go b/internal/invites/service.go index 38a96135d5..7037c2be1b 100644 --- a/internal/invites/service.go +++ b/internal/invites/service.go @@ -32,12 +32,12 @@ import ( // InviteService encapsulates the methods to manage user invites to a project type InviteService interface { // CreateInvite creates a new user invite - CreateInvite(ctx context.Context, qtx db.Querier, idClient auth.Resolver, eventsPub interfaces.Publisher, + CreateInvite(ctx context.Context, qtx db.Querier, eventsPub interfaces.Publisher, emailConfig serverconfig.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string, ) (*minder.Invitation, error) // UpdateInvite updates the invite status - UpdateInvite(ctx context.Context, qtx db.Querier, idClient auth.Resolver, eventsPub interfaces.Publisher, + UpdateInvite(ctx context.Context, qtx db.Querier, eventsPub interfaces.Publisher, emailConfig serverconfig.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string, ) (*minder.Invitation, error) @@ -55,7 +55,7 @@ func NewInviteService() InviteService { return &inviteService{} } -func (_ *inviteService) UpdateInvite(ctx context.Context, qtx db.Querier, idClient auth.Resolver, eventsPub interfaces.Publisher, +func (_ *inviteService) UpdateInvite(ctx context.Context, qtx db.Querier, eventsPub interfaces.Publisher, emailConfig serverconfig.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string, ) (*minder.Invitation, error) { var userInvite db.UserInvite @@ -225,11 +225,12 @@ func (_ *inviteService) RemoveInvite(ctx context.Context, qtx db.Querier, idClie }, nil } -func (_ *inviteService) CreateInvite(ctx context.Context, qtx db.Querier, idClient auth.Resolver, eventsPub interfaces.Publisher, +func (_ *inviteService) CreateInvite(ctx context.Context, qtx db.Querier, eventsPub interfaces.Publisher, emailConfig serverconfig.EmailConfig, targetProject uuid.UUID, authzRole authz.Role, inviteeEmail string, ) (*minder.Invitation, error) { identity := auth.IdentityFromContext(ctx) - if identity.Provider.String() != "" { + // Slight hack -- only the null/default provider has String == UserID + if identity == nil || identity.String() != identity.UserID { return nil, util.UserVisibleError(codes.PermissionDenied, "only human users can create invites") } // Get the sponsor's user information (current user) diff --git a/internal/invites/service_test.go b/internal/invites/service_test.go index 4edf0252b9..a30cbed4f8 100644 --- a/internal/invites/service_test.go +++ b/internal/invites/service_test.go @@ -76,12 +76,9 @@ func TestCreateInvite(t *testing.T) { assert.NoError(t, user.Set("sub", userEmail)) ctx := context.Background() - ctx = authjwt.WithAuthTokenContext(ctx, user) - - idClient := mockauth.NewMockResolver(ctrl) - idClient.EXPECT().Resolve(ctx, userSubject).Return(&auth.Identity{ + ctx = auth.WithIdentityContext(ctx, &auth.Identity{ UserID: userSubject, - }, nil).AnyTimes() + }) publisher := mockevents.NewMockPublisher(ctrl) if scenario.publisherSetup != nil { @@ -93,7 +90,7 @@ func TestCreateInvite(t *testing.T) { } service := NewInviteService() - invite, err := service.CreateInvite(ctx, scenario.dBSetup(ctrl), idClient, publisher, emailConfig, projectId, userRole, userEmail) + invite, err := service.CreateInvite(ctx, scenario.dBSetup(ctrl), publisher, emailConfig, projectId, userRole, userEmail) if scenario.expectedError != "" { require.ErrorContains(t, err, scenario.expectedError) @@ -122,7 +119,7 @@ func TestUpdateInvite(t *testing.T) { { name: "error when no existing invites", dBSetup: dbf.NewDBMock( - withGetUserBySubject(validUser), + // withGetUserBySubject(validUser), withExistingInvites(noInvites), ), expectedError: "no invitations found for this email and project", @@ -130,7 +127,7 @@ func TestUpdateInvite(t *testing.T) { { name: "error when multiple existing invites", dBSetup: dbf.NewDBMock( - withGetUserBySubject(validUser), + // withGetUserBySubject(validUser), withExistingInvites(multipleInvites), ), expectedError: "multiple invitations found for this email and project", @@ -138,7 +135,7 @@ func TestUpdateInvite(t *testing.T) { { name: "no message sent when role is the same", dBSetup: dbf.NewDBMock( - withGetUserBySubject(validUser), + // withGetUserBySubject(validUser), withExistingInvites(singleInviteWithSameRole), withInviteRoleUpdate(userInvite, nil), withProject(), @@ -147,7 +144,7 @@ func TestUpdateInvite(t *testing.T) { { name: "invite updated and message sent successfully", dBSetup: dbf.NewDBMock( - withGetUserBySubject(validUser), + // withGetUserBySubject(validUser), withExistingInvites(singleInviteWithDifferentRole), withInviteRoleUpdate(userInvite, nil), withProject(), @@ -174,12 +171,10 @@ func TestUpdateInvite(t *testing.T) { assert.NoError(t, user.Set("sub", userEmail)) ctx := context.Background() - ctx = authjwt.WithAuthTokenContext(ctx, user) - - idClient := mockauth.NewMockResolver(ctrl) - idClient.EXPECT().Resolve(ctx, userSubject).Return(&auth.Identity{ + // ctx = authjwt.WithAuthTokenContext(ctx, user) + ctx = auth.WithIdentityContext(ctx, &auth.Identity{ UserID: userSubject, - }, nil).AnyTimes() + }) publisher := mockevents.NewMockPublisher(ctrl) if scenario.publisherSetup != nil { @@ -191,7 +186,7 @@ func TestUpdateInvite(t *testing.T) { } service := NewInviteService() - invite, err := service.UpdateInvite(ctx, scenario.dBSetup(ctrl), idClient, publisher, emailConfig, projectId, userRole, userEmail) + invite, err := service.UpdateInvite(ctx, scenario.dBSetup(ctrl), publisher, emailConfig, projectId, userRole, userEmail) if scenario.expectedError != "" { require.ErrorContains(t, err, scenario.expectedError) @@ -280,7 +275,7 @@ func TestRemoveInvite(t *testing.T) { var ( projectId = uuid.New() userEmail = "test@example.com" - userSubject = "subject" + userSubject = uuid.New().String() userRole = authz.RoleAdmin inviteCode = "code" baseUrl = "https://minder.example.com" diff --git a/internal/roles/service_test.go b/internal/roles/service_test.go index 0800e0fb0b..3b612b650f 100644 --- a/internal/roles/service_test.go +++ b/internal/roles/service_test.go @@ -82,7 +82,7 @@ func TestCreateRoleAssignment(t *testing.T) { idClient := mockauth.NewMockResolver(ctrl) idClient.EXPECT().Resolve(ctx, subject).Return(&auth.Identity{ - UserID: subject, + UserID: subject, Provider: &keycloak.KeyCloak{}, }, nil) @@ -153,7 +153,7 @@ func TestUpdateRoleAssignment(t *testing.T) { idClient := mockauth.NewMockResolver(ctrl) idClient.EXPECT().Resolve(ctx, subject).Return(&auth.Identity{ - UserID: subject, + UserID: subject, Provider: &keycloak.KeyCloak{}, }, nil) @@ -232,7 +232,7 @@ func TestRemoveRole(t *testing.T) { idClient := mockauth.NewMockResolver(ctrl) idClient.EXPECT().Resolve(ctx, subject).Return(&auth.Identity{ - UserID: subject, + UserID: subject, Provider: &keycloak.KeyCloak{}, }, nil)