Skip to content

Commit

Permalink
[swarming] Support Swarming bot session in "dry run" mode.
Browse files Browse the repository at this point in the history
Bot session tokens are recognized, checked, and refreshed. But
any errors are logged and ignored.

Also bot dimensions are now fetched from datastore (eventually
all bot RPC calls other than /bot/poll will fetch them from
datastore: that way we can guarantee all of them use the same
consistent set of dimensions).

No new tests, because this is a temporary messy state. Test will
be written for the final "clean" state.

Also remove unused /swarming/api/v1/bot/rbe/ping.

[email protected]
BUG=b/355013282

Change-Id: I9460cb21975480a29f5718bc177d6f0460be21e9
Reviewed-on: https://chromium-review.googlesource.com/c/infra/luci/luci-go/+/6024091
Reviewed-by: Chan Li <[email protected]>
Commit-Queue: Vadim Shtayura <[email protected]>
  • Loading branch information
vadimsht authored and LUCI CQ committed Nov 15, 2024
1 parent ca21b75 commit 19f5647
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 113 deletions.
1 change: 1 addition & 0 deletions swarming/server/botapi/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func NewBotAPIServer(cfg *cfg.Provider, project string) *BotAPIServer {
// UnimplementedRequest is used as a placeholder in unimplemented handlers.
type UnimplementedRequest struct{}

func (r *UnimplementedRequest) ExtractSession() []byte { return nil }
func (r *UnimplementedRequest) ExtractPollToken() []byte { return nil }
func (r *UnimplementedRequest) ExtractSessionToken() []byte { return nil }
func (r *UnimplementedRequest) ExtractDimensions() map[string][]string { return nil }
Expand Down
4 changes: 4 additions & 0 deletions swarming/server/botsession/botsession.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package botsession
import (
"context"
"fmt"
"time"

"go.opentelemetry.io/otel/trace"
"google.golang.org/protobuf/encoding/prototext"
Expand All @@ -31,6 +32,9 @@ import (
"go.chromium.org/luci/swarming/server/hmactoken"
)

// Expiry is how long a new Swarming session token will last.
const Expiry = time.Hour

// cryptoCtx is used whe signing and checking the token as a cryptographic
// context (to make sure produced token can't be incorrectly used in other
// protocols that use the same secret key).
Expand Down
187 changes: 119 additions & 68 deletions swarming/server/botsrv/botsrv.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (
"go.chromium.org/luci/tokenserver/auth/machine"

internalspb "go.chromium.org/luci/swarming/proto/internals"
"go.chromium.org/luci/swarming/server/botsession"
"go.chromium.org/luci/swarming/server/cfg"
"go.chromium.org/luci/swarming/server/hmactoken"
"go.chromium.org/luci/swarming/server/pyproxy"
Expand All @@ -51,19 +52,21 @@ import (
// RequestBody should be implemented by a JSON-serializable struct representing
// format of some particular request.
type RequestBody interface {
ExtractSession() []byte // the token with bot Session proto
ExtractPollToken() []byte // the poll token, if present
ExtractSessionToken() []byte // the session token, if present
ExtractSessionToken() []byte // the RBE session token, if present
ExtractDimensions() map[string][]string // dimensions reported by the bot, if present
ExtractDebugRequest() any // serialized as JSON and logged on errors
}

// Request is extracted from an authenticated request from a bot.
type Request struct {
BotID string // validated bot ID
SessionID string // validated RBE bot session ID, if present
SessionTokenExpired bool // true if the request has expired session token
PollState *internalspb.PollState // validated poll state
Dimensions map[string][]string // validated dimensions
BotID string // validated bot ID, TODO: delete (part of Session)
SessionID string // validated RBE bot session ID, if present, TODO: delete
SessionTokenExpired bool // true if the request has expired session token, TODO: delete
PollState *internalspb.PollState // validated poll state, TODO: delete
Dimensions map[string][]string // validated dimensions, TODO: delete
Session *internalspb.Session // the bot session from the session token
}

// Response is serialized as JSON and sent to the bot.
Expand All @@ -78,16 +81,31 @@ type Response any
// a gRPC error code that will be converted into an HTTP error.
type Handler[B any] func(ctx context.Context, body *B, req *Request) (Response, error)

// KnownBotInfo is information about a bot registered in the datastore.
type KnownBotInfo struct {
// SessionID is the current bot session ID of this bot.
SessionID string
// Dimensions is "k:v" dimensions registered by this bot in the last poll.
Dimensions []string
}

// KnownBotProvider knows how to return information about existing bots.
//
// Returns nil and no error if the bot is not registered in the datastore. All
// other errors can be considered transient.
type KnownBotProvider func(ctx context.Context, botID string) (*KnownBotInfo, error)

// Server knows how to authenticate bot requests and route them to handlers.
type Server struct {
router *router.Router
middlewares router.MiddlewareChain
hmacSecret *hmactoken.Secret
cfg *cfg.Provider
knownBots KnownBotProvider
}

// New constructs new Server.
func New(ctx context.Context, cfg *cfg.Provider, r *router.Router, prx *pyproxy.Proxy, projectID string, hmacSecret *hmactoken.Secret) *Server {
func New(ctx context.Context, cfg *cfg.Provider, r *router.Router, prx *pyproxy.Proxy, bots KnownBotProvider, projectID string, hmacSecret *hmactoken.Secret) *Server {
gaeAppDomain := fmt.Sprintf("%s.appspot.com", projectID)

// Redirect to Python for eligible requests before hitting any other
Expand Down Expand Up @@ -121,6 +139,7 @@ func New(ctx context.Context, cfg *cfg.Provider, r *router.Router, prx *pyproxy.
),
hmacSecret: hmacSecret,
cfg: cfg,
knownBots: bots,
}
}

Expand Down Expand Up @@ -168,6 +187,8 @@ func GET(s *Server, route string, handler router.Handler) {
// some JSON-serialized response.
//
// It performs bot authentication and authorization based on bots.cfg config.
//
// TODO: Most of this will be gone once Bot Session tokens are rolled out.
func JSON[B any, RB RequestBodyConstraint[B]](s *Server, route string, h Handler[B]) {
s.router.POST(route, s.middlewares, func(c *router.Context) {
ctx := c.Request.Context()
Expand Down Expand Up @@ -303,36 +324,8 @@ func JSON[B any, RB RequestBodyConstraint[B]](s *Server, route string, h Handler
}
}

// Authenticate the bot based on the config. Do it in a dry run mode for now
// to compare with the authentication based on the poll state token. Check
// various other bits of the poll state token as well (like enforced
// dimensions). Once it's confirmed there are no differences, the poll state
// token mechanism can be retired.
var dryRunAuthErr error
var ignoreDryRunAuthErr bool
dims := RB(body).ExtractDimensions()
botID, err := botIDFromDimensions(dims)
if err != nil {
logging.Errorf(ctx, "bot_auth: bad bot ID in dims: %s", err)
dryRunAuthErr = err
} else {
if fromPollState := extractBotID(pollState); fromPollState != botID {
logging.Errorf(ctx, "bot_auth: mismatch in bot ID from poll state (%q) and dims (%q)", fromPollState, botID)
}
botGroup := s.cfg.Cached(ctx).BotGroup(botID)
if !sameEnforcedDims(botID, botGroup.Dimensions, pollState.EnforcedDimensions) {
logging.Errorf(ctx, "bot_auth: mismatch in enforced dimensions (%v vs %v)", botGroup.Dimensions, pollState.EnforcedDimensions)
}
dryRunAuthErr = AuthorizeBot(ctx, botID, botGroup.Auth)
if transient.Tag.In(dryRunAuthErr) {
logging.Errorf(ctx, "bot_auth: ignoring transient error when checking bot creds: %s", dryRunAuthErr)
dryRunAuthErr = nil
ignoreDryRunAuthErr = true
}
}

// Extract bot ID from the validated PollToken.
botID = extractBotID(pollState)
botID := extractBotID(pollState)
if botID == "" {
writeErr(status.Errorf(codes.InvalidArgument, "no bot ID"))
return
Expand All @@ -348,21 +341,15 @@ func JSON[B any, RB RequestBodyConstraint[B]](s *Server, route string, h Handler
if transient.Tag.In(err) {
writeErr(status.Errorf(codes.Internal, "transient error checking bot credentials: %s", err))
} else {
if !ignoreDryRunAuthErr && dryRunAuthErr == nil {
logging.Errorf(ctx, "bot_auth: bot auth mismatch: bots.cfg method succeeded, but poll token method failed")
}
writeErr(status.Errorf(codes.Unauthenticated, "bad bot credentials: %s", err))
}
return
}

if !ignoreDryRunAuthErr && dryRunAuthErr != nil {
logging.Errorf(ctx, "bot_auth: bot auth mismatch: bots.cfg method failed, but poll token method succeeded; poll token:\n%s", prettyProto(pollState))
}

// Apply verified state stored in PollState on top of whatever was reported
// by the bot. Normally functioning bots should report the same values as
// stored in the token.
dims := RB(body).ExtractDimensions()
for _, dim := range pollState.EnforcedDimensions {
reported := dims[dim.Key]
if !slices.Equal(reported, dim.Values) {
Expand All @@ -379,13 +366,50 @@ func JSON[B any, RB RequestBodyConstraint[B]](s *Server, route string, h Handler
return
}

// We'll soon switch to using the bot dimensions stored in the datastore.
// Verify they match what the bot is sending via tokens.
var swarmingSID string
switch knownBot, err := s.knownBots(ctx, botID); {
case err != nil:
logging.Errorf(ctx, "Failed to fetch BotInfo of %s: %s", botID, err)
case knownBot == nil:
logging.Errorf(ctx, "Missing BotInfo of %s", botID)
default:
swarmingSID = knownBot.SessionID
fromDS := dimsFlatToString(knownBot.Dimensions)
fromTok := dimsMapToString(dims)
if fromDS != fromTok {
logging.Errorf(ctx, "Dims from datastore != dims from token:\nDatastore: %s\nToken: %s", fromDS, fromTok)
}
}

// If have a Swarming bot session token, verify it is valid and the
// information inside matches what was extracted above from other
// (deprecated) tokens. This is a dry run check for now.
var swarmingSession *internalspb.Session
if sessionTok := RB(body).ExtractSession(); len(sessionTok) != 0 {
swarmingSession = checkSwarmingSession(ctx,
sessionTok,
s.hmacSecret,
botID,
swarmingSID,
sessionState.GetRbeBotSessionId(),
)
if swarmingSession != nil {
logging.Infof(ctx, "Swarming session ID: %s", swarmingSession.SessionId)
}
} else {
logging.Infof(ctx, "No Swarming session token")
}

// The request is valid, dispatch it to the handler.
resp, err := h(ctx, body, &Request{
BotID: botID,
SessionID: sessionState.GetRbeBotSessionId(),
SessionTokenExpired: sessionTokenExpired,
PollState: pollState,
Dimensions: dims,
Session: swarmingSession,
})
if err != nil {
writeErr(err)
Expand All @@ -406,6 +430,41 @@ func JSON[B any, RB RequestBodyConstraint[B]](s *Server, route string, h Handler
})
}

// checkSwarmingSession unmarshals Swarming bot session and compares it to given
// values, logging discrepancies.
//
// This is a dry run check before session tokens become authoritative.
func checkSwarmingSession(ctx context.Context, tok []byte, s *hmactoken.Secret, botID, swarmingSID, rbeSID string) *internalspb.Session {
session, err := botsession.Unmarshal(tok, s)
if err != nil {
logging.Errorf(ctx, "Bad session token: %s", err)
return nil
}
if clock.Now(ctx).After(session.Expiry.AsTime()) {
logging.Errorf(ctx, "Expired session:\n%s", botsession.FormatForDebug(session))
return nil
}
if session.BotId != botID {
logging.Errorf(ctx, "Wrong bot ID:\n%s", botsession.FormatForDebug(session))
return nil
}
if session.RbeBotSessionId != rbeSID {
logging.Errorf(ctx, "Wrong RBE session ID:\n%s", botsession.FormatForDebug(session))
return nil
}
// Following errors are "fine" in a sense that if they happen, something is
// still broken, but we can keep using the token for now. They will all become
// real errors once Swarming bot sessions are fully implemented.
if err := AuthorizeBot(ctx, session.BotId, session.BotConfig.GetBotAuth()); err != nil {
logging.Errorf(ctx, "Failing bot authorization: %s", err)
logging.Errorf(ctx, "Session:\n%s", botsession.FormatForDebug(session))
}
if session.SessionId != swarmingSID {
logging.Errorf(ctx, "Wrong session ID: %s != %s", session.SessionId, swarmingSID)
}
return session
}

// prettyProto formats a proto message for logs.
func prettyProto(msg proto.Message) string {
blob, err := prototext.MarshalOptions{
Expand Down Expand Up @@ -435,31 +494,6 @@ func botIDFromDimensions(dims map[string][]string) (string, error) {
}
}

// sameEnforcedDims compares enforced dimensions in bots.cfg to ones in the
// poll state.
//
// This is temporary until the poll state token is removed.
func sameEnforcedDims(botID string, cfgDims map[string][]string, tokDims []*internalspb.PollState_Dimension) bool {
var fromCfg []string
fromCfg = append(fromCfg, "id:"+botID)
for key, vals := range cfgDims {
for _, val := range vals {
fromCfg = append(fromCfg, fmt.Sprintf("%s:%s", key, val))
}
}
sort.Strings(fromCfg)

var fromTok []string
for _, d := range tokDims {
for _, val := range d.Values {
fromTok = append(fromTok, fmt.Sprintf("%s:%s", d.Key, val))
}
}
sort.Strings(fromTok)

return slices.Equal(fromCfg, fromTok)
}

// checkCredentials checks the bot credentials in the context match what is
// required by the PollState.
//
Expand Down Expand Up @@ -548,3 +582,20 @@ func extractBotID(s *internalspb.PollState) string {
}
return ""
}

// dimsFlatToString converts a flat list of dimensions into a debug string.
func dimsFlatToString(dims []string) string {
return strings.Join(dims, " | ")
}

// dimsMapToString converts a dimensions map into a debug string.
func dimsMapToString(dims map[string][]string) string {
var kv []string
for k, vals := range dims {
for _, v := range vals {
kv = append(kv, fmt.Sprintf("%s:%s", k, v))
}
}
sort.Strings(kv)
return dimsFlatToString(kv)
}
5 changes: 5 additions & 0 deletions swarming/server/botsrv/botsrv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ import (
)

type testRequest struct {
Session []byte
Dimensions map[string][]string
PollToken []byte
SessionToken []byte
}

func (r *testRequest) ExtractSession() []byte { return r.Session }
func (r *testRequest) ExtractPollToken() []byte { return r.PollToken }
func (r *testRequest) ExtractSessionToken() []byte { return r.SessionToken }
func (r *testRequest) ExtractDimensions() map[string][]string { return r.Dimensions }
Expand Down Expand Up @@ -85,6 +87,9 @@ func TestBotHandler(t *testing.T) {
Passive: [][]byte{[]byte("also-secret")},
}),
cfg: cfgtest.MockConfigs(ctx, cfgtest.NewMockedConfigs()),
knownBots: func(ctx context.Context, botID string) (*KnownBotInfo, error) {
return nil, nil
},
}

var lastBody *testRequest
Expand Down
Loading

0 comments on commit 19f5647

Please sign in to comment.