From 8adb96d5721bd293a53a2a9a6e4568c697f36a73 Mon Sep 17 00:00:00 2001 From: Charlie Egan Date: Thu, 18 Apr 2024 14:40:40 +0100 Subject: [PATCH] lsp: Implement code actions for new fixes https://github.com/StyraInc/regal/pull/653 added fixes for no-whitespace-comment and use-assignment-operator. This PR adds code actions for these in the regal lsp. Signed-off-by: Charlie Egan --- internal/lsp/command.go | 18 ++ internal/lsp/commands/parse.go | 88 ++++++++ internal/lsp/commands/parse_test.go | 110 ++++++++++ internal/lsp/messages.go | 5 - internal/lsp/server.go | 227 +++++++++++--------- internal/lsp/types/types.go | 6 + pkg/fixer/fixes/fmt_test.go | 2 +- pkg/fixer/fixes/nowhitespacecomment.go | 6 +- pkg/fixer/fixes/nowhitespacecomment_test.go | 41 +--- pkg/fixer/fixes/useassignmentoperator.go | 6 +- 10 files changed, 360 insertions(+), 149 deletions(-) create mode 100644 internal/lsp/commands/parse.go create mode 100644 internal/lsp/commands/parse_test.go create mode 100644 internal/lsp/types/types.go diff --git a/internal/lsp/command.go b/internal/lsp/command.go index 418387d88..d6d05cb58 100644 --- a/internal/lsp/command.go +++ b/internal/lsp/command.go @@ -26,3 +26,21 @@ func FmtV1Command(args []string) Command { Arguments: toAnySlice(args), } } + +func UseAssignmentOperatorCommand(args []string) Command { + return Command{ + Title: "Replace = with := in assignment", + Command: "regal.use-assignment-operator", + Tooltip: "Replace = with := in assignment", + Arguments: toAnySlice(args), + } +} + +func NoWhiteSpaceCommentCommand(args []string) Command { + return Command{ + Title: "Format comment to have leading whitespace", + Command: "regal.no-whitespace-comment", + Tooltip: "Format comment to have leading whitespace", + Arguments: toAnySlice(args), + } +} diff --git a/internal/lsp/commands/parse.go b/internal/lsp/commands/parse.go new file mode 100644 index 000000000..e81fff031 --- /dev/null +++ b/internal/lsp/commands/parse.go @@ -0,0 +1,88 @@ +package commands + +import ( + "errors" + "fmt" + "strconv" + + "github.com/open-policy-agent/opa/ast" + + "github.com/styrainc/regal/internal/lsp/types" +) + +type ParseOptions struct { + TargetArgIndex int + RowArgIndex int + ColArgIndex int +} + +type ParseResult struct { + Target string + Location *ast.Location +} + +// Parse is responsible for extracting the target and location from the given params command params sent from the client +// after acting on a Code Action. +func Parse(params types.ExecuteCommandParams, opts ParseOptions) (*ParseResult, error) { + if len(params.Arguments) == 0 { + return nil, errors.New("no args supplied") + } + + target := "" + + if opts.TargetArgIndex < len(params.Arguments) { + target = fmt.Sprintf("%s", params.Arguments[opts.TargetArgIndex]) + } + + // we can't extract a location from the same location as the target, so location arg positions + // must not have been set in the opts. + if opts.RowArgIndex == opts.TargetArgIndex { + return &ParseResult{ + Target: target, + }, nil + } + + var loc *ast.Location + + if opts.RowArgIndex < len(params.Arguments) && opts.ColArgIndex < len(params.Arguments) { + var row, col int + + switch v := params.Arguments[opts.RowArgIndex].(type) { + case int: + row = v + case string: + var err error + + row, err = strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("failed to parse row: %w", err) + } + default: + return nil, fmt.Errorf("unexpected type for row: %T", params.Arguments[opts.RowArgIndex]) + } + + switch v := params.Arguments[opts.ColArgIndex].(type) { + case int: + col = v + case string: + var err error + + col, err = strconv.Atoi(v) + if err != nil { + return nil, fmt.Errorf("failed to parse col: %w", err) + } + default: + return nil, fmt.Errorf("unexpected type for col: %T", params.Arguments[opts.ColArgIndex]) + } + + loc = &ast.Location{ + Row: row, + Col: col, + } + } + + return &ParseResult{ + Target: target, + Location: loc, + }, nil +} diff --git a/internal/lsp/commands/parse_test.go b/internal/lsp/commands/parse_test.go new file mode 100644 index 000000000..13f4b84c8 --- /dev/null +++ b/internal/lsp/commands/parse_test.go @@ -0,0 +1,110 @@ +package commands + +import ( + "testing" + + "github.com/open-policy-agent/opa/ast" + + "github.com/styrainc/regal/internal/lsp/types" +) + +func TestParse(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + ExecuteCommandParams types.ExecuteCommandParams + ParseOptions ParseOptions + ExpectedTarget string + ExpectedLocation *ast.Location + }{ + "extract target only": { + ExecuteCommandParams: types.ExecuteCommandParams{ + Command: "example", + Arguments: []interface{}{"target"}, + }, + ParseOptions: ParseOptions{TargetArgIndex: 0}, + ExpectedTarget: "target", + ExpectedLocation: nil, + }, + "extract target and location": { + ExecuteCommandParams: types.ExecuteCommandParams{ + Command: "example", + Arguments: []interface{}{"target", "1", 2}, // different types for testing, but should be strings + }, + ParseOptions: ParseOptions{TargetArgIndex: 0, RowArgIndex: 1, ColArgIndex: 2}, + ExpectedTarget: "target", + ExpectedLocation: &ast.Location{Row: 1, Col: 2}, + }, + } + + for name, tc := range testCases { + tc := tc + + t.Run(name, func(t *testing.T) { + t.Parallel() + + result, err := Parse(tc.ExecuteCommandParams, tc.ParseOptions) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Target != tc.ExpectedTarget { + t.Fatalf("expected target %q, got %q", tc.ExpectedTarget, result.Target) + } + + if tc.ExpectedLocation == nil && result.Location != nil { + t.Fatalf("expected location to be nil, got %v", result.Location) + } + + if tc.ExpectedLocation != nil { + if result.Location == nil { + t.Fatalf("expected location to be %v, got nil", tc.ExpectedLocation) + } + + if result.Location.Row != tc.ExpectedLocation.Row { + t.Fatalf("expected row %d, got %d", tc.ExpectedLocation.Row, result.Location.Row) + } + + if result.Location.Col != tc.ExpectedLocation.Col { + t.Fatalf("expected col %d, got %d", tc.ExpectedLocation.Col, result.Location.Col) + } + } + }) + } +} + +func TestParse_Errors(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + ExecuteCommandParams types.ExecuteCommandParams + ParseOptions ParseOptions + ExpectedError string + }{ + "error extracting target": { + ExecuteCommandParams: types.ExecuteCommandParams{ + Command: "example", + Arguments: []interface{}{}, // empty and so nothing can be extracted + }, + ParseOptions: ParseOptions{TargetArgIndex: 0}, + ExpectedError: "no args supplied", + }, + } + + for name, tc := range testCases { + tc := tc + + t.Run(name, func(t *testing.T) { + t.Parallel() + + _, err := Parse(tc.ExecuteCommandParams, tc.ParseOptions) + if err == nil { + t.Fatalf("expected error %q, got nil", tc.ExpectedError) + } + + if err.Error() != tc.ExpectedError { + t.Fatalf("expected error %q, got %q", tc.ExpectedError, err.Error()) + } + }) + } +} diff --git a/internal/lsp/messages.go b/internal/lsp/messages.go index eb33087f2..caceb3da5 100644 --- a/internal/lsp/messages.go +++ b/internal/lsp/messages.go @@ -124,11 +124,6 @@ type ExecuteCommandOptions struct { Commands []string `json:"commands"` } -type ExecuteCommandParams struct { - Command string `json:"command"` - Arguments []any `json:"arguments"` -} - type ApplyWorkspaceEditParams struct { Label string `json:"label"` Edit WorkspaceEdit `json:"edit"` diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 27e7cf806..77508d507 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -8,6 +8,7 @@ import ( "io" "os" "path/filepath" + "strconv" "strings" "sync" @@ -18,7 +19,9 @@ import ( "github.com/open-policy-agent/opa/format" "github.com/styrainc/regal/internal/lsp/clients" + "github.com/styrainc/regal/internal/lsp/commands" lsconfig "github.com/styrainc/regal/internal/lsp/config" + "github.com/styrainc/regal/internal/lsp/types" "github.com/styrainc/regal/internal/lsp/uri" "github.com/styrainc/regal/pkg/config" "github.com/styrainc/regal/pkg/fixer/fixes" @@ -43,7 +46,7 @@ func NewLanguageServer(opts *LanguageServerOptions) *LanguageServer { diagnosticRequestFile: make(chan fileUpdateEvent, 10), diagnosticRequestWorkspace: make(chan string, 10), builtinsPositionFile: make(chan fileUpdateEvent, 10), - commandRequest: make(chan ExecuteCommandParams, 10), + commandRequest: make(chan types.ExecuteCommandParams, 10), configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{ErrorWriter: opts.ErrorLog}), } @@ -66,7 +69,7 @@ type LanguageServer struct { builtinsPositionFile chan fileUpdateEvent - commandRequest chan ExecuteCommandParams + commandRequest chan types.ExecuteCommandParams clientRootURI string clientIdentifier clients.Identifier @@ -216,47 +219,6 @@ func (l *LanguageServer) StartHoverWorker(ctx context.Context) { } } -func getTargetURIFromParams(params ExecuteCommandParams) (string, error) { - if len(params.Arguments) == 0 { - return "", fmt.Errorf("expected at least one argument in command %v", params.Arguments) - } - - target, ok := params.Arguments[0].(string) - if !ok { - return "", fmt.Errorf("expected argument to be a string in command %v", params.Command) - } - - return target, nil -} - -func (l *LanguageServer) formatToEdits(params ExecuteCommandParams, opts format.Opts) ([]TextEdit, string, error) { - target, err := getTargetURIFromParams(params) - if err != nil { - return nil, "", fmt.Errorf("failed to get target uri: %w", err) - } - - oldContent, ok := l.cache.GetFileContents(target) - if !ok { - return nil, target, fmt.Errorf("could not get file contents for uri %q", target) - } - - f := &fixes.Fmt{OPAFmtOpts: opts} - - fixResults, err := f.Fix(&fixes.FixCandidate{ - Filename: filepath.Base(uri.ToPath(l.clientIdentifier, target)), - Contents: []byte(oldContent), - }, nil) - if err != nil { - return nil, target, fmt.Errorf("failed to format file: %w", err) - } - - if len(fixResults) == 0 { - return []TextEdit{}, target, nil - } - - return ComputeEdits(oldContent, string(fixResults[0].Contents)), target, nil -} - func (l *LanguageServer) StartConfigWorker(ctx context.Context) { err := l.configWatcher.Start(ctx) if err != nil { @@ -307,73 +269,60 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) { } func (l *LanguageServer) StartCommandWorker(ctx context.Context) { + // note, in this function conn.Call is used as the workspace/applyEdit message is a request, not a notification + // as per the spec. In order to be 'routed' to the correct handler on the client it must have an ID + // receive responses too. + // Note however that the responses from the client are not needed by the server. for { select { case <-ctx.Done(): return case params := <-l.commandRequest: - switch params.Command { - case "regal.fmt": - edits, target, err := l.formatToEdits(params, format.Opts{}) - if err != nil { - l.logError(err) + var editParams *ApplyWorkspaceEditParams - break - } + var err error - editParams := ApplyWorkspaceEditParams{ - Label: "Format using opa fmt", - Edit: WorkspaceEdit{ - DocumentChanges: []TextDocumentEdit{ - { - TextDocument: OptionalVersionedTextDocumentIdentifier{URI: target}, - Edits: edits, - }, - }, - }, - } + var fixed bool - // note, here conn.Call is used as the workspace/applyEdit message is a request, not a notification - // as per the spec. In order to be 'routed' to the correct handler on the client it must have an ID - // receive responses too. - err = l.conn.Call( - ctx, - methodWorkspaceApplyEdit, - editParams, - nil, // however, the response content is not important + switch params.Command { + case "regal.fmt": + fixed, editParams, err = l.fixEditParams( + "Format using opa fmt", + &fixes.Fmt{OPAFmtOpts: format.Opts{}}, + commands.ParseOptions{TargetArgIndex: 0}, + params, ) - if err != nil { - l.logError(fmt.Errorf("failed %s notify: %v", methodWorkspaceApplyEdit, err.Error())) - } case "regal.fmt.v1": - edits, target, err := l.formatToEdits(params, format.Opts{RegoVersion: ast.RegoV0CompatV1}) - if err != nil { - l.logError(err) + fixed, editParams, err = l.fixEditParams( + "Format for Rego v1 using opa-fmt", + &fixes.Fmt{OPAFmtOpts: format.Opts{RegoVersion: ast.RegoV0CompatV1}}, + commands.ParseOptions{TargetArgIndex: 0}, + params, + ) + case "regal.use-assignment-operator": + fixed, editParams, err = l.fixEditParams( + "Replace = with := in assignment", + &fixes.UseAssignmentOperator{}, + commands.ParseOptions{TargetArgIndex: 0, RowArgIndex: 1, ColArgIndex: 2}, + params, + ) + case "regal.no-whitespace-comment": + fixed, editParams, err = l.fixEditParams( + "Format comment to have leading whitespace", + &fixes.NoWhitespaceComment{}, + commands.ParseOptions{TargetArgIndex: 0, RowArgIndex: 1, ColArgIndex: 2}, + params, + ) + } - break - } + if err != nil { + l.logError(err) - editParams := ApplyWorkspaceEditParams{ - Label: "Format for Rego v1 using opa fmt", - Edit: WorkspaceEdit{ - DocumentChanges: []TextDocumentEdit{ - { - TextDocument: OptionalVersionedTextDocumentIdentifier{URI: target}, - Edits: edits, - }, - }, - }, - } + break + } - // note, here conn.Call is used as the workspace/applyEdit message is a request, not a notification - // as per the spec. In order to be 'routed' to the correct handler on the client it must have an ID - // receive responses too. - err = l.conn.Call( - ctx, - methodWorkspaceApplyEdit, - editParams, - nil, // however, the response content is not important - ) + if fixed { + err = l.conn.Call(ctx, methodWorkspaceApplyEdit, editParams, nil) if err != nil { l.logError(fmt.Errorf("failed %s notify: %v", methodWorkspaceApplyEdit, err.Error())) } @@ -382,6 +331,57 @@ func (l *LanguageServer) StartCommandWorker(ctx context.Context) { } } +func (l *LanguageServer) fixEditParams( + label string, + fix fixes.Fix, + commandParseOpts commands.ParseOptions, + params types.ExecuteCommandParams, +) (bool, *ApplyWorkspaceEditParams, error) { + pr, err := commands.Parse(params, commandParseOpts) + if err != nil { + return false, nil, fmt.Errorf("failed to parse command params: %w", err) + } + + oldContent, ok := l.cache.GetFileContents(pr.Target) + if !ok { + return false, nil, fmt.Errorf("could not get file contents for uri %q", pr.Target) + } + + var rto *fixes.RuntimeOptions + if pr.Location != nil { + rto = &fixes.RuntimeOptions{Locations: []ast.Location{*pr.Location}} + } + + fixResults, err := fix.Fix( + &fixes.FixCandidate{ + Filename: filepath.Base(uri.ToPath(l.clientIdentifier, pr.Target)), + Contents: []byte(oldContent), + }, + rto, + ) + if err != nil { + return false, nil, fmt.Errorf("failed to fix: %w", err) + } + + if len(fixResults) == 0 { + return false, &ApplyWorkspaceEditParams{}, nil + } + + editParams := &ApplyWorkspaceEditParams{ + Label: label, + Edit: WorkspaceEdit{ + DocumentChanges: []TextDocumentEdit{ + { + TextDocument: OptionalVersionedTextDocumentIdentifier{URI: pr.Target}, + Edits: ComputeEdits(oldContent, string(fixResults[0].Contents)), + }, + }, + }, + } + + return true, editParams, nil +} + // processTextContentUpdate updates the cache with the new content for the file at the given URI, attempts to parse the // file, and returns whether the parse was successful. If it was not successful, the parse errors will be sent // on the diagnostic channel. @@ -515,6 +515,30 @@ func (l *LanguageServer) handleTextDocumentCodeAction( IsPreferred: true, Command: FmtV1Command([]string{params.TextDocument.URI}), }) + case "use-assignment-operator": + actions = append(actions, CodeAction{ + Title: "Replace = with := in assignment", + Kind: "quickfix", + Diagnostics: []Diagnostic{diag}, + IsPreferred: true, + Command: UseAssignmentOperatorCommand([]string{ + params.TextDocument.URI, + strconv.FormatUint(uint64(diag.Range.Start.Line+1), 10), + strconv.FormatUint(uint64(diag.Range.Start.Character+1), 10), + }), + }) + case "no-whitespace-comment": + actions = append(actions, CodeAction{ + Title: "Format comment to have leading whitespace", + Kind: "quickfix", + Diagnostics: []Diagnostic{diag}, + IsPreferred: true, + Command: NoWhiteSpaceCommentCommand([]string{ + params.TextDocument.URI, + strconv.FormatUint(uint64(diag.Range.Start.Line+1), 10), + strconv.FormatUint(uint64(diag.Range.Start.Character+1), 10), + }), + }) } if l.clientIdentifier == clients.IdentifierVSCode { @@ -543,7 +567,7 @@ func (l *LanguageServer) handleWorkspaceExecuteCommand( _ *jsonrpc2.Conn, req *jsonrpc2.Request, ) (result any, err error) { - var params ExecuteCommandParams + var params types.ExecuteCommandParams if err := json.Unmarshal(*req.Params, ¶ms); err != nil { return nil, fmt.Errorf("failed to unmarshal params: %w", err) } @@ -836,7 +860,12 @@ func (l *LanguageServer) handleInitialize( CodeActionKinds: []string{"quickfix"}, }, ExecuteCommandProvider: ExecuteCommandOptions{ - Commands: []string{"regal.fmt", "regal.fmt.v1"}, + Commands: []string{ + "regal.fmt", + "regal.fmt.v1", + "regal.use-assignment-operator", + "regal.no-whitespace-comment", + }, }, DocumentFormattingProvider: true, }, diff --git a/internal/lsp/types/types.go b/internal/lsp/types/types.go new file mode 100644 index 000000000..c224eae10 --- /dev/null +++ b/internal/lsp/types/types.go @@ -0,0 +1,6 @@ +package types + +type ExecuteCommandParams struct { + Command string `json:"command"` + Arguments []any `json:"arguments"` +} diff --git a/pkg/fixer/fixes/fmt_test.go b/pkg/fixer/fixes/fmt_test.go index f02fee13e..fc5dbbe20 100644 --- a/pkg/fixer/fixes/fmt_test.go +++ b/pkg/fixer/fixes/fmt_test.go @@ -57,7 +57,7 @@ allow := true t.Run(testName, func(t *testing.T) { t.Parallel() - fixResults, err := tc.fmt.Fix(tc.fc, &RuntimeOptions{}) + fixResults, err := tc.fmt.Fix(tc.fc, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/pkg/fixer/fixes/nowhitespacecomment.go b/pkg/fixer/fixes/nowhitespacecomment.go index 56211ce4a..0f40895b0 100644 --- a/pkg/fixer/fixes/nowhitespacecomment.go +++ b/pkg/fixer/fixes/nowhitespacecomment.go @@ -2,6 +2,7 @@ package fixes import ( "bytes" + "errors" "slices" ) @@ -14,9 +15,8 @@ func (*NoWhitespaceComment) Name() string { func (*NoWhitespaceComment) Fix(fc *FixCandidate, opts *RuntimeOptions) ([]FixResult, error) { lines := bytes.Split(fc.Contents, []byte("\n")) - // this fix must have locations - if len(opts.Locations) == 0 { - return nil, nil + if opts == nil { + return nil, errors.New("missing runtime options") } fixed := false diff --git a/pkg/fixer/fixes/nowhitespacecomment_test.go b/pkg/fixer/fixes/nowhitespacecomment_test.go index 826f7f7d6..3a9813eaa 100644 --- a/pkg/fixer/fixes/nowhitespacecomment_test.go +++ b/pkg/fixer/fixes/nowhitespacecomment_test.go @@ -27,23 +27,10 @@ func TestNoWhitespaceComment(t *testing.T) { # this is a comment `), - fixExpected: false, - runtimeOptions: &RuntimeOptions{}, - }, - "no change made because no location": { - fc: &FixCandidate{ - Filename: "test.rego", - Contents: []byte(`package test\n - -#this is a comment -`), + fixExpected: false, + runtimeOptions: &RuntimeOptions{ + Locations: []ast.Location{}, }, - contentAfterFix: []byte(`package test\n - -#this is a comment -`), - fixExpected: false, - runtimeOptions: &RuntimeOptions{}, }, "single change": { fc: &FixCandidate{ @@ -67,28 +54,6 @@ func TestNoWhitespaceComment(t *testing.T) { }, }, }, - "bad change": { - fc: &FixCandidate{ - Filename: "test.rego", - Contents: []byte(`package test\n - -#this is a comment -`), - }, - contentAfterFix: []byte(`package test\n - -#this is a comment -`), - fixExpected: false, - runtimeOptions: &RuntimeOptions{ - Locations: []ast.Location{ - { - Row: 3, - Col: 9, // this is wrong and should not be fixed - }, - }, - }, - }, "many changes": { fc: &FixCandidate{ Filename: "test.rego", diff --git a/pkg/fixer/fixes/useassignmentoperator.go b/pkg/fixer/fixes/useassignmentoperator.go index 0bc0afac0..7b8b06f3e 100644 --- a/pkg/fixer/fixes/useassignmentoperator.go +++ b/pkg/fixer/fixes/useassignmentoperator.go @@ -2,6 +2,7 @@ package fixes import ( "bytes" + "errors" "slices" ) @@ -14,9 +15,8 @@ func (*UseAssignmentOperator) Name() string { func (*UseAssignmentOperator) Fix(fc *FixCandidate, opts *RuntimeOptions) ([]FixResult, error) { lines := bytes.Split(fc.Contents, []byte("\n")) - // this fix must have locations - if len(opts.Locations) == 0 { - return nil, nil + if opts == nil { + return nil, errors.New("missing runtime options") } fixed := false