Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

context propagation: pkg/database/bouncers #3249

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions cmd/crowdsec-cli/clibouncer/bouncers.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@
return cmd
}

func (cli *cliBouncers) add(bouncerName string, key string) error {
func (cli *cliBouncers) add(ctx context.Context, bouncerName string, key string) error {
var err error

keyLength := 32
Expand All @@ -220,7 +220,7 @@
}
}

_, err = cli.db.CreateBouncer(bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType)
_, err = cli.db.CreateBouncer(ctx, bouncerName, "", middlewares.HashSHA512(key), types.ApiKeyAuthType)
if err != nil {
return fmt.Errorf("unable to create bouncer: %w", err)
}
Expand Down Expand Up @@ -254,8 +254,8 @@
cscli bouncers add MyBouncerName --key <random-key>`,
Args: cobra.ExactArgs(1),
DisableAutoGenTag: true,
RunE: func(_ *cobra.Command, args []string) error {
return cli.add(args[0], key)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.add(cmd.Context(), args[0], key)
},
}

Expand Down Expand Up @@ -304,9 +304,9 @@
return ret, cobra.ShellCompDirectiveNoFileComp
}

func (cli *cliBouncers) delete(bouncers []string, ignoreMissing bool) error {
func (cli *cliBouncers) delete(ctx context.Context, bouncers []string, ignoreMissing bool) error {
for _, bouncerID := range bouncers {
if err := cli.db.DeleteBouncer(bouncerID); err != nil {
if err := cli.db.DeleteBouncer(ctx, bouncerID); err != nil {
var notFoundErr *database.BouncerNotFoundError
if ignoreMissing && errors.As(err, &notFoundErr) {
return nil
Expand All @@ -332,8 +332,8 @@
Aliases: []string{"remove"},
DisableAutoGenTag: true,
ValidArgsFunction: cli.validBouncerID,
RunE: func(_ *cobra.Command, args []string) error {
return cli.delete(args, ignoreMissing)
RunE: func(cmd *cobra.Command, args []string) error {
return cli.delete(cmd.Context(), args, ignoreMissing)
},
}

Expand All @@ -343,7 +343,7 @@
return cmd
}

func (cli *cliBouncers) prune(duration time.Duration, force bool) error {
func (cli *cliBouncers) prune(ctx context.Context, duration time.Duration, force bool) error {
if duration < 2*time.Minute {
if yes, err := ask.YesNo(
"The duration you provided is less than 2 minutes. "+
Expand All @@ -355,7 +355,7 @@
}
}

bouncers, err := cli.db.QueryBouncersInactiveSince(time.Now().UTC().Add(-duration))
bouncers, err := cli.db.QueryBouncersInactiveSince(ctx, time.Now().UTC().Add(-duration))
if err != nil {
return fmt.Errorf("unable to query bouncers: %w", err)
}
Expand All @@ -378,7 +378,7 @@
}
}

deleted, err := cli.db.BulkDeleteBouncers(bouncers)
deleted, err := cli.db.BulkDeleteBouncers(ctx, bouncers)

Check warning on line 381 in cmd/crowdsec-cli/clibouncer/bouncers.go

View check run for this annotation

Codecov / codecov/patch

cmd/crowdsec-cli/clibouncer/bouncers.go#L381

Added line #L381 was not covered by tests
if err != nil {
return fmt.Errorf("unable to prune bouncers: %w", err)
}
Expand All @@ -403,8 +403,8 @@
DisableAutoGenTag: true,
Example: `cscli bouncers prune -d 45m
cscli bouncers prune -d 45m --force`,
RunE: func(_ *cobra.Command, _ []string) error {
return cli.prune(duration, force)
RunE: func(cmd *cobra.Command, _ []string) error {
return cli.prune(cmd.Context(), duration, force)
},
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/apiserver/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ func CreateTestBouncer(t *testing.T, config *csconfig.DatabaseCfg) string {
apiKey, err := middlewares.GenerateAPIKey(keyLength)
require.NoError(t, err)

_, err = dbClient.CreateBouncer("test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
_, err = dbClient.CreateBouncer(ctx, "test", "127.0.0.1", middlewares.HashSHA512(apiKey), types.ApiKeyAuthType)
require.NoError(t, err)

return apiKey
Expand Down
8 changes: 6 additions & 2 deletions pkg/apiserver/controllers/v1/decisions.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
data []*ent.Decision
)

ctx := gctx.Request.Context()

Check warning on line 47 in pkg/apiserver/controllers/v1/decisions.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/controllers/v1/decisions.go#L47

Added line #L47 was not covered by tests
bouncerInfo, err := getBouncerFromContext(gctx)
if err != nil {
gctx.JSON(http.StatusUnauthorized, gin.H{"message": "not allowed"})
Expand Down Expand Up @@ -73,7 +75,7 @@
}

if bouncerInfo.LastPull == nil || time.Now().UTC().Sub(*bouncerInfo.LastPull) >= time.Minute {
if err := c.DBClient.UpdateBouncerLastPull(time.Now().UTC(), bouncerInfo.ID); err != nil {
if err := c.DBClient.UpdateBouncerLastPull(ctx, time.Now().UTC(), bouncerInfo.ID); err != nil {
log.Errorf("failed to update bouncer last pull: %v", err)
}
}
Expand Down Expand Up @@ -370,6 +372,8 @@
func (c *Controller) StreamDecision(gctx *gin.Context) {
var err error

ctx := gctx.Request.Context()

Check warning on line 376 in pkg/apiserver/controllers/v1/decisions.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/controllers/v1/decisions.go#L376

Added line #L376 was not covered by tests
streamStartTime := time.Now().UTC()

bouncerInfo, err := getBouncerFromContext(gctx)
Expand Down Expand Up @@ -400,7 +404,7 @@

if err == nil {
//Only update the last pull time if no error occurred when sending the decisions to avoid missing decisions
if err := c.DBClient.UpdateBouncerLastPull(streamStartTime, bouncerInfo.ID); err != nil {
if err := c.DBClient.UpdateBouncerLastPull(ctx, streamStartTime, bouncerInfo.ID); err != nil {
log.Errorf("unable to update bouncer '%s' pull: %v", bouncerInfo.Name, err)
}
}
Expand Down
18 changes: 12 additions & 6 deletions pkg/apiserver/middlewares/v1/api_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
return nil
}

ctx := c.Request.Context()

Check warning on line 68 in pkg/apiserver/middlewares/v1/api_key.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/middlewares/v1/api_key.go#L68

Added line #L68 was not covered by tests
extractedCN, err := a.TlsAuth.ValidateCert(c)
if err != nil {
logger.Warn(err)
Expand All @@ -73,7 +75,7 @@
logger = logger.WithField("cn", extractedCN)

bouncerName := fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
bouncer, err := a.DbClient.SelectBouncerByName(bouncerName)
bouncer, err := a.DbClient.SelectBouncerByName(ctx, bouncerName)

// This is likely not the proper way, but isNotFound does not seem to work
if err != nil && strings.Contains(err.Error(), "bouncer not found") {
Expand All @@ -87,7 +89,7 @@

logger.Infof("Creating bouncer %s", bouncerName)

bouncer, err = a.DbClient.CreateBouncer(bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType)
bouncer, err = a.DbClient.CreateBouncer(ctx, bouncerName, c.ClientIP(), HashSHA512(apiKey), types.TlsAuthType)
if err != nil {
logger.Errorf("while creating bouncer db entry: %s", err)
return nil
Expand All @@ -112,9 +114,11 @@
return nil
}

ctx := c.Request.Context()

hashStr := HashSHA512(val[0])

bouncer, err := a.DbClient.SelectBouncer(hashStr)
bouncer, err := a.DbClient.SelectBouncer(ctx, hashStr)
if err != nil {
logger.Errorf("while fetching bouncer info: %s", err)
return nil
Expand All @@ -132,6 +136,8 @@
return func(c *gin.Context) {
var bouncer *ent.Bouncer

ctx := c.Request.Context()

clientIP := c.ClientIP()

logger := log.WithField("ip", clientIP)
Expand All @@ -153,7 +159,7 @@
logger = logger.WithField("name", bouncer.Name)

if bouncer.IPAddress == "" {
if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil {
if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil {
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
Expand All @@ -166,7 +172,7 @@
if bouncer.IPAddress != clientIP && bouncer.IPAddress != "" && c.Request.Method != http.MethodHead {
log.Warningf("new IP address detected for bouncer '%s': %s (old: %s)", bouncer.Name, clientIP, bouncer.IPAddress)

if err := a.DbClient.UpdateBouncerIP(clientIP, bouncer.ID); err != nil {
if err := a.DbClient.UpdateBouncerIP(ctx, clientIP, bouncer.ID); err != nil {

Check warning on line 175 in pkg/apiserver/middlewares/v1/api_key.go

View check run for this annotation

Codecov / codecov/patch

pkg/apiserver/middlewares/v1/api_key.go#L175

Added line #L175 was not covered by tests
logger.Errorf("Failed to update ip address for '%s': %s\n", bouncer.Name, err)
c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
c.Abort()
Expand All @@ -182,7 +188,7 @@
}

if bouncer.Version != useragent[1] || bouncer.Type != useragent[0] {
if err := a.DbClient.UpdateBouncerTypeAndVersion(useragent[0], useragent[1], bouncer.ID); err != nil {
if err := a.DbClient.UpdateBouncerTypeAndVersion(ctx, useragent[0], useragent[1], bouncer.ID); err != nil {
logger.Errorf("failed to update bouncer version and type: %s", err)
c.JSON(http.StatusForbidden, gin.H{"message": "bad user agent"})
c.Abort()
Expand Down
2 changes: 1 addition & 1 deletion pkg/apiserver/usage_metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func TestRCMetrics(t *testing.T) {
assert.Equal(t, tt.expectedStatusCode, w.Code)
assert.Contains(t, w.Body.String(), tt.expectedResponse)

bouncer, _ := dbClient.SelectBouncerByName("test")
bouncer, _ := dbClient.SelectBouncerByName(ctx, "test")
metrics, _ := dbClient.GetBouncerUsageMetricsByName(ctx, "test")

assert.Len(t, metrics, tt.expectedMetricsCount)
Expand Down
36 changes: 18 additions & 18 deletions pkg/database/bouncers.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@
return nil
}

func (c *Client) SelectBouncer(apiKeyHash string) (*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(c.CTX)
func (c *Client) SelectBouncer(ctx context.Context, apiKeyHash string) (*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().Where(bouncer.APIKeyEQ(apiKeyHash)).First(ctx)
if err != nil {
return nil, err
}

return result, nil
}

func (c *Client) SelectBouncerByName(bouncerName string) (*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(c.CTX)
func (c *Client) SelectBouncerByName(ctx context.Context, bouncerName string) (*ent.Bouncer, error) {
result, err := c.Ent.Bouncer.Query().Where(bouncer.NameEQ(bouncerName)).First(ctx)
if err != nil {
return nil, err
}
Expand All @@ -68,14 +68,14 @@
return result, nil
}

func (c *Client) CreateBouncer(name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) {
func (c *Client) CreateBouncer(ctx context.Context, name string, ipAddr string, apiKey string, authType string) (*ent.Bouncer, error) {
bouncer, err := c.Ent.Bouncer.
Create().
SetName(name).
SetAPIKey(apiKey).
SetRevoked(false).
SetAuthType(authType).
Save(c.CTX)
Save(ctx)
if err != nil {
if ent.IsConstraintError(err) {
return nil, fmt.Errorf("bouncer %s already exists", name)
Expand All @@ -87,11 +87,11 @@
return bouncer, nil
}

func (c *Client) DeleteBouncer(name string) error {
func (c *Client) DeleteBouncer(ctx context.Context, name string) error {
nbDeleted, err := c.Ent.Bouncer.
Delete().
Where(bouncer.NameEQ(name)).
Exec(c.CTX)
Exec(ctx)
if err != nil {
return err
}
Expand All @@ -103,50 +103,50 @@
return nil
}

func (c *Client) BulkDeleteBouncers(bouncers []*ent.Bouncer) (int, error) {
func (c *Client) BulkDeleteBouncers(ctx context.Context, bouncers []*ent.Bouncer) (int, error) {

Check warning on line 106 in pkg/database/bouncers.go

View check run for this annotation

Codecov / codecov/patch

pkg/database/bouncers.go#L106

Added line #L106 was not covered by tests
ids := make([]int, len(bouncers))
for i, b := range bouncers {
ids[i] = b.ID
}

nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(c.CTX)
nbDeleted, err := c.Ent.Bouncer.Delete().Where(bouncer.IDIn(ids...)).Exec(ctx)

Check warning on line 112 in pkg/database/bouncers.go

View check run for this annotation

Codecov / codecov/patch

pkg/database/bouncers.go#L112

Added line #L112 was not covered by tests
if err != nil {
return nbDeleted, fmt.Errorf("unable to delete bouncers: %w", err)
}

return nbDeleted, nil
}

func (c *Client) UpdateBouncerLastPull(lastPull time.Time, id int) error {
func (c *Client) UpdateBouncerLastPull(ctx context.Context, lastPull time.Time, id int) error {
_, err := c.Ent.Bouncer.UpdateOneID(id).
SetLastPull(lastPull).
Save(c.CTX)
Save(ctx)
if err != nil {
return fmt.Errorf("unable to update machine last pull in database: %w", err)
}

return nil
}

func (c *Client) UpdateBouncerIP(ipAddr string, id int) error {
_, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(c.CTX)
func (c *Client) UpdateBouncerIP(ctx context.Context, ipAddr string, id int) error {
_, err := c.Ent.Bouncer.UpdateOneID(id).SetIPAddress(ipAddr).Save(ctx)
if err != nil {
return fmt.Errorf("unable to update bouncer ip address in database: %w", err)
}

return nil
}

func (c *Client) UpdateBouncerTypeAndVersion(bType string, version string, id int) error {
_, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(c.CTX)
func (c *Client) UpdateBouncerTypeAndVersion(ctx context.Context, bType string, version string, id int) error {
_, err := c.Ent.Bouncer.UpdateOneID(id).SetVersion(version).SetType(bType).Save(ctx)
if err != nil {
return fmt.Errorf("unable to update bouncer type and version in database: %w", err)
}

return nil
}

func (c *Client) QueryBouncersInactiveSince(t time.Time) ([]*ent.Bouncer, error) {
func (c *Client) QueryBouncersInactiveSince(ctx context.Context, t time.Time) ([]*ent.Bouncer, error) {
return c.Ent.Bouncer.Query().Where(
// poor man's coalesce
bouncer.Or(
Expand All @@ -156,5 +156,5 @@
bouncer.CreatedAtLT(t),
),
),
).All(c.CTX)
).All(ctx)
}