Skip to content

Commit

Permalink
fixer: First draft implementation of --fix
Browse files Browse the repository at this point in the history
WIP while I work on some other fixes to prove out the interfaces.

Signed-off-by: Charlie Egan <[email protected]>
  • Loading branch information
charlieegan3 committed Apr 15, 2024
1 parent c31886e commit 0460b60
Show file tree
Hide file tree
Showing 9 changed files with 457 additions and 18 deletions.
51 changes: 51 additions & 0 deletions cmd/lint.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
rio "github.com/styrainc/regal/internal/io"
regalmetrics "github.com/styrainc/regal/internal/metrics"
"github.com/styrainc/regal/pkg/config"
"github.com/styrainc/regal/pkg/fixer"
"github.com/styrainc/regal/pkg/fixer/fixes"
"github.com/styrainc/regal/pkg/linter"
"github.com/styrainc/regal/pkg/report"
"github.com/styrainc/regal/pkg/reporter"
Expand All @@ -38,6 +40,7 @@ type lintCommandParams struct {
enablePrint bool
metrics bool
profile bool
fix bool
disable repeatedStringFlag
disableAll bool
disableCategory repeatedStringFlag
Expand Down Expand Up @@ -153,6 +156,8 @@ func init() {
"enable metrics reporting (currently supported only for JSON output format)")
lintCommand.Flags().BoolVar(&params.profile, "profile", false,
"enable profiling metrics to be added to reporting (currently supported only for JSON output format)")
lintCommand.Flags().BoolVar(&params.fix, "fix", false,
"enable automatic fixing of violations where supported")

lintCommand.Flags().VarP(&params.disable, "disable", "d",
"disable specific rule(s). This flag can be repeated.")
Expand Down Expand Up @@ -307,6 +312,13 @@ func lint(args []string, params lintCommandParams) (report.Report, error) {
return report.Report{}, fmt.Errorf("error(s) encountered while linting: %w", err)
}

if params.fix {
err = fixReport(&result)
if err != nil {
return report.Report{}, fmt.Errorf("error(s) encountered while fixing: %w", err)
}
}

rep, err := getReporter(params.format, outputWriter)
if err != nil {
return report.Report{}, fmt.Errorf("failed to get reporter: %w", err)
Expand Down Expand Up @@ -383,3 +395,42 @@ func getWriterForOutputFile(filename string) (io.Writer, error) {

return f, nil
}

func fixReport(rep *report.Report) error {
fileReaders := make(map[string]io.Reader)

for _, v := range rep.Violations {
f, err := os.Open(v.Location.File)
if err != nil {
return fmt.Errorf("failed to open file for fixing %s: %w", v.Location.File, err)
}

defer f.Close()

fileReaders[v.Location.File] = f
}

fixResult, err := fixer.Fix(rep, fileReaders, fixes.Options{})
if err != nil {
return fmt.Errorf("error encountered while fixing: %w", err)
}

for file, content := range fixResult {
f, err := os.OpenFile(file, os.O_RDWR|os.O_TRUNC, 0o755)
if err != nil {
return fmt.Errorf("failed to open file %s: %w", file, err)
}

_, err = f.Write(content)
if err != nil {
return fmt.Errorf("failed to write to file %s: %w", file, err)
}

err = f.Close()
if err != nil {
return fmt.Errorf("failed to close file %s: %w", file, err)
}
}

return nil
}
30 changes: 29 additions & 1 deletion e2e/cli_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,42 @@ func TestLintNonExistentDir(t *testing.T) {
expectExitCode(t, err, 1, &stdout, &stderr)

if exp, act := "", stdout.String(); exp != act {
t.Errorf("expected stderr %q, got %q", exp, act)
t.Errorf("expected stdout %q, got %q", exp, act)
}

if exp, act := "error(s) encountered while linting: errors encountered when reading files to lint: "+
"failed to filter paths:\nstat "+td+filepath.FromSlash("/what/ever")+": no such file or directory\n",
stderr.String(); exp != act {
t.Errorf("expected stderr %q, got %q", exp, act)
}
}

func TestLintAndFix(t *testing.T) {
t.Parallel()

stdout := bytes.Buffer{}
stderr := bytes.Buffer{}
td := t.TempDir()

// only violation is for the opa-fmt rule
unformattedContents := []byte("package test\nimport rego.v1\nallow := true")
err := os.WriteFile(filepath.Join(td, "main.rego"), unformattedContents, 0644)
if err != nil {
t.Fatalf("failed to write main.rego: %v", err)
}

err = regal(&stdout, &stderr)("lint", "--fix", td)

// 0 exit status is expected as all violations should have been fixed
expectExitCode(t, err, 0, &stdout, &stderr)

if exp, act := "1 file linted. No violations found.\n", stdout.String(); exp != act {
t.Errorf("expected stdout %q, got %q", exp, act)
}

if exp, act := "", stderr.String(); exp != act {
t.Errorf("expected stderr %q, got %q", exp, act)
}
}

func TestLintAllViolations(t *testing.T) {
Expand Down
13 changes: 0 additions & 13 deletions internal/lsp/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,9 @@
package lsp

import (
"fmt"
"path/filepath"
"strings"

"github.com/open-policy-agent/opa/format"
)

func Format(path, contents string, opts format.Opts) (string, error) {
formatted, err := format.SourceWithOpts(filepath.Base(path), []byte(contents), opts)
if err != nil {
return "", fmt.Errorf("failed to format Rego source file: %w", err)
}

return string(formatted), nil
}

// ComputeEdits computes diff edits from 2 string inputs.
func ComputeEdits(before, after string) []TextEdit {
ops := operations(splitLines(before), splitLines(after))
Expand Down
22 changes: 18 additions & 4 deletions internal/lsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
lsconfig "github.com/styrainc/regal/internal/lsp/config"
"github.com/styrainc/regal/internal/lsp/uri"
"github.com/styrainc/regal/pkg/config"
"github.com/styrainc/regal/pkg/fixer/fixes"
)

const (
Expand Down Expand Up @@ -239,12 +240,19 @@ func (l *LanguageServer) formatToEdits(params ExecuteCommandParams, opts format.
return nil, target, fmt.Errorf("could not get file contents for uri %q", target)
}

newContent, err := Format(uri.ToPath(l.clientIdentifier, target), oldContent, opts)
fixed, formattedContent, err := fixes.Fmt([]byte(oldContent), &fixes.FmtOptions{
Filename: uri.ToPath(l.clientIdentifier, target),
OPAFmtOpts: opts,
})
if err != nil {
return nil, target, fmt.Errorf("failed to format file: %w", err)
}

return ComputeEdits(oldContent, newContent), target, nil
if !fixed {
return []TextEdit{}, target, nil
}

return ComputeEdits(oldContent, string(formattedContent)), target, nil
}

func (l *LanguageServer) StartConfigWorker(ctx context.Context) {
Expand Down Expand Up @@ -633,12 +641,18 @@ func (l *LanguageServer) handleTextDocumentFormatting(
return nil, fmt.Errorf("failed to get file contents for uri %q", params.TextDocument.URI)
}

newContent, err := Format(params.TextDocument.URI, oldContent, format.Opts{})
fixed, formattedContent, err := fixes.Fmt([]byte(oldContent), &fixes.FmtOptions{
Filename: uri.ToPath(l.clientIdentifier, params.TextDocument.URI),
})
if err != nil {
return nil, fmt.Errorf("failed to format file: %w", err)
}

return ComputeEdits(oldContent, newContent), nil
if !fixed {
return []TextEdit{}, nil
}

return ComputeEdits(oldContent, string(formattedContent)), nil
}

func (l *LanguageServer) handleWorkspaceDidCreateFiles(
Expand Down
117 changes: 117 additions & 0 deletions pkg/fixer/fixer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package fixer

import (
"fmt"
"io"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/format"

"github.com/styrainc/regal/pkg/fixer/fixes"
"github.com/styrainc/regal/pkg/report"
)

func Fix(rep *report.Report, readers map[string]io.Reader, _ fixes.Options) (map[string][]byte, error) {
fixableViolations := map[string]struct{}{
"opa-fmt": {},
"use-rego-v1": {},
}

filesToFix, err := computeFilesToFix(rep, readers, fixableViolations)
if err != nil {
return nil, fmt.Errorf("failed to determine files to fix: %w", err)
}

fixResults := make(map[string][]byte)

var fixedViolations []int

for file, content := range filesToFix {
for i, violation := range rep.Violations {
_, ok := fixableViolations[violation.Title]
if !ok {
continue
}

fixed := true

switch violation.Title {
case "opa-fmt":
fixed, fixedContent, err := fixes.Fmt(content, &fixes.FmtOptions{
Filename: file,
})
if err != nil {
return nil, fmt.Errorf("failed to fix %s: %w", file, err)
}

if fixed {
fixResults[file] = fixedContent
}
case "use-rego-v1":
fixed, fixedContent, err := fixes.Fmt(content, &fixes.FmtOptions{
Filename: file,
OPAFmtOpts: format.Opts{
RegoVersion: ast.RegoV0CompatV1,
},
})
if err != nil {
return nil, fmt.Errorf("failed to fix %s: %w", file, err)
}

if fixed {
fixResults[file] = fixedContent
}
default:
fixed = false
}

if fixed {
fixedViolations = append(fixedViolations, i)
}
}
}

for i := len(fixedViolations) - 1; i >= 0; i-- {
rep.Violations = append(rep.Violations[:fixedViolations[i]], rep.Violations[fixedViolations[i]+1:]...)
}

rep.Summary.NumViolations = len(rep.Violations)

return fixResults, nil
}

func computeFilesToFix(
rep *report.Report,
readers map[string]io.Reader,
fixableViolations map[string]struct{},
) (map[string][]byte, error) {
filesToFix := make(map[string][]byte)

// determine which files need to be fixed
for _, violation := range rep.Violations {
file := violation.Location.File

// skip files already marked for fixing
if _, ok := filesToFix[file]; ok {
continue
}

// skip violations that are not fixable
if _, ok := fixableViolations[violation.Title]; !ok {
continue
}

if _, ok := readers[file]; !ok {
return nil, fmt.Errorf("no reader for fixable file %s", file)
}

bs, err := io.ReadAll(readers[file])
if err != nil {
return nil, fmt.Errorf("failed to read file %s: %w", file, err)
}

filesToFix[violation.Location.File] = bs
}

return filesToFix, nil
}
Loading

0 comments on commit 0460b60

Please sign in to comment.