diff --git a/auth/README.md b/auth/README.md index 4a991e0fb1..ee118b11ad 100644 --- a/auth/README.md +++ b/auth/README.md @@ -59,40 +59,42 @@ Domain consists of the following fields: The service is configured using the environment variables presented in the following table. Note that any unset variables will be replaced with their default values. -| Variable | Description | Default | -| ------------------------------ | ----------------------------------------------------------------------- | ------------------------------- | -| MG_AUTH_LOG_LEVEL | Log level for the Auth service (debug, info, warn, error) | info | -| MG_AUTH_DB_HOST | Database host address | localhost | -| MG_AUTH_DB_PORT | Database host port | 5432 | -| MG_AUTH_DB_USER | Database user | magistrala | -| MG_AUTH_DB_PASSWORD | Database password | magistrala | -| MG_AUTH_DB_NAME | Name of the database used by the service | auth | -| MG_AUTH_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | -| MG_AUTH_DB_SSL_CERT | Path to the PEM encoded certificate file | "" | -| MG_AUTH_DB_SSL_KEY | Path to the PEM encoded key file | "" | -| MG_AUTH_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" | -| MG_AUTH_HTTP_HOST | Auth service HTTP host | "" | -| MG_AUTH_HTTP_PORT | Auth service HTTP port | 8189 | -| MG_AUTH_HTTP_SERVER_CERT | Path to the PEM encoded HTTP server certificate file | "" | -| MG_AUTH_HTTP_SERVER_KEY | Path to the PEM encoded HTTP server key file | "" | -| MG_AUTH_GRPC_HOST | Auth service gRPC host | "" | -| MG_AUTH_GRPC_PORT | Auth service gRPC port | 8181 | -| MG_AUTH_GRPC_SERVER_CERT | Path to the PEM encoded gRPC server certificate file | "" | -| MG_AUTH_GRPC_SERVER_KEY | Path to the PEM encoded gRPC server key file | "" | -| MG_AUTH_GRPC_SERVER_CA_CERTS | Path to the PEM encoded gRPC server CA certificate file | "" | -| MG_AUTH_GRPC_CLIENT_CA_CERTS | Path to the PEM encoded gRPC client CA certificate file | "" | -| MG_AUTH_SECRET_KEY | String used for signing tokens | secret | -| MG_AUTH_ACCESS_TOKEN_DURATION | The access token expiration period | 1h | -| MG_AUTH_REFRESH_TOKEN_DURATION | The refresh token expiration period | 24h | -| MG_AUTH_INVITATION_DURATION | The invitation token expiration period | 168h | -| MG_SPICEDB_HOST | SpiceDB host address | localhost | -| MG_SPICEDB_PORT | SpiceDB host port | 50051 | -| MG_SPICEDB_PRE_SHARED_KEY | SpiceDB pre-shared key | 12345678 | -| MG_SPICEDB_SCHEMA_FILE | Path to SpiceDB schema file | ./docker/spicedb/schema.zed | +| Variable | Description | Default | +| ------------------------------ | ----------------------------------------------------------------------- | ------------------------------ | +| MG_AUTH_LOG_LEVEL | Log level for the Auth service (debug, info, warn, error) | info | +| MG_AUTH_DB_HOST | Database host address | localhost | +| MG_AUTH_DB_PORT | Database host port | 5432 | +| MG_AUTH_DB_USER | Database user | magistrala | +| MG_AUTH_DB_PASSWORD | Database password | magistrala | +| MG_AUTH_DB_NAME | Name of the database used by the service | auth | +| MG_AUTH_DB_SSL_MODE | Database connection SSL mode (disable, require, verify-ca, verify-full) | disable | +| MG_AUTH_DB_SSL_CERT | Path to the PEM encoded certificate file | "" | +| MG_AUTH_DB_SSL_KEY | Path to the PEM encoded key file | "" | +| MG_AUTH_DB_SSL_ROOT_CERT | Path to the PEM encoded root certificate file | "" | +| MG_AUTH_HTTP_HOST | Auth service HTTP host | "" | +| MG_AUTH_HTTP_PORT | Auth service HTTP port | 8189 | +| MG_AUTH_HTTP_SERVER_CERT | Path to the PEM encoded HTTP server certificate file | "" | +| MG_AUTH_HTTP_SERVER_KEY | Path to the PEM encoded HTTP server key file | "" | +| MG_AUTH_GRPC_HOST | Auth service gRPC host | "" | +| MG_AUTH_GRPC_PORT | Auth service gRPC port | 8181 | +| MG_AUTH_GRPC_SERVER_CERT | Path to the PEM encoded gRPC server certificate file | "" | +| MG_AUTH_GRPC_SERVER_KEY | Path to the PEM encoded gRPC server key file | "" | +| MG_AUTH_GRPC_SERVER_CA_CERTS | Path to the PEM encoded gRPC server CA certificate file | "" | +| MG_AUTH_GRPC_CLIENT_CA_CERTS | Path to the PEM encoded gRPC client CA certificate file | "" | +| MG_AUTH_SECRET_KEY | String used for signing tokens | secret | +| MG_AUTH_ACCESS_TOKEN_DURATION | The access token expiration period | 1h | +| MG_AUTH_REFRESH_TOKEN_DURATION | The refresh token expiration period | 24h | +| MG_AUTH_INVITATION_DURATION | The invitation token expiration period | 168h | +| MG_SPICEDB_HOST | SpiceDB host address | localhost | +| MG_SPICEDB_PORT | SpiceDB host port | 50051 | +| MG_SPICEDB_PRE_SHARED_KEY | SpiceDB pre-shared key | 12345678 | +| MG_SPICEDB_SCHEMA_FILE | Path to SpiceDB schema file | ./docker/spicedb/schema.zed | +| MG_AUTH_CACHE_URL | Cache server URL | "redis://localhost:6379/0" | +| MG_AUTH_CACHE_KEY_DURATION | Cache key expiration period | "1h" | | MG_JAEGER_URL | Jaeger server URL | | -| MG_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 | -| MG_SEND_TELEMETRY | Send telemetry to magistrala call home server | true | -| MG_AUTH_ADAPTER_INSTANCE_ID | Adapter instance ID | "" | +| MG_JAEGER_TRACE_RATIO | Jaeger sampling ratio | 1.0 | +| MG_SEND_TELEMETRY | Send telemetry to magistrala call home server | true | +| MG_AUTH_ADAPTER_INSTANCE_ID | Adapter instance ID | "" | ## Deployment @@ -142,6 +144,8 @@ MG_SPICEDB_HOST=localhost \ MG_SPICEDB_PORT=50051 \ MG_SPICEDB_PRE_SHARED_KEY=12345678 \ MG_SPICEDB_SCHEMA_FILE=./docker/spicedb/schema.zed \ +MG_AUTH_CACHE_URL=redis://localhost:6379/0 \ +MG_AUTH_CACHE_KEY_DURATION=1h \ MG_JAEGER_URL=http://localhost:14268/api/traces \ MG_JAEGER_TRACE_RATIO=1.0 \ MG_SEND_TELEMETRY=true \ diff --git a/auth/api/http/keys/endpoint.go b/auth/api/http/keys/endpoint.go index 4c3d1b7ecc..6aa1788b0f 100644 --- a/auth/api/http/keys/endpoint.go +++ b/auth/api/http/keys/endpoint.go @@ -85,3 +85,18 @@ func revokeEndpoint(svc auth.Service) endpoint.Endpoint { return revokeKeyRes{}, nil } } + +func revokeTokenEndpoint(svc auth.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + req := request.(revokeTokenReq) + if err := req.validate(); err != nil { + return nil, err + } + + if err := svc.RevokeToken(ctx, req.token); err != nil { + return nil, err + } + + return revokeKeyRes{}, nil + } +} diff --git a/auth/api/http/keys/endpoint_test.go b/auth/api/http/keys/endpoint_test.go index f200f9f5ee..01a59c74f6 100644 --- a/auth/api/http/keys/endpoint_test.go +++ b/auth/api/http/keys/endpoint_test.go @@ -4,7 +4,6 @@ package keys_test import ( - "context" "encoding/json" "fmt" "io" @@ -16,12 +15,11 @@ import ( "github.com/absmach/magistrala/auth" httpapi "github.com/absmach/magistrala/auth/api/http" - "github.com/absmach/magistrala/auth/jwt" "github.com/absmach/magistrala/auth/mocks" + "github.com/absmach/magistrala/internal/testsutil" mglog "github.com/absmach/magistrala/logger" "github.com/absmach/magistrala/pkg/apiutil" svcerr "github.com/absmach/magistrala/pkg/errors/service" - "github.com/absmach/magistrala/pkg/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) @@ -66,17 +64,6 @@ func (tr testRequest) make() (*http.Response, error) { return tr.client.Do(req) } -func newService() (auth.Service, *mocks.KeyRepository) { - krepo := new(mocks.KeyRepository) - prepo := new(mocks.PolicyAgent) - drepo := new(mocks.DomainsRepository) - idProvider := uuid.NewMock() - - t := jwt.New([]byte(secret)) - - return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration), krepo -} - func newServer(svc auth.Service) *httptest.Server { mux := httpapi.MakeHandler(svc, mglog.NewMock(), "") return httptest.NewServer(mux) @@ -91,9 +78,7 @@ func toJSON(data interface{}) string { } func TestIssue(t *testing.T) { - svc, krepo := newService() - token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) + svc := new(mocks.Service) ts := newServer(svc) defer ts.Close() @@ -108,11 +93,14 @@ func TestIssue(t *testing.T) { req string ct string token string + resp auth.Token + err error status int }{ { desc: "issue login key with empty token", req: toJSON(lk), + resp: auth.Token{AccessToken: "token"}, ct: contentType, token: "", status: http.StatusUnauthorized, @@ -120,29 +108,30 @@ func TestIssue(t *testing.T) { { desc: "issue API key", req: toJSON(ak), + resp: auth.Token{AccessToken: "token"}, ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusCreated, }, { desc: "issue recovery key", req: toJSON(rk), ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusCreated, }, { desc: "issue login key wrong content type", req: toJSON(lk), ct: "", - token: token.AccessToken, + token: "token", status: http.StatusUnsupportedMediaType, }, { desc: "issue recovery key wrong content type", req: toJSON(rk), ct: "", - token: token.AccessToken, + token: "token", status: http.StatusUnsupportedMediaType, }, { @@ -150,6 +139,7 @@ func TestIssue(t *testing.T) { req: toJSON(ak), ct: contentType, token: "wrong", + err: svcerr.ErrAuthentication, status: http.StatusUnauthorized, }, { @@ -157,27 +147,28 @@ func TestIssue(t *testing.T) { req: toJSON(rk), ct: contentType, token: "", + err: svcerr.ErrAuthentication, status: http.StatusUnauthorized, }, { desc: "issue key with invalid request", req: "{", ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusBadRequest, }, { desc: "issue key with invalid JSON", req: "{invalid}", ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusBadRequest, }, { desc: "issue key with invalid JSON content", req: `{"Type":{"key":"AccessToken"}}`, ct: contentType, - token: token.AccessToken, + token: "token", status: http.StatusBadRequest, }, } @@ -191,24 +182,16 @@ func TestIssue(t *testing.T) { token: tc.token, body: strings.NewReader(tc.req), } - repocall := krepo.On("Save", mock.Anything, mock.Anything).Return("", nil) + svcCall := svc.On("Issue", mock.Anything, tc.token, mock.Anything).Return(tc.resp, tc.err) res, err := req.make() assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) - repocall.Unset() + svcCall.Unset() } } func TestRetrieve(t *testing.T) { - svc, krepo := newService() - token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id} - - repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) - k, err := svc.Issue(context.Background(), token.AccessToken, key) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - repocall.Unset() + svc := new(mocks.Service) ts := newServer(svc) defer ts.Close() @@ -224,8 +207,8 @@ func TestRetrieve(t *testing.T) { }{ { desc: "retrieve an existing key", - id: k.AccessToken, - token: token.AccessToken, + id: testsutil.GenerateUUID(t), + token: "token", key: auth.Key{ Subject: id, Type: auth.AccessKey, @@ -238,13 +221,13 @@ func TestRetrieve(t *testing.T) { { desc: "retrieve a non-existing key", id: "non-existing", - token: token.AccessToken, - status: http.StatusBadRequest, + token: "token", + status: http.StatusNotFound, err: svcerr.ErrNotFound, }, { desc: "retrieve a key with an invalid token", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), token: "wrong", status: http.StatusUnauthorized, err: svcerr.ErrAuthentication, @@ -252,7 +235,7 @@ func TestRetrieve(t *testing.T) { { desc: "retrieve a key with an empty token", token: "", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), status: http.StatusUnauthorized, err: svcerr.ErrAuthentication, }, @@ -265,24 +248,16 @@ func TestRetrieve(t *testing.T) { url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id), token: tc.token, } - repocall := krepo.On("Retrieve", mock.Anything, mock.Anything, mock.Anything).Return(tc.key, tc.err) + svcCall := svc.On("RetrieveKey", mock.Anything, tc.token, tc.id).Return(tc.key, tc.err) res, err := req.make() assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) - repocall.Unset() + svcCall.Unset() } } func TestRevoke(t *testing.T) { - svc, krepo := newService() - token, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - key := auth.Key{Type: auth.APIKey, IssuedAt: time.Now(), Subject: id} - - repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) - k, err := svc.Issue(context.Background(), token.AccessToken, key) - assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) - repocall.Unset() + svc := new(mocks.Service) ts := newServer(svc) defer ts.Close() @@ -292,29 +267,31 @@ func TestRevoke(t *testing.T) { desc string id string token string + err error status int }{ { desc: "revoke an existing key", - id: k.AccessToken, - token: token.AccessToken, + id: testsutil.GenerateUUID(t), + token: "token", status: http.StatusNoContent, }, { desc: "revoke a non-existing key", id: "non-existing", - token: token.AccessToken, + token: "token", status: http.StatusNoContent, }, { desc: "revoke key with invalid token", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), token: "wrong", + err: svcerr.ErrAuthentication, status: http.StatusUnauthorized, }, { desc: "revoke key with empty token", - id: k.AccessToken, + id: testsutil.GenerateUUID(t), token: "", status: http.StatusUnauthorized, }, @@ -327,10 +304,63 @@ func TestRevoke(t *testing.T) { url: fmt.Sprintf("%s/keys/%s", ts.URL, tc.id), token: tc.token, } - repocall := krepo.On("Remove", mock.Anything, mock.Anything, mock.Anything).Return(nil) + svcCall := svc.On("Revoke", mock.Anything, tc.token, tc.id).Return(tc.err) + res, err := req.make() + assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) + assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) + svcCall.Unset() + } +} + +func TestRevokeToken(t *testing.T) { + svc := new(mocks.Service) + + ts := newServer(svc) + defer ts.Close() + client := ts.Client() + + cases := []struct { + desc string + id string + token string + err error + status int + }{ + { + desc: "revoke an existing token", + token: "token", + status: http.StatusNoContent, + }, + { + desc: "revoke a non-existing token", + token: "token", + err: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + }, + { + desc: "revoke invalid token", + token: "wrong", + err: svcerr.ErrAuthentication, + status: http.StatusUnauthorized, + }, + { + desc: "revoke empty token", + token: "", + status: http.StatusUnauthorized, + }, + } + + for _, tc := range cases { + req := testRequest{ + client: client, + method: http.MethodDelete, + url: fmt.Sprintf("%s/keys/", ts.URL), + token: tc.token, + } + svcCall := svc.On("RevokeToken", mock.Anything, tc.token).Return(tc.err) res, err := req.make() assert.Nil(t, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) assert.Equal(t, tc.status, res.StatusCode, fmt.Sprintf("%s: expected status code %d got %d", tc.desc, tc.status, res.StatusCode)) - repocall.Unset() + svcCall.Unset() } } diff --git a/auth/api/http/keys/requests.go b/auth/api/http/keys/requests.go index 53542c60e6..2500769b26 100644 --- a/auth/api/http/keys/requests.go +++ b/auth/api/http/keys/requests.go @@ -46,3 +46,15 @@ func (req keyReq) validate() error { } return nil } + +type revokeTokenReq struct { + token string +} + +func (req revokeTokenReq) validate() error { + if req.token == "" { + return apiutil.ErrBearerToken + } + + return nil +} diff --git a/auth/api/http/keys/requests_test.go b/auth/api/http/keys/requests_test.go index 6172f24347..2bc5ece50c 100644 --- a/auth/api/http/keys/requests_test.go +++ b/auth/api/http/keys/requests_test.go @@ -86,3 +86,30 @@ func TestKeyReqValidate(t *testing.T) { assert.Equal(t, tc.err, err) } } + +func TestRevokeTokenReqValidate(t *testing.T) { + cases := []struct { + desc string + req revokeTokenReq + err error + }{ + { + desc: "valid request", + req: revokeTokenReq{ + token: valid, + }, + err: nil, + }, + { + desc: "empty token", + req: revokeTokenReq{ + token: "", + }, + err: apiutil.ErrBearerToken, + }, + } + for _, tc := range cases { + err := tc.req.validate() + assert.Equal(t, tc.err, err) + } +} diff --git a/auth/api/http/keys/transport.go b/auth/api/http/keys/transport.go index 9554df3ba1..d09da3ea87 100644 --- a/auth/api/http/keys/transport.go +++ b/auth/api/http/keys/transport.go @@ -33,6 +33,13 @@ func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux { opts..., ).ServeHTTP) + r.Delete("/", kithttp.NewServer( + revokeTokenEndpoint(svc), + decodeRevokeTokenReq, + api.EncodeResponse, + opts..., + ).ServeHTTP) + r.Get("/{id}", kithttp.NewServer( (retrieveEndpoint(svc)), decodeKeyReq, @@ -70,3 +77,11 @@ func decodeKeyReq(_ context.Context, r *http.Request) (interface{}, error) { } return req, nil } + +func decodeRevokeTokenReq(_ context.Context, r *http.Request) (interface{}, error) { + req := revokeTokenReq{ + token: apiutil.ExtractBearerToken(r), + } + + return req, nil +} diff --git a/auth/api/logging.go b/auth/api/logging.go index 7240af2d27..36aadaaf20 100644 --- a/auth/api/logging.go +++ b/auth/api/logging.go @@ -193,6 +193,22 @@ func (lm *loggingMiddleware) Revoke(ctx context.Context, token, id string) (err return lm.svc.Revoke(ctx, token, id) } +func (lm *loggingMiddleware) RevokeToken(ctx context.Context, token string) (err error) { + defer func(begin time.Time) { + args := []any{ + slog.String("duration", time.Since(begin).String()), + } + if err != nil { + args = append(args, slog.Any("error", err)) + lm.logger.Warn("Revoke token failed to complete successfully", args...) + return + } + lm.logger.Info("Revoke token completed successfully", args...) + }(time.Now()) + + return lm.svc.RevokeToken(ctx, token) +} + func (lm *loggingMiddleware) RetrieveKey(ctx context.Context, token, id string) (key auth.Key, err error) { defer func(begin time.Time) { args := []any{ diff --git a/auth/api/metrics.go b/auth/api/metrics.go index 8ed201a82d..22d9831e02 100644 --- a/auth/api/metrics.go +++ b/auth/api/metrics.go @@ -109,6 +109,15 @@ func (ms *metricsMiddleware) Revoke(ctx context.Context, token, id string) error return ms.svc.Revoke(ctx, token, id) } +func (ms *metricsMiddleware) RevokeToken(ctx context.Context, token string) error { + defer func(begin time.Time) { + ms.counter.With("method", "revoke_token").Add(1) + ms.latency.With("method", "revoke_token").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return ms.svc.RevokeToken(ctx, token) +} + func (ms *metricsMiddleware) RetrieveKey(ctx context.Context, token, id string) (auth.Key, error) { defer func(begin time.Time) { ms.counter.With("method", "retrieve_key").Add(1) diff --git a/auth/cache/doc.go b/auth/cache/doc.go new file mode 100644 index 0000000000..6bf2be2e39 --- /dev/null +++ b/auth/cache/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +// Package cache contains the domain concept definitions needed to +// support Magistrala auth cache service functionality. +package cache diff --git a/auth/cache/policies.go b/auth/cache/policies.go new file mode 100644 index 0000000000..3cdc9aebae --- /dev/null +++ b/auth/cache/policies.go @@ -0,0 +1,87 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "context" + "strings" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/redis/go-redis/v9" +) + +const defLimit = 100 + +var _ auth.Cache = (*policiesCache)(nil) + +type policiesCache struct { + client *redis.Client + keyDuration time.Duration +} + +// NewPoliciesCache returns redis auth cache implementation. +func NewPoliciesCache(client *redis.Client, duration time.Duration) auth.Cache { + return &policiesCache{ + client: client, + keyDuration: duration, + } +} + +func (pc *policiesCache) Save(ctx context.Context, key, value string) error { + if err := pc.client.Set(ctx, key, value, pc.keyDuration).Err(); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (pc *policiesCache) Contains(ctx context.Context, key, value string) bool { + rval, err := pc.client.Get(ctx, key).Result() + if err != nil { + return false + } + if rval == value { + return true + } + + return false +} + +func (pc *policiesCache) Remove(ctx context.Context, key string) error { + if strings.Contains(key, "*") { + return pc.delete(ctx, key) + } + + if err := pc.client.Del(ctx, key).Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + return nil +} + +func (pc *policiesCache) delete(ctx context.Context, key string) error { + keys, cursor, err := pc.client.Scan(ctx, 0, key, defLimit).Result() + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + for cursor != 0 { + var newKeys []string + newKeys, cursor, err = pc.client.Scan(ctx, cursor, key, defLimit).Result() + if err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + keys = append(keys, newKeys...) + } + + for _, key := range keys { + if err := pc.client.Del(ctx, key).Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + } + + return nil +} diff --git a/auth/cache/policies_test.go b/auth/cache/policies_test.go new file mode 100644 index 0000000000..b82020cbca --- /dev/null +++ b/auth/cache/policies_test.go @@ -0,0 +1,345 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/auth/cache" + "github.com/absmach/magistrala/internal/testsutil" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +var policy = auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, +} + +func setupRedisCacheClient(t *testing.T) auth.Cache { + opts, err := redis.ParseURL(redisURL) + assert.Nil(t, err, fmt.Sprintf("got unexpected error on parsing redis URL: %s", err)) + redisClient := redis.NewClient(opts) + return cache.NewPoliciesCache(redisClient, 10*time.Minute) +} + +func TestSave(t *testing.T) { + authCache := setupRedisCacheClient(t) + + cases := []struct { + desc string + policy auth.PolicyReq + err error + }{ + { + desc: "Save policy", + policy: policy, + err: nil, + }, + { + desc: "Save already cached policy", + policy: policy, + err: nil, + }, + { + desc: "Save another policy", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Save another policy with domain", + policy: auth.PolicyReq{ + Domain: testsutil.GenerateUUID(&testing.T{}), + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Save policy with long key", + policy: auth.PolicyReq{ + SubjectType: strings.Repeat("a", 513*1024*1024), + Subject: strings.Repeat("a", 513*1024*1024), + ObjectType: strings.Repeat("a", 513*1024*1024), + Object: strings.Repeat("a", 513*1024*1024), + Permission: auth.ViewPermission, + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "Save policy with long value", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: strings.Repeat("a", 513*1024*1024), + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "Save policy with empty key", + policy: auth.PolicyReq{ + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Save policy with empty subject", + policy: auth.PolicyReq{ + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Save policy with empty object", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Save policy with empty value", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + }, + err: nil, + }, + { + desc: "Save policy with empty key and id", + policy: auth.PolicyReq{}, + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + key, val := tc.policy.KV() + err := authCache.Save(context.Background(), key, val) + if err == nil { + ok := authCache.Contains(context.Background(), key, val) + assert.True(t, ok) + } + assert.True(t, errors.Contains(err, tc.err)) + }) + } +} + +func TestContains(t *testing.T) { + authCache := setupRedisCacheClient(t) + + key, val := policy.KV() + err := authCache.Save(context.Background(), key, val) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + + cases := []struct { + desc string + policy auth.PolicyReq + ok bool + }{ + { + desc: "Contains existing policy", + policy: policy, + ok: true, + }, + { + desc: "Contains invalid policy", + policy: auth.PolicyReq{ + SubjectType: policy.SubjectType, + Subject: policy.Subject, + ObjectType: policy.ObjectType, + Object: policy.Object, + Permission: auth.EditPermission, + }, + }, + { + desc: "Contains non existing policy", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + }, + { + desc: "Contains non existing policy with domain", + policy: auth.PolicyReq{ + Domain: testsutil.GenerateUUID(&testing.T{}), + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + }, + { + desc: "Contains policy with empty key", + policy: auth.PolicyReq{ + Permission: auth.ViewPermission, + }, + }, + { + desc: "Contains policy with long key", + policy: auth.PolicyReq{ + SubjectType: strings.Repeat("a", 513*1024*1024), + Subject: strings.Repeat("a", 513*1024*1024), + ObjectType: strings.Repeat("a", 513*1024*1024), + Object: strings.Repeat("a", 513*1024*1024), + Permission: auth.ViewPermission, + }, + }, + { + desc: "Contains policy with empty value", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + key, val := tc.policy.KV() + ok := authCache.Contains(context.Background(), key, val) + assert.Equal(t, tc.ok, ok) + }) + } +} + +func TestRemove(t *testing.T) { + authCache := setupRedisCacheClient(t) + + subject := policy.Subject + object := policy.Object + + num := 200 + var policies []auth.PolicyReq + for i := 0; i < num; i++ { + policy.Subject = fmt.Sprintf("%s-%d", policy.Subject, i) + policy.Object = fmt.Sprintf("%s-%d", policy.Object, i) + key, val := policy.KV() + err := authCache.Save(context.Background(), key, val) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + policies = append(policies, policy) + } + + cases := []struct { + desc string + multiple bool + policy auth.PolicyReq + err error + }{ + { + desc: "Remove an existing policy from cache", + policy: policies[0], + err: nil, + }, + { + desc: "Remove multiple existing policies from cache with subject", + multiple: true, + policy: auth.PolicyReq{ + Subject: subject, + }, + err: nil, + }, + { + desc: "Remove multiple existing policies from cache with object", + multiple: true, + policy: auth.PolicyReq{ + Object: object, + }, + err: nil, + }, + { + desc: "Remove non existing policy from cache", + policy: auth.PolicyReq{ + SubjectType: auth.UserType, + Subject: testsutil.GenerateUUID(&testing.T{}), + ObjectType: auth.ThingType, + Object: testsutil.GenerateUUID(&testing.T{}), + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Remove policy with empty key from cache", + policy: auth.PolicyReq{ + Permission: auth.ViewPermission, + }, + err: nil, + }, + { + desc: "Remove policy with long key from cache", + policy: auth.PolicyReq{ + SubjectType: strings.Repeat("a", 513*1024*1024), + Subject: strings.Repeat("a", 513*1024*1024), + ObjectType: strings.Repeat("a", 513*1024*1024), + Object: strings.Repeat("a", 513*1024*1024), + Permission: auth.ViewPermission, + }, + err: repoerr.ErrRemoveEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := authCache.Remove(context.Background(), tc.policy.KeyForRemoval()) + assert.True(t, errors.Contains(err, tc.err)) + if err == nil { + key, val := tc.policy.KV() + ok := authCache.Contains(context.Background(), key, val) + assert.False(t, ok) + if tc.multiple { + switch { + case tc.policy.Subject != "": + for _, p := range policies { + if strings.HasPrefix(p.Subject, subject) { + key, val := p.KV() + ok := authCache.Contains(context.Background(), key, val) + assert.False(t, ok) + } + } + case tc.policy.Object != "": + for _, p := range policies { + if strings.HasPrefix(p.Object, object) { + key, val := p.KV() + ok := authCache.Contains(context.Background(), key, val) + assert.False(t, ok) + } + } + } + } + } + }) + } +} diff --git a/auth/cache/setup_test.go b/auth/cache/setup_test.go new file mode 100644 index 0000000000..078c3ec758 --- /dev/null +++ b/auth/cache/setup_test.go @@ -0,0 +1,75 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache_test + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + "testing" + + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/redis/go-redis/v9" +) + +var redisURL string + +func TestMain(m *testing.M) { + pool, err := dockertest.NewPool("") + if err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + container, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "redis", + Tag: "7.2.4-alpine", + }, func(config *docker.HostConfig) { + config.AutoRemove = true + config.RestartPolicy = docker.RestartPolicy{Name: "no"} + }) + if err != nil { + log.Fatalf("Could not start container: %s", err) + } + + handleInterrupt(pool, container) + + redisURL = fmt.Sprintf("redis://localhost:%s/0", container.GetPort("6379/tcp")) + opts, err := redis.ParseURL(redisURL) + if err != nil { + log.Fatalf("Could not parse redis URL: %s", err) + } + + if err := pool.Retry(func() error { + redisClient := redis.NewClient(opts) + + return redisClient.Ping(context.Background()).Err() + }); err != nil { + log.Fatalf("Could not connect to docker: %s", err) + } + + code := m.Run() + + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + + os.Exit(code) +} + +func handleInterrupt(pool *dockertest.Pool, container *dockertest.Resource) { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + go func() { + <-c + if err := pool.Purge(container); err != nil { + log.Fatalf("Could not purge container: %s", err) + } + os.Exit(0) + }() +} diff --git a/auth/cache/tokens.go b/auth/cache/tokens.go new file mode 100644 index 0000000000..fa54e1184e --- /dev/null +++ b/auth/cache/tokens.go @@ -0,0 +1,56 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache + +import ( + "context" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/redis/go-redis/v9" +) + +const defKey = "revoked_tokens" + +var _ auth.Cache = (*tokensCache)(nil) + +type tokensCache struct { + client *redis.Client + keyDuration time.Duration +} + +// NewTokensCache returns redis auth cache implementation. +func NewTokensCache(client *redis.Client, duration time.Duration) auth.Cache { + return &tokensCache{ + client: client, + keyDuration: duration, + } +} + +func (tc *tokensCache) Save(ctx context.Context, _, value string) error { + if err := tc.client.SAdd(ctx, defKey, value, tc.keyDuration).Err(); err != nil { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (tc *tokensCache) Contains(ctx context.Context, _, value string) bool { + ok, err := tc.client.SIsMember(ctx, defKey, value).Result() + if err != nil { + return false + } + + return ok +} + +func (tc *tokensCache) Remove(ctx context.Context, value string) error { + if err := tc.client.SRem(ctx, defKey, value).Err(); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + + return nil +} diff --git a/auth/cache/tokens_test.go b/auth/cache/tokens_test.go new file mode 100644 index 0000000000..8f9902073f --- /dev/null +++ b/auth/cache/tokens_test.go @@ -0,0 +1,184 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package cache_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/auth/cache" + "github.com/absmach/magistrala/internal/testsutil" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" +) + +var key = auth.Key{ + ID: testsutil.GenerateUUID(&testing.T{}), +} + +func setupRedisTokensClient(t *testing.T) auth.Cache { + opts, err := redis.ParseURL(redisURL) + assert.Nil(t, err, fmt.Sprintf("got unexpected error on parsing redis URL: %s", err)) + redisClient := redis.NewClient(opts) + return cache.NewPoliciesCache(redisClient, 10*time.Minute) +} + +func TestTokenSave(t *testing.T) { + tokensCache := setupRedisTokensClient(t) + + cases := []struct { + desc string + key auth.Key + err error + }{ + { + desc: "Save token", + key: key, + err: nil, + }, + { + desc: "Save already cached policy", + key: key, + err: nil, + }, + { + desc: "Save another policy", + key: auth.Key{ + ID: testsutil.GenerateUUID(&testing.T{}), + }, + err: nil, + }, + { + desc: "Save policy with long key", + key: auth.Key{ + ID: strings.Repeat("a", 513*1024*1024), + }, + err: repoerr.ErrCreateEntity, + }, + { + desc: "Save policy with empty key", + err: nil, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.Save(context.Background(), "", tc.key.ID) + if err == nil { + ok := tokensCache.Contains(context.Background(), "", tc.key.ID) + assert.True(t, ok) + } + assert.True(t, errors.Contains(err, tc.err)) + }) + } +} + +func TestTokenContains(t *testing.T) { + tokensCache := setupRedisTokensClient(t) + + err := tokensCache.Save(context.Background(), "", key.ID) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + + cases := []struct { + desc string + key auth.Key + ok bool + }{ + { + desc: "Contains existing key", + key: key, + ok: true, + }, + { + desc: "Contains non existing key", + key: auth.Key{ + ID: testsutil.GenerateUUID(&testing.T{}), + }, + }, + { + desc: "Contains key with long id", + key: auth.Key{ + ID: strings.Repeat("a", 513*1024*1024), + }, + }, + { + desc: "Contains key with empty id", + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + ok := tokensCache.Contains(context.Background(), "", tc.key.ID) + assert.Equal(t, tc.ok, ok) + }) + } +} + +func TestTokenRemove(t *testing.T) { + tokensCache := setupRedisTokensClient(t) + + num := 1000 + var ids []string + for i := 0; i < num; i++ { + id := testsutil.GenerateUUID(&testing.T{}) + err := tokensCache.Save(context.Background(), "", id) + assert.Nil(t, err, fmt.Sprintf("Unexpected error while trying to save: %s", err)) + ids = append(ids, id) + } + + cases := []struct { + desc string + id string + err error + }{ + { + desc: "Remove an existing id from cache", + id: ids[0], + err: nil, + }, + { + desc: "Remove multiple existing id from cache", + id: "*", + err: nil, + }, + { + desc: "Remove non existing id from cache", + id: testsutil.GenerateUUID(&testing.T{}), + err: nil, + }, + { + desc: "Remove policy with empty id from cache", + err: nil, + }, + { + desc: "Remove policy with long id from cache", + id: strings.Repeat("a", 513*1024*1024), + err: repoerr.ErrRemoveEntity, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + err := tokensCache.Remove(context.Background(), tc.id) + assert.True(t, errors.Contains(err, tc.err)) + if tc.id == "*" { + for _, id := range ids { + ok := tokensCache.Contains(context.Background(), "", id) + assert.False(t, ok) + } + return + } + if err == nil { + ok := tokensCache.Contains(context.Background(), "", tc.id) + assert.False(t, ok) + } + }) + } +} diff --git a/auth/events/streams.go b/auth/events/streams.go index 0081a962fa..3778c9dc10 100644 --- a/auth/events/streams.go +++ b/auth/events/streams.go @@ -203,6 +203,10 @@ func (es *eventStore) Revoke(ctx context.Context, token, id string) error { return es.svc.Revoke(ctx, token, id) } +func (es *eventStore) RevokeToken(ctx context.Context, token string) error { + return es.svc.RevokeToken(ctx, token) +} + func (es *eventStore) RetrieveKey(ctx context.Context, token, id string) (auth.Key, error) { return es.svc.RetrieveKey(ctx, token, id) } diff --git a/auth/jwt/token_test.go b/auth/jwt/token_test.go index 461adb95be..dafc990489 100644 --- a/auth/jwt/token_test.go +++ b/auth/jwt/token_test.go @@ -4,14 +4,17 @@ package jwt_test import ( + "context" "fmt" "testing" "time" "github.com/absmach/magistrala/auth" authjwt "github.com/absmach/magistrala/auth/jwt" + "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/internal/testsutil" "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" svcerr "github.com/absmach/magistrala/pkg/errors/service" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" @@ -55,7 +58,9 @@ func newToken(issuerName string, key auth.Key) string { } func TestIssue(t *testing.T) { - tokenizer := authjwt.New([]byte(secret)) + repo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + tokenizer := authjwt.New([]byte(secret), repo, cache) cases := []struct { desc string @@ -128,7 +133,9 @@ func TestIssue(t *testing.T) { } func TestParse(t *testing.T) { - tokenizer := authjwt.New([]byte(secret)) + repo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + tokenizer := authjwt.New([]byte(secret), repo, cache) token, err := tokenizer.Issue(key()) require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err)) @@ -162,11 +169,19 @@ func TestParse(t *testing.T) { inValidToken := newToken("invalid", key()) + refreshKey := key() + refreshKey.Type = auth.RefreshKey + refreshToken, err := tokenizer.Issue(refreshKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + cases := []struct { - desc string - key auth.Key - token string - err error + desc string + key auth.Key + token string + cacheContains bool + repoContains bool + cacheSave error + err error }{ { desc: "parse valid key", @@ -222,14 +237,191 @@ func TestParse(t *testing.T) { token: emptyToken, err: nil, }, + { + desc: "parse refresh token", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: false, + err: nil, + }, + { + desc: "parse revoked refresh token in cache", + key: refreshKey, + token: refreshToken, + cacheContains: true, + repoContains: false, + err: svcerr.ErrAuthentication, + }, + { + desc: "parse revoked refresh token not in cache", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: true, + err: svcerr.ErrAuthentication, + }, + { + desc: "parse revoked refresh token failed to save in cache", + key: refreshKey, + token: refreshToken, + cacheContains: false, + repoContains: true, + cacheSave: repoerr.ErrCreateEntity, + err: svcerr.ErrAuthentication, + }, } for _, tc := range cases { - key, err := tokenizer.Parse(tc.token) + cacheCall := cache.On("Contains", context.Background(), "", tc.key.ID).Return(tc.cacheContains) + repoCall := repo.On("Contains", context.Background(), tc.key.ID).Return(tc.repoContains) + cacheCall1 := cache.On("Save", context.Background(), "", tc.key.ID).Return(tc.cacheSave) + key, err := tokenizer.Parse(context.Background(), tc.token) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) if err == nil { assert.Equal(t, tc.key, key, fmt.Sprintf("%s expected %v, got %v", tc.desc, tc.key, key)) } + cacheCall.Unset() + repoCall.Unset() + cacheCall1.Unset() + } +} + +func TestRevoke(t *testing.T) { + repo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + tokenizer := authjwt.New([]byte(secret), repo, cache) + + token, err := tokenizer.Issue(key()) + require.Nil(t, err, fmt.Sprintf("issuing key expected to succeed: %s", err)) + + apiKey := key() + apiKey.Type = auth.APIKey + apiKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second) + apiToken, err := tokenizer.Issue(apiKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + expKey := key() + expKey.ExpiresAt = time.Now().UTC().Add(-1 * time.Minute).Round(time.Second) + expToken, err := tokenizer.Issue(expKey) + require.Nil(t, err, fmt.Sprintf("issuing expired key expected to succeed: %s", err)) + + emptyDomainKey := key() + emptyDomainKey.Domain = "" + emptyDomainToken, err := tokenizer.Issue(emptyDomainKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + emptySubjectKey := key() + emptySubjectKey.Subject = "" + emptySubjectToken, err := tokenizer.Issue(emptySubjectKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + emptyKey := key() + emptyKey.Domain = "" + emptyKey.Subject = "" + emptyToken, err := tokenizer.Issue(emptyKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + inValidToken := newToken("invalid", key()) + + refreshKey := key() + refreshKey.Type = auth.RefreshKey + refreshToken, err := tokenizer.Issue(refreshKey) + require.Nil(t, err, fmt.Sprintf("issuing user key expected to succeed: %s", err)) + + cases := []struct { + desc string + key auth.Key + token string + repoErr error + cacheErr error + err error + }{ + { + desc: "revoke valid key", + key: key(), + token: token, + err: nil, + }, + { + desc: "revoke invalid key", + key: auth.Key{}, + token: "invalid", + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke expired key", + key: auth.Key{}, + token: expToken, + err: auth.ErrExpiry, + }, + { + desc: "revoke expired API key", + key: apiKey, + token: apiToken, + err: auth.ErrExpiry, + }, + { + desc: "revoke token with invalid issuer", + key: auth.Key{}, + token: inValidToken, + err: errInvalidIssuer, + }, + { + desc: "revoke token with invalid content", + key: auth.Key{}, + token: newToken(issuerName, key()), + err: authjwt.ErrJSONHandle, + }, + { + desc: "revoke token with empty domain", + key: emptyDomainKey, + token: emptyDomainToken, + err: nil, + }, + { + desc: "revoke token with empty subject", + key: emptySubjectKey, + token: emptySubjectToken, + err: nil, + }, + { + desc: "revoke token with empty domain and subject", + key: emptyKey, + token: emptyToken, + err: nil, + }, + { + desc: "revoke refresh token", + key: refreshKey, + token: refreshToken, + err: nil, + }, + { + desc: "revoke revoked refresh token failed to save in cache", + key: refreshKey, + token: refreshToken, + repoErr: nil, + cacheErr: repoerr.ErrCreateEntity, + err: svcerr.ErrAuthentication, + }, + { + desc: "revoke revoked refresh token failed to save in cache", + key: refreshKey, + token: refreshToken, + repoErr: repoerr.ErrCreateEntity, + cacheErr: nil, + err: svcerr.ErrAuthentication, + }, + } + + for _, tc := range cases { + repoCall := repo.On("Save", context.Background(), tc.key.ID).Return(tc.repoErr) + cacheCall := cache.On("Save", context.Background(), "", tc.key.ID).Return(tc.cacheErr) + err := tokenizer.Revoke(context.Background(), tc.token) + assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s, got %s", tc.desc, tc.err, err)) + cacheCall.Unset() + repoCall.Unset() } } diff --git a/auth/jwt/tokenizer.go b/auth/jwt/tokenizer.go index ad79549016..0c48ac4577 100644 --- a/auth/jwt/tokenizer.go +++ b/auth/jwt/tokenizer.go @@ -26,6 +26,8 @@ var ( ErrValidateJWTToken = errors.New("failed to validate jwt token") // ErrJSONHandle indicates an error in handling JSON. ErrJSONHandle = errors.New("failed to perform operation JSON") + // errRevokedToken indicates that the token is revoked. + errRevokedToken = errors.New("token is revoked") ) const ( @@ -40,14 +42,18 @@ const ( type tokenizer struct { secret []byte + cache auth.Cache + repo auth.TokenRepository } var _ auth.Tokenizer = (*tokenizer)(nil) // NewRepository instantiates an implementation of Token repository. -func New(secret []byte) auth.Tokenizer { +func New(secret []byte, repo auth.TokenRepository, cache auth.Cache) auth.Tokenizer { return &tokenizer{ secret: secret, + repo: repo, + cache: cache, } } @@ -79,7 +85,7 @@ func (tok *tokenizer) Issue(key auth.Key) (string, error) { return string(signedTkn), nil } -func (tok *tokenizer) Parse(token string) (auth.Key, error) { +func (tok *tokenizer) Parse(ctx context.Context, token string) (auth.Key, error) { tkn, err := tok.validateToken(token) if err != nil { return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) @@ -90,9 +96,48 @@ func (tok *tokenizer) Parse(token string) (auth.Key, error) { return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) } + if key.Type == auth.RefreshKey { + switch tok.cache.Contains(ctx, "", key.ID) { + case true: + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedToken) + default: + if ok := tok.repo.Contains(ctx, key.ID); ok { + if err := tok.cache.Save(ctx, "", key.ID); err != nil { + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, err) + } + + return auth.Key{}, errors.Wrap(svcerr.ErrAuthentication, errRevokedToken) + } + } + } + return key, nil } +func (tok *tokenizer) Revoke(ctx context.Context, token string) error { + tkn, err := tok.validateToken(token) + if err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + key, err := toKey(tkn) + if err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + if key.Type == auth.RefreshKey { + if err := tok.repo.Save(ctx, key.ID); err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + + if err := tok.cache.Save(ctx, "", key.ID); err != nil { + return errors.Wrap(svcerr.ErrAuthentication, err) + } + } + + return nil +} + func (tok *tokenizer) validateToken(token string) (jwt.Token, error) { tkn, err := jwt.Parse( []byte(token), diff --git a/auth/mocks/cache.go b/auth/mocks/cache.go new file mode 100644 index 0000000000..f68b885bdb --- /dev/null +++ b/auth/mocks/cache.go @@ -0,0 +1,84 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// Cache is an autogenerated mock type for the Cache type +type Cache struct { + mock.Mock +} + +// Contains provides a mock function with given fields: ctx, key, value +func (_m *Cache) Contains(ctx context.Context, key string, value string) bool { + ret := _m.Called(ctx, key, value) + + if len(ret) == 0 { + panic("no return value specified for Contains") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, string, string) bool); ok { + r0 = rf(ctx, key, value) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Remove provides a mock function with given fields: ctx, key +func (_m *Cache) Remove(ctx context.Context, key string) error { + ret := _m.Called(ctx, key) + + if len(ret) == 0 { + panic("no return value specified for Remove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, key) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Save provides a mock function with given fields: ctx, key, value +func (_m *Cache) Save(ctx context.Context, key string, value string) error { + ret := _m.Called(ctx, key, value) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) error); ok { + r0 = rf(ctx, key, value) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewCache creates a new instance of Cache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewCache(t interface { + mock.TestingT + Cleanup(func()) +}) *Cache { + mock := &Cache{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/mocks/service.go b/auth/mocks/service.go index 45d3b3d7f9..919592f5d9 100644 --- a/auth/mocks/service.go +++ b/auth/mocks/service.go @@ -613,6 +613,24 @@ func (_m *Service) Revoke(ctx context.Context, token string, id string) error { return r0 } +// RevokeToken provides a mock function with given fields: ctx, token +func (_m *Service) RevokeToken(ctx context.Context, token string) error { + ret := _m.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for RevokeToken") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, token) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UnassignUser provides a mock function with given fields: ctx, token, id, userID func (_m *Service) UnassignUser(ctx context.Context, token string, id string, userID string) error { ret := _m.Called(ctx, token, id, userID) diff --git a/auth/mocks/token.go b/auth/mocks/token.go new file mode 100644 index 0000000000..0f4dfb8bf0 --- /dev/null +++ b/auth/mocks/token.go @@ -0,0 +1,66 @@ +// Code generated by mockery v2.43.2. DO NOT EDIT. + +// Copyright (c) Abstract Machines + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// TokenRepository is an autogenerated mock type for the TokenRepository type +type TokenRepository struct { + mock.Mock +} + +// Contains provides a mock function with given fields: ctx, id +func (_m *TokenRepository) Contains(ctx context.Context, id string) bool { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Contains") + } + + var r0 bool + if rf, ok := ret.Get(0).(func(context.Context, string) bool); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// Save provides a mock function with given fields: ctx, id +func (_m *TokenRepository) Save(ctx context.Context, id string) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for Save") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewTokenRepository creates a new instance of TokenRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokenRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *TokenRepository { + mock := &TokenRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/auth/policies.go b/auth/policies.go index e2e416aed2..84f032bf47 100644 --- a/auth/policies.go +++ b/auth/policies.go @@ -104,6 +104,34 @@ func (pr PolicyReq) String() string { return string(data) } +// KV returns the key-value pair for the given PolicyReq. +func (pr PolicyReq) KV() (string, string) { + var key, val string + switch pr.Domain { + case "": + key = pr.SubjectType + ":" + pr.Subject + ":" + pr.ObjectType + ":" + pr.Object + default: + key = pr.Domain + ":" + pr.SubjectType + ":" + pr.Subject + ":" + pr.ObjectType + ":" + pr.Object + } + val = pr.Permission + + return key, val +} + +// KeyForRemoval returns the key for the given PolicyReq. It is used +// to remove a key from the cache. +func (pr PolicyReq) KeyForRemoval() string { + switch { + case pr.Subject != "" && pr.Object == "": + return "*" + pr.Subject + "*" + case pr.Object != "" && pr.Subject == "": + return "*" + pr.Object + "*" + default: + key, _ := pr.KV() + return key + } +} + type PolicyRes struct { Namespace string Subject string @@ -221,3 +249,18 @@ type PolicyAgent interface { // (ctx context.Context, pr PolicyReq, filterPermissions []string) ([]PolicyReq, error) RetrievePermissions(ctx context.Context, pr PolicyReq, filterPermission []string) (Permissions, error) } + +// Cache represents a cache repository. It exposes functionalities +// through `auth` to perform caching. +// +//go:generate mockery --name Cache --output=./mocks --filename cache.go --quiet --note "Copyright (c) Abstract Machines" +type Cache interface { + // Save saves the key-value pair in the cache. + Save(ctx context.Context, key, value string) error + + // Contains checks if the key-value pair exists in the cache. + Contains(ctx context.Context, key, value string) bool + + // Remove removes the key from the cache. + Remove(ctx context.Context, key string) error +} diff --git a/auth/postgres/init.go b/auth/postgres/init.go index ae69c3a0ca..eca0ad9132 100644 --- a/auth/postgres/init.go +++ b/auth/postgres/init.go @@ -57,6 +57,17 @@ func Migration() *migrate.MemoryMigrationSource { `ALTER TABLE domains ALTER COLUMN alias SET NOT NULL`, }, }, + { + Id: "auth_3", + Up: []string{ + `CREATE TABLE IF NOT EXISTS tokens ( + id VARCHAR(36) PRIMARY KEY + );`, + }, + Down: []string{ + `DROP TABLE IF EXISTS tokens`, + }, + }, }, } } diff --git a/auth/postgres/token.go b/auth/postgres/token.go new file mode 100644 index 0000000000..f63182a50e --- /dev/null +++ b/auth/postgres/token.go @@ -0,0 +1,61 @@ +// Copyright (c) Abstract Machines +// SPDX-License-Identifier: Apache-2.0 + +package postgres + +import ( + "context" + + "github.com/absmach/magistrala/auth" + "github.com/absmach/magistrala/pkg/errors" + repoerr "github.com/absmach/magistrala/pkg/errors/repository" + "github.com/absmach/magistrala/pkg/postgres" +) + +var _ auth.TokenRepository = (*tokenRepo)(nil) + +type tokenRepo struct { + db postgres.Database +} + +// NewTokensRepository instantiates a PostgreSQL implementation of tokens repository. +func NewTokensRepository(db postgres.Database) auth.TokenRepository { + return &tokenRepo{ + db: db, + } +} + +func (repo *tokenRepo) Save(ctx context.Context, id string) error { + q := `INSERT INTO tokens (id) VALUES ($1);` + + result, err := repo.db.ExecContext(ctx, q, id) + if err != nil { + return postgres.HandleError(repoerr.ErrCreateEntity, err) + } + if rows, err := result.RowsAffected(); rows == 0 { + return errors.Wrap(repoerr.ErrCreateEntity, err) + } + + return nil +} + +func (repo *tokenRepo) Contains(ctx context.Context, id string) bool { + q := `SELECT * FROM tokens WHERE id = $1;` + + rows, err := repo.db.QueryContext(ctx, q, id) + if err != nil { + return false + } + defer rows.Close() + + if rows.Next() { + id := "" + if err = rows.Scan(&id); err != nil { + return false + } + + return true + } + + return false +} diff --git a/auth/service.go b/auth/service.go index 3cf67c44e3..67ed7bae96 100644 --- a/auth/service.go +++ b/auth/service.go @@ -85,6 +85,9 @@ type Authn interface { // issued by the user identified by the provided key. Revoke(ctx context.Context, token, id string) error + // RevokeToken revokes the token. + RevokeToken(ctx context.Context, token string) error + // RetrieveKey retrieves data for the Key identified by the provided // ID, that is issued by the user identified by the provided key. RetrieveKey(ctx context.Context, token, id string) (Key, error) @@ -136,6 +139,12 @@ func New(keys KeyRepository, domains DomainsRepository, idp magistrala.IDProvide func (svc service) Issue(ctx context.Context, token string, key Key) (Token, error) { key.IssuedAt = time.Now().UTC() + id, err := svc.idProvider.ID() + if err != nil { + return Token{}, errors.Wrap(errIssueUser, err) + } + key.ID = id + switch key.Type { case APIKey: return svc.userKey(ctx, token, key) @@ -151,7 +160,7 @@ func (svc service) Issue(ctx context.Context, token string, key Key) (Token, err } func (svc service) Revoke(ctx context.Context, token, id string) error { - issuerID, _, err := svc.authenticate(token) + issuerID, _, err := svc.authenticate(ctx, token) if err != nil { return errors.Wrap(errRevoke, err) } @@ -161,8 +170,12 @@ func (svc service) Revoke(ctx context.Context, token, id string) error { return nil } +func (svc service) RevokeToken(ctx context.Context, token string) error { + return svc.tokenizer.Revoke(ctx, token) +} + func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, error) { - issuerID, _, err := svc.authenticate(token) + issuerID, _, err := svc.authenticate(ctx, token) if err != nil { return Key{}, errors.Wrap(errRetrieve, err) } @@ -175,7 +188,7 @@ func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, erro } func (svc service) Identify(ctx context.Context, token string) (Key, error) { - key, err := svc.tokenizer.Parse(token) + key, err := svc.tokenizer.Parse(ctx, token) if errors.Contains(err, ErrExpiry) { err = svc.keys.Remove(ctx, key.Issuer, key.ID) return Key{}, errors.Wrap(svcerr.ErrAuthentication, errors.Wrap(ErrKeyExpired, err)) @@ -464,7 +477,7 @@ func (svc service) invitationKey(ctx context.Context, key Key) (Token, error) { } func (svc service) refreshKey(ctx context.Context, token string, key Key) (Token, error) { - k, err := svc.tokenizer.Parse(token) + k, err := svc.tokenizer.Parse(ctx, token) if err != nil { return Token{}, errors.Wrap(errRetrieve, err) } @@ -528,7 +541,7 @@ func (svc service) checkUserDomain(ctx context.Context, key Key) (subject string } func (svc service) userKey(ctx context.Context, token string, key Key) (Token, error) { - id, sub, err := svc.authenticate(token) + id, sub, err := svc.authenticate(ctx, token) if err != nil { return Token{}, errors.Wrap(errIssueUser, err) } @@ -538,12 +551,6 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e key.Subject = sub } - keyID, err := svc.idProvider.ID() - if err != nil { - return Token{}, errors.Wrap(errIssueUser, err) - } - key.ID = keyID - if _, err := svc.keys.Save(ctx, key); err != nil { return Token{}, errors.Wrap(errIssueUser, err) } @@ -556,8 +563,8 @@ func (svc service) userKey(ctx context.Context, token string, key Key) (Token, e return Token{AccessToken: tkn}, nil } -func (svc service) authenticate(token string) (string, string, error) { - key, err := svc.tokenizer.Parse(token) +func (svc service) authenticate(ctx context.Context, token string) (string, string, error) { + key, err := svc.tokenizer.Parse(ctx, token) if err != nil { return "", "", errors.Wrap(svcerr.ErrAuthentication, err) } diff --git a/auth/service_test.go b/auth/service_test.go index 7547c87db5..cba530b63b 100644 --- a/auth/service_test.go +++ b/auth/service_test.go @@ -63,13 +63,15 @@ var ( drepo *mocks.DomainsRepository ) -func newService() (auth.Service, string) { +func newService() (auth.Service, *mocks.TokenRepository, *mocks.Cache, string) { krepo = new(mocks.KeyRepository) + trepo := new(mocks.TokenRepository) + cache := new(mocks.Cache) prepo = new(mocks.PolicyAgent) drepo = new(mocks.DomainsRepository) idProvider := uuid.NewMock() - t := jwt.New([]byte(secret)) + t := jwt.New([]byte(secret), trepo, cache) key := auth.Key{ IssuedAt: time.Now(), ExpiresAt: time.Now().Add(refreshDuration), @@ -80,13 +82,25 @@ func newService() (auth.Service, string) { } token, _ := t.Issue(key) - return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration), token + return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration), trepo, cache, token } -func TestIssue(t *testing.T) { - svc, accessToken := newService() +func newMinimalService() auth.Service { + krepo = new(mocks.KeyRepository) + trepo := new(mocks.TokenRepository) + cache := new(mocks.Cache) + prepo = new(mocks.PolicyAgent) + drepo = new(mocks.DomainsRepository) + idProvider := uuid.NewMock() + + t := jwt.New([]byte(secret), trepo, cache) - n := jwt.New([]byte(secret)) + return auth.New(krepo, drepo, idProvider, t, prepo, loginDuration, refreshDuration, invalidDuration) +} + +func TestIssue(t *testing.T) { + svc, trepo, cache, accessToken := newService() + n := jwt.New([]byte(secret), trepo, cache) apikey := auth.Key{ IssuedAt: time.Now(), @@ -379,6 +393,9 @@ func TestIssue(t *testing.T) { checkDOmainPolicyReq auth.PolicyReq checkPolicyErr error retrieveByIDErr error + cacheContains bool + repoContains bool + cacheSave error err error }{ { @@ -492,21 +509,82 @@ func TestIssue(t *testing.T) { retrieveByIDErr: repoerr.ErrNotFound, err: svcerr.ErrDomainAuthorization, }, + { + desc: "issue revoked refresh key in cache", + key: auth.Key{ + Type: auth.RefreshKey, + IssuedAt: time.Now(), + }, + checkPolicyRequest: auth.PolicyReq{ + Subject: email, + SubjectType: auth.UserType, + Object: auth.MagistralaObject, + ObjectType: auth.PlatformType, + Permission: auth.AdminPermission, + }, + cacheContains: true, + repoContains: false, + token: refreshToken, + err: svcerr.ErrAuthentication, + }, + { + desc: "issue revoked refresh key not in cache", + key: auth.Key{ + Type: auth.RefreshKey, + IssuedAt: time.Now(), + }, + checkPolicyRequest: auth.PolicyReq{ + Subject: email, + SubjectType: auth.UserType, + Object: auth.MagistralaObject, + ObjectType: auth.PlatformType, + Permission: auth.AdminPermission, + }, + cacheContains: false, + repoContains: true, + token: refreshToken, + err: svcerr.ErrAuthentication, + }, + { + desc: "issue revoked refresh key failed to save in cache", + key: auth.Key{ + Type: auth.RefreshKey, + IssuedAt: time.Now(), + }, + checkPolicyRequest: auth.PolicyReq{ + Subject: email, + SubjectType: auth.UserType, + Object: auth.MagistralaObject, + ObjectType: auth.PlatformType, + Permission: auth.AdminPermission, + }, + cacheContains: false, + repoContains: true, + cacheSave: repoerr.ErrCreateEntity, + token: refreshToken, + err: svcerr.ErrAuthentication, + }, } for _, tc := range cases4 { - repoCall := prepo.On("CheckPolicy", mock.Anything, tc.checkPolicyRequest).Return(tc.checkPolicyErr) - repoCall1 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retrieveByIDErr) - repoCall2 := prepo.On("CheckPolicy", mock.Anything, tc.checkDOmainPolicyReq).Return(tc.checkPolicyErr) + cacheCall := cache.On("Contains", context.Background(), "", refreshkey.ID).Return(tc.cacheContains) + repoCall := trepo.On("Contains", context.Background(), refreshkey.ID).Return(tc.repoContains) + cacheCall1 := cache.On("Save", context.Background(), "", refreshkey.ID).Return(tc.cacheSave) + repoCall1 := prepo.On("CheckPolicy", mock.Anything, tc.checkPolicyRequest).Return(tc.checkPolicyErr) + repoCall2 := drepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(auth.Domain{}, tc.retrieveByIDErr) + repoCall3 := prepo.On("CheckPolicy", mock.Anything, tc.checkDOmainPolicyReq).Return(tc.checkPolicyErr) _, err := svc.Issue(context.Background(), tc.token, tc.key) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s expected %s got %s\n", tc.desc, tc.err, err)) + cacheCall.Unset() + cacheCall1.Unset() repoCall.Unset() repoCall1.Unset() repoCall2.Unset() + repoCall3.Unset() } } func TestRevoke(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, errIssueUser) secret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) repocall.Unset() @@ -559,7 +637,7 @@ func TestRevoke(t *testing.T) { } func TestRetrieve(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) secret, err := svc.Issue(context.Background(), "", auth.Key{Type: auth.AccessKey, IssuedAt: time.Now(), Subject: id}) assert.Nil(t, err, fmt.Sprintf("Issuing login key expected to succeed: %s", err)) @@ -629,7 +707,7 @@ func TestRetrieve(t *testing.T) { } func TestIdentify(t *testing.T) { - svc, _ := newService() + svc, trepo, cache, _ := newService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) repocall1 := prepo.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil) @@ -655,7 +733,7 @@ func TestIdentify(t *testing.T) { assert.Nil(t, err, fmt.Sprintf("Issuing expired login key expected to succeed: %s", err)) repocall4.Unset() - te := jwt.New([]byte(secret)) + te := jwt.New([]byte(secret), trepo, cache) key := auth.Key{ IssuedAt: time.Now(), ExpiresAt: time.Now().Add(refreshDuration), @@ -734,7 +812,7 @@ func TestIdentify(t *testing.T) { } func TestAuthorize(t *testing.T) { - svc, accessToken := newService() + svc, trepo, cache, accessToken := newService() repocall := krepo.On("Save", mock.Anything, mock.Anything).Return(mock.Anything, nil) repocall1 := prepo.On("CheckPolicy", mock.Anything, mock.Anything).Return(nil) @@ -755,7 +833,7 @@ func TestAuthorize(t *testing.T) { repocall2.Unset() repocall3.Unset() - te := jwt.New([]byte(secret)) + te := jwt.New([]byte(secret), trepo, cache) key := auth.Key{ IssuedAt: time.Now(), ExpiresAt: time.Now().Add(refreshDuration), @@ -1198,7 +1276,7 @@ func TestAuthorize(t *testing.T) { } func TestAddPolicy(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() cases := []struct { desc string @@ -1240,7 +1318,7 @@ func TestAddPolicy(t *testing.T) { } func TestAddPolicies(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() cases := []struct { desc string @@ -1302,7 +1380,7 @@ func TestAddPolicies(t *testing.T) { } func TestDeletePolicy(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() cases := []struct { desc string @@ -1344,7 +1422,7 @@ func TestDeletePolicy(t *testing.T) { } func TestDeletePolicies(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() cases := []struct { desc string @@ -1406,7 +1484,7 @@ func TestDeletePolicies(t *testing.T) { } func TestListObjects(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() pageLen := 15 expectedPolicies := make([]auth.PolicyRes, pageLen) @@ -1459,7 +1537,7 @@ func TestListObjects(t *testing.T) { } func TestListAllObjects(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() pageLen := 15 expectedPolicies := make([]auth.PolicyRes, pageLen) @@ -1512,7 +1590,7 @@ func TestListAllObjects(t *testing.T) { } func TestCountObjects(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() pageLen := uint64(15) @@ -1524,7 +1602,7 @@ func TestCountObjects(t *testing.T) { } func TestListSubjects(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() pageLen := 15 expectedPolicies := make([]auth.PolicyRes, pageLen) @@ -1577,7 +1655,7 @@ func TestListSubjects(t *testing.T) { } func TestListAllSubjects(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() pageLen := 15 expectedPolicies := make([]auth.PolicyRes, pageLen) @@ -1630,7 +1708,7 @@ func TestListAllSubjects(t *testing.T) { } func TestCountSubjects(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() pageLen := uint64(15) repocall := prepo.On("RetrieveAllSubjectsCount", mock.Anything, mock.Anything, mock.Anything).Return(pageLen, nil) @@ -1641,7 +1719,7 @@ func TestCountSubjects(t *testing.T) { } func TestListPermissions(t *testing.T) { - svc, _ := newService() + svc := newMinimalService() pr := auth.PolicyReq{ Subject: id, @@ -1703,7 +1781,7 @@ func TestSwitchToPermission(t *testing.T) { } func TestCreateDomain(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -1828,7 +1906,7 @@ func TestCreateDomain(t *testing.T) { } func TestRetrieveDomain(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -1888,7 +1966,7 @@ func TestRetrieveDomain(t *testing.T) { } func TestRetrieveDomainPermissions(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -1947,7 +2025,7 @@ func TestRetrieveDomainPermissions(t *testing.T) { } func TestUpdateDomain(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -2027,7 +2105,7 @@ func TestUpdateDomain(t *testing.T) { } func TestChangeDomainStatus(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() disabledStatus := auth.DisabledStatus @@ -2104,7 +2182,7 @@ func TestChangeDomainStatus(t *testing.T) { } func TestListDomains(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -2170,7 +2248,7 @@ func TestListDomains(t *testing.T) { } func TestAssignUsers(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -2487,7 +2565,7 @@ func TestAssignUsers(t *testing.T) { } func TestUnassignUser(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string @@ -2714,7 +2792,7 @@ func TestUnassignUser(t *testing.T) { } func TestListUsersDomains(t *testing.T) { - svc, accessToken := newService() + svc, _, _, accessToken := newService() cases := []struct { desc string diff --git a/auth/spicedb/policies.go b/auth/spicedb/policies.go index 7ac2ba4a2c..cc83b5c639 100644 --- a/auth/spicedb/policies.go +++ b/auth/spicedb/policies.go @@ -35,17 +35,30 @@ type policyAgent struct { client *authzed.ClientWithExperimental permissionClient v1.PermissionsServiceClient logger *slog.Logger + cache auth.Cache } -func NewPolicyAgent(client *authzed.ClientWithExperimental, logger *slog.Logger) auth.PolicyAgent { +func NewPolicyAgent(client *authzed.ClientWithExperimental, logger *slog.Logger, cache auth.Cache) auth.PolicyAgent { return &policyAgent{ client: client, permissionClient: client.PermissionsServiceClient, logger: logger, + cache: cache, } } -func (pa *policyAgent) CheckPolicy(ctx context.Context, pr auth.PolicyReq) error { +func (pa *policyAgent) CheckPolicy(ctx context.Context, pr auth.PolicyReq) (err error) { + key, val := pr.KV() + if pa.cache.Contains(ctx, key, val) { + return nil + } + defer func() { + if err == nil { + cacheErr := pa.cache.Save(ctx, key, val) + err = errors.Wrap(err, cacheErr) + } + }() + checkReq := v1.CheckPermissionRequest{ // FullyConsistent means little caching will be available, which means performance will suffer. // Only use if a ZedToken is not available or absolutely latest information is required. @@ -134,6 +147,10 @@ func (pa *policyAgent) AddPolicy(ctx context.Context, pr auth.PolicyReq) error { func (pa *policyAgent) DeletePolicies(ctx context.Context, prs []auth.PolicyReq) error { updates := []*v1.RelationshipUpdate{} for _, pr := range prs { + if err := pa.cache.Remove(ctx, pr.KeyForRemoval()); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + updates = append(updates, &v1.RelationshipUpdate{ Operation: v1.RelationshipUpdate_OPERATION_DELETE, Relationship: &v1.Relationship{ @@ -154,6 +171,10 @@ func (pa *policyAgent) DeletePolicies(ctx context.Context, prs []auth.PolicyReq) } func (pa *policyAgent) DeletePolicyFilter(ctx context.Context, pr auth.PolicyReq) error { + if err := pa.cache.Remove(ctx, pr.KeyForRemoval()); err != nil { + return errors.Wrap(repoerr.ErrRemoveEntity, err) + } + req := &v1.DeleteRelationshipsRequest{ RelationshipFilter: &v1.RelationshipFilter{ ResourceType: pr.ObjectType, diff --git a/auth/tokenizer.go b/auth/tokenizer.go index 1aaed7df4f..991bbdf891 100644 --- a/auth/tokenizer.go +++ b/auth/tokenizer.go @@ -3,11 +3,27 @@ package auth +import "context" + // Tokenizer specifies API for encoding and decoding between string and Key. type Tokenizer interface { // Issue converts API Key to its string representation. Issue(key Key) (token string, err error) // Parse extracts API Key data from string token. - Parse(token string) (key Key, err error) + Parse(ctx context.Context, token string) (key Key, err error) + + // Revoke revokes the token. + Revoke(ctx context.Context, token string) error +} + +// TokenRepository specifies token persistence API. +// +//go:generate mockery --name TokenRepository --output=./mocks --filename token.go --quiet --note "Copyright (c) Abstract Machines" +type TokenRepository interface { + // Save persists the token. + Save(ctx context.Context, id string) (err error) + + // Contains checks if token with provided ID exists. + Contains(ctx context.Context, id string) (ok bool) } diff --git a/auth/tracing/tracing.go b/auth/tracing/tracing.go index fe58626b04..12bc9650c8 100644 --- a/auth/tracing/tracing.go +++ b/auth/tracing/tracing.go @@ -43,6 +43,13 @@ func (tm *tracingMiddleware) Revoke(ctx context.Context, token, id string) error return tm.svc.Revoke(ctx, token, id) } +func (tm *tracingMiddleware) RevokeToken(ctx context.Context, token string) error { + ctx, span := tm.tracer.Start(ctx, "revoke") + defer span.End() + + return tm.svc.RevokeToken(ctx, token) +} + func (tm *tracingMiddleware) RetrieveKey(ctx context.Context, token, id string) (auth.Key, error) { ctx, span := tm.tracer.Start(ctx, "retrieve_key", trace.WithAttributes( attribute.String("id", id), diff --git a/cmd/auth/main.go b/cmd/auth/main.go index 053effb5f4..1925746c83 100644 --- a/cmd/auth/main.go +++ b/cmd/auth/main.go @@ -18,11 +18,13 @@ import ( api "github.com/absmach/magistrala/auth/api" grpcapi "github.com/absmach/magistrala/auth/api/grpc" httpapi "github.com/absmach/magistrala/auth/api/http" + "github.com/absmach/magistrala/auth/cache" "github.com/absmach/magistrala/auth/events" "github.com/absmach/magistrala/auth/jwt" apostgres "github.com/absmach/magistrala/auth/postgres" "github.com/absmach/magistrala/auth/spicedb" "github.com/absmach/magistrala/auth/tracing" + redisclient "github.com/absmach/magistrala/internal/clients/redis" mglog "github.com/absmach/magistrala/logger" "github.com/absmach/magistrala/pkg/jaeger" "github.com/absmach/magistrala/pkg/postgres" @@ -37,6 +39,7 @@ import ( "github.com/authzed/grpcutil" "github.com/caarlos0/env/v11" "github.com/jmoiron/sqlx" + "github.com/redis/go-redis/v9" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -67,6 +70,8 @@ type config struct { SpicedbPort string `env:"MG_SPICEDB_PORT" envDefault:"50051"` SpicedbSchemaFile string `env:"MG_SPICEDB_SCHEMA_FILE" envDefault:"./docker/spicedb/schema.zed"` SpicedbPreSharedKey string `env:"MG_SPICEDB_PRE_SHARED_KEY" envDefault:"12345678"` + CacheURL string `env:"MG_AUTH_CACHE_URL" envDefault:"redis://localhost:6379/0"` + CacheKeyDuration time.Duration `env:"MG_AUTH_CACHE_KEY_DURATION" envDefault:"1h"` TraceRatio float64 `env:"MG_JAEGER_TRACE_RATIO" envDefault:"1.0"` ESURL string `env:"MG_ES_URL" envDefault:"nats://localhost:4222"` } @@ -122,6 +127,14 @@ func main() { }() tracer := tp.Tracer(svcName) + cacheclient, err := redisclient.Connect(cfg.CacheURL) + if err != nil { + logger.Error(err.Error()) + exitCode = 1 + return + } + defer cacheclient.Close() + spicedbclient, err := initSpiceDB(ctx, cfg) if err != nil { logger.Error(fmt.Sprintf("failed to init spicedb grpc client : %s\n", err.Error())) @@ -129,7 +142,7 @@ func main() { return } - svc := newService(ctx, db, tracer, cfg, dbConfig, logger, spicedbclient) + svc := newService(ctx, db, tracer, cfg, dbConfig, cacheclient, cfg.CacheKeyDuration, logger, spicedbclient) httpServerConfig := server.Config{Port: defSvcHTTPPort} if err := env.ParseWithOptions(&httpServerConfig, env.Options{Prefix: envPrefixHTTP}); err != nil { @@ -205,14 +218,18 @@ func initSchema(ctx context.Context, client *authzed.ClientWithExperimental, sch return nil } -func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental) auth.Service { +func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, cfg config, dbConfig pgclient.Config, cacheClient *redis.Client, keyDuration time.Duration, logger *slog.Logger, spicedbClient *authzed.ClientWithExperimental) auth.Service { database := postgres.NewDatabase(db, dbConfig, tracer) keysRepo := apostgres.New(database) + tokensRepo := apostgres.NewTokensRepository(database) domainsRepo := apostgres.NewDomainRepository(database) - pa := spicedb.NewPolicyAgent(spicedbClient, logger) + policiesCache := cache.NewPoliciesCache(cacheClient, keyDuration) + tokensCache := cache.NewTokensCache(cacheClient, keyDuration) + + pa := spicedb.NewPolicyAgent(spicedbClient, logger, policiesCache) idProvider := uuid.New() - t := jwt.New([]byte(cfg.SecretKey)) + t := jwt.New([]byte(cfg.SecretKey), tokensRepo, tokensCache) svc := auth.New(keysRepo, domainsRepo, idProvider, t, pa, cfg.AccessDuration, cfg.RefreshDuration, cfg.InvitationDuration) svc, err := events.NewEventStoreMiddleware(ctx, svc, cfg.ESURL) diff --git a/docker/.env b/docker/.env index 4dbd9048c8..2a86c2de98 100644 --- a/docker/.env +++ b/docker/.env @@ -93,6 +93,8 @@ MG_AUTH_DB_SSL_MODE=disable MG_AUTH_DB_SSL_CERT= MG_AUTH_DB_SSL_KEY= MG_AUTH_DB_SSL_ROOT_CERT= +MG_AUTH_CACHE_URL=redis://auth-redis:${MG_REDIS_TCP_PORT}/0 +MG_AUTH_CACHE_KEY_DURATION="1h" MG_AUTH_SECRET_KEY=HyE2D4RUt9nnKG6v8zKEqAp6g6ka8hhZsqUpzgKvnwpXrNVQSH MG_AUTH_ACCESS_TOKEN_DURATION="1h" MG_AUTH_REFRESH_TOKEN_DURATION="24h" diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 1283878a1b..35595d7547 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -17,6 +17,7 @@ volumes: magistrala-auth-db-volume: magistrala-invitations-db-volume: magistrala-ui-db-volume: + magistrala-auth-redis-volume: services: spicedb: @@ -64,6 +65,15 @@ services: volumes: - magistrala-spicedb-db-volume:/var/lib/postgresql/data + auth-redis: + image: redis:7.2.4-alpine + container_name: magistrala-auth-redis + restart: on-failure + networks: + - magistrala-base-net + volumes: + - magistrala-auth-redis-volume:/data + auth-db: image: postgres:16.2-alpine container_name: magistrala-auth-db @@ -83,6 +93,7 @@ services: image: magistrala/auth:${MG_RELEASE_TAG} container_name: magistrala-auth depends_on: + - auth-redis - auth-db - spicedb expose: @@ -120,6 +131,8 @@ services: MG_AUTH_DB_SSL_CERT: ${MG_AUTH_DB_SSL_CERT} MG_AUTH_DB_SSL_KEY: ${MG_AUTH_DB_SSL_KEY} MG_AUTH_DB_SSL_ROOT_CERT: ${MG_AUTH_DB_SSL_ROOT_CERT} + MG_AUTH_CACHE_URL: ${MG_AUTH_CACHE_URL} + MG_AUTH_CACHE_KEY_DURATION: ${MG_AUTH_CACHE_KEY_DURATION} MG_JAEGER_URL: ${MG_JAEGER_URL} MG_JAEGER_TRACE_RATIO: ${MG_JAEGER_TRACE_RATIO} MG_SEND_TELEMETRY: ${MG_SEND_TELEMETRY} @@ -295,6 +308,7 @@ services: image: magistrala/things:${MG_RELEASE_TAG} container_name: magistrala-things depends_on: + - things-redis - things-db - users - auth