diff --git a/internal/lsp/server.go b/internal/lsp/server.go index d07a59d6..58672c2e 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -88,20 +88,21 @@ func NewLanguageServer(ctx context.Context, opts *LanguageServerOptions) *Langua var ls *LanguageServer ls = &LanguageServer{ - cache: c, - regoStore: store, - logWriter: opts.LogWriter, - logLevel: opts.LogLevel, - lintFileJobs: make(chan lintFileJob, 10), - lintWorkspaceJobs: make(chan lintWorkspaceJob, 10), - builtinsPositionJobs: make(chan lintFileJob, 10), - commandRequest: make(chan types.ExecuteCommandParams, 10), - templateFileJobs: make(chan lintFileJob, 10), - configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{LogFunc: ls.logf}), - completionsManager: completions.NewDefaultManager(ctx, c, store), - webServer: web.NewServer(c), - loadedBuiltins: concurrent.MapOf(make(map[string]map[string]*ast.Builtin)), - workspaceDiagnosticsPoll: opts.WorkspaceDiagnosticsPoll, + cache: c, + regoStore: store, + logWriter: opts.LogWriter, + logLevel: opts.LogLevel, + lintFileJobs: make(chan lintFileJob, 10), + lintWorkspaceJobs: make(chan lintWorkspaceJob, 10), + builtinsPositionJobs: make(chan lintFileJob, 10), + commandRequest: make(chan types.ExecuteCommandParams, 10), + templateFileJobs: make(chan lintFileJob, 10), + configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{LogFunc: ls.logf}), + completionsManager: completions.NewDefaultManager(ctx, c, store), + webServer: web.NewServer(c), + loadedBuiltins: concurrent.MapOf(make(map[string]map[string]*ast.Builtin)), + workspaceDiagnosticsPoll: opts.WorkspaceDiagnosticsPoll, + loadedConfigAllRegoVersions: concurrent.MapOf(make(map[string]ast.RegoVersion)), } return ls @@ -114,10 +115,13 @@ type LanguageServer struct { regoStore storage.Store conn *jsonrpc2.Conn - configWatcher *lsconfig.Watcher - loadedConfig *config.Config + configWatcher *lsconfig.Watcher + loadedConfig *config.Config + // this is also used to lock the updates to the cache of enabled rules + loadedConfigLock sync.Mutex loadedConfigEnabledNonAggregateRules []string loadedConfigEnabledAggregateRules []string + loadedConfigAllRegoVersions *concurrent.Map[string, ast.RegoVersion] loadedBuiltins *concurrent.Map[string, map[string]*ast.Builtin] clientInitializationOptions types.InitializationOptions @@ -138,9 +142,6 @@ type LanguageServer struct { workspaceRootURI string clientIdentifier clients.Identifier - // this is also used to lock the updates to the cache of enabled rules - loadedConfigLock sync.Mutex - workspaceDiagnosticsPoll time.Duration } @@ -551,6 +552,22 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) { l.loadedConfig = &mergedConfig l.loadedConfigLock.Unlock() + // Rego versions may have changed, so reload them. + allRegoVersions, err := config.AllRegoVersions( + uri.ToPath(l.clientIdentifier, l.workspaceRootURI), + l.getLoadedConfig(), + ) + if err != nil { + l.logf(log.LevelMessage, "failed to reload rego versions: %s", err) + } + + l.loadedConfigAllRegoVersions.Clear() + + for k, v := range allRegoVersions { + l.loadedConfigAllRegoVersions.Set(k, v) + } + + // Enabled rules might have changed with the new config, so reload. err = l.loadEnabledRulesFromConfig(ctx, mergedConfig) if err != nil { l.logf(log.LevelMessage, "failed to cache enabled rules: %s", err) @@ -1096,6 +1113,32 @@ func (l *LanguageServer) StartWebServer(ctx context.Context) { l.webServer.Start(ctx) } +func (l *LanguageServer) determineVersionForFile(fileURI string) ast.RegoVersion { + var versionedDirs []string + + // if we have no information, then we can return the default + if l.loadedConfigAllRegoVersions.Len() == 0 { + return ast.RegoV1 + } + + versionedDirs = util.Keys(l.loadedConfigAllRegoVersions.Clone()) + slices.Sort(versionedDirs) + slices.Reverse(versionedDirs) + + path := strings.TrimPrefix(fileURI, l.workspaceRootURI+"/") + + for _, versionedDir := range versionedDirs { + if strings.HasPrefix(path, versionedDir) { + val, ok := l.loadedConfigAllRegoVersions.Get(versionedDir) + if ok { + return val + } + } + } + + return ast.RegoV1 +} + func (l *LanguageServer) templateContentsForFile(fileURI string) (string, error) { // this function should not be called with files in the root, but if it is, // then it is an error to prevent unwanted behavior. @@ -1185,7 +1228,13 @@ func (l *LanguageServer) templateContentsForFile(fileURI string) (string, error) pkg += "_test" } - return fmt.Sprintf("package %s\n\nimport rego.v1\n", pkg), nil + version := l.determineVersionForFile(fileURI) + + if version == ast.RegoV0 { + return fmt.Sprintf("package %s\n\nimport rego.v1\n", pkg), nil + } + + return fmt.Sprintf("package %s\n\n", pkg), nil } func (l *LanguageServer) fixEditParams( diff --git a/internal/lsp/server_all_rego_versions_test.go b/internal/lsp/server_all_rego_versions_test.go new file mode 100644 index 00000000..856967c1 --- /dev/null +++ b/internal/lsp/server_all_rego_versions_test.go @@ -0,0 +1,148 @@ +package lsp + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/open-policy-agent/opa/v1/ast" + + "github.com/styrainc/regal/internal/lsp/clients" + "github.com/styrainc/regal/internal/lsp/log" + "github.com/styrainc/regal/internal/lsp/uri" + "github.com/styrainc/regal/pkg/config" +) + +func TestAllRegoVersions(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + FileKey string + ExpectedVersion ast.RegoVersion + DiskContents map[string]string + }{ + "unknown version": { + FileKey: "foo/bar.rego", + DiskContents: map[string]string{ + "foo/bar.rego": "package foo", + ".regal/config.yaml": "", + }, + ExpectedVersion: ast.RegoV1, + }, + "version set in project config": { + FileKey: "foo/bar.rego", + DiskContents: map[string]string{ + "foo/bar.rego": "package foo", + ".regal/config.yaml": ` +project: + rego-version: 0 +`, + }, + ExpectedVersion: ast.RegoV0, + }, + "version set in root config": { + FileKey: "foo/bar.rego", + DiskContents: map[string]string{ + "foo/bar.rego": "package foo", + ".regal/config.yaml": ` +project: + rego-version: 1 + roots: + - path: foo + rego-version: 0 +`, + }, + ExpectedVersion: ast.RegoV0, + }, + "version set in manifest": { + FileKey: "foo/bar.rego", + DiskContents: map[string]string{ + "foo/bar.rego": "package foo", + "foo/.manifest": `{"rego_version": 0}`, + ".regal/config.yaml": ``, + }, + ExpectedVersion: ast.RegoV0, + }, + "version set in manifest, overridden by config": { + FileKey: "foo/bar.rego", + DiskContents: map[string]string{ + "foo/bar.rego": "package foo", + "foo/.manifest": `{"rego_version": 1}`, + ".regal/config.yaml": ` +project: + roots: + - path: foo + rego-version: 0 +`, + }, + ExpectedVersion: ast.RegoV0, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + td := t.TempDir() + + // init the state on disk + for f, c := range tc.DiskContents { + dir := filepath.Dir(f) + + if err := os.MkdirAll(filepath.Join(td, dir), 0o755); err != nil { + t.Fatalf("failed to create directory %s: %s", dir, err) + } + + if err := os.WriteFile(filepath.Join(td, f), []byte(c), 0o600); err != nil { + t.Fatalf("failed to write file %s: %s", f, err) + } + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ls := NewLanguageServer(ctx, &LanguageServerOptions{LogWriter: newTestLogger(t), LogLevel: log.LevelDebug}) + ls.workspaceRootURI = uri.FromPath(clients.IdentifierGeneric, td) + + // have the server load the config + go ls.StartConfigWorker(ctx) + + configFile, err := config.FindConfig(td) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + ls.configWatcher.Watch(configFile.Name()) + + // wait for ls.loadedConfig to be set + timeout := time.NewTimer(determineTimeout()) + defer timeout.Stop() + + for success := false; !success; { + select { + default: + if ls.getLoadedConfig() != nil { + success = true + + break + } + + time.Sleep(500 * time.Millisecond) + case <-timeout.C: + t.Fatalf("timed out waiting for config to be set") + } + } + + // check it has the correct version for the file of interest + fileURI := uri.FromPath(clients.IdentifierGeneric, filepath.Join(td, tc.FileKey)) + + version := ls.determineVersionForFile(fileURI) + + if version != tc.ExpectedVersion { + t.Errorf("expected version %v, got %v", tc.ExpectedVersion, version) + } + }) + } +} diff --git a/internal/lsp/server_template_test.go b/internal/lsp/server_template_test.go index dc8b77b9..441044ca 100644 --- a/internal/lsp/server_template_test.go +++ b/internal/lsp/server_template_test.go @@ -12,22 +12,26 @@ import ( "github.com/sourcegraph/jsonrpc2" + "github.com/open-policy-agent/opa/v1/ast" + "github.com/styrainc/regal/internal/lsp/clients" "github.com/styrainc/regal/internal/lsp/log" "github.com/styrainc/regal/internal/lsp/types" "github.com/styrainc/regal/internal/lsp/uri" + "github.com/styrainc/regal/internal/util/concurrent" ) func TestTemplateContentsForFile(t *testing.T) { t.Parallel() testCases := map[string]struct { - FileKey string - CacheFileContents string - DiskContents map[string]string - RequireConfig bool - ExpectedContents string - ExpectedError string + FileKey string + CacheFileContents string + DiskContents map[string]string + RequireConfig bool + ServerAllRegoVersions *concurrent.Map[string, ast.RegoVersion] + ExpectedContents string + ExpectedError string }{ "existing contents in file": { FileKey: "foo/bar.rego", @@ -49,7 +53,7 @@ func TestTemplateContentsForFile(t *testing.T) { "foo/bar.rego": "", ".regal/config.yaml": "", }, - ExpectedContents: "package foo\n\nimport rego.v1\n", + ExpectedContents: "package foo\n\n", }, "empty test file is templated based on root": { FileKey: "foo/bar_test.rego", @@ -59,7 +63,7 @@ func TestTemplateContentsForFile(t *testing.T) { ".regal/config.yaml": "", }, RequireConfig: true, - ExpectedContents: "package foo_test\n\nimport rego.v1\n", + ExpectedContents: "package foo_test\n\n", }, "empty deeply nested file is templated based on root": { FileKey: "foo/bar/baz/bax.rego", @@ -68,8 +72,32 @@ func TestTemplateContentsForFile(t *testing.T) { "foo/bar/baz/bax.rego": "", ".regal/config.yaml": "", }, + ExpectedContents: "package foo.bar.baz\n\n", + }, + "v0 templating using rego version setting": { + FileKey: "foo/bar/baz/bax.rego", + CacheFileContents: "", + ServerAllRegoVersions: concurrent.MapOf(map[string]ast.RegoVersion{ + "foo": ast.RegoV0, + }), + DiskContents: map[string]string{ + "foo/bar/baz/bax.rego": "", + ".regal/config.yaml": "", // we manually set the versions, config not loaded in these tests + }, ExpectedContents: "package foo.bar.baz\n\nimport rego.v1\n", }, + "v1 templating using rego version setting": { + FileKey: "foo/bar/baz/bax.rego", + CacheFileContents: "", + ServerAllRegoVersions: concurrent.MapOf(map[string]ast.RegoVersion{ + "foo": ast.RegoV1, + }), + DiskContents: map[string]string{ + "foo/bar/baz/bax.rego": "", + ".regal/config.yaml": "", // we manually set the versions, config not loaded in these tests + }, + ExpectedContents: "package foo.bar.baz\n\n", + }, } for name, tc := range testCases { @@ -100,6 +128,8 @@ func TestTemplateContentsForFile(t *testing.T) { ls.workspaceRootURI = uri.FromPath(clients.IdentifierGeneric, td) + ls.loadedConfigAllRegoVersions = tc.ServerAllRegoVersions + fileURI := uri.FromPath(clients.IdentifierGeneric, filepath.Join(td, tc.FileKey)) ls.cache.SetFileContents(fileURI, tc.CacheFileContents) @@ -188,7 +218,6 @@ func TestTemplateContentsForFileWithUnknownRoot(t *testing.T) { exp := `package foo -import rego.v1 ` if exp != newContents { t.Errorf("unexpected content: %s, want %s", newContents, exp) @@ -279,7 +308,7 @@ func TestNewFileTemplating(t *testing.T) { { "edits": [ { - "newText": "package foo.bar_test\n\nimport rego.v1\n", + "newText": "package foo.bar_test\n\n", "range": { "end": { "character": 0, diff --git a/internal/util/concurrent/map.go b/internal/util/concurrent/map.go index c4bcf396..f0f8bf71 100644 --- a/internal/util/concurrent/map.go +++ b/internal/util/concurrent/map.go @@ -86,6 +86,10 @@ func (cm *Map[K, V]) Values() []V { } func (cm *Map[K, V]) Len() int { + if cm == nil { + return 0 + } + cm.murw.RLock() l := len(cm.m)