Skip to content

Commit

Permalink
toolexec: Add required-packages flag
Browse files Browse the repository at this point in the history
toolexec mode only rewrites packages that have import the errtrace
package, and others are silently ignored which can lead to unintended
missed rewrites.

Accept package selectors of packages that must import errtrace if a
rewrite is required.

Most of this change is setting up the machinery for toolexec flags:
 - Set up a new flagset for toolexec flags, which are only parsed once
   we know we're in toolexec mode.
 - Include a hash of the flag values in the version key as a previously
   cached output may not be valid when flags change. E.g., a previous
   package success should fail if the required-packages flag changes and
   it contains a package that needs a rewrite but is missing the import.
  • Loading branch information
prashantv committed Nov 10, 2024
1 parent e12a8ed commit ae4efd5
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 50 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Experimental support for instrumenting code with errtrace automatically
as part of the Go build process.
Try this out with `go build -toolexec=errtrace pkg/to/build`.
Automatic instrumentation only rewrites packages that import errtrace.
The flag `-required-packages` can be used to specify which packages
are expected to import errtrace if they require rewrites.
Example: `go build -toolexec="errtrace -required-packages pkg/..." pkg/to/build`

### Changed
- Update `go` directive in go.mod to 1.21.
Expand Down
161 changes: 127 additions & 34 deletions cmd/errtrace/toolexec.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@ import (
"bytes"
"crypto/md5"
"encoding/hex"
"flag"
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"runtime/debug"
"slices"
"strings"

"braces.dev/errtrace"
Expand All @@ -21,33 +24,95 @@ func (cmd *mainCmd) handleToolExec(args []string) (exitCode int, handled bool) {
return -1, false
}

for _, arg := range args {
if arg == "-V=full" {
// compile is run first with "-V=full" to get a version number
// for caching build IDs.
// No TOOLEXEC_IMPORTPATH is set in this case.
return cmd.toolExecVersion(args), true
}
}

if cmd.Getenv == nil {
cmd.Getenv = os.Getenv
}
// When "-toolexec" is used, the go cmd sets the package being compiled in the env.
if pkg := cmd.Getenv("TOOLEXEC_IMPORTPATH"); pkg != "" {
return cmd.toolExecRewrite(pkg, args), true

// compile is run first with "-V=full" to get a version number
// for caching build IDs.
// No TOOLEXEC_IMPORTPATH is set in this case.
version := slices.Contains(args, "-V=full")
pkg := cmd.Getenv("TOOLEXEC_IMPORTPATH")
if !version && pkg == "" {
return -1, false
}

var p toolExecParams
if err := p.Parse(os.Stdout, args); err != nil {
cmd.log.Print(err)
return 1, true
}

return -1, false
if version {
return cmd.toolExecVersion(p), true
}
return cmd.toolExecRewrite(pkg, p), true
}

func (cmd *mainCmd) toolExecVersion(args []string) int {
type toolExecParams struct {
RequiredPkgSelectors []string

Tool string
ToolArgs []string

flags *flag.FlagSet
}

func (p *toolExecParams) Parse(w io.Writer, args []string) error {
p.flags = flag.NewFlagSet("errtrace (toolexec)", flag.ContinueOnError)
flag.Usage = func() {
logln(w, `usage with go build/run/test: -toolexec="errtrace [options]"`)
flag.PrintDefaults()
}
var requiredPkgs string
p.flags.StringVar(&requiredPkgs, "required-packages", "", "comma-separated list of package selectors "+
"that are expected to be import errtrace if they return errors.")

// Flag parsing stops at the first non-flag argument (no "-").
if err := p.flags.Parse(args); err != nil {
return errtrace.Wrap(err)
}

remArgs := p.flags.Args()
if len(remArgs) == 0 {
return errtrace.New("toolexec expected tool arguments")
}

p.Tool = remArgs[0]
p.ToolArgs = remArgs[1:]
p.RequiredPkgSelectors = strings.Split(requiredPkgs, ",")
return nil
}

// Options affect the generated code, so use a hash
// of any options for the toolexec version.
func (p *toolExecParams) versionCacheKey() string {
withoutTool := *p
withoutTool.flags = nil
withoutTool.Tool = ""
withoutTool.ToolArgs = nil

optStr := fmt.Sprintf("%v", withoutTool)
optHash := md5.Sum([]byte(optStr))
return hex.EncodeToString(optHash[:])
}

func (p *toolExecParams) requiredPackage(pkg string) bool {
for _, selector := range p.RequiredPkgSelectors {
if packageSelectorMatch(selector, pkg) {
return true
}
}
return false
}

func (cmd *mainCmd) toolExecVersion(p toolExecParams) int {
version, err := binaryVersion()
if err != nil {
logf(cmd.Stderr, "errtrace version failed: %v", err)
}

tool := exec.Command(args[0], args[1:]...)
tool := exec.Command(p.Tool, p.ToolArgs...)
var stdout bytes.Buffer
tool.Stdout = &stdout
tool.Stderr = cmd.Stderr
Expand All @@ -56,30 +121,36 @@ func (cmd *mainCmd) toolExecVersion(args []string) int {
return exitErr.ExitCode()
}

logf(cmd.Stderr, "tool %v failed: %v", args[0], err)
logf(cmd.Stderr, "tool %v failed: %v", p.Tool, err)
return 1
}

if _, err := fmt.Fprintf(cmd.Stdout, "%s-errtrace-%s\n", strings.TrimSpace(stdout.String()), version); err != nil {
if _, err := fmt.Fprintf(
cmd.Stdout,
"%s-errtrace-%s%s\n",
strings.TrimSpace(stdout.String()),
version,
p.versionCacheKey(),
); err != nil {
logf(cmd.Stderr, "failed to write version to stdout: %v", err)
return 1
}

return 0
}

func (cmd *mainCmd) toolExecRewrite(pkg string, args []string) (exitCode int) {
func (cmd *mainCmd) toolExecRewrite(pkg string, p toolExecParams) (exitCode int) {
// We only need to modify the arguments for "compile" calls which work with .go files.
if !isCompile(args[0]) {
return cmd.runOriginal(args)
if !isCompile(p.Tool) {
return cmd.runOriginal(p)
}

// We only modify files that import errtrace, so stdlib is never eliglble.
if isStdLib(args) {
return cmd.runOriginal(args)
if isStdLib(p.ToolArgs) {
return cmd.runOriginal(p)
}

exitCode, err := cmd.rewriteCompile(pkg, args)
exitCode, err := cmd.rewriteCompile(pkg, p)
if err != nil {
cmd.log.Print(err)
return 1
Expand All @@ -88,10 +159,10 @@ func (cmd *mainCmd) toolExecRewrite(pkg string, args []string) (exitCode int) {
return exitCode
}

func (cmd *mainCmd) rewriteCompile(pkg string, args []string) (exitCode int, _ error) {
parsed := make(map[string]parsedFile)
func (cmd *mainCmd) rewriteCompile(pkg string, p toolExecParams) (exitCode int, _ error) {
var canRewrite, needRewrite bool
for _, arg := range args {
parsed := make(map[string]parsedFile)
for _, arg := range p.ToolArgs {
if !isGoFile(arg) {
continue
}
Expand All @@ -116,8 +187,16 @@ func (cmd *mainCmd) rewriteCompile(pkg string, args []string) (exitCode int, _ e
}
}

if !canRewrite || !needRewrite {
return cmd.runOriginal(args), nil
if !needRewrite {
return cmd.runOriginal(p), nil
}

if !canRewrite {
if p.requiredPackage(pkg) {
logf(cmd.Stderr, "errtrace required package %v missing errtrace import, needs rewrite", pkg)
return 1, nil
}
return cmd.runOriginal(p), nil
}

// Use a temporary directory per-package that is rewritten.
Expand All @@ -127,8 +206,8 @@ func (cmd *mainCmd) rewriteCompile(pkg string, args []string) (exitCode int, _ e
}
defer os.RemoveAll(tempDir) //nolint:errcheck // best-effort removal of temp files.

newArgs := make([]string, 0, len(args))
for _, arg := range args {
newArgs := make([]string, 0, len(p.ToolArgs))
for _, arg := range p.ToolArgs {
f, ok := parsed[arg]
if !ok || len(f.inserts) == 0 {
newArgs = append(newArgs, arg)
Expand All @@ -152,7 +231,8 @@ func (cmd *mainCmd) rewriteCompile(pkg string, args []string) (exitCode int, _ e
newArgs = append(newArgs, newFile)
}

return cmd.runOriginal(newArgs), nil
p.ToolArgs = newArgs
return cmd.runOriginal(p), nil
}

func isCompile(arg string) bool {
Expand All @@ -166,8 +246,8 @@ func isGoFile(arg string) bool {
return strings.HasSuffix(arg, ".go")
}

func (cmd *mainCmd) runOriginal(args []string) (exitCode int) {
tool := exec.Command(args[0], args[1:]...)
func (cmd *mainCmd) runOriginal(p toolExecParams) (exitCode int) {
tool := exec.Command(p.Tool, p.ToolArgs...)
tool.Stdin = cmd.Stdin
tool.Stdout = cmd.Stdout
tool.Stderr = cmd.Stderr
Expand All @@ -176,7 +256,7 @@ func (cmd *mainCmd) runOriginal(args []string) (exitCode int) {
if exitErr, ok := err.(*exec.ExitError); ok {
return exitErr.ExitCode()
}
logf(cmd.Stderr, "tool %v failed: %v", args[0], err)
logf(cmd.Stderr, "tool %v failed: %v", p.Tool, err)
return 1
}

Expand Down Expand Up @@ -231,3 +311,16 @@ func readBuildSHA() (_ string, ok bool) {
func isStdLib(args []string) bool {
return slicesContains(args, "-std")
}

func packageSelectorMatch(selector, importPath string) bool {
if pkgPrefix, ok := strings.CutSuffix(selector, "..."); ok {
// foo/... should match foo, but not foobar so we want
// the pkgPrefix to contain the /.
if strings.TrimSuffix(pkgPrefix, "/") == importPath {
return true
}
return strings.HasPrefix(importPath, pkgPrefix)
}

return selector == importPath
}
Loading

0 comments on commit ae4efd5

Please sign in to comment.