diff --git a/auth/api/http/keys/endpoint_test.go b/auth/api/http/keys/endpoint_test.go index f200f9f5ee..64f5e872c9 100644 --- a/auth/api/http/keys/endpoint_test.go +++ b/auth/api/http/keys/endpoint_test.go @@ -70,11 +70,13 @@ func newService() (auth.Service, *mocks.KeyRepository) { krepo := new(mocks.KeyRepository) prepo := new(mocks.PolicyAgent) drepo := new(mocks.DomainsRepository) + patsRepo := new(mocks.PATSRepository) + hasher := new(mocks.Hasher) idProvider := uuid.NewMock() t := jwt.New([]byte(secret)) - return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration), krepo + return auth.New(krepo, drepo, patsRepo, hasher, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration), krepo } func newServer(svc auth.Service) *httptest.Server { diff --git a/auth/api/http/pats/endpoint.go b/auth/api/http/pats/endpoint.go new file mode 100644 index 0000000000..9e83bf1581 --- /dev/null +++ b/auth/api/http/pats/endpoint.go @@ -0,0 +1,202 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package pats + +import ( + "context" + + "github.com/absmach/magistrala/auth" + "github.com/go-kit/kit/endpoint" +) + +func createPATEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(createPatReq) + if err := req.validate(); err != nil { + return nil, err + } + + pat, err := svc.CreatePAT(ctx, req.token, req.Name, req.Description, req.Duration, req.Scope) + if err != nil { + return nil, err + } + + return createPatRes{pat}, nil + } +} + +func retrievePATEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(retrievePatReq) + if err := req.validate(); err != nil { + return nil, err + } + + pat, err := svc.RetrievePAT(ctx, req.token, req.id) + if err != nil { + return nil, err + } + + return retrievePatRes{pat}, nil + } +} + +func updatePATNameEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(updatePatNameReq) + if err := req.validate(); err != nil { + return nil, err + } + + pat, err := svc.UpdatePATName(ctx, req.token, req.id, req.Name) + if err != nil { + return nil, err + } + + return updatePatNameRes{pat}, nil + } +} + +func updatePATDescriptionEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(updatePatDescriptionReq) + if err := req.validate(); err != nil { + return nil, err + } + + pat, err := svc.UpdatePATDescription(ctx, req.token, req.id, req.Description) + if err != nil { + return nil, err + } + + return updatePatDescriptionRes{pat}, nil + } +} + +func listPATSEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(listPatsReq) + if err := req.validate(); err != nil { + return nil, err + } + + pm := auth.PATSPageMeta{ + Limit: req.limit, + Offset: req.offset, + } + patsPage, err := svc.ListPATS(ctx, req.token, pm) + if err != nil { + return nil, err + } + + return listPatsRes{patsPage}, nil + } +} + +func deletePATEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(deletePatReq) + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.DeletePAT(ctx, req.token, req.id); err != nil { + return nil, err + } + + return deletePatRes{}, nil + } +} + +func resetPATSecretEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(resetPatSecretReq) + if err := req.validate(); err != nil { + return nil, err + } + + pat, err := svc.ResetPATSecret(ctx, req.token, req.id, req.Duration) + if err != nil { + return nil, err + } + + return resetPatSecretRes{pat}, nil + } +} + +func revokePATSecretEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(revokePatSecretReq) + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.RevokePATSecret(ctx, req.token, req.id); err != nil { + return nil, err + } + + return revokePatSecretRes{}, nil + } +} + +func addPATScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(addPatScopeEntryReq) + if err := req.validate(); err != nil { + return nil, err + } + + scope, err := svc.AddPATScopeEntry(ctx, req.token, req.id, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...) + if err != nil { + return nil, err + } + + return addPatScopeEntryRes{scope}, nil + } +} + +func removePATScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(removePatScopeEntryReq) + if err := req.validate(); err != nil { + return nil, err + } + + scope, err := svc.RemovePATScopeEntry(ctx, req.token, req.id, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...) + if err != nil { + return nil, err + } + return removePatScopeEntryRes{scope}, nil + } +} + +func clearPATAllScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(clearAllScopeEntryReq) + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.ClearPATAllScopeEntry(ctx, req.token, req.id); err != nil { + return nil, err + } + + return clearAllScopeEntryRes{}, nil + } +} + +func authorizePATEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(authorizePATReq) + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.AuthorizePAT(ctx, req.token, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...); err != nil { + return nil, err + } + + return authorizePATRes{}, nil + } +} diff --git a/auth/api/http/pats/requests.go b/auth/api/http/pats/requests.go new file mode 100644 index 0000000000..61584021fd --- /dev/null +++ b/auth/api/http/pats/requests.go @@ -0,0 +1,361 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package pats + +import ( + "encoding/json" + "strings" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/pkg/apiutil" +) + +type createPatReq struct { + token string + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Duration time.Duration `json:"duration,omitempty"` + Scope auth.Scope `json:"scope,omitempty"` +} + +func (cpr *createPatReq) UnmarshalJSON(data []byte) error { + var temp struct { + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Duration string `json:"duration,omitempty"` + Scope auth.Scope `json:"scope,omitempty"` + } + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + duration, err := time.ParseDuration(temp.Duration) + if err != nil { + return err + } + cpr.Name = temp.Name + cpr.Description = temp.Description + cpr.Duration = duration + cpr.Scope = temp.Scope + return nil +} + +func (req createPatReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + + if strings.TrimSpace(req.Name) == "" { + return apiutil.ErrMissingName + } + + return nil +} + +type retrievePatReq struct { + token string + id string +} + +func (req retrievePatReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type updatePatNameReq struct { + token string + id string + Name string `json:"name,omitempty"` +} + +func (req updatePatNameReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + if strings.TrimSpace(req.Name) == "" { + return apiutil.ErrMissingName + } + return nil +} + +type updatePatDescriptionReq struct { + token string + id string + Description string `json:"description,omitempty"` +} + +func (req updatePatDescriptionReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + if strings.TrimSpace(req.Description) == "" { + return apiutil.ErrMissingDescription + } + return nil +} + +type listPatsReq struct { + token string + offset uint64 + limit uint64 +} + +func (req listPatsReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + return nil +} + +type deletePatReq struct { + token string + id string +} + +func (req deletePatReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type resetPatSecretReq struct { + token string + id string + Duration time.Duration `json:"duration,omitempty"` +} + +func (rspr *resetPatSecretReq) UnmarshalJSON(data []byte) error { + var temp struct { + Duration string `json:"duration,omitempty"` + } + + err := json.Unmarshal(data, &temp) + if err != nil { + return err + } + rspr.Duration, err = time.ParseDuration(temp.Duration) + if err != nil { + return err + } + return nil +} + +func (req resetPatSecretReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type revokePatSecretReq struct { + token string + id string +} + +func (req revokePatSecretReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type addPatScopeEntryReq struct { + token string + id string + PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"` + OptionalDomainID string `json:"optional_domain_id,omitempty"` + OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` + Operation auth.OperationType `json:"operation,omitempty"` + EntityIDs []string `json:"entity_ids,omitempty"` +} + +func (apser *addPatScopeEntryReq) UnmarshalJSON(data []byte) error { + var temp struct { + PlatformEntityType string `json:"platform_entity_type,omitempty"` + OptionalDomainID string `json:"optional_domain_id,omitempty"` + OptionalDomainEntityType string `json:"optional_domain_entity_type,omitempty"` + Operation string `json:"operation,omitempty"` + EntityIDs []string `json:"entity_ids,omitempty"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + pet, err := auth.ParsePlatformEntityType(temp.PlatformEntityType) + if err != nil { + return err + } + odt, err := auth.ParseDomainEntityType(temp.OptionalDomainEntityType) + if err != nil { + return err + } + op, err := auth.ParseOperationType(temp.Operation) + if err != nil { + return err + } + apser.PlatformEntityType = pet + apser.OptionalDomainID = temp.OptionalDomainID + apser.OptionalDomainEntityType = odt + apser.Operation = op + apser.EntityIDs = temp.EntityIDs + return nil +} + +func (req addPatScopeEntryReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type removePatScopeEntryReq struct { + token string + id string + PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"` + OptionalDomainID string `json:"optional_domain_id,omitempty"` + OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` + Operation auth.OperationType `json:"operation,omitempty"` + EntityIDs []string `json:"entity_ids,omitempty"` +} + +func (rpser *removePatScopeEntryReq) UnmarshalJSON(data []byte) error { + var temp struct { + PlatformEntityType string `json:"platform_entity_type,omitempty"` + OptionalDomainID string `json:"optional_domain_id,omitempty"` + OptionalDomainEntityType string `json:"optional_domain_entity_type,omitempty"` + Operation string `json:"operation,omitempty"` + EntityIDs []string `json:"entity_ids,omitempty"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + pet, err := auth.ParsePlatformEntityType(temp.PlatformEntityType) + if err != nil { + return err + } + odt, err := auth.ParseDomainEntityType(temp.OptionalDomainEntityType) + if err != nil { + return err + } + op, err := auth.ParseOperationType(temp.Operation) + if err != nil { + return err + } + rpser.PlatformEntityType = pet + rpser.OptionalDomainID = temp.OptionalDomainID + rpser.OptionalDomainEntityType = odt + rpser.Operation = op + rpser.EntityIDs = temp.EntityIDs + return nil +} + +func (req removePatScopeEntryReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type clearAllScopeEntryReq struct { + token string + id string +} + +func (req clearAllScopeEntryReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + if req.id == "" { + return apiutil.ErrMissingID + } + return nil +} + +type authorizePATReq struct { + token string + PlatformEntityType auth.PlatformEntityType `json:"platform_entity_type,omitempty"` + OptionalDomainID string `json:"optional_domain_id,omitempty"` + OptionalDomainEntityType auth.DomainEntityType `json:"optional_domain_entity_type,omitempty"` + Operation auth.OperationType `json:"operation,omitempty"` + EntityIDs []string `json:"entity_ids,omitempty"` +} + +func (tcpsr *authorizePATReq) UnmarshalJSON(data []byte) error { + var temp struct { + PlatformEntityType string `json:"platform_entity_type,omitempty"` + OptionalDomainID string `json:"optional_domain_id,omitempty"` + OptionalDomainEntityType string `json:"optional_domain_entity_type,omitempty"` + Operation string `json:"operation,omitempty"` + EntityIDs []string `json:"entity_ids,omitempty"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + tcpsr.OptionalDomainID = temp.OptionalDomainID + tcpsr.EntityIDs = temp.EntityIDs + + pet, err := auth.ParsePlatformEntityType(temp.PlatformEntityType) + if err != nil { + return err + } + tcpsr.PlatformEntityType = pet + + if temp.OptionalDomainEntityType != "" { + odt, err := auth.ParseDomainEntityType(temp.OptionalDomainEntityType) + if err != nil { + return err + } + tcpsr.OptionalDomainEntityType = odt + } + + if temp.OptionalDomainID != "" { + op, err := auth.ParseOperationType(temp.Operation) + if err != nil { + return err + } + tcpsr.Operation = op + } + + return nil +} + +func (req authorizePATReq) validate() (err error) { + if req.token == "" { + return apiutil.ErrBearerToken + } + + return nil +} diff --git a/auth/api/http/pats/responses.go b/auth/api/http/pats/responses.go new file mode 100644 index 0000000000..01a18722d2 --- /dev/null +++ b/auth/api/http/pats/responses.go @@ -0,0 +1,208 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package pats + +import ( + "net/http" + + "github.com/absmach/magistrala" + "github.com/absmach/magistrala/auth" +) + +var ( + _ magistrala.Response = (*createPatRes)(nil) + _ magistrala.Response = (*retrievePatRes)(nil) + _ magistrala.Response = (*updatePatNameRes)(nil) + _ magistrala.Response = (*updatePatDescriptionRes)(nil) + _ magistrala.Response = (*deletePatRes)(nil) + _ magistrala.Response = (*resetPatSecretRes)(nil) + _ magistrala.Response = (*revokePatSecretRes)(nil) + _ magistrala.Response = (*addPatScopeEntryRes)(nil) + _ magistrala.Response = (*removePatScopeEntryRes)(nil) + _ magistrala.Response = (*clearAllScopeEntryRes)(nil) +) + +type createPatRes struct { + auth.PAT +} + +func (res createPatRes) Code() int { + return http.StatusCreated +} + +func (res createPatRes) Headers() map[string]string { + return map[string]string{} +} + +func (res createPatRes) Empty() bool { + return false +} + +type retrievePatRes struct { + auth.PAT +} + +func (res retrievePatRes) Code() int { + return http.StatusOK +} + +func (res retrievePatRes) Headers() map[string]string { + return map[string]string{} +} + +func (res retrievePatRes) Empty() bool { + return false +} + +type updatePatNameRes struct { + auth.PAT +} + +func (res updatePatNameRes) Code() int { + return http.StatusAccepted +} + +func (res updatePatNameRes) Headers() map[string]string { + return map[string]string{} +} + +func (res updatePatNameRes) Empty() bool { + return false +} + +type updatePatDescriptionRes struct { + auth.PAT +} + +func (res updatePatDescriptionRes) Code() int { + return http.StatusAccepted +} + +func (res updatePatDescriptionRes) Headers() map[string]string { + return map[string]string{} +} + +func (res updatePatDescriptionRes) Empty() bool { + return false +} + +type listPatsRes struct { + auth.PATSPage +} + +func (res listPatsRes) Code() int { + return http.StatusOK +} + +func (res listPatsRes) Headers() map[string]string { + return map[string]string{} +} + +func (res listPatsRes) Empty() bool { + return false +} + +type deletePatRes struct{} + +func (res deletePatRes) Code() int { + return http.StatusNoContent +} + +func (res deletePatRes) Headers() map[string]string { + return map[string]string{} +} + +func (res deletePatRes) Empty() bool { + return true +} + +type resetPatSecretRes struct { + auth.PAT +} + +func (res resetPatSecretRes) Code() int { + return http.StatusOK +} + +func (res resetPatSecretRes) Headers() map[string]string { + return map[string]string{} +} + +func (res resetPatSecretRes) Empty() bool { + return false +} + +type revokePatSecretRes struct{} + +func (res revokePatSecretRes) Code() int { + return http.StatusNoContent +} + +func (res revokePatSecretRes) Headers() map[string]string { + return map[string]string{} +} + +func (res revokePatSecretRes) Empty() bool { + return true +} + +type addPatScopeEntryRes struct { + auth.Scope +} + +func (res addPatScopeEntryRes) Code() int { + return http.StatusAccepted +} + +func (res addPatScopeEntryRes) Headers() map[string]string { + return map[string]string{} +} + +func (res addPatScopeEntryRes) Empty() bool { + return false +} + +type removePatScopeEntryRes struct { + auth.Scope +} + +func (res removePatScopeEntryRes) Code() int { + return http.StatusAccepted +} + +func (res removePatScopeEntryRes) Headers() map[string]string { + return map[string]string{} +} + +func (res removePatScopeEntryRes) Empty() bool { + return false +} + +type clearAllScopeEntryRes struct{} + +func (res clearAllScopeEntryRes) Code() int { + return http.StatusOK +} + +func (res clearAllScopeEntryRes) Headers() map[string]string { + return map[string]string{} +} + +func (res clearAllScopeEntryRes) Empty() bool { + return true +} + +type authorizePATRes struct{} + +func (res authorizePATRes) Code() int { + return http.StatusNoContent +} + +func (res authorizePATRes) Headers() map[string]string { + return map[string]string{} +} + +func (res authorizePATRes) Empty() bool { + return true +} diff --git a/auth/api/http/pats/transport.go b/auth/api/http/pats/transport.go new file mode 100644 index 0000000000..0c7dda9cc3 --- /dev/null +++ b/auth/api/http/pats/transport.go @@ -0,0 +1,266 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package pats + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "strings" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/internal/api" + "github.com/absmach/magistrala/pkg/apiutil" + "github.com/absmach/magistrala/pkg/errors" + "github.com/go-chi/chi/v5" + kithttp "github.com/go-kit/kit/transport/http" +) + +const ( + contentType = "application/json" + defInterval = "30d" +) + +// MakeHandler returns a HTTP handler for API endpoints. +func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux { + opts := []kithttp.ServerOption{ + kithttp.ServerErrorEncoder(apiutil.LoggingErrorEncoder(logger, api.EncodeError)), + } + mux.Route("/pats", func(r chi.Router) { + r.Post("/", kithttp.NewServer( + createPATEndpoint(svc), + decodeCreatePATRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Get("/{id}", kithttp.NewServer( + (retrievePATEndpoint(svc)), + decodeRetrievePATRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Put("/{id}/name", kithttp.NewServer( + (updatePATNameEndpoint(svc)), + decodeUpdatePATNameRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Put("/{id}/description", kithttp.NewServer( + (updatePATDescriptionEndpoint(svc)), + decodeUpdatePATDescriptionRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Get("/", kithttp.NewServer( + (listPATSEndpoint(svc)), + decodeListPATSRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Delete("/{id}", kithttp.NewServer( + (deletePATEndpoint(svc)), + decodeDeletePATRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Put("/{id}/secret/reset", kithttp.NewServer( + (resetPATSecretEndpoint(svc)), + decodeResetPATSecretRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Put("/{id}/secret/revoke", kithttp.NewServer( + (revokePATSecretEndpoint(svc)), + decodeRevokePATSecretRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Put("/{id}/scope/add", kithttp.NewServer( + (addPATScopeEntryEndpoint(svc)), + decodeAddPATScopeEntryRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Put("/{id}/scope/remove", kithttp.NewServer( + (removePATScopeEntryEndpoint(svc)), + decodeRemovePATScopeEntryRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Delete("/{id}/scope", kithttp.NewServer( + (clearPATAllScopeEntryEndpoint(svc)), + decodeClearPATAllScopeEntryRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + + r.Post("/authorize", kithttp.NewServer( + (authorizePATEndpoint(svc)), + decodeAuthorizePATRequest, + api.EncodeResponse, + opts..., + ).ServeHTTP) + }) + return mux +} + +func decodeCreatePATRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := createPatReq{token: apiutil.ExtractBearerToken(r)} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, errors.Wrap(err, errors.ErrMalformedEntity)) + } + return req, nil +} + +func decodeRetrievePATRequest(_ context.Context, r *http.Request) (interface{}, error) { + req := retrievePatReq{ + token: apiutil.ExtractBearerToken(r), + id: chi.URLParam(r, "id"), + } + return req, nil +} + +func decodeUpdatePATNameRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := updatePatNameReq{ + token: apiutil.ExtractBearerToken(r), + id: chi.URLParam(r, "id"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + return req, nil +} + +func decodeUpdatePATDescriptionRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := updatePatDescriptionReq{ + token: apiutil.ExtractBearerToken(r), + id: chi.URLParam(r, "id"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + return req, nil +} + +func decodeListPATSRequest(_ context.Context, r *http.Request) (interface{}, error) { + l, err := apiutil.ReadNumQuery[uint64](r, api.LimitKey, api.DefLimit) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + o, err := apiutil.ReadNumQuery[uint64](r, api.OffsetKey, api.DefOffset) + if err != nil { + return nil, errors.Wrap(apiutil.ErrValidation, err) + } + req := listPatsReq{ + token: apiutil.ExtractBearerToken(r), + limit: l, + offset: o, + } + return req, nil +} + +func decodeDeletePATRequest(_ context.Context, r *http.Request) (interface{}, error) { + return deletePatReq{ + token: apiutil.ExtractBearerToken(r), + id: chi.URLParam(r, "id"), + }, nil +} + +func decodeResetPATSecretRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := resetPatSecretReq{ + token: apiutil.ExtractBearerToken(r), + id: chi.URLParam(r, "id"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + return req, nil +} + +func decodeRevokePATSecretRequest(_ context.Context, r *http.Request) (interface{}, error) { + return revokePatSecretReq{ + token: apiutil.ExtractBearerToken(r), + id: chi.URLParam(r, "id"), + }, nil +} + +func decodeAddPATScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := addPatScopeEntryReq{ + token: apiutil.ExtractBearerToken(r), + id: chi.URLParam(r, "id"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + return req, nil +} + +func decodeRemovePATScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := removePatScopeEntryReq{ + token: apiutil.ExtractBearerToken(r), + id: chi.URLParam(r, "id"), + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + return req, nil +} + +func decodeClearPATAllScopeEntryRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + return clearAllScopeEntryReq{ + token: apiutil.ExtractBearerToken(r), + id: chi.URLParam(r, "id"), + }, nil +} + +func decodeAuthorizePATRequest(_ context.Context, r *http.Request) (interface{}, error) { + if !strings.Contains(r.Header.Get("Content-Type"), contentType) { + return nil, apiutil.ErrUnsupportedContentType + } + + req := authorizePATReq{token: apiutil.ExtractBearerToken(r)} + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + return nil, errors.Wrap(errors.ErrMalformedEntity, err) + } + return req, nil +} diff --git a/auth/api/http/transport.go b/auth/api/http/transport.go index 5e31ee553f..d3d8306871 100644 --- a/auth/api/http/transport.go +++ b/auth/api/http/transport.go @@ -10,6 +10,7 @@ import ( "github.com/absmach/magistrala/auth" "github.com/absmach/magistrala/auth/api/http/domains" "github.com/absmach/magistrala/auth/api/http/keys" + "github.com/absmach/magistrala/auth/api/http/pats" "github.com/go-chi/chi/v5" "github.com/prometheus/client_golang/prometheus/promhttp" ) @@ -20,6 +21,7 @@ func MakeHandler(svc auth.Service, logger *slog.Logger, instanceID string) http. mux = keys.MakeHandler(svc, mux, logger) mux = domains.MakeHandler(svc, mux, logger) + mux = pats.MakeHandler(svc, mux, logger) mux.Get("/health", magistrala.Health("auth", instanceID)) mux.Handle("/metrics", promhttp.Handler()) diff --git a/auth/api/logging.go b/auth/api/logging.go index 7240af2d27..046bc778b7 100644 --- a/auth/api/logging.go +++ b/auth/api/logging.go @@ -507,3 +507,253 @@ func (lm *loggingMiddleware) DeleteEntityPolicies(ctx context.Context, entityTyp }(time.Now()) return lm.svc.DeleteEntityPolicies(ctx, entityType, id) } + +func (lm *loggingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("name", name), + slog.String("description", description), + slog.String("pat_duration", duration.String()), + slog.String("scope", scope.String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Create PAT failed to complete successfully", args...) + return + } + lm.logger.Info("Create PAT completed successfully", args...) + }(time.Now()) + return lm.svc.CreatePAT(ctx, token, name, description, duration, scope) +} + +func (lm *loggingMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + slog.String("name", name), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Update PAT name failed to complete successfully", args...) + return + } + lm.logger.Info("Update PAT name completed successfully", args...) + }(time.Now()) + return lm.svc.UpdatePATName(ctx, token, patID, name) +} + +func (lm *loggingMiddleware) UpdatePATDescription(ctx context.Context, token, patID, description string) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + slog.String("description", description), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Update PAT description failed to complete successfully", args...) + return + } + lm.logger.Info("Update PAT description completed successfully", args...) + }(time.Now()) + return lm.svc.UpdatePATDescription(ctx, token, patID, description) +} + +func (lm *loggingMiddleware) RetrievePAT(ctx context.Context, token, patID string) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Retrieve PAT failed to complete successfully", args...) + return + } + lm.logger.Info("Retrieve PAT completed successfully", args...) + }(time.Now()) + return lm.svc.RetrievePAT(ctx, token, patID) +} + +func (lm *loggingMiddleware) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (pp auth.PATSPage, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.Uint64("limit", pm.Limit), + slog.Uint64("offset", pm.Offset), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("List PATS failed to complete successfully", args...) + return + } + lm.logger.Info("List PATS completed successfully", args...) + }(time.Now()) + return lm.svc.ListPATS(ctx, token, pm) +} + +func (lm *loggingMiddleware) DeletePAT(ctx context.Context, token, patID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Delete PAT failed to complete successfully", args...) + return + } + lm.logger.Info("Delete PAT completed successfully", args...) + }(time.Now()) + return lm.svc.DeletePAT(ctx, token, patID) +} + +func (lm *loggingMiddleware) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + slog.String("pat_duration", duration.String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Reset PAT secret failed to complete successfully", args...) + return + } + lm.logger.Info("Reset PAT secret completed successfully", args...) + }(time.Now()) + return lm.svc.ResetPATSecret(ctx, token, patID, duration) +} + +func (lm *loggingMiddleware) RevokePATSecret(ctx context.Context, token, patID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Revoke PAT secret failed to complete successfully", args...) + return + } + lm.logger.Info("Revoke PAT secret completed successfully", args...) + }(time.Now()) + return lm.svc.RevokePATSecret(ctx, token, patID) +} + +func (lm *loggingMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (sc auth.Scope, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + slog.String("platform_entity_type", platformEntityType.String()), + slog.String("optional_domain_id", optionalDomainID), + slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), + slog.String("operation", operation.String()), + slog.Any("entities", entityIDs), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Add entry to PAT scope failed to complete successfully", args...) + return + } + lm.logger.Info("Add entry to PAT scope completed successfully", args...) + }(time.Now()) + return lm.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (lm *loggingMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (sc auth.Scope, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + slog.String("platform_entity_type", platformEntityType.String()), + slog.String("optional_domain_id", optionalDomainID), + slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), + slog.String("operation", operation.String()), + slog.Any("entities", entityIDs), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Remove entry from PAT scope failed to complete successfully", args...) + return + } + lm.logger.Info("Remove entry from PAT scope completed successfully", args...) + }(time.Now()) + return lm.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (lm *loggingMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("pat_id", patID), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Clear all entry from PAT scope failed to complete successfully", args...) + return + } + lm.logger.Info("Clear all entry from PAT scope completed successfully", args...) + }(time.Now()) + return lm.svc.ClearPATAllScopeEntry(ctx, token, patID) +} + +func (lm *loggingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (pa auth.PAT, err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Identify PAT failed to complete successfully", args...) + return + } + lm.logger.Info("Identify PAT completed successfully", args...) + }(time.Now()) + return lm.svc.IdentifyPAT(ctx, paToken) +} + +func (lm *loggingMiddleware) AuthorizePAT(ctx context.Context, paToken string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("platform_entity_type", platformEntityType.String()), + slog.String("optional_domain_id", optionalDomainID), + slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), + slog.String("operation", operation.String()), + slog.Any("entities", entityIDs), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Authorize PAT failed complete successfully", args...) + return + } + lm.logger.Info("Authorize PAT completed successfully", args...) + }(time.Now()) + return lm.svc.AuthorizePAT(ctx, paToken, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (lm *loggingMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + slog.String("user_id", userID), + slog.String("pat_id", patID), + slog.String("platform_entity_type", platformEntityType.String()), + slog.String("optional_domain_id", optionalDomainID), + slog.String("optional_domain_entity_type", optionalDomainEntityType.String()), + slog.String("operation", operation.String()), + slog.Any("entities", entityIDs), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Check PAT failed complete successfully", args...) + return + } + lm.logger.Info("Check PAT completed successfully", args...) + }(time.Now()) + return lm.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} diff --git a/auth/api/metrics.go b/auth/api/metrics.go index 8ed201a82d..3808efc096 100644 --- a/auth/api/metrics.go +++ b/auth/api/metrics.go @@ -247,3 +247,115 @@ func (ms *metricsMiddleware) DeleteEntityPolicies(ctx context.Context, entityTyp }(time.Now()) return ms.svc.DeleteEntityPolicies(ctx, entityType, id) } + +func (ms *metricsMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "create_pat").Add(1) + ms.latency.With("method", "create_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.CreatePAT(ctx, token, name, description, duration, scope) +} + +func (ms *metricsMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "update_pat_name").Add(1) + ms.latency.With("method", "update_pat_name").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.UpdatePATName(ctx, token, patID, name) +} + +func (ms *metricsMiddleware) UpdatePATDescription(ctx context.Context, token, patID, description string) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "update_pat_description").Add(1) + ms.latency.With("method", "update_pat_description").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.UpdatePATDescription(ctx, token, patID, description) +} + +func (ms *metricsMiddleware) RetrievePAT(ctx context.Context, token, patID string) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "retrieve_pat").Add(1) + ms.latency.With("method", "retrieve_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.RetrievePAT(ctx, token, patID) +} + +func (ms *metricsMiddleware) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + defer func(begin time.Time) { + ms.counter.With("method", "list_pats").Add(1) + ms.latency.With("method", "list_pats").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.ListPATS(ctx, token, pm) +} + +func (ms *metricsMiddleware) DeletePAT(ctx context.Context, token, patID string) error { + defer func(begin time.Time) { + ms.counter.With("method", "delete_pat").Add(1) + ms.latency.With("method", "delete_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.DeletePAT(ctx, token, patID) +} + +func (ms *metricsMiddleware) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "reset_pat_secret").Add(1) + ms.latency.With("method", "reset_pat_secret").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.ResetPATSecret(ctx, token, patID, duration) +} + +func (ms *metricsMiddleware) RevokePATSecret(ctx context.Context, token, patID string) error { + defer func(begin time.Time) { + ms.counter.With("method", "revoke_pat_secret").Add(1) + ms.latency.With("method", "revoke_pat_secret").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.RevokePATSecret(ctx, token, patID) +} + +func (ms *metricsMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + defer func(begin time.Time) { + ms.counter.With("method", "add_pat_scope_entry").Add(1) + ms.latency.With("method", "add_pat_scope_entry").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (ms *metricsMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + defer func(begin time.Time) { + ms.counter.With("method", "remove_pat_scope_entry").Add(1) + ms.latency.With("method", "remove_pat_scope_entry").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (ms *metricsMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error { + defer func(begin time.Time) { + ms.counter.With("method", "clear_pat_all_scope_entry").Add(1) + ms.latency.With("method", "clear_pat_all_scope_entry").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.ClearPATAllScopeEntry(ctx, token, patID) +} + +func (ms *metricsMiddleware) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) { + defer func(begin time.Time) { + ms.counter.With("method", "identify_pat").Add(1) + ms.latency.With("method", "identify_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.IdentifyPAT(ctx, paToken) +} + +func (ms *metricsMiddleware) AuthorizePAT(ctx context.Context, paToken string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + defer func(begin time.Time) { + ms.counter.With("method", "authorize_pat").Add(1) + ms.latency.With("method", "authorize_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.AuthorizePAT(ctx, paToken, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (ms *metricsMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + defer func(begin time.Time) { + ms.counter.With("method", "check_pat").Add(1) + ms.latency.With("method", "check_pat").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return ms.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} diff --git a/auth/bolt/doc.go b/auth/bolt/doc.go new file mode 100644 index 0000000000..dcd06ac566 --- /dev/null +++ b/auth/bolt/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package bolt contains PAT repository implementations using +// bolt as the underlying database. +package bolt diff --git a/auth/bolt/init.go b/auth/bolt/init.go new file mode 100644 index 0000000000..9d496e65ca --- /dev/null +++ b/auth/bolt/init.go @@ -0,0 +1,21 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package bolt contains PAT repository implementations using +// bolt as the underlying database. +package bolt + +import ( + "github.com/absmach/magistrala/pkg/errors" + bolt "go.etcd.io/bbolt" +) + +var errInit = errors.New("failed to initialize BoltDB") + +func Init(tx *bolt.Tx, bucket string) error { + _, err := tx.CreateBucketIfNotExists([]byte(bucket)) + if err != nil { + return errors.Wrap(errInit, err) + } + return nil +} diff --git a/auth/bolt/pat.go b/auth/bolt/pat.go new file mode 100644 index 0000000000..4534dc4e85 --- /dev/null +++ b/auth/bolt/pat.go @@ -0,0 +1,773 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bolt + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "strings" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + bolt "go.etcd.io/bbolt" +) + +const ( + idKey = "id" + userKey = "user" + nameKey = "name" + descriptionKey = "description" + secretKey = "secret_key" + scopeKey = "scope" + issuedAtKey = "issued_at" + expiresAtKey = "expires_at" + updatedAtKey = "updated_at" + lastUsedAtKey = "last_used_at" + revokedKey = "revoked" + revokedAtKey = "revoked_at" + platformEntitiesKey = "platform_entities" + patKey = "pat" + + keySeparator = ":" + anyID = "*" +) + +var ( + activateValue = []byte{0x00} + revokedValue = []byte{0x01} + entityValue = []byte{0x02} + anyIDValue = []byte{0x03} + selectedIDsValue = []byte{0x04} +) + +type patRepo struct { + db *bolt.DB + bucketName string +} + +// NewPATSRepository instantiates a bolt +// implementation of PAT repository. +func NewPATSRepository(db *bolt.DB, bucketName string) auth.PATSRepository { + return &patRepo{ + db: db, + bucketName: bucketName, + } +} + +func (pr *patRepo) Save(ctx context.Context, pat auth.PAT) error { + idxKey := []byte(pat.User + keySeparator + patKey + keySeparator + pat.ID) + kv, err := patToKeyValue(pat) + if err != nil { + return err + } + return pr.db.Update(func(tx *bolt.Tx) error { + rootBucket, err := pr.retrieveRootBucket(tx) + if err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + b, err := pr.createUserBucket(rootBucket, pat.User) + if err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + for key, value := range kv { + fullKey := []byte(pat.ID + keySeparator + key) + if err := b.Put(fullKey, value); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + } + if err := rootBucket.Put(idxKey, []byte(pat.ID)); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + return nil + }) +} + +func (pr *patRepo) Retrieve(ctx context.Context, userID, patID string) (auth.PAT, error) { + prefix := []byte(patID + keySeparator) + kv := map[string][]byte{} + if err := pr.db.View(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) + if err != nil { + return err + } + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + kv[string(k)] = v + } + return nil + }); err != nil { + return auth.PAT{}, err + } + + return keyValueToPAT(kv) +} + +func (pr *patRepo) RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, error) { + revoked := true + keySecret := patID + keySeparator + secretKey + keyRevoked := patID + keySeparator + revokedKey + var secretHash string + if err := pr.db.View(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) + if err != nil { + return err + } + secretHash = string(b.Get([]byte(keySecret))) + revoked = bytesToBoolean(b.Get([]byte(keyRevoked))) + return nil + }); err != nil { + return "", true, err + } + return secretHash, revoked, nil +} + +func (pr *patRepo) UpdateName(ctx context.Context, userID, patID, name string) (auth.PAT, error) { + return pr.updatePATField(ctx, userID, patID, nameKey, []byte(name)) +} + +func (pr *patRepo) UpdateDescription(ctx context.Context, userID, patID, description string) (auth.PAT, error) { + return pr.updatePATField(ctx, userID, patID, descriptionKey, []byte(description)) +} + +func (pr *patRepo) UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (auth.PAT, error) { + prefix := []byte(patID + keySeparator) + kv := map[string][]byte{} + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) + if err != nil { + return err + } + if err := b.Put([]byte(patID+keySeparator+secretKey), []byte(tokenHash)); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + if err := b.Put([]byte(patID+keySeparator+expiresAtKey), timeToBytes(expiryAt)); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + kv[string(k)] = v + } + return nil + }); err != nil { + return auth.PAT{}, err + } + return keyValueToPAT(kv) +} + +func (pr *patRepo) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + prefix := []byte(userID + keySeparator + patKey + keySeparator) + + patIDs := []string{} + if err := pr.db.View(func(tx *bolt.Tx) error { + b, err := pr.retrieveRootBucket(tx) + if err != nil { + return errors.Wrap(repoerr.ErrViewEntity, err) + } + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + if v != nil { + patIDs = append(patIDs, string(v)) + } + } + return nil + }); err != nil { + return auth.PATSPage{}, err + } + + total := len(patIDs) + + var pats []auth.PAT + + patsPage := auth.PATSPage{ + Total: uint64(total), + Limit: pm.Limit, + Offset: pm.Offset, + PATS: pats, + } + + if int(pm.Offset) >= total { + return patsPage, nil + } + + aLimit := pm.Limit + if rLimit := total - int(pm.Offset); int(pm.Limit) > rLimit { + aLimit = uint64(rLimit) + } + + for i := pm.Offset; i < pm.Offset+aLimit; i++ { + if int(i) < total { + pat, err := pr.Retrieve(ctx, userID, patIDs[i]) + if err != nil { + return patsPage, err + } + patsPage.PATS = append(patsPage.PATS, pat) + } + } + + return patsPage, nil +} + +func (pr *patRepo) Revoke(ctx context.Context, userID, patID string) error { + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) + if err != nil { + return err + } + if err := b.Put([]byte(patID+keySeparator+revokedKey), revokedValue); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + if err := b.Put([]byte(patID+keySeparator+revokedAtKey), timeToBytes(time.Now())); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + return nil + }); err != nil { + return err + } + return nil +} + +func (pr *patRepo) Reactivate(ctx context.Context, userID, patID string) error { + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) + if err != nil { + return err + } + if err := b.Put([]byte(patID+keySeparator+revokedKey), activateValue); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + if err := b.Put([]byte(patID+keySeparator+revokedAtKey), []byte{}); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + return nil + }); err != nil { + return err + } + return nil +} + +func (pr *patRepo) Remove(ctx context.Context, userID, patID string) error { + prefix := []byte(patID + keySeparator) + idxKey := []byte(userID + keySeparator + patKey + keySeparator + patID) + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity) + if err != nil { + return err + } + c := b.Cursor() + for k, _ := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, _ = c.Next() { + if err := b.Delete(k); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + } + rb, err := pr.retrieveRootBucket(tx) + if err != nil { + return err + } + if err := rb.Delete(idxKey); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + return nil + }); err != nil { + return err + } + + return nil +} + +func (pr *patRepo) AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + prefix := []byte(patID + keySeparator + scopeKey) + var rKV map[string][]byte + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrCreateEntity) + if err != nil { + return err + } + kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if err != nil { + return err + } + for key, value := range kv { + fullKey := []byte(patID + keySeparator + key) + if err := b.Put(fullKey, value); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + } + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + rKV[string(k)] = v + } + return nil + }); err != nil { + return auth.Scope{}, err + } + + return parseKeyValueToScope(rKV) +} + +func (pr *patRepo) RemoveScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + if len(entityIDs) == 0 { + return auth.Scope{}, repoerr.ErrMalformedEntity + } + prefix := []byte(patID + keySeparator + scopeKey) + var rKV map[string][]byte + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrRemoveEntity) + if err != nil { + return err + } + kv, err := scopeEntryToKeyValue(platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if err != nil { + return err + } + for key := range kv { + fullKey := []byte(patID + keySeparator + key) + if err := b.Delete(fullKey); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + } + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + rKV[string(k)] = v + } + return nil + }); err != nil { + return auth.Scope{}, err + } + return parseKeyValueToScope(rKV) +} + +func (pr *patRepo) CheckScopeEntry(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + return pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrViewEntity) + if err != nil { + return errors.Wrap(repoerr.ErrViewEntity, err) + } + srootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + if err != nil { + return errors.Wrap(repoerr.ErrViewEntity, err) + } + + rootKey := patID + keySeparator + srootKey + if value := b.Get([]byte(rootKey)); bytes.Equal(value, anyIDValue) { + return nil + } + for _, entity := range entityIDs { + value := b.Get([]byte(rootKey + keySeparator + entity)) + if !bytes.Equal(value, entityValue) { + return repoerr.ErrNotFound + } + } + return nil + }) +} + +func (pr *patRepo) RemoveAllScopeEntry(ctx context.Context, userID, patID string) error { + return nil +} + +func (pr *patRepo) updatePATField(_ context.Context, userID, patID, key string, value []byte) (auth.PAT, error) { + prefix := []byte(patID + keySeparator) + kv := map[string][]byte{} + if err := pr.db.Update(func(tx *bolt.Tx) error { + b, err := pr.retrieveUserBucket(tx, userID, patID, repoerr.ErrUpdateEntity) + if err != nil { + return err + } + if err := b.Put([]byte(patID+keySeparator+key), value); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + if err := b.Put([]byte(patID+keySeparator+updatedAtKey), timeToBytes(time.Now())); err != nil { + return errors.Wrap(repoerr.ErrUpdateEntity, err) + } + c := b.Cursor() + for k, v := c.Seek(prefix); k != nil && bytes.HasPrefix(k, prefix); k, v = c.Next() { + kv[string(k)] = v + } + return nil + }); err != nil { + return auth.PAT{}, err + } + return keyValueToPAT(kv) +} + +func (pr *patRepo) createUserBucket(rootBucket *bolt.Bucket, userID string) (*bolt.Bucket, error) { + userBucket, err := rootBucket.CreateBucketIfNotExists([]byte(userID)) + if err != nil { + return nil, errors.Wrap(repoerr.ErrCreateEntity, fmt.Errorf("failed to retrieve or create bucket for user %s : %w", userID, err)) + } + + return userBucket, nil +} + +func (pr *patRepo) retrieveUserBucket(tx *bolt.Tx, userID, patID string, wrap error) (*bolt.Bucket, error) { + rootBucket, err := pr.retrieveRootBucket(tx) + if err != nil { + return nil, errors.Wrap(wrap, err) + } + + vPatID := rootBucket.Get([]byte(userID + keySeparator + patKey + keySeparator + patID)) + if vPatID == nil { + return nil, repoerr.ErrNotFound + } + + userBucket := rootBucket.Bucket([]byte(userID)) + if userBucket == nil { + return nil, errors.Wrap(wrap, fmt.Errorf("user %s not found", userID)) + } + return userBucket, nil +} + +func (pr *patRepo) retrieveRootBucket(tx *bolt.Tx) (*bolt.Bucket, error) { + rootBucket := tx.Bucket([]byte(pr.bucketName)) + if rootBucket == nil { + return nil, fmt.Errorf("bucket %s not found", pr.bucketName) + } + return rootBucket, nil +} + +func patToKeyValue(pat auth.PAT) (map[string][]byte, error) { + kv := map[string][]byte{ + idKey: []byte(pat.ID), + userKey: []byte(pat.User), + nameKey: []byte(pat.Name), + descriptionKey: []byte(pat.Description), + secretKey: []byte(pat.Secret), + issuedAtKey: timeToBytes(pat.IssuedAt), + expiresAtKey: timeToBytes(pat.ExpiresAt), + updatedAtKey: timeToBytes(pat.UpdatedAt), + lastUsedAtKey: timeToBytes(pat.LastUsedAt), + revokedKey: booleanToBytes(pat.Revoked), + revokedAtKey: timeToBytes(pat.RevokedAt), + } + scopeKV, err := scopeToKeyValue(pat.Scope) + if err != nil { + return nil, err + } + for k, v := range scopeKV { + kv[k] = v + } + return kv, nil +} + +func scopeToKeyValue(scope auth.Scope) (map[string][]byte, error) { + kv := map[string][]byte{} + for opType, scopeValue := range scope.Users { + tempKV, err := scopeEntryToKeyValue(auth.PlatformUsersScope, "", auth.DomainNullScope, opType, scopeValue.Values()...) + if err != nil { + return nil, err + } + for k, v := range tempKV { + kv[k] = v + } + } + for domainID, domainScope := range scope.Domains { + for opType, scopeValue := range domainScope.DomainManagement { + tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, auth.DomainManagementScope, opType, scopeValue.Values()...) + if err != nil { + return nil, errors.Wrap(repoerr.ErrCreateEntity, err) + } + for k, v := range tempKV { + kv[k] = v + } + } + for entityType, scope := range domainScope.Entities { + for opType, scopeValue := range scope { + tempKV, err := scopeEntryToKeyValue(auth.PlatformDomainsScope, domainID, entityType, opType, scopeValue.Values()...) + if err != nil { + return nil, errors.Wrap(repoerr.ErrCreateEntity, err) + } + for k, v := range tempKV { + kv[k] = v + } + } + } + } + return kv, nil +} + +func scopeEntryToKeyValue(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (map[string][]byte, error) { + if len(entityIDs) == 0 { + return nil, repoerr.ErrMalformedEntity + } + + rootKey, err := scopeRootKey(platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + if err != nil { + return nil, err + } + if len(entityIDs) == 1 && entityIDs[0] == anyID { + return map[string][]byte{rootKey: anyIDValue}, nil + } + + kv := map[string][]byte{rootKey: selectedIDsValue} + + for _, entryID := range entityIDs { + if entryID == anyID { + return nil, repoerr.ErrMalformedEntity + } + kv[rootKey+keySeparator+entryID] = entityValue + } + + return kv, nil +} + +func scopeRootKey(platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType) (string, error) { + op, err := operation.ValidString() + if err != nil { + return "", errors.Wrap(repoerr.ErrMalformedEntity, err) + } + + var rootKey strings.Builder + + rootKey.WriteString(scopeKey) + rootKey.WriteString(keySeparator) + rootKey.WriteString(platformEntityType.String()) + rootKey.WriteString(keySeparator) + + switch platformEntityType { + case auth.PlatformUsersScope: + rootKey.WriteString(op) + case auth.PlatformDomainsScope: + if optionalDomainID == "" { + return "", fmt.Errorf("failed to add platform %s scope: invalid domain id", platformEntityType.String()) + } + odet, err := optionalDomainEntityType.ValidString() + if err != nil { + return "", errors.Wrap(repoerr.ErrMalformedEntity, err) + } + rootKey.WriteString(optionalDomainID) + rootKey.WriteString(keySeparator) + rootKey.WriteString(odet) + rootKey.WriteString(keySeparator) + rootKey.WriteString(op) + default: + return "", errors.Wrap(repoerr.ErrMalformedEntity, fmt.Errorf("invalid platform entity type %s", platformEntityType.String())) + } + + return rootKey.String(), nil +} + +func keyValueToBasicPAT(kv map[string][]byte) auth.PAT { + var pat auth.PAT + for k, v := range kv { + switch { + case strings.HasSuffix(k, keySeparator+idKey): + pat.ID = string(v) + case strings.HasSuffix(k, keySeparator+userKey): + pat.User = string(v) + case strings.HasSuffix(k, keySeparator+nameKey): + pat.Name = string(v) + case strings.HasSuffix(k, keySeparator+descriptionKey): + pat.Description = string(v) + case strings.HasSuffix(k, keySeparator+issuedAtKey): + pat.IssuedAt = bytesToTime(v) + case strings.HasSuffix(k, keySeparator+expiresAtKey): + pat.ExpiresAt = bytesToTime(v) + case strings.HasSuffix(k, keySeparator+updatedAtKey): + pat.UpdatedAt = bytesToTime(v) + case strings.HasSuffix(k, keySeparator+lastUsedAtKey): + pat.LastUsedAt = bytesToTime(v) + case strings.HasSuffix(k, keySeparator+revokedKey): + pat.Revoked = bytesToBoolean(v) + case strings.HasSuffix(k, keySeparator+revokedAtKey): + pat.RevokedAt = bytesToTime(v) + } + } + return pat +} + +func keyValueToPAT(kv map[string][]byte) (auth.PAT, error) { + pat := keyValueToBasicPAT(kv) + scope, err := parseKeyValueToScope(kv) + if err != nil { + return auth.PAT{}, err + } + pat.Scope = scope + return pat, nil +} + +func parseKeyValueToScope(kv map[string][]byte) (auth.Scope, error) { + scope := auth.Scope{ + Domains: make(map[string]auth.DomainScope), + } + for key, value := range kv { + if strings.Index(key, keySeparator+scopeKey+keySeparator) > 0 { + keyParts := strings.Split(key, keySeparator) + + platformEntityType, err := auth.ParsePlatformEntityType(keyParts[2]) + if err != nil { + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + switch platformEntityType { + case auth.PlatformUsersScope: + scope.Users, err = parseOperation(platformEntityType, scope.Users, key, keyParts, value) + if err != nil { + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + + case auth.PlatformDomainsScope: + if len(keyParts) < 6 { + return auth.Scope{}, fmt.Errorf("invalid scope key format: %s", key) + } + domainID := keyParts[3] + if scope.Domains == nil { + scope.Domains = make(map[string]auth.DomainScope) + } + if _, ok := scope.Domains[domainID]; !ok { + scope.Domains[domainID] = auth.DomainScope{} + } + domainScope := scope.Domains[domainID] + + entityType := keyParts[4] + + switch entityType { + case auth.DomainManagementScope.String(): + domainScope.DomainManagement, err = parseOperation(platformEntityType, domainScope.DomainManagement, key, keyParts, value) + if err != nil { + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + default: + etype, err := auth.ParseDomainEntityType(entityType) + if err != nil { + return auth.Scope{}, fmt.Errorf("key %s invalid entity type %s : %w", key, entityType, err) + } + if domainScope.Entities == nil { + domainScope.Entities = make(map[auth.DomainEntityType]auth.OperationScope) + } + if _, ok := domainScope.Entities[etype]; !ok { + domainScope.Entities[etype] = auth.OperationScope{} + } + entityOperationScope := domainScope.Entities[etype] + entityOperationScope, err = parseOperation(platformEntityType, entityOperationScope, key, keyParts, value) + if err != nil { + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + domainScope.Entities[etype] = entityOperationScope + } + scope.Domains[domainID] = domainScope + default: + return auth.Scope{}, errors.Wrap(repoerr.ErrViewEntity, fmt.Errorf("invalid platform entity type : %s", platformEntityType.String())) + } + } + } + return scope, nil +} + +func parseOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) (auth.OperationScope, error) { + if opScope == nil { + opScope = make(map[auth.OperationType]auth.ScopeValue) + } + + if err := validateOperation(platformEntityType, opScope, key, keyParts, value); err != nil { + return auth.OperationScope{}, err + } + + switch string(value) { + case string(entityValue): + opType, err := auth.ParseOperationType(keyParts[len(keyParts)-2]) + if err != nil { + return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + entityID := keyParts[len(keyParts)-1] + + if _, oValueExists := opScope[opType]; !oValueExists { + opScope[opType] = &auth.SelectedIDs{} + } + oValue := opScope[opType] + if err := oValue.AddValues(entityID); err != nil { + return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity value %v : %w", key, entityID, err) + } + opScope[opType] = oValue + case string(anyIDValue): + opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1]) + if err != nil { + return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + if oValue, oValueExists := opScope[opType]; oValueExists && oValue != nil { + if _, ok := oValue.(*auth.AnyIDs); !ok { + return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity anyIDs scope value : key already initialized with different type", key) + } + } + opScope[opType] = &auth.AnyIDs{} + case string(selectedIDsValue): + opType, err := auth.ParseOperationType(keyParts[len(keyParts)-1]) + if err != nil { + return auth.OperationScope{}, errors.Wrap(repoerr.ErrViewEntity, err) + } + oValue, oValueExists := opScope[opType] + if oValueExists && oValue != nil { + if _, ok := oValue.(*auth.SelectedIDs); !ok { + return auth.OperationScope{}, fmt.Errorf("failed to add scope key %s with entity selectedIDs scope value : key already initialized with different type", key) + } + } + if !oValueExists { + opScope[opType] = &auth.SelectedIDs{} + } + default: + return auth.OperationScope{}, fmt.Errorf("key %s have invalid value %v", key, value) + } + return opScope, nil +} + +func validateOperation(platformEntityType auth.PlatformEntityType, opScope auth.OperationScope, key string, keyParts []string, value []byte) error { + expectedKeyPartsLength := 0 + switch string(value) { + case string(entityValue): + switch platformEntityType { + case auth.PlatformDomainsScope: + expectedKeyPartsLength = 7 + case auth.PlatformUsersScope: + expectedKeyPartsLength = 5 + default: + return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String()) + } + case string(selectedIDsValue), string(anyIDValue): + switch platformEntityType { + case auth.PlatformDomainsScope: + expectedKeyPartsLength = 6 + case auth.PlatformUsersScope: + expectedKeyPartsLength = 4 + default: + return fmt.Errorf("invalid platform entity type : %s", platformEntityType.String()) + } + default: + return fmt.Errorf("key %s have invalid value %v", key, value) + } + if len(keyParts) != expectedKeyPartsLength { + return fmt.Errorf("invalid scope key format: %s", key) + } + return nil +} + +func timeToBytes(t time.Time) []byte { + timeBytes := make([]byte, 8) + binary.BigEndian.PutUint64(timeBytes, uint64(t.Unix())) + return timeBytes +} + +func bytesToTime(b []byte) time.Time { + timeAtSeconds := binary.BigEndian.Uint64(b) + return time.Unix(int64(timeAtSeconds), 0) +} + +func booleanToBytes(b bool) []byte { + if b { + return []byte{1} + } + return []byte{0} +} + +func bytesToBoolean(b []byte) bool { + if len(b) > 1 || b[0] != activateValue[0] { + return true + } + return false +} diff --git a/auth/events/streams.go b/auth/events/streams.go index 0081a962fa..f2e87f5783 100644 --- a/auth/events/streams.go +++ b/auth/events/streams.go @@ -5,6 +5,7 @@ package events import ( "context" + "time" "github.com/absmach/magistrala/auth" "github.com/absmach/magistrala/pkg/events" @@ -262,3 +263,59 @@ func (es *eventStore) CountSubjects(ctx context.Context, pr auth.PolicyReq) (uin func (es *eventStore) ListPermissions(ctx context.Context, pr auth.PolicyReq, filterPermission []string) (auth.Permissions, error) { return es.svc.ListPermissions(ctx, pr, filterPermission) } + +func (es *eventStore) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { + return es.svc.CreatePAT(ctx, token, name, description, duration, scope) +} + +func (es *eventStore) UpdatePATName(ctx context.Context, token, patID, name string) (auth.PAT, error) { + return es.svc.UpdatePATName(ctx, token, patID, name) +} + +func (es *eventStore) UpdatePATDescription(ctx context.Context, token, patID, description string) (auth.PAT, error) { + return es.svc.UpdatePATDescription(ctx, token, patID, description) +} + +func (es *eventStore) RetrievePAT(ctx context.Context, token, patID string) (auth.PAT, error) { + return es.svc.RetrievePAT(ctx, token, patID) +} + +func (es *eventStore) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + return es.svc.ListPATS(ctx, token, pm) +} + +func (es *eventStore) DeletePAT(ctx context.Context, token, patID string) error { + return es.svc.DeletePAT(ctx, token, patID) +} + +func (es *eventStore) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (auth.PAT, error) { + return es.svc.ResetPATSecret(ctx, token, patID, duration) +} + +func (es *eventStore) RevokePATSecret(ctx context.Context, token, patID string) error { + return es.svc.RevokePATSecret(ctx, token, patID) +} + +func (es *eventStore) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + return es.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (es *eventStore) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + return es.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (es *eventStore) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error { + return es.svc.ClearPATAllScopeEntry(ctx, token, patID) +} + +func (es *eventStore) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) { + return es.svc.IdentifyPAT(ctx, paToken) +} + +func (es *eventStore) AuthorizePAT(ctx context.Context, paToken string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + return es.svc.AuthorizePAT(ctx, paToken, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (es *eventStore) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + return es.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} diff --git a/auth/hasher.go b/auth/hasher.go new file mode 100644 index 0000000000..ada2352bbe --- /dev/null +++ b/auth/hasher.go @@ -0,0 +1,17 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package auth + +// Hasher specifies an API for generating hashes of an arbitrary textual +// content. +// +//go:generate mockery --name Hasher --output=./mocks --filename hasher.go --quiet --note "Copyright (c) Abstract Machines" +type Hasher interface { + // Hash generates the hashed string from plain-text. + Hash(string) (string, error) + + // Compare compares plain-text version to the hashed one. An error should + // indicate failed comparison. + Compare(string, string) error +} diff --git a/auth/hasher/doc.go b/auth/hasher/doc.go new file mode 100644 index 0000000000..98be992262 --- /dev/null +++ b/auth/hasher/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package hasher contains the domain concept definitions needed to +// support Magistrala users password hasher sub-service functionality. +package hasher diff --git a/auth/hasher/hasher.go b/auth/hasher/hasher.go new file mode 100644 index 0000000000..c417bf7b80 --- /dev/null +++ b/auth/hasher/hasher.go @@ -0,0 +1,86 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package hasher + +import ( + "encoding/base64" + "fmt" + "math/rand" + "strings" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/pkg/errors" + "golang.org/x/crypto/scrypt" +) + +var ( + errHashToken = errors.New("failed to generate hash for token") + errHashCompare = errors.New("failed to generate hash for given compare string") + errToken = errors.New("given token and hash are not same") + errSalt = errors.New("failed to generate salt") + errInvalidHashStore = errors.New("invalid stored hash format") + errDecode = errors.New("failed to decode") +) + +var _ auth.Hasher = (*bcryptHasher)(nil) + +type bcryptHasher struct{} + +// New instantiates a bcrypt-based hasher implementation. +func New() auth.Hasher { + return &bcryptHasher{} +} + +func (bh *bcryptHasher) Hash(token string) (string, error) { + salt, err := generateSalt(25) + if err != nil { + return "", err + } + // N is kept 16384 to make faster and added large salt, since PAT will be access by automation scripts in high frequency. + hash, err := scrypt.Key([]byte(token), salt, 16384, 8, 1, 32) + if err != nil { + return "", errors.Wrap(errHashToken, err) + } + + return fmt.Sprintf("%s.%s", base64.StdEncoding.EncodeToString(hash), base64.StdEncoding.EncodeToString(salt)), nil +} + +func (bh *bcryptHasher) Compare(plain, hashed string) error { + parts := strings.Split(hashed, ".") + if len(parts) != 2 { + return errInvalidHashStore + } + + actHash, err := base64.StdEncoding.DecodeString(parts[0]) + if err != nil { + return errors.Wrap(errDecode, err) + } + + salt, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return errors.Wrap(errDecode, err) + } + + derivedHash, err := scrypt.Key([]byte(plain), salt, 16384, 8, 1, 32) + if err != nil { + return errors.Wrap(errHashCompare, err) + } + + if string(derivedHash) == string(actHash) { + return nil + } + + return errToken +} + +func generateSalt(length int) ([]byte, error) { + rand.New(rand.NewSource(time.Now().UnixNano())) + salt := make([]byte, length) + _, err := rand.Read(salt) + if err != nil { + return nil, errors.Wrap(errSalt, err) + } + return salt, nil +} diff --git a/auth/keys.go b/auth/keys.go index aa21ee48e0..e273119014 100644 --- a/auth/keys.go +++ b/auth/keys.go @@ -30,6 +30,8 @@ const ( RecoveryKey // APIKey enables the one to act on behalf of the user. APIKey + // PersonalAccessToken represents token generated by user for automation. + PersonalAccessToken // InvitationKey is a key for inviting new users. InvitationKey ) @@ -44,6 +46,8 @@ func (kt KeyType) String() string { return "recovery" case APIKey: return "API" + case PersonalAccessToken: + return "pat" default: return "unknown" } diff --git a/auth/mocks/hasher.go b/auth/mocks/hasher.go new file mode 100644 index 0000000000..4c4425b257 --- /dev/null +++ b/auth/mocks/hasher.go @@ -0,0 +1,72 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// Hasher is an autogenerated mock type for the Hasher type +type Hasher struct { + mock.Mock +} + +// Compare provides a mock function with given fields: _a0, _a1 +func (_m *Hasher) Compare(_a0 string, _a1 string) error { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for Compare") + } + + var r0 error + if rf, ok := ret.Get(0).(func(string, string) error); ok { + r0 = rf(_a0, _a1) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Hash provides a mock function with given fields: _a0 +func (_m *Hasher) Hash(_a0 string) (string, error) { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for Hash") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(string) (string, error)); ok { + return rf(_a0) + } + if rf, ok := ret.Get(0).(func(string) string); ok { + r0 = rf(_a0) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(_a0) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewHasher creates a new instance of Hasher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewHasher(t interface { + mock.TestingT + Cleanup(func()) +}) *Hasher { + mock := &Hasher{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/mocks/pats.go b/auth/mocks/pats.go new file mode 100644 index 0000000000..c18cddcef0 --- /dev/null +++ b/auth/mocks/pats.go @@ -0,0 +1,404 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + auth "github.com/absmach/magistrala/auth" + + mock "github.com/stretchr/testify/mock" + + time "time" +) + +// PATS is an autogenerated mock type for the PATS type +type PATS struct { + mock.Mock +} + +// AddPATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATS) AddPATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AddPATScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// AuthorizePAT provides a mock function with given fields: ctx, paToken, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATS) AuthorizePAT(ctx context.Context, paToken string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, paToken, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AuthorizePAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r0 = rf(ctx, paToken, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CheckPAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATS) CheckPAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CheckPAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ClearPATAllScopeEntry provides a mock function with given fields: ctx, token, patID +func (_m *PATS) ClearPATAllScopeEntry(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for ClearPATAllScopeEntry") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration, scope +func (_m *PATS) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { + ret := _m.Called(ctx, token, name, description, duration, scope) + + if len(ret) == 0 { + panic("no return value specified for CreatePAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) (auth.PAT, error)); ok { + return rf(ctx, token, name, description, duration, scope) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) auth.PAT); ok { + r0 = rf(ctx, token, name, description, duration, scope) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration, auth.Scope) error); ok { + r1 = rf(ctx, token, name, description, duration, scope) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeletePAT provides a mock function with given fields: ctx, token, patID +func (_m *PATS) DeletePAT(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for DeletePAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// IdentifyPAT provides a mock function with given fields: ctx, paToken +func (_m *PATS) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) { + ret := _m.Called(ctx, paToken) + + if len(ret) == 0 { + panic("no return value specified for IdentifyPAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (auth.PAT, error)); ok { + return rf(ctx, paToken) + } + if rf, ok := ret.Get(0).(func(context.Context, string) auth.PAT); ok { + r0 = rf(ctx, paToken) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, paToken) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ListPATS provides a mock function with given fields: ctx, token, pm +func (_m *PATS) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + ret := _m.Called(ctx, token, pm) + + if len(ret) == 0 { + panic("no return value specified for ListPATS") + } + + var r0 auth.PATSPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) (auth.PATSPage, error)); ok { + return rf(ctx, token, pm) + } + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) auth.PATSPage); ok { + r0 = rf(ctx, token, pm) + } else { + r0 = ret.Get(0).(auth.PATSPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, auth.PATSPageMeta) error); ok { + r1 = rf(ctx, token, pm) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RemovePATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATS) RemovePATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RemovePATScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ResetPATSecret provides a mock function with given fields: ctx, token, patID, duration +func (_m *PATS) ResetPATSecret(ctx context.Context, token string, patID string, duration time.Duration) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, duration) + + if len(ret) == 0 { + panic("no return value specified for ResetPATSecret") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) (auth.PAT, error)); ok { + return rf(ctx, token, patID, duration) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) auth.PAT); ok { + r0 = rf(ctx, token, patID, duration) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, time.Duration) error); ok { + r1 = rf(ctx, token, patID, duration) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RetrievePAT provides a mock function with given fields: ctx, token, patID +func (_m *PATS) RetrievePAT(ctx context.Context, token string, patID string) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for RetrievePAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (auth.PAT, error)); ok { + return rf(ctx, token, patID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) auth.PAT); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, token, patID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RevokePATSecret provides a mock function with given fields: ctx, token, patID +func (_m *PATS) RevokePATSecret(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for RevokePATSecret") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdatePATDescription provides a mock function with given fields: ctx, token, patID, description +func (_m *PATS) UpdatePATDescription(ctx context.Context, token string, patID string, description string) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, description) + + if len(ret) == 0 { + panic("no return value specified for UpdatePATDescription") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, token, patID, description) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, token, patID, description) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, token, patID, description) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdatePATName provides a mock function with given fields: ctx, token, patID, name +func (_m *PATS) UpdatePATName(ctx context.Context, token string, patID string, name string) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, name) + + if len(ret) == 0 { + panic("no return value specified for UpdatePATName") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, token, patID, name) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, token, patID, name) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, token, patID, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewPATS creates a new instance of PATS. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPATS(t interface { + mock.TestingT + Cleanup(func()) +}) *PATS { + mock := &PATS{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/mocks/patsrepo.go b/auth/mocks/patsrepo.go new file mode 100644 index 0000000000..323baca097 --- /dev/null +++ b/auth/mocks/patsrepo.go @@ -0,0 +1,394 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + auth "github.com/absmach/magistrala/auth" + + mock "github.com/stretchr/testify/mock" + + time "time" +) + +// PATSRepository is an autogenerated mock type for the PATSRepository type +type PATSRepository struct { + mock.Mock +} + +// AddScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATSRepository) AddScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AddScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CheckScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATSRepository) CheckScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CheckScopeEntry") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Reactivate provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) Reactivate(ctx context.Context, userID string, patID string) error { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for Reactivate") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Remove provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) Remove(ctx context.Context, userID string, patID string) error { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for Remove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemoveAllScopeEntry provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) RemoveAllScopeEntry(ctx context.Context, userID string, patID string) error { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for RemoveAllScopeEntry") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// RemoveScopeEntry provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *PATSRepository) RemoveScopeEntry(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RemoveScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Retrieve provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) Retrieve(ctx context.Context, userID string, patID string) (auth.PAT, error) { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for Retrieve") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (auth.PAT, error)); ok { + return rf(ctx, userID, patID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) auth.PAT); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, userID, patID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RetrieveAll provides a mock function with given fields: ctx, userID, pm +func (_m *PATSRepository) RetrieveAll(ctx context.Context, userID string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + ret := _m.Called(ctx, userID, pm) + + if len(ret) == 0 { + panic("no return value specified for RetrieveAll") + } + + var r0 auth.PATSPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) (auth.PATSPage, error)); ok { + return rf(ctx, userID, pm) + } + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) auth.PATSPage); ok { + r0 = rf(ctx, userID, pm) + } else { + r0 = ret.Get(0).(auth.PATSPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, auth.PATSPageMeta) error); ok { + r1 = rf(ctx, userID, pm) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// RetrieveSecretAndRevokeStatus provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) RetrieveSecretAndRevokeStatus(ctx context.Context, userID string, patID string) (string, bool, error) { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for RetrieveSecretAndRevokeStatus") + } + + var r0 string + var r1 bool + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (string, bool, error)); ok { + return rf(ctx, userID, patID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) string); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) bool); ok { + r1 = rf(ctx, userID, patID) + } else { + r1 = ret.Get(1).(bool) + } + + if rf, ok := ret.Get(2).(func(context.Context, string, string) error); ok { + r2 = rf(ctx, userID, patID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// Revoke provides a mock function with given fields: ctx, userID, patID +func (_m *PATSRepository) Revoke(ctx context.Context, userID string, patID string) error { + ret := _m.Called(ctx, userID, patID) + + if len(ret) == 0 { + panic("no return value specified for Revoke") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, userID, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Save provides a mock function with given fields: ctx, pat +func (_m *PATSRepository) Save(ctx context.Context, pat auth.PAT) error { + ret := _m.Called(ctx, pat) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, auth.PAT) error); ok { + r0 = rf(ctx, pat) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateDescription provides a mock function with given fields: ctx, userID, patID, description +func (_m *PATSRepository) UpdateDescription(ctx context.Context, userID string, patID string, description string) (auth.PAT, error) { + ret := _m.Called(ctx, userID, patID, description) + + if len(ret) == 0 { + panic("no return value specified for UpdateDescription") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, userID, patID, description) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, userID, patID, description) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, userID, patID, description) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateName provides a mock function with given fields: ctx, userID, patID, name +func (_m *PATSRepository) UpdateName(ctx context.Context, userID string, patID string, name string) (auth.PAT, error) { + ret := _m.Called(ctx, userID, patID, name) + + if len(ret) == 0 { + panic("no return value specified for UpdateName") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, userID, patID, name) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, userID, patID, name) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, userID, patID, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdateTokenHash provides a mock function with given fields: ctx, userID, patID, tokenHash, expiryAt +func (_m *PATSRepository) UpdateTokenHash(ctx context.Context, userID string, patID string, tokenHash string, expiryAt time.Time) (auth.PAT, error) { + ret := _m.Called(ctx, userID, patID, tokenHash, expiryAt) + + if len(ret) == 0 { + panic("no return value specified for UpdateTokenHash") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Time) (auth.PAT, error)); ok { + return rf(ctx, userID, patID, tokenHash, expiryAt) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Time) auth.PAT); ok { + r0 = rf(ctx, userID, patID, tokenHash, expiryAt) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Time) error); ok { + r1 = rf(ctx, userID, patID, tokenHash, expiryAt) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewPATSRepository creates a new instance of PATSRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewPATSRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *PATSRepository { + mock := &PATSRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/mocks/service.go b/auth/mocks/service.go index 45d3b3d7f9..2700de304e 100644 --- a/auth/mocks/service.go +++ b/auth/mocks/service.go @@ -10,6 +10,8 @@ import ( auth "github.com/absmach/magistrala/auth" mock "github.com/stretchr/testify/mock" + + time "time" ) // Service is an autogenerated mock type for the Service type @@ -17,6 +19,41 @@ type Service struct { mock.Mock } +// AddPATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *Service) AddPATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AddPATScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // AddPolicies provides a mock function with given fields: ctx, prs func (_m *Service) AddPolicies(ctx context.Context, prs []auth.PolicyReq) error { ret := _m.Called(ctx, prs) @@ -89,6 +126,31 @@ func (_m *Service) Authorize(ctx context.Context, pr auth.PolicyReq) error { return r0 } +// AuthorizePAT provides a mock function with given fields: ctx, paToken, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *Service) AuthorizePAT(ctx context.Context, paToken string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, paToken, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for AuthorizePAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r0 = rf(ctx, paToken, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // ChangeDomainStatus provides a mock function with given fields: ctx, token, id, d func (_m *Service) ChangeDomainStatus(ctx context.Context, token string, id string, d auth.DomainReq) (auth.Domain, error) { ret := _m.Called(ctx, token, id, d) @@ -117,6 +179,49 @@ func (_m *Service) ChangeDomainStatus(ctx context.Context, token string, id stri return r0, r1 } +// CheckPAT provides a mock function with given fields: ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *Service) CheckPAT(ctx context.Context, userID string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CheckPAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r0 = rf(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ClearPATAllScopeEntry provides a mock function with given fields: ctx, token, patID +func (_m *Service) ClearPATAllScopeEntry(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for ClearPATAllScopeEntry") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // CountObjects provides a mock function with given fields: ctx, pr func (_m *Service) CountObjects(ctx context.Context, pr auth.PolicyReq) (uint64, error) { ret := _m.Called(ctx, pr) @@ -201,6 +306,34 @@ func (_m *Service) CreateDomain(ctx context.Context, token string, d auth.Domain return r0, r1 } +// CreatePAT provides a mock function with given fields: ctx, token, name, description, duration, scope +func (_m *Service) CreatePAT(ctx context.Context, token string, name string, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { + ret := _m.Called(ctx, token, name, description, duration, scope) + + if len(ret) == 0 { + panic("no return value specified for CreatePAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) (auth.PAT, error)); ok { + return rf(ctx, token, name, description, duration, scope) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string, time.Duration, auth.Scope) auth.PAT); ok { + r0 = rf(ctx, token, name, description, duration, scope) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string, time.Duration, auth.Scope) error); ok { + r1 = rf(ctx, token, name, description, duration, scope) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // DeleteEntityPolicies provides a mock function with given fields: ctx, entityType, id func (_m *Service) DeleteEntityPolicies(ctx context.Context, entityType string, id string) error { ret := _m.Called(ctx, entityType, id) @@ -219,6 +352,24 @@ func (_m *Service) DeleteEntityPolicies(ctx context.Context, entityType string, return r0 } +// DeletePAT provides a mock function with given fields: ctx, token, patID +func (_m *Service) DeletePAT(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for DeletePAT") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // DeletePolicies provides a mock function with given fields: ctx, prs func (_m *Service) DeletePolicies(ctx context.Context, prs []auth.PolicyReq) error { ret := _m.Called(ctx, prs) @@ -283,6 +434,34 @@ func (_m *Service) Identify(ctx context.Context, token string) (auth.Key, error) return r0, r1 } +// IdentifyPAT provides a mock function with given fields: ctx, paToken +func (_m *Service) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) { + ret := _m.Called(ctx, paToken) + + if len(ret) == 0 { + panic("no return value specified for IdentifyPAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (auth.PAT, error)); ok { + return rf(ctx, paToken) + } + if rf, ok := ret.Get(0).(func(context.Context, string) auth.PAT); ok { + r0 = rf(ctx, paToken) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, paToken) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Issue provides a mock function with given fields: ctx, token, key func (_m *Service) Issue(ctx context.Context, token string, key auth.Key) (auth.Token, error) { ret := _m.Called(ctx, token, key) @@ -423,6 +602,34 @@ func (_m *Service) ListObjects(ctx context.Context, pr auth.PolicyReq, nextPageT return r0, r1 } +// ListPATS provides a mock function with given fields: ctx, token, pm +func (_m *Service) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + ret := _m.Called(ctx, token, pm) + + if len(ret) == 0 { + panic("no return value specified for ListPATS") + } + + var r0 auth.PATSPage + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) (auth.PATSPage, error)); ok { + return rf(ctx, token, pm) + } + if rf, ok := ret.Get(0).(func(context.Context, string, auth.PATSPageMeta) auth.PATSPage); ok { + r0 = rf(ctx, token, pm) + } else { + r0 = ret.Get(0).(auth.PATSPage) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, auth.PATSPageMeta) error); ok { + r1 = rf(ctx, token, pm) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // ListPermissions provides a mock function with given fields: ctx, pr, filterPermission func (_m *Service) ListPermissions(ctx context.Context, pr auth.PolicyReq, filterPermission []string) (auth.Permissions, error) { ret := _m.Called(ctx, pr, filterPermission) @@ -509,6 +716,69 @@ func (_m *Service) ListUserDomains(ctx context.Context, token string, userID str return r0, r1 } +// RemovePATScopeEntry provides a mock function with given fields: ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs +func (_m *Service) RemovePATScopeEntry(ctx context.Context, token string, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + _va := make([]interface{}, len(entityIDs)) + for _i := range entityIDs { + _va[_i] = entityIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for RemovePATScopeEntry") + } + + var r0 auth.Scope + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) (auth.Scope, error)); ok { + return rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) auth.Scope); ok { + r0 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r0 = ret.Get(0).(auth.Scope) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, auth.PlatformEntityType, string, auth.DomainEntityType, auth.OperationType, ...string) error); ok { + r1 = rf(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ResetPATSecret provides a mock function with given fields: ctx, token, patID, duration +func (_m *Service) ResetPATSecret(ctx context.Context, token string, patID string, duration time.Duration) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, duration) + + if len(ret) == 0 { + panic("no return value specified for ResetPATSecret") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) (auth.PAT, error)); ok { + return rf(ctx, token, patID, duration) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, time.Duration) auth.PAT); ok { + r0 = rf(ctx, token, patID, duration) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, time.Duration) error); ok { + r1 = rf(ctx, token, patID, duration) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // RetrieveDomain provides a mock function with given fields: ctx, token, id func (_m *Service) RetrieveDomain(ctx context.Context, token string, id string) (auth.Domain, error) { ret := _m.Called(ctx, token, id) @@ -595,6 +865,34 @@ func (_m *Service) RetrieveKey(ctx context.Context, token string, id string) (au return r0, r1 } +// RetrievePAT provides a mock function with given fields: ctx, token, patID +func (_m *Service) RetrievePAT(ctx context.Context, token string, patID string) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for RetrievePAT") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (auth.PAT, error)); ok { + return rf(ctx, token, patID) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) auth.PAT); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { + r1 = rf(ctx, token, patID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Revoke provides a mock function with given fields: ctx, token, id func (_m *Service) Revoke(ctx context.Context, token string, id string) error { ret := _m.Called(ctx, token, id) @@ -613,6 +911,24 @@ func (_m *Service) Revoke(ctx context.Context, token string, id string) error { return r0 } +// RevokePATSecret provides a mock function with given fields: ctx, token, patID +func (_m *Service) RevokePATSecret(ctx context.Context, token string, patID string) error { + ret := _m.Called(ctx, token, patID) + + if len(ret) == 0 { + panic("no return value specified for RevokePATSecret") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, token, patID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UnassignUser provides a mock function with given fields: ctx, token, id, userID func (_m *Service) UnassignUser(ctx context.Context, token string, id string, userID string) error { ret := _m.Called(ctx, token, id, userID) @@ -659,6 +975,62 @@ func (_m *Service) UpdateDomain(ctx context.Context, token string, id string, d return r0, r1 } +// UpdatePATDescription provides a mock function with given fields: ctx, token, patID, description +func (_m *Service) UpdatePATDescription(ctx context.Context, token string, patID string, description string) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, description) + + if len(ret) == 0 { + panic("no return value specified for UpdatePATDescription") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, token, patID, description) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, token, patID, description) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, token, patID, description) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// UpdatePATName provides a mock function with given fields: ctx, token, patID, name +func (_m *Service) UpdatePATName(ctx context.Context, token string, patID string, name string) (auth.PAT, error) { + ret := _m.Called(ctx, token, patID, name) + + if len(ret) == 0 { + panic("no return value specified for UpdatePATName") + } + + var r0 auth.PAT + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (auth.PAT, error)); ok { + return rf(ctx, token, patID, name) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) auth.PAT); ok { + r0 = rf(ctx, token, patID, name) + } else { + r0 = ret.Get(0).(auth.PAT) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { + r1 = rf(ctx, token, patID, name) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewService(t interface { diff --git a/auth/pat.go b/auth/pat.go new file mode 100644 index 0000000000..7e26065264 --- /dev/null +++ b/auth/pat.go @@ -0,0 +1,752 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package auth + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/absmach/magistrala/pkg/errors" +) + +var errAddEntityToAnyIDs = errors.New("could not add entity id to any ID scope value") + +// Define OperationType. +type OperationType uint32 + +const ( + CreateOp OperationType = iota + ReadOp + ListOp + UpdateOp + DeleteOp +) + +const ( + createOpStr = "create" + readOpStr = "read" + listOpStr = "list" + updateOpStr = "update" + deleteOpStr = "delete" +) + +func (ot OperationType) String() string { + switch ot { + case CreateOp: + return createOpStr + case ReadOp: + return readOpStr + case ListOp: + return listOpStr + case UpdateOp: + return updateOpStr + case DeleteOp: + return deleteOpStr + default: + return fmt.Sprintf("unknown operation type %d", ot) + } +} + +func (ot OperationType) ValidString() (string, error) { + str := ot.String() + if str == fmt.Sprintf("unknown operation type %d", ot) { + return "", errors.New(str) + } + return str, nil +} + +func ParseOperationType(ot string) (OperationType, error) { + switch ot { + case createOpStr: + return CreateOp, nil + case readOpStr: + return ReadOp, nil + case listOpStr: + return ListOp, nil + case updateOpStr: + return UpdateOp, nil + case deleteOpStr: + return DeleteOp, nil + default: + return 0, fmt.Errorf("unknown operation type %s", ot) + } +} + +func (ot OperationType) MarshalJSON() ([]byte, error) { + return []byte(ot.String()), nil +} + +func (ot OperationType) MarshalText() (text []byte, err error) { + return []byte(ot.String()), nil +} + +func (ot *OperationType) UnmarshalText(data []byte) (err error) { + *ot, err = ParseOperationType(string(data)) + return err +} + +// Define DomainEntityType. +type DomainEntityType uint32 + +const ( + DomainManagementScope DomainEntityType = iota + DomainGroupsScope + DomainChannelsScope + DomainThingsScope + DomainNullScope +) + +const ( + domainManagementScopeStr = "domain_management" + domainGroupsScopeStr = "groups" + domainChannelsScopeStr = "channels" + domainThingsScopeStr = "things" +) + +func (det DomainEntityType) String() string { + switch det { + case DomainManagementScope: + return domainManagementScopeStr + case DomainGroupsScope: + return domainGroupsScopeStr + case DomainChannelsScope: + return domainChannelsScopeStr + case DomainThingsScope: + return domainThingsScopeStr + default: + return fmt.Sprintf("unknown domain entity type %d", det) + } +} + +func (det DomainEntityType) ValidString() (string, error) { + str := det.String() + if str == fmt.Sprintf("unknown operation type %d", det) { + return "", errors.New(str) + } + return str, nil +} + +func ParseDomainEntityType(det string) (DomainEntityType, error) { + switch det { + case domainManagementScopeStr: + return DomainManagementScope, nil + case domainGroupsScopeStr: + return DomainGroupsScope, nil + case domainChannelsScopeStr: + return DomainChannelsScope, nil + case domainThingsScopeStr: + return DomainThingsScope, nil + default: + return 0, fmt.Errorf("unknown domain entity type %s", det) + } +} + +func (det DomainEntityType) MarshalJSON() ([]byte, error) { + return []byte(det.String()), nil +} + +func (det DomainEntityType) MarshalText() ([]byte, error) { + return []byte(det.String()), nil +} + +func (det *DomainEntityType) UnmarshalText(data []byte) (err error) { + *det, err = ParseDomainEntityType(string(data)) + return err +} + +// Define DomainEntityType. +type PlatformEntityType uint32 + +const ( + PlatformUsersScope PlatformEntityType = iota + PlatformDomainsScope +) + +const ( + platformUsersScopeStr = "users" + platformDomainsScopeStr = "domains" +) + +func (pet PlatformEntityType) String() string { + switch pet { + case PlatformUsersScope: + return platformUsersScopeStr + case PlatformDomainsScope: + return platformDomainsScopeStr + default: + return fmt.Sprintf("unknown platform entity type %d", pet) + } +} + +func (pet PlatformEntityType) ValidString() (string, error) { + str := pet.String() + if str == fmt.Sprintf("unknown platform entity type %d", pet) { + return "", errors.New(str) + } + return str, nil +} + +func ParsePlatformEntityType(pet string) (PlatformEntityType, error) { + switch pet { + case platformUsersScopeStr: + return PlatformUsersScope, nil + case platformDomainsScopeStr: + return PlatformDomainsScope, nil + default: + return 0, fmt.Errorf("unknown platform entity type %s", pet) + } +} + +func (pet PlatformEntityType) MarshalJSON() ([]byte, error) { + return []byte(pet.String()), nil +} + +func (pet PlatformEntityType) MarshalText() (text []byte, err error) { + return []byte(pet.String()), nil +} + +func (pet *PlatformEntityType) UnmarshalText(data []byte) (err error) { + *pet, err = ParsePlatformEntityType(string(data)) + return err +} + +// ScopeValue interface for Any entity ids or for sets of entity ids. +type ScopeValue interface { + Contains(id string) bool + Values() []string + AddValues(ids ...string) error + RemoveValues(ids ...string) error +} + +// AnyIDs implements ScopeValue for any entity id value. +type AnyIDs struct{} + +func (s AnyIDs) Contains(id string) bool { return true } +func (s AnyIDs) Values() []string { return []string{"*"} } +func (s *AnyIDs) AddValues(ids ...string) error { return errAddEntityToAnyIDs } +func (s *AnyIDs) RemoveValues(ids ...string) error { return errAddEntityToAnyIDs } + +// SelectedIDs implements ScopeValue for sets of entity ids. +type SelectedIDs map[string]struct{} + +func (s SelectedIDs) Contains(id string) bool { _, ok := s[id]; return ok } +func (s SelectedIDs) Values() []string { + values := []string{} + for value := range s { + values = append(values, value) + } + return values +} + +func (s *SelectedIDs) AddValues(ids ...string) error { + if *s == nil { + *s = make(SelectedIDs) + } + for _, id := range ids { + (*s)[id] = struct{}{} + } + return nil +} + +func (s *SelectedIDs) RemoveValues(ids ...string) error { + if *s == nil { + return nil + } + for _, id := range ids { + delete(*s, id) + } + return nil +} + +// OperationScope contains map of OperationType with value of AnyIDs or SelectedIDs. +type OperationScope map[OperationType]ScopeValue + +func (os *OperationScope) UnmarshalJSON(data []byte) error { + type tempOperationScope map[OperationType]json.RawMessage + + var tempScope tempOperationScope + if err := json.Unmarshal(data, &tempScope); err != nil { + return err + } + // Initialize the Operations map + *os = OperationScope{} + + for opType, rawMessage := range tempScope { + var stringValue string + var stringArrayValue []string + + // Try to unmarshal as string + if err := json.Unmarshal(rawMessage, &stringValue); err == nil { + if err := os.Add(opType, stringValue); err != nil { + return err + } + continue + } + + // Try to unmarshal as []string + if err := json.Unmarshal(rawMessage, &stringArrayValue); err == nil { + if err := os.Add(opType, stringArrayValue...); err != nil { + return err + } + continue + } + + // If neither unmarshalling succeeded, return an error + return fmt.Errorf("invalid ScopeValue for OperationType %v", opType) + } + + return nil +} + +func (os OperationScope) MarshalJSON() ([]byte, error) { + tempOperationScope := make(map[OperationType]interface{}) + for oType, scope := range os { + value := scope.Values() + if len(value) == 1 && value[0] == "*" { + tempOperationScope[oType] = "*" + continue + } + tempOperationScope[oType] = value + } + + b, err := json.Marshal(tempOperationScope) + if err != nil { + return nil, err + } + return b, nil +} + +func (os *OperationScope) Add(operation OperationType, entityIDs ...string) error { + var value ScopeValue + + if os == nil { + os = &OperationScope{} + } + + if len(entityIDs) == 0 { + return fmt.Errorf("entity ID is missing") + } + switch { + case len(entityIDs) == 1 && entityIDs[0] == "*": + value = &AnyIDs{} + default: + var sids SelectedIDs + for _, entityID := range entityIDs { + if entityID == "*" { + return fmt.Errorf("list contains wildcard") + } + if sids == nil { + sids = make(SelectedIDs) + } + sids[entityID] = struct{}{} + } + value = &sids + } + (*os)[operation] = value + return nil +} + +func (os *OperationScope) Delete(operation OperationType, entityIDs ...string) error { + if os == nil { + return nil + } + + opEntityIDs, exists := (*os)[operation] + if !exists { + return nil + } + + if len(entityIDs) == 0 { + return fmt.Errorf("failed to delete operation %s: entity ID is missing", operation.String()) + } + + switch eIDs := opEntityIDs.(type) { + case *AnyIDs: + if !(len(entityIDs) == 1 && entityIDs[0] == "*") { + return fmt.Errorf("failed to delete operation %s: invalid list", operation.String()) + } + delete((*os), operation) + return nil + case *SelectedIDs: + for _, entityID := range entityIDs { + if !eIDs.Contains(entityID) { + return fmt.Errorf("failed to delete operation %s: invalid entity ID in list", operation.String()) + } + } + for _, entityID := range entityIDs { + delete(*eIDs, entityID) + if len(*eIDs) == 0 { + delete((*os), operation) + } + } + return nil + default: + return fmt.Errorf("failed to delete operation: invalid entity id type %d", operation) + } +} + +func (os *OperationScope) Check(operation OperationType, entityIDs ...string) bool { + if os == nil { + return false + } + + if scopeValue, ok := (*os)[operation]; ok { + if len(entityIDs) == 0 { + _, ok := scopeValue.(*AnyIDs) + return ok + } + for _, entityID := range entityIDs { + if !scopeValue.Contains(entityID) { + return false + } + } + return true + } + + return false +} + +type DomainScope struct { + DomainManagement OperationScope `json:"domain_management,omitempty"` + Entities map[DomainEntityType]OperationScope `json:"entities,omitempty"` +} + +// Add entry in Domain scope. +func (ds *DomainScope) Add(domainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + if ds == nil { + return fmt.Errorf("failed to add domain %s scope: domain_scope is nil and not initialized", domainEntityType) + } + + if domainEntityType < DomainManagementScope || domainEntityType > DomainThingsScope { + return fmt.Errorf("failed to add domain %d scope: invalid domain entity type", domainEntityType) + } + if domainEntityType == DomainManagementScope { + if err := ds.DomainManagement.Add(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete domain management scope: %w", err) + } + } + + if ds.Entities == nil { + ds.Entities = make(map[DomainEntityType]OperationScope) + } + + opReg, ok := ds.Entities[domainEntityType] + if !ok { + opReg = OperationScope{} + } + + if err := opReg.Add(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to add domain %s scope: %w ", domainEntityType.String(), err) + } + ds.Entities[domainEntityType] = opReg + return nil +} + +// Delete entry in Domain scope. +func (ds *DomainScope) Delete(domainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + if ds == nil { + return nil + } + + if domainEntityType < DomainManagementScope || domainEntityType > DomainThingsScope { + return fmt.Errorf("failed to delete domain %d scope: invalid domain entity type", domainEntityType) + } + if ds.Entities == nil { + return nil + } + + if domainEntityType == DomainManagementScope { + if err := ds.DomainManagement.Delete(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete domain management scope: %w", err) + } + } + + os, exists := ds.Entities[domainEntityType] + if !exists { + return nil + } + + if err := os.Delete(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete domain %s scope: %w", domainEntityType.String(), err) + } + + if len(os) == 0 { + delete(ds.Entities, domainEntityType) + } + return nil +} + +// Check entry in Domain scope. +func (ds *DomainScope) Check(domainEntityType DomainEntityType, operation OperationType, ids ...string) bool { + if ds.Entities == nil { + return false + } + if domainEntityType < DomainManagementScope || domainEntityType > DomainThingsScope { + return false + } + if domainEntityType == DomainManagementScope { + return ds.DomainManagement.Check(operation, ids...) + } + os, exists := ds.Entities[domainEntityType] + if !exists { + return false + } + + return os.Check(operation, ids...) +} + +// Example Scope as JSON +// +// { +// "platform": { +// "users": { +// "create": {}, +// "read": {}, +// "list": {}, +// "update": {}, +// "delete": {} +// } +// }, +// "domains": { +// "domain_1": { +// "entities": { +// "groups": { +// "create": {}, // this for all groups in domain +// }, +// "channels": { +// // for particular channel in domain +// "delete": { +// "channel1": {}, +// "channel2":{} +// } +// }, +// "things": { +// "update": {} // this for all things in domain +// } +// } +// } +// } +// } +type Scope struct { + Users OperationScope `json:"users,omitempty"` + Domains map[string]DomainScope `json:"domains,omitempty"` +} + +// Add entry in Domain scope. +func (s *Scope) Add(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + if s == nil { + return fmt.Errorf("failed to add platform %s scope: scope is nil and not initialized", platformEntityType.String()) + } + switch platformEntityType { + case PlatformUsersScope: + if err := s.Users.Add(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to add platform %s scope: %w", platformEntityType.String(), err) + } + case PlatformDomainsScope: + if optionalDomainID == "" { + return fmt.Errorf("failed to add platform %s scope: invalid domain id", platformEntityType.String()) + } + if s.Domains == nil || len(s.Domains) == 0 { + s.Domains = make(map[string]DomainScope) + } + + ds, ok := s.Domains[optionalDomainID] + if !ok { + ds = DomainScope{} + } + if err := ds.Add(optionalDomainEntityType, operation, entityIDs...); err != nil { + return fmt.Errorf("failed to add platform %s id %s scope : %w", platformEntityType.String(), optionalDomainID, err) + } + s.Domains[optionalDomainID] = ds + default: + return fmt.Errorf("failed to add platform %d scope: invalid platform entity type ", platformEntityType) + } + return nil +} + +// Delete entry in Domain scope. +func (s *Scope) Delete(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + if s == nil { + return nil + } + switch platformEntityType { + case PlatformUsersScope: + if err := s.Users.Delete(operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete platform %s scope: %w", platformEntityType.String(), err) + } + case PlatformDomainsScope: + if optionalDomainID == "" { + return fmt.Errorf("failed to delete platform %s scope: invalid domain id", platformEntityType.String()) + } + ds, ok := s.Domains[optionalDomainID] + if !ok { + return nil + } + if err := ds.Delete(optionalDomainEntityType, operation, entityIDs...); err != nil { + return fmt.Errorf("failed to delete platform %s id %s scope : %w", platformEntityType.String(), optionalDomainID, err) + } + default: + return fmt.Errorf("failed to add platform %d scope: invalid platform entity type ", platformEntityType) + } + return nil +} + +// Check entry in Domain scope. +func (s *Scope) Check(platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) bool { + if s == nil { + return false + } + switch platformEntityType { + case PlatformUsersScope: + return s.Users.Check(operation, entityIDs...) + case PlatformDomainsScope: + ds, ok := s.Domains[optionalDomainID] + if !ok { + return false + } + return ds.Check(optionalDomainEntityType, operation, entityIDs...) + default: + return false + } +} + +func (s *Scope) String() string { + str, err := json.Marshal(s) // , "", " ") + if err != nil { + return fmt.Sprintf("failed to convert scope to string: json marshal error :%s", err.Error()) + } + return string(str) +} + +// PAT represents Personal Access Token. +type PAT struct { + ID string `json:"id,omitempty"` + User string `json:"user,omitempty"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Secret string `json:"secret,omitempty"` + Scope Scope `json:"scope,omitempty"` + IssuedAt time.Time `json:"issued_at,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty"` + LastUsedAt time.Time `json:"last_used_at,omitempty"` + Revoked bool `json:"revoked,omitempty"` + RevokedAt time.Time `json:"revoked_at,omitempty"` +} + +type PATSPageMeta struct { + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` +} +type PATSPage struct { + Total uint64 `json:"total"` + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + PATS []PAT `json:"pats"` +} + +func (pat *PAT) String() string { + str, err := json.MarshalIndent(pat, "", " ") + if err != nil { + return fmt.Sprintf("failed to convert PAT to string: json marshal error :%s", err.Error()) + } + return string(str) +} + +// Expired verifies if the key is expired. +func (pat PAT) Expired() bool { + return pat.ExpiresAt.UTC().Before(time.Now().UTC()) +} + +// PATS specifies function which are required for Personal access Token implementation. +//go:generate mockery --name PATS --output=./mocks --filename pats.go --quiet --note "Copyright (c) Abstract Machines" + +type PATS interface { + // Create function creates new PAT for given valid inputs. + CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope Scope) (PAT, error) + + // UpdateName function updates the name for the given PAT ID. + UpdatePATName(ctx context.Context, token, patID, name string) (PAT, error) + + // UpdateDescription function updates the description for the given PAT ID. + UpdatePATDescription(ctx context.Context, token, patID, description string) (PAT, error) + + // Retrieve function retrieves the PAT for given ID. + RetrievePAT(ctx context.Context, token, patID string) (PAT, error) + + // List function lists all the PATs for the user. + ListPATS(ctx context.Context, token string, pm PATSPageMeta) (PATSPage, error) + + // Delete function deletes the PAT for given ID. + DeletePAT(ctx context.Context, token, patID string) error + + // ResetSecret function reset the secret and creates new secret for the given ID. + ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (PAT, error) + + // RevokeSecret function revokes the secret for the given ID. + RevokePATSecret(ctx context.Context, token, patID string) error + + // AddScope function adds a new scope entry. + AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + + // RemoveScope function removes a scope entry. + RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + + // ClearAllScope function removes all scope entry. + ClearPATAllScopeEntry(ctx context.Context, token, patID string) error + + // IdentifyPAT function will valid the secret. + IdentifyPAT(ctx context.Context, paToken string) (PAT, error) + + // AuthorizePAT function will valid the secret and check the given scope exists. + AuthorizePAT(ctx context.Context, paToken string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error + + // CheckPAT function will check the given scope exists. + CheckPAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error +} + +// PATSRepository specifies PATS persistence API. +// +//go:generate mockery --name PATSRepository --output=./mocks --filename patsrepo.go --quiet --note "Copyright (c) Abstract Machines" +type PATSRepository interface { + // Save persists the PAT + Save(ctx context.Context, pat PAT) (err error) + + // Retrieve retrieves users PAT by its unique identifier. + Retrieve(ctx context.Context, userID, patID string) (pat PAT, err error) + + // RetrieveSecretAndRevokeStatus retrieves secret and revoke status of PAT by its unique identifier. + RetrieveSecretAndRevokeStatus(ctx context.Context, userID, patID string) (string, bool, error) + + // UpdateName updates the name of a PAT. + UpdateName(ctx context.Context, userID, patID, name string) (PAT, error) + + // UpdateDescription updates the description of a PAT. + UpdateDescription(ctx context.Context, userID, patID, description string) (PAT, error) + + // UpdateTokenHash updates the token hash of a PAT. + UpdateTokenHash(ctx context.Context, userID, patID, tokenHash string, expiryAt time.Time) (PAT, error) + + // RetrieveAll retrieves all PATs belongs to userID. + RetrieveAll(ctx context.Context, userID string, pm PATSPageMeta) (pats PATSPage, err error) + + // Revoke PAT with provided ID. + Revoke(ctx context.Context, userID, patID string) error + + // Reactivate PAT with provided ID. + Reactivate(ctx context.Context, userID, patID string) error + + // Remove removes Key with provided ID. + Remove(ctx context.Context, userID, patID string) error + + AddScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + + RemoveScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) + + CheckScopeEntry(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error + + RemoveAllScopeEntry(ctx context.Context, userID, patID string) error +} diff --git a/auth/service.go b/auth/service.go index 8bb5f1342d..26965a40f0 100644 --- a/auth/service.go +++ b/auth/service.go @@ -5,18 +5,24 @@ package auth import ( "context" + "encoding/base64" "fmt" + "math/rand" "strings" "time" "github.com/absmach/magistrala" "github.com/absmach/magistrala/pkg/errors" svcerr "github.com/absmach/magistrala/pkg/errors/service" + "github.com/google/uuid" ) const ( - recoveryDuration = 5 * time.Minute - defLimit = 100 + recoveryDuration = 5 * time.Minute + defLimit = 100 + randStr = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890!@#$%^&&*|+-=" + patPrefix = "pat" + patSecretSeparator = "_" ) var ( @@ -36,7 +42,11 @@ var ( errRemoveLocalPolicy = errors.New("failed to remove from local policy copy") errRemovePolicyEngine = errors.New("failed to remove from policy engine") // errInvalidEntityType indicates invalid entity type. - errInvalidEntityType = errors.New("invalid entity type") + errInvalidEntityType = errors.New("invalid entity type") + errMalformedPAT = errors.New("malformed personal access token") + errFailedToParseUUID = errors.New("failed to parse string to UUID") + errInvalidLenFor2UUIDs = errors.New("invalid input length for 2 UUID, excepted 32 byte") + errRevokedPAT = errors.New("revoked pat") ) var ( @@ -71,6 +81,13 @@ var ( AdminPermission, MembershipPermission, } + + errCreatePAT = errors.New("failed to create PAT") + errUpdatePAT = errors.New("failed to update PAT") + errRetrievePAT = errors.New("failed to retrieve PAT") + errDeletePAT = errors.New("failed to delete PAT") + errRevokePAT = errors.New("failed to revoke PAT") + errClearAllScope = errors.New("failed to clear all entry in scope") ) // Authn specifies an API that must be fulfilled by the domain service @@ -105,6 +122,7 @@ type Service interface { Authn Authz Domains + PATS } var _ Service = (*service)(nil) @@ -112,6 +130,8 @@ var _ Service = (*service)(nil) type service struct { keys KeyRepository domains DomainsRepository + pats PATSRepository + hasher Hasher idProvider magistrala.IDProvider agent PolicyAgent tokenizer Tokenizer @@ -121,11 +141,13 @@ type service struct { } // New instantiates the auth service implementation. -func New(keys KeyRepository, domains DomainsRepository, idp magistrala.IDProvider, tokenizer Tokenizer, policyAgent PolicyAgent, loginDuration, refreshDuration, invitationDuration time.Duration) Service { +func New(keys KeyRepository, domains DomainsRepository, pats PATSRepository, hasher Hasher, idp magistrala.IDProvider, tokenizer Tokenizer, policyAgent PolicyAgent, loginDuration, refreshDuration, invitationDuration time.Duration) Service { return &service{ tokenizer: tokenizer, domains: domains, keys: keys, + pats: pats, + hasher: hasher, idProvider: idp, agent: policyAgent, loginDuration: loginDuration, @@ -208,10 +230,14 @@ func (svc service) Authorize(ctx context.Context, pr PolicyReq) error { return errors.Wrap(svcerr.ErrAuthentication, err) } if key.Subject == "" { - if pr.ObjectType == GroupType || pr.ObjectType == ThingType || pr.ObjectType == DomainType { + switch { + case pr.ObjectType == GroupType || pr.ObjectType == ThingType || pr.ObjectType == DomainType: return svcerr.ErrDomainAuthorization + case pr.ObjectType == UserType: + key.Subject = key.User + default: + return svcerr.ErrAuthentication } - return svcerr.ErrAuthentication } pr.Subject = key.Subject pr.Domain = key.Domain @@ -1066,3 +1092,256 @@ func (svc service) DeleteEntityPolicies(ctx context.Context, entityType, id stri return errInvalidEntityType } } + +func (svc service) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope Scope) (PAT, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PAT{}, err + } + + id, err := svc.idProvider.ID() + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrCreateEntity, err) + } + secret, hash, err := svc.generateSecretAndHash(key.User, id) + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrCreateEntity, err) + } + + now := time.Now() + pat := PAT{ + ID: id, + User: key.User, + Name: name, + Description: description, + Secret: hash, + IssuedAt: now, + ExpiresAt: now.Add(duration), + Scope: scope, + } + if err := svc.pats.Save(ctx, pat); err != nil { + return PAT{}, errors.Wrap(errCreatePAT, err) + } + pat.Secret = secret + return pat, nil +} + +func (svc service) UpdatePATName(ctx context.Context, token, patID, name string) (PAT, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PAT{}, err + } + pat, err := svc.pats.UpdateName(ctx, key.User, patID, name) + if err != nil { + return PAT{}, errors.Wrap(errUpdatePAT, err) + } + return pat, nil +} + +func (svc service) UpdatePATDescription(ctx context.Context, token, patID, description string) (PAT, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PAT{}, err + } + pat, err := svc.pats.UpdateDescription(ctx, key.User, patID, description) + if err != nil { + return PAT{}, errors.Wrap(errUpdatePAT, err) + } + return pat, nil +} + +func (svc service) RetrievePAT(ctx context.Context, token, patID string) (PAT, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PAT{}, err + } + + pat, err := svc.pats.Retrieve(ctx, key.User, patID) + if err != nil { + return PAT{}, errors.Wrap(errRetrievePAT, err) + } + return pat, nil +} + +func (svc service) ListPATS(ctx context.Context, token string, pm PATSPageMeta) (PATSPage, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PATSPage{}, err + } + patsPage, err := svc.pats.RetrieveAll(ctx, key.User, pm) + if err != nil { + return PATSPage{}, errors.Wrap(errRetrievePAT, err) + } + return patsPage, nil +} + +func (svc service) DeletePAT(ctx context.Context, token, patID string) error { + key, err := svc.Identify(ctx, token) + if err != nil { + return err + } + if err := svc.pats.Remove(ctx, key.User, patID); err != nil { + return errors.Wrap(errDeletePAT, err) + } + return nil +} + +func (svc service) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (PAT, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return PAT{}, err + } + + // Generate new HashToken take place here + secret, hash, err := svc.generateSecretAndHash(key.User, patID) + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + pat, err := svc.pats.UpdateTokenHash(ctx, key.User, patID, hash, time.Now().Add(duration)) + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + + if err := svc.pats.Reactivate(ctx, key.User, patID); err != nil { + return PAT{}, errors.Wrap(svcerr.ErrUpdateEntity, err) + } + pat.Secret = secret + pat.Revoked = false + pat.RevokedAt = time.Time{} + return pat, nil +} + +func (svc service) RevokePATSecret(ctx context.Context, token, patID string) error { + key, err := svc.Identify(ctx, token) + if err != nil { + return err + } + + if err := svc.pats.Revoke(ctx, key.User, patID); err != nil { + return errors.Wrap(errRevokePAT, err) + } + return nil +} + +func (svc service) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return Scope{}, err + } + scope, err := svc.pats.AddScopeEntry(ctx, key.User, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if err != nil { + return Scope{}, errors.Wrap(errRevokePAT, err) + } + return scope, nil +} + +func (svc service) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) (Scope, error) { + key, err := svc.Identify(ctx, token) + if err != nil { + return Scope{}, err + } + scope, err := svc.pats.RemoveScopeEntry(ctx, key.User, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) + if err != nil { + return Scope{}, err + } + return scope, nil +} + +func (svc service) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error { + key, err := svc.Identify(ctx, token) + if err != nil { + return err + } + if err := svc.pats.RemoveAllScopeEntry(ctx, key.User, patID); err != nil { + return errors.Wrap(errClearAllScope, err) + } + return nil +} + +func (svc service) IdentifyPAT(ctx context.Context, secret string) (PAT, error) { + parts := strings.Split(secret, patSecretSeparator) + if len(parts) != 3 && parts[0] != patPrefix { + return PAT{}, errors.Wrap(svcerr.ErrAuthentication, errMalformedPAT) + } + userID, patID, err := decode(parts[1]) + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrAuthentication, errMalformedPAT) + } + secretHash, revoked, err := svc.pats.RetrieveSecretAndRevokeStatus(ctx, userID.String(), patID.String()) + if err != nil { + return PAT{}, errors.Wrap(svcerr.ErrAuthentication, err) + } + if revoked { + return PAT{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedPAT) + } + if err := svc.hasher.Compare(secret, secretHash); err != nil { + return PAT{}, errors.Wrap(svcerr.ErrAuthentication, err) + } + return PAT{ID: patID.String(), User: userID.String()}, nil +} + +func (svc service) AuthorizePAT(ctx context.Context, paToken string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + res, err := svc.IdentifyPAT(ctx, paToken) + if err != nil { + return err + } + if err := svc.pats.CheckScopeEntry(ctx, res.User, res.ID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...); err != nil { + return errors.Wrap(svcerr.ErrAuthorization, err) + } + return nil +} + +func (svc service) CheckPAT(ctx context.Context, userID, patID string, platformEntityType PlatformEntityType, optionalDomainID string, optionalDomainEntityType DomainEntityType, operation OperationType, entityIDs ...string) error { + if err := svc.pats.CheckScopeEntry(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...); err != nil { + return errors.Wrap(svcerr.ErrAuthorization, err) + } + return nil +} + +func (svc service) generateSecretAndHash(userID, patID string) (string, string, error) { + uID, err := uuid.Parse(userID) + if err != nil { + return "", "", errors.Wrap(errFailedToParseUUID, err) + } + pID, err := uuid.Parse(patID) + if err != nil { + return "", "", errors.Wrap(errFailedToParseUUID, err) + } + + secret := patPrefix + patSecretSeparator + encode(uID, pID) + patSecretSeparator + generateRandomString(100) + secretHash, err := svc.hasher.Hash(secret) + return secret, secretHash, err +} + +func encode(userID, patID uuid.UUID) string { + c := append(userID[:], patID[:]...) + return base64.StdEncoding.EncodeToString(c) +} + +func decode(encoded string) (uuid.UUID, uuid.UUID, error) { + data, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + return uuid.Nil, uuid.Nil, err + } + + if len(data) != 32 { + return uuid.Nil, uuid.Nil, errInvalidLenFor2UUIDs + } + + var userID, patID uuid.UUID + copy(userID[:], data[:16]) + copy(patID[:], data[16:]) + + return userID, patID, nil +} + +func generateRandomString(n int) string { + letterRunes := []rune(randStr) + rand.New(rand.NewSource(time.Now().UnixNano())) + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} diff --git a/auth/service_test.go b/auth/service_test.go index 7547c87db5..4987683e47 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -58,15 +58,19 @@ var ( ) var ( - krepo *mocks.KeyRepository - prepo *mocks.PolicyAgent - drepo *mocks.DomainsRepository + krepo *mocks.KeyRepository + prepo *mocks.PolicyAgent + drepo *mocks.DomainsRepository + patsrepo *mocks.PATSRepository + hasher *mocks.Hasher ) func newService() (auth.Service, string) { krepo = new(mocks.KeyRepository) prepo = new(mocks.PolicyAgent) drepo = new(mocks.DomainsRepository) + patsrepo = new(mocks.PATSRepository) + hasher = new(mocks.Hasher) idProvider := uuid.NewMock() t := jwt.New([]byte(secret)) @@ -80,7 +84,7 @@ func newService() (auth.Service, string) { } token, _ := t.Issue(key) - return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration), token + return auth.New(krepo, drepo, patsrepo, hasher, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration), token } func TestIssue(t *testing.T) { diff --git a/auth/tracing/tracing.go b/auth/tracing/tracing.go index fe58626b04..2be137e71a 100644 --- a/auth/tracing/tracing.go +++ b/auth/tracing/tracing.go @@ -6,6 +6,7 @@ package tracing import ( "context" "fmt" + "time" "github.com/absmach/magistrala/auth" "go.opentelemetry.io/otel/attribute" @@ -312,3 +313,141 @@ func (tm *tracingMiddleware) DeleteEntityPolicies(ctx context.Context, entityTyp defer span.End() return tm.svc.DeleteEntityPolicies(ctx, entityType, id) } + +func (tm *tracingMiddleware) CreatePAT(ctx context.Context, token, name, description string, duration time.Duration, scope auth.Scope) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "create_pat", trace.WithAttributes( + attribute.String("name", name), + attribute.String("description", description), + attribute.String("duration", duration.String()), + attribute.String("scope", scope.String()), + )) + defer span.End() + return tm.svc.CreatePAT(ctx, token, name, description, duration, scope) +} + +func (tm *tracingMiddleware) UpdatePATName(ctx context.Context, token, patID, name string) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "update_pat_name", trace.WithAttributes( + attribute.String("pat_id", patID), + attribute.String("name", name), + )) + defer span.End() + return tm.svc.UpdatePATName(ctx, token, patID, name) +} + +func (tm *tracingMiddleware) UpdatePATDescription(ctx context.Context, token, patID, description string) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "update_pat_description", trace.WithAttributes( + attribute.String("pat_id", patID), + attribute.String("description", description), + )) + defer span.End() + return tm.svc.UpdatePATDescription(ctx, token, patID, description) +} + +func (tm *tracingMiddleware) RetrievePAT(ctx context.Context, token, patID string) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "retrieve_pat", trace.WithAttributes( + attribute.String("pat_id", patID), + )) + defer span.End() + return tm.svc.RetrievePAT(ctx, token, patID) +} + +func (tm *tracingMiddleware) ListPATS(ctx context.Context, token string, pm auth.PATSPageMeta) (auth.PATSPage, error) { + ctx, span := tm.tracer.Start(ctx, "list_pat", trace.WithAttributes( + attribute.Int64("limit", int64(pm.Limit)), + attribute.Int64("offset", int64(pm.Offset)), + )) + defer span.End() + return tm.svc.ListPATS(ctx, token, pm) +} + +func (tm *tracingMiddleware) DeletePAT(ctx context.Context, token, patID string) error { + ctx, span := tm.tracer.Start(ctx, "delete_pat", trace.WithAttributes( + attribute.String("pat_id", patID), + )) + defer span.End() + return tm.svc.DeletePAT(ctx, token, patID) +} + +func (tm *tracingMiddleware) ResetPATSecret(ctx context.Context, token, patID string, duration time.Duration) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "reset_pat_secret", trace.WithAttributes( + attribute.String("pat_id", patID), + attribute.String("duration", duration.String()), + )) + defer span.End() + return tm.svc.ResetPATSecret(ctx, token, patID, duration) +} + +func (tm *tracingMiddleware) RevokePATSecret(ctx context.Context, token, patID string) error { + ctx, span := tm.tracer.Start(ctx, "revoke_pat_secret", trace.WithAttributes( + attribute.String("pat_id", patID), + )) + defer span.End() + return tm.svc.RevokePATSecret(ctx, token, patID) +} + +func (tm *tracingMiddleware) AddPATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + ctx, span := tm.tracer.Start(ctx, "add_pat_scope_entry", trace.WithAttributes( + attribute.String("pat_id", patID), + attribute.String("platform_entity", platformEntityType.String()), + attribute.String("optional_domain_id", optionalDomainID), + attribute.String("optional_domain_entity", optionalDomainEntityType.String()), + attribute.String("operation", operation.String()), + attribute.StringSlice("entities", entityIDs), + )) + defer span.End() + return tm.svc.AddPATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (tm *tracingMiddleware) RemovePATScopeEntry(ctx context.Context, token, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) (auth.Scope, error) { + ctx, span := tm.tracer.Start(ctx, "remove_pat_scope_entry", trace.WithAttributes( + attribute.String("pat_id", patID), + attribute.String("platform_entity", platformEntityType.String()), + attribute.String("optional_domain_id", optionalDomainID), + attribute.String("optional_domain_entity", optionalDomainEntityType.String()), + attribute.String("operation", operation.String()), + attribute.StringSlice("entities", entityIDs), + )) + defer span.End() + return tm.svc.RemovePATScopeEntry(ctx, token, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (tm *tracingMiddleware) ClearPATAllScopeEntry(ctx context.Context, token, patID string) error { + ctx, span := tm.tracer.Start(ctx, "clear_pat_all_scope_entry", trace.WithAttributes( + attribute.String("pat_id", patID), + )) + defer span.End() + return tm.svc.ClearPATAllScopeEntry(ctx, token, patID) +} + +func (tm *tracingMiddleware) IdentifyPAT(ctx context.Context, paToken string) (auth.PAT, error) { + ctx, span := tm.tracer.Start(ctx, "identity_pat") + defer span.End() + return tm.svc.IdentifyPAT(ctx, paToken) +} + +func (tm *tracingMiddleware) AuthorizePAT(ctx context.Context, paToken string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + ctx, span := tm.tracer.Start(ctx, "authorize_pat", trace.WithAttributes( + attribute.String("personal_access_token", paToken), + attribute.String("platform_entity", platformEntityType.String()), + attribute.String("optional_domain_id", optionalDomainID), + attribute.String("optional_domain_entity", optionalDomainEntityType.String()), + attribute.String("operation", operation.String()), + attribute.StringSlice("entities", entityIDs), + )) + defer span.End() + return tm.svc.AuthorizePAT(ctx, paToken, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} + +func (tm *tracingMiddleware) CheckPAT(ctx context.Context, userID, patID string, platformEntityType auth.PlatformEntityType, optionalDomainID string, optionalDomainEntityType auth.DomainEntityType, operation auth.OperationType, entityIDs ...string) error { + ctx, span := tm.tracer.Start(ctx, "check_pat", trace.WithAttributes( + attribute.String("user_id", userID), + attribute.String("patID", patID), + attribute.String("platform_entity", platformEntityType.String()), + attribute.String("optional_domain_id", optionalDomainID), + attribute.String("optional_domain_entity", optionalDomainEntityType.String()), + attribute.String("operation", operation.String()), + attribute.StringSlice("entities", entityIDs), + )) + defer span.End() + return tm.svc.CheckPAT(ctx, userID, patID, platformEntityType, optionalDomainID, optionalDomainEntityType, operation, entityIDs...) +} diff --git a/bootstrap/postgres/configs.go b/bootstrap/postgres/configs.go index c8b90c09a5..66cbe7d00d 100644 --- a/bootstrap/postgres/configs.go +++ b/bootstrap/postgres/configs.go @@ -282,7 +282,7 @@ func (cr configRepository) Update(ctx context.Context, cfg bootstrap.Config) err } func (cr configRepository) UpdateCert(ctx context.Context, domainID, thingID, clientCert, clientKey, caCert string) (bootstrap.Config, error) { - q := `UPDATE configs SET client_cert = :client_cert, client_key = :client_key, ca_cert = :ca_cert WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id + q := `UPDATE configs SET client_cert = :client_cert, client_key = :client_key, ca_cert = :ca_cert WHERE magistrala_thing = :magistrala_thing AND domain_id = :domain_id RETURNING magistrala_thing, client_cert, client_key, ca_cert` dbcfg := dbConfig{ @@ -443,7 +443,7 @@ func (cr configRepository) UpdateChannel(ctx context.Context, c bootstrap.Channe return errors.Wrap(repoerr.ErrUpdateEntity, err) } - q := `UPDATE channels SET name = :name, metadata = :metadata, updated_at = :updated_at, updated_by = :updated_by + q := `UPDATE channels SET name = :name, metadata = :metadata, updated_at = :updated_at, updated_by = :updated_by WHERE magistrala_channel = :magistrala_channel` if _, err = cr.db.NamedExecContext(ctx, q, dbch); err != nil { return errors.Wrap(errUpdateChannels, err) diff --git a/cmd/auth/main.go b/cmd/auth/main.go index 1a3ae89aa3..67f2c254ff 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -18,11 +18,14 @@ import ( api "github.com/absmach/magistrala/auth/api" grpcapi "github.com/absmach/magistrala/auth/api/grpc" httpapi "github.com/absmach/magistrala/auth/api/http" + "github.com/absmach/magistrala/auth/bolt" "github.com/absmach/magistrala/auth/events" + "github.com/absmach/magistrala/auth/hasher" "github.com/absmach/magistrala/auth/jwt" apostgres "github.com/absmach/magistrala/auth/postgres" "github.com/absmach/magistrala/auth/spicedb" "github.com/absmach/magistrala/auth/tracing" + boltclient "github.com/absmach/magistrala/internal/clients/bolt" mglog "github.com/absmach/magistrala/logger" "github.com/absmach/magistrala/pkg/jaeger" "github.com/absmach/magistrala/pkg/postgres" @@ -37,6 +40,7 @@ import ( "github.com/authzed/grpcutil" "github.com/caarlos0/env/v10" "github.com/jmoiron/sqlx" + "go.etcd.io/bbolt" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -49,6 +53,7 @@ const ( envPrefixHTTP = "MG_AUTH_HTTP_" envPrefixGrpc = "MG_AUTH_GRPC_" envPrefixDB = "MG_AUTH_DB_" + envPrefixPATDB = "MG_AUTH_PAT_DB_" defDB = "auth" defSvcHTTPPort = "8189" defSvcGRPCPort = "8181" @@ -129,7 +134,22 @@ func main() { return } - svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient) + boltDBConfig := boltclient.Config{} + if err := env.ParseWithOptions(&boltDBConfig, env.Options{Prefix: envPrefixPATDB}); err != nil { + logger.Error(fmt.Sprintf("failed to parse bolt db config : %s\n", err.Error())) + exitCode = 1 + return + } + + bClient, err := boltclient.Connect(boltDBConfig, bolt.Init) + if err != nil { + logger.Error(fmt.Sprintf("failed to connect to bolt db : %s\n", err.Error())) + exitCode = 1 + return + } + defer bClient.Close() + + svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient, bClient, boltDBConfig) httpServerConfig := server.Config{Port: defSvcHTTPPort} if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { @@ -203,16 +223,18 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch return nil } -func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental) auth.Service { +func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental, bClient *bbolt.DB, bConfig boltclient.Config) auth.Service { database := postgres.NewDatabase(db, dbConfig, tracer) keysRepo := apostgres.New(database) domainsRepo := apostgres.NewDomainRepository(database) + patsRepo := bolt.NewPATSRepository(bClient, bConfig.Bucket) + hasher := hasher.New() pa := spicedb.NewPolicyAgent(spicedbClient, logger) idProvider := uuid.New() t := jwt.New([]byte(cfg.SecretKey)) - svc := auth.New(keysRepo, domainsRepo, idProvider, t, pa, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) + svc := auth.New(keysRepo, domainsRepo, patsRepo, hasher, idProvider, t, pa, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) svc, err := events.NewEventStoreMiddleware(ctx, svc, cfg.ESURL) if err != nil { logger.Error(fmt.Sprintf("failed to init event store middleware : %s", err)) diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index c206e7529c..ab333e37da 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -15,6 +15,7 @@ volumes: magistrala-mqtt-broker-volume: magistrala-spicedb-db-volume: magistrala-auth-db-volume: + magistrala-pat-db-volume: magistrala-invitations-db-volume: magistrala-ui-db-volume: @@ -132,6 +133,7 @@ services: - magistrala-base-net volumes: - ./spicedb/schema.zed:${MG_SPICEDB_SCHEMA_FILE} + - magistrala-pat-db-volume:/magistrala-data # Auth gRPC mTLS server certificates - type: bind source: ${MG_AUTH_GRPC_SERVER_CERT:-ssl/certs/dummy/server_cert} diff --git a/docker/nginx/nginx-key.conf b/docker/nginx/nginx-key.conf index 153a7b7a42..c9a59e7c34 100644 --- a/docker/nginx/nginx-key.conf +++ b/docker/nginx/nginx-key.conf @@ -114,7 +114,7 @@ http { # Proxy pass to auth service - location ~ ^/(domains) { + location ~ ^/(domains|keys|pats) { include snippets/proxy-headers.conf; add_header Access-Control-Expose-Headers Location; proxy_pass http://auth:${MG_AUTH_HTTP_PORT}; diff --git a/go.mod b/go.mod index 2e367d7ed6..f81a340d75 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/go-redis/redis/v8 v8.11.5 github.com/gocql/gocql v1.6.0 github.com/gofrs/uuid v4.4.0+incompatible + github.com/google/uuid v1.6.0 github.com/gookit/color v1.5.4 github.com/gopcua/opcua v0.1.6 github.com/gorilla/websocket v1.5.3 @@ -44,6 +45,7 @@ require ( github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.9.0 + go.etcd.io/bbolt v1.3.10 go.mongodb.org/mongo-driver v1.15.0 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 @@ -100,7 +102,6 @@ require ( github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect diff --git a/go.sum b/go.sum index bb37ab1192..d3828dc2b2 100644 --- a/go.sum +++ b/go.sum @@ -482,6 +482,8 @@ github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.etcd.io/bbolt v1.3.10 h1:+BqfJTcCzTItrop8mq/lbzL8wSGtj94UO/3U31shqG0= +go.etcd.io/bbolt v1.3.10/go.mod h1:bK3UQLPJZly7IlNmV7uVHJDxfe5aK9Ll93e/74Y9oEQ= go.mongodb.org/mongo-driver v1.15.0 h1:rJCKC8eEliewXjZGf0ddURtl7tTVy1TK3bfl0gkUSLc= go.mongodb.org/mongo-driver v1.15.0/go.mod h1:Vzb0Mk/pa7e6cWw85R4F/endUC3u0U9jGcNU603k65c= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.53.0 h1:9G6E0TXzGFVfTnawRzrPl83iHOAV7L8NJiR8RSGYV1g= diff --git a/internal/clients/bolt/bolt.go b/internal/clients/bolt/bolt.go new file mode 100644 index 0000000000..6db2d1a276 --- /dev/null +++ b/internal/clients/bolt/bolt.go @@ -0,0 +1,83 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package bolt + +import ( + "io/fs" + "strconv" + "time" + + "github.com/absmach/magistrala/pkg/errors" + "github.com/caarlos0/env/v10" + bolt "go.etcd.io/bbolt" +) + +var ( + errConfig = errors.New("failed to load BoltDB configuration") + errConnect = errors.New("failed to connect to BoltDB database") + errInit = errors.New("failed to initialize to BoltDB database") +) + +type FileMode fs.FileMode + +func (fm *FileMode) UnmarshalText(text []byte) error { + temp, err := strconv.ParseUint(string(text), 8, 32) + if err != nil { + return err + } + *fm = FileMode(temp) + return nil +} + +// Config contains BoltDB specific parameters. +type Config struct { + FileDirPath string `env:"FILE_DIR_PATH" envDefault:"./magistrala-data"` + FileName string `env:"FILE_NAME" envDefault:"magistrala-pat.db"` + FileMode FileMode `env:"FILE_MODE" envDefault:"0600"` + Bucket string `env:"BUCKET" envDefault:"magistrala"` + Timeout time.Duration `env:"TIMEOUT" envDefault:"0"` +} + +// Setup load configuration from environment and creates new BoltDB. +func Setup(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { + return SetupDB(envPrefix, initFn) +} + +// SetupDB load configuration from environment,. +func SetupDB(envPrefix string, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { + cfg := Config{} + if err := env.ParseWithOptions(&cfg, env.Options{Prefix: envPrefix}); err != nil { + return nil, errors.Wrap(errConfig, err) + } + bdb, err := Connect(cfg, initFn) + if err != nil { + return nil, err + } + + return bdb, nil +} + +// Connect establishes connection to the BoltDB. +func Connect(cfg Config, initFn func(*bolt.Tx, string) error) (*bolt.DB, error) { + filePath := cfg.FileDirPath + "/" + cfg.FileName + db, err := bolt.Open(filePath, fs.FileMode(cfg.FileMode), nil) + if err != nil { + return nil, errors.Wrap(errConnect, err) + } + if initFn != nil { + if err := Init(db, cfg, initFn); err != nil { + return nil, err + } + } + return db, nil +} + +func Init(db *bolt.DB, cfg Config, initFn func(*bolt.Tx, string) error) error { + if err := db.Update(func(tx *bolt.Tx) error { + return initFn(tx, cfg.Bucket) + }); err != nil { + return errors.Wrap(errInit, err) + } + return nil +} diff --git a/internal/clients/bolt/doc.go b/internal/clients/bolt/doc.go new file mode 100644 index 0000000000..3941091882 --- /dev/null +++ b/internal/clients/bolt/doc.go @@ -0,0 +1,9 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package BoltDB contains the domain concept definitions needed to support +// Magistrala BoltDB database functionality. +// +// It provides the abstraction of the BoltDB database service, which is used +// to configure, setup and connect to the BoltDB database. +package bolt diff --git a/pkg/apiutil/errors.go b/pkg/apiutil/errors.go index 833f8ecbd8..58b56f0276 100644 --- a/pkg/apiutil/errors.go +++ b/pkg/apiutil/errors.go @@ -90,6 +90,9 @@ var ( // ErrMissingHost indicates missing host. ErrMissingHost = errors.New("missing host") + // ErrMissingDescription indicates missing description. + ErrMissingDescription = errors.New("missing description") + // ErrMissingPass indicates missing password. ErrMissingPass = errors.New("missing password") @@ -179,4 +182,7 @@ var ( // ErrInvalidTimeFormat indicates invalid time format i.e not unix time. ErrInvalidTimeFormat = errors.New("invalid time format use unix time") + + // ErrInvalidDuration indicates invalid duration value. + ErrInvalidDuration = errors.New("invalid duration value") )