Skip to content

Commit

Permalink
Merge pull request #252 from brave-intl/tlv2-comments
Browse files Browse the repository at this point in the history
Tlv2 comments
  • Loading branch information
husobee authored Oct 26, 2022
2 parents 6c1ed68 + 111fb32 commit 1f1b586
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 76 deletions.
6 changes: 4 additions & 2 deletions kafka/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ func StartConsumers(providedServer *server.Server, logger *zerolog.Logger) error
failureCount++
continue
}
logger.Info().Msgf("Processing message for topic %s at offset %d", msg.Topic, msg.Offset)
logger.Info().Msgf("Reader Stats: %#v", consumer.Stats())
logger.Debug().Msgf("Processing message for topic %s at offset %d", msg.Topic, msg.Offset)
logger.Debug().Msgf("Reader Stats: %#v", consumer.Stats())
logger.Debug().Msgf("topicMappings: %+v", topicMappings)
for _, topicMapping := range topicMappings {
logger.Debug().Msgf("topic: %+v, topicMapping: %+v", msg.Topic, topicMapping.Topic)
if msg.Topic == topicMapping.Topic {
go func(
msg kafka.Message,
Expand Down
49 changes: 42 additions & 7 deletions kafka/signed_blinded_token_issuer_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,19 @@ func SignedBlindedTokenIssuerHandler(data []byte, producer *kafka.Writer, server
issuerError = 2
)

log.Debug().Msg("starting blinded token processor")

log.Info().Msg("deserialize signing request")

blindedTokenRequestSet, err := avroSchema.DeserializeSigningRequestSet(bytes.NewReader(data))
if err != nil {
return fmt.Errorf("request %s: failed avro deserialization: %w", blindedTokenRequestSet.Request_id, err)
}

logger := log.With().Str("request_id", blindedTokenRequestSet.Request_id).Logger()

logger.Debug().Msg("processing blinded token request for request_id")

var blindedTokenResults []avroSchema.SigningResultV2
if len(blindedTokenRequestSet.Data) > 1 {
// NOTE: When we start supporting multiple requests we will need to review
Expand All @@ -44,6 +50,7 @@ func SignedBlindedTokenIssuerHandler(data []byte, producer *kafka.Writer, server

OUTER:
for _, request := range blindedTokenRequestSet.Data {
logger.Debug().Msgf("processing request: %+v", request)
if request.Blinded_tokens == nil {
logger.Error().Err(errors.New("blinded tokens is empty")).Msg("")
blindedTokenResults = append(blindedTokenResults, avroSchema.SigningResultV2{
Expand All @@ -56,6 +63,7 @@ OUTER:
}

// check to see if issuer cohort will overflow
logger.Debug().Msgf("checking request cohort: %+v", request)
if request.Issuer_cohort > math.MaxInt16 || request.Issuer_cohort < math.MinInt16 {
logger.Error().Msg("invalid cohort")
blindedTokenResults = append(blindedTokenResults, avroSchema.SigningResultV2{
Expand All @@ -67,6 +75,7 @@ OUTER:
break OUTER
}

logger.Debug().Msgf("getting latest issuer: %+v - %+v", request.Issuer_type, request.Issuer_cohort)
issuer, appErr := server.GetLatestIssuer(request.Issuer_type, int16(request.Issuer_cohort))
if appErr != nil {
logger.Error().Err(appErr).Msg("error retrieving issuer")
Expand All @@ -79,6 +88,7 @@ OUTER:
break OUTER
}

logger.Debug().Msgf("checking if issuer is version 3: %+v", issuer)
// if this is a time aware issuer, make sure the request contains the appropriate number of blinded tokens
if issuer.Version == 3 && issuer.Buffer > 0 {
if len(request.Blinded_tokens)%(issuer.Buffer+issuer.Overlap) != 0 {
Expand All @@ -93,10 +103,12 @@ OUTER:
}
}

logger.Debug().Msgf("checking blinded tokens: %+v", request.Blinded_tokens)
var blindedTokens []*crypto.BlindedToken
// Iterate over the provided tokens and create data structure from them,
// grouping into a slice for approval
for _, stringBlindedToken := range request.Blinded_tokens {
logger.Debug().Msgf("blinded token: %+v", stringBlindedToken)
blindedToken := crypto.BlindedToken{}
err := blindedToken.UnmarshalText([]byte(stringBlindedToken))
if err != nil {
Expand All @@ -113,24 +125,38 @@ OUTER:
blindedTokens = append(blindedTokens, &blindedToken)
}

logger.Debug().Msgf("checking if issuer is time aware: %+v - %+v", issuer.Version, issuer.Buffer)
// if the issuer is time aware, we need to approve tokens
if issuer.Version == 3 && issuer.Buffer > 0 {
// number of tokens per signing key
// Calculate the number of tokens per signing key.
// Given the mod check this should be a multiple of the total tokens in the request.
var numT = len(request.Blinded_tokens) / (issuer.Buffer + issuer.Overlap)
// sign tokens with all the keys in buffer+overlap
for i := issuer.Buffer + issuer.Overlap; i > 0; i-- {
count := 0
for i := 0; i < len(blindedTokens); i += numT {
count++

logger.Debug().Msgf("version 3 issuer: %+v , numT: %+v", issuer, numT)
var (
blindedTokensSlice []*crypto.BlindedToken
signingKey *crypto.SigningKey
validFrom string
validTo string
)

signingKey = issuer.Keys[len(issuer.Keys)-i].SigningKey
validFrom = issuer.Keys[len(issuer.Keys)-i].StartAt.Format(time.RFC3339)
validTo = issuer.Keys[len(issuer.Keys)-i].EndAt.Format(time.RFC3339)
signingKey = issuer.Keys[len(issuer.Keys)-count].SigningKey
validFrom = issuer.Keys[len(issuer.Keys)-count].StartAt.Format(time.RFC3339)
validTo = issuer.Keys[len(issuer.Keys)-count].EndAt.Format(time.RFC3339)

blindedTokensSlice = blindedTokens[(i - numT):i]
// Calculate the next step size to retrieve. Given previous checks end should never
// be greater than the total number of tokens.
end := i + numT
if end > len(blindedTokens) {
return fmt.Errorf("request %s: error invalid token step length",
blindedTokenRequestSet.Request_id)
}

// Get the next group of tokens and approve
blindedTokensSlice = blindedTokens[i:end]
signedTokens, DLEQProof, err := btd.ApproveTokens(blindedTokensSlice, signingKey)
if err != nil {
// @TODO: If one token fails they will all fail. Assess this behavior
Expand All @@ -145,6 +171,8 @@ OUTER:
break OUTER
}

logger.Debug().Msg("marshalling proof")

marshaledDLEQProof, err := DLEQProof.MarshalText()
if err != nil {
return fmt.Errorf("request %s: could not marshal dleq proof: %w", blindedTokenRequestSet.Request_id, err)
Expand All @@ -170,6 +198,7 @@ OUTER:
marshaledSignedTokens = append(marshaledSignedTokens, string(marshaledToken[:]))
}

logger.Debug().Msg("getting public key")
publicKey := signingKey.PublicKey()
marshaledPublicKey, err := publicKey.MarshalText()
if err != nil {
Expand All @@ -195,6 +224,7 @@ OUTER:
signingKey = issuer.Keys[len(issuer.Keys)-1].SigningKey
}

logger.Debug().Msgf("approving tokens: %+v", blindedTokens)
// @TODO: If one token fails they will all fail. Assess this behavior
signedTokens, DLEQProof, err := btd.ApproveTokens(blindedTokens, signingKey)
if err != nil {
Expand Down Expand Up @@ -256,6 +286,7 @@ OUTER:
Request_id: blindedTokenRequestSet.Request_id,
Data: blindedTokenResults,
}
logger.Debug().Msgf("resultSet: %+v", resultSet)

var resultSetBuffer bytes.Buffer
err = resultSet.Serialize(&resultSetBuffer)
Expand All @@ -264,11 +295,15 @@ OUTER:
blindedTokenRequestSet.Request_id, resultSetBuffer.String(), err)
}

logger.Debug().Msg("ending blinded token request processor loop")
logger.Debug().Msgf("about to emit: %+v", resultSet)
err = Emit(producer, resultSetBuffer.Bytes(), log)
if err != nil {
logger.Error().Msgf("failed to emit: %+v", resultSet)
return fmt.Errorf("request %s: failed to emit results to topic %s: %w",
blindedTokenRequestSet.Request_id, producer.Topic, err)
}
logger.Debug().Msgf("emitted: %+v", resultSet)

return nil
}
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func main() {
if os.Getenv("ENV") != "production" {
zerolog.SetGlobalLevel(zerolog.TraceLevel)
}
zerolog.SetGlobalLevel(zerolog.TraceLevel)

srv := *server.DefaultServer

Expand Down
9 changes: 9 additions & 0 deletions server/cron.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,14 @@ func (c *Server) SetupCronTasks() {
}); err != nil {
panic(err)
}
if _, err := cron.AddFunc(cadence, func() {
rows, err := c.deleteIssuerKeys("P1M")
if err != nil {
panic(err)
}
c.Logger.Infof("cron: delete issuers keys removed %d", rows)
}); err != nil {
panic(err)
}
cron.Start()
}
106 changes: 87 additions & 19 deletions server/db.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -118,7 +119,7 @@ type RedemptionV2 struct {
TTL int64 `json:"TTL"`
}

// CacheInterface cach functions
// CacheInterface cache functions
type CacheInterface interface {
Get(k string) (interface{}, bool)
Delete(k string)
Expand Down Expand Up @@ -397,6 +398,56 @@ func (c *Server) fetchIssuersByCohort(issuerType string, issuerCohort int16) (*[
return &issuers, nil
}

func (c *Server) fetchIssuerByType(ctx context.Context, issuerType string) (*Issuer, error) {
if c.caches != nil {
if cached, found := c.caches["issuer"].Get(issuerType); found {
// TODO: check this
return cached.(*Issuer), nil
}
}

var issuerV3 issuer
err := c.db.GetContext(ctx, &issuerV3,
`SELECT *
FROM v3_issuers
WHERE issuer_type=$1
ORDER BY expires_at DESC NULLS LAST, created_at DESC`, issuerType)
if err != nil {
return nil, err
}

convertedIssuer, err := c.convertDBIssuer(issuerV3)
if err != nil {
return nil, err
}

if convertedIssuer.Keys == nil {
convertedIssuer.Keys = []IssuerKeys{}
}

var fetchIssuerKeys []issuerKeys
err = c.db.SelectContext(ctx, &fetchIssuerKeys, `SELECT * FROM v3_issuer_keys where issuer_id=$1
ORDER BY end_at DESC NULLS LAST, start_at DESC`, issuerV3.ID)
if err != nil {
return nil, err
}

for _, v := range fetchIssuerKeys {
k, err := c.convertDBIssuerKeys(v)
if err != nil {
c.Logger.Error("Failed to convert issuer keys from DB")
return nil, err
}
convertedIssuer.Keys = append(convertedIssuer.Keys, *k)
}

if c.caches != nil {
c.caches["issuer"].SetDefault(issuerType, issuerV3)
}

return convertedIssuer, nil
}

func (c *Server) fetchIssuers(issuerType string) (*[]Issuer, error) {
if c.caches != nil {
if cached, found := c.caches["issuers"].Get(issuerType); found {
Expand Down Expand Up @@ -564,14 +615,14 @@ func (c *Server) rotateIssuers() error {
err = tx.Commit()
}()

fetchedIssuers := []issuer{}
var fetchedIssuers []issuer
err = tx.Select(
&fetchedIssuers,
`SELECT * FROM v3_issuers
WHERE expires_at IS NOT NULL
AND last_rotated_at < NOW() - $1 * INTERVAL '1 day'
AND expires_at < NOW() + $1 * INTERVAL '1 day'
AND version >= 2
AND version <= 2
FOR UPDATE SKIP LOCKED`, cfg.DefaultDaysBeforeExpiry,
)
if err != nil {
Expand Down Expand Up @@ -619,26 +670,24 @@ func (c *Server) rotateIssuersV3() error {

fetchedIssuers := []issuer{}

// we need to get all of the v3 issuers that
// 1. are not expired
// we need to get all the v3 issuers that are
// 1. not expired
// 2. now is after valid_from
// 3. have max(issuer_v3.end_at) < buffer

err = tx.Select(
&fetchedIssuers,
`
select
i.issuer_id, i.issuer_type, i.issuer_cohort, i.max_tokens, i.version,
i.buffer, i.valid_from, i.last_rotated_at, i.expires_at, i.duration,
i.created_at
from
v3_issuers i
join v3_issuer_keys ik on (ik.issuer_id = i.issuer_id)
where
i.version = 3
and i.expires_at is not null and i.expires_at < now()
and greatest(ik.end_at) < now() + i.buffer * i.duration::interval
for update skip locked
select
i.issuer_id, i.issuer_type, i.issuer_cohort, i.max_tokens, i.version,i.buffer, i.valid_from, i.last_rotated_at, i.expires_at, i.duration,i.created_at
from
v3_issuers i
where
i.version = 3 and
i.expires_at is not null and
i.expires_at < now()
and (select max(end_at) from v3_issuer_keys where issuer_id=i.issuer_id) < now() + i.buffer * i.duration::interval
for update skip locked
`,
)
if err != nil {
Expand Down Expand Up @@ -669,6 +718,21 @@ func (c *Server) rotateIssuersV3() error {
return nil
}

// deleteIssuerKeys deletes v3 issuers keys that have ended more than the duration ago.
func (c *Server) deleteIssuerKeys(duration string) (int64, error) {
result, err := c.db.Exec(`delete from v3_issuer_keys where issuer_id in (select issuer_id from v3_issuers where version = 3) and end_at < now() - $1::interval`, duration)
if err != nil {
return 0, fmt.Errorf("error deleting v3 issuer keys: %w", err)
}

rows, err := result.RowsAffected()
if err != nil {
return 0, fmt.Errorf("error deleting v3 issuer keys row affected: %w", err)
}

return rows, nil
}

// createIssuer - creation of a v3 issuer
func (c *Server) createV3Issuer(issuer Issuer) error {
defer incrementCounter(createIssuerCounter)
Expand Down Expand Up @@ -731,9 +795,11 @@ func txPopulateIssuerKeys(logger *logrus.Logger, tx *sqlx.Tx, issuer Issuer) err
err error
)

logger.Debug("checking if v3")
if issuer.Version == 3 {
// get the duration from the issuer
if issuer.Duration != nil {
logger.Debug("making sure duration is not nil")
duration, err = timeutils.ParseDuration(*issuer.Duration)
if err != nil {
return fmt.Errorf("failed to parse issuer duration: %w", err)
Expand Down Expand Up @@ -762,13 +828,14 @@ func txPopulateIssuerKeys(logger *logrus.Logger, tx *sqlx.Tx, issuer Issuer) err
start = &tmp
i = len(issuer.Keys)
}
logger.Debug("about to make the issuer keys")

valueFmtStr := ""

var keys []issuerKeys
var position = 0
// for i in buffer, create signing keys for each
for ; i < issuer.Buffer; i++ {
// Create signing keys for buffer and overlap
for ; i < issuer.Buffer+issuer.Overlap; i++ {
end := new(time.Time)
if duration != nil {
// start/end, increment every iteration
Expand Down Expand Up @@ -799,6 +866,7 @@ func txPopulateIssuerKeys(logger *logrus.Logger, tx *sqlx.Tx, issuer Issuer) err
tx.Rollback()
return err
}
logger.Infof("iteration key pubkey: %+v", pubKeyTxt)

tmpStart := *start
tmpEnd := *end
Expand Down
Loading

0 comments on commit 1f1b586

Please sign in to comment.