Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(relay/client):error handling for rpc's and replacing errgroup with sync.WaitGroup #55

Merged
merged 1 commit into from
May 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 82 additions & 53 deletions relay-server/server/relayServer.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,6 @@ func NewClient(server string) *LogClient {
kg.Warnf("Failed to call WatchLogs (%s)\n err=%s", server, err.Error())
return nil
}
// == //

// set wait group
lc.WgServer, lc.Context = errgroup.WithContext(context.Background())

return lc
}
Expand All @@ -402,30 +398,35 @@ func (lc *LogClient) DoHealthCheck() bool {
}

// WatchMessages Function
func (lc *LogClient) WatchMessages(ctx context.Context) error {
func (lc *LogClient) WatchMessages(wg *sync.WaitGroup, stop chan struct{}, errCh chan error) {

defer wg.Done()

var err error

for lc.Running {
var res *pb.Message

if res, err = lc.MsgStream.Recv(); err != nil {
return fmt.Errorf("failed to receive a message (%s) %s", lc.Server, err.Error())

}
select {
case MsgBufferChannel <- res:
case <-ctx.Done():
// The context is over, stop processing results
return nil
case <-stop:
return
default:
//not able to add it to Log buffer
if res, err = lc.MsgStream.Recv(); err != nil {
errCh <- fmt.Errorf("failed to receive a message (%s) %s", lc.Server, err.Error())
return
}

select {
case MsgBufferChannel <- res:
case <-stop:
return
default:
// Not able to add it to Message buffer
}
}
}

kg.Print("Stopped watching messages from " + lc.Server)

return nil
}

// AddMsgFromBuffChan Adds Msg from MsgBufferChannel into MsgStructs
Expand Down Expand Up @@ -461,30 +462,35 @@ func (rs *RelayServer) AddMsgFromBuffChan() {
}

// WatchAlerts Function
func (lc *LogClient) WatchAlerts(ctx context.Context) error {
func (lc *LogClient) WatchAlerts(wg *sync.WaitGroup, stop chan struct{}, errCh chan error) {

defer wg.Done()

var err error

for lc.Running {
var res *pb.Alert

if res, err = lc.AlertStream.Recv(); err != nil {
return fmt.Errorf("failed to receive a alert (%s) %s", lc.Server, err.Error())
}

select {
case AlertBufferChannel <- res:
case <-ctx.Done():
// The context is over, stop processing results
return nil
case <-stop:
return
default:
//not able to add it to Log buffer
if res, err = lc.AlertStream.Recv(); err != nil {
errCh <- fmt.Errorf("failed to receive an alert (%s) %s", lc.Server, err.Error())
return
}

select {
case AlertBufferChannel <- res:
case <-stop:
return
default:
// Not able to add it to Alert buffer
}
}
}

kg.Print("Stopped watching alerts from " + lc.Server)

return nil
}

// AddAlertFromBuffChan Adds ALert from AlertBufferChannel into AlertStructs
Expand Down Expand Up @@ -520,30 +526,34 @@ func (rs *RelayServer) AddAlertFromBuffChan() {
}

// WatchLogs Function
func (lc *LogClient) WatchLogs(ctx context.Context) error {
func (lc *LogClient) WatchLogs(wg *sync.WaitGroup, stop chan struct{}, errCh chan error) {
defer wg.Done()

var err error

for lc.Running {
var res *pb.Log

if res, err = lc.LogStream.Recv(); err != nil {
return fmt.Errorf("failed to receive a log (%s) %s", lc.Server, err.Error())
}

select {
case LogBufferChannel <- res:
case <-ctx.Done():
// The context is over, stop processing results
return nil
case <-stop:
return
default:
//not able to add it to Log buffer
if res, err = lc.LogStream.Recv(); err != nil {
errCh <- fmt.Errorf("failed to receive a log (%s) %s", lc.Server, err.Error())
return
}

select {
case LogBufferChannel <- res:
case <-stop:
return
default:
// Not able to add it to Log buffer
}
}
}

kg.Print("Stopped watching logs from " + lc.Server)

return nil
}

// AddLogFromBuffChan Adds Log from LogBufferChannel into LogStructs
Expand Down Expand Up @@ -744,26 +754,45 @@ func connectToKubeArmor(nodeIP, port string) error {
}
kg.Printf("Checked the liveness of KubeArmor's gRPC service (%s)", server)

// watch messages
client.WgServer.Go(func() error {
return client.WatchMessages(client.Context)
})
var wg sync.WaitGroup
stop := make(chan struct{})
errCh := make(chan error, 1)

// Start watching messages
wg.Add(1)
go func() {
client.WatchMessages(&wg, stop, errCh)
}()
kg.Print("Started to watch messages from " + server)

// watch alerts
client.WgServer.Go(func() error {
return client.WatchAlerts(client.Context)
})
// Start watching alerts
wg.Add(1)
go func() {
client.WatchAlerts(&wg, stop, errCh)
}()
kg.Print("Started to watch alerts from " + server)

// watch logs
client.WgServer.Go(func() error {
return client.WatchLogs(client.Context)
})
// Start watching logs
wg.Add(1)
go func() {
client.WatchLogs(&wg, stop, errCh)
}()
kg.Print("Started to watch logs from " + server)

if err := client.WgServer.Wait(); err != nil {
// Wait for an error or all goroutines to finish
select {
case err := <-errCh:
close(stop) // Stop other goroutines
kg.Warn(err.Error())
case <-func() chan struct{} {
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
return done
}():
// All goroutines finished without error
}

if err := client.DestroyClient(); err != nil {
Expand Down
Loading