Skip to content

Commit

Permalink
Merge branch 'main' into vault-27421-update-cap-ldap-dep
Browse files Browse the repository at this point in the history
  • Loading branch information
helenfufu authored Jan 7, 2025
2 parents a6be6c1 + e153846 commit adbad9e
Show file tree
Hide file tree
Showing 13 changed files with 589 additions and 39 deletions.
72 changes: 48 additions & 24 deletions builtin/logical/pki/acme_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"io"
"net"
"path"
"strings"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -311,7 +310,22 @@ func (a *acmeState) UpdateAccount(sc *storageContext, acct *acmeAccount) error {
// LoadAccount will load the account object based on the passed in keyId field value
// otherwise will return an error if the account does not exist.
func (a *acmeState) LoadAccount(ac *acmeContext, keyId string) (*acmeAccount, error) {
entry, err := ac.sc.Storage.Get(ac.sc.Context, acmeAccountPrefix+keyId)
acct, err := a.LoadAccountWithoutDirEnforcement(ac.sc, keyId)
if err != nil {
return acct, err
}

if acct.AcmeDirectory != ac.acmeDirectory {
return nil, fmt.Errorf("%w: account part of different ACME directory path", ErrMalformed)
}

return acct, nil
}

// LoadAccountWithoutDirEnforcement will load the account object based on the passed in keyId field value,
// but does not enforce the ACME directory path, normally this is used by non ACME specific APIs.
func (a *acmeState) LoadAccountWithoutDirEnforcement(sc *storageContext, keyId string) (*acmeAccount, error) {
entry, err := sc.Storage.Get(sc.Context, acmeAccountPrefix+keyId)
if err != nil {
return nil, fmt.Errorf("error loading account: %w", err)
}
Expand All @@ -324,13 +338,7 @@ func (a *acmeState) LoadAccount(ac *acmeContext, keyId string) (*acmeAccount, er
if err != nil {
return nil, fmt.Errorf("error decoding account: %w", err)
}

if acct.AcmeDirectory != ac.acmeDirectory {
return nil, fmt.Errorf("%w: account part of different ACME directory path", ErrMalformed)
}

acct.KeyId = keyId

return &acct, nil
}

Expand Down Expand Up @@ -536,6 +544,27 @@ func (a *acmeState) LoadOrder(ac *acmeContext, userCtx *jwsCtx, orderId string)
return &order, nil
}

// LoadAccountOrders will load all orders for a given account ID, this should be used by the
// management interface only, not through any of the ACME APIs.
func (a *acmeState) LoadAccountOrders(sc *storageContext, accountId string) ([]*acmeOrder, error) {
orderIds, err := a.ListOrderIds(sc, accountId)
if err != nil {
return nil, fmt.Errorf("failed listing order ids for account id %s: %w", accountId, err)
}

var orders []*acmeOrder
for _, orderId := range orderIds {
order, err := a.LoadOrder(&acmeContext{sc: sc}, &jwsCtx{Kid: accountId}, orderId)
if err != nil {
return nil, err
}

orders = append(orders, order)
}

return orders, nil
}

func (a *acmeState) SaveOrder(ac *acmeContext, order *acmeOrder) error {
if order.OrderId == "" {
return fmt.Errorf("invalid order, missing order id")
Expand Down Expand Up @@ -565,15 +594,7 @@ func (a *acmeState) ListOrderIds(sc *storageContext, accountId string) ([]string
return nil, fmt.Errorf("failed listing order ids for account %s: %w", accountId, err)
}

orderIds := []string{}
for _, order := range rawOrderIds {
if strings.HasSuffix(order, "/") {
// skip any folders we might have for some reason
continue
}
orderIds = append(orderIds, order)
}
return orderIds, nil
return filterDirEntries(rawOrderIds), nil
}

type acmeCertEntry struct {
Expand Down Expand Up @@ -672,17 +693,20 @@ func (a *acmeState) ListEabIds(sc *storageContext) ([]string, error) {
if err != nil {
return nil, err
}
var ids []string
for _, entry := range entries {
if strings.HasSuffix(entry, "/") {
continue
}
ids = append(ids, entry)
}
ids := filterDirEntries(entries)

return ids, nil
}

func (a *acmeState) ListAccountIds(sc *storageContext) ([]string, error) {
entries, err := sc.Storage.List(sc.Context, acmeAccountPrefix)
if err != nil {
return nil, fmt.Errorf("failed listing ACME account prefix directory %s: %w", acmeAccountPrefix, err)
}

return filterDirEntries(entries), nil
}

func getAcmeSerialToAccountTrackerPath(accountId string, serial string) string {
return acmeAccountPrefix + accountId + "/certs/" + normalizeSerial(serial)
}
Expand Down
2 changes: 2 additions & 0 deletions builtin/logical/pki/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ func Backend(conf *logical.BackendConfig) *backend {
pathAcmeConfig(&b),
pathAcmeEabList(&b),
pathAcmeEabDelete(&b),
pathAcmeMgmtAccountList(&b),
pathAcmeMgmtAccountRead(&b),
},

Secrets: []*framework.Secret{
Expand Down
8 changes: 7 additions & 1 deletion builtin/logical/pki/backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6831,6 +6831,7 @@ func TestProperAuthing(t *testing.T) {
}
serial := resp.Data["serial_number"].(string)
eabKid := "13b80844-e60d-42d2-b7e9-152a8e834b90"
acmeKeyId := "hrKmDYTvicHoHGVN2-3uzZV_BPGdE0W_dNaqYTtYqeo="
paths := map[string]pathAuthChecker{
"ca_chain": shouldBeUnauthedReadList,
"cert/ca_chain": shouldBeUnauthedReadList,
Expand Down Expand Up @@ -6950,6 +6951,8 @@ func TestProperAuthing(t *testing.T) {
"unified-ocsp/dGVzdAo=": shouldBeUnauthedReadList,
"eab/": shouldBeAuthed,
"eab/" + eabKid: shouldBeAuthed,
"acme/mgmt/account/keyid/": shouldBeAuthed,
"acme/mgmt/account/keyid/" + acmeKeyId: shouldBeAuthed,
}

entPaths := getEntProperAuthingPaths(serial)
Expand Down Expand Up @@ -7020,7 +7023,10 @@ func TestProperAuthing(t *testing.T) {
raw_path = strings.ReplaceAll(raw_path, "{serial}", serial)
}
if strings.Contains(raw_path, "acme/account/") && strings.Contains(raw_path, "{kid}") {
raw_path = strings.ReplaceAll(raw_path, "{kid}", "hrKmDYTvicHoHGVN2-3uzZV_BPGdE0W_dNaqYTtYqeo=")
raw_path = strings.ReplaceAll(raw_path, "{kid}", acmeKeyId)
}
if strings.Contains(raw_path, "acme/mgmt/account/") && strings.Contains(raw_path, "{keyid}") {
raw_path = strings.ReplaceAll(raw_path, "{keyid}", acmeKeyId)
}
if strings.Contains(raw_path, "acme/") && strings.Contains(raw_path, "{auth_id}") {
raw_path = strings.ReplaceAll(raw_path, "{auth_id}", "29da8c38-7a09-465e-b9a6-3d76802b1afd")
Expand Down
226 changes: 226 additions & 0 deletions builtin/logical/pki/path_acme_account_mgmt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

package pki

import (
"context"
"errors"
"fmt"
"strings"
"time"

"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
)

func pathAcmeMgmtAccountList(b *backend) *framework.Path {
return &framework.Path{
Pattern: "acme/mgmt/account/keyid/?$",

Operations: map[logical.Operation]framework.OperationHandler{
logical.ListOperation: &framework.PathOperation{
Callback: b.pathAcmeMgmtListAccounts,
DisplayAttrs: &framework.DisplayAttributes{
OperationPrefix: operationPrefixPKI,
OperationVerb: "list-acme-account-keys",
Description: "List all ACME account key identifiers.",
},
},
},

HelpSynopsis: "List all ACME account key identifiers.",
HelpDescription: `Allows an operator to list all ACME account key identifiers.`,
}
}

func pathAcmeMgmtAccountRead(b *backend) *framework.Path {
return &framework.Path{
Pattern: "acme/mgmt/account/keyid/" + framework.GenericNameRegex("keyid"),
Fields: map[string]*framework.FieldSchema{
"keyid": {
Type: framework.TypeString,
Description: "The key identifier of the account.",
Required: true,
},
"status": {
Type: framework.TypeString,
Description: "The status of the account.",
Required: true,
AllowedValues: []interface{}{AccountStatusValid.String(), AccountStatusRevoked.String()},
},
},
Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
Callback: b.pathAcmeMgmtReadAccount,
DisplayAttrs: &framework.DisplayAttributes{
OperationPrefix: operationPrefixPKI,
OperationSuffix: "acme-key-id",
},
},
logical.UpdateOperation: &framework.PathOperation{
Callback: b.pathAcmeMgmtUpdateAccount,
DisplayAttrs: &framework.DisplayAttributes{
OperationPrefix: operationPrefixPKI,
OperationSuffix: "acme-key-id",
},
},
},

HelpSynopsis: "Fetch the details or update the status of an ACME account by key identifier.",
HelpDescription: `Allows an operator to retrieve details of an ACME account and to update the account status.`,
}
}

func (b *backend) pathAcmeMgmtListAccounts(ctx context.Context, r *logical.Request, d *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, r.Storage)

accountIds, err := b.GetAcmeState().ListAccountIds(sc)
if err != nil {
return nil, err
}

return logical.ListResponse(accountIds), nil
}

func (b *backend) pathAcmeMgmtReadAccount(ctx context.Context, r *logical.Request, d *framework.FieldData) (*logical.Response, error) {
keyId := d.Get("keyid").(string)
if len(keyId) == 0 {
return logical.ErrorResponse("keyid is required"), logical.ErrInvalidRequest
}

sc := b.makeStorageContext(ctx, r.Storage)
as := b.GetAcmeState()

accountEntry, err := as.LoadAccountWithoutDirEnforcement(sc, keyId)
if err != nil {
if errors.Is(err, ErrAccountDoesNotExist) {
return logical.ErrorResponse("ACME key id %s did not exist", keyId), logical.ErrNotFound
}
return nil, fmt.Errorf("failed loading ACME account id %q: %w", keyId, err)
}

orders, err := as.LoadAccountOrders(sc, accountEntry.KeyId)
if err != nil {
return nil, fmt.Errorf("failed loading orders for account %q: %w", accountEntry.KeyId, err)
}

orderData := make([]map[string]interface{}, 0, len(orders))
for _, order := range orders {
orderData = append(orderData, acmeOrderToDataMap(order))
}

dataMap := acmeAccountToDataMap(accountEntry)
dataMap["orders"] = orderData
return &logical.Response{Data: dataMap}, nil
}

func (b *backend) pathAcmeMgmtUpdateAccount(ctx context.Context, r *logical.Request, d *framework.FieldData) (*logical.Response, error) {
keyId := d.Get("keyid").(string)
if len(keyId) == 0 {
return logical.ErrorResponse("keyid is required"), logical.ErrInvalidRequest
}

status, err := convertToAccountStatus(d.Get("status"))
if err != nil {
return logical.ErrorResponse(err.Error()), logical.ErrInvalidRequest
}
if status != AccountStatusValid && status != AccountStatusRevoked {
return logical.ErrorResponse("invalid status %q", status), logical.ErrInvalidRequest
}

sc := b.makeStorageContext(ctx, r.Storage)
as := b.GetAcmeState()

accountEntry, err := as.LoadAccountWithoutDirEnforcement(sc, keyId)
if err != nil {
if errors.Is(err, ErrAccountDoesNotExist) {
return logical.ErrorResponse("ACME key id %q did not exist", keyId), logical.ErrNotFound
}
return nil, fmt.Errorf("failed loading ACME account id %q: %w", keyId, err)
}

if accountEntry.Status != status {
accountEntry.Status = status

switch status {
case AccountStatusRevoked:
accountEntry.AccountRevokedDate = time.Now()
case AccountStatusValid:
accountEntry.AccountRevokedDate = time.Time{}
}

if err := as.UpdateAccount(sc, accountEntry); err != nil {
return nil, fmt.Errorf("failed saving account %q: %w", keyId, err)
}
}

dataMap := acmeAccountToDataMap(accountEntry)
return &logical.Response{Data: dataMap}, nil
}

func convertToAccountStatus(status any) (ACMEAccountStatus, error) {
if status == nil {
return "", fmt.Errorf("status is required")
}

statusStr, ok := status.(string)
if !ok {
return "", fmt.Errorf("status must be a string")
}

switch strings.ToLower(strings.TrimSpace(statusStr)) {
case AccountStatusValid.String():
return AccountStatusValid, nil
case AccountStatusRevoked.String():
return AccountStatusRevoked, nil
case AccountStatusDeactivated.String():
return AccountStatusDeactivated, nil
default:
return "", fmt.Errorf("invalid status %q", statusStr)
}
}

func acmeAccountToDataMap(accountEntry *acmeAccount) map[string]interface{} {
revokedDate := ""
if !accountEntry.AccountRevokedDate.IsZero() {
revokedDate = accountEntry.AccountRevokedDate.Format(time.RFC3339)
}

eab := map[string]string{}
if accountEntry.Eab != nil {
eab["eab_id"] = accountEntry.Eab.KeyID
eab["directory"] = accountEntry.Eab.AcmeDirectory
eab["created_time"] = accountEntry.Eab.CreatedOn.Format(time.RFC3339)
eab["key_type"] = accountEntry.Eab.KeyType
}

return map[string]interface{}{
"key_id": accountEntry.KeyId,
"status": accountEntry.Status,
"contacts": accountEntry.Contact,
"created_time": accountEntry.AccountCreatedDate.Format(time.RFC3339),
"revoked_time": revokedDate,
"directory": accountEntry.AcmeDirectory,
"eab": eab,
}
}

func acmeOrderToDataMap(order *acmeOrder) map[string]interface{} {
identifiers := make([]string, 0, len(order.Identifiers))
for _, identifier := range order.Identifiers {
identifiers = append(identifiers, identifier.Value)
}
var certExpiry string
if !order.CertificateExpiry.IsZero() {
certExpiry = order.CertificateExpiry.Format(time.RFC3339)
}
return map[string]interface{}{
"order_id": order.OrderId,
"status": string(order.Status),
"identifiers": identifiers,
"cert_serial_number": strings.ReplaceAll(order.CertificateSerialNumber, "-", ":"),
"cert_expiry": certExpiry,
"order_expiry": order.Expires.Format(time.RFC3339),
}
}
2 changes: 1 addition & 1 deletion builtin/logical/pki/path_acme_eab.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ a warning that it did not exist.`,
}

type eabType struct {
KeyID string `json:"-"`
KeyID string `json:"key-id"`
KeyType string `json:"key-type"`
PrivateBytes []byte `json:"private-bytes"`
AcmeDirectory string `json:"acme-directory"`
Expand Down
Loading

0 comments on commit adbad9e

Please sign in to comment.