-
-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat[venom]: add loop invariant hoisting pass #4175
base: master
Are you sure you want to change the base?
Changes from 10 commits
13944eb
f626b68
942c731
399a0a4
2a8dd4a
302aa21
013840c
1ee0068
f703408
c5dbf05
cd655fc
ecb272a
18c7610
5583cc5
bdb2896
788bd0d
715b128
8edce11
cf6b25e
77f97b6
e62c8cd
139f79b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import pytest | ||
|
||
from vyper.venom.analysis.analysis import IRAnalysesCache | ||
from vyper.venom.analysis.loop_detection import LoopDetectionAnalysis | ||
from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRVariable | ||
from vyper.venom.context import IRContext | ||
from vyper.venom.function import IRFunction | ||
from vyper.venom.passes.loop_invariant_hosting import LoopInvariantHoisting | ||
|
||
|
||
def _create_loops(fn, depth, loop_id, body_fn=lambda _: (), top=True): | ||
bb = fn.get_basic_block() | ||
cond = IRBasicBlock(IRLabel(f"cond{loop_id}{depth}"), fn) | ||
body = IRBasicBlock(IRLabel(f"body{loop_id}{depth}"), fn) | ||
if top: | ||
exit_block = IRBasicBlock(IRLabel(f"exit_top{loop_id}{depth}"), fn) | ||
else: | ||
exit_block = IRBasicBlock(IRLabel(f"exit{loop_id}{depth}"), fn) | ||
fn.append_basic_block(cond) | ||
fn.append_basic_block(body) | ||
|
||
bb.append_instruction("jmp", cond.label) | ||
|
||
cond_var = IRVariable(f"cond_var{loop_id}{depth}") | ||
cond.append_instruction("iszero", 0, ret=cond_var) | ||
assert isinstance(cond_var, IRVariable) | ||
cond.append_instruction("jnz", cond_var, body.label, exit_block.label) | ||
body_fn(fn, loop_id, depth) | ||
if depth > 1: | ||
_create_loops(fn, depth - 1, loop_id, body_fn, top=False) | ||
bb = fn.get_basic_block() | ||
bb.append_instruction("jmp", cond.label) | ||
fn.append_basic_block(exit_block) | ||
|
||
|
||
def _simple_body(fn, loop_id, depth): | ||
assert isinstance(fn, IRFunction) | ||
bb = fn.get_basic_block() | ||
add_var = IRVariable(f"add_var{loop_id}{depth}") | ||
bb.append_instruction("add", 1, 2, ret=add_var) | ||
|
||
|
||
def _hoistable_body(fn, loop_id, depth): | ||
assert isinstance(fn, IRFunction) | ||
bb = fn.get_basic_block() | ||
add_var_a = IRVariable(f"add_var_a{loop_id}{depth}") | ||
bb.append_instruction("add", 1, 2, ret=add_var_a) | ||
add_var_b = IRVariable(f"add_var_b{loop_id}{depth}") | ||
bb.append_instruction("add", add_var_a, 2, ret=add_var_b) | ||
|
||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_detection_analysis(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _simple_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
analysis = ac.request_analysis(LoopDetectionAnalysis) | ||
assert len(analysis.loops) == depth * count | ||
|
||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_invariant_hoisting_simple(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _simple_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
LoopInvariantHoisting(ac, fn).run_pass() | ||
|
||
entry = fn.entry | ||
assignments = list(map(lambda x: x.value, entry.get_assignments())) | ||
for bb in filter(lambda bb: bb.label.name.startswith("exit_top"), fn.get_basic_blocks()): | ||
assignments.extend(map(lambda x: x.value, bb.get_assignments())) | ||
|
||
assert len(assignments) == depth * count * 2 | ||
for loop_id in range(count): | ||
for d in range(1, depth + 1): | ||
assert f"%add_var{loop_id}{d}" in assignments, repr(fn) | ||
assert f"%cond_var{loop_id}{d}" in assignments, repr(fn) | ||
|
||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_invariant_hoisting_dependant(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _hoistable_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
LoopInvariantHoisting(ac, fn).run_pass() | ||
|
||
entry = fn.entry | ||
assignments = list(map(lambda x: x.value, entry.get_assignments())) | ||
for bb in filter(lambda bb: bb.label.name.startswith("exit_top"), fn.get_basic_blocks()): | ||
assignments.extend(map(lambda x: x.value, bb.get_assignments())) | ||
|
||
assert len(assignments) == depth * count * 3 | ||
for loop_id in range(count): | ||
for d in range(1, depth + 1): | ||
assert f"%add_var_a{loop_id}{d}" in assignments, repr(fn) | ||
assert f"%add_var_b{loop_id}{d}" in assignments, repr(fn) | ||
assert f"%cond_var{loop_id}{d}" in assignments, repr(fn) | ||
|
||
|
||
def _unhoistable_body(fn, loop_id, depth): | ||
assert isinstance(fn, IRFunction) | ||
bb = fn.get_basic_block() | ||
add_var_a = IRVariable(f"add_var_a{loop_id}{depth}") | ||
bb.append_instruction("mload", 64, ret=add_var_a) | ||
add_var_b = IRVariable(f"add_var_b{loop_id}{depth}") | ||
bb.append_instruction("add", add_var_a, 2, ret=add_var_b) | ||
|
||
|
||
@pytest.mark.parametrize("depth", range(1, 4)) | ||
@pytest.mark.parametrize("count", range(1, 4)) | ||
def test_loop_invariant_hoisting_unhoistable(depth, count): | ||
ctx = IRContext() | ||
fn = ctx.create_function("_global") | ||
|
||
for c in range(count): | ||
_create_loops(fn, depth, c, _unhoistable_body) | ||
|
||
bb = fn.get_basic_block() | ||
bb.append_instruction("ret") | ||
|
||
ac = IRAnalysesCache(fn) | ||
LoopInvariantHoisting(ac, fn).run_pass() | ||
|
||
entry = fn.entry | ||
assignments = list(map(lambda x: x.value, entry.get_assignments())) | ||
for bb in filter(lambda bb: bb.label.name.startswith("exit_top"), fn.get_basic_blocks()): | ||
assignments.extend(map(lambda x: x.value, bb.get_assignments())) | ||
|
||
assert len(assignments) == depth * count | ||
for loop_id in range(count): | ||
for d in range(1, depth + 1): | ||
assert f"%cond_var{loop_id}{d}" in assignments, repr(fn) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from vyper.utils import OrderedSet | ||
from vyper.venom.analysis.analysis import IRAnalysis | ||
from vyper.venom.analysis.cfg import CFGAnalysis | ||
from vyper.venom.basicblock import IRBasicBlock | ||
|
||
|
||
class LoopDetectionAnalysis(IRAnalysis): | ||
""" | ||
Detects loops and computes basic blocks | ||
and the block which is before the loop | ||
""" | ||
|
||
# key = start of the loop (last bb not in the loop) | ||
# value = all the block that loop contains | ||
loops: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] | ||
|
||
done: OrderedSet[IRBasicBlock] | ||
visited: OrderedSet[IRBasicBlock] | ||
|
||
def analyze(self): | ||
self.analyses_cache.request_analysis(CFGAnalysis) | ||
self.loops: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] = dict() | ||
self.done = OrderedSet() | ||
self.visited = OrderedSet() | ||
entry = self.function.entry | ||
self._dfs_r(entry) | ||
|
||
def _dfs_r(self, bb: IRBasicBlock, before: IRBasicBlock = None): | ||
if bb in self.visited: | ||
assert before is not None, "Loop must have one basic block before it" | ||
loop = self._collect_path(before, bb) | ||
in_bb = bb.cfg_in.difference({before}) | ||
assert len(in_bb) == 1, "Loop must have one input basic block" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
input_bb = in_bb.first() | ||
self.loops[input_bb] = loop | ||
self.done.add(bb) | ||
return | ||
|
||
self.visited.add(bb) | ||
|
||
for neighbour in bb.cfg_out: | ||
if neighbour not in self.done: | ||
self._dfs_r(neighbour, bb) | ||
|
||
self.done.add(bb) | ||
return | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundant |
||
|
||
def _collect_path(self, bb_from: IRBasicBlock, bb_to: IRBasicBlock) -> OrderedSet[IRBasicBlock]: | ||
loop: OrderedSet[IRBasicBlock] = OrderedSet() | ||
collect_visit: OrderedSet[IRBasicBlock] = OrderedSet() | ||
self._collect_path_r(bb_from, bb_to, loop, collect_visit) | ||
return loop | ||
|
||
def _collect_path_r( | ||
self, | ||
act_bb: IRBasicBlock, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does "act_bb" stand for "active bb"? maybe |
||
bb_to: IRBasicBlock, | ||
loop: OrderedSet[IRBasicBlock], | ||
collect_visit: OrderedSet[IRBasicBlock], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe just |
||
): | ||
if act_bb in collect_visit: | ||
return | ||
collect_visit.add(act_bb) | ||
loop.add(act_bb) | ||
if act_bb == bb_to: | ||
return | ||
|
||
for before in act_bb.cfg_in: | ||
self._collect_path_r(before, bb_to, loop, collect_visit) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from vyper.utils import OrderedSet | ||
from vyper.venom.analysis.cfg import CFGAnalysis | ||
from vyper.venom.analysis.dfg import DFGAnalysis | ||
from vyper.venom.analysis.liveness import LivenessAnalysis | ||
from vyper.venom.analysis.loop_detection import LoopDetectionAnalysis | ||
from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRLabel, IRVariable | ||
from vyper.venom.function import IRFunction | ||
from vyper.venom.passes.base_pass import IRPass | ||
|
||
|
||
def _ignore_instruction(instruction: IRInstruction) -> bool: | ||
return ( | ||
instruction.is_volatile | ||
or instruction.is_bb_terminator | ||
or instruction.opcode == "returndatasize" | ||
or instruction.opcode == "phi" | ||
or (instruction.opcode == "add" and isinstance(instruction.operands[1], IRLabel)) | ||
) | ||
|
||
|
||
def _is_correct_store(instruction: IRInstruction) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is basically |
||
return instruction.opcode == "store" | ||
|
||
|
||
class LoopInvariantHoisting(IRPass): | ||
""" | ||
This pass detects invariants in loops and hoists them above the loop body. | ||
Any VOLATILE_INSTRUCTIONS, BB_TERMINATORS CFG_ALTERING_INSTRUCTIONS are ignored | ||
""" | ||
|
||
function: IRFunction | ||
loop_analysis: dict[IRBasicBlock, OrderedSet[IRBasicBlock]] | ||
dfg: DFGAnalysis | ||
|
||
def run_pass(self): | ||
self.analyses_cache.request_analysis(CFGAnalysis) | ||
self.dfg = self.analyses_cache.request_analysis(DFGAnalysis) | ||
loops = self.analyses_cache.request_analysis(LoopDetectionAnalysis) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
self.loop_analysis = loops.loops | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's not an analysis anymore so maybe it should just be |
||
invalidate_dependant = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dependent what? maybe rename to just |
||
while True: | ||
change = False | ||
for from_bb, loop in self.loop_analysis.items(): | ||
hoistable: list[IRInstruction] = self._get_hoistable_loop(from_bb, loop) | ||
if len(hoistable) == 0: | ||
continue | ||
change |= True | ||
self._hoist(from_bb, hoistable) | ||
if not change: | ||
break | ||
invalidate_dependant = True | ||
|
||
# only need to invalidate if you did some hoisting | ||
if invalidate_dependant: | ||
self.analyses_cache.invalidate_analysis(LivenessAnalysis) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's bring this outside the loop, and add a flag inside the loop to detect if |
||
|
||
def _hoist(self, target_bb: IRBasicBlock, hoistable: list[IRInstruction]): | ||
for inst in hoistable: | ||
bb = inst.parent | ||
bb.remove_instruction(inst) | ||
target_bb.insert_instruction(inst, index=len(target_bb.instructions) - 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe index does not need to be specified in this case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am modifying already terminated basic block so I it would crash on assert if the index would not be specified |
||
|
||
def _get_hoistable_loop( | ||
self, from_bb: IRBasicBlock, loop: OrderedSet[IRBasicBlock] | ||
) -> list[IRInstruction]: | ||
result: list[IRInstruction] = [] | ||
for bb in loop: | ||
result.extend(self._get_hoistable_bb(bb, from_bb)) | ||
return result | ||
|
||
def _get_hoistable_bb(self, bb: IRBasicBlock, loop_idx: IRBasicBlock) -> list[IRInstruction]: | ||
result: list[IRInstruction] = [] | ||
for instruction in bb.instructions: | ||
if self._can_hoist_instruction_ignore_stores(instruction, self.loop_analysis[loop_idx]): | ||
result.extend(self._store_dependencies(instruction, loop_idx)) | ||
result.append(instruction) | ||
|
||
return result | ||
|
||
# query store dependacies of instruction (they are not handled otherwise) | ||
def _store_dependencies( | ||
self, inst: IRInstruction, loop_idx: IRBasicBlock | ||
) -> list[IRInstruction]: | ||
result: list[IRInstruction] = [] | ||
for var in inst.get_input_variables(): | ||
source_inst = self.dfg.get_producing_instruction(var) | ||
assert isinstance(source_inst, IRInstruction) | ||
if _is_correct_store(source_inst): | ||
for bb in self.loop_analysis[loop_idx]: | ||
if source_inst.parent == bb: | ||
result.append(source_inst) | ||
return result | ||
|
||
# since the stores are always hoistable this ignores | ||
# stores in analysis (their are hoisted if some instrution is dependent on them) | ||
def _can_hoist_instruction_ignore_stores( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leave a comment explaining why this is necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe just call this |
||
self, instruction: IRInstruction, loop: OrderedSet[IRBasicBlock] | ||
) -> bool: | ||
if _ignore_instruction(instruction) or _is_correct_store(instruction): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the volatiles are only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh snap... you are right. I had that one named There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok I will add it there |
||
return False | ||
for bb in loop: | ||
if self._dependant_in_bb(instruction, bb): | ||
return False | ||
return True | ||
|
||
def _dependant_in_bb(self, instruction: IRInstruction, bb: IRBasicBlock): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sp: dependent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this can be a free function rather than a method, as i don't see any references to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The DFGAnalysis from self |
||
for in_var in instruction.get_input_variables(): | ||
assert isinstance(in_var, IRVariable) | ||
source_ins = self.dfg._dfg_outputs[in_var] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prefer instruction to always be abbreviated as prev_inst = self.dfg.get_producing_instruction(in_var)
# can add `assert prev_inst is not None` if desired |
||
|
||
# ignores stores since all stores are independant | ||
# and can be always hoisted | ||
if _is_correct_store(source_ins): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how can this condition ever hold true? |
||
continue | ||
|
||
if source_ins.parent == bb: | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this might be clearer as
if before in bb.cfg_in: ...
. does that introduce any bugs?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just because it checks for natural loops, so I check if it has only one input into the loop.