Skip to content

Commit

Permalink
refactor(codegen): get locals from calldata directly in external func…
Browse files Browse the repository at this point in the history
…tions (#244)

* feat(codegen): get locals of exported functions from calldata directly

* refactor(codegen): remove parameter processor in dispatcher
  • Loading branch information
clearloop authored Oct 6, 2024
1 parent 9e4f05e commit 57ce7f7
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 153 deletions.
148 changes: 8 additions & 140 deletions codegen/src/codegen/dispatcher.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
//! Code generator for EVM dispatcher.
use crate::{
codegen::code::ExtFunc,
wasm::{self, Env, Functions, ToLSBytes},
Error, JumpTable, MacroAssembler, Result,
wasm::{self, Env, Functions},
JumpTable, MacroAssembler, Result,
};
use std::collections::BTreeMap;
use wasmparser::FuncType;
Expand Down Expand Up @@ -60,166 +59,35 @@ impl Dispatcher {
Ok(self.asm.buffer().into())
}

/// Emit return of ext function.
fn ext_return(&mut self, sig: &FuncType) -> Result<()> {
self.asm.increment_sp(1)?;
let asm = self.asm.clone();

{
self.asm.main_return(sig.results())?;
}

let bytecode = {
let jumpdest = vec![0x5b];
let ret = self.asm.buffer()[asm.buffer().len()..].to_vec();
[jumpdest, ret].concat()
};

self.asm = asm;
let ret = ExtFunc {
bytecode,
stack_in: 0,
stack_out: 0,
};
self.table.ext(self.asm.pc_offset(), ret);
Ok(())
}

// Process to the selected function.
//
// 1. drop selector.
// 2. load calldata to stack.
// 3. jump to the callee function.
//
// TODO: Parse bytes from the selector.
fn process(&mut self, len: usize, last: bool) -> Result<bool> {
let len = len as u8;
if last && len == 0 {
return Ok(false);
}

self.asm.increment_sp(1)?;
let asm = self.asm.clone();
{
if !last {
// TODO: check the safety of this.
//
// [ ret, callee, selector ] -> [ selector, ret, callee ]
self.asm.shift_stack(2, false)?;
// [ selector, ret, callee ] -> [ ret, callee ]
self.asm._drop()?;
} else {
self.asm._swap1()?;
}

if len > 0 {
// FIXME: Using the length of parameters here
// is incorrect once we have params have length
// over than 4 bytes.
//
// 1. decode the abi from signature, if contains
// bytes type, use `calldatacopy` to load the data
// on stack.
//
// 2. if the present param is a 4 bytes value, use
// `calldataload[n]` directly.
//
// Actually 1. is more closed to the common cases,
// what 4 bytes for in EVM?

// [ ret, callee ] -> [ param * len, ret, callee ]
for p in (0..len).rev() {
let offset = 4 + p * 32;
self.asm.push(&offset.to_ls_bytes())?;
self.asm._calldataload()?;
}

// [ param * len, ret, callee ] -> [ ret, param * len, callee ]
self.asm.shift_stack(len, false)?;
// [ ret, param * len, callee ] -> [ callee, ret, param * len ]
self.asm.shift_stack(len + 1, false)?;
} else {
self.asm._swap1()?;
}

self.asm._jump()?;
}

let bytecode = {
let jumpdest = vec![0x5b];
let ret = self.asm.buffer()[asm.buffer().len()..].to_vec();
[jumpdest, ret].concat()
};
self.asm = asm;
let ret = ExtFunc {
bytecode,
stack_in: len,
stack_out: 1,
};
self.table.ext(self.asm.pc_offset(), ret);
Ok(true)
}

/// Emit selector to buffer.
fn emit_selector(&mut self, selector: &wasm::Function<'_>, last: bool) -> Result<()> {
let abi = self.env.load_abi(selector)?;

// TODO: refactor this. (#206)
self.abi.push(abi.clone());

let selector_bytes = abi.selector();

tracing::trace!(
"Emitting selector {:?} for function: {}",
selector_bytes,
abi.signature(),
);

let func = self.env.query_func(&abi.name)?;
let sig = self
.funcs
.get(&func)
.ok_or(Error::FuncNotFound(func))?
.clone();

// TODO: optimize this on parameter length (#165)
{
// Prepare the `PC` of the callee function.
//
// TODO: remove this (#160)
{
self.asm.increment_sp(1)?;
self.table.call(self.asm.pc_offset(), func);
self.asm._jumpdest()?;
}
self.asm.increment_sp(1)?;

// Jump to the end of the current function.
//
// TODO: detect the bytes of the position. (#157)
self.ext_return(&sig)?;
}
// Prepare the `PC` of the callee function.
self.table.call(self.asm.pc_offset(), func);

if last {
self.asm._swap2()?;
self.asm._swap1()?;
} else {
self.asm._dup3()?;
self.asm._dup2()?;
}

self.asm.push(&selector_bytes)?;
self.asm._eq()?;
let processed = self.process(sig.params().len(), last)?;
if last && !processed {
self.asm._swap1()?;
}
self.asm._swap1()?;
self.asm._jumpi()?;

if !last {
// drop the PC of the previous callee function.
self.asm._drop()?;
// drop the PC of the previous callee function preprocessor.
self.asm._drop()?;
}

Ok(())
}
}
12 changes: 10 additions & 2 deletions codegen/src/codegen/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ impl Function {
params_count = ty.params().len() as u8;
}

let is_external = abi.is_some();
let mut codegen = Self {
abi,
backtrace: Backtrace::default(),
Expand All @@ -54,8 +55,15 @@ impl Function {
is_main,
};

if is_main {
return Ok(codegen);
}

// post process program counter and stack pointer.
if !is_main {
if is_external {
codegen.masm.increment_sp(1)?;
codegen.masm._jumpdest()?;
} else {
// Mock the stack frame for the callee function
//
// STACK: PC + params
Expand Down Expand Up @@ -129,7 +137,7 @@ impl Function {
/// Finish code generation.
pub fn finish(self, jump_table: &mut JumpTable, pc: u16) -> Result<Buffer> {
let sp = self.masm.sp();
if !self.is_main && self.masm.sp() != self.ty.results().len() as u8 {
if !self.is_main && self.abi.is_none() && self.masm.sp() != self.ty.results().len() as u8 {
return Err(Error::StackNotBalanced(sp));
}

Expand Down
9 changes: 5 additions & 4 deletions codegen/src/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ impl Locals {
let offset = if local.ty() == &LocalSlotType::Parameter {
self.inner[..index].iter().fold(0, |acc, x| acc + x.align())
} else {
self.inner[..index]
.iter()
.filter(|x| x.ty() == &LocalSlotType::Variable)
.fold(0, |acc, x| acc + x.align())
panic!("This should never be reached");
// self.inner[..index]
// .iter()
// .filter(|x| x.ty() == &LocalSlotType::Variable)
// .fold(0, |acc, x| acc + x.align())
}
.to_ls_bytes()
.to_vec()
Expand Down
5 changes: 5 additions & 0 deletions codegen/src/visitor/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ impl Function {

/// The call instruction calls a function specified by its index.
pub fn _call(&mut self, index: u32) -> Result<()> {
if self.env.is_external(index) {
// TODO: throw with error
panic!("External functions could not be called internally");
}

if self.env.imports.len() as u32 > index {
self.call_imported(index)
} else {
Expand Down
8 changes: 4 additions & 4 deletions codegen/src/visitor/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ impl Function {
pub fn _end(&mut self) -> Result<()> {
if let Ok(frame) = self.control.pop() {
self.handle_frame_popping(frame)
} else if !self.is_main {
tracing::trace!("end of call");
self.handle_call_return()
} else {
} else if self.is_main || self.abi.is_some() {
tracing::trace!("end of main function");
self.handle_return()
} else {
tracing::trace!("end of call");
self.handle_call_return()
}
}

Expand Down
10 changes: 7 additions & 3 deletions codegen/src/visitor/local.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//! Local instructions
use crate::{Error, Function, Result};
use crate::{wasm::ToLSBytes, Error, Function, Result};

impl Function {
/// This instruction gets the value of a variable.
pub fn _local_get(&mut self, local_index: u32) -> Result<()> {
let local_index = local_index as usize;
if self.is_main && local_index < self.ty.params().len() {
if (self.is_main && local_index < self.ty.params().len()) || self.abi.is_some() {
// Parsing data from selector.
self._local_get_calldata(local_index)
} else {
Expand Down Expand Up @@ -49,7 +49,11 @@ impl Function {

/// Local get from calldata.
fn _local_get_calldata(&mut self, local_index: usize) -> Result<()> {
let offset = self.locals.offset_of(local_index)?;
let mut offset = self.locals.offset_of(local_index)?;
if self.abi.is_some() {
offset = (4 + local_index * 32).to_ls_bytes().to_vec().into();
}

self.masm.push(&offset)?;
self.masm._calldataload()?;

Expand Down
10 changes: 10 additions & 0 deletions codegen/src/wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ impl Env {
Err(Error::FuncNotImported(name.into()))
}

/// Check if the input function is external function
pub fn is_external(&self, index: u32) -> bool {
let Some(name) = self.exports.get(&index) else {
return false;
};

let selector = name.to_owned() + "_selector";
self.exports.iter().any(|(_, n)| **n == selector)
}

/// If the present function index is the main function
pub fn is_main(&self, index: u32) -> bool {
self.imports.len() as u32 == index
Expand Down
1 change: 1 addition & 0 deletions compiler/filetests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ impl Test {
let mut compiler = zinkc::Compiler::default();
// TODO: after #166
if name == "fibonacci" {
// return Ok(());
compiler.config = compiler.config.dispatcher(true);
}
compiler.compile(&wasm)?;
Expand Down
5 changes: 5 additions & 0 deletions examples/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ extern crate zink;
/// Calculates the nth fibonacci number.
#[zink::external]
pub fn fib(n: u64) -> u64 {
internal_rec(n)
}

#[inline(never)]
fn internal_rec(n: u64) -> u64 {
if n < 2 {
n
} else {
Expand Down

0 comments on commit 57ce7f7

Please sign in to comment.