diff --git a/groups/api/http/endpoints.go b/groups/api/http/endpoints.go index 4a5f238c2d..99bb6f9617 100644 --- a/groups/api/http/endpoints.go +++ b/groups/api/http/endpoints.go @@ -27,7 +27,7 @@ func CreateGroupEndpoint(svc groups.Service) endpoint.Endpoint { return createGroupRes{created: false}, svcerr.ErrAuthentication } - group, err := svc.CreateGroup(ctx, session, req.Group) + group, _, err := svc.CreateGroup(ctx, session, req.Group) if err != nil { return createGroupRes{created: false}, err } diff --git a/groups/events/events.go b/groups/events/events.go index 50b4dc3b6a..629ef0da6e 100644 --- a/groups/events/events.go +++ b/groups/events/events.go @@ -8,6 +8,7 @@ import ( groups "github.com/absmach/supermq/groups" "github.com/absmach/supermq/pkg/events" + "github.com/absmach/supermq/pkg/roles" ) var ( @@ -49,14 +50,16 @@ var ( type createGroupEvent struct { groups.Group + rolesProvisioned []roles.RoleProvision } func (cge createGroupEvent) Encode() (map[string]interface{}, error) { val := map[string]interface{}{ - "operation": groupCreate, - "id": cge.ID, - "status": cge.Status.String(), - "created_at": cge.CreatedAt, + "operation": groupCreate, + "id": cge.ID, + "roles_provisioned": cge.rolesProvisioned, + "status": cge.Status.String(), + "created_at": cge.CreatedAt, } if cge.Domain != "" { diff --git a/groups/events/streams.go b/groups/events/streams.go index 6886fcd8f1..793554daa0 100644 --- a/groups/events/streams.go +++ b/groups/events/streams.go @@ -10,6 +10,7 @@ import ( "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/events" "github.com/absmach/supermq/pkg/events/store" + "github.com/absmach/supermq/pkg/roles" rmEvents "github.com/absmach/supermq/pkg/roles/rolemanager/events" ) @@ -39,21 +40,22 @@ func New(ctx context.Context, svc groups.Service, url string) (groups.Service, e }, nil } -func (es eventStore) CreateGroup(ctx context.Context, session authn.Session, group groups.Group) (groups.Group, error) { - group, err := es.svc.CreateGroup(ctx, session, group) +func (es eventStore) CreateGroup(ctx context.Context, session authn.Session, group groups.Group) (groups.Group, []roles.RoleProvision, error) { + group, rps, err := es.svc.CreateGroup(ctx, session, group) if err != nil { - return group, err + return group, rps, err } event := createGroupEvent{ - group, + Group: group, + rolesProvisioned: rps, } if err := es.Publish(ctx, event); err != nil { - return group, err + return group, rps, err } - return group, nil + return group, rps, nil } func (es eventStore) UpdateGroup(ctx context.Context, session authn.Session, group groups.Group) (groups.Group, error) { diff --git a/groups/groups.go b/groups/groups.go index c9a8bde68e..b9bf07046f 100644 --- a/groups/groups.go +++ b/groups/groups.go @@ -134,7 +134,7 @@ type Repository interface { //go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines" --unroll-variadic=false type Service interface { // CreateGroup creates new group. - CreateGroup(ctx context.Context, session authn.Session, g Group) (Group, error) + CreateGroup(ctx context.Context, session authn.Session, g Group) (Group, []roles.RoleProvision, error) // UpdateGroup updates the group identified by the provided ID. UpdateGroup(ctx context.Context, session authn.Session, g Group) (Group, error) diff --git a/groups/middleware/authorization.go b/groups/middleware/authorization.go index d8cfb2742d..0beccdf563 100644 --- a/groups/middleware/authorization.go +++ b/groups/middleware/authorization.go @@ -14,6 +14,7 @@ import ( "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/pkg/roles" rmMW "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" "github.com/absmach/supermq/pkg/svcutil" ) @@ -79,7 +80,7 @@ func AuthorizationMiddleware(entityType string, svc groups.Service, repo groups. }, nil } -func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, error) { +func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, []roles.RoleProvision, error) { if err := am.extAuthorize(ctx, groups.DomainOpCreateGroup, smqauthz.PolicyReq{ Domain: session.DomainID, SubjectType: policies.UserType, @@ -88,7 +89,7 @@ func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session auth Object: session.DomainID, ObjectType: policies.DomainType, }); err != nil { - return groups.Group{}, errors.Wrap(errDomainCreateGroups, err) + return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errDomainCreateGroups, err) } if g.Parent != "" { @@ -100,7 +101,7 @@ func (am *authorizationMiddleware) CreateGroup(ctx context.Context, session auth Object: g.Parent, ObjectType: policies.GroupType, }); err != nil { - return groups.Group{}, errors.Wrap(errParentGroupSetChildGroup, err) + return groups.Group{}, []roles.RoleProvision{}, errors.Wrap(errParentGroupSetChildGroup, err) } } diff --git a/groups/middleware/logging.go b/groups/middleware/logging.go index f9bab7c3be..3bc7fddf1a 100644 --- a/groups/middleware/logging.go +++ b/groups/middleware/logging.go @@ -10,6 +10,7 @@ import ( "github.com/absmach/supermq/groups" "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/roles" rmMW "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" ) @@ -28,7 +29,7 @@ func LoggingMiddleware(svc groups.Service, logger *slog.Logger) groups.Service { // CreateGroup logs the create_group request. It logs the group name, id and token and the time it took to complete the request. // If the request fails, it logs the error. -func (lm *loggingMiddleware) CreateGroup(ctx context.Context, session authn.Session, group groups.Group) (g groups.Group, err error) { +func (lm *loggingMiddleware) CreateGroup(ctx context.Context, session authn.Session, group groups.Group) (g groups.Group, rps []roles.RoleProvision, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), diff --git a/groups/middleware/metrics.go b/groups/middleware/metrics.go index 9c3e349b54..68cf578a05 100644 --- a/groups/middleware/metrics.go +++ b/groups/middleware/metrics.go @@ -9,6 +9,7 @@ import ( "github.com/absmach/supermq/groups" "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/roles" rmMW "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" "github.com/go-kit/kit/metrics" ) @@ -34,7 +35,7 @@ func MetricsMiddleware(svc groups.Service, counter metrics.Counter, latency metr } // CreateGroup instruments CreateGroup method with metrics. -func (ms *metricsMiddleware) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, error) { +func (ms *metricsMiddleware) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, []roles.RoleProvision, error) { defer func(begin time.Time) { ms.counter.With("method", "create_group").Add(1) ms.latency.With("method", "create_group").Observe(time.Since(begin).Seconds()) diff --git a/groups/mocks/service.go b/groups/mocks/service.go index 4dda23ee9c..e0726aead8 100644 --- a/groups/mocks/service.go +++ b/groups/mocks/service.go @@ -86,7 +86,7 @@ func (_m *Service) AddRole(ctx context.Context, session authn.Session, entityID } // CreateGroup provides a mock function with given fields: ctx, session, g -func (_m *Service) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, error) { +func (_m *Service) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, []roles.RoleProvision, error) { ret := _m.Called(ctx, session, g) if len(ret) == 0 { @@ -94,8 +94,9 @@ func (_m *Service) CreateGroup(ctx context.Context, session authn.Session, g gro } var r0 groups.Group - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, groups.Group) (groups.Group, error)); ok { + var r1 []roles.RoleProvision + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, groups.Group) (groups.Group, []roles.RoleProvision, error)); ok { return rf(ctx, session, g) } if rf, ok := ret.Get(0).(func(context.Context, authn.Session, groups.Group) groups.Group); ok { @@ -104,13 +105,21 @@ func (_m *Service) CreateGroup(ctx context.Context, session authn.Session, g gro r0 = ret.Get(0).(groups.Group) } - if rf, ok := ret.Get(1).(func(context.Context, authn.Session, groups.Group) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, authn.Session, groups.Group) []roles.RoleProvision); ok { r1 = rf(ctx, session, g) } else { - r1 = ret.Error(1) + if ret.Get(1) != nil { + r1 = ret.Get(1).([]roles.RoleProvision) + } } - return r0, r1 + if rf, ok := ret.Get(2).(func(context.Context, authn.Session, groups.Group) error); ok { + r2 = rf(ctx, session, g) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 } // DeleteGroup provides a mock function with given fields: ctx, session, id diff --git a/groups/service.go b/groups/service.go index 1dd7bf7df2..b8755d8b9a 100644 --- a/groups/service.go +++ b/groups/service.go @@ -47,13 +47,13 @@ func NewService(repo Repository, policy policies.Service, idp supermq.IDProvider }, nil } -func (svc service) CreateGroup(ctx context.Context, session smqauthn.Session, g Group) (gr Group, retErr error) { +func (svc service) CreateGroup(ctx context.Context, session smqauthn.Session, g Group) (retGr Group, retRps []roles.RoleProvision, retErr error) { groupID, err := svc.idProvider.ID() if err != nil { - return Group{}, err + return Group{}, []roles.RoleProvision{}, err } if g.Status != EnabledStatus && g.Status != DisabledStatus { - return Group{}, svcerr.ErrInvalidStatus + return Group{}, []roles.RoleProvision{}, svcerr.ErrInvalidStatus } g.ID = groupID @@ -62,7 +62,7 @@ func (svc service) CreateGroup(ctx context.Context, session smqauthn.Session, g saved, err := svc.repo.Save(ctx, g) if err != nil { - return Group{}, errors.Wrap(svcerr.ErrCreateEntity, err) + return Group{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrCreateEntity, err) } defer func() { @@ -97,11 +97,12 @@ func (svc service) CreateGroup(ctx context.Context, session smqauthn.Session, g newBuiltInRoleMembers := map[roles.BuiltInRoleName][]roles.Member{ BuiltInRoleAdmin: {roles.Member(session.UserID)}, } - if _, err := svc.AddNewEntitiesRoles(ctx, session.DomainID, session.UserID, []string{saved.ID}, oprs, newBuiltInRoleMembers); err != nil { - return Group{}, errors.Wrap(svcerr.ErrAddPolicies, err) + nrps, err := svc.AddNewEntitiesRoles(ctx, session.DomainID, session.UserID, []string{saved.ID}, oprs, newBuiltInRoleMembers) + if err != nil { + return Group{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrAddPolicies, err) } - return saved, nil + return saved, nrps, nil } func (svc service) ViewGroup(ctx context.Context, session smqauthn.Session, id string) (Group, error) { diff --git a/groups/service_test.go b/groups/service_test.go index f4312413f4..a02d85a30d 100644 --- a/groups/service_test.go +++ b/groups/service_test.go @@ -201,9 +201,9 @@ func TestCreateGroup(t *testing.T) { repoCall := repo.On("Save", context.Background(), mock.Anything).Return(tc.saveResp, tc.saveErr) policyCall := policies.On("AddPolicies", context.Background(), mock.Anything).Return(tc.addPoliciesErr) policyCall1 := policies.On("DeletePolicies", context.Background(), mock.Anything).Return(tc.deletePoliciesErr) - repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.Role{}, tc.addRoleErr) + repoCall1 := repo.On("AddRoles", context.Background(), mock.Anything).Return([]roles.RoleProvision{}, tc.addRoleErr) repoCall2 := repo.On("Delete", context.Background(), mock.Anything).Return(tc.deleteErr) - got, err := svc.CreateGroup(context.Background(), validSession, tc.group) + got, _, err := svc.CreateGroup(context.Background(), validSession, tc.group) assert.Equal(t, tc.err, err, fmt.Sprintf("expected error %v but got %v", tc.err, err)) if err == nil { assert.NotEmpty(t, got.ID) diff --git a/groups/tracing/tracing.go b/groups/tracing/tracing.go index 719bab3101..80e369dcbc 100644 --- a/groups/tracing/tracing.go +++ b/groups/tracing/tracing.go @@ -9,6 +9,7 @@ import ( "github.com/absmach/supermq/groups" "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/roles" rmTrace "github.com/absmach/supermq/pkg/roles/rolemanager/tracing" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -28,7 +29,7 @@ func New(svc groups.Service, tracer trace.Tracer) groups.Service { } // CreateGroup traces the "CreateGroup" operation of the wrapped groups.Service. -func (tm *tracingMiddleware) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, error) { +func (tm *tracingMiddleware) CreateGroup(ctx context.Context, session authn.Session, g groups.Group) (groups.Group, []roles.RoleProvision, error) { ctx, span := tm.tracer.Start(ctx, "svc_create_group") defer span.End()