Skip to content

Commit

Permalink
refact pkg/database: context propagation (start) (#3226)
Browse files Browse the repository at this point in the history
* refact pkg/database: context propagation (part)

* more context propagation (usagemetrics)

* propagate errors when updating metrics
  • Loading branch information
mmetc authored Sep 12, 2024
1 parent cae76ba commit 6810b41
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 20 deletions.
6 changes: 4 additions & 2 deletions cmd/crowdsec/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha
return
}

decisions, err := dbClient.QueryDecisionCountByScenario()
ctx := r.Context()

decisions, err := dbClient.QueryDecisionCountByScenario(ctx)
if err != nil {
log.Errorf("Error querying decisions for metrics: %v", err)
next.ServeHTTP(w, r)
Expand All @@ -138,7 +140,7 @@ func computeDynamicMetrics(next http.Handler, dbClient *database.Client) http.Ha
"include_capi": {"false"},
}

alerts, err := dbClient.AlertsCountPerScenario(alertsFilter)
alerts, err := dbClient.AlertsCountPerScenario(ctx, alertsFilter)
if err != nil {
log.Errorf("Error querying alerts for metrics: %v", err)
next.ServeHTTP(w, r)
Expand Down
15 changes: 8 additions & 7 deletions pkg/apiserver/controllers/v1/usagemetrics.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package v1

import (
"context"
"encoding/json"
"errors"
"net/http"
Expand All @@ -18,17 +19,15 @@ import (
)

// updateBaseMetrics updates the base metrics for a machine or bouncer
func (c *Controller) updateBaseMetrics(machineID string, bouncer *ent.Bouncer, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error {
func (c *Controller) updateBaseMetrics(ctx context.Context, machineID string, bouncer *ent.Bouncer, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error {
switch {
case machineID != "":
c.DBClient.MachineUpdateBaseMetrics(machineID, baseMetrics, hubItems, datasources)
return c.DBClient.MachineUpdateBaseMetrics(ctx, machineID, baseMetrics, hubItems, datasources)
case bouncer != nil:
c.DBClient.BouncerUpdateBaseMetrics(bouncer.Name, bouncer.Type, baseMetrics)
return c.DBClient.BouncerUpdateBaseMetrics(ctx, bouncer.Name, bouncer.Type, baseMetrics)
default:
return errors.New("no machineID or bouncerName set")
}

return nil
}

// UsageMetrics receives metrics from log processors and remediation components
Expand Down Expand Up @@ -172,7 +171,9 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) {
}
}

err := c.updateBaseMetrics(machineID, bouncer, baseMetrics, hubItems, datasources)
ctx := gctx.Request.Context()

err := c.updateBaseMetrics(ctx, machineID, bouncer, baseMetrics, hubItems, datasources)
if err != nil {
logger.Errorf("Failed to update base metrics: %s", err)
c.HandleDBErrors(gctx, err)
Expand All @@ -190,7 +191,7 @@ func (c *Controller) UsageMetrics(gctx *gin.Context) {

receivedAt := time.Now().UTC()

if _, err := c.DBClient.CreateMetric(generatedType, generatedBy, receivedAt, string(jsonPayload)); err != nil {
if _, err := c.DBClient.CreateMetric(ctx, generatedType, generatedBy, receivedAt, string(jsonPayload)); err != nil {
logger.Error(err)
c.HandleDBErrors(gctx, err)

Expand Down
4 changes: 1 addition & 3 deletions pkg/database/alerts.go
Original file line number Diff line number Diff line change
Expand Up @@ -941,14 +941,12 @@ func BuildAlertRequestFromFilter(alerts *ent.AlertQuery, filter map[string][]str
return alerts.Where(preds...), nil
}

func (c *Client) AlertsCountPerScenario(filters map[string][]string) (map[string]int, error) {
func (c *Client) AlertsCountPerScenario(ctx context.Context, filters map[string][]string) (map[string]int, error) {
var res []struct {
Scenario string
Count int
}

ctx := context.TODO()

query := c.Ent.Alert.Query()

query, err := BuildAlertRequestFromFilter(query, filters)
Expand Down
5 changes: 3 additions & 2 deletions pkg/database/bouncers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package database

import (
"context"
"fmt"
"strings"
"time"
Expand All @@ -20,7 +21,7 @@ func (e *BouncerNotFoundError) Error() string {
return fmt.Sprintf("'%s' does not exist", e.BouncerName)
}

func (c *Client) BouncerUpdateBaseMetrics(bouncerName string, bouncerType string, baseMetrics models.BaseMetrics) error {
func (c *Client) BouncerUpdateBaseMetrics(ctx context.Context, bouncerName string, bouncerType string, baseMetrics models.BaseMetrics) error {
os := baseMetrics.Os
features := strings.Join(baseMetrics.FeatureFlags, ",")

Expand All @@ -32,7 +33,7 @@ func (c *Client) BouncerUpdateBaseMetrics(bouncerName string, bouncerType string
SetOsversion(*os.Version).
SetFeatureflags(features).
SetType(bouncerType).
Save(c.CTX)
Save(ctx)
if err != nil {
return fmt.Errorf("unable to update base bouncer metrics in database: %w", err)
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/database/decisions.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package database

import (
"context"
"fmt"
"strconv"
"strings"
Expand Down Expand Up @@ -173,7 +174,7 @@ func (c *Client) QueryExpiredDecisionsWithFilters(filters map[string][]string) (
return data, nil
}

func (c *Client) QueryDecisionCountByScenario() ([]*DecisionsByScenario, error) {
func (c *Client) QueryDecisionCountByScenario(ctx context.Context) ([]*DecisionsByScenario, error) {
query := c.Ent.Decision.Query().Where(
decision.UntilGT(time.Now().UTC()),
)
Expand All @@ -186,7 +187,7 @@ func (c *Client) QueryDecisionCountByScenario() ([]*DecisionsByScenario, error)

var r []*DecisionsByScenario

err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(c.CTX, &r)
err = query.GroupBy(decision.FieldScenario, decision.FieldOrigin, decision.FieldType).Aggregate(ent.Count()).Scan(ctx, &r)
if err != nil {
c.Log.Warningf("QueryDecisionCountByScenario : %s", err)
return nil, errors.Wrap(QueryFail, "count all decisions with filters")
Expand Down
5 changes: 3 additions & 2 deletions pkg/database/machines.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package database

import (
"context"
"fmt"
"strings"
"time"
Expand Down Expand Up @@ -29,7 +30,7 @@ func (e *MachineNotFoundError) Error() string {
return fmt.Sprintf("'%s' does not exist", e.MachineID)
}

func (c *Client) MachineUpdateBaseMetrics(machineID string, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error {
func (c *Client) MachineUpdateBaseMetrics(ctx context.Context, machineID string, baseMetrics models.BaseMetrics, hubItems models.HubItems, datasources map[string]int64) error {
os := baseMetrics.Os
features := strings.Join(baseMetrics.FeatureFlags, ",")

Expand Down Expand Up @@ -63,7 +64,7 @@ func (c *Client) MachineUpdateBaseMetrics(machineID string, baseMetrics models.B
SetLastHeartbeat(heartbeat).
SetHubstate(hubState).
SetDatasources(datasources).
Save(c.CTX)
Save(ctx)
if err != nil {
return fmt.Errorf("unable to update base machine metrics in database: %w", err)
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/database/metrics.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
package database

import (
"context"
"fmt"
"time"

"github.com/crowdsecurity/crowdsec/pkg/database/ent"
"github.com/crowdsecurity/crowdsec/pkg/database/ent/metric"
)

func (c *Client) CreateMetric(generatedType metric.GeneratedType, generatedBy string, receivedAt time.Time, payload string) (*ent.Metric, error) {
func (c *Client) CreateMetric(ctx context.Context, generatedType metric.GeneratedType, generatedBy string, receivedAt time.Time, payload string) (*ent.Metric, error) {
metric, err := c.Ent.Metric.
Create().
SetGeneratedType(generatedType).
SetGeneratedBy(generatedBy).
SetReceivedAt(receivedAt).
SetPayload(payload).
Save(c.CTX)
Save(ctx)
if err != nil {
c.Log.Warningf("CreateMetric: %s", err)
return nil, fmt.Errorf("storing metrics snapshot for '%s' at %s: %w", generatedBy, receivedAt, InsertFail)
Expand Down

0 comments on commit 6810b41

Please sign in to comment.