Skip to content

Commit

Permalink
Ordered fixes
Browse files Browse the repository at this point in the history
Signed-off-by: Charlie Egan <[email protected]>
  • Loading branch information
charlieegan3 committed Apr 16, 2024
1 parent c243dd3 commit 3e9faac
Show file tree
Hide file tree
Showing 9 changed files with 267 additions and 47 deletions.
30 changes: 25 additions & 5 deletions e2e/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,14 @@ func TestFix(t *testing.T) {
td := t.TempDir()

// only violation is for the opa-fmt rule
unformattedContents := []byte("package test\nimport rego.v1\nallow := true")
unformattedContents := []byte(`package wow
#comment
allow = 1 {
input.foo == true
}
`)
err := os.WriteFile(filepath.Join(td, "main.rego"), unformattedContents, 0644)
if err != nil {
t.Fatalf("failed to write main.rego: %v", err)
Expand All @@ -747,9 +754,11 @@ func TestFix(t *testing.T) {
// 0 exit status is expected as all violations should have been fixed
expectExitCode(t, err, 0, &stdout, &stderr)

exp := fmt.Sprintf(`1 fix applied:
exp := fmt.Sprintf(`3 fixes applied:
%s/main.rego:
- opa-fmt
- no-whitespace-comment
- use-assignment-operator
- use-rego-v1
`, td)

if act := stdout.String(); exp != act {
Expand All @@ -766,8 +775,19 @@ func TestFix(t *testing.T) {
t.Fatalf("failed to read main.rego: %v", err)
}

if exp, act := "package test\n\nimport rego.v1\n\nallow := true\n", string(bs); exp != act {
t.Errorf("expected\n%s, got\n%s", exp, act)
expectedContent := `package wow
import rego.v1
# comment
allow := 1 if {
input.foo == true
}
`

if act := string(bs); expectedContent != act {
t.Errorf("expected\n%s, got\n%s", expectedContent, act)
}
}

Expand Down
101 changes: 62 additions & 39 deletions pkg/fixer/fixer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ func (r *Report) SetFileContents(file string, content []byte) {
r.fileContents[file] = content
}

func (r *Report) GetFileContents(file string) ([]byte, bool) {
content, ok := r.fileContents[file]

return content, ok
}

func (r *Report) SetFileFixedViolation(file string, violation string) {
if _, ok := r.fileFixedViolations[file]; !ok {
r.fileFixedViolations[file] = make(map[string]struct{})
Expand Down Expand Up @@ -71,35 +77,17 @@ func (r *Report) TotalFixes() int {
return r.totalFixes
}

type FixToggles struct {
OPAFmt bool
OPAFmtRegoV1 bool
}

func (f *FixToggles) IsEnabled(key string) bool {
if f == nil {
return false
}

switch key {
case "opa-fmt":
return f.OPAFmt
case "use-rego-v1":
return f.OPAFmtRegoV1
}

return false
}

func NewDefaultFixes() []fixes.Fix {
return []fixes.Fix{
&fixes.Fmt{},
&fixes.Fmt{
KeyOverride: "use-rego-v1",
OPAFmtOpts: format.Opts{
RegoVersion: ast.RegoV0CompatV1,
},
},
&fixes.UseAssignmentOperator{},
&fixes.NoWhitespaceComment{},
}
}

Expand Down Expand Up @@ -131,6 +119,28 @@ func (f *Fixer) GetFixForKey(key string) (fixes.Fix, bool) {
return fixInstance, true
}

func (f *Fixer) OrderedFixes() []fixes.Fix {
orderedFixes := make([]fixes.Fix, 0)
wholeFileFixes := make([]fixes.Fix, 0)

for _, fix := range f.registeredFixes {
fixInstance, ok := fix.(fixes.Fix)
if !ok {
continue
}

if fixInstance.WholeFile() {
wholeFileFixes = append(wholeFileFixes, fixInstance)

continue
}

orderedFixes = append(orderedFixes, fixInstance)
}

return append(orderedFixes, wholeFileFixes...)
}

func (f *Fixer) Fix(rep *report.Report, readers map[string]io.Reader) (*Report, error) {
filesToFix, err := computeFilesToFix(f, rep, readers)
if err != nil {
Expand All @@ -139,25 +149,38 @@ func (f *Fixer) Fix(rep *report.Report, readers map[string]io.Reader) (*Report,

fixReport := NewReport()

for file, content := range filesToFix {
for _, violation := range rep.Violations {
fixInstance, ok := f.GetFixForKey(violation.Title)
if !ok {
continue
}

fixed, fixedContent, err := fixInstance.Fix(content, &fixes.RuntimeOptions{
Metadata: fixes.RuntimeMetadata{
Filename: file,
},
})
if err != nil {
return nil, fmt.Errorf("failed to fix %s: %w", file, err)
}

if fixed {
fixReport.SetFileContents(file, fixedContent)
fixReport.SetFileFixedViolation(file, violation.Title)
for _, fixInstance := range f.OrderedFixes() {
// fix by line
for file, content := range filesToFix {
for _, violation := range rep.Violations {
if violation.Title != fixInstance.Key() {
continue
}

// if the file has been fixed, use the fixed content from other fixes
if fixedContent, ok := fixReport.GetFileContents(file); ok {
content = fixedContent
}

fixed, fixedContent, err := fixInstance.Fix(content, &fixes.RuntimeOptions{
Metadata: fixes.RuntimeMetadata{
Filename: file,
},
Locations: []ast.Location{
{
Row: violation.Location.Row,
Col: violation.Location.Column,
},
},
})
if err != nil {
return nil, fmt.Errorf("failed to fix %s: %w", file, err)
}

if fixed {
fixReport.SetFileContents(file, fixedContent)
fixReport.SetFileFixedViolation(file, violation.Title)
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions pkg/fixer/fixes/fixes.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import "github.com/open-policy-agent/opa/ast"

type Fix interface {
Key() string
WholeFile() bool
Fix(in []byte, opts *RuntimeOptions) (bool, []byte, error)
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/fixer/fixes/fmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ func (f *Fmt) Key() string {
return "opa-fmt"
}

func (*Fmt) WholeFile() bool {
return true
}

func (f *Fmt) Fix(in []byte, opts *RuntimeOptions) (bool, []byte, error) {
filename := ""

Expand Down
43 changes: 43 additions & 0 deletions pkg/fixer/fixes/nowhitespacecomment.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package fixes

import (
"bytes"
)

type NoWhitespaceComment struct{}

func (*NoWhitespaceComment) Key() string {
return "no-whitespace-comment"
}

func (*NoWhitespaceComment) WholeFile() bool {
return false
}

func (*NoWhitespaceComment) Fix(in []byte, opts *RuntimeOptions) (bool, []byte, error) {
lines := bytes.Split(in, []byte("\n"))

// this fix must have locations
if len(opts.Locations) == 0 {
return false, nil, nil
}

for _, loc := range opts.Locations {
if loc.Row > len(lines) {
return false, nil, nil
}

if loc.Col != 1 {
// current impl only understands the first column
return false, nil, nil
}

line := lines[loc.Row-1]

if bytes.HasPrefix(line, []byte("#")) && !bytes.HasPrefix(line, []byte("# ")) {
lines[loc.Row-1] = bytes.Replace(line, []byte("#"), []byte("# "), 1)
}
}

return true, bytes.Join(lines, []byte("\n")), nil
}
120 changes: 120 additions & 0 deletions pkg/fixer/fixes/nowhitespacecomment_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package fixes

import (
"testing"

"github.com/open-policy-agent/opa/ast"
)

func TestNoWhitespaceComment(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
runtimeOptions *RuntimeOptions

beforeFix []byte
afterFix []byte

fixExpected bool
}{
"no change needed": {
beforeFix: []byte(`package test\n
# this is a comment
`),
afterFix: []byte(`package test\n
# this is a comment
`),
fixExpected: false,
runtimeOptions: &RuntimeOptions{},
},
"no change made because no location": {
beforeFix: []byte(`package test\n
#this is a comment
`),
afterFix: []byte(`package test\n
#this is a comment
`),
fixExpected: false,
runtimeOptions: &RuntimeOptions{},
},
"single change": {
beforeFix: []byte(`package test\n
#this is a comment
`),
afterFix: []byte(`package test\n
# this is a comment
`),
fixExpected: true,
runtimeOptions: &RuntimeOptions{
Locations: []ast.Location{
{
Row: 3,
Col: 1, // this is what the rule outputs at the moment
},
},
},
},
"many changes": {
beforeFix: []byte(`package test\n
#this is a comment
#this is a comment
#this is a comment
`),
afterFix: []byte(`package test\n
# this is a comment
# this is a comment
# this is a comment
`),
fixExpected: true,
runtimeOptions: &RuntimeOptions{
Locations: []ast.Location{
{
Row: 3,
Col: 1,
},
{
Row: 4,
Col: 1,
},
{
Row: 5,
Col: 1,
},
},
},
},
}

for testName, tc := range testCases {
tc := tc

nwc := NoWhitespaceComment{}

t.Run(testName, func(t *testing.T) {
t.Parallel()

fixed, fixedContent, err := nwc.Fix(tc.beforeFix, tc.runtimeOptions)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if tc.fixExpected && !fixed {
t.Fatalf("expected fix to be applied")
}

if tc.fixExpected && string(fixedContent) != string(tc.afterFix) {
t.Fatalf("unexpected content, got:\n%s---\nexpected:\n%s---",
string(fixedContent),
string(tc.afterFix))
}
})
}
}
4 changes: 4 additions & 0 deletions pkg/fixer/fixes/useassignmentoperator.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ func (*UseAssignmentOperator) Key() string {
return "use-assignment-operator"
}

func (*UseAssignmentOperator) WholeFile() bool {
return false
}

func (*UseAssignmentOperator) Fix(in []byte, opts *RuntimeOptions) (bool, []byte, error) {
lines := bytes.Split(in, []byte("\n"))

Expand Down
Loading

0 comments on commit 3e9faac

Please sign in to comment.