diff --git a/examples/leaderregistry/main_test.go b/examples/leaderregistry/main_test.go index fc921cb..1646974 100644 --- a/examples/leaderregistry/main_test.go +++ b/examples/leaderregistry/main_test.go @@ -488,6 +488,60 @@ func TestSurviveReplicaFailureWithRandomStrategy(t *testing.T) { } } +// TestRemoteServerTimeout tests the scenario where an actor is invoked with 2 replicas. +// The invocation is done using a context with a deadline, and invoking a special operation that receives the expected +// timeout as the payload. The actor operation checks whether the expected deadline matches the context deadline received by +// the server. Due to HTTP lag, the test verifies that the deadlines are no more than 1 second apart. +func TestRemoteServerTimeout(t *testing.T) { + var ( + lp = &leaderProvider{} + portServer1 = nextPort() + ) + + // Set the leader address for the registry. + lp.setLeader(registry.Address{ + IP: net.ParseIP("127.0.0.1"), + Port: baseRegistryPort + portServer1, + }) + + var ( + server1, _, cleanupFn1 = newServer(t, lp, portServer1) + server2, _, cleanupFn2 = newServer(t, lp, nextPort()) + ) + + // Clean up resources at the end of the test. + defer cleanupFn1() + defer cleanupFn2() + defer server1.Close(context.Background()) + defer server2.Close(context.Background()) + + // Sleep for a few seconds to allow servers to heartbeat and actors to activate. + time.Sleep(5 * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Define the actor options with 1 extra replica and the random selection strategy. + options := types.CreateIfNotExist{Options: types.ActorOptions{ + ExtraReplicas: 1, + ReplicationStrategy: types.ReplicaSelectionStrategyRandom, + }} + + // We ensure that the actor invocation reaches two different servers, thus verifying that at least one + // of the requests went over the network. We use the "ctx-timeout-check" operation to propagate the + // context deadline as a payload to the actor invocation. The environment used for the test is created + // with the ForceRemoteProcedureCalls option enabled. This ensures that even localhost calls are done + // using HTTP, allowing us to validate network communication between different servers. + require.Eventually(t, func() bool { + deadline, _ := ctx.Deadline() + _, err := server1.InvokeActor( + ctx, namespace, actorID(0), module, "ctx-timeout-check", []byte(time.Until(deadline).String()), options) + require.NoError(t, err) + + return server1.NumActivatedActors() == 1 && server2.NumActivatedActors() == 1 + }, 10*time.Second, time.Millisecond, "actor should be replicated in all servers") +} + // TestSurviveReplicaFailureWithSortedStrategy tests the ability to survive replica failure with a sorted strategy (biased towards a single replica). // It validates that the actor can handle replica failures and still maintain high availability. // The test creates an actor with multiple replicas, ensures that the actor is replicated only on the server that is biased towards, @@ -638,6 +692,19 @@ func newServer( require.NoError(t, err) require.NoError(t, env.RegisterGoModule(types.NewNamespacedIDNoType(namespace, module), &testModule{})) + server := newServerWithEnv(t, env, idx) + return env, reg, func() { + server.Stop(context.Background()) + env.Close(context.Background()) + reg.Close(context.Background()) + } +} + +func newServerWithEnv(t *testing.T, env virtual.Environment, idx int) *virtual.Server { + var ( + envPort = baseEnvPort + idx + ) + server := virtual.NewServer(registry.NewNoopModuleStore(), env) go func() { if err := server.Start(envPort); err != nil { @@ -648,11 +715,7 @@ func newServer( } }() - return env, reg, func() { - reg.Close(context.Background()) - env.Close(context.Background()) - server.Stop(context.Background()) - } + return server } func newRegistry( @@ -735,6 +798,23 @@ func (ta *testActor) Invoke( case "inc-memory-usage": ta.count++ return nil, nil + case "ctx-timeout-check": + // Handle the special "ctx-timeout-check" operation where the expected timeout + // value is passed as the payload. The function verifies whether the received + // context's deadline matches the expected timeout. This ensures that the context + // deadlines are propagated correctly over RPCs. + expectedTimeout, err := time.ParseDuration(string(payload)) + if err != nil { + return nil, fmt.Errorf("failed to parse duration: %w", err) + } + got, ok := ctx.Deadline() + if !ok { + return nil, fmt.Errorf("context has no deadline") + } + if expected := time.Now().Add(expectedTimeout); !areWithinDuration(expected, got, time.Second) { + return nil, fmt.Errorf("context deadline is not within expected duration: expected %s got %s", expected, got) + } + return nil, nil default: return nil, fmt.Errorf("testActor: unhandled operation: %s", operation) } @@ -753,3 +833,15 @@ func actorID(idx int) string { func nextPort() int { return int(atomic.AddInt64(&nextServerPort, 1)) } + +// areWithinDuration checks whether the duration between two given time values +// is within the specified maximum duration. +func areWithinDuration(t1, t2 time.Time, maxDuration time.Duration) bool { + duration := t1.Sub(t2) + absDuration := duration + if absDuration < 0 { + absDuration = -absDuration + } + + return absDuration <= maxDuration +} diff --git a/virtual/client.go b/virtual/client.go index 5b764eb..57aa86f 100644 --- a/virtual/client.go +++ b/virtual/client.go @@ -51,6 +51,12 @@ func (h *httpClient) InvokeActorRemote( return nil, fmt.Errorf("HTTPClient: InvokeDirect: error constructing request: %w", err) } + deadline, ok := ctx.Deadline() + if ok { + timeout := time.Until(deadline) + req.Header.Add(types.HTTPHeaderTimeout, timeout.String()) + } + resp, err := h.c.Do(req) if err != nil { return nil, fmt.Errorf("HTTPClient: InvokeDirect: error running request: %w", err) diff --git a/virtual/server.go b/virtual/server.go index 38dc673..a60cc67 100644 --- a/virtual/server.go +++ b/virtual/server.go @@ -16,6 +16,8 @@ import ( "github.com/richardartoul/nola/virtual/types" ) +const DefaultHTTPRequestTimeout = 15 * time.Second + type Server struct { sync.Mutex @@ -97,9 +99,7 @@ func (s *Server) registerModule(w http.ResponseWriter, r *http.Request) { return } - ctx, cc := context.WithTimeout(context.Background(), 60*time.Second) - defer cc() - result, err := s.moduleStore.RegisterModule(ctx, namespace, moduleID, moduleBytes, registry.ModuleOptions{}) + result, err := s.moduleStore.RegisterModule(getContextFromRequest(r), namespace, moduleID, moduleBytes, registry.ModuleOptions{}) if err != nil { writeStatusCodeForError(w, err) w.Write([]byte(err.Error())) @@ -156,11 +156,8 @@ func (s *Server) invoke(w http.ResponseWriter, r *http.Request) { req.Payload = marshaled } - // TODO: This should be configurable, probably in a header with some maximum. - ctx, cc := context.WithTimeout(context.Background(), 5*time.Second) - defer cc() result, err := s.environment.InvokeActorStream( - ctx, req.Namespace, req.ActorID, req.ModuleID, req.Operation, req.Payload, req.CreateIfNotExist) + getContextFromRequest(r), req.Namespace, req.ActorID, req.ModuleID, req.Operation, req.Payload, req.CreateIfNotExist) if err != nil { writeStatusCodeForError(w, err) w.Write([]byte(err.Error())) @@ -203,10 +200,6 @@ func (s *Server) invokeDirect(w http.ResponseWriter, r *http.Request) { return } - // TODO: This should be configurable, probably in a header with some maximum. - ctx, cc := context.WithTimeout(context.Background(), 5*time.Second) - defer cc() - ref, err := types.NewVirtualActorReference(req.Namespace, req.ModuleID, req.ActorID, uint64(req.Generation)) if err != nil { writeStatusCodeForError(w, err) @@ -215,7 +208,7 @@ func (s *Server) invokeDirect(w http.ResponseWriter, r *http.Request) { } result, err := s.environment.InvokeActorDirectStream( - ctx, req.VersionStamp, req.ServerID, req.ServerVersion, ref, + getContextFromRequest(r), req.VersionStamp, req.ServerID, req.ServerVersion, ref, req.Operation, req.Payload, req.CreateIfNotExist) if err != nil { writeStatusCodeForError(w, err) @@ -256,12 +249,8 @@ func (s *Server) invokeWorker(w http.ResponseWriter, r *http.Request) { return } - // TODO: This should be configurable, probably in a header with some maximum. - ctx, cc := context.WithTimeout(context.Background(), 5*time.Second) - defer cc() - result, err := s.environment.InvokeWorkerStream( - ctx, req.Namespace, req.ModuleID, req.Operation, req.Payload, req.CreateIfNotExist) + getContextFromRequest(r), req.Namespace, req.ModuleID, req.Operation, req.Payload, req.CreateIfNotExist) if err != nil { writeStatusCodeForError(w, err) w.Write([]byte(err.Error())) @@ -348,3 +337,17 @@ func copyResultIntoStreamAndCloseResult( terminateConnection(w) } } + +func getContextFromRequest(r *http.Request) context.Context { + timeout := DefaultHTTPRequestTimeout + + if headerValue := r.Header.Get(types.HTTPHeaderTimeout); headerValue != "" { + headerTimeout, err := time.ParseDuration(headerValue) + if err == nil { + timeout = headerTimeout + } + } + + ctx, _ := context.WithTimeout(r.Context(), timeout) + return ctx +} diff --git a/virtual/types/http.go b/virtual/types/http.go new file mode 100644 index 0000000..ce59d3a --- /dev/null +++ b/virtual/types/http.go @@ -0,0 +1,5 @@ +package types + +const ( + HTTPHeaderTimeout = "nola-context-timeout" +)