Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional options for service auth #25

Merged
merged 12 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ require (
github.com/apache/thrift v0.16.0
github.com/gofrs/uuid v3.2.0+incompatible
github.com/golang-jwt/jwt/v5 v5.0.0
github.com/google/go-cmp v0.5.6
github.com/reddit/baseplate.go v0.9.6
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e
)
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
Expand Down
90 changes: 90 additions & 0 deletions lib/go/edgecontext/edgecontext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/apache/thrift/lib/go/thrift"
"github.com/gofrs/uuid"
"github.com/google/go-cmp/cmp"
"github.com/reddit/baseplate.go/detach"
"github.com/reddit/baseplate.go/experiments"
"github.com/reddit/baseplate.go/timebp"
Expand All @@ -32,6 +33,9 @@ const (
"\x00" +
// end of struct
"\x00")

headerWithValidServiceAuth = "\x0c\x00\x01\x00\x0c\x00\x02\x00\x0b\x00\x03\x00\x00\x02\x0aeyJhbGciOiJSUzI1NiIsImtpZCI6IlNIQTI1NjpsWjBoa1dSc0RwYXBlQnUyZWtYOVdZMm9ZSW5Id2RSYVhUd3RCZWNEaWNJIiwidHlwIjoiSldUIn0.eyJzdWIiOiJzZXJ2aWNlL3Rlc3Qtc2VydmljZSIsImV4cCI6MjUyNDYwODAwMH0.P41Iahxu-Bbg5srTFSQTBkzwiff4ytlhVBUYuyYTFGY_7XCyKdZywUmVHRY_Q2w8Q2uaybnmuoM95JhRpdNYcTPIYWEby4Z5DSV-zMqqmHnP22aH_sAckFQl86Yw_2pdZpKKJ-KQkyT0vEkxe-vNs5HhEdBr6Rae0g2SKEr7RaPMoToq6xpucDAREVWa7yJMtyyNtiVixeLoxTegRLOZTFEVt4TTYKDuT2FdEY5P2b8BOSpFMoiv9w51gZO1qvn9Zjrl00Z-lI_onihMIkrG_viWVAlzEl8d5ZWuJVjHJvm7O0CS4OuhZocE2qbYQrw9THSS1Mh4YR-_r2v1ArYnVA\x0c\x00\x04\x00\x0c\x00\x05\x0b\x00\x01\x00\x00\x00\x0eorigin/service\x00\x0c\x00\x06\x00\x0c\x00\x07\x0b\x00\x01\x00\x00\x00$1566dce9-9567-4952-b23b-9fd72e111162\x00\x00"
headerWithValidServiceAuthAndOptions = "\x0c\x00\x01\x00\x0c\x00\x02\x00\x0b\x00\x03\x00\x00\x02VeyJhbGciOiJSUzI1NiIsImtpZCI6IlNIQTI1NjpsWjBoa1dSc0RwYXBlQnUyZWtYOVdZMm9ZSW5Id2RSYVhUd3RCZWNEaWNJIiwidHlwIjoiSldUIn0.eyJzdWIiOiJzZXJ2aWNlL3Rlc3Qtc2VydmljZSIsImV4cCI6MjUyNDYwODAwMCwib2JvIjp7ImFpZCI6InQyX2RlYWRiZWVmIiwicm9sZXMiOlsiYWRtaW4iXX0sInNlYSI6dHJ1ZX0.PVefAKWUFfk_7QKen6Iz0Cfu95Yp92lYETlrxCUacLsa9u-qz36aet21iwFrdnJiz7gDeJRH7sOJyh6jRmkD0ptWs4Zl7VqpZY-ALgDOdhwSHoUIoV2L7twT-Dm3Tdyfbzq01fOni9ioq5akKnETC5IbLSOqp1ssWJcgo_9g-X-SdRiuf5u8YHD2Mrep5U21bkbYnm4rK9tX_oCnhrrp4rbXi5yogx594oNmOWUedIeyv6QY_xVGbaXOz7deBIWQY2fSYG3cpiBNtSYEJ4yDTbjGY0G1Vp78bX8YZlboc13TGoDpARdfHuHeQU0wAQEhi7pu0Q4FufEVua4q1f0P3A\x0c\x00\x04\x00\x0c\x00\x05\x0b\x00\x01\x00\x00\x00\x0eorigin/service\x00\x0c\x00\x06\x00\x0c\x00\x07\x0b\x00\x01\x00\x00\x00$a3b2d5c2-ab27-4948-9dae-78a3ffb46957\x00\x00"
)

const (
Expand All @@ -42,6 +46,7 @@ const (
expectedOrigin = "baseplate"
expectedSessionID = "beefdead"
expectedRequestID = "2adaff94-9067-4de0-a00b-79fded5cff9e"
expectedServiceName = "test-service"

emptyDeviceID = "00000000-0000-0000-0000-000000000000"
)
Expand Down Expand Up @@ -431,6 +436,16 @@ func TestFromHeader(t *testing.T) {
)
},
)

t.Run(
"service",
func(t *testing.T) {
_, ok := e.Service()
if ok {
t.Errorf("Expected service to be false, got true")
}
},
)
},
)

Expand Down Expand Up @@ -769,6 +784,81 @@ func TestFromHeader(t *testing.T) {
}
},
)

t.Run(
"service",
func(t *testing.T) {
t.Run("service auth only", func(t *testing.T) {
e, err := edgecontext.FromHeader(context.Background(), headerWithValidServiceAuth, globalTestImpl)
if err != nil {
t.Fatal(err)
}

svc, ok := e.Service()
if !ok {
t.Fatal("Expected service to be true, got false")
}
name, ok := svc.Name()
if !ok {
t.Fatal("Failed to get service name")
}
if name != expectedServiceName {
t.Errorf("Expected service name %q, got %q", expectedServiceName, name)
}

if id, ok := svc.OnBehalfOfID(); ok {
t.Errorf("expected no id, got %q", id)
}

if roles, ok := svc.OnBehalfOfRoles(); ok {
t.Errorf("expected no roles, got %q", roles)
}

if svc.RequestsElevatedAccess() {
t.Errorf("expected no elevated access, got true")
}
})

t.Run("with additional options", func(t *testing.T) {
e, err := edgecontext.FromHeader(context.Background(), headerWithValidServiceAuthAndOptions, globalTestImpl)
if err != nil {
t.Fatal(err)
}

svc, ok := e.Service()
if !ok {
t.Fatal("Expected service to be true, got false")
}
name, ok := svc.Name()
if !ok {
t.Fatal("Failed to get service name")
}
if name != expectedServiceName {
t.Errorf("Expected service name %q, got %q", expectedServiceName, name)
}

id, ok := svc.OnBehalfOfID()
if !ok {
t.Fatal("Failed to get on behalf of id")
}
if id != expectedLoID {
t.Errorf("Expected on behalf of id %q, got %q", expectedLoID, id)
}

roles, ok := svc.OnBehalfOfRoles()
if !ok {
t.Fatal("Failed to get on behalf of roles")
}
if diff := cmp.Diff([]string{"admin"}, roles); diff != "" {
t.Errorf("mismatch (-want +got)\n%s\n", diff)
}

if !svc.RequestsElevatedAccess() {
t.Errorf("expected elevated access, got false")
}
})
},
)
}

func TestDetachIntegration(t *testing.T) {
Expand Down
4 changes: 2 additions & 2 deletions lib/go/edgecontext/oauth_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ func (o OAuthClient) ID() string {
//
// For example, use:
//
// if client.IsType("third_party")
// if client.IsType("third_party")
//
// Instead of:
//
// if !client.IsType("first_party")
// if !client.IsType("first_party")
func (o OAuthClient) IsType(types ...string) bool {
clientType := AuthenticationToken(o).OAuthClientType
for _, t := range types {
Expand Down
52 changes: 49 additions & 3 deletions lib/go/edgecontext/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,60 @@ const servicePrefix = "service/"
// service talking to us.
type Service AuthenticationToken

func (s Service) isService() bool {
subject := AuthenticationToken(s).Subject()
return strings.HasPrefix(subject, servicePrefix)
}

// Name returns the name of the service.
//
// If it's not coming from an authenticated service,
// ("", false) will be returned.
func (s Service) Name() (name string, ok bool) {
subject := AuthenticationToken(s).Subject()
if strings.HasPrefix(subject, servicePrefix) {
if s.isService() {
subject := AuthenticationToken(s).Subject()
return subject[len(servicePrefix):], true
}
return
return "", false
}

// OnBehalfOfID returns the ID of the user on whose behalf the service is acting.
//
// If it's not coming from an authenticated service,
// ("", false) will be returned.
func (s Service) OnBehalfOfID() (id string, ok bool) {
if s.isService() {
token := AuthenticationToken(s)
if token.OnBehalfOf == nil {
return "", false
}
if strings.HasPrefix(token.OnBehalfOf.AccountID, userPrefix) {
return token.OnBehalfOf.AccountID, true
}
return "", false
}
return "", false
}

// OnBehalfOfRoles returns the roles of the user on whose behalf the service is acting.
//
// If it's not coming from an authenticated service,
// (nil, false) will be returned.
func (s Service) OnBehalfOfRoles() (roles []string, ok bool) {
if s.isService() {
token := AuthenticationToken(s)
if token.OnBehalfOf == nil {
return nil, false
}
return token.OnBehalfOf.Roles, true
}
return nil, false
}

// RequestsElevatedAccess returns whether the service requested elevated access.
func (s Service) RequestsElevatedAccess() bool {
if s.isService() {
return AuthenticationToken(s).ServiceRequestedElevatedAccess
}
return false
}
7 changes: 7 additions & 0 deletions lib/go/edgecontext/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ type AuthenticationToken struct {
ID string `json:"id,omitempty"`
CreatedAt timebp.TimestampMillisecond `json:"created_ms,omitempty"`
} `json:"loid,omitempty"`

OnBehalfOf *struct {
AccountID string `json:"aid,omitempty"`
Roles []string `json:"roles,omitempty"`
} `json:"obo,omitempty"`

ServiceRequestedElevatedAccess bool `json:"sea,omitempty"`
}

// Subject returns the subject field of the token.
Expand Down
2 changes: 1 addition & 1 deletion lib/go/edgecontext/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (u User) HasRole(role string) bool {
// Since in most cases the roles slice would be quite small,
// it's better to iterate them than converting the slice into a set.
for _, r := range token.Roles {
if strings.ToLower(role) == strings.ToLower(r) {
if strings.EqualFold(role, r) {
return true
}
}
Expand Down
67 changes: 62 additions & 5 deletions lib/py/reddit_edgecontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ def loid(self) -> Optional[str]:
def loid_created_ms(self) -> Optional[int]:
raise NotImplementedError

@property
def on_behalf_of_id(self) -> Optional[str]:
raise NotImplementedError

@property
def on_behalf_of_roles(self) -> Optional[Set[str]]:
raise NotImplementedError

@property
def requests_elevated_access(self) -> Optional[bool]:
raise NotImplementedError


class ValidatedAuthenticationToken(AuthenticationToken):
def __init__(self, payload: Dict[str, Any]):
Expand Down Expand Up @@ -150,6 +162,18 @@ def loid(self) -> Optional[str]:
def loid_created_ms(self) -> Optional[int]:
return (self.payload.get("loid") or {}).get("created_ms")

@property
def on_behalf_of_id(self) -> Optional[str]:
return (self.payload.get("obo") or {}).get("aid")

@property
def on_behalf_of_roles(self) -> Optional[Set[str]]:
return (self.payload.get("obo") or {}).get("roles")

@property
def requests_elevated_access(self) -> Optional[bool]:
return self.payload.get("sea")


class InvalidAuthenticationToken(AuthenticationToken):
@property
Expand Down Expand Up @@ -180,6 +204,18 @@ def loid(self) -> Optional[str]:
def loid_created_ms(self) -> Optional[int]:
raise NoAuthenticationError

@property
def on_behalf_of_id(self) -> Optional[str]:
raise NoAuthenticationError

@property
def on_behalf_of_roles(self) -> Optional[Set[str]]:
raise NoAuthenticationError

@property
def requests_elevated_access(self) -> Optional[bool]:
raise NoAuthenticationError


class Session(NamedTuple):
"""Wrapper for the session values in the EdgeContext."""
Expand Down Expand Up @@ -384,6 +420,12 @@ class Service(NamedTuple):
authentication_token: AuthenticationToken
"""The authentication token for this request."""

def is_service(self) -> bool:
subject = self.authentication_token.subject
if subject is None or subject == "":
return False
return subject.startswith("service/")

@property
def name(self) -> str:
"""Return the authenticated service name.
Expand All @@ -395,12 +437,29 @@ def name(self) -> str:

"""
subject = self.authentication_token.subject
if not (subject and subject.startswith("service/")):
if subject is None or subject == "":
raise NoAuthenticationError

name = subject[len("service/") :]
return name

@property
def on_behalf_of_id(self) -> Optional[str]:
if not self.is_service():
raise NoAuthenticationError
return self.authentication_token.on_behalf_of_id

@property
def on_behalf_of_roles(self) -> Optional[Set[str]]:
if not self.is_service():
raise NoAuthenticationError
return self.authentication_token.on_behalf_of_roles

@property
def requests_elevated_access(self) -> Optional[bool]:
if not self.is_service():
raise NoAuthenticationError
return self.authentication_token.requests_elevated_access


class EdgeContext:
"""Contextual information about the initial request to an edge service.
Expand Down Expand Up @@ -480,9 +539,7 @@ def request_id(self) -> RequestId:
@cached_property
def locale(self) -> Locale:
""":py:class:`~reddit_edgecontext.Locale` object for the current context."""
return Locale(
locale_code=self._t_request.locale.locale_code,
)
return Locale(locale_code=self._t_request.locale.locale_code)

@cached_property
def _t_request(self) -> TRequest:
Expand Down
Loading
Loading