diff --git a/internal/controlplane/handlers_datasource.go b/internal/controlplane/handlers_datasource.go index d7d65bc87c..cf1ce3aca2 100644 --- a/internal/controlplane/handlers_datasource.go +++ b/internal/controlplane/handlers_datasource.go @@ -13,6 +13,7 @@ import ( "github.com/mindersec/minder/internal/datasources/service" "github.com/mindersec/minder/internal/engine/engcontext" "github.com/mindersec/minder/internal/flags" + "github.com/mindersec/minder/internal/util" minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1" ) @@ -25,6 +26,14 @@ func (s *Server) CreateDataSource(ctx context.Context, return nil, status.Errorf(codes.Unavailable, "DataSources feature is disabled") } + entityCtx := engcontext.EntityFromContext(ctx) + err := entityCtx.ValidateProject(ctx, s.store) + if err != nil { + return nil, util.UserVisibleError(codes.InvalidArgument, "error in entity context: %v", err) + } + + projectID := entityCtx.Project.ID + // Get the data source from the request dsReq := in.GetDataSource() if dsReq == nil { @@ -36,7 +45,7 @@ func (s *Server) CreateDataSource(ctx context.Context, } // Process the request - ret, err := s.dataSourcesService.Create(ctx, uuid.Nil, dsReq, nil) + ret, err := s.dataSourcesService.Create(ctx, projectID, uuid.Nil, dsReq, nil) if err != nil { return nil, err } @@ -156,6 +165,14 @@ func (s *Server) UpdateDataSource(ctx context.Context, return nil, status.Errorf(codes.Unavailable, "DataSources feature is disabled") } + entityCtx := engcontext.EntityFromContext(ctx) + err := entityCtx.ValidateProject(ctx, s.store) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "error in entity context: %v", err) + } + + projectID := entityCtx.Project.ID + // Get the data source from the request dsReq := in.GetDataSource() if dsReq == nil { @@ -167,7 +184,7 @@ func (s *Server) UpdateDataSource(ctx context.Context, } // Process the request - ret, err := s.dataSourcesService.Update(ctx, uuid.Nil, dsReq, nil) + ret, err := s.dataSourcesService.Update(ctx, projectID, uuid.Nil, dsReq, nil) if err != nil { return nil, err } diff --git a/internal/controlplane/handlers_datasource_test.go b/internal/controlplane/handlers_datasource_test.go index 9fcc7cccd0..7ccb1dcb4d 100644 --- a/internal/controlplane/handlers_datasource_test.go +++ b/internal/controlplane/handlers_datasource_test.go @@ -37,7 +37,7 @@ func TestCreateDataSource(t *testing.T) { setupMocks: func(dsService *mock_service.MockDataSourcesService, featureClient *flags.FakeClient) { featureClient.Data = map[string]interface{}{"data_sources": true} dsService.EXPECT(). - Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Create(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(&minderv1.DataSource{Name: "test-ds"}, nil) }, request: &minderv1.CreateDataSourceRequest{ @@ -417,7 +417,7 @@ func TestUpdateDataSource(t *testing.T) { setupMocks: func(dsService *mock_service.MockDataSourcesService, featureClient *flags.FakeClient, _ *mockdb.MockStore) { featureClient.Data = map[string]interface{}{"data_sources": true} dsService.EXPECT(). - Update(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Update(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(&minderv1.DataSource{Id: dsIDStr, Name: "updated-ds"}, nil) }, request: &minderv1.UpdateDataSourceRequest{ diff --git a/internal/controlplane/handlers_user.go b/internal/controlplane/handlers_user.go index 7621f4fe23..88808cb469 100644 --- a/internal/controlplane/handlers_user.go +++ b/internal/controlplane/handlers_user.go @@ -114,7 +114,7 @@ func (s *Server) CreateUser(ctx context.Context, }, nil } -func (s *Server) claimGitHubInstalls(ctx context.Context, qtx db.Querier) []*db.Project { +func (s *Server) claimGitHubInstalls(ctx context.Context, qtx db.ExtendQuerier) []*db.Project { ghId, ok := jwt.GetUserClaimFromContext[string](ctx, "gh_id") if !ok || ghId == "" { return nil @@ -484,7 +484,7 @@ func (s *Server) acceptInvitation(ctx context.Context, userInvite db.GetInvitati return nil } -func ensureUser(ctx context.Context, s *Server, store db.Querier) (db.User, error) { +func ensureUser(ctx context.Context, s *Server, store db.ExtendQuerier) (db.User, error) { sub := jwt.GetUserSubjectFromContext(ctx) if sub == "" { return db.User{}, status.Error(codes.Internal, "failed to get user subject") diff --git a/internal/datasources/service/mock/fixtures/service.go b/internal/datasources/service/mock/fixtures/service.go index 8bb3f00174..532e99553a 100644 --- a/internal/datasources/service/mock/fixtures/service.go +++ b/internal/datasources/service/mock/fixtures/service.go @@ -75,3 +75,9 @@ func WithFailedGetByName() func(DataSourcesSvcMock) { Return(nil, errDefault) } } + +func WithSuccessfulUpsertDataSource(mock DataSourcesSvcMock) { + mock.EXPECT(). + Upsert(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil) +} diff --git a/internal/datasources/service/mock/service.go b/internal/datasources/service/mock/service.go index 9b8d142531..11b00100da 100644 --- a/internal/datasources/service/mock/service.go +++ b/internal/datasources/service/mock/service.go @@ -60,18 +60,18 @@ func (mr *MockDataSourcesServiceMockRecorder) BuildDataSourceRegistry(ctx, rt, o } // Create mocks base method. -func (m *MockDataSourcesService) Create(ctx context.Context, subscriptionID uuid.UUID, ds *v1.DataSource, opts *service.Options) (*v1.DataSource, error) { +func (m *MockDataSourcesService) Create(ctx context.Context, projectID, subscriptionID uuid.UUID, ds *v1.DataSource, opts *service.Options) (*v1.DataSource, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Create", ctx, subscriptionID, ds, opts) + ret := m.ctrl.Call(m, "Create", ctx, projectID, subscriptionID, ds, opts) ret0, _ := ret[0].(*v1.DataSource) ret1, _ := ret[1].(error) return ret0, ret1 } // Create indicates an expected call of Create. -func (mr *MockDataSourcesServiceMockRecorder) Create(ctx, subscriptionID, ds, opts any) *gomock.Call { +func (mr *MockDataSourcesServiceMockRecorder) Create(ctx, projectID, subscriptionID, ds, opts any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockDataSourcesService)(nil).Create), ctx, subscriptionID, ds, opts) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockDataSourcesService)(nil).Create), ctx, projectID, subscriptionID, ds, opts) } // Delete mocks base method. @@ -134,16 +134,30 @@ func (mr *MockDataSourcesServiceMockRecorder) List(ctx, project, opts any) *gomo } // Update mocks base method. -func (m *MockDataSourcesService) Update(ctx context.Context, subscriptionID uuid.UUID, ds *v1.DataSource, opts *service.Options) (*v1.DataSource, error) { +func (m *MockDataSourcesService) Update(ctx context.Context, projectID, subscriptionID uuid.UUID, ds *v1.DataSource, opts *service.Options) (*v1.DataSource, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Update", ctx, subscriptionID, ds, opts) + ret := m.ctrl.Call(m, "Update", ctx, projectID, subscriptionID, ds, opts) ret0, _ := ret[0].(*v1.DataSource) ret1, _ := ret[1].(error) return ret0, ret1 } // Update indicates an expected call of Update. -func (mr *MockDataSourcesServiceMockRecorder) Update(ctx, subscriptionID, ds, opts any) *gomock.Call { +func (mr *MockDataSourcesServiceMockRecorder) Update(ctx, projectID, subscriptionID, ds, opts any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDataSourcesService)(nil).Update), ctx, subscriptionID, ds, opts) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDataSourcesService)(nil).Update), ctx, projectID, subscriptionID, ds, opts) +} + +// Upsert mocks base method. +func (m *MockDataSourcesService) Upsert(ctx context.Context, projectID, subscriptionID uuid.UUID, ds *v1.DataSource, opts *service.Options) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Upsert", ctx, projectID, subscriptionID, ds, opts) + ret0, _ := ret[0].(error) + return ret0 +} + +// Upsert indicates an expected call of Upsert. +func (mr *MockDataSourcesServiceMockRecorder) Upsert(ctx, projectID, subscriptionID, ds, opts any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Upsert", reflect.TypeOf((*MockDataSourcesService)(nil).Upsert), ctx, projectID, subscriptionID, ds, opts) } diff --git a/internal/datasources/service/service.go b/internal/datasources/service/service.go index 5ec472a5fe..c2116b19cf 100644 --- a/internal/datasources/service/service.go +++ b/internal/datasources/service/service.go @@ -25,6 +25,11 @@ import ( //go:generate go run go.uber.org/mock/mockgen -package mock_$GOPACKAGE -destination=./mock/$GOFILE -source=./$GOFILE +var ( + // ErrDataSourceAlreadyExists is returned when a data source already exists + ErrDataSourceAlreadyExists = util.UserVisibleError(codes.AlreadyExists, "data source already exists") +) + // DataSourcesService is an interface that defines the methods for the data sources service. type DataSourcesService interface { // GetByName returns a data source by name. @@ -37,10 +42,26 @@ type DataSourcesService interface { List(ctx context.Context, project uuid.UUID, opts *ReadOptions) ([]*minderv1.DataSource, error) // Create creates a new data source. - Create(ctx context.Context, subscriptionID uuid.UUID, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) + Create( + ctx context.Context, + projectID uuid.UUID, + subscriptionID uuid.UUID, + ds *minderv1.DataSource, + opts *Options, + ) (*minderv1.DataSource, error) // Update updates an existing data source. - Update(ctx context.Context, subscriptionID uuid.UUID, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) + Update( + ctx context.Context, + projectID uuid.UUID, + subscriptionID uuid.UUID, + ds *minderv1.DataSource, + opts *Options, + ) (*minderv1.DataSource, error) + + // Upsert creates a new data source if it does not exist or updates it if it already exists. + // This is used in the subscription logic. + Upsert(ctx context.Context, projectID uuid.UUID, subscriptionID uuid.UUID, ds *minderv1.DataSource, opts *Options) error // Delete deletes a data source in the given project. // @@ -150,7 +171,12 @@ func (d *dataSourceService) List( // We first validate the data source name uniqueness, then create the data source record. // Finally, we create function records based on the driver type. func (d *dataSourceService) Create( - ctx context.Context, subscriptionID uuid.UUID, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) { + ctx context.Context, + projectID uuid.UUID, + subscriptionID uuid.UUID, + ds *minderv1.DataSource, + opts *Options, +) (*minderv1.DataSource, error) { if err := ds.Validate(); err != nil { return nil, fmt.Errorf("data source validation failed: %w", err) } @@ -173,11 +199,6 @@ func (d *dataSourceService) Create( tx := stx.Q() - projectID, err := uuid.Parse(ds.GetContext().GetProjectId()) - if err != nil { - return nil, fmt.Errorf("invalid project ID: %w", err) - } - // Check if such data source already exists in project hierarchy projs, err := listRelevantProjects(ctx, tx, projectID, true) if err != nil { @@ -191,8 +212,7 @@ func (d *dataSourceService) Create( return nil, fmt.Errorf("failed to check for existing data source: %w", err) } if existing.ID != uuid.Nil { - return nil, util.UserVisibleError(codes.AlreadyExists, - "data source with name %s already exists", ds.GetName()) + return nil, ErrDataSourceAlreadyExists } // Create data source record @@ -226,7 +246,12 @@ func (d *dataSourceService) Create( // because it's simpler and safer - it ensures consistency and avoids partial updates. // All functions must use the same driver type to maintain data source integrity. func (d *dataSourceService) Update( - ctx context.Context, subscriptionID uuid.UUID, ds *minderv1.DataSource, opts *Options) (*minderv1.DataSource, error) { + ctx context.Context, + projectID uuid.UUID, + subscriptionID uuid.UUID, + ds *minderv1.DataSource, + opts *Options, +) (*minderv1.DataSource, error) { if err := ds.Validate(); err != nil { return nil, fmt.Errorf("data source validation failed: %w", err) } @@ -245,11 +270,6 @@ func (d *dataSourceService) Update( tx := stx.Q() - projectID, err := uuid.Parse(ds.GetContext().GetProjectId()) - if err != nil { - return nil, fmt.Errorf("invalid project ID: %w", err) - } - // Validate the subscription ID if present existingDS, err := getDataSourceFromDb(ctx, projectID, ReadBuilder().WithTransaction(tx), tx, func(ctx context.Context, tx db.ExtendQuerier, projs []uuid.UUID) (db.DataSource, error) { @@ -301,6 +321,33 @@ func (d *dataSourceService) Update( return ds, nil } +// Upsert creates the data source if it does not already exist +// or updates it if it already exists. This is used in the subscription +// logic. +func (d *dataSourceService) Upsert( + ctx context.Context, + projectID uuid.UUID, + subscriptionID uuid.UUID, + ds *minderv1.DataSource, + opts *Options, +) error { + // Simulate upsert semantics by trying to create, then trying to update. + _, err := d.Create(ctx, projectID, subscriptionID, ds, opts) + if err == nil { + // Rule successfully created, we can stop here. + return nil + } else if !errors.Is(err, ErrDataSourceAlreadyExists) { + return fmt.Errorf("error while creating data source: %w", err) + } + + // If we get here: data source already exists. Let's update it. + _, err = d.Update(ctx, projectID, subscriptionID, ds, opts) + if err != nil { + return fmt.Errorf("error while updating data source: %w", err) + } + return nil +} + // Delete deletes a data source in the given project. func (d *dataSourceService) Delete( ctx context.Context, id uuid.UUID, project uuid.UUID, opts *Options) error { diff --git a/internal/datasources/service/service_test.go b/internal/datasources/service/service_test.go index 297a0b0930..e7d1fdf88e 100644 --- a/internal/datasources/service/service_test.go +++ b/internal/datasources/service/service_test.go @@ -25,6 +25,7 @@ import ( ) var ( + projectID = uuid.New() subscriptionID = uuid.New() validRESTDriverFixture = &minderv1.DataSource_Rest{ Rest: &minderv1.RestDataSource{ @@ -554,7 +555,7 @@ func TestCreate(t *testing.T) { } tt.setup(mockStore) - got, err := svc.Create(context.Background(), tt.args.subscriptionId, tt.args.ds, tt.args.opts) + got, err := svc.Create(context.Background(), projectID, tt.args.subscriptionId, tt.args.ds, tt.args.opts) if tt.wantErr { assert.Error(t, err) return @@ -1535,7 +1536,7 @@ func TestUpdate(t *testing.T) { tt.setup(mockStore) - got, err := svc.Update(context.Background(), tt.args.subscriptionId, tt.args.ds, tt.args.opts) + got, err := svc.Update(context.Background(), projectID, tt.args.subscriptionId, tt.args.ds, tt.args.opts) if tt.wantErr { assert.Error(t, err) return @@ -1546,3 +1547,174 @@ func TestUpdate(t *testing.T) { }) } } + +func TestUpsert(t *testing.T) { + t.Parallel() + + type args struct { + subscriptionID uuid.UUID + ds *minderv1.DataSource + opts *Options + } + tests := []struct { + name string + args args + setup func(mockDB *mockdb.MockStore) + wantErr bool + }{ + { + name: "Successfully create data source", + args: args{ + subscriptionID: subscriptionID, + ds: &minderv1.DataSource{ + Name: "namespace/test_ds", + Context: &minderv1.ContextV2{ + ProjectId: uuid.New().String(), + }, + Driver: validRESTDriverFixture, + }, + opts: &Options{}, + }, + setup: func(mockDB *mockdb.MockStore) { + mockDB.EXPECT().GetParentProjects(gomock.Any(), gomock.Any()). + Return([]uuid.UUID{uuid.New()}, nil) + + mockDB.EXPECT().GetDataSourceByName(gomock.Any(), gomock.Any()). + Return(db.DataSource{}, sql.ErrNoRows) + + mockDB.EXPECT().CreateDataSource(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: uuid.New(), + Name: "namespace/test_ds", + }, nil) + + mockDB.EXPECT().AddDataSourceFunction(gomock.Any(), gomock.Any()). + Return(db.DataSourcesFunction{}, nil) + }, + wantErr: false, + }, + { + name: "Successfully update existing data source", + args: args{ + subscriptionID: subscriptionID, + ds: &minderv1.DataSource{ + Name: "namespace/test_ds", + Context: &minderv1.ContextV2{ + ProjectId: uuid.New().String(), + }, + Driver: validRESTDriverFixture, + }, + opts: &Options{}, + }, + setup: func(mockDB *mockdb.MockStore) { + dsID := uuid.New() + mockDB.EXPECT().GetParentProjects(gomock.Any(), gomock.Any()). + Return([]uuid.UUID{uuid.New()}, nil) + + // The data source already exists + mockDB.EXPECT().GetDataSourceByName(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: dsID, + Name: "namespace/test_ds", + SubscriptionID: uuid.NullUUID{Valid: true, UUID: subscriptionID}, + }, nil).AnyTimes() + + mockDB.EXPECT().ListDataSourceFunctions(gomock.Any(), gomock.Any()). + Return([]db.DataSourcesFunction{ + { + ID: uuid.New(), + DataSourceID: dsID, + Name: "test_function", + Type: v1.DataSourceDriverRest, + Definition: restDriverToJson(t, &minderv1.RestDataSource_Def{ + Endpoint: "http://example.com/updated", + InputSchema: func() *structpb.Struct { + s, _ := structpb.NewStruct(map[string]any{}) + return s + }(), + }), + }, + }, nil) + + mockDB.EXPECT().UpdateDataSource(gomock.Any(), gomock.Any()). + Return(db.DataSource{ + ID: uuid.New(), + Name: "test_ds", + }, nil) + + mockDB.EXPECT().DeleteDataSourceFunctions(gomock.Any(), gomock.Any()). + Return(nil, nil) + + mockDB.EXPECT().AddDataSourceFunction(gomock.Any(), gomock.Any()). + Return(db.DataSourcesFunction{}, nil) + }, + wantErr: false, + }, + { + name: "Invalid namespace name", + args: args{ + ds: &minderv1.DataSource{ + Name: "name-with-no-namespace", + Context: &minderv1.ContextV2{ + ProjectId: uuid.New().String(), + }, + Driver: validRESTDriverFixture, + }, + subscriptionID: subscriptionID, + opts: &Options{}, + }, + setup: func(_ *mockdb.MockStore) {}, + wantErr: true, + }, + { + name: "Nil data source", + args: args{ + ds: nil, + opts: &Options{}, + }, + setup: func(_ *mockdb.MockStore) {}, + wantErr: true, + }, + { + name: "Invalid project ID", + args: args{ + ds: &minderv1.DataSource{ + Context: &minderv1.ContextV2{ + ProjectId: "invalid-uuid", + }, + }, + opts: &Options{}, + }, + setup: func(_ *mockdb.MockStore) {}, + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockStore := mockdb.NewMockStore(ctrl) + + svc := NewDataSourceService(mockStore) + svc.txBuilder = func(_ *dataSourceService, _ txGetter) (serviceTX, error) { + return &fakeTxBuilder{ + store: mockStore, + }, nil + } + tt.setup(mockStore) + + err := svc.Upsert(context.Background(), projectID, tt.args.subscriptionID, tt.args.ds, tt.args.opts) + if tt.wantErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + }) + } +} diff --git a/internal/marketplaces/bundles/mock/fixtures/reader.go b/internal/marketplaces/bundles/mock/fixtures/reader.go index 88f37f5cfc..8588baa742 100644 --- a/internal/marketplaces/bundles/mock/fixtures/reader.go +++ b/internal/marketplaces/bundles/mock/fixtures/reader.go @@ -76,3 +76,19 @@ func WithFailedForEachRuleType(mock BundleMock) { ForEachRuleType(gomock.Any()). Return(errDefault) } + +func WithSuccessfulForEachDataSource(mock BundleMock) { + type argType = func(source *v1.DataSource) error + var argument argType + mock.EXPECT(). + ForEachDataSource(gomock.AssignableToTypeOf(argument)). + DoAndReturn(func(fn argType) error { + return fn(&v1.DataSource{}) + }) +} + +func WithFailedForEachDataSource(mock BundleMock) { + mock.EXPECT(). + ForEachDataSource(gomock.Any()). + Return(errDefault) +} diff --git a/internal/marketplaces/bundles/mock/reader.go b/internal/marketplaces/bundles/mock/reader.go index 59ccef8e33..e7dec2b606 100644 --- a/internal/marketplaces/bundles/mock/reader.go +++ b/internal/marketplaces/bundles/mock/reader.go @@ -41,6 +41,20 @@ func (m *MockBundleReader) EXPECT() *MockBundleReaderMockRecorder { return m.recorder } +// ForEachDataSource mocks base method. +func (m *MockBundleReader) ForEachDataSource(arg0 func(*v1.DataSource) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ForEachDataSource", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// ForEachDataSource indicates an expected call of ForEachDataSource. +func (mr *MockBundleReaderMockRecorder) ForEachDataSource(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForEachDataSource", reflect.TypeOf((*MockBundleReader)(nil).ForEachDataSource), arg0) +} + // ForEachRuleType mocks base method. func (m *MockBundleReader) ForEachRuleType(arg0 func(*v1.RuleType) error) error { m.ctrl.T.Helper() diff --git a/internal/marketplaces/factory.go b/internal/marketplaces/factory.go index 1233a2cac6..bde3615afd 100644 --- a/internal/marketplaces/factory.go +++ b/internal/marketplaces/factory.go @@ -8,6 +8,7 @@ import ( "fmt" "path/filepath" + datasourceservice "github.com/mindersec/minder/internal/datasources/service" sub "github.com/mindersec/minder/internal/marketplaces/subscriptions" "github.com/mindersec/minder/pkg/config/server" "github.com/mindersec/minder/pkg/mindpak" @@ -25,6 +26,7 @@ func NewMarketplaceFromServiceConfig( config server.MarketplaceConfig, profile profiles.ProfileService, ruleType ruletypes.RuleTypeService, + dataSource datasourceservice.DataSourcesService, ) (Marketplace, error) { if !config.Enabled { return NewNoopMarketplace(), nil @@ -54,7 +56,7 @@ func NewMarketplaceFromServiceConfig( newSources[i] = source } - subscription := sub.NewSubscriptionService(profile, ruleType) + subscription := sub.NewSubscriptionService(profile, ruleType, dataSource) marketplace, err := NewMarketplace(newSources, subscription) if err != nil { return nil, fmt.Errorf("error while creating marketplace: %w", err) diff --git a/internal/marketplaces/service.go b/internal/marketplaces/service.go index af0e1eddf9..1ae07df0cf 100644 --- a/internal/marketplaces/service.go +++ b/internal/marketplaces/service.go @@ -28,7 +28,7 @@ type Marketplace interface { ctx context.Context, projectID uuid.UUID, bundleID mindpak.BundleID, - qtx db.Querier, + qtx db.ExtendQuerier, ) error // AddProfile adds the specified profile from the bundle to the project. AddProfile( @@ -53,7 +53,7 @@ func (s *marketplace) Subscribe( ctx context.Context, projectID uuid.UUID, bundleID mindpak.BundleID, - qtx db.Querier, + qtx db.ExtendQuerier, ) error { bundle, err := s.getBundle(bundleID) if err != nil { @@ -100,7 +100,7 @@ func (s *marketplace) getBundle(bundleID mindpak.BundleID) (reader.BundleReader, // This is used when the Marketplace functionality is disabled type noopMarketplace struct{} -func (_ *noopMarketplace) Subscribe(_ context.Context, _ uuid.UUID, _ mindpak.BundleID, _ db.Querier) error { +func (_ *noopMarketplace) Subscribe(_ context.Context, _ uuid.UUID, _ mindpak.BundleID, _ db.ExtendQuerier) error { return nil } diff --git a/internal/marketplaces/subscriptions/mock/service.go b/internal/marketplaces/subscriptions/mock/service.go index b95b69a84a..1576b24d54 100644 --- a/internal/marketplaces/subscriptions/mock/service.go +++ b/internal/marketplaces/subscriptions/mock/service.go @@ -58,7 +58,7 @@ func (mr *MockSubscriptionServiceMockRecorder) CreateProfile(ctx, projectID, bun } // Subscribe mocks base method. -func (m *MockSubscriptionService) Subscribe(ctx context.Context, projectID uuid.UUID, bundle reader.BundleReader, qtx db.Querier) error { +func (m *MockSubscriptionService) Subscribe(ctx context.Context, projectID uuid.UUID, bundle reader.BundleReader, qtx db.ExtendQuerier) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Subscribe", ctx, projectID, bundle, qtx) ret0, _ := ret[0].(error) diff --git a/internal/marketplaces/subscriptions/service.go b/internal/marketplaces/subscriptions/service.go index 2d8a156d05..37f282ed58 100644 --- a/internal/marketplaces/subscriptions/service.go +++ b/internal/marketplaces/subscriptions/service.go @@ -14,6 +14,7 @@ import ( "github.com/google/uuid" + datasourceservice "github.com/mindersec/minder/internal/datasources/service" "github.com/mindersec/minder/internal/db" minderv1 "github.com/mindersec/minder/pkg/api/protobuf/go/minder/v1" "github.com/mindersec/minder/pkg/mindpak" @@ -35,7 +36,7 @@ type SubscriptionService interface { ctx context.Context, projectID uuid.UUID, bundle reader.BundleReader, - qtx db.Querier, + qtx db.ExtendQuerier, ) error // CreateProfile creates the specified profile from the bundle in the project. CreateProfile( @@ -48,18 +49,21 @@ type SubscriptionService interface { } type subscriptionService struct { - profiles profsvc.ProfileService - rules ruletypes.RuleTypeService + profiles profsvc.ProfileService + rules ruletypes.RuleTypeService + dataSources datasourceservice.DataSourcesService } // NewSubscriptionService creates an instance of the SubscriptionService interface func NewSubscriptionService( profiles profsvc.ProfileService, rules ruletypes.RuleTypeService, + dataSources datasourceservice.DataSourcesService, ) SubscriptionService { return &subscriptionService{ - profiles: profiles, - rules: rules, + profiles: profiles, + rules: rules, + dataSources: dataSources, } } @@ -67,7 +71,7 @@ func (s *subscriptionService) Subscribe( ctx context.Context, projectID uuid.UUID, bundle reader.BundleReader, - qtx db.Querier, + qtx db.ExtendQuerier, ) error { metadata := bundle.GetMetadata() _, err := qtx.GetSubscriptionByProjectBundle(ctx, db.GetSubscriptionByProjectBundleParams{ @@ -100,11 +104,19 @@ func (s *subscriptionService) Subscribe( return fmt.Errorf("error while creating subscription: %w", err) } + // populate all data sources from this bundle into the project + // this should happen before populating the rules, as rules may depend on data sources + err = s.upsertBundleDataSources(ctx, qtx, projectID, bundle, subscription.ID) + if err != nil { + return fmt.Errorf("error while creating data sources in project: %w", err) + } + // populate all rule types from this bundle into the project err = s.upsertBundleRules(ctx, qtx, projectID, bundle, subscription.ID) if err != nil { return fmt.Errorf("error while creating rules in project: %w", err) } + return nil } @@ -194,3 +206,15 @@ func (s *subscriptionService) upsertBundleRules( return s.rules.UpsertRuleType(ctx, projectID, subscriptionID, ruleType, qtx) }) } + +func (s *subscriptionService) upsertBundleDataSources( + ctx context.Context, + qtx db.ExtendQuerier, + projectID uuid.UUID, + bundle reader.BundleReader, + subscriptionID uuid.UUID, +) error { + return bundle.ForEachDataSource(func(dataSource *minderv1.DataSource) error { + return s.dataSources.Upsert(ctx, projectID, subscriptionID, dataSource, datasourceservice.OptionsBuilder().WithTransaction(qtx)) + }) +} diff --git a/internal/marketplaces/subscriptions/service_test.go b/internal/marketplaces/subscriptions/service_test.go index 8703776f08..bd203e75df 100644 --- a/internal/marketplaces/subscriptions/service_test.go +++ b/internal/marketplaces/subscriptions/service_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" + datasourceservice "github.com/mindersec/minder/internal/datasources/service" + dsf "github.com/mindersec/minder/internal/datasources/service/mock/fixtures" "github.com/mindersec/minder/internal/db" dbf "github.com/mindersec/minder/internal/db/fixtures" brf "github.com/mindersec/minder/internal/marketplaces/bundles/mock/fixtures" @@ -27,11 +29,12 @@ import ( func TestSubscriptionService_Subscribe(t *testing.T) { t.Parallel() scenarios := []struct { - Name string - DBSetup dbf.DBMockBuilder - BundleSetup brf.BundleMockBuilder - RuleTypeSetup rsf.RuleTypeSvcMockBuilder - ExpectedError string + Name string + DBSetup dbf.DBMockBuilder + BundleSetup brf.BundleMockBuilder + RuleTypeSetup rsf.RuleTypeSvcMockBuilder + DataSourceSetup dsf.DataSourcesSvcMockBuilder + ExpectedError string }{ { Name: "Subscribe is a no-op when the subscription already exists", @@ -57,23 +60,32 @@ func TestSubscriptionService_Subscribe(t *testing.T) { ExpectedError: "error while creating subscription", }, { - Name: "Subscribe returns error if rules cannot be read from bundle", - DBSetup: dbf.NewDBMock(withNotFoundFindSubscription, withBundleUpsert, withSuccessfulCreateSubscription), - BundleSetup: brf.NewBundleReaderMock(brf.WithMetadata, brf.WithFailedForEachRuleType), - ExpectedError: "error while creating rules in project", + Name: "Subscribe returns error if rules cannot be read from bundle", + DBSetup: dbf.NewDBMock(withNotFoundFindSubscription, withBundleUpsert, withSuccessfulCreateSubscription), + BundleSetup: brf.NewBundleReaderMock(brf.WithMetadata, brf.WithFailedForEachRuleType, brf.WithSuccessfulForEachDataSource), + DataSourceSetup: dsf.NewDataSourcesServiceMock(dsf.WithSuccessfulUpsertDataSource), + ExpectedError: "error while creating rules in project", }, { - Name: "Subscribe returns error if rules cannot be upserted into database", + Name: "Subscribe returns error if rules cannot be upserted into database", + DBSetup: dbf.NewDBMock(withNotFoundFindSubscription, withBundleUpsert, withSuccessfulCreateSubscription), + BundleSetup: brf.NewBundleReaderMock(brf.WithMetadata, brf.WithSuccessfulForEachRuleType, brf.WithSuccessfulForEachDataSource), + DataSourceSetup: dsf.NewDataSourcesServiceMock(dsf.WithSuccessfulUpsertDataSource), + RuleTypeSetup: rsf.NewRuleTypeServiceMock(rsf.WithFailedUpsertRuleType), + ExpectedError: "error while creating rules in project", + }, + { + Name: "Subscribe returns error if data sources cannot be read from bundle", DBSetup: dbf.NewDBMock(withNotFoundFindSubscription, withBundleUpsert, withSuccessfulCreateSubscription), - BundleSetup: brf.NewBundleReaderMock(brf.WithMetadata, brf.WithSuccessfulForEachRuleType), - RuleTypeSetup: rsf.NewRuleTypeServiceMock(rsf.WithFailedUpsertRuleType), - ExpectedError: "error while creating rules in project", + BundleSetup: brf.NewBundleReaderMock(brf.WithMetadata, brf.WithFailedForEachDataSource), + ExpectedError: "error while creating data sources in project", }, { - Name: "Subscribe creates subscription", - DBSetup: dbf.NewDBMock(withNotFoundFindSubscription, withSuccessfulCreateSubscription, withBundleUpsert), - BundleSetup: brf.NewBundleReaderMock(brf.WithMetadata, brf.WithSuccessfulForEachRuleType), - RuleTypeSetup: rsf.NewRuleTypeServiceMock(rsf.WithSuccessfulUpsertRuleType), + Name: "Subscribe creates subscription", + DBSetup: dbf.NewDBMock(withNotFoundFindSubscription, withSuccessfulCreateSubscription, withBundleUpsert), + BundleSetup: brf.NewBundleReaderMock(brf.WithMetadata, brf.WithSuccessfulForEachRuleType, brf.WithSuccessfulForEachDataSource), + RuleTypeSetup: rsf.NewRuleTypeServiceMock(rsf.WithSuccessfulUpsertRuleType), + DataSourceSetup: dsf.NewDataSourcesServiceMock(dsf.WithSuccessfulUpsertDataSource), }, } @@ -92,7 +104,7 @@ func TestSubscriptionService_Subscribe(t *testing.T) { querier := getQuerier(ctrl, scenario.DBSetup) - svc := createService(ctrl, nil, scenario.RuleTypeSetup) + svc := createService(ctrl, nil, scenario.RuleTypeSetup, scenario.DataSourceSetup) err := svc.Subscribe(ctx, projectID, bundle, querier) if scenario.ExpectedError == "" { require.NoError(t, err) @@ -156,7 +168,7 @@ func TestSubscriptionService_CreateProfile(t *testing.T) { bundle := scenario.BundleSetup(ctrl) querier := getQuerier(ctrl, scenario.DBSetup) - svc := createService(ctrl, scenario.ProfileSetup, nil) + svc := createService(ctrl, scenario.ProfileSetup, nil, nil) err := svc.CreateProfile(ctx, projectID, bundle, profileName, querier) if scenario.ExpectedError == "" { require.NoError(t, err) @@ -228,6 +240,7 @@ func createService( ctrl *gomock.Controller, profileSetup psf.ProfileSvcMockBuilder, ruleTypeSetup rsf.RuleTypeSvcMockBuilder, + dataSourceSetup dsf.DataSourcesSvcMockBuilder, ) subscriptions.SubscriptionService { var rules ruletypes.RuleTypeService if ruleTypeSetup != nil { @@ -239,7 +252,12 @@ func createService( profSvc = profileSetup(ctrl) } - return subscriptions.NewSubscriptionService(profSvc, rules) + var dataSources datasourceservice.DataSourcesService + if dataSourceSetup != nil { + dataSources = dataSourceSetup(ctrl) + } + + return subscriptions.NewSubscriptionService(profSvc, rules, dataSources) } func getQuerier(ctrl *gomock.Controller, dbSetup dbf.DBMockBuilder) db.ExtendQuerier { diff --git a/internal/projects/creator.go b/internal/projects/creator.go index 13b4f422a8..4558b2493d 100644 --- a/internal/projects/creator.go +++ b/internal/projects/creator.go @@ -29,7 +29,7 @@ type ProjectCreator interface { // (project, marketplace subscriptions, etc.) but *does not* create a project. ProvisionSelfEnrolledProject( ctx context.Context, - qtx db.Querier, + qtx db.ExtendQuerier, projectName string, userSub string, ) (outproj *db.Project, projerr error) @@ -63,7 +63,7 @@ var ( func (p *projectCreator) ProvisionSelfEnrolledProject( ctx context.Context, - qtx db.Querier, + qtx db.ExtendQuerier, projectName string, userSub string, ) (outproj *db.Project, projerr error) { diff --git a/internal/providers/github/service/mock/service.go b/internal/providers/github/service/mock/service.go index f32d241ad2..cc74916dc4 100644 --- a/internal/providers/github/service/mock/service.go +++ b/internal/providers/github/service/mock/service.go @@ -60,7 +60,7 @@ func (mr *MockGitHubProviderServiceMockRecorder) CreateGitHubAppProvider(ctx, to } // CreateGitHubAppWithoutInvitation mocks base method. -func (m *MockGitHubProviderService) CreateGitHubAppWithoutInvitation(ctx context.Context, qtx db.Querier, userID, installationID int64) (*db.Project, error) { +func (m *MockGitHubProviderService) CreateGitHubAppWithoutInvitation(ctx context.Context, qtx db.ExtendQuerier, userID, installationID int64) (*db.Project, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateGitHubAppWithoutInvitation", ctx, qtx, userID, installationID) ret0, _ := ret[0].(*db.Project) diff --git a/internal/providers/github/service/service.go b/internal/providers/github/service/service.go index 63e239f112..68b1c3ba99 100644 --- a/internal/providers/github/service/service.go +++ b/internal/providers/github/service/service.go @@ -41,7 +41,7 @@ type GitHubProviderService interface { // the installation in preparation for creating a new project when the authorizing user logs in. // // Note that this function may return nil, nil if the installation user is not known to Minder. - CreateGitHubAppWithoutInvitation(ctx context.Context, qtx db.Querier, userID int64, + CreateGitHubAppWithoutInvitation(ctx context.Context, qtx db.ExtendQuerier, userID int64, installationID int64) (*db.Project, error) // ValidateGitHubInstallationId checks if the supplied GitHub token has access to the installation ID ValidateGitHubInstallationId(ctx context.Context, token *oauth2.Token, installationID int64) error @@ -67,7 +67,7 @@ var ErrInvalidTokenIdentity = errors.New("invalid token identity") // present in the system. If a db.Project is returned, it should be used as the // location to create a Provider corresponding to the GitHub App installation. type ProjectFactory func( - ctx context.Context, qtx db.Querier, name string, user int64) (*db.Project, error) + ctx context.Context, qtx db.ExtendQuerier, name string, user int64) (*db.Project, error) type ghProviderService struct { store db.Store @@ -167,7 +167,7 @@ func (p *ghProviderService) CreateGitHubAppProvider( // Note that this function may return nil, nil if the installation user is not known to Minder. func (p *ghProviderService) CreateGitHubAppWithoutInvitation( ctx context.Context, - qtx db.Querier, + qtx db.ExtendQuerier, userID int64, installationID int64, ) (*db.Project, error) { diff --git a/internal/providers/github/service/service_test.go b/internal/providers/github/service/service_test.go index 1aa8803f66..2687d02077 100644 --- a/internal/providers/github/service/service_test.go +++ b/internal/providers/github/service/service_test.go @@ -308,7 +308,7 @@ func TestProviderService_CreateGitHubAppWithNewProject(t *testing.T) { PrivateKey: pvtKeyFile.Name(), }, } - factory := func(_ context.Context, qtx db.Querier, name string, _ int64) (*db.Project, error) { + factory := func(_ context.Context, qtx db.ExtendQuerier, name string, _ int64) (*db.Project, error) { project, err := qtx.CreateProject(context.Background(), db.CreateProjectParams{ Name: name, Metadata: []byte(`{}`), @@ -374,7 +374,7 @@ func TestProviderService_CreateUnclaimedGitHubAppInstallation(t *testing.T) { }, } - factory := func(context.Context, db.Querier, string, int64) (*db.Project, error) { + factory := func(context.Context, db.ExtendQuerier, string, int64) (*db.Project, error) { return nil, errors.New("error getting user for GitHub ID: 404 not found") } diff --git a/internal/service/service.go b/internal/service/service.go index 06fce58ac4..3c8d6a695a 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -95,7 +95,8 @@ func AllInOneServerService( profileSvc := profiles.NewProfileService(evt, selChecker) ruleSvc := ruletypes.NewRuleTypeService() roleScv := roles.NewRoleService() - marketplace, err := marketplaces.NewMarketplaceFromServiceConfig(cfg.Marketplace, profileSvc, ruleSvc) + dataSourcesSvc := datasourcessvc.NewDataSourceService(store) + marketplace, err := marketplaces.NewMarketplaceFromServiceConfig(cfg.Marketplace, profileSvc, ruleSvc, dataSourcesSvc) if err != nil { return fmt.Errorf("failed to create marketplace: %w", err) } @@ -171,7 +172,6 @@ func AllInOneServerService( repos := repositories.NewRepositoryService(store, propSvc, evt, providerManager) projectDeleter := projects.NewProjectDeleter(authzClient, providerManager) sessionsService := session.NewProviderSessionService(providerManager, providerStore, store) - dataSourcesSvc := datasourcessvc.NewDataSourceService(store) s := controlplane.NewServer( store, @@ -332,7 +332,7 @@ func makeProjectFactory( ) service.ProjectFactory { return func( ctx context.Context, - qtx db.Querier, + qtx db.ExtendQuerier, name string, ghUser int64, ) (*db.Project, error) { diff --git a/pkg/mindpak/bundle.go b/pkg/mindpak/bundle.go index 2ea8046ff6..590c163be0 100644 --- a/pkg/mindpak/bundle.go +++ b/pkg/mindpak/bundle.go @@ -109,14 +109,16 @@ func (b *Bundle) ReadSource() error { b.Manifest = &Manifest{ Metadata: &Metadata{}, Files: &Files{ - Profiles: []*File{}, - RuleTypes: []*File{}, + Profiles: []*File{}, + RuleTypes: []*File{}, + DataSources: []*File{}, }, } b.Files = &Files{ - Profiles: []*File{}, - RuleTypes: []*File{}, + Profiles: []*File{}, + RuleTypes: []*File{}, + DataSources: []*File{}, } err := fs.WalkDir(b.Source, ".", func(path string, d fs.DirEntry, err error) error { @@ -127,9 +129,7 @@ func (b *Bundle) ReadSource() error { return nil } - if !strings.HasPrefix(path, PathProfiles+"/") && - !strings.HasPrefix(path, PathRuleTypes+"/") && - !strings.HasPrefix(path, ManifestFileName) { + if !pathInKnownDirectory(path) { return fmt.Errorf("found unexpected entry in mindpak source: %q", path) } @@ -165,17 +165,26 @@ func (b *Bundle) ReadSource() error { b.Files.Profiles = append(b.Files.Profiles, &fentry) case strings.HasPrefix(path, PathRuleTypes): b.Files.RuleTypes = append(b.Files.RuleTypes, &fentry) + case strings.HasPrefix(path, PathDataSources): + b.Files.DataSources = append(b.Files.DataSources, &fentry) } return nil }) if err != nil { - return fmt.Errorf("traversing bundle data source: %w", err) + return fmt.Errorf("traversing bundle: %w", err) } return nil } +func pathInKnownDirectory(path string) bool { + return strings.HasPrefix(path, PathProfiles+"/") || + strings.HasPrefix(path, PathRuleTypes+"/") || + strings.HasPrefix(path, PathDataSources+"/") || + strings.HasPrefix(path, ManifestFileName) +} + // Verify checks the contents of the bundle against its manifest func (_ *Bundle) Verify() error { // FIXME(puerco): Implement @@ -185,11 +194,16 @@ func (_ *Bundle) Verify() error { func copyTarIntoMemory(tarReader *tar.Reader) (fs.StatFS, error) { // create the memfs instance, and create the directories we need sourceFS := afero.NewIOFS(afero.NewMemMapFs()) - if err := sourceFS.MkdirAll("/"+PathProfiles, 0700); err != nil { - return nil, fmt.Errorf("error creating directory in memfs: %w", err) + initialDirs := []string{ + "/" + PathProfiles, + "/" + PathRuleTypes, + "/" + PathDataSources, } - if err := sourceFS.MkdirAll("/"+PathRuleTypes, 0700); err != nil { - return nil, fmt.Errorf("error creating directory in memfs: %w", err) + + for _, dir := range initialDirs { + if err := sourceFS.MkdirAll(dir, 0700); err != nil { + return nil, fmt.Errorf("error creating directory in memfs: %w", err) + } } var memFile afero.File @@ -230,5 +244,34 @@ func copyTarIntoMemory(tarReader *tar.Reader) (fs.StatFS, error) { } } + // Remove the in-memory directories if they are empty + for _, dir := range initialDirs { + err := removeDirIfEmpty(sourceFS, dir) + if err != nil { + return nil, err + } + } + return sourceFS, nil } + +func removeDirIfEmpty(sourceFS afero.IOFS, dir string) error { + isEmpty, err := isDirEmpty(sourceFS, dir) + if err != nil { + return fmt.Errorf("error checking if directory is empty: %w", err) + } + if isEmpty { + if err := sourceFS.Remove(dir); err != nil { + return fmt.Errorf("error removing empty directory: %w", err) + } + } + return nil +} + +func isDirEmpty(sourceFS afero.IOFS, path string) (bool, error) { + entries, err := sourceFS.ReadDir(path) + if err != nil { + return false, err + } + return len(entries) == 0, nil +} diff --git a/pkg/mindpak/bundle_test.go b/pkg/mindpak/bundle_test.go index 2eddfa99ed..3700475889 100644 --- a/pkg/mindpak/bundle_test.go +++ b/pkg/mindpak/bundle_test.go @@ -43,6 +43,7 @@ func TestReadSource(t *testing.T) { Hashes: map[HashAlgorithm]string{SHA256: "3857bca2ccabdac3d136eb3df4549ddd87a00ddef9fdcf88d8f824e5e796d34c"}, }, }, + DataSources: []*File{}, }, Source: nil, }, diff --git a/pkg/mindpak/mindpack.go b/pkg/mindpak/mindpack.go index 82f5a55ef3..8c55bc4f6a 100644 --- a/pkg/mindpak/mindpack.go +++ b/pkg/mindpak/mindpack.go @@ -17,6 +17,9 @@ const ( // PathRuleTypes is the name of the directory holding the rule types of a bundle PathRuleTypes = "rule_types" + // PathDataSources is the name of the directory holding the data sources of a bundle + PathDataSources = "data_sources" + // ManifestFileName is the defaul filename for the manifest ManifestFileName = "manifest.json" ) @@ -42,6 +45,7 @@ type File struct { // Files is a collection of the files included in the bundle organized by type type Files struct { - Profiles []*File `json:"profiles,omitempty"` - RuleTypes []*File `json:"ruleTypes,omitempty"` + Profiles []*File `json:"profiles,omitempty"` + RuleTypes []*File `json:"ruleTypes,omitempty"` + DataSources []*File `json:"dataSources,omitempty"` } diff --git a/pkg/mindpak/reader/reader.go b/pkg/mindpak/reader/reader.go index ecfa49fd06..aaba15eae5 100644 --- a/pkg/mindpak/reader/reader.go +++ b/pkg/mindpak/reader/reader.go @@ -27,6 +27,10 @@ type BundleReader interface { // and parse the rule type, and then applies the specified anonymous // function to the rule type ForEachRuleType(func(*v1.RuleType) error) error + // ForEachDataSource walks each data source in the bundle, attempts to read + // and parse the data source, and then applies the specified anonymous + // function to the rule type + ForEachDataSource(func(source *v1.DataSource) error) error } type profileSetType = map[string]struct{} @@ -91,7 +95,10 @@ func (b *bundleReader) ForEachRuleType(fn func(*v1.RuleType) error) error { var file fs.File // used for error handling if we return during the loop defer func() { - _ = file.Close() + // Add precaution to close file only if it was assigned + if file != nil { + _ = file.Close() + } }() for _, ruleType := range b.original.Files.RuleTypes { @@ -121,6 +128,44 @@ func (b *bundleReader) ForEachRuleType(fn func(*v1.RuleType) error) error { return nil } +func (b *bundleReader) ForEachDataSource(fn func(source *v1.DataSource) error) error { + var err error + var file fs.File + // used for error handling if we return during the loop + defer func() { + // Add precaution to close file only if it was assigned + if file != nil { + _ = file.Close() + } + }() + + for _, dataSource := range b.original.Files.DataSources { + // read from bundle + path := fmt.Sprintf("%s/%s", mindpak.PathDataSources, dataSource.Name) + file, err = b.original.Source.Open(path) + if err != nil { + return fmt.Errorf("error reading data source from bundle: %w", err) + } + + // parse data source from YAML + parsedDataSource := &v1.DataSource{} + if err := v1.ParseResourceProto(file, parsedDataSource); err != nil { + return fmt.Errorf("error parsing data source yaml: %w", err) + } + if err = file.Close(); err != nil { + return fmt.Errorf("error closing file: %w", err) + } + + // apply operation from caller + err = fn(parsedDataSource) + if err != nil { + return err + } + } + + return nil +} + func ensureYamlSuffix(name string) string { if strings.HasSuffix(name, fileSuffix) { return name diff --git a/pkg/mindpak/reader/reader_test.go b/pkg/mindpak/reader/reader_test.go index 4e12b4225b..bb03a38beb 100644 --- a/pkg/mindpak/reader/reader_test.go +++ b/pkg/mindpak/reader/reader_test.go @@ -16,7 +16,7 @@ import ( func TestBundle_GetMetadata(t *testing.T) { t.Parallel() - bundle := loadBundle(t) + bundle := loadBundle(t, testDataPath) metadata := bundle.GetMetadata() require.NotNil(t, metadata) require.Equal(t, "t2", metadata.Name) @@ -65,7 +65,7 @@ func TestBundle_GetProfile(t *testing.T) { } // immutable - can be shared across parallel runs - bundle := loadBundle(t) + bundle := loadBundle(t, testDataPath) for i := range scenarios { scenario := scenarios[i] t.Run(scenario.Name, func(t *testing.T) { @@ -86,7 +86,7 @@ func TestBundle_GetProfile(t *testing.T) { func TestBundle_ForEachRuleType(t *testing.T) { t.Parallel() results := []string{} - bundle := loadBundle(t) + bundle := loadBundle(t, testDataPath) err := bundle.ForEachRuleType(func(ruleType *minderv1.RuleType) error { results = append(results, ruleType.Name) return nil @@ -98,18 +98,52 @@ func TestBundle_ForEachRuleType(t *testing.T) { func TestBundle_ForEachRuleTypeError(t *testing.T) { t.Parallel() errorMessage := "oh no" - bundle := loadBundle(t) + bundle := loadBundle(t, testDataPath) err := bundle.ForEachRuleType(func(_ *minderv1.RuleType) error { return errors.New(errorMessage) }) require.ErrorContains(t, err, errorMessage) } -func loadBundle(t *testing.T) reader.BundleReader { +func TestBundle_ForEachDataSource(t *testing.T) { + t.Parallel() + results := []string{} + bundle := loadBundle(t, testDataPath) + err := bundle.ForEachDataSource(func(dataSource *minderv1.DataSource) error { + results = append(results, dataSource.Name) + return nil + }) + require.NoError(t, err) + require.Equal(t, []string{"osv"}, results) +} + +func TestBundle_ForEachDataSource_NoDataSources(t *testing.T) { + t.Parallel() + results := []string{} + bundle := loadBundle(t, noDataSourcesPath) + err := bundle.ForEachDataSource(func(dataSource *minderv1.DataSource) error { + results = append(results, dataSource.Name) + return nil + }) + require.NoError(t, err) + require.Equal(t, []string{}, results) +} + +func TestBundle_ForEachDataSourceError(t *testing.T) { + t.Parallel() + errorMessage := "oh no" + bundle := loadBundle(t, testDataPath) + err := bundle.ForEachDataSource(func(_ *minderv1.DataSource) error { + return errors.New(errorMessage) + }) + require.ErrorContains(t, err, errorMessage) +} + +func loadBundle(t *testing.T, path string) reader.BundleReader { t.Helper() - bundle, err := mindpak.NewBundleFromDirectory(testDataPath) + bundle, err := mindpak.NewBundleFromDirectory(path) if err != nil { - t.Fatalf("Unable to load test data from %s: %v", testDataPath, err) + t.Fatalf("Unable to load test data from %s: %v", path, err) } return reader.NewBundleReader(bundle) } @@ -117,4 +151,5 @@ func loadBundle(t *testing.T) reader.BundleReader { const ( testDataPath = "../testdata/t2" expectedProfileName = "branch-protection-github-profile" + noDataSourcesPath = "../testdata/no-data-sources" ) diff --git a/pkg/mindpak/testdata/no-data-sources/profiles/branch-protection-github-profile.yaml b/pkg/mindpak/testdata/no-data-sources/profiles/branch-protection-github-profile.yaml new file mode 100644 index 0000000000..8190f70938 --- /dev/null +++ b/pkg/mindpak/testdata/no-data-sources/profiles/branch-protection-github-profile.yaml @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright 2025 The Minder Authors +# SPDX-License-Identifier: Apache-2.0 + +--- +# A profile to verify branch protection settings +version: v1 +type: profile +name: branch-protection-github-profile +context: + provider: github +alert: "off" +remediate: "off" +repository: + - type: branch_protection_enabled + params: + branch: "" + def: {} + - type: branch_protection_allow_deletions + params: + branch: "" + def: + allow_deletions: false + - type: branch_protection_allow_force_pushes + params: + branch: "" + def: + allow_force_pushes: false + - type: branch_protection_enforce_admins + params: + branch: "" + def: + enforce_admins: true + - type: branch_protection_lock_branch + params: + branch: "" + def: + lock_branch: false + - type: branch_protection_require_conversation_resolution + params: + branch: "" + def: + required_conversation_resolution: false + - type: branch_protection_require_pull_request_approving_review_count + params: + branch: "" + def: + required_approving_review_count: 1 + - type: branch_protection_require_pull_request_code_owners_review + params: + branch: "" + def: + require_code_owner_reviews: false + - type: branch_protection_require_pull_request_dismiss_stale_reviews + params: + branch: "" + def: + dismiss_stale_reviews: true + - type: branch_protection_require_pull_request_last_push_approval + params: + branch: "" + def: + require_last_push_approval: true + - type: branch_protection_require_pull_requests + params: + branch: "" + def: + required_pull_request_reviews: true + - type: branch_protection_require_signatures + params: + branch: "" + def: + required_signatures: false diff --git a/pkg/mindpak/testdata/no-data-sources/rule_types/branch_protection_enabled.yaml b/pkg/mindpak/testdata/no-data-sources/rule_types/branch_protection_enabled.yaml new file mode 100644 index 0000000000..af6a92d04b --- /dev/null +++ b/pkg/mindpak/testdata/no-data-sources/rule_types/branch_protection_enabled.yaml @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright 2025 The Minder Authors +# SPDX-License-Identifier: Apache-2.0 + +--- +version: v1 +type: rule-type +name: branch_protection_enabled +context: + provider: github +description: Verifies that a branch has a branch protection rule +guidance: | + You can protect important branches by setting branch protection rules, which define whether + collaborators can delete or force push to the branch and set requirements for any pushes to the branch, + such as passing status checks or a linear commit history. + + For more information, see + https://docs.github.com/en/repositories/configuring-branches-and-merges-in-your-repository/managing-protected-branches/managing-a-branch-protection-rule +def: + # Defines the section of the pipeline the rule will appear in. + # This will affect the template used to render multiple parts + # of the rule. + in_entity: repository + # Defines the schema for parameters that will be passed to the rule + param_schema: + properties: + branch: + type: string + description: "The name of the branch to check. If left empty, the default branch will be used." + required: + - branch + rule_schema: {} + # Defines the configuration for ingesting data relevant for the rule + ingest: + type: rest + rest: + # This is the path to the data source. Given that this will evaluate + # for each repository in the organization, we use a template that + # will be evaluated for each repository. The structure to use is the + # protobuf structure for the entity that is being evaluated. + endpoint: '{{ $branch_param := index .Params "branch" }}/repos/{{.Entity.Owner}}/{{.Entity.Name}}/branches/{{if ne $branch_param "" }}{{ $branch_param }}{{ else }}{{ .Entity.DefaultBranch }}{{ end }}/protection' + # This is the method to use to retrieve the data. It should already default to JSON + parse: json + fallback: + - http_code: 404 + body: | + {"http_status": 404, "message": "Not Protected"} + eval: + type: rego + rego: + type: deny-by-default + def: | + package minder + + import future.keywords.every + import future.keywords.if + + default allow := false + + allow if { + input.ingested.url != "" + } + # Defines the configuration for alerting on the rule + alert: + type: security_advisory + security_advisory: + severity: "medium" diff --git a/pkg/mindpak/testdata/t2/data_sources/osv.yaml b/pkg/mindpak/testdata/t2/data_sources/osv.yaml new file mode 100644 index 0000000000..6a3e95d00e --- /dev/null +++ b/pkg/mindpak/testdata/t2/data_sources/osv.yaml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright 2024 The Minder Authors +# SPDX-License-Identifier: Apache-2.0 + +--- +version: v1 +type: data-source +name: osv +context: {} +rest: + def: + query: + endpoint: 'https://api.osv.dev/v1/query' + parse: json + method: POST + body_from_field: query + input_schema: + type: object + properties: + query: + type: object + properties: + version: + type: string + package: + type: object + properties: + ecosystem: + type: string + description: The ecosystem the dependency belongs to + name: + type: string + description: The name of the dependency + required: + - query