Skip to content

Commit

Permalink
Merge pull request #185 from a16z/optimize/fast-montgomery-conversion
Browse files Browse the repository at this point in the history
Fast u64 montgomery conversion
  • Loading branch information
moodlezoup authored Feb 28, 2024
2 parents 7aa760b + e5d1592 commit 1151696
Show file tree
Hide file tree
Showing 34 changed files with 539 additions and 381 deletions.
18 changes: 12 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
# authors who contributed to Lasso
"Michael Zhu <[email protected]>",
"Sam Ragsdale <[email protected]>",
"Noah Citron <[email protected]>"
"Noah Citron <[email protected]>",
]
edition = "2021"
description = "The lookup singularity. Based on Spartan; built on Arkworks."
Expand All @@ -27,15 +27,21 @@ members = [
"examples/fibonacci",
"examples/sha2-ex",
"examples/sha3-ex",
"integration-tests",
"integration-tests",
]

[profile.release]
debug = 1
codegen-units = 1
lto = "fat"
lto = "off"
incremental = true

[profile.build-fast]
[profile.extra-fast]
inherits = "release"
incremental = true
lto = "off"
lto = "fat"
incremental = false

[patch.crates-io]
ark-ff = { git = "https://github.com/a16z/arkworks-algebra", branch = "optimize/field-from-u64" }
ark-ec = { git = "https://github.com/a16z/arkworks-algebra", branch = "optimize/field-from-u64" }
ark-serialize = { git = "https://github.com/a16z/arkworks-algebra", branch = "optimize/field-from-u64" }
3 changes: 1 addition & 2 deletions jolt-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,11 @@ witness = { git = "https://github.com/sragss/circom-witness-rs", branch = "non-r
ruint = "1.11.1"
spartan2 = { git = "https://github.com/a16z/Spartan2.git" }
ark-bn254 = "0.4.0"

lazy_static = "1.4.0"

[build-dependencies]
common = { path = "../common" }


[lib]
name = "liblasso"
path = "src/lib.rs"
Expand Down
92 changes: 57 additions & 35 deletions jolt-core/src/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::jolt::vm::read_write_memory::{
use crate::jolt::vm::rv32i_vm::{RV32IJoltVM, C, M, RV32I};
use crate::jolt::vm::Jolt;
use crate::poly::dense_mlpoly::bench::{init_commit_bench, run_commit_bench};
use ark_bn254::{G1Projective, Fr};
use ark_bn254::{Fr, G1Projective};
use common::constants::MEMORY_OPS_PER_INSTRUCTION;
use common::ELFInstruction;
use criterion::black_box;
Expand Down Expand Up @@ -57,9 +57,9 @@ fn prove_e2e_except_r1cs(
) -> Vec<(tracing::Span, Box<dyn FnOnce()>)> {
let mut rng = rand::rngs::StdRng::seed_from_u64(1234567890);

let memory_size = memory_size.unwrap_or(1 << 22); // 4,194,304 = 4 MB
let bytecode_size = bytecode_size.unwrap_or(1 << 16); // 65,536 = 64 kB
let num_cycles = num_cycles.unwrap_or(1 << 16); // 65,536
let memory_size = memory_size.unwrap_or(1 << 22); // 4,194,304 = 4 MB
let bytecode_size = bytecode_size.unwrap_or(1 << 16); // 65,536 = 64 kB
let num_cycles = num_cycles.unwrap_or(1 << 16); // 65,536

let ops: Vec<RV32I> = std::iter::repeat_with(|| RV32I::random_instruction(&mut rng))
.take(num_cycles)
Expand All @@ -74,14 +74,20 @@ fn prove_e2e_except_r1cs(
.collect();
let bytecode_trace = random_bytecode_trace(&bytecode_rows, num_cycles, &mut rng);

let work = Box::new(|| {
let mut transcript = Transcript::new(b"example");
let _: (_, BytecodePolynomials<Fr, G1Projective>, _) =
RV32IJoltVM::prove_bytecode(bytecode_rows, bytecode_trace, &mut transcript);
let generators = RV32IJoltVM::preprocess(bytecode_size, memory_size, num_cycles);
let mut transcript = Transcript::new(b"example");

let work = Box::new(move || {
let _: (_, BytecodePolynomials<Fr, G1Projective>, _) = RV32IJoltVM::prove_bytecode(
bytecode_rows,
bytecode_trace,
&generators,
&mut transcript,
);
let _: (_, ReadWriteMemory<Fr, G1Projective>, _) =
RV32IJoltVM::prove_memory(bytecode, memory_trace, &mut transcript);
RV32IJoltVM::prove_memory(bytecode, memory_trace, &generators, &mut transcript);
let _: (_, InstructionPolynomials<Fr, G1Projective>, _) =
RV32IJoltVM::prove_instruction_lookups(ops, &mut transcript);
RV32IJoltVM::prove_instruction_lookups(ops, &generators, &mut transcript);
});
vec![(
tracing::info_span!("prove_bytecode + prove_memory + prove_instruction_lookups"),
Expand All @@ -103,10 +109,16 @@ fn prove_bytecode(
.collect();
let bytecode_trace = random_bytecode_trace(&bytecode_rows, num_cycles, &mut rng);

let work = Box::new(|| {
let mut transcript = Transcript::new(b"example");
let _: (_, BytecodePolynomials<Fr, G1Projective>, _) =
RV32IJoltVM::prove_bytecode(bytecode_rows, bytecode_trace, &mut transcript);
let generators = RV32IJoltVM::preprocess(bytecode_size, 1, num_cycles);
let mut transcript = Transcript::new(b"example");

let work = Box::new(move || {
let _: (_, BytecodePolynomials<Fr, G1Projective>, _) = RV32IJoltVM::prove_bytecode(
bytecode_rows,
bytecode_trace,
&generators,
&mut transcript,
);
});
vec![(tracing::info_span!("prove_bytecode"), work)]
}
Expand All @@ -127,10 +139,12 @@ fn prove_memory(
.collect();
let memory_trace = random_memory_trace(&bytecode, memory_size, num_cycles, &mut rng);

let work = Box::new(|| {
let generators = RV32IJoltVM::preprocess(bytecode_size, memory_size, num_cycles);

let work = Box::new(move || {
let mut transcript = Transcript::new(b"example");
let _: (_, ReadWriteMemory<Fr, G1Projective>, _) =
RV32IJoltVM::prove_memory(bytecode, memory_trace, &mut transcript);
RV32IJoltVM::prove_memory(bytecode, memory_trace, &generators, &mut transcript);
});
vec![(tracing::info_span!("prove_memory"), work)]
}
Expand All @@ -143,10 +157,12 @@ fn prove_instruction_lookups(num_cycles: Option<usize>) -> Vec<(tracing::Span, B
.take(num_cycles)
.collect();

let work = Box::new(|| {
let mut transcript = Transcript::new(b"example");
let generators = RV32IJoltVM::preprocess(1, 1, num_cycles);
let mut transcript = Transcript::new(b"example");

let work = Box::new(move || {
let _: (_, InstructionPolynomials<Fr, G1Projective>, _) =
RV32IJoltVM::prove_instruction_lookups(ops, &mut transcript);
RV32IJoltVM::prove_instruction_lookups(ops, &generators, &mut transcript);
});
vec![(tracing::info_span!("prove_instruction_lookups"), work)]
}
Expand Down Expand Up @@ -225,13 +241,17 @@ fn prove_example(example_name: &str) -> Vec<(tracing::Span, Box<dyn FnOnce()>)>
.flat_map(|row| row.to_circuit_flags::<Fr>())
.collect::<Vec<_>>();

let (jolt_proof, jolt_commitments) = <RV32IJoltVM as Jolt<'_, _, G1Projective, C, M>>::prove(
bytecode,
bytecode_trace,
memory_trace,
instructions_r1cs,
circuit_flags,
);
let generators = RV32IJoltVM::preprocess(1 << 20, 1 << 20, 1 << 22);

let (jolt_proof, jolt_commitments) =
<RV32IJoltVM as Jolt<'_, _, G1Projective, C, M>>::prove(
bytecode,
bytecode_trace,
memory_trace,
instructions_r1cs,
circuit_flags,
generators,
);

assert!(RV32IJoltVM::verify(jolt_proof, jolt_commitments).is_ok());
};
Expand Down Expand Up @@ -284,22 +304,24 @@ fn fibonacci() -> Vec<(tracing::Span, Box<dyn FnOnce()>)> {
let memory_trace: Vec<[MemoryOp; MEMORY_OPS_PER_INSTRUCTION]> = converted_trace
.clone()
.into_iter()

.map(|row| row.to_ram_ops().try_into().unwrap())
.collect();
let circuit_flags = converted_trace
.iter()
.flat_map(|row| row.to_circuit_flags::<Fr>())
.collect::<Vec<_>>();

let (jolt_proof, jolt_commitments) = <RV32IJoltVM as Jolt<'_, _, G1Projective, C, M>>::prove(
bytecode,
bytecode_trace,
memory_trace,
instructions_r1cs,
circuit_flags,
);

let generators = RV32IJoltVM::preprocess(1 << 20, 1 << 20, 1 << 22);

let (jolt_proof, jolt_commitments) =
<RV32IJoltVM as Jolt<'_, _, G1Projective, C, M>>::prove(
bytecode,
bytecode_trace,
memory_trace,
instructions_r1cs,
circuit_flags,
generators,
);

assert!(RV32IJoltVM::verify(jolt_proof, jolt_commitments).is_ok());
};
Expand Down
4 changes: 2 additions & 2 deletions jolt-core/src/jolt/instruction/srl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ mod test {
#[test]
fn srl_instruction_e2e() {
let mut rng = test_rng();
const C: usize = 3;
const M: usize = 1 << 22;
const C: usize = 4;
const M: usize = 1 << 16;
const WORD_SIZE: usize = 32;

for _ in 0..256 {
Expand Down
4 changes: 3 additions & 1 deletion jolt-core/src/jolt/instruction/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
/// 4. Checks that the result equals the expected value, given by the `lookup_output`
macro_rules! jolt_instruction_test {
($instr:expr) => {
use ark_ff::PrimeField;

let materialized_subtables: Vec<_> = $instr
.subtables::<Fr>(C)
.iter()
Expand All @@ -23,7 +25,7 @@ macro_rules! jolt_instruction_test {
}

let actual = $instr.combine_lookups(&subtable_values, C, M);
let expected = Fr::from($instr.lookup_entry());
let expected = Fr::from_u64($instr.lookup_entry()).unwrap();

assert_eq!(actual, expected, "{:?}", $instr);
};
Expand Down
4 changes: 2 additions & 2 deletions jolt-core/src/jolt/subtable/and.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl<F: PrimeField> LassoSubtable<F> for AndSubtable<F> {
// Materialize table entries in order where (x | y) ranges 0..M
for idx in 0..M {
let (x, y) = split_bits(idx, bits_per_operand);
let row = F::from((x & y) as u64);
let row = F::from_u64((x & y) as u64).unwrap();
entries.push(row);
}
entries
Expand All @@ -42,7 +42,7 @@ impl<F: PrimeField> LassoSubtable<F> for AndSubtable<F> {
for i in 0..b {
let x = x[b - i - 1];
let y = y[b - i - 1];
result += F::from(1u64 << i) * x * y;
result += F::from_u64(1u64 << i).unwrap() * x * y;
}
result
}
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/jolt/subtable/eq_msb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl<F: PrimeField> LassoSubtable<F> for EqMSBSubtable<F> {
for idx in 0..M {
let (x, y) = split_bits(idx, bits_per_operand);
let row = (x & high_bit) == (y & high_bit);
entries.push(F::from(row));
entries.push(if row { F::one() } else { F::zero() });
}
entries
}
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/jolt/subtable/gt_msb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl<F: PrimeField> LassoSubtable<F> for GtMSBSubtable<F> {
for idx in 0..M {
let (x, y) = split_bits(idx, bits_per_operand);
let row = (x & high_bit) > (y & high_bit);
entries.push(F::from(row));
entries.push(if row { F::one() } else { F::zero() });
}
entries
}
Expand Down
4 changes: 2 additions & 2 deletions jolt-core/src/jolt/subtable/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ impl<F: PrimeField> IdentitySubtable<F> {

impl<F: PrimeField> LassoSubtable<F> for IdentitySubtable<F> {
fn materialize(&self, M: usize) -> Vec<F> {
(0..M).map(|i| F::from(i as u64)).collect()
(0..M).map(|i| F::from_u64(i as u64).unwrap()).collect()
}

fn evaluate_mle(&self, point: &[F]) -> F {
let mut result = F::zero();
for i in 0..point.len() {
result += F::from(1u64 << i) * point[point.len() - 1 - i];
result += F::from_u64(1u64 << i).unwrap() * point[point.len() - 1 - i];
}
result
}
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/jolt/subtable/lt_abs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl<F: PrimeField> LassoSubtable<F> for LtAbsSubtable<F> {
// Skip i=0
for i in 1..b {
result += (F::one() - x[i]) * y[i] * eq_term;
eq_term *= F::one() - x[i] - y[i] + F::from(2u64) * x[i] * y[i];
eq_term *= F::one() - x[i] - y[i] + F::from_u64(2u64).unwrap() * x[i] * y[i];
}
result
}
Expand Down
2 changes: 1 addition & 1 deletion jolt-core/src/jolt/subtable/ltu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<F: PrimeField> LassoSubtable<F> for LtuSubtable<F> {
let mut eq_term = F::one();
for i in 0..b {
result += (F::one() - x[i]) * y[i] * eq_term;
eq_term *= F::one() - x[i] - y[i] + F::from(2u64) * x[i] * y[i];
eq_term *= F::one() - x[i] - y[i] + F::from_u64(2u64).unwrap() * x[i] * y[i];
}
result
}
Expand Down
4 changes: 2 additions & 2 deletions jolt-core/src/jolt/subtable/or.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl<F: PrimeField> LassoSubtable<F> for OrSubtable<F> {
// Materialize table entries in order where (x | y) ranges 0..M
for idx in 0..M {
let (x, y) = split_bits(idx, bits_per_operand);
let row = F::from((x | y) as u64);
let row = F::from_u64((x | y) as u64).unwrap();
entries.push(row);
}
entries
Expand All @@ -42,7 +42,7 @@ impl<F: PrimeField> LassoSubtable<F> for OrSubtable<F> {
for i in 0..b {
let x = x[b - i - 1];
let y = y[b - i - 1];
result += F::from(1u64 << i) * (x + y - x * y);
result += F::from_u64(1u64 << i).unwrap() * (x + y - x * y);
}
result
}
Expand Down
6 changes: 3 additions & 3 deletions jolt-core/src/jolt/subtable/sll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl<F: PrimeField, const CHUNK_INDEX: usize, const WORD_SIZE: usize> LassoSubta
.checked_shr(suffix_length as u32)
.unwrap_or(0);

entries.push(F::from(row as u64));
entries.push(F::from_u64(row as u64).unwrap());
}
entries
}
Expand All @@ -64,7 +64,7 @@ impl<F: PrimeField, const CHUNK_INDEX: usize, const WORD_SIZE: usize> LassoSubta
let k_bits = (k as usize)
.get_bits(log_WORD_SIZE)
.iter()
.map(|bit| F::from(*bit as u64))
.map(|bit| F::from_u64(*bit as u64).unwrap())
.collect::<Vec<F>>(); // big-endian

let mut eq_term = F::one();
Expand All @@ -84,7 +84,7 @@ impl<F: PrimeField, const CHUNK_INDEX: usize, const WORD_SIZE: usize> LassoSubta

let shift_x_by_k = (0..m_prime)
.enumerate()
.map(|(j, _)| F::from(1_u64 << (j + k)) * x[b - 1 - j])
.map(|(j, _)| F::from_u64(1_u64 << (j + k)).unwrap() * x[b - 1 - j])
.fold(F::zero(), |acc, val| acc + val);

result += eq_term * shift_x_by_k;
Expand Down
10 changes: 5 additions & 5 deletions jolt-core/src/jolt/subtable/sra_sign.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@ impl<F: PrimeField, const WORD_SIZE: usize> LassoSubtable<F> for SraSignSubtable
for idx in 0..M {
let (x, y) = split_bits(idx, operand_chunk_width);

let x_sign = F::from(((x >> sign_bit_index) & 1) as u64);
let x_sign = F::from_u64(((x >> sign_bit_index) & 1) as u64).unwrap();

let row = (0..(y % WORD_SIZE) as u32)
.into_iter()
.fold(F::zero(), |acc, i: u32| {
acc + F::from(1_u64 << (WORD_SIZE as u32 - 1 - i)) * x_sign
acc + F::from_u64(1_u64 << (WORD_SIZE as u32 - 1 - i)).unwrap() * x_sign
});

entries.push(F::from(row));
entries.push(row);
}
entries
}
Expand All @@ -64,7 +64,7 @@ impl<F: PrimeField, const WORD_SIZE: usize> LassoSubtable<F> for SraSignSubtable
let k_bits = (k as usize)
.get_bits(log_WORD_SIZE)
.iter()
.map(|bit| F::from(*bit as u64))
.map(|bit| if *bit { F::one() } else { F::zero() })
.collect::<Vec<F>>(); // big-endian

let mut eq_term = F::one();
Expand All @@ -75,7 +75,7 @@ impl<F: PrimeField, const WORD_SIZE: usize> LassoSubtable<F> for SraSignSubtable
}

let x_sign_upper = (0..k).into_iter().fold(F::zero(), |acc, i| {
acc + F::from(1_u64 << (WORD_SIZE - 1 - i)) * x_sign
acc + F::from_u64(1_u64 << (WORD_SIZE - 1 - i)).unwrap() * x_sign
});

result += eq_term * x_sign_upper;
Expand Down
Loading

0 comments on commit 1151696

Please sign in to comment.