Skip to content

Commit

Permalink
rust: Iterate eko/lib.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
felixhekhorn committed Jan 14, 2025
1 parent 680dae8 commit 5046f50
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
31 changes: 17 additions & 14 deletions crates/eko/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ use std::ffi::c_void;
pub mod bib;
pub mod mellin;

/// Wrapper to pass arguments back to Python
/// Wrapper to pass arguments back to Python.
struct RawCmplx {
re: Vec<f64>,
im: Vec<f64>,
}

/// Map tensors to c-ordered list
/// (res is a vector with dim order_qcd filled with DIMxDIM matrices)
/// Map tensor with shape (o,d,d) to c-ordered list.
///
/// This is needed for the QCD singlet.
fn unravel<const DIM: usize>(res: Vec<[[Complex<f64>; DIM]; DIM]>, order_qcd: usize) -> RawCmplx {
let mut target = RawCmplx {
re: Vec::<f64>::new(),
Expand All @@ -31,8 +32,9 @@ fn unravel<const DIM: usize>(res: Vec<[[Complex<f64>; DIM]; DIM]>, order_qcd: us
target
}

/// Map tensors to c-ordered list in the QED singlet and valence case
/// (res is a matrix with dim order_qcd x order_qed filled with DIMxDIM matrices)
/// Map tensor with shape (o,o',d,d) to c-ordered list.
///
/// This is needed for the QED singlet and valence.
fn unravel_qed<const DIM: usize>(
res: Vec<Vec<[[Complex<f64>; DIM]; DIM]>>,
order_qcd: usize,
Expand All @@ -55,8 +57,9 @@ fn unravel_qed<const DIM: usize>(
target
}

/// Map tensors to c-ordered list in the QED non-singlet case
/// (res is a matrix with dim order_qcd x order_qed filled with complex numbers)
/// Map tensor with shape (o,o',d) to c-ordered list.
///
/// This is needed for the QED non-singlet.
fn unravel_qed_ns(res: Vec<Vec<Complex<f64>>>, order_qcd: usize, order_qed: usize) -> RawCmplx {
let mut target = RawCmplx {
re: Vec::<f64>::new(),
Expand All @@ -71,13 +74,13 @@ fn unravel_qed_ns(res: Vec<Vec<Complex<f64>>>, order_qcd: usize, order_qed: usiz
target
}

/// QCD intergration kernel inside quad.
/// Intergration kernel inside quad.
///
/// # Safety
/// This is the connection from Python, so we don't know what is on the other side.
#[no_mangle]
pub unsafe extern "C" fn rust_quad_ker_qcd(u: f64, rargs: *mut c_void) -> f64 {
let args = *(rargs as *mut QuadQCDargs);
pub unsafe extern "C" fn rust_quad_ker(u: f64, rargs: *mut c_void) -> f64 {
let args = *(rargs as *mut QuadArgs);

let is_singlet = (100 == args.mode0)
|| (21 == args.mode0)
Expand Down Expand Up @@ -248,7 +251,7 @@ type PyQuadKerQCDT = unsafe extern "C" fn(
#[allow(non_snake_case)]
#[repr(C)]
#[derive(Clone, Copy)]
pub struct QuadQCDargs {
pub struct QuadArgs {
pub order_qcd: usize,
pub order_qed: usize,
pub mode0: u16,
Expand Down Expand Up @@ -329,13 +332,13 @@ pub unsafe extern "C" fn my_py(
/// Return empty additional arguments.
///
/// This is required to make the arguments part of the API, otherwise it won't be added to the compiled
/// package (since it does not appear in the signature of `rust_quad_ker_qcd`).
/// package (since it does not appear in the signature of `rust_quad_ker`).
///
/// # Safety
/// This is the connection from and back to Python, so we don't know what is on the other side.
#[no_mangle]
pub unsafe extern "C" fn empty_qcd_args() -> QuadQCDargs {
QuadQCDargs {
pub unsafe extern "C" fn empty_args() -> QuadArgs {
QuadArgs {
order_qcd: 0,
order_qed: 0,
mode0: 0,
Expand Down
4 changes: 2 additions & 2 deletions src/eko/evolution_operator/__init__.py.patch
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ index bd1b19d6..de87651c 100644
+ labels = self.labels
start_time = time.perf_counter()
+ # start preparing C arguments
+ cfg = ekors.lib.empty_qcd_args()
+ cfg = ekors.lib.empty_args()
+ cfg.order_qcd = self.order[0]
+ cfg.order_qed = self.order[1]
+ cfg.is_polarized = self.config["polarized"]
Expand Down Expand Up @@ -680,7 +680,7 @@ index bd1b19d6..de87651c 100644
+ cfg.mode1 = label[1]
+ # construct the low level object
+ func = LowLevelCallable(
+ ekors.lib.rust_quad_ker_qcd, ekors.ffi.addressof(cfg)
+ ekors.lib.rust_quad_ker, ekors.ffi.addressof(cfg)
+ )
res = integrate.quad(
- self.quad_ker(
Expand Down

0 comments on commit 5046f50

Please sign in to comment.