Skip to content

Commit

Permalink
Fix unintuitive behaviour of RegisterAt with custom registry (#181)
Browse files Browse the repository at this point in the history
This commit fixes some counterintuitive behaviour when creating the middleware
with a custom registry.

Previously, using a custom registry with additional metrics requires avoiding
the RegisterAt function, which used only the prometheus.DefaultGatherer.

Given a *prometheus.Registry, instead of the more intuitive:

```
fp := fiberprometheus.NewWithRegistry(reg, "my-application", "http", "", nil)
fp.RegisterAt(app, "/metrics")
```

One would have to do this:

```
app.Get("/metrics", adaptor.HTTPHandler(promhttp.HandlerFor(reg, promhttp.HandlerOpts{})))
```

The change here detects when the `prometheus.Registerer` is also a
`prometheus.Gatherer`, and then uses the `prometheus.Gatherer` when calling
`RegisterAt`.

This change will _not_ fix work in the case of a custom implementation of
`prometheus.Registerer` when the that implementation is not _also_ a
`prometheus.Gatherer`, which is an unlikely situation for inexperienced users of
prometheus.

For the more common case of passing a `*prometheus.Registry`, this change will
provide a more intuitive experience.
  • Loading branch information
calloway-jacob authored Oct 21, 2023
1 parent 3f27009 commit 7015afd
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
11 changes: 10 additions & 1 deletion middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (

// FiberPrometheus ...
type FiberPrometheus struct {
gatherer prometheus.Gatherer
requestsTotal *prometheus.CounterVec
requestDuration *prometheus.HistogramVec
requestInFlight *prometheus.GaugeVec
Expand Down Expand Up @@ -107,7 +108,15 @@ func create(registry prometheus.Registerer, serviceName, namespace, subsystem st
ConstLabels: constLabels,
}, []string{"method"})

// If the registerer is also a gatherer, use it, falling back to the
// DefaultGatherer.
gatherer, ok := registry.(prometheus.Gatherer)
if !ok {
gatherer = prometheus.DefaultGatherer
}

return &FiberPrometheus{
gatherer: gatherer,
requestsTotal: counter,
requestDuration: histogram,
requestInFlight: gauge,
Expand Down Expand Up @@ -160,7 +169,7 @@ func NewWithRegistry(registry prometheus.Registerer, serviceName, namespace, sub
func (ps *FiberPrometheus) RegisterAt(app fiber.Router, url string, handlers ...fiber.Handler) {
ps.defaultURL = url

h := append(handlers, adaptor.HTTPHandler(promhttp.Handler()))
h := append(handlers, adaptor.HTTPHandler(promhttp.HandlerFor(ps.gatherer, promhttp.HandlerOpts{})))
app.Get(ps.defaultURL, h...)
}

Expand Down
46 changes: 46 additions & 0 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package fiberprometheus

import (
"fmt"
"io"
"net/http/httptest"
"strings"
Expand All @@ -30,6 +31,7 @@ import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/basicauth"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/prometheus/client_golang/prometheus/promhttp"
)

Expand Down Expand Up @@ -332,3 +334,47 @@ func TestMiddlewareWithCustomRegistry(t *testing.T) {
t.Errorf("got %s; want %s", got, want)
}
}

func TestCustomRegistryRegisterAt(t *testing.T) {
app := fiber.New()
registry := prometheus.NewRegistry()
registry.Register(collectors.NewGoCollector())
registry.Register(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
fpCustom := NewWithRegistry(registry, "custom-registry", "custom_name", "http", nil)
fpCustom.RegisterAt(app, "/metrics")

app.Use(fpCustom.Middleware)

app.Get("/", func(c *fiber.Ctx) error {
return c.SendString("Hello, world!")
})
req := httptest.NewRequest("GET", "/", nil)
res, err := app.Test(req, -1)
if err != nil {
t.Fatal(fmt.Errorf("GET / failed: %w", err))
}
defer res.Body.Close()
if res.StatusCode != 200 {
t.Fatal(fmt.Errorf("GET /: Status=%d", res.StatusCode))
}

req = httptest.NewRequest("GET", "/metrics", nil)
resMetr, err := app.Test(req, -1)
if err != nil {
t.Fatal(fmt.Errorf("GET /metrics failed: %W", err))
}
defer resMetr.Body.Close()
if res.StatusCode != 200 {
t.Fatal(fmt.Errorf("GET /metrics: Status=%d", resMetr.StatusCode))
}
body, err := io.ReadAll(resMetr.Body)
if err != nil {
t.Fatal(fmt.Errorf("GET /metrics: read body: %w", err))
}
got := string(body)

want := `custom_name_http_requests_total{method="GET",path="/",service="custom-registry",status_code="200"} 1`
if !strings.Contains(got, want) {
t.Errorf("got %s; want %s", got, want)
}
}

0 comments on commit 7015afd

Please sign in to comment.