From d035368a3647a7f700dcfd66df4c5bb8e8c9b838 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 8 Jan 2025 13:37:22 +0300 Subject: [PATCH] seperate pats middleware Signed-off-by: nyagamunene --- auth/service.go | 9 +- clients/middleware/authorization.go | 154 --------- clients/middleware/pat.go | 476 ++++++++++++++++++++++++++++ cmd/clients/main.go | 15 +- cmd/users/main.go | 15 +- domains/api/http/endpoint_test.go | 1 - pkg/authz/authsvc/authz.go | 38 +-- pkg/authz/authz.go | 22 +- pkg/authz/mocks/authz.go | 18 -- pkg/pat/pat.go | 72 +++++ pkg/sdk/channels_test.go | 2 - users/middleware/authorization.go | 210 ------------ users/middleware/pat.go | 210 ++++++++++++ 13 files changed, 815 insertions(+), 427 deletions(-) create mode 100644 clients/middleware/pat.go create mode 100644 pkg/pat/pat.go create mode 100644 users/middleware/pat.go diff --git a/auth/service.go b/auth/service.go index 579d8466fd..bf31a40850 100644 --- a/auth/service.go +++ b/auth/service.go @@ -514,13 +514,8 @@ func (svc service) UpdatePATDescription(ctx context.Context, token, patID, descr return pat, nil } -func (svc service) RetrievePAT(ctx context.Context, token, patID string) (PAT, error) { - key, err := svc.Identify(ctx, token) - if err != nil { - return PAT{}, err - } - - pat, err := svc.pats.Retrieve(ctx, key.User, patID) +func (svc service) RetrievePAT(ctx context.Context, userID, patID string) (PAT, error) { + pat, err := svc.pats.Retrieve(ctx, userID, patID) if err != nil { return PAT{}, errors.Wrap(errRetrievePAT, err) } diff --git a/clients/middleware/authorization.go b/clients/middleware/authorization.go index 3729d463c9..e9fdce595b 100644 --- a/clients/middleware/authorization.go +++ b/clients/middleware/authorization.go @@ -6,7 +6,6 @@ package middleware import ( "context" - "github.com/absmach/supermq/auth" "github.com/absmach/supermq/clients" "github.com/absmach/supermq/pkg/authn" smqauthz "github.com/absmach/supermq/pkg/authz" @@ -75,20 +74,6 @@ func AuthorizationMiddleware(entityType string, svc clients.Service, authz smqau } func (am *authorizationMiddleware) CreateClients(ctx context.Context, session authn.Session, client ...clients.Client) ([]clients.Client, []roles.RoleProvision, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.CreateOp, - EntityIDs: auth.AnyIDs{}.Values(), - }); err != nil { - return []clients.Client{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.extAuthorize(ctx, clients.DomainOpCreateClient, smqauthz.PolicyReq{ Domain: session.DomainID, SubjectType: policies.UserType, @@ -103,20 +88,6 @@ func (am *authorizationMiddleware) CreateClients(ctx context.Context, session au } func (am *authorizationMiddleware) View(ctx context.Context, session authn.Session, id string) (clients.Client, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.ReadOp, - EntityIDs: []string{id}, - }); err != nil { - return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.authorize(ctx, clients.OpViewClient, smqauthz.PolicyReq{ Domain: session.DomainID, SubjectType: policies.UserType, @@ -130,20 +101,6 @@ func (am *authorizationMiddleware) View(ctx context.Context, session authn.Sessi } func (am *authorizationMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pm clients.Page) (clients.ClientsPage, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.ListOp, - EntityIDs: auth.AnyIDs{}.Values(), - }); err != nil { - return clients.ClientsPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.checkSuperAdmin(ctx, session.UserID); err != nil { session.SuperAdmin = true } @@ -152,20 +109,6 @@ func (am *authorizationMiddleware) ListClients(ctx context.Context, session auth } func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{client.ID}, - }); err != nil { - return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.authorize(ctx, clients.OpUpdateClient, smqauthz.PolicyReq{ Domain: session.DomainID, SubjectType: policies.UserType, @@ -180,20 +123,6 @@ func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Ses } func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{client.ID}, - }); err != nil { - return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.authorize(ctx, clients.OpUpdateClientTags, smqauthz.PolicyReq{ Domain: session.DomainID, SubjectType: policies.UserType, @@ -208,20 +137,6 @@ func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn } func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session authn.Session, id, key string) (clients.Client, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, - }); err != nil { - return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.authorize(ctx, clients.OpUpdateClientSecret, smqauthz.PolicyReq{ Domain: session.DomainID, SubjectType: policies.UserType, @@ -235,20 +150,6 @@ func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session aut } func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Session, id string) (clients.Client, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, - }); err != nil { - return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.authorize(ctx, clients.OpEnableClient, smqauthz.PolicyReq{ Domain: session.DomainID, SubjectType: policies.UserType, @@ -263,20 +164,6 @@ func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Ses } func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Session, id string) (clients.Client, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, - }); err != nil { - return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.authorize(ctx, clients.OpDisableClient, smqauthz.PolicyReq{ Domain: session.DomainID, SubjectType: policies.UserType, @@ -290,19 +177,6 @@ func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Se } func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Session, id string) error { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainClientsScope, - Operation: auth.DeleteOp, - EntityIDs: []string{id}, - }); err != nil { - return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } if err := am.authorize(ctx, clients.OpDeleteClient, smqauthz.PolicyReq{ Domain: session.DomainID, SubjectType: policies.UserType, @@ -317,20 +191,6 @@ func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Ses } func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.UpdateOp, - EntityIDs: []string{id}, - }); err != nil { - return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.authorize(ctx, clients.OpSetParentGroup, smqauthz.PolicyReq{ Domain: session.DomainID, SubjectType: policies.UserType, @@ -354,20 +214,6 @@ func (am *authorizationMiddleware) SetParentGroup(ctx context.Context, session a } func (am *authorizationMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: auth.PlatformDomainsScope, - OptionalDomainID: session.DomainID, - OptionalDomainEntityType: auth.DomainGroupsScope, - Operation: auth.DeleteOp, - EntityIDs: []string{id}, - }); err != nil { - return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.authorize(ctx, clients.OpRemoveParentGroup, smqauthz.PolicyReq{ Domain: session.DomainID, SubjectType: policies.UserType, diff --git a/clients/middleware/pat.go b/clients/middleware/pat.go new file mode 100644 index 0000000000..c8dc0de560 --- /dev/null +++ b/clients/middleware/pat.go @@ -0,0 +1,476 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + smqauth "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/clients" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + smqpat "github.com/absmach/supermq/pkg/pat" + "github.com/absmach/supermq/pkg/roles" +) + +const emptyDomain = "" + +var _ clients.Service = (*patMiddleware)(nil) + +type patMiddleware struct { + svc clients.Service + pat smqpat.Authorization +} + +func PATMiddleware(svc clients.Service, pat smqpat.Authorization) clients.Service { + return &patMiddleware{ + svc: svc, + pat: pat, + } +} + +// authorizePAT validates Personal Access Token if present in the session +func (pm *patMiddleware) authorizePAT(ctx context.Context, session authn.Session, platformEntityType smqauth.PlatformEntityType, optionalDomainEntityType smqauth.DomainEntityType, OptionalDomainID string, operation smqauth.OperationType, entityIDs []string) error { + if session.Type != authn.PersonalAccessToken { + return nil + } + if session.ID == "" || session.UserID == "" { + return errors.Wrap(svcerr.ErrAuthentication, errors.New("invalid PAT credentials")) + } + + if err := pm.pat.AuthorizePAT(ctx, smqpat.PatReq{ + UserID: session.UserID, + PatID: session.ID, + PlatformEntityType: platformEntityType, + OptionalDomainEntityType: optionalDomainEntityType, + OptionalDomainID: OptionalDomainID, + Operation: operation, + EntityIDs: entityIDs, + }); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return nil +} + +func (pm *patMiddleware) CreateClients(ctx context.Context, session authn.Session, client ...clients.Client) ([]clients.Client, []roles.RoleProvision, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.CreateOp, + smqauth.AnyIDs{}.Values(), + ); err != nil { + return []clients.Client{}, []roles.RoleProvision{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.CreateClients(ctx, session, client...) +} + +func (pm *patMiddleware) View(ctx context.Context, session authn.Session, id string) (clients.Client, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.ReadOp, + []string{id}, + ); err != nil { + return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.View(ctx, session, id) +} + +func (pm *patMiddleware) ListClients(ctx context.Context, session authn.Session, reqUserID string, pg clients.Page) (clients.ClientsPage, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.ListOp, + smqauth.AnyIDs{}.Values(), + ); err != nil { + return clients.ClientsPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.ListClients(ctx, session, reqUserID, pg) +} + +func (pm *patMiddleware) Update(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.UpdateOp, + []string{client.ID}, + ); err != nil { + return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.Update(ctx, session, client) +} + +func (pm *patMiddleware) UpdateTags(ctx context.Context, session authn.Session, client clients.Client) (clients.Client, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.UpdateOp, + []string{client.ID}, + ); err != nil { + return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.UpdateTags(ctx, session, client) +} + +func (pm *patMiddleware) UpdateSecret(ctx context.Context, session authn.Session, id, key string) (clients.Client, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.UpdateOp, + []string{id}, + ); err != nil { + return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.UpdateSecret(ctx, session, id, key) +} + +func (pm *patMiddleware) Enable(ctx context.Context, session authn.Session, id string) (clients.Client, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.UpdateOp, + []string{id}, + ); err != nil { + return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.Enable(ctx, session, id) +} + +func (pm *patMiddleware) Disable(ctx context.Context, session authn.Session, id string) (clients.Client, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.UpdateOp, + []string{id}, + ); err != nil { + return clients.Client{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.Disable(ctx, session, id) +} + +func (pm *patMiddleware) Delete(ctx context.Context, session authn.Session, id string) error { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.DeleteOp, + []string{id}, + ); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.Delete(ctx, session, id) +} + +func (pm *patMiddleware) SetParentGroup(ctx context.Context, session authn.Session, parentGroupID string, id string) error { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainGroupsScope, + session.DomainID, + smqauth.UpdateOp, + []string{id}, + ); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.SetParentGroup(ctx, session, parentGroupID, id) +} + +func (pm *patMiddleware) RemoveParentGroup(ctx context.Context, session authn.Session, id string) error { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainGroupsScope, + session.DomainID, + smqauth.DeleteOp, + []string{id}, + ); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RemoveParentGroup(ctx, session, id) +} + +func (pm *patMiddleware) AddRole(ctx context.Context, session authn.Session, entityID string, roleName string, optionalActions []string, optionalMembers []string) (roles.RoleProvision, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.CreateOp, + []string{entityID}, + ); err != nil { + return roles.RoleProvision{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.AddRole(ctx, session, entityID, roleName, optionalActions, optionalMembers) +} + +func (pm *patMiddleware) ListAvailableActions(ctx context.Context, session authn.Session) ([]string, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.ListOp, + smqauth.AnyIDs{}.Values(), + ); err != nil { + return []string{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.ListAvailableActions(ctx, session) +} + +func (pm *patMiddleware) RemoveMemberFromAllRoles(ctx context.Context, session authn.Session, member string) error { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.DeleteOp, + smqauth.AnyIDs{}.Values(), + ); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RemoveMemberFromAllRoles(ctx, session, member) +} + +func (pm *patMiddleware) RemoveRole(ctx context.Context, session authn.Session, entityID, roleID string) error { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.DeleteOp, + []string{entityID}, + ); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RemoveRole(ctx, session, entityID, roleID) +} + +func (pm *patMiddleware) RetrieveAllRoles(ctx context.Context, session authn.Session, entityID string, limit, offset uint64) (roles.RolePage, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.ListOp, + []string{entityID}, + ); err != nil { + return roles.RolePage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RetrieveAllRoles(ctx, session, entityID, limit, offset) +} + +func (pm *patMiddleware) RetrieveRole(ctx context.Context, session authn.Session, entityID, roleID string) (roles.Role, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.ReadOp, + []string{entityID}, + ); err != nil { + return roles.Role{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RetrieveRole(ctx, session, entityID, roleID) +} + +func (pm *patMiddleware) RoleAddActions(ctx context.Context, session authn.Session, entityID, roleID string, actions []string) ([]string, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.UpdateOp, + []string{entityID}, + ); err != nil { + return []string{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RoleAddActions(ctx, session, entityID, roleID, actions) +} + +func (pm *patMiddleware) RoleListActions(ctx context.Context, session authn.Session, entityID, roleID string) ([]string, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.ReadOp, + []string{entityID}, + ); err != nil { + return []string{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RoleListActions(ctx, session, entityID, roleID) +} + +func (pm *patMiddleware) RoleCheckActionsExists(ctx context.Context, session authn.Session, entityID, roleID string, actions []string) (bool, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.ReadOp, + []string{entityID}, + ); err != nil { + return false, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RoleCheckActionsExists(ctx, session, entityID, roleID, actions) +} + +func (pm *patMiddleware) RoleRemoveActions(ctx context.Context, session authn.Session, entityID, roleID string, actions []string) error { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.DeleteOp, + []string{entityID}, + ); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RoleRemoveActions(ctx, session, entityID, roleID, actions) +} + +func (pm *patMiddleware) RoleRemoveAllActions(ctx context.Context, session authn.Session, entityID, roleID string) error { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.DeleteOp, + []string{entityID}, + ); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RoleRemoveAllActions(ctx, session, entityID, roleID) +} + +func (pm *patMiddleware) RoleAddMembers(ctx context.Context, session authn.Session, entityID, roleID string, members []string) ([]string, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.UpdateOp, + []string{entityID}, + ); err != nil { + return []string{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RoleAddMembers(ctx, session, entityID, roleID, members) +} + +func (pm *patMiddleware) RoleListMembers(ctx context.Context, session authn.Session, entityID, roleID string, limit, offset uint64) (roles.MembersPage, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.ReadOp, + []string{entityID}, + ); err != nil { + return roles.MembersPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RoleListMembers(ctx, session, entityID, roleID, limit, offset) +} + +func (pm *patMiddleware) RoleCheckMembersExists(ctx context.Context, session authn.Session, entityID, roleID string, members []string) (bool, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.ReadOp, + []string{entityID}, + ); err != nil { + return false, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RoleCheckMembersExists(ctx, session, entityID, roleID, members) +} + +func (pm *patMiddleware) RoleRemoveMembers(ctx context.Context, session authn.Session, entityID, roleID string, members []string) error { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.DeleteOp, + []string{entityID}, + ); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RoleRemoveMembers(ctx, session, entityID, roleID, members) +} + +func (pm *patMiddleware) RoleRemoveAllMembers(ctx context.Context, session authn.Session, entityID, roleID string) error { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.DeleteOp, + []string{entityID}, + ); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.RoleRemoveAllMembers(ctx, session, entityID, roleID) +} + +func (pm *patMiddleware) UpdateRoleName(ctx context.Context, session authn.Session, entityID, roleID, newRoleName string) (roles.Role, error) { + if err := pm.authorizePAT(ctx, + session, + smqauth.PlatformDomainsScope, + smqauth.DomainClientsScope, + session.DomainID, + smqauth.UpdateOp, + []string{entityID}, + ); err != nil { + return roles.Role{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return pm.svc.UpdateRoleName(ctx, session, entityID, roleID, newRoleName) +} diff --git a/cmd/clients/main.go b/cmd/clients/main.go index 94b9c402d3..47877d8d51 100644 --- a/cmd/clients/main.go +++ b/cmd/clients/main.go @@ -35,6 +35,7 @@ import ( domainsAuthz "github.com/absmach/supermq/pkg/domains/grpcclient" "github.com/absmach/supermq/pkg/grpcclient" jaegerclient "github.com/absmach/supermq/pkg/jaeger" + authsvcPat "github.com/absmach/supermq/pkg/pat" "github.com/absmach/supermq/pkg/policies" "github.com/absmach/supermq/pkg/policies/spicedb" pg "github.com/absmach/supermq/pkg/postgres" @@ -204,6 +205,15 @@ func main() { defer authzClient.Close() logger.Info("AuthZ successfully connected to auth gRPC server " + authnClient.Secure()) + pat, patHandler, err := authsvcPat.NewAuthorization(ctx, grpcCfg) + if err != nil { + logger.Error("failed to create authz " + err.Error()) + exitCode = 1 + return + } + defer patHandler.Close() + logger.Info("PAT successfully connected to auth gRPC server " + patHandler.Secure()) + chgrpccfg := grpcclient.Config{} if err := env.ParseWithOptions(&chgrpccfg, env.Options{Prefix: envPrefixChannels}); err != nil { logger.Error(fmt.Sprintf("failed to load channels gRPC client configuration : %s", err)) @@ -234,7 +244,7 @@ func main() { defer groupsHandler.Close() logger.Info("Groups gRPC client successfully connected to groups gRPC server " + groupsHandler.Secure()) - svc, psvc, err := newService(ctx, db, dbConfig, authz, policyEvaluator, policyService, cacheclient, cfg.CacheKeyDuration, cfg.ESURL, channelsgRPC, groupsClient, tracer, logger) + svc, psvc, err := newService(ctx, db, dbConfig, authz, pat, policyEvaluator, policyService, cacheclient, cfg.CacheKeyDuration, cfg.ESURL, channelsgRPC, groupsClient, tracer, logger) if err != nil { logger.Error(fmt.Sprintf("failed to create services: %s", err)) exitCode = 1 @@ -286,7 +296,7 @@ func main() { } } -func newService(ctx context.Context, db *sqlx.DB, dbConfig pgclient.Config, authz smqauthz.Authorization, pe policies.Evaluator, ps policies.Service, cacheClient *redis.Client, keyDuration time.Duration, esURL string, channels grpcChannelsV1.ChannelsServiceClient, groups grpcGroupsV1.GroupsServiceClient, tracer trace.Tracer, logger *slog.Logger) (clients.Service, pClients.Service, error) { +func newService(ctx context.Context, db *sqlx.DB, dbConfig pgclient.Config, authz smqauthz.Authorization, pat authsvcPat.Authorization, pe policies.Evaluator, ps policies.Service, cacheClient *redis.Client, keyDuration time.Duration, esURL string, channels grpcChannelsV1.ChannelsServiceClient, groups grpcGroupsV1.GroupsServiceClient, tracer trace.Tracer, logger *slog.Logger) (clients.Service, pClients.Service, error) { database := pg.NewDatabase(db, dbConfig, tracer) repo := postgres.NewRepository(database) @@ -315,6 +325,7 @@ func newService(ctx context.Context, db *sqlx.DB, dbConfig pgclient.Config, auth csvc = middleware.MetricsMiddleware(csvc, counter, latency) csvc = middleware.MetricsMiddleware(csvc, counter, latency) + csvc = middleware.PATMiddleware(csvc, pat) csvc, err = middleware.AuthorizationMiddleware(policies.ClientType, csvc, authz, repo, clients.NewOperationPermissionMap(), clients.NewRolesOperationPermissionMap(), clients.NewExternalOperationPermissionMap()) if err != nil { return nil, nil, err diff --git a/cmd/users/main.go b/cmd/users/main.go index 275a477358..009bf06747 100644 --- a/cmd/users/main.go +++ b/cmd/users/main.go @@ -28,6 +28,7 @@ import ( jaegerclient "github.com/absmach/supermq/pkg/jaeger" "github.com/absmach/supermq/pkg/oauth2" googleoauth "github.com/absmach/supermq/pkg/oauth2/google" + authsvcPat "github.com/absmach/supermq/pkg/pat" "github.com/absmach/supermq/pkg/policies" "github.com/absmach/supermq/pkg/policies/spicedb" pg "github.com/absmach/supermq/pkg/postgres" @@ -205,6 +206,15 @@ func main() { defer authzHandler.Close() logger.Info("AuthZ successfully connected to auth gRPC server " + authzHandler.Secure()) + pat, patHandler, err := authsvcPat.NewAuthorization(ctx, authClientConfig) + if err != nil { + logger.Error("failed to create authz " + err.Error()) + exitCode = 1 + return + } + defer patHandler.Close() + logger.Info("PAT successfully connected to auth gRPC server " + patHandler.Secure()) + policyService, err := newPolicyService(cfg, logger) if err != nil { logger.Error("failed to create new policies service " + err.Error()) @@ -213,7 +223,7 @@ func main() { } logger.Info("Policy client successfully connected to spicedb gRPC server") - csvc, err := newService(ctx, authz, tokenClient, policyService, domainsClient, db, dbConfig, tracer, cfg, ec, logger) + csvc, err := newService(ctx, authz, pat, tokenClient, policyService, domainsClient, db, dbConfig, tracer, cfg, ec, logger) if err != nil { logger.Error(fmt.Sprintf("failed to setup service: %s", err)) exitCode = 1 @@ -256,7 +266,7 @@ func main() { } } -func newService(ctx context.Context, authz smqauthz.Authorization, token grpcTokenV1.TokenServiceClient, policyService policies.Service, domainsClient grpcDomainsV1.DomainsServiceClient, db *sqlx.DB, dbConfig pgclient.Config, tracer trace.Tracer, c config, ec email.Config, logger *slog.Logger) (users.Service, error) { +func newService(ctx context.Context, authz smqauthz.Authorization, pat authsvcPat.Authorization, token grpcTokenV1.TokenServiceClient, policyService policies.Service, domainsClient grpcDomainsV1.DomainsServiceClient, db *sqlx.DB, dbConfig pgclient.Config, tracer trace.Tracer, c config, ec email.Config, logger *slog.Logger) (users.Service, error) { database := pg.NewDatabase(db, dbConfig, tracer) idp := uuid.New() hsr := hasher.New() @@ -274,6 +284,7 @@ func newService(ctx context.Context, authz smqauthz.Authorization, token grpcTok if err != nil { return nil, err } + svc = middleware.PATMiddleware(svc, pat) svc = middleware.AuthorizationMiddleware(svc, authz, c.SelfRegister) svc = tracing.New(svc, tracer) diff --git a/domains/api/http/endpoint_test.go b/domains/api/http/endpoint_test.go index 703acb2e74..51a9b31adb 100644 --- a/domains/api/http/endpoint_test.go +++ b/domains/api/http/endpoint_test.go @@ -816,7 +816,6 @@ func TestUpdateDomain(t *testing.T) { contentType: tc.contentType, token: tc.token, } - fmt.Println("req url", req.url) if tc.token == validToken { tc.session = authn.Session{UserID: userID, DomainID: tc.domainID, DomainUserID: tc.domainID + "_" + userID} diff --git a/pkg/authz/authsvc/authz.go b/pkg/authz/authsvc/authz.go index f0d1de592b..8412a1ca55 100644 --- a/pkg/authz/authsvc/authz.go +++ b/pkg/authz/authsvc/authz.go @@ -122,22 +122,22 @@ func (a authorization) checkDomain(ctx context.Context, subjectType, subject, do } } -func (a authorization) AuthorizePAT(ctx context.Context, pr authz.PatReq) error { - req := grpcAuthV1.AuthZPatReq{ - UserId: pr.UserID, - PatId: pr.PatID, - PlatformEntityType: uint32(pr.PlatformEntityType), - OptionalDomainId: pr.OptionalDomainID, - OptionalDomainEntityType: uint32(pr.OptionalDomainEntityType), - Operation: uint32(pr.Operation), - EntityIds: pr.EntityIDs, - } - res, err := a.authSvcClient.AuthorizePAT(ctx, &req) - if err != nil { - return errors.Wrap(errors.ErrAuthorization, err) - } - if !res.Authorized { - return errors.ErrAuthorization - } - return nil -} +// func (a authorization) AuthorizePAT(ctx context.Context, pr authz.PatReq) error { +// req := grpcAuthV1.AuthZPatReq{ +// UserId: pr.UserID, +// PatId: pr.PatID, +// PlatformEntityType: uint32(pr.PlatformEntityType), +// OptionalDomainId: pr.OptionalDomainID, +// OptionalDomainEntityType: uint32(pr.OptionalDomainEntityType), +// Operation: uint32(pr.Operation), +// EntityIds: pr.EntityIDs, +// } +// res, err := a.authSvcClient.AuthorizePAT(ctx, &req) +// if err != nil { +// return errors.Wrap(errors.ErrAuthorization, err) +// } +// if !res.Authorized { +// return errors.ErrAuthorization +// } +// return nil +// } diff --git a/pkg/authz/authz.go b/pkg/authz/authz.go index 93d807ac5e..4796333476 100644 --- a/pkg/authz/authz.go +++ b/pkg/authz/authz.go @@ -5,8 +5,6 @@ package authz import ( "context" - - "github.com/absmach/supermq/auth" ) type PolicyReq struct { @@ -46,20 +44,20 @@ type PolicyReq struct { Permission string `json:"permission,omitempty"` } -type PatReq struct { - UserID string `json:"user_id,omitempty"` // UserID - PatID string `json:"pat_id,omitempty"` // UserID - PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"` // Platform entity type - OptionalDomainID string `json:"optional_domainID,omitempty"` // Optional domain id - OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` // Optional domain entity type - Operation auth.OperationType `json:"operation,omitempty"` // Operation - EntityIDs []string `json:"entityIDs,omitempty"` // EntityIDs -} +// type PatReq struct { +// UserID string `json:"user_id,omitempty"` // UserID +// PatID string `json:"pat_id,omitempty"` // UserID +// PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"` // Platform entity type +// OptionalDomainID string `json:"optional_domainID,omitempty"` // Optional domain id +// OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` // Optional domain entity type +// Operation auth.OperationType `json:"operation,omitempty"` // Operation +// EntityIDs []string `json:"entityIDs,omitempty"` // EntityIDs +// } // Authz is supermq authorization library. // //go:generate mockery --name Authorization --output=./mocks --filename authz.go --quiet --note "Copyright (c) Abstract Machines" type Authorization interface { Authorize(ctx context.Context, pr PolicyReq) error - AuthorizePAT(ctx context.Context, pr PatReq) error + // AuthorizePAT(ctx context.Context, pr PatReq) error } diff --git a/pkg/authz/mocks/authz.go b/pkg/authz/mocks/authz.go index 69a75517bd..f3153680c9 100644 --- a/pkg/authz/mocks/authz.go +++ b/pkg/authz/mocks/authz.go @@ -35,24 +35,6 @@ func (_m *Authorization) Authorize(ctx context.Context, pr authz.PolicyReq) erro return r0 } -// AuthorizePAT provides a mock function with given fields: ctx, pr -func (_m *Authorization) AuthorizePAT(ctx context.Context, pr authz.PatReq) error { - ret := _m.Called(ctx, pr) - - if len(ret) == 0 { - panic("no return value specified for AuthorizePAT") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, authz.PatReq) error); ok { - r0 = rf(ctx, pr) - } else { - r0 = ret.Error(0) - } - - return r0 -} - // NewAuthorization creates a new instance of Authorization. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewAuthorization(t interface { diff --git a/pkg/pat/pat.go b/pkg/pat/pat.go new file mode 100644 index 0000000000..847650a353 --- /dev/null +++ b/pkg/pat/pat.go @@ -0,0 +1,72 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package pat + +import ( + "context" + + grpcAuthV1 "github.com/absmach/supermq/api/grpc/auth/v1" + smqauth "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/auth/api/grpc/auth" + "github.com/absmach/supermq/pkg/errors" + "github.com/absmach/supermq/pkg/grpcclient" + grpchealth "google.golang.org/grpc/health/grpc_health_v1" +) + +type PatReq struct { + UserID string `json:"user_id,omitempty"` // UserID + PatID string `json:"pat_id,omitempty"` // UserID + PlatformEntityType smqauth.PlatformEntityType `json:"platform_entity_type,omitempty"` // Platform entity type + OptionalDomainID string `json:"optional_domainID,omitempty"` // Optional domain id + OptionalDomainEntityType smqauth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` // Optional domain entity type + Operation smqauth.OperationType `json:"operation,omitempty"` // Operation + EntityIDs []string `json:"entityIDs,omitempty"` // EntityIDs +} + +type Authorization interface { + AuthorizePAT(ctx context.Context, pr PatReq) error +} + +type authorization struct { + authSvcClient grpcAuthV1.AuthServiceClient +} + +func NewAuthorization(ctx context.Context, cfg grpcclient.Config) (Authorization, grpcclient.Handler, error) { + client, err := grpcclient.NewHandler(cfg) + if err != nil { + return nil, nil, err + } + + health := grpchealth.NewHealthClient(client.Connection()) + resp, err := health.Check(ctx, &grpchealth.HealthCheckRequest{ + Service: "auth", + }) + if err != nil || resp.GetStatus() != grpchealth.HealthCheckResponse_SERVING { + return nil, nil, grpcclient.ErrSvcNotServing + } + authSvcClient := auth.NewAuthClient(client.Connection(), cfg.Timeout) + return authorization{ + authSvcClient: authSvcClient, + }, client, nil +} + +func (a authorization) AuthorizePAT(ctx context.Context, pr PatReq) error { + req := grpcAuthV1.AuthZPatReq{ + UserId: pr.UserID, + PatId: pr.PatID, + PlatformEntityType: uint32(pr.PlatformEntityType), + OptionalDomainId: pr.OptionalDomainID, + OptionalDomainEntityType: uint32(pr.OptionalDomainEntityType), + Operation: uint32(pr.Operation), + EntityIds: pr.EntityIDs, + } + res, err := a.authSvcClient.AuthorizePAT(ctx, &req) + if err != nil { + return errors.Wrap(errors.ErrAuthorization, err) + } + if !res.Authorized { + return errors.ErrAuthorization + } + return nil +} diff --git a/pkg/sdk/channels_test.go b/pkg/sdk/channels_test.go index d6f6bf8b71..4cad6ef9d1 100644 --- a/pkg/sdk/channels_test.go +++ b/pkg/sdk/channels_test.go @@ -1335,7 +1335,6 @@ func TestDisableChannel(t *testing.T) { authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := gsvc.On("DisableChannel", mock.Anything, tc.session, tc.channelID).Return(tc.svcRes, tc.svcErr) resp, err := mgsdk.DisableChannel(tc.channelID, tc.domainID, tc.token) - fmt.Println(resp) assert.Equal(t, tc.err, err) assert.Equal(t, tc.response, resp) if tc.err == nil { @@ -1535,7 +1534,6 @@ func TestConnect(t *testing.T) { authCall := auth.On("Authenticate", mock.Anything, tc.token).Return(tc.session, tc.authenticateErr) svcCall := gsvc.On("Connect", mock.Anything, tc.session, tc.connection.ChannelIDs, tc.connection.ClientIDs, connTypes).Return(tc.svcErr) err := mgsdk.Connect(tc.connection, tc.domainID, tc.token) - fmt.Println(err) assert.Equal(t, tc.err, err) if tc.err == nil { ok := svcCall.Parent.AssertCalled(t, "Connect", mock.Anything, tc.session, tc.connection.ChannelIDs, tc.connection.ClientIDs, connTypes) diff --git a/users/middleware/authorization.go b/users/middleware/authorization.go index b38e622f03..cba9d2d750 100644 --- a/users/middleware/authorization.go +++ b/users/middleware/authorization.go @@ -10,7 +10,6 @@ import ( smqauth "github.com/absmach/supermq/auth" "github.com/absmach/supermq/pkg/authn" smqauthz "github.com/absmach/supermq/pkg/authz" - "github.com/absmach/supermq/pkg/errors" svcerr "github.com/absmach/supermq/pkg/errors/service" "github.com/absmach/supermq/pkg/policies" "github.com/absmach/supermq/users" @@ -44,19 +43,6 @@ func (am *authorizationMiddleware) Register(ctx context.Context, session authn.S } func (am *authorizationMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.ReadOp, - EntityIDs: []string{id}, - }); err != nil { - return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } @@ -65,34 +51,10 @@ func (am *authorizationMiddleware) View(ctx context.Context, session authn.Sessi } func (am *authorizationMiddleware) ViewProfile(ctx context.Context, session authn.Session) (users.User, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.ReadOp, - EntityIDs: []string{session.UserID}, - }); err != nil { - return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } return am.svc.ViewProfile(ctx, session) } func (am *authorizationMiddleware) ListUsers(ctx context.Context, session authn.Session, pm users.Page) (users.UsersPage, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.ListOp, - EntityIDs: smqauth.AnyIDs{}.Values(), - }); err != nil { - return users.UsersPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } @@ -101,49 +63,6 @@ func (am *authorizationMiddleware) ListUsers(ctx context.Context, session authn. } func (am *authorizationMiddleware) ListMembers(ctx context.Context, session authn.Session, objectKind, objectID string, pm users.Page) (users.MembersPage, error) { - if session.Type == authn.PersonalAccessToken { - switch objectKind { - case policies.GroupsKind: - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - OptionalDomainID: session.DomainID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainGroupsScope, - Operation: smqauth.ListOp, - EntityIDs: smqauth.AnyIDs{}.Values(), - }); err != nil { - return users.MembersPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - case policies.DomainsKind: - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - OptionalDomainID: session.DomainID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainManagementScope, - Operation: smqauth.ListOp, - EntityIDs: smqauth.AnyIDs{}.Values(), - }); err != nil { - return users.MembersPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - case policies.ClientsKind: - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - OptionalDomainID: session.DomainID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainClientsScope, - Operation: smqauth.ListOp, - EntityIDs: smqauth.AnyIDs{}.Values(), - }); err != nil { - return users.MembersPage{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - default: - return users.MembersPage{}, svcerr.ErrAuthorization - } - } - if session.DomainUserID == "" { return users.MembersPage{}, svcerr.ErrDomainAuthorization } @@ -172,19 +91,6 @@ func (am *authorizationMiddleware) SearchUsers(ctx context.Context, pm users.Pag } func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Session, user users.User) (users.User, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{user.ID}, - }); err != nil { - return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } @@ -193,19 +99,6 @@ func (am *authorizationMiddleware) Update(ctx context.Context, session authn.Ses } func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn.Session, user users.User) (users.User, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{user.ID}, - }); err != nil { - return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } @@ -214,18 +107,6 @@ func (am *authorizationMiddleware) UpdateTags(ctx context.Context, session authn } func (am *authorizationMiddleware) UpdateEmail(ctx context.Context, session authn.Session, id, email string) (users.User, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{id}, - }); err != nil { - return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } @@ -234,19 +115,6 @@ func (am *authorizationMiddleware) UpdateEmail(ctx context.Context, session auth } func (am *authorizationMiddleware) UpdateUsername(ctx context.Context, session authn.Session, id, username string) (users.User, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{id}, - }); err != nil { - return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } @@ -255,19 +123,6 @@ func (am *authorizationMiddleware) UpdateUsername(ctx context.Context, session a } func (am *authorizationMiddleware) UpdateProfilePicture(ctx context.Context, session authn.Session, user users.User) (users.User, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{user.ID}, - }); err != nil { - return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } @@ -280,19 +135,6 @@ func (am *authorizationMiddleware) GenerateResetToken(ctx context.Context, email } func (am *authorizationMiddleware) UpdateSecret(ctx context.Context, session authn.Session, oldSecret, newSecret string) (users.User, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{session.UserID}, - }); err != nil { - return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - return am.svc.UpdateSecret(ctx, session, oldSecret, newSecret) } @@ -305,19 +147,6 @@ func (am *authorizationMiddleware) SendPasswordReset(ctx context.Context, host, } func (am *authorizationMiddleware) UpdateRole(ctx context.Context, session authn.Session, user users.User) (users.User, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{user.ID}, - }); err != nil { - return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.checkSuperAdmin(ctx, session.UserID); err != nil { return users.User{}, err } @@ -330,19 +159,6 @@ func (am *authorizationMiddleware) UpdateRole(ctx context.Context, session authn } func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Session, id string) (users.User, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{id}, - }); err != nil { - return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } @@ -351,19 +167,6 @@ func (am *authorizationMiddleware) Enable(ctx context.Context, session authn.Ses } func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Session, id string) (users.User, error) { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.UpdateOp, - EntityIDs: []string{id}, - }); err != nil { - return users.User{}, errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } @@ -372,19 +175,6 @@ func (am *authorizationMiddleware) Disable(ctx context.Context, session authn.Se } func (am *authorizationMiddleware) Delete(ctx context.Context, session authn.Session, id string) error { - if session.Type == authn.PersonalAccessToken { - if err := am.authz.AuthorizePAT(ctx, smqauthz.PatReq{ - UserID: session.UserID, - PatID: session.ID, - PlatformEntityType: smqauth.PlatformUsersScope, - OptionalDomainEntityType: smqauth.DomainNullScope, - Operation: smqauth.DeleteOp, - EntityIDs: []string{id}, - }); err != nil { - return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) - } - } - if err := am.checkSuperAdmin(ctx, session.UserID); err == nil { session.SuperAdmin = true } diff --git a/users/middleware/pat.go b/users/middleware/pat.go new file mode 100644 index 0000000000..6a770a41d8 --- /dev/null +++ b/users/middleware/pat.go @@ -0,0 +1,210 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + grpcTokenV1 "github.com/absmach/supermq/api/grpc/token/v1" + smqauth "github.com/absmach/supermq/auth" + "github.com/absmach/supermq/pkg/authn" + "github.com/absmach/supermq/pkg/errors" + svcerr "github.com/absmach/supermq/pkg/errors/service" + smqpat "github.com/absmach/supermq/pkg/pat" + "github.com/absmach/supermq/pkg/policies" + "github.com/absmach/supermq/users" +) + +const emptyDomain = "" + +var _ users.Service = (*patMiddleware)(nil) + +type patMiddleware struct { + svc users.Service + pat smqpat.Authorization +} + +// PATMiddleware adds PAT validation to the users service. +func PATMiddleware(svc users.Service, pat smqpat.Authorization) users.Service { + return &patMiddleware{ + svc: svc, + pat: pat, + } +} + +// authorizePAT validates Personal Access Token if present in the session +func (pm *patMiddleware) authorizePAT(ctx context.Context, session authn.Session, platformEntityType smqauth.PlatformEntityType, optionalDomainEntityType smqauth.DomainEntityType, OptionalDomainID string, operation smqauth.OperationType, entityIDs []string) error { + if session.Type != authn.PersonalAccessToken { + return nil + } + if session.ID == "" || session.UserID == "" { + return errors.Wrap(svcerr.ErrAuthentication, errors.New("invalid PAT credentials")) + } + + if err := pm.pat.AuthorizePAT(ctx, smqpat.PatReq{ + UserID: session.UserID, + PatID: session.ID, + PlatformEntityType: platformEntityType, + OptionalDomainEntityType: optionalDomainEntityType, + OptionalDomainID: OptionalDomainID, + Operation: operation, + EntityIDs: entityIDs, + }); err != nil { + return errors.Wrap(svcerr.ErrUnauthorizedPAT, err) + } + + return nil +} + +func (pm *patMiddleware) Register(ctx context.Context, session authn.Session, user users.User, selfRegister bool) (users.User, error) { + return pm.svc.Register(ctx, session, user, selfRegister) +} + +func (pm *patMiddleware) View(ctx context.Context, session authn.Session, id string) (users.User, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.ReadOp, []string{id}); err != nil { + return users.User{}, err + } + return pm.svc.View(ctx, session, id) +} + +func (pm *patMiddleware) ViewProfile(ctx context.Context, session authn.Session) (users.User, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.ReadOp, []string{session.UserID}); err != nil { + return users.User{}, err + } + return pm.svc.ViewProfile(ctx, session) +} + +func (pm *patMiddleware) ListUsers(ctx context.Context, session authn.Session, page users.Page) (users.UsersPage, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.ListOp, smqauth.AnyIDs{}.Values()); err != nil { + return users.UsersPage{}, err + } + return pm.svc.ListUsers(ctx, session, page) +} + +func (pm *patMiddleware) ListMembers(ctx context.Context, session authn.Session, objectKind, objectID string, page users.Page) (users.MembersPage, error) { + switch objectKind { + case policies.GroupsKind: + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainGroupsScope, session.DomainID, smqauth.ListOp, smqauth.AnyIDs{}.Values()); err != nil { + return users.MembersPage{}, err + } + case policies.DomainsKind: + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainManagementScope, session.DomainID, smqauth.ListOp, smqauth.AnyIDs{}.Values()); err != nil { + return users.MembersPage{}, err + } + case policies.ClientsKind: + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainClientsScope, session.DomainID, smqauth.ListOp, smqauth.AnyIDs{}.Values()); err != nil { + return users.MembersPage{}, err + } + default: + return users.MembersPage{}, svcerr.ErrAuthorization + } + + return pm.svc.ListMembers(ctx, session, objectKind, objectID, page) +} + +func (pm *patMiddleware) SearchUsers(ctx context.Context, page users.Page) (users.UsersPage, error) { + return pm.svc.SearchUsers(ctx, page) +} + +func (pm *patMiddleware) Update(ctx context.Context, session authn.Session, user users.User) (users.User, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.UpdateOp, []string{user.ID}); err != nil { + return users.User{}, err + } + return pm.svc.Update(ctx, session, user) +} + +func (pm *patMiddleware) UpdateTags(ctx context.Context, session authn.Session, user users.User) (users.User, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.UpdateOp, []string{user.ID}); err != nil { + return users.User{}, err + } + return pm.svc.UpdateTags(ctx, session, user) +} + +func (pm *patMiddleware) UpdateEmail(ctx context.Context, session authn.Session, id, email string) (users.User, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.UpdateOp, []string{id}); err != nil { + return users.User{}, err + } + return pm.svc.UpdateEmail(ctx, session, id, email) +} + +func (pm *patMiddleware) UpdateUsername(ctx context.Context, session authn.Session, id, username string) (users.User, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.UpdateOp, []string{id}); err != nil { + return users.User{}, err + } + return pm.svc.UpdateUsername(ctx, session, id, username) +} + +func (pm *patMiddleware) UpdateProfilePicture(ctx context.Context, session authn.Session, user users.User) (users.User, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.UpdateOp, []string{user.ID}); err != nil { + return users.User{}, err + } + return pm.svc.UpdateProfilePicture(ctx, session, user) +} + +func (pm *patMiddleware) GenerateResetToken(ctx context.Context, email, host string) error { + return pm.svc.GenerateResetToken(ctx, email, host) +} + +func (pm *patMiddleware) UpdateSecret(ctx context.Context, session authn.Session, oldSecret, newSecret string) (users.User, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.UpdateOp, []string{session.UserID}); err != nil { + return users.User{}, err + } + return pm.svc.UpdateSecret(ctx, session, oldSecret, newSecret) +} + +func (pm *patMiddleware) ResetSecret(ctx context.Context, session authn.Session, secret string) error { + return pm.svc.ResetSecret(ctx, session, secret) +} + +func (pm *patMiddleware) SendPasswordReset(ctx context.Context, host, email, user, token string) error { + return pm.svc.SendPasswordReset(ctx, host, email, user, token) +} + +func (pm *patMiddleware) UpdateRole(ctx context.Context, session authn.Session, user users.User) (users.User, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.UpdateOp, []string{user.ID}); err != nil { + return users.User{}, err + } + return pm.svc.UpdateRole(ctx, session, user) +} + +func (pm *patMiddleware) Enable(ctx context.Context, session authn.Session, id string) (users.User, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.UpdateOp, []string{id}); err != nil { + return users.User{}, err + } + return pm.svc.Enable(ctx, session, id) +} + +func (pm *patMiddleware) Disable(ctx context.Context, session authn.Session, id string) (users.User, error) { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.UpdateOp, []string{id}); err != nil { + return users.User{}, err + } + return pm.svc.Disable(ctx, session, id) +} + +func (pm *patMiddleware) Delete(ctx context.Context, session authn.Session, id string) error { + if err := pm.authorizePAT(ctx, session, smqauth.PlatformUsersScope, smqauth.DomainNullScope, emptyDomain, smqauth.DeleteOp, []string{id}); err != nil { + return err + } + return pm.svc.Delete(ctx, session, id) +} + +func (pm *patMiddleware) Identify(ctx context.Context, session authn.Session) (string, error) { + return pm.svc.Identify(ctx, session) +} + +func (pm *patMiddleware) IssueToken(ctx context.Context, username, secret string) (*grpcTokenV1.Token, error) { + return pm.svc.IssueToken(ctx, username, secret) +} + +func (pm *patMiddleware) RefreshToken(ctx context.Context, session authn.Session, refreshToken string) (*grpcTokenV1.Token, error) { + return pm.svc.RefreshToken(ctx, session, refreshToken) +} + +func (pm *patMiddleware) OAuthCallback(ctx context.Context, user users.User) (users.User, error) { + return pm.svc.OAuthCallback(ctx, user) +} + +func (pm *patMiddleware) OAuthAddUserPolicy(ctx context.Context, user users.User) error { + return pm.svc.OAuthAddUserPolicy(ctx, user) +}