Skip to content

Commit

Permalink
add a flag for Enum Generate methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt Jones committed Oct 7, 2015
1 parent 1821c94 commit 2631e8f
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions generator/go.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ import (
)

var (
flagGoBinarystring = flag.Bool("go.binarystring", false, "Always use string for binary instead of []byte")
flagGoJSONEnumnum = flag.Bool("go.json.enumnum", false, "For JSON marshal enums by number instead of name")
flagGoPointers = flag.Bool("go.pointers", false, "Make all fields pointers")
flagGoImportPrefix = flag.String("go.importprefix", "", "Prefix for thrift-generated go package imports")
flagGoBinarystring = flag.Bool("go.binarystring", false, "Always use string for binary instead of []byte")
flagGoJSONEnumnum = flag.Bool("go.json.enumnum", false, "For JSON marshal enums by number instead of name")
flagGoPointers = flag.Bool("go.pointers", false, "Make all fields pointers")
flagGoImportPrefix = flag.String("go.importprefix", "", "Prefix for thrift-generated go package imports")
flagGoGenerateMethods = flag.Bool("go.generate", false, "Add testing/quick compatible Generate methods to enum types")
)

var (
Expand Down Expand Up @@ -423,21 +424,23 @@ func (e *%s) UnmarshalJSON(b []byte) error {
}
`, enumName, enumName, enumName, enumName)

valueStrings := make([]string, 0, len(enum.Values))
for _, val := range enum.Values {
valueStrings = append(valueStrings, strconv.FormatInt(int64(val.Value), 10))
}
sort.Strings(valueStrings)
valueStringsName := strings.ToLower(enumName) + "Values"
if *flagGoGenerateMethods {
valueStrings := make([]string, 0, len(enum.Values))
for _, val := range enum.Values {
valueStrings = append(valueStrings, strconv.FormatInt(int64(val.Value), 10))
}
sort.Strings(valueStrings)
valueStringsName := strings.ToLower(enumName) + "Values"

g.write(out, `
g.write(out, `
var %s = []int32{%s}
func (e *%s) Generate(rand *rand.Rand, size int) reflect.Value {
v := %s(%s[rand.Intn(%d)])
return reflect.ValueOf(&v)
}
`, valueStringsName, strings.Join(valueStrings, ", "), enumName, enumName, valueStringsName, len(valueNames))
}

return nil
}
Expand Down Expand Up @@ -637,7 +640,11 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p
// Imports
imports := []string{"fmt"}
if len(thrift.Enums) > 0 {
imports = append(imports, "strconv", "math/rand", "reflect")
imports = append(imports, "strconv")

if *flagGoGenerateMethods {
imports = append(imports, "math/rand", "reflect")
}
}
if len(thrift.Includes) > 0 {
for _, path := range thrift.Includes {
Expand Down

0 comments on commit 2631e8f

Please sign in to comment.