From 5ba08d82a7f143b84126ffaacff2b1c8757b8cee Mon Sep 17 00:00:00 2001 From: Arvindh Date: Thu, 26 Dec 2024 15:48:29 +0530 Subject: [PATCH] domains: return role provisioned on create Signed-off-by: Arvindh --- domains/api/http/endpoint.go | 2 +- domains/domains.go | 2 +- domains/events/events.go | 15 +++++++++------ domains/events/streams.go | 14 ++++++++------ domains/middleware/authorization.go | 3 ++- domains/middleware/logging.go | 3 ++- domains/middleware/metrics.go | 3 ++- domains/mocks/service.go | 21 +++++++++++++++------ domains/service.go | 19 ++++++++++--------- domains/service_test.go | 2 +- domains/tracing/tracing.go | 3 ++- pkg/roles/roles.go | 4 ++-- 12 files changed, 55 insertions(+), 36 deletions(-) diff --git a/domains/api/http/endpoint.go b/domains/api/http/endpoint.go index 40fb2c1182b..b50fe898a6e 100644 --- a/domains/api/http/endpoint.go +++ b/domains/api/http/endpoint.go @@ -33,7 +33,7 @@ func createDomainEndpoint(svc domains.Service) endpoint.Endpoint { Tags: req.Tags, Alias: req.Alias, } - domain, err := svc.CreateDomain(ctx, session, d) + domain, _, err := svc.CreateDomain(ctx, session, d) if err != nil { return nil, err } diff --git a/domains/domains.go b/domains/domains.go index e9629f61375..a601608491d 100644 --- a/domains/domains.go +++ b/domains/domains.go @@ -162,7 +162,7 @@ func (page DomainsPage) MarshalJSON() ([]byte, error) { //go:generate mockery --name Service --output=./mocks --filename service.go --quiet --note "Copyright (c) Abstract Machines" type Service interface { - CreateDomain(ctx context.Context, sesssion authn.Session, d Domain) (Domain, error) + CreateDomain(ctx context.Context, sesssion authn.Session, d Domain) (Domain, []roles.RoleProvision, error) RetrieveDomain(ctx context.Context, sesssion authn.Session, id string) (Domain, error) UpdateDomain(ctx context.Context, sesssion authn.Session, id string, d DomainReq) (Domain, error) EnableDomain(ctx context.Context, sesssion authn.Session, id string) (Domain, error) diff --git a/domains/events/events.go b/domains/events/events.go index 1bf5dd84679..9cc707a6475 100644 --- a/domains/events/events.go +++ b/domains/events/events.go @@ -8,6 +8,7 @@ import ( "github.com/absmach/supermq/domains" "github.com/absmach/supermq/pkg/events" + "github.com/absmach/supermq/pkg/roles" ) const ( @@ -36,16 +37,18 @@ var ( type createDomainEvent struct { domains.Domain + rolesProvisioned []roles.RoleProvision } func (cde createDomainEvent) Encode() (map[string]interface{}, error) { val := map[string]interface{}{ - "operation": domainCreate, - "id": cde.ID, - "alias": cde.Alias, - "status": cde.Status.String(), - "created_at": cde.CreatedAt, - "created_by": cde.CreatedBy, + "operation": domainCreate, + "id": cde.ID, + "alias": cde.Alias, + "status": cde.Status.String(), + "created_at": cde.CreatedAt, + "created_by": cde.CreatedBy, + "roles_provisioned": cde.rolesProvisioned, } if cde.Name != "" { diff --git a/domains/events/streams.go b/domains/events/streams.go index aa54584bb6c..4175a13753d 100644 --- a/domains/events/streams.go +++ b/domains/events/streams.go @@ -10,6 +10,7 @@ import ( "github.com/absmach/supermq/pkg/authn" "github.com/absmach/supermq/pkg/events" "github.com/absmach/supermq/pkg/events/store" + "github.com/absmach/supermq/pkg/roles" rmEvents "github.com/absmach/supermq/pkg/roles/rolemanager/events" ) @@ -40,21 +41,22 @@ func NewEventStoreMiddleware(ctx context.Context, svc domains.Service, url strin }, nil } -func (es *eventStore) CreateDomain(ctx context.Context, session authn.Session, domain domains.Domain) (domains.Domain, error) { - domain, err := es.svc.CreateDomain(ctx, session, domain) +func (es *eventStore) CreateDomain(ctx context.Context, session authn.Session, domain domains.Domain) (domains.Domain, []roles.RoleProvision, error) { + domain, rps, err := es.svc.CreateDomain(ctx, session, domain) if err != nil { - return domain, err + return domain, rps, err } event := createDomainEvent{ - domain, + Domain: domain, + rolesProvisioned: rps, } if err := es.Publish(ctx, event); err != nil { - return domain, err + return domain, rps, err } - return domain, nil + return domain, rps, nil } func (es *eventStore) RetrieveDomain(ctx context.Context, session authn.Session, id string) (domains.Domain, error) { diff --git a/domains/middleware/authorization.go b/domains/middleware/authorization.go index af57da6a94f..0b8e9ef6368 100644 --- a/domains/middleware/authorization.go +++ b/domains/middleware/authorization.go @@ -11,6 +11,7 @@ import ( "github.com/absmach/supermq/pkg/authz" smqauthz "github.com/absmach/supermq/pkg/authz" "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/pkg/roles" rmMW "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" "github.com/absmach/supermq/pkg/svcutil" ) @@ -46,7 +47,7 @@ func AuthorizationMiddleware(entityType string, svc domains.Service, authz smqau }, nil } -func (am *authorizationMiddleware) CreateDomain(ctx context.Context, session authn.Session, d domains.Domain) (domains.Domain, error) { +func (am *authorizationMiddleware) CreateDomain(ctx context.Context, session authn.Session, d domains.Domain) (domains.Domain, []roles.RoleProvision, error) { return am.svc.CreateDomain(ctx, session, d) } diff --git a/domains/middleware/logging.go b/domains/middleware/logging.go index 9eb4e5a0044..ef6da0c03de 100644 --- a/domains/middleware/logging.go +++ b/domains/middleware/logging.go @@ -12,6 +12,7 @@ import ( "github.com/absmach/supermq/domains" "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/roles" rmMW "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" ) @@ -33,7 +34,7 @@ func LoggingMiddleware(svc domains.Service, logger *slog.Logger) domains.Service } } -func (lm *loggingMiddleware) CreateDomain(ctx context.Context, session authn.Session, d domains.Domain) (do domains.Domain, err error) { +func (lm *loggingMiddleware) CreateDomain(ctx context.Context, session authn.Session, d domains.Domain) (do domains.Domain, rps []roles.RoleProvision, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), diff --git a/domains/middleware/metrics.go b/domains/middleware/metrics.go index 4aff05498a3..6ed64318440 100644 --- a/domains/middleware/metrics.go +++ b/domains/middleware/metrics.go @@ -11,6 +11,7 @@ import ( "github.com/absmach/supermq/domains" "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/roles" rmMW "github.com/absmach/supermq/pkg/roles/rolemanager/middleware" "github.com/go-kit/kit/metrics" ) @@ -36,7 +37,7 @@ func MetricsMiddleware(svc domains.Service, counter metrics.Counter, latency met } } -func (ms *metricsMiddleware) CreateDomain(ctx context.Context, session authn.Session, d domains.Domain) (domains.Domain, error) { +func (ms *metricsMiddleware) CreateDomain(ctx context.Context, session authn.Session, d domains.Domain) (domains.Domain, []roles.RoleProvision, error) { defer func(begin time.Time) { ms.counter.With("method", "create_domain").Add(1) ms.latency.With("method", "create_domain").Observe(time.Since(begin).Seconds()) diff --git a/domains/mocks/service.go b/domains/mocks/service.go index 6c258eba23b..6025633c306 100644 --- a/domains/mocks/service.go +++ b/domains/mocks/service.go @@ -50,7 +50,7 @@ func (_m *Service) AddRole(ctx context.Context, session authn.Session, entityID } // CreateDomain provides a mock function with given fields: ctx, sesssion, d -func (_m *Service) CreateDomain(ctx context.Context, sesssion authn.Session, d domains.Domain) (domains.Domain, error) { +func (_m *Service) CreateDomain(ctx context.Context, sesssion authn.Session, d domains.Domain) (domains.Domain, []roles.RoleProvision, error) { ret := _m.Called(ctx, sesssion, d) if len(ret) == 0 { @@ -58,8 +58,9 @@ func (_m *Service) CreateDomain(ctx context.Context, sesssion authn.Session, d d } var r0 domains.Domain - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, authn.Session, domains.Domain) (domains.Domain, error)); ok { + var r1 []roles.RoleProvision + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, authn.Session, domains.Domain) (domains.Domain, []roles.RoleProvision, error)); ok { return rf(ctx, sesssion, d) } if rf, ok := ret.Get(0).(func(context.Context, authn.Session, domains.Domain) domains.Domain); ok { @@ -68,13 +69,21 @@ func (_m *Service) CreateDomain(ctx context.Context, sesssion authn.Session, d d r0 = ret.Get(0).(domains.Domain) } - if rf, ok := ret.Get(1).(func(context.Context, authn.Session, domains.Domain) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, authn.Session, domains.Domain) []roles.RoleProvision); ok { r1 = rf(ctx, sesssion, d) } else { - r1 = ret.Error(1) + if ret.Get(1) != nil { + r1 = ret.Get(1).([]roles.RoleProvision) + } } - return r0, r1 + if rf, ok := ret.Get(2).(func(context.Context, authn.Session, domains.Domain) error); ok { + r2 = rf(ctx, sesssion, d) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 } // DisableDomain provides a mock function with given fields: ctx, sesssion, id diff --git a/domains/service.go b/domains/service.go index 58d523bfdb2..1483f048e38 100644 --- a/domains/service.go +++ b/domains/service.go @@ -45,17 +45,17 @@ func New(repo Repository, cache Cache, policy policies.Service, idProvider super }, nil } -func (svc service) CreateDomain(ctx context.Context, session authn.Session, d Domain) (do Domain, err error) { +func (svc service) CreateDomain(ctx context.Context, session authn.Session, d Domain) (retDo Domain, retRps []roles.RoleProvision, retErr error) { d.CreatedBy = session.UserID domainID, err := svc.idProvider.ID() if err != nil { - return Domain{}, errors.Wrap(svcerr.ErrCreateEntity, err) + return Domain{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrCreateEntity, err) } d.ID = domainID if d.Status != DisabledStatus && d.Status != EnabledStatus { - return Domain{}, svcerr.ErrInvalidStatus + return Domain{}, []roles.RoleProvision{}, svcerr.ErrInvalidStatus } d.CreatedAt = time.Now() @@ -63,12 +63,12 @@ func (svc service) CreateDomain(ctx context.Context, session authn.Session, d Do // Domain is created in repo first, because Roles table have foreign key relation with Domain ID dom, err := svc.repo.Save(ctx, d) if err != nil { - return Domain{}, errors.Wrap(svcerr.ErrCreateEntity, err) + return Domain{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrCreateEntity, err) } defer func() { - if err != nil { + if retErr != nil { if errRollBack := svc.repo.Delete(ctx, domainID); errRollBack != nil { - err = errors.Wrap(err, errors.Wrap(errRollbackRepo, errRollBack)) + retErr = errors.Wrap(retErr, errors.Wrap(errRollbackRepo, errRollBack)) } } }() @@ -87,11 +87,12 @@ func (svc service) CreateDomain(ctx context.Context, session authn.Session, d Do }, } - if _, err := svc.AddNewEntitiesRoles(ctx, domainID, session.UserID, []string{domainID}, optionalPolicies, newBuiltInRoleMembers); err != nil { - return Domain{}, errors.Wrap(errCreateDomainPolicy, err) + rps, err := svc.AddNewEntitiesRoles(ctx, domainID, session.UserID, []string{domainID}, optionalPolicies, newBuiltInRoleMembers) + if err != nil { + return Domain{}, []roles.RoleProvision{}, errors.Wrap(errCreateDomainPolicy, err) } - return dom, nil + return dom, rps, nil } func (svc service) RetrieveDomain(ctx context.Context, session authn.Session, id string) (Domain, error) { diff --git a/domains/service_test.go b/domains/service_test.go index c1cab14e83f..34792b722ea 100644 --- a/domains/service_test.go +++ b/domains/service_test.go @@ -171,7 +171,7 @@ func TestCreateDomain(t *testing.T) { repoCall2 := drepo.On("AddRoles", mock.Anything, mock.Anything).Return([]roles.Role{}, tc.addRolesErr) policyCall := policy.On("AddPolicies", mock.Anything, mock.Anything).Return(tc.addPoliciesErr) policyCall1 := policy.On("DeletePolicies", mock.Anything, mock.Anything).Return(tc.deletePoliciesErr) - _, err := svc.CreateDomain(context.Background(), tc.session, tc.d) + _, _, err := svc.CreateDomain(context.Background(), tc.session, tc.d) assert.True(t, errors.Contains(err, tc.err)) repoCall.Unset() repoCall1.Unset() diff --git a/domains/tracing/tracing.go b/domains/tracing/tracing.go index 76616dda59b..40e675ea380 100644 --- a/domains/tracing/tracing.go +++ b/domains/tracing/tracing.go @@ -8,6 +8,7 @@ import ( "github.com/absmach/supermq/domains" "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/roles" rmTrace "github.com/absmach/supermq/pkg/roles/rolemanager/tracing" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -26,7 +27,7 @@ func New(svc domains.Service, tracer trace.Tracer) domains.Service { return &tracingMiddleware{tracer, svc, rmTrace.NewRoleManagerTracing("domain", svc, tracer)} } -func (tm *tracingMiddleware) CreateDomain(ctx context.Context, session authn.Session, d domains.Domain) (domains.Domain, error) { +func (tm *tracingMiddleware) CreateDomain(ctx context.Context, session authn.Session, d domains.Domain) (domains.Domain, []roles.RoleProvision, error) { ctx, span := tm.tracer.Start(ctx, "create_domain", trace.WithAttributes( attribute.String("name", d.Name), )) diff --git a/pkg/roles/roles.go b/pkg/roles/roles.go index a247ec6a684..54cc9261117 100644 --- a/pkg/roles/roles.go +++ b/pkg/roles/roles.go @@ -52,8 +52,8 @@ type Role struct { type RoleProvision struct { Role - OptionalActions []string `json:"-"` - OptionalMembers []string `json:"-"` + OptionalActions []string `json:"optional_actions"` + OptionalMembers []string `json:"optional_members"` } type RolePage struct {