diff --git a/crates/prover/src/core/backend/gpu/fraction.rs b/crates/prover/src/core/backend/gpu/fraction.rs new file mode 100644 index 000000000..f22c0fc32 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/fraction.rs @@ -0,0 +1,130 @@ +use std::borrow::Cow; + +use super::compute_composition_polynomial::GpuQM31; +use super::gpu_common::{ByteSerialize, GpuComputeInstance, GpuOperation}; +use crate::core::fields::qm31::QM31; +use crate::core::lookups::utils::Fraction; + +#[repr(C)] +#[derive(Copy, Clone, Debug, PartialEq)] +pub struct GpuFraction { + pub numerator: GpuQM31, + pub denominator: GpuQM31, +} + +#[repr(C)] +#[derive(Copy, Clone, Debug)] +pub struct ComputeInput { + pub first: GpuFraction, + pub second: GpuFraction, +} + +#[repr(C)] +#[derive(Copy, Clone, Debug)] +pub struct ComputeOutput { + pub result: GpuFraction, +} + +impl ByteSerialize for ComputeInput {} +impl ByteSerialize for ComputeOutput {} + +impl From> for GpuFraction { + fn from(value: Fraction) -> Self { + GpuFraction { + numerator: GpuQM31::from(value.numerator), + denominator: GpuQM31::from(value.denominator), + } + } +} + +pub enum FractionOperation { + Add, +} + +impl GpuOperation for FractionOperation { + fn shader_source(&self) -> Cow<'static, str> { + let base_source = include_str!("fraction.wgsl"); + let qm31_source = include_str!("qm31.wgsl"); + + let inputs = r#" + struct ComputeInput { + first: Fraction, + second: Fraction, + } + + @group(0) @binding(0) var input: ComputeInput; + "#; + + let output = r#" + struct ComputeOutput { + result: Fraction, + } + + @group(0) @binding(1) var output: ComputeOutput; + "#; + + let operation = match self { + FractionOperation::Add => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = fraction_add(input.first, input.second); + } + "# + } + }; + + format!("{qm31_source}\n{base_source}\n{inputs}\n{output}\n{operation}").into() + } +} + +pub async fn compute_fraction_operation( + operation: FractionOperation, + first: Fraction, + second: Fraction, +) -> ComputeOutput { + let input = ComputeInput { + first: first.into(), + second: second.into(), + }; + + let instance = GpuComputeInstance::new(&input, std::mem::size_of::()).await; + let (pipeline, bind_group) = + instance.create_pipeline(&operation.shader_source(), operation.entry_point()); + + let output = instance + .run_computation::(&pipeline, &bind_group, (1, 1, 1)) + .await; + + output +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::fields::qm31::QM31; + + #[test] + fn test_fraction_add() { + // CPU implementation + let cpu_a = Fraction::new( + QM31::from_u32_unchecked(1u32, 0u32, 0u32, 0u32), + QM31::from_u32_unchecked(3u32, 0u32, 0u32, 0u32), + ); + let cpu_b = Fraction::new( + QM31::from_u32_unchecked(2u32, 0u32, 0u32, 0u32), + QM31::from_u32_unchecked(6u32, 0u32, 0u32, 0u32), + ); + let cpu_result = cpu_a + cpu_b; + + // GPU implementation + let gpu_result = pollster::block_on(compute_fraction_operation( + FractionOperation::Add, + cpu_a, + cpu_b, + )); + + assert_eq!(cpu_result.numerator, gpu_result.result.numerator.into()); + assert_eq!(cpu_result.denominator, gpu_result.result.denominator.into()); + } +} diff --git a/crates/prover/src/core/backend/gpu/fraction.wgsl b/crates/prover/src/core/backend/gpu/fraction.wgsl new file mode 100644 index 000000000..79f54455a --- /dev/null +++ b/crates/prover/src/core/backend/gpu/fraction.wgsl @@ -0,0 +1,17 @@ +// This shader contains implementations for fraction operations. +// It is stateless and can be used as a library in other shaders. + +struct Fraction { + numerator: QM31, + denominator: QM31, +} + +// Add two fractions: (a/b + c/d) = (ad + bc)/(bd) +fn fraction_add(a: Fraction, b: Fraction) -> Fraction { + let numerator = qm31_add( + qm31_mul(a.numerator, b.denominator), + qm31_mul(b.numerator, a.denominator) + ); + let denominator = qm31_mul(a.denominator, b.denominator); + return Fraction(numerator, denominator); +} diff --git a/crates/prover/src/core/backend/gpu/gpu_common.rs b/crates/prover/src/core/backend/gpu/gpu_common.rs new file mode 100644 index 000000000..66691a8b2 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/gpu_common.rs @@ -0,0 +1,226 @@ +use std::borrow::Cow; + +use wgpu::util::DeviceExt; + +/// Common trait for GPU input/output types +pub trait ByteSerialize: Sized { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + (self as *const Self) as *const u8, + std::mem::size_of::(), + ) + } + } + + fn from_bytes(bytes: &[u8]) -> &Self { + assert!(bytes.len() >= std::mem::size_of::()); + unsafe { &*(bytes.as_ptr() as *const Self) } + } +} + +/// Base GPU instance for field computations +pub struct GpuComputeInstance { + pub device: wgpu::Device, + pub queue: wgpu::Queue, + pub input_buffer: wgpu::Buffer, + pub output_buffer: wgpu::Buffer, + pub staging_buffer: wgpu::Buffer, +} + +impl GpuComputeInstance { + pub async fn new(input_data: &T, output_size: usize) -> Self { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Field Operations Device"), + required_features: wgpu::Features::SHADER_INT64, + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + // Create input buffer + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Field Input Buffer"), + contents: input_data.as_bytes(), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + + // Create output buffer + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Output Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Create staging buffer for reading results + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Staging Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + Self { + device, + queue, + input_buffer, + output_buffer, + staging_buffer, + } + } + + pub fn create_pipeline( + &self, + shader_source: &str, + entry_point: &str, + ) -> (wgpu::ComputePipeline, wgpu::BindGroup) { + let shader = self + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Field Operations Shader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + let bind_group_layout = + self.device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Field Operations Bind Group Layout"), + }); + + let pipeline_layout = self + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Field Operations Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let pipeline = self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Field Operations Pipeline"), + layout: Some(&pipeline_layout), + module: &shader, + entry_point: Some(entry_point), + cache: None, + compilation_options: Default::default(), + }); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: self.input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: self.output_buffer.as_entire_binding(), + }, + ], + label: Some("Field Operations Bind Group"), + }); + + (pipeline, bind_group) + } + + pub async fn run_computation( + &self, + pipeline: &wgpu::ComputePipeline, + bind_group: &wgpu::BindGroup, + workgroup_count: (u32, u32, u32), + ) -> T { + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Field Operations Encoder"), + }); + + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("Field Operations Compute Pass"), + timestamp_writes: None, + }); + compute_pass.set_pipeline(pipeline); + compute_pass.set_bind_group(0, bind_group, &[]); + compute_pass.dispatch_workgroups( + workgroup_count.0, + workgroup_count.1, + workgroup_count.2, + ); + } + + encoder.copy_buffer_to_buffer( + &self.output_buffer, + 0, + &self.staging_buffer, + 0, + self.staging_buffer.size(), + ); + + self.queue.submit(Some(encoder.finish())); + + let buffer_slice = self.staging_buffer.slice(..); + let (tx, rx) = flume::bounded(1); + buffer_slice.map_async(wgpu::MapMode::Read, move |result| { + tx.send(result).unwrap(); + }); + + self.device.poll(wgpu::Maintain::wait()); + + rx.recv_async().await.unwrap().unwrap(); + let data = buffer_slice.get_mapped_range(); + let result = *T::from_bytes(&data); + drop(data); + self.staging_buffer.unmap(); + + result + } +} + +/// Trait for GPU operations that require shader source generation +pub trait GpuOperation { + fn shader_source(&self) -> Cow<'static, str>; + + fn entry_point(&self) -> &'static str { + "main" + } +} diff --git a/crates/prover/src/core/backend/gpu/mod.rs b/crates/prover/src/core/backend/gpu/mod.rs index 8ddbf6302..20a73b47b 100644 --- a/crates/prover/src/core/backend/gpu/mod.rs +++ b/crates/prover/src/core/backend/gpu/mod.rs @@ -1,5 +1,9 @@ +pub mod fraction; pub mod gen_trace; pub mod gen_trace_interpolate_columns; pub mod gen_trace_parallel; pub mod gen_trace_parallel_no_packed; pub mod gen_trace_parallel_no_packed_parallel_columns; +pub mod gpu_common; +pub mod qm31; +pub mod utils; diff --git a/crates/prover/src/core/backend/gpu/qm31.rs b/crates/prover/src/core/backend/gpu/qm31.rs new file mode 100644 index 000000000..aa0ce536a --- /dev/null +++ b/crates/prover/src/core/backend/gpu/qm31.rs @@ -0,0 +1,372 @@ +use std::borrow::Cow; + +use super::compute_composition_polynomial::GpuQM31; +use super::gpu_common::{ByteSerialize, GpuComputeInstance, GpuOperation}; +use crate::core::fields::qm31::QM31; + +#[repr(C)] +#[derive(Copy, Clone, Debug, PartialEq)] +pub struct ComputeInput { + pub first: GpuQM31, + pub second: GpuQM31, +} + +#[repr(C)] +#[derive(Copy, Clone, Debug, PartialEq)] +pub struct ComputeOutput { + pub result: GpuQM31, +} + +impl ByteSerialize for ComputeInput {} +impl ByteSerialize for ComputeOutput {} + +pub enum QM31Operation { + Add, + Subtract, + Multiply, + Negate, + Inverse, +} + +impl GpuOperation for QM31Operation { + fn shader_source(&self) -> Cow<'static, str> { + let base_source = include_str!("qm31.wgsl"); + + let inputs = r#" + struct ComputeInput { + first: QM31, + second: QM31, + } + + @group(0) @binding(0) var input: ComputeInput; + "#; + + let output = r#" + struct ComputeOutput { + result: QM31, + } + + @group(0) @binding(1) var output: ComputeOutput; + "#; + + let operation = match self { + QM31Operation::Add => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_add(input.first, input.second); + } + "# + } + QM31Operation::Multiply => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_mul(input.first, input.second); + } + "# + } + QM31Operation::Subtract => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_sub(input.first, input.second); + } + "# + } + QM31Operation::Negate => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_neg(input.first); + } + "# + } + QM31Operation::Inverse => { + r#" + @compute @workgroup_size(1) + fn main() { + output.result = qm31_inverse(input.first); + } + "# + } + }; + + format!("{base_source}\n{inputs}\n{output}\n{operation}").into() + } +} + +pub async fn compute_field_operation(operation: QM31Operation, first: QM31, second: QM31) -> QM31 { + let input = ComputeInput { + first: first.into(), + second: second.into(), + }; + + let instance = GpuComputeInstance::new(&input, std::mem::size_of::()).await; + let (pipeline, bind_group) = + instance.create_pipeline(&operation.shader_source(), operation.entry_point()); + + let output = instance + .run_computation::(&pipeline, &bind_group, (1, 1, 1)) + .await; + + output.result.into() +} + +#[cfg(test)] +mod tests { + use num_traits::Zero; + + use super::*; + use crate::core::fields::cm31::CM31; + use crate::core::fields::m31::{M31, P}; + use crate::core::fields::qm31::QM31; + use crate::core::fields::FieldExpOps; + use crate::{cm31, qm31}; + + #[test] + fn test_gpu_field_values() { + let qm0 = qm31!(1, 2, 3, 4); + let qm1 = qm31!(4, 5, 6, 7); + + // Test round-trip conversion CPU -> GPU -> CPU + let gpu_qm0 = GpuQM31::from(qm0); + let gpu_qm1 = GpuQM31::from(qm1); + + let cpu_qm0 = QM31( + CM31(gpu_qm0.a.a.data.into(), gpu_qm0.a.b.data.into()), + CM31(gpu_qm0.b.a.data.into(), gpu_qm0.b.b.data.into()), + ); + + let cpu_qm1 = QM31( + CM31(gpu_qm1.a.a.data.into(), gpu_qm1.a.b.data.into()), + CM31(gpu_qm1.b.a.data.into(), gpu_qm1.b.b.data.into()), + ); + + assert_eq!( + qm0, cpu_qm0, + "Round-trip conversion should preserve values for qm0" + ); + assert_eq!( + qm1, cpu_qm1, + "Round-trip conversion should preserve values for qm1" + ); + } + + #[test] + fn test_gpu_m31_field_arithmetic() { + // Test M31 field operations + let m = M31::from(19u32); + let one = M31::from(1u32); + let zero = M31::zero(); + + // Create QM31 values for GPU computation + let m_qm = QM31(CM31(m, zero), CM31::zero()); + let one_qm = QM31(CM31(one, zero), CM31::zero()); + let zero_qm = QM31(CM31(zero, zero), CM31::zero()); + + // Test addition + let cpu_add = m + one; + let gpu_add = pollster::block_on(compute_field_operation(QM31Operation::Add, m_qm, one_qm)); + assert_eq!(gpu_add.0 .0, cpu_add, "M31 addition failed"); + + // Test subtraction + let cpu_sub = m - one; + let gpu_sub = pollster::block_on(compute_field_operation( + QM31Operation::Subtract, + m_qm, + one_qm, + )); + assert_eq!(gpu_sub.0 .0, cpu_sub, "M31 subtraction failed"); + + // Test multiplication + let cpu_mul = m * one; + let gpu_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, + m_qm, + one_qm, + )); + assert_eq!(gpu_mul.0 .0, cpu_mul, "M31 multiplication failed"); + + // Test negation + let cpu_neg = -m; + let gpu_neg = pollster::block_on(compute_field_operation( + QM31Operation::Negate, + m_qm, + zero_qm, + )); + assert_eq!(gpu_neg.0 .0, cpu_neg, "M31 negation failed"); + + // Test inverse + let cpu_inv = m.inverse(); + let gpu_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, + m_qm, + zero_qm, + )); + assert_eq!(gpu_inv.0 .0, cpu_inv, "M31 inverse failed"); + + // Test with large numbers (near P) + let large = M31::from(P - 1); + let large_qm = QM31(CM31(large, zero), CM31::zero()); + + // Test large number multiplication + let cpu_large_mul = large * m; + let gpu_large_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, + large_qm, + m_qm, + )); + assert_eq!( + gpu_large_mul.0 .0, cpu_large_mul, + "M31 large number multiplication failed" + ); + + // Test large number inverse + let cpu_large_inv = one / large; + let gpu_large_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, + large_qm, + zero_qm, + )); + assert_eq!( + gpu_large_inv.0 .0, cpu_large_inv, + "M31 large number inverse failed" + ); + } + + #[test] + fn test_gpu_cm31_field_arithmetic() { + let cm0 = cm31!(1, 2); + let cm1 = cm31!(4, 5); + let zero = CM31::zero(); + + // Test addition + let cpu_add = cm0 + cm1; + let gpu_add = pollster::block_on(compute_field_operation( + QM31Operation::Add, + QM31(cm0, zero), + QM31(cm1, zero), + )); + assert_eq!(gpu_add.0, cpu_add, "CM31 addition failed"); + + // Test subtraction + let cpu_sub = cm0 - cm1; + let gpu_sub = pollster::block_on(compute_field_operation( + QM31Operation::Subtract, + QM31(cm0, zero), + QM31(cm1, zero), + )); + assert_eq!(gpu_sub.0, cpu_sub, "CM31 subtraction failed"); + + // Test multiplication + let cpu_mul = cm0 * cm1; + let gpu_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, + QM31(cm0, zero), + QM31(cm1, zero), + )); + assert_eq!(gpu_mul.0, cpu_mul, "CM31 multiplication failed"); + + // Test negation + let cpu_neg = -cm0; + let gpu_neg = pollster::block_on(compute_field_operation( + QM31Operation::Negate, + QM31(cm0, zero), + QM31(zero, zero), + )); + assert_eq!(gpu_neg.0, cpu_neg, "CM31 negation failed"); + + // Test inverse + let cpu_inv = cm0.inverse(); + let gpu_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, + QM31(cm0, zero), + QM31(zero, zero), + )); + assert_eq!(gpu_inv.0, cpu_inv, "CM31 inverse failed"); + + // Test with large numbers (near P) + let large = cm31!(P - 1, P - 2); + let large_qm = QM31(large, zero); + + // Test large number multiplication + let cpu_large_mul = large * cm1; + let gpu_large_mul = pollster::block_on(compute_field_operation( + QM31Operation::Multiply, + large_qm, + QM31(cm1, zero), + )); + assert_eq!( + gpu_large_mul.0, cpu_large_mul, + "CM31 large number multiplication failed" + ); + + // Test large number inverse + let cpu_large_inv = large.inverse(); + let gpu_large_inv = pollster::block_on(compute_field_operation( + QM31Operation::Inverse, + large_qm, + QM31(zero, zero), + )); + assert_eq!( + gpu_large_inv.0, cpu_large_inv, + "CM31 large number inverse failed" + ); + } + + #[test] + fn test_gpu_qm31_field_arithmetic() { + let qm0 = qm31!(1, 2, 3, 4); + let qm1 = qm31!(4, 5, 6, 7); + let zero = QM31::zero(); + + // Test addition + let cpu_add = qm0 + qm1; + let gpu_add = pollster::block_on(compute_field_operation(QM31Operation::Add, qm0, qm1)); + assert_eq!(gpu_add, cpu_add, "QM31 addition failed"); + + // Test subtraction + let cpu_sub = qm0 - qm1; + let gpu_sub = + pollster::block_on(compute_field_operation(QM31Operation::Subtract, qm0, qm1)); + assert_eq!(gpu_sub, cpu_sub, "QM31 subtraction failed"); + + // Test multiplication + let cpu_mul = qm0 * qm1; + let gpu_mul = + pollster::block_on(compute_field_operation(QM31Operation::Multiply, qm0, qm1)); + assert_eq!(gpu_mul, cpu_mul, "QM31 multiplication failed"); + + // Test negation + let cpu_neg = -qm0; + let gpu_neg = pollster::block_on(compute_field_operation(QM31Operation::Negate, qm0, zero)); + assert_eq!(gpu_neg, cpu_neg, "QM31 negation failed"); + + // Test inverse + let cpu_inv = qm0.inverse(); + let gpu_inv = pollster::block_on(compute_field_operation(QM31Operation::Inverse, qm0, qm1)); + assert_eq!(gpu_inv, cpu_inv, "QM31 inverse failed"); + + // Test with large numbers (near P) + let large = qm31!(P - 1, P - 2, P - 3, P - 4); + + // Test large number multiplication + let cpu_large_mul = large * qm1; + let gpu_large_mul = + pollster::block_on(compute_field_operation(QM31Operation::Multiply, large, qm1)); + assert_eq!( + gpu_large_mul, cpu_large_mul, + "QM31 large number multiplication failed" + ); + + // Test large number inverse + let cpu_large_inv = qm1.inverse(); + let gpu_large_inv = + pollster::block_on(compute_field_operation(QM31Operation::Inverse, qm1, zero)); + assert_eq!( + gpu_large_inv, cpu_large_inv, + "QM31 large number inverse failed" + ); + } +} diff --git a/crates/prover/src/core/backend/gpu/qm31.wgsl b/crates/prover/src/core/backend/gpu/qm31.wgsl new file mode 100644 index 000000000..4d044a1bd --- /dev/null +++ b/crates/prover/src/core/backend/gpu/qm31.wgsl @@ -0,0 +1,223 @@ +// This shader contains implementations for QM31/CM31/M31 operations. +// It is stateless, i.e. it does not contain any storage variables, and also it does not include +// any entrypoint functions, which means that it can be used as a library in other shaders. +// Note that the variable names that are used in this shader cannot be used in other shaders. +const P: u32 = 0x7FFFFFFF; // 2^31 - 1 +const MODULUS_BITS: u32 = 31u; +const HALF_BITS: u32 = 16u; + +struct M31 { + data: u32, +} + +struct CM31 { + a: M31, + b: M31, +} + +struct QM31 { + a: CM31, + b: CM31, +} + +fn m31_add(a: M31, b: M31) -> M31 { + return M31(partial_reduce(a.data + b.data)); +} + +fn m31_sub(a: M31, b: M31) -> M31 { + return m31_add(a, m31_neg(b)); +} + +fn m31_mul(a: M31, b: M31) -> M31 { + // Split into 16-bit parts + let a1 = a.data >> HALF_BITS; + let a0 = a.data & 0xFFFFu; + let b1 = b.data >> HALF_BITS; + let b0 = b.data & 0xFFFFu; + + // Compute partial products + let m0 = partial_reduce(a0 * b0); + let m1 = partial_reduce(a0 * b1); + let m2 = partial_reduce(a1 * b0); + let m3 = partial_reduce(a1 * b1); + + // Combine middle terms with reduction + let mid = partial_reduce(m1 + m2); + + // Combine parts with partial reduction + let shifted_mid = partial_reduce(mid << HALF_BITS); + let low = partial_reduce(m0 + shifted_mid); + + let high_part = partial_reduce(m3 + (mid >> HALF_BITS)); + + // Final combination using Mersenne prime property + let result = partial_reduce( + partial_reduce((high_part << 1u)) + + partial_reduce((low >> MODULUS_BITS)) + + partial_reduce(low & P) + ); + return M31(result); +} + +fn m31_neg(a: M31) -> M31 { + return M31(partial_reduce(P - a.data)); +} + +fn m31_square(x: M31, n: u32) -> M31 { + var result = x; + for (var i = 0u; i < n; i += 1u) { + result = m31_mul(result, result); + } + return result; +} + +fn m31_pow5(x: M31) -> M31 { + return m31_mul(m31_square(x, 2u), x); +} + +fn m31_inverse(x: M31) -> M31 { + // Computes x^(2^31-2) using the same sequence as pow2147483645 + // This is equivalent to x^(P-2) where P = 2^31-1 + + // t0 = x^5 + let t0 = m31_mul(m31_square(x, 2u), x); + + // t1 = x^15 + let t1 = m31_mul(m31_square(t0, 1u), t0); + + // t2 = x^125 + let t2 = m31_mul(m31_square(t1, 3u), t0); + + // t3 = x^255 + let t3 = m31_mul(m31_square(t2, 1u), t0); + + // t4 = x^65535 + let t4 = m31_mul(m31_square(t3, 8u), t3); + + // t5 = x^16777215 + let t5 = m31_mul(m31_square(t4, 8u), t3); + + // result = x^2147483520 + var result = m31_square(t5, 7u); + result = m31_mul(result, t2); + + return result; +} + +// Complex field operations for CM31 +fn cm31_add(a: CM31, b: CM31) -> CM31 { + return CM31( + m31_add(a.a, b.a), + m31_add(a.b, b.b) + ); +} + +fn cm31_sub(a: CM31, b: CM31) -> CM31 { + return CM31( + m31_sub(a.a, b.a), + m31_sub(a.b, b.b) + ); +} + +fn cm31_mul(a: CM31, b: CM31) -> CM31 { + // (a + bi)(c + di) = (ac - bd) + (ad + bc)i + let ac = m31_mul(a.a, b.a); + let bd = m31_mul(a.b, b.b); + let ad = m31_mul(a.a, b.b); + let bc = m31_mul(a.b, b.a); + + return CM31( + m31_sub(ac, bd), + m31_add(ad, bc) + ); +} + +fn cm31_neg(a: CM31) -> CM31 { + return CM31(m31_neg(a.a), m31_neg(a.b)); +} + +fn cm31_square(x: CM31) -> CM31 { + return cm31_mul(x, x); +} + +fn cm31_inverse(x: CM31) -> CM31 { + // 1/(a + bi) = (a - bi)/(a² + b²) + let a_sq = m31_square(x.a, 1u); + let b_sq = m31_square(x.b, 1u); + let denom = m31_add(a_sq, b_sq); + let denom_inv = m31_inverse(denom); + + // Multiply by conjugate and divide by norm + return cm31_mul( + CM31(x.a, m31_neg(x.b)), + CM31(denom_inv, M31(0u)) + ); +} + +// Quadratic extension field operations for QM31 +fn qm31_add(a: QM31, b: QM31) -> QM31 { + return QM31( + cm31_add(a.a, b.a), + cm31_add(a.b, b.b) + ); +} + +fn qm31_sub(a: QM31, b: QM31) -> QM31 { + return QM31( + cm31_sub(a.a, b.a), + cm31_sub(a.b, b.b) + ); +} + +fn qm31_mul(a: QM31, b: QM31) -> QM31 { + // (a + bu)(c + du) = (ac + rbd) + (ad + bc)u + // where r = 2 + i is the irreducible polynomial coefficient + let ac = cm31_mul(a.a, b.a); + let bd = cm31_mul(a.b, b.b); + let ad = cm31_mul(a.a, b.b); + let bc = cm31_mul(a.b, b.a); + + // r = 2 + i + let r = CM31(M31(2u), M31(1u)); + let rbd = cm31_mul(r, bd); + + return QM31( + cm31_add(ac, rbd), + cm31_add(ad, bc) + ); +} + +fn qm31_neg(a: QM31) -> QM31 { + return QM31(cm31_neg(a.a), cm31_neg(a.b)); +} + +fn qm31_square(x: QM31) -> QM31 { + return qm31_mul(x, x); +} + +fn qm31_inverse(x: QM31) -> QM31 { + // (a + bu)^-1 = (a - bu)/(a^2 - (2+i)b^2) + let b2 = cm31_square(x.b); + + // Create 2+i + let r = CM31(M31(2u), M31(1u)); + + let rb2 = cm31_mul(r, b2); + let a2 = cm31_square(x.a); + let denom = cm31_sub(a2, rb2); + let denom_inv = cm31_inverse(denom); + + // Compute (a - bu) + let neg_b = cm31_neg(x.b); + + return QM31( + cm31_mul(x.a, denom_inv), + cm31_mul(neg_b, denom_inv) + ); +} + +// Utility functions +fn partial_reduce(val: u32) -> u32 { + let reduced = val - P; + return select(val, reduced, reduced < val); +} \ No newline at end of file diff --git a/crates/prover/src/core/backend/gpu/utils.rs b/crates/prover/src/core/backend/gpu/utils.rs new file mode 100644 index 000000000..495a3fbb2 --- /dev/null +++ b/crates/prover/src/core/backend/gpu/utils.rs @@ -0,0 +1,349 @@ +use wgpu::util::DeviceExt; + +/// Input data for the GPU computation +#[repr(C)] +#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] +pub struct ComputeInput { + pub i: u32, + pub domain_log_size: u32, + pub eval_log_size: u32, + pub offset: i32, +} + +/// Output data from the GPU computation +#[repr(C)] +#[derive(Copy, Clone, Debug, bytemuck::Pod, bytemuck::Zeroable)] +pub struct ComputeOutput { + pub result: u32, +} + +impl From for usize { + fn from(output: ComputeOutput) -> Self { + output.result as usize + } +} + +pub trait ByteSerialize: Sized { + fn as_bytes(&self) -> &[u8] { + unsafe { + std::slice::from_raw_parts( + (self as *const Self) as *const u8, + std::mem::size_of::(), + ) + } + } + + fn from_bytes(bytes: &[u8]) -> &Self { + assert!(bytes.len() >= std::mem::size_of::()); + unsafe { &*(bytes.as_ptr() as *const Self) } + } +} + +impl ByteSerialize for ComputeInput {} +impl ByteSerialize for ComputeOutput {} + +/// GPU instance for utility computations +pub struct GpuUtilsInstance { + device: wgpu::Device, + queue: wgpu::Queue, + input_buffer: wgpu::Buffer, + output_buffer: wgpu::Buffer, + staging_buffer: wgpu::Buffer, +} + +impl GpuUtilsInstance { + pub async fn new(input_data: &T, output_size: usize) -> Self { + let instance = wgpu::Instance::default(); + let adapter = instance + .request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }) + .await + .unwrap(); + + let (device, queue) = adapter + .request_device( + &wgpu::DeviceDescriptor { + label: Some("Field Operations Device"), + required_features: wgpu::Features::SHADER_INT64, + required_limits: wgpu::Limits::default(), + memory_hints: wgpu::MemoryHints::Performance, + }, + None, + ) + .await + .unwrap(); + + // Create input buffer + let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some("Field Input Buffer"), + contents: input_data.as_bytes(), + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST, + }); + + // Create output buffer + let output_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Output Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + // Create staging buffer for reading results + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Field Staging Buffer"), + size: output_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + Self { + device, + queue, + input_buffer, + output_buffer, + staging_buffer, + } + } + + /// Creates a compute pipeline for the operation + pub fn create_pipeline( + &self, + shader_source: &str, + entry_point: &str, + ) -> (wgpu::ComputePipeline, wgpu::BindGroup) { + let shader = self + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("Field Operations Shader"), + source: wgpu::ShaderSource::Wgsl(shader_source.into()), + }); + + let bind_group_layout = + self.device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + entries: &[ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ], + label: Some("Field Operations Bind Group Layout"), + }); + + let pipeline_layout = self + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("Field Operations Pipeline Layout"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let pipeline = self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("Field Operations Pipeline"), + layout: Some(&pipeline_layout), + module: &shader, + entry_point: Some(entry_point), + cache: None, + compilation_options: Default::default(), + }); + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + entries: &[ + wgpu::BindGroupEntry { + binding: 0, + resource: self.input_buffer.as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: self.output_buffer.as_entire_binding(), + }, + ], + label: Some("Field Operations Bind Group"), + }); + + (pipeline, bind_group) + } + + /// Runs the computation on the GPU + async fn run_computation( + &self, + pipeline: &wgpu::ComputePipeline, + bind_group: &wgpu::BindGroup, + workgroup_count: (u32, u32, u32), + ) -> T { + // Create command encoder and compute pass + let mut encoder = self + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + compute_pass.set_pipeline(&pipeline); + compute_pass.set_bind_group(0, bind_group, &[]); + compute_pass.dispatch_workgroups( + workgroup_count.0, + workgroup_count.1, + workgroup_count.2, + ); + } + + // Copy results to staging buffer + encoder.copy_buffer_to_buffer( + &self.output_buffer, + 0, + &self.staging_buffer, + 0, + self.staging_buffer.size(), + ); + + // Submit command buffer and wait for results + self.queue.submit(Some(encoder.finish())); + + // Read results from staging buffer + let slice = self.staging_buffer.slice(..); + let (sender, receiver) = flume::bounded(1); + slice.map_async(wgpu::MapMode::Read, move |result| { + sender.send(result).unwrap(); + }); + self.device.poll(wgpu::Maintain::Wait); + + receiver.recv_async().await.unwrap().unwrap(); + let data = slice.get_mapped_range(); + let result = *T::from_bytes(&data); + drop(data); + self.staging_buffer.unmap(); + + result + } +} + +#[derive(Debug)] +pub enum GpuUtilsOperation { + OffsetBitReversedCircleDomainIndex, +} + +impl GpuUtilsOperation { + pub fn shader_source(&self) -> String { + let base_source = include_str!("utils.wgsl"); + let qm31_source = include_str!("qm31.wgsl"); + + let inputs = r#" + struct Inputs { + i: u32, + domain_log_size: u32, + eval_log_size: u32, + offset: i32, + } + + @group(0) @binding(0) var inputs: Inputs; + "#; + + let output = r#" + struct Output { + result: u32, + } + + @group(0) @binding(1) var output: Output; + "#; + + let operation = match self { + GpuUtilsOperation::OffsetBitReversedCircleDomainIndex => { + r#" + @compute @workgroup_size(1) + fn main() {{ + let i = inputs.i; + let domain_log_size = inputs.domain_log_size; + let eval_log_size = inputs.eval_log_size; + let offset = inputs.offset; + + let result = offset_bit_reversed_circle_domain_index(i, domain_log_size, eval_log_size, offset); + output.result = result; + }} + "# + } + }; + + format!("{base_source}\n{qm31_source}\n{inputs}\n{output}\n{operation}") + } +} + +/// Computes the offset bit reversed circle domain index using the GPU +pub async fn compute_offset_bit_reversed_circle_domain_index( + i: usize, + domain_log_size: u32, + eval_log_size: u32, + offset: i32, +) -> usize { + let input = ComputeInput { + i: i as u32, + domain_log_size, + eval_log_size, + offset, + }; + + let instance = GpuUtilsInstance::new(&input, std::mem::size_of::()).await; + + let shader_source = GpuUtilsOperation::OffsetBitReversedCircleDomainIndex.shader_source(); + let (pipeline, bind_group) = instance.create_pipeline(&shader_source, "main"); + + let gpu_result = instance + .run_computation::(&pipeline, &bind_group, (1, 1, 1)) + .await; + gpu_result.into() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::utils::offset_bit_reversed_circle_domain_index as cpu_offset_bit_reversed_circle_domain_index; + + #[test] + fn test_offset_bit_reversed_circle_domain_index() { + // Test parameters from the CPU test + let domain_log_size = 3; + let eval_log_size = 6; + let initial_index = 5; + let offset = -2; + + let gpu_result = pollster::block_on(compute_offset_bit_reversed_circle_domain_index( + initial_index, + domain_log_size, + eval_log_size, + offset, + )); + println!("gpu_result: {}", gpu_result); + + let cpu_result = cpu_offset_bit_reversed_circle_domain_index( + initial_index, + domain_log_size, + eval_log_size, + offset as isize, + ); + + assert_eq!(gpu_result, cpu_result, "GPU and CPU results should match"); + } +} diff --git a/crates/prover/src/core/backend/gpu/utils.wgsl b/crates/prover/src/core/backend/gpu/utils.wgsl new file mode 100644 index 000000000..604b14eab --- /dev/null +++ b/crates/prover/src/core/backend/gpu/utils.wgsl @@ -0,0 +1,52 @@ +// This shader contains utility functions for bit manipulation and index transformations. +// It is stateless and can be used as a library in other shaders. + +/// Returns the bit reversed index of `i` which is represented by `log_size` bits. +fn bit_reverse_index(i: u32, log_size: u32) -> u32 { + if (log_size == 0u) { + return i; + } + let bits = reverse_bits_u32(i); + return bits >> (32u - log_size); +} + +fn reverse_bits_u32(x: u32) -> u32 { + var x_mut = x; + var result = 0u; + + for (var i = 0u; i < 32u; i = i + 1u) { + result = (result << 1u) | (x_mut & 1u); + x_mut = x_mut >> 1u; + } + + return result; +} + +/// Returns the index of the offset element in a bit reversed circle evaluation +/// of log size `eval_log_size` relative to a smaller domain of size `domain_log_size`. +fn offset_bit_reversed_circle_domain_index( + i: u32, + domain_log_size: u32, + eval_log_size: u32, + offset: i32, +) -> u32 { + var prev_index = bit_reverse_index(i, eval_log_size); + let half_size = 1u << (eval_log_size - 1u); + let step_size = i32(1u << (eval_log_size - domain_log_size - 1u)) * offset; + + if (prev_index < half_size) { + let temp = i32(prev_index) + step_size; + // Implement rem_euclid for positive modulo + let m = i32(half_size); + let rem = temp % m; + prev_index = u32(select(rem + m, rem, rem >= 0)); + } else { + let temp = i32(prev_index - half_size) - step_size; + // Implement rem_euclid for positive modulo + let m = i32(half_size); + let rem = temp % m; + prev_index = u32(select(rem + m, rem, rem >= 0)) + half_size; + } + + return bit_reverse_index(prev_index, eval_log_size); +}