Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lsp/server: Cache AllRegoVersions at config load #1325

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 69 additions & 20 deletions internal/lsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,20 +88,21 @@ func NewLanguageServer(ctx context.Context, opts *LanguageServerOptions) *Langua

var ls *LanguageServer
ls = &LanguageServer{
cache: c,
regoStore: store,
logWriter: opts.LogWriter,
logLevel: opts.LogLevel,
lintFileJobs: make(chan lintFileJob, 10),
lintWorkspaceJobs: make(chan lintWorkspaceJob, 10),
builtinsPositionJobs: make(chan lintFileJob, 10),
commandRequest: make(chan types.ExecuteCommandParams, 10),
templateFileJobs: make(chan lintFileJob, 10),
configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{LogFunc: ls.logf}),
completionsManager: completions.NewDefaultManager(ctx, c, store),
webServer: web.NewServer(c),
loadedBuiltins: concurrent.MapOf(make(map[string]map[string]*ast.Builtin)),
workspaceDiagnosticsPoll: opts.WorkspaceDiagnosticsPoll,
cache: c,
regoStore: store,
logWriter: opts.LogWriter,
logLevel: opts.LogLevel,
lintFileJobs: make(chan lintFileJob, 10),
lintWorkspaceJobs: make(chan lintWorkspaceJob, 10),
builtinsPositionJobs: make(chan lintFileJob, 10),
commandRequest: make(chan types.ExecuteCommandParams, 10),
templateFileJobs: make(chan lintFileJob, 10),
configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{LogFunc: ls.logf}),
completionsManager: completions.NewDefaultManager(ctx, c, store),
webServer: web.NewServer(c),
loadedBuiltins: concurrent.MapOf(make(map[string]map[string]*ast.Builtin)),
workspaceDiagnosticsPoll: opts.WorkspaceDiagnosticsPoll,
loadedConfigAllRegoVersions: concurrent.MapOf(make(map[string]ast.RegoVersion)),
}

return ls
Expand All @@ -114,10 +115,13 @@ type LanguageServer struct {
regoStore storage.Store
conn *jsonrpc2.Conn

configWatcher *lsconfig.Watcher
loadedConfig *config.Config
configWatcher *lsconfig.Watcher
loadedConfig *config.Config
// this is also used to lock the updates to the cache of enabled rules
loadedConfigLock sync.Mutex
loadedConfigEnabledNonAggregateRules []string
loadedConfigEnabledAggregateRules []string
loadedConfigAllRegoVersions *concurrent.Map[string, ast.RegoVersion]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat :)

loadedBuiltins *concurrent.Map[string, map[string]*ast.Builtin]

clientInitializationOptions types.InitializationOptions
Expand All @@ -138,9 +142,6 @@ type LanguageServer struct {
workspaceRootURI string
clientIdentifier clients.Identifier

// this is also used to lock the updates to the cache of enabled rules
loadedConfigLock sync.Mutex

workspaceDiagnosticsPoll time.Duration
}

Expand Down Expand Up @@ -551,6 +552,22 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) {
l.loadedConfig = &mergedConfig
l.loadedConfigLock.Unlock()

// Rego versions may have changed, so reload them.
allRegoVersions, err := config.AllRegoVersions(
uri.ToPath(l.clientIdentifier, l.workspaceRootURI),
l.getLoadedConfig(),
)
if err != nil {
l.logf(log.LevelMessage, "failed to reload rego versions: %s", err)
}

l.loadedConfigAllRegoVersions.Clear()

for k, v := range allRegoVersions {
l.loadedConfigAllRegoVersions.Set(k, v)
}

// Enabled rules might have changed with the new config, so reload.
err = l.loadEnabledRulesFromConfig(ctx, mergedConfig)
if err != nil {
l.logf(log.LevelMessage, "failed to cache enabled rules: %s", err)
Expand Down Expand Up @@ -1096,6 +1113,32 @@ func (l *LanguageServer) StartWebServer(ctx context.Context) {
l.webServer.Start(ctx)
}

func (l *LanguageServer) determineVersionForFile(fileURI string) ast.RegoVersion {
var versionedDirs []string

// if we have no information, then we can return the default
if l.loadedConfigAllRegoVersions == nil || l.loadedConfigAllRegoVersions.Len() == 0 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should have Len() always return 0 when its parent is nil? It would make the check here a little cleaner.

return ast.RegoV1
}

versionedDirs = util.Keys(l.loadedConfigAllRegoVersions.Clone())
slices.Sort(versionedDirs)
slices.Reverse(versionedDirs)

path := strings.TrimPrefix(fileURI, l.workspaceRootURI+"/")

for _, versionedDir := range versionedDirs {
if strings.HasPrefix(path, versionedDir) {
val, ok := l.loadedConfigAllRegoVersions.Get(versionedDir)
if ok {
return val
}
}
}

return ast.RegoV1
}

func (l *LanguageServer) templateContentsForFile(fileURI string) (string, error) {
// this function should not be called with files in the root, but if it is,
// then it is an error to prevent unwanted behavior.
Expand Down Expand Up @@ -1185,7 +1228,13 @@ func (l *LanguageServer) templateContentsForFile(fileURI string) (string, error)
pkg += "_test"
}

return fmt.Sprintf("package %s\n\nimport rego.v1\n", pkg), nil
version := l.determineVersionForFile(fileURI)

if version == ast.RegoV0 {
return fmt.Sprintf("package %s\n\nimport rego.v1\n", pkg), nil
}

return fmt.Sprintf("package %s\n\n", pkg), nil
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a47ab6d
fixed here!

}

func (l *LanguageServer) fixEditParams(
Expand Down
148 changes: 148 additions & 0 deletions internal/lsp/server_all_rego_versions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package lsp

import (
"context"
"os"
"path/filepath"
"testing"
"time"

"github.com/open-policy-agent/opa/v1/ast"

"github.com/styrainc/regal/internal/lsp/clients"
"github.com/styrainc/regal/internal/lsp/log"
"github.com/styrainc/regal/internal/lsp/uri"
"github.com/styrainc/regal/pkg/config"
)

func TestAllRegoVersions(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
FileKey string
ExpectedVersion ast.RegoVersion
DiskContents map[string]string
}{
"unknown version": {
FileKey: "foo/bar.rego",
DiskContents: map[string]string{
"foo/bar.rego": "package foo",
".regal/config.yaml": "",
},
ExpectedVersion: ast.RegoV1,
},
"version set in project config": {
FileKey: "foo/bar.rego",
DiskContents: map[string]string{
"foo/bar.rego": "package foo",
".regal/config.yaml": `
project:
rego-version: 0
`,
},
ExpectedVersion: ast.RegoV0,
},
"version set in root config": {
FileKey: "foo/bar.rego",
DiskContents: map[string]string{
"foo/bar.rego": "package foo",
".regal/config.yaml": `
project:
rego-version: 1
roots:
- path: foo
rego-version: 0
`,
},
ExpectedVersion: ast.RegoV0,
},
"version set in manifest": {
FileKey: "foo/bar.rego",
DiskContents: map[string]string{
"foo/bar.rego": "package foo",
"foo/.manifest": `{"rego_version": 0}`,
".regal/config.yaml": ``,
},
ExpectedVersion: ast.RegoV0,
},
"version set in manifest, overridden by config": {
FileKey: "foo/bar.rego",
DiskContents: map[string]string{
"foo/bar.rego": "package foo",
"foo/.manifest": `{"rego_version": 1}`,
".regal/config.yaml": `
project:
roots:
- path: foo
rego-version: 0
`,
},
ExpectedVersion: ast.RegoV0,
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
t.Parallel()

td := t.TempDir()

// init the state on disk
for f, c := range tc.DiskContents {
dir := filepath.Dir(f)

if err := os.MkdirAll(filepath.Join(td, dir), 0o755); err != nil {
t.Fatalf("failed to create directory %s: %s", dir, err)
}

if err := os.WriteFile(filepath.Join(td, f), []byte(c), 0o600); err != nil {
t.Fatalf("failed to write file %s: %s", f, err)
}
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

ls := NewLanguageServer(ctx, &LanguageServerOptions{LogWriter: newTestLogger(t), LogLevel: log.LevelDebug})
ls.workspaceRootURI = uri.FromPath(clients.IdentifierGeneric, td)

// have the server load the config
go ls.StartConfigWorker(ctx)

configFile, err := config.FindConfig(td)
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

ls.configWatcher.Watch(configFile.Name())

// wait for ls.loadedConfig to be set
timeout := time.NewTimer(determineTimeout())
defer timeout.Stop()

for success := false; !success; {
select {
default:
if ls.getLoadedConfig() != nil {
success = true

break
}

time.Sleep(500 * time.Millisecond)
case <-timeout.C:
t.Fatalf("timed out waiting for config to be set")
}
}

// check it has the correct version for the file of interest
fileURI := uri.FromPath(clients.IdentifierGeneric, filepath.Join(td, tc.FileKey))

version := ls.determineVersionForFile(fileURI)

if version != tc.ExpectedVersion {
t.Errorf("expected version %v, got %v", tc.ExpectedVersion, version)
}
})
}
}
49 changes: 39 additions & 10 deletions internal/lsp/server_template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,26 @@ import (

"github.com/sourcegraph/jsonrpc2"

"github.com/open-policy-agent/opa/v1/ast"

"github.com/styrainc/regal/internal/lsp/clients"
"github.com/styrainc/regal/internal/lsp/log"
"github.com/styrainc/regal/internal/lsp/types"
"github.com/styrainc/regal/internal/lsp/uri"
"github.com/styrainc/regal/internal/util/concurrent"
)

func TestTemplateContentsForFile(t *testing.T) {
t.Parallel()

testCases := map[string]struct {
FileKey string
CacheFileContents string
DiskContents map[string]string
RequireConfig bool
ExpectedContents string
ExpectedError string
FileKey string
CacheFileContents string
DiskContents map[string]string
RequireConfig bool
ServerAllRegoVersions *concurrent.Map[string, ast.RegoVersion]
ExpectedContents string
ExpectedError string
}{
"existing contents in file": {
FileKey: "foo/bar.rego",
Expand All @@ -49,7 +53,7 @@ func TestTemplateContentsForFile(t *testing.T) {
"foo/bar.rego": "",
".regal/config.yaml": "",
},
ExpectedContents: "package foo\n\nimport rego.v1\n",
ExpectedContents: "package foo\n\n",
},
"empty test file is templated based on root": {
FileKey: "foo/bar_test.rego",
Expand All @@ -59,7 +63,7 @@ func TestTemplateContentsForFile(t *testing.T) {
".regal/config.yaml": "",
},
RequireConfig: true,
ExpectedContents: "package foo_test\n\nimport rego.v1\n",
ExpectedContents: "package foo_test\n\n",
},
"empty deeply nested file is templated based on root": {
FileKey: "foo/bar/baz/bax.rego",
Expand All @@ -68,8 +72,32 @@ func TestTemplateContentsForFile(t *testing.T) {
"foo/bar/baz/bax.rego": "",
".regal/config.yaml": "",
},
ExpectedContents: "package foo.bar.baz\n\n",
},
"v0 templating using rego version setting": {
FileKey: "foo/bar/baz/bax.rego",
CacheFileContents: "",
ServerAllRegoVersions: concurrent.MapOf(map[string]ast.RegoVersion{
"foo": ast.RegoV0,
}),
DiskContents: map[string]string{
"foo/bar/baz/bax.rego": "",
".regal/config.yaml": "", // we manually set the versions, config not loaded in these tests
},
ExpectedContents: "package foo.bar.baz\n\nimport rego.v1\n",
},
"v1 templating using rego version setting": {
FileKey: "foo/bar/baz/bax.rego",
CacheFileContents: "",
ServerAllRegoVersions: concurrent.MapOf(map[string]ast.RegoVersion{
"foo": ast.RegoV1,
}),
DiskContents: map[string]string{
"foo/bar/baz/bax.rego": "",
".regal/config.yaml": "", // we manually set the versions, config not loaded in these tests
},
ExpectedContents: "package foo.bar.baz\n\n",
},
}

for name, tc := range testCases {
Expand Down Expand Up @@ -100,6 +128,8 @@ func TestTemplateContentsForFile(t *testing.T) {

ls.workspaceRootURI = uri.FromPath(clients.IdentifierGeneric, td)

ls.loadedConfigAllRegoVersions = tc.ServerAllRegoVersions

fileURI := uri.FromPath(clients.IdentifierGeneric, filepath.Join(td, tc.FileKey))

ls.cache.SetFileContents(fileURI, tc.CacheFileContents)
Expand Down Expand Up @@ -188,7 +218,6 @@ func TestTemplateContentsForFileWithUnknownRoot(t *testing.T) {

exp := `package foo
import rego.v1
`
if exp != newContents {
t.Errorf("unexpected content: %s, want %s", newContents, exp)
Expand Down Expand Up @@ -279,7 +308,7 @@ func TestNewFileTemplating(t *testing.T) {
{
"edits": [
{
"newText": "package foo.bar_test\n\nimport rego.v1\n",
"newText": "package foo.bar_test\n\n",
"range": {
"end": {
"character": 0,
Expand Down
Loading