Skip to content

Commit

Permalink
Adding fallback lib name options for dynamic loading (#240)
Browse files Browse the repository at this point in the history
* #219 #232 Adding fallback options for dynamic loading

* Adding another option

* Fixing cargo check

* Fix clippy
  • Loading branch information
coreylowman authored May 30, 2024
1 parent 5d99bc6 commit 1fa82db
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 31 deletions.
63 changes: 39 additions & 24 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,43 @@ fn main() {
println!("cargo:rerun-if-env-changed=CUDA_PATH");
println!("cargo:rerun-if-env-changed=CUDA_TOOLKIT_ROOT_DIR");

#[cfg(not(any(
feature = "cuda-version-from-build-system",
feature = "cuda-12050",
feature = "cuda-12040",
feature = "cuda-12030",
feature = "cuda-12020",
feature = "cuda-12010",
feature = "cuda-12000",
feature = "cuda-11080",
feature = "cuda-11070",
)))]
compile_error!("Must specify one of the following features: [cuda-version-from-build-system, cuda-12050, cuda-12040, cuda-12030, cuda-12020, cuda-12010, cuda-12000, cuda-11080, cuda-11070]");

#[cfg(feature = "cuda-version-from-build-system")]
cuda_version_from_build_system();
let (major, minor): (usize, usize) = if cfg!(feature = "cuda-12050") {
(12, 5)
} else if cfg!(feature = "cuda-12040") {
(12, 4)
} else if cfg!(feature = "cuda-12030") {
(12, 3)
} else if cfg!(feature = "cuda-12020") {
(12, 2)
} else if cfg!(feature = "cuda-12010") {
(12, 1)
} else if cfg!(feature = "cuda-12000") {
(12, 0)
} else if cfg!(feature = "cuda-11080") {
(11, 8)
} else if cfg!(feature = "cuda-11070") {
(11, 7)
} else {
#[cfg(not(feature = "cuda-version-from-build-system"))]
panic!("Must specify one of the following features: [cuda-version-from-build-system, cuda-12050, cuda-12040, cuda-12030, cuda-12020, cuda-12010, cuda-12000, cuda-11080, cuda-11070]");

#[cfg(feature = "cuda-version-from-build-system")]
{
let (major, minor) = cuda_version_from_build_system();
println!("cargo:rustc-cfg=feature=\"cuda-{major}0{minor}0\"");
(major, minor)
}
};

println!("cargo:rustc-env=CUDA_MAJOR_VERSION={major}");
println!("cargo:rustc-env=CUDA_MINOR_VERSION={minor}");

#[cfg(feature = "dynamic-linking")]
dynamic_linking();
}

#[allow(unused)]
fn cuda_version_from_build_system() {
fn cuda_version_from_build_system() -> (usize, usize) {
let toolkit_root = root_candidates()
.find(|path| path.join("include").join("cuda.h").is_file())
.unwrap_or_else(|| {
Expand All @@ -45,14 +60,14 @@ fn cuda_version_from_build_system() {
let key = "CUDA_VERSION ";
let start = key.len() + contents.find(key).unwrap();
match contents[start..].lines().next().unwrap() {
"12050" => println!("cargo:rustc-cfg=feature=\"cuda-12050\""),
"12040" => println!("cargo:rustc-cfg=feature=\"cuda-12040\""),
"12030" => println!("cargo:rustc-cfg=feature=\"cuda-12030\""),
"12020" => println!("cargo:rustc-cfg=feature=\"cuda-12020\""),
"12010" => println!("cargo:rustc-cfg=feature=\"cuda-12010\""),
"12000" => println!("cargo:rustc-cfg=feature=\"cuda-12000\""),
"11080" => println!("cargo:rustc-cfg=feature=\"cuda-11080\""),
"11070" => println!("cargo:rustc-cfg=feature=\"cuda-11070\""),
"12050" => (12, 5),
"12040" => (12, 4),
"12030" => (12, 3),
"12020" => (12, 2),
"12010" => (12, 1),
"12000" => (12, 0),
"11080" => (11, 8),
"11070" => (11, 7),
v => panic!("Unsupported cuda toolkit version: `{v}`. Please raise a github issue."),
}
}
Expand Down
13 changes: 12 additions & 1 deletion src/cublas/sys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,16 @@ pub use sys_12050::*;

pub unsafe fn lib() -> &'static Lib {
static LIB: std::sync::OnceLock<Lib> = std::sync::OnceLock::new();
LIB.get_or_init(|| Lib::new(libloading::library_filename("cublas")).unwrap())
LIB.get_or_init(|| {
let lib_name = "cublas";
let choices = crate::get_lib_name_candidates(lib_name);
for choice in choices.iter() {
if let Ok(lib) = Lib::new(libloading::library_filename(choice)) {
return lib;
}
}
panic!(
"Unable to find {lib_name} lib under the names {choices:?}. Please open GitHub issue."
);
})
}
13 changes: 12 additions & 1 deletion src/cublaslt/sys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,16 @@ pub use sys_12050::*;

pub unsafe fn lib() -> &'static Lib {
static LIB: std::sync::OnceLock<Lib> = std::sync::OnceLock::new();
LIB.get_or_init(|| Lib::new(libloading::library_filename("cublasLt")).unwrap())
LIB.get_or_init(|| {
let lib_name = "cublasLt";
let choices = crate::get_lib_name_candidates(lib_name);
for choice in choices.iter() {
if let Ok(lib) = Lib::new(libloading::library_filename(choice)) {
return lib;
}
}
panic!(
"Unable to find {lib_name} lib under the names {choices:?}. Please open GitHub issue."
);
})
}
13 changes: 12 additions & 1 deletion src/cudnn/sys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,16 @@ pub use sys_12050::*;

pub unsafe fn lib() -> &'static Lib {
static LIB: std::sync::OnceLock<Lib> = std::sync::OnceLock::new();
LIB.get_or_init(|| Lib::new(libloading::library_filename("cudnn")).unwrap())
LIB.get_or_init(|| {
let lib_name = "cudnn";
let choices = crate::get_lib_name_candidates(lib_name);
for choice in choices.iter() {
if let Ok(lib) = Lib::new(libloading::library_filename(choice)) {
return lib;
}
}
panic!(
"Unable to find {lib_name} lib under the names {choices:?}. Please open GitHub issue."
);
})
}
13 changes: 12 additions & 1 deletion src/curand/sys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,16 @@ pub use sys_12050::*;

pub unsafe fn lib() -> &'static Lib {
static LIB: std::sync::OnceLock<Lib> = std::sync::OnceLock::new();
LIB.get_or_init(|| Lib::new(libloading::library_filename("curand")).unwrap())
LIB.get_or_init(|| {
let lib_name = "curand";
let choices = crate::get_lib_name_candidates(lib_name);
for choice in choices.iter() {
if let Ok(lib) = Lib::new(libloading::library_filename(choice)) {
return lib;
}
}
panic!(
"Unable to find {lib_name} lib under the names {choices:?}. Please open GitHub issue."
);
})
}
13 changes: 12 additions & 1 deletion src/driver/sys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,16 @@ pub use sys_12050::*;

pub unsafe fn lib() -> &'static Lib {
static LIB: std::sync::OnceLock<Lib> = std::sync::OnceLock::new();
LIB.get_or_init(|| Lib::new(libloading::library_filename("cuda")).unwrap())
LIB.get_or_init(|| {
let lib_name = "cuda";
let choices = [lib_name, "nvcuda"];
for choice in choices {
if let Ok(lib) = Lib::new(libloading::library_filename(choice)) {
return lib;
}
}
panic!(
"Unable to find {lib_name} lib under the names {choices:?}. Please open GitHub issue."
);
})
}
23 changes: 23 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,26 @@ pub mod nccl;
pub mod nvrtc;

pub mod types;

pub(crate) fn get_lib_name_candidates(lib_name: &str) -> std::vec::Vec<std::string::String> {
let pointer_width = if cfg!(target_pointer_width = "32") {
"32"
} else if cfg!(target_pointer_width = "64") {
"64"
} else {
panic!("Unsupported target pointer width")
};

let major = env!("CUDA_MAJOR_VERSION");
let minor = env!("CUDA_MINOR_VERSION");

[
lib_name.into(),
std::format!("{lib_name}{pointer_width}"),
std::format!("{lib_name}{pointer_width}_{major}"),
std::format!("{lib_name}{pointer_width}_{major}{minor}"),
std::format!("{lib_name}{pointer_width}_{major}{minor}_0"),
std::format!("{lib_name}{pointer_width}_{major}0_{minor}"),
]
.into()
}
13 changes: 12 additions & 1 deletion src/nccl/sys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,16 @@ pub use sys_12050::*;

pub unsafe fn lib() -> &'static Lib {
static LIB: std::sync::OnceLock<Lib> = std::sync::OnceLock::new();
LIB.get_or_init(|| Lib::new(libloading::library_filename("nccl")).unwrap())
LIB.get_or_init(|| {
let lib_name = "nccl";
let choices = crate::get_lib_name_candidates(lib_name);
for choice in choices.iter() {
if let Ok(lib) = Lib::new(libloading::library_filename(choice)) {
return lib;
}
}
panic!(
"Unable to find {lib_name} lib under the names {choices:?}. Please open GitHub issue."
);
})
}
13 changes: 12 additions & 1 deletion src/nvrtc/sys/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,16 @@ pub use sys_12050::*;

pub unsafe fn lib() -> &'static Lib {
static LIB: std::sync::OnceLock<Lib> = std::sync::OnceLock::new();
LIB.get_or_init(|| Lib::new(libloading::library_filename("nvrtc")).unwrap())
LIB.get_or_init(|| {
let lib_name = "nvrtc";
let choices = crate::get_lib_name_candidates(lib_name);
for choice in choices.iter() {
if let Ok(lib) = Lib::new(libloading::library_filename(choice)) {
return lib;
}
}
panic!(
"Unable to find {lib_name} lib under the names {choices:?}. Please open GitHub issue."
);
})
}

0 comments on commit 1fa82db

Please sign in to comment.