Skip to content

Commit

Permalink
Add new opa test sub-command
Browse files Browse the repository at this point in the history
These changes add a new sub-command to execute policy tests.

Fixes #428
  • Loading branch information
tsandall committed Sep 7, 2017
1 parent 9ab9d89 commit 003c630
Show file tree
Hide file tree
Showing 5 changed files with 522 additions and 0 deletions.
120 changes: 120 additions & 0 deletions cmd/test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.

package cmd

import (
"context"
"fmt"
"os"
"time"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/tester"
"github.com/spf13/cobra"
)

var testParams = struct {
verbose bool
errLimit int
timeout time.Duration
}{}

var testCommand = &cobra.Command{
Use: "test",
Short: "Execute Rego test cases",
Long: `Execute Rego test cases.
The 'test' command takes a file or directory path as input and executes all
test cases discovered in matching files. Test cases are rules whose names have the prefix "test_".
Example policy (example/authz.rego):
package authz
allow {
input.path = ["users"]
input.method = "POST"
}
allow {
input.path = ["users", profile_id]
input.method = "GET"
profile_id = input.user_id
}
Example test (example/authz_test.rego):
package authz
test_post_allowed {
allow with input as {"path": ["users"], "method": "POST"}
}
test_get_denied {
not allow with input as {"path": ["users"], "method": "GET"}
}
test_get_user_allowed {
allow with input as {"path": ["users", "bob"], "method": "GET", "user_id": "bob"}
}
test_get_another_user_denied {
not allow with input as {"path": ["users", "bob"], "method": "GET", "user_id": "alice"}
}
Example test run:
$ opa test ./example/
`,
Run: func(cmd *cobra.Command, args []string) {
os.Exit(opaTest(args))
},
}

func opaTest(args []string) int {

compiler := ast.NewCompiler().SetErrorLimit(testParams.errLimit)

ctx, cancel := context.WithTimeout(context.Background(), testParams.timeout)
defer cancel()

ch, err := tester.NewRunner().SetCompiler(compiler).Paths(ctx, args...)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 1
}

reporter := tester.PrettyReporter{
Verbose: testParams.verbose,
Output: os.Stdout,
}

exitCode := 0
dup := make(chan *tester.Result)

go func() {
defer close(dup)
for tr := range ch {
if !tr.Pass() {
exitCode = 1
}
dup <- tr
}
}()

if err := reporter.Report(dup); err != nil {
fmt.Fprintln(os.Stderr, err)
return 1
}

return exitCode
}

func init() {
testCommand.Flags().BoolVarP(&testParams.verbose, "verbose", "v", false, "set verbose reporting mode")
testCommand.Flags().DurationVarP(&testParams.timeout, "timeout", "t", time.Second*5, "set test timeout")
setMaxErrors(testCommand.Flags(), &testParams.errLimit)
RootCommand.AddCommand(testCommand)
}
63 changes: 63 additions & 0 deletions tester/reporter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.

package tester

import (
"fmt"
"io"
"strings"
)

// PrettyReporter reports test results in a simple human readable format.
type PrettyReporter struct {
Output io.Writer
Verbose bool
}

// Report prints the test report to the reporter's output.
func (r PrettyReporter) Report(ch chan *Result) error {

dirty := false
var pass, fail, errs int

// Report individual tests.
for tr := range ch {
if tr.Pass() {
pass++
} else if tr.Error != nil {
errs++
} else if tr.Fail != nil {
fail++
}
if !tr.Pass() || r.Verbose {
fmt.Fprintln(r.Output, tr)
dirty = true
}
if tr.Error != nil {
fmt.Fprintf(r.Output, " %v\n", tr.Error)
}
}

// Report summary of test.
if dirty {
fmt.Fprintln(r.Output, strings.Repeat("-", 80))
}

total := pass + fail + errs

if pass != 0 {
fmt.Fprintln(r.Output, "PASS:", fmt.Sprintf("%d/%d", pass, total))
}

if fail != 0 {
fmt.Fprintln(r.Output, "FAIL:", fmt.Sprintf("%d/%d", fail, total))
}

if errs != 0 {
fmt.Fprintln(r.Output, "ERROR:", fmt.Sprintf("%d/%d", errs, total))
}

return nil
}
51 changes: 51 additions & 0 deletions tester/reporter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package tester

import (
"bytes"
"fmt"
"testing"
)

func TestPrettyReporter(t *testing.T) {

var badResult interface{} = "fail"

ts := []*Result{
{nil, "data.foo.bar", "test_baz", nil, nil, 0},
{nil, "data.foo.bar", "test_qux", nil, fmt.Errorf("some err"), 0},
{nil, "data.foo.bar", "test_corge", &badResult, nil, 0},
}

var buf bytes.Buffer

r := PrettyReporter{
Output: &buf,
Verbose: true,
}

ch := make(chan *Result)
go func() {
for _, tr := range ts {
ch <- tr
}
close(ch)
}()

if err := r.Report(ch); err != nil {
t.Fatal(err)
}

exp := `data.foo.bar.test_baz: PASS (0s)
data.foo.bar.test_qux: ERROR (0s)
some err
data.foo.bar.test_corge: FAIL (0s)
--------------------------------------------------------------------------------
PASS: 1/3
FAIL: 1/3
ERROR: 1/3
`

if exp != buf.String() {
t.Fatalf("Expected:\n\n%v\n\nGot:\n\n%v", exp, buf.String())
}
}
Loading

0 comments on commit 003c630

Please sign in to comment.