diff --git a/internal/lsp/cache.go b/internal/lsp/cache.go index f40b0b0f..bc015125 100644 --- a/internal/lsp/cache.go +++ b/internal/lsp/cache.go @@ -2,7 +2,6 @@ package lsp import ( "fmt" - "net/url" "os" "sync" @@ -196,17 +195,8 @@ func (c *Cache) Delete(uri string) { c.diagnosticsParseMu.Unlock() } -func updateCacheForURIFromDisk(cache *Cache, uri string) (string, error) { - parsedURI, err := url.Parse(uri) - if err != nil { - return "", fmt.Errorf("failed to parse URI: %w", err) - } - - if parsedURI.Scheme != "file" { - return "", fmt.Errorf("only file:// URIs are supported, got %q", parsedURI.String()) - } - - content, err := os.ReadFile(parsedURI.Path) +func updateCacheForURIFromDisk(cache *Cache, uri, path string) (string, error) { + content, err := os.ReadFile(path) if err != nil { return "", fmt.Errorf("failed to read file: %w", err) } diff --git a/internal/lsp/clients/clients.go b/internal/lsp/clients/clients.go new file mode 100644 index 00000000..147525ea --- /dev/null +++ b/internal/lsp/clients/clients.go @@ -0,0 +1,18 @@ +package clients + +// Identifier represent different supported clients and can be used to toggle or change +// server behavior based on the client. +type Identifier int + +const ( + IdentifierGeneric Identifier = iota + IdentifierVSCode +) + +func DetermineClientIdentifier(clientName string) Identifier { + if clientName == "Visual Studio Code" { + return IdentifierVSCode + } + + return IdentifierGeneric +} diff --git a/internal/lsp/messages.go b/internal/lsp/messages.go index 8924e6d8..3dd72d6f 100644 --- a/internal/lsp/messages.go +++ b/internal/lsp/messages.go @@ -120,6 +120,15 @@ type TextDocumentIdentifier struct { URI string `json:"uri"` } +type TextDocumentDidChangeParams struct { + TextDocument TextDocumentIdentifier `json:"textDocument"` + ContentChanges []TextDocumentContentChangeEvent `json:"contentChanges"` +} + +type TextDocumentContentChangeEvent struct { + Text string `json:"text"` +} + type Diagnostic struct { Range Range `json:"range"` Message string `json:"message"` diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 419bcaa6..1dc91dec 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "net/url" "os" "path/filepath" "strings" @@ -15,6 +14,8 @@ import ( "github.com/sourcegraph/jsonrpc2" "gopkg.in/yaml.v3" + "github.com/styrainc/regal/internal/lsp/clients" + "github.com/styrainc/regal/internal/lsp/uri" "github.com/styrainc/regal/pkg/config" ) @@ -51,7 +52,16 @@ type LanguageServer struct { diagnosticRequestFile chan fileDiagnosticRequiredEvent diagnosticRequestWorkspace chan string - clientRootURI string + clientRootURI string + clientIdentifier clients.Identifier +} + +// fileDiagnosticRequiredEvent is sent to the diagnosticRequestFile channel when +// diagnostics are required for a file. +type fileDiagnosticRequiredEvent struct { + Reason string + URI string + Content string } func (l *LanguageServer) Handle( @@ -256,14 +266,6 @@ func (l *LanguageServer) logOutboundMessage(method string, message any) { } } -// fileDiagnosticRequiredEvent is sent to the diagnosticRequestFile channel when -// diagnostics are required for a file. -type fileDiagnosticRequiredEvent struct { - Reason string - URI string - Content string -} - func (l *LanguageServer) handleTextDocumentDidOpen( _ context.Context, _ *jsonrpc2.Conn, @@ -283,15 +285,6 @@ func (l *LanguageServer) handleTextDocumentDidOpen( return struct{}{}, nil } -type TextDocumentDidChangeParams struct { - TextDocument TextDocumentIdentifier `json:"textDocument"` - ContentChanges []TextDocumentContentChangeEvent `json:"contentChanges"` -} - -type TextDocumentContentChangeEvent struct { - Text string `json:"text"` -} - func (l *LanguageServer) handleTextDocumentDidChange( _ context.Context, _ *jsonrpc2.Conn, @@ -322,7 +315,11 @@ func (l *LanguageServer) handleWorkspaceDidCreateFiles( } for _, createOp := range params.Files { - _, err = updateCacheForURIFromDisk(l.cache, createOp.URI) + _, err = updateCacheForURIFromDisk( + l.cache, + uri.FromPath(l.clientIdentifier, createOp.URI), + uri.ToPath(l.clientIdentifier, createOp.URI), + ) if err != nil { return nil, fmt.Errorf("failed to update cache for uri %q: %w", createOp.URI, err) } @@ -369,7 +366,11 @@ func (l *LanguageServer) handleWorkspaceDidRenameFiles( } for _, renameOp := range params.Files { - content, err := updateCacheForURIFromDisk(l.cache, renameOp.NewURI) + content, err := updateCacheForURIFromDisk( + l.cache, + uri.FromPath(l.clientIdentifier, renameOp.NewURI), + uri.ToPath(l.clientIdentifier, renameOp.NewURI), + ) if err != nil { return nil, fmt.Errorf("failed to update cache for uri %q: %w", renameOp.NewURI, err) } @@ -418,6 +419,13 @@ func (l *LanguageServer) handleInitialize( } l.clientRootURI = params.RootURI + l.clientIdentifier = clients.DetermineClientIdentifier(params.ClientInfo.Name) + + if l.clientIdentifier == clients.IdentifierGeneric { + l.log( + "Unable to match client identifier for initializing client, using generic functionality: " + params.ClientInfo.Name, + ) + } regoFilter := FileOperationFilter{ Scheme: "file", @@ -453,15 +461,12 @@ func (l *LanguageServer) handleInitialize( }, } - folderURI, err := url.Parse(l.clientRootURI) - if err != nil { - return nil, fmt.Errorf("failed to parse URI: %w", err) - } + workspaceRootPath := uri.ToPath(l.clientIdentifier, l.clientRootURI) // load the rego source files into the cache - err = filepath.WalkDir(folderURI.Path, func(path string, d os.DirEntry, err error) error { + err = filepath.WalkDir(workspaceRootPath, func(path string, d os.DirEntry, err error) error { if err != nil { - return fmt.Errorf("failed to walk workspace dir %q: %w", folderURI.Path, err) + return fmt.Errorf("failed to walk workspace dir %q: %w", d.Name(), err) } // TODO(charlieegan3): make this configurable for things like .rq etc? @@ -469,12 +474,14 @@ func (l *LanguageServer) handleInitialize( return nil } - _, err = updateCacheForURIFromDisk(l.cache, "file://"+path) + fileURI := uri.FromPath(l.clientIdentifier, path) + + _, err = updateCacheForURIFromDisk(l.cache, fileURI, path) if err != nil { return fmt.Errorf("failed to update cache for uri %q: %w", path, err) } - _, err = updateParse(l.cache, "file://"+path) + _, err = updateParse(l.cache, fileURI) if err != nil { return fmt.Errorf("failed to update parse: %w", err) } @@ -482,7 +489,7 @@ func (l *LanguageServer) handleInitialize( return nil }) if err != nil { - return nil, fmt.Errorf("failed to walk workspace dir %q: %w", folderURI.Path, err) + return nil, fmt.Errorf("failed to walk workspace dir %q: %w", workspaceRootPath, err) } // attempt to load the config as it is found on disk diff --git a/internal/lsp/uri/uri.go b/internal/lsp/uri/uri.go new file mode 100644 index 00000000..f6316fb6 --- /dev/null +++ b/internal/lsp/uri/uri.go @@ -0,0 +1,45 @@ +package uri + +import ( + "path/filepath" + "strings" + + "github.com/styrainc/regal/internal/lsp/clients" +) + +// FromPath converts a file path to a URI for a given client. +// Since clients expect URIs to be in a specific format, this function +// will convert the path to the appropriate format for the client. +func FromPath(client clients.Identifier, path string) string { + path = strings.TrimPrefix(path, "file://") + path = strings.TrimPrefix(path, "/") + + if client == clients.IdentifierVSCode { + // Convert Windows path separators to Unix separators + path = filepath.ToSlash(path) + + // If the path is a Windows path, the colon after the drive letter needs to be + // percent-encoded. + if parts := strings.Split(path, ":"); len(parts) > 1 { + path = parts[0] + "%3A" + parts[1] + } + } + + return "file://" + "/" + path +} + +// ToPath converts a URI to a file path from a format for a given client. +// Some clients represent URIs differently, and so this function exists to convert +// client URIs into a standard file paths. +func ToPath(client clients.Identifier, uri string) string { + path := strings.TrimPrefix(uri, "file://") + + if client == clients.IdentifierVSCode { + if strings.Contains(path, ":") || strings.Contains(path, "%3A") { + path = strings.Replace(path, "%3A", ":", 1) + path = strings.TrimPrefix(path, "/") + } + } + + return path +} diff --git a/internal/lsp/uri/uri_test.go b/internal/lsp/uri/uri_test.go new file mode 100644 index 00000000..0578e1c6 --- /dev/null +++ b/internal/lsp/uri/uri_test.go @@ -0,0 +1,166 @@ +package uri + +import ( + "testing" + + "github.com/styrainc/regal/internal/lsp/clients" +) + +func TestPathToURI(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + path string + want string + }{ + "unix simple": { + path: "/foo/bar", + want: "file:///foo/bar", + }, + "unix prefixed": { + path: "file:///foo/bar", + want: "file:///foo/bar", + }, + "windows not encoded": { + path: "c:/foo/bar", + want: "file:///c:/foo/bar", + }, + } + + for label, tc := range testCases { + tt := tc + + t.Run(label, func(t *testing.T) { + t.Parallel() + + got := FromPath(clients.IdentifierGeneric, tt.path) + + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestPathToURI_VSCode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + path string + want string + }{ + "unix simple": { + path: "/foo/bar", + want: "file:///foo/bar", + }, + "unix prefixed": { + path: "file:///foo/bar", + want: "file:///foo/bar", + }, + "windows encoded": { + path: "c%3A/foo/bar", + want: "file:///c%3A/foo/bar", + }, + "windows not encoded": { + path: "c:/foo/bar", + want: "file:///c%3A/foo/bar", + }, + } + + for label, tc := range testCases { + tt := tc + + t.Run(label, func(t *testing.T) { + t.Parallel() + + got := FromPath(clients.IdentifierVSCode, tt.path) + + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestURIToPath(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + uri string + want string + }{ + "unix unprefixed": { + uri: "/foo/bar", + want: "/foo/bar", + }, + "unix simple": { + uri: "file:///foo/bar", + want: "/foo/bar", + }, + "windows not encoded": { + uri: "file://c:/foo/bar", + want: "c:/foo/bar", + }, + } + + for label, tc := range testCases { + tt := tc + + t.Run(label, func(t *testing.T) { + t.Parallel() + + got := ToPath(clients.IdentifierGeneric, tt.uri) + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestURIToPath_VSCode(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + uri string + want string + }{ + "unix unprefixed": { + uri: "/foo/bar", + want: "/foo/bar", + }, + "unix simple": { + uri: "file:///foo/bar", + want: "/foo/bar", + }, + "windows encoded": { + uri: "file:///c%3A/foo/bar", + want: "c:/foo/bar", + }, + // these other examples shouldn't happen, but we should handle them + "windows not encoded": { + uri: "file://c:/foo/bar", + want: "c:/foo/bar", + }, + "windows not prefixed": { + uri: "c:/foo/bar", + want: "c:/foo/bar", + }, + "windows not prefixed, but encoded": { + uri: "c%3A/foo/bar", + want: "c:/foo/bar", + }, + } + + for label, tc := range testCases { + tt := tc + + t.Run(label, func(t *testing.T) { + t.Parallel() + + got := ToPath(clients.IdentifierVSCode, tt.uri) + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +}