Skip to content

Commit

Permalink
feat(codegen): test jump targets with mocked dispatcher environment
Browse files Browse the repository at this point in the history
  • Loading branch information
clearloop committed Nov 28, 2024
1 parent 71d4811 commit ea0e2db
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 19 deletions.
61 changes: 54 additions & 7 deletions codegen/src/jump/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,23 @@ mod tests {
}

#[test]
fn test_nested_internal_calls() -> anyhow::Result<()> {
fn test_nested_function_calls() -> anyhow::Result<()> {
let mut table = JumpTable::default();

// Simulate transfer_from calling both _spend_allowance and _transfer
table.register(0x10, Jump::Label(0x100)); // transfer_from -> _spend_allowance
table.register(0x100, Jump::Label(0x200)); // _spend_allowance -> _approve
table.register(0x20, Jump::Label(0x300)); // transfer_from -> _transfer
table.register(0x300, Jump::Label(0x400)); // _transfer -> _update
// Simulate ERC20's approve -> _approve call chain
table.register(0x100, Jump::Label(0x200)); // approve entry
table.register(0x110, Jump::Label(0x300)); // approve -> _approve
table.register(0x200, Jump::Label(0x400)); // _approve entry

assert_target_shift_vs_relocation(table)
let mut buffer = smallvec![0; table.max_target() as usize];
table.relocate(&mut buffer)?;

// Check if all jumps use correct PUSH instructions
assert_eq!(buffer[0x100], 0x61); // PUSH2
assert_eq!(buffer[0x113], 0x61); // PUSH2
assert_eq!(buffer[0x206], 0x61); // PUSH2

Ok(())
}

#[test]
Expand Down Expand Up @@ -189,4 +196,44 @@ mod tests {

Ok(())
}

#[test]
fn test_dispatcher_jump_targets() -> anyhow::Result<()> {
let mut table = JumpTable::default();
let selectors = 5;

// Register jumps for each selector check
for i in 0..selectors {
let i = i as u16;
let check_pc = 0x10 + i * 0x20;
let target_pc = 0x100 + i * 0x40;

// Register both the comparison jump and function jump
table.register(check_pc, Jump::Label(check_pc + 0x10));
table.register(check_pc + 0x10, Jump::Label(target_pc));
}

let mut buffer = smallvec![0; table.max_target() as usize];
table.relocate(&mut buffer)?;

// Verify each selector's jump chain
let mut total_offset = 0;
for i in 0..selectors {
let check_pc = 0x10 + i * 0x20 + total_offset;
let check_pc_offset = if check_pc + 0x10 > 0xff { 3 } else { 2 };

let func_pc = check_pc + 0x10 + check_pc_offset;

let check_jump = buffer[check_pc];
let func_jump = buffer[func_pc];

assert_eq!(check_jump, if func_pc > 0xff { 0x61 } else { 0x60 });
assert_eq!(func_jump, 0x61);

// Update total offset for next iteration
total_offset += check_pc_offset + 3;
}

Ok(())
}
}
28 changes: 16 additions & 12 deletions codegen/src/jump/target.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ impl JumpTable {
pub fn shift_targets(&mut self) -> Result<()> {
let mut total_offset = 0;
let mut target_sizes = Vec::new();
let jumps = self.jump.clone();

// First pass: calculate all target sizes
for (original_pc, jump) in self.jump.clone().iter() {
// First pass: calculate all target sizes and accumulate offsets
for (original_pc, jump) in jumps.iter() {
let pc = original_pc + total_offset;
let target = self.target(jump)? + total_offset;

// Calculate instruction size based on target value
// Calculate instruction size based on absolute target value
let instr_size = if target > 0xff {
3 // PUSH2 + 2 bytes
} else {
Expand All @@ -45,7 +46,7 @@ impl JumpTable {
total_offset += instr_size;
}

// Second pass: apply shifts with correct accumulated offsets
// Second pass: apply shifts with accumulated offsets
total_offset = 0;
for (pc, size) in target_sizes {
tracing::debug!("shift target at pc=0x{pc:x} with size={size}");
Expand All @@ -61,17 +62,19 @@ impl JumpTable {
/// This function handles the shifting of the code section, label targets, and
/// function targets.
pub fn shift_target(&mut self, ptr: u16, offset: u16) -> Result<()> {
// First shift the code section
self.code.shift(offset);

// Only shift targets that are after ptr
self.shift_label_target(ptr, offset)?;
self.shift_func_target(ptr, offset)
}

/// Shifts the program counter for functions.
pub fn shift_func_target(&mut self, ptr: u16, offset: u16) -> Result<()> {
self.func.iter_mut().try_for_each(|(index, target)| {
let next_target = *target + offset;

if *target > ptr {
let next_target = *target + offset;
tracing::trace!(
"shift Func({index}) target with offset={offset}: 0x{target:x}(0x{ptr:x}) -> 0x{:x}",
next_target
Expand All @@ -91,13 +94,14 @@ impl JumpTable {
continue;
};

let next_target = *target + offset;
if *target > ptr {
// Only shift targets that come after ptr AND
// only if the jump instruction itself comes after ptr
if *target > ptr && *pc >= ptr {
let next_target = *target + offset;
tracing::trace!(
"shift Label(0x{pc:x}) target with offset={offset}: 0x{target:x}(0x{ptr:x}) -> 0x{:x}",
next_target,
);

"shift Label target with offset={offset}: 0x{target:x}(0x{ptr:x}) -> 0x{:x}",
next_target,
);
*target = next_target;
}
}
Expand Down

0 comments on commit ea0e2db

Please sign in to comment.