Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: minor tweaks #45

Merged
merged 1 commit into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 21 additions & 26 deletions env.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Package env provides an API for loading environment variables into structs.
// Package env implements loading environment variables into a config struct.
package env

import (
Expand All @@ -8,55 +8,51 @@ import (
"strings"
)

// Options are options for the [Load] function.
// Options are the options for the [Load] function.
type Options struct {
Source Source // The source of environment variables. The default is [OS].
SliceSep string // The separator used to parse slice values. The default is space.
}

// NotSetError is returned when environment variables are marked as required but not set.
type NotSetError struct {
// The names of the missing required environment variables.
Names []string
Names []string // The names of the missing environment variables.
}

// Error implements the error interface.
func (e *NotSetError) Error() string {
return fmt.Sprintf("env: %v are required but not set", e.Names)
if len(e.Names) == 1 {
return fmt.Sprintf("env: %s is required but not set", e.Names[0])
}
return fmt.Sprintf("env: %s are required but not set", strings.Join(e.Names, " "))
}

// Load loads environment variables into the provided struct using the [OS] [Source].
// Load loads environment variables into the given struct.
// cfg must be a non-nil struct pointer, otherwise Load panics.
// If opts is nil, the default [Options] are used.
//
// The struct fields must have the `env:"VAR"` struct tag, where VAR is the name of the corresponding environment variable.
// The struct fields must have the `env:"VAR"` struct tag,
// where VAR is the name of the corresponding environment variable.
// Unexported fields are ignored.
//
// # Supported types
//
// The following types are supported:
// - int (any kind)
// - float (any kind)
// - bool
// - string
// - [time.Duration]
// - [encoding.TextUnmarshaler]
// - slices of any type above (space is the default separator for values)
//
// See the [strconv].Parse* functions for parsing rules.
// Implementing the [encoding.TextUnmarshaler] interface is enough to use any user-defined type.
// Nested structs of any depth level are supported, only the leaves of the config tree must have the `env` tag.
//
// # Default values
// - slices of any type above
// - nested structs of any depth
//
// Default values can be specified either using the `default` struct tag (has a higher priority) or by initializing the struct fields directly.
// See the [strconv].Parse* functions for the parsing rules.
// User-defined types can be used by implementing the [encoding.TextUnmarshaler] interface.
//
// # Per-variable options
//
// The name of the environment variable can be followed by comma-separated options in the form of `env:"VAR,option1,option2,..."`:
// Default values can be specified using the `default:"VALUE"` struct tag.
//
// The name of an environment variable can be followed by comma-separated options:
// - required: marks the environment variable as required
// - expand: expands the value of the environment variable using [os.Expand]
//
// If environment variables are marked as required but not set, an error of type [NotSetError] will be returned.
func Load(cfg any, opts *Options) error {
if opts == nil {
opts = new(Options)
Expand Down Expand Up @@ -113,20 +109,19 @@ func parseVars(v reflect.Value) []Var {
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if !field.CanSet() {
continue // skip unexported fields.
continue
}

// special case: a nested struct, parse its fields recursively.
if kindOf(field, reflect.Struct) && !implements(field, unmarshalerIface) {
nested := parseVars(field)
vars = append(vars, nested...)
vars = append(vars, parseVars(field)...)
continue
}

sf := v.Type().Field(i)
value, ok := sf.Tag.Lookup("env")
if !ok {
continue // skip fields without the `env` tag.
continue
}

parts := strings.Split(value, ",")
Expand Down
2 changes: 1 addition & 1 deletion env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestLoad(t *testing.T) {
})

t.Run("unsupported type", func(t *testing.T) {
m := env.Map{"FOO": "1 + 2i"}
m := env.Map{"FOO": "1+2i"}

var cfg struct {
Foo complex64 `env:"FOO"`
Expand Down
14 changes: 6 additions & 8 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,19 @@ func ExampleLoad_nestedStruct() {
}

func ExampleLoad_required() {
os.Unsetenv("HOST")
os.Unsetenv("PORT")

var cfg struct {
Host string `env:"HOST,required"`
Port int `env:"PORT,required"`
Port int `env:"PORT,required"`
}
if err := env.Load(&cfg, nil); err != nil {
var notSetErr *env.NotSetError
if errors.As(err, &notSetErr) {
fmt.Println(notSetErr.Names)
fmt.Println(notSetErr)
}
}

// Output: [HOST PORT]
// Output: env: PORT is required but not set
}

func ExampleLoad_expand() {
Expand Down Expand Up @@ -100,12 +98,12 @@ func ExampleLoad_source() {
}

func ExampleLoad_sliceSeparator() {
os.Setenv("PORTS", "8080;8081;8082")
os.Setenv("PORTS", "8080,8081,8082")

var cfg struct {
Ports []int `env:"PORTS"`
}
if err := env.Load(&cfg, &env.Options{SliceSep: ";"}); err != nil {
if err := env.Load(&cfg, &env.Options{SliceSep: ","}); err != nil {
fmt.Println(err)
}

Expand All @@ -129,7 +127,7 @@ func ExampleUsage() {
env.Usage(&cfg, os.Stdout)
}

// Output: env: [DB_HOST DB_PORT] are required but not set
// Output: env: DB_HOST DB_PORT are required but not set
// Usage:
// DB_HOST string required database host
// DB_PORT int required database port
Expand Down
16 changes: 4 additions & 12 deletions reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ var (
unmarshalerIface = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
)

// typeOf reports whether v's type is one of the provided types.
func typeOf(v reflect.Value, types ...reflect.Type) bool {
for _, t := range types {
if t == v.Type() {
Expand All @@ -23,7 +22,6 @@ func typeOf(v reflect.Value, types ...reflect.Type) bool {
return false
}

// kindOf reports whether v's kind is one of the provided kinds.
func kindOf(v reflect.Value, kinds ...reflect.Kind) bool {
for _, k := range kinds {
if k == v.Kind() {
Expand All @@ -33,22 +31,19 @@ func kindOf(v reflect.Value, kinds ...reflect.Kind) bool {
return false
}

// implements reports whether v's type implements one of the provided interfaces.
func implements(v reflect.Value, ifaces ...reflect.Type) bool {
for _, iface := range ifaces {
if t := v.Type(); t.Implements(iface) || reflect.PtrTo(v.Type()).Implements(iface) {
if t := v.Type(); t.Implements(iface) || reflect.PtrTo(t).Implements(iface) {
return true
}
}
return false
}

// structPtr reports whether v is a non-nil struct pointer.
func structPtr(v reflect.Value) bool {
return v.IsValid() && v.Kind() == reflect.Ptr && v.Elem().Kind() == reflect.Struct && !v.IsNil()
}

// setValue parses s based on v's type/kind and sets v's underlying value to the result.
func setValue(v reflect.Value, s string) error {
switch {
case typeOf(v, durationType):
Expand All @@ -71,8 +66,7 @@ func setValue(v reflect.Value, s string) error {
}

func setInt(v reflect.Value, s string) error {
bits := v.Type().Bits()
i, err := strconv.ParseInt(s, 10, bits)
i, err := strconv.ParseInt(s, 10, v.Type().Bits())
if err != nil {
return fmt.Errorf("parsing int: %w", err)
}
Expand All @@ -81,8 +75,7 @@ func setInt(v reflect.Value, s string) error {
}

func setUint(v reflect.Value, s string) error {
bits := v.Type().Bits()
u, err := strconv.ParseUint(s, 10, bits)
u, err := strconv.ParseUint(s, 10, v.Type().Bits())
if err != nil {
return fmt.Errorf("parsing uint: %w", err)
}
Expand All @@ -91,8 +84,7 @@ func setUint(v reflect.Value, s string) error {
}

func setFloat(v reflect.Value, s string) error {
bits := v.Type().Bits()
f, err := strconv.ParseFloat(s, bits)
f, err := strconv.ParseFloat(s, v.Type().Bits())
if err != nil {
return fmt.Errorf("parsing float: %w", err)
}
Expand Down