diff --git a/go.mod b/go.mod index 2a5a895..b756bfb 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/apache/thrift v0.16.0 github.com/gofrs/uuid v3.2.0+incompatible github.com/golang-jwt/jwt/v5 v5.0.0 + github.com/google/go-cmp v0.5.6 github.com/reddit/baseplate.go v0.9.6 golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e ) diff --git a/go.sum b/go.sum index 1653ce6..8004bf7 100644 --- a/go.sum +++ b/go.sum @@ -90,6 +90,7 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= diff --git a/lib/go/edgecontext/edgecontext_test.go b/lib/go/edgecontext/edgecontext_test.go index efda2d4..cb7e1d7 100644 --- a/lib/go/edgecontext/edgecontext_test.go +++ b/lib/go/edgecontext/edgecontext_test.go @@ -8,6 +8,7 @@ import ( "github.com/apache/thrift/lib/go/thrift" "github.com/gofrs/uuid" + "github.com/google/go-cmp/cmp" "github.com/reddit/baseplate.go/detach" "github.com/reddit/baseplate.go/experiments" "github.com/reddit/baseplate.go/timebp" @@ -32,6 +33,9 @@ const ( "\x00" + // end of struct "\x00") + + headerWithValidServiceAuth = "\x0c\x00\x01\x00\x0c\x00\x02\x00\x0b\x00\x03\x00\x00\x02\x0aeyJhbGciOiJSUzI1NiIsImtpZCI6IlNIQTI1NjpsWjBoa1dSc0RwYXBlQnUyZWtYOVdZMm9ZSW5Id2RSYVhUd3RCZWNEaWNJIiwidHlwIjoiSldUIn0.eyJzdWIiOiJzZXJ2aWNlL3Rlc3Qtc2VydmljZSIsImV4cCI6MjUyNDYwODAwMH0.P41Iahxu-Bbg5srTFSQTBkzwiff4ytlhVBUYuyYTFGY_7XCyKdZywUmVHRY_Q2w8Q2uaybnmuoM95JhRpdNYcTPIYWEby4Z5DSV-zMqqmHnP22aH_sAckFQl86Yw_2pdZpKKJ-KQkyT0vEkxe-vNs5HhEdBr6Rae0g2SKEr7RaPMoToq6xpucDAREVWa7yJMtyyNtiVixeLoxTegRLOZTFEVt4TTYKDuT2FdEY5P2b8BOSpFMoiv9w51gZO1qvn9Zjrl00Z-lI_onihMIkrG_viWVAlzEl8d5ZWuJVjHJvm7O0CS4OuhZocE2qbYQrw9THSS1Mh4YR-_r2v1ArYnVA\x0c\x00\x04\x00\x0c\x00\x05\x0b\x00\x01\x00\x00\x00\x0eorigin/service\x00\x0c\x00\x06\x00\x0c\x00\x07\x0b\x00\x01\x00\x00\x00$1566dce9-9567-4952-b23b-9fd72e111162\x00\x00" + headerWithValidServiceAuthAndOptions = "\x0c\x00\x01\x00\x0c\x00\x02\x00\x0b\x00\x03\x00\x00\x02VeyJhbGciOiJSUzI1NiIsImtpZCI6IlNIQTI1NjpsWjBoa1dSc0RwYXBlQnUyZWtYOVdZMm9ZSW5Id2RSYVhUd3RCZWNEaWNJIiwidHlwIjoiSldUIn0.eyJzdWIiOiJzZXJ2aWNlL3Rlc3Qtc2VydmljZSIsImV4cCI6MjUyNDYwODAwMCwib2JvIjp7ImFpZCI6InQyX2RlYWRiZWVmIiwicm9sZXMiOlsiYWRtaW4iXX0sInNlYSI6dHJ1ZX0.PVefAKWUFfk_7QKen6Iz0Cfu95Yp92lYETlrxCUacLsa9u-qz36aet21iwFrdnJiz7gDeJRH7sOJyh6jRmkD0ptWs4Zl7VqpZY-ALgDOdhwSHoUIoV2L7twT-Dm3Tdyfbzq01fOni9ioq5akKnETC5IbLSOqp1ssWJcgo_9g-X-SdRiuf5u8YHD2Mrep5U21bkbYnm4rK9tX_oCnhrrp4rbXi5yogx594oNmOWUedIeyv6QY_xVGbaXOz7deBIWQY2fSYG3cpiBNtSYEJ4yDTbjGY0G1Vp78bX8YZlboc13TGoDpARdfHuHeQU0wAQEhi7pu0Q4FufEVua4q1f0P3A\x0c\x00\x04\x00\x0c\x00\x05\x0b\x00\x01\x00\x00\x00\x0eorigin/service\x00\x0c\x00\x06\x00\x0c\x00\x07\x0b\x00\x01\x00\x00\x00$a3b2d5c2-ab27-4948-9dae-78a3ffb46957\x00\x00" ) const ( @@ -42,6 +46,7 @@ const ( expectedOrigin = "baseplate" expectedSessionID = "beefdead" expectedRequestID = "2adaff94-9067-4de0-a00b-79fded5cff9e" + expectedServiceName = "test-service" emptyDeviceID = "00000000-0000-0000-0000-000000000000" ) @@ -431,6 +436,16 @@ func TestFromHeader(t *testing.T) { ) }, ) + + t.Run( + "service", + func(t *testing.T) { + _, ok := e.Service() + if ok { + t.Errorf("Expected service to be false, got true") + } + }, + ) }, ) @@ -769,6 +784,81 @@ func TestFromHeader(t *testing.T) { } }, ) + + t.Run( + "service", + func(t *testing.T) { + t.Run("service auth only", func(t *testing.T) { + e, err := edgecontext.FromHeader(context.Background(), headerWithValidServiceAuth, globalTestImpl) + if err != nil { + t.Fatal(err) + } + + svc, ok := e.Service() + if !ok { + t.Fatal("Expected service to be true, got false") + } + name, ok := svc.Name() + if !ok { + t.Fatal("Failed to get service name") + } + if name != expectedServiceName { + t.Errorf("Expected service name %q, got %q", expectedServiceName, name) + } + + if id, ok := svc.OnBehalfOfID(); ok { + t.Errorf("expected no id, got %q", id) + } + + if roles, ok := svc.OnBehalfOfRoles(); ok { + t.Errorf("expected no roles, got %q", roles) + } + + if svc.RequestsElevatedAccess() { + t.Errorf("expected no elevated access, got true") + } + }) + + t.Run("with additional options", func(t *testing.T) { + e, err := edgecontext.FromHeader(context.Background(), headerWithValidServiceAuthAndOptions, globalTestImpl) + if err != nil { + t.Fatal(err) + } + + svc, ok := e.Service() + if !ok { + t.Fatal("Expected service to be true, got false") + } + name, ok := svc.Name() + if !ok { + t.Fatal("Failed to get service name") + } + if name != expectedServiceName { + t.Errorf("Expected service name %q, got %q", expectedServiceName, name) + } + + id, ok := svc.OnBehalfOfID() + if !ok { + t.Fatal("Failed to get on behalf of id") + } + if id != expectedLoID { + t.Errorf("Expected on behalf of id %q, got %q", expectedLoID, id) + } + + roles, ok := svc.OnBehalfOfRoles() + if !ok { + t.Fatal("Failed to get on behalf of roles") + } + if diff := cmp.Diff([]string{"admin"}, roles); diff != "" { + t.Errorf("mismatch (-want +got)\n%s\n", diff) + } + + if !svc.RequestsElevatedAccess() { + t.Errorf("expected elevated access, got false") + } + }) + }, + ) } func TestDetachIntegration(t *testing.T) { diff --git a/lib/go/edgecontext/oauth_client.go b/lib/go/edgecontext/oauth_client.go index 7b66651..ac0dbd6 100644 --- a/lib/go/edgecontext/oauth_client.go +++ b/lib/go/edgecontext/oauth_client.go @@ -23,11 +23,11 @@ func (o OAuthClient) ID() string { // // For example, use: // -// if client.IsType("third_party") +// if client.IsType("third_party") // // Instead of: // -// if !client.IsType("first_party") +// if !client.IsType("first_party") func (o OAuthClient) IsType(types ...string) bool { clientType := AuthenticationToken(o).OAuthClientType for _, t := range types { diff --git a/lib/go/edgecontext/service.go b/lib/go/edgecontext/service.go index 48f1cc3..9b8cc81 100644 --- a/lib/go/edgecontext/service.go +++ b/lib/go/edgecontext/service.go @@ -10,14 +10,60 @@ const servicePrefix = "service/" // service talking to us. type Service AuthenticationToken +func (s Service) isService() bool { + subject := AuthenticationToken(s).Subject() + return strings.HasPrefix(subject, servicePrefix) +} + // Name returns the name of the service. // // If it's not coming from an authenticated service, // ("", false) will be returned. func (s Service) Name() (name string, ok bool) { - subject := AuthenticationToken(s).Subject() - if strings.HasPrefix(subject, servicePrefix) { + if s.isService() { + subject := AuthenticationToken(s).Subject() return subject[len(servicePrefix):], true } - return + return "", false +} + +// OnBehalfOfID returns the ID of the user on whose behalf the service is acting. +// +// If it's not coming from an authenticated service, +// ("", false) will be returned. +func (s Service) OnBehalfOfID() (id string, ok bool) { + if s.isService() { + token := AuthenticationToken(s) + if token.OnBehalfOf == nil { + return "", false + } + if strings.HasPrefix(token.OnBehalfOf.AccountID, userPrefix) { + return token.OnBehalfOf.AccountID, true + } + return "", false + } + return "", false +} + +// OnBehalfOfRoles returns the roles of the user on whose behalf the service is acting. +// +// If it's not coming from an authenticated service, +// (nil, false) will be returned. +func (s Service) OnBehalfOfRoles() (roles []string, ok bool) { + if s.isService() { + token := AuthenticationToken(s) + if token.OnBehalfOf == nil { + return nil, false + } + return token.OnBehalfOf.Roles, true + } + return nil, false +} + +// RequestsElevatedAccess returns whether the service requested elevated access. +func (s Service) RequestsElevatedAccess() bool { + if s.isService() { + return AuthenticationToken(s).ServiceRequestedElevatedAccess + } + return false } diff --git a/lib/go/edgecontext/token.go b/lib/go/edgecontext/token.go index f02e27b..ab46c8f 100644 --- a/lib/go/edgecontext/token.go +++ b/lib/go/edgecontext/token.go @@ -21,6 +21,13 @@ type AuthenticationToken struct { ID string `json:"id,omitempty"` CreatedAt timebp.TimestampMillisecond `json:"created_ms,omitempty"` } `json:"loid,omitempty"` + + OnBehalfOf *struct { + AccountID string `json:"aid,omitempty"` + Roles []string `json:"roles,omitempty"` + } `json:"obo,omitempty"` + + ServiceRequestedElevatedAccess bool `json:"sea,omitempty"` } // Subject returns the subject field of the token. diff --git a/lib/go/edgecontext/user.go b/lib/go/edgecontext/user.go index e3e205b..1d97df9 100644 --- a/lib/go/edgecontext/user.go +++ b/lib/go/edgecontext/user.go @@ -88,7 +88,7 @@ func (u User) HasRole(role string) bool { // Since in most cases the roles slice would be quite small, // it's better to iterate them than converting the slice into a set. for _, r := range token.Roles { - if strings.ToLower(role) == strings.ToLower(r) { + if strings.EqualFold(role, r) { return true } } diff --git a/lib/py/reddit_edgecontext/__init__.py b/lib/py/reddit_edgecontext/__init__.py index 2b4eeae..5e0a97f 100644 --- a/lib/py/reddit_edgecontext/__init__.py +++ b/lib/py/reddit_edgecontext/__init__.py @@ -117,6 +117,18 @@ def loid(self) -> Optional[str]: def loid_created_ms(self) -> Optional[int]: raise NotImplementedError + @property + def on_behalf_of_id(self) -> Optional[str]: + raise NotImplementedError + + @property + def on_behalf_of_roles(self) -> Optional[Set[str]]: + raise NotImplementedError + + @property + def requests_elevated_access(self) -> Optional[bool]: + raise NotImplementedError + class ValidatedAuthenticationToken(AuthenticationToken): def __init__(self, payload: Dict[str, Any]): @@ -150,6 +162,18 @@ def loid(self) -> Optional[str]: def loid_created_ms(self) -> Optional[int]: return (self.payload.get("loid") or {}).get("created_ms") + @property + def on_behalf_of_id(self) -> Optional[str]: + return (self.payload.get("obo") or {}).get("aid") + + @property + def on_behalf_of_roles(self) -> Optional[Set[str]]: + return (self.payload.get("obo") or {}).get("roles") + + @property + def requests_elevated_access(self) -> Optional[bool]: + return self.payload.get("sea") + class InvalidAuthenticationToken(AuthenticationToken): @property @@ -180,6 +204,18 @@ def loid(self) -> Optional[str]: def loid_created_ms(self) -> Optional[int]: raise NoAuthenticationError + @property + def on_behalf_of_id(self) -> Optional[str]: + raise NoAuthenticationError + + @property + def on_behalf_of_roles(self) -> Optional[Set[str]]: + raise NoAuthenticationError + + @property + def requests_elevated_access(self) -> Optional[bool]: + raise NoAuthenticationError + class Session(NamedTuple): """Wrapper for the session values in the EdgeContext.""" @@ -384,6 +420,12 @@ class Service(NamedTuple): authentication_token: AuthenticationToken """The authentication token for this request.""" + def is_service(self) -> bool: + subject = self.authentication_token.subject + if subject is None or subject == "": + return False + return subject.startswith("service/") + @property def name(self) -> str: """Return the authenticated service name. @@ -395,12 +437,29 @@ def name(self) -> str: """ subject = self.authentication_token.subject - if not (subject and subject.startswith("service/")): + if subject is None or subject == "": raise NoAuthenticationError - name = subject[len("service/") :] return name + @property + def on_behalf_of_id(self) -> Optional[str]: + if not self.is_service(): + raise NoAuthenticationError + return self.authentication_token.on_behalf_of_id + + @property + def on_behalf_of_roles(self) -> Optional[Set[str]]: + if not self.is_service(): + raise NoAuthenticationError + return self.authentication_token.on_behalf_of_roles + + @property + def requests_elevated_access(self) -> Optional[bool]: + if not self.is_service(): + raise NoAuthenticationError + return self.authentication_token.requests_elevated_access + class EdgeContext: """Contextual information about the initial request to an edge service. @@ -480,9 +539,7 @@ def request_id(self) -> RequestId: @cached_property def locale(self) -> Locale: """:py:class:`~reddit_edgecontext.Locale` object for the current context.""" - return Locale( - locale_code=self._t_request.locale.locale_code, - ) + return Locale(locale_code=self._t_request.locale.locale_code) @cached_property def _t_request(self) -> TRequest: diff --git a/lib/py/reddit_edgecontext/thrift/ttypes.py b/lib/py/reddit_edgecontext/thrift/ttypes.py index 8538076..aaa3f05 100644 --- a/lib/py/reddit_edgecontext/thrift/ttypes.py +++ b/lib/py/reddit_edgecontext/thrift/ttypes.py @@ -37,16 +37,9 @@ class Loid(object): """ - __slots__ = ( - "id", - "created_ms", - ) + __slots__ = ("id", "created_ms") - def __init__( - self, - id=None, - created_ms=None, - ): + def __init__(self, id=None, created_ms=None): self.id = id self.created_ms = created_ms @@ -137,10 +130,7 @@ class Session(object): __slots__ = ("id",) - def __init__( - self, - id=None, - ): + def __init__(self, id=None): self.id = id def read(self, iprot): @@ -221,10 +211,7 @@ class Device(object): __slots__ = ("id",) - def __init__( - self, - id=None, - ): + def __init__(self, id=None): self.id = id def read(self, iprot): @@ -306,10 +293,7 @@ class OriginService(object): __slots__ = ("name",) - def __init__( - self, - name=None, - ): + def __init__(self, name=None): self.name = name def read(self, iprot): @@ -389,10 +373,7 @@ class Geolocation(object): __slots__ = ("country_code",) - def __init__( - self, - country_code=None, - ): + def __init__(self, country_code=None): self.country_code = country_code def read(self, iprot): @@ -473,10 +454,7 @@ class RequestId(object): __slots__ = ("readable_id",) - def __init__( - self, - readable_id=None, - ): + def __init__(self, readable_id=None): self.readable_id = readable_id def read(self, iprot): @@ -560,10 +538,7 @@ class Locale(object): __slots__ = ("locale_code",) - def __init__( - self, - locale_code=None, - ): + def __init__(self, locale_code=None): self.locale_code = locale_code def read(self, iprot): @@ -818,146 +793,32 @@ def __ne__(self, other): all_structs.append(Loid) Loid.thrift_spec = ( None, # 0 - ( - 1, - TType.STRING, - "id", - "UTF8", - None, - ), # 1 - ( - 2, - TType.I64, - "created_ms", - None, - None, - ), # 2 + (1, TType.STRING, "id", "UTF8", None), # 1 + (2, TType.I64, "created_ms", None, None), # 2 ) all_structs.append(Session) -Session.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "id", - "UTF8", - None, - ), # 1 -) +Session.thrift_spec = (None, (1, TType.STRING, "id", "UTF8", None)) # 0 # 1 all_structs.append(Device) -Device.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "id", - "UTF8", - None, - ), # 1 -) +Device.thrift_spec = (None, (1, TType.STRING, "id", "UTF8", None)) # 0 # 1 all_structs.append(OriginService) -OriginService.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "name", - "UTF8", - None, - ), # 1 -) +OriginService.thrift_spec = (None, (1, TType.STRING, "name", "UTF8", None)) # 0 # 1 all_structs.append(Geolocation) -Geolocation.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "country_code", - "UTF8", - None, - ), # 1 -) +Geolocation.thrift_spec = (None, (1, TType.STRING, "country_code", "UTF8", None)) # 0 # 1 all_structs.append(RequestId) -RequestId.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "readable_id", - "UTF8", - None, - ), # 1 -) +RequestId.thrift_spec = (None, (1, TType.STRING, "readable_id", "UTF8", None)) # 0 # 1 all_structs.append(Locale) -Locale.thrift_spec = ( - None, # 0 - ( - 1, - TType.STRING, - "locale_code", - "UTF8", - None, - ), # 1 -) +Locale.thrift_spec = (None, (1, TType.STRING, "locale_code", "UTF8", None)) # 0 # 1 all_structs.append(Request) Request.thrift_spec = ( None, # 0 - ( - 1, - TType.STRUCT, - "loid", - [Loid, None], - None, - ), # 1 - ( - 2, - TType.STRUCT, - "session", - [Session, None], - None, - ), # 2 - ( - 3, - TType.STRING, - "authentication_token", - "UTF8", - None, - ), # 3 - ( - 4, - TType.STRUCT, - "device", - [Device, None], - None, - ), # 4 - ( - 5, - TType.STRUCT, - "origin_service", - [OriginService, None], - None, - ), # 5 - ( - 6, - TType.STRUCT, - "geolocation", - [Geolocation, None], - None, - ), # 6 - ( - 7, - TType.STRUCT, - "request_id", - [RequestId, None], - None, - ), # 7 - ( - 8, - TType.STRUCT, - "locale", - [Locale, None], - None, - ), # 8 + (1, TType.STRUCT, "loid", [Loid, None], None), # 1 + (2, TType.STRUCT, "session", [Session, None], None), # 2 + (3, TType.STRING, "authentication_token", "UTF8", None), # 3 + (4, TType.STRUCT, "device", [Device, None], None), # 4 + (5, TType.STRUCT, "origin_service", [OriginService, None], None), # 5 + (6, TType.STRUCT, "geolocation", [Geolocation, None], None), # 6 + (7, TType.STRUCT, "request_id", [RequestId, None], None), # 7 + (8, TType.STRUCT, "locale", [Locale, None], None), # 8 ) fix_spec(all_structs) del all_structs diff --git a/lib/py/tests/edge_context_tests.py b/lib/py/tests/edge_context_tests.py index 0abfa93..b05ee35 100644 --- a/lib/py/tests/edge_context_tests.py +++ b/lib/py/tests/edge_context_tests.py @@ -81,6 +81,9 @@ + b"\x00" ) +SERIALIZED_EDGECONTEXT_WITH_VALID_SERVICE_AUTH = b"\x0c\x00\x01\x00\x0c\x00\x02\x00\x0b\x00\x03\x00\x00\x02\x0aeyJhbGciOiJSUzI1NiIsImtpZCI6IlNIQTI1NjpsWjBoa1dSc0RwYXBlQnUyZWtYOVdZMm9ZSW5Id2RSYVhUd3RCZWNEaWNJIiwidHlwIjoiSldUIn0.eyJzdWIiOiJzZXJ2aWNlL3Rlc3Qtc2VydmljZSIsImV4cCI6MjUyNDYwODAwMH0.P41Iahxu-Bbg5srTFSQTBkzwiff4ytlhVBUYuyYTFGY_7XCyKdZywUmVHRY_Q2w8Q2uaybnmuoM95JhRpdNYcTPIYWEby4Z5DSV-zMqqmHnP22aH_sAckFQl86Yw_2pdZpKKJ-KQkyT0vEkxe-vNs5HhEdBr6Rae0g2SKEr7RaPMoToq6xpucDAREVWa7yJMtyyNtiVixeLoxTegRLOZTFEVt4TTYKDuT2FdEY5P2b8BOSpFMoiv9w51gZO1qvn9Zjrl00Z-lI_onihMIkrG_viWVAlzEl8d5ZWuJVjHJvm7O0CS4OuhZocE2qbYQrw9THSS1Mh4YR-_r2v1ArYnVA\x0c\x00\x04\x00\x0c\x00\x05\x0b\x00\x01\x00\x00\x00\x0eorigin/service\x00\x0c\x00\x06\x00\x0c\x00\x07\x0b\x00\x01\x00\x00\x00$1566dce9-9567-4952-b23b-9fd72e111162\x00\x00" +SERIALIZED_EDGECONTEXT_WITH_VALID_SERVICE_AUTH_AND_OPTIONS = b"\x0c\x00\x01\x00\x0c\x00\x02\x00\x0b\x00\x03\x00\x00\x02VeyJhbGciOiJSUzI1NiIsImtpZCI6IlNIQTI1NjpsWjBoa1dSc0RwYXBlQnUyZWtYOVdZMm9ZSW5Id2RSYVhUd3RCZWNEaWNJIiwidHlwIjoiSldUIn0.eyJzdWIiOiJzZXJ2aWNlL3Rlc3Qtc2VydmljZSIsImV4cCI6MjUyNDYwODAwMCwib2JvIjp7ImFpZCI6InQyX2RlYWRiZWVmIiwicm9sZXMiOlsiYWRtaW4iXX0sInNlYSI6dHJ1ZX0.PVefAKWUFfk_7QKen6Iz0Cfu95Yp92lYETlrxCUacLsa9u-qz36aet21iwFrdnJiz7gDeJRH7sOJyh6jRmkD0ptWs4Zl7VqpZY-ALgDOdhwSHoUIoV2L7twT-Dm3Tdyfbzq01fOni9ioq5akKnETC5IbLSOqp1ssWJcgo_9g-X-SdRiuf5u8YHD2Mrep5U21bkbYnm4rK9tX_oCnhrrp4rbXi5yogx594oNmOWUedIeyv6QY_xVGbaXOz7deBIWQY2fSYG3cpiBNtSYEJ4yDTbjGY0G1Vp78bX8YZlboc13TGoDpARdfHuHeQU0wAQEhi7pu0Q4FufEVua4q1f0P3A\x0c\x00\x04\x00\x0c\x00\x05\x0b\x00\x01\x00\x00\x00\x0eorigin/service\x00\x0c\x00\x06\x00\x0c\x00\x07\x0b\x00\x01\x00\x00\x00$a3b2d5c2-ab27-4948-9dae-78a3ffb46957\x00\x00" + class AuthenticationTokenTests(unittest.TestCase): def test_validated_authentication_token(self): @@ -128,6 +131,20 @@ def test_invalidated_authentication_token(self): with self.assertRaises(NoAuthenticationError): getattr(token, attr) + def test_validated_service_authentication_token(self): + payload = { + "sub": "service/test-service", + "exp": 1574458470, + "obo": {"aid": "t2_deadbeef", "roles": ["admin"]}, + "sea": True, + } + + token = ValidatedAuthenticationToken(payload) + self.assertEqual(token.subject, "service/test-service") + self.assertEqual(token.on_behalf_of_roles, ["admin"]) + self.assertEqual(token.on_behalf_of_id, "t2_deadbeef") + self.assertEqual(token.requests_elevated_access, True) + class EdgeContextTests(unittest.TestCase): LOID_ID = "t2_deadbeef" @@ -324,3 +341,21 @@ def test_request_id(self): "edge_request_id": REQUEST_ID, }, ) + + def test_service_auth(self): + request_context = self.factory.from_upstream(SERIALIZED_EDGECONTEXT_WITH_VALID_SERVICE_AUTH) + + self.assertEqual(request_context.service.name, "test-service") + self.assertEqual(request_context.service.on_behalf_of_id, None) + self.assertEqual(request_context.service.on_behalf_of_roles, None) + self.assertEqual(request_context.service.requests_elevated_access, None) + + def test_service_auth_with_additional_options(self): + request_context = self.factory.from_upstream( + SERIALIZED_EDGECONTEXT_WITH_VALID_SERVICE_AUTH_AND_OPTIONS + ) + + self.assertEqual(request_context.service.name, "test-service") + self.assertEqual(request_context.service.on_behalf_of_id, "t2_deadbeef") + self.assertEqual(request_context.service.on_behalf_of_roles, ["admin"]) + self.assertEqual(request_context.service.requests_elevated_access, True)