From c31886ed0803c135afc08b072d17a45ee637ab0a Mon Sep 17 00:00:00 2001 From: Charlie Egan Date: Mon, 15 Apr 2024 13:31:43 +0100 Subject: [PATCH] lsp: Load config from parent dirs (#650) * lsp: Load config from parent dirs Fixes https://github.com/StyraInc/regal/issues/626 Signed-off-by: Charlie Egan * lsp: watcher error, use := Signed-off-by: Charlie Egan * lsp: Wait in server tests Server tests refactored a little to allow messages to be sent in a different order. Signed-off-by: Charlie Egan * lsp: PR review comments Signed-off-by: Charlie Egan * lsp: further reduce the config watch timeout Signed-off-by: Charlie Egan * lsp: server test use shared timeout Signed-off-by: Charlie Egan --------- Signed-off-by: Charlie Egan --- cmd/languageserver.go | 1 + go.mod | 1 + go.sum | 2 + internal/lsp/clients/clients.go | 5 + internal/lsp/config/watcher.go | 132 ++++++++ internal/lsp/config/watcher_test.go | 71 ++++ internal/lsp/server.go | 111 ++++--- internal/lsp/server_test.go | 481 ++++++++++++++++++---------- 8 files changed, 597 insertions(+), 207 deletions(-) create mode 100644 internal/lsp/config/watcher.go create mode 100644 internal/lsp/config/watcher_test.go diff --git a/cmd/languageserver.go b/cmd/languageserver.go index a0d2c11b..012ffe8b 100644 --- a/cmd/languageserver.go +++ b/cmd/languageserver.go @@ -43,6 +43,7 @@ func init() { go ls.StartDiagnosticsWorker(ctx) go ls.StartHoverWorker(ctx) go ls.StartCommandWorker(ctx) + go ls.StartConfigWorker(ctx) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) diff --git a/go.mod b/go.mod index 496db943..f0c6ce0d 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ toolchain go1.22.0 require ( dario.cat/mergo v1.0.0 github.com/fatih/color v1.16.0 + github.com/fsnotify/fsnotify v1.7.0 github.com/gobwas/glob v0.2.3 github.com/google/go-cmp v0.6.0 github.com/olekukonko/tablewriter v0.0.5 diff --git a/go.sum b/go.sum index b2878924..c55e72b0 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8 github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7DlmewI= github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/gdamore/encoding v1.0.0 h1:+7OoQ1Bc6eTm5niUzBa0Ctsh6JbMW6Ra+YNuAtDBdko= github.com/gdamore/encoding v1.0.0/go.mod h1:alR0ol34c49FCSBLjhosxzcPHQbf2trDkoo5dl+VrEg= github.com/gdamore/tcell v1.1.4/go.mod h1:Hjvr+Ofd+gLglo7RYKxxnzCBmev3BzsS67MebKS4zMM= diff --git a/internal/lsp/clients/clients.go b/internal/lsp/clients/clients.go index 147525ea..8ca420f9 100644 --- a/internal/lsp/clients/clients.go +++ b/internal/lsp/clients/clients.go @@ -7,9 +7,14 @@ type Identifier int const ( IdentifierGeneric Identifier = iota IdentifierVSCode + IdentifierGoTest ) func DetermineClientIdentifier(clientName string) Identifier { + if clientName == "go test" { + return IdentifierGoTest + } + if clientName == "Visual Studio Code" { return IdentifierVSCode } diff --git a/internal/lsp/config/watcher.go b/internal/lsp/config/watcher.go new file mode 100644 index 00000000..08b8d52e --- /dev/null +++ b/internal/lsp/config/watcher.go @@ -0,0 +1,132 @@ +package config + +import ( + "context" + "fmt" + "io" + + "github.com/fsnotify/fsnotify" +) + +type Watcher struct { + Reload chan string + Drop chan struct{} + + path string + pathUpdates chan string + + fsWatcher *fsnotify.Watcher + + errorWriter io.Writer +} + +type WatcherOpts struct { + ErrorWriter io.Writer + Path string +} + +func NewWatcher(opts *WatcherOpts) *Watcher { + w := &Watcher{ + Reload: make(chan string, 1), + Drop: make(chan struct{}, 1), + pathUpdates: make(chan string, 1), + } + + if opts != nil { + w.errorWriter = opts.ErrorWriter + w.path = opts.Path + } + + return w +} + +func (w *Watcher) Start(ctx context.Context) error { + err := w.Stop() + if err != nil { + return fmt.Errorf("failed to stop existing watcher: %w", err) + } + + w.fsWatcher, err = fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("failed to create fsnotify watcher: %w", err) + } + + go func() { + w.loop(ctx) + }() + + return nil +} + +func (w *Watcher) loop(ctx context.Context) { + for { + select { + case path := <-w.pathUpdates: + if w.path != "" { + err := w.fsWatcher.Remove(w.path) + if err != nil { + fmt.Fprintf(w.errorWriter, "failed to remove existing watch: %v\n", err) + } + } + + err := w.fsWatcher.Add(path) + if err != nil { + fmt.Fprintf(w.errorWriter, "failed to add watch: %v\n", err) + } + + w.path = path + + // when the path itself is changed, then this is an event too + w.Reload <- path + case event, ok := <-w.fsWatcher.Events: + if !ok { + fmt.Fprintf(w.errorWriter, "config watcher event channel closed\n") + + return + } + + if event.Has(fsnotify.Write) { + w.Reload <- event.Name + } + + if event.Has(fsnotify.Remove) || event.Has(fsnotify.Rename) { + w.path = "" + w.Drop <- struct{}{} + } + case err := <-w.fsWatcher.Errors: + fmt.Fprintf(w.errorWriter, "config watcher error: %v\n", err) + case <-ctx.Done(): + err := w.Stop() + if err != nil { + fmt.Fprintf(w.errorWriter, "failed to stop watcher: %v\n", err) + } + + return + } + } +} + +func (w *Watcher) Watch(configFilePath string) { + w.pathUpdates <- configFilePath +} + +func (w *Watcher) Stop() error { + if w.fsWatcher != nil { + err := w.fsWatcher.Close() + if err != nil { + return fmt.Errorf("failed to close fsnotify watcher: %w", err) + } + + return nil + } + + return nil +} + +func (w *Watcher) IsWatching() bool { + if w.fsWatcher == nil { + return false + } + + return len(w.fsWatcher.WatchList()) > 0 +} diff --git a/internal/lsp/config/watcher_test.go b/internal/lsp/config/watcher_test.go new file mode 100644 index 00000000..cb1401ba --- /dev/null +++ b/internal/lsp/config/watcher_test.go @@ -0,0 +1,71 @@ +package config + +import ( + "context" + "os" + "testing" + "time" +) + +func TestWatcher(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + + configFilePath := tempDir + "/config.yaml" + + configFileContents := `--- +foo: bar +` + + err := os.WriteFile(configFilePath, []byte(configFileContents), 0o600) + if err != nil { + t.Fatal(err) + } + + watcher := NewWatcher(&WatcherOpts{ErrorWriter: os.Stderr}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + err = watcher.Start(ctx) + if err != nil { + t.Errorf("failed to start watcher: %v", err) + } + }() + + watcher.Watch(configFilePath) + + select { + case <-watcher.Reload: + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for initial config event") + } + + newConfigFileContents := `--- +foo: baz +` + + err = os.WriteFile(configFilePath, []byte(newConfigFileContents), 0o600) + if err != nil { + t.Fatal(err) + } + + select { + case <-watcher.Reload: + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for config event") + } + + err = os.Rename(configFilePath, configFilePath+".new") + if err != nil { + t.Fatal(err) + } + + select { + case <-watcher.Drop: + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout waiting for config drop event") + } +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index b42c4c73..45beac8c 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -18,6 +18,7 @@ import ( "github.com/open-policy-agent/opa/format" "github.com/styrainc/regal/internal/lsp/clients" + lsconfig "github.com/styrainc/regal/internal/lsp/config" "github.com/styrainc/regal/internal/lsp/uri" "github.com/styrainc/regal/pkg/config" ) @@ -25,6 +26,9 @@ import ( const ( methodTextDocumentPublishDiagnostics = "textDocument/publishDiagnostics" methodWorkspaceApplyEdit = "workspace/applyEdit" + + ruleNameOPAFmt = "opa-fmt" + ruleNameUseRegoV1 = "use-rego-v1" ) type LanguageServerOptions struct { @@ -39,6 +43,7 @@ func NewLanguageServer(opts *LanguageServerOptions) *LanguageServer { diagnosticRequestWorkspace: make(chan string, 10), builtinsPositionFile: make(chan fileUpdateEvent, 10), commandRequest: make(chan ExecuteCommandParams, 10), + configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{ErrorWriter: opts.ErrorLog}), } return ls @@ -51,6 +56,7 @@ type LanguageServer struct { errorLog io.Writer + configWatcher *lsconfig.Watcher loadedConfig *config.Config loadedConfigLock sync.Mutex @@ -185,8 +191,8 @@ func (l *LanguageServer) StartDiagnosticsWorker(ctx context.Context) { } // send diagnostics for all files - for uri := range l.cache.GetAllFiles() { - err = l.sendFileDiagnostics(ctx, uri) + for fileURI := range l.cache.GetAllFiles() { + err = l.sendFileDiagnostics(ctx, fileURI) if err != nil { l.logError(fmt.Errorf("failed to send diagnostic: %w", err)) } @@ -241,6 +247,55 @@ func (l *LanguageServer) formatToEdits(params ExecuteCommandParams, opts format. return ComputeEdits(oldContent, newContent), target, nil } +func (l *LanguageServer) StartConfigWorker(ctx context.Context) { + err := l.configWatcher.Start(ctx) + if err != nil { + l.logError(fmt.Errorf("failed to start config watcher: %w", err)) + + return + } + + for { + select { + case <-ctx.Done(): + return + case path := <-l.configWatcher.Reload: + configFile, err := os.Open(path) + if err != nil { + l.logError(fmt.Errorf("failed to open config file: %w", err)) + + continue + } + + var loadedConfig config.Config + + err = yaml.NewDecoder(configFile).Decode(&loadedConfig) + if err != nil && !errors.Is(err, io.EOF) { + l.logError(fmt.Errorf("failed to reload config: %w", err)) + + return + } + + // if the config is now blank, then we need to clear it + l.loadedConfigLock.Lock() + if errors.Is(err, io.EOF) { + l.loadedConfig = nil + } else { + l.loadedConfig = &loadedConfig + } + l.loadedConfigLock.Unlock() + + l.diagnosticRequestWorkspace <- "config file changed" + case <-l.configWatcher.Drop: + l.loadedConfigLock.Lock() + l.loadedConfig = nil + l.loadedConfigLock.Unlock() + + l.diagnosticRequestWorkspace <- "config file dropped" + } + } +} + func (l *LanguageServer) StartCommandWorker(ctx context.Context) { for { select { @@ -434,7 +489,7 @@ func (l *LanguageServer) handleTextDocumentCodeAction( for _, diag := range params.Context.Diagnostics { switch diag.Code { - case "opa-fmt": + case ruleNameOPAFmt: actions = append(actions, CodeAction{ Title: "Format using opa fmt", Kind: "quickfix", @@ -442,7 +497,7 @@ func (l *LanguageServer) handleTextDocumentCodeAction( IsPreferred: true, Command: FmtCommand([]string{params.TextDocument.URI}), }) - case "use-rego-v1": + case ruleNameUseRegoV1: actions = append(actions, CodeAction{ Title: "Format for Rego v1 using opa fmt", Kind: "quickfix", @@ -776,10 +831,9 @@ func (l *LanguageServer) handleInitialize( return nil, fmt.Errorf("failed to load workspace contents: %w", err) } - // attempt to load the config as it is found on disk - file, err := config.FindConfig(strings.TrimPrefix(l.clientRootURI, "file://")) + configFile, err := config.FindConfig(uri.ToPath(l.clientIdentifier, l.clientRootURI)) if err == nil { - l.reloadConfig(file, false) + l.configWatcher.Watch(configFile.Name()) } } @@ -820,39 +874,16 @@ func (l *LanguageServer) loadWorkspaceContents() error { return nil } -func (l *LanguageServer) reloadConfig(configReader io.Reader, runWorkspaceDiagnostics bool) { - l.loadedConfigLock.Lock() - defer l.loadedConfigLock.Unlock() - - var loadedConfig config.Config - - err := yaml.NewDecoder(configReader).Decode(&loadedConfig) - if err != nil && !errors.Is(err, io.EOF) { - l.logError(fmt.Errorf("failed to reload config: %w", err)) - - return - } - - // if the config is now blank, then we need to clear it - if errors.Is(err, io.EOF) { - l.loadedConfig = nil - } else { - l.loadedConfig = &loadedConfig - } - - // this can be set to false by callers to disable the running of diagnostics for the whole workspace. - // this is intended to be used at start up when a workspace run is already going to be taking place. - if runWorkspaceDiagnostics { - l.diagnosticRequestWorkspace <- "config file changed" - } -} - func (l *LanguageServer) handleInitialized( _ context.Context, _ *jsonrpc2.Conn, _ *jsonrpc2.Request, ) (result any, err error) { - l.diagnosticRequestWorkspace <- "initialized" + // if running without config, then we should send the diagnostic request now + // otherwise it'll happen when the config is loaded + if !l.configWatcher.IsWatching() { + l.diagnosticRequestWorkspace <- "initialized" + } return struct{}{}, nil } @@ -890,16 +921,6 @@ func (l *LanguageServer) handleWorkspaceDidChangeWatchedFiles( continue } - if strings.HasSuffix(change.URI, "/.regal/config.yaml") { - // attempt to load the config as it is found on disk - file, err := os.Open(strings.TrimPrefix(change.URI, "file://")) - if err == nil { - l.reloadConfig(file, true) - } - - continue - } - regoFiles = append(regoFiles, change.URI) } diff --git a/internal/lsp/server_test.go b/internal/lsp/server_test.go index 50adf8a4..e6ac7842 100644 --- a/internal/lsp/server_test.go +++ b/internal/lsp/server_test.go @@ -14,6 +14,10 @@ import ( "github.com/sourcegraph/jsonrpc2" ) +const mainRegoFileName = "/main.rego" + +const defaultTimeout = 3 * time.Second + // InMemoryReadWriteCloser is an in-memory implementation of jsonrpc2.ReadWriteCloser. type InMemoryReadWriteCloser struct { Buffer bytes.Buffer @@ -41,14 +45,14 @@ const fileURIScheme = "file://" // language server my making updates to both and validating that the correct diagnostics are sent to the client. // //nolint:gocognit,gocyclo,maintidx -func TestLanguageServerSingleFileWithConfig(t *testing.T) { +func TestLanguageServerSingleFile(t *testing.T) { t.Parallel() var err error // set up the workspace content with some example rego and regal config tempDir := t.TempDir() - mainRegoURI := fileURIScheme + tempDir + "/main.rego" + mainRegoURI := fileURIScheme + tempDir + mainRegoFileName err = os.MkdirAll(filepath.Join(tempDir, ".regal"), 0o755) if err != nil { @@ -81,11 +85,19 @@ allow = true ErrorLog: os.Stderr, }) go ls.StartDiagnosticsWorker(ctx) + go ls.StartConfigWorker(ctx) - receivedMessages := make(chan jsonrpc2.Request, 1) - testHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { + receivedMessages := make(chan FileDiagnostics, 1) + clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { if req.Method == methodTextDocumentPublishDiagnostics { - receivedMessages <- *req + var requestData FileDiagnostics + + err = json.Unmarshal(*req.Params, &requestData) + if err != nil { + t.Fatalf("failed to unmarshal diagnostics: %s", err) + } + + receivedMessages <- requestData return struct{}{}, nil } @@ -95,29 +107,15 @@ allow = true return struct{}{}, nil } - netConnServer, netConnClient := net.Pipe() - defer netConnServer.Close() - defer netConnClient.Close() - - connServer := jsonrpc2.NewConn( - ctx, - jsonrpc2.NewBufferedStream(netConnServer, jsonrpc2.VSCodeObjectCodec{}), - jsonrpc2.HandlerWithError(ls.Handle), - ) - defer connServer.Close() - - connClient := jsonrpc2.NewConn( - ctx, - jsonrpc2.NewBufferedStream(netConnClient, jsonrpc2.VSCodeObjectCodec{}), - jsonrpc2.HandlerWithError(testHandler), - ) - defer connClient.Close() + connServer, connClient, cleanup := createConnections(ctx, ls.Handle, clientHandler) + defer cleanup() ls.SetConn(connServer) // 1. Client sends initialize request request := InitializeParams{ - RootURI: fileURIScheme + tempDir, + RootURI: fileURIScheme + tempDir, + ClientInfo: Client{Name: "go test"}, } var response InitializeResult @@ -163,46 +161,21 @@ allow = true } // validate that the client received a diagnostics notification for the file - select { - case request := <-receivedMessages: - if request.Method != methodTextDocumentPublishDiagnostics { - t.Fatalf("expected diagnostics to be sent, got %v", request) - } - - // validate that the diagnostics are correct - var requestData FileDiagnostics - - err = json.Unmarshal(*request.Params, &requestData) - if err != nil { - t.Fatalf("failed to unmarshal diagnostics: %s", err) - } - - if requestData.URI != mainRegoURI { - t.Fatalf("expected diagnostics to be sent for main.rego, got %s", requestData.URI) + timeout := time.NewTimer(defaultTimeout) + defer timeout.Stop() + + for { + var success bool + select { + case requestData := <-receivedMessages: + success = testRequestDataCodes(t, requestData, mainRegoURI, []string{"opa-fmt", "use-assignment-operator"}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") } - if len(requestData.Items) != 2 { - t.Fatalf("expected 2 diagnostics, got %d, %v", len(requestData.Items), requestData) + if success { + break } - - expectedItems := map[string]bool{ - "opa-fmt": false, - "use-assignment-operator": false, - } - - for _, item := range requestData.Items { - t.Log(item.Code) - - expectedItems[item.Code] = true - } - - for item, found := range expectedItems { - if !found { - t.Fatalf("expected diagnostic %s to be found", item) - } - } - case <-time.After(3 * time.Second): - t.Fatalf("timed out waiting for file diagnostics to be sent") } // 3. Client sends textDocument/didChange notification with new contents for main.rego @@ -225,33 +198,21 @@ allow := true } // validate that the client received a new diagnostics notification for the file - select { - case request := <-receivedMessages: - if request.Method != methodTextDocumentPublishDiagnostics { - t.Fatalf("expected diagnostics to be sent, got %v", request) + timeout = time.NewTimer(defaultTimeout) + defer timeout.Stop() + + for { + var success bool + select { + case requestData := <-receivedMessages: + success = testRequestDataCodes(t, requestData, mainRegoURI, []string{"opa-fmt"}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") } - // validate that the diagnostics are correct - var requestData FileDiagnostics - - err = json.Unmarshal(*request.Params, &requestData) - if err != nil { - t.Fatalf("failed to unmarshal diagnostics: %s", err) + if success { + break } - - if requestData.URI != mainRegoURI { - t.Fatalf("expected diagnostics to be sent for main.rego, got %s", requestData.URI) - } - - if len(requestData.Items) != 1 { - t.Fatalf("expected 1 diagnostic, got %d", len(requestData.Items)) - } - - if requestData.Items[0].Code != "opa-fmt" { - t.Fatalf("expected diagnostic to be opa-fmt, got %s", requestData.Items[0].Code) - } - case <-time.After(3 * time.Second): - t.Fatalf("timed out waiting for file diagnostics to be sent") } // 4. Client sends workspace/didChangeWatchedFiles notification with new config @@ -267,42 +228,34 @@ rules: t.Fatalf("failed to write new config file: %s", err) } - err = connClient.Call(ctx, "workspace/didChangeWatchedFiles", WorkspaceDidChangeWatchedFilesParams{ - Changes: []FileEvent{ - { - Type: 1, - URI: fileURIScheme + tempDir + "/.regal/config.yaml", - }, - }, - }, nil) - if err != nil { - t.Fatalf("failed to send didChangeWatchedFiles notification: %s", err) - } - // validate that the client received a new, empty diagnostics notification for the file - select { - case request := <-receivedMessages: - if request.Method != methodTextDocumentPublishDiagnostics { - t.Fatalf("expected diagnostics to be sent, got %v", request) - } + timeout = time.NewTimer(defaultTimeout) + defer timeout.Stop() - // validate that the diagnostics are correct - var requestData FileDiagnostics + for { + var success bool + select { + case requestData := <-receivedMessages: + if requestData.URI != mainRegoURI { + t.Logf("expected diagnostics to be sent for main.rego, got %s", requestData.URI) - err = json.Unmarshal(*request.Params, &requestData) - if err != nil { - t.Fatalf("failed to unmarshal diagnostics: %s", err) - } + break + } + + if len(requestData.Items) != 0 { + t.Logf("expected 0 diagnostic, got %d", len(requestData.Items)) + + break + } - if requestData.URI != mainRegoURI { - t.Fatalf("expected diagnostics to be sent for main.rego, got %s", requestData.URI) + success = testRequestDataCodes(t, requestData, mainRegoURI, []string{}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") } - if len(requestData.Items) != 0 { - t.Fatalf("expected 1 diagnostic, got %d", len(requestData.Items)) + if success { + break } - case <-time.After(3 * time.Second): - t.Fatalf("timed out waiting for file diagnostics to be sent") } } @@ -376,11 +329,12 @@ ignore: ErrorLog: os.Stderr, }) go ls.StartDiagnosticsWorker(ctx) + go ls.StartConfigWorker(ctx) authzFileMessages := make(chan FileDiagnostics, 1) adminsFileMessages := make(chan FileDiagnostics, 1) ignoredFileMessages := make(chan FileDiagnostics, 1) - testHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { + clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { if req.Method == "textDocument/publishDiagnostics" { var requestData FileDiagnostics @@ -409,29 +363,15 @@ ignore: return struct{}{}, nil } - netConnServer, netConnClient := net.Pipe() - defer netConnServer.Close() - defer netConnClient.Close() - - connServer := jsonrpc2.NewConn( - ctx, - jsonrpc2.NewBufferedStream(netConnServer, jsonrpc2.VSCodeObjectCodec{}), - jsonrpc2.HandlerWithError(ls.Handle), - ) - defer connServer.Close() - - connClient := jsonrpc2.NewConn( - ctx, - jsonrpc2.NewBufferedStream(netConnClient, jsonrpc2.VSCodeObjectCodec{}), - jsonrpc2.HandlerWithError(testHandler), - ) - defer connClient.Close() + connServer, connClient, cleanup := createConnections(ctx, ls.Handle, clientHandler) + defer cleanup() ls.SetConn(connServer) // 1. Client sends initialize request request := InitializeParams{ - RootURI: fileURIScheme + tempDir, + RootURI: fileURIScheme + tempDir, + ClientInfo: Client{Name: "go test"}, } var response InitializeResult @@ -462,7 +402,7 @@ ignore: if requestData.Items[0].Code != "prefer-package-imports" { t.Fatalf("expected diagnostic to be prefer-package-imports, got %s", requestData.Items[0].Code) } - case <-time.After(3 * time.Second): + case <-time.After(defaultTimeout): t.Fatalf("timed out waiting for authz.rego diagnostics to be sent") } @@ -480,7 +420,7 @@ ignore: if requestData.Items[0].Code != "use-assignment-operator" { t.Fatalf("expected diagnostic to be use-assignment-operator, got %s", requestData.Items[0].Code) } - case <-time.After(3 * time.Second): + case <-time.After(defaultTimeout): t.Fatalf("timed out waiting for admins.rego diagnostics to be sent") } @@ -494,7 +434,7 @@ ignore: if len(requestData.Items) != 0 { t.Fatalf("expected 0 diagnostics, got %d, %v", len(requestData.Items), requestData) } - case <-time.After(3 * time.Second): + case <-time.After(defaultTimeout): t.Fatalf("timed out waiting for ignored/foo.rego diagnostics to be sent") } @@ -528,44 +468,261 @@ allow if input.user in admins.users // here we wait to receive a diagnostics notification for authz.rego with no diagnostics items, the file diagnostics // can arrive first which can still contain the old diagnostics items - ok := make(chan bool, 1) + timeout := time.NewTimer(defaultTimeout) + defer timeout.Stop() + + for { + var success bool + select { + case diags := <-authzFileMessages: + success = testRequestDataCodes(t, diags, authzRegoURI, []string{}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") + } - go func() { - for { - requestData := <-authzFileMessages - if requestData.URI != authzRegoURI { - t.Logf("expected diagnostics to be sent for authz.rego, got %s", requestData.URI) - } + if success { + break + } + } - if len(requestData.Items) != 0 { - continue + // we should also receive a diagnostics notification for admins.rego, since it is in the workspace, but it has not + // been changed, so the violations should be the same + timeout = time.NewTimer(defaultTimeout) + defer timeout.Stop() + + for { + var success bool + select { + case requestData := <-adminsFileMessages: + success = testRequestDataCodes(t, requestData, adminsRegoURI, []string{"use-assignment-operator"}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") + } + + if success { + break + } + } +} + +// TestLanguageServerParentDirConfig tests that regal config is loaded as it is for the +// Regal CLI, and that config files in a parent directory are loaded correctly +// even when the workspace is a child directory. +func TestLanguageServerParentDirConfig(t *testing.T) { + t.Parallel() + + var err error + + // this is the top level directory for the test + parentDir := t.TempDir() + // childDir will be the directory that the client is using as its workspace + childDirName := "child" + childDir := filepath.Join(parentDir, childDirName) + + for _, dir := range []string{childDirName, ".regal"} { + err = os.MkdirAll(filepath.Join(parentDir, dir), 0o755) + if err != nil { + t.Fatalf("failed to create %q directory under parent: %s", dir, err) + } + } + + mainRegoContents := `package main + +import rego.v1 +allow := true +` + + files := map[string]string{ + childDirName + mainRegoFileName: mainRegoContents, + ".regal/config.yaml": `rules: + style: + opa-fmt: + level: error +`, + } + + for f, fc := range files { + err = os.WriteFile(filepath.Join(parentDir, f), []byte(fc), 0o600) + if err != nil { + t.Fatalf("failed to write file %s: %s", f, err) + } + } + + // mainRegoFileURI is used throughout the test to refer to the main.rego file + // and so it is defined here for convenience + mainRegoFileURI := fileURIScheme + childDir + mainRegoFileName + + // set up the server and client connections + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ls := NewLanguageServer(&LanguageServerOptions{ + ErrorLog: os.Stderr, + }) + go ls.StartDiagnosticsWorker(ctx) + go ls.StartConfigWorker(ctx) + + receivedMessages := make(chan FileDiagnostics, 1) + clientHandler := func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error) { + if req.Method == methodTextDocumentPublishDiagnostics { + var requestData FileDiagnostics + + err = json.Unmarshal(*req.Params, &requestData) + if err != nil { + t.Fatalf("failed to unmarshal diagnostics: %s", err) } - ok <- true + receivedMessages <- requestData - return + return struct{}{}, nil } - }() - select { - case <-ok: - case <-time.After(3 * time.Second): - t.Fatalf("timed out waiting for authz.rego diagnostics to be sent") + t.Logf("unexpected request from server: %v", req) + + return struct{}{}, nil } - // we should also receive a diagnostics notification for admins.rego, since it is in the workspace, but it has not - // been changed, so the violations should be the same - select { - case requestData := <-adminsFileMessages: - if requestData.URI != adminsRegoURI { - t.Fatalf("expected diagnostics to be sent for admins.rego, got %s", requestData.URI) + connServer, connClient, cleanup := createConnections(ctx, ls.Handle, clientHandler) + defer cleanup() + + ls.SetConn(connServer) + + // Client sends initialize request + request := InitializeParams{ + RootURI: fileURIScheme + childDir, + ClientInfo: Client{Name: "go test"}, + } + + var response InitializeResult + + err = connClient.Call(ctx, "initialize", request, &response) + if err != nil { + t.Fatalf("failed to send initialize request: %s", err) + } + + if ls.clientRootURI != request.RootURI { + t.Fatalf("expected client root URI to be %s, got %s", request.RootURI, ls.clientRootURI) + } + + // Client sends initialized notification + // the response to the call is expected to be empty and is ignored + err = connClient.Call(ctx, "initialized", struct{}{}, nil) + if err != nil { + t.Fatalf("failed to send initialized notification: %s", err) + } + + timeout := time.NewTimer(defaultTimeout) + defer timeout.Stop() + + for { + var success bool + select { + case requestData := <-receivedMessages: + success = testRequestDataCodes(t, requestData, mainRegoFileURI, []string{"opa-fmt"}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") } - // this file is unchanged - if len(requestData.Items) != 1 { - t.Fatalf("expected 1 diagnostics, got %d", len(requestData.Items)) + if success { + break } - case <-time.After(3 * time.Second): - t.Fatalf("timed out waiting for admins.rego diagnostics to be sent") } + + // User updates config file contents in parent directory that is not + // part of the workspace + newConfigContents := `rules: + style: + opa-fmt: + level: ignore +` + + err = os.WriteFile(filepath.Join(parentDir, ".regal/config.yaml"), []byte(newConfigContents), 0o600) + if err != nil { + t.Fatalf("failed to write new config file: %s", err) + } + + // validate that the client received a new, empty diagnostics notification for the file + timeout = time.NewTimer(defaultTimeout) + defer timeout.Stop() + + for { + var success bool + select { + case requestData := <-receivedMessages: + success = testRequestDataCodes(t, requestData, mainRegoFileURI, []string{}) + case <-timeout.C: + t.Fatalf("timed out waiting for file diagnostics to be sent") + } + + if success { + break + } + } +} + +func testRequestDataCodes(t *testing.T, requestData FileDiagnostics, uri string, codes []string) bool { + t.Helper() + + if requestData.URI != uri { + t.Log("expected diagnostics to be sent for", uri, "got", requestData.URI) + + return false + } + + if len(requestData.Items) != len(codes) { + t.Log("expected", len(codes), "diagnostics, got", len(requestData.Items)) + + return false + } + + for _, v := range codes { + found := false + foundItems := make([]string, 0, len(requestData.Items)) + + for _, i := range requestData.Items { + foundItems = append(foundItems, i.Code) + + if i.Code == v { + found = true + + break + } + } + + if !found { + t.Log("expected diagnostic", v, "not found in", foundItems) + + return false + } + } + + return true +} + +func createConnections( + ctx context.Context, + serverHandler, clientHandler func(_ context.Context, _ *jsonrpc2.Conn, req *jsonrpc2.Request) (result any, err error), +) (*jsonrpc2.Conn, *jsonrpc2.Conn, func()) { + netConnServer, netConnClient := net.Pipe() + + connServer := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewBufferedStream(netConnServer, jsonrpc2.VSCodeObjectCodec{}), + jsonrpc2.HandlerWithError(serverHandler), + ) + + connClient := jsonrpc2.NewConn( + ctx, + jsonrpc2.NewBufferedStream(netConnClient, jsonrpc2.VSCodeObjectCodec{}), + jsonrpc2.HandlerWithError(clientHandler), + ) + + cleanup := func() { + netConnServer.Close() + netConnClient.Close() + connServer.Close() + connClient.Close() + } + + return connServer, connClient, cleanup }