Skip to content

Commit

Permalink
MG-2048 - Implement Personal Access Tokens (PATs) (#2492)
Browse files Browse the repository at this point in the history
Signed-off-by: nyagamunene <[email protected]>
  • Loading branch information
nyagamunene authored Dec 18, 2024
1 parent 0d126c3 commit e8e17f5
Show file tree
Hide file tree
Showing 42 changed files with 5,545 additions and 67 deletions.
76 changes: 73 additions & 3 deletions auth/api/grpc/auth/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"time"

"github.com/absmach/supermq/auth"
grpcapi "github.com/absmach/supermq/auth/api/grpc"
grpcAuthV1 "github.com/absmach/supermq/internal/grpc/auth/v1"
"github.com/go-kit/kit/endpoint"
Expand All @@ -17,9 +18,11 @@ import (
const authSvcName = "auth.v1.AuthService"

type authGrpcClient struct {
authenticate endpoint.Endpoint
authorize endpoint.Endpoint
timeout time.Duration
authenticate endpoint.Endpoint
authenticatePAT endpoint.Endpoint
authorize endpoint.Endpoint
authorizePAT endpoint.Endpoint
timeout time.Duration
}

var _ grpcAuthV1.AuthServiceClient = (*authGrpcClient)(nil)
Expand All @@ -35,6 +38,14 @@ func NewAuthClient(conn *grpc.ClientConn, timeout time.Duration) grpcAuthV1.Auth
decodeIdentifyResponse,
grpcAuthV1.AuthNRes{},
).Endpoint(),
authenticatePAT: kitgrpc.NewClient(
conn,
authSvcName,
"AuthenticatePAT",
encodeIdentifyRequest,
decodeIdentifyPATResponse,
grpcAuthV1.AuthNRes{},
).Endpoint(),
authorize: kitgrpc.NewClient(
conn,
authSvcName,
Expand All @@ -43,6 +54,14 @@ func NewAuthClient(conn *grpc.ClientConn, timeout time.Duration) grpcAuthV1.Auth
decodeAuthorizeResponse,
grpcAuthV1.AuthZRes{},
).Endpoint(),
authorizePAT: kitgrpc.NewClient(
conn,
authSvcName,
"AuthorizePAT",
encodeAuthorizePATRequest,
decodeAuthorizeResponse,
grpcAuthV1.AuthZRes{},
).Endpoint(),
timeout: timeout,
}
}
Expand All @@ -69,6 +88,23 @@ func decodeIdentifyResponse(_ context.Context, grpcRes interface{}) (interface{}
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), domainID: res.GetDomainId()}, nil
}

func (client authGrpcClient) AuthenticatePAT(ctx context.Context, token *grpcAuthV1.AuthNReq, _ ...grpc.CallOption) (*grpcAuthV1.AuthNRes, error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()

res, err := client.authenticatePAT(ctx, authenticateReq{token: token.GetToken()})
if err != nil {
return &grpcAuthV1.AuthNRes{}, grpcapi.DecodeError(err)
}
ir := res.(authenticateRes)
return &grpcAuthV1.AuthNRes{Id: ir.id, UserId: ir.userID}, nil
}

func decodeIdentifyPATResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
res := grpcRes.(*grpcAuthV1.AuthNRes)
return authenticateRes{id: res.GetId(), userID: res.GetUserId()}, nil
}

func (client authGrpcClient) Authorize(ctx context.Context, req *grpcAuthV1.AuthZReq, _ ...grpc.CallOption) (r *grpcAuthV1.AuthZRes, err error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
Expand Down Expand Up @@ -109,3 +145,37 @@ func encodeAuthorizeRequest(_ context.Context, grpcReq interface{}) (interface{}
Object: req.Object,
}, nil
}

func (client authGrpcClient) AuthorizePAT(ctx context.Context, req *grpcAuthV1.AuthZPatReq, _ ...grpc.CallOption) (r *grpcAuthV1.AuthZRes, err error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()

res, err := client.authorizePAT(ctx, authPATReq{
userID: req.GetUserId(),
patID: req.GetPatId(),
platformEntityType: auth.PlatformEntityType(req.GetPlatformEntityType()),
optionalDomainID: req.GetOptionalDomainId(),
optionalDomainEntityType: auth.DomainEntityType(req.GetOptionalDomainEntityType()),
operation: auth.OperationType(req.GetOperation()),
entityIDs: req.GetEntityIds(),
})
if err != nil {
return &grpcAuthV1.AuthZRes{}, grpcapi.DecodeError(err)
}

ar := res.(authorizeRes)
return &grpcAuthV1.AuthZRes{Authorized: ar.authorized, Id: ar.id}, nil
}

func encodeAuthorizePATRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(authPATReq)
return &grpcAuthV1.AuthZPatReq{
UserId: req.userID,
PatId: req.patID,
PlatformEntityType: uint32(req.platformEntityType),
OptionalDomainId: req.optionalDomainID,
OptionalDomainEntityType: uint32(req.optionalDomainEntityType),
Operation: uint32(req.operation),
EntityIds: req.entityIDs,
}, nil
}
31 changes: 31 additions & 0 deletions auth/api/grpc/auth/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ func authenticateEndpoint(svc auth.Service) endpoint.Endpoint {
}
}

func authenticatePATEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(authenticateReq)
if err := req.validate(); err != nil {
return authenticateRes{}, err
}

pat, err := svc.IdentifyPAT(ctx, req.token)
if err != nil {
return authenticateRes{}, err
}

return authenticateRes{id: pat.ID, userID: pat.User}, nil
}
}

func authorizeEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(authReq)
Expand All @@ -50,3 +66,18 @@ func authorizeEndpoint(svc auth.Service) endpoint.Endpoint {
return authorizeRes{authorized: true}, nil
}
}

func authorizePATEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(authPATReq)

if err := req.validate(); err != nil {
return authorizeRes{}, err
}
err := svc.AuthorizePAT(ctx, req.userID, req.patID, req.platformEntityType, req.optionalDomainID, req.optionalDomainEntityType, req.operation, req.entityIDs...)
if err != nil {
return authorizeRes{authorized: false}, err
}
return authorizeRes{authorized: true}, nil
}
}
180 changes: 164 additions & 16 deletions auth/api/grpc/auth/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ const (
invalidDuration = 7 * 24 * time.Hour
validToken = "valid"
inValidToken = "invalid"
validPATToken = "valid"
inValidPATToken = "invalid"
validPolicy = "valid"
)

var (
domainID = testsutil.GenerateUUID(&testing.T{})
authAddr = fmt.Sprintf("localhost:%d", port)
clientID = testsutil.GenerateUUID(&testing.T{})
)

func startGRPCServer(svc auth.Service, port int) *grpc.Server {
Expand All @@ -63,8 +66,8 @@ func startGRPCServer(svc auth.Service, port int) *grpc.Server {

func TestIdentify(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
defer conn.Close()
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
defer conn.Close()
grpcClient := grpcapi.NewAuthClient(conn, time.Second)

cases := []struct {
Expand Down Expand Up @@ -96,20 +99,23 @@ func TestIdentify(t *testing.T) {
}

for _, tc := range cases {
svcCall := svc.On("Identify", mock.Anything, mock.Anything, mock.Anything).Return(auth.Key{Subject: id, User: email, Domain: domainID}, tc.svcErr)
idt, err := grpcClient.Authenticate(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token})
if idt != nil {
assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("Identify", mock.Anything, mock.Anything).Return(auth.Key{Subject: id, User: email, Domain: domainID}, tc.svcErr)
idt, err := grpcClient.Authenticate(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token})
if idt != nil {
assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}

func TestAuthorize(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
defer conn.Close()
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
defer conn.Close()

grpcClient := grpcapi.NewAuthClient(conn, time.Second)

cases := []struct {
Expand Down Expand Up @@ -219,12 +225,154 @@ func TestAuthorize(t *testing.T) {
},
}
for _, tc := range cases {
svccall := svc.On("Authorize", mock.Anything, mock.Anything).Return(tc.err)
ar, err := grpcClient.Authorize(context.Background(), tc.authRequest)
if ar != nil {
assert.Equal(t, tc.authResponse, ar, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.authResponse, ar))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svccall.Unset()
t.Run(tc.desc, func(t *testing.T) {
svccall := svc.On("Authorize", mock.Anything, mock.Anything).Return(tc.err)
ar, err := grpcClient.Authorize(context.Background(), tc.authRequest)
if ar != nil {
assert.Equal(t, tc.authResponse, ar, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.authResponse, ar))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svccall.Unset()
})
}
}

func TestIdentifyPAT(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
defer conn.Close()
grpcClient := grpcapi.NewAuthClient(conn, time.Second)

cases := []struct {
desc string
token string
idt *grpcAuthV1.AuthNRes
svcErr error
err error
}{
{
desc: "authenticate user with valid user token",
token: validToken,
idt: &grpcAuthV1.AuthNRes{Id: id, UserId: clientID},
err: nil,
},
{
desc: "authenticate user with invalid user token",
token: "invalid",
idt: &grpcAuthV1.AuthNRes{},
svcErr: svcerr.ErrAuthentication,
err: svcerr.ErrAuthentication,
},
{
desc: "authenticate user with empty token",
token: "",
idt: &grpcAuthV1.AuthNRes{},
err: apiutil.ErrBearerToken,
},
}

for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svcCall := svc.On("IdentifyPAT", mock.Anything, tc.token).Return(auth.PAT{ID: id, User: clientID, IssuedAt: time.Now()}, tc.svcErr)
idt, err := grpcClient.AuthenticatePAT(context.Background(), &grpcAuthV1.AuthNReq{Token: tc.token})
if idt != nil {
assert.Equal(t, tc.idt, idt, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.idt, idt))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svcCall.Unset()
})
}
}

func TestAuthorizePAT(t *testing.T) {
conn, err := grpc.NewClient(authAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
assert.Nil(t, err, fmt.Sprintf("Unexpected error creating client connection %s", err))
defer conn.Close()

grpcClient := grpcapi.NewAuthClient(conn, time.Second)
cases := []struct {
desc string
token string
authRequest *grpcAuthV1.AuthZPatReq
authResponse *grpcAuthV1.AuthZRes
err error
}{
{
desc: "authorize user with authorized token",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
UserId: id,
PatId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: true},
err: nil,
},
{
desc: "authorize user with unauthorized token",
token: inValidPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
UserId: id,
PatId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: svcerr.ErrAuthorization,
},
{
desc: "authorize user with missing user id",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
PatId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingUserID,
},
{
desc: "authorize user with missing pat id",
token: validPATToken,
authRequest: &grpcAuthV1.AuthZPatReq{
UserId: id,
PlatformEntityType: uint32(auth.PlatformDomainsScope),
OptionalDomainId: domainID,
OptionalDomainEntityType: uint32(auth.DomainClientsScope),
Operation: uint32(auth.CreateOp),
EntityIds: []string{clientID},
},
authResponse: &grpcAuthV1.AuthZRes{Authorized: false},
err: apiutil.ErrMissingPATID,
},
}
for _, tc := range cases {
t.Run(tc.desc, func(t *testing.T) {
svccall := svc.On("AuthorizePAT",
mock.Anything,
tc.authRequest.UserId,
tc.authRequest.PatId,
mock.Anything,
tc.authRequest.OptionalDomainId,
mock.Anything,
mock.Anything,
mock.Anything).Return(tc.err)
ar, err := grpcClient.AuthorizePAT(context.Background(), tc.authRequest)
if ar != nil {
assert.Equal(t, tc.authResponse, ar, fmt.Sprintf("%s: expected %v got %v", tc.desc, tc.authResponse, ar))
}
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
svccall.Unset()
})
}
}
Loading

0 comments on commit e8e17f5

Please sign in to comment.