Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add configuration file #140

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 74 additions & 62 deletions cmd/rest-server/main.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
package main

import (
"errors"
"context"
"fmt"
"log"
"net/http"
"os"
"path/filepath"
"runtime"
"runtime/pprof"

"github.com/PowerDNS/go-tlsconfig"
"github.com/c2h5oh/datasize"
restserver "github.com/restic/rest-server"
"github.com/restic/rest-server/config"
"github.com/spf13/cobra"
)

Expand All @@ -25,57 +27,37 @@ var cmdRoot = &cobra.Command{
//Version: fmt.Sprintf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH),
}

var server = restserver.Server{
Path: "/tmp/restic",
Listen: ":8000",
}

var (
showVersion bool
cpuProfile string
showVersion bool
cpuProfile string
maxSizeBytes uint64
tlsEnabled bool
configFile string
flagConfig = config.Config{}
)

func init() {
flags := cmdRoot.Flags()
flags.StringVarP(&configFile, "config", "c", configFile, "path to YAML config file")
flags.StringVar(&cpuProfile, "cpu-profile", cpuProfile, "write CPU profile to file")
flags.BoolVar(&server.Debug, "debug", server.Debug, "output debug messages")
flags.StringVar(&server.Listen, "listen", server.Listen, "listen address")
flags.StringVar(&server.Log, "log", server.Log, "log HTTP requests in the combined log format")
flags.Int64Var(&server.MaxRepoSize, "max-size", server.MaxRepoSize, "the maximum size of the repository in bytes")
flags.StringVar(&server.Path, "path", server.Path, "data directory")
flags.BoolVar(&server.TLS, "tls", server.TLS, "turn on TLS support")
flags.StringVar(&server.TLSCert, "tls-cert", server.TLSCert, "TLS certificate path")
flags.StringVar(&server.TLSKey, "tls-key", server.TLSKey, "TLS key path")
flags.BoolVar(&server.NoAuth, "no-auth", server.NoAuth, "disable .htpasswd authentication")
flags.BoolVar(&server.AppendOnly, "append-only", server.AppendOnly, "enable append only mode")
flags.BoolVar(&server.PrivateRepos, "private-repos", server.PrivateRepos, "users can only access their private repo")
flags.BoolVar(&server.Prometheus, "prometheus", server.Prometheus, "enable Prometheus metrics")
flags.BoolVar(&server.Prometheus, "prometheus-no-auth", server.PrometheusNoAuth, "disable auth for Prometheus /metrics endpoint")
flags.BoolVar(&flagConfig.Debug, "debug", flagConfig.Debug, "output debug messages")
flags.StringVar(&flagConfig.Listen, "listen", flagConfig.Listen, "listen address")
flags.StringVar(&flagConfig.AccessLog, "log", flagConfig.AccessLog, "log HTTP requests in the combined log format")
flags.Uint64Var(&maxSizeBytes, "max-size", uint64(flagConfig.Quota.MaxSize), "the maximum size of the repository in bytes")
flags.StringVar(&flagConfig.Path, "path", flagConfig.Path, "data directory")
flags.BoolVar(&tlsEnabled, "tls", flagConfig.TLS.HasCertWithKey(), "turn on TLS support")
flags.StringVar(&flagConfig.TLS.CertFile, "tls-cert", flagConfig.TLS.CertFile, "TLS certificate path")
flags.StringVar(&flagConfig.TLS.KeyFile, "tls-key", flagConfig.TLS.KeyFile, "TLS key path")
flags.BoolVar(&flagConfig.Auth.Disabled, "no-auth", flagConfig.Auth.Disabled, "disable .htpasswd authentication")
flags.BoolVar(&flagConfig.AppendOnly, "append-only", flagConfig.AppendOnly, "enable append only mode")
flags.BoolVar(&flagConfig.PrivateRepos, "private-repos", flagConfig.PrivateRepos, "users can only access their private repo")
flags.BoolVar(&flagConfig.Metrics.Enabled, "prometheus", flagConfig.Metrics.Enabled, "enable Prometheus metrics")
flags.BoolVar(&flagConfig.Metrics.NoAuth, "prometheus-no-auth", flagConfig.Metrics.NoAuth, "disable auth for Prometheus /metrics endpoint")
flags.BoolVarP(&showVersion, "version", "V", showVersion, "output version and exit")
}

var version = "0.10.0-dev"

func tlsSettings() (bool, string, string, error) {
var key, cert string
if !server.TLS && (server.TLSKey != "" || server.TLSCert != "") {
return false, "", "", errors.New("requires enabled TLS")
} else if !server.TLS {
return false, "", "", nil
}
if server.TLSKey != "" {
key = server.TLSKey
} else {
key = filepath.Join(server.Path, "private_key")
}
if server.TLSCert != "" {
cert = server.TLSCert
} else {
cert = filepath.Join(server.Path, "public_key")
}
return server.TLS, key, cert, nil
}

func runRoot(cmd *cobra.Command, args []string) error {
if showVersion {
fmt.Printf("rest-server %s compiled with %v on %v/%v\n", version, runtime.Version(), runtime.GOOS, runtime.GOARCH)
Expand All @@ -84,7 +66,26 @@ func runRoot(cmd *cobra.Command, args []string) error {

log.SetFlags(0)

log.Printf("Data directory: %s", server.Path)
// Load config
conf := config.Default()
if configFile != "" {
if err := conf.LoadYAMLFile(configFile); err != nil {
return err
}
}

// Merge flag config
conf.Quota.MaxSize = datasize.ByteSize(maxSizeBytes)
conf.MergeFlags(flagConfig)
if conf.Debug {
log.Printf("Effective config:\n%s", conf.String())
}
if err := conf.Check(); err != nil {
return err
}
if tlsEnabled && !conf.TLS.HasCertWithKey() {
return fmt.Errorf("--tls set, but key and cert not configured")
}

if cpuProfile != "" {
f, err := os.Create(cpuProfile)
Expand All @@ -98,40 +99,51 @@ func runRoot(cmd *cobra.Command, args []string) error {
defer pprof.StopCPUProfile()
}

if server.NoAuth {
log.Printf("Data directory: %s", conf.Path)
if conf.Auth.Disabled {
log.Println("Authentication disabled")
} else {
log.Println("Authentication enabled")
}

handler, err := restserver.NewHandler(&server)
if err != nil {
log.Fatalf("error: %v", err)
}

if server.PrivateRepos {
if conf.PrivateRepos {
log.Println("Private repositories enabled")
} else {
log.Println("Private repositories disabled")
}

enabledTLS, privateKey, publicKey, err := tlsSettings()
server, err := restserver.NewServer(*conf)
if err != nil {
return err
}
handler, err := restserver.NewHandler(server)
if err != nil {
return err
}
if !enabledTLS {
log.Printf("Starting server on %s\n", server.Listen)
err = http.ListenAndServe(server.Listen, handler)
} else {

ctx := context.Background()
if !conf.TLS.HasCertWithKey() {
log.Printf("Starting server on %s\n", conf.Listen)
return http.ListenAndServe(conf.Listen, handler)
} else {
log.Println("TLS enabled")
log.Printf("Private key: %s", privateKey)
log.Printf("Public key(certificate): %s", publicKey)
log.Printf("Starting server on %s\n", server.Listen)
err = http.ListenAndServeTLS(server.Listen, publicKey, privateKey, handler)
log.Printf("Starting server on %s\n", conf.Listen)
manager, err := tlsconfig.NewManager(ctx, conf.TLS, tlsconfig.Options{
IsServer: true,
})
if err != nil {
return err
}
tlsConfig, err := manager.TLSConfig()
if err != nil {
return err
}
hs := http.Server{
Addr: conf.Listen,
Handler: handler,
TLSConfig: tlsConfig,
}
return hs.ListenAndServeTLS("", "") // Certificates are handled by TLSConfig
}

return err
}

func main() {
Expand Down
65 changes: 0 additions & 65 deletions cmd/rest-server/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,71 +9,6 @@ import (
restserver "github.com/restic/rest-server"
)

func TestTLSSettings(t *testing.T) {
type expected struct {
TLSKey string
TLSCert string
Error bool
}
type passed struct {
Path string
TLS bool
TLSKey string
TLSCert string
}

var tests = []struct {
passed passed
expected expected
}{
{passed{TLS: false}, expected{"", "", false}},
{passed{TLS: true}, expected{"/tmp/restic/private_key", "/tmp/restic/public_key", false}},
{passed{Path: "/tmp", TLS: true}, expected{"/tmp/private_key", "/tmp/public_key", false}},
{passed{Path: "/tmp", TLS: true, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"/etc/restic/key", "/etc/restic/cert", false}},
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key", TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
{passed{Path: "/tmp", TLS: false, TLSKey: "/etc/restic/key"}, expected{"", "", true}},
{passed{Path: "/tmp", TLS: false, TLSCert: "/etc/restic/cert"}, expected{"", "", true}},
}

for _, test := range tests {

t.Run("", func(t *testing.T) {
// defer func() { restserver.Server = defaultConfig }()
if test.passed.Path != "" {
server.Path = test.passed.Path
}
server.TLS = test.passed.TLS
server.TLSKey = test.passed.TLSKey
server.TLSCert = test.passed.TLSCert

gotTLS, gotKey, gotCert, err := tlsSettings()
if err != nil && !test.expected.Error {
t.Fatalf("tls_settings returned err (%v)", err)
}
if test.expected.Error {
if err == nil {
t.Fatalf("Error not returned properly (%v)", test)
} else {
return
}
}
if gotTLS != test.passed.TLS {
t.Errorf("TLS enabled, want (%v), got (%v)", test.passed.TLS, gotTLS)
}
wantKey := test.expected.TLSKey
if gotKey != wantKey {
t.Errorf("wrong TLSPrivPath path, want (%v), got (%v)", wantKey, gotKey)
}

wantCert := test.expected.TLSCert
if gotCert != wantCert {
t.Errorf("wrong TLSCertPath path, want (%v), got (%v)", wantCert, gotCert)
}

})
}
}

func TestGetHandler(t *testing.T) {
dir, err := ioutil.TempDir("", "rest-server-test")
if err != nil {
Expand Down
121 changes: 121 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Package config contains the configuration structures for rest-server
package config

import (
"fmt"
"io/ioutil"
"log"

"github.com/PowerDNS/go-tlsconfig"
"github.com/c2h5oh/datasize"
"gopkg.in/yaml.v2"
)

// Config is the config root object
type Config struct {
Path string `yaml:"path"`
AppendOnly bool `yaml:"append_only"`
PrivateRepos bool `yaml:"private_repos"`
Listen string `yaml:"listen"` // Address like ":8000"
TLS tlsconfig.Config `yaml:"tls"`
AccessLog string `yaml:"access_log"`
Debug bool `yaml:"debug"`
Quota Quota `yaml:"quota"`
Metrics Metrics `yaml:"metrics"`
Auth Auth `yaml:"auth"`
Users map[string]User `yaml:"users"`
}

// Quota configures disk usage quota enforcements
type Quota struct {
Scope string `yaml:"scope,omitempty"`
MaxSize datasize.ByteSize `yaml:"max_size"`
}

// Metrics configures Prometheus metrics
type Metrics struct {
Enabled bool `yaml:"enabled"`
NoAuth bool `yaml:"no_auth"`
}

// Auth configures authentication
type Auth struct {
Disabled bool `yaml:"disabled"`
Backend string `yaml:"backend,omitempty"`
HTPasswdFile string `yaml:"htpasswd_file"`
}

// User configures user overrides
type User struct {
AppendOnly *bool `yaml:"append_only,omitempty"`
PrivateRepos *bool `yaml:"private_repos,omitempty"`
}

// Check validates a Config instance
func (c Config) Check() error {
return nil
}

// String returns the config as a YAML string
func (c Config) String() string {
y, err := yaml.Marshal(c)
if err != nil {
log.Panicf("YAML marshal of config failed: %v", err) // Should never happen
}
return string(y)
}

// LoadYAML loads config from YAML. Any set value overwrites any existing value,
// but omitted keys are untouched.
func (c *Config) LoadYAML(yamlContents []byte) error {
return yaml.UnmarshalStrict(yamlContents, c)
}

// LoadYAML loads config from a YAML file. Any set value overwrites any existing value,
// but omitted keys are untouched.
func (c *Config) LoadYAMLFile(fpath string) error {
contents, err := ioutil.ReadFile(fpath)
if err != nil {
return fmt.Errorf("open yaml file: %w", err)
}
return c.LoadYAML(contents)
}

func mergeString(a, b string) string {
if b != "" {
return b
}
return a
}

// MergeFlags merges configuration set by commandline flags into the current Config
func (c *Config) MergeFlags(fc Config) {
c.Debug = c.Debug || fc.Debug
c.Listen = mergeString(c.Listen, fc.Listen)
c.AccessLog = mergeString(c.AccessLog, fc.AccessLog)
if fc.Quota.MaxSize > 0 {
c.Quota.MaxSize = fc.Quota.MaxSize
}
c.Path = mergeString(c.Path, fc.Path)
c.TLS.CertFile = mergeString(c.TLS.CertFile, fc.TLS.CertFile)
c.TLS.KeyFile = mergeString(c.TLS.KeyFile, fc.TLS.KeyFile)
c.Auth.Disabled = c.Auth.Disabled || fc.Auth.Disabled
c.AppendOnly = c.AppendOnly || fc.AppendOnly
c.PrivateRepos = c.PrivateRepos || fc.PrivateRepos
c.Metrics.Enabled = c.Metrics.Enabled || fc.Metrics.Enabled
c.Metrics.NoAuth = c.Metrics.NoAuth || fc.Metrics.NoAuth
}

// Default returns a Config with default settings
func Default() *Config {
return &Config{
Path: "/tmp/restic",
Listen: ":8000",
Users: make(map[string]User),
Auth: Auth{
Disabled: false,
Backend: "htpasswd",
HTPasswdFile: ".htpasswd",
},
}
}
Loading