Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add qm31 wgpu implementation #3

Open
wants to merge 4 commits into
base: gpu-gen-trace
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions crates/prover/src/core/backend/gpu/fraction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
use std::borrow::Cow;

use super::compute_composition_polynomial::GpuQM31;
use super::gpu_common::{ByteSerialize, GpuComputeInstance, GpuOperation};
use crate::core::fields::qm31::QM31;
use crate::core::lookups::utils::Fraction;

#[repr(C)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct GpuFraction {
pub numerator: GpuQM31,
pub denominator: GpuQM31,
}

#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct ComputeInput {
pub first: GpuFraction,
pub second: GpuFraction,
}

#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct ComputeOutput {
pub result: GpuFraction,
}

impl ByteSerialize for ComputeInput {}
impl ByteSerialize for ComputeOutput {}

impl From<Fraction<QM31, QM31>> for GpuFraction {
fn from(value: Fraction<QM31, QM31>) -> Self {
GpuFraction {
numerator: GpuQM31::from(value.numerator),
denominator: GpuQM31::from(value.denominator),
}
}
}

pub enum FractionOperation {
Add,
}

impl GpuOperation for FractionOperation {
fn shader_source(&self) -> Cow<'static, str> {
let base_source = include_str!("fraction.wgsl");
let qm31_source = include_str!("qm31.wgsl");

let inputs = r#"
struct ComputeInput {
first: Fraction,
second: Fraction,
}

@group(0) @binding(0) var<storage, read> input: ComputeInput;
"#;

let output = r#"
struct ComputeOutput {
result: Fraction,
}

@group(0) @binding(1) var<storage, read_write> output: ComputeOutput;
"#;

let operation = match self {
FractionOperation::Add => {
r#"
@compute @workgroup_size(1)
fn main() {
output.result = fraction_add(input.first, input.second);
}
"#
}
};

format!("{qm31_source}\n{base_source}\n{inputs}\n{output}\n{operation}").into()
}
}

pub async fn compute_fraction_operation(
operation: FractionOperation,
first: Fraction<QM31, QM31>,
second: Fraction<QM31, QM31>,
) -> ComputeOutput {
let input = ComputeInput {
first: first.into(),
second: second.into(),
};

let instance = GpuComputeInstance::new(&input, std::mem::size_of::<ComputeOutput>()).await;
let (pipeline, bind_group) =
instance.create_pipeline(&operation.shader_source(), operation.entry_point());

let output = instance
.run_computation::<ComputeOutput>(&pipeline, &bind_group, (1, 1, 1))
.await;

output
}

#[cfg(test)]
mod tests {
use super::*;
use crate::core::fields::qm31::QM31;

#[test]
fn test_fraction_add() {
// CPU implementation
let cpu_a = Fraction::new(
QM31::from_u32_unchecked(1u32, 0u32, 0u32, 0u32),
QM31::from_u32_unchecked(3u32, 0u32, 0u32, 0u32),
);
let cpu_b = Fraction::new(
QM31::from_u32_unchecked(2u32, 0u32, 0u32, 0u32),
QM31::from_u32_unchecked(6u32, 0u32, 0u32, 0u32),
);
let cpu_result = cpu_a + cpu_b;

// GPU implementation
let gpu_result = pollster::block_on(compute_fraction_operation(
FractionOperation::Add,
cpu_a,
cpu_b,
));

assert_eq!(cpu_result.numerator, gpu_result.result.numerator.into());
assert_eq!(cpu_result.denominator, gpu_result.result.denominator.into());
}
}
17 changes: 17 additions & 0 deletions crates/prover/src/core/backend/gpu/fraction.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// This shader contains implementations for fraction operations.
// It is stateless and can be used as a library in other shaders.

struct Fraction {
numerator: QM31,
denominator: QM31,
}

// Add two fractions: (a/b + c/d) = (ad + bc)/(bd)
fn fraction_add(a: Fraction, b: Fraction) -> Fraction {
let numerator = qm31_add(
qm31_mul(a.numerator, b.denominator),
qm31_mul(b.numerator, a.denominator)
);
let denominator = qm31_mul(a.denominator, b.denominator);
return Fraction(numerator, denominator);
}
226 changes: 226 additions & 0 deletions crates/prover/src/core/backend/gpu/gpu_common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
use std::borrow::Cow;

use wgpu::util::DeviceExt;

/// Common trait for GPU input/output types
pub trait ByteSerialize: Sized {
fn as_bytes(&self) -> &[u8] {
unsafe {
std::slice::from_raw_parts(
(self as *const Self) as *const u8,
std::mem::size_of::<Self>(),
)
}
}

fn from_bytes(bytes: &[u8]) -> &Self {
assert!(bytes.len() >= std::mem::size_of::<Self>());
unsafe { &*(bytes.as_ptr() as *const Self) }
}
}

/// Base GPU instance for field computations
pub struct GpuComputeInstance {
pub device: wgpu::Device,
pub queue: wgpu::Queue,
pub input_buffer: wgpu::Buffer,
pub output_buffer: wgpu::Buffer,
pub staging_buffer: wgpu::Buffer,
}

impl GpuComputeInstance {
pub async fn new<T: ByteSerialize>(input_data: &T, output_size: usize) -> Self {
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
compatible_surface: None,
force_fallback_adapter: false,
})
.await
.unwrap();

let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: Some("Field Operations Device"),
required_features: wgpu::Features::SHADER_INT64,
required_limits: wgpu::Limits::default(),
memory_hints: wgpu::MemoryHints::Performance,
},
None,
)
.await
.unwrap();

// Create input buffer
let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Field Input Buffer"),
contents: input_data.as_bytes(),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});

// Create output buffer
let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Field Output Buffer"),
size: output_size as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

// Create staging buffer for reading results
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Field Staging Buffer"),
size: output_size as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});

Self {
device,
queue,
input_buffer,
output_buffer,
staging_buffer,
}
}

pub fn create_pipeline(
&self,
shader_source: &str,
entry_point: &str,
) -> (wgpu::ComputePipeline, wgpu::BindGroup) {
let shader = self
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Field Operations Shader"),
source: wgpu::ShaderSource::Wgsl(shader_source.into()),
});

let bind_group_layout =
self.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
label: Some("Field Operations Bind Group Layout"),
});

let pipeline_layout = self
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Field Operations Pipeline Layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});

let pipeline = self
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Field Operations Pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some(entry_point),
cache: None,
compilation_options: Default::default(),
});

let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: self.input_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: self.output_buffer.as_entire_binding(),
},
],
label: Some("Field Operations Bind Group"),
});

(pipeline, bind_group)
}

pub async fn run_computation<T: ByteSerialize + Copy>(
&self,
pipeline: &wgpu::ComputePipeline,
bind_group: &wgpu::BindGroup,
workgroup_count: (u32, u32, u32),
) -> T {
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Field Operations Encoder"),
});

{
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Field Operations Compute Pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(pipeline);
compute_pass.set_bind_group(0, bind_group, &[]);
compute_pass.dispatch_workgroups(
workgroup_count.0,
workgroup_count.1,
workgroup_count.2,
);
}

encoder.copy_buffer_to_buffer(
&self.output_buffer,
0,
&self.staging_buffer,
0,
self.staging_buffer.size(),
);

self.queue.submit(Some(encoder.finish()));

let buffer_slice = self.staging_buffer.slice(..);
let (tx, rx) = flume::bounded(1);
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
tx.send(result).unwrap();
});

self.device.poll(wgpu::Maintain::wait());

rx.recv_async().await.unwrap().unwrap();
let data = buffer_slice.get_mapped_range();
let result = *T::from_bytes(&data);
drop(data);
self.staging_buffer.unmap();

result
}
}

/// Trait for GPU operations that require shader source generation
pub trait GpuOperation {
fn shader_source(&self) -> Cow<'static, str>;

fn entry_point(&self) -> &'static str {
"main"
}
}
4 changes: 4 additions & 0 deletions crates/prover/src/core/backend/gpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
pub mod fraction;
pub mod gen_trace;
pub mod gen_trace_interpolate_columns;
pub mod gen_trace_parallel;
pub mod gen_trace_parallel_no_packed;
pub mod gen_trace_parallel_no_packed_parallel_columns;
pub mod gpu_common;
pub mod qm31;
pub mod utils;
Loading