diff --git a/cmd/test.go b/cmd/test.go new file mode 100644 index 0000000000..bcb56b92bf --- /dev/null +++ b/cmd/test.go @@ -0,0 +1,120 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package cmd + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/tester" + "github.com/spf13/cobra" +) + +var testParams = struct { + verbose bool + errLimit int + timeout time.Duration +}{} + +var testCommand = &cobra.Command{ + Use: "test", + Short: "Execute Rego test cases", + Long: `Execute Rego test cases. + +The 'test' command takes a file or directory path as input and executes all +test cases discovered in matching files. Test cases are rules whose names have the prefix "test_". + +Example policy (example/authz.rego): + + package authz + + allow { + input.path = ["users"] + input.method = "POST" + } + + allow { + input.path = ["users", profile_id] + input.method = "GET" + profile_id = input.user_id + } + +Example test (example/authz_test.rego): + + package authz + + test_post_allowed { + allow with input as {"path": ["users"], "method": "POST"} + } + + test_get_denied { + not allow with input as {"path": ["users"], "method": "GET"} + } + + test_get_user_allowed { + allow with input as {"path": ["users", "bob"], "method": "GET", "user_id": "bob"} + } + + test_get_another_user_denied { + not allow with input as {"path": ["users", "bob"], "method": "GET", "user_id": "alice"} + } + +Example test run: + + $ opa test ./example/ +`, + Run: func(cmd *cobra.Command, args []string) { + os.Exit(opaTest(args)) + }, +} + +func opaTest(args []string) int { + + compiler := ast.NewCompiler().SetErrorLimit(testParams.errLimit) + + ctx, cancel := context.WithTimeout(context.Background(), testParams.timeout) + defer cancel() + + ch, err := tester.NewRunner().SetCompiler(compiler).Paths(ctx, args...) + if err != nil { + fmt.Fprintln(os.Stderr, err) + return 1 + } + + reporter := tester.PrettyReporter{ + Verbose: testParams.verbose, + Output: os.Stdout, + } + + exitCode := 0 + dup := make(chan *tester.Result) + + go func() { + defer close(dup) + for tr := range ch { + if !tr.Pass() { + exitCode = 1 + } + dup <- tr + } + }() + + if err := reporter.Report(dup); err != nil { + fmt.Fprintln(os.Stderr, err) + return 1 + } + + return exitCode +} + +func init() { + testCommand.Flags().BoolVarP(&testParams.verbose, "verbose", "v", false, "set verbose reporting mode") + testCommand.Flags().DurationVarP(&testParams.timeout, "timeout", "t", time.Second*5, "set test timeout") + setMaxErrors(testCommand.Flags(), &testParams.errLimit) + RootCommand.AddCommand(testCommand) +} diff --git a/tester/reporter.go b/tester/reporter.go new file mode 100644 index 0000000000..a2f70aa7ac --- /dev/null +++ b/tester/reporter.go @@ -0,0 +1,63 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package tester + +import ( + "fmt" + "io" + "strings" +) + +// PrettyReporter reports test results in a simple human readable format. +type PrettyReporter struct { + Output io.Writer + Verbose bool +} + +// Report prints the test report to the reporter's output. +func (r PrettyReporter) Report(ch chan *Result) error { + + dirty := false + var pass, fail, errs int + + // Report individual tests. + for tr := range ch { + if tr.Pass() { + pass++ + } else if tr.Error != nil { + errs++ + } else if tr.Fail != nil { + fail++ + } + if !tr.Pass() || r.Verbose { + fmt.Fprintln(r.Output, tr) + dirty = true + } + if tr.Error != nil { + fmt.Fprintf(r.Output, " %v\n", tr.Error) + } + } + + // Report summary of test. + if dirty { + fmt.Fprintln(r.Output, strings.Repeat("-", 80)) + } + + total := pass + fail + errs + + if pass != 0 { + fmt.Fprintln(r.Output, "PASS:", fmt.Sprintf("%d/%d", pass, total)) + } + + if fail != 0 { + fmt.Fprintln(r.Output, "FAIL:", fmt.Sprintf("%d/%d", fail, total)) + } + + if errs != 0 { + fmt.Fprintln(r.Output, "ERROR:", fmt.Sprintf("%d/%d", errs, total)) + } + + return nil +} diff --git a/tester/reporter_test.go b/tester/reporter_test.go new file mode 100644 index 0000000000..596803278a --- /dev/null +++ b/tester/reporter_test.go @@ -0,0 +1,51 @@ +package tester + +import ( + "bytes" + "fmt" + "testing" +) + +func TestPrettyReporter(t *testing.T) { + + var badResult interface{} = "fail" + + ts := []*Result{ + {nil, "data.foo.bar", "test_baz", nil, nil, 0}, + {nil, "data.foo.bar", "test_qux", nil, fmt.Errorf("some err"), 0}, + {nil, "data.foo.bar", "test_corge", &badResult, nil, 0}, + } + + var buf bytes.Buffer + + r := PrettyReporter{ + Output: &buf, + Verbose: true, + } + + ch := make(chan *Result) + go func() { + for _, tr := range ts { + ch <- tr + } + close(ch) + }() + + if err := r.Report(ch); err != nil { + t.Fatal(err) + } + + exp := `data.foo.bar.test_baz: PASS (0s) +data.foo.bar.test_qux: ERROR (0s) + some err +data.foo.bar.test_corge: FAIL (0s) +-------------------------------------------------------------------------------- +PASS: 1/3 +FAIL: 1/3 +ERROR: 1/3 +` + + if exp != buf.String() { + t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, buf.String()) + } +} diff --git a/tester/runner.go b/tester/runner.go new file mode 100644 index 0000000000..e684a1ac5a --- /dev/null +++ b/tester/runner.go @@ -0,0 +1,178 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +// Package tester contains utilities for executing Rego tests. +package tester + +import ( + "context" + "fmt" + "sort" + "strings" + "time" + + "github.com/open-policy-agent/opa/topdown" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/loader" + "github.com/open-policy-agent/opa/rego" +) + +// TestPrefix declares the prefix for all rules. +const TestPrefix = "test_" + +// Run executes all test cases found under files in path. +func Run(ctx context.Context, path ...string) ([]*Result, error) { + ch, err := NewRunner().Paths(ctx, path...) + if err != nil { + return nil, err + } + result := []*Result{} + for r := range ch { + result = append(result, r) + } + return result, nil +} + +// Result represents a single test case result. +type Result struct { + Location *ast.Location `json:"location"` + Package string `json:"package"` + Name string `json:"name"` + Fail *interface{} `json:"fail,omitempty"` + Error error `json:"error,omitempty"` + Duration time.Duration `json:"duration"` +} + +func newResult(loc *ast.Location, pkg, name string, duration time.Duration) *Result { + return &Result{ + Location: loc, + Package: pkg, + Name: name, + Duration: duration, + } +} + +// Pass returns true if the test case passed. +func (r Result) Pass() bool { + return r.Fail == nil && r.Error == nil +} + +func (r *Result) String() string { + return fmt.Sprintf("%v.%v: %v (%v)", r.Package, r.Name, r.outcome(), r.Duration/time.Microsecond) +} + +func (r *Result) outcome() string { + if r.Pass() { + return "PASS" + } + if r.Fail != nil { + return "FAIL" + } + return "ERROR" +} + +func (r *Result) setFail(fail interface{}) { + r.Fail = &fail +} + +// Runner implements simple test discovery and execution. +type Runner struct { + compiler *ast.Compiler +} + +// NewRunner returns a new runner. +func NewRunner() *Runner { + return &Runner{} +} + +// SetCompiler sets the compiler used by the runner. +func (r *Runner) SetCompiler(compiler *ast.Compiler) *Runner { + r.compiler = compiler + return r +} + +// Paths executes all tests contained in policies under the specified paths. +func (r *Runner) Paths(ctx context.Context, path ...string) (ch chan *Result, err error) { + + if r.compiler == nil { + r.compiler = ast.NewCompiler() + } + + result, err := loader.AllRegos(path) + if err != nil { + return nil, err + } + + modules := map[string]*ast.Module{} + for _, m := range result.Modules { + modules[m.Name] = m.Parsed + } + + return r.Modules(ctx, modules) +} + +// Modules executes all tests contained in the specified modules. +func (r *Runner) Modules(ctx context.Context, modules map[string]*ast.Module) (ch chan *Result, err error) { + + filenames := make([]string, 0, len(modules)) + for name := range modules { + filenames = append(filenames, name) + } + + sort.Strings(filenames) + + if r.compiler.Compile(modules); r.compiler.Failed() { + return nil, r.compiler.Errors + } + + ch = make(chan *Result) + + go func() { + defer close(ch) + for _, name := range filenames { + module := r.compiler.Modules[name] + for _, rule := range module.Rules { + if !strings.HasPrefix(string(rule.Head.Name), TestPrefix) { + continue + } + tr, stop := r.runTest(ctx, module, rule) + ch <- tr + if stop { + return + } + } + } + }() + + return ch, nil +} + +func (r *Runner) runTest(ctx context.Context, mod *ast.Module, rule *ast.Rule) (*Result, bool) { + + rego := rego.New( + rego.Compiler(r.compiler), + rego.Query(rule.Path().String()), + ) + + t0 := time.Now() + rs, err := rego.Eval(ctx) + dt := time.Since(t0) + + tr := newResult(rule.Loc(), mod.Package.Path.String(), string(rule.Head.Name), dt) + var stop bool + + if err != nil { + tr.Error = err + if err, ok := err.(*topdown.Error); ok && err.Code == topdown.CancelErr { + stop = true + } + } else if len(rs) == 0 { + tr.setFail(false) + } else if b, ok := rs[0].Expressions[0].Value.(bool); !ok || !b { + tr.setFail(rs[0].Expressions[0].Value) + } + + return tr, stop +} diff --git a/tester/runner_test.go b/tester/runner_test.go new file mode 100644 index 0000000000..3be3d62c54 --- /dev/null +++ b/tester/runner_test.go @@ -0,0 +1,110 @@ +// Copyright 2017 The OPA Authors. All rights reserved. +// Use of this source code is governed by an Apache2 +// license that can be found in the LICENSE file. + +package tester + +import ( + "context" + "testing" + "time" + + "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/topdown" + "github.com/open-policy-agent/opa/types" + "github.com/open-policy-agent/opa/util/test" +) + +func TestRun(t *testing.T) { + + ctx := context.Background() + + files := map[string]string{ + "/a.rego": `package foo + allow { true } + `, + "/a_test.rego": `package foo + test_pass { allow } + non_test { true } + test_fail { not allow } + test_fail_non_bool = 100 + test_err { conflict } + conflict = true + conflict = false + `, + } + + tests := map[[2]string]struct { + wantErr bool + wantFail bool + }{ + {"data.foo", "test_pass"}: {false, false}, + {"data.foo", "test_fail"}: {false, true}, + {"data.foo", "test_fail_non_bool"}: {false, true}, + {"data.foo", "test_err"}: {true, false}, + } + + test.WithTempFS(files, func(d string) { + rs, err := Run(ctx, d) + if err != nil { + t.Fatal(err) + } + seen := map[[2]string]struct{}{} + for i := range rs { + k := [2]string{rs[i].Package, rs[i].Name} + seen[k] = struct{}{} + exp, ok := tests[k] + if !ok { + t.Errorf("Unexpected result for %v", k) + } else if exp.wantErr != (rs[i].Error != nil) || exp.wantFail != (rs[i].Fail != nil) { + t.Errorf("Expected %v for %v but got: %v", exp, k, rs[i]) + } + } + for k := range tests { + if _, ok := seen[k]; !ok { + t.Errorf("Expected result for %v", k) + } + } + }) +} + +func TestRunnerCancel(t *testing.T) { + + ast.RegisterBuiltin(&ast.Builtin{ + Name: ast.String("test.sleep"), + Args: []types.Type{ + types.S, + }, + }) + + topdown.RegisterFunctionalBuiltinVoid1("test.sleep", func(a ast.Value) error { + d, _ := time.ParseDuration(string(a.(ast.String))) + time.Sleep(d) + return nil + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + module := `package foo + + test_1 { test.sleep("100ms") } + test_2 { true }` + + files := map[string]string{ + "/a_test.rego": module, + } + + test.WithTempFS(files, func(d string) { + ch, err := NewRunner().Paths(ctx, d) + if err != nil { + t.Fatal(err) + } + <-ch + _, ok := <-ch + if ok { + t.Fatal("Expected channel to be closed after first test") + } + }) + +}