Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
arjan-bal committed Jan 27, 2025
1 parent 89b45a5 commit 79159f2
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 19 deletions.
60 changes: 44 additions & 16 deletions balancer/endpointsharding/endpointsharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,22 @@ type endpointSharding struct {
bOpts balancer.BuildOptions
enableAutoReconnect bool

childMu sync.Mutex // syncs balancer.Balancer calls into children
// childMu synchronizes calls to any single child. It must be held for all
// calls into a child. To avoid deadlocks, do not acquire childMu while
// holding mu.
childMu sync.Mutex
children atomic.Pointer[resolver.EndpointMap]

// inhibitChildUpdates is set during UpdateClientConnState/ResolverError
// calls (calls to children will each produce an update, only want one
// update).
inhibitChildUpdates atomic.Bool

mu sync.Mutex // Sync updateState callouts and childState recent state updates
// mu synchronizes access to the stored children balancer states.
// It must not be held during calls into a child since synchronous calls
// back from the child may require taking mu, causing a deadlock. To avoid
// deadlocks, do not acquire childMu while holding mu.
mu sync.Mutex
}

// UpdateClientConnState creates a child for new endpoints and deletes children
Expand Down Expand Up @@ -153,10 +160,10 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState
es: es,
}
bal.childState.Balancer = bal
bal.Balancer = gracefulswitch.NewBalancer(bal, es.bOpts)
bal.child = gracefulswitch.NewBalancer(bal, es.bOpts)
}
newChildren.Set(endpoint, bal)
if err := bal.UpdateClientConnState(balancer.ClientConnState{
if err := bal.updateClientConnStateLocked(balancer.ClientConnState{
BalancerConfig: state.BalancerConfig,
ResolverState: resolver.State{
Endpoints: []resolver.Endpoint{endpoint},
Expand All @@ -175,7 +182,7 @@ func (es *endpointSharding) UpdateClientConnState(state balancer.ClientConnState
child, _ := children.Get(e)
bal := child.(*balancerWrapper)
if _, ok := newChildren.Get(e); !ok {
bal.Close()
bal.closeLocked()
}
}
es.children.Store(newChildren)
Expand Down Expand Up @@ -213,7 +220,7 @@ func (es *endpointSharding) Close() {
children := es.children.Load()
for _, child := range children.Values() {
bal := child.(*balancerWrapper)
bal.Close()
bal.closeLocked()
}
}

Expand Down Expand Up @@ -310,11 +317,20 @@ func ChildStatesFromPicker(picker balancer.Picker) []ChildState {
// balancerWrapper is a wrapper of a balancer. It ID's a child balancer by
// endpoint, and persists recent child balancer state.
type balancerWrapper struct {
balancer.Balancer // Simply forward balancer.Balancer operations.
// The following fields are initialized at build time and read-only after
// that and therefore do not need to be guarded by a mutex.

// child contains the wrapped balancer. Access it's methods only through
// methods on balancerWrapper to ensure proper synchronization
child balancer.Balancer
balancer.ClientConn // embed to intercept UpdateState, doesn't deal with SubConns
es *endpointSharding
childState ChildState
isClosed bool

es *endpointSharding

// Access to the following fields is guarded by es.mu.

childState ChildState
isClosed bool
}

func (bw *balancerWrapper) UpdateState(state balancer.State) {
Expand All @@ -327,15 +343,10 @@ func (bw *balancerWrapper) UpdateState(state balancer.State) {
bw.es.updateState()
}

func (bw *balancerWrapper) Close() {
bw.Balancer.Close()
bw.isClosed = true
}

// ExitIdle pings an IDLE child balancer to exit idle in a new goroutine to
// avoid deadlocks due to synchronous balancer state updates.
func (bw *balancerWrapper) ExitIdle() {
if ei, ok := bw.Balancer.(balancer.ExitIdler); ok {
if ei, ok := bw.child.(balancer.ExitIdler); ok {
go func() {
bw.es.childMu.Lock()
if !bw.isClosed {
Expand All @@ -346,6 +357,23 @@ func (bw *balancerWrapper) ExitIdle() {
}
}

// updateClientConnStateLocked delivers the ClientConnState to the child
// balancer. Callers must hold the child mutex of the parent endpointsharding
// balancer.
func (bw *balancerWrapper) updateClientConnStateLocked(ccs balancer.ClientConnState) error {
return bw.child.UpdateClientConnState(ccs)
}

// closeLocked closes the child balancer. Callers must hold the child mutext of
// the parent endpointsharding balancer.
func (bw *balancerWrapper) closeLocked() {
if bw.isClosed {
return
}
bw.child.Close()
bw.isClosed = true
}

// ParseConfig parses a child config list and returns an LB config to use with
// the endpointsharding balancer.
//
Expand Down
24 changes: 21 additions & 3 deletions balancer/endpointsharding/endpointsharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,27 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils/roundrobin"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/status"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)

var (
defaultTestTimeout = time.Second * 10
defaultTestTimeout = time.Second * 10
defaultTestShortTimeout = time.Millisecond * 10
)

type s struct {
Expand Down Expand Up @@ -221,7 +226,20 @@ func (s) TestEndpointShardingReconnectDisabled(t *testing.T) {
// remain IDLE and will not try to connect to the second backend in the same
// endpoint.
backend1.Stop()
if err = roundrobin.CheckRoundRobinRPCs(ctx, client, []resolver.Address{{Addr: backend3.Address}}); err != nil {
t.Fatalf("error in expected round robin: %v", err)

// Verify requests go only to backend3.
shortCtx, cancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer cancel()
for ; shortCtx.Err() == nil; <-time.After(time.Millisecond) {
var peer peer.Peer
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer)); err != nil {
if status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() returned unexpected error %v", err)
}
break
}
if got, want := peer.Addr.String(), backend3.Address; got != want {
t.Fatalf("EmptyCall() went to unexpected backend: got %q, want %q", got, want)
}
}
}

0 comments on commit 79159f2

Please sign in to comment.