diff --git a/internal/bigmod/nat.go b/internal/bigmod/nat.go index 2cc5f0f..7ac6720 100644 --- a/internal/bigmod/nat.go +++ b/internal/bigmod/nat.go @@ -7,7 +7,6 @@ package bigmod import ( "encoding/binary" "errors" - "math/big" "math/bits" ) @@ -104,26 +103,34 @@ func (x *Nat) reset(n int) *Nat { return x } -// set assigns x = y, optionally resizing x to the appropriate size. -func (x *Nat) Set(y *Nat) *Nat { - x.reset(len(y.limbs)) - copy(x.limbs, y.limbs) - return x -} - -// SetBig assigns x = n, optionally resizing n to the appropriate size. +// resetToBytes assigns x = b, where b is a slice of big-endian bytes, resizing +// n to the appropriate size. // // The announced length of x is set based on the actual bit size of the input, // ignoring leading zeroes. -func (x *Nat) SetBig(n *big.Int) *Nat { - limbs := n.Bits() - x.reset(len(limbs)) - for i := range limbs { - x.limbs[i] = uint(limbs[i]) +func (x *Nat) resetToBytes(b []byte) *Nat { + x.reset((len(b) + _S - 1) / _S) + if err := x.setBytes(b); err != nil { + panic("bigmod: internal error: bad arithmetic") + } + // Trim most significant (trailing in little-endian) zero limbs. + // We assume comparison with zero (but not the branch) is constant time. + for i := len(x.limbs) - 1; i >= 0; i-- { + if x.limbs[i] != 0 { + break + } + x.limbs = x.limbs[:i] } return x } +// set assigns x = y, optionally resizing x to the appropriate size. +func (x *Nat) Set(y *Nat) *Nat { + x.reset(len(y.limbs)) + copy(x.limbs, y.limbs) + return x +} + // Bytes returns x as a zero-extended big-endian byte slice. The size of the // slice will match the size of m. // @@ -152,7 +159,8 @@ func (x *Nat) Bytes(m *Modulus) []byte { // // The output will be resized to the size of m and overwritten. func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { - if err := x.setBytes(b, m); err != nil { + x.resetFor(m) + if err := x.setBytes(b); err != nil { return nil, err } if x.CmpGeq(m.nat) == yes { @@ -167,7 +175,8 @@ func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { // // The output will be resized to the size of m and overwritten. func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { - if err := x.setBytes(b, m); err != nil { + x.resetFor(m) + if err := x.setBytes(b); err != nil { return nil, err } leading := _W - bitLen(x.limbs[len(x.limbs)-1]) @@ -178,6 +187,19 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { return x, nil } +// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes. +// +// The output will be resized to the size of m and overwritten. +func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat { + mMinusOne := NewNat().Set(m.nat) + mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1 + one := NewNat().resetFor(m) + one.limbs[0] = 1 + x.resetToBytes(b) + x = NewNat().modNat(x, mMinusOne) + return x.Add(one, m) +} + // bigEndianUint returns the contents of buf interpreted as a // big-endian encoded uint value. func bigEndianUint(buf []byte) uint { @@ -187,8 +209,7 @@ func bigEndianUint(buf []byte) uint { return uint(binary.BigEndian.Uint32(buf)) } -func (x *Nat) setBytes(b []byte, m *Modulus) error { - x.resetFor(m) +func (x *Nat) setBytes(b []byte) error { i, k := len(b), 0 for k < len(x.limbs) && i >= _S { x.limbs[k] = bigEndianUint(b[i-_S : i]) @@ -381,18 +402,16 @@ func minusInverseModW(x uint) uint { return -y } -// NewModulusFromBig creates a new Modulus from a [big.Int]. +// NewModulus creates a new Modulus from a slice of big-endian bytes. // -// The Int must be odd. The number of significant bits (and nothing else) is +// The value must be odd. The number of significant bits (and nothing else) is // leaked through timing side-channels. -func NewModulusFromBig(n *big.Int) (*Modulus, error) { - if b := n.Bits(); len(b) == 0 { - return nil, errors.New("modulus must be >= 0") - } else if b[0]&1 != 1 { - return nil, errors.New("modulus must be odd") +func NewModulus(b []byte) (*Modulus, error) { + if len(b) == 0 || b[len(b)-1]&1 != 1 { + return nil, errors.New("modulus must be > 0 and odd") } m := &Modulus{} - m.nat = NewNat().SetBig(n) + m.nat = NewNat().resetToBytes(b) m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1]) m.m0inv = minusInverseModW(m.nat.limbs[0]) m.rr = rr(m) @@ -478,7 +497,7 @@ func (x *Nat) shiftInNat(y uint, m *Nat) *Nat { // // The output will be resized to the size of m and overwritten. func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { - return out.ModNat(x, m.nat) + return out.modNat(x, m.nat) } // Mod calculates out = x mod m. @@ -486,7 +505,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { // This works regardless how large the value of x is. // // The output will be resized to the size of m and overwritten. -func (out *Nat) ModNat(x *Nat, m *Nat) *Nat { +func (out *Nat) modNat(x *Nat, m *Nat) *Nat { out.reset(len(m.limbs)) // Working our way from the most significant to the least significant limb, // we can insert each limb at the least significant position, shifting all @@ -683,7 +702,7 @@ func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { } copy(x.reset(n).limbs, T[n:]) x.maybeSubtractModulus(choice(c), m) - + case 1024 / _W: const n = 1024 / _W // compiler hint T := make([]uint, n*2) diff --git a/internal/bigmod/nat_test.go b/internal/bigmod/nat_test.go index 2821610..18516a6 100644 --- a/internal/bigmod/nat_test.go +++ b/internal/bigmod/nat_test.go @@ -5,6 +5,8 @@ package bigmod import ( + "bytes" + "encoding/hex" "fmt" "math/big" "math/bits" @@ -70,7 +72,7 @@ func TestMontgomeryRoundtrip(t *testing.T) { one.limbs[0] = 1 aPlusOne := new(big.Int).SetBytes(natBytes(a)) aPlusOne.Add(aPlusOne, big.NewInt(1)) - m, _ := NewModulusFromBig(aPlusOne) + m, _ := NewModulus(aPlusOne.Bytes()) monty := new(Nat).Set(a) monty.montgomeryRepresentation(m) aAgain := new(Nat).Set(monty) @@ -310,6 +312,19 @@ func TestExpShort(t *testing.T) { } } +// setBig assigns x = n, optionally resizing n to the appropriate size. +// +// The announced length of x is set based on the actual bit size of the input, +// ignoring leading zeroes. +func (x *Nat) setBig(n *big.Int) *Nat { + limbs := n.Bits() + x.reset(len(limbs)) + for i := range limbs { + x.limbs[i] = uint(limbs[i]) + } + return x +} + // TestMulReductions tests that Mul reduces results equal or slightly greater // than the modulus. Some Montgomery algorithms don't and need extra care to // return correct results. See https://go.dev/issue/13907. @@ -319,19 +334,19 @@ func TestMulReductions(t *testing.T) { b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10) n := new(big.Int).Mul(a, b) - N, _ := NewModulusFromBig(n) - A := NewNat().SetBig(a).ExpandFor(N) - B := NewNat().SetBig(b).ExpandFor(N) + N, _ := NewModulus(n.Bytes()) + A := NewNat().setBig(a).ExpandFor(N) + B := NewNat().setBig(b).ExpandFor(N) if A.Mul(B, N).IsZero() != 1 { t.Error("a * b mod (a * b) != 0") } i := new(big.Int).ModInverse(a, b) - N, _ = NewModulusFromBig(b) - A = NewNat().SetBig(a).ExpandFor(N) - I := NewNat().SetBig(i).ExpandFor(N) - one := NewNat().SetBig(big.NewInt(1)).ExpandFor(N) + N, _ = NewModulus(b.Bytes()) + A = NewNat().setBig(a).ExpandFor(N) + I := NewNat().setBig(i).ExpandFor(N) + one := NewNat().setBig(big.NewInt(1)).ExpandFor(N) if A.Mul(I, N).Equal(one) != 1 { t.Error("a * inv(a) mod b != 1") @@ -345,12 +360,12 @@ func natBytes(n *Nat) []byte { func natFromBytes(b []byte) *Nat { // Must not use Nat.SetBytes as it's used in TestSetBytes. bb := new(big.Int).SetBytes(b) - return NewNat().SetBig(bb) + return NewNat().setBig(bb) } func modulusFromBytes(b []byte) *Modulus { bb := new(big.Int).SetBytes(b) - m, _ := NewModulusFromBig(bb) + m, _ := NewModulus(bb.Bytes()) return m } @@ -359,7 +374,7 @@ func maxModulus(n uint) *Modulus { b := big.NewInt(1) b.Lsh(b, n*_W) b.Sub(b, big.NewInt(1)) - m, _ := NewModulusFromBig(b) + m, _ := NewModulus(b.Bytes()) return m } @@ -483,16 +498,56 @@ func BenchmarkExp(b *testing.B) { } } -func TestNewModFromBigZero(t *testing.T) { - expected := "modulus must be >= 0" - _, err := NewModulusFromBig(big.NewInt(0)) +func TestNewModulus(t *testing.T) { + expected := "modulus must be > 0 and odd" + _, err := NewModulus([]byte{}) if err == nil || err.Error() != expected { - t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected) + t.Errorf("NewModulus(0) got %q, want %q", err, expected) } - - expected = "modulus must be odd" - _, err = NewModulusFromBig(big.NewInt(2)) + _, err = NewModulus([]byte{0}) if err == nil || err.Error() != expected { - t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected) + t.Errorf("NewModulus(0) got %q, want %q", err, expected) + } + _, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) + if err == nil || err.Error() != expected { + t.Errorf("NewModulus(0) got %q, want %q", err, expected) + } + _, err = NewModulus([]byte{1, 1, 1, 1, 2}) + if err == nil || err.Error() != expected { + t.Errorf("NewModulus(2) got %q, want %q", err, expected) + } +} + +func TestOverflowedBytes(t *testing.T) { + cases := []string{ + "b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", + "b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf23", + "b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24", + "b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24b640000002a3a6f1", + "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", + "00", + } + mBytes, _ := hex.DecodeString(cases[0]) + m, err := NewModulus(mBytes) + if err != nil { + t.Fatal(err) + } + bigOne := big.NewInt(1) + mBigInt := new(big.Int).SetBytes(mBytes) + mMinusOne := new(big.Int).Sub(mBigInt, bigOne) + + for _, c := range cases { + d, _ := hex.DecodeString(c) + k := new(big.Int).SetBytes(d) + k = new(big.Int).Mod(k, mMinusOne) + k = new(big.Int).Add(k, bigOne) + k = new(big.Int).Mod(k, mBigInt) + + kNat := NewNat().SetOverflowedBytes(d, m) + k2 := new(big.Int).SetBytes(kNat.Bytes(m)) + + if !bytes.Equal(k2.Bytes(), k.Bytes()) { + t.Errorf("%s, expected %x, got %x", c, k.Bytes(), k2.Bytes()) + } } } diff --git a/sm2/sm2.go b/sm2/sm2.go index 94d8efc..df6a682 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -1062,8 +1062,8 @@ func p256() *sm2Curve { func precomputeParams(c *sm2Curve, curve elliptic.Curve) { params := curve.Params() c.curve = curve - c.N, _ = bigmod.NewModulusFromBig(params.N) - c.P, _ = bigmod.NewModulusFromBig(params.P) + c.N, _ = bigmod.NewModulus(params.N.Bytes()) + c.P, _ = bigmod.NewModulus(params.P.Bytes()) c.nMinus2 = new(big.Int).Sub(params.N, big.NewInt(2)).Bytes() c.nMinus1, _ = bigmod.NewNat().SetBytes(new(big.Int).Sub(params.N, big.NewInt(1)).Bytes(), c.N) } diff --git a/sm9/sm9.go b/sm9/sm9.go index 02c842f..a0f097f 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -19,15 +19,14 @@ import ( ) // SM9 ASN.1 format reference: Information security technology - SM9 cryptographic algorithm application specification - -var orderNat, _ = bigmod.NewModulusFromBig(bn256.Order) -var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes() -var bigOne = big.NewInt(1) -var bigOneNat *bigmod.Nat -var orderMinus1 = bigmod.NewNat().SetBig(new(big.Int).Sub(bn256.Order, bigOne)) +var ( + orderMinus2 []byte + orderNat *bigmod.Modulus +) func init() { - bigOneNat, _ = bigmod.NewNat().SetBytes(bigOne.Bytes(), orderNat) + orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes() + orderNat, _ = bigmod.NewModulus(bn256.Order.Bytes()) } type hashMode byte @@ -70,11 +69,7 @@ func hash(z []byte, h hashMode) *bigmod.Nat { md.Write(countBytes[:]) copy(ha[sm3.Size:], md.Sum(nil)) - k := new(big.Int).SetBytes(ha[:40]) - kNat := bigmod.NewNat().SetBig(k) - kNat = bigmod.NewNat().ModNat(kNat, orderMinus1) - kNat.Add(bigOneNat, orderNat) - return kNat + return bigmod.NewNat().SetOverflowedBytes(ha[:40], orderNat) } func hashH1(z []byte) *bigmod.Nat {