diff --git a/check.go b/check.go index 5ec196a..0912520 100644 --- a/check.go +++ b/check.go @@ -167,7 +167,7 @@ next: // C is a type passed to all check functions to provide context. type C struct { Filename string - Includes []string + Dirs []string Program *ast.Program Check string Messages Messages @@ -204,7 +204,7 @@ func (c *C) Errorf(node ast.Node, message string, args ...interface{}) { // Resolve resolves a type reference. func (c *C) Resolve(ref ast.TypeReference) ast.Node { - if n, err := Resolve(ref, c.Program, c.Includes); err == nil { + if n, err := Resolve(ref, c.Program, c.Dirs); err == nil { return n } return nil @@ -212,7 +212,7 @@ func (c *C) Resolve(ref ast.TypeReference) ast.Node { // ResolveType resolves a type reference to its target type. func (c *C) ResolveType(ref ast.TypeReference) ast.Node { - if n, err := ResolveType(ref, c.Program, c.Includes); err == nil { + if n, err := ResolveType(ref, c.Program, c.Dirs); err == nil { return n } return nil diff --git a/checks/includes.go b/checks/includes.go index f24b60b..01ecc45 100644 --- a/checks/includes.go +++ b/checks/includes.go @@ -36,14 +36,8 @@ func CheckIncludePath() *thriftcheck.Check { return } - // Check the current directory first to match `thrift`s behavior. - dirs := c.Includes - if cwd, err := os.Getwd(); err != nil { - dirs = append([]string{cwd}, c.Includes...) - } - found := false - for _, dir := range dirs { + for _, dir := range c.Dirs { if _, err := os.Stat(filepath.Join(dir, i.Path)); err == nil { found = true break diff --git a/linter.go b/linter.go index 8fa3496..41789e8 100644 --- a/linter.go +++ b/linter.go @@ -21,6 +21,7 @@ import ( "io/ioutil" "log" "os" + "path/filepath" "strings" "go.uber.org/thriftrw/ast" @@ -115,7 +116,7 @@ func (l *Linter) lint(program *ast.Program, filename string, parseInfo *idl.Info ctx := &C{ Filename: filename, - Includes: l.includes, + Dirs: append([]string{filepath.Dir(filename)}, l.includes...), Program: program, logger: l.logger, parseInfo: parseInfo, diff --git a/parse.go b/parse.go index 909d926..b497165 100644 --- a/parse.go +++ b/parse.go @@ -47,11 +47,6 @@ func ParseFile(filename string, dirs []string) (*ast.Program, *idl.Info, error) return nil, nil, fmt.Errorf("%s not found", filename) } - // Check the current directory first to match `thrift`s behavior. - if cwd, err := os.Getwd(); err != nil { - dirs = append([]string{cwd}, dirs...) - } - for _, dir := range dirs { if f, err := os.Open(filepath.Join(dir, filename)); err == nil { return Parse(f)