diff --git a/common/src/rv_trace.rs b/common/src/rv_trace.rs index b6088a6ea..9d7593bcf 100644 --- a/common/src/rv_trace.rs +++ b/common/src/rv_trace.rs @@ -95,7 +95,8 @@ impl From<&RVTraceRow> for [MemoryOp; MEMORY_OPS_PER_INSTRUCTION] { MemoryOp::noop_read(), ], RV32InstructionFormat::I => match val.instruction.opcode { - RV32IM::VIRTUAL_ASSERT_HALFWORD_ALIGNMENT => [ + RV32IM::VIRTUAL_ASSERT_HALFWORD_ALIGNMENT + | RV32IM::VIRTUAL_ASSERT_WORD_ALIGNMENT => [ rs1_read(), MemoryOp::noop_read(), MemoryOp::noop_write(), @@ -232,7 +233,8 @@ impl ELFInstruction { | RV32IM::JALR | RV32IM::SW | RV32IM::LW - | RV32IM::VIRTUAL_ASSERT_HALFWORD_ALIGNMENT, + | RV32IM::VIRTUAL_ASSERT_HALFWORD_ALIGNMENT + | RV32IM::VIRTUAL_ASSERT_WORD_ALIGNMENT, ); flags[CircuitFlags::Load as usize] = matches!( @@ -275,6 +277,7 @@ impl ELFInstruction { | RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER | RV32IM::VIRTUAL_ASSERT_VALID_UNSIGNED_REMAINDER | RV32IM::VIRTUAL_ASSERT_HALFWORD_ALIGNMENT + | RV32IM::VIRTUAL_ASSERT_WORD_ALIGNMENT ); flags[CircuitFlags::ConcatLookupQueryChunks as usize] = matches!( @@ -314,12 +317,10 @@ impl ELFInstruction { RV32IM::VIRTUAL_ASSERT_EQ | RV32IM::VIRTUAL_ASSERT_LTE | RV32IM::VIRTUAL_ASSERT_HALFWORD_ALIGNMENT | + RV32IM::VIRTUAL_ASSERT_WORD_ALIGNMENT | RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER | RV32IM::VIRTUAL_ASSERT_VALID_UNSIGNED_REMAINDER | - RV32IM::VIRTUAL_ASSERT_VALID_DIV0 | - // SW and LW perform a `AssertAlignedMemoryAccessInstruction` lookup - RV32IM::SW | - RV32IM::LW + RV32IM::VIRTUAL_ASSERT_VALID_DIV0 ); // All instructions in virtual sequence are mapped from the same @@ -429,6 +430,7 @@ pub enum RV32IM { VIRTUAL_ASSERT_EQ, VIRTUAL_ASSERT_VALID_DIV0, VIRTUAL_ASSERT_HALFWORD_ALIGNMENT, + VIRTUAL_ASSERT_WORD_ALIGNMENT, } impl FromStr for RV32IM { @@ -537,6 +539,7 @@ impl RV32IM { RV32IM::SLTIU | RV32IM::VIRTUAL_MOVE | RV32IM::VIRTUAL_ASSERT_HALFWORD_ALIGNMENT | + RV32IM::VIRTUAL_ASSERT_WORD_ALIGNMENT | RV32IM::VIRTUAL_MOVSIGN => RV32InstructionFormat::I, RV32IM::LB | diff --git a/jolt-core/src/host/mod.rs b/jolt-core/src/host/mod.rs index 09758fa00..7f9ea3a11 100644 --- a/jolt-core/src/host/mod.rs +++ b/jolt-core/src/host/mod.rs @@ -25,9 +25,9 @@ use crate::{ jolt::{ instruction::{ div::DIVInstruction, divu::DIVUInstruction, lb::LBInstruction, lbu::LBUInstruction, - lh::LHInstruction, lhu::LHUInstruction, mulh::MULHInstruction, + lh::LHInstruction, lhu::LHUInstruction, lw::LWInstruction, mulh::MULHInstruction, mulhsu::MULHSUInstruction, rem::REMInstruction, remu::REMUInstruction, - sb::SBInstruction, sh::SHInstruction, VirtualInstructionSequence, + sb::SBInstruction, sh::SHInstruction, sw::SWInstruction, VirtualInstructionSequence, }, vm::{bytecode::BytecodeRow, rv32i_vm::RV32I, JoltTraceStep}, }, @@ -197,8 +197,10 @@ impl Program { tracer::RV32IM::DIVU => DIVUInstruction::<32>::virtual_trace(row), tracer::RV32IM::REM => REMInstruction::<32>::virtual_trace(row), tracer::RV32IM::REMU => REMUInstruction::<32>::virtual_trace(row), + tracer::RV32IM::SW => SWInstruction::<32>::virtual_trace(row), tracer::RV32IM::SH => SHInstruction::<32>::virtual_trace(row), tracer::RV32IM::SB => SBInstruction::<32>::virtual_trace(row), + tracer::RV32IM::LW => LWInstruction::<32>::virtual_trace(row), tracer::RV32IM::LBU => LBUInstruction::<32>::virtual_trace(row), tracer::RV32IM::LHU => LHUInstruction::<32>::virtual_trace(row), tracer::RV32IM::LB => LBInstruction::<32>::virtual_trace(row), diff --git a/jolt-core/src/jolt/instruction/lw.rs b/jolt-core/src/jolt/instruction/lw.rs new file mode 100644 index 000000000..0f78cb635 --- /dev/null +++ b/jolt-core/src/jolt/instruction/lw.rs @@ -0,0 +1,134 @@ +use tracer::{ELFInstruction, MemoryState, RVTraceRow, RegisterState, RV32IM}; + +use super::VirtualInstructionSequence; +use crate::jolt::instruction::{ + virtual_assert_aligned_memory_access::AssertAlignedMemoryAccessInstruction, JoltInstruction, +}; +/// Loads a word from memory +pub struct LWInstruction; + +impl VirtualInstructionSequence for LWInstruction { + const SEQUENCE_LENGTH: usize = 2; + + fn virtual_trace(mut trace_row: RVTraceRow) -> Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::LW); + // LW source registers + let rs1 = trace_row.instruction.rs1; + // LW operands + let rs1_val = trace_row.register_state.rs1_val.unwrap(); + let offset = trace_row.instruction.imm.unwrap(); + + let mut virtual_trace = vec![]; + + let offset_unsigned = match WORD_SIZE { + 32 => (offset & u32::MAX as i64) as u64, + 64 => offset as u64, + _ => panic!("Unsupported WORD_SIZE: {}", WORD_SIZE), + }; + + let is_aligned = + AssertAlignedMemoryAccessInstruction::(rs1_val, offset_unsigned) + .lookup_entry(); + debug_assert_eq!(is_aligned, 1); + virtual_trace.push(RVTraceRow { + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_WORD_ALIGNMENT, + rs1, + rs2: None, + rd: None, + imm: Some(offset_unsigned as i64), + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, + register_state: RegisterState { + rs1_val: Some(rs1_val), + rs2_val: None, + rd_post_val: None, + }, + memory_state: None, + advice_value: None, + }); + + trace_row.instruction.virtual_sequence_remaining = Some(0); + virtual_trace.push(trace_row); + + virtual_trace + } + + fn sequence_output(_: u64, _: u64) -> u64 { + unimplemented!("") + } + + fn virtual_sequence(instruction: ELFInstruction) -> Vec { + let dummy_trace_row = RVTraceRow { + instruction, + register_state: RegisterState { + rs1_val: Some(0), + rs2_val: Some(0), + rd_post_val: Some(0), + }, + memory_state: Some(MemoryState::Read { + address: 0, + value: 0, + }), + advice_value: None, + }; + Self::virtual_trace(dummy_trace_row) + .into_iter() + .map(|trace_row| trace_row.instruction) + .collect() + } +} + +#[cfg(test)] +mod test { + use ark_std::test_rng; + use rand_core::RngCore; + + use super::*; + + #[test] + fn lw_virtual_sequence_32() { + let mut rng = test_rng(); + for _ in 0..256 { + let rs1 = rng.next_u64() % 32; + let rd = rng.next_u64() % 32; + + let mut rs1_val = rng.next_u32() as u64; + let mut imm = rng.next_u64() as i64 % (1 << 12); + + // Reroll rs1_val and imm until dest is aligned to a word + while (rs1_val as i64 + imm as i64) % 4 != 0 || (rs1_val as i64 + imm as i64) < 0 { + rs1_val = rng.next_u32() as u64; + imm = rng.next_u64() as i64 % (1 << 12); + } + let address = (rs1_val as i64 + imm as i64) as u64; + let word = rng.next_u32() as u64; + + let lw_trace_row = RVTraceRow { + instruction: ELFInstruction { + address: rng.next_u64(), + opcode: RV32IM::LW, + rs1: Some(rs1), + rs2: None, + rd: Some(rd), + imm: Some(imm), + virtual_sequence_remaining: None, + }, + register_state: RegisterState { + rs1_val: Some(rs1_val), + rs2_val: None, + rd_post_val: Some(word), + }, + memory_state: Some(MemoryState::Read { + address, + value: word, + }), + advice_value: None, + }; + + let trace = LWInstruction::<32>::virtual_trace(lw_trace_row); + assert_eq!(trace.len(), LWInstruction::<32>::SEQUENCE_LENGTH); + } + } +} diff --git a/jolt-core/src/jolt/instruction/mod.rs b/jolt-core/src/jolt/instruction/mod.rs index ee5362c7a..33ec242e1 100644 --- a/jolt-core/src/jolt/instruction/mod.rs +++ b/jolt-core/src/jolt/instruction/mod.rs @@ -159,6 +159,7 @@ pub mod lb; pub mod lbu; pub mod lh; pub mod lhu; +pub mod lw; pub mod mul; pub mod mulh; pub mod mulhsu; @@ -175,6 +176,7 @@ pub mod sltu; pub mod sra; pub mod srl; pub mod sub; +pub mod sw; pub mod virtual_advice; pub mod virtual_assert_aligned_memory_access; pub mod virtual_assert_lte; diff --git a/jolt-core/src/jolt/instruction/sw.rs b/jolt-core/src/jolt/instruction/sw.rs new file mode 100644 index 000000000..b0d9d0251 --- /dev/null +++ b/jolt-core/src/jolt/instruction/sw.rs @@ -0,0 +1,138 @@ +use tracer::{ELFInstruction, MemoryState, RVTraceRow, RegisterState, RV32IM}; + +use super::VirtualInstructionSequence; +use crate::jolt::instruction::{ + virtual_assert_aligned_memory_access::AssertAlignedMemoryAccessInstruction, JoltInstruction, +}; +/// Stores a word to memory +pub struct SWInstruction; + +impl VirtualInstructionSequence for SWInstruction { + const SEQUENCE_LENGTH: usize = 2; + + fn virtual_trace(mut trace_row: RVTraceRow) -> Vec { + assert_eq!(trace_row.instruction.opcode, RV32IM::SW); + // SW source registers + let rs1 = trace_row.instruction.rs1; + // SW operands + let dest = trace_row.register_state.rs1_val.unwrap(); + let offset = trace_row.instruction.imm.unwrap(); + + let mut virtual_trace = vec![]; + + let offset_unsigned = match WORD_SIZE { + 32 => (offset & u32::MAX as i64) as u64, + 64 => offset as u64, + _ => panic!("Unsupported WORD_SIZE: {}", WORD_SIZE), + }; + + let is_aligned = + AssertAlignedMemoryAccessInstruction::(dest, offset_unsigned) + .lookup_entry(); + debug_assert_eq!(is_aligned, 1); + virtual_trace.push(RVTraceRow { + instruction: ELFInstruction { + address: trace_row.instruction.address, + opcode: RV32IM::VIRTUAL_ASSERT_WORD_ALIGNMENT, + rs1, + rs2: None, + rd: None, + imm: Some(offset_unsigned as i64), + virtual_sequence_remaining: Some(Self::SEQUENCE_LENGTH - virtual_trace.len() - 1), + }, + register_state: RegisterState { + rs1_val: Some(dest), + rs2_val: None, + rd_post_val: None, + }, + memory_state: None, + advice_value: None, + }); + + trace_row.instruction.virtual_sequence_remaining = Some(0); + virtual_trace.push(trace_row); + + virtual_trace + } + + fn sequence_output(_: u64, _: u64) -> u64 { + unimplemented!("") + } + + fn virtual_sequence(instruction: ELFInstruction) -> Vec { + let dummy_trace_row = RVTraceRow { + instruction, + register_state: RegisterState { + rs1_val: Some(0), + rs2_val: Some(0), + rd_post_val: Some(0), + }, + memory_state: Some(MemoryState::Read { + address: 0, + value: 0, + }), + advice_value: None, + }; + Self::virtual_trace(dummy_trace_row) + .into_iter() + .map(|trace_row| trace_row.instruction) + .collect() + } +} + +#[cfg(test)] +mod test { + use ark_std::test_rng; + use rand_core::RngCore; + + use super::*; + + #[test] + fn sw_virtual_sequence_32() { + let mut rng = test_rng(); + for _ in 0..256 { + let rs1 = rng.next_u64() % 32; + let rs2 = rng.next_u64() % 32; + + let mut rs1_val = rng.next_u32() as u64; + let mut imm = rng.next_u64() as i64 % (1 << 12); + + // Reroll rs1_val and imm until dest is aligned to a word + while (rs1_val as i64 + imm as i64) % 4 != 0 || (rs1_val as i64 + imm as i64) < 0 { + rs1_val = rng.next_u32() as u64; + imm = rng.next_u64() as i64 % (1 << 12); + } + + let dest = (rs1_val as i64 + imm as i64) as u64; + + let word_before = rng.next_u32() as u64; + let word_after = rng.next_u32() as u64; + + let sw_trace_row = RVTraceRow { + instruction: ELFInstruction { + address: rng.next_u64(), + opcode: RV32IM::SW, + rs1: Some(rs1), + rs2: Some(rs2), + rd: None, + imm: Some(imm), + virtual_sequence_remaining: None, + }, + register_state: RegisterState { + rs1_val: Some(rs1_val), + rs2_val: Some(word_after), + rd_post_val: None, + }, + memory_state: Some(MemoryState::Write { + address: dest, + pre_value: word_before, + post_value: word_after, + }), + advice_value: None, + }; + + let trace = SWInstruction::<32>::virtual_trace(sw_trace_row); + assert_eq!(trace.len(), SWInstruction::<32>::SEQUENCE_LENGTH); + } + } +} diff --git a/jolt-core/src/jolt/trace/rv.rs b/jolt-core/src/jolt/trace/rv.rs index c9965095c..ebc6ae4e0 100644 --- a/jolt-core/src/jolt/trace/rv.rs +++ b/jolt-core/src/jolt/trace/rv.rs @@ -76,9 +76,7 @@ impl TryFrom<&ELFInstruction> for RV32I { RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER => Ok(AssertValidSignedRemainderInstruction::default().into()), RV32IM::VIRTUAL_ASSERT_VALID_DIV0 => Ok(AssertValidDiv0Instruction::default().into()), RV32IM::VIRTUAL_ASSERT_HALFWORD_ALIGNMENT => Ok(AssertAlignedMemoryAccessInstruction::<32, 2>::default().into()), - - RV32IM::LW => Ok(AssertAlignedMemoryAccessInstruction::<32, 4>::default().into()), - RV32IM::SW => Ok(AssertAlignedMemoryAccessInstruction::<32, 4>::default().into()), + RV32IM::VIRTUAL_ASSERT_WORD_ALIGNMENT => Ok(AssertAlignedMemoryAccessInstruction::<32, 4>::default().into()), _ => Err("No corresponding RV32I instruction") } @@ -136,9 +134,7 @@ impl TryFrom<&RVTraceRow> for RV32I { RV32IM::VIRTUAL_ASSERT_VALID_SIGNED_REMAINDER => Ok(AssertValidSignedRemainderInstruction(row.register_state.rs1_val.unwrap(), row.register_state.rs2_val.unwrap()).into()), RV32IM::VIRTUAL_ASSERT_VALID_DIV0 => Ok(AssertValidDiv0Instruction(row.register_state.rs1_val.unwrap(), row.register_state.rs2_val.unwrap()).into()), RV32IM::VIRTUAL_ASSERT_HALFWORD_ALIGNMENT => Ok(AssertAlignedMemoryAccessInstruction::<32, 2>(row.register_state.rs1_val.unwrap(), row.imm_u32() as u64).into()), - - RV32IM::LW => Ok(AssertAlignedMemoryAccessInstruction::<32, 4>(row.register_state.rs1_val.unwrap(), row.imm_u32() as u64).into()), - RV32IM::SW => Ok(AssertAlignedMemoryAccessInstruction::<32, 4>(row.register_state.rs1_val.unwrap(), row.imm_u32() as u64).into()), + RV32IM::VIRTUAL_ASSERT_WORD_ALIGNMENT => Ok(AssertAlignedMemoryAccessInstruction::<32, 4>(row.register_state.rs1_val.unwrap(), row.imm_u32() as u64).into()), _ => Err("No corresponding RV32I instruction") } diff --git a/jolt-core/src/jolt/vm/mod.rs b/jolt-core/src/jolt/vm/mod.rs index d1834caf7..e4d5eb595 100644 --- a/jolt-core/src/jolt/vm/mod.rs +++ b/jolt-core/src/jolt/vm/mod.rs @@ -52,8 +52,10 @@ use super::instruction::lb::LBInstruction; use super::instruction::lbu::LBUInstruction; use super::instruction::lh::LHInstruction; use super::instruction::lhu::LHUInstruction; +use super::instruction::lw::LWInstruction; use super::instruction::sb::SBInstruction; use super::instruction::sh::SHInstruction; +use super::instruction::sw::SWInstruction; use super::instruction::JoltInstructionSet; #[derive(Clone)] @@ -350,8 +352,10 @@ where tracer::RV32IM::DIVU => DIVUInstruction::<32>::virtual_sequence(instruction), tracer::RV32IM::REM => REMInstruction::<32>::virtual_sequence(instruction), tracer::RV32IM::REMU => REMUInstruction::<32>::virtual_sequence(instruction), + tracer::RV32IM::SW => SWInstruction::<32>::virtual_sequence(instruction), tracer::RV32IM::SH => SHInstruction::<32>::virtual_sequence(instruction), tracer::RV32IM::SB => SBInstruction::<32>::virtual_sequence(instruction), + tracer::RV32IM::LW => LWInstruction::<32>::virtual_sequence(instruction), tracer::RV32IM::LBU => LBUInstruction::<32>::virtual_sequence(instruction), tracer::RV32IM::LHU => LHUInstruction::<32>::virtual_sequence(instruction), tracer::RV32IM::LB => LBInstruction::<32>::virtual_sequence(instruction), diff --git a/jolt-core/src/r1cs/builder.rs b/jolt-core/src/r1cs/builder.rs index 729bfc0fe..1184cfa8e 100644 --- a/jolt-core/src/r1cs/builder.rs +++ b/jolt-core/src/r1cs/builder.rs @@ -70,7 +70,7 @@ impl Constraint { } } -type AuxComputationFunction = dyn Fn(&[i64]) -> i128 + Send + Sync; +type AuxComputationFunction = dyn Fn(&[i128]) -> i128 + Send + Sync; struct AuxComputation { symbolic_inputs: Vec, @@ -131,7 +131,7 @@ impl AuxComputation { .map(|var| var.get_ref(jolt_polynomials)) .collect(); - let mut aux_poly: Vec = vec![0; poly_len]; + let mut aux_poly: Vec = vec![0; poly_len]; let num_threads = rayon::current_num_threads(); let chunk_size = poly_len.div_ceil(num_threads); @@ -149,20 +149,29 @@ impl AuxComputation { for term in lc.terms().iter() { match term.0 { Variable::Input(index) | Variable::Auxiliary(index) => { - input += flattened_polys[index].get_coeff_i64(global_index) - * term.1; + input += flattened_polys[index] + .get_coeff_i128(global_index) + * term.1 as i128; } - Variable::Constant => input += term.1, + Variable::Constant => input += term.1 as i128, } } input }) .collect(); - *result = u64::try_from((self.compute)(&compute_inputs)).unwrap(); + *result = (self.compute)(&compute_inputs); }); }); - MultilinearPolynomial::from(aux_poly) + let contains_negative_values = aux_poly.iter().any(|x| x.is_negative()); + println!("contains_negative_values: {contains_negative_values}"); + if contains_negative_values { + let aux_poly: Vec<_> = aux_poly.iter().map(|x| *x as i64).collect(); + MultilinearPolynomial::from(aux_poly) + } else { + let aux_poly: Vec<_> = aux_poly.iter().map(|x| *x as i64 as u64).collect(); + MultilinearPolynomial::from(aux_poly) + } } } @@ -295,16 +304,16 @@ impl R1CSBuilder { result_false: &LC, ) -> Variable { // aux = (condition == 1) ? result_true : result_false; - let if_else = |values: &[i64]| -> i128 { + let if_else = |values: &[i128]| -> i128 { assert_eq!(values.len(), 3); let condition = values[0]; let result_true = values[1]; let result_false = values[2]; if condition.is_one() { - result_true as i128 + result_true } else { - result_false as i128 + result_false } }; @@ -386,9 +395,9 @@ impl R1CSBuilder { } fn aux_prod(&mut self, aux_symbol: I, x: &LC, y: &LC) -> Variable { - let prod = |values: &[i64]| { + let prod = |values: &[i128]| { assert_eq!(values.len(), 2); - (values[0] as i128) * (values[1] as i128) + (values[0]) * (values[1]) }; let symbolic_inputs = vec![x.clone(), y.clone()]; diff --git a/jolt-core/src/r1cs/constraints.rs b/jolt-core/src/r1cs/constraints.rs index 6cc0c165c..324bc2a6e 100644 --- a/jolt-core/src/r1cs/constraints.rs +++ b/jolt-core/src/r1cs/constraints.rs @@ -138,7 +138,11 @@ impl R1CSConstraints for JoltRV32IMConstrain let is_mul = JoltR1CSInputs::InstructionFlags(MULInstruction::default().into()) + JoltR1CSInputs::InstructionFlags(MULUInstruction::default().into()) + JoltR1CSInputs::InstructionFlags(MULHUInstruction::default().into()); - let product = cs.allocate_prod(JoltR1CSInputs::Aux(AuxVariable::Product), x, y); + let product = cs.allocate_prod( + JoltR1CSInputs::Aux(AuxVariable::Product), + JoltR1CSInputs::RS1_Read, + JoltR1CSInputs::RS2_Read, + ); cs.constrain_eq_conditional(is_mul, packed_query.clone(), product); cs.constrain_eq_conditional( JoltR1CSInputs::InstructionFlags(MOVSIGNInstruction::default().into())