Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPLYTM-467] feat: initial work to create assessment results generator #9

Merged
merged 15 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions framework/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

oscalTypes "github.com/defenseunicorns/go-oscal/src/types/oscal-1-1-2"
"github.com/hashicorp/go-hclog"
"github.com/oscal-compass/oscal-sdk-go/models/components"
"github.com/oscal-compass/oscal-sdk-go/rules"

"github.com/oscal-compass/compliance-to-policy-go/v2/plugin"
Expand Down Expand Up @@ -98,7 +99,8 @@ func ResolveOptions(config *C2PConfig) (*rules.MemoryStore, map[string]string, e
if err != nil {
return nil, nil, err
}
store, err := rules.NewMemoryStoreFromComponents(allComponents)
store := rules.NewMemoryStore()
err = store.IndexAll(allComponents)
if err != nil {
return store, titleByID, err
}
Expand All @@ -107,8 +109,8 @@ func ResolveOptions(config *C2PConfig) (*rules.MemoryStore, map[string]string, e

// resolveOptions returns processed OSCAL Components and a plugin identifier map. This performs most
// of the logic in ResolveOptions, but is broken out to make it easier to test.
func resolveOptions(config *C2PConfig) ([]oscalTypes.DefinedComponent, map[string]string, error) {
var allComponents []oscalTypes.DefinedComponent
func resolveOptions(config *C2PConfig) ([]components.Component, map[string]string, error) {
var allComponents []components.Component
titleByID := make(map[string]string)
for _, compDef := range config.ComponentDefinitions {
if compDef.Components == nil {
Expand All @@ -122,7 +124,8 @@ func resolveOptions(config *C2PConfig) ([]oscalTypes.DefinedComponent, map[strin
}
titleByID[pluginId] = component.Title
}
allComponents = append(allComponents, component)
compAdapter := components.NewDefinedComponentAdapter(component)
allComponents = append(allComponents, compAdapter)
}
}
return allComponents, titleByID, nil
Expand Down
23 changes: 13 additions & 10 deletions framework/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/hashicorp/go-hclog"
"github.com/oscal-compass/oscal-sdk-go/rules"
"github.com/oscal-compass/oscal-sdk-go/settings"

"github.com/oscal-compass/compliance-to-policy-go/v2/framework/config"
"github.com/oscal-compass/compliance-to-policy-go/v2/plugin"
Expand Down Expand Up @@ -101,29 +102,31 @@ func (m *PluginManager) LaunchPolicyPlugins() (map[string]policy.Provider, error
}

// GeneratePolicy identifies policy configuration for each provider in the given pluginSet to execute the Generate() method
// each policy.Provider.
func (m *PluginManager) GeneratePolicy(ctx context.Context, pluginSet map[string]policy.Provider) error {
// each policy.Provider. The rule set passed to each plugin can be configured with compliance specific settings with the
// complianceSettings input.
func (m *PluginManager) GeneratePolicy(ctx context.Context, pluginSet map[string]policy.Provider, complianceSettings settings.Settings) error {
for providerId, policyPlugin := range pluginSet {
componentTitle, ok := m.pluginIdMap[providerId]
if !ok {
return fmt.Errorf("missing title for provider %s", providerId)
}
m.log.Debug(fmt.Sprintf("Generating policy for provider %s", providerId))

ruleSets, err := m.rulesStore.FindByComponent(ctx, componentTitle)
appliedRuleSet, err := settings.ApplyToComponent(ctx, componentTitle, m.rulesStore, complianceSettings)
if err != nil {
return err
return fmt.Errorf("failed to get rule sets for component %s: %w", componentTitle, err)
}
if err := policyPlugin.Generate(ruleSets); err != nil {
if err := policyPlugin.Generate(appliedRuleSet); err != nil {
return fmt.Errorf("plugin %s: %w", providerId, err)
}
}
return nil
}

// AggregateResults identifies policy configuration for each provider in the given pluginSet to execute the GetResults() method
// each policy.Provider.
func (m *PluginManager) AggregateResults(ctx context.Context, pluginSet map[string]policy.Provider) ([]policy.PVPResult, error) {
// each policy.Provider. The rule set passed to each plugin can be configured with compliance specific settings with the
// // complianceSettings input.
func (m *PluginManager) AggregateResults(ctx context.Context, pluginSet map[string]policy.Provider, complianceSettings settings.Settings) ([]policy.PVPResult, error) {
var allResults []policy.PVPResult
for providerId, policyPlugin := range pluginSet {
// get the provider ids here to grab the policy
Expand All @@ -132,12 +135,12 @@ func (m *PluginManager) AggregateResults(ctx context.Context, pluginSet map[stri
return allResults, fmt.Errorf("missing title for provider %s", providerId)
}
m.log.Debug(fmt.Sprintf("Aggregating results for provider %s", providerId))
ruleSets, err := m.rulesStore.FindByComponent(ctx, componentTitle)
appliedRuleSet, err := settings.ApplyToComponent(ctx, componentTitle, m.rulesStore, complianceSettings)
if err != nil {
return allResults, err
return allResults, fmt.Errorf("failed to get rule sets for component %s: %w", componentTitle, err)
}

pluginResults, err := policyPlugin.GetResults(ruleSets)
pluginResults, err := policyPlugin.GetResults(appliedRuleSet)
if err != nil {
return allResults, fmt.Errorf("plugin %s: %w", providerId, err)
}
Expand Down
26 changes: 20 additions & 6 deletions framework/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@ import (
oscalTypes "github.com/defenseunicorns/go-oscal/src/types/oscal-1-1-2"
"github.com/oscal-compass/oscal-sdk-go/extensions"
"github.com/oscal-compass/oscal-sdk-go/generators"
"github.com/oscal-compass/oscal-sdk-go/settings"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

"github.com/oscal-compass/compliance-to-policy-go/v2/framework/config"
"github.com/oscal-compass/compliance-to-policy-go/v2/policy"
)

const testDataPath = "../test/testdata/component-definition-test.json"

var (
expectedCertFileRule = extensions.RuleSet{
Rule: extensions.Rule{
Expand Down Expand Up @@ -77,11 +76,14 @@ func TestPluginManager_GeneratePolicy(t *testing.T) {

// Create pluginSet
providerTestObj := new(policyProvider)
providerTestObj.On("Generate", policy.Policy{expectedKeyFileRule, expectedCertFileRule}).Return(nil)
providerTestObj.On("Generate", policy.Policy{expectedCertFileRule}).Return(nil)
pluginSet := map[string]policy.Provider{
"mypvpvalidator": providerTestObj,
}
err = pluginManager.GeneratePolicy(context.TODO(), pluginSet)

testSettings := settings.NewSettings(map[string]struct{}{"etcd_cert_file": {}}, map[string]string{})

err = pluginManager.GeneratePolicy(context.TODO(), pluginSet, testSettings)
require.NoError(t, err)
providerTestObj.AssertExpectations(t)
}
Expand All @@ -101,13 +103,25 @@ func TestPluginManager_AggregateResults(t *testing.T) {
},
}

updatedParam := &extensions.Parameter{
ID: "file_name",
Description: "A parameter for a file name",
Value: "my_file",
}

updatedKeyFileRule := expectedKeyFileRule
updatedKeyFileRule.Rule.Parameter = updatedParam

// Create pluginSet
providerTestObj := new(policyProvider)
providerTestObj.On("GetResults", policy.Policy{expectedKeyFileRule, expectedCertFileRule}).Return(wantResults, nil)
providerTestObj.On("GetResults", policy.Policy{updatedKeyFileRule}).Return(wantResults, nil)
pluginSet := map[string]policy.Provider{
"mypvpvalidator": providerTestObj,
}
gotResults, err := pluginManager.AggregateResults(context.TODO(), pluginSet)

testSettings := settings.NewSettings(map[string]struct{}{"etcd_key_file": {}}, map[string]string{"file_name": "my_file"})

gotResults, err := pluginManager.AggregateResults(context.TODO(), pluginSet, testSettings)
require.NoError(t, err)
providerTestObj.AssertExpectations(t)
require.Len(t, gotResults, 1)
Expand Down
Loading