Skip to content

Commit

Permalink
Merge branch 'master' into feat/binopt
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-cooper authored Jan 20, 2025
2 parents ef035cc + 7136eab commit 8642223
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 5 deletions.
42 changes: 41 additions & 1 deletion tests/functional/builtins/codegen/test_ecrecover.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import contextlib

from eth_account import Account
from eth_account._utils.signing import to_bytes32

from tests.utils import ZERO_ADDRESS
from tests.utils import ZERO_ADDRESS, check_precompile_asserts
from vyper.compiler.settings import OptimizationLevel


def test_ecrecover_test(get_contract):
Expand Down Expand Up @@ -86,3 +89,40 @@ def test_ecrecover() -> bool:
"""
c = get_contract(code)
assert c.test_ecrecover() is True


def test_ecrecover_oog_handling(env, get_contract, tx_failed, optimize, experimental_codegen):
# GHSA-vgf2-gvx8-xwc3
code = """
@external
@view
def do_ecrecover(hash: bytes32, v: uint256, r:uint256, s:uint256) -> address:
return ecrecover(hash, v, r, s)
"""
check_precompile_asserts(code)

c = get_contract(code)

h = b"\x35" * 32
local_account = Account.from_key(b"\x46" * 32)
sig = local_account.signHash(h)
v, r, s = sig.v, sig.r, sig.s

assert c.do_ecrecover(h, v, r, s) == local_account.address

gas_used = env.last_result.gas_used

if optimize == OptimizationLevel.NONE and not experimental_codegen:
# if optimizations are off, enough gas is used by the contract
# that the gas provided to ecrecover (63/64ths rule) is enough
# for it to succeed
ctx = contextlib.nullcontext
else:
# in other cases, the gas forwarded is small enough for ecrecover
# to fail with oog, which we handle by reverting.
ctx = tx_failed

with ctx():
# provide enough spare gas for the top-level call to not oog but
# not enough for ecrecover to succeed
c.do_ecrecover(h, v, r, s, gas=gas_used)
60 changes: 59 additions & 1 deletion tests/functional/codegen/types/test_dynamic_array.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import contextlib
import itertools
from typing import Any, Callable

import pytest

from tests.utils import decimal_to_int
from tests.utils import check_precompile_asserts, decimal_to_int
from vyper.compiler import compile_code
from vyper.evm.opcodes import version_check
from vyper.exceptions import (
ArgumentException,
ArrayIndexException,
Expand Down Expand Up @@ -1909,3 +1911,59 @@ def foo():
c = get_contract(code)
with tx_failed():
c.foo()


def test_dynarray_copy_oog(env, get_contract, tx_failed):
# GHSA-vgf2-gvx8-xwc3
code = """
@external
def foo(a: DynArray[uint256, 4000]) -> uint256:
b: DynArray[uint256, 4000] = a
return b[0]
"""
check_precompile_asserts(code)

c = get_contract(code)
dynarray = [2] * 4000
assert c.foo(dynarray) == 2

gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed

with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(dynarray, gas=gas_used)


def test_dynarray_copy_oog2(env, get_contract, tx_failed):
# GHSA-vgf2-gvx8-xwc3
code = """
@external
@view
def foo(x: String[1000000], y: String[1000000]) -> DynArray[String[1000000], 2]:
z: DynArray[String[1000000], 2] = [x, y]
# Some code
return z
"""
check_precompile_asserts(code)

c = get_contract(code)
calldata0 = "a" * 10
calldata1 = "b" * 1000000
assert c.foo(calldata0, calldata1) == [calldata0, calldata1]

gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed

with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(calldata0, calldata1, gas=gas_used)
76 changes: 75 additions & 1 deletion tests/functional/codegen/types/test_lists.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import contextlib
import itertools

import pytest

from tests.utils import decimal_to_int
from tests.evm_backends.base_env import EvmError
from tests.utils import check_precompile_asserts, decimal_to_int
from vyper.compiler.settings import OptimizationLevel
from vyper.evm.opcodes import version_check
from vyper.exceptions import ArrayIndexException, OverflowException, TypeMismatch


Expand Down Expand Up @@ -848,3 +852,73 @@ def foo() -> {return_type}:
return MY_CONSTANT[0][0]
"""
assert_compile_failed(lambda: get_contract(code), TypeMismatch)


def test_array_copy_oog(env, get_contract, tx_failed, optimize, experimental_codegen, request):
# GHSA-vgf2-gvx8-xwc3
code = """
@internal
def bar(x: uint256[3000]) -> uint256[3000]:
a: uint256[3000] = x
return a
@external
def foo(x: uint256[3000]) -> uint256:
s: uint256[3000] = self.bar(x)
return s[0]
"""
check_precompile_asserts(code)

if optimize == OptimizationLevel.NONE and not experimental_codegen:
# fails in bytecode generation due to jumpdests too large
with pytest.raises(AssertionError):
get_contract(code)
return

c = get_contract(code)
array = [2] * 3000
assert c.foo(array) == array[0]

# get the minimum gas for the contract complete execution
gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed
with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(array, gas=gas_used)


def test_array_copy_oog2(env, get_contract, tx_failed, optimize, experimental_codegen, request):
# GHSA-vgf2-gvx8-xwc3
code = """
@external
def foo(x: uint256[2500]) -> uint256:
s: uint256[2500] = x
t: uint256[2500] = s
return t[0]
"""
check_precompile_asserts(code)

if optimize == OptimizationLevel.NONE and not experimental_codegen:
# fails in creating contract due to code too large
with tx_failed(EvmError):
get_contract(code)
return

c = get_contract(code)
array = [2] * 2500
assert c.foo(array) == array[0]

# get the minimum gas for the contract complete execution
gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed
with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(array, gas=gas_used)
58 changes: 58 additions & 0 deletions tests/functional/codegen/types/test_string.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import contextlib

import pytest

from tests.utils import check_precompile_asserts
from vyper.evm.opcodes import version_check


def test_string_return(get_contract):
code = """
Expand Down Expand Up @@ -359,3 +364,56 @@ def compare_var_storage_not_equal_false() -> bool:
assert c.compare_var_storage_equal_false() is False
assert c.compare_var_storage_not_equal_true() is True
assert c.compare_var_storage_not_equal_false() is False


def test_string_copy_oog(env, get_contract, tx_failed):
# GHSA-vgf2-gvx8-xwc3
code = """
@external
@view
def foo(x: String[1000000]) -> String[1000000]:
return x
"""
check_precompile_asserts(code)

c = get_contract(code)
calldata = "a" * 1000000
assert c.foo(calldata) == calldata

gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed

with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(calldata, gas=gas_used)


def test_string_copy_oog2(env, get_contract, tx_failed):
# GHSA-vgf2-gvx8-xwc3
code = """
@external
@view
def foo(x: String[1000000]) -> uint256:
y: String[1000000] = x
return len(y)
"""
check_precompile_asserts(code)

c = get_contract(code)
calldata = "a" * 1000000
assert c.foo(calldata) == len(calldata)

gas_used = env.last_result.gas_used
if version_check(begin="cancun"):
ctx = contextlib.nullcontext
else:
ctx = tx_failed

with ctx():
# depends on EVM version. pre-cancun, will revert due to checking
# success flag from identity precompile.
c.foo(calldata, gas=gas_used)
22 changes: 22 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

from vyper import ast as vy_ast
from vyper.compiler.phases import CompilerData
from vyper.semantics.analysis.constant_folding import constant_fold
from vyper.utils import DECIMAL_EPSILON, round_towards_zero

Expand All @@ -28,3 +29,24 @@ def parse_and_fold(source_code):
def decimal_to_int(*args):
s = decimal.Decimal(*args)
return round_towards_zero(s / DECIMAL_EPSILON)


def check_precompile_asserts(source_code):
# common sanity check for some tests, that calls to precompiles
# are correctly wrapped in an assert.

compiler_data = CompilerData(source_code)
deploy_ir = compiler_data.ir_nodes
runtime_ir = compiler_data.ir_runtime

def _check(ir_node, parent=None):
if ir_node.value == "staticcall":
precompile_addr = ir_node.args[1]
if isinstance(precompile_addr.value, int) and precompile_addr.value < 10:
assert parent is not None and parent.value == "assert"
for arg in ir_node.args:
_check(arg, ir_node)

_check(deploy_ir)
# technically runtime_ir is contained in deploy_ir, but check it anyways.
_check(runtime_ir)
2 changes: 1 addition & 1 deletion vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ def build_IR(self, expr, args, kwargs, context):
["mstore", add_ofst(input_buf, 32), args[1]],
["mstore", add_ofst(input_buf, 64), args[2]],
["mstore", add_ofst(input_buf, 96), args[3]],
["staticcall", "gas", 1, input_buf, 128, output_buf, 32],
["assert", ["staticcall", "gas", 1, input_buf, 128, output_buf, 32]],
["mload", output_buf],
],
typ=AddressT(),
Expand Down
2 changes: 1 addition & 1 deletion vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def copy_bytes(dst, src, length, length_bound):
copy_op = ["mcopy", dst, src, length]
gas_bound = _mcopy_gas_bound(length_bound)
else:
copy_op = ["staticcall", "gas", 4, src, length, dst, length]
copy_op = ["assert", ["staticcall", "gas", 4, src, length, dst, length]]
gas_bound = _identity_gas_bound(length_bound)
elif src.location == CALLDATA:
copy_op = ["calldatacopy", dst, src, length]
Expand Down

0 comments on commit 8642223

Please sign in to comment.