Skip to content

Commit

Permalink
Merge pull request #3 from sei-protocol/mj/bigint
Browse files Browse the repository at this point in the history
Use big.Int instead of uint64s
  • Loading branch information
mj850 authored Nov 25, 2024
2 parents 878abf6 + e6dd4ac commit ef88c69
Show file tree
Hide file tree
Showing 13 changed files with 230 additions and 179 deletions.
36 changes: 22 additions & 14 deletions pkg/encryption/aes.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"github.com/ethereum/go-ethereum/crypto/secp256k1"
"io"
"math/big"

"github.com/ethereum/go-ethereum/crypto/secp256k1"

"golang.org/x/crypto/hkdf"
)
Expand Down Expand Up @@ -49,8 +51,14 @@ func GetAESKey(privKey ecdsa.PrivateKey, denom string) ([]byte, error) {
return aesKey, nil
}

// EncryptAESGCM Key must be a len 32 byte array for AES-256
func EncryptAESGCM(value uint64, key []byte) (string, error) {
// EncryptAESGCM encrypts a big.Int value using AES-GCM with a 32-byte key.
// Key must be a len 32 byte array for AES-256
func EncryptAESGCM(value *big.Int, key []byte) (string, error) {
// Validate the key length
if len(key) != 32 {
return "", errors.New("key must be 32 bytes for AES-256")
}

// Create a GCM cipher mode instance
aesgcm, err := getCipher(key)
if err != nil {
Expand All @@ -63,45 +71,45 @@ func EncryptAESGCM(value uint64, key []byte) (string, error) {
return "", err
}

plaintext := make([]byte, 8)
// Convert the integer to []byte using BigEndian or LittleEndian
binary.BigEndian.PutUint64(plaintext, value)
// Serialize the big.Int value as a big-endian byte array
valueBytes := value.Bytes()

// Encrypt the data
ciphertext := aesgcm.Seal(nonce, nonce, plaintext, nil)
ciphertext := aesgcm.Seal(nonce, nonce, valueBytes, nil)

// Encode to Base64 for storage or transmission
return base64.StdEncoding.EncodeToString(ciphertext), nil
}

// DecryptAESGCM Key must be a len 32 byte array for AES-256
func DecryptAESGCM(ciphertextBase64 string, key []byte) (uint64, error) {
func DecryptAESGCM(ciphertextBase64 string, key []byte) (*big.Int, error) {
// Decode the Base64-encoded ciphertext
ciphertext, err := base64.StdEncoding.DecodeString(ciphertextBase64)
if err != nil {
return 0, err
return nil, err
}

// Create a GCM cipher mode instance
aesgcm, err := getCipher(key)
if err != nil {
return 0, err
return nil, err
}

// Extract the nonce
nonceSize := aesgcm.NonceSize()
if len(ciphertext) < nonceSize {
return 0, fmt.Errorf("ciphertext too short")
return nil, fmt.Errorf("ciphertext too short")
}
nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:]

// Decrypt the data
plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return 0, err
return nil, err
}

value := binary.BigEndian.Uint64(plaintext)
// Convert the plaintext (byte array) to a big.Int
value := new(big.Int).SetBytes(plaintext)

return value, nil
}
Expand Down
54 changes: 30 additions & 24 deletions pkg/encryption/aes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@ package encryption

import (
"crypto/ecdsa"
"math/big"
"testing"

"github.com/stretchr/testify/require"
)

const (
TestDenom = "factory/sei1239081236470/testToken"
TestKey = "examplekey12345678901234567890ab"
)

func TestGetAESKey(t *testing.T) {
tests := []struct {
name string
Expand All @@ -19,27 +25,27 @@ func TestGetAESKey(t *testing.T) {
{
name: "Deterministic Key Generation",
privateKey: generateTestKey(t),
denom: "factory/sei1239081236470/testToken",
denom: TestDenom,
expectEqual: true,
},
{
name: "Different Denom (Salt) Generates Different Key",
privateKey: generateTestKey(t),
denom: "factory/sei1239081236470/testToken",
anotherDenom: "factory/sei1239081236470/testToken1",
denom: TestDenom,
anotherDenom: TestDenom + "1",
expectEqual: false,
},
{
name: "Different Denom (Salt) of same length Generates Different Key",
privateKey: generateTestKey(t),
denom: "factory/sei1239081236470/testToken1",
anotherDenom: "factory/sei1239081236470/testToken2",
denom: TestDenom + "1",
anotherDenom: TestDenom + "2",
expectEqual: false,
},
{
name: "Different PrivateKey Generates Different Key",
privateKey: generateTestKey(t),
denom: "factory/sei1239081236470/testTokenN",
denom: TestDenom + "N",
anotherKey: generateTestKey(t),
expectEqual: false,
},
Expand Down Expand Up @@ -74,11 +80,11 @@ func TestGetAESKey(t *testing.T) {

func TestGetAESKey_InvalidInput(t *testing.T) {
// Nil private key
_, err := GetAESKey(*new(ecdsa.PrivateKey), "valid/denom")
_, err := GetAESKey(*new(ecdsa.PrivateKey), TestDenom)
require.Error(t, err, "Should return error for nil private key")

invalidPrivateKey := &ecdsa.PrivateKey{ /* Invalid key data */ }
_, err = GetAESKey(*invalidPrivateKey, "valid/denom")
_, err = GetAESKey(*invalidPrivateKey, TestDenom)
require.Error(t, err, "Should return error for invalid private key")

validPrivateKey := generateTestKey(t)
Expand All @@ -91,48 +97,48 @@ func TestAESEncryptionDecryption(t *testing.T) {
name string
key []byte
anotherKey []byte
value uint64
value *big.Int
expectError bool
decryptWithKey []byte
encryptAgain bool
}{
{
name: "Successful Encryption and Decryption",
key: []byte("examplekey12345678901234567890ab"), // 32 bytes for AES-256
value: 3023,
key: []byte(TestKey), // 32 bytes for AES-256
value: big.NewInt(3023),
expectError: false,
},
{
name: "Encryption Yields Different Ciphertext If Encrypted Again",
key: []byte("examplekey12345678901234567890ab"),
value: 3023,
key: []byte(TestKey),
value: big.NewInt(3023),
encryptAgain: true,
expectError: false,
},
{
name: "Different Key Produces Different Ciphertext",
key: []byte("examplekey12345678901234567890ab"),
key: []byte(TestKey),
anotherKey: []byte("randomkey12345678901234567890abc"), // 32 bytes for AES-256
value: 3023,
value: big.NewInt(3023),
expectError: false,
},
{
name: "Decryption with Wrong Key",
key: []byte("examplekey12345678901234567890ab"),
value: 3023,
key: []byte(TestKey),
value: big.NewInt(3023),
expectError: true,
decryptWithKey: []byte("wrongkey12345678901234567890ab"),
},
{
name: "Edge Case: Zero Value",
key: []byte("examplekey12345678901234567890ab"),
value: 0,
key: []byte(TestKey),
value: big.NewInt(0),
expectError: false,
},
{
name: "Edge Case: Maximum Uint64",
key: []byte("examplekey12345678901234567890ab"),
value: ^uint64(0),
name: "Maximum Uint64",
key: []byte(TestKey),
value: new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil).Sub(new(big.Int).Exp(big.NewInt(2), big.NewInt(256), nil), big.NewInt(1)), // 2^256 - 1
expectError: false,
},
}
Expand Down Expand Up @@ -185,7 +191,7 @@ func TestEncryptAESGCM_InvalidKeyLength(t *testing.T) {
[]byte("thiskeyiswaytoolongforaesgcm"), // Too long
}

value := uint64(1234)
value := big.NewInt(1234)

for _, key := range invalidKeys {
t.Run("InvalidKeyLength", func(t *testing.T) {
Expand All @@ -196,7 +202,7 @@ func TestEncryptAESGCM_InvalidKeyLength(t *testing.T) {
}

func TestDecryptAESGCM_InvalidCiphertext(t *testing.T) {
key := []byte("examplekey12345678901234567890ab")
key := []byte(TestKey)
invalidCiphertexts := [][]byte{
{}, // Empty ciphertext
[]byte("invalidciphertext"),
Expand Down
30 changes: 15 additions & 15 deletions pkg/encryption/elgamal/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package elgamal
import (
crand "crypto/rand"
"fmt"
"math/big"

"github.com/bwesterb/go-ristretto"
"github.com/coinbase/kryptology/pkg/core/curves"
"math/big"
)

type TwistedElGamal struct {
Expand All @@ -32,15 +33,15 @@ func NewTwistedElgamal() *TwistedElGamal {
}

// Encrypt encrypts a message using the public key pk.
func (teg TwistedElGamal) Encrypt(pk curves.Point, message uint64) (*Ciphertext, curves.Scalar, error) {
func (teg TwistedElGamal) Encrypt(pk curves.Point, message *big.Int) (*Ciphertext, curves.Scalar, error) {
// Generate a random scalar r
randomFactor := teg.curve.Scalar.Random(crand.Reader)

return teg.encryptWithRand(pk, message, randomFactor)
}

// EncryptWithRand encrypts a message using the public key pk and a given random factor.
func (teg TwistedElGamal) encryptWithRand(pk curves.Point, message uint64, randomFactor curves.Scalar) (*Ciphertext, curves.Scalar, error) {
func (teg TwistedElGamal) encryptWithRand(pk curves.Point, message *big.Int, randomFactor curves.Scalar) (*Ciphertext, curves.Scalar, error) {
if pk == nil {
return nil, nil, fmt.Errorf("invalid public key")
}
Expand All @@ -54,8 +55,7 @@ func (teg TwistedElGamal) encryptWithRand(pk curves.Point, message uint64, rando
G := teg.GetG()

// Convert message x (big.Int) to a scalar on the elliptic curve
bigIntMessage := new(big.Int).SetUint64(message)
x, _ := teg.curve.Scalar.SetBigInt(bigIntMessage)
x, _ := teg.curve.Scalar.SetBigInt(message)

// Compute the Pedersen commitment: C = r * H + x * G
rH := H.Mul(randomFactor) // r * H
Expand All @@ -72,15 +72,15 @@ func (teg TwistedElGamal) encryptWithRand(pk curves.Point, message uint64, rando
return &ciphertext, randomFactor, nil
}

// Decrypt decrypts the ciphertext ct using the private key sk = s.
// Decrypt decrypts the ciphertext ct using the private key sk = s. It can realistically only decrypt up to a maximum of a uint48.
// MaxBits denotes the maximum size of the decrypted message. The lower this can be set, the faster we can decrypt the message.
func (teg TwistedElGamal) Decrypt(sk curves.Scalar, ct *Ciphertext, maxBits MaxBits) (uint64, error) {
func (teg TwistedElGamal) Decrypt(sk curves.Scalar, ct *Ciphertext, maxBits MaxBits) (*big.Int, error) {
if sk == nil {
return 0, fmt.Errorf("invalid private key")
return nil, fmt.Errorf("invalid private key")
}

if ct == nil || ct.C == nil || ct.D == nil {
return 0, fmt.Errorf("invalid ciphertext")
return nil, fmt.Errorf("invalid ciphertext")
}

G := teg.GetG()
Expand Down Expand Up @@ -122,11 +122,11 @@ func (teg TwistedElGamal) Decrypt(sk curves.Scalar, ct *Ciphertext, maxBits MaxB
continue
}
xComputed := xHiMultiplied + i
return xComputed, nil
return big.NewInt(int64(xComputed)), nil
}
}

return 0, fmt.Errorf("could not find x")
return nil, fmt.Errorf("could not find x")
}

// updateIterMap Helper function to create large maps used by the decryption funciton.
Expand All @@ -151,14 +151,14 @@ func (teg TwistedElGamal) updateIterMap(maxBits MaxBits) {
// DecryptLargeNumber Optimistically decrypt up to a 48 bit number.
// Since creating the map for a 48 bit number takes a large amount of time, we work our way up in hopes that we find
// the answer before having to create the 48 bit map.
func (teg TwistedElGamal) DecryptLargeNumber(sk curves.Scalar, ct *Ciphertext, maxBits MaxBits) (uint64, error) {
func (teg TwistedElGamal) DecryptLargeNumber(sk curves.Scalar, ct *Ciphertext, maxBits MaxBits) (*big.Int, error) {
if maxBits > MaxBits48 {
return 0, fmt.Errorf("maxBits must be at most 48, provided (%d)", maxBits)
return nil, fmt.Errorf("maxBits must be at most 48, provided (%d)", maxBits)
}
values := []MaxBits{MaxBits16, MaxBits32, MaxBits40, MaxBits48}
for _, bits := range values {
if bits > maxBits {
return 0, fmt.Errorf("failed to find value")
return nil, fmt.Errorf("failed to find value")
}

res, err := teg.Decrypt(sk, ct, bits)
Expand All @@ -167,7 +167,7 @@ func (teg TwistedElGamal) DecryptLargeNumber(sk curves.Scalar, ct *Ciphertext, m
}
}

return 0, fmt.Errorf("failed to find value")
return nil, fmt.Errorf("failed to find value")
}

func getCompressedKeyString(key curves.Point) string {
Expand Down
Loading

0 comments on commit ef88c69

Please sign in to comment.