From 7fa1760a0d67b57dc935198ae69cdec230db10be Mon Sep 17 00:00:00 2001 From: Michelangelo Mori Date: Tue, 17 Dec 2024 18:04:50 +0100 Subject: [PATCH 1/2] Add REGO debugger to Mindev. This change adds the possibility to start evaluate a REGO-based rule type in a debugger. The debugger allows setting breakpoints, stepping, printing source, and a few other simple utilities. The debugger is currently very, very, VERY rough around the edges and could use some love, especially in the reception of events from the debuggee, which is done inline and not asynchronously. --- cmd/dev/app/rule_type/rttst.go | 7 +- internal/engine/eval/rego/debug.go | 857 ++++++++++++++++++++++++++++ internal/engine/eval/rego/eval.go | 36 +- internal/engine/eval/rego/result.go | 17 +- internal/engine/options/options.go | 20 + internal/util/cli/styles.go | 3 +- 6 files changed, 927 insertions(+), 13 deletions(-) create mode 100644 internal/engine/eval/rego/debug.go diff --git a/cmd/dev/app/rule_type/rttst.go b/cmd/dev/app/rule_type/rttst.go index 7c263bb956..2236414313 100644 --- a/cmd/dev/app/rule_type/rttst.go +++ b/cmd/dev/app/rule_type/rttst.go @@ -67,6 +67,7 @@ func CmdTest() *cobra.Command { testCmd.Flags().StringP("token", "t", "", "token to authenticate to the provider."+ "Can also be set via the TEST_AUTH_TOKEN environment variable.") testCmd.Flags().StringArrayP("data-source", "d", []string{}, "YAML file containing the data source to test the rule with") + testCmd.Flags().BoolP("debug", "", false, "Start REGO debugger (only works for REGO-based rules types)") if err := testCmd.MarkFlagRequired("rule-type"); err != nil { fmt.Fprintf(os.Stderr, "Error marking flag as required: %s\n", err) @@ -98,6 +99,7 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { token := viper.GetString("test.auth.token") providerclass := cmd.Flag("provider") providerconfig := cmd.Flag("provider-config") + debug := cmd.Flag("debug").Value.String() == "true" dataSourceFileStrings, err := cmd.Flags().GetStringArray("data-source") if err != nil { @@ -197,7 +199,10 @@ func testCmdRun(cmd *cobra.Command, _ []string) error { // TODO: use cobra context here ctx := context.Background() - eng, err := rtengine.NewRuleTypeEngine(ctx, ruletype, prov, nil /*experiments*/, options.WithDataSources(dsRegistry)) + eng, err := rtengine.NewRuleTypeEngine(ctx, ruletype, prov, nil, /*experiments*/ + options.WithDataSources(dsRegistry), + options.WithDebugger(debug), + ) if err != nil { return fmt.Errorf("cannot create rule type engine: %w", err) } diff --git a/internal/engine/eval/rego/debug.go b/internal/engine/eval/rego/debug.go new file mode 100644 index 0000000000..dc1110431d --- /dev/null +++ b/internal/engine/eval/rego/debug.go @@ -0,0 +1,857 @@ +// SPDX-FileCopyrightText: Copyright 2024 The Minder Authors +// SPDX-License-Identifier: Apache-2.0 + +// Package rego provides the rego rule evaluator +package rego + +import ( + "bufio" + "context" + "errors" + "fmt" + "math" + "os" + "regexp" + "slices" + "strconv" + "strings" + + "github.com/open-policy-agent/opa/ast/location" + "github.com/open-policy-agent/opa/debug" + "github.com/open-policy-agent/opa/rego" + + "github.com/mindersec/minder/internal/util/cli" + "github.com/mindersec/minder/pkg/engine/v1/interfaces" +) + +func MakeEventHandler(ch chan<- *debug.Event) func(debug.Event) { + return func(event debug.Event) { + ch <- &event + } +} + +func MakeTracingEventHandler(ch chan<- *debug.Event) func(debug.Event) { + return func(event debug.Event) { + fmt.Fprintf(os.Stderr, "%+v\n", event) + ch <- &event + } +} + +func (ds *debugSession) WaitFor( + ctx context.Context, + eventTypes ...debug.EventType, +) *debug.Event { + for { + select { + case e := <-ds.ch: + if slices.Contains(eventTypes, e.Type) { + return e + } + case <-ctx.Done(): + return nil + } + } +} + +var ( + errEmptySource = errors.New("empty source code") + errInvalidInput = errors.New("invalid input") + errInvalidInstr = errors.New("invalid instruction") + errInvalidBP = errors.New("invalid breakpoint") +) + +// Debug implements an interactive debugger for REGO-based evaluators. +func (e *Evaluator) Debug( + ctx context.Context, + _ *interfaces.Result, + input *Input, + funcs ...func(*rego.Rego), +) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + allOpts := make([]func(*rego.Rego), 0, len(e.regoOpts)+len(funcs)) + allOpts = append(allOpts, e.regoOpts...) + allOpts = append(allOpts, funcs...) + + ds, err := newDebugSession( + withPrompt("(mindbg)"), + withSource(e.cfg.Def), + withInput(input), + withQuery(e.reseval.getQueryString()), + withOpts(allOpts...), + withTracingEventHandler(), + ) + if err != nil { + return fmt.Errorf("error initializing debugger: %w", err) + } + + return ds.Start(ctx) +} + +type debugSession struct { + prompt string + src string + lines int + input *Input + query string + opts []debug.LaunchOption + ch chan *debug.Event + handler debug.EventHandler + + // fields initialized after starting the session + session debug.Session +} + +type debugSessionOption func(*debugSession) error + +func withPrompt(prompt string) debugSessionOption { + return func(ds *debugSession) error { + ds.prompt = prompt + return nil + } +} + +func withSource(src string) debugSessionOption { + return func(ds *debugSession) error { + if len(src) == 0 { + return errEmptySource + } + ds.src = src + ds.lines = len(strings.Split(src, "\n")) + return nil + } +} + +func withInput(input any) debugSessionOption { + return func(ds *debugSession) error { + inner, ok := input.(*Input) + if !ok { + return fmt.Errorf("%w: wrong type %T", errInvalidInput, input) + } + ds.input = inner + return nil + } +} + +func withQuery(query string) debugSessionOption { + return func(ds *debugSession) error { + ds.query = query + return nil + } +} + +func withOpts(opts ...func(*rego.Rego)) debugSessionOption { + return func(ds *debugSession) error { + var res []debug.LaunchOption + if ds.opts == nil { + res = make([]debug.LaunchOption, 0, len(opts)) + } else { + res = ds.opts + } + + for _, opt := range opts { + res = append(res, debug.RegoOption(opt)) + } + + ds.opts = res + return nil + } +} + +func withTracingEventHandler() debugSessionOption { + return func(ds *debugSession) error { + ch := make(chan *debug.Event, 10) + ds.ch = ch + ds.handler = MakeTracingEventHandler(ch) + return nil + } +} + +func newDebugSession( + opts ...debugSessionOption, +) (*debugSession, error) { + ds := &debugSession{} + + for _, opt := range opts { + if err := opt(ds); err != nil { + return nil, err + } + } + + if ds.handler == nil { + ch := make(chan *debug.Event, 10) + ds.ch = ch + ds.handler = MakeEventHandler(ch) + } + + return ds, nil +} + +func (ds *debugSession) startDebugger( + ctx context.Context, +) error { + debugger := debug.NewDebugger( + debug.SetEventHandler(ds.handler), + ) + launchProps := debug.LaunchEvalProperties{ + LaunchProperties: debug.LaunchProperties{ + StopOnEntry: false, + StopOnFail: false, + StopOnResult: true, + EnablePrint: true, + RuleIndexing: false, + }, + Input: ds.input, + Query: ds.query, + } + + session, err := debugger.LaunchEval(ctx, launchProps, ds.opts...) + if err != nil { + return err + } + + ds.session = session + + return nil +} + +//nolint:gocyclo +func (ds *debugSession) Start(ctx context.Context) error { + err := ds.startDebugger(ctx) + if err != nil { + return fmt.Errorf("error launching debugger: %w", err) + } + + thr := debug.ThreadID(1) + fmt.Printf("%s ", ds.prompt) + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + var b strings.Builder + switch { + case line == "": + // There's nothing to do here, but it is + // useful to let the user spam enter to see if + // it's working. + case line == "r": + err = ds.startDebugger(ctx) + if err != nil { + return fmt.Errorf("error restarting debugger: %w", err) + } + fmt.Fprintf(&b, "Restarted") + case line == "c": + if err := ds.session.Resume(thr); err != nil { + return fmt.Errorf("error resuming execution: %w", err) + } + + evt := ds.WaitFor(ctx, + debug.ExceptionEventType, + debug.StoppedEventType, + debug.StdoutEventType, + debug.TerminatedEventType, + ) + switch evt.Type { + case debug.ExceptionEventType: + fmt.Fprintf(&b, "\nException\n") + if err := printLocals(&b, ds.session, evt.Thread); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + case debug.StoppedEventType: + fmt.Fprintf(&b, "\nStopped\n") + if err := printLocals(&b, ds.session, evt.Thread); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + case debug.StdoutEventType: + fmt.Fprintf(&b, "\nFinished\n") + if err := printLocals(&b, ds.session, evt.Thread); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + fmt.Fprintf(&b, "\nResult: ") + err := printVar(&b, + fmt.Sprintf("%s.*", RegoQueryPrefix), + ds.session, + evt.Thread, + ) + if err != nil { + return fmt.Errorf("error printing variable: %w", err) + } + case debug.TerminatedEventType: + fmt.Fprintf(&b, "\nTerminated\n") + } + case line == "locals": + if err := printLocals(&b, ds.session, thr); err != nil { + return fmt.Errorf("error printing locals: %w", err) + } + case line == "bp": + bps, err := ds.session.Breakpoints() + if err != nil { + return fmt.Errorf("error getting breakpoints: %w", err) + } + printBreakpoints(&b, bps) + case line == "bt": + stack, err := ds.session.StackTrace(thr) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printStackTrace(&b, stack, 10) + case line == "list", line == "l": + stack, err := ds.session.StackTrace(thr) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printSource(&b, ds.src, stack) + case line == "trs": + threads, err := ds.session.Threads() + if err != nil { + return fmt.Errorf("error getting threads: %w", err) + } + printThreads(&b, threads) + + // "clearall" command currently removes all + // breakpoints, both user-defined and internal + // ones. This is not desirable for the very same + // reasons described in the comment related to the + // "next" command. + case line == "cla", + line == "clearall": + if err := ds.session.ClearBreakpoints(); err != nil { + return fmt.Errorf("error clearing breakpoints: %w", err) + } + + // "next" is a bit quirky, since it requires a few + // steps to function, namely: + // + // * adding a so called "internal breakpoint" + // * running until it's reached, and finally + // * removing the breakpoint + // + // Internal breakpoints should be managed separately + // from user-defined breakpoints, as the user should + // neither see them nor be allowed to remove them + // since it could invalidate some assumptions the code + // does around them. + // + // TODO: add two lists of breakpoints to + // `debugSession` struct and add routines to manage + // them. + case line == "n", + line == "next": + stack, err := ds.session.StackTrace(thr) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + if loc := getCurrentLocation(stack); loc != nil { + loc.Row += 1 // let's hope it always exists... + loc.Col = 0 + + // add internal breakpoint + bp, err := ds.session.AddBreakpoint(*loc) + if err != nil { + return fmt.Errorf("error setting breakpoint: %w", err) + } + + // resume execution + if err := ds.session.Resume(thr); err != nil { + return fmt.Errorf("error resuming execution: %w", err) + } + + evt := ds.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + + // clear internal breakpoint, even if + // we stopped for another reason. + if _, err := ds.session.RemoveBreakpoint(bp.ID()); err != nil { + return fmt.Errorf("error removing breakpoing: %w", err) + } + + printSource(&b, ds.src, stack) + } + case line == "s", + line == "sv": + if err := ds.session.StepOver(thr); err != nil { + panic(err) + } + evt := ds.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printSource(&b, ds.src, stack) + case line == "si": + go func() { + if err := ds.session.StepIn(thr); err != nil { + panic(err) + } + }() + evt := ds.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printSource(&b, ds.src, stack) + case line == "so": + go func() { + if err := ds.session.StepOut(thr); err != nil { + panic(err) + } + }() + evt := ds.WaitFor(ctx, debug.StoppedEventType) + stack, err := ds.session.StackTrace(evt.Thread) + if err != nil { + return fmt.Errorf("error getting stack trace: %w", err) + } + printSource(&b, ds.src, stack) + case line == "q": + return fmt.Errorf("user abort") + case line == "h", + line == "help": + printHelp(&b) + case strings.HasPrefix(line, "p"): + varname, err := toVarName(line) + if err != nil { + fmt.Fprintln(&b, err) + continue + } + // printVar function accepts a regexp as + // variable name, allowing the caller to match + // multiple variables. + // + // We don't want to expose this functionality + // to the user, as the general case (fetching + // a specific variable) becomes awkward, + // requiring the user to specify the full + // regex. + // + // To solve this, we always wrap the received + // variable name in ^ and $. + r := fmt.Sprintf("^%s$", varname) + if err := printVar(&b, r, ds.session, thr); err != nil { + return fmt.Errorf("error printing variables: %w", err) + } + case strings.HasPrefix(line, "b"): + loc, err := toLocation(line, ds.lines) + if err != nil { + fmt.Fprintln(&b, err) + } else { + bp, err := ds.session.AddBreakpoint(*loc) + if err != nil { + return fmt.Errorf("error setting breakpoint: %w", err) + } + fmt.Fprintln(&b) + printBreakpoint(&b, bp) + } + + // "clear" command currently allows removing all + // breakpoints, both user-defined and internal + // ones. This is not desirable for the very same + // reasons described in the comment related to the + // "next" command. + case strings.HasPrefix(line, "cl "), + strings.HasPrefix(line, "clear "): + ids := make([]debug.BreakpointID, 0) + bps, err := ds.session.Breakpoints() + if err != nil { + return fmt.Errorf("error gettin breakpoints: %w", err) + } + for _, bp := range bps { + ids = append(ids, bp.ID()) + } + id, err := toBreakpointID(line, ids) + if err != nil { + fmt.Fprintln(&b, err) + } else { + if _, err := ds.session.RemoveBreakpoint(id); err != nil { + return fmt.Errorf("error removing breakpoint: %w", err) + } + } + default: + fmt.Fprintf(&b, "Invalid command: %s\nPress h for help\n", line) + } + + output := b.String() + if output != "" { + fmt.Printf("%s\n%s ", output, ds.prompt) + } else { + fmt.Printf("%s ", ds.prompt) + } + } + + return scanner.Err() +} + +func toLocation(line string, lineCount int) (*location.Location, error) { + num, ok := strings.CutPrefix(line, "b ") + if !ok { + return nil, fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + i, err := strconv.ParseInt(num, 10, 64) + if err != nil { + return nil, fmt.Errorf(`%w: invalid line "%s": %s`, errInvalidBP, num, err) + } + if i < 1 || int(i) > lineCount { + return nil, fmt.Errorf("%w: invalid line %d", errInvalidBP, i) + } + return &location.Location{File: "minder.rego", Row: int(i)}, nil +} + +func toBreakpointID(line string, ids []debug.BreakpointID) (debug.BreakpointID, error) { + num1, ok1 := strings.CutPrefix(line, "cl ") + num2, ok2 := strings.CutPrefix(line, "clear ") + if !ok1 && !ok2 { + return debug.BreakpointID(-1), fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + + var num string + if !ok1 { + num = num2 + } + if !ok2 { + num = num1 + } + + i, err := strconv.ParseInt(num, 10, 64) + if err != nil { + return debug.BreakpointID(-1), fmt.Errorf( + `%w: invalid breakpoint id %s`, + errInvalidBP, num, + ) + } + + if i < 1 { + return debug.BreakpointID(-1), fmt.Errorf( + "%w: negative line id", + errInvalidBP, + ) + } + + if !slices.Contains(ids, debug.BreakpointID(i)) { + return debug.BreakpointID(-1), fmt.Errorf( + "%w: breakpoint does not exist", + errInvalidBP, + ) + } + + return debug.BreakpointID(i), nil +} + +func toVarName(line string) (string, error) { + varname, ok := strings.CutPrefix(line, "p ") + if !ok { + return "", fmt.Errorf(`%w: "%s"`, errInvalidInstr, line) + } + return varname, nil +} + +func printBreakpoints(b *strings.Builder, bps []debug.Breakpoint) { + if len(bps) == 0 { + return + } + fmt.Fprintln(b) + for _, bp := range bps { + printBreakpoint(b, bp) + } +} + +func printBreakpoint(b *strings.Builder, bp debug.Breakpoint) { + fmt.Fprintf(b, "Breakpoint %d set at %s:%d\n", + bp.ID(), + bp.Location().File, + bp.Location().Row, + ) +} + +func printThreads(b *strings.Builder, threads []debug.Thread) { + if len(threads) == 0 { + return + } + fmt.Fprintln(b) + for _, thread := range threads { + fmt.Fprintf(b, "Thread %d\n", thread.ID()) + } +} + +func getCurrentLocation(stack debug.StackTrace) *location.Location { + if len(stack) == 0 { + return nil + } + + frame := stack[0] + return frame.Location() +} + +func printStackTrace(b *strings.Builder, stack debug.StackTrace, limit int) { + if len(stack) == 0 { + return + } + + fmt.Fprintln(b) + for _, frame := range stack[:limit] { + if loc := frame.Location(); loc != nil { + fmt.Fprintf(b, "Frame %d at %s:%d.%d\n", + frame.ID(), + loc.File, + loc.Row, + loc.Col, + ) + } + } + if len(stack) > limit { + fmt.Fprintf(b, "...\n") + } +} + +func printSource(b *strings.Builder, src string, stack debug.StackTrace) { + if len(stack) == 0 { + printSourceSimple(b, src) + return + } + + lines := strings.Split(src, "\n") + padding := int64(math.Floor(math.Log10(float64(len(lines)))) + 1) + + fmt.Fprintln(b) + frame := stack[0] + if loc := frame.Location(); loc != nil { + fmt.Fprintf(b, "Frame %d at %s:%d.%d\n", + frame.ID(), + loc.File, + loc.Row, + loc.Col, + ) + + for idx, line := range strings.Split(src, "\n") { + fmt.Fprintf(b, "%*d: %s", padding, idx+1, line) + if idx+1 == loc.Row { + theline := strings.Split(string(loc.Text), "\n")[0] + fmt.Fprintf(b, "\n%s%s", + strings.Repeat(" ", loc.Col+int(padding)+2-1), + cli.SimpleBoldStyle.Render(strings.Repeat("^", len(theline))), + ) + } + fmt.Fprintln(b) + } + } +} + +func printSourceSimple(b *strings.Builder, source string) { + fmt.Fprintln(b) + lines := strings.Split(source, "\n") + padding := int64(math.Floor(math.Log10(float64(len(lines)))) + 1) + for idx, line := range lines { + fmt.Fprintf(b, "%*d: %s\n", padding, idx+1, line) + } +} + +func printLocals(b *strings.Builder, s debug.Session, thrID debug.ThreadID) error { + trace, err := s.StackTrace(thrID) + if err != nil { + return fmt.Errorf("error getting stacktrace: %w", err) + } + + if len(trace) == 0 { + return nil + } + + // The first trace in the list is the one related to the + // current stack frame. + scopes, err := s.Scopes(trace[0].ID()) + if err != nil { + return fmt.Errorf("error getting scopes: %w", err) + } + + for _, scope := range scopes { + vars, err := s.Variables(scope.VariablesReference()) + if err != nil { + return fmt.Errorf("error getting variables: %w", err) + } + for _, v := range vars { + fmt.Fprintf(b, "%s %s = %s\n", v.Type(), v.Name(), v.Value()) + } + } + + return nil +} + +func printVar( + b *strings.Builder, + varname string, + s debug.Session, + thrID debug.ThreadID, +) error { + r, err := regexp.Compile(varname) + if err != nil { + return fmt.Errorf("error instantiating regex: %w", err) + } + + trace, err := s.StackTrace(thrID) + if err != nil { + return fmt.Errorf("error getting stacktrace: %w", err) + } + + if len(trace) == 0 { + return nil + } + + // The first trace in the list is the one related to the + // current stack frame. + scopes, err := s.Scopes(trace[0].ID()) + if err != nil { + return fmt.Errorf("error getting scopes: %w", err) + } + + for _, scope := range scopes { + if err := printVariablesInScope(b, r, s, scope.VariablesReference()); err != nil { + return err + } + } + + return nil +} + +func printVariablesInScope( + b *strings.Builder, + r *regexp.Regexp, + s debug.Session, + varRef debug.VarRef, +) error { + if varRef == 0 { + return nil + } + + vars, err := s.Variables(varRef) + if err != nil { + return fmt.Errorf("error getting variables: %w", err) + } + for _, v := range vars { + if r.MatchString(v.Name()) { + var b1 strings.Builder + if err := varToString(&b1, v, s, 0); err != nil { + return err + } + fmt.Fprintf(b, "%s %s = %s\n", v.Type(), v.Name(), b1.String()) + + // We break early here despite the fact that + // multiple variables might match the given + // `varname`. This is done to honour lexical + // scope, showing just the only variable that + // is actually being used for evaluation in + // the given frame. + return nil + } + } + + return nil +} + +func varToString( + b *strings.Builder, + v debug.Variable, + s debug.Session, + indentation int, +) error { + padding := strings.Repeat(" ", indentation) + switch v.Type() { + case "array": + return elementsToString(b, v, s, indentation, "[", "]", + func(elem debug.Variable) error { + fmt.Fprintf(b, " %s", padding) + err := varToString(b, elem, s, indentation) + if err != nil { + return err + } + fmt.Fprintf(b, ",\n") + return nil + }, + ) + case "set": + return elementsToString(b, v, s, indentation, "{", "}", + func(elem debug.Variable) error { + fmt.Fprintf(b, " %s", padding) + err := varToString(b, elem, s, indentation) + if err != nil { + return err + } + fmt.Fprintf(b, ",\n") + return nil + }, + ) + case "object": + return elementsToString(b, v, s, indentation, "{", "}", + func(elem debug.Variable) error { + fmt.Fprintf(b, " %s%s: ", padding, elem.Name()) + err := varToString(b, elem, s, indentation) + if err != nil { + return err + } + fmt.Fprintf(b, ",\n") + return nil + }, + ) + default: + fmt.Fprintf(b, "%s%s", padding, v.Value()) + } + + return nil +} + +func elementsToString( + b *strings.Builder, + v debug.Variable, + s debug.Session, + indentation int, + leftDelimiter string, + rightDelimiter string, + formatter func(debug.Variable) error, +) error { + padding := strings.Repeat(" ", indentation) + fmt.Fprintf(b, "%s%s\n", padding, leftDelimiter) + elems, err := s.Variables(v.VariablesReference()) + if err != nil { + return err + } + for _, elem := range elems { + if err := formatter(elem); err != nil { + return err + } + } + fmt.Fprintf(b, "%s%s", padding, rightDelimiter) + + return nil +} + +var helpMsg = ` +Controlling execution: + c ------------- continue + r ------------- restart debugging session + q ------------- quit + +Printing: + bt ------------ print stack trace (top 10) + trs ----------- print threads + list/l -------- list source + locals -------- print local variables + +Breakpoints: + bp ------------ show breakpoints + b ------- set breakpoint at line + clear/cl - clear breakpoint with id + clearall/cla -- clear all breakpoints + +Stepping: + n ------------- next line + s/sv ---------- step over + so ------------ step out + si ------------ step into + +Help: + help/h -------- print help +` + +func printHelp(b *strings.Builder) { + fmt.Fprint(b, helpMsg) +} diff --git a/internal/engine/eval/rego/eval.go b/internal/engine/eval/rego/eval.go index fd2a597360..726d1494c2 100644 --- a/internal/engine/eval/rego/eval.go +++ b/internal/engine/eval/rego/eval.go @@ -44,6 +44,15 @@ type Evaluator struct { regoOpts []func(*rego.Rego) reseval resultEvaluator datasources *v1datasources.DataSourceRegistry + debug bool +} + +var _ eoptions.HasDebuggerSupport = (*Evaluator)(nil) + +// SetDebugFlag implements `HasDebuggerSupport` interface. +func (e *Evaluator) SetDebugFlag(flag bool) error { + e.debug = flag + return nil } // Input is the input for the rego evaluator @@ -132,6 +141,26 @@ func (e *Evaluator) Eval( // If the evaluator has data sources defined, expose their functions regoFuncOptions = append(regoFuncOptions, buildDataSourceOptions(res, e.datasources)...) + input := &Input{ + Profile: pol, + Ingested: obj, + OutputFormat: e.cfg.ViolationFormat, + } + enrichInputWithEntityProps(input, entity) + + if e.debug { + err := e.Debug( + ctx, + res, + input, + regoFuncOptions..., + ) + if err != nil { + return nil, err + } + return nil, nil + } + // Create the rego object r := e.newRegoFromOptions( regoFuncOptions..., @@ -142,13 +171,6 @@ func (e *Evaluator) Eval( return nil, fmt.Errorf("could not prepare Rego: %w", err) } - input := &Input{ - Profile: pol, - Ingested: obj, - OutputFormat: e.cfg.ViolationFormat, - } - - enrichInputWithEntityProps(input, entity) rs, err := pq.Eval(ctx, rego.EvalInput(input)) if err != nil { return nil, fmt.Errorf("error evaluating profile. Might be wrong input: %w", err) diff --git a/internal/engine/eval/rego/result.go b/internal/engine/eval/rego/result.go index d713aa2557..300b903b3e 100644 --- a/internal/engine/eval/rego/result.go +++ b/internal/engine/eval/rego/result.go @@ -53,6 +53,7 @@ func (c ConstraintsViolationsFormat) String() string { } type resultEvaluator interface { + getQueryString() string getQuery() func(*rego.Rego) parseResult(rego.ResultSet, protoreflect.ProtoMessage) (*interfaces.EvaluationResult, error) } @@ -60,8 +61,12 @@ type resultEvaluator interface { type denyByDefaultEvaluator struct { } -func (*denyByDefaultEvaluator) getQuery() func(r *rego.Rego) { - return rego.Query(RegoQueryPrefix) +func (*denyByDefaultEvaluator) getQueryString() string { + return RegoQueryPrefix +} + +func (d *denyByDefaultEvaluator) getQuery() func(r *rego.Rego) { + return rego.Query(d.getQueryString()) } func (*denyByDefaultEvaluator) parseResult(rs rego.ResultSet, entity protoreflect.ProtoMessage, @@ -168,8 +173,12 @@ type constraintsEvaluator struct { format ConstraintsViolationsFormat } -func (*constraintsEvaluator) getQuery() func(r *rego.Rego) { - return rego.Query(fmt.Sprintf("%s.violations[details]", RegoQueryPrefix)) +func (*constraintsEvaluator) getQueryString() string { + return fmt.Sprintf("%s.violations[details]", RegoQueryPrefix) +} + +func (c *constraintsEvaluator) getQuery() func(r *rego.Rego) { + return rego.Query(c.getQueryString()) } func (c *constraintsEvaluator) parseResult(rs rego.ResultSet, _ protoreflect.ProtoMessage) (*interfaces.EvaluationResult, error) { diff --git a/internal/engine/options/options.go b/internal/engine/options/options.go index 0da6223418..b71b1eb3ab 100644 --- a/internal/engine/options/options.go +++ b/internal/engine/options/options.go @@ -35,6 +35,26 @@ func WithFlagsClient(client openfeature.IClient) Option { } } +// HasDebuggerSupport interface should be implemented by evaluation +// engines that support interactive debugger. Currently, only +// REGO-based engines should implement this. +type HasDebuggerSupport interface { + SetDebugFlag(bool) error +} + +// WithDebugger sets the evaluation engine to start an interactive +// debugging session. This MUST NOT be used in backend servers, and is +// only meant to be used in CLI tools. +func WithDebugger(flag bool) Option { + return func(e interfaces.Evaluator) error { + inner, ok := e.(HasDebuggerSupport) + if !ok { + return nil + } + return inner.SetDebugFlag(flag) + } +} + // SupportsDataSources interface advertises the fact that the implementer // can register data sources with the evaluator. type SupportsDataSources interface { diff --git a/internal/util/cli/styles.go b/internal/util/cli/styles.go index 4973c54b85..453db8c72b 100644 --- a/internal/util/cli/styles.go +++ b/internal/util/cli/styles.go @@ -27,7 +27,8 @@ var ( // Common styles var ( - CursorStyle = lipgloss.NewStyle().Foreground(SecondaryColor) + CursorStyle = lipgloss.NewStyle().Foreground(SecondaryColor) + SimpleBoldStyle = lipgloss.NewStyle().Bold(true) ) // Banner styles From 1d0fe2c181e30f09335a21e893ab6bc026c0b9fc Mon Sep 17 00:00:00 2001 From: Michelangelo Mori Date: Mon, 23 Dec 2024 11:01:03 +0100 Subject: [PATCH 2/2] Accept `any` as input and add comments. --- internal/engine/eval/rego/debug.go | 170 +++++++++++++++++++---------- 1 file changed, 113 insertions(+), 57 deletions(-) diff --git a/internal/engine/eval/rego/debug.go b/internal/engine/eval/rego/debug.go index dc1110431d..5d68ff178d 100644 --- a/internal/engine/eval/rego/debug.go +++ b/internal/engine/eval/rego/debug.go @@ -24,40 +24,25 @@ import ( "github.com/mindersec/minder/pkg/engine/v1/interfaces" ) -func MakeEventHandler(ch chan<- *debug.Event) func(debug.Event) { +func makeEventHandler(ch chan<- *debug.Event) func(debug.Event) { return func(event debug.Event) { ch <- &event } } -func MakeTracingEventHandler(ch chan<- *debug.Event) func(debug.Event) { +//nolint:unused +func makeTracingEventHandler(ch chan<- *debug.Event) func(debug.Event) { return func(event debug.Event) { fmt.Fprintf(os.Stderr, "%+v\n", event) ch <- &event } } -func (ds *debugSession) WaitFor( - ctx context.Context, - eventTypes ...debug.EventType, -) *debug.Event { - for { - select { - case e := <-ds.ch: - if slices.Contains(eventTypes, e.Type) { - return e - } - case <-ctx.Done(): - return nil - } - } -} - var ( errEmptySource = errors.New("empty source code") - errInvalidInput = errors.New("invalid input") errInvalidInstr = errors.New("invalid instruction") errInvalidBP = errors.New("invalid breakpoint") + errUserAbort = errors.New("user abort") ) // Debug implements an interactive debugger for REGO-based evaluators. @@ -67,9 +52,6 @@ func (e *Evaluator) Debug( input *Input, funcs ...func(*rego.Rego), ) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - allOpts := make([]func(*rego.Rego), 0, len(e.regoOpts)+len(funcs)) allOpts = append(allOpts, e.regoOpts...) allOpts = append(allOpts, funcs...) @@ -80,7 +62,7 @@ func (e *Evaluator) Debug( withInput(input), withQuery(e.reseval.getQueryString()), withOpts(allOpts...), - withTracingEventHandler(), + // withTracingEventHandler(), ) if err != nil { return fmt.Errorf("error initializing debugger: %w", err) @@ -93,7 +75,7 @@ type debugSession struct { prompt string src string lines int - input *Input + input any query string opts []debug.LaunchOption ch chan *debug.Event @@ -125,11 +107,7 @@ func withSource(src string) debugSessionOption { func withInput(input any) debugSessionOption { return func(ds *debugSession) error { - inner, ok := input.(*Input) - if !ok { - return fmt.Errorf("%w: wrong type %T", errInvalidInput, input) - } - ds.input = inner + ds.input = input return nil } } @@ -159,11 +137,20 @@ func withOpts(opts ...func(*rego.Rego)) debugSessionOption { } } +//nolint:unused func withTracingEventHandler() debugSessionOption { return func(ds *debugSession) error { + // NOTE: this channel must be buffered, because REGO + // interpreter emits several events that we're + // currently handling in the same thread of execition + // of the CLI interface. + // + // The solution would be handling CLI events and + // debuggee events asynchronously, but we're not there + // yet. ch := make(chan *debug.Event, 10) ds.ch = ch - ds.handler = MakeTracingEventHandler(ch) + ds.handler = makeTracingEventHandler(ch) return nil } } @@ -180,20 +167,62 @@ func newDebugSession( } if ds.handler == nil { + // NOTE: this channel must be buffered, because REGO + // interpreter emits several events that we're + // currently handling in the same thread of execition + // of the CLI interface. + // + // The solution would be handling CLI events and + // debuggee events asynchronously, but we're not there + // yet. ch := make(chan *debug.Event, 10) ds.ch = ch - ds.handler = MakeEventHandler(ch) + ds.handler = makeEventHandler(ch) } return ds, nil } +func (ds *debugSession) waitFor( + ctx context.Context, + eventTypes ...debug.EventType, +) *debug.Event { + for { + select { + case e := <-ds.ch: + if slices.Contains(eventTypes, e.Type) { + return e + } + case <-ctx.Done(): + return nil + } + } +} + func (ds *debugSession) startDebugger( ctx context.Context, ) error { debugger := debug.NewDebugger( debug.SetEventHandler(ds.handler), ) + // This combination of flags provides roughly the same user + // experience as one would have while debugging imperative + // languages using a standard debugger like lldb or gdb. + // + // Specifically, `StopOnEntry` stops when entering an + // expression, which is like stepping through some, but not + // all, lines and even inside the same line multiple times in + // the case of list/set comprehensions, while `StopOnFail` + // results in stopping at all expressions producing a `false` + // value, which is similar to the previous case in that it + // stops every time a check fails during a list/set + // comprehension. + // + // The previous descriptions must be taken with a grain of + // salt and are likely missing useful cases. That said, the + // described cases are hardly seen when debugging imperative + // languages, which is the user experience we want to provide + // at the moment. Of course, this might change in the future. launchProps := debug.LaunchEvalProperties{ LaunchProperties: debug.LaunchProperties{ StopOnEntry: false, @@ -242,11 +271,11 @@ func (ds *debugSession) Start(ctx context.Context) error { } fmt.Fprintf(&b, "Restarted") case line == "c": - if err := ds.session.Resume(thr); err != nil { + if err := ds.session.ResumeAll(); err != nil { return fmt.Errorf("error resuming execution: %w", err) } - evt := ds.WaitFor(ctx, + evt := ds.waitFor(ctx, debug.ExceptionEventType, debug.StoppedEventType, debug.StdoutEventType, @@ -280,6 +309,9 @@ func (ds *debugSession) Start(ctx context.Context) error { case debug.TerminatedEventType: fmt.Fprintf(&b, "\nTerminated\n") } + case line == "q": + return errUserAbort + case line == "locals": if err := printLocals(&b, ds.session, thr); err != nil { return fmt.Errorf("error printing locals: %w", err) @@ -343,11 +375,23 @@ func (ds *debugSession) Start(ctx context.Context) error { return fmt.Errorf("error getting stack trace: %w", err) } if loc := getCurrentLocation(stack); loc != nil { - loc.Row += 1 // let's hope it always exists... - loc.Col = 0 + // Unfortunately, getting the column + // right is tricky, since source-level + // breakpoints only look at line + // numbers in the REGO interpreter, so + // the safest assumption is starting + // from 0. + // + // It would be great if the frame + // struct contained details about the + // position in the source. + nextloc := location.Location{ + Row: loc.Row + 1, // let's hope it always exists... + Col: 0, + } // add internal breakpoint - bp, err := ds.session.AddBreakpoint(*loc) + bp, err := ds.session.AddBreakpoint(nextloc) if err != nil { return fmt.Errorf("error setting breakpoint: %w", err) } @@ -357,7 +401,7 @@ func (ds *debugSession) Start(ctx context.Context) error { return fmt.Errorf("error resuming execution: %w", err) } - evt := ds.WaitFor(ctx, debug.StoppedEventType) + evt := ds.waitFor(ctx, debug.StoppedEventType) stack, err := ds.session.StackTrace(evt.Thread) if err != nil { return fmt.Errorf("error getting stack trace: %w", err) @@ -374,43 +418,39 @@ func (ds *debugSession) Start(ctx context.Context) error { case line == "s", line == "sv": if err := ds.session.StepOver(thr); err != nil { - panic(err) + return fmt.Errorf("error on step-over: %w", err) } - evt := ds.WaitFor(ctx, debug.StoppedEventType) + evt := ds.waitFor(ctx, debug.StoppedEventType) stack, err := ds.session.StackTrace(evt.Thread) if err != nil { return fmt.Errorf("error getting stack trace: %w", err) } printSource(&b, ds.src, stack) case line == "si": - go func() { - if err := ds.session.StepIn(thr); err != nil { - panic(err) - } - }() - evt := ds.WaitFor(ctx, debug.StoppedEventType) + if err := ds.session.StepIn(thr); err != nil { + return fmt.Errorf("error on step-in: %w", err) + } + evt := ds.waitFor(ctx, debug.StoppedEventType) stack, err := ds.session.StackTrace(evt.Thread) if err != nil { return fmt.Errorf("error getting stack trace: %w", err) } printSource(&b, ds.src, stack) case line == "so": - go func() { - if err := ds.session.StepOut(thr); err != nil { - panic(err) - } - }() - evt := ds.WaitFor(ctx, debug.StoppedEventType) + if err := ds.session.StepOut(thr); err != nil { + return fmt.Errorf("error on step-out: %w", err) + } + evt := ds.waitFor(ctx, debug.StoppedEventType) stack, err := ds.session.StackTrace(evt.Thread) if err != nil { return fmt.Errorf("error getting stack trace: %w", err) } printSource(&b, ds.src, stack) - case line == "q": - return fmt.Errorf("user abort") + case line == "h", line == "help": printHelp(&b) + case strings.HasPrefix(line, "p"): varname, err := toVarName(line) if err != nil { @@ -433,6 +473,7 @@ func (ds *debugSession) Start(ctx context.Context) error { if err := printVar(&b, r, ds.session, thr); err != nil { return fmt.Errorf("error printing variables: %w", err) } + case strings.HasPrefix(line, "b"): loc, err := toLocation(line, ds.lines) if err != nil { @@ -627,10 +668,24 @@ func printSource(b *strings.Builder, src string, stack debug.StackTrace) { for idx, line := range strings.Split(src, "\n") { fmt.Fprintf(b, "%*d: %s", padding, idx+1, line) if idx+1 == loc.Row { - theline := strings.Split(string(loc.Text), "\n")[0] + // `theline` is the very first line of + // the expression starting at the + // given position. + // + // In REGO expressions can span + // multiple lines (for example, rules + // do), but we really are interested + // in underlining only the first line + // of the given expression. + // + // For weird underlyining starting + // from column 0 of the line, see + // comment on setting source-level + // breakpoints. + theline := strings.Split(line, "\n")[0] fmt.Fprintf(b, "\n%s%s", - strings.Repeat(" ", loc.Col+int(padding)+2-1), - cli.SimpleBoldStyle.Render(strings.Repeat("^", len(theline))), + strings.Repeat(" ", int(padding)+2+loc.Col-1), + cli.SimpleBoldStyle.Render(strings.Repeat("^", len(theline)-loc.Col+1)), ) } fmt.Fprintln(b) @@ -654,6 +709,7 @@ func printLocals(b *strings.Builder, s debug.Session, thrID debug.ThreadID) erro } if len(trace) == 0 { + fmt.Fprintln(b, "No locals") return nil } @@ -843,7 +899,7 @@ Breakpoints: clearall/cla -- clear all breakpoints Stepping: - n ------------- next line + n/next--------- next line s/sv ---------- step over so ------------ step out si ------------ step into