diff --git a/event-handler/teleport_event_test.go b/event-handler/teleport_event_test.go index 6ed79549d..085eca442 100644 --- a/event-handler/teleport_event_test.go +++ b/event-handler/teleport_event_test.go @@ -23,6 +23,7 @@ import ( auditlogpb "github.com/gravitational/teleport/api/gen/proto/go/teleport/auditlog/v1" "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/trace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/structpb" @@ -39,7 +40,10 @@ func TestNew(t *testing.T) { }, } - event, err := NewTeleportEvent(eventToJSON(t, events.AuditEvent(e)), "cursor") + protoEvent, err := eventToProto(events.AuditEvent(e)) + require.NoError(t, err) + + event, err := NewTeleportEvent(protoEvent, "cursor") require.NoError(t, err) assert.Equal(t, "test", event.ID) assert.Equal(t, "mock", event.Type) @@ -49,7 +53,10 @@ func TestNew(t *testing.T) { func TestGenID(t *testing.T) { e := &events.SessionPrint{} - event, err := NewTeleportEvent(eventToJSON(t, events.AuditEvent(e)), "cursor") + protoEvent, err := eventToProto(events.AuditEvent(e)) + require.NoError(t, err) + + event, err := NewTeleportEvent(protoEvent, "cursor") require.NoError(t, err) assert.NotEmpty(t, event.ID) } @@ -64,7 +71,10 @@ func TestSessionEnd(t *testing.T) { }, } - event, err := NewTeleportEvent(eventToJSON(t, events.AuditEvent(e)), "cursor") + protoEvent, err := eventToProto(events.AuditEvent(e)) + require.NoError(t, err) + + event, err := NewTeleportEvent(protoEvent, "cursor") require.NoError(t, err) assert.NotEmpty(t, event.ID) assert.NotEmpty(t, event.SessionID) @@ -81,7 +91,10 @@ func TestFailedLogin(t *testing.T) { }, } - event, err := NewTeleportEvent(eventToJSON(t, events.AuditEvent(e)), "cursor") + protoEvent, err := eventToProto(events.AuditEvent(e)) + require.NoError(t, err) + + event, err := NewTeleportEvent(protoEvent, "cursor") require.NoError(t, err) assert.NotEmpty(t, event.ID) assert.True(t, event.IsFailedLogin) @@ -97,28 +110,49 @@ func TestSuccessLogin(t *testing.T) { }, } - event, err := NewTeleportEvent(eventToJSON(t, events.AuditEvent(e)), "cursor") + protoEvent, err := eventToProto(events.AuditEvent(e)) + require.NoError(t, err) + + event, err := NewTeleportEvent(protoEvent, "cursor") require.NoError(t, err) assert.NotEmpty(t, event.ID) assert.False(t, event.IsFailedLogin) } -func eventToJSON(t *testing.T, e events.AuditEvent) *auditlogpb.EventUnstructured { +func eventToProto(e events.AuditEvent) (*auditlogpb.EventUnstructured, error) { data, err := lib.FastMarshal(e) - require.NoError(t, err) + if err != nil { + return nil, trace.Wrap(err) + } + str := &structpb.Struct{} - err = str.UnmarshalJSON(data) - require.NoError(t, err) + if err = str.UnmarshalJSON(data); err != nil { + return nil, trace.Wrap(err) + } + id := e.GetID() if id == "" { hash := sha256.Sum256(data) id = hex.EncodeToString(hash[:]) } + return &auditlogpb.EventUnstructured{ Type: e.GetType(), Unstructured: str, Id: id, Index: e.GetIndex(), Time: timestamppb.New(e.GetTime()), + }, nil +} + +func eventsToProto(events []events.AuditEvent) ([]*auditlogpb.EventUnstructured, error) { + protoEvents := make([]*auditlogpb.EventUnstructured, len(events)) + for i, event := range events { + protoEvent, err := eventToProto(event) + if err != nil { + return nil, trace.Wrap(err) + } + protoEvents[i] = protoEvent } + return protoEvents, nil } diff --git a/event-handler/teleport_events_watcher.go b/event-handler/teleport_events_watcher.go index 44c85b3d0..0c1444632 100644 --- a/event-handler/teleport_events_watcher.go +++ b/event-handler/teleport_events_watcher.go @@ -17,6 +17,7 @@ limitations under the License. package main import ( + "context" "fmt" "time" @@ -30,7 +31,6 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/trace" log "github.com/sirupsen/logrus" - "golang.org/x/net/context" ) const ( @@ -293,7 +293,7 @@ func (t *TeleportEventsWatcher) Events(ctx context.Context) (chan *TeleportEvent err := t.fetch(ctx) if err != nil { e <- trace.Wrap(err) - continue + break } // If there is still nothing new on current page, sleep diff --git a/event-handler/teleport_events_watcher_test.go b/event-handler/teleport_events_watcher_test.go index 7510238ba..b3242ac7c 100644 --- a/event-handler/teleport_events_watcher_test.go +++ b/event-handler/teleport_events_watcher_test.go @@ -17,6 +17,8 @@ limitations under the License. package main import ( + "strconv" + "sync" "testing" "time" @@ -31,41 +33,79 @@ import ( // mockTeleportEventWatcher is Teleport client mock type mockTeleportEventWatcher struct { + mu sync.Mutex // events is the mock list of events events []events.AuditEvent - t *testing.T + // mockSearchErr is an error to return + mockSearchErr error +} + +func (c *mockTeleportEventWatcher) setEvents(events []events.AuditEvent) { + c.mu.Lock() + defer c.mu.Unlock() + + c.events = events +} + +func (c *mockTeleportEventWatcher) setSearchEventsError(err error) { + c.mu.Lock() + defer c.mu.Unlock() + + c.mockSearchErr = err } -// SearchEvents is mock SearchEvents method which returns events func (c *mockTeleportEventWatcher) SearchEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]events.AuditEvent, string, error) { - e := c.events - c.events = make([]events.AuditEvent, 0) // nullify events - return e, "test", nil + c.mu.Lock() + defer c.mu.Unlock() + + if c.mockSearchErr != nil { + return nil, "", c.mockSearchErr + } + + var startIndex int + if startKey != "" { + startIndex, _ = strconv.Atoi(startKey) + } + + endIndex := startIndex + limit + if endIndex >= len(c.events) { + endIndex = len(c.events) + } + + // Get the next page + e := c.events[startIndex:endIndex] + + // Check if we finished the page + var lastKey string + if len(e) == limit { + lastKey = strconv.Itoa(startIndex + (len(e) - 1)) + } + + return e, lastKey, nil } -// StreamSessionEvents returns session events stream func (c *mockTeleportEventWatcher) StreamSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan events.AuditEvent, chan error) { return nil, nil } -// SearchEvents is mock SearchEvents method which returns events func (c *mockTeleportEventWatcher) SearchUnstructuredEvents(ctx context.Context, fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]*auditlogpb.EventUnstructured, string, error) { - e := c.events - c.events = make([]events.AuditEvent, 0) // nullify events + events, lastKey, err := c.SearchEvents(ctx, fromUTC, toUTC, namespace, eventTypes, limit, order, startKey) + if err != nil { + return nil, "", trace.Wrap(err) + } - events := make([]*auditlogpb.EventUnstructured, len(e)) - for i, event := range e { - events[i] = eventToJSON(c.t, event) + protoEvents, err := eventsToProto(events) + if err != nil { + return nil, "", trace.Wrap(err) } - return events, "test", nil + + return protoEvents, lastKey, nil } -// StreamSessionEvents returns session events stream func (c *mockTeleportEventWatcher) StreamUnstructuredSessionEvents(ctx context.Context, sessionID string, startIndex int64) (chan *auditlogpb.EventUnstructured, chan error) { return nil, nil } -// UsertLock is mock UpsertLock method func (c *mockTeleportEventWatcher) UpsertLock(ctx context.Context, lock types.Lock) error { return nil } @@ -76,16 +116,13 @@ func (c *mockTeleportEventWatcher) Ping(ctx context.Context) (proto.PingResponse }, nil } -// Close is mock close method func (c *mockTeleportEventWatcher) Close() error { return nil } -func newTeleportEventWatcher(t *testing.T, e []events.AuditEvent) *TeleportEventsWatcher { - teleportEventWatcher := &mockTeleportEventWatcher{events: e, t: t} - +func newTeleportEventWatcher(t *testing.T, eventsClient TeleportSearchEventsClient) *TeleportEventsWatcher { client := &TeleportEventsWatcher{ - client: teleportEventWatcher, + client: eventsClient, pos: -1, config: &StartCmdConfig{ IngestConfig: IngestConfig{ @@ -98,41 +135,206 @@ func newTeleportEventWatcher(t *testing.T, e []events.AuditEvent) *TeleportEvent return client } -func TestNext(t *testing.T) { - e := []events.AuditEvent{ - &events.UserCreate{ +func TestEvents(t *testing.T) { + ctx := context.Background() + + // create fake audit events with ids 0-19 + testAuditEvents := make([]events.AuditEvent, 20) + for i := 0; i < 20; i++ { + testAuditEvents[i] = &events.UserCreate{ Metadata: events.Metadata{ - ID: "1", + ID: strconv.Itoa(i), }, - }, - &events.UserDelete{ + } + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Add the 20 events to a mock event watcher. + mockEventWatcher := &mockTeleportEventWatcher{events: testAuditEvents} + client := newTeleportEventWatcher(t, mockEventWatcher) + + // Start the events goroutine + chEvt, chErr := client.Events(ctx) + + // Collect all 20 events + for i := 0; i < 20; i++ { + select { + case event, ok := <-chEvt: + require.NotNil(t, event, "Expected an event but got nil. i: %v", i) + require.Equal(t, strconv.Itoa(i), event.ID) + if !ok { + return + } + case err := <-chErr: + t.Fatalf("Received unexpected error from error channel: %v", err) + return + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") + } + } + + // Both channels should be closed once the last event is reached. + select { + case _, ok := <-chEvt: + require.False(t, ok, "Events channel should be closed") + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") + } + + select { + case _, ok := <-chErr: + require.False(t, ok, "Error channel should be closed") + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") + } + + // Events goroutine should return next page errors + mockErr := trace.Errorf("error") + mockEventWatcher.setSearchEventsError(mockErr) + + select { + case err := <-chErr: + require.Error(t, mockErr, err) + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") + } + + // Both channels should be closed + select { + case _, ok := <-chEvt: + require.False(t, ok, "Events channel should be closed") + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") + } + + select { + case _, ok := <-chErr: + require.False(t, ok, "Error channel should be closed") + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") + } +} + +func TestUpdatePage(t *testing.T) { + ctx := context.Background() + + // create fake audit events with ids 0-9 + testAuditEvents := make([]events.AuditEvent, 10) + for i := 0; i < 10; i++ { + testAuditEvents[i] = &events.UserCreate{ Metadata: events.Metadata{ - ID: "", + ID: strconv.Itoa(i), }, - }, + } } + ctx, cancel := context.WithCancel(ctx) + defer cancel() - client := newTeleportEventWatcher(t, e) - chEvt, chErr := client.Events(context.Background()) + mockEventWatcher := &mockTeleportEventWatcher{} + client := newTeleportEventWatcher(t, mockEventWatcher) + client.config.ExitOnLastEvent = false + // Start the events goroutine + chEvt, chErr := client.Events(ctx) + + // Add an incomplete page of 3 events and collect them. + mockEventWatcher.setEvents(testAuditEvents[:3]) + var i int + for ; i < 3; i++ { + select { + case event, ok := <-chEvt: + require.NotNil(t, event, "Expected an event but got nil") + require.Equal(t, strconv.Itoa(i), event.ID) + if !ok { + return + } + case err := <-chErr: + t.Fatalf("Received unexpected error from error channel: %v", err) + return + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") + } + } + + // Both channels should still be open and empty. select { - case err := <-chErr: - require.NoError(t, err) - case e := <-chEvt: - require.NotNil(t, e.Event) - require.Equal(t, e.ID, "1") - case <-time.After(time.Second): - require.Fail(t, "No events were sent") + case <-chEvt: + t.Fatalf("Events channel should be open") + case <-chErr: + t.Fatalf("Events channel should be open") + case <-time.After(100 * time.Millisecond): } + // Update the event watcher with the full page of events an collect. + mockEventWatcher.setEvents(testAuditEvents[:5]) + for ; i < 5; i++ { + select { + case event, ok := <-chEvt: + require.NotNil(t, event, "Expected an event but got nil") + require.Equal(t, strconv.Itoa(i), event.ID) + if !ok { + return + } + case err := <-chErr: + t.Fatalf("Received unexpected error from error channel: %v", err) + return + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") + } + } + + // Both channels should still be open and empty. + select { + case <-chEvt: + t.Fatalf("Events channel should be open") + case <-chErr: + t.Fatalf("Events channel should be open") + case <-time.After(100 * time.Millisecond): + } + + // Add another partial page and collect the events + mockEventWatcher.setEvents(testAuditEvents[:7]) + for ; i < 7; i++ { + select { + case event, ok := <-chEvt: + require.NotNil(t, event, "Expected an event but got nil") + require.Equal(t, strconv.Itoa(i), event.ID) + if !ok { + return + } + case err := <-chErr: + t.Fatalf("Received unexpected error from error channel: %v", err) + return + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") + } + } + + // Events goroutine should return update page errors + mockErr := trace.Errorf("error") + mockEventWatcher.setSearchEventsError(mockErr) + select { case err := <-chErr: - require.NoError(t, err) - case e := <-chEvt: - require.NotNil(t, e.Event) - require.Equal(t, "081ca05eea09ac0cd06e2d2acd06bec424146b254aa500de37bdc2c2b0a4dd0f", e.ID) - case <-time.After(time.Second): - require.Fail(t, "No events were sent") + require.Error(t, mockErr, err) + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") + } + + // Both channels should be closed + select { + case _, ok := <-chEvt: + require.False(t, ok, "Events channel should be closed") + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") + } + + select { + case _, ok := <-chErr: + require.False(t, ok, "Error channel should be closed") + case <-time.After(100 * time.Millisecond): + t.Fatalf("No events received within deadline") } }