Skip to content

Commit

Permalink
Add authenticate PAT grpc endpoint
Browse files Browse the repository at this point in the history
Signed-off-by: nyagamunene <[email protected]>
  • Loading branch information
nyagamunene committed Nov 19, 2024
1 parent 9140565 commit 1fc0369
Show file tree
Hide file tree
Showing 14 changed files with 260 additions and 139 deletions.
32 changes: 29 additions & 3 deletions auth/api/grpc/auth/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ import (
const authSvcName = "auth.v1.AuthService"

type authGrpcClient struct {
authenticate endpoint.Endpoint
authorize endpoint.Endpoint
timeout time.Duration
authenticate endpoint.Endpoint
authenticatePAT endpoint.Endpoint
authorize endpoint.Endpoint
timeout time.Duration
}

var _ grpcAuthV1.AuthServiceClient = (*authGrpcClient)(nil)
Expand All @@ -35,6 +36,14 @@ func NewAuthClient(conn *grpc.ClientConn, timeout time.Duration) grpcAuthV1.Auth
decodeIdentifyResponse,
grpcAuthV1.AuthNRes{},
).Endpoint(),
authenticatePAT: kitgrpc.NewClient(
conn,
authSvcName,
"AuthenticatePAT",
encodeIdentifyRequest,
decodeIdentifyPATResponse,
grpcAuthV1.AuthNPATRes{},
).Endpoint(),
authorize: kitgrpc.NewClient(
conn,
authSvcName,
Expand Down Expand Up @@ -69,6 +78,23 @@ func decodeIdentifyResponse(_ context.Context, grpcRes interface{}) (interface{}
return authenticateRes{id: res.GetId(), userID: res.GetUserId(), domainID: res.GetDomainId()}, nil
}

func (client authGrpcClient) AuthenticatePAT(ctx context.Context, token *grpcAuthV1.AuthNReq, _ ...grpc.CallOption) (*grpcAuthV1.AuthNPATRes, error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()

res, err := client.authenticatePAT(ctx, authenticateReq{token: token.GetToken()})
if err != nil {
return &grpcAuthV1.AuthNPATRes{}, grpcapi.DecodeError(err)
}
ir := res.(authenticateRes)
return &grpcAuthV1.AuthNPATRes{Id: ir.id, UserId: ir.userID}, nil
}

func decodeIdentifyPATResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
res := grpcRes.(*grpcAuthV1.AuthNPATRes)
return authenticateRes{id: res.GetId(), userID: res.GetUserId()}, nil
}

func (client authGrpcClient) Authorize(ctx context.Context, req *grpcAuthV1.AuthZReq, _ ...grpc.CallOption) (r *grpcAuthV1.AuthZRes, err error) {
ctx, cancel := context.WithTimeout(ctx, client.timeout)
defer cancel()
Expand Down
16 changes: 16 additions & 0 deletions auth/api/grpc/auth/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ func authenticateEndpoint(svc auth.Service) endpoint.Endpoint {
}
}

func authenticatePATEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(authenticateReq)
if err := req.validate(); err != nil {
return authenticateRes{}, err
}

pat, err := svc.IdentifyPAT(ctx, req.token)
if err != nil {
return authenticateRes{}, err
}

return authenticateRes{id: pat.ID, userID: pat.User}, nil
}
}

func authorizeEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(authReq)
Expand Down
16 changes: 0 additions & 16 deletions auth/api/grpc/auth/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,3 @@ func (req authReq) validate() error {

return nil
}

type authPATReq struct {
paToken string
platformEntityType string
optionalDomainID string
optionalDomainEntityType string
operation string
entityIDs []string
}

func (req authPATReq) validate() error {
if req.paToken == "" {
return apiutil.ErrBearerToken
}
return nil
}
6 changes: 3 additions & 3 deletions auth/api/grpc/auth/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
package auth

type authenticateRes struct {
id string
userID string
domainID string
id string
userID string
domainID string
}

type authorizeRes struct {
Expand Down
24 changes: 22 additions & 2 deletions auth/api/grpc/auth/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ var _ grpcAuthV1.AuthServiceServer = (*authGrpcServer)(nil)

type authGrpcServer struct {
grpcAuthV1.UnimplementedAuthServiceServer
authorize kitgrpc.Handler
authenticate kitgrpc.Handler
authorize kitgrpc.Handler
authenticate kitgrpc.Handler
authenticatePAT kitgrpc.Handler
}

// NewAuthServer returns new AuthnServiceServer instance.
Expand All @@ -34,6 +35,12 @@ func NewAuthServer(svc auth.Service) grpcAuthV1.AuthServiceServer {
decodeAuthenticateRequest,
encodeAuthenticateResponse,
),

authenticatePAT: kitgrpc.NewServer(
(authenticatePATEndpoint(svc)),
decodeAuthenticateRequest,
encodeAuthenticatePATResponse,
),
}
}

Expand All @@ -45,6 +52,14 @@ func (s *authGrpcServer) Authenticate(ctx context.Context, req *grpcAuthV1.AuthN
return res.(*grpcAuthV1.AuthNRes), nil
}

func (s *authGrpcServer) AuthenticatePAT(ctx context.Context, req *grpcAuthV1.AuthNReq) (*grpcAuthV1.AuthNPATRes, error) {
_, res, err := s.authenticatePAT.ServeGRPC(ctx, req)
if err != nil {
return nil, grpcapi.EncodeError(err)
}
return res.(*grpcAuthV1.AuthNPATRes), nil
}

func (s *authGrpcServer) Authorize(ctx context.Context, req *grpcAuthV1.AuthZReq) (*grpcAuthV1.AuthZRes, error) {
_, res, err := s.authorize.ServeGRPC(ctx, req)
if err != nil {
Expand All @@ -63,6 +78,11 @@ func encodeAuthenticateResponse(_ context.Context, grpcRes interface{}) (interfa
return &grpcAuthV1.AuthNRes{Id: res.id, UserId: res.userID, DomainId: res.domainID}, nil
}

func encodeAuthenticatePATResponse(_ context.Context, grpcRes interface{}) (interface{}, error) {
res := grpcRes.(authenticateRes)
return &grpcAuthV1.AuthNPATRes{Id: res.id, UserId: res.userID}, nil
}

func decodeAuthorizeRequest(_ context.Context, grpcReq interface{}) (interface{}, error) {
req := grpcReq.(*grpcAuthV1.AuthZReq)
return authReq{
Expand Down
15 changes: 0 additions & 15 deletions auth/api/http/pats/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,18 +185,3 @@ func clearPATAllScopeEntryEndpoint(svc auth.Service) endpoint.Endpoint {
return clearAllScopeEntryRes{}, nil
}
}

func authorizePATEndpoint(svc auth.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
req := request.(authorizePATReq)
if err := req.validate(); err != nil {
return nil, err
}

if err := svc.AuthorizePAT(ctx, req.token, req.PlatformEntityType, req.OptionalDomainID, req.OptionalDomainEntityType, req.Operation, req.EntityIDs...); err != nil {
return nil, err
}

return authorizePATRes{}, nil
}
}
19 changes: 0 additions & 19 deletions auth/api/http/pats/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,6 @@ func MakeHandler(svc auth.Service, mux *chi.Mux, logger *slog.Logger) *chi.Mux {
api.EncodeResponse,
opts...,
).ServeHTTP)

r.Post("/authorize", kithttp.NewServer(
(authorizePATEndpoint(svc)),
decodeAuthorizePATRequest,
api.EncodeResponse,
opts...,
).ServeHTTP)
})
return mux
}
Expand Down Expand Up @@ -252,15 +245,3 @@ func decodeClearPATAllScopeEntryRequest(_ context.Context, r *http.Request) (int
id: chi.URLParam(r, "id"),
}, nil
}

func decodeAuthorizePATRequest(_ context.Context, r *http.Request) (interface{}, error) {
if !strings.Contains(r.Header.Get("Content-Type"), contentType) {
return nil, apiutil.ErrUnsupportedContentType
}

req := authorizePATReq{token: apiutil.ExtractBearerToken(r)}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, errors.Wrap(errors.ErrMalformedEntity, err)
}
return req, nil
}
15 changes: 0 additions & 15 deletions auth/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,6 @@ func (svc service) RetrieveKey(ctx context.Context, token, id string) (Key, erro
}

func (svc service) Identify(ctx context.Context, token string) (Key, error) {
if strings.HasPrefix(token, patPrefix+patSecretSeparator) {
pat, err := svc.IdentifyPAT(ctx, token)
if err != nil {
return Key{}, err
}
return Key{
ID: pat.ID,
Type: PersonalAccessToken,
Subject: pat.User,
User: pat.User,
IssuedAt: pat.IssuedAt,
ExpiresAt: pat.ExpiresAt,
}, nil
}

key, err := svc.tokenizer.Parse(token)
if errors.Contains(err, ErrExpiry) {
err = svc.keys.Remove(ctx, key.Issuer, key.ID)
Expand Down
29 changes: 17 additions & 12 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"strings"

"github.com/absmach/magistrala/auth"
"github.com/absmach/magistrala/pkg/apiutil"
mgauthn "github.com/absmach/magistrala/pkg/authn"
"github.com/go-chi/chi/v5"
Expand All @@ -17,7 +18,7 @@ type sessionKeyType string

const (
SessionKey = sessionKeyType("session")
seperator = "_"
patPrefix = "pat_"
)

func AuthenticateMiddleware(authn mgauthn.Authentication, domainCheck bool) func(http.Handler) http.Handler {
Expand All @@ -28,16 +29,20 @@ func AuthenticateMiddleware(authn mgauthn.Authentication, domainCheck bool) func
EncodeError(r.Context(), apiutil.ErrBearerToken, w)
return
}

resp, err := authn.Authenticate(r.Context(), token)
if err != nil {
EncodeError(r.Context(), err, w)
return
}

resp.Type = mgauthn.AccessToken
if strings.HasPrefix(token, "pat"+seperator) {
resp.Type = mgauthn.PersonalAccessToken
var resp mgauthn.Session
var err error
if strings.HasPrefix(token, patPrefix) {
resp, err = authn.AuthenticatePAT(r.Context(), token)
if err != nil {
EncodeError(r.Context(), err, w)
return
}
} else {
resp, err = authn.Authenticate(r.Context(), token)
if err != nil {
EncodeError(r.Context(), err, w)
return
}
}

if domainCheck {
Expand All @@ -47,7 +52,7 @@ func AuthenticateMiddleware(authn mgauthn.Authentication, domainCheck bool) func
return
}
resp.DomainID = domain
resp.DomainUserID = domain + seperator + resp.UserID
resp.DomainUserID = auth.EncodeDomainUserID(domain, resp.UserID)
}

ctx := context.WithValue(r.Context(), SessionKey, resp)
Expand Down
Loading

0 comments on commit 1fc0369

Please sign in to comment.