diff --git a/handlers/routeservice_test.go b/handlers/routeservice_test.go index 7be3af61..7caa4cc5 100644 --- a/handlers/routeservice_test.go +++ b/handlers/routeservice_test.go @@ -129,7 +129,7 @@ var _ = Describe("Route Service Handler", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{}) added := routePool.Put(endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) }) It("should not add route service metadata to the request for normal routes", func() { handler.ServeHTTP(resp, req) @@ -152,7 +152,7 @@ var _ = Describe("Route Service Handler", func() { BeforeEach(func() { endpoint := route.NewEndpoint(&route.EndpointOpts{RouteServiceUrl: "route-service.com"}) added := routePool.Put(endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) }) It("returns 502 Bad Gateway", func() { @@ -172,7 +172,7 @@ var _ = Describe("Route Service Handler", func() { BeforeEach(func() { endpoint := route.NewEndpoint(&route.EndpointOpts{}) added := routePool.Put(endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) }) It("should not add route service metadata to the request for normal routes", func() { handler.ServeHTTP(resp, req) @@ -211,7 +211,7 @@ var _ = Describe("Route Service Handler", func() { BeforeEach(func() { endpoint := route.NewEndpoint(&route.EndpointOpts{RouteServiceUrl: "https://route-service.com"}) added := routePool.Put(endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) }) It("sends the request to the route service with X-CF-Forwarded-Url using https scheme", func() { @@ -699,7 +699,7 @@ var _ = Describe("Route Service Handler", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{RouteServiceUrl: "https://goodrouteservice.com"}) added := routePool.Put(endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) req.Header.Set("connection", "upgrade") req.Header.Set("upgrade", "websocket") @@ -718,7 +718,7 @@ var _ = Describe("Route Service Handler", func() { BeforeEach(func() { endpoint := route.NewEndpoint(&route.EndpointOpts{RouteServiceUrl: "https://bad%20service.com"}) added := routePool.Put(endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) }) It("returns a 500 internal server error response", func() { diff --git a/metrics/compositereporter.go b/metrics/compositereporter.go index 9631893d..56255f3b 100644 --- a/metrics/compositereporter.go +++ b/metrics/compositereporter.go @@ -44,10 +44,10 @@ type RouteRegistryReporter interface { CaptureRouteStats(totalRoutes int, msSinceLastUpdate int64) CaptureRoutesPruned(prunedRoutes uint64) CaptureLookupTime(t time.Duration) - CaptureRegistryMessage(msg ComponentTagged) + CaptureRegistryMessage(msg ComponentTagged, action string) CaptureRouteRegistrationLatency(t time.Duration) UnmuzzleRouteRegistrationLatency() - CaptureUnregistryMessage(msg ComponentTagged) + CaptureUnregistryMessage(msg ComponentTagged, action string) } type CompositeReporter struct { diff --git a/metrics/fakes/fake_registry_reporter.go b/metrics/fakes/fake_registry_reporter.go index 7a8dbd72..3dafcb05 100644 --- a/metrics/fakes/fake_registry_reporter.go +++ b/metrics/fakes/fake_registry_reporter.go @@ -14,10 +14,11 @@ type FakeRouteRegistryReporter struct { captureLookupTimeArgsForCall []struct { arg1 time.Duration } - CaptureRegistryMessageStub func(metrics.ComponentTagged) + CaptureRegistryMessageStub func(metrics.ComponentTagged, string) captureRegistryMessageMutex sync.RWMutex captureRegistryMessageArgsForCall []struct { arg1 metrics.ComponentTagged + arg2 string } CaptureRouteRegistrationLatencyStub func(time.Duration) captureRouteRegistrationLatencyMutex sync.RWMutex @@ -35,10 +36,11 @@ type FakeRouteRegistryReporter struct { captureRoutesPrunedArgsForCall []struct { arg1 uint64 } - CaptureUnregistryMessageStub func(metrics.ComponentTagged) + CaptureUnregistryMessageStub func(metrics.ComponentTagged, string) captureUnregistryMessageMutex sync.RWMutex captureUnregistryMessageArgsForCall []struct { arg1 metrics.ComponentTagged + arg2 string } UnmuzzleRouteRegistrationLatencyStub func() unmuzzleRouteRegistrationLatencyMutex sync.RWMutex @@ -80,16 +82,17 @@ func (fake *FakeRouteRegistryReporter) CaptureLookupTimeArgsForCall(i int) time. return argsForCall.arg1 } -func (fake *FakeRouteRegistryReporter) CaptureRegistryMessage(arg1 metrics.ComponentTagged) { +func (fake *FakeRouteRegistryReporter) CaptureRegistryMessage(arg1 metrics.ComponentTagged, arg2 string) { fake.captureRegistryMessageMutex.Lock() fake.captureRegistryMessageArgsForCall = append(fake.captureRegistryMessageArgsForCall, struct { arg1 metrics.ComponentTagged - }{arg1}) + arg2 string + }{arg1, arg2}) stub := fake.CaptureRegistryMessageStub - fake.recordInvocation("CaptureRegistryMessage", []interface{}{arg1}) + fake.recordInvocation("CaptureRegistryMessage", []interface{}{arg1, arg2}) fake.captureRegistryMessageMutex.Unlock() if stub != nil { - fake.CaptureRegistryMessageStub(arg1) + fake.CaptureRegistryMessageStub(arg1, arg2) } } @@ -99,17 +102,17 @@ func (fake *FakeRouteRegistryReporter) CaptureRegistryMessageCallCount() int { return len(fake.captureRegistryMessageArgsForCall) } -func (fake *FakeRouteRegistryReporter) CaptureRegistryMessageCalls(stub func(metrics.ComponentTagged)) { +func (fake *FakeRouteRegistryReporter) CaptureRegistryMessageCalls(stub func(metrics.ComponentTagged, string)) { fake.captureRegistryMessageMutex.Lock() defer fake.captureRegistryMessageMutex.Unlock() fake.CaptureRegistryMessageStub = stub } -func (fake *FakeRouteRegistryReporter) CaptureRegistryMessageArgsForCall(i int) metrics.ComponentTagged { +func (fake *FakeRouteRegistryReporter) CaptureRegistryMessageArgsForCall(i int) (metrics.ComponentTagged, string) { fake.captureRegistryMessageMutex.RLock() defer fake.captureRegistryMessageMutex.RUnlock() argsForCall := fake.captureRegistryMessageArgsForCall[i] - return argsForCall.arg1 + return argsForCall.arg1, argsForCall.arg2 } func (fake *FakeRouteRegistryReporter) CaptureRouteRegistrationLatency(arg1 time.Duration) { @@ -209,16 +212,17 @@ func (fake *FakeRouteRegistryReporter) CaptureRoutesPrunedArgsForCall(i int) uin return argsForCall.arg1 } -func (fake *FakeRouteRegistryReporter) CaptureUnregistryMessage(arg1 metrics.ComponentTagged) { +func (fake *FakeRouteRegistryReporter) CaptureUnregistryMessage(arg1 metrics.ComponentTagged, arg2 string) { fake.captureUnregistryMessageMutex.Lock() fake.captureUnregistryMessageArgsForCall = append(fake.captureUnregistryMessageArgsForCall, struct { arg1 metrics.ComponentTagged - }{arg1}) + arg2 string + }{arg1, arg2}) stub := fake.CaptureUnregistryMessageStub - fake.recordInvocation("CaptureUnregistryMessage", []interface{}{arg1}) + fake.recordInvocation("CaptureUnregistryMessage", []interface{}{arg1, arg2}) fake.captureUnregistryMessageMutex.Unlock() if stub != nil { - fake.CaptureUnregistryMessageStub(arg1) + fake.CaptureUnregistryMessageStub(arg1, arg2) } } @@ -228,17 +232,17 @@ func (fake *FakeRouteRegistryReporter) CaptureUnregistryMessageCallCount() int { return len(fake.captureUnregistryMessageArgsForCall) } -func (fake *FakeRouteRegistryReporter) CaptureUnregistryMessageCalls(stub func(metrics.ComponentTagged)) { +func (fake *FakeRouteRegistryReporter) CaptureUnregistryMessageCalls(stub func(metrics.ComponentTagged, string)) { fake.captureUnregistryMessageMutex.Lock() defer fake.captureUnregistryMessageMutex.Unlock() fake.CaptureUnregistryMessageStub = stub } -func (fake *FakeRouteRegistryReporter) CaptureUnregistryMessageArgsForCall(i int) metrics.ComponentTagged { +func (fake *FakeRouteRegistryReporter) CaptureUnregistryMessageArgsForCall(i int) (metrics.ComponentTagged, string) { fake.captureUnregistryMessageMutex.RLock() defer fake.captureUnregistryMessageMutex.RUnlock() argsForCall := fake.captureUnregistryMessageArgsForCall[i] - return argsForCall.arg1 + return argsForCall.arg1, argsForCall.arg2 } func (fake *FakeRouteRegistryReporter) UnmuzzleRouteRegistrationLatency() { diff --git a/metrics/metricsreporter.go b/metrics/metricsreporter.go index 79fd6356..6ead6963 100644 --- a/metrics/metricsreporter.go +++ b/metrics/metricsreporter.go @@ -135,27 +135,22 @@ func (m *MetricsReporter) CaptureRoutesPruned(routesPruned uint64) { m.Batcher.BatchAddCounter("routes_pruned", routesPruned) } -func (m *MetricsReporter) CaptureRegistryMessage(msg ComponentTagged) { +func (m *MetricsReporter) CaptureRegistryMessage(msg ComponentTagged, action string) { var componentName string if msg.Component() == "" { - componentName = "registry_message" + componentName = "registry_message." + action } else { - componentName = "registry_message." + msg.Component() + componentName = "registry_message." + action + "." + msg.Component() } m.Batcher.BatchIncrementCounter(componentName) } -func (m *MetricsReporter) CaptureUnregistryMessage(msg ComponentTagged) { - var componentName string - if msg.Component() == "" { - componentName = "unregistry_message" - } else { - componentName = "unregistry_message." + msg.Component() - } - err := m.Sender.IncrementCounter(componentName) - if err != nil { - m.Logger.Debug("failed-sending-metric", log.ErrAttr(err), slog.String("metric", componentName)) +func (m *MetricsReporter) CaptureUnregistryMessage(msg ComponentTagged, action string) { + unregisterMsg := "unregistry_message." + action + if msg.Component() != "" { + unregisterMsg = unregisterMsg + "." + msg.Component() } + m.Batcher.BatchIncrementCounter(unregisterMsg) } func (m *MetricsReporter) CaptureWebSocketUpdate() { diff --git a/metrics/metricsreporter_test.go b/metrics/metricsreporter_test.go index 5508f172..aa584ce0 100644 --- a/metrics/metricsreporter_test.go +++ b/metrics/metricsreporter_test.go @@ -448,22 +448,22 @@ var _ = Describe("MetricsReporter", func() { It("sends number of nats messages received from each component", func() { endpoint.Tags = map[string]string{} - metricReporter.CaptureRegistryMessage(endpoint) + metricReporter.CaptureRegistryMessage(endpoint, "some-action") Expect(batcher.BatchIncrementCounterCallCount()).To(Equal(1)) - Expect(batcher.BatchIncrementCounterArgsForCall(0)).To(Equal("registry_message")) + Expect(batcher.BatchIncrementCounterArgsForCall(0)).To(Equal("registry_message.some-action")) }) It("sends number of nats messages received from each component", func() { endpoint.Tags = map[string]string{"component": "uaa"} - metricReporter.CaptureRegistryMessage(endpoint) + metricReporter.CaptureRegistryMessage(endpoint, "some-action") endpoint.Tags = map[string]string{"component": "route-emitter"} - metricReporter.CaptureRegistryMessage(endpoint) + metricReporter.CaptureRegistryMessage(endpoint, "some-action") Expect(batcher.BatchIncrementCounterCallCount()).To(Equal(2)) - Expect(batcher.BatchIncrementCounterArgsForCall(0)).To(Equal("registry_message.uaa")) - Expect(batcher.BatchIncrementCounterArgsForCall(1)).To(Equal("registry_message.route-emitter")) + Expect(batcher.BatchIncrementCounterArgsForCall(0)).To(Equal("registry_message.some-action.uaa")) + Expect(batcher.BatchIncrementCounterArgsForCall(1)).To(Equal("registry_message.some-action.route-emitter")) }) It("sends the total routes", func() { @@ -517,33 +517,33 @@ var _ = Describe("MetricsReporter", func() { BeforeEach(func() { endpoint = new(route.Endpoint) endpoint.Tags = map[string]string{"component": "oauth-server"} - metricReporter.CaptureUnregistryMessage(endpoint) + metricReporter.CaptureUnregistryMessage(endpoint, "some-action") }) It("increments the counter metric", func() { - Expect(sender.IncrementCounterCallCount()).To(Equal(1)) - Expect(sender.IncrementCounterArgsForCall(0)).To(Equal("unregistry_message.oauth-server")) + Expect(batcher.BatchIncrementCounterCallCount()).To(Equal(1)) + Expect(batcher.BatchIncrementCounterArgsForCall(0)).To(Equal("unregistry_message.some-action.oauth-server")) }) It("increments the counter metric for each component unregistered", func() { endpointTwo := new(route.Endpoint) endpointTwo.Tags = map[string]string{"component": "api-server"} - metricReporter.CaptureUnregistryMessage(endpointTwo) + metricReporter.CaptureUnregistryMessage(endpointTwo, "some-action") - Expect(sender.IncrementCounterCallCount()).To(Equal(2)) - Expect(sender.IncrementCounterArgsForCall(0)).To(Equal("unregistry_message.oauth-server")) - Expect(sender.IncrementCounterArgsForCall(1)).To(Equal("unregistry_message.api-server")) + Expect(batcher.BatchIncrementCounterCallCount()).To(Equal(2)) + Expect(batcher.BatchIncrementCounterArgsForCall(0)).To(Equal("unregistry_message.some-action.oauth-server")) + Expect(batcher.BatchIncrementCounterArgsForCall(1)).To(Equal("unregistry_message.some-action.api-server")) }) }) Context("when unregister msg with empty component name is incremented", func() { BeforeEach(func() { endpoint = new(route.Endpoint) endpoint.Tags = map[string]string{} - metricReporter.CaptureUnregistryMessage(endpoint) + metricReporter.CaptureUnregistryMessage(endpoint, "some-action") }) It("increments the counter metric", func() { - Expect(sender.IncrementCounterCallCount()).To(Equal(1)) - Expect(sender.IncrementCounterArgsForCall(0)).To(Equal("unregistry_message")) + Expect(batcher.BatchIncrementCounterCallCount()).To(Equal(1)) + Expect(batcher.BatchIncrementCounterArgsForCall(0)).To(Equal("unregistry_message.some-action")) }) }) }) diff --git a/proxy/round_tripper/proxy_round_tripper_test.go b/proxy/round_tripper/proxy_round_tripper_test.go index 416a7f54..7da6fd71 100644 --- a/proxy/round_tripper/proxy_round_tripper_test.go +++ b/proxy/round_tripper/proxy_round_tripper_test.go @@ -151,7 +151,7 @@ var _ = Describe("ProxyRoundTripper", func() { }) added := routePool.Put(endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) } proxyRoundTripper = round_tripper.NewProxyRoundTripper( @@ -483,7 +483,7 @@ var _ = Describe("ProxyRoundTripper", func() { }) added := routePool.Put(endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) _, err := proxyRoundTripper.RoundTrip(req) Expect(err).To(MatchError(ContainSubstring("tls: handshake failure"))) @@ -655,7 +655,7 @@ var _ = Describe("ProxyRoundTripper", func() { Context("when there are no more endpoints available", func() { JustBeforeEach(func() { removed := routePool.Remove(endpoint) - Expect(removed).To(BeTrue()) + Expect(removed).To(Equal(route.EndpointUnregistered)) }) It("returns a 502 Bad Gateway response", func() { @@ -738,14 +738,14 @@ var _ = Describe("ProxyRoundTripper", func() { Port: 20222, UseTLS: true, }) - Expect(routePool.Put(tlsEndpoint)).To(Equal(route.ADDED)) + Expect(routePool.Put(tlsEndpoint)).To(Equal(route.EndpointAdded)) nonTLSEndpoint := route.NewEndpoint(&route.EndpointOpts{ Host: "3.3.3.3", Port: 30333, UseTLS: false, }) - Expect(routePool.Put(nonTLSEndpoint)).To(Equal(route.ADDED)) + Expect(routePool.Put(nonTLSEndpoint)).To(Equal(route.EndpointAdded)) }) Context("when retrying different backends", func() { @@ -799,7 +799,7 @@ var _ = Describe("ProxyRoundTripper", func() { }) added := routePool.Put(endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) transport.RoundTripReturns( &http.Response{StatusCode: http.StatusTeapot}, nil, ) @@ -822,7 +822,7 @@ var _ = Describe("ProxyRoundTripper", func() { }) added := routePool.Put(endpoint) - Expect(added).To(Equal(route.UPDATED)) + Expect(added).To(Equal(route.EndpointUpdated)) transport.RoundTripReturns( &http.Response{StatusCode: http.StatusTeapot}, nil, ) @@ -859,7 +859,7 @@ var _ = Describe("ProxyRoundTripper", func() { Host: "1.1.1.1", Port: 9091, UseTLS: true, PrivateInstanceId: "instanceId-2", }) added := routePool.Put(endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) _, err := proxyRoundTripper.RoundTrip(req) Expect(err).ToNot(HaveOccurred()) @@ -1143,11 +1143,11 @@ var _ = Describe("ProxyRoundTripper", func() { }) added := routePool.Put(endpoint1) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) added = routePool.Put(endpoint2) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) removed := routePool.Remove(endpoint) - Expect(removed).To(BeTrue()) + Expect(removed).To(Equal(route.EndpointUnregistered)) }) Context("when there are no cookies on the request", func() { @@ -1439,14 +1439,14 @@ var _ = Describe("ProxyRoundTripper", func() { JustBeforeEach(func() { removed := routePool.Remove(endpoint1) - Expect(removed).To(BeTrue()) + Expect(removed).To(Equal(route.EndpointUnregistered)) removed = routePool.Remove(endpoint2) - Expect(removed).To(BeTrue()) + Expect(removed).To(Equal(route.EndpointUnregistered)) new_endpoint := route.NewEndpoint(&route.EndpointOpts{PrivateInstanceId: "id-5"}) added := routePool.Put(new_endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) }) Context("when route service headers are not on the request", func() { @@ -1502,14 +1502,14 @@ var _ = Describe("ProxyRoundTripper", func() { JustBeforeEach(func() { removed := routePool.Remove(endpoint1) - Expect(removed).To(BeTrue()) + Expect(removed).To(Equal(route.EndpointUnregistered)) removed = routePool.Remove(endpoint2) - Expect(removed).To(BeTrue()) + Expect(removed).To(Equal(route.EndpointUnregistered)) new_endpoint := route.NewEndpoint(&route.EndpointOpts{PrivateInstanceId: "id-5"}) added := routePool.Put(new_endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) }) Context("when route service headers are not on the request", func() { @@ -1568,14 +1568,14 @@ var _ = Describe("ProxyRoundTripper", func() { JustBeforeEach(func() { removed := routePool.Remove(endpoint1) - Expect(removed).To(BeTrue()) + Expect(removed).To(Equal(route.EndpointUnregistered)) removed = routePool.Remove(endpoint2) - Expect(removed).To(BeTrue()) + Expect(removed).To(Equal(route.EndpointUnregistered)) new_endpoint := route.NewEndpoint(&route.EndpointOpts{PrivateInstanceId: "id-5"}) added := routePool.Put(new_endpoint) - Expect(added).To(Equal(route.ADDED)) + Expect(added).To(Equal(route.EndpointAdded)) }) Context("when route service headers are not on the request", func() { diff --git a/registry/registry.go b/registry/registry.go index d70ff316..3701d794 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -87,35 +87,42 @@ func (r *RouteRegistry) Register(uri route.Uri, endpoint *route.Endpoint) { return } - endpointAdded := r.register(uri, endpoint) + r.RLock() + defer r.RUnlock() + + t := time.Now() + registerRouteResult, pool := r.registerRoute(uri) + if registerRouteResult == route.RouteRegistered { + r.reporter.CaptureRegistryMessage(endpoint, string(route.RouteRegistered)) + if r.logger.Enabled(context.Background(), slog.LevelInfo) { + r.logger.Info(string(route.RouteRegistered), buildSlogAttrs(uri, endpoint)...) + } + } - r.reporter.CaptureRegistryMessage(endpoint) + endpointPutResult := r.registerEndpoint(endpoint, pool) - if endpointAdded == route.ADDED && !endpoint.UpdatedAt.IsZero() { + if endpointPutResult == route.EndpointAdded && !endpoint.UpdatedAt.IsZero() { r.reporter.CaptureRouteRegistrationLatency(time.Since(endpoint.UpdatedAt)) } - switch endpointAdded { - case route.ADDED: + r.reporter.CaptureRegistryMessage(endpoint, string(endpointPutResult)) + + switch endpointPutResult { + case route.EndpointAdded: if r.logger.Enabled(context.Background(), slog.LevelInfo) { - r.logger.Info("endpoint-registered", buildSlogAttrs(uri, endpoint)...) - } - case route.UPDATED: - if r.logger.Enabled(context.Background(), slog.LevelDebug) { - r.logger.Debug("endpoint-registered", buildSlogAttrs(uri, endpoint)...) + r.logger.Info(string(endpointPutResult), buildSlogAttrs(uri, endpoint)...) } default: if r.logger.Enabled(context.Background(), slog.LevelDebug) { - r.logger.Debug("endpoint-not-registered", buildSlogAttrs(uri, endpoint)...) + r.logger.Debug(string(endpointPutResult), buildSlogAttrs(uri, endpoint)...) } } -} -func (r *RouteRegistry) register(uri route.Uri, endpoint *route.Endpoint) route.PoolPutResult { - r.RLock() - defer r.RUnlock() + r.SetTimeOfLastUpdate(t) +} - t := time.Now() +func (r *RouteRegistry) registerRoute(uri route.Uri) (route.PoolRegisterRouteResult, *route.EndpointPool) { + poolRegisterRouteResult := route.RouteAlreadyExists routekey := uri.RouteKey() pool := r.byURI.Find(routekey) @@ -124,18 +131,21 @@ func (r *RouteRegistry) register(uri route.Uri, endpoint *route.Endpoint) route. r.RUnlock() pool = r.insertRouteKey(routekey, uri) r.RLock() + poolRegisterRouteResult = route.RouteRegistered } + return poolRegisterRouteResult, pool +} +func (r *RouteRegistry) registerEndpoint(endpoint *route.Endpoint, pool *route.EndpointPool) route.PoolRegisterEndpointResult { if endpoint.StaleThreshold > r.dropletStaleThreshold || endpoint.StaleThreshold == 0 { endpoint.StaleThreshold = r.dropletStaleThreshold } - endpointAdded := pool.Put(endpoint) + endpointAddedResult := pool.Put(endpoint) // Overwrites the load balancing algorithm of a pool by that of a specified endpoint, if that is valid. pool.SetPoolLoadBalancingAlgorithm(endpoint) - r.SetTimeOfLastUpdate(t) - return endpointAdded + return endpointAddedResult } // insertRouteKey acquires a write lock, inserts the route key into the registry and releases the write lock. @@ -156,7 +166,7 @@ func (r *RouteRegistry) insertRouteKey(routekey route.Uri, uri route.Uri) *route LoadBalancingAlgorithm: r.DefaultLoadBalancingAlgorithm, }) r.byURI.Insert(routekey, pool) - r.logger.Info("route-registered", slog.Any("uri", routekey)) + r.logger.Info(string(route.RouteRegistered), slog.Any("uri", routekey)) // for backward compatibility: r.logger.Debug("uri-added", slog.Any("uri", routekey)) } @@ -168,42 +178,47 @@ func (r *RouteRegistry) Unregister(uri route.Uri, endpoint *route.Endpoint) { return } - r.unregister(uri, endpoint) - - r.reporter.CaptureUnregistryMessage(endpoint) -} - -func (r *RouteRegistry) unregister(uri route.Uri, endpoint *route.Endpoint) { r.Lock() defer r.Unlock() - uri = uri.RouteKey() + routeKey := uri.RouteKey() + endpointUnregisteredResult, pool := r.unregisterEndpoint(routeKey, endpoint) + if pool == nil { + return + } - pool := r.byURI.Find(uri) - if pool != nil { - endpointRemoved := pool.Remove(endpoint) - if endpointRemoved { - if r.logger.Enabled(context.Background(), slog.LevelInfo) { - r.logger.Info("endpoint-unregistered", buildSlogAttrs(uri, endpoint)...) - } - } else { - if r.logger.Enabled(context.Background(), slog.LevelInfo) { - r.logger.Info("endpoint-not-unregistered", buildSlogAttrs(uri, endpoint)...) - } + r.reporter.CaptureUnregistryMessage(endpoint, string(endpointUnregisteredResult)) + if r.logger.Enabled(context.Background(), slog.LevelInfo) { + r.logger.Info(string(endpointUnregisteredResult), buildSlogAttrs(routeKey, endpoint)...) + } + + routeUnregisteredResult := r.deleteRouteWithoutEndpoint(routeKey, pool) + switch routeUnregisteredResult { + case route.RouteUnregistered: + r.reporter.CaptureUnregistryMessage(endpoint, string(routeUnregisteredResult)) + if r.logger.Enabled(context.Background(), slog.LevelInfo) { + r.logger.Info(string(routeUnregisteredResult), slog.Any("uri", routeKey)) } + } +} - if pool.IsEmpty() { - if r.EmptyPoolResponseCode503 && r.EmptyPoolTimeout > 0 { - if time.Since(pool.LastUpdated()) > r.EmptyPoolTimeout { - r.byURI.Delete(uri) - r.logger.Info("route-unregistered", slog.Any("uri", uri)) - } - } else { - r.byURI.Delete(uri) - r.logger.Info("route-unregistered", slog.Any("uri", uri)) - } +func (r *RouteRegistry) unregisterEndpoint(routeKey route.Uri, endpoint *route.Endpoint) (route.PoolRemoveEndpointResult, *route.EndpointPool) { + pool := r.byURI.Find(routeKey) + if pool == nil { + return route.EndpointNotUnregistered, nil + } + return pool.Remove(endpoint), pool +} + +func (r *RouteRegistry) deleteRouteWithoutEndpoint(routeKey route.Uri, pool *route.EndpointPool) route.PoolRemoveRouteResult { + if pool.IsEmpty() { + if !(r.EmptyPoolResponseCode503 && r.EmptyPoolTimeout > 0) || + (r.EmptyPoolResponseCode503 && r.EmptyPoolTimeout > 0 && time.Since(pool.LastUpdated()) > r.EmptyPoolTimeout) { + r.byURI.Delete(routeKey) + return route.RouteUnregistered } } + return route.RouteNotUnregistered } func (r *RouteRegistry) Lookup(uri route.Uri) *route.EndpointPool { @@ -301,11 +316,11 @@ func (r *RouteRegistry) StopPruningCycle() { } } -func (registry *RouteRegistry) NumUris() int { - registry.RLock() - defer registry.RUnlock() +func (r *RouteRegistry) NumUris() int { + r.RLock() + defer r.RUnlock() - return registry.byURI.PoolCount() + return r.byURI.PoolCount() } func (r *RouteRegistry) MSSinceLastUpdate() int64 { diff --git a/registry/registry_test.go b/registry/registry_test.go index 9454d069..f6939943 100644 --- a/registry/registry_test.go +++ b/registry/registry_test.go @@ -70,9 +70,59 @@ var _ = Describe("RouteRegistry", func() { }) Context("Register", func() { - It("emits message_count metrics", func() { - r.Register("foo", fooEndpoint) - Expect(reporter.CaptureRegistryMessageCallCount()).To(Equal(1)) + Context("when a new endpoint is registered", func() { + It("emits endpoint-registered message_count metrics", func() { + r.Register("foo", fooEndpoint) + Expect(reporter.CaptureRegistryMessageCallCount()).To(Equal(2)) + endpoint1, action1 := reporter.CaptureRegistryMessageArgsForCall(0) + Expect(endpoint1).To(Equal(fooEndpoint)) + Expect(action1).To(Equal(string(route.RouteRegistered))) + endpoint2, action2 := reporter.CaptureRegistryMessageArgsForCall(1) + Expect(endpoint2).To(Equal(fooEndpoint)) + Expect(action2).To(Equal(string(route.EndpointAdded))) + }) + }) + + Context("when an endpoint is updated", func() { + It("emits endpoint-updated message_count metrics", func() { + modTag1 := models.ModificationTag{Guid: "abc", Index: 0} + endpoint1 := route.NewEndpoint(&route.EndpointOpts{ModificationTag: modTag1}) + modTag2 := models.ModificationTag{Guid: "abc", Index: 1} + endpoint2 := route.NewEndpoint(&route.EndpointOpts{ModificationTag: modTag2}) + r.Register("foo", endpoint1) + r.Register("foo", endpoint2) + Expect(reporter.CaptureRegistryMessageCallCount()).To(Equal(3)) + endpointR1, action1 := reporter.CaptureRegistryMessageArgsForCall(0) + Expect(endpointR1).To(Equal(endpoint1)) + Expect(action1).To(Equal(string(route.RouteRegistered))) + endpointR2, action2 := reporter.CaptureRegistryMessageArgsForCall(1) + Expect(endpointR2).To(Equal(endpoint1)) + Expect(action2).To(Equal(string(route.EndpointAdded))) + endpointR3, action3 := reporter.CaptureRegistryMessageArgsForCall(2) + Expect(endpointR3).To(Equal(endpoint2)) + Expect(action3).To(Equal(string(route.EndpointUpdated))) + }) + }) + + Context("when modificationTag is older so that the endpoint is not updated", func() { + It("emits endpoint-bot-updated message_count metrics", func() { + modTag1 := models.ModificationTag{Guid: "abc", Index: 1} + endpoint1 := route.NewEndpoint(&route.EndpointOpts{ModificationTag: modTag1}) + modTag2 := models.ModificationTag{Guid: "abc", Index: 0} + endpoint2 := route.NewEndpoint(&route.EndpointOpts{ModificationTag: modTag2}) + r.Register("foo", endpoint1) + r.Register("foo", endpoint2) + Expect(reporter.CaptureRegistryMessageCallCount()).To(Equal(3)) + endpointR1, action1 := reporter.CaptureRegistryMessageArgsForCall(0) + Expect(endpointR1).To(Equal(endpoint1)) + Expect(action1).To(Equal(string(route.RouteRegistered))) + endpointR2, action2 := reporter.CaptureRegistryMessageArgsForCall(1) + Expect(endpointR2).To(Equal(endpoint1)) + Expect(action2).To(Equal(string(route.EndpointAdded))) + endpointR3, action3 := reporter.CaptureRegistryMessageArgsForCall(2) + Expect(endpointR3).To(Equal(endpoint2)) + Expect(action3).To(Equal(string(route.EndpointNotUpdated))) + }) }) Context("when the endpoint has an UpdatedAt timestamp", func() { @@ -237,7 +287,7 @@ var _ = Describe("RouteRegistry", func() { r.Register("a.route", fooEndpoint) Eventually(logger).Should(gbytes.Say(`"log_level":1.*route-registered.*a\.route`)) - Eventually(logger).Should(gbytes.Say(`"log_level":1.*endpoint-registered.*a\.route.*192\.168\.1\.1`)) + Eventually(logger).Should(gbytes.Say(`"log_level":1.*endpoint-added.*a\.route.*192\.168\.1\.1`)) }) It("logs 'uri-added' at debug level for backward compatibility", func() { @@ -253,7 +303,7 @@ var _ = Describe("RouteRegistry", func() { Expect(logger).NotTo(gbytes.Say(`uri-added.*.*a\.route`)) By("not providing IsolationSegment property") r.Register("a.route", fooEndpoint) - //TODO: use pattern matching to make sure we are asserting on the unregister line + //TODO: use pattern matching to make sure we are asserting on the unregisterEndpoint line Eventually(logger).Should(gbytes.Say(`"isolation_segment":"-"`)) }) @@ -263,7 +313,7 @@ var _ = Describe("RouteRegistry", func() { }) r.Register("a.route", isoSegEndpoint) - //TODO: use pattern matching to make sure we are asserting on the unregister line + //TODO: use pattern matching to make sure we are asserting on the unregisterEndpoint line Eventually(logger).Should(gbytes.Say(`"isolation_segment":"is1"`)) }) @@ -516,25 +566,73 @@ var _ = Describe("RouteRegistry", func() { }) Context("Unregister", func() { - Context("when endpoint has component tagged", func() { + Context("when route is registered", func() { + Context("when endpoint has component tagged", func() { + BeforeEach(func() { + fooEndpoint.Tags = map[string]string{"component": "oauth-server"} + r.Register("foo", fooEndpoint) + Expect(reporter.CaptureRegistryMessageCallCount()).To(Equal(2)) + }) + + It("emits counter metrics for unregister endpoint and route", func() { + r.Unregister("foo", fooEndpoint) + Expect(reporter.CaptureUnregistryMessageCallCount()).To(Equal(2)) + endpoint1, action1 := reporter.CaptureUnregistryMessageArgsForCall(0) + Expect(endpoint1).To(Equal(fooEndpoint)) + Expect(action1).To(Equal(string(route.EndpointUnregistered))) + endpoint2, action2 := reporter.CaptureUnregistryMessageArgsForCall(1) + Expect(endpoint2).To(Equal(fooEndpoint)) + Expect(action2).To(Equal(string(route.RouteUnregistered))) + }) + }) + + Context("when endpoint does not have component tag", func() { + BeforeEach(func() { + fooEndpoint.Tags = map[string]string{} + r.Register("foo", fooEndpoint) + Expect(reporter.CaptureRegistryMessageCallCount()).To(Equal(2)) + }) + It("emits counter metrics for unregister endpoint and route", func() { + r.Unregister("foo", fooEndpoint) + Expect(reporter.CaptureUnregistryMessageCallCount()).To(Equal(2)) + endpoint1, action1 := reporter.CaptureUnregistryMessageArgsForCall(0) + Expect(endpoint1).To(Equal(fooEndpoint)) + Expect(action1).To(Equal(string(route.EndpointUnregistered))) + endpoint2, action2 := reporter.CaptureUnregistryMessageArgsForCall(1) + Expect(endpoint2).To(Equal(fooEndpoint)) + Expect(action2).To(Equal(string(route.RouteUnregistered))) + }) + }) + }) + Context("when route has multiple endpoints", func() { BeforeEach(func() { - fooEndpoint.Tags = map[string]string{"component": "oauth-server"} + fooEndpoint.Tags = map[string]string{} + fooEndpoint2 := route.NewEndpoint(&route.EndpointOpts{ + Host: "192.168.1.2", + Tags: map[string]string{ + "runtime": "ruby18", + "framework": "sinatra", + }}) + + r.Register("foo", fooEndpoint) + r.Register("foo", fooEndpoint2) + Expect(reporter.CaptureRegistryMessageCallCount()).To(Equal(3)) }) - It("emits counter metrics", func() { + It("emits counter metrics for unregister endpoint only", func() { r.Unregister("foo", fooEndpoint) Expect(reporter.CaptureUnregistryMessageCallCount()).To(Equal(1)) - Expect(reporter.CaptureUnregistryMessageArgsForCall(0)).To(Equal(fooEndpoint)) + endpoint1, action1 := reporter.CaptureUnregistryMessageArgsForCall(0) + Expect(endpoint1).To(Equal(fooEndpoint)) + Expect(action1).To(Equal(string(route.EndpointUnregistered))) }) }) - - Context("when endpoint does not have component tag", func() { + Context("when route is not registered", func() { BeforeEach(func() { fooEndpoint.Tags = map[string]string{} }) - It("emits counter metrics", func() { + It("does not emit counter metrics for unregister", func() { r.Unregister("foo", fooEndpoint) - Expect(reporter.CaptureUnregistryMessageCallCount()).To(Equal(1)) - Expect(reporter.CaptureUnregistryMessageArgsForCall(0)).To(Equal(fooEndpoint)) + Expect(reporter.CaptureUnregistryMessageCallCount()).To(Equal(0)) }) }) @@ -701,7 +799,7 @@ var _ = Describe("RouteRegistry", func() { BeforeEach(func() { fooEndpoint.IsolationSegment = "" }) - It("does not log an unregister message", func() { + It("does not log an unregisterEndpoint message", func() { r.Unregister("a.route", fooEndpoint) Expect(r.NumUris()).To(Equal(3)) Expect(r.NumEndpoints()).To(Equal(3)) @@ -809,21 +907,21 @@ var _ = Describe("RouteRegistry", func() { It("only logs unregistration for existing routes", func() { r.Unregister("non-existent-route", fooEndpoint) - Expect(logger).NotTo(gbytes.Say(`unregister.*.*a\.non-existent-route`)) + Expect(logger).NotTo(gbytes.Say(`unregisterEndpoint.*.*a\.non-existent-route`)) By("not providing IsolationSegment property") r.Unregister("a.route", fooEndpoint) - //TODO: use pattern matching to make sure we are asserting on the unregister line + //TODO: use pattern matching to make sure we are asserting on the unregisterEndpoint line Eventually(logger).Should(gbytes.Say(`"isolation_segment":"-"`)) }) - It("logs unregister message with IsolationSegment when it's provided", func() { + It("logs unregisterEndpoint message with IsolationSegment when it's provided", func() { isoSegEndpoint := route.NewEndpoint(&route.EndpointOpts{ IsolationSegment: "is1", }) r.Register("a.isoSegRoute", isoSegEndpoint) r.Unregister("a.isoSegRoute", isoSegEndpoint) - //TODO: use pattern matching to make sure we are asserting on the unregister line + //TODO: use pattern matching to make sure we are asserting on the unregisterEndpoint line Eventually(logger).Should(gbytes.Say(`"isolation_segment":"is1"`)) }) }) @@ -849,7 +947,7 @@ var _ = Describe("RouteRegistry", func() { Expect(r.NumEndpoints()).To(Equal(0)) }) - It("does not unregister route if modification tag older", func() { + It("does not unregisterEndpoint route if modification tag older", func() { modTag2 := models.ModificationTag{ Guid: "abc", Index: 8, diff --git a/route/leastconnection_test.go b/route/leastconnection_test.go index 45cf814d..afe599ea 100644 --- a/route/leastconnection_test.go +++ b/route/leastconnection_test.go @@ -187,7 +187,7 @@ var _ = Describe("LeastConnection", func() { Context("when there is only one endpoint", func() { BeforeEach(func() { - Expect(pool.Remove(epOne)).To(BeTrue()) + Expect(pool.Remove(epOne)).To(Equal(route.EndpointUnregistered)) epTwo.Stats.NumberConnections.Increment() epTwo.Stats.NumberConnections.Increment() }) diff --git a/route/pool.go b/route/pool.go index bf903c19..8b50f66e 100644 --- a/route/pool.go +++ b/route/pool.go @@ -20,12 +20,33 @@ type Counter struct { value int64 } -type PoolPutResult int +type PoolRegisterEndpointResult string const ( - UNMODIFIED = PoolPutResult(iota) - UPDATED - ADDED + EndpointNotUpdated PoolRegisterEndpointResult = "endpoint-not-updated" + EndpointUpdated PoolRegisterEndpointResult = "endpoint-updated" + EndpointAdded PoolRegisterEndpointResult = "endpoint-added" +) + +type PoolRegisterRouteResult string + +const ( + RouteRegistered PoolRegisterRouteResult = "route-registered" + RouteAlreadyExists PoolRegisterRouteResult = "route-already-exists" +) + +type PoolRemoveEndpointResult string + +const ( + EndpointUnregistered PoolRemoveEndpointResult = "endpoint-unregistered" + EndpointNotUnregistered PoolRemoveEndpointResult = "endpoint-not-unregistered" +) + +type PoolRemoveRouteResult string + +const ( + RouteUnregistered PoolRemoveRouteResult = "route-unregistered" + RouteNotUnregistered PoolRemoveRouteResult = "route-not-unregistered" ) func NewCounter(initial int64) *Counter { @@ -254,20 +275,20 @@ func (p *EndpointPool) Update() { p.updatedAt = time.Now() } -func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { +func (p *EndpointPool) Put(endpoint *Endpoint) PoolRegisterEndpointResult { p.Lock() defer p.Unlock() - var result PoolPutResult + var result PoolRegisterEndpointResult e, found := p.index[endpoint.CanonicalAddr()] if found { - result = UPDATED + result = EndpointUpdated if !e.endpoint.Equal(endpoint) { e.Lock() defer e.Unlock() if !e.endpoint.ModificationTag.SucceededBy(&endpoint.ModificationTag) { - return UNMODIFIED + return EndpointNotUpdated } oldEndpoint := e.endpoint @@ -283,7 +304,7 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { } } } else { - result = ADDED + result = EndpointAdded e = &endpointElem{ endpoint: endpoint, index: len(p.endpoints), @@ -341,8 +362,8 @@ func (p *EndpointPool) PruneEndpoints() []*Endpoint { return prunedEndpoints } -// Returns true if the endpoint was removed from the EndpointPool, false otherwise. -func (p *EndpointPool) Remove(endpoint *Endpoint) bool { +// Remove Returns true if the endpoint was removed from the EndpointPool, false otherwise. +func (p *EndpointPool) Remove(endpoint *Endpoint) PoolRemoveEndpointResult { var e *endpointElem p.Lock() @@ -352,11 +373,11 @@ func (p *EndpointPool) Remove(endpoint *Endpoint) bool { e = p.index[endpoint.CanonicalAddr()] if e != nil && e.endpoint.modificationTagSameOrNewer(endpoint) { p.removeEndpoint(e) - return true + return EndpointUnregistered } } - return false + return EndpointNotUnregistered } func (p *EndpointPool) removeEndpoint(e *endpointElem) { diff --git a/route/pool_test.go b/route/pool_test.go index ee4bd31e..dfc70f4e 100644 --- a/route/pool_test.go +++ b/route/pool_test.go @@ -142,7 +142,7 @@ var _ = Describe("EndpointPool", func() { endpoint := &route.Endpoint{} b := pool.Put(endpoint) - Expect(b).To(Equal(route.ADDED)) + Expect(b).To(Equal(route.EndpointAdded)) }) It("handles duplicate endpoints", func() { @@ -152,7 +152,7 @@ var _ = Describe("EndpointPool", func() { pool.MarkUpdated(time.Now().Add(-(10 * time.Minute))) b := pool.Put(endpoint) - Expect(b).To(Equal(route.UPDATED)) + Expect(b).To(Equal(route.EndpointUpdated)) prunedEndpoints := pool.PruneEndpoints() Expect(prunedEndpoints).To(BeEmpty()) @@ -163,7 +163,7 @@ var _ = Describe("EndpointPool", func() { endpoint2 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678}) pool.Put(endpoint1) - Expect(pool.Put(endpoint2)).To(Equal(route.UPDATED)) + Expect(pool.Put(endpoint2)).To(Equal(route.EndpointUpdated)) }) Context("with modification tags", func() { @@ -175,13 +175,13 @@ var _ = Describe("EndpointPool", func() { modTag2 = models.ModificationTag{Guid: "abc"} endpoint1 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: modTag}) - Expect(pool.Put(endpoint1)).To(Equal(route.ADDED)) + Expect(pool.Put(endpoint1)).To(Equal(route.EndpointAdded)) }) It("updates an endpoint with modification tag", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: modTag2}) - Expect(pool.Put(endpoint)).To(Equal(route.UPDATED)) + Expect(pool.Put(endpoint)).To(Equal(route.EndpointUpdated)) Expect(pool.Endpoints(logger.Logger, "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag2)) }) @@ -196,7 +196,7 @@ var _ = Describe("EndpointPool", func() { olderModTag := models.ModificationTag{Guid: "abc"} endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: olderModTag}) - Expect(pool.Put(endpoint)).To(Equal(route.UNMODIFIED)) + Expect(pool.Put(endpoint)).To(Equal(route.EndpointNotUpdated)) Expect(pool.Endpoints(logger.Logger, "", false, azPreference, az).Next(0).ModificationTag).To(Equal(modTag2)) }) }) @@ -340,13 +340,13 @@ var _ = Describe("EndpointPool", func() { endpoint := &route.Endpoint{} endpointRS := &route.Endpoint{RouteServiceUrl: "my-url"} b := pool.Put(endpoint) - Expect(b).To(Equal(route.ADDED)) + Expect(b).To(Equal(route.EndpointAdded)) url := pool.RouteServiceUrl() Expect(url).To(BeEmpty()) b = pool.Put(endpointRS) - Expect(b).To(Equal(route.UPDATED)) + Expect(b).To(Equal(route.EndpointUpdated)) url = pool.RouteServiceUrl() Expect(url).To(Equal("my-url")) }) @@ -362,25 +362,25 @@ var _ = Describe("EndpointPool", func() { endpointRS1 := route.NewEndpoint(&route.EndpointOpts{Host: "host-1", Port: 1234, RouteServiceUrl: "first-url"}) endpointRS2 := route.NewEndpoint(&route.EndpointOpts{Host: "host-2", Port: 2234, RouteServiceUrl: "second-url"}) b := pool.Put(endpointRS1) - Expect(b).To(Equal(route.ADDED)) + Expect(b).To(Equal(route.EndpointAdded)) url := pool.RouteServiceUrl() Expect(url).To(Equal("first-url")) b = pool.Put(endpointRS2) - Expect(b).To(Equal(route.ADDED)) + Expect(b).To(Equal(route.EndpointAdded)) url = pool.RouteServiceUrl() Expect(url).To(Equal("second-url")) endpointRS1.RouteServiceUrl = "third-url" b = pool.Put(endpointRS1) - Expect(b).To(Equal(route.UPDATED)) + Expect(b).To(Equal(route.EndpointUpdated)) url = pool.RouteServiceUrl() Expect(url).To(Equal("third-url")) endpointRS2.RouteServiceUrl = "fourth-url" b = pool.Put(endpointRS2) - Expect(b).To(Equal(route.UPDATED)) + Expect(b).To(Equal(route.EndpointUpdated)) url = pool.RouteServiceUrl() Expect(url).To(Equal("fourth-url")) }) @@ -469,7 +469,7 @@ var _ = Describe("EndpointPool", func() { pool.Put(endpoint) b := pool.Remove(endpoint) - Expect(b).To(BeTrue()) + Expect(b).To(Equal(route.EndpointUnregistered)) Expect(pool.IsEmpty()).To(BeTrue()) }) @@ -478,7 +478,7 @@ var _ = Describe("EndpointPool", func() { b := pool.Remove(endpoint) - Expect(b).To(BeFalse()) + Expect(b).To(Equal(route.EndpointNotUnregistered)) }) Context("with modification tags", func() { @@ -487,12 +487,12 @@ var _ = Describe("EndpointPool", func() { modTag = models.ModificationTag{Guid: "abc"} endpoint1 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: modTag}) - Expect(pool.Put(endpoint1)).To(Equal(route.ADDED)) + Expect(pool.Put(endpoint1)).To(Equal(route.EndpointAdded)) }) It("removes an endpoint with modification tag", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: modTag}) - Expect(pool.Remove(endpoint)).To(BeTrue()) + Expect(pool.Remove(endpoint)).To(Equal(route.EndpointUnregistered)) Expect(pool.IsEmpty()).To(BeTrue()) }) @@ -505,7 +505,7 @@ var _ = Describe("EndpointPool", func() { It("removes an endpoint", func() { endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: modTag}) - Expect(pool.Remove(endpoint)).To(BeTrue()) + Expect(pool.Remove(endpoint)).To(Equal(route.EndpointUnregistered)) Expect(pool.IsEmpty()).To(BeTrue()) }) }) @@ -521,7 +521,7 @@ var _ = Describe("EndpointPool", func() { olderModTag := models.ModificationTag{Guid: "abc"} endpoint := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, ModificationTag: olderModTag}) - Expect(pool.Remove(endpoint)).To(BeFalse()) + Expect(pool.Remove(endpoint)).To(Equal(route.EndpointNotUnregistered)) Expect(pool.IsEmpty()).To(BeFalse()) }) })