-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lsp/server: Cache AllRegoVersions at config load #1325
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,20 +88,21 @@ func NewLanguageServer(ctx context.Context, opts *LanguageServerOptions) *Langua | |
|
||
var ls *LanguageServer | ||
ls = &LanguageServer{ | ||
cache: c, | ||
regoStore: store, | ||
logWriter: opts.LogWriter, | ||
logLevel: opts.LogLevel, | ||
lintFileJobs: make(chan lintFileJob, 10), | ||
lintWorkspaceJobs: make(chan lintWorkspaceJob, 10), | ||
builtinsPositionJobs: make(chan lintFileJob, 10), | ||
commandRequest: make(chan types.ExecuteCommandParams, 10), | ||
templateFileJobs: make(chan lintFileJob, 10), | ||
configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{LogFunc: ls.logf}), | ||
completionsManager: completions.NewDefaultManager(ctx, c, store), | ||
webServer: web.NewServer(c), | ||
loadedBuiltins: concurrent.MapOf(make(map[string]map[string]*ast.Builtin)), | ||
workspaceDiagnosticsPoll: opts.WorkspaceDiagnosticsPoll, | ||
cache: c, | ||
regoStore: store, | ||
logWriter: opts.LogWriter, | ||
logLevel: opts.LogLevel, | ||
lintFileJobs: make(chan lintFileJob, 10), | ||
lintWorkspaceJobs: make(chan lintWorkspaceJob, 10), | ||
builtinsPositionJobs: make(chan lintFileJob, 10), | ||
commandRequest: make(chan types.ExecuteCommandParams, 10), | ||
templateFileJobs: make(chan lintFileJob, 10), | ||
configWatcher: lsconfig.NewWatcher(&lsconfig.WatcherOpts{LogFunc: ls.logf}), | ||
completionsManager: completions.NewDefaultManager(ctx, c, store), | ||
webServer: web.NewServer(c), | ||
loadedBuiltins: concurrent.MapOf(make(map[string]map[string]*ast.Builtin)), | ||
workspaceDiagnosticsPoll: opts.WorkspaceDiagnosticsPoll, | ||
loadedConfigAllRegoVersions: concurrent.MapOf(make(map[string]ast.RegoVersion)), | ||
} | ||
|
||
return ls | ||
|
@@ -114,10 +115,13 @@ type LanguageServer struct { | |
regoStore storage.Store | ||
conn *jsonrpc2.Conn | ||
|
||
configWatcher *lsconfig.Watcher | ||
loadedConfig *config.Config | ||
configWatcher *lsconfig.Watcher | ||
loadedConfig *config.Config | ||
// this is also used to lock the updates to the cache of enabled rules | ||
loadedConfigLock sync.Mutex | ||
loadedConfigEnabledNonAggregateRules []string | ||
loadedConfigEnabledAggregateRules []string | ||
loadedConfigAllRegoVersions *concurrent.Map[string, ast.RegoVersion] | ||
loadedBuiltins *concurrent.Map[string, map[string]*ast.Builtin] | ||
|
||
clientInitializationOptions types.InitializationOptions | ||
|
@@ -138,9 +142,6 @@ type LanguageServer struct { | |
workspaceRootURI string | ||
clientIdentifier clients.Identifier | ||
|
||
// this is also used to lock the updates to the cache of enabled rules | ||
loadedConfigLock sync.Mutex | ||
|
||
workspaceDiagnosticsPoll time.Duration | ||
} | ||
|
||
|
@@ -551,6 +552,22 @@ func (l *LanguageServer) StartConfigWorker(ctx context.Context) { | |
l.loadedConfig = &mergedConfig | ||
l.loadedConfigLock.Unlock() | ||
|
||
// Rego versions may have changed, so reload them. | ||
allRegoVersions, err := config.AllRegoVersions( | ||
uri.ToPath(l.clientIdentifier, l.workspaceRootURI), | ||
l.getLoadedConfig(), | ||
) | ||
if err != nil { | ||
l.logf(log.LevelMessage, "failed to reload rego versions: %s", err) | ||
} | ||
|
||
l.loadedConfigAllRegoVersions.Clear() | ||
|
||
for k, v := range allRegoVersions { | ||
l.loadedConfigAllRegoVersions.Set(k, v) | ||
} | ||
|
||
// Enabled rules might have changed with the new config, so reload. | ||
err = l.loadEnabledRulesFromConfig(ctx, mergedConfig) | ||
if err != nil { | ||
l.logf(log.LevelMessage, "failed to cache enabled rules: %s", err) | ||
|
@@ -1096,6 +1113,32 @@ func (l *LanguageServer) StartWebServer(ctx context.Context) { | |
l.webServer.Start(ctx) | ||
} | ||
|
||
func (l *LanguageServer) determineVersionForFile(fileURI string) ast.RegoVersion { | ||
var versionedDirs []string | ||
|
||
// if we have no information, then we can return the default | ||
if l.loadedConfigAllRegoVersions == nil || l.loadedConfigAllRegoVersions.Len() == 0 { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we should have Len() always return 0 when its parent is nil? It would make the check here a little cleaner. |
||
return ast.RegoV1 | ||
} | ||
|
||
versionedDirs = util.Keys(l.loadedConfigAllRegoVersions.Clone()) | ||
slices.Sort(versionedDirs) | ||
slices.Reverse(versionedDirs) | ||
|
||
path := strings.TrimPrefix(fileURI, l.workspaceRootURI+"/") | ||
|
||
for _, versionedDir := range versionedDirs { | ||
if strings.HasPrefix(path, versionedDir) { | ||
val, ok := l.loadedConfigAllRegoVersions.Get(versionedDir) | ||
if ok { | ||
return val | ||
} | ||
} | ||
} | ||
|
||
return ast.RegoV1 | ||
} | ||
|
||
func (l *LanguageServer) templateContentsForFile(fileURI string) (string, error) { | ||
// this function should not be called with files in the root, but if it is, | ||
// then it is an error to prevent unwanted behavior. | ||
|
@@ -1185,7 +1228,13 @@ func (l *LanguageServer) templateContentsForFile(fileURI string) (string, error) | |
pkg += "_test" | ||
} | ||
|
||
return fmt.Sprintf("package %s\n\nimport rego.v1\n", pkg), nil | ||
version := l.determineVersionForFile(fileURI) | ||
|
||
if version == ast.RegoV0 { | ||
return fmt.Sprintf("package %s\n\nimport rego.v1\n", pkg), nil | ||
} | ||
|
||
return fmt.Sprintf("package %s\n\n", pkg), nil | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very good! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a47ab6d |
||
} | ||
|
||
func (l *LanguageServer) fixEditParams( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
package lsp | ||
|
||
import ( | ||
"context" | ||
"os" | ||
"path/filepath" | ||
"testing" | ||
"time" | ||
|
||
"github.com/open-policy-agent/opa/v1/ast" | ||
|
||
"github.com/styrainc/regal/internal/lsp/clients" | ||
"github.com/styrainc/regal/internal/lsp/log" | ||
"github.com/styrainc/regal/internal/lsp/uri" | ||
"github.com/styrainc/regal/pkg/config" | ||
) | ||
|
||
func TestAllRegoVersions(t *testing.T) { | ||
t.Parallel() | ||
|
||
testCases := map[string]struct { | ||
FileKey string | ||
ExpectedVersion ast.RegoVersion | ||
DiskContents map[string]string | ||
}{ | ||
"unknown version": { | ||
FileKey: "foo/bar.rego", | ||
DiskContents: map[string]string{ | ||
"foo/bar.rego": "package foo", | ||
".regal/config.yaml": "", | ||
}, | ||
ExpectedVersion: ast.RegoV1, | ||
}, | ||
"version set in project config": { | ||
FileKey: "foo/bar.rego", | ||
DiskContents: map[string]string{ | ||
"foo/bar.rego": "package foo", | ||
".regal/config.yaml": ` | ||
project: | ||
rego-version: 0 | ||
`, | ||
}, | ||
ExpectedVersion: ast.RegoV0, | ||
}, | ||
"version set in root config": { | ||
FileKey: "foo/bar.rego", | ||
DiskContents: map[string]string{ | ||
"foo/bar.rego": "package foo", | ||
".regal/config.yaml": ` | ||
project: | ||
rego-version: 1 | ||
roots: | ||
- path: foo | ||
rego-version: 0 | ||
`, | ||
}, | ||
ExpectedVersion: ast.RegoV0, | ||
}, | ||
"version set in manifest": { | ||
FileKey: "foo/bar.rego", | ||
DiskContents: map[string]string{ | ||
"foo/bar.rego": "package foo", | ||
"foo/.manifest": `{"rego_version": 0}`, | ||
".regal/config.yaml": ``, | ||
}, | ||
ExpectedVersion: ast.RegoV0, | ||
}, | ||
"version set in manifest, overridden by config": { | ||
FileKey: "foo/bar.rego", | ||
DiskContents: map[string]string{ | ||
"foo/bar.rego": "package foo", | ||
"foo/.manifest": `{"rego_version": 1}`, | ||
".regal/config.yaml": ` | ||
project: | ||
roots: | ||
- path: foo | ||
rego-version: 0 | ||
`, | ||
}, | ||
ExpectedVersion: ast.RegoV0, | ||
}, | ||
} | ||
|
||
for name, tc := range testCases { | ||
t.Run(name, func(t *testing.T) { | ||
t.Parallel() | ||
|
||
td := t.TempDir() | ||
|
||
// init the state on disk | ||
for f, c := range tc.DiskContents { | ||
dir := filepath.Dir(f) | ||
|
||
if err := os.MkdirAll(filepath.Join(td, dir), 0o755); err != nil { | ||
t.Fatalf("failed to create directory %s: %s", dir, err) | ||
} | ||
|
||
if err := os.WriteFile(filepath.Join(td, f), []byte(c), 0o600); err != nil { | ||
t.Fatalf("failed to write file %s: %s", f, err) | ||
} | ||
} | ||
|
||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
|
||
ls := NewLanguageServer(ctx, &LanguageServerOptions{LogWriter: newTestLogger(t), LogLevel: log.LevelDebug}) | ||
ls.workspaceRootURI = uri.FromPath(clients.IdentifierGeneric, td) | ||
|
||
// have the server load the config | ||
go ls.StartConfigWorker(ctx) | ||
|
||
configFile, err := config.FindConfig(td) | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
|
||
ls.configWatcher.Watch(configFile.Name()) | ||
|
||
// wait for ls.loadedConfig to be set | ||
timeout := time.NewTimer(determineTimeout()) | ||
defer timeout.Stop() | ||
|
||
for success := false; !success; { | ||
select { | ||
default: | ||
if ls.getLoadedConfig() != nil { | ||
success = true | ||
|
||
break | ||
} | ||
|
||
time.Sleep(500 * time.Millisecond) | ||
case <-timeout.C: | ||
t.Fatalf("timed out waiting for config to be set") | ||
} | ||
} | ||
|
||
// check it has the correct version for the file of interest | ||
fileURI := uri.FromPath(clients.IdentifierGeneric, filepath.Join(td, tc.FileKey)) | ||
|
||
version := ls.determineVersionForFile(fileURI) | ||
|
||
if version != tc.ExpectedVersion { | ||
t.Errorf("expected version %v, got %v", tc.ExpectedVersion, version) | ||
} | ||
}) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat :)