diff --git a/docs/content/configuration.md b/docs/content/configuration.md index e44ff58888..8faec0d708 100644 --- a/docs/content/configuration.md +++ b/docs/content/configuration.md @@ -856,11 +856,15 @@ Caching represents the configuration of the inter-query cache that built-in func functions provided by OPA, `http.send` is currently the only one to utilize the inter-query cache. See the documentation on the [http.send built-in function](../policy-reference/#http) for information about the available caching options. -| Field | Type | Required | Description | -| --- | --- | --- | --- | -| `caching.inter_query_builtin_cache.max_size_bytes` | `int64` | No | Inter-query cache size limit in bytes. OPA will drop old items from the cache if this limit is exceeded. By default, no limit is set. | -| `caching.inter_query_builtin_cache.forced_eviction_threshold_percentage` | `int64` | No | Threshold limit configured as percentage of `caching.inter_query_builtin_cache.max_size_bytes`, when exceeded OPA will start dropping old items permaturely. By default, set to `100`. | -| `caching.inter_query_builtin_cache.stale_entry_eviction_period_seconds` | `int64` | No | Stale entry eviction period in seconds. OPA will drop expired items from the cache every `stale_entry_eviction_period_seconds`. By default, set to `0` indicating stale entry eviction is disabled. | +It also represents the configuration of the inter-query value cache that built-in functions can utilize. Currently, this +cache is utilized by the `regex` and `glob` built-in functions for compiled regex and glob match patterns respectively. + +| Field | Type | Required | Description | +|--------------------------------------------------------------------------| --- | --- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `caching.inter_query_builtin_cache.max_size_bytes` | `int64` | No | Inter-query cache size limit in bytes. OPA will drop old items from the cache if this limit is exceeded. By default, no limit is set. | +| `caching.inter_query_builtin_cache.forced_eviction_threshold_percentage` | `int64` | No | Threshold limit configured as percentage of `caching.inter_query_builtin_cache.max_size_bytes`, when exceeded OPA will start dropping old items permaturely. By default, set to `100`. | +| `caching.inter_query_builtin_cache.stale_entry_eviction_period_seconds` | `int64` | No | Stale entry eviction period in seconds. OPA will drop expired items from the cache every `stale_entry_eviction_period_seconds`. By default, set to `0` indicating stale entry eviction is disabled. | +| `caching.inter_query_builtin_value_cache.max_num_entries` | `int` | No | Maximum number of entries in the Inter-query value cache. OPA will drop random items from the cache if this limit is exceeded. By default, set to `0` indicating unlimited size. | ## Distributed tracing diff --git a/internal/rego/opa/options.go b/internal/rego/opa/options.go index ea1e339c1b..b58a05ee8e 100644 --- a/internal/rego/opa/options.go +++ b/internal/rego/opa/options.go @@ -18,13 +18,14 @@ type Result struct { // EvalOpts define options for performing an evaluation. type EvalOpts struct { - Input *interface{} - Metrics metrics.Metrics - Entrypoint int32 - Time time.Time - Seed io.Reader - InterQueryBuiltinCache cache.InterQueryCache - NDBuiltinCache builtins.NDBCache - PrintHook print.Hook - Capabilities *ast.Capabilities + Input *interface{} + Metrics metrics.Metrics + Entrypoint int32 + Time time.Time + Seed io.Reader + InterQueryBuiltinCache cache.InterQueryCache + InterQueryBuiltinValueCache cache.InterQueryValueCache + NDBuiltinCache builtins.NDBCache + PrintHook print.Hook + Capabilities *ast.Capabilities } diff --git a/plugins/discovery/discovery_test.go b/plugins/discovery/discovery_test.go index c6aa726bd4..5367e8f895 100644 --- a/plugins/discovery/discovery_test.go +++ b/plugins/discovery/discovery_test.go @@ -1905,7 +1905,11 @@ func TestReconfigureWithLocalOverride(t *testing.T) { *period = 10 threshold := new(int64) *threshold = 90 - expectedCacheConf := &cache.Config{InterQueryBuiltinCache: cache.InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, StaleEntryEvictionPeriodSeconds: period, ForcedEvictionThresholdPercentage: threshold}} + maxNumEntriesInterQueryValueCache := new(int) + *maxNumEntriesInterQueryValueCache = 0 + + expectedCacheConf := &cache.Config{InterQueryBuiltinCache: cache.InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, StaleEntryEvictionPeriodSeconds: period, ForcedEvictionThresholdPercentage: threshold}, + InterQueryBuiltinValueCache: cache.InterQueryBuiltinValueCacheConfig{MaxNumEntries: maxNumEntriesInterQueryValueCache}} if !reflect.DeepEqual(cacheConf, expectedCacheConf) { t.Fatalf("want %v got %v", expectedCacheConf, cacheConf) diff --git a/plugins/plugins.go b/plugins/plugins.go index bacdd15076..567acfb817 100644 --- a/plugins/plugins.go +++ b/plugins/plugins.go @@ -576,7 +576,7 @@ func (m *Manager) Labels() map[string]string { return m.Config.Labels } -// InterQueryBuiltinCacheConfig returns the configuration for the inter-query cache. +// InterQueryBuiltinCacheConfig returns the configuration for the inter-query caches. func (m *Manager) InterQueryBuiltinCacheConfig() *cache.Config { m.mtx.Lock() defer m.mtx.Unlock() diff --git a/plugins/plugins_test.go b/plugins/plugins_test.go index b04096bab7..960d44ae8f 100644 --- a/plugins/plugins_test.go +++ b/plugins/plugins_test.go @@ -396,7 +396,7 @@ func TestPluginManagerInitIdempotence(t *testing.T) { } func TestManagerWithCachingConfig(t *testing.T) { - m, err := New([]byte(`{"caching": {"inter_query_builtin_cache": {"max_size_bytes": 100}}}`), "test", inmem.New()) + m, err := New([]byte(`{"caching": {"inter_query_builtin_cache": {"max_size_bytes": 100}, "inter_query_builtin_value_cache": {"max_num_entries": 100}}}`), "test", inmem.New()) if err != nil { t.Fatal(err) } @@ -404,6 +404,8 @@ func TestManagerWithCachingConfig(t *testing.T) { expected, _ := cache.ParseCachingConfig(nil) limit := int64(100) expected.InterQueryBuiltinCache.MaxSizeBytes = &limit + maxNumEntriesInterQueryValueCache := int(100) + expected.InterQueryBuiltinValueCache.MaxNumEntries = &maxNumEntriesInterQueryValueCache if !reflect.DeepEqual(m.InterQueryBuiltinCacheConfig(), expected) { t.Fatalf("want %+v got %+v", expected, m.interQueryBuiltinCacheConfig) @@ -414,6 +416,12 @@ func TestManagerWithCachingConfig(t *testing.T) { if err == nil { t.Fatal("expected error but got nil") } + + // config error + _, err = New([]byte(`{"caching": {"inter_query_builtin_value_cache": {"max_num_entries": "100"}}}`), "test", inmem.New()) + if err == nil { + t.Fatal("expected error but got nil") + } } func TestManagerWithNDCachingConfig(t *testing.T) { diff --git a/rego/rego.go b/rego/rego.go index df349b25cb..8672efd669 100644 --- a/rego/rego.go +++ b/rego/rego.go @@ -99,32 +99,33 @@ type preparedQuery struct { // EvalContext defines the set of options allowed to be set at evaluation // time. Any other options will need to be set on a new Rego object. type EvalContext struct { - hasInput bool - time time.Time - seed io.Reader - rawInput *interface{} - parsedInput ast.Value - metrics metrics.Metrics - txn storage.Transaction - instrument bool - instrumentation *topdown.Instrumentation - partialNamespace string - queryTracers []topdown.QueryTracer - compiledQuery compiledQuery - unknowns []string - disableInlining []ast.Ref - parsedUnknowns []*ast.Term - indexing bool - earlyExit bool - interQueryBuiltinCache cache.InterQueryCache - ndBuiltinCache builtins.NDBCache - resolvers []refResolver - sortSets bool - copyMaps bool - printHook print.Hook - capabilities *ast.Capabilities - strictBuiltinErrors bool - virtualCache topdown.VirtualCache + hasInput bool + time time.Time + seed io.Reader + rawInput *interface{} + parsedInput ast.Value + metrics metrics.Metrics + txn storage.Transaction + instrument bool + instrumentation *topdown.Instrumentation + partialNamespace string + queryTracers []topdown.QueryTracer + compiledQuery compiledQuery + unknowns []string + disableInlining []ast.Ref + parsedUnknowns []*ast.Term + indexing bool + earlyExit bool + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + ndBuiltinCache builtins.NDBCache + resolvers []refResolver + sortSets bool + copyMaps bool + printHook print.Hook + capabilities *ast.Capabilities + strictBuiltinErrors bool + virtualCache topdown.VirtualCache } func (e *EvalContext) RawInput() *interface{} { @@ -147,6 +148,10 @@ func (e *EvalContext) InterQueryBuiltinCache() cache.InterQueryCache { return e.interQueryBuiltinCache } +func (e *EvalContext) InterQueryBuiltinValueCache() cache.InterQueryValueCache { + return e.interQueryBuiltinValueCache +} + func (e *EvalContext) PrintHook() print.Hook { return e.printHook } @@ -307,6 +312,14 @@ func EvalInterQueryBuiltinCache(c cache.InterQueryCache) EvalOption { } } +// EvalInterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize +// during evaluation. +func EvalInterQueryBuiltinValueCache(c cache.InterQueryValueCache) EvalOption { + return func(e *EvalContext) { + e.interQueryBuiltinValueCache = c + } +} + // EvalNDBuiltinCache sets the non-deterministic builtin cache that built-in functions can // use during evaluation. func EvalNDBuiltinCache(c builtins.NDBCache) EvalOption { @@ -546,64 +559,65 @@ type loadPaths struct { // Rego constructs a query and can be evaluated to obtain results. type Rego struct { - query string - parsedQuery ast.Body - compiledQueries map[queryType]compiledQuery - pkg string - parsedPackage *ast.Package - imports []string - parsedImports []*ast.Import - rawInput *interface{} - parsedInput ast.Value - unknowns []string - parsedUnknowns []*ast.Term - disableInlining []string - shallowInlining bool - skipPartialNamespace bool - partialNamespace string - modules []rawModule - parsedModules map[string]*ast.Module - compiler *ast.Compiler - store storage.Store - ownStore bool - txn storage.Transaction - metrics metrics.Metrics - queryTracers []topdown.QueryTracer - tracebuf *topdown.BufferTracer - trace bool - instrumentation *topdown.Instrumentation - instrument bool - capture map[*ast.Expr]ast.Var // map exprs to generated capture vars - termVarID int - dump io.Writer - runtime *ast.Term - time time.Time - seed io.Reader - capabilities *ast.Capabilities - builtinDecls map[string]*ast.Builtin - builtinFuncs map[string]*topdown.Builtin - unsafeBuiltins map[string]struct{} - loadPaths loadPaths - bundlePaths []string - bundles map[string]*bundle.Bundle - skipBundleVerification bool - interQueryBuiltinCache cache.InterQueryCache - ndBuiltinCache builtins.NDBCache - strictBuiltinErrors bool - builtinErrorList *[]topdown.Error - resolvers []refResolver - schemaSet *ast.SchemaSet - target string // target type (wasm, rego, etc.) - opa opa.EvalEngine - generateJSON func(*ast.Term, *EvalContext) (interface{}, error) - printHook print.Hook - enablePrintStatements bool - distributedTacingOpts tracing.Options - strict bool - pluginMgr *plugins.Manager - plugins []TargetPlugin - targetPrepState TargetPluginEval - regoVersion ast.RegoVersion + query string + parsedQuery ast.Body + compiledQueries map[queryType]compiledQuery + pkg string + parsedPackage *ast.Package + imports []string + parsedImports []*ast.Import + rawInput *interface{} + parsedInput ast.Value + unknowns []string + parsedUnknowns []*ast.Term + disableInlining []string + shallowInlining bool + skipPartialNamespace bool + partialNamespace string + modules []rawModule + parsedModules map[string]*ast.Module + compiler *ast.Compiler + store storage.Store + ownStore bool + txn storage.Transaction + metrics metrics.Metrics + queryTracers []topdown.QueryTracer + tracebuf *topdown.BufferTracer + trace bool + instrumentation *topdown.Instrumentation + instrument bool + capture map[*ast.Expr]ast.Var // map exprs to generated capture vars + termVarID int + dump io.Writer + runtime *ast.Term + time time.Time + seed io.Reader + capabilities *ast.Capabilities + builtinDecls map[string]*ast.Builtin + builtinFuncs map[string]*topdown.Builtin + unsafeBuiltins map[string]struct{} + loadPaths loadPaths + bundlePaths []string + bundles map[string]*bundle.Bundle + skipBundleVerification bool + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + ndBuiltinCache builtins.NDBCache + strictBuiltinErrors bool + builtinErrorList *[]topdown.Error + resolvers []refResolver + schemaSet *ast.SchemaSet + target string // target type (wasm, rego, etc.) + opa opa.EvalEngine + generateJSON func(*ast.Term, *EvalContext) (interface{}, error) + printHook print.Hook + enablePrintStatements bool + distributedTacingOpts tracing.Options + strict bool + pluginMgr *plugins.Manager + plugins []TargetPlugin + targetPrepState TargetPluginEval + regoVersion ast.RegoVersion } // Function represents a built-in function that is callable in Rego. @@ -1114,6 +1128,14 @@ func InterQueryBuiltinCache(c cache.InterQueryCache) func(r *Rego) { } } +// InterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize +// during evaluation. +func InterQueryBuiltinValueCache(c cache.InterQueryValueCache) func(r *Rego) { + return func(r *Rego) { + r.interQueryBuiltinValueCache = c + } +} + // NDBuiltinCache sets the non-deterministic builtins cache. func NDBuiltinCache(c builtins.NDBCache) func(r *Rego) { return func(r *Rego) { @@ -1309,6 +1331,7 @@ func (r *Rego) Eval(ctx context.Context) (ResultSet, error) { EvalInstrument(r.instrument), EvalTime(r.time), EvalInterQueryBuiltinCache(r.interQueryBuiltinCache), + EvalInterQueryBuiltinValueCache(r.interQueryBuiltinValueCache), EvalSeed(r.seed), } @@ -1386,6 +1409,7 @@ func (r *Rego) Partial(ctx context.Context) (*PartialQueries, error) { EvalMetrics(r.metrics), EvalInstrument(r.instrument), EvalInterQueryBuiltinCache(r.interQueryBuiltinCache), + EvalInterQueryBuiltinValueCache(r.interQueryBuiltinValueCache), } if r.ndBuiltinCache != nil { @@ -2106,6 +2130,7 @@ func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) { WithIndexing(ectx.indexing). WithEarlyExit(ectx.earlyExit). WithInterQueryBuiltinCache(ectx.interQueryBuiltinCache). + WithInterQueryBuiltinValueCache(ectx.interQueryBuiltinValueCache). WithStrictBuiltinErrors(r.strictBuiltinErrors). WithBuiltinErrorList(r.builtinErrorList). WithSeed(ectx.seed). @@ -2164,7 +2189,6 @@ func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) { } func (r *Rego) evalWasm(ctx context.Context, ectx *EvalContext) (ResultSet, error) { - input := ectx.rawInput if ectx.parsedInput != nil { i := interface{}(ectx.parsedInput) @@ -2393,6 +2417,7 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, WithSkipPartialNamespace(r.skipPartialNamespace). WithShallowInlining(r.shallowInlining). WithInterQueryBuiltinCache(ectx.interQueryBuiltinCache). + WithInterQueryBuiltinValueCache(ectx.interQueryBuiltinValueCache). WithStrictBuiltinErrors(ectx.strictBuiltinErrors). WithSeed(ectx.seed). WithPrintHook(ectx.printHook) diff --git a/rego/rego_test.go b/rego/rego_test.go index 86550e35f5..208ce5123f 100644 --- a/rego/rego_test.go +++ b/rego/rego_test.go @@ -2432,6 +2432,55 @@ func TestEvalWithInterQueryCache(t *testing.T) { } } +func TestEvalWithInterQueryValueCache(t *testing.T) { + ctx := context.Background() + + // add an inter-query value cache + config, _ := cache.ParseCachingConfig(nil) + interQueryValueCache := cache.NewInterQueryValueCache(ctx, config) + + m := metrics.New() + + query := `regex.match("foo.*", "foobar")` + _, err := New(Query(query), InterQueryBuiltinValueCache(interQueryValueCache), Metrics(m)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + // eval again with same query + // this request should be served by the cache + _, err = New(Query(query), InterQueryBuiltinValueCache(interQueryValueCache), Metrics(m)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + if exp, act := uint64(1), m.Counter("rego_builtin_regex_interquery_value_cache_hits").Value(); exp != act { + t.Fatalf("expected %d cache hits, got %d", exp, act) + } + + query = `glob.match("*.example.com", ["."], "api.example.com")` + _, err = New(Query(query), InterQueryBuiltinValueCache(interQueryValueCache), Metrics(m)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + // eval again with same query + // this request should be served by the cache + _, err = New(Query(query), InterQueryBuiltinValueCache(interQueryValueCache), Metrics(m)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + _, err = New(Query(query), InterQueryBuiltinValueCache(interQueryValueCache), Metrics(m)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + if exp, act := uint64(2), m.Counter("rego_builtin_glob_interquery_value_cache_hits").Value(); exp != act { + t.Fatalf("expected %d cache hits, got %d", exp, act) + } +} + // We use http.send to ensure the NDBuiltinCache is involved. func TestEvalWithNDCache(t *testing.T) { var requests []*http.Request diff --git a/sdk/opa.go b/sdk/opa.go index 7203df042c..8b180ea1a6 100644 --- a/sdk/opa.go +++ b/sdk/opa.go @@ -53,9 +53,10 @@ type OPA struct { } type state struct { - manager *plugins.Manager - interQueryBuiltinCache cache.InterQueryCache - queryCache *queryCache + manager *plugins.Manager + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + queryCache *queryCache } // New returns a new OPA object. This function should minimally be called with @@ -235,6 +236,7 @@ func (opa *OPA) configure(ctx context.Context, bs []byte, ready chan struct{}, b opa.state.manager = manager opa.state.queryCache.Clear() opa.state.interQueryBuiltinCache = cache.NewInterQueryCacheWithContext(ctx, manager.InterQueryBuiltinCacheConfig()) + opa.state.interQueryBuiltinValueCache = cache.NewInterQueryValueCache(ctx, manager.InterQueryBuiltinCacheConfig()) opa.config = bs return nil @@ -277,22 +279,23 @@ func (opa *OPA) Decision(ctx context.Context, options DecisionOptions) (*Decisio &record, func(s state, result *DecisionResult) { result.Result, result.Provenance, record.InputAST, record.Bundles, record.Error = evaluate(ctx, evalArgs{ - runtime: s.manager.Info, - printHook: s.manager.PrintHook(), - compiler: s.manager.GetCompiler(), - store: s.manager.Store, - queryCache: s.queryCache, - interQueryCache: s.interQueryBuiltinCache, - ndbcache: ndbc, - txn: record.Txn, - now: record.Timestamp, - path: record.Path, - input: *record.Input, - m: record.Metrics, - strictBuiltinErrors: options.StrictBuiltinErrors, - tracer: options.Tracer, - profiler: options.Profiler, - instrument: options.Instrument, + runtime: s.manager.Info, + printHook: s.manager.PrintHook(), + compiler: s.manager.GetCompiler(), + store: s.manager.Store, + queryCache: s.queryCache, + interQueryCache: s.interQueryBuiltinCache, + interQueryBuiltinValueCache: s.interQueryBuiltinValueCache, + ndbcache: ndbc, + txn: record.Txn, + now: record.Timestamp, + path: record.Path, + input: *record.Input, + m: record.Metrics, + strictBuiltinErrors: options.StrictBuiltinErrors, + tracer: options.Tracer, + profiler: options.Profiler, + instrument: options.Instrument, }) if record.Error == nil { record.Results = &result.Result @@ -506,22 +509,23 @@ func IsUndefinedErr(err error) bool { } type evalArgs struct { - runtime *ast.Term - printHook print.Hook - compiler *ast.Compiler - store storage.Store - txn storage.Transaction - queryCache *queryCache - interQueryCache cache.InterQueryCache - now time.Time - path string - input interface{} - ndbcache builtins.NDBCache - m metrics.Metrics - strictBuiltinErrors bool - tracer topdown.QueryTracer - profiler topdown.QueryTracer - instrument bool + runtime *ast.Term + printHook print.Hook + compiler *ast.Compiler + store storage.Store + txn storage.Transaction + queryCache *queryCache + interQueryCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + now time.Time + path string + input interface{} + ndbcache builtins.NDBCache + m metrics.Metrics + strictBuiltinErrors bool + tracer topdown.QueryTracer + profiler topdown.QueryTracer + instrument bool } func evaluate(ctx context.Context, args evalArgs) (interface{}, types.ProvenanceV1, ast.Value, map[string]server.BundleInfo, error) { @@ -581,6 +585,7 @@ func evaluate(ctx context.Context, args evalArgs) (interface{}, types.Provenance rego.EvalTransaction(args.txn), rego.EvalMetrics(args.m), rego.EvalInterQueryBuiltinCache(args.interQueryCache), + rego.EvalInterQueryBuiltinValueCache(args.interQueryBuiltinValueCache), rego.EvalNDBuiltinCache(args.ndbcache), rego.EvalQueryTracer(args.tracer), rego.EvalMetrics(args.m), diff --git a/server/authorizer/authorizer.go b/server/authorizer/authorizer.go index 8dcc3c3394..4240f40377 100644 --- a/server/authorizer/authorizer.go +++ b/server/authorizer/authorizer.go @@ -32,6 +32,7 @@ type Basic struct { printHook print.Hook enablePrintStatements bool interQueryCache cache.InterQueryCache + interQueryValueCache cache.InterQueryValueCache } // Runtime returns an argument that sets the runtime on the authorizer. @@ -73,6 +74,13 @@ func InterQueryCache(interQueryCache cache.InterQueryCache) func(*Basic) { } } +// InterQueryValueCache enables the inter-query value cache on the authorizer +func InterQueryValueCache(interQueryValueCache cache.InterQueryValueCache) func(*Basic) { + return func(b *Basic) { + b.interQueryValueCache = interQueryValueCache + } +} + // NewBasic returns a new Basic object. func NewBasic(inner http.Handler, compiler func() *ast.Compiler, store storage.Store, opts ...func(*Basic)) http.Handler { b := &Basic{ @@ -107,6 +115,7 @@ func (h *Basic) ServeHTTP(w http.ResponseWriter, r *http.Request) { rego.EnablePrintStatements(h.enablePrintStatements), rego.PrintHook(h.printHook), rego.InterQueryBuiltinCache(h.interQueryCache), + rego.InterQueryBuiltinValueCache(h.interQueryValueCache), ) rs, err := rego.Eval(r.Context()) diff --git a/server/authorizer/authorizer_test.go b/server/authorizer/authorizer_test.go index 17478c8f1a..712142054a 100644 --- a/server/authorizer/authorizer_test.go +++ b/server/authorizer/authorizer_test.go @@ -6,6 +6,7 @@ package authorizer import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -516,6 +517,43 @@ func TestInterQueryCache(t *testing.T) { } } +func TestInterQueryValueCache(t *testing.T) { + + compiler := func() *ast.Compiler { + module := ` + package system.authz + import rego.v1 + + allow if { + regex.match("foo.*", "foobar") + }` + c := ast.NewCompiler() + c.Compile(map[string]*ast.Module{ + "test.rego": ast.MustParseModule(module), + }) + if c.Failed() { + t.Fatalf("Unexpected error compiling test module: %v", c.Errors) + } + return c + } + + recorder := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "http://localhost:8181/v1/data", nil) + if err != nil { + t.Fatal(err) + } + + config, _ := cache.ParseCachingConfig(nil) + interQueryValueCache := cache.NewInterQueryValueCache(context.Background(), config) + + basic := NewBasic(&mockHandler{}, compiler, inmem.New(), InterQueryValueCache(interQueryValueCache), Decision(func() ast.Ref { + return ast.MustParseRef("data.system.authz.allow") + })) + + // Execute the policy + basic.ServeHTTP(recorder, req) +} + func Equal(a, b []string) bool { if len(a) != len(b) { return false diff --git a/server/server.go b/server/server.go index c059719a96..bb5c70767e 100644 --- a/server/server.go +++ b/server/server.go @@ -111,42 +111,43 @@ type Server struct { Handler http.Handler DiagnosticHandler http.Handler - router *mux.Router - addrs []string - diagAddrs []string - h2cEnabled bool - authentication AuthenticationScheme - authorization AuthorizationScheme - cert *tls.Certificate - tlsConfigMtx sync.RWMutex - certFile string - certFileHash []byte - certKeyFile string - certKeyFileHash []byte - certRefresh time.Duration - certPool *x509.CertPool - certPoolFile string - certPoolFileHash []byte - minTLSVersion uint16 - mtx sync.RWMutex - partials map[string]rego.PartialResult - preparedEvalQueries *cache - store storage.Store - manager *plugins.Manager - decisionIDFactory func() string - logger func(context.Context, *Info) error - errLimit int - pprofEnabled bool - runtime *ast.Term - httpListeners []httpListener - metrics Metrics - defaultDecisionPath string - interQueryBuiltinCache iCache.InterQueryCache - allPluginsOkOnce bool - distributedTracingOpts tracing.Options - ndbCacheEnabled bool - unixSocketPerm *string - cipherSuites *[]uint16 + router *mux.Router + addrs []string + diagAddrs []string + h2cEnabled bool + authentication AuthenticationScheme + authorization AuthorizationScheme + cert *tls.Certificate + tlsConfigMtx sync.RWMutex + certFile string + certFileHash []byte + certKeyFile string + certKeyFileHash []byte + certRefresh time.Duration + certPool *x509.CertPool + certPoolFile string + certPoolFileHash []byte + minTLSVersion uint16 + mtx sync.RWMutex + partials map[string]rego.PartialResult + preparedEvalQueries *cache + store storage.Store + manager *plugins.Manager + decisionIDFactory func() string + logger func(context.Context, *Info) error + errLimit int + pprofEnabled bool + runtime *ast.Term + httpListeners []httpListener + metrics Metrics + defaultDecisionPath string + interQueryBuiltinCache iCache.InterQueryCache + interQueryBuiltinValueCache iCache.InterQueryValueCache + allPluginsOkOnce bool + distributedTracingOpts tracing.Options + ndbCacheEnabled bool + unixSocketPerm *string + cipherSuites *[]uint16 } // Metrics defines the interface that the server requires for recording HTTP @@ -748,7 +749,8 @@ func (s *Server) initHandlerAuthz(handler http.Handler) http.Handler { authorizer.Decision(s.manager.Config.DefaultAuthorizationDecisionRef), authorizer.PrintHook(s.manager.PrintHook()), authorizer.EnablePrintStatements(s.manager.EnablePrintStatements()), - authorizer.InterQueryCache(s.interQueryBuiltinCache)) + authorizer.InterQueryCache(s.interQueryBuiltinCache), + authorizer.InterQueryValueCache(s.interQueryBuiltinValueCache)) if s.metrics != nil { handler = s.instrumentHandler(handler.ServeHTTP, PromHandlerAPIAuthz) @@ -800,7 +802,12 @@ func (s *Server) initRouters(ctx context.Context) { diagRouter := mux.NewRouter() // authorizer, if configured, needs the iCache to be set up already - s.interQueryBuiltinCache = iCache.NewInterQueryCacheWithContext(ctx, s.manager.InterQueryBuiltinCacheConfig()) + + cacheConfig := s.manager.InterQueryBuiltinCacheConfig() + + s.interQueryBuiltinCache = iCache.NewInterQueryCacheWithContext(ctx, cacheConfig) + s.interQueryBuiltinValueCache = iCache.NewInterQueryValueCache(ctx, cacheConfig) + s.manager.RegisterCacheTrigger(s.updateCacheConfig) // Add authorization handler. This must come BEFORE authentication handler @@ -933,6 +940,7 @@ func (s *Server) execQuery(ctx context.Context, br bundleRevisions, txn storage. rego.Runtime(s.runtime), rego.UnsafeBuiltins(unsafeBuiltinsMap), rego.InterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.InterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), rego.PrintHook(s.manager.PrintHook()), rego.EnablePrintStatements(s.manager.EnablePrintStatements()), rego.DistributedTracingOpts(s.distributedTracingOpts), @@ -1121,6 +1129,7 @@ func (s *Server) v0QueryPath(w http.ResponseWriter, r *http.Request, urlPath str rego.EvalParsedInput(input), rego.EvalMetrics(m), rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), rego.EvalNDBuiltinCache(ndbCache), } @@ -1402,6 +1411,7 @@ func (s *Server) v1CompilePost(w http.ResponseWriter, r *http.Request) { rego.Runtime(s.runtime), rego.UnsafeBuiltins(unsafeBuiltinsMap), rego.InterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.InterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), rego.PrintHook(s.manager.PrintHook()), ) @@ -1541,6 +1551,7 @@ func (s *Server) v1DataGet(w http.ResponseWriter, r *http.Request) { rego.EvalMetrics(m), rego.EvalQueryTracer(buf), rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), rego.EvalInstrument(includeInstrumentation), rego.EvalNDBuiltinCache(ndbCache), } @@ -1760,6 +1771,7 @@ func (s *Server) v1DataPost(w http.ResponseWriter, r *http.Request) { rego.EvalMetrics(m), rego.EvalQueryTracer(buf), rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), rego.EvalInstrument(includeInstrumentation), rego.EvalNDBuiltinCache(ndbCache), } @@ -2655,6 +2667,7 @@ func isPathOwned(path, root []string) bool { func (s *Server) updateCacheConfig(cacheConfig *iCache.Config) { s.interQueryBuiltinCache.UpdateConfig(cacheConfig) + s.interQueryBuiltinValueCache.UpdateConfig(cacheConfig) } func (s *Server) updateNDCache(enabled bool) { diff --git a/topdown/builtins.go b/topdown/builtins.go index 30c488050f..cf694d4331 100644 --- a/topdown/builtins.go +++ b/topdown/builtins.go @@ -35,25 +35,26 @@ type ( // BuiltinContext contains context from the evaluator that may be used by // built-in functions. BuiltinContext struct { - Context context.Context // request context that was passed when query started - Metrics metrics.Metrics // metrics registry for recording built-in specific metrics - Seed io.Reader // randomization source - Time *ast.Term // wall clock time - Cancel Cancel // atomic value that signals evaluation to halt - Runtime *ast.Term // runtime information on the OPA instance - Cache builtins.Cache // built-in function state cache - InterQueryBuiltinCache cache.InterQueryCache // cross-query built-in function state cache - NDBuiltinCache builtins.NDBCache // cache for non-deterministic built-in state - Location *ast.Location // location of built-in call - Tracers []Tracer // Deprecated: Use QueryTracers instead - QueryTracers []QueryTracer // tracer objects for trace() built-in function - TraceEnabled bool // indicates whether tracing is enabled for the evaluation - QueryID uint64 // identifies query being evaluated - ParentID uint64 // identifies parent of query being evaluated - PrintHook print.Hook // provides callback function to use for printing - DistributedTracingOpts tracing.Options // options to be used by distributed tracing. - rand *rand.Rand // randomization source for non-security-sensitive operations - Capabilities *ast.Capabilities + Context context.Context // request context that was passed when query started + Metrics metrics.Metrics // metrics registry for recording built-in specific metrics + Seed io.Reader // randomization source + Time *ast.Term // wall clock time + Cancel Cancel // atomic value that signals evaluation to halt + Runtime *ast.Term // runtime information on the OPA instance + Cache builtins.Cache // built-in function state cache + InterQueryBuiltinCache cache.InterQueryCache // cross-query built-in function state cache + InterQueryBuiltinValueCache cache.InterQueryValueCache // cross-query built-in function state value cache. this cache is useful for scenarios where the entry size cannot be calculated + NDBuiltinCache builtins.NDBCache // cache for non-deterministic built-in state + Location *ast.Location // location of built-in call + Tracers []Tracer // Deprecated: Use QueryTracers instead + QueryTracers []QueryTracer // tracer objects for trace() built-in function + TraceEnabled bool // indicates whether tracing is enabled for the evaluation + QueryID uint64 // identifies query being evaluated + ParentID uint64 // identifies parent of query being evaluated + PrintHook print.Hook // provides callback function to use for printing + DistributedTracingOpts tracing.Options // options to be used by distributed tracing. + rand *rand.Rand // randomization source for non-security-sensitive operations + Capabilities *ast.Capabilities } // BuiltinFunc defines an interface for implementing built-in functions. diff --git a/topdown/cache/cache.go b/topdown/cache/cache.go index c83c9828bf..55ed340619 100644 --- a/topdown/cache/cache.go +++ b/topdown/cache/cache.go @@ -18,14 +18,22 @@ import ( ) const ( + defaultInterQueryBuiltinValueCacheSize = int(0) // unlimited defaultMaxSizeBytes = int64(0) // unlimited defaultForcedEvictionThresholdPercentage = int64(100) // trigger at max_size_bytes defaultStaleEntryEvictionPeriodSeconds = int64(0) // never ) -// Config represents the configuration of the inter-query cache. +// Config represents the configuration for the inter-query builtin cache. type Config struct { - InterQueryBuiltinCache InterQueryBuiltinCacheConfig `json:"inter_query_builtin_cache"` + InterQueryBuiltinCache InterQueryBuiltinCacheConfig `json:"inter_query_builtin_cache"` + InterQueryBuiltinValueCache InterQueryBuiltinValueCacheConfig `json:"inter_query_builtin_value_cache"` +} + +// InterQueryBuiltinValueCacheConfig represents the configuration of the inter-query value cache that built-in functions can utilize. +// MaxNumEntries - max number of cache entries +type InterQueryBuiltinValueCacheConfig struct { + MaxNumEntries *int `json:"max_num_entries,omitempty"` } // InterQueryBuiltinCacheConfig represents the configuration of the inter-query cache that built-in functions can utilize. @@ -47,7 +55,12 @@ func ParseCachingConfig(raw []byte) (*Config, error) { *threshold = defaultForcedEvictionThresholdPercentage period := new(int64) *period = defaultStaleEntryEvictionPeriodSeconds - return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, ForcedEvictionThresholdPercentage: threshold, StaleEntryEvictionPeriodSeconds: period}}, nil + + maxInterQueryBuiltinValueCacheSize := new(int) + *maxInterQueryBuiltinValueCacheSize = defaultInterQueryBuiltinValueCacheSize + + return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, ForcedEvictionThresholdPercentage: threshold, StaleEntryEvictionPeriodSeconds: period}, + InterQueryBuiltinValueCache: InterQueryBuiltinValueCacheConfig{MaxNumEntries: maxInterQueryBuiltinValueCacheSize}}, nil } var config Config @@ -89,6 +102,18 @@ func (c *Config) validateAndInjectDefaults() error { return fmt.Errorf("invalid stale_entry_eviction_period_seconds %v", period) } } + + if c.InterQueryBuiltinValueCache.MaxNumEntries == nil { + maxSize := new(int) + *maxSize = defaultInterQueryBuiltinValueCacheSize + c.InterQueryBuiltinValueCache.MaxNumEntries = maxSize + } else { + numEntries := *c.InterQueryBuiltinValueCache.MaxNumEntries + if numEntries < 0 { + return fmt.Errorf("invalid max_num_entries %v", numEntries) + } + } + return nil } @@ -301,3 +326,81 @@ func (c *cache) cleanStaleValues() (dropped int) { } return dropped } + +type InterQueryValueCache interface { + Get(key ast.Value) (value any, found bool) + Insert(key ast.Value, value any) int + Delete(key ast.Value) + UpdateConfig(config *Config) +} + +type interQueryValueCache struct { + items map[string]any + config *Config + mtx sync.RWMutex +} + +// Get returns the value in the cache for k. +func (c *interQueryValueCache) Get(k ast.Value) (any, bool) { + c.mtx.RLock() + defer c.mtx.RUnlock() + value, ok := c.items[k.String()] + return value, ok +} + +// Insert inserts a key k into the cache with value v. +func (c *interQueryValueCache) Insert(k ast.Value, v any) (dropped int) { + c.mtx.Lock() + defer c.mtx.Unlock() + + maxEntries := c.maxNumEntries() + if maxEntries > 0 { + if len(c.items) >= maxEntries { + itemsToRemove := len(c.items) - maxEntries + 1 + + // Delete a (semi-)random key to make room for the new one. + for k := range c.items { + delete(c.items, k) + dropped++ + + if itemsToRemove == dropped { + break + } + } + } + } + + c.items[k.String()] = v + return dropped +} + +// Delete deletes the value in the cache for k. +func (c *interQueryValueCache) Delete(k ast.Value) { + c.mtx.Lock() + defer c.mtx.Unlock() + delete(c.items, k.String()) +} + +// UpdateConfig updates the cache config. +func (c *interQueryValueCache) UpdateConfig(config *Config) { + if config == nil { + return + } + c.mtx.Lock() + defer c.mtx.Unlock() + c.config = config +} + +func (c *interQueryValueCache) maxNumEntries() int { + if c.config == nil { + return defaultInterQueryBuiltinValueCacheSize + } + return *c.config.InterQueryBuiltinValueCache.MaxNumEntries +} + +func NewInterQueryValueCache(_ context.Context, config *Config) InterQueryValueCache { + return &interQueryValueCache{ + items: map[string]any{}, + config: config, + } +} diff --git a/topdown/cache/cache_test.go b/topdown/cache/cache_test.go index 85ccb20913..b5c1a3865a 100644 --- a/topdown/cache/cache_test.go +++ b/topdown/cache/cache_test.go @@ -21,7 +21,11 @@ func TestParseCachingConfig(t *testing.T) { *period = defaultStaleEntryEvictionPeriodSeconds threshold := new(int64) *threshold = defaultForcedEvictionThresholdPercentage - expected := &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, StaleEntryEvictionPeriodSeconds: period, ForcedEvictionThresholdPercentage: threshold}} + maxNumEntriesInterQueryValueCache := new(int) + *maxNumEntriesInterQueryValueCache = defaultInterQueryBuiltinValueCacheSize + + expected := &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, StaleEntryEvictionPeriodSeconds: period, ForcedEvictionThresholdPercentage: threshold}, + InterQueryBuiltinValueCache: InterQueryBuiltinValueCacheConfig{MaxNumEntries: maxNumEntriesInterQueryValueCache}} tests := map[string]struct { input []byte @@ -35,10 +39,18 @@ func TestParseCachingConfig(t *testing.T) { input: []byte(`{"inter_query_builtin_cache": {},}`), wantErr: false, }, + "default_num_entries": { + input: []byte(`{"inter_query_builtin_value_cache": {},}`), + wantErr: false, + }, "bad_limit": { input: []byte(`{"inter_query_builtin_cache": {"max_size_bytes": "100"},}`), wantErr: true, }, + "bad_num_entries": { + input: []byte(`{"inter_query_builtin_value_cache": {"max_num_entries": "100"},}`), + wantErr: true, + }, } for name, tc := range tests { @@ -165,6 +177,97 @@ func TestInsert(t *testing.T) { } } +func TestInterQueryValueCache(t *testing.T) { + + in := `{"inter_query_builtin_value_cache": {"max_num_entries": 4},}` + + config, err := ParseCachingConfig([]byte(in)) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + cache := NewInterQueryValueCache(context.Background(), config) + + cache.Insert(ast.StringTerm("foo").Value, "bar") + cache.Insert(ast.StringTerm("foo2").Value, "bar2") + cache.Insert(ast.StringTerm("hello").Value, "world") + dropped := cache.Insert(ast.StringTerm("hello2").Value, "world2") + + if dropped != 0 { + t.Fatal("Expected dropped to be zero") + } + + value, found := cache.Get(ast.StringTerm("foo").Value) + if !found { + t.Fatal("Expected key \"foo\" in cache") + } + + actual, ok := value.(string) + if !ok { + t.Fatal("Expected string value") + } + + if actual != "bar" { + t.Fatalf("Expected value \"bar\" but got %v", actual) + } + + dropped = cache.Insert(ast.StringTerm("foo3").Value, "bar3") + if dropped != 1 { + t.Fatal("Expected dropped to be one") + } + + _, found = cache.Get(ast.StringTerm("foo3").Value) + if !found { + t.Fatal("Expected key \"foo3\" in cache") + } + + // update the cache config + in = `{"inter_query_builtin_value_cache": {"max_num_entries": 0},}` // unlimited + config, err = ParseCachingConfig([]byte(in)) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + cache.UpdateConfig(config) + + cache.Insert(ast.StringTerm("a").Value, "b") + cache.Insert(ast.StringTerm("c").Value, "d") + cache.Insert(ast.StringTerm("e").Value, "f") + dropped = cache.Insert(ast.StringTerm("g").Value, "h") + + if dropped != 0 { + t.Fatal("Expected dropped to be zero") + } + + // at this point the cache should have 8 entries + // update the cache size and verify multiple items dropped + in = `{"inter_query_builtin_value_cache": {"max_num_entries": 6},}` + config, err = ParseCachingConfig([]byte(in)) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + cache.UpdateConfig(config) + + dropped = cache.Insert(ast.StringTerm("i").Value, "j") + + if dropped != 3 { + t.Fatal("Expected dropped to be three") + } + + _, found = cache.Get(ast.StringTerm("i").Value) + if !found { + t.Fatal("Expected key \"i\" in cache") + } + + cache.Delete(ast.StringTerm("i").Value) + + _, found = cache.Get(ast.StringTerm("i").Value) + if found { + t.Fatal("Unexpected key \"i\" in cache") + } +} + func TestConcurrentInsert(t *testing.T) { in := `{"inter_query_builtin_cache": {"max_size_bytes": 20},}` // 20 byte limit for test purposes diff --git a/topdown/eval.go b/topdown/eval.go index 2fcc431c80..7884ac01e0 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -58,55 +58,56 @@ func (ee deferredEarlyExitError) Error() string { } type eval struct { - ctx context.Context - metrics metrics.Metrics - seed io.Reader - time *ast.Term - queryID uint64 - queryIDFact *queryIDFactory - parent *eval - caller *eval - cancel Cancel - query ast.Body - queryCompiler ast.QueryCompiler - index int - indexing bool - earlyExit bool - bindings *bindings - store storage.Store - baseCache *baseCache - txn storage.Transaction - compiler *ast.Compiler - input *ast.Term - data *ast.Term - external *resolverTrie - targetStack *refStack - tracers []QueryTracer - traceEnabled bool - traceLastLocation *ast.Location // Last location of a trace event. - plugTraceVars bool - instr *Instrumentation - builtins map[string]*Builtin - builtinCache builtins.Cache - ndBuiltinCache builtins.NDBCache - functionMocks *functionMocksStack - virtualCache VirtualCache - comprehensionCache *comprehensionCache - interQueryBuiltinCache cache.InterQueryCache - saveSet *saveSet - saveStack *saveStack - saveSupport *saveSupport - saveNamespace *ast.Term - skipSaveNamespace bool - inliningControl *inliningControl - genvarprefix string - genvarid int - runtime *ast.Term - builtinErrors *builtinErrors - printHook print.Hook - tracingOpts tracing.Options - findOne bool - strictObjects bool + ctx context.Context + metrics metrics.Metrics + seed io.Reader + time *ast.Term + queryID uint64 + queryIDFact *queryIDFactory + parent *eval + caller *eval + cancel Cancel + query ast.Body + queryCompiler ast.QueryCompiler + index int + indexing bool + earlyExit bool + bindings *bindings + store storage.Store + baseCache *baseCache + txn storage.Transaction + compiler *ast.Compiler + input *ast.Term + data *ast.Term + external *resolverTrie + targetStack *refStack + tracers []QueryTracer + traceEnabled bool + traceLastLocation *ast.Location // Last location of a trace event. + plugTraceVars bool + instr *Instrumentation + builtins map[string]*Builtin + builtinCache builtins.Cache + ndBuiltinCache builtins.NDBCache + functionMocks *functionMocksStack + virtualCache VirtualCache + comprehensionCache *comprehensionCache + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + saveSet *saveSet + saveStack *saveStack + saveSupport *saveSupport + saveNamespace *ast.Term + skipSaveNamespace bool + inliningControl *inliningControl + genvarprefix string + genvarid int + runtime *ast.Term + builtinErrors *builtinErrors + printHook print.Hook + tracingOpts tracing.Options + findOne bool + strictObjects bool } func (e *eval) Run(iter evalIterator) error { @@ -817,23 +818,24 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { } bctx := BuiltinContext{ - Context: e.ctx, - Metrics: e.metrics, - Seed: e.seed, - Time: e.time, - Cancel: e.cancel, - Runtime: e.runtime, - Cache: e.builtinCache, - InterQueryBuiltinCache: e.interQueryBuiltinCache, - NDBuiltinCache: e.ndBuiltinCache, - Location: e.query[e.index].Location, - QueryTracers: e.tracers, - TraceEnabled: e.traceEnabled, - QueryID: e.queryID, - ParentID: parentID, - PrintHook: e.printHook, - DistributedTracingOpts: e.tracingOpts, - Capabilities: capabilities, + Context: e.ctx, + Metrics: e.metrics, + Seed: e.seed, + Time: e.time, + Cancel: e.cancel, + Runtime: e.runtime, + Cache: e.builtinCache, + InterQueryBuiltinCache: e.interQueryBuiltinCache, + InterQueryBuiltinValueCache: e.interQueryBuiltinValueCache, + NDBuiltinCache: e.ndBuiltinCache, + Location: e.query[e.index].Location, + QueryTracers: e.tracers, + TraceEnabled: e.traceEnabled, + QueryID: e.queryID, + ParentID: parentID, + PrintHook: e.printHook, + DistributedTracingOpts: e.tracingOpts, + Capabilities: capabilities, } eval := evalBuiltin{ diff --git a/topdown/glob.go b/topdown/glob.go index 116602db74..3f769d8369 100644 --- a/topdown/glob.go +++ b/topdown/glob.go @@ -15,7 +15,9 @@ const globCacheMaxSize = 100 var globCacheLock = sync.Mutex{} var globCache map[string]glob.Glob -func builtinGlobMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +var globInterQueryValueCacheHits = "rego_builtin_glob_interquery_value_cache_hits" + +func builtinGlobMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { pattern, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -50,14 +52,46 @@ func builtinGlobMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter } id := builder.String() - m, err := globCompileAndMatch(id, string(pattern), string(match), delimiters) + m, err := globCompileAndMatch(bctx, id, string(pattern), string(match), delimiters) if err != nil { return err } return iter(ast.BooleanTerm(m)) } -func globCompileAndMatch(id, pattern, match string, delimiters []rune) (bool, error) { +func globCompileAndMatch(bctx BuiltinContext, id, pattern, match string, delimiters []rune) (bool, error) { + + if bctx.InterQueryBuiltinValueCache != nil { + val, ok := bctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(id).Value) + if ok { + pat, valid := val.(glob.Glob) + if !valid { + // The cache key may exist for a different value type (eg. regex). + // In this case, we calculate the glob and return the result w/o updating the cache. + var res glob.Glob + var err error + if res, err = glob.Compile(pattern, delimiters...); err != nil { + return false, err + } + out := res.Match(match) + return out, nil + } + bctx.Metrics.Counter(globInterQueryValueCacheHits).Incr() + out := pat.Match(match) + return out, nil + } + + var res glob.Glob + var err error + if res, err = glob.Compile(pattern, delimiters...); err != nil { + return false, err + } + bctx.InterQueryBuiltinValueCache.Insert(ast.StringTerm(id).Value, res) + + out := res.Match(match) + return out, nil + } + globCacheLock.Lock() defer globCacheLock.Unlock() p, ok := globCache[id] diff --git a/topdown/glob_test.go b/topdown/glob_test.go index f72e5d8b2c..2ec7f3732e 100644 --- a/topdown/glob_test.go +++ b/topdown/glob_test.go @@ -5,10 +5,12 @@ package topdown import ( + "context" "fmt" "testing" "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/topdown/cache" ) func TestGlobBuiltinCache(t *testing.T) { @@ -69,3 +71,108 @@ func TestGlobBuiltinCache(t *testing.T) { t.Fatalf("Expected glob to be cached: %v", glob2) } } + +func TestGlobBuiltinInterQueryValueCache(t *testing.T) { + ip := []byte(`{"inter_query_builtin_value_cache": {"max_num_entries": "10"},}`) + config, _ := cache.ParseCachingConfig(ip) + interQueryValueCache := cache.NewInterQueryValueCache(context.Background(), config) + + ctx := BuiltinContext{InterQueryBuiltinValueCache: interQueryValueCache} + iter := func(*ast.Term) error { return nil } + + // A novel glob pattern is cached. + glob1 := "foo/*" + operands := []*ast.Term{ + ast.NewTerm(ast.String(glob1)), + ast.NullTerm(), + ast.NewTerm(ast.String("foo/bar")), + } + err := builtinGlobMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // the glob id will have a trailing '-' rune. + if _, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(fmt.Sprintf("%s-", glob1)).Value); !ok { + t.Fatalf("Expected glob to be cached: %v", glob1) + } + + // Fill up the cache. + for i := 0; i < 9; i++ { + operands := []*ast.Term{ + ast.NewTerm(ast.String(fmt.Sprintf("foo/%d/*", i))), + ast.NullTerm(), + ast.NewTerm(ast.String(fmt.Sprintf("foo/%d/bar", i))), + } + err := builtinGlobMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } + + // A new glob pattern is cached and a random pattern is evicted. + glob2 := "bar/*" + operands = []*ast.Term{ + ast.NewTerm(ast.String(glob2)), + ast.NullTerm(), + ast.NewTerm(ast.String("bar/baz")), + } + err = builtinGlobMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if _, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(fmt.Sprintf("%s-", glob2)).Value); !ok { + t.Fatalf("Expected glob to be cached: %v", glob2) + } +} + +func TestGlobBuiltinInterQueryValueCacheTypeMismatch(t *testing.T) { + + ip := []byte(`{"inter_query_builtin_value_cache": {"max_num_entries": "10"},}`) + config, _ := cache.ParseCachingConfig(ip) + interQueryValueCache := cache.NewInterQueryValueCache(context.Background(), config) + + ctx := BuiltinContext{InterQueryBuiltinValueCache: interQueryValueCache} + iter := func(*ast.Term) error { return nil } + + key := "foo.*" + + operands := []*ast.Term{ + ast.NewTerm(ast.String(key)), + ast.NullTerm(), + ast.NewTerm(ast.String("foo/bar")), + } + err := builtinGlobMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // the glob id will have a trailing '-' rune. + if _, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(fmt.Sprintf("%s-", key)).Value); !ok { + t.Fatalf("Expected glob to be cached: %v", key) + } + + // update the cache entry + ctx.InterQueryBuiltinValueCache.Insert(ast.StringTerm(fmt.Sprintf("%s-", key)).Value, "bar") + + err = builtinGlobMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // verify the cache entry is unchanged + value, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(fmt.Sprintf("%s-", key)).Value) + if !ok { + t.Fatal("Expected key \"foo.*-\" in cache") + } + + actual, ok := value.(string) + if !ok { + t.Fatal("Expected string value") + } + + if actual != "bar" { + t.Fatalf("Expected value \"bar\" but got %v", actual) + } +} diff --git a/topdown/query.go b/topdown/query.go index bbb4ba58f3..8406cfdd87 100644 --- a/topdown/query.go +++ b/topdown/query.go @@ -27,38 +27,39 @@ type QueryResult map[ast.Var]*ast.Term // Query provides a configurable interface for performing query evaluation. type Query struct { - seed io.Reader - time time.Time - cancel Cancel - query ast.Body - queryCompiler ast.QueryCompiler - compiler *ast.Compiler - store storage.Store - txn storage.Transaction - input *ast.Term - external *resolverTrie - tracers []QueryTracer - plugTraceVars bool - unknowns []*ast.Term - partialNamespace string - skipSaveNamespace bool - metrics metrics.Metrics - instr *Instrumentation - disableInlining []ast.Ref - shallowInlining bool - genvarprefix string - runtime *ast.Term - builtins map[string]*Builtin - indexing bool - earlyExit bool - interQueryBuiltinCache cache.InterQueryCache - ndBuiltinCache builtins.NDBCache - strictBuiltinErrors bool - builtinErrorList *[]Error - strictObjects bool - printHook print.Hook - tracingOpts tracing.Options - virtualCache VirtualCache + seed io.Reader + time time.Time + cancel Cancel + query ast.Body + queryCompiler ast.QueryCompiler + compiler *ast.Compiler + store storage.Store + txn storage.Transaction + input *ast.Term + external *resolverTrie + tracers []QueryTracer + plugTraceVars bool + unknowns []*ast.Term + partialNamespace string + skipSaveNamespace bool + metrics metrics.Metrics + instr *Instrumentation + disableInlining []ast.Ref + shallowInlining bool + genvarprefix string + runtime *ast.Term + builtins map[string]*Builtin + indexing bool + earlyExit bool + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + ndBuiltinCache builtins.NDBCache + strictBuiltinErrors bool + builtinErrorList *[]Error + strictObjects bool + printHook print.Hook + tracingOpts tracing.Options + virtualCache VirtualCache } // Builtin represents a built-in function that queries can call. @@ -246,6 +247,12 @@ func (q *Query) WithInterQueryBuiltinCache(c cache.InterQueryCache) *Query { return q } +// WithInterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize. +func (q *Query) WithInterQueryBuiltinValueCache(c cache.InterQueryValueCache) *Query { + q.interQueryBuiltinValueCache = c + return q +} + // WithNDBuiltinCache sets the non-deterministic builtin cache. func (q *Query) WithNDBuiltinCache(c builtins.NDBCache) *Query { q.ndBuiltinCache = c @@ -331,39 +338,40 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support [] } e := &eval{ - ctx: ctx, - metrics: q.metrics, - seed: q.seed, - time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), - cancel: q.cancel, - query: q.query, - queryCompiler: q.queryCompiler, - queryIDFact: f, - queryID: f.Next(), - bindings: b, - compiler: q.compiler, - store: q.store, - baseCache: newBaseCache(), - targetStack: newRefStack(), - txn: q.txn, - input: q.input, - external: q.external, - tracers: q.tracers, - traceEnabled: len(q.tracers) > 0, - plugTraceVars: q.plugTraceVars, - instr: q.instr, - builtins: q.builtins, - builtinCache: builtins.Cache{}, - functionMocks: newFunctionMocksStack(), - interQueryBuiltinCache: q.interQueryBuiltinCache, - ndBuiltinCache: q.ndBuiltinCache, - virtualCache: vc, - comprehensionCache: newComprehensionCache(), - saveSet: newSaveSet(q.unknowns, b, q.instr), - saveStack: newSaveStack(), - saveSupport: newSaveSupport(), - saveNamespace: ast.StringTerm(q.partialNamespace), - skipSaveNamespace: q.skipSaveNamespace, + ctx: ctx, + metrics: q.metrics, + seed: q.seed, + time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), + cancel: q.cancel, + query: q.query, + queryCompiler: q.queryCompiler, + queryIDFact: f, + queryID: f.Next(), + bindings: b, + compiler: q.compiler, + store: q.store, + baseCache: newBaseCache(), + targetStack: newRefStack(), + txn: q.txn, + input: q.input, + external: q.external, + tracers: q.tracers, + traceEnabled: len(q.tracers) > 0, + plugTraceVars: q.plugTraceVars, + instr: q.instr, + builtins: q.builtins, + builtinCache: builtins.Cache{}, + functionMocks: newFunctionMocksStack(), + interQueryBuiltinCache: q.interQueryBuiltinCache, + interQueryBuiltinValueCache: q.interQueryBuiltinValueCache, + ndBuiltinCache: q.ndBuiltinCache, + virtualCache: vc, + comprehensionCache: newComprehensionCache(), + saveSet: newSaveSet(q.unknowns, b, q.instr), + saveStack: newSaveStack(), + saveSupport: newSaveSupport(), + saveNamespace: ast.StringTerm(q.partialNamespace), + skipSaveNamespace: q.skipSaveNamespace, inliningControl: &inliningControl{ shallow: q.shallowInlining, }, @@ -516,42 +524,43 @@ func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error { } e := &eval{ - ctx: ctx, - metrics: q.metrics, - seed: q.seed, - time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), - cancel: q.cancel, - query: q.query, - queryCompiler: q.queryCompiler, - queryIDFact: f, - queryID: f.Next(), - bindings: newBindings(0, q.instr), - compiler: q.compiler, - store: q.store, - baseCache: newBaseCache(), - targetStack: newRefStack(), - txn: q.txn, - input: q.input, - external: q.external, - tracers: q.tracers, - traceEnabled: len(q.tracers) > 0, - plugTraceVars: q.plugTraceVars, - instr: q.instr, - builtins: q.builtins, - builtinCache: builtins.Cache{}, - functionMocks: newFunctionMocksStack(), - interQueryBuiltinCache: q.interQueryBuiltinCache, - ndBuiltinCache: q.ndBuiltinCache, - virtualCache: vc, - comprehensionCache: newComprehensionCache(), - genvarprefix: q.genvarprefix, - runtime: q.runtime, - indexing: q.indexing, - earlyExit: q.earlyExit, - builtinErrors: &builtinErrors{}, - printHook: q.printHook, - tracingOpts: q.tracingOpts, - strictObjects: q.strictObjects, + ctx: ctx, + metrics: q.metrics, + seed: q.seed, + time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), + cancel: q.cancel, + query: q.query, + queryCompiler: q.queryCompiler, + queryIDFact: f, + queryID: f.Next(), + bindings: newBindings(0, q.instr), + compiler: q.compiler, + store: q.store, + baseCache: newBaseCache(), + targetStack: newRefStack(), + txn: q.txn, + input: q.input, + external: q.external, + tracers: q.tracers, + traceEnabled: len(q.tracers) > 0, + plugTraceVars: q.plugTraceVars, + instr: q.instr, + builtins: q.builtins, + builtinCache: builtins.Cache{}, + functionMocks: newFunctionMocksStack(), + interQueryBuiltinCache: q.interQueryBuiltinCache, + interQueryBuiltinValueCache: q.interQueryBuiltinValueCache, + ndBuiltinCache: q.ndBuiltinCache, + virtualCache: vc, + comprehensionCache: newComprehensionCache(), + genvarprefix: q.genvarprefix, + runtime: q.runtime, + indexing: q.indexing, + earlyExit: q.earlyExit, + builtinErrors: &builtinErrors{}, + printHook: q.printHook, + tracingOpts: q.tracingOpts, + strictObjects: q.strictObjects, } e.caller = e q.metrics.Timer(metrics.RegoQueryEval).Start() diff --git a/topdown/regex.go b/topdown/regex.go index 877f19e233..7ddffe8d57 100644 --- a/topdown/regex.go +++ b/topdown/regex.go @@ -20,6 +20,8 @@ const regexCacheMaxSize = 100 var regexpCacheLock = sync.Mutex{} var regexpCache map[string]*regexp.Regexp +var regexInterQueryValueCacheHits = "rego_builtin_regex_interquery_value_cache_hits" + func builtinRegexIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { s, err := builtins.StringOperand(operands[0].Value, 1) @@ -35,7 +37,7 @@ func builtinRegexIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast. return iter(ast.BooleanTerm(true)) } -func builtinRegexMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +func builtinRegexMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { s1, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -44,7 +46,7 @@ func builtinRegexMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te if err != nil { return err } - re, err := getRegexp(string(s1)) + re, err := getRegexp(bctx, string(s1)) if err != nil { return err } @@ -81,7 +83,7 @@ func builtinRegexMatchTemplate(_ BuiltinContext, operands []*ast.Term, iter func return iter(ast.BooleanTerm(re.MatchString(string(match)))) } -func builtinRegexSplit(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +func builtinRegexSplit(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { s1, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -90,7 +92,7 @@ func builtinRegexSplit(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te if err != nil { return err } - re, err := getRegexp(string(s1)) + re, err := getRegexp(bctx, string(s1)) if err != nil { return err } @@ -103,7 +105,33 @@ func builtinRegexSplit(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te return iter(ast.NewTerm(ast.NewArray(arr...))) } -func getRegexp(pat string) (*regexp.Regexp, error) { +func getRegexp(bctx BuiltinContext, pat string) (*regexp.Regexp, error) { + if bctx.InterQueryBuiltinValueCache != nil { + val, ok := bctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(pat).Value) + if ok { + res, valid := val.(*regexp.Regexp) + if !valid { + // The cache key may exist for a different value type (eg. glob). + // In this case, we calculate the regex and return the result w/o updating the cache. + re, err := regexp.Compile(pat) + if err != nil { + return nil, err + } + return re, nil + } + + bctx.Metrics.Counter(regexInterQueryValueCacheHits).Incr() + return res, nil + } + + re, err := regexp.Compile(pat) + if err != nil { + return nil, err + } + bctx.InterQueryBuiltinValueCache.Insert(ast.StringTerm(pat).Value, re) + return re, nil + } + regexpCacheLock.Lock() defer regexpCacheLock.Unlock() re, ok := regexpCache[pat] @@ -156,7 +184,7 @@ func builtinGlobsMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te return iter(ast.BooleanTerm(ne)) } -func builtinRegexFind(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +func builtinRegexFind(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { s1, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -169,7 +197,7 @@ func builtinRegexFind(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter if err != nil { return err } - re, err := getRegexp(string(s1)) + re, err := getRegexp(bctx, string(s1)) if err != nil { return err } @@ -182,7 +210,7 @@ func builtinRegexFind(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter return iter(ast.NewTerm(ast.NewArray(arr...))) } -func builtinRegexFindAllStringSubmatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +func builtinRegexFindAllStringSubmatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { s1, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -196,7 +224,7 @@ func builtinRegexFindAllStringSubmatch(_ BuiltinContext, operands []*ast.Term, i return err } - re, err := getRegexp(string(s1)) + re, err := getRegexp(bctx, string(s1)) if err != nil { return err } @@ -214,7 +242,7 @@ func builtinRegexFindAllStringSubmatch(_ BuiltinContext, operands []*ast.Term, i return iter(ast.NewTerm(ast.NewArray(outer...))) } -func builtinRegexReplace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +func builtinRegexReplace(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { base, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -230,7 +258,7 @@ func builtinRegexReplace(_ BuiltinContext, operands []*ast.Term, iter func(*ast. return err } - re, err := getRegexp(string(pattern)) + re, err := getRegexp(bctx, string(pattern)) if err != nil { return err } diff --git a/topdown/regex_test.go b/topdown/regex_test.go index 90957e9d1b..bfa429d5e7 100644 --- a/topdown/regex_test.go +++ b/topdown/regex_test.go @@ -5,10 +5,12 @@ package topdown import ( + "context" "fmt" "testing" "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/topdown/cache" ) func TestRegexBuiltinCache(t *testing.T) { @@ -65,3 +67,93 @@ func TestRegexBuiltinCache(t *testing.T) { t.Fatalf("Expected regex to be cached: %v", regex2) } } + +func TestRegexBuiltinInterQueryValueCache(t *testing.T) { + + ip := []byte(`{"inter_query_builtin_value_cache": {"max_num_entries": "10"},}`) + config, _ := cache.ParseCachingConfig(ip) + interQueryValueCache := cache.NewInterQueryValueCache(context.Background(), config) + + ctx := BuiltinContext{InterQueryBuiltinValueCache: interQueryValueCache} + iter := func(*ast.Term) error { return nil } + + // A novel regex pattern is cached. + regex1 := "foo.*" + operands := []*ast.Term{ + ast.NewTerm(ast.String(regex1)), + ast.NewTerm(ast.String("foobar")), + } + err := builtinRegexMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if _, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(regex1).Value); !ok { + t.Fatalf("Expected regex to be cached: %v", regex1) + } + + // Fill up the cache. + for i := 0; i < 9; i++ { + operands := []*ast.Term{ + ast.NewTerm(ast.String(fmt.Sprintf("foo%d.*", i))), + ast.NewTerm(ast.String(fmt.Sprintf("foo%dbar", i))), + } + err := builtinRegexMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } + + // A new regex pattern is cached and a random pattern is evicted. + regex2 := "bar.*" + operands = []*ast.Term{ + ast.NewTerm(ast.String(regex2)), + ast.NewTerm(ast.String("barbaz")), + } + err = builtinRegexMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if _, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(regex2).Value); !ok { + t.Fatalf("Expected regex to be cached: %v", regex2) + } +} + +func TestRegexBuiltinInterQueryValueCacheTypeMismatch(t *testing.T) { + + ip := []byte(`{"inter_query_builtin_value_cache": {"max_num_entries": "10"},}`) + config, _ := cache.ParseCachingConfig(ip) + interQueryValueCache := cache.NewInterQueryValueCache(context.Background(), config) + + ctx := BuiltinContext{InterQueryBuiltinValueCache: interQueryValueCache} + iter := func(*ast.Term) error { return nil } + + key := "foo.*" + + ctx.InterQueryBuiltinValueCache.Insert(ast.StringTerm(key).Value, "bar") + + operands := []*ast.Term{ + ast.NewTerm(ast.String(key)), + ast.NewTerm(ast.String("foobar")), + } + err := builtinRegexMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // verify the original cache entry is unchanged + value, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(key).Value) + if !ok { + t.Fatal("Expected key \"foo.*\" in cache") + } + + actual, ok := value.(string) + if !ok { + t.Fatal("Expected string value") + } + + if actual != "bar" { + t.Fatalf("Expected value \"bar\" but got %v", actual) + } +} diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index 63c0555517..7dce7d3947 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -2119,6 +2119,7 @@ func assertTopDownWithPathAndContext(ctx context.Context, t *testing.T, compiler // add an inter-query cache config, _ := iCache.ParseCachingConfig(nil) interQueryCache := iCache.NewInterQueryCache(config) + interQueryValueCache := iCache.NewInterQueryValueCache(ctx, config) var strictBuiltinErrors bool @@ -2133,6 +2134,7 @@ func assertTopDownWithPathAndContext(ctx context.Context, t *testing.T, compiler WithTransaction(txn). WithInput(inputTerm). WithInterQueryBuiltinCache(interQueryCache). + WithInterQueryBuiltinValueCache(interQueryValueCache). WithStrictBuiltinErrors(strictBuiltinErrors) var tracer BufferTracer @@ -2212,13 +2214,15 @@ func runTopDownPartialTestCase(ctx context.Context, t *testing.T, compiler *ast. // add an inter-query cache config, _ := iCache.ParseCachingConfig(nil) interQueryCache := iCache.NewInterQueryCache(config) + interQueryValueCache := iCache.NewInterQueryValueCache(ctx, config) partialQuery := NewQuery(body). WithCompiler(compiler). WithStore(store). WithUnknowns([]*ast.Term{ast.MustParseTerm("input")}). WithTransaction(txn). - WithInterQueryBuiltinCache(interQueryCache) + WithInterQueryBuiltinCache(interQueryCache). + WithInterQueryBuiltinValueCache(interQueryValueCache) partials, support, err := partialQuery.PartialRun(ctx) @@ -2251,7 +2255,8 @@ func runTopDownPartialTestCase(ctx context.Context, t *testing.T, compiler *ast. WithStore(store). WithTransaction(txn). WithInput(input). - WithInterQueryBuiltinCache(interQueryCache) + WithInterQueryBuiltinCache(interQueryCache). + WithInterQueryBuiltinValueCache(interQueryValueCache) qrs, err := query.Run(ctx) if err != nil {