Skip to content

Commit

Permalink
Also use variables for known inputs in machine calls. (#2315)
Browse files Browse the repository at this point in the history
This introduces variables not only for outputs but only for inputs. This
allows us to back-propagate range constraints we get from the
sub-machine call also to inputs.
  • Loading branch information
chriseth authored Jan 9, 2025
1 parent 9c6315e commit 40aa582
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 116 deletions.
29 changes: 21 additions & 8 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ impl<T: FieldElement> FixedEvaluator<T> for &BlockMachineProcessor<'_, T> {
mod test {
use std::fs::read_to_string;

use pretty_assertions::assert_eq;
use test_log::test;

use powdr_number::GoldilocksField;
Expand Down Expand Up @@ -420,8 +421,11 @@ main_binary::B_byte[2] = ((main_binary::B[3] & 4278190080) // 16777216);
main_binary::B[2] = (main_binary::B[3] & 16777215);
assert (main_binary::B[3] & 18446744069414584320) == 0;
main_binary::operation_id_next[2] = main_binary::operation_id[3];
machine_call(9, [Known(main_binary::operation_id_next[2]), Known(main_binary::A_byte[2]), Known(main_binary::B_byte[2]), Unknown(ret(9, 2, 3))]);
main_binary::C_byte[2] = ret(9, 2, 3);
call_var(9, 2, 0) = main_binary::operation_id_next[2];
call_var(9, 2, 1) = main_binary::A_byte[2];
call_var(9, 2, 2) = main_binary::B_byte[2];
machine_call(9, [Known(call_var(9, 2, 0)), Known(call_var(9, 2, 1)), Known(call_var(9, 2, 2)), Unknown(call_var(9, 2, 3))]);
main_binary::C_byte[2] = call_var(9, 2, 3);
main_binary::operation_id[1] = main_binary::operation_id[2];
main_binary::A_byte[1] = ((main_binary::A[2] & 16711680) // 65536);
main_binary::A[1] = (main_binary::A[2] & 65535);
Expand All @@ -430,8 +434,11 @@ main_binary::B_byte[1] = ((main_binary::B[2] & 16711680) // 65536);
main_binary::B[1] = (main_binary::B[2] & 65535);
assert (main_binary::B[2] & 18446744073692774400) == 0;
main_binary::operation_id_next[1] = main_binary::operation_id[2];
machine_call(9, [Known(main_binary::operation_id_next[1]), Known(main_binary::A_byte[1]), Known(main_binary::B_byte[1]), Unknown(ret(9, 1, 3))]);
main_binary::C_byte[1] = ret(9, 1, 3);
call_var(9, 1, 0) = main_binary::operation_id_next[1];
call_var(9, 1, 1) = main_binary::A_byte[1];
call_var(9, 1, 2) = main_binary::B_byte[1];
machine_call(9, [Known(call_var(9, 1, 0)), Known(call_var(9, 1, 1)), Known(call_var(9, 1, 2)), Unknown(call_var(9, 1, 3))]);
main_binary::C_byte[1] = call_var(9, 1, 3);
main_binary::operation_id[0] = main_binary::operation_id[1];
main_binary::A_byte[0] = ((main_binary::A[1] & 65280) // 256);
main_binary::A[0] = (main_binary::A[1] & 255);
Expand All @@ -440,13 +447,19 @@ main_binary::B_byte[0] = ((main_binary::B[1] & 65280) // 256);
main_binary::B[0] = (main_binary::B[1] & 255);
assert (main_binary::B[1] & 18446744073709486080) == 0;
main_binary::operation_id_next[0] = main_binary::operation_id[1];
machine_call(9, [Known(main_binary::operation_id_next[0]), Known(main_binary::A_byte[0]), Known(main_binary::B_byte[0]), Unknown(ret(9, 0, 3))]);
main_binary::C_byte[0] = ret(9, 0, 3);
call_var(9, 0, 0) = main_binary::operation_id_next[0];
call_var(9, 0, 1) = main_binary::A_byte[0];
call_var(9, 0, 2) = main_binary::B_byte[0];
machine_call(9, [Known(call_var(9, 0, 0)), Known(call_var(9, 0, 1)), Known(call_var(9, 0, 2)), Unknown(call_var(9, 0, 3))]);
main_binary::C_byte[0] = call_var(9, 0, 3);
main_binary::A_byte[-1] = main_binary::A[0];
main_binary::B_byte[-1] = main_binary::B[0];
main_binary::operation_id_next[-1] = main_binary::operation_id[0];
machine_call(9, [Known(main_binary::operation_id_next[-1]), Known(main_binary::A_byte[-1]), Known(main_binary::B_byte[-1]), Unknown(ret(9, -1, 3))]);
main_binary::C_byte[-1] = ret(9, -1, 3);
call_var(9, -1, 0) = main_binary::operation_id_next[-1];
call_var(9, -1, 1) = main_binary::A_byte[-1];
call_var(9, -1, 2) = main_binary::B_byte[-1];
machine_call(9, [Known(call_var(9, -1, 0)), Known(call_var(9, -1, 1)), Known(call_var(9, -1, 2)), Unknown(call_var(9, -1, 3))]);
main_binary::C_byte[-1] = call_var(9, -1, 3);
main_binary::C[0] = main_binary::C_byte[-1];
main_binary::C[1] = (main_binary::C[0] + (main_binary::C_byte[0] * 256));
main_binary::C[2] = (main_binary::C[1] + (main_binary::C_byte[1] * 65536));
Expand Down
87 changes: 44 additions & 43 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use powdr_number::{FieldElement, KnownField};

use crate::witgen::{
data_structures::{finalizable_data::CompactDataRef, mutable_state::MutableState},
jit::effect::MachineCallArgument,
machines::{
profiling::{record_end, record_start},
LookupCell,
Expand Down Expand Up @@ -150,8 +149,8 @@ fn witgen_code<T: FieldElement>(
format!("get(data, row_offset, {}, {})", c.row_offset, c.id)
}
Variable::Param(i) => format!("get_param(params, {i})"),
Variable::MachineCallReturnValue(_) => {
unreachable!("Machine call return values should not be pre-known.")
Variable::MachineCallParam(_) => {
unreachable!("Machine call variables should not be pre-known.")
}
};
format!(" let {var_name} = {value};")
Expand All @@ -173,7 +172,7 @@ fn witgen_code<T: FieldElement>(
cell.row_offset, cell.id,
)),
Variable::Param(i) => Some(format!(" set_param(params, {i}, {value});")),
Variable::MachineCallReturnValue(_) => {
Variable::MachineCallParam(_) => {
// This is just an internal variable.
None
}
Expand All @@ -186,7 +185,7 @@ fn witgen_code<T: FieldElement>(
.iter()
.filter_map(|var| match var {
Variable::Cell(cell) => Some(cell),
Variable::Param(_) | Variable::MachineCallReturnValue(_) => None,
Variable::Param(_) | Variable::MachineCallParam(_) => None,
})
.map(|cell| {
format!(
Expand Down Expand Up @@ -234,10 +233,11 @@ fn written_vars_in_effect<T: FieldElement>(
Effect::Assignment(var, _) => Box::new(iter::once((var, false))),
Effect::RangeConstraint(..) => unreachable!(),
Effect::Assertion(..) => Box::new(iter::empty()),
Effect::MachineCall(_, arguments) => Box::new(arguments.iter().flat_map(|e| match e {
MachineCallArgument::Unknown(v) => Some((v, true)),
MachineCallArgument::Known(_) => None,
})),
Effect::MachineCall(_, known, vars) => Box::new(
vars.iter()
.zip_eq(known)
.flat_map(|(v, known)| (!known).then_some((v, true))),
),
Effect::Branch(_, first, second) => Box::new(
first
.iter()
Expand Down Expand Up @@ -287,21 +287,21 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
if *expected_equal { "==" } else { "!=" },
format_expression(rhs)
),
Effect::MachineCall(id, arguments) => {
Effect::MachineCall(id, known, vars) => {
let mut result_vars = vec![];
let args = arguments
let args = vars
.iter()
.map(|a| match a {
MachineCallArgument::Unknown(v) => {
let var_name = variable_to_string(v);
.zip_eq(known)
.map(|(v, known)| {
let var_name = variable_to_string(v);
if known {
format!("LookupCell::Input(&{var_name})")
} else {
if is_top_level {
result_vars.push(var_name.clone());
}
format!("LookupCell::Output(&mut {var_name})")
}
MachineCallArgument::Known(v) => {
format!("LookupCell::Input(&{})", format_expression(v))
}
})
.format(", ")
.to_string();
Expand Down Expand Up @@ -396,12 +396,12 @@ fn variable_to_string(v: &Variable) -> String {
format_row_offset(cell.row_offset)
),
Variable::Param(i) => format!("p_{i}"),
Variable::MachineCallReturnValue(ret) => {
Variable::MachineCallParam(call_var) => {
format!(
"ret_{}_{}_{}",
ret.identity_id,
format_row_offset(ret.row_offset),
ret.index
"call_var_{}_{}_{}",
call_var.identity_id,
format_row_offset(call_var.row_offset),
call_var.index
)
}
}
Expand Down Expand Up @@ -483,13 +483,14 @@ fn util_code<T: FieldElement>(first_column_id: u64, column_count: usize) -> Resu

#[cfg(test)]
mod tests {

use pretty_assertions::assert_eq;
use test_log::test;

use powdr_number::GoldilocksField;

use crate::witgen::jit::variable::Cell;
use crate::witgen::jit::variable::MachineCallReturnVariable;
use crate::witgen::jit::variable::MachineCallVariable;
use crate::witgen::range_constraints::RangeConstraint;

use super::*;
Expand Down Expand Up @@ -518,8 +519,8 @@ mod tests {
Variable::Param(i)
}

fn ret_val(identity_id: u64, row_offset: i32, index: usize) -> Variable {
Variable::MachineCallReturnValue(MachineCallReturnVariable {
fn call_var(identity_id: u64, row_offset: i32, index: usize) -> Variable {
Variable::MachineCallParam(MachineCallVariable {
identity_id,
row_offset,
index,
Expand Down Expand Up @@ -547,15 +548,15 @@ mod tests {
let x0 = cell("x", 0, 0);
let ym1 = cell("y", 1, -1);
let yp1 = cell("y", 1, 1);
let r1 = ret_val(7, 1, 1);
let cv1 = call_var(7, 1, 0);
let r1 = call_var(7, 1, 1);
let effects = vec![
assignment(&x0, number(7) * symbol(&a0)),
assignment(&cv1, symbol(&x0)),
Effect::MachineCall(
7,
vec![
MachineCallArgument::Unknown(r1.clone()),
MachineCallArgument::Known(symbol(&x0)),
],
[false, true].into_iter().collect(),
vec![r1.clone(), cv1.clone()],
),
assignment(&ym1, symbol(&r1)),
assignment(&yp1, symbol(&ym1) + symbol(&x0)),
Expand All @@ -568,8 +569,8 @@ mod tests {
let known_inputs = vec![a0.clone()];
let code = witgen_code(&known_inputs, &effects);
assert_eq!(
code,
"
code,
"
#[no_mangle]
extern \"C\" fn witgen(
WitgenFunctionParams{
Expand All @@ -588,9 +589,10 @@ extern \"C\" fn witgen(
let c_a_2_0 = get(data, row_offset, 0, 2);
let c_x_0_0 = (FieldElement::from(7) * c_a_2_0);
let mut ret_7_1_1 = FieldElement::default();
assert!(call_machine(mutable_state, 7, MutSlice::from((&mut [LookupCell::Output(&mut ret_7_1_1), LookupCell::Input(&c_x_0_0)]).as_mut_slice())));
let c_y_1_m1 = ret_7_1_1;
let call_var_7_1_0 = c_x_0_0;
let mut call_var_7_1_1 = FieldElement::default();
assert!(call_machine(mutable_state, 7, MutSlice::from((&mut [LookupCell::Output(&mut call_var_7_1_1), LookupCell::Input(&call_var_7_1_0)]).as_mut_slice())));
let c_y_1_m1 = call_var_7_1_1;
let c_y_1_1 = (c_y_1_m1 + c_x_0_0);
assert!(c_y_1_m1 == c_x_0_0);
Expand All @@ -603,7 +605,7 @@ extern \"C\" fn witgen(
set_known(known, row_offset, 1, 1);
}
"
);
);
}

extern "C" fn no_call_machine(
Expand Down Expand Up @@ -817,16 +819,15 @@ extern \"C\" fn witgen(
fn submachine_calls() {
let x = cell("x", 0, 0);
let y = cell("y", 1, 0);
let r1 = ret_val(7, 0, 1);
let r2 = ret_val(7, 0, 2);
let v1 = call_var(7, 0, 0);
let r1 = call_var(7, 0, 1);
let r2 = call_var(7, 0, 2);
let effects = vec![
Effect::Assignment(v1.clone(), number(7)),
Effect::MachineCall(
7,
vec![
MachineCallArgument::Known(number(7)),
MachineCallArgument::Unknown(r1.clone()),
MachineCallArgument::Unknown(r2.clone()),
],
[true, false, false].into_iter().collect(),
vec![v1, r1.clone(), r2.clone()],
),
Effect::Assignment(x.clone(), symbol(&r1)),
Effect::Assignment(y.clone(), symbol(&r2)),
Expand Down
23 changes: 10 additions & 13 deletions executor/src/witgen/jit/effect.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::cmp::Ordering;

use bit_vec::BitVec;
use itertools::Itertools;
use powdr_ast::indent;
use powdr_number::FieldElement;
Expand All @@ -17,8 +18,8 @@ pub enum Effect<T: FieldElement, V> {
RangeConstraint(V, RangeConstraint<T>),
/// A run-time assertion. If this fails, we have conflicting constraints.
Assertion(Assertion<T, V>),
/// A call to a different machine.
MachineCall(u64, Vec<MachineCallArgument<T, V>>),
/// A call to a different machine, with identity ID, known inputs and argument variables.
MachineCall(u64, BitVec, Vec<V>),
/// A branch on a variable.
Branch(BranchCondition<T, V>, Vec<Effect<T, V>>, Vec<Effect<T, V>>),
}
Expand Down Expand Up @@ -59,12 +60,6 @@ impl<T: FieldElement, V> Assertion<T, V> {
}
}

#[derive(Clone, PartialEq, Eq)]
pub enum MachineCallArgument<T: FieldElement, V> {
Known(SymbolicExpression<T, V>),
Unknown(V),
}

#[derive(Clone, PartialEq, Eq)]
pub struct BranchCondition<T: FieldElement, V> {
pub variable: V,
Expand All @@ -88,13 +83,15 @@ pub fn format_code<T: FieldElement>(effects: &[Effect<T, Variable>]) -> String {
if *expected_equal { "==" } else { "!=" }
)
}
Effect::MachineCall(id, args) => {
Effect::MachineCall(id, known, vars) => {
format!(
"machine_call({id}, [{}]);",
args.iter()
.map(|arg| match arg {
MachineCallArgument::Known(k) => format!("Known({k})"),
MachineCallArgument::Unknown(u) => format!("Unknown({u})"),
vars.iter()
.zip(known)
.map(|(v, known)| if known {
format!("Known({v})")
} else {
format!("Unknown({v})")
})
.join(", ")
)
Expand Down
7 changes: 4 additions & 3 deletions executor/src/witgen/jit/single_step_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,10 @@ mod test {
format_code(&code),
"\
VM::pc[1] = (VM::pc[0] + 1);
machine_call(1, [Known(VM::pc[1]), Unknown(ret(1, 1, 1)), Unknown(ret(1, 1, 2))]);
VM::instr_add[1] = ret(1, 1, 1);
VM::instr_mul[1] = ret(1, 1, 2);
call_var(1, 1, 0) = VM::pc[1];
machine_call(1, [Known(call_var(1, 1, 0)), Unknown(call_var(1, 1, 1)), Unknown(call_var(1, 1, 2))]);
VM::instr_add[1] = call_var(1, 1, 1);
VM::instr_mul[1] = call_var(1, 1, 2);
VM::B[1] = VM::B[0];
if (VM::instr_add[0] == 1) {
if (VM::instr_mul[0] == 1) {
Expand Down
21 changes: 6 additions & 15 deletions executor/src/witgen/jit/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ use std::{
};

use powdr_ast::analyzed::{AlgebraicReference, PolyID, PolynomialType};
use powdr_number::FieldElement;

use super::effect::MachineCallArgument;

#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Debug)]
/// A variable that can be used in the inference engine.
Expand All @@ -16,20 +13,20 @@ pub enum Variable {
/// A parameter (input or output) of the machine.
#[allow(dead_code)]
Param(usize),
/// The return value of a machine call on a certain
/// An input or output value of a machine call on a certain
/// identity on a certain row offset.
MachineCallReturnValue(MachineCallReturnVariable),
MachineCallParam(MachineCallVariable),
}

impl Display for Variable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Variable::Cell(cell) => write!(f, "{cell}"),
Variable::Param(i) => write!(f, "params[{i}]"),
Variable::MachineCallReturnValue(ret) => {
Variable::MachineCallParam(ret) => {
write!(
f,
"ret({}, {}, {})",
"call_var({}, {}, {})",
ret.identity_id, ret.row_offset, ret.index
)
}
Expand All @@ -55,24 +52,18 @@ impl Variable {
id: cell.id,
ptype: PolynomialType::Committed,
}),
Variable::Param(_) | Variable::MachineCallReturnValue(_) => None,
Variable::Param(_) | Variable::MachineCallParam(_) => None,
}
}
}

#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Debug)]
pub struct MachineCallReturnVariable {
pub struct MachineCallVariable {
pub identity_id: u64,
pub row_offset: i32,
pub index: usize,
}

impl MachineCallReturnVariable {
pub fn into_argument<T: FieldElement>(self) -> MachineCallArgument<T, Variable> {
MachineCallArgument::Unknown(Variable::MachineCallReturnValue(self))
}
}

/// The identifier of a witness cell in the trace table.
/// The `row_offset` is relative to a certain "zero row" defined
/// by the component that uses this data structure.
Expand Down
Loading

0 comments on commit 40aa582

Please sign in to comment.