From aa5c5c94b2b753d055a6fff415aa425c4271354a Mon Sep 17 00:00:00 2001 From: jason <94618524+mellowcroc@users.noreply.github.com> Date: Sat, 11 Jan 2025 16:39:42 +0900 Subject: [PATCH 1/6] feat: add qm31 wgpu implementation --- crates/prover/src/core/backend/gpu/mod.rs | 1 + crates/prover/src/core/backend/gpu/qm31.rs | 736 +++++++++++++++++++ crates/prover/src/core/backend/gpu/qm31.wgsl | 219 ++++++ 3 files changed, 956 insertions(+) create mode 100644 crates/prover/src/core/backend/gpu/qm31.rs create mode 100644 crates/prover/src/core/backend/gpu/qm31.wgsl diff --git a/crates/prover/src/core/backend/gpu/mod.rs b/crates/prover/src/core/backend/gpu/mod.rs index 8ddbf6302..eb19a1d81 100644 --- a/crates/prover/src/core/backend/gpu/mod.rs +++ b/crates/prover/src/core/backend/gpu/mod.rs @@ -3,3 +3,4 @@ pub mod gen_trace_interpolate_columns; pub mod gen_trace_parallel; pub mod gen_trace_parallel_no_packed; pub mod gen_trace_parallel_no_packed_parallel_columns; +pub mod qm31; diff --git a/crates/prover/src/core/backend/gpu/qm31.rs b/crates/prover/src/core/backend/gpu/qm31.rs new file mode 100644 index 000000000..691ed1195 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/qm31.rs @@ -0,0 +1,736 @@ +use wgpu::util::DeviceExt; + +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::M31; +use crate::core::fields::qm31::QM31; + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct GpuM31 { + pub data: u32, +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct GpuCM31 { + pub a: GpuM31, + pub b: GpuM31, +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct GpuQM31 { + pub a: GpuCM31, + pub b: GpuCM31, +} + +impl From for GpuM31 { + fn from(value: M31) -> Self { + GpuM31 { data: value.into() } + } +} + +impl From for GpuCM31 { + fn from(value: CM31) -> Self { + GpuCM31 { + a: value.0.into(), + b: value.1.into(), + } + } +} + +impl From for GpuQM31 { + fn from(value: QM31) -> Self { + GpuQM31 { + a: value.0.into(), + b: value.1.into(), + } + } +} + +pub trait ByteSerialize: Sized { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + (self as *const Self) as *const u8, + std::mem::size_of::(), + ) + } + } + + fn from_bytes(bytes: &[u8]) -> &Self { + assert!(bytes.len() >= std::mem::size_of::()); + unsafe { &*(bytes.as_ptr() as *const Self) } + } +} + +impl ByteSerialize for GpuM31 {} +impl ByteSerialize for GpuCM31 {} +impl ByteSerialize for GpuQM31 {} + +pub struct GpuFieldInstance { + pub device: wgpu::Device, + pub queue: wgpu::Queue, + pub input_buffer: wgpu::Buffer, + pub output_buffer: wgpu::Buffer, + pub staging_buffer: wgpu::Buffer, +} + +impl GpuFieldInstance { + pub async fn new(input_data: &T, output_size: usize) -> Self { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Field Operations Device"), + required_features: wgpu::Features::SHADER_INT64, + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + // Create input buffer + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Field Input Buffer"), + contents: input_data.as_bytes(), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + + // Create output buffer + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Output Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Create staging buffer for reading results + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Staging Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + Self { + device, + queue, + input_buffer, + output_buffer, + staging_buffer, + } + } + + pub fn create_pipeline( + &self, + shader_source: &str, + entry_point: &str, + ) -> (wgpu::ComputePipeline, wgpu::BindGroup) { + let shader = self + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Field Operations Shader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + let bind_group_layout = + self.device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Field Operations Bind Group Layout"), + }); + + let pipeline_layout = self + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Field Operations Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let pipeline = self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Field Operations Pipeline"), + layout: Some(&pipeline_layout), + module: &shader, + entry_point: Some(entry_point), + cache: None, + compilation_options: Default::default(), + }); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: self.input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: self.output_buffer.as_entire_binding(), + }, + ], + label: Some("Field Operations Bind Group"), + }); + + (pipeline, bind_group) + } + + pub async fn run_computation( + &self, + pipeline: &wgpu::ComputePipeline, + bind_group: &wgpu::BindGroup, + workgroup_count: (u32, u32, u32), + ) -> T { + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Field Operations Encoder"), + }); + + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Field Operations Compute Pass"), + timestamp_writes: None, + }); + compute_pass.set_pipeline(pipeline); + compute_pass.set_bind_group(0, bind_group, &[]); + compute_pass.dispatch_workgroups( + workgroup_count.0, + workgroup_count.1, + workgroup_count.2, + ); + } + + encoder.copy_buffer_to_buffer( + &self.output_buffer, + 0, + &self.staging_buffer, + 0, + self.staging_buffer.size(), + ); + + self.queue.submit(Some(encoder.finish())); + + let buffer_slice = self.staging_buffer.slice(..); + let (tx, rx) = flume::bounded(1); + buffer_slice.map_async(wgpu::MapMode::Read, move |result| { + tx.send(result).unwrap(); + }); + + self.device.poll(wgpu::Maintain::wait()); + + rx.recv_async().await.unwrap().unwrap(); + let data = buffer_slice.get_mapped_range(); + let result = *T::from_bytes(&data); + drop(data); + self.staging_buffer.unmap(); + + result + } +} + +#[derive(Debug)] +pub enum GpuFieldOperation { + Add, + Multiply, + Subtract, + Divide, + Negate, +} + +impl GpuFieldOperation { + fn shader_source(&self) -> String { + let base_shader = include_str!("qm31.wgsl"); + + let inputs = r#" + struct Inputs { + first: QM31, + second: QM31, + } + + @group(0) @binding(0) var inputs: Inputs; + "#; + + let output = r#" + @group(0) @binding(1) var output: QM31; + "#; + + let operation = match self { + GpuFieldOperation::Add => { + r#" + @compute @workgroup_size(1) + fn main() { + output.a = cm31_add(inputs.first.a, inputs.second.a); + output.b = cm31_add(inputs.first.b, inputs.second.b); + } + "# + } + GpuFieldOperation::Multiply => { + r#" + @compute @workgroup_size(1) + fn main() { + output = qm31_mul(inputs.first, inputs.second); + } + "# + } + GpuFieldOperation::Subtract => { + r#" + @compute @workgroup_size(1) + fn main() { + output.a = cm31_sub(inputs.first.a, inputs.second.a); + output.b = cm31_sub(inputs.first.b, inputs.second.b); + } + "# + } + GpuFieldOperation::Divide => { + r#" + @compute @workgroup_size(1) + fn main() { + let inv_b = qm31_inverse(inputs.second); + output = qm31_mul(inputs.first, inv_b); + } + "# + } + GpuFieldOperation::Negate => { + r#" + @compute @workgroup_size(1) + fn main() { + output.a = cm31_neg(inputs.first.a); + output.b = cm31_neg(inputs.first.b); + } + "# + } + }; + + format!("{}\n{}\n{}\n{}\n", base_shader, inputs, output, operation) + } +} + +#[derive(Debug)] +pub struct GpuFieldInputs { + pub first: GpuQM31, + pub second: GpuQM31, +} + +impl ByteSerialize for GpuFieldInputs {} + +pub async fn compute_field_operation(a: QM31, b: QM31, operation: GpuFieldOperation) -> QM31 { + let inputs = GpuFieldInputs { + first: GpuQM31::from(a), + second: GpuQM31::from(b), + }; + + let instance = GpuFieldInstance::new(&inputs, std::mem::size_of::()).await; + + let shader_source = operation.shader_source(); + let (pipeline, bind_group) = instance.create_pipeline(&shader_source, "main"); + + let result = instance + .run_computation::(&pipeline, &bind_group, (1, 1, 1)) + .await; + + QM31( + CM31(result.a.a.data.into(), result.a.b.data.into()), + CM31(result.b.a.data.into(), result.b.b.data.into()), + ) +} + +#[cfg(test)] +mod tests { + use num_traits::Zero; + + use super::*; + use crate::core::fields::m31::{M31, P}; + use crate::core::fields::qm31::QM31; + use crate::{cm31, qm31}; + + #[test] + fn test_gpu_field_values() { + let qm0 = qm31!(1, 2, 3, 4); + let qm1 = qm31!(4, 5, 6, 7); + + // Test round-trip conversion CPU -> GPU -> CPU + let gpu_qm0 = GpuQM31::from(qm0); + let gpu_qm1 = GpuQM31::from(qm1); + + let cpu_qm0 = QM31( + CM31(gpu_qm0.a.a.data.into(), gpu_qm0.a.b.data.into()), + CM31(gpu_qm0.b.a.data.into(), gpu_qm0.b.b.data.into()), + ); + + let cpu_qm1 = QM31( + CM31(gpu_qm1.a.a.data.into(), gpu_qm1.a.b.data.into()), + CM31(gpu_qm1.b.a.data.into(), gpu_qm1.b.b.data.into()), + ); + + assert_eq!( + qm0, cpu_qm0, + "Round-trip conversion should preserve values for qm0" + ); + assert_eq!( + qm1, cpu_qm1, + "Round-trip conversion should preserve values for qm1" + ); + } + + #[test] + fn test_gpu_m31_field_arithmetic() { + // Test M31 field operations + let m = M31::from(19u32); + let one = M31::from(1u32); + let zero = M31::zero(); + + // Create QM31 values for GPU computation + let m_qm = QM31(CM31(m, zero), CM31::zero()); + let one_qm = QM31(CM31(one, zero), CM31::zero()); + let zero_qm = QM31(CM31(zero, zero), CM31::zero()); + + // Test multiplication + let cpu_mul = m * one; + let gpu_mul = pollster::block_on(compute_field_operation( + m_qm, + one_qm, + GpuFieldOperation::Multiply, + )); + assert_eq!(gpu_mul.0 .0, cpu_mul, "M31 multiplication failed"); + + // Test addition + let cpu_add = m + one; + let gpu_add = pollster::block_on(compute_field_operation( + m_qm, + one_qm, + GpuFieldOperation::Add, + )); + assert_eq!(gpu_add.0 .0, cpu_add, "M31 addition failed"); + + // Test subtraction + let cpu_sub = m - one; + let gpu_sub = pollster::block_on(compute_field_operation( + m_qm, + one_qm, + GpuFieldOperation::Subtract, + )); + assert_eq!(gpu_sub.0 .0, cpu_sub, "M31 subtraction failed"); + + // Test negation + let cpu_neg = -m; + let gpu_neg = pollster::block_on(compute_field_operation( + m_qm, + zero_qm, + GpuFieldOperation::Negate, + )); + assert_eq!(gpu_neg.0 .0, cpu_neg, "M31 negation failed"); + + // Test division and inverse + let cpu_div = one / m; + let gpu_div = pollster::block_on(compute_field_operation( + one_qm, + m_qm, + GpuFieldOperation::Divide, + )); + assert_eq!(gpu_div.0 .0, cpu_div, "M31 division failed"); + + // Test with large numbers (near P) + let large = M31::from(P - 1); + let large_qm = QM31(CM31(large, zero), CM31::zero()); + + // Test large number multiplication + let cpu_large_mul = large * m; + let gpu_large_mul = pollster::block_on(compute_field_operation( + large_qm, + m_qm, + GpuFieldOperation::Multiply, + )); + assert_eq!( + gpu_large_mul.0 .0, cpu_large_mul, + "M31 large number multiplication failed" + ); + + // Test large number division + let cpu_large_div = one / large; + let gpu_large_div = pollster::block_on(compute_field_operation( + one_qm, + large_qm, + GpuFieldOperation::Divide, + )); + assert_eq!( + gpu_large_div.0 .0, cpu_large_div, + "M31 large number division failed" + ); + } + + #[test] + fn test_gpu_cm31_field_arithmetic() { + let cm0 = cm31!(1, 2); + let cm1 = cm31!(4, 5); + let m = M31::from(8u32); + let cm = CM31::from(m); + let zero = CM31::zero(); + + // Test multiplication + let cpu_mul = cm0 * cm1; + let gpu_mul = pollster::block_on(compute_field_operation( + QM31(cm0, zero), + QM31(cm1, zero), + GpuFieldOperation::Multiply, + )); + assert_eq!(gpu_mul.0, cpu_mul, "CM31 multiplication failed"); + + // Test addition + let cpu_add = cm0 + cm1; + let gpu_add = pollster::block_on(compute_field_operation( + QM31(cm0, zero), + QM31(cm1, zero), + GpuFieldOperation::Add, + )); + assert_eq!(gpu_add.0, cpu_add, "CM31 addition failed"); + + // Test subtraction + let cpu_sub = cm0 - cm1; + let gpu_sub = pollster::block_on(compute_field_operation( + QM31(cm0, zero), + QM31(cm1, zero), + GpuFieldOperation::Subtract, + )); + assert_eq!(gpu_sub.0, cpu_sub, "CM31 subtraction failed"); + + // Test negation + let cpu_neg = -cm0; + let gpu_neg = pollster::block_on(compute_field_operation( + QM31(cm0, zero), + QM31(zero, zero), + GpuFieldOperation::Negate, + )); + assert_eq!(gpu_neg.0, cpu_neg, "CM31 negation failed"); + + // Test division + let cpu_div = cm0 / cm1; + let gpu_div = pollster::block_on(compute_field_operation( + QM31(cm0, zero), + QM31(cm1, zero), + GpuFieldOperation::Divide, + )); + assert_eq!(gpu_div.0, cpu_div, "CM31 division failed"); + + // Test scalar operations + let cpu_scalar_mul = cm1 * m; + let gpu_scalar_mul = pollster::block_on(compute_field_operation( + QM31(cm1, zero), + QM31(cm, zero), + GpuFieldOperation::Multiply, + )); + assert_eq!( + gpu_scalar_mul.0, cpu_scalar_mul, + "CM31 scalar multiplication failed" + ); + + let cpu_scalar_add = cm1 + m; + let gpu_scalar_add = pollster::block_on(compute_field_operation( + QM31(cm1, zero), + QM31(cm, zero), + GpuFieldOperation::Add, + )); + assert_eq!( + gpu_scalar_add.0, cpu_scalar_add, + "CM31 scalar addition failed" + ); + + let cpu_scalar_sub = cm1 - m; + let gpu_scalar_sub = pollster::block_on(compute_field_operation( + QM31(cm1, zero), + QM31(cm, zero), + GpuFieldOperation::Subtract, + )); + assert_eq!( + gpu_scalar_sub.0, cpu_scalar_sub, + "CM31 scalar subtraction failed" + ); + + let cpu_scalar_div = cm1 / m; + let gpu_scalar_div = pollster::block_on(compute_field_operation( + QM31(cm1, zero), + QM31(cm, zero), + GpuFieldOperation::Divide, + )); + assert_eq!( + gpu_scalar_div.0, cpu_scalar_div, + "CM31 scalar division failed" + ); + + // Test with large numbers (near P) + let large = cm31!(P - 1, P - 2); + let large_qm = QM31(large, zero); + + // Test large number multiplication + let cpu_large_mul = large * cm1; + let gpu_large_mul = pollster::block_on(compute_field_operation( + large_qm, + QM31(cm1, zero), + GpuFieldOperation::Multiply, + )); + assert_eq!( + gpu_large_mul.0, cpu_large_mul, + "CM31 large number multiplication failed" + ); + + // Test large number division + let cpu_large_div = large / cm1; + let gpu_large_div = pollster::block_on(compute_field_operation( + large_qm, + QM31(cm1, zero), + GpuFieldOperation::Divide, + )); + assert_eq!( + gpu_large_div.0, cpu_large_div, + "CM31 large number division failed" + ); + } + + #[test] + fn test_gpu_qm31_field_arithmetic() { + let qm0 = qm31!(1, 2, 3, 4); + let qm1 = qm31!(4, 5, 6, 7); + let m = M31::from(8u32); + let qm = QM31::from(m); + let zero = QM31::zero(); + + // Test multiplication + let cpu_mul = qm0 * qm1; + let gpu_mul = pollster::block_on(compute_field_operation( + qm0, + qm1, + GpuFieldOperation::Multiply, + )); + assert_eq!(gpu_mul, cpu_mul, "QM31 multiplication failed"); + + // Test addition + let cpu_add = qm0 + qm1; + let gpu_add = pollster::block_on(compute_field_operation(qm0, qm1, GpuFieldOperation::Add)); + assert_eq!(gpu_add, cpu_add, "QM31 addition failed"); + + // Test subtraction + let cpu_sub = qm0 - qm1; + let gpu_sub = pollster::block_on(compute_field_operation( + qm0, + qm1, + GpuFieldOperation::Subtract, + )); + assert_eq!(gpu_sub, cpu_sub, "QM31 subtraction failed"); + + // Test negation + let cpu_neg = -qm0; + let gpu_neg = pollster::block_on(compute_field_operation( + qm0, + zero, + GpuFieldOperation::Negate, + )); + assert_eq!(gpu_neg, cpu_neg, "QM31 negation failed"); + + // Test division + let cpu_div = qm0 / qm1; + let gpu_div = + pollster::block_on(compute_field_operation(qm0, qm1, GpuFieldOperation::Divide)); + assert_eq!(gpu_div, cpu_div, "QM31 division failed"); + + // Test scalar operations + let cpu_scalar_mul = qm1 * m; + let gpu_scalar_mul = pollster::block_on(compute_field_operation( + qm1, + qm, + GpuFieldOperation::Multiply, + )); + assert_eq!( + cpu_scalar_mul, gpu_scalar_mul, + "QM31 scalar multiplication failed" + ); + + let cpu_scalar_add = qm1 + m; + let gpu_scalar_add = + pollster::block_on(compute_field_operation(qm1, qm, GpuFieldOperation::Add)); + assert_eq!( + cpu_scalar_add, gpu_scalar_add, + "QM31 scalar addition failed" + ); + + let cpu_scalar_sub = qm1 - m; + let gpu_scalar_sub = pollster::block_on(compute_field_operation( + qm1, + qm, + GpuFieldOperation::Subtract, + )); + assert_eq!( + cpu_scalar_sub, gpu_scalar_sub, + "QM31 scalar subtraction failed" + ); + + let cpu_scalar_div = qm1 / m; + let gpu_scalar_div = + pollster::block_on(compute_field_operation(qm1, qm, GpuFieldOperation::Divide)); + assert_eq!( + cpu_scalar_div, gpu_scalar_div, + "QM31 scalar division failed" + ); + + // Test with large numbers (near P) + let large = qm31!(P - 1, P - 2, P - 3, P - 4); + + // Test large number multiplication + let cpu_large_mul = large * qm1; + let gpu_large_mul = pollster::block_on(compute_field_operation( + large, + qm1, + GpuFieldOperation::Multiply, + )); + assert_eq!( + gpu_large_mul, cpu_large_mul, + "QM31 large number multiplication failed" + ); + + // Test large number division + let cpu_large_div = large / qm1; + let gpu_large_div = pollster::block_on(compute_field_operation( + large, + qm1, + GpuFieldOperation::Divide, + )); + assert_eq!( + gpu_large_div, cpu_large_div, + "QM31 large number division failed" + ); + } +} diff --git a/crates/prover/src/core/backend/gpu/qm31.wgsl b/crates/prover/src/core/backend/gpu/qm31.wgsl new file mode 100644 index 000000000..4807724ec --- /dev/null +++ b/crates/prover/src/core/backend/gpu/qm31.wgsl @@ -0,0 +1,219 @@ +// This shader contains implementations for QM31/CM31/M31 operations. +// It is stateless, i.e. it does not contain any storage variables, and also it does not include +// any entrypoint functions, which means that it can be used as a library in other shaders. +// Note that the variable names that are used in this shader cannot be used in other shaders. +const P: u32 = 0x7FFFFFFF; // 2^31 - 1 +const MODULUS_BITS: u32 = 31u; +const HALF_BITS: u32 = 16u; + +struct M31 { + data: u32, +} + +struct CM31 { + a: M31, + b: M31, +} + +struct QM31 { + a: CM31, + b: CM31, +} + +fn m31_add(a: M31, b: M31) -> M31 { + return M31(partial_reduce(a.data + b.data)); +} + +fn m31_sub(a: M31, b: M31) -> M31 { + return m31_add(a, m31_neg(b)); +} + +fn m31_mul(a: M31, b: M31) -> M31 { + // Split into 16-bit parts + let a1 = a.data >> HALF_BITS; + let a0 = a.data & 0xFFFFu; + let b1 = b.data >> HALF_BITS; + let b0 = b.data & 0xFFFFu; + + // Compute partial products + let m0 = partial_reduce(a0 * b0); + let m1 = partial_reduce(a0 * b1); + let m2 = partial_reduce(a1 * b0); + let m3 = partial_reduce(a1 * b1); + + // Combine middle terms with reduction + let mid = partial_reduce(m1 + m2); + + // Combine parts with partial reduction + let shifted_mid = partial_reduce(mid << HALF_BITS); + let low = partial_reduce(m0 + shifted_mid); + + let high_part = partial_reduce(m3 + (mid >> HALF_BITS)); + + // Final combination using Mersenne prime property + let result = partial_reduce( + partial_reduce((high_part << 1u)) + + partial_reduce((low >> MODULUS_BITS)) + + partial_reduce(low & P) + ); + return M31(result); +} + +fn m31_neg(a: M31) -> M31 { + return M31(partial_reduce(P - a.data)); +} + +fn m31_square(x: M31, n: u32) -> M31 { + var result = x; + for (var i = 0u; i < n; i += 1u) { + result = m31_mul(result, result); + } + return result; +} + +fn m31_inverse(x: M31) -> M31 { + // Computes x^(2^31-2) using the same sequence as pow2147483645 + // This is equivalent to x^(P-2) where P = 2^31-1 + + // t0 = x^5 + let t0 = m31_mul(m31_square(x, 2u), x); + + // t1 = x^15 + let t1 = m31_mul(m31_square(t0, 1u), t0); + + // t2 = x^125 + let t2 = m31_mul(m31_square(t1, 3u), t0); + + // t3 = x^255 + let t3 = m31_mul(m31_square(t2, 1u), t0); + + // t4 = x^65535 + let t4 = m31_mul(m31_square(t3, 8u), t3); + + // t5 = x^16777215 + let t5 = m31_mul(m31_square(t4, 8u), t3); + + // result = x^2147483520 + var result = m31_square(t5, 7u); + result = m31_mul(result, t2); + + return result; +} + +// Complex field operations for CM31 +fn cm31_add(a: CM31, b: CM31) -> CM31 { + return CM31( + m31_add(a.a, b.a), + m31_add(a.b, b.b) + ); +} + +fn cm31_sub(a: CM31, b: CM31) -> CM31 { + return CM31( + m31_sub(a.a, b.a), + m31_sub(a.b, b.b) + ); +} + +fn cm31_mul(a: CM31, b: CM31) -> CM31 { + // (a + bi)(c + di) = (ac - bd) + (ad + bc)i + let ac = m31_mul(a.a, b.a); + let bd = m31_mul(a.b, b.b); + let ad = m31_mul(a.a, b.b); + let bc = m31_mul(a.b, b.a); + + return CM31( + m31_sub(ac, bd), + m31_add(ad, bc) + ); +} + +fn cm31_neg(a: CM31) -> CM31 { + return CM31(m31_neg(a.a), m31_neg(a.b)); +} + +fn cm31_square(x: CM31) -> CM31 { + return cm31_mul(x, x); +} + +fn cm31_inverse(x: CM31) -> CM31 { + // 1/(a + bi) = (a - bi)/(a² + b²) + let a_sq = m31_square(x.a, 1u); + let b_sq = m31_square(x.b, 1u); + let denom = m31_add(a_sq, b_sq); + let denom_inv = m31_inverse(denom); + + // Multiply by conjugate and divide by norm + return cm31_mul( + CM31(x.a, m31_neg(x.b)), + CM31(denom_inv, M31(0u)) + ); +} + +// Quadratic extension field operations for QM31 +fn qm31_add(a: QM31, b: QM31) -> QM31 { + return QM31( + cm31_add(a.a, b.a), + cm31_add(a.b, b.b) + ); +} + +fn qm31_sub(a: QM31, b: QM31) -> QM31 { + return QM31( + cm31_sub(a.a, b.a), + cm31_sub(a.b, b.b) + ); +} + +fn qm31_mul(a: QM31, b: QM31) -> QM31 { + // (a + bu)(c + du) = (ac + rbd) + (ad + bc)u + // where r = 2 + i is the irreducible polynomial coefficient + let ac = cm31_mul(a.a, b.a); + let bd = cm31_mul(a.b, b.b); + let ad = cm31_mul(a.a, b.b); + let bc = cm31_mul(a.b, b.a); + + // r = 2 + i + let r = CM31(M31(2u), M31(1u)); + let rbd = cm31_mul(r, bd); + + return QM31( + cm31_add(ac, rbd), + cm31_add(ad, bc) + ); +} + +fn qm31_neg(a: QM31) -> QM31 { + return QM31(cm31_neg(a.a), cm31_neg(a.b)); +} + +fn qm31_square(x: QM31) -> QM31 { + return qm31_mul(x, x); +} + +fn qm31_inverse(x: QM31) -> QM31 { + // (a + bu)^-1 = (a - bu)/(a^2 - (2+i)b^2) + let b2 = cm31_square(x.b); + + // Create 2+i + let r = CM31(M31(2u), M31(1u)); + + let rb2 = cm31_mul(r, b2); + let a2 = cm31_square(x.a); + let denom = cm31_sub(a2, rb2); + let denom_inv = cm31_inverse(denom); + + // Compute (a - bu) + let neg_b = cm31_neg(x.b); + + return QM31( + cm31_mul(x.a, denom_inv), + cm31_mul(neg_b, denom_inv) + ); +} + +// Utility functions +fn partial_reduce(val: u32) -> u32 { + let reduced = val - P; + return select(val, reduced, reduced < val); +} \ No newline at end of file From 2324013578cb39b074c5e67ed6d16aad6c98d4b3 Mon Sep 17 00:00:00 2001 From: jason <94618524+mellowcroc@users.noreply.github.com> Date: Sun, 12 Jan 2025 17:12:40 +0900 Subject: [PATCH 2/6] feat: add fraction.wgsl --- .../prover/src/core/backend/gpu/fraction.rs | 362 ++++++++++++++++++ .../prover/src/core/backend/gpu/fraction.wgsl | 17 + crates/prover/src/core/backend/gpu/mod.rs | 1 + crates/prover/src/core/backend/gpu/qm31.rs | 11 + 4 files changed, 391 insertions(+) create mode 100644 crates/prover/src/core/backend/gpu/fraction.rs create mode 100644 crates/prover/src/core/backend/gpu/fraction.wgsl diff --git a/crates/prover/src/core/backend/gpu/fraction.rs b/crates/prover/src/core/backend/gpu/fraction.rs new file mode 100644 index 000000000..d9cd3b2f9 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/fraction.rs @@ -0,0 +1,362 @@ +use wgpu::util::DeviceExt; + +use super::qm31::GpuQM31; +use crate::core::fields::qm31::QM31; +use crate::core::lookups::utils::Fraction as CpuFraction; + +#[derive(Copy, Clone, Debug)] +pub struct GpuFraction { + numerator: GpuQM31, + denominator: GpuQM31, +} + +impl From> for GpuFraction { + fn from(f: CpuFraction) -> Self { + GpuFraction { + numerator: f.numerator.into(), + denominator: f.denominator.into(), + } + } +} + +impl From> for GpuFraction { + fn from(f: CpuFraction) -> Self { + GpuFraction { + numerator: f.numerator.into(), + denominator: f.denominator.into(), + } + } +} + +// GPU computation structures +#[repr(C)] +#[derive(Copy, Clone, Debug)] +struct ComputeInputs { + first: GpuFraction, + second: GpuFraction, +} + +pub trait ByteSerialize: Sized { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + (self as *const Self) as *const u8, + std::mem::size_of::(), + ) + } + } + + fn from_bytes(bytes: &[u8]) -> &Self { + assert!(bytes.len() >= std::mem::size_of::()); + unsafe { &*(bytes.as_ptr() as *const Self) } + } +} + +impl ByteSerialize for GpuFraction {} +impl ByteSerialize for ComputeInputs {} + +pub struct GpuFractionInstance { + pub device: wgpu::Device, + pub queue: wgpu::Queue, + pub input_buffer: wgpu::Buffer, + pub output_buffer: wgpu::Buffer, + pub staging_buffer: wgpu::Buffer, +} + +impl GpuFractionInstance { + pub async fn new(input_data: &T, output_size: usize) -> Self { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Field Operations Device"), + required_features: wgpu::Features::SHADER_INT64, + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + // Create input buffer + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Field Input Buffer"), + contents: input_data.as_bytes(), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + + // Create output buffer + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Output Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Create staging buffer for reading results + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Staging Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + Self { + device, + queue, + input_buffer, + output_buffer, + staging_buffer, + } + } + + pub fn create_pipeline( + &self, + shader_source: &str, + entry_point: &str, + ) -> (wgpu::ComputePipeline, wgpu::BindGroup) { + let shader = self + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Field Operations Shader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + let bind_group_layout = + self.device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Field Operations Bind Group Layout"), + }); + + let pipeline_layout = self + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Field Operations Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let pipeline = self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Field Operations Pipeline"), + layout: Some(&pipeline_layout), + module: &shader, + entry_point: Some(entry_point), + cache: None, + compilation_options: Default::default(), + }); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: self.input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: self.output_buffer.as_entire_binding(), + }, + ], + label: Some("Field Operations Bind Group"), + }); + + (pipeline, bind_group) + } + + pub async fn run_computation( + &self, + pipeline: &wgpu::ComputePipeline, + bind_group: &wgpu::BindGroup, + workgroup_count: (u32, u32, u32), + ) -> T { + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Field Operations Encoder"), + }); + + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Field Operations Compute Pass"), + timestamp_writes: None, + }); + compute_pass.set_pipeline(pipeline); + compute_pass.set_bind_group(0, bind_group, &[]); + compute_pass.dispatch_workgroups( + workgroup_count.0, + workgroup_count.1, + workgroup_count.2, + ); + } + + encoder.copy_buffer_to_buffer( + &self.output_buffer, + 0, + &self.staging_buffer, + 0, + self.staging_buffer.size(), + ); + + self.queue.submit(Some(encoder.finish())); + + let buffer_slice = self.staging_buffer.slice(..); + let (tx, rx) = flume::bounded(1); + buffer_slice.map_async(wgpu::MapMode::Read, move |result| { + tx.send(result).unwrap(); + }); + + self.device.poll(wgpu::Maintain::wait()); + + rx.recv_async().await.unwrap().unwrap(); + let data = buffer_slice.get_mapped_range(); + let result = *T::from_bytes(&data); + drop(data); + self.staging_buffer.unmap(); + + result + } +} + +#[derive(Debug)] +pub enum GpuFractionOperation { + Add, + Zero, +} + +impl GpuFractionOperation { + pub fn shader_source(&self) -> String { + let base_source = include_str!("fraction.wgsl"); + let qm31_source = include_str!("qm31.wgsl"); + + let inputs = r#" + struct Inputs { + first: Fraction, + second: Fraction, + } + + @group(0) @binding(0) var inputs: Inputs; + "#; + + let output = r#" + @group(0) @binding(1) var output: Fraction; + "#; + + let operation = match self { + GpuFractionOperation::Add => { + r#" + @compute @workgroup_size(1) + fn main() {{ + let result = fraction_add(inputs.first, inputs.second); + output = result; + }} + "# + } + GpuFractionOperation::Zero => { + r#" + @compute @workgroup_size(1) + fn main() {{ + let result = fraction_zero(); + output = result; + }} + "# + } + }; + + format!( + "{}\n{}\n{}\n{}\n{}\n", + qm31_source, base_source, inputs, output, operation + ) + } +} + +impl From for CpuFraction { + fn from(f: GpuFraction) -> Self { + CpuFraction::new(f.numerator.into(), f.denominator.into()) + } +} + +pub async fn compute_fraction_operation( + a: CpuFraction, + b: CpuFraction, + operation: GpuFractionOperation, +) -> CpuFraction { + let inputs = ComputeInputs { + first: a.into(), + second: b.into(), + }; + + let instance = GpuFractionInstance::new(&inputs, std::mem::size_of::()).await; + + let shader_source = operation.shader_source(); + let (pipeline, bind_group) = instance.create_pipeline(&shader_source, "main"); + + let gpu_result = instance + .run_computation::(&pipeline, &bind_group, (1, 1, 1)) + .await; + + gpu_result.into() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::fields::qm31::QM31; + + #[test] + fn test_fraction_add() { + // CPU implementation + let cpu_a = CpuFraction::new( + QM31::from_u32_unchecked(1u32, 0u32, 0u32, 0u32), + QM31::from_u32_unchecked(3u32, 0u32, 0u32, 0u32), + ); + let cpu_b = CpuFraction::new( + QM31::from_u32_unchecked(2u32, 0u32, 0u32, 0u32), + QM31::from_u32_unchecked(6u32, 0u32, 0u32, 0u32), + ); + let cpu_result = cpu_a + cpu_b; + + // GPU implementation + let gpu_result = pollster::block_on(compute_fraction_operation( + cpu_a, + cpu_b, + GpuFractionOperation::Add, + )); + + assert_eq!(cpu_result.numerator, gpu_result.numerator); + assert_eq!(cpu_result.denominator, gpu_result.denominator); + } +} diff --git a/crates/prover/src/core/backend/gpu/fraction.wgsl b/crates/prover/src/core/backend/gpu/fraction.wgsl new file mode 100644 index 000000000..93e13c4f5 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/fraction.wgsl @@ -0,0 +1,17 @@ +// This shader contains implementations for fraction operations. +// It is stateless and can be used as a library in other shaders. + +struct Fraction { + numerator: QM31, + denominator: QM31, +} + +// Add two fractions: (a/b + c/d) = (ad + bc)/(bd) +fn fraction_add(a: Fraction, b: Fraction) -> Fraction { + let numerator = qm31_add( + qm31_mul(a.numerator, b.denominator), + qm31_mul(b.numerator, a.denominator) + ); + let denominator = qm31_mul(a.denominator, b.denominator); + return Fraction(numerator, denominator); +} \ No newline at end of file diff --git a/crates/prover/src/core/backend/gpu/mod.rs b/crates/prover/src/core/backend/gpu/mod.rs index eb19a1d81..8e0db270c 100644 --- a/crates/prover/src/core/backend/gpu/mod.rs +++ b/crates/prover/src/core/backend/gpu/mod.rs @@ -1,3 +1,4 @@ +pub mod fraction; pub mod gen_trace; pub mod gen_trace_interpolate_columns; pub mod gen_trace_parallel; diff --git a/crates/prover/src/core/backend/gpu/qm31.rs b/crates/prover/src/core/backend/gpu/qm31.rs index 691ed1195..759163da5 100644 --- a/crates/prover/src/core/backend/gpu/qm31.rs +++ b/crates/prover/src/core/backend/gpu/qm31.rs @@ -48,6 +48,17 @@ impl From for GpuQM31 { } } +impl From for QM31 { + fn from(value: GpuQM31) -> Self { + QM31::from_m31_array([ + value.a.a.data.into(), + value.a.b.data.into(), + value.b.a.data.into(), + value.b.b.data.into(), + ]) + } +} + pub trait ByteSerialize: Sized { fn as_bytes(&self) -> &[u8] { unsafe { From 54a33d1d3bc7550d9868d94d891f3818ae1382a9 Mon Sep 17 00:00:00 2001 From: jason <94618524+mellowcroc@users.noreply.github.com> Date: Mon, 13 Jan 2025 16:10:57 +0900 Subject: [PATCH 3/6] feat: add utils.wgsl --- crates/prover/src/core/backend/gpu/mod.rs | 1 + crates/prover/src/core/backend/gpu/utils.rs | 349 ++++++++++++++++++ crates/prover/src/core/backend/gpu/utils.wgsl | 52 +++ 3 files changed, 402 insertions(+) create mode 100644 crates/prover/src/core/backend/gpu/utils.rs create mode 100644 crates/prover/src/core/backend/gpu/utils.wgsl diff --git a/crates/prover/src/core/backend/gpu/mod.rs b/crates/prover/src/core/backend/gpu/mod.rs index 8e0db270c..34934ffba 100644 --- a/crates/prover/src/core/backend/gpu/mod.rs +++ b/crates/prover/src/core/backend/gpu/mod.rs @@ -5,3 +5,4 @@ pub mod gen_trace_parallel; pub mod gen_trace_parallel_no_packed; pub mod gen_trace_parallel_no_packed_parallel_columns; pub mod qm31; +pub mod utils; diff --git a/crates/prover/src/core/backend/gpu/utils.rs b/crates/prover/src/core/backend/gpu/utils.rs new file mode 100644 index 000000000..495a3fbb2 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/utils.rs @@ -0,0 +1,349 @@ +use wgpu::util::DeviceExt; + +/// Input data for the GPU computation +#[repr(C)] +#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] +pub struct ComputeInput { + pub i: u32, + pub domain_log_size: u32, + pub eval_log_size: u32, + pub offset: i32, +} + +/// Output data from the GPU computation +#[repr(C)] +#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] +pub struct ComputeOutput { + pub result: u32, +} + +impl From for usize { + fn from(output: ComputeOutput) -> Self { + output.result as usize + } +} + +pub trait ByteSerialize: Sized { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + (self as *const Self) as *const u8, + std::mem::size_of::(), + ) + } + } + + fn from_bytes(bytes: &[u8]) -> &Self { + assert!(bytes.len() >= std::mem::size_of::()); + unsafe { &*(bytes.as_ptr() as *const Self) } + } +} + +impl ByteSerialize for ComputeInput {} +impl ByteSerialize for ComputeOutput {} + +/// GPU instance for utility computations +pub struct GpuUtilsInstance { + device: wgpu::Device, + queue: wgpu::Queue, + input_buffer: wgpu::Buffer, + output_buffer: wgpu::Buffer, + staging_buffer: wgpu::Buffer, +} + +impl GpuUtilsInstance { + pub async fn new(input_data: &T, output_size: usize) -> Self { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Field Operations Device"), + required_features: wgpu::Features::SHADER_INT64, + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + // Create input buffer + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Field Input Buffer"), + contents: input_data.as_bytes(), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + + // Create output buffer + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Output Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Create staging buffer for reading results + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Staging Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + Self { + device, + queue, + input_buffer, + output_buffer, + staging_buffer, + } + } + + /// Creates a compute pipeline for the operation + pub fn create_pipeline( + &self, + shader_source: &str, + entry_point: &str, + ) -> (wgpu::ComputePipeline, wgpu::BindGroup) { + let shader = self + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Field Operations Shader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + let bind_group_layout = + self.device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Field Operations Bind Group Layout"), + }); + + let pipeline_layout = self + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Field Operations Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let pipeline = self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Field Operations Pipeline"), + layout: Some(&pipeline_layout), + module: &shader, + entry_point: Some(entry_point), + cache: None, + compilation_options: Default::default(), + }); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: self.input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: self.output_buffer.as_entire_binding(), + }, + ], + label: Some("Field Operations Bind Group"), + }); + + (pipeline, bind_group) + } + + /// Runs the computation on the GPU + async fn run_computation( + &self, + pipeline: &wgpu::ComputePipeline, + bind_group: &wgpu::BindGroup, + workgroup_count: (u32, u32, u32), + ) -> T { + // Create command encoder and compute pass + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + compute_pass.set_pipeline(&pipeline); + compute_pass.set_bind_group(0, bind_group, &[]); + compute_pass.dispatch_workgroups( + workgroup_count.0, + workgroup_count.1, + workgroup_count.2, + ); + } + + // Copy results to staging buffer + encoder.copy_buffer_to_buffer( + &self.output_buffer, + 0, + &self.staging_buffer, + 0, + self.staging_buffer.size(), + ); + + // Submit command buffer and wait for results + self.queue.submit(Some(encoder.finish())); + + // Read results from staging buffer + let slice = self.staging_buffer.slice(..); + let (sender, receiver) = flume::bounded(1); + slice.map_async(wgpu::MapMode::Read, move |result| { + sender.send(result).unwrap(); + }); + self.device.poll(wgpu::Maintain::Wait); + + receiver.recv_async().await.unwrap().unwrap(); + let data = slice.get_mapped_range(); + let result = *T::from_bytes(&data); + drop(data); + self.staging_buffer.unmap(); + + result + } +} + +#[derive(Debug)] +pub enum GpuUtilsOperation { + OffsetBitReversedCircleDomainIndex, +} + +impl GpuUtilsOperation { + pub fn shader_source(&self) -> String { + let base_source = include_str!("utils.wgsl"); + let qm31_source = include_str!("qm31.wgsl"); + + let inputs = r#" + struct Inputs { + i: u32, + domain_log_size: u32, + eval_log_size: u32, + offset: i32, + } + + @group(0) @binding(0) var inputs: Inputs; + "#; + + let output = r#" + struct Output { + result: u32, + } + + @group(0) @binding(1) var output: Output; + "#; + + let operation = match self { + GpuUtilsOperation::OffsetBitReversedCircleDomainIndex => { + r#" + @compute @workgroup_size(1) + fn main() {{ + let i = inputs.i; + let domain_log_size = inputs.domain_log_size; + let eval_log_size = inputs.eval_log_size; + let offset = inputs.offset; + + let result = offset_bit_reversed_circle_domain_index(i, domain_log_size, eval_log_size, offset); + output.result = result; + }} + "# + } + }; + + format!("{base_source}\n{qm31_source}\n{inputs}\n{output}\n{operation}") + } +} + +/// Computes the offset bit reversed circle domain index using the GPU +pub async fn compute_offset_bit_reversed_circle_domain_index( + i: usize, + domain_log_size: u32, + eval_log_size: u32, + offset: i32, +) -> usize { + let input = ComputeInput { + i: i as u32, + domain_log_size, + eval_log_size, + offset, + }; + + let instance = GpuUtilsInstance::new(&input, std::mem::size_of::()).await; + + let shader_source = GpuUtilsOperation::OffsetBitReversedCircleDomainIndex.shader_source(); + let (pipeline, bind_group) = instance.create_pipeline(&shader_source, "main"); + + let gpu_result = instance + .run_computation::(&pipeline, &bind_group, (1, 1, 1)) + .await; + gpu_result.into() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::utils::offset_bit_reversed_circle_domain_index as cpu_offset_bit_reversed_circle_domain_index; + + #[test] + fn test_offset_bit_reversed_circle_domain_index() { + // Test parameters from the CPU test + let domain_log_size = 3; + let eval_log_size = 6; + let initial_index = 5; + let offset = -2; + + let gpu_result = pollster::block_on(compute_offset_bit_reversed_circle_domain_index( + initial_index, + domain_log_size, + eval_log_size, + offset, + )); + println!("gpu_result: {}", gpu_result); + + let cpu_result = cpu_offset_bit_reversed_circle_domain_index( + initial_index, + domain_log_size, + eval_log_size, + offset as isize, + ); + + assert_eq!(gpu_result, cpu_result, "GPU and CPU results should match"); + } +} diff --git a/crates/prover/src/core/backend/gpu/utils.wgsl b/crates/prover/src/core/backend/gpu/utils.wgsl new file mode 100644 index 000000000..604b14eab --- /dev/null +++ b/crates/prover/src/core/backend/gpu/utils.wgsl @@ -0,0 +1,52 @@ +// This shader contains utility functions for bit manipulation and index transformations. +// It is stateless and can be used as a library in other shaders. + +/// Returns the bit reversed index of `i` which is represented by `log_size` bits. +fn bit_reverse_index(i: u32, log_size: u32) -> u32 { + if (log_size == 0u) { + return i; + } + let bits = reverse_bits_u32(i); + return bits >> (32u - log_size); +} + +fn reverse_bits_u32(x: u32) -> u32 { + var x_mut = x; + var result = 0u; + + for (var i = 0u; i < 32u; i = i + 1u) { + result = (result << 1u) | (x_mut & 1u); + x_mut = x_mut >> 1u; + } + + return result; +} + +/// Returns the index of the offset element in a bit reversed circle evaluation +/// of log size `eval_log_size` relative to a smaller domain of size `domain_log_size`. +fn offset_bit_reversed_circle_domain_index( + i: u32, + domain_log_size: u32, + eval_log_size: u32, + offset: i32, +) -> u32 { + var prev_index = bit_reverse_index(i, eval_log_size); + let half_size = 1u << (eval_log_size - 1u); + let step_size = i32(1u << (eval_log_size - domain_log_size - 1u)) * offset; + + if (prev_index < half_size) { + let temp = i32(prev_index) + step_size; + // Implement rem_euclid for positive modulo + let m = i32(half_size); + let rem = temp % m; + prev_index = u32(select(rem + m, rem, rem >= 0)); + } else { + let temp = i32(prev_index - half_size) - step_size; + // Implement rem_euclid for positive modulo + let m = i32(half_size); + let rem = temp % m; + prev_index = u32(select(rem + m, rem, rem >= 0)) + half_size; + } + + return bit_reverse_index(prev_index, eval_log_size); +} From d3630b3fe6f4c21dd8418ad54aed65db856c1d4a Mon Sep 17 00:00:00 2001 From: jason <94618524+mellowcroc@users.noreply.github.com> Date: Tue, 14 Jan 2025 14:58:20 +0900 Subject: [PATCH 4/6] refactor: create `gpu_common.rs` --- .../prover/src/core/backend/gpu/fraction.rs | 344 ++-------- .../prover/src/core/backend/gpu/fraction.wgsl | 2 +- .../prover/src/core/backend/gpu/gpu_common.rs | 226 +++++++ crates/prover/src/core/backend/gpu/mod.rs | 1 + crates/prover/src/core/backend/gpu/qm31.rs | 619 ++++-------------- crates/prover/src/core/backend/gpu/qm31.wgsl | 4 + 6 files changed, 410 insertions(+), 786 deletions(-) create mode 100644 crates/prover/src/core/backend/gpu/gpu_common.rs diff --git a/crates/prover/src/core/backend/gpu/fraction.rs b/crates/prover/src/core/backend/gpu/fraction.rs index d9cd3b2f9..f22c0fc32 100644 --- a/crates/prover/src/core/backend/gpu/fraction.rs +++ b/crates/prover/src/core/backend/gpu/fraction.rs @@ -1,334 +1,102 @@ -use wgpu::util::DeviceExt; +use std::borrow::Cow; -use super::qm31::GpuQM31; +use super::compute_composition_polynomial::GpuQM31; +use super::gpu_common::{ByteSerialize, GpuComputeInstance, GpuOperation}; use crate::core::fields::qm31::QM31; -use crate::core::lookups::utils::Fraction as CpuFraction; +use crate::core::lookups::utils::Fraction; -#[derive(Copy, Clone, Debug)] +#[repr(C)] +#[derive(Copy, Clone, Debug, PartialEq)] pub struct GpuFraction { - numerator: GpuQM31, - denominator: GpuQM31, -} - -impl From> for GpuFraction { - fn from(f: CpuFraction) -> Self { - GpuFraction { - numerator: f.numerator.into(), - denominator: f.denominator.into(), - } - } + pub numerator: GpuQM31, + pub denominator: GpuQM31, } -impl From> for GpuFraction { - fn from(f: CpuFraction) -> Self { - GpuFraction { - numerator: f.numerator.into(), - denominator: f.denominator.into(), - } - } -} - -// GPU computation structures #[repr(C)] #[derive(Copy, Clone, Debug)] -struct ComputeInputs { - first: GpuFraction, - second: GpuFraction, +pub struct ComputeInput { + pub first: GpuFraction, + pub second: GpuFraction, } -pub trait ByteSerialize: Sized { - fn as_bytes(&self) -> &[u8] { - unsafe { - std::slice::from_raw_parts( - (self as *const Self) as *const u8, - std::mem::size_of::(), - ) - } - } - - fn from_bytes(bytes: &[u8]) -> &Self { - assert!(bytes.len() >= std::mem::size_of::()); - unsafe { &*(bytes.as_ptr() as *const Self) } - } -} - -impl ByteSerialize for GpuFraction {} -impl ByteSerialize for ComputeInputs {} - -pub struct GpuFractionInstance { - pub device: wgpu::Device, - pub queue: wgpu::Queue, - pub input_buffer: wgpu::Buffer, - pub output_buffer: wgpu::Buffer, - pub staging_buffer: wgpu::Buffer, +#[repr(C)] +#[derive(Copy, Clone, Debug)] +pub struct ComputeOutput { + pub result: GpuFraction, } -impl GpuFractionInstance { - pub async fn new(input_data: &T, output_size: usize) -> Self { - let instance = wgpu::Instance::default(); - let adapter = instance - .request_adapter(&wgpu::RequestAdapterOptions { - power_preference: wgpu::PowerPreference::HighPerformance, - compatible_surface: None, - force_fallback_adapter: false, - }) - .await - .unwrap(); +impl ByteSerialize for ComputeInput {} +impl ByteSerialize for ComputeOutput {} - let (device, queue) = adapter - .request_device( - &wgpu::DeviceDescriptor { - label: Some("Field Operations Device"), - required_features: wgpu::Features::SHADER_INT64, - required_limits: wgpu::Limits::default(), - memory_hints: wgpu::MemoryHints::Performance, - }, - None, - ) - .await - .unwrap(); - - // Create input buffer - let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { - label: Some("Field Input Buffer"), - contents: input_data.as_bytes(), - usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, - }); - - // Create output buffer - let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { - label: Some("Field Output Buffer"), - size: output_size as u64, - usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, - mapped_at_creation: false, - }); - - // Create staging buffer for reading results - let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { - label: Some("Field Staging Buffer"), - size: output_size as u64, - usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, - mapped_at_creation: false, - }); - - Self { - device, - queue, - input_buffer, - output_buffer, - staging_buffer, - } - } - - pub fn create_pipeline( - &self, - shader_source: &str, - entry_point: &str, - ) -> (wgpu::ComputePipeline, wgpu::BindGroup) { - let shader = self - .device - .create_shader_module(wgpu::ShaderModuleDescriptor { - label: Some("Field Operations Shader"), - source: wgpu::ShaderSource::Wgsl(shader_source.into()), - }); - - let bind_group_layout = - self.device - .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { - entries: &[ - wgpu::BindGroupLayoutEntry { - binding: 0, - visibility: wgpu::ShaderStages::COMPUTE, - ty: wgpu::BindingType::Buffer { - ty: wgpu::BufferBindingType::Storage { read_only: true }, - has_dynamic_offset: false, - min_binding_size: None, - }, - count: None, - }, - wgpu::BindGroupLayoutEntry { - binding: 1, - visibility: wgpu::ShaderStages::COMPUTE, - ty: wgpu::BindingType::Buffer { - ty: wgpu::BufferBindingType::Storage { read_only: false }, - has_dynamic_offset: false, - min_binding_size: None, - }, - count: None, - }, - ], - label: Some("Field Operations Bind Group Layout"), - }); - - let pipeline_layout = self - .device - .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { - label: Some("Field Operations Pipeline Layout"), - bind_group_layouts: &[&bind_group_layout], - push_constant_ranges: &[], - }); - - let pipeline = self - .device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: Some("Field Operations Pipeline"), - layout: Some(&pipeline_layout), - module: &shader, - entry_point: Some(entry_point), - cache: None, - compilation_options: Default::default(), - }); - - let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { - layout: &bind_group_layout, - entries: &[ - wgpu::BindGroupEntry { - binding: 0, - resource: self.input_buffer.as_entire_binding(), - }, - wgpu::BindGroupEntry { - binding: 1, - resource: self.output_buffer.as_entire_binding(), - }, - ], - label: Some("Field Operations Bind Group"), - }); - - (pipeline, bind_group) - } - - pub async fn run_computation( - &self, - pipeline: &wgpu::ComputePipeline, - bind_group: &wgpu::BindGroup, - workgroup_count: (u32, u32, u32), - ) -> T { - let mut encoder = self - .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { - label: Some("Field Operations Encoder"), - }); - - { - let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { - label: Some("Field Operations Compute Pass"), - timestamp_writes: None, - }); - compute_pass.set_pipeline(pipeline); - compute_pass.set_bind_group(0, bind_group, &[]); - compute_pass.dispatch_workgroups( - workgroup_count.0, - workgroup_count.1, - workgroup_count.2, - ); +impl From> for GpuFraction { + fn from(value: Fraction) -> Self { + GpuFraction { + numerator: GpuQM31::from(value.numerator), + denominator: GpuQM31::from(value.denominator), } - - encoder.copy_buffer_to_buffer( - &self.output_buffer, - 0, - &self.staging_buffer, - 0, - self.staging_buffer.size(), - ); - - self.queue.submit(Some(encoder.finish())); - - let buffer_slice = self.staging_buffer.slice(..); - let (tx, rx) = flume::bounded(1); - buffer_slice.map_async(wgpu::MapMode::Read, move |result| { - tx.send(result).unwrap(); - }); - - self.device.poll(wgpu::Maintain::wait()); - - rx.recv_async().await.unwrap().unwrap(); - let data = buffer_slice.get_mapped_range(); - let result = *T::from_bytes(&data); - drop(data); - self.staging_buffer.unmap(); - - result } } -#[derive(Debug)] -pub enum GpuFractionOperation { +pub enum FractionOperation { Add, - Zero, } -impl GpuFractionOperation { - pub fn shader_source(&self) -> String { +impl GpuOperation for FractionOperation { + fn shader_source(&self) -> Cow<'static, str> { let base_source = include_str!("fraction.wgsl"); let qm31_source = include_str!("qm31.wgsl"); let inputs = r#" - struct Inputs { + struct ComputeInput { first: Fraction, second: Fraction, } - @group(0) @binding(0) var inputs: Inputs; + @group(0) @binding(0) var input: ComputeInput; "#; let output = r#" - @group(0) @binding(1) var output: Fraction; + struct ComputeOutput { + result: Fraction, + } + + @group(0) @binding(1) var output: ComputeOutput; "#; let operation = match self { - GpuFractionOperation::Add => { + FractionOperation::Add => { r#" @compute @workgroup_size(1) - fn main() {{ - let result = fraction_add(inputs.first, inputs.second); - output = result; - }} - "# - } - GpuFractionOperation::Zero => { - r#" - @compute @workgroup_size(1) - fn main() {{ - let result = fraction_zero(); - output = result; - }} - "# + fn main() { + output.result = fraction_add(input.first, input.second); + } + "# } }; - format!( - "{}\n{}\n{}\n{}\n{}\n", - qm31_source, base_source, inputs, output, operation - ) - } -} - -impl From for CpuFraction { - fn from(f: GpuFraction) -> Self { - CpuFraction::new(f.numerator.into(), f.denominator.into()) + format!("{qm31_source}\n{base_source}\n{inputs}\n{output}\n{operation}").into() } } pub async fn compute_fraction_operation( - a: CpuFraction, - b: CpuFraction, - operation: GpuFractionOperation, -) -> CpuFraction { - let inputs = ComputeInputs { - first: a.into(), - second: b.into(), + operation: FractionOperation, + first: Fraction, + second: Fraction, +) -> ComputeOutput { + let input = ComputeInput { + first: first.into(), + second: second.into(), }; - let instance = GpuFractionInstance::new(&inputs, std::mem::size_of::()).await; - - let shader_source = operation.shader_source(); - let (pipeline, bind_group) = instance.create_pipeline(&shader_source, "main"); + let instance = GpuComputeInstance::new(&input, std::mem::size_of::()).await; + let (pipeline, bind_group) = + instance.create_pipeline(&operation.shader_source(), operation.entry_point()); - let gpu_result = instance - .run_computation::(&pipeline, &bind_group, (1, 1, 1)) + let output = instance + .run_computation::(&pipeline, &bind_group, (1, 1, 1)) .await; - gpu_result.into() + output } #[cfg(test)] @@ -339,11 +107,11 @@ mod tests { #[test] fn test_fraction_add() { // CPU implementation - let cpu_a = CpuFraction::new( + let cpu_a = Fraction::new( QM31::from_u32_unchecked(1u32, 0u32, 0u32, 0u32), QM31::from_u32_unchecked(3u32, 0u32, 0u32, 0u32), ); - let cpu_b = CpuFraction::new( + let cpu_b = Fraction::new( QM31::from_u32_unchecked(2u32, 0u32, 0u32, 0u32), QM31::from_u32_unchecked(6u32, 0u32, 0u32, 0u32), ); @@ -351,12 +119,12 @@ mod tests { // GPU implementation let gpu_result = pollster::block_on(compute_fraction_operation( + FractionOperation::Add, cpu_a, cpu_b, - GpuFractionOperation::Add, )); - assert_eq!(cpu_result.numerator, gpu_result.numerator); - assert_eq!(cpu_result.denominator, gpu_result.denominator); + assert_eq!(cpu_result.numerator, gpu_result.result.numerator.into()); + assert_eq!(cpu_result.denominator, gpu_result.result.denominator.into()); } } diff --git a/crates/prover/src/core/backend/gpu/fraction.wgsl b/crates/prover/src/core/backend/gpu/fraction.wgsl index 93e13c4f5..79f54455a 100644 --- a/crates/prover/src/core/backend/gpu/fraction.wgsl +++ b/crates/prover/src/core/backend/gpu/fraction.wgsl @@ -14,4 +14,4 @@ fn fraction_add(a: Fraction, b: Fraction) -> Fraction { ); let denominator = qm31_mul(a.denominator, b.denominator); return Fraction(numerator, denominator); -} \ No newline at end of file +} diff --git a/crates/prover/src/core/backend/gpu/gpu_common.rs b/crates/prover/src/core/backend/gpu/gpu_common.rs new file mode 100644 index 000000000..66691a8b2 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/gpu_common.rs @@ -0,0 +1,226 @@ +use std::borrow::Cow; + +use wgpu::util::DeviceExt; + +/// Common trait for GPU input/output types +pub trait ByteSerialize: Sized { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + (self as *const Self) as *const u8, + std::mem::size_of::(), + ) + } + } + + fn from_bytes(bytes: &[u8]) -> &Self { + assert!(bytes.len() >= std::mem::size_of::()); + unsafe { &*(bytes.as_ptr() as *const Self) } + } +} + +/// Base GPU instance for field computations +pub struct GpuComputeInstance { + pub device: wgpu::Device, + pub queue: wgpu::Queue, + pub input_buffer: wgpu::Buffer, + pub output_buffer: wgpu::Buffer, + pub staging_buffer: wgpu::Buffer, +} + +impl GpuComputeInstance { + pub async fn new(input_data: &T, output_size: usize) -> Self { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Field Operations Device"), + required_features: wgpu::Features::SHADER_INT64, + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + // Create input buffer + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Field Input Buffer"), + contents: input_data.as_bytes(), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + + // Create output buffer + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Output Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Create staging buffer for reading results + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Staging Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + Self { + device, + queue, + input_buffer, + output_buffer, + staging_buffer, + } + } + + pub fn create_pipeline( + &self, + shader_source: &str, + entry_point: &str, + ) -> (wgpu::ComputePipeline, wgpu::BindGroup) { + let shader = self + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Field Operations Shader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + let bind_group_layout = + self.device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Field Operations Bind Group Layout"), + }); + + let pipeline_layout = self + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Field Operations Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let pipeline = self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Field Operations Pipeline"), + layout: Some(&pipeline_layout), + module: &shader, + entry_point: Some(entry_point), + cache: None, + compilation_options: Default::default(), + }); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: self.input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: self.output_buffer.as_entire_binding(), + }, + ], + label: Some("Field Operations Bind Group"), + }); + + (pipeline, bind_group) + } + + pub async fn run_computation( + &self, + pipeline: &wgpu::ComputePipeline, + bind_group: &wgpu::BindGroup, + workgroup_count: (u32, u32, u32), + ) -> T { + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Field Operations Encoder"), + }); + + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Field Operations Compute Pass"), + timestamp_writes: None, + }); + compute_pass.set_pipeline(pipeline); + compute_pass.set_bind_group(0, bind_group, &[]); + compute_pass.dispatch_workgroups( + workgroup_count.0, + workgroup_count.1, + workgroup_count.2, + ); + } + + encoder.copy_buffer_to_buffer( + &self.output_buffer, + 0, + &self.staging_buffer, + 0, + self.staging_buffer.size(), + ); + + self.queue.submit(Some(encoder.finish())); + + let buffer_slice = self.staging_buffer.slice(..); + let (tx, rx) = flume::bounded(1); + buffer_slice.map_async(wgpu::MapMode::Read, move |result| { + tx.send(result).unwrap(); + }); + + self.device.poll(wgpu::Maintain::wait()); + + rx.recv_async().await.unwrap().unwrap(); + let data = buffer_slice.get_mapped_range(); + let result = *T::from_bytes(&data); + drop(data); + self.staging_buffer.unmap(); + + result + } +} + +/// Trait for GPU operations that require shader source generation +pub trait GpuOperation { + fn shader_source(&self) -> Cow<'static, str>; + + fn entry_point(&self) -> &'static str { + "main" + } +} diff --git a/crates/prover/src/core/backend/gpu/mod.rs b/crates/prover/src/core/backend/gpu/mod.rs index 34934ffba..20a73b47b 100644 --- a/crates/prover/src/core/backend/gpu/mod.rs +++ b/crates/prover/src/core/backend/gpu/mod.rs @@ -4,5 +4,6 @@ pub mod gen_trace_interpolate_columns; pub mod gen_trace_parallel; pub mod gen_trace_parallel_no_packed; pub mod gen_trace_parallel_no_packed_parallel_columns; +pub mod gpu_common; pub mod qm31; pub mod utils; diff --git a/crates/prover/src/core/backend/gpu/qm31.rs b/crates/prover/src/core/backend/gpu/qm31.rs index 759163da5..aa0ce536a 100644 --- a/crates/prover/src/core/backend/gpu/qm31.rs +++ b/crates/prover/src/core/backend/gpu/qm31.rs @@ -1,384 +1,116 @@ -use wgpu::util::DeviceExt; +use std::borrow::Cow; -use crate::core::fields::cm31::CM31; -use crate::core::fields::m31::M31; +use super::compute_composition_polynomial::GpuQM31; +use super::gpu_common::{ByteSerialize, GpuComputeInstance, GpuOperation}; use crate::core::fields::qm31::QM31; -#[derive(Debug, Clone, Copy)] #[repr(C)] -pub struct GpuM31 { - pub data: u32, -} - -#[derive(Debug, Clone, Copy)] -#[repr(C)] -pub struct GpuCM31 { - pub a: GpuM31, - pub b: GpuM31, +#[derive(Copy, Clone, Debug, PartialEq)] +pub struct ComputeInput { + pub first: GpuQM31, + pub second: GpuQM31, } -#[derive(Debug, Clone, Copy)] #[repr(C)] -pub struct GpuQM31 { - pub a: GpuCM31, - pub b: GpuCM31, -} - -impl From for GpuM31 { - fn from(value: M31) -> Self { - GpuM31 { data: value.into() } - } -} - -impl From for GpuCM31 { - fn from(value: CM31) -> Self { - GpuCM31 { - a: value.0.into(), - b: value.1.into(), - } - } -} - -impl From for GpuQM31 { - fn from(value: QM31) -> Self { - GpuQM31 { - a: value.0.into(), - b: value.1.into(), - } - } -} - -impl From for QM31 { - fn from(value: GpuQM31) -> Self { - QM31::from_m31_array([ - value.a.a.data.into(), - value.a.b.data.into(), - value.b.a.data.into(), - value.b.b.data.into(), - ]) - } -} - -pub trait ByteSerialize: Sized { - fn as_bytes(&self) -> &[u8] { - unsafe { - std::slice::from_raw_parts( - (self as *const Self) as *const u8, - std::mem::size_of::(), - ) - } - } - - fn from_bytes(bytes: &[u8]) -> &Self { - assert!(bytes.len() >= std::mem::size_of::()); - unsafe { &*(bytes.as_ptr() as *const Self) } - } +#[derive(Copy, Clone, Debug, PartialEq)] +pub struct ComputeOutput { + pub result: GpuQM31, } -impl ByteSerialize for GpuM31 {} -impl ByteSerialize for GpuCM31 {} -impl ByteSerialize for GpuQM31 {} - -pub struct GpuFieldInstance { - pub device: wgpu::Device, - pub queue: wgpu::Queue, - pub input_buffer: wgpu::Buffer, - pub output_buffer: wgpu::Buffer, - pub staging_buffer: wgpu::Buffer, -} - -impl GpuFieldInstance { - pub async fn new(input_data: &T, output_size: usize) -> Self { - let instance = wgpu::Instance::default(); - let adapter = instance - .request_adapter(&wgpu::RequestAdapterOptions { - power_preference: wgpu::PowerPreference::HighPerformance, - compatible_surface: None, - force_fallback_adapter: false, - }) - .await - .unwrap(); - - let (device, queue) = adapter - .request_device( - &wgpu::DeviceDescriptor { - label: Some("Field Operations Device"), - required_features: wgpu::Features::SHADER_INT64, - required_limits: wgpu::Limits::default(), - memory_hints: wgpu::MemoryHints::Performance, - }, - None, - ) - .await - .unwrap(); - - // Create input buffer - let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { - label: Some("Field Input Buffer"), - contents: input_data.as_bytes(), - usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, - }); - - // Create output buffer - let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { - label: Some("Field Output Buffer"), - size: output_size as u64, - usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, - mapped_at_creation: false, - }); - - // Create staging buffer for reading results - let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { - label: Some("Field Staging Buffer"), - size: output_size as u64, - usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, - mapped_at_creation: false, - }); - - Self { - device, - queue, - input_buffer, - output_buffer, - staging_buffer, - } - } - - pub fn create_pipeline( - &self, - shader_source: &str, - entry_point: &str, - ) -> (wgpu::ComputePipeline, wgpu::BindGroup) { - let shader = self - .device - .create_shader_module(wgpu::ShaderModuleDescriptor { - label: Some("Field Operations Shader"), - source: wgpu::ShaderSource::Wgsl(shader_source.into()), - }); - - let bind_group_layout = - self.device - .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { - entries: &[ - wgpu::BindGroupLayoutEntry { - binding: 0, - visibility: wgpu::ShaderStages::COMPUTE, - ty: wgpu::BindingType::Buffer { - ty: wgpu::BufferBindingType::Storage { read_only: true }, - has_dynamic_offset: false, - min_binding_size: None, - }, - count: None, - }, - wgpu::BindGroupLayoutEntry { - binding: 1, - visibility: wgpu::ShaderStages::COMPUTE, - ty: wgpu::BindingType::Buffer { - ty: wgpu::BufferBindingType::Storage { read_only: false }, - has_dynamic_offset: false, - min_binding_size: None, - }, - count: None, - }, - ], - label: Some("Field Operations Bind Group Layout"), - }); - - let pipeline_layout = self - .device - .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { - label: Some("Field Operations Pipeline Layout"), - bind_group_layouts: &[&bind_group_layout], - push_constant_ranges: &[], - }); - - let pipeline = self - .device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: Some("Field Operations Pipeline"), - layout: Some(&pipeline_layout), - module: &shader, - entry_point: Some(entry_point), - cache: None, - compilation_options: Default::default(), - }); - - let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { - layout: &bind_group_layout, - entries: &[ - wgpu::BindGroupEntry { - binding: 0, - resource: self.input_buffer.as_entire_binding(), - }, - wgpu::BindGroupEntry { - binding: 1, - resource: self.output_buffer.as_entire_binding(), - }, - ], - label: Some("Field Operations Bind Group"), - }); - - (pipeline, bind_group) - } - - pub async fn run_computation( - &self, - pipeline: &wgpu::ComputePipeline, - bind_group: &wgpu::BindGroup, - workgroup_count: (u32, u32, u32), - ) -> T { - let mut encoder = self - .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { - label: Some("Field Operations Encoder"), - }); - - { - let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { - label: Some("Field Operations Compute Pass"), - timestamp_writes: None, - }); - compute_pass.set_pipeline(pipeline); - compute_pass.set_bind_group(0, bind_group, &[]); - compute_pass.dispatch_workgroups( - workgroup_count.0, - workgroup_count.1, - workgroup_count.2, - ); - } - - encoder.copy_buffer_to_buffer( - &self.output_buffer, - 0, - &self.staging_buffer, - 0, - self.staging_buffer.size(), - ); - - self.queue.submit(Some(encoder.finish())); - - let buffer_slice = self.staging_buffer.slice(..); - let (tx, rx) = flume::bounded(1); - buffer_slice.map_async(wgpu::MapMode::Read, move |result| { - tx.send(result).unwrap(); - }); +impl ByteSerialize for ComputeInput {} +impl ByteSerialize for ComputeOutput {} - self.device.poll(wgpu::Maintain::wait()); - - rx.recv_async().await.unwrap().unwrap(); - let data = buffer_slice.get_mapped_range(); - let result = *T::from_bytes(&data); - drop(data); - self.staging_buffer.unmap(); - - result - } -} - -#[derive(Debug)] -pub enum GpuFieldOperation { +pub enum QM31Operation { Add, - Multiply, Subtract, - Divide, + Multiply, Negate, + Inverse, } -impl GpuFieldOperation { - fn shader_source(&self) -> String { - let base_shader = include_str!("qm31.wgsl"); +impl GpuOperation for QM31Operation { + fn shader_source(&self) -> Cow<'static, str> { + let base_source = include_str!("qm31.wgsl"); let inputs = r#" - struct Inputs { + struct ComputeInput { first: QM31, second: QM31, } - @group(0) @binding(0) var inputs: Inputs; + @group(0) @binding(0) var input: ComputeInput; "#; let output = r#" - @group(0) @binding(1) var output: QM31; + struct ComputeOutput { + result: QM31, + } + + @group(0) @binding(1) var output: ComputeOutput; "#; let operation = match self { - GpuFieldOperation::Add => { + QM31Operation::Add => { r#" @compute @workgroup_size(1) fn main() { - output.a = cm31_add(inputs.first.a, inputs.second.a); - output.b = cm31_add(inputs.first.b, inputs.second.b); + output.result = qm31_add(input.first, input.second); } "# } - GpuFieldOperation::Multiply => { + QM31Operation::Multiply => { r#" @compute @workgroup_size(1) fn main() { - output = qm31_mul(inputs.first, inputs.second); + output.result = qm31_mul(input.first, input.second); } "# } - GpuFieldOperation::Subtract => { + QM31Operation::Subtract => { r#" @compute @workgroup_size(1) fn main() { - output.a = cm31_sub(inputs.first.a, inputs.second.a); - output.b = cm31_sub(inputs.first.b, inputs.second.b); + output.result = qm31_sub(input.first, input.second); } "# } - GpuFieldOperation::Divide => { + QM31Operation::Negate => { r#" @compute @workgroup_size(1) fn main() { - let inv_b = qm31_inverse(inputs.second); - output = qm31_mul(inputs.first, inv_b); + output.result = qm31_neg(input.first); } "# } - GpuFieldOperation::Negate => { + QM31Operation::Inverse => { r#" @compute @workgroup_size(1) fn main() { - output.a = cm31_neg(inputs.first.a); - output.b = cm31_neg(inputs.first.b); + output.result = qm31_inverse(input.first); } "# } }; - format!("{}\n{}\n{}\n{}\n", base_shader, inputs, output, operation) + format!("{base_source}\n{inputs}\n{output}\n{operation}").into() } } -#[derive(Debug)] -pub struct GpuFieldInputs { - pub first: GpuQM31, - pub second: GpuQM31, -} - -impl ByteSerialize for GpuFieldInputs {} - -pub async fn compute_field_operation(a: QM31, b: QM31, operation: GpuFieldOperation) -> QM31 { - let inputs = GpuFieldInputs { - first: GpuQM31::from(a), - second: GpuQM31::from(b), +pub async fn compute_field_operation(operation: QM31Operation, first: QM31, second: QM31) -> QM31 { + let input = ComputeInput { + first: first.into(), + second: second.into(), }; - let instance = GpuFieldInstance::new(&inputs, std::mem::size_of::()).await; - - let shader_source = operation.shader_source(); - let (pipeline, bind_group) = instance.create_pipeline(&shader_source, "main"); + let instance = GpuComputeInstance::new(&input, std::mem::size_of::()).await; + let (pipeline, bind_group) = + instance.create_pipeline(&operation.shader_source(), operation.entry_point()); - let result = instance - .run_computation::(&pipeline, &bind_group, (1, 1, 1)) + let output = instance + .run_computation::(&pipeline, &bind_group, (1, 1, 1)) .await; - QM31( - CM31(result.a.a.data.into(), result.a.b.data.into()), - CM31(result.b.a.data.into(), result.b.b.data.into()), - ) + output.result.into() } #[cfg(test)] @@ -386,8 +118,10 @@ mod tests { use num_traits::Zero; use super::*; + use crate::core::fields::cm31::CM31; use crate::core::fields::m31::{M31, P}; use crate::core::fields::qm31::QM31; + use crate::core::fields::FieldExpOps; use crate::{cm31, qm31}; #[test] @@ -431,50 +165,46 @@ mod tests { let one_qm = QM31(CM31(one, zero), CM31::zero()); let zero_qm = QM31(CM31(zero, zero), CM31::zero()); - // Test multiplication - let cpu_mul = m * one; - let gpu_mul = pollster::block_on(compute_field_operation( - m_qm, - one_qm, - GpuFieldOperation::Multiply, - )); - assert_eq!(gpu_mul.0 .0, cpu_mul, "M31 multiplication failed"); - // Test addition let cpu_add = m + one; - let gpu_add = pollster::block_on(compute_field_operation( - m_qm, - one_qm, - GpuFieldOperation::Add, - )); + let gpu_add = pollster::block_on(compute_field_operation(QM31Operation::Add, m_qm, one_qm)); assert_eq!(gpu_add.0 .0, cpu_add, "M31 addition failed"); // Test subtraction let cpu_sub = m - one; let gpu_sub = pollster::block_on(compute_field_operation( + QM31Operation::Subtract, m_qm, one_qm, - GpuFieldOperation::Subtract, )); assert_eq!(gpu_sub.0 .0, cpu_sub, "M31 subtraction failed"); + // Test multiplication + let cpu_mul = m * one; + let gpu_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, + m_qm, + one_qm, + )); + assert_eq!(gpu_mul.0 .0, cpu_mul, "M31 multiplication failed"); + // Test negation let cpu_neg = -m; let gpu_neg = pollster::block_on(compute_field_operation( + QM31Operation::Negate, m_qm, zero_qm, - GpuFieldOperation::Negate, )); assert_eq!(gpu_neg.0 .0, cpu_neg, "M31 negation failed"); - // Test division and inverse - let cpu_div = one / m; - let gpu_div = pollster::block_on(compute_field_operation( - one_qm, + // Test inverse + let cpu_inv = m.inverse(); + let gpu_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, m_qm, - GpuFieldOperation::Divide, + zero_qm, )); - assert_eq!(gpu_div.0 .0, cpu_div, "M31 division failed"); + assert_eq!(gpu_inv.0 .0, cpu_inv, "M31 inverse failed"); // Test with large numbers (near P) let large = M31::from(P - 1); @@ -483,25 +213,25 @@ mod tests { // Test large number multiplication let cpu_large_mul = large * m; let gpu_large_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, large_qm, m_qm, - GpuFieldOperation::Multiply, )); assert_eq!( gpu_large_mul.0 .0, cpu_large_mul, "M31 large number multiplication failed" ); - // Test large number division - let cpu_large_div = one / large; - let gpu_large_div = pollster::block_on(compute_field_operation( - one_qm, + // Test large number inverse + let cpu_large_inv = one / large; + let gpu_large_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, large_qm, - GpuFieldOperation::Divide, + zero_qm, )); assert_eq!( - gpu_large_div.0 .0, cpu_large_div, - "M31 large number division failed" + gpu_large_inv.0 .0, cpu_large_inv, + "M31 large number inverse failed" ); } @@ -509,99 +239,52 @@ mod tests { fn test_gpu_cm31_field_arithmetic() { let cm0 = cm31!(1, 2); let cm1 = cm31!(4, 5); - let m = M31::from(8u32); - let cm = CM31::from(m); let zero = CM31::zero(); - // Test multiplication - let cpu_mul = cm0 * cm1; - let gpu_mul = pollster::block_on(compute_field_operation( - QM31(cm0, zero), - QM31(cm1, zero), - GpuFieldOperation::Multiply, - )); - assert_eq!(gpu_mul.0, cpu_mul, "CM31 multiplication failed"); - // Test addition let cpu_add = cm0 + cm1; let gpu_add = pollster::block_on(compute_field_operation( + QM31Operation::Add, QM31(cm0, zero), QM31(cm1, zero), - GpuFieldOperation::Add, )); assert_eq!(gpu_add.0, cpu_add, "CM31 addition failed"); // Test subtraction let cpu_sub = cm0 - cm1; let gpu_sub = pollster::block_on(compute_field_operation( + QM31Operation::Subtract, QM31(cm0, zero), QM31(cm1, zero), - GpuFieldOperation::Subtract, )); assert_eq!(gpu_sub.0, cpu_sub, "CM31 subtraction failed"); + // Test multiplication + let cpu_mul = cm0 * cm1; + let gpu_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, + QM31(cm0, zero), + QM31(cm1, zero), + )); + assert_eq!(gpu_mul.0, cpu_mul, "CM31 multiplication failed"); + // Test negation let cpu_neg = -cm0; let gpu_neg = pollster::block_on(compute_field_operation( + QM31Operation::Negate, QM31(cm0, zero), QM31(zero, zero), - GpuFieldOperation::Negate, )); assert_eq!(gpu_neg.0, cpu_neg, "CM31 negation failed"); - // Test division - let cpu_div = cm0 / cm1; - let gpu_div = pollster::block_on(compute_field_operation( + // Test inverse + let cpu_inv = cm0.inverse(); + let gpu_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, QM31(cm0, zero), - QM31(cm1, zero), - GpuFieldOperation::Divide, - )); - assert_eq!(gpu_div.0, cpu_div, "CM31 division failed"); - - // Test scalar operations - let cpu_scalar_mul = cm1 * m; - let gpu_scalar_mul = pollster::block_on(compute_field_operation( - QM31(cm1, zero), - QM31(cm, zero), - GpuFieldOperation::Multiply, - )); - assert_eq!( - gpu_scalar_mul.0, cpu_scalar_mul, - "CM31 scalar multiplication failed" - ); - - let cpu_scalar_add = cm1 + m; - let gpu_scalar_add = pollster::block_on(compute_field_operation( - QM31(cm1, zero), - QM31(cm, zero), - GpuFieldOperation::Add, - )); - assert_eq!( - gpu_scalar_add.0, cpu_scalar_add, - "CM31 scalar addition failed" - ); - - let cpu_scalar_sub = cm1 - m; - let gpu_scalar_sub = pollster::block_on(compute_field_operation( - QM31(cm1, zero), - QM31(cm, zero), - GpuFieldOperation::Subtract, - )); - assert_eq!( - gpu_scalar_sub.0, cpu_scalar_sub, - "CM31 scalar subtraction failed" - ); - - let cpu_scalar_div = cm1 / m; - let gpu_scalar_div = pollster::block_on(compute_field_operation( - QM31(cm1, zero), - QM31(cm, zero), - GpuFieldOperation::Divide, + QM31(zero, zero), )); - assert_eq!( - gpu_scalar_div.0, cpu_scalar_div, - "CM31 scalar division failed" - ); + assert_eq!(gpu_inv.0, cpu_inv, "CM31 inverse failed"); // Test with large numbers (near P) let large = cm31!(P - 1, P - 2); @@ -610,25 +293,25 @@ mod tests { // Test large number multiplication let cpu_large_mul = large * cm1; let gpu_large_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, large_qm, QM31(cm1, zero), - GpuFieldOperation::Multiply, )); assert_eq!( gpu_large_mul.0, cpu_large_mul, "CM31 large number multiplication failed" ); - // Test large number division - let cpu_large_div = large / cm1; - let gpu_large_div = pollster::block_on(compute_field_operation( + // Test large number inverse + let cpu_large_inv = large.inverse(); + let gpu_large_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, large_qm, - QM31(cm1, zero), - GpuFieldOperation::Divide, + QM31(zero, zero), )); assert_eq!( - gpu_large_div.0, cpu_large_div, - "CM31 large number division failed" + gpu_large_inv.0, cpu_large_inv, + "CM31 large number inverse failed" ); } @@ -636,112 +319,54 @@ mod tests { fn test_gpu_qm31_field_arithmetic() { let qm0 = qm31!(1, 2, 3, 4); let qm1 = qm31!(4, 5, 6, 7); - let m = M31::from(8u32); - let qm = QM31::from(m); let zero = QM31::zero(); - // Test multiplication - let cpu_mul = qm0 * qm1; - let gpu_mul = pollster::block_on(compute_field_operation( - qm0, - qm1, - GpuFieldOperation::Multiply, - )); - assert_eq!(gpu_mul, cpu_mul, "QM31 multiplication failed"); - // Test addition let cpu_add = qm0 + qm1; - let gpu_add = pollster::block_on(compute_field_operation(qm0, qm1, GpuFieldOperation::Add)); + let gpu_add = pollster::block_on(compute_field_operation(QM31Operation::Add, qm0, qm1)); assert_eq!(gpu_add, cpu_add, "QM31 addition failed"); // Test subtraction let cpu_sub = qm0 - qm1; - let gpu_sub = pollster::block_on(compute_field_operation( - qm0, - qm1, - GpuFieldOperation::Subtract, - )); + let gpu_sub = + pollster::block_on(compute_field_operation(QM31Operation::Subtract, qm0, qm1)); assert_eq!(gpu_sub, cpu_sub, "QM31 subtraction failed"); + // Test multiplication + let cpu_mul = qm0 * qm1; + let gpu_mul = + pollster::block_on(compute_field_operation(QM31Operation::Multiply, qm0, qm1)); + assert_eq!(gpu_mul, cpu_mul, "QM31 multiplication failed"); + // Test negation let cpu_neg = -qm0; - let gpu_neg = pollster::block_on(compute_field_operation( - qm0, - zero, - GpuFieldOperation::Negate, - )); + let gpu_neg = pollster::block_on(compute_field_operation(QM31Operation::Negate, qm0, zero)); assert_eq!(gpu_neg, cpu_neg, "QM31 negation failed"); - // Test division - let cpu_div = qm0 / qm1; - let gpu_div = - pollster::block_on(compute_field_operation(qm0, qm1, GpuFieldOperation::Divide)); - assert_eq!(gpu_div, cpu_div, "QM31 division failed"); - - // Test scalar operations - let cpu_scalar_mul = qm1 * m; - let gpu_scalar_mul = pollster::block_on(compute_field_operation( - qm1, - qm, - GpuFieldOperation::Multiply, - )); - assert_eq!( - cpu_scalar_mul, gpu_scalar_mul, - "QM31 scalar multiplication failed" - ); - - let cpu_scalar_add = qm1 + m; - let gpu_scalar_add = - pollster::block_on(compute_field_operation(qm1, qm, GpuFieldOperation::Add)); - assert_eq!( - cpu_scalar_add, gpu_scalar_add, - "QM31 scalar addition failed" - ); - - let cpu_scalar_sub = qm1 - m; - let gpu_scalar_sub = pollster::block_on(compute_field_operation( - qm1, - qm, - GpuFieldOperation::Subtract, - )); - assert_eq!( - cpu_scalar_sub, gpu_scalar_sub, - "QM31 scalar subtraction failed" - ); - - let cpu_scalar_div = qm1 / m; - let gpu_scalar_div = - pollster::block_on(compute_field_operation(qm1, qm, GpuFieldOperation::Divide)); - assert_eq!( - cpu_scalar_div, gpu_scalar_div, - "QM31 scalar division failed" - ); + // Test inverse + let cpu_inv = qm0.inverse(); + let gpu_inv = pollster::block_on(compute_field_operation(QM31Operation::Inverse, qm0, qm1)); + assert_eq!(gpu_inv, cpu_inv, "QM31 inverse failed"); // Test with large numbers (near P) let large = qm31!(P - 1, P - 2, P - 3, P - 4); // Test large number multiplication let cpu_large_mul = large * qm1; - let gpu_large_mul = pollster::block_on(compute_field_operation( - large, - qm1, - GpuFieldOperation::Multiply, - )); + let gpu_large_mul = + pollster::block_on(compute_field_operation(QM31Operation::Multiply, large, qm1)); assert_eq!( gpu_large_mul, cpu_large_mul, "QM31 large number multiplication failed" ); - // Test large number division - let cpu_large_div = large / qm1; - let gpu_large_div = pollster::block_on(compute_field_operation( - large, - qm1, - GpuFieldOperation::Divide, - )); + // Test large number inverse + let cpu_large_inv = qm1.inverse(); + let gpu_large_inv = + pollster::block_on(compute_field_operation(QM31Operation::Inverse, qm1, zero)); assert_eq!( - gpu_large_div, cpu_large_div, - "QM31 large number division failed" + gpu_large_inv, cpu_large_inv, + "QM31 large number inverse failed" ); } } diff --git a/crates/prover/src/core/backend/gpu/qm31.wgsl b/crates/prover/src/core/backend/gpu/qm31.wgsl index 4807724ec..4d044a1bd 100644 --- a/crates/prover/src/core/backend/gpu/qm31.wgsl +++ b/crates/prover/src/core/backend/gpu/qm31.wgsl @@ -71,6 +71,10 @@ fn m31_square(x: M31, n: u32) -> M31 { return result; } +fn m31_pow5(x: M31) -> M31 { + return m31_mul(m31_square(x, 2u), x); +} + fn m31_inverse(x: M31) -> M31 { // Computes x^(2^31-2) using the same sequence as pow2147483645 // This is equivalent to x^(P-2) where P = 2^31-1 From 0c752d26f63e4b2821d4967193821e5844897cfc Mon Sep 17 00:00:00 2001 From: jason <94618524+mellowcroc@users.noreply.github.com> Date: Wed, 22 Jan 2025 17:22:51 +0900 Subject: [PATCH 5/6] feat: add compute_composition_polynomial.wgsl Run `evaluate_constraint_quotients_on_domain` on the GPU and compare results Note: need to uncomment different hardcoded lookup_elements depending on the `LOG_N_INSTANCES` value when running `prove_poseidon` (values for 9 to 15 LOG_N_INSTANCES have been added to the code). This is because the current code structure doesn't allow retrieving the lookup elements inside `component.rs` --- .../src/constraint_framework/component.rs | 171 ++++++- crates/prover/src/constraint_framework/mod.rs | 4 +- .../gpu/compute_composition_polynomial.rs | 478 ++++++++++++++++++ .../gpu/compute_composition_polynomial.wgsl | 336 ++++++++++++ .../prover/src/core/backend/gpu/fraction.rs | 2 +- .../prover/src/core/backend/gpu/fraction.wgsl | 13 + .../prover/src/core/backend/gpu/gpu_common.rs | 20 +- crates/prover/src/core/backend/gpu/mod.rs | 1 + crates/prover/src/core/backend/gpu/qm31.rs | 67 ++- 9 files changed, 1078 insertions(+), 14 deletions(-) create mode 100644 crates/prover/src/core/backend/gpu/compute_composition_polynomial.rs create mode 100644 crates/prover/src/core/backend/gpu/compute_composition_polynomial.wgsl diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 23981cc06..6664f2ddf 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -3,6 +3,8 @@ use std::collections::HashMap; use std::fmt::{self, Display, Formatter}; use std::iter::zip; use std::ops::Deref; +#[cfg(not(target_family = "wasm"))] +use std::time::Instant; use itertools::Itertools; #[cfg(feature = "parallel")] @@ -15,16 +17,19 @@ use super::preprocessed_columns::PreprocessedColumn; use super::{ EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX, }; +use crate::constraint_framework::logup::LookupElements; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::{Component, ComponentProver, Trace}; +use crate::core::backend::gpu::compute_composition_polynomial::compute_composition_polynomial_gpu; use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; use crate::core::backend::simd::m31::LOG_N_LANES; use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS}; use crate::core::backend::simd::SimdBackend; use crate::core::circle::CirclePoint; use crate::core::constraints::coset_vanishing; -use crate::core::fields::m31::BaseField; -use crate::core::fields::qm31::SecureField; +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::{BaseField, M31}; +use crate::core::fields::qm31::{SecureField, QM31}; use crate::core::fields::secure_column::SecureColumnByCoords; use crate::core::fields::FieldExpOps; use crate::core::pcs::{TreeSubspan, TreeVec}; @@ -329,6 +334,133 @@ impl ComponentProver for FrameworkComponen *accum.col = col; return; } + let trace_cols = trace.as_cols_ref().map_cols(|c| c.to_cpu()); + let trace_cols = trace_cols.as_cols_ref(); + + let mut lookup_elements: LookupElements<16> = LookupElements::dummy(); + // // 2^9 instances + // lookup_elements.z = QM31::from_m31_array([ + // M31::from(1620680704), + // M31::from(1901317872), + // M31::from(913853993), + // M31::from(1286799353), + // ]); + // lookup_elements.alpha = QM31::from_m31_array([ + // M31::from(2011422255), + // M31::from(1962282213), + // M31::from(69078916), + // M31::from(407074834), + // ]); + + // // 2^10 instances + // lookup_elements.z = QM31::from_m31_array([ + // M31::from(1465862614), + // M31::from(1583357442), + // M31::from(1715957657), + // M31::from(977402081), + // ]); + // lookup_elements.alpha = QM31::from_m31_array([ + // M31::from(1058568156), + // M31::from(1376697150), + // M31::from(1770783003), + // M31::from(1982948122), + // ]); + + // // 2^11 instances + // lookup_elements.z = QM31::from_m31_array([ + // M31::from(1465160322), + // M31::from(1969992531), + // M31::from(2003820064), + // M31::from(1543307892), + // ]); + // lookup_elements.alpha = QM31::from_m31_array([ + // M31::from(1503490752), + // M31::from(1175877240), + // M31::from(1430566545), + // M31::from(1673011189), + // ]); + + // 2^12 instances + lookup_elements.z = QM31::from_m31_array([ + M31::from(589075703), + M31::from(149359250), + M31::from(1907284710), + M31::from(729671227), + ]); + lookup_elements.alpha = QM31::from_m31_array([ + M31::from(318198925), + M31::from(1203679427), + M31::from(870875217), + M31::from(1185640677), + ]); + + // // 2^13 instances + // lookup_elements.z = QM31::from_m31_array([ + // M31::from(1628655791), + // M31::from(1055381932), + // M31::from(980792236), + // M31::from(1563574579), + // ]); + // lookup_elements.alpha = QM31::from_m31_array([ + // M31::from(758947366), + // M31::from(782855802), + // M31::from(792359994), + // M31::from(1161959256), + // ]); + + // // 2^14 instances + // lookup_elements.z = QM31::from_m31_array([ + // M31::from(668979421), + // M31::from(2097978502), + // M31::from(428317414), + // M31::from(1503540921), + // ]); + // lookup_elements.alpha = QM31::from_m31_array([ + // M31::from(962480916), + // M31::from(462545530), + // M31::from(118859601), + // M31::from(1868751663), + // ]); + + // // 2^15 instances + // lookup_elements.z = QM31::from_m31_array([ + // M31::from(1185288908), + // M31::from(1548569092), + // M31::from(792634712), + // M31::from(779398798), + // ]); + // lookup_elements.alpha = QM31::from_m31_array([ + // M31::from(138774446), + // M31::from(799972521), + // M31::from(2070047733), + // M31::from(2053058841), + // ]); + let mut cur = QM31::from(1); + lookup_elements.alpha_powers = std::array::from_fn(|_| { + let res = cur; + cur *= lookup_elements.alpha; + res + }); + + #[cfg(not(target_family = "wasm"))] + let gpu_start = Instant::now(); + + #[cfg(not(target_family = "wasm"))] + let gpu_results = pollster::block_on(compute_composition_polynomial_gpu( + trace_cols, + denom_inv.clone(), + accum.random_coeff_powers.clone(), + lookup_elements, + trace_domain.log_size(), + eval_domain.log_size(), + self.logup_sums.0, + )); + + #[cfg(not(target_family = "wasm"))] + println!("GPU time: {:?}", gpu_start.elapsed()); + + #[cfg(not(target_family = "wasm"))] + let cpu_start = Instant::now(); let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) }; @@ -378,6 +510,41 @@ impl ComponentProver for FrameworkComponen } } }); + + #[cfg(not(target_family = "wasm"))] + let cpu_end = Instant::now(); + #[cfg(not(target_family = "wasm"))] + println!("CPU time: {:?}", cpu_end - cpu_start); + + // Compare CPU and GPU results + let cpu_vec = col.to_vec(); + let gpu_vec: Vec = gpu_results + .output + .poly + .iter() + .flat_map(|inner| { + inner.iter().map(|&gpu_qm| { + QM31( + CM31::from_m31(gpu_qm.a.a.data.into(), gpu_qm.a.b.data.into()), + CM31::from_m31(gpu_qm.b.a.data.into(), gpu_qm.b.b.data.into()), + ) + }) + }) + .collect(); + + assert_eq!( + cpu_vec.len(), + gpu_vec.len(), + "CPU and GPU result lengths differ" + ); + + for (_i, (cpu_val, gpu_val)) in zip(cpu_vec.iter(), gpu_vec.iter()).enumerate() { + assert_eq!( + cpu_val, gpu_val, + "Mismatch at index {}: CPU = {:?}, GPU = {:?}", + _i, cpu_val, gpu_val + ); + } } } diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 044f05ac8..5a744c9d1 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -1,7 +1,7 @@ /// ! This module contains helpers to express and use constraints for components. mod assert; mod component; -mod cpu_domain; +pub mod cpu_domain; pub mod expr; mod info; pub mod logup; @@ -247,7 +247,7 @@ impl<'a, F: Clone, EF: RelationEFTraitBound, R: Relation> RelationEntr macro_rules! relation { ($name:tt, $size:tt) => { #[derive(Clone, Debug, PartialEq)] - pub struct $name($crate::constraint_framework::logup::LookupElements<$size>); + pub struct $name(pub $crate::constraint_framework::logup::LookupElements<$size>); #[allow(dead_code)] impl $name { diff --git a/crates/prover/src/core/backend/gpu/compute_composition_polynomial.rs b/crates/prover/src/core/backend/gpu/compute_composition_polynomial.rs new file mode 100644 index 000000000..9287634eb --- /dev/null +++ b/crates/prover/src/core/backend/gpu/compute_composition_polynomial.rs @@ -0,0 +1,478 @@ +use std::collections::HashMap; + +use itertools::Itertools; +use wgpu::util::DeviceExt; + +use super::qm31::GpuQM31; +use crate::constraint_framework::logup::LookupElements; +use crate::constraint_framework::{ + INTERACTION_TRACE_IDX, ORIGINAL_TRACE_IDX, PREPROCESSED_TRACE_IDX, +}; +use crate::core::backend::gpu::qm31::GpuM31; +use crate::core::backend::CpuBackend; +use crate::core::fields::m31::M31; +use crate::core::fields::qm31::QM31; +use crate::core::pcs::TreeVec; +use crate::core::poly::circle::CircleEvaluation; +use crate::core::poly::BitReversedOrder; + +pub const N_ROWS: u32 = 32; +pub const N_STATE: u32 = 16; +pub const N_LOG_INSTANCES_PER_ROW: u32 = 3; +pub const N_INSTANCES_PER_ROW: u32 = 1 << N_LOG_INSTANCES_PER_ROW; +pub const N_LANES: u32 = 16; +pub const N_EXTENDED_ROWS: u32 = N_ROWS * 4; +pub const N_CONSTRAINTS: u32 = 1144; +pub const N_COLUMNS: u32 = 1264; +pub const N_INTERACTION_COLUMNS: u32 = N_INSTANCES_PER_ROW * 4; +pub const N_WORKGROUPS: u32 = N_EXTENDED_ROWS * N_LANES / THREADS_PER_WORKGROUP; +pub const THREADS_PER_WORKGROUP: u32 = 256; +pub const N_HALF_FULL_ROUNDS: u32 = 4; +pub const N_PARTIAL_ROUNDS: u32 = 14; + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct GpuExtendedColumn { + pub data: [[GpuM31; N_LANES as usize]; N_EXTENDED_ROWS as usize], + pub length: u32, +} + +impl From<&CircleEvaluation> for GpuExtendedColumn { + fn from(value: &CircleEvaluation) -> Self { + let mut data = [[GpuM31 { data: 0 }; N_LANES as usize]; N_EXTENDED_ROWS as usize]; + for (i, chunk) in value.values.chunks(N_LANES as usize).enumerate() { + let mut row = [GpuM31 { data: 0 }; N_LANES as usize]; + for (j, &val) in chunk.iter().enumerate() { + row[j] = val.into(); + } + data[i] = row; + } + GpuExtendedColumn { + data, + length: N_EXTENDED_ROWS, + } + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(C)] +pub struct GpuLookupElements { + pub z: GpuQM31, + pub alpha: GpuQM31, + pub alpha_powers: [GpuQM31; N_STATE as usize], +} + +impl From> for GpuLookupElements +where + [GpuQM31; N]: Sized, +{ + fn from(value: LookupElements) -> Self { + GpuLookupElements { + z: value.z.into(), + alpha: value.alpha.into(), + alpha_powers: value + .alpha_powers + .iter() + .map(|&x| x.into()) + .collect::>() + .try_into() + .unwrap(), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ComputeCompositionPolynomialInput { + extended_preprocessed_trace: GpuExtendedColumn, + extended_trace: [GpuExtendedColumn; N_COLUMNS as usize], + extended_interaction_trace: [GpuExtendedColumn; N_INTERACTION_COLUMNS as usize], + denom_inv: [GpuM31; 4], + random_coeff_powers: [GpuQM31; N_CONSTRAINTS as usize], + lookup_elements: GpuLookupElements, + trace_domain_log_size: u32, + eval_domain_log_size: u32, + total_sum: GpuQM31, +} + +#[derive(Debug, Clone, Copy)] +pub struct ComputeCompositionPolynomialOutput { + pub poly: [[GpuQM31; N_LANES as usize]; N_EXTENDED_ROWS as usize], +} + +#[derive(Debug, Clone)] +pub struct ComputationResults { + pub output: ComputeCompositionPolynomialOutput, +} + +pub trait ByteSerialize: Sized { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + (self as *const Self) as *const u8, + std::mem::size_of::(), + ) + } + } + + fn from_bytes(bytes: &[u8]) -> &Self { + assert!(bytes.len() >= std::mem::size_of::()); + unsafe { &*(bytes.as_ptr() as *const Self) } + } +} + +impl ByteSerialize for GpuExtendedColumn {} +impl ByteSerialize for ComputeCompositionPolynomialOutput {} + +impl ComputeCompositionPolynomialInput { + fn as_bytes(&self) -> &[u8] { + let total_size = std::mem::size_of::(); + let mut bytes = Vec::with_capacity(total_size); + bytes.extend_from_slice(unsafe { + std::slice::from_raw_parts( + &self.extended_preprocessed_trace as *const GpuExtendedColumn as *const u8, + std::mem::size_of::(), + ) + }); + bytes.extend_from_slice(unsafe { + std::slice::from_raw_parts( + &self.extended_trace as *const GpuExtendedColumn as *const u8, + N_COLUMNS as usize * std::mem::size_of::(), + ) + }); + + bytes.extend_from_slice(unsafe { + std::slice::from_raw_parts( + &self.extended_interaction_trace as *const GpuExtendedColumn as *const u8, + N_INTERACTION_COLUMNS as usize * std::mem::size_of::(), + ) + }); + + bytes.extend_from_slice(unsafe { + std::slice::from_raw_parts( + &self.denom_inv as *const GpuM31 as *const u8, + 4 * std::mem::size_of::(), + ) + }); + + bytes.extend_from_slice(unsafe { + std::slice::from_raw_parts( + &self.random_coeff_powers as *const GpuQM31 as *const u8, + N_CONSTRAINTS as usize * std::mem::size_of::(), + ) + }); + + bytes.extend_from_slice(unsafe { + std::slice::from_raw_parts( + &self.lookup_elements as *const GpuLookupElements as *const u8, + std::mem::size_of::(), + ) + }); + + bytes.extend_from_slice(unsafe { + std::slice::from_raw_parts( + &self.trace_domain_log_size as *const u32 as *const u8, + std::mem::size_of::(), + ) + }); + + bytes.extend_from_slice(unsafe { + std::slice::from_raw_parts( + &self.eval_domain_log_size as *const u32 as *const u8, + std::mem::size_of::(), + ) + }); + + bytes.extend_from_slice(unsafe { + std::slice::from_raw_parts( + &self.total_sum as *const GpuQM31 as *const u8, + std::mem::size_of::(), + ) + }); + + Box::leak(bytes.into_boxed_slice()) + } +} + +impl ComputeCompositionPolynomialOutput { + fn from_bytes(bytes: &[u8]) -> Self { + unsafe { *(bytes.as_ptr() as *const Self) } + } +} + +pub struct WgpuInstance { + pub instance: wgpu::Instance, + pub adapter: wgpu::Adapter, + pub device: wgpu::Device, + pub queue: wgpu::Queue, + pub staging_buffer: wgpu::Buffer, + pub encoder: wgpu::CommandEncoder, +} + +async fn init( + trace: TreeVec>>, + denom_inv: Vec, + random_coeff_powers: Vec, + lookup_elements: LookupElements, + trace_domain_log_size: u32, + eval_domain_log_size: u32, + total_sum: QM31, +) -> WgpuInstance { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Device"), + required_features: wgpu::Features::SHADER_INT64, + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + let input_data = create_gpu_input::( + trace, + denom_inv, + random_coeff_powers, + lookup_elements, + trace_domain_log_size, + eval_domain_log_size, + total_sum, + ); + + // Create buffers + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Input Buffer"), + contents: input_data.as_bytes(), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + + let buffer_size = std::mem::size_of::(); + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Output Buffer"), + size: buffer_size as wgpu::BufferAddress, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Load shader + let qm31_shader = include_str!("qm31.wgsl"); + let fraction_shader = include_str!("fraction.wgsl"); + let utils_shader = include_str!("utils.wgsl"); + let composition_shader = include_str!("compute_composition_polynomial.wgsl"); + let combined_shader = format!( + "{}\n + {}\n + {}\n + {}", + qm31_shader, fraction_shader, utils_shader, composition_shader + ); + let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Compute Composition Polynomial Shader"), + source: wgpu::ShaderSource::Wgsl(combined_shader.into()), + }); + + // Bind group layout + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + // Binding 0: Input buffer + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + // Binding 1: Output buffer + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Compute Composition Polynomial Bind Group Layout"), + }); + + // Create bind group + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: output_buffer.as_entire_binding(), + }, + ], + label: Some("Compute Composition Polynomial Bind Group"), + }); + + // Pipeline layout + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + label: Some("Compute Composition Polynomial Pipeline Layout"), + }); + + // Compute pipeline + let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Compute Composition Polynomial Compute Pipeline"), + layout: Some(&pipeline_layout), + module: &shader_module, + entry_point: Some("compute_composition_polynomial"), + cache: None, + compilation_options: wgpu::PipelineCompilationOptions { + constants: &HashMap::from([]), + zero_initialize_workgroup_memory: true, + }, + }); + + // Create encoder + let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Compute Composition Polynomial Command Encoder"), + }); + + // Dispatch the compute shader + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Compute Composition Polynomial Compute Pass"), + timestamp_writes: None, + }); + compute_pass.set_pipeline(&compute_pipeline); + compute_pass.set_bind_group(0, &bind_group, &[]); + compute_pass.dispatch_workgroups(N_WORKGROUPS, 1, 1); + } + + // Copy output to staging buffer for read access + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Staging Buffer"), + size: buffer_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, staging_buffer.size()); + + WgpuInstance { + instance, + adapter, + device, + queue, + staging_buffer, + encoder, + } +} + +fn create_gpu_input( + trace: TreeVec>>, + denom_inv: Vec, + random_coeff_powers: Vec, + lookup_elements: LookupElements, + trace_domain_log_size: u32, + eval_domain_log_size: u32, + total_sum: QM31, +) -> ComputeCompositionPolynomialInput { + let extended_preprocessed_trace_gpu = + GpuExtendedColumn::from(trace.0[PREPROCESSED_TRACE_IDX][0]); + + let extended_trace_gpu: [GpuExtendedColumn; N_COLUMNS as usize] = trace.0[ORIGINAL_TRACE_IDX] + .iter() + .map(|eval| GpuExtendedColumn::from(*eval)) + .collect_vec() + .try_into() + .expect("Wrong length"); + + let extended_interaction_trace_gpu: [GpuExtendedColumn; N_INTERACTION_COLUMNS as usize] = trace + .0[INTERACTION_TRACE_IDX] + .iter() + .map(|eval| GpuExtendedColumn::from(*eval)) + .collect_vec() + .try_into() + .expect("Wrong length"); + + let denom_inv_gpu: [GpuM31; 4] = denom_inv + .into_iter() + .map(GpuM31::from) + .collect::>() + .try_into() + .expect("Wrong length"); + + let random_coeff_powers_gpu: [GpuQM31; N_CONSTRAINTS as usize] = random_coeff_powers + .into_iter() + .map(GpuQM31::from) + .collect::>() + .try_into() + .expect("Wrong length"); + + let lookup_elements_gpu = GpuLookupElements::from(lookup_elements); + + ComputeCompositionPolynomialInput { + extended_preprocessed_trace: extended_preprocessed_trace_gpu, + extended_trace: extended_trace_gpu, + extended_interaction_trace: extended_interaction_trace_gpu, + denom_inv: denom_inv_gpu, + random_coeff_powers: random_coeff_powers_gpu, + lookup_elements: lookup_elements_gpu, + trace_domain_log_size, + eval_domain_log_size, + total_sum: total_sum.into(), + } +} + +pub async fn compute_composition_polynomial_gpu<'a, const N: usize>( + trace: TreeVec>>, + denom_inv: Vec, + random_coeff_powers: Vec, + lookup_elements: LookupElements, + trace_domain_log_size: u32, + eval_domain_log_size: u32, + total_sum: QM31, +) -> ComputationResults { + let instance = init( + trace, + denom_inv, + random_coeff_powers, + lookup_elements, + trace_domain_log_size, + eval_domain_log_size, + total_sum, + ) + .await; + instance.queue.submit(Some(instance.encoder.finish())); + let output_slice = instance.staging_buffer.slice(..); + let (sender, receiver) = flume::bounded(1); + output_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap()); + instance + .device + .poll(wgpu::Maintain::wait()) + .panic_on_timeout(); + let result = async { + receiver.recv_async().await.unwrap().unwrap(); + let data = output_slice.get_mapped_range(); + let output = ComputeCompositionPolynomialOutput::from_bytes(&data); + drop(data); + instance.staging_buffer.unmap(); + output + }; + + let output = result.await; + ComputationResults { output } +} diff --git a/crates/prover/src/core/backend/gpu/compute_composition_polynomial.wgsl b/crates/prover/src/core/backend/gpu/compute_composition_polynomial.wgsl new file mode 100644 index 000000000..8d53bd101 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/compute_composition_polynomial.wgsl @@ -0,0 +1,336 @@ +// Note: depends on qm31.wgsl, fraction.wgsl, utils.wgsl +// Define constants +const N_ROWS: u32 = 32; +const N_EXTENDED_ROWS: u32 = N_ROWS * 4; +const N_STATE: u32 = 16; +const N_INSTANCES_PER_ROW: u32 = 8; +const N_COLUMNS: u32 = N_INSTANCES_PER_ROW * N_COLUMNS_PER_REP; +const N_INTERACTION_COLUMNS: u32 = N_INSTANCES_PER_ROW * 4; +const N_HALF_FULL_ROUNDS: u32 = 4; +const FULL_ROUNDS: u32 = 2u * N_HALF_FULL_ROUNDS; +const N_PARTIAL_ROUNDS: u32 = 14; +const N_LANES: u32 = 16; +const N_COLUMNS_PER_REP: u32 = N_STATE * (1 + FULL_ROUNDS) + N_PARTIAL_ROUNDS; +const LOG_N_LANES: u32 = 4; +const N_WORKGROUPS: u32 = N_EXTENDED_ROWS * N_LANES / THREADS_PER_WORKGROUP; +const THREADS_PER_WORKGROUP: u32 = 256; +const MAX_ARRAY_LOG_SIZE: u32 = 20; +const MAX_ARRAY_SIZE: u32 = 1u << MAX_ARRAY_LOG_SIZE; +const N_CONSTRAINTS: u32 = 1144; +const R: CM31 = CM31(M31(2u), M31(1u)); +const ONE = QM31(CM31(M31(1u), M31(0u)), CM31(M31(0u), M31(0u))); + +// Initialize EXTERNAL_ROUND_CONSTS with explicit values +const EXTERNAL_ROUND_CONSTS: array, FULL_ROUNDS> = array, FULL_ROUNDS>( + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), + array(1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u, 1234u), +); + +// Initialize INTERNAL_ROUND_CONSTS with explicit values +const INTERNAL_ROUND_CONSTS: array = array( + 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234, 1234 +); + +struct BaseColumn { + data: array, N_EXTENDED_ROWS>, + length: u32, +} + +struct LookupElements { + z: QM31, + alpha: QM31, + alpha_powers: array, +} + +struct ComputeCompositionPolynomialInput { + extended_preprocessed_trace: BaseColumn, + extended_trace: array, + extended_interaction_trace: array, + denom_inv: array, + random_coeff_powers: array, + lookup_elements: LookupElements, + trace_domain_log_size: u32, + eval_domain_log_size: u32, + total_sum: QM31, +} + +struct ComputeCompositionPolynomialOutput { + poly: array, N_EXTENDED_ROWS>, +} + +struct RelationEntry { + multiplicity: QM31, + values: array, +} + +@group(0) @binding(0) +var input: ComputeCompositionPolynomialInput; + +@group(0) @binding(1) +var output: ComputeCompositionPolynomialOutput; + +var constraint_index: u32 = 0u; + +var prev_col_cumsum: QM31 = QM31(CM31(M31(0u), M31(0u)), CM31(M31(0u), M31(0u))); + +var cur_frac: Fraction = ZERO_FRACTION; + +var is_first: M31 = M31(0u); + +var is_finalized: bool = false; + +@compute @workgroup_size(THREADS_PER_WORKGROUP) +fn compute_composition_polynomial( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(local_invocation_id) local_invocation_id: vec3, + @builtin(global_invocation_id) global_invocation_id: vec3, + @builtin(local_invocation_index) local_invocation_index: u32, + @builtin(num_workgroups) num_workgroups: vec3, +) { + let workgroup_index = + workgroup_id.x + + workgroup_id.y * num_workgroups.x + + workgroup_id.z * num_workgroups.x * num_workgroups.y; + + let global_invocation_index = workgroup_index * THREADS_PER_WORKGROUP + local_invocation_index; // [0, 512) + + var vec_index = global_invocation_index / N_LANES; + var inner_vec_index = global_invocation_index % N_LANES; + var col_index = 0u; + + for (var rep_i = 0u; rep_i < N_INSTANCES_PER_ROW; rep_i++) { + var state: array; + for (var j = 0u; j < N_STATE; j++) { + state[j] = next_trace_mask(col_index, vec_index, inner_vec_index); + col_index += 1u; + } + var initial_state = state; + + // 4 full rounds + for (var i = 0u; i < N_HALF_FULL_ROUNDS; i++) { + for (var j = 0u; j < N_STATE; j++) { + state[j] = m31_add(state[j], M31(EXTERNAL_ROUND_CONSTS[i][j])); + } + state = apply_external_round_matrix(state); + for (var j = 0u; j < N_STATE; j++) { + state[j] = m31_pow5(state[j]); + } + for (var j = 0u; j < N_STATE; j++) { + var m_1 = next_trace_mask(col_index, vec_index, inner_vec_index); + let constraint = m31_sub(state[j], m_1); + add_constraint(constraint, vec_index, inner_vec_index); + + state[j] = m_1; + col_index += 1u; + } + } + // Partial rounds + for (var i = 0u; i < N_PARTIAL_ROUNDS; i++) { + state[0] = m31_add(state[0], M31(INTERNAL_ROUND_CONSTS[i])); + state = apply_internal_round_matrix(state); + state[0] = m31_pow5(state[0]); + var m_1 = next_trace_mask(col_index, vec_index, inner_vec_index); + let constraint = m31_sub(state[0], m_1); + add_constraint(constraint, vec_index, inner_vec_index); + + state[0] = m_1; + col_index += 1u; + } + // 4 full rounds + for (var i = 0u; i < N_HALF_FULL_ROUNDS; i++) { + for (var j = 0u; j < N_STATE; j++) { + state[j] = m31_add(state[j], M31(EXTERNAL_ROUND_CONSTS[i + N_HALF_FULL_ROUNDS][j])); + } + state = apply_external_round_matrix(state); + for (var j = 0u; j < N_STATE; j++) { + state[j] = m31_pow5(state[j]); + } + for (var j = 0u; j < N_STATE; j++) { + var m_1 = next_trace_mask(col_index, vec_index, inner_vec_index); + let constraint = m31_sub(state[j], m_1); + add_constraint(constraint, vec_index, inner_vec_index); + state[j] = m_1; + col_index += 1u; + } + } + + // Store the final relation constraints + let relation_constraint_1 = qm31_mul(ONE, QM31(CM31(state[0], M31(0u)), CM31(M31(0u), M31(0u)))); + let relation_constraint_2 = qm31_mul(qm31_neg(ONE), QM31(CM31(state[0], M31(0u)), CM31(M31(0u), M31(0u)))); + + add_to_relation(array( + RelationEntry(ONE, initial_state), + RelationEntry(qm31_neg(ONE), state) + ), vec_index, inner_vec_index, rep_i); + } + finalize_logup(vec_index, inner_vec_index); + + let row = vec_index * N_STATE + inner_vec_index; + let denom_inv = input.denom_inv[row >> input.trace_domain_log_size]; + output.poly[vec_index][inner_vec_index] = qm31_mul(output.poly[vec_index][inner_vec_index], QM31(CM31(denom_inv, M31(0u)), CM31(M31(0u), M31(0u)))); +} + +fn add_constraint(constraint: M31, vec_index: u32, inner_vec_index: u32) { + add_constraint_qm31(QM31(CM31(constraint, M31(0u)), CM31(M31(0u), M31(0u))), vec_index, inner_vec_index); +} + +fn add_constraint_qm31(constraint: QM31, vec_index: u32, inner_vec_index: u32) { + var new_add = qm31_mul(constraint, input.random_coeff_powers[constraint_index]); + output.poly[vec_index][inner_vec_index] = qm31_add(output.poly[vec_index][inner_vec_index], new_add); + constraint_index += 1u; +} + +fn add_to_relation(entries: array, vec_index: u32, inner_vec_index: u32, rep_i: u32) { + var frac_sum = Fraction(QM31(CM31(M31(0u), M31(0u)), CM31(M31(0u), M31(0u))), QM31(CM31(M31(1u), M31(0u)), CM31(M31(0u), M31(0u)))); + for (var i = 0u; i < 2; i++) { + var combined_value = QM31(CM31(M31(0u), M31(0u)), CM31(M31(0u), M31(0u))); + for (var j = 0u; j < N_STATE; j++) { + let value = QM31(CM31(entries[i].values[j], M31(0u)), CM31(M31(0u), M31(0u))); + combined_value = qm31_add(combined_value, qm31_mul(input.lookup_elements.alpha_powers[j], value)); + } + + combined_value = qm31_sub(combined_value, input.lookup_elements.z); + + frac_sum = fraction_add(frac_sum, Fraction(entries[i].multiplicity, combined_value)); + } + write_logup_frac(frac_sum, vec_index, inner_vec_index, rep_i); +} + +fn write_logup_frac(frac: Fraction, vec_index: u32, inner_vec_index: u32, rep_i: u32) { + if (!fraction_eq(cur_frac, ZERO_FRACTION)) { + var interaction_col_index = (rep_i - 1u) * 4; // TODO: Improve this. + var cur_cumsum = next_interaction_trace_mask(interaction_col_index, vec_index, inner_vec_index); + var diff = qm31_sub(cur_cumsum, prev_col_cumsum); + prev_col_cumsum = cur_cumsum; + var constraint = qm31_sub(qm31_mul(diff, cur_frac.denominator), cur_frac.numerator); + add_constraint_qm31(constraint, vec_index, inner_vec_index); + } else { + is_first = input.extended_preprocessed_trace.data[vec_index][inner_vec_index]; + is_finalized = false; + } + cur_frac = frac; +} + +fn finalize_logup(vec_index: u32, inner_vec_index: u32) { + if (is_finalized) { + return; + } + // TODO: add support for when claimed_sum is not None. + var last_interaction_col_index = (N_INSTANCES_PER_ROW - 1u) * 4u; + + var cur_cumsum = next_interaction_trace_mask(last_interaction_col_index, vec_index, inner_vec_index); + + var prev_row_cumsum = next_interaction_trace_mask_offset(last_interaction_col_index, vec_index, inner_vec_index, -1); + + var total_sum_mod = qm31_mul(QM31(CM31(is_first, M31(0u)), CM31(M31(0u), M31(0u))), input.total_sum); + + var fixed_prev_row_cumsum = qm31_sub(prev_row_cumsum, total_sum_mod); + + var diff = qm31_sub(qm31_sub(cur_cumsum, fixed_prev_row_cumsum), prev_col_cumsum); + var constraint = qm31_sub(qm31_mul(diff, cur_frac.denominator), cur_frac.numerator); + add_constraint_qm31(constraint, vec_index, inner_vec_index); + is_finalized = true; +} + +fn next_trace_mask(col_index: u32, vec_index: u32, inner_vec_index: u32) -> M31 { + return input.extended_trace[col_index].data[vec_index][inner_vec_index]; +} + +fn next_interaction_trace_mask(col_index: u32, vec_index: u32, inner_vec_index: u32) -> QM31 { + // get the next 4 values in the interaction trace columns + return QM31( + CM31( + input.extended_interaction_trace[col_index].data[vec_index][inner_vec_index], + input.extended_interaction_trace[col_index + 1].data[vec_index][inner_vec_index] + ), + CM31( + input.extended_interaction_trace[col_index + 2].data[vec_index][inner_vec_index], + input.extended_interaction_trace[col_index + 3].data[vec_index][inner_vec_index] + ) + ); +} + +fn next_interaction_trace_mask_offset(col_index: u32, vec_index: u32, inner_vec_index: u32, offset: i32) -> QM31 { + var curr_row = vec_index * N_STATE + inner_vec_index; + + var row = offset_bit_reversed_circle_domain_index(curr_row, input.trace_domain_log_size, input.eval_domain_log_size, offset); + + var new_vec_index = row / N_LANES; + var new_inner_vec_index = row % N_LANES; + return QM31( + CM31( + input.extended_interaction_trace[col_index].data[new_vec_index][new_inner_vec_index], + input.extended_interaction_trace[col_index + 1].data[new_vec_index][new_inner_vec_index] + ), + CM31( + input.extended_interaction_trace[col_index + 2].data[new_vec_index][new_inner_vec_index], + input.extended_interaction_trace[col_index + 3].data[new_vec_index][new_inner_vec_index] + ) + ); +} + +/// Applies the external round matrix. +/// See 5.1 and Appendix B. +fn apply_external_round_matrix(state: array) -> array { + // Applies circ(2M4, M4, M4, M4). + var modified_state = state; + for (var i = 0u; i < 4u; i++) { + let partial_state = array( + state[4 * i], + state[4 * i + 1], + state[4 * i + 2], + state[4 * i + 3], + ); + let modified_partial_state = apply_m4(partial_state); + modified_state[4 * i] = modified_partial_state[0]; + modified_state[4 * i + 1] = modified_partial_state[1]; + modified_state[4 * i + 2] = modified_partial_state[2]; + modified_state[4 * i + 3] = modified_partial_state[3]; + } + for (var j = 0u; j < 4u; j++) { + let s = m31_add(m31_add(modified_state[j], modified_state[j + 4]), m31_add(modified_state[j + 8], modified_state[j + 12])); + for (var i = 0u; i < 4u; i++) { + modified_state[4 * i + j] = m31_add(modified_state[4 * i + j], s); + } + } + return modified_state; +} + +// Applies the internal round matrix. +// mu_i = 2^{i+1} + 1. +// See 5.2. +fn apply_internal_round_matrix(state: array) -> array { + var sum = state[0]; + for (var i = 1u; i < N_STATE; i++) { + sum = m31_add(sum, state[i]); + } + + var result = array(); + for (var i = 0u; i < N_STATE; i++) { + let factor = partial_reduce(1u << (i + 1)); + result[i] = m31_add(m31_mul(M31(factor), state[i]), sum); + } + + return result; +} + +/// Applies the M4 MDS matrix described in 5.1. +fn apply_m4(x: array) -> array { + let t0 = m31_add(x[0], x[1]); + let t02 = m31_add(t0, t0); + let t1 = m31_add(x[2], x[3]); + let t12 = m31_add(t1, t1); + let t2 = m31_add(m31_add(x[1], x[1]), t1); + let t3 = m31_add(m31_add(x[3], x[3]), t0); + let t4 = m31_add(m31_add(t12, t12), t3); + let t5 = m31_add(m31_add(t02, t02), t2); + let t6 = m31_add(t3, t5); + let t7 = m31_add(t2, t4); + return array(t6, t5, t7, t4); +} diff --git a/crates/prover/src/core/backend/gpu/fraction.rs b/crates/prover/src/core/backend/gpu/fraction.rs index f22c0fc32..94725cd04 100644 --- a/crates/prover/src/core/backend/gpu/fraction.rs +++ b/crates/prover/src/core/backend/gpu/fraction.rs @@ -1,7 +1,7 @@ use std::borrow::Cow; -use super::compute_composition_polynomial::GpuQM31; use super::gpu_common::{ByteSerialize, GpuComputeInstance, GpuOperation}; +use super::qm31::GpuQM31; use crate::core::fields::qm31::QM31; use crate::core::lookups::utils::Fraction; diff --git a/crates/prover/src/core/backend/gpu/fraction.wgsl b/crates/prover/src/core/backend/gpu/fraction.wgsl index 79f54455a..19037a534 100644 --- a/crates/prover/src/core/backend/gpu/fraction.wgsl +++ b/crates/prover/src/core/backend/gpu/fraction.wgsl @@ -1,6 +1,8 @@ // This shader contains implementations for fraction operations. // It is stateless and can be used as a library in other shaders. +const ZERO_FRACTION: Fraction = Fraction(QM31(CM31(M31(0u), M31(0u)), CM31(M31(0u), M31(0u))), QM31(CM31(M31(1u), M31(0u)), CM31(M31(0u), M31(0u)))); + struct Fraction { numerator: QM31, denominator: QM31, @@ -15,3 +17,14 @@ fn fraction_add(a: Fraction, b: Fraction) -> Fraction { let denominator = qm31_mul(a.denominator, b.denominator); return Fraction(numerator, denominator); } + +fn fraction_eq(a: Fraction, b: Fraction) -> bool { + return a.numerator.a.a.data == b.numerator.a.a.data + && a.numerator.a.b.data == b.numerator.a.b.data + && a.numerator.b.a.data == b.numerator.b.a.data + && a.numerator.b.b.data == b.numerator.b.b.data + && a.denominator.a.a.data == b.denominator.a.a.data + && a.denominator.a.b.data == b.denominator.a.b.data + && a.denominator.b.a.data == b.denominator.b.a.data + && a.denominator.b.b.data == b.denominator.b.b.data; +} diff --git a/crates/prover/src/core/backend/gpu/gpu_common.rs b/crates/prover/src/core/backend/gpu/gpu_common.rs index 66691a8b2..69df3680c 100644 --- a/crates/prover/src/core/backend/gpu/gpu_common.rs +++ b/crates/prover/src/core/backend/gpu/gpu_common.rs @@ -13,9 +13,9 @@ pub trait ByteSerialize: Sized { } } - fn from_bytes(bytes: &[u8]) -> &Self { + fn from_bytes(bytes: &[u8]) -> Self { assert!(bytes.len() >= std::mem::size_of::()); - unsafe { &*(bytes.as_ptr() as *const Self) } + unsafe { std::ptr::read(bytes.as_ptr() as *const Self) } } } @@ -162,7 +162,7 @@ impl GpuComputeInstance { (pipeline, bind_group) } - pub async fn run_computation( + pub async fn run_computation( &self, pipeline: &wgpu::ComputePipeline, bind_group: &wgpu::BindGroup, @@ -206,11 +206,15 @@ impl GpuComputeInstance { self.device.poll(wgpu::Maintain::wait()); - rx.recv_async().await.unwrap().unwrap(); - let data = buffer_slice.get_mapped_range(); - let result = *T::from_bytes(&data); - drop(data); - self.staging_buffer.unmap(); + let result = async { + rx.recv_async().await.unwrap().unwrap(); + let data = buffer_slice.get_mapped_range(); + let result = T::from_bytes(&data); + drop(data); + self.staging_buffer.unmap(); + result + }; + let result = result.await; result } diff --git a/crates/prover/src/core/backend/gpu/mod.rs b/crates/prover/src/core/backend/gpu/mod.rs index 20a73b47b..f0a5b5732 100644 --- a/crates/prover/src/core/backend/gpu/mod.rs +++ b/crates/prover/src/core/backend/gpu/mod.rs @@ -1,3 +1,4 @@ +pub mod compute_composition_polynomial; pub mod fraction; pub mod gen_trace; pub mod gen_trace_interpolate_columns; diff --git a/crates/prover/src/core/backend/gpu/qm31.rs b/crates/prover/src/core/backend/gpu/qm31.rs index aa0ce536a..3585b343e 100644 --- a/crates/prover/src/core/backend/gpu/qm31.rs +++ b/crates/prover/src/core/backend/gpu/qm31.rs @@ -1,9 +1,74 @@ use std::borrow::Cow; -use super::compute_composition_polynomial::GpuQM31; use super::gpu_common::{ByteSerialize, GpuComputeInstance, GpuOperation}; +use crate::core::fields::cm31::CM31; +use crate::core::fields::m31::M31; use crate::core::fields::qm31::QM31; +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(C)] +pub struct GpuM31 { + pub data: u32, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(C)] +pub struct GpuCM31 { + pub a: GpuM31, + pub b: GpuM31, +} + +#[derive(Debug, Clone, Copy, PartialEq)] +#[repr(C)] +pub struct GpuQM31 { + pub a: GpuCM31, + pub b: GpuCM31, +} + +impl From for GpuQM31 { + fn from(value: QM31) -> Self { + GpuQM31 { + a: GpuCM31 { + a: GpuM31 { + data: value.0 .0.into(), + }, + b: GpuM31 { + data: value.0 .1.into(), + }, + }, + b: GpuCM31 { + a: GpuM31 { + data: value.1 .0.into(), + }, + b: GpuM31 { + data: value.1 .1.into(), + }, + }, + } + } +} + +impl From for QM31 { + fn from(value: GpuQM31) -> Self { + QM31( + CM31::from_m31(value.a.a.data.into(), value.a.b.data.into()), + CM31::from_m31(value.b.a.data.into(), value.b.b.data.into()), + ) + } +} + +impl From for GpuM31 { + fn from(value: M31) -> Self { + GpuM31 { data: value.into() } + } +} + +impl From for M31 { + fn from(value: GpuM31) -> Self { + M31::from(value.data) + } +} + #[repr(C)] #[derive(Copy, Clone, Debug, PartialEq)] pub struct ComputeInput { From 769ea40a0e433efc5ba96725ad7ac2130f6d796d Mon Sep 17 00:00:00 2001 From: jason <94618524+mellowcroc@users.noreply.github.com> Date: Thu, 23 Jan 2025 18:28:30 +0900 Subject: [PATCH 6/6] feat: support evaluating constraints in wasm --- .cargo/config.toml | 5 + .../src/constraint_framework/component.rs | 54 ++++-- .../gpu/compute_composition_polynomial.rs | 28 ++- .../gpu/gen_trace_interpolate_columns.rs | 13 +- crates/prover/src/core/backend/gpu/mod.rs | 1 + crates/prover/src/core/backend/gpu/prove.rs | 165 ++++++++++++++++++ crates/prover/src/core/poly/circle/mod.rs | 2 +- crates/prover/src/examples/poseidon/mod.rs | 79 ++++++++- 8 files changed, 311 insertions(+), 36 deletions(-) create mode 100644 crates/prover/src/core/backend/gpu/prove.rs diff --git a/.cargo/config.toml b/.cargo/config.toml index f26d8449a..4ec64d6d0 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,3 +1,8 @@ [build] # Configuration for these lints should be placed in `.clippy.toml` at the crate root. rustflags = ["-Dwarnings"] + +[target.wasm32-unknown-unknown] +rustflags = [ + "-C", "link-args=-z stack-size=268435456", +] \ No newline at end of file diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs index 6664f2ddf..de62864df 100644 --- a/crates/prover/src/constraint_framework/component.rs +++ b/crates/prover/src/constraint_framework/component.rs @@ -17,9 +17,9 @@ use super::preprocessed_columns::PreprocessedColumn; use super::{ EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX, }; -use crate::constraint_framework::logup::LookupElements; use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; use crate::core::air::{Component, ComponentProver, Trace}; +#[cfg(not(target_family = "wasm"))] use crate::core::backend::gpu::compute_composition_polynomial::compute_composition_polynomial_gpu; use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords; use crate::core::backend::simd::m31::LOG_N_LANES; @@ -27,6 +27,7 @@ use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VER use crate::core::backend::simd::SimdBackend; use crate::core::circle::CirclePoint; use crate::core::constraints::coset_vanishing; +#[cfg(not(target_family = "wasm"))] use crate::core::fields::cm31::CM31; use crate::core::fields::m31::{BaseField, M31}; use crate::core::fields::qm31::{SecureField, QM31}; @@ -36,6 +37,7 @@ use crate::core::pcs::{TreeSubspan, TreeVec}; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::{utils, ColumnVec}; +use crate::examples::poseidon::PoseidonElements; const CHUNK_SIZE: usize = 1; @@ -177,6 +179,10 @@ impl FrameworkComponent { pub fn trace_locations(&self) -> &[TreeSubspan] { &self.trace_locations } + + pub fn eval(&self) -> &E { + &self.eval + } } impl Component for FrameworkComponent { @@ -246,6 +252,7 @@ impl Component for FrameworkComponent { } } +#[allow(unused_variables)] impl ComponentProver for FrameworkComponent { fn evaluate_constraint_quotients_on_domain( &self, @@ -337,15 +344,15 @@ impl ComponentProver for FrameworkComponen let trace_cols = trace.as_cols_ref().map_cols(|c| c.to_cpu()); let trace_cols = trace_cols.as_cols_ref(); - let mut lookup_elements: LookupElements<16> = LookupElements::dummy(); + let mut lookup_elements: PoseidonElements = PoseidonElements::dummy(); // // 2^9 instances - // lookup_elements.z = QM31::from_m31_array([ + // lookup_elements.0.z = QM31::from_m31_array([ // M31::from(1620680704), // M31::from(1901317872), // M31::from(913853993), // M31::from(1286799353), // ]); - // lookup_elements.alpha = QM31::from_m31_array([ + // lookup_elements.0.alpha = QM31::from_m31_array([ // M31::from(2011422255), // M31::from(1962282213), // M31::from(69078916), @@ -353,13 +360,13 @@ impl ComponentProver for FrameworkComponen // ]); // // 2^10 instances - // lookup_elements.z = QM31::from_m31_array([ + // lookup_elements.0.z = QM31::from_m31_array([ // M31::from(1465862614), // M31::from(1583357442), // M31::from(1715957657), // M31::from(977402081), // ]); - // lookup_elements.alpha = QM31::from_m31_array([ + // lookup_elements.0.alpha = QM31::from_m31_array([ // M31::from(1058568156), // M31::from(1376697150), // M31::from(1770783003), @@ -367,13 +374,13 @@ impl ComponentProver for FrameworkComponen // ]); // // 2^11 instances - // lookup_elements.z = QM31::from_m31_array([ + // lookup_elements.0.z = QM31::from_m31_array([ // M31::from(1465160322), // M31::from(1969992531), // M31::from(2003820064), // M31::from(1543307892), // ]); - // lookup_elements.alpha = QM31::from_m31_array([ + // lookup_elements.0.alpha = QM31::from_m31_array([ // M31::from(1503490752), // M31::from(1175877240), // M31::from(1430566545), @@ -381,13 +388,13 @@ impl ComponentProver for FrameworkComponen // ]); // 2^12 instances - lookup_elements.z = QM31::from_m31_array([ + lookup_elements.0.z = QM31::from_m31_array([ M31::from(589075703), M31::from(149359250), M31::from(1907284710), M31::from(729671227), ]); - lookup_elements.alpha = QM31::from_m31_array([ + lookup_elements.0.alpha = QM31::from_m31_array([ M31::from(318198925), M31::from(1203679427), M31::from(870875217), @@ -395,13 +402,13 @@ impl ComponentProver for FrameworkComponen ]); // // 2^13 instances - // lookup_elements.z = QM31::from_m31_array([ + // lookup_elements.0.z = QM31::from_m31_array([ // M31::from(1628655791), // M31::from(1055381932), // M31::from(980792236), // M31::from(1563574579), // ]); - // lookup_elements.alpha = QM31::from_m31_array([ + // lookup_elements.0.alpha = QM31::from_m31_array([ // M31::from(758947366), // M31::from(782855802), // M31::from(792359994), @@ -409,13 +416,13 @@ impl ComponentProver for FrameworkComponen // ]); // // 2^14 instances - // lookup_elements.z = QM31::from_m31_array([ + // lookup_elements.0.z = QM31::from_m31_array([ // M31::from(668979421), // M31::from(2097978502), // M31::from(428317414), // M31::from(1503540921), // ]); - // lookup_elements.alpha = QM31::from_m31_array([ + // lookup_elements.0.alpha = QM31::from_m31_array([ // M31::from(962480916), // M31::from(462545530), // M31::from(118859601), @@ -423,7 +430,7 @@ impl ComponentProver for FrameworkComponen // ]); // // 2^15 instances - // lookup_elements.z = QM31::from_m31_array([ + // lookup_elements.0.z = QM31::from_m31_array([ // M31::from(1185288908), // M31::from(1548569092), // M31::from(792634712), @@ -436,9 +443,9 @@ impl ComponentProver for FrameworkComponen // M31::from(2053058841), // ]); let mut cur = QM31::from(1); - lookup_elements.alpha_powers = std::array::from_fn(|_| { + lookup_elements.0.alpha_powers = std::array::from_fn(|_| { let res = cur; - cur *= lookup_elements.alpha; + cur *= lookup_elements.0.alpha; res }); @@ -462,6 +469,9 @@ impl ComponentProver for FrameworkComponen #[cfg(not(target_family = "wasm"))] let cpu_start = Instant::now(); + #[cfg(target_family = "wasm")] + let cpu_start = web_sys::window().unwrap().performance().unwrap().now(); + let col = unsafe { VeryPackedSecureColumnByCoords::transform_under_mut(accum.col) }; let range = 0..(1 << (eval_domain.log_size() - LOG_N_LANES - LOG_N_VERY_PACKED_ELEMS)); @@ -515,9 +525,17 @@ impl ComponentProver for FrameworkComponen let cpu_end = Instant::now(); #[cfg(not(target_family = "wasm"))] println!("CPU time: {:?}", cpu_end - cpu_start); + #[cfg(target_family = "wasm")] + let cpu_end = web_sys::window().unwrap().performance().unwrap().now(); + #[cfg(target_family = "wasm")] + web_sys::console::log_1( + &format!("Evaluate Constraints CPU time: {:?}", cpu_end - cpu_start).into(), + ); // Compare CPU and GPU results + #[cfg(not(target_family = "wasm"))] let cpu_vec = col.to_vec(); + #[cfg(not(target_family = "wasm"))] let gpu_vec: Vec = gpu_results .output .poly @@ -532,12 +550,14 @@ impl ComponentProver for FrameworkComponen }) .collect(); + #[cfg(not(target_family = "wasm"))] assert_eq!( cpu_vec.len(), gpu_vec.len(), "CPU and GPU result lengths differ" ); + #[cfg(not(target_family = "wasm"))] for (_i, (cpu_val, gpu_val)) in zip(cpu_vec.iter(), gpu_vec.iter()).enumerate() { assert_eq!( cpu_val, gpu_val, diff --git a/crates/prover/src/core/backend/gpu/compute_composition_polynomial.rs b/crates/prover/src/core/backend/gpu/compute_composition_polynomial.rs index 9287634eb..686fe8394 100644 --- a/crates/prover/src/core/backend/gpu/compute_composition_polynomial.rs +++ b/crates/prover/src/core/backend/gpu/compute_composition_polynomial.rs @@ -4,7 +4,6 @@ use itertools::Itertools; use wgpu::util::DeviceExt; use super::qm31::GpuQM31; -use crate::constraint_framework::logup::LookupElements; use crate::constraint_framework::{ INTERACTION_TRACE_IDX, ORIGINAL_TRACE_IDX, PREPROCESSED_TRACE_IDX, }; @@ -15,6 +14,7 @@ use crate::core::fields::qm31::QM31; use crate::core::pcs::TreeVec; use crate::core::poly::circle::CircleEvaluation; use crate::core::poly::BitReversedOrder; +use crate::examples::poseidon::PoseidonElements; pub const N_ROWS: u32 = 32; pub const N_STATE: u32 = 16; @@ -62,15 +62,13 @@ pub struct GpuLookupElements { pub alpha_powers: [GpuQM31; N_STATE as usize], } -impl From> for GpuLookupElements -where - [GpuQM31; N]: Sized, -{ - fn from(value: LookupElements) -> Self { +impl From for GpuLookupElements { + fn from(value: PoseidonElements) -> Self { GpuLookupElements { - z: value.z.into(), - alpha: value.alpha.into(), + z: value.0.z.into(), + alpha: value.0.alpha.into(), alpha_powers: value + .0 .alpha_powers .iter() .map(|&x| x.into()) @@ -208,11 +206,11 @@ pub struct WgpuInstance { pub encoder: wgpu::CommandEncoder, } -async fn init( +async fn init( trace: TreeVec>>, denom_inv: Vec, random_coeff_powers: Vec, - lookup_elements: LookupElements, + lookup_elements: PoseidonElements, trace_domain_log_size: u32, eval_domain_log_size: u32, total_sum: QM31, @@ -239,7 +237,7 @@ async fn init( .await .unwrap(); - let input_data = create_gpu_input::( + let input_data = create_gpu_input( trace, denom_inv, random_coeff_powers, @@ -381,11 +379,11 @@ async fn init( } } -fn create_gpu_input( +fn create_gpu_input( trace: TreeVec>>, denom_inv: Vec, random_coeff_powers: Vec, - lookup_elements: LookupElements, + lookup_elements: PoseidonElements, trace_domain_log_size: u32, eval_domain_log_size: u32, total_sum: QM31, @@ -437,11 +435,11 @@ fn create_gpu_input( } } -pub async fn compute_composition_polynomial_gpu<'a, const N: usize>( +pub async fn compute_composition_polynomial_gpu<'a>( trace: TreeVec>>, denom_inv: Vec, random_coeff_powers: Vec, - lookup_elements: LookupElements, + lookup_elements: PoseidonElements, trace_domain_log_size: u32, eval_domain_log_size: u32, total_sum: QM31, diff --git a/crates/prover/src/core/backend/gpu/gen_trace_interpolate_columns.rs b/crates/prover/src/core/backend/gpu/gen_trace_interpolate_columns.rs index 8c2a01769..523075d44 100644 --- a/crates/prover/src/core/backend/gpu/gen_trace_interpolate_columns.rs +++ b/crates/prover/src/core/backend/gpu/gen_trace_interpolate_columns.rs @@ -704,12 +704,21 @@ pub async fn gen_trace_interpolate_columns( let _interpolate_output = interpolate_result.await; #[cfg(not(target_family = "wasm"))] - println!("GPU time: {:?}", gpu_start.elapsed()); + println!( + "Gen Trace Interpolate Columns GPU time: {:?}", + gpu_start.elapsed() + ); #[cfg(target_family = "wasm")] let gpu_end = web_sys::window().unwrap().performance().unwrap().now(); #[cfg(target_family = "wasm")] - web_sys::console::log_1(&format!("GPU time: {:?}ms", gpu_end - gpu_start).into()); + web_sys::console::log_1( + &format!( + "Gen Trace Interpolate Columns GPU time: {:?}ms", + gpu_end - gpu_start + ) + .into(), + ); let lookup_data = LookupData { initial_state: std::array::from_fn(|_| std::array::from_fn(|_| BaseColumn::zeros(1))), diff --git a/crates/prover/src/core/backend/gpu/mod.rs b/crates/prover/src/core/backend/gpu/mod.rs index f0a5b5732..efa892718 100644 --- a/crates/prover/src/core/backend/gpu/mod.rs +++ b/crates/prover/src/core/backend/gpu/mod.rs @@ -6,5 +6,6 @@ pub mod gen_trace_parallel; pub mod gen_trace_parallel_no_packed; pub mod gen_trace_parallel_no_packed_parallel_columns; pub mod gpu_common; +pub mod prove; pub mod qm31; pub mod utils; diff --git a/crates/prover/src/core/backend/gpu/prove.rs b/crates/prover/src/core/backend/gpu/prove.rs new file mode 100644 index 000000000..aa106f328 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/prove.rs @@ -0,0 +1,165 @@ +use std::any::Any; +use std::borrow::Cow; + +use itertools::Itertools; +use tracing::{span, Level}; + +use crate::constraint_framework::{FrameworkComponent, FrameworkEval, PREPROCESSED_TRACE_IDX}; +use crate::core::air::accumulation::DomainEvaluationAccumulator; +use crate::core::air::{Component, Trace}; +use crate::core::backend::gpu::compute_composition_polynomial::compute_composition_polynomial_gpu as compute_composition_polynomial_gpu_poseidon; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::BackendForChannel; +use crate::core::channel::{Channel, MerkleChannel}; +use crate::core::constraints::coset_vanishing; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +// use crate::core::poly::circle::SecureCirclePoly; +use crate::core::fields::FieldExpOps; +use crate::core::pcs::{CommitmentSchemeProver, TreeVec}; +use crate::core::poly::circle::ops::PolyOps; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::utils; +use crate::examples::poseidon::PoseidonComponent; + +#[allow(unused_mut)] +#[allow(unused_variables)] +pub async fn prove_gpu( + components: &[&FrameworkComponent], + channel: &mut MC::C, + mut commitment_scheme: CommitmentSchemeProver<'_, SimdBackend, MC>, +) where + SimdBackend: BackendForChannel, +{ + let n_preprocessed_columns = commitment_scheme.trees[PREPROCESSED_TRACE_IDX] + .polynomials + .len(); + let trace = commitment_scheme.trace(); + + // Evaluate and commit on composition polynomial. + let random_coeff = channel.draw_felt(); + + let span = span!(Level::INFO, "Composition").entered(); + let span1 = span!(Level::INFO, "Generation").entered(); + // let composition_poly = + compute_composition_polynomial_gpu(components, random_coeff, &trace).await; + span1.exit(); + + // let mut tree_builder = commitment_scheme.tree_builder(); + // tree_builder.extend_polys(composition_poly.into_coordinate_polys()); + // tree_builder.commit(channel); + // span.exit(); +} + +#[allow(unused_variables)] +async fn compute_composition_polynomial_gpu<'a, MC: MerkleChannel, E: FrameworkEval + 'static>( + components: &'a [&'a FrameworkComponent], + random_coeff: SecureField, + trace: &'a Trace<'a, SimdBackend>, +) +// -> SecureCirclePoly +where + SimdBackend: BackendForChannel, +{ + let total_constraints: usize = components.iter().map(|c| c.n_constraints()).sum(); + let composition_log_degree_bound = components + .iter() + .map(|c| c.max_constraint_log_degree_bound()) + .max() + .unwrap(); + #[allow(unused_mut)] + let mut accumulator = DomainEvaluationAccumulator::new( + random_coeff, + composition_log_degree_bound, + total_constraints, + ); + + for &component in components { + evaluate_constraint_quotients_on_domain_gpu(component, trace, &mut accumulator).await; + } + // accumulator.finalize() +} + +#[allow(unused_variables)] +async fn evaluate_constraint_quotients_on_domain_gpu<'a, E: FrameworkEval + Any + 'static>( + component: &'a FrameworkComponent, + trace: &'a Trace<'a, SimdBackend>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, +) { + if component.n_constraints() == 0 { + return; + } + + let eval_domain = + CanonicCoset::new(component.max_constraint_log_degree_bound()).circle_domain(); + let trace_domain = CanonicCoset::new(component.eval().log_size()); + + let mut component_polys = trace.polys.sub_tree(&component.trace_locations()); + component_polys[PREPROCESSED_TRACE_IDX] = component + .preproccessed_column_indices() + .iter() + .map(|idx| &trace.polys[PREPROCESSED_TRACE_IDX][*idx]) + .collect(); + + let mut component_evals = trace.evals.sub_tree(&component.trace_locations()); + component_evals[PREPROCESSED_TRACE_IDX] = component + .preproccessed_column_indices() + .iter() + .map(|idx| &trace.evals[PREPROCESSED_TRACE_IDX][*idx]) + .collect(); + + // Extend trace if necessary. + // TODO: Don't extend when eval_size < committed_size. Instead, pick a good + // subdomain. (For larger blowup factors). + let need_to_extend = component_evals + .iter() + .flatten() + .any(|c| c.domain != eval_domain); + let trace: TreeVec>>> = + if need_to_extend { + let _span = span!(Level::INFO, "Extension").entered(); + let twiddles = SimdBackend::precompute_twiddles(eval_domain.half_coset); + component_polys + .as_cols_ref() + .map_cols(|col| Cow::Owned(col.evaluate_with_twiddles(eval_domain, &twiddles))) + } else { + component_evals.clone().map_cols(|c| Cow::Borrowed(*c)) + }; + + // Denom inverses. + let log_expand = eval_domain.log_size() - trace_domain.log_size(); + let mut denom_inv = (0..1 << log_expand) + .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) + .collect_vec(); + utils::bit_reverse(&mut denom_inv); + + // Accumulator. + let [mut accum] = + evaluation_accumulator.columns([(eval_domain.log_size(), component.n_constraints())]); + accum.random_coeff_powers.reverse(); + + let trace_cols = trace.as_cols_ref().map_cols(|c| c.to_cpu()); + let trace_cols = trace_cols.as_cols_ref(); + + #[cfg(target_family = "wasm")] + let gpu_start = web_sys::window().unwrap().performance().unwrap().now(); + + if let Some(poseidon_component) = (component as &dyn Any).downcast_ref::() { + let gpu_results = compute_composition_polynomial_gpu_poseidon( + trace_cols, + denom_inv.clone(), + accum.random_coeff_powers.clone(), + poseidon_component.eval().lookup_elements.clone(), + trace_domain.log_size(), + eval_domain.log_size(), + poseidon_component.eval().total_sum, + ) + .await; + } + + #[cfg(target_family = "wasm")] + let gpu_end = web_sys::window().unwrap().performance().unwrap().now(); + #[cfg(target_family = "wasm")] + web_sys::console::log_1(&format!("Time spent on GPU: {:?}ms", gpu_end - gpu_start).into()); +} diff --git a/crates/prover/src/core/poly/circle/mod.rs b/crates/prover/src/core/poly/circle/mod.rs index f2532d5bf..2f486b94c 100644 --- a/crates/prover/src/core/poly/circle/mod.rs +++ b/crates/prover/src/core/poly/circle/mod.rs @@ -1,7 +1,7 @@ mod canonic; mod domain; mod evaluation; -mod ops; +pub mod ops; mod poly; mod secure_poly; diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index f6450bbab..d7593883c 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -406,6 +406,7 @@ mod tests { use crate::constraint_framework::assert_constraints; use crate::constraint_framework::preprocessed_columns::gen_is_first; use crate::core::air::Component; + use crate::core::backend::gpu::prove::prove_gpu; use crate::core::backend::CpuBackend; use crate::core::channel::Blake2sChannel; use crate::core::fields::m31::BaseField; @@ -512,7 +513,13 @@ mod tests { let circle_evals: Vec<_> = _trace.iter().map(|eval| eval.to_cpu()).collect(); let _cpu_trace_polys = CpuBackend::interpolate_columns(circle_evals, &twiddles); let checkpoint2 = web_sys::window().unwrap().performance().unwrap().now(); - web_sys::console::log_1(&format!("CPU time: {:?}ms", checkpoint2 - checkpoint1).into()); + web_sys::console::log_1( + &format!( + "Gen Trace Interpolate Columns CPU time: {:?}ms", + checkpoint2 - checkpoint1 + ) + .into(), + ); let _cpu_trace = _trace.into_iter().map(|c| c.values.clone()).collect_vec(); // assert_eq!(_cpu_trace, _gpu_trace); // assert_eq!(_lookup_data, _gpu_lookup_data); @@ -603,4 +610,74 @@ mod tests { verify(&[&component], channel, commitment_scheme, proof).unwrap(); } + + #[wasm_bindgen_test::wasm_bindgen_test] + async fn test_poseidon_prove_wasm_gpu_cpu() { + use crate::constraint_framework::TraceLocationAllocator; + use crate::core::pcs::CommitmentSchemeProver; + use crate::examples::poseidon::{ + PoseidonComponent, PoseidonEval, SimdBackend, LOG_EXPAND, N_LOG_INSTANCES_PER_ROW, + }; + + let log_n_instances = env::var("LOG_N_INSTANCES") + .unwrap_or_else(|_| "12".to_string()) + .parse::() + .unwrap(); + let config = PcsConfig { + pow_bits: 10, + fri_config: FriConfig::new(5, 1, 64), + }; + + assert!(log_n_instances >= N_LOG_INSTANCES_PER_ROW as u32); + let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32; + + // Precompute twiddles. + let twiddles = SimdBackend::precompute_twiddles( + CanonicCoset::new(log_n_rows + LOG_EXPAND + config.fri_config.log_blowup_factor) + .circle_domain() + .half_coset, + ); + + // Setup protocol. + let channel = &mut Blake2sChannel::default(); + let mut commitment_scheme = + CommitmentSchemeProver::<_, Blake2sMerkleChannel>::new(config, &twiddles); + + // Preprocessed trace. + let mut tree_builder = commitment_scheme.tree_builder(); + let constant_trace = vec![gen_is_first(log_n_rows)]; + tree_builder.extend_evals(constant_trace); + tree_builder.commit(channel); + + // Trace. + let (trace, lookup_data) = gen_trace(log_n_rows); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(channel); + + // Draw lookup elements. + let lookup_elements = PoseidonElements::draw(channel); + + // Interaction trace. + let (trace, total_sum) = gen_interaction_trace(log_n_rows, lookup_data, &lookup_elements); + let mut tree_builder = commitment_scheme.tree_builder(); + tree_builder.extend_evals(trace); + tree_builder.commit(channel); + + // Prove constraints. + let component = PoseidonComponent::new( + &mut TraceLocationAllocator::default(), + PoseidonEval { + log_n_rows, + lookup_elements, + total_sum, + }, + (total_sum, None), + ); + + prove_gpu(&[&component], channel, commitment_scheme).await; + + // Prove in CPU + let (_component, _proof) = prove_poseidon(log_n_instances, config); + } }