diff --git a/main.go b/main.go index d581750..e0c6a11 100644 --- a/main.go +++ b/main.go @@ -3,8 +3,11 @@ package main import ( "context" "flag" - "io/ioutil" + "io" "log" + "net" + "net/http" + "net/http/pprof" "os" "os/signal" "strings" @@ -17,36 +20,10 @@ import ( "github.com/voc/srtrelay/srt" ) -func handleSignal(ctx context.Context, cancel context.CancelFunc) { - // Set up channel on which to send signal notifications. - // We must use a buffered channel or risk missing the signal - // if we're not ready to receive when the signal is sent. - c := make(chan os.Signal, 1) - signal.Notify(c, - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM) - - go func() { - for { - select { - case <-ctx.Done(): - return - case s := <-c: - log.Println("caught signal", s) - if s == syscall.SIGHUP { - continue - } - cancel() - } - } - }() -} - func main() { // allow specifying config path configFlags := flag.NewFlagSet("config", flag.ContinueOnError) - configFlags.SetOutput(ioutil.Discard) + configFlags.SetOutput(io.Discard) configPath := configFlags.String("config", "config.toml", "") configFlags.Parse(os.Args[1:]) @@ -65,8 +42,16 @@ func main() { flag.UintVar(&conf.App.Latency, "latency", conf.App.Latency, "srt protocol latency in ms") flag.UintVar(&conf.App.Buffersize, "buffersize", conf.App.Buffersize, `relay buffer size in bytes, determines maximum delay of a client`) + profile := flag.String("pprof", "", "enable profiling server on given address") flag.Parse() + if *profile != "" { + log.Println("Enabling profiling on", *profile) + if err := enablePprof(*profile); err != nil { + log.Println("failed to enable profiling:", err) + } + } + conf.App.Addresses = strings.Split(addresses, ",") auth, err := config.GetAuthenticator(conf.Auth) @@ -117,3 +102,48 @@ func main() { } srtgo.CleanupSRT() } + +func enablePprof(addr string) error { + conn, err := net.Listen("tcp", addr) + if err != nil { + return err + } + mux := http.NewServeMux() + mux.HandleFunc("/debug/pprof/", pprof.Index) + srv := http.Server{ + Handler: mux, + } + go func() { + err := srv.Serve(conn) + if err != nil && err != http.ErrServerClosed { + log.Println(err) + } + }() + return nil +} + +func handleSignal(ctx context.Context, cancel context.CancelFunc) { + // Set up channel on which to send signal notifications. + // We must use a buffered channel or risk missing the signal + // if we're not ready to receive when the signal is sent. + c := make(chan os.Signal, 1) + signal.Notify(c, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGTERM) + + go func() { + for { + select { + case <-ctx.Done(): + return + case s := <-c: + log.Println("caught signal", s) + if s == syscall.SIGHUP { + continue + } + cancel() + } + } + }() +}