Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(codegen): get locals from calldata directly in external functions #244

Merged
merged 2 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 8 additions & 140 deletions codegen/src/codegen/dispatcher.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
//! Code generator for EVM dispatcher.

use crate::{
codegen::code::ExtFunc,
wasm::{self, Env, Functions, ToLSBytes},
Error, JumpTable, MacroAssembler, Result,
wasm::{self, Env, Functions},
JumpTable, MacroAssembler, Result,
};
use std::collections::BTreeMap;
use wasmparser::FuncType;
Expand Down Expand Up @@ -60,166 +59,35 @@ impl Dispatcher {
Ok(self.asm.buffer().into())
}

/// Emit return of ext function.
fn ext_return(&mut self, sig: &FuncType) -> Result<()> {
self.asm.increment_sp(1)?;
let asm = self.asm.clone();

{
self.asm.main_return(sig.results())?;
}

let bytecode = {
let jumpdest = vec![0x5b];
let ret = self.asm.buffer()[asm.buffer().len()..].to_vec();
[jumpdest, ret].concat()
};

self.asm = asm;
let ret = ExtFunc {
bytecode,
stack_in: 0,
stack_out: 0,
};
self.table.ext(self.asm.pc_offset(), ret);
Ok(())
}

// Process to the selected function.
//
// 1. drop selector.
// 2. load calldata to stack.
// 3. jump to the callee function.
//
// TODO: Parse bytes from the selector.
fn process(&mut self, len: usize, last: bool) -> Result<bool> {
let len = len as u8;
if last && len == 0 {
return Ok(false);
}

self.asm.increment_sp(1)?;
let asm = self.asm.clone();
{
if !last {
// TODO: check the safety of this.
//
// [ ret, callee, selector ] -> [ selector, ret, callee ]
self.asm.shift_stack(2, false)?;
// [ selector, ret, callee ] -> [ ret, callee ]
self.asm._drop()?;
} else {
self.asm._swap1()?;
}

if len > 0 {
// FIXME: Using the length of parameters here
// is incorrect once we have params have length
// over than 4 bytes.
//
// 1. decode the abi from signature, if contains
// bytes type, use `calldatacopy` to load the data
// on stack.
//
// 2. if the present param is a 4 bytes value, use
// `calldataload[n]` directly.
//
// Actually 1. is more closed to the common cases,
// what 4 bytes for in EVM?

// [ ret, callee ] -> [ param * len, ret, callee ]
for p in (0..len).rev() {
let offset = 4 + p * 32;
self.asm.push(&offset.to_ls_bytes())?;
self.asm._calldataload()?;
}

// [ param * len, ret, callee ] -> [ ret, param * len, callee ]
self.asm.shift_stack(len, false)?;
// [ ret, param * len, callee ] -> [ callee, ret, param * len ]
self.asm.shift_stack(len + 1, false)?;
} else {
self.asm._swap1()?;
}

self.asm._jump()?;
}

let bytecode = {
let jumpdest = vec![0x5b];
let ret = self.asm.buffer()[asm.buffer().len()..].to_vec();
[jumpdest, ret].concat()
};
self.asm = asm;
let ret = ExtFunc {
bytecode,
stack_in: len,
stack_out: 1,
};
self.table.ext(self.asm.pc_offset(), ret);
Ok(true)
}

/// Emit selector to buffer.
fn emit_selector(&mut self, selector: &wasm::Function<'_>, last: bool) -> Result<()> {
let abi = self.env.load_abi(selector)?;

// TODO: refactor this. (#206)
self.abi.push(abi.clone());

let selector_bytes = abi.selector();

tracing::trace!(
"Emitting selector {:?} for function: {}",
selector_bytes,
abi.signature(),
);

let func = self.env.query_func(&abi.name)?;
let sig = self
.funcs
.get(&func)
.ok_or(Error::FuncNotFound(func))?
.clone();

// TODO: optimize this on parameter length (#165)
{
// Prepare the `PC` of the callee function.
//
// TODO: remove this (#160)
{
self.asm.increment_sp(1)?;
self.table.call(self.asm.pc_offset(), func);
self.asm._jumpdest()?;
}
self.asm.increment_sp(1)?;

// Jump to the end of the current function.
//
// TODO: detect the bytes of the position. (#157)
self.ext_return(&sig)?;
}
// Prepare the `PC` of the callee function.
self.table.call(self.asm.pc_offset(), func);

if last {
self.asm._swap2()?;
self.asm._swap1()?;
} else {
self.asm._dup3()?;
self.asm._dup2()?;
}

self.asm.push(&selector_bytes)?;
self.asm._eq()?;
let processed = self.process(sig.params().len(), last)?;
if last && !processed {
self.asm._swap1()?;
}
self.asm._swap1()?;
self.asm._jumpi()?;

if !last {
// drop the PC of the previous callee function.
self.asm._drop()?;
// drop the PC of the previous callee function preprocessor.
self.asm._drop()?;
}

Ok(())
}
}
12 changes: 10 additions & 2 deletions codegen/src/codegen/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ impl Function {
params_count = ty.params().len() as u8;
}

let is_external = abi.is_some();
let mut codegen = Self {
abi,
backtrace: Backtrace::default(),
Expand All @@ -54,8 +55,15 @@ impl Function {
is_main,
};

if is_main {
return Ok(codegen);
}

// post process program counter and stack pointer.
if !is_main {
if is_external {
codegen.masm.increment_sp(1)?;
codegen.masm._jumpdest()?;
} else {
// Mock the stack frame for the callee function
//
// STACK: PC + params
Expand Down Expand Up @@ -129,7 +137,7 @@ impl Function {
/// Finish code generation.
pub fn finish(self, jump_table: &mut JumpTable, pc: u16) -> Result<Buffer> {
let sp = self.masm.sp();
if !self.is_main && self.masm.sp() != self.ty.results().len() as u8 {
if !self.is_main && self.abi.is_none() && self.masm.sp() != self.ty.results().len() as u8 {
return Err(Error::StackNotBalanced(sp));
}

Expand Down
9 changes: 5 additions & 4 deletions codegen/src/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ impl Locals {
let offset = if local.ty() == &LocalSlotType::Parameter {
self.inner[..index].iter().fold(0, |acc, x| acc + x.align())
} else {
self.inner[..index]
.iter()
.filter(|x| x.ty() == &LocalSlotType::Variable)
.fold(0, |acc, x| acc + x.align())
panic!("This should never be reached");
// self.inner[..index]
// .iter()
// .filter(|x| x.ty() == &LocalSlotType::Variable)
// .fold(0, |acc, x| acc + x.align())
}
.to_ls_bytes()
.to_vec()
Expand Down
5 changes: 5 additions & 0 deletions codegen/src/visitor/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ impl Function {

/// The call instruction calls a function specified by its index.
pub fn _call(&mut self, index: u32) -> Result<()> {
if self.env.is_external(index) {
// TODO: throw with error
panic!("External functions could not be called internally");
}

if self.env.imports.len() as u32 > index {
self.call_imported(index)
} else {
Expand Down
8 changes: 4 additions & 4 deletions codegen/src/visitor/control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ impl Function {
pub fn _end(&mut self) -> Result<()> {
if let Ok(frame) = self.control.pop() {
self.handle_frame_popping(frame)
} else if !self.is_main {
tracing::trace!("end of call");
self.handle_call_return()
} else {
} else if self.is_main || self.abi.is_some() {
tracing::trace!("end of main function");
self.handle_return()
} else {
tracing::trace!("end of call");
self.handle_call_return()
}
}

Expand Down
10 changes: 7 additions & 3 deletions codegen/src/visitor/local.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//! Local instructions

use crate::{Error, Function, Result};
use crate::{wasm::ToLSBytes, Error, Function, Result};

impl Function {
/// This instruction gets the value of a variable.
pub fn _local_get(&mut self, local_index: u32) -> Result<()> {
let local_index = local_index as usize;
if self.is_main && local_index < self.ty.params().len() {
if (self.is_main && local_index < self.ty.params().len()) || self.abi.is_some() {
// Parsing data from selector.
self._local_get_calldata(local_index)
} else {
Expand Down Expand Up @@ -49,7 +49,11 @@ impl Function {

/// Local get from calldata.
fn _local_get_calldata(&mut self, local_index: usize) -> Result<()> {
let offset = self.locals.offset_of(local_index)?;
let mut offset = self.locals.offset_of(local_index)?;
if self.abi.is_some() {
offset = (4 + local_index * 32).to_ls_bytes().to_vec().into();
}

self.masm.push(&offset)?;
self.masm._calldataload()?;

Expand Down
10 changes: 10 additions & 0 deletions codegen/src/wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,16 @@ impl Env {
Err(Error::FuncNotImported(name.into()))
}

/// Check if the input function is external function
pub fn is_external(&self, index: u32) -> bool {
let Some(name) = self.exports.get(&index) else {
return false;
};

let selector = name.to_owned() + "_selector";
self.exports.iter().any(|(_, n)| **n == selector)
}

/// If the present function index is the main function
pub fn is_main(&self, index: u32) -> bool {
self.imports.len() as u32 == index
Expand Down
1 change: 1 addition & 0 deletions compiler/filetests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ impl Test {
let mut compiler = zinkc::Compiler::default();
// TODO: after #166
if name == "fibonacci" {
// return Ok(());
compiler.config = compiler.config.dispatcher(true);
}
compiler.compile(&wasm)?;
Expand Down
5 changes: 5 additions & 0 deletions examples/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ extern crate zink;
/// Calculates the nth fibonacci number.
#[zink::external]
pub fn fib(n: u64) -> u64 {
internal_rec(n)
}

#[inline(never)]
fn internal_rec(n: u64) -> u64 {
if n < 2 {
n
} else {
Expand Down
Loading