From 2c562936954feee312b21f7bebd21b9edb437599 Mon Sep 17 00:00:00 2001 From: Ashutosh Narkar Date: Wed, 4 Sep 2024 17:53:48 -0700 Subject: [PATCH] Add a new inter-query value cache to cache data across queries This commit adds a new inter-query value cache that built-in functions can use to cache information across queries. For example, the `regex` and `glob` builtins can use this to cache compiled regex and glob match patterns respectively. The number of entries in the cache can be configured via the OPA config. By default there is no limit. Fixes: #6908 Signed-off-by: Ashutosh Narkar --- docs/content/configuration.md | 14 +- internal/rego/opa/options.go | 19 +-- plugins/discovery/discovery_test.go | 6 +- plugins/plugins.go | 2 +- plugins/plugins_test.go | 10 +- rego/rego.go | 195 ++++++++++++++----------- rego/rego_test.go | 49 +++++++ sdk/opa.go | 75 +++++----- server/authorizer/authorizer.go | 9 ++ server/authorizer/authorizer_test.go | 38 +++++ server/server.go | 89 ++++++----- topdown/builtins.go | 39 ++--- topdown/cache/cache.go | 109 +++++++++++++- topdown/cache/cache_test.go | 105 ++++++++++++- topdown/eval.go | 134 ++++++++--------- topdown/glob.go | 40 ++++- topdown/glob_test.go | 107 ++++++++++++++ topdown/query.go | 211 ++++++++++++++------------- topdown/regex.go | 50 +++++-- topdown/regex_test.go | 92 ++++++++++++ topdown/topdown_test.go | 9 +- 21 files changed, 1021 insertions(+), 381 deletions(-) diff --git a/docs/content/configuration.md b/docs/content/configuration.md index e44ff58888..8faec0d708 100644 --- a/docs/content/configuration.md +++ b/docs/content/configuration.md @@ -856,11 +856,15 @@ Caching represents the configuration of the inter-query cache that built-in func functions provided by OPA, `http.send` is currently the only one to utilize the inter-query cache. See the documentation on the [http.send built-in function](../policy-reference/#http) for information about the available caching options. -| Field | Type | Required | Description | -| --- | --- | --- | --- | -| `caching.inter_query_builtin_cache.max_size_bytes` | `int64` | No | Inter-query cache size limit in bytes. OPA will drop old items from the cache if this limit is exceeded. By default, no limit is set. | -| `caching.inter_query_builtin_cache.forced_eviction_threshold_percentage` | `int64` | No | Threshold limit configured as percentage of `caching.inter_query_builtin_cache.max_size_bytes`, when exceeded OPA will start dropping old items permaturely. By default, set to `100`. | -| `caching.inter_query_builtin_cache.stale_entry_eviction_period_seconds` | `int64` | No | Stale entry eviction period in seconds. OPA will drop expired items from the cache every `stale_entry_eviction_period_seconds`. By default, set to `0` indicating stale entry eviction is disabled. | +It also represents the configuration of the inter-query value cache that built-in functions can utilize. Currently, this +cache is utilized by the `regex` and `glob` built-in functions for compiled regex and glob match patterns respectively. + +| Field | Type | Required | Description | +|--------------------------------------------------------------------------| --- | --- |-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `caching.inter_query_builtin_cache.max_size_bytes` | `int64` | No | Inter-query cache size limit in bytes. OPA will drop old items from the cache if this limit is exceeded. By default, no limit is set. | +| `caching.inter_query_builtin_cache.forced_eviction_threshold_percentage` | `int64` | No | Threshold limit configured as percentage of `caching.inter_query_builtin_cache.max_size_bytes`, when exceeded OPA will start dropping old items permaturely. By default, set to `100`. | +| `caching.inter_query_builtin_cache.stale_entry_eviction_period_seconds` | `int64` | No | Stale entry eviction period in seconds. OPA will drop expired items from the cache every `stale_entry_eviction_period_seconds`. By default, set to `0` indicating stale entry eviction is disabled. | +| `caching.inter_query_builtin_value_cache.max_num_entries` | `int` | No | Maximum number of entries in the Inter-query value cache. OPA will drop random items from the cache if this limit is exceeded. By default, set to `0` indicating unlimited size. | ## Distributed tracing diff --git a/internal/rego/opa/options.go b/internal/rego/opa/options.go index ea1e339c1b..b58a05ee8e 100644 --- a/internal/rego/opa/options.go +++ b/internal/rego/opa/options.go @@ -18,13 +18,14 @@ type Result struct { // EvalOpts define options for performing an evaluation. type EvalOpts struct { - Input *interface{} - Metrics metrics.Metrics - Entrypoint int32 - Time time.Time - Seed io.Reader - InterQueryBuiltinCache cache.InterQueryCache - NDBuiltinCache builtins.NDBCache - PrintHook print.Hook - Capabilities *ast.Capabilities + Input *interface{} + Metrics metrics.Metrics + Entrypoint int32 + Time time.Time + Seed io.Reader + InterQueryBuiltinCache cache.InterQueryCache + InterQueryBuiltinValueCache cache.InterQueryValueCache + NDBuiltinCache builtins.NDBCache + PrintHook print.Hook + Capabilities *ast.Capabilities } diff --git a/plugins/discovery/discovery_test.go b/plugins/discovery/discovery_test.go index c6aa726bd4..5367e8f895 100644 --- a/plugins/discovery/discovery_test.go +++ b/plugins/discovery/discovery_test.go @@ -1905,7 +1905,11 @@ func TestReconfigureWithLocalOverride(t *testing.T) { *period = 10 threshold := new(int64) *threshold = 90 - expectedCacheConf := &cache.Config{InterQueryBuiltinCache: cache.InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, StaleEntryEvictionPeriodSeconds: period, ForcedEvictionThresholdPercentage: threshold}} + maxNumEntriesInterQueryValueCache := new(int) + *maxNumEntriesInterQueryValueCache = 0 + + expectedCacheConf := &cache.Config{InterQueryBuiltinCache: cache.InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, StaleEntryEvictionPeriodSeconds: period, ForcedEvictionThresholdPercentage: threshold}, + InterQueryBuiltinValueCache: cache.InterQueryBuiltinValueCacheConfig{MaxNumEntries: maxNumEntriesInterQueryValueCache}} if !reflect.DeepEqual(cacheConf, expectedCacheConf) { t.Fatalf("want %v got %v", expectedCacheConf, cacheConf) diff --git a/plugins/plugins.go b/plugins/plugins.go index bacdd15076..567acfb817 100644 --- a/plugins/plugins.go +++ b/plugins/plugins.go @@ -576,7 +576,7 @@ func (m *Manager) Labels() map[string]string { return m.Config.Labels } -// InterQueryBuiltinCacheConfig returns the configuration for the inter-query cache. +// InterQueryBuiltinCacheConfig returns the configuration for the inter-query caches. func (m *Manager) InterQueryBuiltinCacheConfig() *cache.Config { m.mtx.Lock() defer m.mtx.Unlock() diff --git a/plugins/plugins_test.go b/plugins/plugins_test.go index b04096bab7..960d44ae8f 100644 --- a/plugins/plugins_test.go +++ b/plugins/plugins_test.go @@ -396,7 +396,7 @@ func TestPluginManagerInitIdempotence(t *testing.T) { } func TestManagerWithCachingConfig(t *testing.T) { - m, err := New([]byte(`{"caching": {"inter_query_builtin_cache": {"max_size_bytes": 100}}}`), "test", inmem.New()) + m, err := New([]byte(`{"caching": {"inter_query_builtin_cache": {"max_size_bytes": 100}, "inter_query_builtin_value_cache": {"max_num_entries": 100}}}`), "test", inmem.New()) if err != nil { t.Fatal(err) } @@ -404,6 +404,8 @@ func TestManagerWithCachingConfig(t *testing.T) { expected, _ := cache.ParseCachingConfig(nil) limit := int64(100) expected.InterQueryBuiltinCache.MaxSizeBytes = &limit + maxNumEntriesInterQueryValueCache := int(100) + expected.InterQueryBuiltinValueCache.MaxNumEntries = &maxNumEntriesInterQueryValueCache if !reflect.DeepEqual(m.InterQueryBuiltinCacheConfig(), expected) { t.Fatalf("want %+v got %+v", expected, m.interQueryBuiltinCacheConfig) @@ -414,6 +416,12 @@ func TestManagerWithCachingConfig(t *testing.T) { if err == nil { t.Fatal("expected error but got nil") } + + // config error + _, err = New([]byte(`{"caching": {"inter_query_builtin_value_cache": {"max_num_entries": "100"}}}`), "test", inmem.New()) + if err == nil { + t.Fatal("expected error but got nil") + } } func TestManagerWithNDCachingConfig(t *testing.T) { diff --git a/rego/rego.go b/rego/rego.go index df349b25cb..8672efd669 100644 --- a/rego/rego.go +++ b/rego/rego.go @@ -99,32 +99,33 @@ type preparedQuery struct { // EvalContext defines the set of options allowed to be set at evaluation // time. Any other options will need to be set on a new Rego object. type EvalContext struct { - hasInput bool - time time.Time - seed io.Reader - rawInput *interface{} - parsedInput ast.Value - metrics metrics.Metrics - txn storage.Transaction - instrument bool - instrumentation *topdown.Instrumentation - partialNamespace string - queryTracers []topdown.QueryTracer - compiledQuery compiledQuery - unknowns []string - disableInlining []ast.Ref - parsedUnknowns []*ast.Term - indexing bool - earlyExit bool - interQueryBuiltinCache cache.InterQueryCache - ndBuiltinCache builtins.NDBCache - resolvers []refResolver - sortSets bool - copyMaps bool - printHook print.Hook - capabilities *ast.Capabilities - strictBuiltinErrors bool - virtualCache topdown.VirtualCache + hasInput bool + time time.Time + seed io.Reader + rawInput *interface{} + parsedInput ast.Value + metrics metrics.Metrics + txn storage.Transaction + instrument bool + instrumentation *topdown.Instrumentation + partialNamespace string + queryTracers []topdown.QueryTracer + compiledQuery compiledQuery + unknowns []string + disableInlining []ast.Ref + parsedUnknowns []*ast.Term + indexing bool + earlyExit bool + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + ndBuiltinCache builtins.NDBCache + resolvers []refResolver + sortSets bool + copyMaps bool + printHook print.Hook + capabilities *ast.Capabilities + strictBuiltinErrors bool + virtualCache topdown.VirtualCache } func (e *EvalContext) RawInput() *interface{} { @@ -147,6 +148,10 @@ func (e *EvalContext) InterQueryBuiltinCache() cache.InterQueryCache { return e.interQueryBuiltinCache } +func (e *EvalContext) InterQueryBuiltinValueCache() cache.InterQueryValueCache { + return e.interQueryBuiltinValueCache +} + func (e *EvalContext) PrintHook() print.Hook { return e.printHook } @@ -307,6 +312,14 @@ func EvalInterQueryBuiltinCache(c cache.InterQueryCache) EvalOption { } } +// EvalInterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize +// during evaluation. +func EvalInterQueryBuiltinValueCache(c cache.InterQueryValueCache) EvalOption { + return func(e *EvalContext) { + e.interQueryBuiltinValueCache = c + } +} + // EvalNDBuiltinCache sets the non-deterministic builtin cache that built-in functions can // use during evaluation. func EvalNDBuiltinCache(c builtins.NDBCache) EvalOption { @@ -546,64 +559,65 @@ type loadPaths struct { // Rego constructs a query and can be evaluated to obtain results. type Rego struct { - query string - parsedQuery ast.Body - compiledQueries map[queryType]compiledQuery - pkg string - parsedPackage *ast.Package - imports []string - parsedImports []*ast.Import - rawInput *interface{} - parsedInput ast.Value - unknowns []string - parsedUnknowns []*ast.Term - disableInlining []string - shallowInlining bool - skipPartialNamespace bool - partialNamespace string - modules []rawModule - parsedModules map[string]*ast.Module - compiler *ast.Compiler - store storage.Store - ownStore bool - txn storage.Transaction - metrics metrics.Metrics - queryTracers []topdown.QueryTracer - tracebuf *topdown.BufferTracer - trace bool - instrumentation *topdown.Instrumentation - instrument bool - capture map[*ast.Expr]ast.Var // map exprs to generated capture vars - termVarID int - dump io.Writer - runtime *ast.Term - time time.Time - seed io.Reader - capabilities *ast.Capabilities - builtinDecls map[string]*ast.Builtin - builtinFuncs map[string]*topdown.Builtin - unsafeBuiltins map[string]struct{} - loadPaths loadPaths - bundlePaths []string - bundles map[string]*bundle.Bundle - skipBundleVerification bool - interQueryBuiltinCache cache.InterQueryCache - ndBuiltinCache builtins.NDBCache - strictBuiltinErrors bool - builtinErrorList *[]topdown.Error - resolvers []refResolver - schemaSet *ast.SchemaSet - target string // target type (wasm, rego, etc.) - opa opa.EvalEngine - generateJSON func(*ast.Term, *EvalContext) (interface{}, error) - printHook print.Hook - enablePrintStatements bool - distributedTacingOpts tracing.Options - strict bool - pluginMgr *plugins.Manager - plugins []TargetPlugin - targetPrepState TargetPluginEval - regoVersion ast.RegoVersion + query string + parsedQuery ast.Body + compiledQueries map[queryType]compiledQuery + pkg string + parsedPackage *ast.Package + imports []string + parsedImports []*ast.Import + rawInput *interface{} + parsedInput ast.Value + unknowns []string + parsedUnknowns []*ast.Term + disableInlining []string + shallowInlining bool + skipPartialNamespace bool + partialNamespace string + modules []rawModule + parsedModules map[string]*ast.Module + compiler *ast.Compiler + store storage.Store + ownStore bool + txn storage.Transaction + metrics metrics.Metrics + queryTracers []topdown.QueryTracer + tracebuf *topdown.BufferTracer + trace bool + instrumentation *topdown.Instrumentation + instrument bool + capture map[*ast.Expr]ast.Var // map exprs to generated capture vars + termVarID int + dump io.Writer + runtime *ast.Term + time time.Time + seed io.Reader + capabilities *ast.Capabilities + builtinDecls map[string]*ast.Builtin + builtinFuncs map[string]*topdown.Builtin + unsafeBuiltins map[string]struct{} + loadPaths loadPaths + bundlePaths []string + bundles map[string]*bundle.Bundle + skipBundleVerification bool + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + ndBuiltinCache builtins.NDBCache + strictBuiltinErrors bool + builtinErrorList *[]topdown.Error + resolvers []refResolver + schemaSet *ast.SchemaSet + target string // target type (wasm, rego, etc.) + opa opa.EvalEngine + generateJSON func(*ast.Term, *EvalContext) (interface{}, error) + printHook print.Hook + enablePrintStatements bool + distributedTacingOpts tracing.Options + strict bool + pluginMgr *plugins.Manager + plugins []TargetPlugin + targetPrepState TargetPluginEval + regoVersion ast.RegoVersion } // Function represents a built-in function that is callable in Rego. @@ -1114,6 +1128,14 @@ func InterQueryBuiltinCache(c cache.InterQueryCache) func(r *Rego) { } } +// InterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize +// during evaluation. +func InterQueryBuiltinValueCache(c cache.InterQueryValueCache) func(r *Rego) { + return func(r *Rego) { + r.interQueryBuiltinValueCache = c + } +} + // NDBuiltinCache sets the non-deterministic builtins cache. func NDBuiltinCache(c builtins.NDBCache) func(r *Rego) { return func(r *Rego) { @@ -1309,6 +1331,7 @@ func (r *Rego) Eval(ctx context.Context) (ResultSet, error) { EvalInstrument(r.instrument), EvalTime(r.time), EvalInterQueryBuiltinCache(r.interQueryBuiltinCache), + EvalInterQueryBuiltinValueCache(r.interQueryBuiltinValueCache), EvalSeed(r.seed), } @@ -1386,6 +1409,7 @@ func (r *Rego) Partial(ctx context.Context) (*PartialQueries, error) { EvalMetrics(r.metrics), EvalInstrument(r.instrument), EvalInterQueryBuiltinCache(r.interQueryBuiltinCache), + EvalInterQueryBuiltinValueCache(r.interQueryBuiltinValueCache), } if r.ndBuiltinCache != nil { @@ -2106,6 +2130,7 @@ func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) { WithIndexing(ectx.indexing). WithEarlyExit(ectx.earlyExit). WithInterQueryBuiltinCache(ectx.interQueryBuiltinCache). + WithInterQueryBuiltinValueCache(ectx.interQueryBuiltinValueCache). WithStrictBuiltinErrors(r.strictBuiltinErrors). WithBuiltinErrorList(r.builtinErrorList). WithSeed(ectx.seed). @@ -2164,7 +2189,6 @@ func (r *Rego) eval(ctx context.Context, ectx *EvalContext) (ResultSet, error) { } func (r *Rego) evalWasm(ctx context.Context, ectx *EvalContext) (ResultSet, error) { - input := ectx.rawInput if ectx.parsedInput != nil { i := interface{}(ectx.parsedInput) @@ -2393,6 +2417,7 @@ func (r *Rego) partial(ctx context.Context, ectx *EvalContext) (*PartialQueries, WithSkipPartialNamespace(r.skipPartialNamespace). WithShallowInlining(r.shallowInlining). WithInterQueryBuiltinCache(ectx.interQueryBuiltinCache). + WithInterQueryBuiltinValueCache(ectx.interQueryBuiltinValueCache). WithStrictBuiltinErrors(ectx.strictBuiltinErrors). WithSeed(ectx.seed). WithPrintHook(ectx.printHook) diff --git a/rego/rego_test.go b/rego/rego_test.go index 86550e35f5..208ce5123f 100644 --- a/rego/rego_test.go +++ b/rego/rego_test.go @@ -2432,6 +2432,55 @@ func TestEvalWithInterQueryCache(t *testing.T) { } } +func TestEvalWithInterQueryValueCache(t *testing.T) { + ctx := context.Background() + + // add an inter-query value cache + config, _ := cache.ParseCachingConfig(nil) + interQueryValueCache := cache.NewInterQueryValueCache(ctx, config) + + m := metrics.New() + + query := `regex.match("foo.*", "foobar")` + _, err := New(Query(query), InterQueryBuiltinValueCache(interQueryValueCache), Metrics(m)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + // eval again with same query + // this request should be served by the cache + _, err = New(Query(query), InterQueryBuiltinValueCache(interQueryValueCache), Metrics(m)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + if exp, act := uint64(1), m.Counter("rego_builtin_regex_interquery_value_cache_hits").Value(); exp != act { + t.Fatalf("expected %d cache hits, got %d", exp, act) + } + + query = `glob.match("*.example.com", ["."], "api.example.com")` + _, err = New(Query(query), InterQueryBuiltinValueCache(interQueryValueCache), Metrics(m)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + // eval again with same query + // this request should be served by the cache + _, err = New(Query(query), InterQueryBuiltinValueCache(interQueryValueCache), Metrics(m)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + _, err = New(Query(query), InterQueryBuiltinValueCache(interQueryValueCache), Metrics(m)).Eval(ctx) + if err != nil { + t.Fatal(err) + } + + if exp, act := uint64(2), m.Counter("rego_builtin_glob_interquery_value_cache_hits").Value(); exp != act { + t.Fatalf("expected %d cache hits, got %d", exp, act) + } +} + // We use http.send to ensure the NDBuiltinCache is involved. func TestEvalWithNDCache(t *testing.T) { var requests []*http.Request diff --git a/sdk/opa.go b/sdk/opa.go index 7203df042c..8b180ea1a6 100644 --- a/sdk/opa.go +++ b/sdk/opa.go @@ -53,9 +53,10 @@ type OPA struct { } type state struct { - manager *plugins.Manager - interQueryBuiltinCache cache.InterQueryCache - queryCache *queryCache + manager *plugins.Manager + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + queryCache *queryCache } // New returns a new OPA object. This function should minimally be called with @@ -235,6 +236,7 @@ func (opa *OPA) configure(ctx context.Context, bs []byte, ready chan struct{}, b opa.state.manager = manager opa.state.queryCache.Clear() opa.state.interQueryBuiltinCache = cache.NewInterQueryCacheWithContext(ctx, manager.InterQueryBuiltinCacheConfig()) + opa.state.interQueryBuiltinValueCache = cache.NewInterQueryValueCache(ctx, manager.InterQueryBuiltinCacheConfig()) opa.config = bs return nil @@ -277,22 +279,23 @@ func (opa *OPA) Decision(ctx context.Context, options DecisionOptions) (*Decisio &record, func(s state, result *DecisionResult) { result.Result, result.Provenance, record.InputAST, record.Bundles, record.Error = evaluate(ctx, evalArgs{ - runtime: s.manager.Info, - printHook: s.manager.PrintHook(), - compiler: s.manager.GetCompiler(), - store: s.manager.Store, - queryCache: s.queryCache, - interQueryCache: s.interQueryBuiltinCache, - ndbcache: ndbc, - txn: record.Txn, - now: record.Timestamp, - path: record.Path, - input: *record.Input, - m: record.Metrics, - strictBuiltinErrors: options.StrictBuiltinErrors, - tracer: options.Tracer, - profiler: options.Profiler, - instrument: options.Instrument, + runtime: s.manager.Info, + printHook: s.manager.PrintHook(), + compiler: s.manager.GetCompiler(), + store: s.manager.Store, + queryCache: s.queryCache, + interQueryCache: s.interQueryBuiltinCache, + interQueryBuiltinValueCache: s.interQueryBuiltinValueCache, + ndbcache: ndbc, + txn: record.Txn, + now: record.Timestamp, + path: record.Path, + input: *record.Input, + m: record.Metrics, + strictBuiltinErrors: options.StrictBuiltinErrors, + tracer: options.Tracer, + profiler: options.Profiler, + instrument: options.Instrument, }) if record.Error == nil { record.Results = &result.Result @@ -506,22 +509,23 @@ func IsUndefinedErr(err error) bool { } type evalArgs struct { - runtime *ast.Term - printHook print.Hook - compiler *ast.Compiler - store storage.Store - txn storage.Transaction - queryCache *queryCache - interQueryCache cache.InterQueryCache - now time.Time - path string - input interface{} - ndbcache builtins.NDBCache - m metrics.Metrics - strictBuiltinErrors bool - tracer topdown.QueryTracer - profiler topdown.QueryTracer - instrument bool + runtime *ast.Term + printHook print.Hook + compiler *ast.Compiler + store storage.Store + txn storage.Transaction + queryCache *queryCache + interQueryCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + now time.Time + path string + input interface{} + ndbcache builtins.NDBCache + m metrics.Metrics + strictBuiltinErrors bool + tracer topdown.QueryTracer + profiler topdown.QueryTracer + instrument bool } func evaluate(ctx context.Context, args evalArgs) (interface{}, types.ProvenanceV1, ast.Value, map[string]server.BundleInfo, error) { @@ -581,6 +585,7 @@ func evaluate(ctx context.Context, args evalArgs) (interface{}, types.Provenance rego.EvalTransaction(args.txn), rego.EvalMetrics(args.m), rego.EvalInterQueryBuiltinCache(args.interQueryCache), + rego.EvalInterQueryBuiltinValueCache(args.interQueryBuiltinValueCache), rego.EvalNDBuiltinCache(args.ndbcache), rego.EvalQueryTracer(args.tracer), rego.EvalMetrics(args.m), diff --git a/server/authorizer/authorizer.go b/server/authorizer/authorizer.go index 8dcc3c3394..4240f40377 100644 --- a/server/authorizer/authorizer.go +++ b/server/authorizer/authorizer.go @@ -32,6 +32,7 @@ type Basic struct { printHook print.Hook enablePrintStatements bool interQueryCache cache.InterQueryCache + interQueryValueCache cache.InterQueryValueCache } // Runtime returns an argument that sets the runtime on the authorizer. @@ -73,6 +74,13 @@ func InterQueryCache(interQueryCache cache.InterQueryCache) func(*Basic) { } } +// InterQueryValueCache enables the inter-query value cache on the authorizer +func InterQueryValueCache(interQueryValueCache cache.InterQueryValueCache) func(*Basic) { + return func(b *Basic) { + b.interQueryValueCache = interQueryValueCache + } +} + // NewBasic returns a new Basic object. func NewBasic(inner http.Handler, compiler func() *ast.Compiler, store storage.Store, opts ...func(*Basic)) http.Handler { b := &Basic{ @@ -107,6 +115,7 @@ func (h *Basic) ServeHTTP(w http.ResponseWriter, r *http.Request) { rego.EnablePrintStatements(h.enablePrintStatements), rego.PrintHook(h.printHook), rego.InterQueryBuiltinCache(h.interQueryCache), + rego.InterQueryBuiltinValueCache(h.interQueryValueCache), ) rs, err := rego.Eval(r.Context()) diff --git a/server/authorizer/authorizer_test.go b/server/authorizer/authorizer_test.go index 17478c8f1a..712142054a 100644 --- a/server/authorizer/authorizer_test.go +++ b/server/authorizer/authorizer_test.go @@ -6,6 +6,7 @@ package authorizer import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -516,6 +517,43 @@ func TestInterQueryCache(t *testing.T) { } } +func TestInterQueryValueCache(t *testing.T) { + + compiler := func() *ast.Compiler { + module := ` + package system.authz + import rego.v1 + + allow if { + regex.match("foo.*", "foobar") + }` + c := ast.NewCompiler() + c.Compile(map[string]*ast.Module{ + "test.rego": ast.MustParseModule(module), + }) + if c.Failed() { + t.Fatalf("Unexpected error compiling test module: %v", c.Errors) + } + return c + } + + recorder := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "http://localhost:8181/v1/data", nil) + if err != nil { + t.Fatal(err) + } + + config, _ := cache.ParseCachingConfig(nil) + interQueryValueCache := cache.NewInterQueryValueCache(context.Background(), config) + + basic := NewBasic(&mockHandler{}, compiler, inmem.New(), InterQueryValueCache(interQueryValueCache), Decision(func() ast.Ref { + return ast.MustParseRef("data.system.authz.allow") + })) + + // Execute the policy + basic.ServeHTTP(recorder, req) +} + func Equal(a, b []string) bool { if len(a) != len(b) { return false diff --git a/server/server.go b/server/server.go index c059719a96..bb5c70767e 100644 --- a/server/server.go +++ b/server/server.go @@ -111,42 +111,43 @@ type Server struct { Handler http.Handler DiagnosticHandler http.Handler - router *mux.Router - addrs []string - diagAddrs []string - h2cEnabled bool - authentication AuthenticationScheme - authorization AuthorizationScheme - cert *tls.Certificate - tlsConfigMtx sync.RWMutex - certFile string - certFileHash []byte - certKeyFile string - certKeyFileHash []byte - certRefresh time.Duration - certPool *x509.CertPool - certPoolFile string - certPoolFileHash []byte - minTLSVersion uint16 - mtx sync.RWMutex - partials map[string]rego.PartialResult - preparedEvalQueries *cache - store storage.Store - manager *plugins.Manager - decisionIDFactory func() string - logger func(context.Context, *Info) error - errLimit int - pprofEnabled bool - runtime *ast.Term - httpListeners []httpListener - metrics Metrics - defaultDecisionPath string - interQueryBuiltinCache iCache.InterQueryCache - allPluginsOkOnce bool - distributedTracingOpts tracing.Options - ndbCacheEnabled bool - unixSocketPerm *string - cipherSuites *[]uint16 + router *mux.Router + addrs []string + diagAddrs []string + h2cEnabled bool + authentication AuthenticationScheme + authorization AuthorizationScheme + cert *tls.Certificate + tlsConfigMtx sync.RWMutex + certFile string + certFileHash []byte + certKeyFile string + certKeyFileHash []byte + certRefresh time.Duration + certPool *x509.CertPool + certPoolFile string + certPoolFileHash []byte + minTLSVersion uint16 + mtx sync.RWMutex + partials map[string]rego.PartialResult + preparedEvalQueries *cache + store storage.Store + manager *plugins.Manager + decisionIDFactory func() string + logger func(context.Context, *Info) error + errLimit int + pprofEnabled bool + runtime *ast.Term + httpListeners []httpListener + metrics Metrics + defaultDecisionPath string + interQueryBuiltinCache iCache.InterQueryCache + interQueryBuiltinValueCache iCache.InterQueryValueCache + allPluginsOkOnce bool + distributedTracingOpts tracing.Options + ndbCacheEnabled bool + unixSocketPerm *string + cipherSuites *[]uint16 } // Metrics defines the interface that the server requires for recording HTTP @@ -748,7 +749,8 @@ func (s *Server) initHandlerAuthz(handler http.Handler) http.Handler { authorizer.Decision(s.manager.Config.DefaultAuthorizationDecisionRef), authorizer.PrintHook(s.manager.PrintHook()), authorizer.EnablePrintStatements(s.manager.EnablePrintStatements()), - authorizer.InterQueryCache(s.interQueryBuiltinCache)) + authorizer.InterQueryCache(s.interQueryBuiltinCache), + authorizer.InterQueryValueCache(s.interQueryBuiltinValueCache)) if s.metrics != nil { handler = s.instrumentHandler(handler.ServeHTTP, PromHandlerAPIAuthz) @@ -800,7 +802,12 @@ func (s *Server) initRouters(ctx context.Context) { diagRouter := mux.NewRouter() // authorizer, if configured, needs the iCache to be set up already - s.interQueryBuiltinCache = iCache.NewInterQueryCacheWithContext(ctx, s.manager.InterQueryBuiltinCacheConfig()) + + cacheConfig := s.manager.InterQueryBuiltinCacheConfig() + + s.interQueryBuiltinCache = iCache.NewInterQueryCacheWithContext(ctx, cacheConfig) + s.interQueryBuiltinValueCache = iCache.NewInterQueryValueCache(ctx, cacheConfig) + s.manager.RegisterCacheTrigger(s.updateCacheConfig) // Add authorization handler. This must come BEFORE authentication handler @@ -933,6 +940,7 @@ func (s *Server) execQuery(ctx context.Context, br bundleRevisions, txn storage. rego.Runtime(s.runtime), rego.UnsafeBuiltins(unsafeBuiltinsMap), rego.InterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.InterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), rego.PrintHook(s.manager.PrintHook()), rego.EnablePrintStatements(s.manager.EnablePrintStatements()), rego.DistributedTracingOpts(s.distributedTracingOpts), @@ -1121,6 +1129,7 @@ func (s *Server) v0QueryPath(w http.ResponseWriter, r *http.Request, urlPath str rego.EvalParsedInput(input), rego.EvalMetrics(m), rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), rego.EvalNDBuiltinCache(ndbCache), } @@ -1402,6 +1411,7 @@ func (s *Server) v1CompilePost(w http.ResponseWriter, r *http.Request) { rego.Runtime(s.runtime), rego.UnsafeBuiltins(unsafeBuiltinsMap), rego.InterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.InterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), rego.PrintHook(s.manager.PrintHook()), ) @@ -1541,6 +1551,7 @@ func (s *Server) v1DataGet(w http.ResponseWriter, r *http.Request) { rego.EvalMetrics(m), rego.EvalQueryTracer(buf), rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), rego.EvalInstrument(includeInstrumentation), rego.EvalNDBuiltinCache(ndbCache), } @@ -1760,6 +1771,7 @@ func (s *Server) v1DataPost(w http.ResponseWriter, r *http.Request) { rego.EvalMetrics(m), rego.EvalQueryTracer(buf), rego.EvalInterQueryBuiltinCache(s.interQueryBuiltinCache), + rego.EvalInterQueryBuiltinValueCache(s.interQueryBuiltinValueCache), rego.EvalInstrument(includeInstrumentation), rego.EvalNDBuiltinCache(ndbCache), } @@ -2655,6 +2667,7 @@ func isPathOwned(path, root []string) bool { func (s *Server) updateCacheConfig(cacheConfig *iCache.Config) { s.interQueryBuiltinCache.UpdateConfig(cacheConfig) + s.interQueryBuiltinValueCache.UpdateConfig(cacheConfig) } func (s *Server) updateNDCache(enabled bool) { diff --git a/topdown/builtins.go b/topdown/builtins.go index 30c488050f..cf694d4331 100644 --- a/topdown/builtins.go +++ b/topdown/builtins.go @@ -35,25 +35,26 @@ type ( // BuiltinContext contains context from the evaluator that may be used by // built-in functions. BuiltinContext struct { - Context context.Context // request context that was passed when query started - Metrics metrics.Metrics // metrics registry for recording built-in specific metrics - Seed io.Reader // randomization source - Time *ast.Term // wall clock time - Cancel Cancel // atomic value that signals evaluation to halt - Runtime *ast.Term // runtime information on the OPA instance - Cache builtins.Cache // built-in function state cache - InterQueryBuiltinCache cache.InterQueryCache // cross-query built-in function state cache - NDBuiltinCache builtins.NDBCache // cache for non-deterministic built-in state - Location *ast.Location // location of built-in call - Tracers []Tracer // Deprecated: Use QueryTracers instead - QueryTracers []QueryTracer // tracer objects for trace() built-in function - TraceEnabled bool // indicates whether tracing is enabled for the evaluation - QueryID uint64 // identifies query being evaluated - ParentID uint64 // identifies parent of query being evaluated - PrintHook print.Hook // provides callback function to use for printing - DistributedTracingOpts tracing.Options // options to be used by distributed tracing. - rand *rand.Rand // randomization source for non-security-sensitive operations - Capabilities *ast.Capabilities + Context context.Context // request context that was passed when query started + Metrics metrics.Metrics // metrics registry for recording built-in specific metrics + Seed io.Reader // randomization source + Time *ast.Term // wall clock time + Cancel Cancel // atomic value that signals evaluation to halt + Runtime *ast.Term // runtime information on the OPA instance + Cache builtins.Cache // built-in function state cache + InterQueryBuiltinCache cache.InterQueryCache // cross-query built-in function state cache + InterQueryBuiltinValueCache cache.InterQueryValueCache // cross-query built-in function state value cache. this cache is useful for scenarios where the entry size cannot be calculated + NDBuiltinCache builtins.NDBCache // cache for non-deterministic built-in state + Location *ast.Location // location of built-in call + Tracers []Tracer // Deprecated: Use QueryTracers instead + QueryTracers []QueryTracer // tracer objects for trace() built-in function + TraceEnabled bool // indicates whether tracing is enabled for the evaluation + QueryID uint64 // identifies query being evaluated + ParentID uint64 // identifies parent of query being evaluated + PrintHook print.Hook // provides callback function to use for printing + DistributedTracingOpts tracing.Options // options to be used by distributed tracing. + rand *rand.Rand // randomization source for non-security-sensitive operations + Capabilities *ast.Capabilities } // BuiltinFunc defines an interface for implementing built-in functions. diff --git a/topdown/cache/cache.go b/topdown/cache/cache.go index c83c9828bf..55ed340619 100644 --- a/topdown/cache/cache.go +++ b/topdown/cache/cache.go @@ -18,14 +18,22 @@ import ( ) const ( + defaultInterQueryBuiltinValueCacheSize = int(0) // unlimited defaultMaxSizeBytes = int64(0) // unlimited defaultForcedEvictionThresholdPercentage = int64(100) // trigger at max_size_bytes defaultStaleEntryEvictionPeriodSeconds = int64(0) // never ) -// Config represents the configuration of the inter-query cache. +// Config represents the configuration for the inter-query builtin cache. type Config struct { - InterQueryBuiltinCache InterQueryBuiltinCacheConfig `json:"inter_query_builtin_cache"` + InterQueryBuiltinCache InterQueryBuiltinCacheConfig `json:"inter_query_builtin_cache"` + InterQueryBuiltinValueCache InterQueryBuiltinValueCacheConfig `json:"inter_query_builtin_value_cache"` +} + +// InterQueryBuiltinValueCacheConfig represents the configuration of the inter-query value cache that built-in functions can utilize. +// MaxNumEntries - max number of cache entries +type InterQueryBuiltinValueCacheConfig struct { + MaxNumEntries *int `json:"max_num_entries,omitempty"` } // InterQueryBuiltinCacheConfig represents the configuration of the inter-query cache that built-in functions can utilize. @@ -47,7 +55,12 @@ func ParseCachingConfig(raw []byte) (*Config, error) { *threshold = defaultForcedEvictionThresholdPercentage period := new(int64) *period = defaultStaleEntryEvictionPeriodSeconds - return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, ForcedEvictionThresholdPercentage: threshold, StaleEntryEvictionPeriodSeconds: period}}, nil + + maxInterQueryBuiltinValueCacheSize := new(int) + *maxInterQueryBuiltinValueCacheSize = defaultInterQueryBuiltinValueCacheSize + + return &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, ForcedEvictionThresholdPercentage: threshold, StaleEntryEvictionPeriodSeconds: period}, + InterQueryBuiltinValueCache: InterQueryBuiltinValueCacheConfig{MaxNumEntries: maxInterQueryBuiltinValueCacheSize}}, nil } var config Config @@ -89,6 +102,18 @@ func (c *Config) validateAndInjectDefaults() error { return fmt.Errorf("invalid stale_entry_eviction_period_seconds %v", period) } } + + if c.InterQueryBuiltinValueCache.MaxNumEntries == nil { + maxSize := new(int) + *maxSize = defaultInterQueryBuiltinValueCacheSize + c.InterQueryBuiltinValueCache.MaxNumEntries = maxSize + } else { + numEntries := *c.InterQueryBuiltinValueCache.MaxNumEntries + if numEntries < 0 { + return fmt.Errorf("invalid max_num_entries %v", numEntries) + } + } + return nil } @@ -301,3 +326,81 @@ func (c *cache) cleanStaleValues() (dropped int) { } return dropped } + +type InterQueryValueCache interface { + Get(key ast.Value) (value any, found bool) + Insert(key ast.Value, value any) int + Delete(key ast.Value) + UpdateConfig(config *Config) +} + +type interQueryValueCache struct { + items map[string]any + config *Config + mtx sync.RWMutex +} + +// Get returns the value in the cache for k. +func (c *interQueryValueCache) Get(k ast.Value) (any, bool) { + c.mtx.RLock() + defer c.mtx.RUnlock() + value, ok := c.items[k.String()] + return value, ok +} + +// Insert inserts a key k into the cache with value v. +func (c *interQueryValueCache) Insert(k ast.Value, v any) (dropped int) { + c.mtx.Lock() + defer c.mtx.Unlock() + + maxEntries := c.maxNumEntries() + if maxEntries > 0 { + if len(c.items) >= maxEntries { + itemsToRemove := len(c.items) - maxEntries + 1 + + // Delete a (semi-)random key to make room for the new one. + for k := range c.items { + delete(c.items, k) + dropped++ + + if itemsToRemove == dropped { + break + } + } + } + } + + c.items[k.String()] = v + return dropped +} + +// Delete deletes the value in the cache for k. +func (c *interQueryValueCache) Delete(k ast.Value) { + c.mtx.Lock() + defer c.mtx.Unlock() + delete(c.items, k.String()) +} + +// UpdateConfig updates the cache config. +func (c *interQueryValueCache) UpdateConfig(config *Config) { + if config == nil { + return + } + c.mtx.Lock() + defer c.mtx.Unlock() + c.config = config +} + +func (c *interQueryValueCache) maxNumEntries() int { + if c.config == nil { + return defaultInterQueryBuiltinValueCacheSize + } + return *c.config.InterQueryBuiltinValueCache.MaxNumEntries +} + +func NewInterQueryValueCache(_ context.Context, config *Config) InterQueryValueCache { + return &interQueryValueCache{ + items: map[string]any{}, + config: config, + } +} diff --git a/topdown/cache/cache_test.go b/topdown/cache/cache_test.go index 85ccb20913..b5c1a3865a 100644 --- a/topdown/cache/cache_test.go +++ b/topdown/cache/cache_test.go @@ -21,7 +21,11 @@ func TestParseCachingConfig(t *testing.T) { *period = defaultStaleEntryEvictionPeriodSeconds threshold := new(int64) *threshold = defaultForcedEvictionThresholdPercentage - expected := &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, StaleEntryEvictionPeriodSeconds: period, ForcedEvictionThresholdPercentage: threshold}} + maxNumEntriesInterQueryValueCache := new(int) + *maxNumEntriesInterQueryValueCache = defaultInterQueryBuiltinValueCacheSize + + expected := &Config{InterQueryBuiltinCache: InterQueryBuiltinCacheConfig{MaxSizeBytes: maxSize, StaleEntryEvictionPeriodSeconds: period, ForcedEvictionThresholdPercentage: threshold}, + InterQueryBuiltinValueCache: InterQueryBuiltinValueCacheConfig{MaxNumEntries: maxNumEntriesInterQueryValueCache}} tests := map[string]struct { input []byte @@ -35,10 +39,18 @@ func TestParseCachingConfig(t *testing.T) { input: []byte(`{"inter_query_builtin_cache": {},}`), wantErr: false, }, + "default_num_entries": { + input: []byte(`{"inter_query_builtin_value_cache": {},}`), + wantErr: false, + }, "bad_limit": { input: []byte(`{"inter_query_builtin_cache": {"max_size_bytes": "100"},}`), wantErr: true, }, + "bad_num_entries": { + input: []byte(`{"inter_query_builtin_value_cache": {"max_num_entries": "100"},}`), + wantErr: true, + }, } for name, tc := range tests { @@ -165,6 +177,97 @@ func TestInsert(t *testing.T) { } } +func TestInterQueryValueCache(t *testing.T) { + + in := `{"inter_query_builtin_value_cache": {"max_num_entries": 4},}` + + config, err := ParseCachingConfig([]byte(in)) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + cache := NewInterQueryValueCache(context.Background(), config) + + cache.Insert(ast.StringTerm("foo").Value, "bar") + cache.Insert(ast.StringTerm("foo2").Value, "bar2") + cache.Insert(ast.StringTerm("hello").Value, "world") + dropped := cache.Insert(ast.StringTerm("hello2").Value, "world2") + + if dropped != 0 { + t.Fatal("Expected dropped to be zero") + } + + value, found := cache.Get(ast.StringTerm("foo").Value) + if !found { + t.Fatal("Expected key \"foo\" in cache") + } + + actual, ok := value.(string) + if !ok { + t.Fatal("Expected string value") + } + + if actual != "bar" { + t.Fatalf("Expected value \"bar\" but got %v", actual) + } + + dropped = cache.Insert(ast.StringTerm("foo3").Value, "bar3") + if dropped != 1 { + t.Fatal("Expected dropped to be one") + } + + _, found = cache.Get(ast.StringTerm("foo3").Value) + if !found { + t.Fatal("Expected key \"foo3\" in cache") + } + + // update the cache config + in = `{"inter_query_builtin_value_cache": {"max_num_entries": 0},}` // unlimited + config, err = ParseCachingConfig([]byte(in)) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + cache.UpdateConfig(config) + + cache.Insert(ast.StringTerm("a").Value, "b") + cache.Insert(ast.StringTerm("c").Value, "d") + cache.Insert(ast.StringTerm("e").Value, "f") + dropped = cache.Insert(ast.StringTerm("g").Value, "h") + + if dropped != 0 { + t.Fatal("Expected dropped to be zero") + } + + // at this point the cache should have 8 entries + // update the cache size and verify multiple items dropped + in = `{"inter_query_builtin_value_cache": {"max_num_entries": 6},}` + config, err = ParseCachingConfig([]byte(in)) + if err != nil { + t.Fatalf("Unexpected error %v", err) + } + + cache.UpdateConfig(config) + + dropped = cache.Insert(ast.StringTerm("i").Value, "j") + + if dropped != 3 { + t.Fatal("Expected dropped to be three") + } + + _, found = cache.Get(ast.StringTerm("i").Value) + if !found { + t.Fatal("Expected key \"i\" in cache") + } + + cache.Delete(ast.StringTerm("i").Value) + + _, found = cache.Get(ast.StringTerm("i").Value) + if found { + t.Fatal("Unexpected key \"i\" in cache") + } +} + func TestConcurrentInsert(t *testing.T) { in := `{"inter_query_builtin_cache": {"max_size_bytes": 20},}` // 20 byte limit for test purposes diff --git a/topdown/eval.go b/topdown/eval.go index 2fcc431c80..7884ac01e0 100644 --- a/topdown/eval.go +++ b/topdown/eval.go @@ -58,55 +58,56 @@ func (ee deferredEarlyExitError) Error() string { } type eval struct { - ctx context.Context - metrics metrics.Metrics - seed io.Reader - time *ast.Term - queryID uint64 - queryIDFact *queryIDFactory - parent *eval - caller *eval - cancel Cancel - query ast.Body - queryCompiler ast.QueryCompiler - index int - indexing bool - earlyExit bool - bindings *bindings - store storage.Store - baseCache *baseCache - txn storage.Transaction - compiler *ast.Compiler - input *ast.Term - data *ast.Term - external *resolverTrie - targetStack *refStack - tracers []QueryTracer - traceEnabled bool - traceLastLocation *ast.Location // Last location of a trace event. - plugTraceVars bool - instr *Instrumentation - builtins map[string]*Builtin - builtinCache builtins.Cache - ndBuiltinCache builtins.NDBCache - functionMocks *functionMocksStack - virtualCache VirtualCache - comprehensionCache *comprehensionCache - interQueryBuiltinCache cache.InterQueryCache - saveSet *saveSet - saveStack *saveStack - saveSupport *saveSupport - saveNamespace *ast.Term - skipSaveNamespace bool - inliningControl *inliningControl - genvarprefix string - genvarid int - runtime *ast.Term - builtinErrors *builtinErrors - printHook print.Hook - tracingOpts tracing.Options - findOne bool - strictObjects bool + ctx context.Context + metrics metrics.Metrics + seed io.Reader + time *ast.Term + queryID uint64 + queryIDFact *queryIDFactory + parent *eval + caller *eval + cancel Cancel + query ast.Body + queryCompiler ast.QueryCompiler + index int + indexing bool + earlyExit bool + bindings *bindings + store storage.Store + baseCache *baseCache + txn storage.Transaction + compiler *ast.Compiler + input *ast.Term + data *ast.Term + external *resolverTrie + targetStack *refStack + tracers []QueryTracer + traceEnabled bool + traceLastLocation *ast.Location // Last location of a trace event. + plugTraceVars bool + instr *Instrumentation + builtins map[string]*Builtin + builtinCache builtins.Cache + ndBuiltinCache builtins.NDBCache + functionMocks *functionMocksStack + virtualCache VirtualCache + comprehensionCache *comprehensionCache + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + saveSet *saveSet + saveStack *saveStack + saveSupport *saveSupport + saveNamespace *ast.Term + skipSaveNamespace bool + inliningControl *inliningControl + genvarprefix string + genvarid int + runtime *ast.Term + builtinErrors *builtinErrors + printHook print.Hook + tracingOpts tracing.Options + findOne bool + strictObjects bool } func (e *eval) Run(iter evalIterator) error { @@ -817,23 +818,24 @@ func (e *eval) evalCall(terms []*ast.Term, iter unifyIterator) error { } bctx := BuiltinContext{ - Context: e.ctx, - Metrics: e.metrics, - Seed: e.seed, - Time: e.time, - Cancel: e.cancel, - Runtime: e.runtime, - Cache: e.builtinCache, - InterQueryBuiltinCache: e.interQueryBuiltinCache, - NDBuiltinCache: e.ndBuiltinCache, - Location: e.query[e.index].Location, - QueryTracers: e.tracers, - TraceEnabled: e.traceEnabled, - QueryID: e.queryID, - ParentID: parentID, - PrintHook: e.printHook, - DistributedTracingOpts: e.tracingOpts, - Capabilities: capabilities, + Context: e.ctx, + Metrics: e.metrics, + Seed: e.seed, + Time: e.time, + Cancel: e.cancel, + Runtime: e.runtime, + Cache: e.builtinCache, + InterQueryBuiltinCache: e.interQueryBuiltinCache, + InterQueryBuiltinValueCache: e.interQueryBuiltinValueCache, + NDBuiltinCache: e.ndBuiltinCache, + Location: e.query[e.index].Location, + QueryTracers: e.tracers, + TraceEnabled: e.traceEnabled, + QueryID: e.queryID, + ParentID: parentID, + PrintHook: e.printHook, + DistributedTracingOpts: e.tracingOpts, + Capabilities: capabilities, } eval := evalBuiltin{ diff --git a/topdown/glob.go b/topdown/glob.go index 116602db74..3f769d8369 100644 --- a/topdown/glob.go +++ b/topdown/glob.go @@ -15,7 +15,9 @@ const globCacheMaxSize = 100 var globCacheLock = sync.Mutex{} var globCache map[string]glob.Glob -func builtinGlobMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +var globInterQueryValueCacheHits = "rego_builtin_glob_interquery_value_cache_hits" + +func builtinGlobMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { pattern, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -50,14 +52,46 @@ func builtinGlobMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter } id := builder.String() - m, err := globCompileAndMatch(id, string(pattern), string(match), delimiters) + m, err := globCompileAndMatch(bctx, id, string(pattern), string(match), delimiters) if err != nil { return err } return iter(ast.BooleanTerm(m)) } -func globCompileAndMatch(id, pattern, match string, delimiters []rune) (bool, error) { +func globCompileAndMatch(bctx BuiltinContext, id, pattern, match string, delimiters []rune) (bool, error) { + + if bctx.InterQueryBuiltinValueCache != nil { + val, ok := bctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(id).Value) + if ok { + pat, valid := val.(glob.Glob) + if !valid { + // The cache key may exist for a different value type (eg. regex). + // In this case, we calculate the glob and return the result w/o updating the cache. + var res glob.Glob + var err error + if res, err = glob.Compile(pattern, delimiters...); err != nil { + return false, err + } + out := res.Match(match) + return out, nil + } + bctx.Metrics.Counter(globInterQueryValueCacheHits).Incr() + out := pat.Match(match) + return out, nil + } + + var res glob.Glob + var err error + if res, err = glob.Compile(pattern, delimiters...); err != nil { + return false, err + } + bctx.InterQueryBuiltinValueCache.Insert(ast.StringTerm(id).Value, res) + + out := res.Match(match) + return out, nil + } + globCacheLock.Lock() defer globCacheLock.Unlock() p, ok := globCache[id] diff --git a/topdown/glob_test.go b/topdown/glob_test.go index f72e5d8b2c..2ec7f3732e 100644 --- a/topdown/glob_test.go +++ b/topdown/glob_test.go @@ -5,10 +5,12 @@ package topdown import ( + "context" "fmt" "testing" "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/topdown/cache" ) func TestGlobBuiltinCache(t *testing.T) { @@ -69,3 +71,108 @@ func TestGlobBuiltinCache(t *testing.T) { t.Fatalf("Expected glob to be cached: %v", glob2) } } + +func TestGlobBuiltinInterQueryValueCache(t *testing.T) { + ip := []byte(`{"inter_query_builtin_value_cache": {"max_num_entries": "10"},}`) + config, _ := cache.ParseCachingConfig(ip) + interQueryValueCache := cache.NewInterQueryValueCache(context.Background(), config) + + ctx := BuiltinContext{InterQueryBuiltinValueCache: interQueryValueCache} + iter := func(*ast.Term) error { return nil } + + // A novel glob pattern is cached. + glob1 := "foo/*" + operands := []*ast.Term{ + ast.NewTerm(ast.String(glob1)), + ast.NullTerm(), + ast.NewTerm(ast.String("foo/bar")), + } + err := builtinGlobMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // the glob id will have a trailing '-' rune. + if _, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(fmt.Sprintf("%s-", glob1)).Value); !ok { + t.Fatalf("Expected glob to be cached: %v", glob1) + } + + // Fill up the cache. + for i := 0; i < 9; i++ { + operands := []*ast.Term{ + ast.NewTerm(ast.String(fmt.Sprintf("foo/%d/*", i))), + ast.NullTerm(), + ast.NewTerm(ast.String(fmt.Sprintf("foo/%d/bar", i))), + } + err := builtinGlobMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } + + // A new glob pattern is cached and a random pattern is evicted. + glob2 := "bar/*" + operands = []*ast.Term{ + ast.NewTerm(ast.String(glob2)), + ast.NullTerm(), + ast.NewTerm(ast.String("bar/baz")), + } + err = builtinGlobMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if _, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(fmt.Sprintf("%s-", glob2)).Value); !ok { + t.Fatalf("Expected glob to be cached: %v", glob2) + } +} + +func TestGlobBuiltinInterQueryValueCacheTypeMismatch(t *testing.T) { + + ip := []byte(`{"inter_query_builtin_value_cache": {"max_num_entries": "10"},}`) + config, _ := cache.ParseCachingConfig(ip) + interQueryValueCache := cache.NewInterQueryValueCache(context.Background(), config) + + ctx := BuiltinContext{InterQueryBuiltinValueCache: interQueryValueCache} + iter := func(*ast.Term) error { return nil } + + key := "foo.*" + + operands := []*ast.Term{ + ast.NewTerm(ast.String(key)), + ast.NullTerm(), + ast.NewTerm(ast.String("foo/bar")), + } + err := builtinGlobMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // the glob id will have a trailing '-' rune. + if _, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(fmt.Sprintf("%s-", key)).Value); !ok { + t.Fatalf("Expected glob to be cached: %v", key) + } + + // update the cache entry + ctx.InterQueryBuiltinValueCache.Insert(ast.StringTerm(fmt.Sprintf("%s-", key)).Value, "bar") + + err = builtinGlobMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // verify the cache entry is unchanged + value, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(fmt.Sprintf("%s-", key)).Value) + if !ok { + t.Fatal("Expected key \"foo.*-\" in cache") + } + + actual, ok := value.(string) + if !ok { + t.Fatal("Expected string value") + } + + if actual != "bar" { + t.Fatalf("Expected value \"bar\" but got %v", actual) + } +} diff --git a/topdown/query.go b/topdown/query.go index bbb4ba58f3..8406cfdd87 100644 --- a/topdown/query.go +++ b/topdown/query.go @@ -27,38 +27,39 @@ type QueryResult map[ast.Var]*ast.Term // Query provides a configurable interface for performing query evaluation. type Query struct { - seed io.Reader - time time.Time - cancel Cancel - query ast.Body - queryCompiler ast.QueryCompiler - compiler *ast.Compiler - store storage.Store - txn storage.Transaction - input *ast.Term - external *resolverTrie - tracers []QueryTracer - plugTraceVars bool - unknowns []*ast.Term - partialNamespace string - skipSaveNamespace bool - metrics metrics.Metrics - instr *Instrumentation - disableInlining []ast.Ref - shallowInlining bool - genvarprefix string - runtime *ast.Term - builtins map[string]*Builtin - indexing bool - earlyExit bool - interQueryBuiltinCache cache.InterQueryCache - ndBuiltinCache builtins.NDBCache - strictBuiltinErrors bool - builtinErrorList *[]Error - strictObjects bool - printHook print.Hook - tracingOpts tracing.Options - virtualCache VirtualCache + seed io.Reader + time time.Time + cancel Cancel + query ast.Body + queryCompiler ast.QueryCompiler + compiler *ast.Compiler + store storage.Store + txn storage.Transaction + input *ast.Term + external *resolverTrie + tracers []QueryTracer + plugTraceVars bool + unknowns []*ast.Term + partialNamespace string + skipSaveNamespace bool + metrics metrics.Metrics + instr *Instrumentation + disableInlining []ast.Ref + shallowInlining bool + genvarprefix string + runtime *ast.Term + builtins map[string]*Builtin + indexing bool + earlyExit bool + interQueryBuiltinCache cache.InterQueryCache + interQueryBuiltinValueCache cache.InterQueryValueCache + ndBuiltinCache builtins.NDBCache + strictBuiltinErrors bool + builtinErrorList *[]Error + strictObjects bool + printHook print.Hook + tracingOpts tracing.Options + virtualCache VirtualCache } // Builtin represents a built-in function that queries can call. @@ -246,6 +247,12 @@ func (q *Query) WithInterQueryBuiltinCache(c cache.InterQueryCache) *Query { return q } +// WithInterQueryBuiltinValueCache sets the inter-query value cache that built-in functions can utilize. +func (q *Query) WithInterQueryBuiltinValueCache(c cache.InterQueryValueCache) *Query { + q.interQueryBuiltinValueCache = c + return q +} + // WithNDBuiltinCache sets the non-deterministic builtin cache. func (q *Query) WithNDBuiltinCache(c builtins.NDBCache) *Query { q.ndBuiltinCache = c @@ -331,39 +338,40 @@ func (q *Query) PartialRun(ctx context.Context) (partials []ast.Body, support [] } e := &eval{ - ctx: ctx, - metrics: q.metrics, - seed: q.seed, - time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), - cancel: q.cancel, - query: q.query, - queryCompiler: q.queryCompiler, - queryIDFact: f, - queryID: f.Next(), - bindings: b, - compiler: q.compiler, - store: q.store, - baseCache: newBaseCache(), - targetStack: newRefStack(), - txn: q.txn, - input: q.input, - external: q.external, - tracers: q.tracers, - traceEnabled: len(q.tracers) > 0, - plugTraceVars: q.plugTraceVars, - instr: q.instr, - builtins: q.builtins, - builtinCache: builtins.Cache{}, - functionMocks: newFunctionMocksStack(), - interQueryBuiltinCache: q.interQueryBuiltinCache, - ndBuiltinCache: q.ndBuiltinCache, - virtualCache: vc, - comprehensionCache: newComprehensionCache(), - saveSet: newSaveSet(q.unknowns, b, q.instr), - saveStack: newSaveStack(), - saveSupport: newSaveSupport(), - saveNamespace: ast.StringTerm(q.partialNamespace), - skipSaveNamespace: q.skipSaveNamespace, + ctx: ctx, + metrics: q.metrics, + seed: q.seed, + time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), + cancel: q.cancel, + query: q.query, + queryCompiler: q.queryCompiler, + queryIDFact: f, + queryID: f.Next(), + bindings: b, + compiler: q.compiler, + store: q.store, + baseCache: newBaseCache(), + targetStack: newRefStack(), + txn: q.txn, + input: q.input, + external: q.external, + tracers: q.tracers, + traceEnabled: len(q.tracers) > 0, + plugTraceVars: q.plugTraceVars, + instr: q.instr, + builtins: q.builtins, + builtinCache: builtins.Cache{}, + functionMocks: newFunctionMocksStack(), + interQueryBuiltinCache: q.interQueryBuiltinCache, + interQueryBuiltinValueCache: q.interQueryBuiltinValueCache, + ndBuiltinCache: q.ndBuiltinCache, + virtualCache: vc, + comprehensionCache: newComprehensionCache(), + saveSet: newSaveSet(q.unknowns, b, q.instr), + saveStack: newSaveStack(), + saveSupport: newSaveSupport(), + saveNamespace: ast.StringTerm(q.partialNamespace), + skipSaveNamespace: q.skipSaveNamespace, inliningControl: &inliningControl{ shallow: q.shallowInlining, }, @@ -516,42 +524,43 @@ func (q *Query) Iter(ctx context.Context, iter func(QueryResult) error) error { } e := &eval{ - ctx: ctx, - metrics: q.metrics, - seed: q.seed, - time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), - cancel: q.cancel, - query: q.query, - queryCompiler: q.queryCompiler, - queryIDFact: f, - queryID: f.Next(), - bindings: newBindings(0, q.instr), - compiler: q.compiler, - store: q.store, - baseCache: newBaseCache(), - targetStack: newRefStack(), - txn: q.txn, - input: q.input, - external: q.external, - tracers: q.tracers, - traceEnabled: len(q.tracers) > 0, - plugTraceVars: q.plugTraceVars, - instr: q.instr, - builtins: q.builtins, - builtinCache: builtins.Cache{}, - functionMocks: newFunctionMocksStack(), - interQueryBuiltinCache: q.interQueryBuiltinCache, - ndBuiltinCache: q.ndBuiltinCache, - virtualCache: vc, - comprehensionCache: newComprehensionCache(), - genvarprefix: q.genvarprefix, - runtime: q.runtime, - indexing: q.indexing, - earlyExit: q.earlyExit, - builtinErrors: &builtinErrors{}, - printHook: q.printHook, - tracingOpts: q.tracingOpts, - strictObjects: q.strictObjects, + ctx: ctx, + metrics: q.metrics, + seed: q.seed, + time: ast.NumberTerm(int64ToJSONNumber(q.time.UnixNano())), + cancel: q.cancel, + query: q.query, + queryCompiler: q.queryCompiler, + queryIDFact: f, + queryID: f.Next(), + bindings: newBindings(0, q.instr), + compiler: q.compiler, + store: q.store, + baseCache: newBaseCache(), + targetStack: newRefStack(), + txn: q.txn, + input: q.input, + external: q.external, + tracers: q.tracers, + traceEnabled: len(q.tracers) > 0, + plugTraceVars: q.plugTraceVars, + instr: q.instr, + builtins: q.builtins, + builtinCache: builtins.Cache{}, + functionMocks: newFunctionMocksStack(), + interQueryBuiltinCache: q.interQueryBuiltinCache, + interQueryBuiltinValueCache: q.interQueryBuiltinValueCache, + ndBuiltinCache: q.ndBuiltinCache, + virtualCache: vc, + comprehensionCache: newComprehensionCache(), + genvarprefix: q.genvarprefix, + runtime: q.runtime, + indexing: q.indexing, + earlyExit: q.earlyExit, + builtinErrors: &builtinErrors{}, + printHook: q.printHook, + tracingOpts: q.tracingOpts, + strictObjects: q.strictObjects, } e.caller = e q.metrics.Timer(metrics.RegoQueryEval).Start() diff --git a/topdown/regex.go b/topdown/regex.go index 877f19e233..7ddffe8d57 100644 --- a/topdown/regex.go +++ b/topdown/regex.go @@ -20,6 +20,8 @@ const regexCacheMaxSize = 100 var regexpCacheLock = sync.Mutex{} var regexpCache map[string]*regexp.Regexp +var regexInterQueryValueCacheHits = "rego_builtin_regex_interquery_value_cache_hits" + func builtinRegexIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { s, err := builtins.StringOperand(operands[0].Value, 1) @@ -35,7 +37,7 @@ func builtinRegexIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast. return iter(ast.BooleanTerm(true)) } -func builtinRegexMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +func builtinRegexMatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { s1, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -44,7 +46,7 @@ func builtinRegexMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te if err != nil { return err } - re, err := getRegexp(string(s1)) + re, err := getRegexp(bctx, string(s1)) if err != nil { return err } @@ -81,7 +83,7 @@ func builtinRegexMatchTemplate(_ BuiltinContext, operands []*ast.Term, iter func return iter(ast.BooleanTerm(re.MatchString(string(match)))) } -func builtinRegexSplit(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +func builtinRegexSplit(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { s1, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -90,7 +92,7 @@ func builtinRegexSplit(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te if err != nil { return err } - re, err := getRegexp(string(s1)) + re, err := getRegexp(bctx, string(s1)) if err != nil { return err } @@ -103,7 +105,33 @@ func builtinRegexSplit(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te return iter(ast.NewTerm(ast.NewArray(arr...))) } -func getRegexp(pat string) (*regexp.Regexp, error) { +func getRegexp(bctx BuiltinContext, pat string) (*regexp.Regexp, error) { + if bctx.InterQueryBuiltinValueCache != nil { + val, ok := bctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(pat).Value) + if ok { + res, valid := val.(*regexp.Regexp) + if !valid { + // The cache key may exist for a different value type (eg. glob). + // In this case, we calculate the regex and return the result w/o updating the cache. + re, err := regexp.Compile(pat) + if err != nil { + return nil, err + } + return re, nil + } + + bctx.Metrics.Counter(regexInterQueryValueCacheHits).Incr() + return res, nil + } + + re, err := regexp.Compile(pat) + if err != nil { + return nil, err + } + bctx.InterQueryBuiltinValueCache.Insert(ast.StringTerm(pat).Value, re) + return re, nil + } + regexpCacheLock.Lock() defer regexpCacheLock.Unlock() re, ok := regexpCache[pat] @@ -156,7 +184,7 @@ func builtinGlobsMatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Te return iter(ast.BooleanTerm(ne)) } -func builtinRegexFind(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +func builtinRegexFind(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { s1, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -169,7 +197,7 @@ func builtinRegexFind(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter if err != nil { return err } - re, err := getRegexp(string(s1)) + re, err := getRegexp(bctx, string(s1)) if err != nil { return err } @@ -182,7 +210,7 @@ func builtinRegexFind(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Ter return iter(ast.NewTerm(ast.NewArray(arr...))) } -func builtinRegexFindAllStringSubmatch(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +func builtinRegexFindAllStringSubmatch(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { s1, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -196,7 +224,7 @@ func builtinRegexFindAllStringSubmatch(_ BuiltinContext, operands []*ast.Term, i return err } - re, err := getRegexp(string(s1)) + re, err := getRegexp(bctx, string(s1)) if err != nil { return err } @@ -214,7 +242,7 @@ func builtinRegexFindAllStringSubmatch(_ BuiltinContext, operands []*ast.Term, i return iter(ast.NewTerm(ast.NewArray(outer...))) } -func builtinRegexReplace(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { +func builtinRegexReplace(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { base, err := builtins.StringOperand(operands[0].Value, 1) if err != nil { return err @@ -230,7 +258,7 @@ func builtinRegexReplace(_ BuiltinContext, operands []*ast.Term, iter func(*ast. return err } - re, err := getRegexp(string(pattern)) + re, err := getRegexp(bctx, string(pattern)) if err != nil { return err } diff --git a/topdown/regex_test.go b/topdown/regex_test.go index 90957e9d1b..bfa429d5e7 100644 --- a/topdown/regex_test.go +++ b/topdown/regex_test.go @@ -5,10 +5,12 @@ package topdown import ( + "context" "fmt" "testing" "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/topdown/cache" ) func TestRegexBuiltinCache(t *testing.T) { @@ -65,3 +67,93 @@ func TestRegexBuiltinCache(t *testing.T) { t.Fatalf("Expected regex to be cached: %v", regex2) } } + +func TestRegexBuiltinInterQueryValueCache(t *testing.T) { + + ip := []byte(`{"inter_query_builtin_value_cache": {"max_num_entries": "10"},}`) + config, _ := cache.ParseCachingConfig(ip) + interQueryValueCache := cache.NewInterQueryValueCache(context.Background(), config) + + ctx := BuiltinContext{InterQueryBuiltinValueCache: interQueryValueCache} + iter := func(*ast.Term) error { return nil } + + // A novel regex pattern is cached. + regex1 := "foo.*" + operands := []*ast.Term{ + ast.NewTerm(ast.String(regex1)), + ast.NewTerm(ast.String("foobar")), + } + err := builtinRegexMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if _, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(regex1).Value); !ok { + t.Fatalf("Expected regex to be cached: %v", regex1) + } + + // Fill up the cache. + for i := 0; i < 9; i++ { + operands := []*ast.Term{ + ast.NewTerm(ast.String(fmt.Sprintf("foo%d.*", i))), + ast.NewTerm(ast.String(fmt.Sprintf("foo%dbar", i))), + } + err := builtinRegexMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } + + // A new regex pattern is cached and a random pattern is evicted. + regex2 := "bar.*" + operands = []*ast.Term{ + ast.NewTerm(ast.String(regex2)), + ast.NewTerm(ast.String("barbaz")), + } + err = builtinRegexMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if _, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(regex2).Value); !ok { + t.Fatalf("Expected regex to be cached: %v", regex2) + } +} + +func TestRegexBuiltinInterQueryValueCacheTypeMismatch(t *testing.T) { + + ip := []byte(`{"inter_query_builtin_value_cache": {"max_num_entries": "10"},}`) + config, _ := cache.ParseCachingConfig(ip) + interQueryValueCache := cache.NewInterQueryValueCache(context.Background(), config) + + ctx := BuiltinContext{InterQueryBuiltinValueCache: interQueryValueCache} + iter := func(*ast.Term) error { return nil } + + key := "foo.*" + + ctx.InterQueryBuiltinValueCache.Insert(ast.StringTerm(key).Value, "bar") + + operands := []*ast.Term{ + ast.NewTerm(ast.String(key)), + ast.NewTerm(ast.String("foobar")), + } + err := builtinRegexMatch(ctx, operands, iter) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // verify the original cache entry is unchanged + value, ok := ctx.InterQueryBuiltinValueCache.Get(ast.StringTerm(key).Value) + if !ok { + t.Fatal("Expected key \"foo.*\" in cache") + } + + actual, ok := value.(string) + if !ok { + t.Fatal("Expected string value") + } + + if actual != "bar" { + t.Fatalf("Expected value \"bar\" but got %v", actual) + } +} diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index 63c0555517..7dce7d3947 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -2119,6 +2119,7 @@ func assertTopDownWithPathAndContext(ctx context.Context, t *testing.T, compiler // add an inter-query cache config, _ := iCache.ParseCachingConfig(nil) interQueryCache := iCache.NewInterQueryCache(config) + interQueryValueCache := iCache.NewInterQueryValueCache(ctx, config) var strictBuiltinErrors bool @@ -2133,6 +2134,7 @@ func assertTopDownWithPathAndContext(ctx context.Context, t *testing.T, compiler WithTransaction(txn). WithInput(inputTerm). WithInterQueryBuiltinCache(interQueryCache). + WithInterQueryBuiltinValueCache(interQueryValueCache). WithStrictBuiltinErrors(strictBuiltinErrors) var tracer BufferTracer @@ -2212,13 +2214,15 @@ func runTopDownPartialTestCase(ctx context.Context, t *testing.T, compiler *ast. // add an inter-query cache config, _ := iCache.ParseCachingConfig(nil) interQueryCache := iCache.NewInterQueryCache(config) + interQueryValueCache := iCache.NewInterQueryValueCache(ctx, config) partialQuery := NewQuery(body). WithCompiler(compiler). WithStore(store). WithUnknowns([]*ast.Term{ast.MustParseTerm("input")}). WithTransaction(txn). - WithInterQueryBuiltinCache(interQueryCache) + WithInterQueryBuiltinCache(interQueryCache). + WithInterQueryBuiltinValueCache(interQueryValueCache) partials, support, err := partialQuery.PartialRun(ctx) @@ -2251,7 +2255,8 @@ func runTopDownPartialTestCase(ctx context.Context, t *testing.T, compiler *ast. WithStore(store). WithTransaction(txn). WithInput(input). - WithInterQueryBuiltinCache(interQueryCache) + WithInterQueryBuiltinCache(interQueryCache). + WithInterQueryBuiltinValueCache(interQueryValueCache) qrs, err := query.Run(ctx) if err != nil {