diff --git a/angrop/chain_builder/builder.py b/angrop/chain_builder/builder.py index ede3c48..c7f5171 100644 --- a/angrop/chain_builder/builder.py +++ b/angrop/chain_builder/builder.py @@ -1,7 +1,6 @@ import struct from abc import abstractmethod from functools import cmp_to_key -from collections import defaultdict import claripy @@ -103,7 +102,8 @@ def _get_ptr_to_null(self): return addr return None - def _ast_contains_stack_data(self, ast): + @staticmethod + def _ast_contains_stack_data(ast): vs = ast.variables return len(vs) == 1 and list(vs)[0].startswith('symbolic_stack_') @@ -323,7 +323,7 @@ def __filter_gadgets(self, gadgets): g1 = gadgets.pop() # check if nothing is better than g1 for g2 in bests|gadgets: - if self._better_than(g2, g1): + if self._better_than(g2, g1): #pylint: disable=arguments-out-of-order break else: bests.add(g1) diff --git a/angrop/chain_builder/func_caller.py b/angrop/chain_builder/func_caller.py index d23c50c..f554371 100644 --- a/angrop/chain_builder/func_caller.py +++ b/angrop/chain_builder/func_caller.py @@ -69,7 +69,8 @@ def _func_call(self, func_gadget, cc, args, extra_regs=None, preserve_regs=None, # 1. handle stack arguments # 2. handle function return address to maintain the control flow if stack_arguments: - cleaner = self.chain_builder.shift((len(stack_arguments)+1)*arch_bytes, next_pc_idx=-1, preserve_regs=preserve_regs) + shift_bytes = (len(stack_arguments)+1)*arch_bytes + cleaner = self.chain_builder.shift(shift_bytes, next_pc_idx=-1, preserve_regs=preserve_regs) chain.add_gadget(cleaner._gadgets[0]) for arg in stack_arguments: chain.add_value(arg) diff --git a/angrop/chain_builder/mem_changer.py b/angrop/chain_builder/mem_changer.py index c718970..7af6e3c 100644 --- a/angrop/chain_builder/mem_changer.py +++ b/angrop/chain_builder/mem_changer.py @@ -96,7 +96,8 @@ def add_to_mem(self, addr, value, data_size=None): # get the data from trying to set all the registers registers = dict((reg, 0x41) for reg in self.chain_builder.arch.reg_set) l.debug("getting reg data for mem adds") - _, _, reg_data = self.chain_builder._reg_setter.find_candidate_chains_graph_search(max_stack_change=0x50, **registers) + _, _, reg_data = self.chain_builder._reg_setter.find_candidate_chains_graph_search(max_stack_change=0x50, + **registers) l.debug("trying mem_add gadgets") # filter out gadgets that certainly cannot be used for add_mem diff --git a/angrop/chain_builder/mem_writer.py b/angrop/chain_builder/mem_writer.py index e771c9a..d6977ee 100644 --- a/angrop/chain_builder/mem_writer.py +++ b/angrop/chain_builder/mem_writer.py @@ -61,10 +61,10 @@ def _gen_mem_write_gadgets(self, string_data): # generate from the cache first if self._good_mem_write_gadgets: - for g in self._good_mem_write_gadgets: - yield g + yield from self._good_mem_write_gadgets - possible_gadgets = {g for g in self._mem_write_gadgets.copy() if g.transit_type != 'jmp_reg'} - self._good_mem_write_gadgets + possible_gadgets = {g for g in self._mem_write_gadgets.copy() if g.transit_type != 'jmp_reg'} + possible_gadgets -= self._good_mem_write_gadgets # already yield these # use the graph-search to gain a rough idea about (stack_change, register setting) registers = dict((reg, 0x41) for reg in self.arch.reg_set) @@ -98,8 +98,11 @@ def _gen_mem_write_gadgets(self, string_data): if stack_change == best_stack_change and self._better_than(g, best_gadget): best_gadget = g - yield best_gadget - possible_gadgets.remove(best_gadget) + if best_gadget: + possible_gadgets.remove(best_gadget) + yield best_gadget + else: + break @rop_utils.timeout(5) def _try_write_to_mem(self, gadget, use_partial_controllers, addr, string_data, fill_byte): @@ -251,7 +254,8 @@ def _write_to_mem_with_gadget(self, gadget, addr_val, data, use_partial_controll break reg_vals[reg] = var - chain = self._set_regs(use_partial_controllers=use_partial_controllers, **reg_vals) + + chain = self._set_regs(**reg_vals) chain.add_gadget(gadget) bytes_per_pop = self.project.arch.bytes diff --git a/angrop/chain_builder/reg_setter.py b/angrop/chain_builder/reg_setter.py index 9046d08..eccd7c4 100644 --- a/angrop/chain_builder/reg_setter.py +++ b/angrop/chain_builder/reg_setter.py @@ -16,6 +16,12 @@ l = logging.getLogger("angrop.chain_builder.reg_setter") class RegSetter(Builder): + """ + a chain builder that aims to set registers using different algorithms + 1. algo1: graph-search, fast, not reliable + 2. algo2: pop-only bfs search, fast, reliable, can generate chains to bypass bad-bytes + 3. algo3: riscy-rop inspired backward search, slow, can utilize gadgets containing conditional branches + """ def __init__(self, chain_builder): super().__init__(chain_builder) self._reg_setting_gadgets = None # all the gadgets that can set registers @@ -54,12 +60,14 @@ def verify(self, chain, preserve_regs, registers): offset -= act.offset % self.project.arch.bytes reg_name = self.project.arch.translate_register_name(offset) if reg_name in preserve_regs: - l.exception("Somehow angrop thinks \n%s\n can be used for the chain generation - 1.\ntarget registers: %s", chain_str, registers) + l.exception("Somehow angrop thinks\n%s\ncan be used for the chain generation-1.\nregisters: %s", + chain_str, registers) return False for reg, val in registers.items(): bv = getattr(state.regs, reg) if (val.symbolic != bv.symbolic) or state.solver.eval(bv != val.data): - l.exception("Somehow angrop thinks \n%s\n can be used for the chain generation - 2.\ntarget registers: %s", chain_str, registers) + l.exception("Somehow angrop thinks\n%s\ncan be used for the chain generation-2.\nregisters: %s", + chain_str, registers) return False # the next pc must come from the stack or just marked as the next_pc if len(state.regs.pc.variables) != 1: @@ -67,7 +75,7 @@ def verify(self, chain, preserve_regs, registers): pc_var = set(state.regs.pc.variables).pop() return pc_var.startswith("symbolic_stack") or pc_var.startswith("next_pc") - def run(self, modifiable_memory_range=None, use_partial_controllers=False, preserve_regs=None, max_length=10, **registers): + def run(self, modifiable_memory_range=None, preserve_regs=None, max_length=10, **registers): if len(registers) == 0: return RopChain(self.project, None, badbytes=self.badbytes) @@ -106,24 +114,20 @@ def iterate_candidate_chains(self, modifiable_memory_range, preserve_regs, max_l yield gadgets # algorithm2 - gadgets_list = self.find_candidate_chains_pop_only_bfs_search( + yield from self.find_candidate_chains_pop_only_bfs_search( self._find_relevant_gadgets(**registers), preserve_regs.copy(), **registers) - for gadgets in gadgets_list: - yield gadgets # algorithm3 - for gadgets in self.find_candidate_chains_backwards_recursive_search( + yield from self.find_candidate_chains_backwards_recursive_search( self._reg_setting_gadgets, set(registers), current_chain=[], preserve_regs=preserve_regs.copy(), modifiable_memory_range=modifiable_memory_range, visited={}, - max_length=max_length): - yield gadgets - return + max_length=max_length) #### Chain Building Algorithm 1: fast but unreliable graph-based search #### @@ -137,6 +141,7 @@ def _tuple_to_gadgets(data, reg_tuple): curr_tuple = reg_tuple else: gadgets_reverse = reg_tuple[2] + curr_tuple = () while curr_tuple != (): gadgets_reverse.append(data[curr_tuple][2]) curr_tuple = data[curr_tuple][0] @@ -586,11 +591,11 @@ def _get_remaining_regs(self, gadget: RopGadget, registers: set[str]) -> set[str for reg in registers: if reg in gadget.popped_regs: - vars = gadget.popped_reg_vars[reg] - if not vars.isdisjoint(stack_dependencies): + reg_vars = gadget.popped_reg_vars[reg] + if not reg_vars.isdisjoint(stack_dependencies): # Two registers are popped from the same location on the stack. return None - stack_dependencies |= vars + stack_dependencies |= reg_vars continue new_reg = reg for reg_move in gadget.reg_moves: @@ -681,4 +686,4 @@ def filter_gadgets(self, gadgets): bests = bests.union(self._filter_gadgets(equal_class)) gadgets -= equal_class - return bests \ No newline at end of file + return bests diff --git a/angrop/chain_builder/sys_caller.py b/angrop/chain_builder/sys_caller.py index eb85097..2c1ecb9 100644 --- a/angrop/chain_builder/sys_caller.py +++ b/angrop/chain_builder/sys_caller.py @@ -51,8 +51,7 @@ def supported_os(os): def update(self): self.syscall_gadgets = self._filter_gadgets(self.chain_builder.syscall_gadgets) - @staticmethod - def _filter_gadgets(gadgets): + def _filter_gadgets(self, gadgets): return sorted(gadgets, key=functools.cmp_to_key(cmp)) def _try_invoke_execve(self, path_addr): @@ -169,7 +168,7 @@ def key_func(x): try: return self._func_call(gadget, cc, args, extra_regs=extra_regs, needs_return=needs_return, preserve_regs=preserve_regs, **kwargs) - except Exception: # pylint: disable=broad-exception-caught + except Exception: # pylint:disable=broad-exception-caught continue raise RopException(f"Fail to invoke syscall {syscall_num} with arguments: {args}!") diff --git a/angrop/gadget_finder/gadget_analyzer.py b/angrop/gadget_finder/gadget_analyzer.py index 192053f..4bd555c 100644 --- a/angrop/gadget_finder/gadget_analyzer.py +++ b/angrop/gadget_finder/gadget_analyzer.py @@ -82,7 +82,7 @@ def _step_to_gadget_stopping_states(self, init_state): try: simgr = self.project.factory.simulation_manager(init_state, save_unconstrained=True) - def filter(state): + def filter_func(state): if not state.ip.concrete: return None if self.project.is_hooked(state.addr): @@ -94,8 +94,9 @@ def filter(state): return simgr.DROP return None - simgr.run(n=2, filter_func=filter) - simgr.move(from_stash='active', to_stash='syscall', filter_func=lambda s: rop_utils.is_in_kernel(self.project, s)) + simgr.run(n=2, filter_func=filter_func) + simgr.move(from_stash='active', to_stash='syscall', + filter_func=lambda s: rop_utils.is_in_kernel(self.project, s)) except (claripy.errors.ClaripySolverInterruptError, claripy.errors.ClaripyZ3Error, ValueError): return [], [] @@ -282,7 +283,7 @@ def _can_reach_stopping_states(self, addr, allow_conditional_branches, max_steps def _try_stepping_past_syscall(self, state): try: return rop_utils.step_to_unconstrained_successor(self.project, state, max_steps=3) - except Exception: # pylint: disable=broad-exception-caught + except Exception: # pylint:disable=broad-exception-caught return state def _identify_transit_type(self, final_state, ctrl_type): diff --git a/angrop/rop.py b/angrop/rop.py index 0e12c87..2645467 100644 --- a/angrop/rop.py +++ b/angrop/rop.py @@ -160,14 +160,16 @@ def find_gadgets_single_threaded(self, show_progress=True): return self.rop_gadgets def _get_cache_tuple(self): - all_gadgets = [x for x in self._all_gadgets] - for g in all_gadgets: g.project = None + all_gadgets = self._all_gadgets + for g in all_gadgets: + g.project = None return (all_gadgets, self._duplicates) def _load_cache_tuple(self, tup): self._all_gadgets = tup[0] self._duplicates = tup[1] - for g in self._all_gadgets: g.project = self.project + for g in self._all_gadgets: + g.project = self.project self._screen_gadgets() def save_gadgets(self, path): @@ -177,7 +179,8 @@ def save_gadgets(self, path): """ with open(path, "wb") as f: pickle.dump(self._get_cache_tuple(), f) - for g in self._all_gadgets: g.project = self.project + for g in self._all_gadgets: + g.project = self.project def load_gadgets(self, path): """ diff --git a/angrop/rop_gadget.py b/angrop/rop_gadget.py index b38d750..53ba85a 100644 --- a/angrop/rop_gadget.py +++ b/angrop/rop_gadget.py @@ -205,7 +205,8 @@ def __repr__(self): return "<Gadget %#x>" % self.addr def copy(self): - out = RopGadget(self.project, self.addr) + out = RopGadget(self.addr) + out.project = self.project out.addr = self.addr out.changed_regs = set(self.changed_regs) out.popped_regs = set(self.popped_regs) @@ -303,4 +304,4 @@ def __init__(self, addr, symbol): def dstr(self): if self.symbol: return f"<{self.symbol}>" - return f"<func_{self.addr:#x}>" \ No newline at end of file + return f"<func_{self.addr:#x}>" diff --git a/angrop/rop_utils.py b/angrop/rop_utils.py index 41c008f..235955e 100644 --- a/angrop/rop_utils.py +++ b/angrop/rop_utils.py @@ -152,7 +152,7 @@ def _asts_must_be_equal(state, ast1, ast2): return True -def fast_uninitialized_filler(name, addr, size, state): +def fast_uninitialized_filler(_, addr, size, state): return state.solver.BVS("uninitialized" + hex(addr), size, explicit_name=True) diff --git a/tests/test_chainbuilder.py b/tests/test_chainbuilder.py index b1118c4..4c17d1c 100644 --- a/tests/test_chainbuilder.py +++ b/tests/test_chainbuilder.py @@ -1,8 +1,9 @@ import os +import claripy + import angr import angrop # pylint: disable=unused-import -import claripy from angrop.rop_value import RopValue from angrop.errors import RopException