Skip to content

Commit

Permalink
recognize standard package from go root
Browse files Browse the repository at this point in the history
  • Loading branch information
wusendong committed May 30, 2019
1 parent 510d79b commit ac401eb
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 20 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
module github.com/wusendong/gogimport
98 changes: 78 additions & 20 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,42 @@ import (
"bytes"
"errors"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/printer"
"go/token"
"io"
"io/ioutil"
"log"
"os"
"os/exec"
"os/user"
"runtime"
"sort"
"strings"
)

var rootPkg = flag.String("local", "", "local package name")

func main() {
var (
err error
)

flag.Parse()
if len(*rootPkg) <= 0 {
flag.Usage()
log.Fatalln("local must set")
}
files := os.Args[3:]

err = initStdPkg()
if err != nil {
log.Fatalf("init std package failed: %v", err)
}

files := flag.Args()
for _, filename := range files {
sortFile(filename)
}
Expand Down Expand Up @@ -107,12 +121,13 @@ func (st *Sorter) sortSpecs(specs []ast.Spec) (results []ast.Spec) {
for _, spec := range specs {
switch im := spec.(type) {
case *ast.ImportSpec:
if strings.HasPrefix(im.Path.Value, `"`+st.rootPkg) {
switch {
case strings.HasPrefix(im.Path.Value, `"`+st.rootPkg):
appPkg = append(appPkg, im)
} else if isThirparty(im.Path.Value) {
thirdpartyPkg = append(thirdpartyPkg, im)
} else {
case stdPkgs[im.Path.Value]:
innerPkg = append(innerPkg, im)
default:
thirdpartyPkg = append(thirdpartyPkg, im)
}
if lowestPos >= im.Pos() {
lowestPos = im.Pos()
Expand Down Expand Up @@ -269,21 +284,6 @@ func deduline(lines []int) []int {
return lines
}

func isThirparty(path string) bool {
for _, pkg := range thirdpartyPrefix {
if strings.HasPrefix(path, `"`+pkg) {
return true
}
}
return false
}

var thirdpartyPrefix = []string{
"github",
"gitlab",
"gopkg",
}

// Sorter gogimport sorter
type Sorter struct {
filename string
Expand All @@ -304,3 +304,61 @@ const (
MaxInt = int(MaxUint >> 1)
MinInt = -MaxInt - 1
)

var stdPkgs = map[string]bool{}

func initStdPkg() error {
me, err := user.Current()
if err != nil {
return err
}

cacheDir := me.HomeDir + "/.gogimport"
if err = os.MkdirAll(cacheDir, 0666); err != nil {
return fmt.Errorf("mkdir failed: %v", err)
}
cacheFileName := cacheDir + "/" + runtime.Version()
stat, statErr := os.Stat(cacheFileName)
if statErr != nil && !os.IsNotExist(statErr) {
return statErr
}

var reader io.ReadWriter
cacheFile, err := os.OpenFile(cacheFileName, os.O_RDWR|os.O_CREATE, 0666)
if err != nil {
return fmt.Errorf("open cache file failed: %v", err)
}
defer cacheFile.Close()
if os.IsNotExist(statErr) || stat.Size() < 10 {
cmd := exec.Command("go", "list", "./...")
cmd.Dir = strings.TrimSpace(runtime.GOROOT()) + "/src/"
var stderr bytes.Buffer
var stdout bytes.Buffer

cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err = cmd.Run(); err != nil {
if _, ok := err.(*exec.ExitError); ok {
return fmt.Errorf("list standard package failed: %s", stderr.Bytes())
}
return fmt.Errorf("list standard package failed: %v", err.Error())
}
if err = cacheFile.Truncate(0); err != nil {
return fmt.Errorf("truncate cache file failed: %v", err)
}

reader = &bytes.Buffer{}
if _, err = stdout.WriteTo(io.MultiWriter(reader, cacheFile)); err != nil {
return fmt.Errorf("write cache file failed: %v", err)
}
} else {
reader = cacheFile
}

// find std packages
sc := bufio.NewScanner(reader)
for sc.Scan() {
stdPkgs[sc.Text()] = true
}
return nil
}

0 comments on commit ac401eb

Please sign in to comment.