diff --git a/go.mod b/go.mod index a94a4e2f43..e532f4d475 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( require ( github.com/SaveTheRbtz/mph v0.1.2 github.com/k0kubun/pp v3.0.1+incompatible + github.com/onflow/crypto v0.25.0 golang.org/x/exp v0.0.0-20230321023759-10a507213a29 ) @@ -55,5 +56,6 @@ require ( golang.org/x/lint v0.0.0-20200302205851-738671d3881b // indirect golang.org/x/sys v0.6.0 // indirect golang.org/x/term v0.6.0 // indirect + gonum.org/v1/gonum v0.6.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index df32a1ac3a..178998405a 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/SaveTheRbtz/mph v0.1.2 h1:5l3W496Up+7BNOVJQnJhzcGBh+wWfxWdmPUAkx3WmaM= github.com/SaveTheRbtz/mph v0.1.2/go.mod h1:V4+WtKQPe2+dEA5os1WnGsEB0NR9qgqqgIiSt73+sT4= +github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/bits-and-blooms/bitset v1.5.0 h1:NpE8frKRLGHIcEzkR+gZhiioW1+WbYV6fKwD6ZIpQT8= github.com/bits-and-blooms/bitset v1.5.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= github.com/bytecodealliance/wasmtime-go/v7 v7.0.0 h1:/rBNjgFju2HCZnkPb1eL+W4GBwP8DMbaQu7i+GR9DH4= @@ -13,11 +14,14 @@ github.com/dave/jennifer v1.5.0 h1:HmgPN93bVDpkQyYbqhCHj5QlgvUkvEOzMyEvKLgCRrg= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/fxamacker/cbor/v2 v2.4.1-0.20230228173756-c0c9f774e40c h1:5tm/Wbs9d9r+qZaUFXk59CWDD0+77PBqDREffYkyi5c= github.com/fxamacker/cbor/v2 v2.4.1-0.20230228173756-c0c9f774e40c/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/fxamacker/circlehash v0.3.0 h1:XKdvTtIJV9t7DDUtsf0RIpC1OcxZtPbmgIH7ekx28WA= github.com/fxamacker/circlehash v0.3.0/go.mod h1:3aq3OfVvsWtkWMb6A1owjOQFA+TLsD5FgJflnaQwtMM= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88 h1:uC1QfSlInpQF+M0ao65imhwqKnz3Q2z/d8PWZRMQvDM= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= github.com/k0kubun/go-ansi v0.0.0-20180517002512-3bf9e2903213/go.mod h1:vNUNkEQ1e29fT/6vq2aBdFsgNPmy8qMdSay1npru+Sw= @@ -60,6 +64,8 @@ github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0 github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/onflow/atree v0.6.1-0.20230711151834-86040b30171f h1:Z8/PgTqOgOg02MTRpTBYO2k16FE6z4wEOtaC2WBR9Xo= github.com/onflow/atree v0.6.1-0.20230711151834-86040b30171f/go.mod h1:xvP61FoOs95K7IYdIYRnNcYQGf4nbF/uuJ0tHf4DRuM= +github.com/onflow/crypto v0.25.0 h1:BeWbLsh3ZD13Ej+Uky6kg1PL1ZIVBDVX+2MVBNwqddg= +github.com/onflow/crypto v0.25.0/go.mod h1:C8FbaX0x8y+FxWjbkHy0Q4EASCDR9bSPWZqlpCLYyVI= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/term v1.2.0-beta.2 h1:L3y/h2jkuBVFdWiJvNfYfKmzcCnILw7mJWm2JQuMppw= github.com/pkg/term v1.2.0-beta.2/go.mod h1:E25nymQcrSllhX42Ok8MRm1+hyBdHY0dCeiKZ9jpNGw= @@ -103,8 +109,12 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug= golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= +golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= @@ -134,6 +144,8 @@ golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= @@ -143,6 +155,12 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= +gonum.org/v1/gonum v0.6.1 h1:/LSrTrgZtpbXyAR6+0e152SROCkJJSh7goYWVmdPFGc= +gonum.org/v1/gonum v0.6.1/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= +gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b h1:QRR6H1YWRnHb4Y/HeNFCTJLFVxaq6wH4YuVdsUOr75U= @@ -150,3 +168,4 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= lukechampine.com/blake3 v1.1.7 h1:GgRMhmdsuK8+ii6UZFDL8Nb+VyMwadAgcJyfYHxG6n0= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index 2d713abda7..9a2daa5ce8 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -19,11 +19,11 @@ package runtime_test import ( - "encoding/binary" + "crypto/rand" "encoding/hex" "errors" "fmt" - "math" + "math/big" "sync" "sync/atomic" "testing" @@ -4523,139 +4523,229 @@ func TestRuntimeBlock(t *testing.T) { ) } -func TestRuntimeRandomWithUnsafeRandom(t *testing.T) { +func TestRuntimeRandom(t *testing.T) { t.Parallel() - runtime := NewTestInterpreterRuntime() - - script := []byte(` + transactionSource := ` transaction { prepare() { - let rand1 = revertibleRandom() - log(rand1) + let rand = revertibleRandom<%[1]s>(%[2]s) + log(rand) } } - `) - - var loggedMessages []string + ` + scriptSource := ` + access(all) fun main(): %[1]s { + let rand = revertibleRandom<%[1]s>(%[2]s) + return rand + } + ` - runtimeInterface := &TestRuntimeInterface{ - OnReadRandom: func(buffer []byte) error { - binary.BigEndian.PutUint64(buffer, 7558174677681708339) - return nil - }, - OnProgramLog: func(message string) { - loggedMessages = append(loggedMessages, message) - }, + // read randoms from `crypto/rand` when the random values + // do not matter in the rest of the test + readCryptoRandom := func(buffer []byte) error { + // random value does not matter in this test + _, err := rand.Read(buffer) + return err } - nextTransactionLocation := NewTransactionLocationGenerator() + executeScript := func( + ty sema.Type, + moduloArgument string, + randomGenerator func(buffer []byte) error, + ) (cadence.Value, error) { - err := runtime.ExecuteTransaction( - Script{ - Source: script, - }, - Context{ - Interface: runtimeInterface, - Location: nextTransactionLocation(), - }, - ) - require.NoError(t, err) + nextScriptLocation := NewScriptLocationGenerator() + runtime := NewTestInterpreterRuntime() - assert.Equal(t, - []string{ - "7558174677681708339", - }, - loggedMessages, - ) -} + if moduloArgument != "" { + // example "modulo: UInt8(77)" + moduloArgument = fmt.Sprintf("modulo: %s(%s)", ty.String(), moduloArgument) + } + return runtime.ExecuteScript( + Script{ + Source: []byte( + fmt.Sprintf(scriptSource, + ty.String(), + moduloArgument, + )), + }, + Context{ + Interface: &TestRuntimeInterface{ + OnReadRandom: randomGenerator, + }, + Location: nextScriptLocation(), + }, + ) + } + + testTypes := func(t *testing.T, testType func(*testing.T, sema.Type)) { + for _, ty := range sema.AllFixedSizeUnsignedIntegerTypes { + tyCopy := ty + t.Run(ty.String(), func(t *testing.T) { + t.Parallel() -func getFixedSizeUnsignedIntegerForSemaType(ty sema.Type) cadence.Value { - switch ty { - case sema.UInt8Type: - return cadence.NewUInt8(math.MaxUint8) - case sema.UInt16Type: - return cadence.NewUInt16(math.MaxUint16) - case sema.UInt32Type: - return cadence.NewUInt32(math.MaxUint32) - case sema.UInt64Type: - return cadence.NewUInt64(math.MaxUint64) - case sema.UInt128Type: - value, _ := cadence.NewUInt128FromBig(sema.UInt128TypeMaxIntBig) - return value - case sema.UInt256Type: - value, _ := cadence.NewUInt256FromBig(sema.UInt256TypeMaxIntBig) - return value - - case sema.Word8Type: - return cadence.NewWord8(math.MaxUint8) - case sema.Word16Type: - return cadence.NewWord16(math.MaxUint16) - case sema.Word32Type: - return cadence.NewWord32(math.MaxUint32) - case sema.Word64Type: - return cadence.NewWord64(math.MaxUint64) - case sema.Word128Type: - value, _ := cadence.NewWord128FromBig(sema.Word128TypeMaxIntBig) - return value - case sema.Word256Type: - value, _ := cadence.NewWord256FromBig(sema.Word256TypeMaxIntBig) - return value + testType(t, tyCopy) + }) + } } - panic(fmt.Sprintf("Broken test. Trying to get fixed size unsigned integer for ty: %s", ty)) -} + typeToBytes := func(t *testing.T, ty sema.Type) int { + require.IsType(t, &sema.NumericType{}, ty) + return ty.(*sema.NumericType).ByteSize() + } -func TestRuntimeRandom(t *testing.T) { + newRandBuffer := func(t *testing.T) []byte { + // `randBuffer` is the random source + randBuffer := make([]byte, 32) + _, err := rand.Read(randBuffer) + require.NoError(t, err) - t.Parallel() + return randBuffer + } - script := ` - access(all) fun main(): %[1]s { - let rand = revertibleRandom<%[1]s>() - return rand + newReadFromBuffer := func(readBuffer []byte) func(buffer []byte) error { + return func(buffer []byte) error { + // randoms are read from the random source + copy(buffer, readBuffer) + return nil } - ` + } - runValidCaseWithoutModulo := func(t *testing.T, ty sema.Type) { - t.Run(ty.String(), func(t *testing.T) { - t.Parallel() + // test based on a transaction, all other tests are script-based - test all types + t.Run("transaction without modulo", func(t *testing.T) { + t.Parallel() - nextScriptLocation := NewScriptLocationGenerator() + runValidCaseWithoutModulo := func(t *testing.T, ty sema.Type) { + + randBuffer := newRandBuffer(t) + + var loggedMessage string + runtimeInterface := &TestRuntimeInterface{ + OnReadRandom: newReadFromBuffer(randBuffer), + OnProgramLog: func(message string) { + loggedMessage = message + }, + } runtime := NewTestInterpreterRuntime() - value, err := runtime.ExecuteScript( + + nextTransactionLocation := NewTransactionLocationGenerator() + err := runtime.ExecuteTransaction( Script{ - Source: []byte(fmt.Sprintf(script, ty.String())), + Source: []byte( + fmt.Sprintf(transactionSource, ty.String(), ""), + ), }, Context{ - Interface: &TestRuntimeInterface{ - OnReadRandom: func(buffer []byte) error { - for i := 0; i < len(buffer); i++ { - buffer[i] = 0xff - } - return nil - }, - }, - Location: nextScriptLocation(), + Interface: runtimeInterface, + Location: nextTransactionLocation(), }, ) require.NoError(t, err) - require.Equal(t, getFixedSizeUnsignedIntegerForSemaType(ty), value) + // prepare the expected value from the random source + expected := new(big.Int).SetBytes(randBuffer[:typeToBytes(t, ty)]) + assert.Equal(t, expected.String(), loggedMessage) + } + testTypes(t, runValidCaseWithoutModulo) + }) + + // no modulo is passed - test all types + t.Run("script without modulo", func(t *testing.T) { + t.Parallel() + + runValidCaseWithoutModulo := func(t *testing.T, ty sema.Type) { + randBuffer := newRandBuffer(t) + + value, err := executeScript(ty, "", newReadFromBuffer(randBuffer)) + require.NoError(t, err) + // prepare the expected value from the random source + expected := new(big.Int).SetBytes(randBuffer[:typeToBytes(t, ty)]) + assert.Equal(t, expected.String(), value.String()) + } + testTypes(t, runValidCaseWithoutModulo) + }) + + // random modulo is passed as the modulo argument - test all types + t.Run("script with modulo all types", func(t *testing.T) { + t.Parallel() + + runValidCaseWithModulo := func(t *testing.T, ty sema.Type) { + moduloBuffer := newRandBuffer(t) + + // build a big Int from the modulo buffer, with the required `ty` size + // big.Int are used as they cover all the tested types including the small ones (UInt8 ..) + modulo := new(big.Int).SetBytes(moduloBuffer[:typeToBytes(t, ty)]) + + value, err := executeScript(ty, modulo.String(), readCryptoRandom) + require.NoError(t, err) + // convert `value` to big Int for comparison + valueBig, ok := new(big.Int).SetString(value.String(), 10) + require.True(t, ok) + // check that modulo > value + require.Equal(t, 1, modulo.Cmp(valueBig)) + } + testTypes(t, runValidCaseWithModulo) + }) + + // test valid edge cases of the value modulo - test all types + t.Run("script with modulo edge cases all types", func(t *testing.T) { + + t.Run("max modulo", func(t *testing.T) { + t.Parallel() + + // case where modulo is the max value of the type + runValidCaseWithMaxModulo := func(t *testing.T, ty sema.Type) { + + // set modulo to the max value of the type: (1 << bitSize) - 1 + // big.Int are used as they cover all the tested types including the small ones (UInt8 ..) + bitSize := typeToBytes(t, ty) << 3 + one := big.NewInt(1) + modulo := new(big.Int).Lsh(one, uint(bitSize)) + modulo.Sub(modulo, one) + + value, err := executeScript(ty, modulo.String(), readCryptoRandom) + require.NoError(t, err) + // convert `value` to big Int for comparison + valueBig, ok := new(big.Int).SetString(value.String(), 10) + require.True(t, ok) + // check that modulo > value + require.Equal(t, 1, modulo.Cmp(valueBig)) + } + testTypes(t, runValidCaseWithMaxModulo) }) - } - for _, ty := range sema.AllFixedSizeUnsignedIntegerTypes { - switch ty { - case sema.FixedSizeUnsignedIntegerType: - continue + t.Run("one modulo", func(t *testing.T) { + t.Parallel() + + // case where modulo is 1 and expected value in 0 + runValidCaseWithOneModulo := func(t *testing.T, ty sema.Type) { + // set modulo to 1 + value, err := executeScript(ty, "1", readCryptoRandom) + require.NoError(t, err) + + // check that value is zero + require.Equal(t, "0", value.String()) + } - default: - runValidCaseWithoutModulo(t, ty) + testTypes(t, runValidCaseWithOneModulo) + }) + }) + + // function should error if zero is used as modulo - test all types + t.Run("script with zero modulo", func(t *testing.T) { + t.Parallel() + + runCaseWithZeroModulo := func(t *testing.T, ty sema.Type) { + // set modulo to "0" + _, err := executeScript(ty, "0", readCryptoRandom) + assertUserError(t, err) + require.ErrorContains(t, err, stdlib.ZeroModuloError.Error()) } - } + testTypes(t, runCaseWithZeroModulo) + }) } func TestRuntimeTransactionTopLevelDeclarations(t *testing.T) { diff --git a/runtime/sema/type.go b/runtime/sema/type.go index 9419b7f759..d850907bb3 100644 --- a/runtime/sema/type.go +++ b/runtime/sema/type.go @@ -1159,6 +1159,7 @@ type SaturatingArithmeticSupport struct { type NumericType struct { minInt *big.Int maxInt *big.Int + byteSize int memberResolvers map[string]MemberResolver name string tag TypeTag @@ -1190,6 +1191,11 @@ func (t *NumericType) WithIntRange(min *big.Int, max *big.Int) *NumericType { return t } +func (t *NumericType) WithByteSize(size int) *NumericType { + t.byteSize = size + return t +} + func (t *NumericType) WithSaturatingFunctions(saturatingArithmetic SaturatingArithmeticSupport) *NumericType { t.saturatingArithmetic = saturatingArithmetic @@ -1291,6 +1297,10 @@ func (t *NumericType) MaxInt() *big.Int { return t.maxInt } +func (t *NumericType) ByteSize() int { + return t.byteSize +} + func (*NumericType) Unify( _ Type, _ *TypeParameterTypeOrderedMap, @@ -1599,6 +1609,7 @@ var ( Int8Type = NewNumericType(Int8TypeName). WithTag(Int8TypeTag). WithIntRange(Int8TypeMinInt, Int8TypeMaxInt). + WithByteSize(1). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1612,6 +1623,7 @@ var ( Int16Type = NewNumericType(Int16TypeName). WithTag(Int16TypeTag). WithIntRange(Int16TypeMinInt, Int16TypeMaxInt). + WithByteSize(2). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1625,6 +1637,7 @@ var ( Int32Type = NewNumericType(Int32TypeName). WithTag(Int32TypeTag). WithIntRange(Int32TypeMinInt, Int32TypeMaxInt). + WithByteSize(4). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1638,6 +1651,7 @@ var ( Int64Type = NewNumericType(Int64TypeName). WithTag(Int64TypeTag). WithIntRange(Int64TypeMinInt, Int64TypeMaxInt). + WithByteSize(8). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1651,6 +1665,7 @@ var ( Int128Type = NewNumericType(Int128TypeName). WithTag(Int128TypeTag). WithIntRange(Int128TypeMinIntBig, Int128TypeMaxIntBig). + WithByteSize(16). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1664,6 +1679,7 @@ var ( Int256Type = NewNumericType(Int256TypeName). WithTag(Int256TypeTag). WithIntRange(Int256TypeMinIntBig, Int256TypeMaxIntBig). + WithByteSize(32). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1688,6 +1704,7 @@ var ( UInt8Type = NewNumericType(UInt8TypeName). WithTag(UInt8TypeTag). WithIntRange(UInt8TypeMinInt, UInt8TypeMaxInt). + WithByteSize(1). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1701,6 +1718,7 @@ var ( UInt16Type = NewNumericType(UInt16TypeName). WithTag(UInt16TypeTag). WithIntRange(UInt16TypeMinInt, UInt16TypeMaxInt). + WithByteSize(2). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1714,6 +1732,7 @@ var ( UInt32Type = NewNumericType(UInt32TypeName). WithTag(UInt32TypeTag). WithIntRange(UInt32TypeMinInt, UInt32TypeMaxInt). + WithByteSize(4). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1727,6 +1746,7 @@ var ( UInt64Type = NewNumericType(UInt64TypeName). WithTag(UInt64TypeTag). WithIntRange(UInt64TypeMinInt, UInt64TypeMaxInt). + WithByteSize(8). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1740,6 +1760,7 @@ var ( UInt128Type = NewNumericType(UInt128TypeName). WithTag(UInt128TypeTag). WithIntRange(UInt128TypeMinIntBig, UInt128TypeMaxIntBig). + WithByteSize(16). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1753,6 +1774,7 @@ var ( UInt256Type = NewNumericType(UInt256TypeName). WithTag(UInt256TypeTag). WithIntRange(UInt256TypeMinIntBig, UInt256TypeMaxIntBig). + WithByteSize(32). WithSaturatingFunctions(SaturatingArithmeticSupport{ Add: true, Subtract: true, @@ -1765,6 +1787,7 @@ var ( // which does NOT check for overflow and underflow Word8Type = NewNumericType(Word8TypeName). WithTag(Word8TypeTag). + WithByteSize(1). WithIntRange(Word8TypeMinInt, Word8TypeMaxInt) Word8TypeAnnotation = NewTypeAnnotation(Word8Type) @@ -1773,6 +1796,7 @@ var ( // which does NOT check for overflow and underflow Word16Type = NewNumericType(Word16TypeName). WithTag(Word16TypeTag). + WithByteSize(2). WithIntRange(Word16TypeMinInt, Word16TypeMaxInt) Word16TypeAnnotation = NewTypeAnnotation(Word16Type) @@ -1781,6 +1805,7 @@ var ( // which does NOT check for overflow and underflow Word32Type = NewNumericType(Word32TypeName). WithTag(Word32TypeTag). + WithByteSize(4). WithIntRange(Word32TypeMinInt, Word32TypeMaxInt) Word32TypeAnnotation = NewTypeAnnotation(Word32Type) @@ -1789,6 +1814,7 @@ var ( // which does NOT check for overflow and underflow Word64Type = NewNumericType(Word64TypeName). WithTag(Word64TypeTag). + WithByteSize(8). WithIntRange(Word64TypeMinInt, Word64TypeMaxInt) Word64TypeAnnotation = NewTypeAnnotation(Word64Type) @@ -1797,6 +1823,7 @@ var ( // which does NOT check for overflow and underflow Word128Type = NewNumericType(Word128TypeName). WithTag(Word128TypeTag). + WithByteSize(16). WithIntRange(Word128TypeMinIntBig, Word128TypeMaxIntBig) Word128TypeAnnotation = NewTypeAnnotation(Word128Type) @@ -1805,6 +1832,7 @@ var ( // which does NOT check for overflow and underflow Word256Type = NewNumericType(Word256TypeName). WithTag(Word256TypeTag). + WithByteSize(32). WithIntRange(Word256TypeMinIntBig, Word256TypeMaxIntBig) Word256TypeAnnotation = NewTypeAnnotation(Word256Type) diff --git a/runtime/stdlib/random.go b/runtime/stdlib/random.go index 76a813f889..46e6c821d9 100644 --- a/runtime/stdlib/random.go +++ b/runtime/stdlib/random.go @@ -22,6 +22,7 @@ import ( "encoding/binary" "math/big" + "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/errors" "github.com/onflow/cadence/runtime/interpreter" "github.com/onflow/cadence/runtime/sema" @@ -69,20 +70,18 @@ type RandomGenerator interface { ReadRandom([]byte) error } -func getRandomBytes(generator RandomGenerator, numBytes int) []byte { - buffer := make([]byte, numBytes) - +func getRandomBytes(buffer []byte, generator RandomGenerator) { var err error errors.WrapPanic(func() { - err = generator.ReadRandom(buffer[:]) + err = generator.ReadRandom(buffer) }) if err != nil { panic(interpreter.WrappedExternalError(err)) } - - return buffer } +var ZeroModuloError = errors.NewDefaultUserError("modulo argument cannot be zero") + func NewRevertibleRandomFunction(generator RandomGenerator) StandardLibraryValue { return NewStandardLibraryFunction( "revertibleRandom", @@ -91,107 +90,343 @@ func NewRevertibleRandomFunction(generator RandomGenerator) StandardLibraryValue func(invocation interpreter.Invocation) interpreter.Value { inter := invocation.Interpreter - // TODO: Check if invocation has an argument and implement modulo operation. - returnIntegerType := invocation.TypeParameterTypes.Oldest().Value - switch returnIntegerType { - // UInt* - case sema.UInt8Type: - return interpreter.NewUInt8Value( - inter, - func() uint8 { - return getRandomBytes(generator, 1)[0] - }, - ) - case sema.UInt16Type: - return interpreter.NewUInt16Value( - inter, - func() uint16 { - return binary.BigEndian.Uint16(getRandomBytes(generator, 2)) - }, - ) - case sema.UInt32Type: - return interpreter.NewUInt32Value( - inter, - func() uint32 { - return binary.BigEndian.Uint32(getRandomBytes(generator, 4)) - }, - ) - case sema.UInt64Type: - return interpreter.NewUInt64Value( - inter, - func() uint64 { - return binary.BigEndian.Uint64(getRandomBytes(generator, 8)) - }, - ) - case sema.UInt128Type: - return interpreter.NewUInt128ValueFromBigInt( - inter, - func() *big.Int { - buffer := getRandomBytes(generator, 16) - return interpreter.BigEndianBytesToUnsignedBigInt(buffer) - }, - ) - case sema.UInt256Type: - return interpreter.NewUInt256ValueFromBigInt( - inter, - func() *big.Int { - buffer := getRandomBytes(generator, 32) - return interpreter.BigEndianBytesToUnsignedBigInt(buffer) - }, - ) - - // Word* - case sema.Word8Type: - return interpreter.NewWord8Value( - inter, - func() uint8 { - return getRandomBytes(generator, 1)[0] - }, - ) - case sema.Word16Type: - return interpreter.NewWord16Value( - inter, - func() uint16 { - return binary.BigEndian.Uint16(getRandomBytes(generator, 2)) - }, - ) - case sema.Word32Type: - return interpreter.NewWord32Value( - inter, - func() uint32 { - return binary.BigEndian.Uint32(getRandomBytes(generator, 4)) - }, - ) - case sema.Word64Type: - return interpreter.NewWord64Value( - inter, - func() uint64 { - return binary.BigEndian.Uint64(getRandomBytes(generator, 8)) - }, - ) - case sema.Word128Type: - return interpreter.NewWord128ValueFromBigInt( - inter, - func() *big.Int { - buffer := getRandomBytes(generator, 16) - return interpreter.BigEndianBytesToUnsignedBigInt(buffer) - }, - ) - case sema.Word256Type: - return interpreter.NewWord256ValueFromBigInt( - inter, - func() *big.Int { - buffer := getRandomBytes(generator, 32) - return interpreter.BigEndianBytesToUnsignedBigInt(buffer) - }, - ) - - default: - // Checker should prevent this. - panic(errors.NewUnreachableError()) + // arguments should be 0 or 1 at this point + var moduloValue interpreter.Value + if len(invocation.Arguments) == 1 { + moduloValue = invocation.Arguments[0] } + + return RevertibleRandom( + generator, + inter, + returnIntegerType, + moduloValue, + ) }, ) } + +func RevertibleRandom( + generator RandomGenerator, + memoryGauge common.MemoryGauge, + returnIntegerType sema.Type, + moduloValue interpreter.Value, +) interpreter.Value { + switch returnIntegerType { + // UInt* + case sema.UInt8Type: + randomUint64 := getUint64RandomNumber(generator, returnIntegerType, moduloValue) + return interpreter.NewUInt8Value( + memoryGauge, + func() uint8 { + return uint8(randomUint64) + }, + ) + case sema.UInt16Type: + randomUint64 := getUint64RandomNumber(generator, returnIntegerType, moduloValue) + return interpreter.NewUInt16Value( + memoryGauge, + func() uint16 { + return uint16(randomUint64) + }, + ) + case sema.UInt32Type: + randomUint64 := getUint64RandomNumber(generator, returnIntegerType, moduloValue) + return interpreter.NewUInt32Value( + memoryGauge, + func() uint32 { + return uint32(randomUint64) + }, + ) + case sema.UInt64Type: + randomUint64 := getUint64RandomNumber(generator, returnIntegerType, moduloValue) + return interpreter.NewUInt64Value( + memoryGauge, + func() uint64 { + return randomUint64 + }, + ) + case sema.UInt128Type: + randomBig := getBigRandomNumber(generator, memoryGauge, returnIntegerType, moduloValue) + return interpreter.NewUInt128ValueFromBigInt( + memoryGauge, + func() *big.Int { + return randomBig + }, + ) + case sema.UInt256Type: + randomBig := getBigRandomNumber(generator, memoryGauge, returnIntegerType, moduloValue) + return interpreter.NewUInt256ValueFromBigInt( + memoryGauge, + func() *big.Int { + return randomBig + }, + ) + + // Word* + case sema.Word8Type: + randomUint64 := getUint64RandomNumber(generator, returnIntegerType, moduloValue) + return interpreter.NewWord8Value( + memoryGauge, + func() uint8 { + return uint8(randomUint64) + }, + ) + case sema.Word16Type: + randomUint64 := getUint64RandomNumber(generator, returnIntegerType, moduloValue) + return interpreter.NewWord16Value( + memoryGauge, + func() uint16 { + return uint16(randomUint64) + }, + ) + case sema.Word32Type: + randomUint64 := getUint64RandomNumber(generator, returnIntegerType, moduloValue) + return interpreter.NewWord32Value( + memoryGauge, + func() uint32 { + return uint32(randomUint64) + }, + ) + case sema.Word64Type: + randomUint64 := getUint64RandomNumber(generator, returnIntegerType, moduloValue) + return interpreter.NewWord64Value( + memoryGauge, + func() uint64 { + return randomUint64 + }, + ) + case sema.Word128Type: + randomBig := getBigRandomNumber(generator, memoryGauge, returnIntegerType, moduloValue) + return interpreter.NewWord128ValueFromBigInt( + memoryGauge, + func() *big.Int { + return randomBig + }, + ) + case sema.Word256Type: + randomBig := getBigRandomNumber(generator, memoryGauge, returnIntegerType, moduloValue) + return interpreter.NewWord256ValueFromBigInt( + memoryGauge, + func() *big.Int { + return randomBig + }, + ) + + default: + // Checker should prevent this. + panic(errors.NewUnreachableError()) + } +} + +// cases of a random number of size 8 bytes or less can be all treated +// by the same function, based on the uint64 type. +// Although the final output is a `uint64`, it can be safely +// casted into the desired output type because the extra bytes are guaranteed +// to be zeros. +func getUint64RandomNumber( + generator RandomGenerator, + ty sema.Type, + moduloArg interpreter.Value, +) uint64 { + + // buffer to get random bytes from the generator + // 8 is the size of the largest type supported, it is also the size needed for + // the `binary.BigEndian.Uint64` call + const bufferSize = 8 + var buffer [bufferSize]byte + + // case where no modulo argument was provided + if moduloArg == nil { + numericType, ok := ty.(*sema.NumericType) + if !ok { + // checker should prevent this + panic(errors.NewUnreachableError()) + } + bytes := numericType.ByteSize() + getRandomBytes(buffer[bufferSize-bytes:], generator) + return binary.BigEndian.Uint64(buffer[:]) + } + + var ok bool + var modulo uint64 + + switch ty { + case sema.UInt8Type: + var moduloVal interpreter.UInt8Value + moduloVal, ok = moduloArg.(interpreter.UInt8Value) + modulo = uint64(moduloVal) + case sema.UInt16Type: + var moduloVal interpreter.UInt16Value + moduloVal, ok = moduloArg.(interpreter.UInt16Value) + modulo = uint64(moduloVal) + case sema.UInt32Type: + var moduloVal interpreter.UInt32Value + moduloVal, ok = moduloArg.(interpreter.UInt32Value) + modulo = uint64(moduloVal) + case sema.UInt64Type: + var moduloVal interpreter.UInt64Value + moduloVal, ok = moduloArg.(interpreter.UInt64Value) + modulo = uint64(moduloVal) + case sema.Word8Type: + var moduloVal interpreter.Word8Value + moduloVal, ok = moduloArg.(interpreter.Word8Value) + modulo = uint64(moduloVal) + case sema.Word16Type: + var moduloVal interpreter.Word16Value + moduloVal, ok = moduloArg.(interpreter.Word16Value) + modulo = uint64(moduloVal) + case sema.Word32Type: + var moduloVal interpreter.Word32Value + moduloVal, ok = moduloArg.(interpreter.Word32Value) + modulo = uint64(moduloVal) + case sema.Word64Type: + var moduloVal interpreter.Word64Value + moduloVal, ok = moduloArg.(interpreter.Word64Value) + modulo = uint64(moduloVal) + default: + // sanity check: shouldn't reach here + panic(errors.NewUnreachableError()) + } + + if !ok { + // checker should prevent this + panic(errors.NewUnreachableError()) + } + + // user error if modulo is zero + if modulo == 0 { + panic(ZeroModuloError) + } + + // `max` is the maximum value that can be returned + max := modulo - 1 + // get a bit mask (0b11..11) that covers all `max` bits, + // and count the byte size of `max` + mask := uint64(0) + bitSize := 0 + for max&mask != max { + bitSize++ + mask = (mask << 1) | 1 + } + byteSize := (bitSize + 7) >> 3 + + // Generate a number less or equal than `max`. + // use the reject-sample method to avoid the modulo bias. + // the function isn't constant-time in this case and may take longer than computing + // a modular reduction. + // However, sampling exactly the size of `max` in bits makes the loop return fast: + // the probability of the loop running for (k) iterations is at most (1/2)^k. + // + // (a different approach would be to pull 128 bits more bits than the size of `max` + // from the random generator and use big number reduction by `modulo`) + for { + // only generate `byteSize` random bytes + getRandomBytes(buffer[bufferSize-byteSize:], generator) + // big endianness must be used in this case + random := binary.BigEndian.Uint64(buffer[:]) + // truncate to the bit size of `max` + random &= mask + if random <= max { + return random + } + } +} + +// cases of a random number of size larger than 8 bytes can be all treated +// by the same function, based on the big.Int type. +func getBigRandomNumber( + generator RandomGenerator, + gauge common.MemoryGauge, + ty sema.Type, + moduloArg interpreter.Value, +) *big.Int { + + // get the numeric type byte size + numericType, ok := ty.(*sema.NumericType) + if !ok { + // checker should prevent this + panic(errors.NewUnreachableError()) + } + bytes := numericType.ByteSize() + // buffer to get random bytes from the generator + common.UseMemory(gauge, common.NewBytesMemoryUsage(bytes)) + buffer := make([]byte, bytes) + + // case where no modulo argument was provided + if moduloArg == nil { + getRandomBytes(buffer, generator) + // SetBytes considers big endianness (although little endian could be used too) + common.UseMemory(gauge, common.NewBigIntMemoryUsage(len(buffer))) + return new(big.Int).SetBytes(buffer) + } + + var modulo *big.Int + switch ty { + case sema.UInt128Type: + var moduloVal interpreter.UInt128Value + moduloVal, ok = moduloArg.(interpreter.UInt128Value) + modulo = moduloVal.BigInt + case sema.UInt256Type: + var moduloVal interpreter.UInt256Value + moduloVal, ok = moduloArg.(interpreter.UInt256Value) + modulo = moduloVal.BigInt + case sema.Word128Type: + var moduloVal interpreter.Word128Value + moduloVal, ok = moduloArg.(interpreter.Word128Value) + modulo = moduloVal.BigInt + case sema.Word256Type: + var moduloVal interpreter.Word256Value + moduloVal, ok = moduloArg.(interpreter.Word256Value) + modulo = moduloVal.BigInt + default: + // sanity check: shouldn't reach here + panic(errors.NewUnreachableError()) + } + + if !ok { + // checker should prevent this + panic(errors.NewUnreachableError()) + } + + // user error if modulo is zero + if modulo.Sign() == 0 { + panic(ZeroModuloError) + } + + // `max` is the maximum value that can be returned (modulo - 1) + one := big.NewInt(1) + max := new(big.Int).Sub(modulo, one) + // count the byte size of `max` + bitSize := max.BitLen() + byteSize := (bitSize + 7) >> 3 + // get a bit mask (0b11..11) that covers all `max`'s bits: + // `mask` can be computed as: (1 << bitSize) -1 + mask := new(big.Int).Lsh(one, uint(bitSize)) + mask.Sub(mask, one) + + // Generate a number less or equal than `max` + // use the reject-sample method to avoid the modulo bias. + // the function isn't constant-time in this case and may take longer than computing + // a modular reduction. + // However, sampling exactly the size of `max` in bits makes the loop return fast: + // the probability of the loop running for (k) iterations is at most (1/2)^k. + // + // (a different approach would be to pull 128 bits more bits than the size of `max` + // from the random generator and use big number reduction by `modulo`) + common.UseMemory(gauge, common.NewBigIntMemoryUsage(byteSize)) + random := new(big.Int) + for { + // only generate `byteSize` random bytes + getRandomBytes(buffer[:byteSize], generator) + // big endianness is used for consistency (but little can be used too) + random.SetBytes(buffer[:byteSize]) + // truncate to the bit size of `max` + random.And(random, mask) + if random.Cmp(max) <= 0 { + return random + } + } +} diff --git a/runtime/stdlib/random_test.go b/runtime/stdlib/random_test.go new file mode 100644 index 0000000000..1f2c0af529 --- /dev/null +++ b/runtime/stdlib/random_test.go @@ -0,0 +1,109 @@ +/* + * Cadence - The resource-oriented smart contract programming language + * + * Copyright Dapper Labs, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package stdlib + +import ( + "crypto/rand" + "strconv" + "testing" + + "github.com/onflow/crypto/random" + "github.com/stretchr/testify/require" + + "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" +) + +type testCryptRandomGenerator struct{} + +var _ RandomGenerator = testCryptRandomGenerator{} + +func (t testCryptRandomGenerator) ReadRandom(buffer []byte) error { + _, err := rand.Read(buffer) + return err +} + +// TestRandomBasicUniformityWithModulo is a sanity statistical test +// to make sure the random numbers less than modulo are uniform in [0,modulo-1]. +// The test requires the original random source (here `crypto/rand`) to be uniform. +// The test uses the same small values for all types: +// one is a power of 2 and the other is not. +func TestRandomBasicUniformityWithModulo(t *testing.T) { + + t.Parallel() + + if testing.Short() { + // skipped because the test is slow + t.Skip() + } + + testTypes := func(t *testing.T, testType func(*testing.T, sema.Type)) { + for _, ty := range sema.AllFixedSizeUnsignedIntegerTypes { + tyCopy := ty + t.Run(ty.String(), func(t *testing.T) { + t.Parallel() + + testType(t, tyCopy) + }) + } + } + + // dummy interpreter, just use for ConvertAndBox + inter := newInterpreter(t, ``) + + runStatisticsWithModulo := func(modulo int) func(*testing.T, sema.Type) { + return func(t *testing.T, ty sema.Type) { + // make sure modulo fits in 8 bits + require.Less(t, modulo, 1<<8) + + moduloValue := inter.ConvertAndBox( + interpreter.EmptyLocationRange, + interpreter.NewUnmeteredUIntValueFromUint64(uint64(modulo)), + sema.UIntType, + ty, + ) + + f := func() (uint64, error) { + + value := RevertibleRandom( + testCryptRandomGenerator{}, + nil, + ty, + moduloValue, + ) + + return strconv.ParseUint(value.String(), 10, 8) + } + + random.BasicDistributionTest(t, uint64(modulo), 1, f) + } + } + + t.Run("power of 2 (that fits in 8 bits)", func(t *testing.T) { + t.Parallel() + + testTypes(t, runStatisticsWithModulo(64)) + }) + + t.Run("non-power of 2", func(t *testing.T) { + t.Parallel() + + testTypes(t, runStatisticsWithModulo(71)) + }) +}