Skip to content

Commit

Permalink
Merge pull request #321 from fahadnaeemkhan/cancel_propagation
Browse files Browse the repository at this point in the history
added context in gnmi_server.handleStreamSubscriptionRequest() for proper cleanup
  • Loading branch information
karimra authored Dec 22, 2023
2 parents 0bba558 + fec4642 commit 2027acb
Showing 1 changed file with 57 additions and 40 deletions.
97 changes: 57 additions & 40 deletions pkg/app/gnmi_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,8 +499,9 @@ func (a *App) Subscribe(stream gnmi.GNMI_SubscribeServer) error {
a.Logger.Printf("received a subscribe request mode=%v from %q for target %q", sc.req.GetSubscribe().GetMode(), pr.Addr, sc.target)
defer a.Logger.Printf("subscription from peer %q terminated", pr.Addr)

errChan := make(chan error, 3)
sc.errChan = make(chan error, 3)
// closing of this channel is handeled by respective goroutines that are going to send error on this channel
errChan := make(chan error, len(sc.req.GetSubscribe().GetSubscription()))
sc.errChan = errChan // send-only

a.Logger.Printf("acquiring subscription spot for target %q", sc.target)
ok := a.subscribeRPCsem.TryAcquire(1)
Expand All @@ -520,19 +521,30 @@ func (a *App) Subscribe(stream gnmi.GNMI_SubscribeServer) error {
})
close(errChan)
}()

case gnmi.SubscriptionList_POLL:
go a.handlePolledSubscription(sc)

case gnmi.SubscriptionList_STREAM:
go a.handleStreamSubscriptionRequest(sc)
default:
return status.Errorf(codes.InvalidArgument, "unrecognized subscription mode: %v", sc.req.GetSubscribe().GetMode())
}

// flushing the errChan
defer func() {
a.Logger.Printf("flushing subscription errChan")
for range errChan {
}
}()

// returing first non-nil error and flushing rest in defer
for err := range errChan {
if err != nil {
return status.Errorf(codes.Internal, "%v", err)
}
}

return nil
}

Expand Down Expand Up @@ -588,20 +600,28 @@ func (a *App) handleONCESubscriptionRequest(sc *streamClient) {

func (a *App) handleStreamSubscriptionRequest(sc *streamClient) {
peer, _ := peer.FromContext(sc.stream.Context())

errChan := make(chan error)
defer close(errChan)

// this context is required to signal this goroutine and `handleSampledQuery` goroutine that error has happened in cache
ctx, cancel := context.WithCancel(sc.stream.Context())
a.Logger.Printf("processing STREAM subscription from %q to target %q", peer.Addr, sc.target)

go func() {
defer close(sc.errChan)

for err := range errChan {
if err == nil {
a.Logger.Printf("subscription request from %q to target %q processed", peer.Addr, sc.target)
} else if errors.Is(err, context.Canceled) {
a.Logger.Printf("subscription to target %q canceled", sc.target)
sc.errChan <- err
cancel()
} else {
a.Logger.Printf("error processing STREAM subscription to target %q: %v", sc.target, err)
sc.errChan <- err
cancel()
}
}
}()
Expand All @@ -613,6 +633,7 @@ func (a *App) handleStreamSubscriptionRequest(sc *streamClient) {

if err != nil {
errChan <- err
return
}
}
var pr *gnmi.Path
Expand All @@ -624,13 +645,17 @@ func (a *App) handleStreamSubscriptionRequest(sc *streamClient) {
subs := sc.req.GetSubscribe().GetSubscription()
wg := new(sync.WaitGroup)
wg.Add(len(subs))

for i, sub := range subs {
a.Logger.Printf("handling subscriptionList item[%d]: target %q, %q", i, sc.target, sub.String())

go func(sub *gnmi.Subscription) {
defer wg.Done()
var ro *cache.ReadOpts

switch sub.GetMode() {
case gnmi.SubscriptionMode_ON_CHANGE, gnmi.SubscriptionMode_TARGET_DEFINED:
ro := &cache.ReadOpts{
ro = &cache.ReadOpts{
Target: sc.target,
Paths: []*gnmi.Path{
{
Expand All @@ -644,31 +669,14 @@ func (a *App) handleStreamSubscriptionRequest(sc *streamClient) {
SuppressRedundant: sub.GetSuppressRedundant(),
UpdatesOnly: sc.req.GetSubscribe().GetUpdatesOnly(),
}
a.Logger.Printf("cache subscribe: %+v", ro)
for n := range a.c.Subscribe(sc.stream.Context(), ro) {
if n.Err != nil {
errChan <- n.Err
return
}
err := sc.stream.Send(&gnmi.SubscribeResponse{
Response: &gnmi.SubscribeResponse_Update{
Update: n.Notification,
},
})
if err != nil {
errChan <- n.Err
return
}
}
return
case gnmi.SubscriptionMode_SAMPLE:
period := time.Duration(sub.GetSampleInterval())
if period == 0 {
period = a.Config.GnmiServer.DefaultSampleInterval
} else if period < a.Config.GnmiServer.MinSampleInterval {
period = a.Config.GnmiServer.MinSampleInterval
}
ro := &cache.ReadOpts{
ro = &cache.ReadOpts{
Target: sc.target,
Paths: []*gnmi.Path{
{
Expand All @@ -682,31 +690,39 @@ func (a *App) handleStreamSubscriptionRequest(sc *streamClient) {
SuppressRedundant: sub.GetSuppressRedundant(),
UpdatesOnly: sc.req.GetSubscribe().GetUpdatesOnly(),
}
a.Logger.Printf("cache subscribe: %+v", ro)
for n := range a.c.Subscribe(sc.stream.Context(), ro) {
if n.Err != nil {
errChan <- n.Err
a.Logger.Printf("cache subscribe failed: %+v: %v", ro, n.Err)
return
}
err := sc.stream.Send(&gnmi.SubscribeResponse{
Response: &gnmi.SubscribeResponse_Update{
Update: n.Notification,
},
})
if err != nil {
errChan <- n.Err
return
}
}

a.Logger.Printf("cache subscribe: %+v", ro)

for n := range a.c.Subscribe(ctx, ro) {
// `errChan <- n.Err` should trigger the gnmi-server side cleanup
// only wait would be for the cache to close the channel
if n.Err != nil {
errChan <- n.Err
a.Logger.Printf("cache subscribe failed: %+v: %v", ro, n.Err)

// reader should only stop once the channel is closed by sender or otherwise
// it coould block the senders who doesn't know that error has happened

continue
}

err := sc.stream.Send(&gnmi.SubscribeResponse{
Response: &gnmi.SubscribeResponse_Update{
Update: n.Notification,
},
})

if err != nil {
errChan <- n.Err
}
return
}
}(sub)
}

// wait for ctx to be done
<-sc.stream.Context().Done()
errChan <- sc.stream.Context().Err()
<-ctx.Done()
errChan <- ctx.Err()
wg.Wait()
}

Expand All @@ -717,6 +733,7 @@ func (a *App) handlePolledSubscription(sc *streamClient) {
for {
_, err = sc.stream.Recv()
if errors.Is(err, io.EOF) {
sc.errChan <- err
return
}
if err != nil {
Expand Down

0 comments on commit 2027acb

Please sign in to comment.