diff --git a/internal/pkg/beskar/plugin.go b/internal/pkg/beskar/plugin.go index d26bd10..bf449b3 100644 --- a/internal/pkg/beskar/plugin.go +++ b/internal/pkg/beskar/plugin.go @@ -28,6 +28,7 @@ import ( "github.com/distribution/distribution/v3" "github.com/hashicorp/memberlist" "github.com/sirupsen/logrus" + "go.ciq.dev/beskar/internal/pkg/config" "go.ciq.dev/beskar/internal/pkg/gossip" "go.ciq.dev/beskar/internal/pkg/router" eventv1 "go.ciq.dev/beskar/pkg/api/event/v1" @@ -118,7 +119,7 @@ func (pm *pluginManager) ServeHTTP(w http.ResponseWriter, r *http.Request) { pl.ServeHTTP(w, r) } -func (pm *pluginManager) register(node *memberlist.Node, meta *gossip.BeskarMeta) error { +func (pm *pluginManager) register(node *memberlist.Node, meta *gossip.BeskarMeta, beskarConfig *config.BeskarConfig) error { hostport := net.JoinHostPort(node.Addr.String(), strconv.Itoa(int(meta.ServicePort))) info, err := pm.getPluginInfo(hostport) if err != nil { @@ -146,7 +147,7 @@ func (pm *pluginManager) register(node *memberlist.Node, meta *gossip.BeskarMeta logger: pm.logger, } - if err := pl.initRouter(info); err != nil { + if err := pl.initRouter(info, beskarConfig.Router.BodyLimit); err != nil { return err } @@ -160,7 +161,7 @@ func (pm *pluginManager) register(node *memberlist.Node, meta *gossip.BeskarMeta pl.version = info.Version pl.mediaTypes = mediaTypes - if err := pl.initRouter(info); err != nil { + if err := pl.initRouter(info, beskarConfig.Router.BodyLimit); err != nil { return err } } @@ -343,7 +344,7 @@ func (p *plugin) sendEvent(ctx context.Context, event *eventv1.EventPayload, nod }, backoff.WithContext(eb, ctx)) } -func (p *plugin) initRouter(info *pluginv1.Info) error { +func (p *plugin) initRouter(info *pluginv1.Info, bodyLimit int64) error { var routerOptions []router.RegoRouterOption if info.Router == nil { @@ -352,6 +353,9 @@ func (p *plugin) initRouter(info *pluginv1.Info) error { if len(info.Router.Data) > 0 { routerOptions = append(routerOptions, router.WithData(bytes.NewReader(info.Router.Data))) } + if bodyLimit > 0 { + routerOptions = append(routerOptions, router.WithBodyLimit(bodyLimit)) + } rr, err := router.New(info.Name, string(info.Router.Rego), routerOptions...) if err != nil { return err diff --git a/internal/pkg/beskar/registry.go b/internal/pkg/beskar/registry.go index 44e7b5a..2e81017 100644 --- a/internal/pkg/beskar/registry.go +++ b/internal/pkg/beskar/registry.go @@ -158,7 +158,7 @@ func (br *Registry) startGossipWatcher() { br.logger.Debugf("Added groupcache peer %s", peer) case gossip.PluginInstance: br.logger.Infof("Register plugin") - if err := br.pluginManager.register(node, meta); err != nil { + if err := br.pluginManager.register(node, meta, br.beskarConfig); err != nil { br.logger.Errorf("plugin register error: %s", err) } } @@ -182,7 +182,7 @@ func (br *Registry) startGossipWatcher() { br.logger.Debugf("Added groupcache peer %s", peer) case gossip.PluginInstance: br.logger.Infof("Register plugin") - if err := br.pluginManager.register(node, meta); err != nil { + if err := br.pluginManager.register(node, meta, br.beskarConfig); err != nil { br.logger.Errorf("plugin register error: %s", err) } } diff --git a/internal/pkg/config/beskar.go b/internal/pkg/config/beskar.go index 6575fef..f184d4c 100644 --- a/internal/pkg/config/beskar.go +++ b/internal/pkg/config/beskar.go @@ -31,6 +31,10 @@ type Cache struct { Size uint32 `yaml:"size"` } +type Router struct { + BodyLimit int64 `yaml:"bodylimit"` +} + type BeskarConfig struct { Version string `yaml:"version"` Profiling bool `yaml:"profiling"` @@ -38,6 +42,7 @@ type BeskarConfig struct { Cache Cache `yaml:"cache"` Gossip gossip.Config `yaml:"gossip"` Registry *configuration.Configuration `yaml:"registry"` + Router Router `yaml:"router"` } type BeskarConfigV1 BeskarConfig diff --git a/internal/pkg/config/default/beskar.yaml b/internal/pkg/config/default/beskar.yaml index 482d0c3..224cb7c 100644 --- a/internal/pkg/config/default/beskar.yaml +++ b/internal/pkg/config/default/beskar.yaml @@ -11,6 +11,9 @@ gossip: key: XD1IOhcp0HWFgZJ/HAaARqMKJwfMWtz284Yj7wxmerA= peers: [] +router: + bodyLimit: 8192 + # hostname returned to plugins to access registry service, # automatically set when deployed on kubernetes hostname: localhost diff --git a/internal/pkg/router/builtin.go b/internal/pkg/router/builtin.go index 8e1e29d..93d4027 100644 --- a/internal/pkg/router/builtin.go +++ b/internal/pkg/router/builtin.go @@ -11,7 +11,6 @@ import ( "io" "net/http" "strings" - "sync" "github.com/distribution/distribution/v3" "github.com/distribution/reference" @@ -28,6 +27,7 @@ type funcContext struct { req *http.Request registry distribution.Namespace builtinErr error + bodyLimit int64 } var ociBlobDigestBuiltin = rego.Function3( @@ -129,22 +129,6 @@ var ociBlobDigestBuiltin = rego.Function3( }, ) -type bodyCloser struct { - io.Reader - closeFn func() error -} - -func (bc bodyCloser) Close() error { - return bc.closeFn() -} - -var bufferPool = sync.Pool{ - New: func() interface{} { - buffer := make([]byte, 8192) - return &buffer - }, -} - var requestBodyBuiltin = rego.FunctionDyn( ®o.Function{ Name: "request.body", @@ -165,38 +149,33 @@ var requestBodyBuiltin = rego.FunctionDyn( } }() - if funcContext.req.Body != nil && funcContext.req.Body != http.NoBody { - buf := bufferPool.Get().(*[]byte) - - n, err := io.ReadAtLeast(funcContext.req.Body, *buf, 1) - if err != nil { - return nil, fmt.Errorf("empty body request") - } - - bodyReader := bytes.NewReader((*buf)[:n]) - - v, err := ast.ValueFromReader(bodyReader) - if err != nil { - return nil, err - } + if funcContext.req.Body == nil || funcContext.req.Body == http.NoBody { + v, err := ast.InterfaceToValue(nil) + return ast.NewTerm(v), err + } - _, _ = bodyReader.Seek(0, io.SeekStart) + buf := new(bytes.Buffer) - originalBody := funcContext.req.Body + // plugin API are receiving small JSON objects, so we limit it to 8KB by default + // configurable via beskar router configuration. + _, err := buf.ReadFrom(io.LimitReader(funcContext.req.Body, funcContext.bodyLimit)) + if err != nil { + return nil, err + } else if err := funcContext.req.Body.Close(); err != nil { + return nil, err + } - funcContext.req.Body = &bodyCloser{ - Reader: bodyReader, - closeFn: func() error { - defer bufferPool.Put(buf) - return originalBody.Close() - }, - } + bodyReader := bytes.NewReader(buf.Bytes()) - return ast.NewTerm(v), nil + v, err := ast.ValueFromReader(bodyReader) + if err != nil { + return nil, err + } else if _, err = bodyReader.Seek(0, io.SeekStart); err != nil { + return nil, err } - v, err := ast.InterfaceToValue(nil) + funcContext.req.Body = io.NopCloser(bodyReader) - return ast.NewTerm(v), err + return ast.NewTerm(v), nil }, ) diff --git a/internal/pkg/router/router.go b/internal/pkg/router/router.go index fbc6b8c..15efcc2 100644 --- a/internal/pkg/router/router.go +++ b/internal/pkg/router/router.go @@ -28,9 +28,10 @@ type Result struct { type RegoOption = func(r *rego.Rego) type RegoRouter struct { - name string - options []RegoOption - peq rego.PreparedEvalQuery + name string + options []RegoOption + peq rego.PreparedEvalQuery + bodyLimit int64 } type RegoRouterOption func(r *RegoRouter) error @@ -59,6 +60,13 @@ func WithOption(option RegoOption) RegoRouterOption { } } +func WithBodyLimit(limit int64) RegoRouterOption { + return func(r *RegoRouter) error { + r.bodyLimit = limit + return nil + } +} + func New(name, module string, options ...RegoRouterOption) (_ *RegoRouter, err error) { router := &RegoRouter{ name: name, @@ -68,6 +76,7 @@ func New(name, module string, options ...RegoRouterOption) (_ *RegoRouter, err e ociBlobDigestBuiltin, requestBodyBuiltin, }, + bodyLimit: 8192, } for _, opt := range options { @@ -86,8 +95,9 @@ func New(name, module string, options ...RegoRouterOption) (_ *RegoRouter, err e func (rr *RegoRouter) Decision(req *http.Request, registry distribution.Namespace) (*Result, error) { fctx := &funcContext{ - req: req, - registry: registry, + req: req, + registry: registry, + bodyLimit: rr.bodyLimit, } ctx := context.WithValue(req.Context(), &funcContextKey, fctx)