From a0df34f3fab7e373b15e02179639ddcfec42ff42 Mon Sep 17 00:00:00 2001 From: Lucas Clemente Vella Date: Mon, 2 Oct 2023 16:09:26 +0100 Subject: [PATCH] RISCV executor --- Cargo.toml | 1 + analysis/src/lib.rs | 27 +- analysis/src/vm/mod.rs | 4 - ast/src/asm_analysis/mod.rs | 6 +- ast/src/parsed/asm.rs | 1 + ast/src/parsed/display.rs | 3 + ast/src/parsed/mod.rs | 4 +- compiler/benches/executor_benchmark.rs | 5 +- compiler/src/lib.rs | 80 ++- compiler/src/verify.rs | 3 +- halo2/src/mock_prover.rs | 4 +- linker/src/lib.rs | 4 +- parser/src/powdr.lalrpop | 4 +- powdr_cli/Cargo.toml | 3 +- powdr_cli/src/main.rs | 102 ++- riscv/Cargo.toml | 3 +- riscv/src/compiler.rs | 16 +- riscv/src/disambiguator.rs | 13 +- riscv/tests/instructions.rs | 7 +- riscv/tests/riscv.rs | 8 +- riscv/tests/riscv_data/trivial.rs | 3 +- riscv_executor/Cargo.toml | 13 + riscv_executor/src/lib.rs | 841 +++++++++++++++++++++++++ 23 files changed, 1074 insertions(+), 81 deletions(-) create mode 100644 riscv_executor/Cargo.toml create mode 100644 riscv_executor/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index f7c5862b4c..71d0107654 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ members = [ "asm_utils", "airgen", "type_check", + "riscv_executor", ] [patch."https://github.com/privacy-scaling-explorations/halo2.git"] diff --git a/analysis/src/lib.rs b/analysis/src/lib.rs index 4494f3354b..9de1633311 100644 --- a/analysis/src/lib.rs +++ b/analysis/src/lib.rs @@ -10,9 +10,18 @@ pub use macro_expansion::MacroExpander; use ast::{asm_analysis::AnalysisASMFile, parsed::asm::ASMProgram, DiffMonitor}; use number::FieldElement; -pub fn analyze(file: ASMProgram) -> Result, Vec> { +pub fn convert_asm_to_pil( + file: ASMProgram, +) -> Result, Vec> { let mut monitor = DiffMonitor::default(); + let file = analyze(file, &mut monitor)?; + Ok(convert_analyzed_to_pil_constraints(file, &mut monitor)) +} +pub fn analyze( + file: ASMProgram, + monitor: &mut DiffMonitor, +) -> Result, Vec> { // expand macros log::debug!("Run expand analysis step"); let file = macro_expansion::expand(file); @@ -23,15 +32,27 @@ pub fn analyze(file: ASMProgram) -> Result( + file: AnalysisASMFile, + monitor: &mut DiffMonitor, +) -> AnalysisASMFile { + // remove all asm (except external instructions) + log::debug!("Run asm_to_pil"); + let file = asm_to_pil::compile(file); + monitor.push(&file); + // enforce blocks using `operation_id` and `latch` log::debug!("Run enforce_block analysis step"); let file = block_enforcer::enforce(file); monitor.push(&file); - Ok(file) + file } pub mod utils { diff --git a/analysis/src/vm/mod.rs b/analysis/src/vm/mod.rs index cd92c2fc9e..735997e2f3 100644 --- a/analysis/src/vm/mod.rs +++ b/analysis/src/vm/mod.rs @@ -19,10 +19,6 @@ pub fn analyze( log::debug!("Run batch analysis step"); let file = batcher::batch(file); monitor.push(&file); - // remove all asm (except external instructions) - log::debug!("Run asm_to_pil analysis step"); - let file = asm_to_pil::compile(file); - monitor.push(&file); Ok(file) } diff --git a/ast/src/asm_analysis/mod.rs b/ast/src/asm_analysis/mod.rs index 58a1fc159d..cf6048a020 100644 --- a/ast/src/asm_analysis/mod.rs +++ b/ast/src/asm_analysis/mod.rs @@ -87,7 +87,7 @@ pub struct FunctionStatements { } pub struct BatchRef<'a, T> { - statements: &'a [FunctionStatement], + pub statements: &'a [FunctionStatement], reason: &'a Option, } @@ -151,7 +151,7 @@ impl FunctionStatements { } /// iterate over the batches by reference - fn iter_batches(&self) -> impl Iterator> { + pub fn iter_batches(&self) -> impl Iterator> { match &self.batches { Some(batches) => Either::Left(batches.iter()), None => Either::Right( @@ -246,7 +246,7 @@ pub struct CallableSymbolDefinition { } #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Default)] -pub struct CallableSymbolDefinitions(BTreeMap>); +pub struct CallableSymbolDefinitions(pub BTreeMap>); impl IntoIterator for CallableSymbolDefinitions { type Item = CallableSymbolDefinition; diff --git a/ast/src/parsed/asm.rs b/ast/src/parsed/asm.rs index ab92e06b12..eb6e6eb610 100644 --- a/ast/src/parsed/asm.rs +++ b/ast/src/parsed/asm.rs @@ -325,6 +325,7 @@ pub enum FunctionStatement { pub enum DebugDirective { File(usize, String, String), Loc(usize, usize, usize), + OriginalInstruction(String), } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index d52ef657c7..87d864da0f 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -238,6 +238,9 @@ impl Display for DebugDirective { DebugDirective::Loc(file, line, col) => { write!(f, "debug loc {file} {line} {col};") } + DebugDirective::OriginalInstruction(insn) => { + write!(f, "debug insn \"{insn}\";") + } } } } diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 043a45d217..872fc3b6b9 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -161,9 +161,9 @@ pub struct PolynomialName { /// A polynomial with an optional namespace pub struct NamespacedPolynomialReference { /// The optional namespace, if `None` then this polynomial inherits the next enclosing namespace, if any - namespace: Option, + pub namespace: Option, /// The underlying polynomial - pol: IndexedPolynomialReference, + pub pol: IndexedPolynomialReference, } impl NamespacedPolynomialReference { diff --git a/compiler/benches/executor_benchmark.rs b/compiler/benches/executor_benchmark.rs index 15943c2888..10e6d5c464 100644 --- a/compiler/benches/executor_benchmark.rs +++ b/compiler/benches/executor_benchmark.rs @@ -1,5 +1,4 @@ use ::compiler::inputs_to_query_callback; -use analysis::analyze; use ast::analyzed::Analyzed; use criterion::{criterion_group, criterion_main, Criterion}; @@ -19,7 +18,7 @@ fn get_pil() -> Analyzed { let contents = compiler::compile(riscv_asm_files, &CoProcessors::base()); let parsed = parser::parse_asm::(None, &contents).unwrap(); let resolved = importer::resolve(None, parsed).unwrap(); - let analyzed = analyze(resolved).unwrap(); + let analyzed = analysis::convert_asm_to_pil(resolved).unwrap(); let graph = airgen::compile(analyzed); let pil = linker::link(graph).unwrap(); let analyzed = pil_analyzer::analyze_string(&format!("{pil}")); @@ -27,7 +26,7 @@ fn get_pil() -> Analyzed { } fn run_witgen(analyzed: &Analyzed, input: Vec) { - let query_callback = inputs_to_query_callback(input); + let query_callback = inputs_to_query_callback(&input); let (constants, degree) = constant_evaluator::generate(analyzed); executor::witgen::WitnessGenerator::new(analyzed, degree, &constants, query_callback) .generate(); diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index 47b2294f23..49a3cb938d 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -10,12 +10,15 @@ use std::path::Path; use std::path::PathBuf; use std::time::Instant; +use analysis::analyze; +use analysis::convert_analyzed_to_pil_constraints; use ast::analyzed::Analyzed; +use ast::DiffMonitor; pub mod util; mod verify; -use analysis::analyze; +use ast::asm_analysis::AnalysisASMFile; pub use backend::{BackendType, Proof}; use executor::witgen::QueryCallback; use number::write_polys_file; @@ -54,7 +57,7 @@ pub fn compile_pil_or_asm( Ok(Some(compile_pil( Path::new(file_name), output_dir, - inputs_to_query_callback(inputs), + inputs_to_query_callback(&inputs), prove_with, external_witness_values, ))) @@ -123,7 +126,8 @@ pub fn compile_asm( Ok(compile_asm_string( file_name, &contents, - inputs, + &inputs, + None, output_dir, force_overwrite, prove_with, @@ -132,19 +136,11 @@ pub fn compile_asm( .1) } -/// Compiles the contents of a .asm file, outputs the PIL on stdout and tries to generate -/// fixed and witness columns. -/// -/// Returns the relative pil file name and the compilation result if any compilation was done. -pub fn compile_asm_string( +pub fn compile_asm_string_to_analyzed_ast( file_name: &str, contents: &str, - inputs: Vec, - output_dir: &Path, - force_overwrite: bool, - prove_with: Option, - external_witness_values: Vec<(&str, Vec)>, -) -> Result<(PathBuf, Option>), Vec> { + monitor: &mut DiffMonitor, +) -> Result, Vec> { let parsed = parser::parse_asm(Some(file_name), contents).unwrap_or_else(|err| { eprintln!("Error parsing .asm file:"); err.output_to_stderr(); @@ -154,11 +150,26 @@ pub fn compile_asm_string( let resolved = importer::resolve(Some(PathBuf::from(file_name)), parsed).map_err(|e| vec![e])?; log::debug!("Run analysis"); - let analysed = analyze(resolved).unwrap(); + let analyzed = analyze(resolved, monitor)?; log::debug!("Analysis done"); - log::trace!("{analysed}"); + log::trace!("{analyzed}"); + + Ok(analyzed) +} + +pub fn convert_analyzed_to_pil( + file_name: &str, + monitor: &mut DiffMonitor, + analyzed: AnalysisASMFile, + inputs: &[T], + output_dir: &Path, + force_overwrite: bool, + prove_with: Option, + external_witness_values: Vec<(&str, Vec)>, +) -> Result<(PathBuf, Option>), Vec> { + let constraints = convert_analyzed_to_pil_constraints(analyzed, monitor); log::debug!("Run airgen"); - let graph = airgen::compile(analysed); + let graph = airgen::compile(constraints); log::debug!("Airgen done"); log::trace!("{graph}"); log::debug!("Run linker"); @@ -196,6 +207,37 @@ pub fn compile_asm_string( )) } +/// Compiles the contents of a .asm file, outputs the PIL on stdout and tries to generate +/// fixed and witness columns. +/// +/// Returns the relative pil file name and the compilation result if any compilation was done. +pub fn compile_asm_string( + file_name: &str, + contents: &str, + inputs: &[T], + analyzed_hook: Option<&mut dyn FnMut(&AnalysisASMFile)>, + output_dir: &Path, + force_overwrite: bool, + prove_with: Option, + external_witness_values: Vec<(&str, Vec)>, +) -> Result<(PathBuf, Option>), Vec> { + let mut monitor = DiffMonitor::default(); + let analyzed = compile_asm_string_to_analyzed_ast(file_name, contents, &mut monitor)?; + if let Some(hook) = analyzed_hook { + hook(&analyzed); + }; + convert_analyzed_to_pil( + file_name, + &mut monitor, + analyzed, + inputs, + output_dir, + force_overwrite, + prove_with, + external_witness_values, + ) +} + pub struct CompilationResult { /// Constant columns, potentially incomplete (if success is false) pub constants: Vec<(String, Vec)>, @@ -336,7 +378,9 @@ fn write_commits_to_fs( } #[allow(clippy::print_stdout)] -pub fn inputs_to_query_callback(inputs: Vec) -> impl Fn(&str) -> Option { +pub fn inputs_to_query_callback<'a, T: FieldElement>( + inputs: &'a [T], +) -> impl Fn(&str) -> Option + 'a { move |query: &str| -> Option { let items = query.split(',').map(|s| s.trim()).collect::>(); match items[0] { diff --git a/compiler/src/verify.rs b/compiler/src/verify.rs index a3272e1c49..43ebba9b18 100644 --- a/compiler/src/verify.rs +++ b/compiler/src/verify.rs @@ -9,7 +9,8 @@ pub fn verify_asm_string(file_name: &str, contents: &str, input compile_asm_string( file_name, contents, - inputs, + &inputs, + None, &temp_dir, true, Some(BackendType::PilStarkCli), diff --git a/halo2/src/mock_prover.rs b/halo2/src/mock_prover.rs index 24fefaa5d4..95ca8c872c 100644 --- a/halo2/src/mock_prover.rs +++ b/halo2/src/mock_prover.rs @@ -35,7 +35,7 @@ pub fn mock_prove( mod test { use std::{fs, path::PathBuf}; - use analysis::analyze; + use analysis::convert_asm_to_pil; use number::Bn254Field; use parser::parse_asm; use test_log::test; @@ -53,7 +53,7 @@ mod test { let contents = fs::read_to_string(&location).unwrap(); let parsed = parse_asm::(Some(&location), &contents).unwrap(); let resolved = importer::resolve(Some(PathBuf::from(location)), parsed).unwrap(); - let analysed = analyze(resolved).unwrap(); + let analysed = convert_asm_to_pil(resolved).unwrap(); let graph = airgen::compile(analysed); let pil = linker::link(graph).unwrap(); diff --git a/linker/src/lib.rs b/linker/src/lib.rs index 3964995c4e..ae700c64be 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -151,7 +151,7 @@ mod test { }; use number::{Bn254Field, FieldElement, GoldilocksField}; - use analysis::analyze; + use analysis::convert_asm_to_pil; use parser::parse_asm; use pretty_assertions::assert_eq; @@ -161,7 +161,7 @@ mod test { fn parse_analyse_and_compile(input: &str) -> PILGraph { let parsed = parse_asm(None, input).unwrap(); let resolved = importer::resolve(None, parsed).unwrap(); - airgen::compile(analyze(resolved).unwrap()) + airgen::compile(convert_asm_to_pil(resolved).unwrap()) } #[test] diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 6599e22c7e..aae8670ab4 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -333,6 +333,8 @@ DebugDirectiveStatement: FunctionStatement = { => FunctionStatement::DebugDirective(l, DebugDirective::File(n.try_into().unwrap(), d, f)), "debug" "loc" ";" => FunctionStatement::DebugDirective(l, DebugDirective::Loc(f.try_into().unwrap(), line.try_into().unwrap(), col.try_into().unwrap())), + "debug" "insn" ";" + => FunctionStatement::DebugDirective(l, DebugDirective::OriginalInstruction(insn)), } LabelStatement: FunctionStatement = { @@ -549,4 +551,4 @@ FieldElement: T = { Integer: AbstractNumberType = { r"[0-9][0-9_]*" => AbstractNumberType::from_str(&<>.replace('_', "")).unwrap(), r"0x[0-9A-Fa-f][0-9A-Fa-f_]*" => AbstractNumberType::from_str_radix(&<>[2..].replace('_', ""), 16).unwrap(), -} \ No newline at end of file +} diff --git a/powdr_cli/Cargo.toml b/powdr_cli/Cargo.toml index 0f9c1d195d..e1aaeebbff 100644 --- a/powdr_cli/Cargo.toml +++ b/powdr_cli/Cargo.toml @@ -4,7 +4,7 @@ version = "0.1.0" edition = "2021" [features] -default = [] # halo2 is disabled by default +default = [] # halo2 is disabled by default halo2 = ["dep:halo2", "backend/halo2", "compiler/halo2"] [dependencies] @@ -14,6 +14,7 @@ log = "0.4.17" compiler = { path = "../compiler" } parser = { path = "../parser" } riscv = { path = "../riscv" } +riscv_executor = { path = "../riscv_executor" } number = { path = "../number" } halo2 = { path = "../halo2", optional = true } backend = { path = "../backend" } diff --git a/powdr_cli/src/main.rs b/powdr_cli/src/main.rs index b75bad3e5d..3ce3ba95db 100644 --- a/powdr_cli/src/main.rs +++ b/powdr_cli/src/main.rs @@ -94,6 +94,11 @@ enum Commands { #[arg(default_value_t = CsvRenderModeCLI::Hex)] #[arg(value_parser = clap_enum_variants!(CsvRenderModeCLI))] csv_mode: CsvRenderModeCLI, + + /// Just execute in the RISCV/Powdr executor + #[arg(short, long)] + #[arg(default_value_t = false)] + just_execute: bool, }, /// Compiles (no-std) rust code to riscv assembly, then to powdr assembly /// and finally to PIL and generates fixed and witness columns. @@ -131,6 +136,11 @@ enum Commands { /// Comma-separated list of coprocessors. #[arg(long)] coprocessors: Option, + + /// Just execute in the RISCV/Powdr executor + #[arg(short, long)] + #[arg(default_value_t = false)] + just_execute: bool, }, /// Compiles riscv assembly to powdr assembly and then to PIL @@ -169,6 +179,11 @@ enum Commands { /// Comma-separated list of coprocessors. #[arg(long)] coprocessors: Option, + + /// Just execute in the RISCV/Powdr executor + #[arg(short, long)] + #[arg(default_value_t = false)] + just_execute: bool, }, Prove { @@ -299,6 +314,7 @@ fn run_command(command: Commands) { force, prove_with, coprocessors, + just_execute, } => { let coprocessors = match coprocessors { Some(list) => { @@ -312,7 +328,8 @@ fn run_command(command: Commands) { Path::new(&output_directory), force, prove_with, - coprocessors + coprocessors, + just_execute )) { eprintln!("Errors:"); for e in errors { @@ -328,6 +345,7 @@ fn run_command(command: Commands) { force, prove_with, coprocessors, + just_execute, } => { assert!(!files.is_empty()); let name = if files.len() == 1 { @@ -349,7 +367,8 @@ fn run_command(command: Commands) { Path::new(&output_directory), force, prove_with, - coprocessors + coprocessors, + just_execute )) { eprintln!("Errors:"); for e in errors { @@ -377,25 +396,33 @@ fn run_command(command: Commands) { prove_with, export_csv, csv_mode, + just_execute, } => { - match call_with_field!(compile_with_csv_export::( - file, - output_directory, - witness_values, - inputs, - force, - prove_with, - export_csv, - csv_mode - )) { - Ok(()) => {} - Err(errors) => { - eprintln!("Errors:"); - for e in errors { - eprintln!("{e}"); + if just_execute { + // assume input is riscv asm and just execute it + let contents = fs::read_to_string(file).unwrap(); + let inputs = split_inputs(&inputs); + riscv_executor::execute::(&contents, &inputs); + } else { + match call_with_field!(compile_with_csv_export::( + file, + output_directory, + witness_values, + inputs, + force, + prove_with, + export_csv, + csv_mode + )) { + Ok(()) => {} + Err(errors) => { + eprintln!("Errors:"); + for e in errors { + eprintln!("{e}"); + } } - } - }; + }; + } } Commands::Prove { file, @@ -442,19 +469,20 @@ fn run_rust( force_overwrite: bool, prove_with: Option, coprocessors: riscv::CoProcessors, + just_execute: bool, ) -> Result<(), Vec> { let (asm_file_path, asm_contents) = compile_rust(file_name, output_dir, force_overwrite, &coprocessors) .ok_or_else(|| vec!["could not compile rust".to_string()])?; - compile_asm_string( + handle_riscv_asm( asm_file_path.to_str().unwrap(), &asm_contents, inputs, output_dir, force_overwrite, prove_with, - vec![], + just_execute, )?; Ok(()) } @@ -467,6 +495,7 @@ fn run_riscv_asm( force_overwrite: bool, prove_with: Option, coprocessors: riscv::CoProcessors, + just_execute: bool, ) -> Result<(), Vec> { let (asm_file_path, asm_contents) = compile_riscv_asm( original_file_name, @@ -477,18 +506,44 @@ fn run_riscv_asm( ) .ok_or_else(|| vec!["could not compile RISC-V assembly".to_string()])?; - compile_asm_string( + handle_riscv_asm( asm_file_path.to_str().unwrap(), &asm_contents, inputs, output_dir, force_overwrite, prove_with, - vec![], + just_execute, )?; Ok(()) } +fn handle_riscv_asm( + file_name: &str, + contents: &str, + inputs: Vec, + output_dir: &Path, + force_overwrite: bool, + prove_with: Option, + just_execute: bool, +) -> Result<(), Vec> { + if just_execute { + riscv_executor::execute::(contents, &inputs); + } else { + compile_asm_string( + file_name, + contents, + &inputs, + None, + output_dir, + force_overwrite, + prove_with, + vec![] + )?; + } + Ok(()) +} + #[allow(clippy::too_many_arguments)] fn compile_with_csv_export( file: String, @@ -629,6 +684,7 @@ mod test { prove_with: Some(BackendType::PilStarkCli), export_csv: true, csv_mode: CsvRenderModeCLI::Hex, + just_execute: false, }; run_command(pil_command); diff --git a/riscv/Cargo.toml b/riscv/Cargo.toml index 50f6075164..259766bb4a 100644 --- a/riscv/Cargo.toml +++ b/riscv/Cargo.toml @@ -24,6 +24,7 @@ lalrpop = "^0.19" [dev-dependencies] test-log = "0.2.12" env_logger = "0.10.0" +hex = "0.4.3" number = { path = "../number" } compiler = { path = "../compiler" } -hex = "0.4.3" +riscv_executor = { path = "../riscv_executor" } diff --git a/riscv/src/compiler.rs b/riscv/src/compiler.rs index ed290197a8..a8435f4416 100644 --- a/riscv/src/compiler.rs +++ b/riscv/src/compiler.rs @@ -759,10 +759,18 @@ fn process_statement(s: Statement, coprocessors: &CoProcessors) -> Vec { args.iter().format(", ") ), }, - Statement::Instruction(instr, args) => process_instruction(instr, args, coprocessors) - .into_iter() - .map(|s| " ".to_string() + &s) - .collect(), + Statement::Instruction(instr, args) => { + let stmt_str = format!("{s}"); + // remove indentation and trailing newline + let stmt_str = &stmt_str[2..(stmt_str.len() - 1)]; + let mut ret = vec![format!(" debug insn \"{stmt_str}\";")]; + ret.extend( + process_instruction(instr, args, coprocessors) + .into_iter() + .map(|s| " ".to_string() + &s), + ); + ret + } } } diff --git a/riscv/src/disambiguator.rs b/riscv/src/disambiguator.rs index 8359e8fd74..06c9ff9d73 100644 --- a/riscv/src/disambiguator.rs +++ b/riscv/src/disambiguator.rs @@ -88,11 +88,14 @@ fn disambiguate_file_ids( .iter() .flat_map(|(name, statements)| extract_file_ids(name, statements)) .collect::>(); - let debug_file_id_mapping = debug_file_ids - .iter() - .enumerate() - .map(|(i, (asm_name, file_id, ..))| ((asm_name.to_string(), *file_id), i as i64 + 1)) - .collect::>(); + // ensure the ids are densely packed: + let debug_file_id_mapping = { + let mut map = HashMap::new(); + for (asm_name, file_id, ..) in debug_file_ids.iter() { + map.insert((asm_name.to_string(), *file_id), map.len() as i64 + 1); + } + map + }; let new_debug_file_ids = debug_file_ids .into_iter() .map(|(asm_file, id, dir, file)| { diff --git a/riscv/tests/instructions.rs b/riscv/tests/instructions.rs index 9248d6c174..5445b2d441 100644 --- a/riscv/tests/instructions.rs +++ b/riscv/tests/instructions.rs @@ -1,6 +1,7 @@ +mod common; + mod instruction_tests { - use compiler::verify_asm_string; - use number::GoldilocksField; + use crate::common::verify_riscv_asm_string; use riscv::compiler::compile; use riscv::CoProcessors; use test_log::test; @@ -12,7 +13,7 @@ mod instruction_tests { &CoProcessors::base(), ); - verify_asm_string::(&format!("{name}.asm"), &powdr_asm, vec![]); + verify_riscv_asm_string(&format!("{name}.asm"), &powdr_asm, vec![]); } include!(concat!(env!("OUT_DIR"), "/instruction_tests.rs")); diff --git a/riscv/tests/riscv.rs b/riscv/tests/riscv.rs index 9b3f24e040..aae5129a56 100644 --- a/riscv/tests/riscv.rs +++ b/riscv/tests/riscv.rs @@ -1,4 +1,6 @@ -use compiler::verify_asm_string; +mod common; + +use common::verify_riscv_asm_string; use mktemp::Temp; use number::GoldilocksField; use test_log::test; @@ -143,7 +145,7 @@ fn verify_file(case: &str, inputs: Vec, coprocessors: &CoProces riscv::compile_rust_to_riscv_asm(&format!("tests/riscv_data/{case}"), &temp_dir); let powdr_asm = riscv::compiler::compile(riscv_asm, coprocessors); - verify_asm_string(&format!("{case}.asm"), &powdr_asm, inputs); + verify_riscv_asm_string(&format!("{case}.asm"), &powdr_asm, inputs); } fn verify_crate(case: &str, inputs: Vec, coprocessors: &CoProcessors) { @@ -154,5 +156,5 @@ fn verify_crate(case: &str, inputs: Vec, coprocessors: &CoProce ); let powdr_asm = riscv::compiler::compile(riscv_asm, coprocessors); - verify_asm_string(&format!("{case}.asm"), &powdr_asm, inputs); + verify_riscv_asm_string(&format!("{case}.asm"), &powdr_asm, inputs); } diff --git a/riscv/tests/riscv_data/trivial.rs b/riscv/tests/riscv_data/trivial.rs index 3aadb5c8f8..b8eeca29d7 100644 --- a/riscv/tests/riscv_data/trivial.rs +++ b/riscv/tests/riscv_data/trivial.rs @@ -1,5 +1,4 @@ #![no_std] #[no_mangle] -pub fn main() { -} +pub fn main() {} diff --git a/riscv_executor/Cargo.toml b/riscv_executor/Cargo.toml new file mode 100644 index 0000000000..5b2d05b62e --- /dev/null +++ b/riscv_executor/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "riscv_executor" +version = "0.1.0" +edition = "2021" + +[dependencies] +log = "0.4.17" +itertools = "0.11" +ast = { path = "../ast" } +number = { path = "../number" } +parser = { path = "../parser" } +importer = { path = "../importer" } +analysis = { path = "../analysis" } diff --git a/riscv_executor/src/lib.rs b/riscv_executor/src/lib.rs new file mode 100644 index 0000000000..57feba76dd --- /dev/null +++ b/riscv_executor/src/lib.rs @@ -0,0 +1,841 @@ +//! A specialized executor for our RISC-V assembly that can speedup witgen. +//! +//! WARNING: the general witness generation/execution code over the polynomial +//! constraints try to ensure the determinism of the instructions. If we bypass +//! much of witness generation using the present module, we lose the +//! non-determinism verification. +//! +//! TODO: perform determinism verification for each instruction independently +//! from execution. + +use std::{ + collections::HashMap, + io::{self, Write}, +}; + +use ast::{ + asm_analysis::{AnalysisASMFile, CallableSymbol, FunctionStatement, LabelStatement, Machine}, + parsed::{ + asm::{AssignmentRegister, DebugDirective}, + Expression, + }, +}; +use builder::{MemoryBuilder, TraceBuilder}; +use number::{BigInt, FieldElement}; + +#[derive(Clone, Copy, PartialEq, Eq)] +struct Elem(i64); + +impl Elem { + const fn zero() -> Self { + Self(0) + } + + fn u(&self) -> u32 { + self.0.try_into().unwrap() + } + + fn s(&self) -> i32 { + self.0.try_into().unwrap() + } +} + +impl From for Elem { + fn from(value: u32) -> Self { + Self(value as i64) + } +} + +impl From for Elem { + fn from(value: i64) -> Self { + Self(value) + } +} + +impl From for Elem { + fn from(value: i32) -> Self { + Self(value as i64) + } +} + +pub struct ExecutionTrace<'a> { + reg_map: HashMap<&'a str, usize>, + + /// Values of the registers in the execution trace. + /// + /// Each N elements is a row with all registers. + values: Vec, +} + +mod builder { + use std::collections::HashMap; + + use ast::asm_analysis::Machine; + use number::FieldElement; + + use crate::{Elem, ExecutionTrace}; + + fn register_names(main: &Machine) -> Vec<&str> { + main.registers.iter().map(|stmnt| &stmnt.name[..]).collect() + } + + pub struct TraceBuilder<'a, 'b> { + trace: ExecutionTrace<'a>, + + /// First register of current row. + /// Next row is reg_map.len() elems ahead. + curr_idx: usize, + + // index of special case registers to look after: + x0_idx: usize, + pc_idx: usize, + + /// The PC in the register bank refers to the batches, we have to track our + /// actual program counter independently. + next_statement_line: u32, + + /// When PC is written, we need to know what line to actually execute next + /// from this map of batch to statement line. + batch_to_line_map: &'b [u32], + } + + impl<'a, 'b> TraceBuilder<'a, 'b> { + // creates a new builder + pub fn new(main: &'a Machine, batch_to_line_map: &'b [u32]) -> Self { + let reg_map = register_names(main) + .into_iter() + .enumerate() + .map(|(i, name)| (name, i)) + .collect::>(); + + // first row has all values zeroed + let values = vec![Elem::zero(); 2 * reg_map.len()]; + + let mut ret = Self { + curr_idx: 0, + x0_idx: reg_map["x0"], + pc_idx: reg_map["pc"], + trace: ExecutionTrace { reg_map, values }, + next_statement_line: 1, + batch_to_line_map, + }; + + ret.set_next_pc(); + + ret + } + + /// get current value of register + pub(crate) fn g(&self, idx: &str) -> Elem { + self.g_idx(self.trace.reg_map[idx]) + } + + /// get current value of register by register index instead of name + fn g_idx(&self, idx: usize) -> Elem { + self.trace.values[self.curr_idx + idx] + } + + fn g_idx_next(&self, idx: usize) -> Elem { + self.trace.values[self.curr_idx + self.reg_len() + idx] + } + + /// set next value of register, accounting to x0 or pc writes + pub(crate) fn s(&mut self, idx: &str, value: impl Into) { + self.s_impl(idx, value.into()) + } + + fn s_impl(&mut self, idx: &str, value: Elem) { + let idx = self.trace.reg_map[idx]; + if idx == self.x0_idx { + return; + } else if idx == self.pc_idx { + // PC has been written, so we must update our statement-based + // program counter accordingly: + self.next_statement_line = self.batch_to_line_map[value.u() as usize]; + } + self.s_idx(idx, value); + } + + /// raw set next value of register by register index instead of name + fn s_idx(&mut self, idx: usize, value: Elem) { + let final_idx = self.curr_idx + self.reg_len() + idx; + self.trace.values[final_idx] = value; + } + + /// advance to next row, returns the index to the statement that must be + /// executed now + pub fn advance(&mut self, was_nop: bool) -> u32 { + if self.g_idx(self.pc_idx) != self.g_idx_next(self.pc_idx) { + // PC changed, create a new line + self.curr_idx += self.reg_len(); + self.trace.values.extend_from_within(self.curr_idx..); + } else { + // PC didn't change, execution was inside same batch, + // so there is no need to create a new row, just update curr + if !was_nop { + let next_idx = self.curr_idx + self.reg_len(); + self.trace.values.copy_within(next_idx.., self.curr_idx); + } + } + + // advance the next statement + let curr_line = self.next_statement_line; + self.next_statement_line += 1; + + // optimistically write next PC, but the code might rewrite it + self.set_next_pc(); + + curr_line + } + + pub fn finish(self) -> ExecutionTrace<'a> { + self.trace + } + + fn reg_len(&self) -> usize { + self.trace.reg_map.len() + } + + fn set_next_pc(&mut self) { + let curr_pc = self.g_idx(self.pc_idx).u(); + + let line_of_next_batch = self.batch_to_line_map[curr_pc as usize + 1]; + + self.s_idx( + self.pc_idx, + if self.next_statement_line >= line_of_next_batch { + assert_eq!(self.next_statement_line, line_of_next_batch); + curr_pc + 1 + } else { + curr_pc + } + .into(), + ); + } + } + + pub struct MemoryBuilder( + // TODO: track modifications to help build the memory machine + HashMap, + ); + + impl MemoryBuilder { + pub fn new() -> Self { + Self(HashMap::new()) + } + + pub(crate) fn s(&mut self, addr: u32, val: Elem) { + if val.u() != 0 { + self.0.insert(addr, val); + } else { + self.0.remove(&addr); + } + } + + pub(crate) fn g(&mut self, addr: u32) -> Elem { + *self.0.get(&addr).unwrap_or(&Elem::zero()) + } + } +} + +fn get_main_machine(program: &AnalysisASMFile) -> &Machine { + for (name, m) in program.machines.iter() { + if name.parts.len() == 1 && name.parts[0] == "Main" { + return m; + } + } + panic!(); +} + +/// Returns the list of instructions, directly indexable by PC, the map from +/// labels to indices into that list, and the list with the start of each batch. +fn preprocess_main_function( + machine: &Machine, +) -> ( + Vec<&FunctionStatement>, + HashMap<&str, Elem>, + Vec, + Vec<(&str, &str)>, +) { + let CallableSymbol::Function(main_function) = &machine.callable.0["main"] else { + panic!("main function missing") + }; + + let orig_statements = &main_function.body.statements; + + let mut statements = Vec::new(); + let mut label_map = HashMap::new(); + let mut batch_to_line_map = Vec::new(); + let mut debug_files = Vec::new(); + + for (batch_idx, batch) in orig_statements.iter_batches().enumerate() { + batch_to_line_map.push(statements.len() as u32); + let mut statement_seen = false; + for s in batch.statements { + match s { + FunctionStatement::Assignment(_) + | FunctionStatement::Instruction(_) + | FunctionStatement::Return(_) => { + statement_seen = true; + statements.push(s) + } + FunctionStatement::DebugDirective(d) => { + match &d.directive { + DebugDirective::File(idx, dir, file) => { + // debug files should be densely packed starting + // from 1, so the idx should match vec size + 1: + assert_eq!(*idx, debug_files.len() + 1); + debug_files.push((dir.as_str(), file.as_str())); + } + DebugDirective::Loc(_, _, _) | DebugDirective::OriginalInstruction(_) => { + // keep debug locs for debugging purposes + statements.push(s); + } + } + } + FunctionStatement::Label(LabelStatement { start: _, name }) => { + // assert there are no statements in the middle of a block + assert!(!statement_seen); + label_map.insert(name.as_str(), (batch_idx as i64).into()); + } + } + } + } + assert!(statements.len() <= u32::MAX as usize); + + // add a final element to the map so the queries don't overflow: + batch_to_line_map.push(statements.len() as u32); + + (statements, label_map, batch_to_line_map, debug_files) +} + +struct Executor<'a, 'b, F: FieldElement> { + proc: TraceBuilder<'a, 'b>, + mem: MemoryBuilder, + label_map: HashMap<&'a str, Elem>, + inputs: &'b [F], + stdout: io::Stdout, +} + +impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { + fn exec_instruction(&mut self, name: &str, args: &[Expression]) -> Vec { + let args = args + .iter() + .map(|expr| self.eval_expression(expr)[0]) + .collect::>(); + + match name { + "mstore" => { + // input + self.proc.s("Y", args[0]); + self.proc.s("Z", args[1]); + + // execution + let addr = args[0].0 as u32; + assert_eq!(addr % 4, 0); + self.mem.s(args[0].0 as u32, args[1]); + + // no output + Vec::new() + } + "mload" => { + // input + self.proc.s("Y", args[0]); + + // execution + let addr = args[0].0 as u32; + let val = self.mem.g(addr & 0xfffffffc); + let rem = addr % 4; + + // output + self.proc.s("X", val); + self.proc.s("Z", rem); + vec![val, rem.into()] + } + "jump" => { + // no register input + + // execution + self.proc.s("pc", args[0]); + + // no output + Vec::new() + } + "load_label" => { + // no register input + + // no execution + + // output + self.proc.s("X", args[0]); + args + } + "jump_dyn" => { + // input + self.proc.s("X", args[0]); + + // execution + self.proc.s("pc", args[0]); + + // no output + Vec::new() + } + "jump_and_link_dyn" => { + // input + self.proc.s("X", args[0]); + + // execution + let pc = self.proc.g("pc"); + self.proc.s("x1", pc.u() + 1); + self.proc.s("pc", args[0]); + + // no output + Vec::new() + } + "call" => { + // no register input + + // execution + let pc = self.proc.g("pc"); + self.proc.s("x1", pc.u() + 1); + self.proc.s("pc", args[0]); + + // no output + Vec::new() + } + "tail" => { + // no register input + + // execution + self.proc.s("pc", args[0]); + self.proc.s("x6", args[0]); + + // no output + Vec::new() + } + "ret" => { + // no input + + // execution + let target = self.proc.g("x1"); + self.proc.s("pc", target); + + // no output + Vec::new() + } + "branch_if_nonzero" => { + // input + self.proc.s("X", args[0]); + + // execution + if args[0].0 != 0 { + self.proc.s("pc", args[1]); + } + + // no output + Vec::new() + } + "branch_if_zero" => { + // input + self.proc.s("X", args[0]); + + // execution + if args[0].0 == 0 { + self.proc.s("pc", args[1]); + } + + // no output + Vec::new() + } + "skip_if_zero" => { + // input + self.proc.s("X", args[0]); + self.proc.s("Y", args[1]); + + // execution + if args[0].0 == 0 { + let pc = self.proc.g("pc").s(); + self.proc.s("pc", pc + args[1].s() + 1); + } + + // no output + Vec::new() + } + "branch_if_positive" => { + // input + self.proc.s("X", args[0]); + + // execution + if args[0].0 > 0 { + self.proc.s("pc", args[1]); + } + + // no output + Vec::new() + } + "is_positive" => { + // input + self.proc.s("X", args[0]); + + // execution + let r = if args[0].0 > 0 { 1 } else { 0 }; + + // output + self.proc.s("Y", r); + vec![r.into()] + } + "is_equal_zero" => { + // input + self.proc.s("X", args[0]); + + // execution + let r = if args[0].0 == 0 { 1 } else { 0 }; + + // output + self.proc.s("Y", r); + vec![r.into()] + } + "is_not_equal_zero" => { + // input + self.proc.s("X", args[0]); + + // execution + let r = if args[0].0 != 0 { 1 } else { 0 }; + + // output + self.proc.s("Y", r); + vec![r.into()] + } + "wrap" | "wrap16" => { + // input + self.proc.s("Y", args[0]); + + // execution + let r = args[0].0 as u32; + + // output + self.proc.s("X", r); + vec![r.into()] + } + "wrap_signed" => { + // input + self.proc.s("Y", args[0]); + + // execution + let r = (args[0].0 + 0x100000000) as u32; + + // output + self.proc.s("X", r); + vec![r.into()] + } + "sign_extend_byte" => { + // input + self.proc.s("Y", args[0]); + + // execution + let r = args[0].u() as i8 as u32; + + // output + self.proc.s("X", r); + vec![r.into()] + } + "sign_extend_16_bits" => { + // input + self.proc.s("Y", args[0]); + + // execution + let r = args[0].u() as i16 as u32; + + // output + self.proc.s("X", r); + vec![r.into()] + } + "to_signed" => { + // input + self.proc.s("Y", args[0]); + + // execution + let r = args[0].u() as i32; + + // output + self.proc.s("X", r); + vec![r.into()] + } + "fail" => { + // TODO: handle it better + panic!("reached a fail instruction") + } + "divremu" => { + // input + self.proc.s("Y", args[0]); + self.proc.s("X", args[1]); + + // execution + let y = args[0].u(); + let x = args[1].u(); + let div; + let rem; + if x != 0 { + div = y / x; + rem = y % x; + } else { + div = 0xffffffff; + rem = y; + } + + // output + self.proc.s("Z", div); + self.proc.s("W", rem); + vec![div.into(), rem.into()] + } + "mul" => { + // input + self.proc.s("Z", args[0]); + self.proc.s("W", args[1]); + + // execution + let r = args[0].u() as u64 * args[1].u() as u64; + let lo = r as u32; + let hi = (r >> 32) as u32; + + // output + self.proc.s("X", lo); + self.proc.s("Y", hi); + vec![lo.into(), hi.into()] + } + bin_op => { + // input + self.proc.s("Y", args[0]); + self.proc.s("Z", args[1]); + + let val = match bin_op { + "poseidon" => todo!(), + "and" => (args[0].u() & args[1].u()).into(), + "or" => (args[0].u() | args[1].u()).into(), + "xor" => (args[0].u() ^ args[1].u()).into(), + "shl" => (args[0].u() << args[1].u()).into(), + "shr" => (args[0].u() >> args[1].u()).into(), + _ => { + unreachable!() + } + }; + + // output + self.proc.s("X", val); + vec![val] + } + } + } + + fn eval_expression(&mut self, expression: &Expression) -> Vec { + match expression { + Expression::Reference(r) => { + // an identifier looks like this: + assert!(r.namespace.is_none()); + assert!(r.pol.index().is_none()); + + let name = r.pol.name(); + + // labels share the identifier space with registers: + // try one, then the other + let val = self + .label_map + .get(name) + .cloned() + .unwrap_or_else(|| self.proc.g(name)); + vec![val] + } + Expression::PublicReference(_) => todo!(), + Expression::Number(n) => { + vec![if let Some(unsigned) = to_u32(n) { + unsigned.into() + } else { + panic!("Value does not fit in 32 bits.") + }] + } + Expression::String(_) => todo!(), + Expression::Tuple(_) => todo!(), + Expression::LambdaExpression(_) => todo!(), + Expression::ArrayLiteral(_) => todo!(), + Expression::BinaryOperation(l, op, r) => { + let l = self.eval_expression(l)[0]; + let r = self.eval_expression(r)[0]; + + let result = match op { + ast::parsed::BinaryOperator::Add => l.0 + r.0, + ast::parsed::BinaryOperator::Sub => l.0 - r.0, + ast::parsed::BinaryOperator::Mul => l.0 * r.0, + ast::parsed::BinaryOperator::Div => l.0 / r.0, + ast::parsed::BinaryOperator::Mod => l.0 % r.0, + ast::parsed::BinaryOperator::Pow => l.0.pow(r.u()), + ast::parsed::BinaryOperator::BinaryAnd => todo!(), + ast::parsed::BinaryOperator::BinaryXor => todo!(), + ast::parsed::BinaryOperator::BinaryOr => todo!(), + ast::parsed::BinaryOperator::ShiftLeft => todo!(), + ast::parsed::BinaryOperator::ShiftRight => todo!(), + ast::parsed::BinaryOperator::LogicalOr => todo!(), + ast::parsed::BinaryOperator::LogicalAnd => todo!(), + ast::parsed::BinaryOperator::Less => todo!(), + ast::parsed::BinaryOperator::LessEqual => todo!(), + ast::parsed::BinaryOperator::Equal => todo!(), + ast::parsed::BinaryOperator::NotEqual => todo!(), + ast::parsed::BinaryOperator::GreaterEqual => todo!(), + ast::parsed::BinaryOperator::Greater => todo!(), + }; + + vec![result.into()] + } + Expression::UnaryOperation(op, arg) => { + let arg = self.eval_expression(arg)[0]; + let result = match op { + ast::parsed::UnaryOperator::Plus => arg.0, + ast::parsed::UnaryOperator::Minus => -arg.0, + ast::parsed::UnaryOperator::LogicalNot => todo!(), + ast::parsed::UnaryOperator::Next => unreachable!(), + }; + + vec![result.into()] + } + Expression::FunctionCall(f) => self.exec_instruction(&f.id, &f.arguments), + Expression::FreeInput(expr) => 'input: { + if let Expression::Tuple(t) = &**expr { + if let Expression::String(name) = &t[0] { + let val = self.eval_expression(&t[1])[0]; + break 'input vec![match name.as_str() { + "input" => { + let idx = val.u() as usize; + to_u32(&self.inputs[idx]).unwrap().into() + } + "print_char" => { + self.stdout.write(&[val.u() as u8]).unwrap(); + // what is print_char supposed to return? + Elem::zero() + } + unk => { + panic!("unknown IO command: {unk}"); + } + }]; + } + }; + panic!("does not matched IO pattern") + } + Expression::MatchExpression(_, _) => todo!(), + } + } +} + +pub fn execute_ast<'a, T: FieldElement>( + program: &'a AnalysisASMFile, + inputs: &[T], +) -> ExecutionTrace<'a> { + let main_machine = get_main_machine(program); + let (statements, label_map, batch_to_line_map, debug_files) = + preprocess_main_function(main_machine); + + let mut e = Executor { + proc: TraceBuilder::new(main_machine, &batch_to_line_map), + mem: MemoryBuilder::new(), + label_map, + inputs, + stdout: io::stdout(), + }; + + let mut curr_pc = 0u32; + loop { + let stm = statements[curr_pc as usize]; + + //println!("l {curr_pc}: {stm}",); + + let is_nop = match stm { + FunctionStatement::Assignment(a) => { + let results = e.eval_expression(a.rhs.as_ref()); + assert_eq!(a.lhs_with_reg.len(), results.len()); + for ((dest, reg), val) in a.lhs_with_reg.iter().zip(results) { + let AssignmentRegister::Register(reg) = reg else { + panic!(); + }; + e.proc.s(reg, val); + e.proc.s(dest, val); + } + + false + } + FunctionStatement::Instruction(i) => { + e.exec_instruction(&i.instruction, &i.inputs); + + false + } + FunctionStatement::Return(_) => break, + FunctionStatement::DebugDirective(dd) => { + match &dd.directive { + DebugDirective::Loc(file, line, column) => { + let (dir, file) = debug_files[file - 1]; + println!("Executed {dir}/{file}:{line}:{column}"); + } + DebugDirective::OriginalInstruction(insn) => { + println!(" {insn}"); + } + DebugDirective::File(_, _, _) => unreachable!(), + }; + + true + } + FunctionStatement::Label(_) => { + unreachable!() + } + }; + + curr_pc = e.proc.advance(is_nop); + } + + e.proc.finish() +} + +/// Execute a Powdr/RISCV assembly source. +/// +/// The FieldElement is just used by the parser, before everything is converted +/// to i64, so it is probably not very important. +pub fn execute(asm_source: &str, inputs: &[F]) { + log::info!("Parsing..."); + let parsed = parser::parse_asm::(None, asm_source).unwrap(); + log::info!("Resolving imports..."); + let resolved = importer::resolve(None, parsed).unwrap(); + log::info!("Analyzing..."); + let analyzed = analysis::analyze(resolved, &mut ast::DiffMonitor::default()).unwrap(); + + log::info!("Executing..."); + execute_ast(&analyzed, inputs); +} + +fn to_u32(val: &F) -> Option { + val.to_arbitrary_integer().try_into().ok().or_else(|| { + // Number is negative, gets it binary representation as u32. + let modulus = F::modulus().to_arbitrary_integer(); + let diff = modulus - val.to_arbitrary_integer(); + if diff <= 0x80000000u32.into() { + let negated: i64 = diff.try_into().unwrap(); + Some((-negated) as u32) + } else { + None + } + }) +} + +#[cfg(test)] +mod test { + use crate::execute; + use number::GoldilocksField; + use std::fs; + + #[test] + fn execute_from_file() { + println!("{}", std::env::current_dir().unwrap().to_string_lossy()); + + println!("Loading..."); + let asm = fs::read("../tmp/evm.asm").unwrap(); + println!("Validating UTF-8..."); + let asm_str = std::str::from_utf8(&asm).unwrap(); + + execute::(asm_str, &[]); + } +}