diff --git a/cmd/fix.go b/cmd/fix.go index d08f43ac..1bd3967e 100644 --- a/cmd/fix.go +++ b/cmd/fix.go @@ -237,7 +237,7 @@ func fix(args []string, params *fixCommandParams) error { var userConfig config.Config - userConfigFile, err := readUserConfig(params, regalDir) + userConfigFile, err := readUserConfig(params, configSearchPath) switch { case err == nil: diff --git a/cmd/lint.go b/cmd/lint.go index dac93dc3..dc3e630d 100644 --- a/cmd/lint.go +++ b/cmd/lint.go @@ -299,7 +299,7 @@ func lint(args []string, params *lintCommandParams) (report.Report, error) { var userConfig config.Config - userConfigFile, err := readUserConfig(params, regalDir) + userConfigFile, err := readUserConfig(params, configSearchPath) switch { case err == nil: diff --git a/cmd/utils.go b/cmd/utils.go index 1f02c302..b6717a05 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -15,21 +15,18 @@ type configFileParams interface { getConfigFile() string } -func readUserConfig(params configFileParams, regalDir *os.File) (userConfig *os.File, err error) { +func readUserConfig(params configFileParams, searchPath string) (userConfig *os.File, err error) { if cfgFile := params.getConfigFile(); cfgFile != "" { userConfig, err = os.Open(cfgFile) if err != nil { return nil, fmt.Errorf("failed to open config file %w", err) } } else { - searchPath, _ := os.Getwd() - if regalDir != nil { - searchPath = regalDir.Name() + if searchPath == "" { + searchPath, _ = os.Getwd() } - if searchPath != "" { - userConfig, err = config.FindConfig(searchPath) - } + userConfig, err = config.FindConfig(searchPath) } return userConfig, err //nolint:wrapcheck diff --git a/e2e/cli_test.go b/e2e/cli_test.go index 8f42ab36..62982a1a 100644 --- a/e2e/cli_test.go +++ b/e2e/cli_test.go @@ -1201,6 +1201,81 @@ import rego.v1 } } +func TestFixSingleFileNested(t *testing.T) { + t.Parallel() + + stdout, stderr := bytes.Buffer{}, bytes.Buffer{} + td := t.TempDir() + + initialState := map[string]string{ + ".regal/config.yaml": ` +project: + rego-version: 1 +`, + "foo/.regal.yaml": ` +project: + rego-version: 1 +rules: + style: + opa-fmt: + level: ignore +`, + "foo/foo.rego": `package wow`, + } + + for file, content := range initialState { + mustWriteToFile(t, filepath.Join(td, file), content) + } + + // --force is required to make the changes when there is no git repo + err := regal(&stdout, &stderr)( + "fix", + "--force", + filepath.Join(td, "foo/foo.rego"), + ) + + // 0 exit status is expected as all violations should have been fixed + expectExitCode(t, err, 0, &stdout, &stderr) + + exp := fmt.Sprintf(`1 fix applied: +In project root: %[1]s +foo.rego -> wow/foo.rego: +- directory-package-mismatch +`, filepath.Join(td, "foo")) + + if act := stdout.String(); exp != act { + t.Fatalf("expected stdout:\n%s\ngot:\n%s", exp, act) + } + + if exp, act := "", stderr.String(); exp != act { + t.Fatalf("expected stderr %q, got %q", exp, act) + } + + expectedState := map[string]string{ + ".regal/config.yaml": ` +project: + rego-version: 1 +`, + "foo/.regal.yaml": ` +project: + rego-version: 1 +rules: + style: + opa-fmt: + level: ignore +`, + "foo/wow/foo.rego": `package wow`, + } + + for file, expectedContent := range expectedState { + bs := testutil.Must(os.ReadFile(filepath.Join(td, file)))(t) + + if act := string(bs); expectedContent != act { + t.Errorf("expected %s contents:\n%s\ngot\n%s", file, expectedContent, act) + } + } +} + // verify fix for https://github.com/StyraInc/regal/issues/1082 func TestLintAnnotationCustomAttributeMultipleItems(t *testing.T) { t.Parallel() diff --git a/pkg/config/config.go b/pkg/config/config.go index 3be6ae3d..e3614f36 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -246,7 +246,7 @@ func FindBundleRootDirectories(path string) ([]string, error) { // This will traverse the tree **upwards** searching for a .regal directory regalDir, err := FindRegalDirectory(path) if err == nil { - roots, err := rootsFromRegalDirectory(regalDir) + roots, err := rootsFromRegalConfigDirOrFile(regalDir) if err != nil { return nil, fmt.Errorf("failed to get roots from .regal directory: %w", err) } @@ -254,6 +254,17 @@ func FindBundleRootDirectories(path string) ([]string, error) { foundBundleRoots = append(foundBundleRoots, roots...) } + // This will traverse the tree **upwards** searching for a .regal.yaml file + regalConfigFile, err := FindRegalConfigFile(path) + if err == nil { + roots, err := rootsFromRegalConfigDirOrFile(regalConfigFile) + if err != nil { + return nil, fmt.Errorf("failed to get roots from .regal.yaml: %w", err) + } + + foundBundleRoots = append(foundBundleRoots, roots...) + } + // This will traverse the tree **downwards** searching for .regal directories // Not using rio.WalkFiles here as we're specifically looking for directories if err := filepath.WalkDir(path, func(path string, info os.DirEntry, err error) error { @@ -272,7 +283,7 @@ func FindBundleRootDirectories(path string) ([]string, error) { defer rd.Close() - roots, err := rootsFromRegalDirectory(rd) + roots, err := rootsFromRegalConfigDirOrFile(rd) if err != nil { return fmt.Errorf("failed to get roots from .regal directory: %w", err) } @@ -295,23 +306,38 @@ func FindBundleRootDirectories(path string) ([]string, error) { return slices.Compact(foundBundleRoots), nil } -func rootsFromRegalDirectory(regalDir *os.File) ([]string, error) { - foundBundleRoots := make([]string, 0) +func rootsFromRegalConfigDirOrFile(file *os.File) ([]string, error) { + defer file.Close() - defer regalDir.Close() + fileInfo, err := file.Stat() + if err != nil { + return nil, fmt.Errorf("failed to stat file: %w", err) + } + + if (fileInfo.IsDir() && filepath.Base(file.Name()) != ".regal") || + (!fileInfo.IsDir() && filepath.Base(file.Name()) != ".regal.yaml") { + return nil, fmt.Errorf( + "expected a directory named '.regal' or a file named '.regal.yaml', got '%s'", + filepath.Base(file.Name()), + ) + } - parent, _ := filepath.Split(regalDir.Name()) + parent := filepath.Dir(file.Name()) - parent = filepath.Clean(parent) + foundBundleRoots := []string{parent} - // add the parent directory of .regal - foundBundleRoots = append(foundBundleRoots, parent) + var configFilePath string - file, err := os.ReadFile(filepath.Join(regalDir.Name(), "config.yaml")) + if fileInfo.IsDir() { + configFilePath = filepath.Join(file.Name(), "config.yaml") + } else { + configFilePath = file.Name() + } + + fileContent, err := os.ReadFile(configFilePath) if err == nil { var conf Config - - if err = yaml.Unmarshal(file, &conf); err != nil { + if err = yaml.Unmarshal(fileContent, &conf); err != nil { return nil, fmt.Errorf("failed to unmarshal config file: %w", err) } @@ -322,14 +348,18 @@ func rootsFromRegalDirectory(regalDir *os.File) ([]string, error) { } } - customRulesDir := filepath.Join(regalDir.Name(), "rules") + // Include the "rules" directory when loading from a .regal dir + if fileInfo.IsDir() { + customDir := filepath.Join(file.Name(), "rules") - info, err := os.Stat(customRulesDir) - if err == nil && info.IsDir() { - foundBundleRoots = append(foundBundleRoots, customRulesDir) + info, err := os.Stat(customDir) + if err == nil && info.IsDir() { + foundBundleRoots = append(foundBundleRoots, customDir) + } } - manifestRoots, err := rio.FindManifestLocations(filepath.Dir(regalDir.Name())) + // Include a search for manifest files + manifestRoots, err := rio.FindManifestLocations(parent) if err != nil { return nil, fmt.Errorf("failed while looking for manifest locations: %w", err) } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index cb133c5a..6c5423f9 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -175,7 +175,42 @@ project: expected := util.Map(util.FilepathJoiner(root), []string{"", ".regal/rules", "baz", "bundle", "foo/bar"}) if !slices.Equal(expected, locations) { - t.Errorf("expected %v, got %v", expected, locations) + t.Errorf("expected\n%s\ngot\n%s", strings.Join(expected, "\n"), strings.Join(locations, "\n")) + } + }) +} + +func TestFindBundleRootDirectoriesWithStandaloneConfig(t *testing.T) { + t.Parallel() + + cfg := ` +project: + roots: + - foo/bar + - baz +` + + fs := map[string]string{ + "/.regal.yaml": cfg, // root from config + "/bundle/.manifest": "", // bundle from .manifest + "/foo/bar/baz/policy.rego": "", // foo/bar from config + "/baz": "", // baz from config + } + + test.WithTempFS(fs, func(root string) { + locations, err := FindBundleRootDirectories(root) + if err != nil { + t.Error(err) + } + + if len(locations) != 4 { + t.Errorf("expected 5 locations, got %d", len(locations)) + } + + expected := util.Map(util.FilepathJoiner(root), []string{"", "baz", "bundle", "foo/bar"}) + + if !slices.Equal(expected, locations) { + t.Errorf("expected\n%s\ngot\n%s", strings.Join(expected, "\n"), strings.Join(locations, "\n")) } }) }