diff --git a/src/halmos/__main__.py b/src/halmos/__main__.py index e5e7b076..9965aa1c 100644 --- a/src/halmos/__main__.py +++ b/src/halmos/__main__.py @@ -25,11 +25,15 @@ def parse_args(args) -> argparse.Namespace: parser.add_argument('--contract', metavar='CONTRACT_NAME', help='run tests in the given contract only') parser.add_argument('--function', metavar='FUNCTION_NAME_PREFIX', default='test', help='run tests matching the given prefix only (default: %(default)s)') + parser.add_argument('--bytecode', metavar='HEX_STRING', help='execute the given bytecode') + parser.add_argument('--loop', metavar='MAX_BOUND', type=int, default=2, help='set loop unrolling bounds (default: %(default)s)') parser.add_argument('--width', metavar='MAX_WIDTH', type=int, help='set the max number of paths') parser.add_argument('--depth', metavar='MAX_DEPTH', type=int, help='set the max path length') parser.add_argument('--array-lengths', metavar='NAME1=LENGTH1,NAME2=LENGTH2,...', help='set the length of dynamic-sized arrays including bytes and string (default: loop unrolling bound)') + parser.add_argument('--symbolic-jump', action='store_true', help='support symbolic jump destination (experimental)') + parser.add_argument('--no-smt-add', action='store_true', help='do not interpret `+`') parser.add_argument('--no-smt-sub', action='store_true', help='do not interpret `-`') parser.add_argument('--no-smt-mul', action='store_true', help='do not interpret `*`') @@ -136,6 +140,68 @@ def decode_hex(hexcode: str) -> Tuple[List[Opcode], List[Any]]: pgm = ops_to_pgm(ops) return (pgm, code) +def mk_callvalue() -> Word: + return BitVec('msg_value', 256) + +def mk_balance() -> Word: + return BitVec('this_balance', 256) + +def mk_caller(solver) -> Word: + caller = BitVec('msg_sender', 256) + solver.add(Extract(255, 160, caller) == BitVecVal(0, 96)) + return caller + +def mk_this(solver) -> Word: + this = BitVec('this_address', 256) + solver.add(Extract(255, 160, this) == BitVecVal(0, 96)) + return this + +def mk_solver(args: argparse.Namespace): + solver = SolverFor('QF_AUFBV') # quantifier-free bitvector + array theory; https://smtlib.cs.uiowa.edu/logics.shtml + solver.set(timeout=args.solver_timeout_branching) + return solver + +def run_bytecode(hexcode: str, args: argparse.Namespace, options: Dict) -> List[Exec]: + (pgm, code) = decode_hex(hexcode) + + storage = {} + + solver = mk_solver(args) + + balance = mk_balance() + callvalue = mk_callvalue() + caller = mk_caller(solver) + this = mk_this(solver) + + sevm = SEVM(options) + ex = sevm.mk_exec( + pgm = { this: pgm }, + code = { this: code }, + storage = { this: storage }, + balance = { this: balance }, + calldata = [], + callvalue = callvalue, + caller = caller, + this = this, + symbolic = True, + solver = solver, + ) + (exs, _) = sevm.run(ex) + + models = [] + for idx, ex in enumerate(exs): + opcode = ex.pgm[ex.this][ex.pc].op[0] + if is_bv_value(opcode) and opcode.as_long() in [EVM.STOP, EVM.RETURN, EVM.REVERT]: + gen_model(args, models, idx, ex) + print(f'Final opcode: {opcode.as_long()} | Return data: {ex.output} | Input example: {models[-1][0]}') + else: + print(color_warn('Not supported: ' + opcode + ' ' + ex.error)) + if args.verbose >= 1: + print(f'# {idx+1} / {len(exs)}') + print(ex) + + return exs + def setup( hexcode: str, abi: List, @@ -146,36 +212,22 @@ def setup( args: argparse.Namespace, options: Dict ) -> Exec: - # bytecode (pgm, code) = decode_hex(hexcode) - # solver - solver = SolverFor('QF_AUFBV') # quantifier-free bitvector + array theory; https://smtlib.cs.uiowa.edu/logics.shtml - solver.set(timeout=args.solver_timeout_branching) - - # storage - storage = {} - - # caller - caller = BitVec('msg_sender', 256) - solver.add(Extract(255, 160, caller) == BitVecVal(0, 96)) + solver = mk_solver(args) - # this - this = BitVec('this_address', 256) - solver.add(Extract(255, 160, this) == BitVecVal(0, 96)) - - # run setup if any + this = mk_this(solver) sevm = SEVM(options) setup_ex = sevm.mk_exec( pgm = { this: pgm }, code = { this: code }, - storage = { this: storage }, + storage = { this: {} }, balance = { this: con(0) }, calldata = [], callvalue = con(0), - caller = caller, + caller = mk_caller(solver), this = this, symbolic = False, solver = solver, @@ -227,13 +279,13 @@ def run( # callvalue # - callvalue = BitVec('msg_value', 256) + callvalue = mk_callvalue() # # balance # - balance = BitVec('this_balance', 256) + balance = mk_balance() # # run @@ -416,6 +468,23 @@ def select(var): else: return '[' + ', '.join(sorted(map(lambda decl: f'{decl} = {model[decl]}', filter(select, model)))) + ']' +def mk_options(args: argparse.Namespace) -> Dict: + return { + 'target': args.target, + 'verbose': args.verbose, + 'debug': args.debug, + 'log': args.log, + 'add': not args.no_smt_add, + 'sub': not args.no_smt_sub, + 'mul': not args.no_smt_mul, + 'div': args.smt_div, + 'divByConst': args.smt_div_by_const, + 'modByConst': args.smt_mod_by_const, + 'expByConst': args.smt_exp_by_const, + 'timeout': args.solver_timeout_branching, + 'sym_jump': args.symbolic_jump, + } + def main() -> int: # # z3 global options @@ -431,20 +500,7 @@ def main() -> int: args = parse_args(sys.argv[1:]) - options = { - 'target': args.target, - 'verbose': args.verbose, - 'debug': args.debug, - 'log': args.log, - 'add': not args.no_smt_add, - 'sub': not args.no_smt_sub, - 'mul': not args.no_smt_mul, - 'div': args.smt_div, - 'divByConst': args.smt_div_by_const, - 'modByConst': args.smt_mod_by_const, - 'expByConst': args.smt_exp_by_const, - 'timeout': args.solver_timeout_branching, - } + options = mk_options(args) if args.width is not None: options['max_width'] = args.width @@ -462,6 +518,11 @@ def main() -> int: size = assign[1].strip() arrlen[name] = int(size) + # quick bytecode execution mode + if args.bytecode is not None: + run_bytecode(args.bytecode, args, options) + return 0 + # # compile # diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index f72ae080..36b9a1c9 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -5,7 +5,7 @@ from copy import deepcopy from collections import defaultdict -from typing import List, Dict, Tuple, Any +from typing import List, Dict, Set, Tuple, Any from functools import reduce from z3 import * @@ -78,6 +78,20 @@ def wstore_bytes(mem: List[Byte], loc: int, size: int, arr: List[Byte]) -> None: def create_address(cnt: int) -> Word: return con(0x220E + cnt) +def valid_jump_destinations(pgm: List[Opcode]) -> Set[int]: + jumpdests = set() + i = 0 + while i < len(pgm): + opcode = pgm[i].op[0] + if is_bv_value(opcode): + opcode = opcode.as_long() + if opcode == EVM.JUMPDEST: + jumpdests.add(i) + elif EVM.PUSH1 <= opcode <= EVM.PUSH32: + i += opcode - EVM.PUSH1 + 1 + i += 1 + return jumpdests + class State: stack: List[Word] memory: List[Byte] @@ -885,39 +899,7 @@ def jumpi(self, ex: Exec, stack: List[Tuple[Exec,int]], step_id: int) -> None: cond_true = simplify(is_non_zero(cond)) ex.solver.add(cond_true) if ex.solver.check() != unsat: # jump - new_solver = SolverFor('QF_AUFBV') - new_solver.set(timeout=self.options['timeout']) - new_solver.add(ex.solver.assertions()) - new_path = deepcopy(ex.path) - new_path.append(str(cond_true)) - new_ex_true = Exec( - pgm = ex.pgm.copy(), # shallow copy for potential new contract creation; existing code doesn't change - code = ex.code.copy(), # shallow copy - storage = deepcopy(ex.storage), - balance = deepcopy(ex.balance), - # - calldata = ex.calldata, - callvalue= ex.callvalue, - caller = ex.caller, - this = ex.this, - # - pc = target, - st = deepcopy(ex.st), - jumpis = deepcopy(ex.jumpis), - output = deepcopy(ex.output), - symbolic = ex.symbolic, - # - solver = new_solver, - path = new_path, - # - log = deepcopy(ex.log), - cnts = deepcopy(ex.cnts), - sha3s = deepcopy(ex.sha3s), - storages = deepcopy(ex.storages), - calls = deepcopy(ex.calls), - failed = ex.failed, - error = ex.error, - ) + new_ex_true = self.create_branch(ex, str(cond_true), target) ex.solver.pop() cond_false = simplify(is_zero(cond)) @@ -941,6 +923,66 @@ def jumpi(self, ex: Exec, stack: List[Tuple[Exec,int]], step_id: int) -> None: else: pass # this may happen if the previous path condition was considered unknown but turns out to be unsat later + def jump(self, ex: Exec, stack: List[Tuple[Exec,int]], step_id: int) -> None: + dst = ex.st.pop() + + # if dst is concrete, just jump + if is_bv_value(dst): + ex.pc = int(str(dst)) + stack.append((ex, step_id)) + + # otherwise, create a new execution for feasible targets + elif self.options['sym_jump']: + for target in valid_jump_destinations(ex.pgm[ex.this]): + ex.solver.push() + target_reachable = simplify(dst == target) + ex.solver.add(target_reachable) + if ex.solver.check() != unsat: # jump + if self.options.get('debug'): + print(f"we can jump to {target} with model {ex.solver.model()}") + new_ex = self.create_branch(ex, str(target_reachable), target) + stack.append((new_ex, step_id)) + ex.solver.pop() + + else: + raise ValueError(dst) + + def create_branch(self, ex: Exec, cond: str, target: int) -> Exec: + new_solver = SolverFor('QF_AUFBV') + new_solver.set(timeout=self.options['timeout']) + new_solver.add(ex.solver.assertions()) + new_path = deepcopy(ex.path) + new_path.append(cond) + new_ex = Exec( + pgm = ex.pgm.copy(), # shallow copy for potential new contract creation; existing code doesn't change + code = ex.code.copy(), # shallow copy + storage = deepcopy(ex.storage), + balance = deepcopy(ex.balance), + # + calldata = ex.calldata, + callvalue= ex.callvalue, + caller = ex.caller, + this = ex.this, + # + pc = target, + st = deepcopy(ex.st), + jumpis = deepcopy(ex.jumpis), + output = deepcopy(ex.output), + symbolic = ex.symbolic, + # + solver = new_solver, + path = new_path, + # + log = deepcopy(ex.log), + cnts = deepcopy(ex.cnts), + sha3s = deepcopy(ex.sha3s), + storages = deepcopy(ex.storages), + calls = deepcopy(ex.calls), + failed = ex.failed, + error = ex.error, + ) + return new_ex + def run(self, ex0: Exec) -> Tuple[List[Exec], Steps]: out: List[Exec] = [] steps: Steps = {} @@ -997,9 +1039,8 @@ def run(self, ex0: Exec) -> Tuple[List[Exec], Steps]: continue elif opcode == EVM.JUMP: - source: int = ex.pc - target: int = int(str(ex.st.pop())) # target must be concrete - ex.pc = target + self.jump(ex, stack, step_id) + continue elif opcode == EVM.JUMPDEST: pass diff --git a/tests/test_cli.py b/tests/test_cli.py index cb2ce9db..fab68b2e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -7,7 +7,7 @@ from halmos.byte2op import decode -from halmos.__main__ import str_abi, parse_args, decode_hex +from halmos.__main__ import str_abi, parse_args, decode_hex, mk_options, run_bytecode import halmos.__main__ @pytest.fixture @@ -42,19 +42,15 @@ def args(): @pytest.fixture def options(args): - return { - 'verbose': args.verbose, - 'debug': args.debug, - 'log': args.log, - 'add': not args.no_smt_add, - 'sub': not args.no_smt_sub, - 'mul': not args.no_smt_mul, - 'div': args.smt_div, - 'divByConst': args.smt_div_by_const, - 'modByConst': args.smt_mod_by_const, - 'expByConst': args.smt_exp_by_const, - 'timeout': args.solver_timeout_branching, - } + return mk_options(args) + +def test_run_bytecode(args, options): + hexcode = '34381856FDFDFDFDFDFD5B00' + options['sym_jump'] = True + exs = run_bytecode(hexcode, args, options) + assert len(exs) == 1 + ex = exs[0] + assert str(ex.pgm[ex.this][ex.pc].op[0]) == str(EVM.STOP) def test_setup(setup_abi, setup_name, setup_sig, setup_selector, args, options): hexcode = '600100' diff --git a/tests/test_sevm.py b/tests/test_sevm.py index b876a241..5fd60009 100644 --- a/tests/test_sevm.py +++ b/tests/test_sevm.py @@ -9,7 +9,7 @@ from halmos.sevm import SEVM, con, ops_to_pgm, f_div, f_sdiv, f_mod, f_smod, f_exp, f_orig_balance, f_origin -from halmos.__main__ import parse_args +from halmos.__main__ import parse_args, mk_options @pytest.fixture def args(): @@ -17,19 +17,7 @@ def args(): @pytest.fixture def options(args): - return { - 'verbose': args.verbose, - 'debug': args.debug, - 'log': args.log, - 'add': not args.no_smt_add, - 'sub': not args.no_smt_sub, - 'mul': not args.no_smt_mul, - 'div': args.smt_div, - 'divByConst': args.smt_div_by_const, - 'modByConst': args.smt_mod_by_const, - 'expByConst': args.smt_exp_by_const, - 'timeout': args.solver_timeout_branching, - } + return mk_options(args) @pytest.fixture def sevm(options):