From 471eba3069491332301bf7af218069134be0e65e Mon Sep 17 00:00:00 2001 From: Andrew Westberg Date: Mon, 9 Sep 2024 14:03:01 +0000 Subject: [PATCH] fix[pallas-math]: use malachite as default --- .github/workflows/validate.yml | 31 +- pallas-crypto/Cargo.toml | 3 +- pallas-crypto/src/nonce/epoch_nonce.rs | 86 -- pallas-crypto/src/nonce/mod.rs | 205 +++- pallas-crypto/src/nonce/rolling_nonce.rs | 168 --- pallas-crypto/src/vrf/mod.rs | 143 ++- pallas-math/Cargo.toml | 11 +- pallas-math/src/lib.rs | 14 +- pallas-math/src/math.rs | 290 ++++- pallas-math/src/math_gmp.rs | 1062 ----------------- .../src/{math_num.rs => math_malachite.rs} | 352 ++++-- 11 files changed, 832 insertions(+), 1533 deletions(-) delete mode 100644 pallas-crypto/src/nonce/epoch_nonce.rs delete mode 100644 pallas-crypto/src/nonce/rolling_nonce.rs delete mode 100644 pallas-math/src/math_gmp.rs rename pallas-math/src/{math_num.rs => math_malachite.rs} (59%) diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index dd24b439..84cbfb30 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -26,31 +26,18 @@ jobs: toolchain: ${{ matrix.rust }} - name: Run cargo check Windows - if: matrix.os == 'windows-latest' - run: cargo check --no-default-features --features num - - - name: Run cargo check - if: matrix.os != 'windows-latest' run: cargo check test: name: Test Suite - runs-on: ubuntu-latest - steps: - - name: Checkout sources - uses: actions/checkout@v2 - - - name: Install stable toolchain - uses: dtolnay/rust-toolchain@stable - with: - toolchain: stable + strategy: + fail-fast: false + matrix: + os: [ windows-latest, ubuntu-latest, macOS-latest ] + rust: [ stable ] - - name: Run cargo test - run: cargo test + runs-on: ${{ matrix.os }} - test-windows: - name: Test Suite Windows - runs-on: windows-latest steps: - name: Checkout sources uses: actions/checkout@v2 @@ -61,7 +48,7 @@ jobs: toolchain: stable - name: Run cargo test - run: cargo test --no-default-features --features num + run: cargo test lints: name: Lints @@ -80,6 +67,4 @@ jobs: run: cargo fmt --all -- --check - name: Run cargo clippy - run: | - cargo clippy -- -D warnings - cargo clippy --no-default-features --features num -- -D warnings \ No newline at end of file + run: cargo clippy -- -D warnings diff --git a/pallas-crypto/Cargo.toml b/pallas-crypto/Cargo.toml index 6ba4181e..9b925777 100644 --- a/pallas-crypto/Cargo.toml +++ b/pallas-crypto/Cargo.toml @@ -21,10 +21,9 @@ rand_core = "0.6" pallas-codec = { version = "=0.30.2", path = "../pallas-codec" } serde = "1.0.143" -# FIXME: This needs to be a properly deployed crate from the input-output-hk/vrf repository after my PR is merged # The vrf crate has not been fully tested in production environments and still has several upstream issues that # are open PRs but not merged yet. -vrf_dalek = { git = "https://github.com/AndrewWestberg/vrf", rev = "6fc1440b197098feb6d75e2b71517019b8e2e9c2" } +vrf_dalek = { git = "https://github.com/input-output-hk/vrf", rev = "2b6bffe9e1506341601cae51e02659d12389c5be" } [dev-dependencies] itertools = "0.13" diff --git a/pallas-crypto/src/nonce/epoch_nonce.rs b/pallas-crypto/src/nonce/epoch_nonce.rs deleted file mode 100644 index a3ea955d..00000000 --- a/pallas-crypto/src/nonce/epoch_nonce.rs +++ /dev/null @@ -1,86 +0,0 @@ -use crate::hash::{Hash, Hasher}; -use crate::nonce::{Error, NonceGenerator}; - -/// A nonce generator that calculates an epoch nonce from the eta_v value (nc) of the block right before -/// the stability window and the block hash of the first block from the previous epoch (nh). -#[derive(Debug, Clone)] -pub struct EpochNonceGenerator { - pub nonce: Hash<32>, -} - -impl EpochNonceGenerator { - /// Create a new [`EpochNonceGenerator`] generator. - /// params: - /// - nc: the eta_v value of the block right before the stability window. - /// - nh: the block hash of the first block from the previous epoch. - /// - extra_entropy: optional extra entropy to be used in the nonce calculation. - pub fn new(nc: Hash<32>, nh: Hash<32>, extra_entropy: Option<&[u8]>) -> Self { - let mut hasher = Hasher::<256>::new(); - hasher.input(nc.as_ref()); - hasher.input(nh.as_ref()); - let epoch_nonce = hasher.finalize(); - if let Some(extra_entropy) = extra_entropy { - let mut hasher = Hasher::<256>::new(); - hasher.input(epoch_nonce.as_ref()); - hasher.input(extra_entropy); - let extra_nonce = hasher.finalize(); - Self { nonce: extra_nonce } - } else { - Self { nonce: epoch_nonce } - } - } -} - -impl NonceGenerator for EpochNonceGenerator { - fn finalize(&mut self) -> Result, Error> { - Ok(self.nonce) - } -} - -#[cfg(test)] -mod tests { - use itertools::izip; - - use crate::hash::Hash; - - use super::*; - - #[test] - fn test_epoch_nonce() { - let nc_values = vec![ - hex::decode("e86e133bd48ff5e79bec43af1ac3e348b539172f33e502d2c96735e8c51bd04d") - .unwrap(), - hex::decode("d1340a9c1491f0face38d41fd5c82953d0eb48320d65e952414a0c5ebaf87587") - .unwrap(), - ]; - let nh_values = vec![ - hex::decode("d7a1ff2a365abed59c9ae346cba842b6d3df06d055dba79a113e0704b44cc3e9") - .unwrap(), - hex::decode("ee91d679b0a6ce3015b894c575c799e971efac35c7a8cbdc2b3f579005e69abd") - .unwrap(), - ]; - let ee = hex::decode("d982e06fd33e7440b43cefad529b7ecafbaa255e38178ad4189a37e4ce9bf1fa") - .unwrap(); - let extra_entropy_values: Vec> = vec![None, Some(&ee)]; - let expected_epoch_nonces = vec![ - hex::decode("e536a0081ddd6d19786e9d708a85819a5c3492c0da7349f59c8ad3e17e4acd98") - .unwrap(), - hex::decode("0022cfa563a5328c4fb5c8017121329e964c26ade5d167b1bd9b2ec967772b60") - .unwrap(), - ]; - - for (nc_value, nh_value, extra_entropy_value, expected_epoch_nonce) in izip!( - nc_values.iter(), - nh_values.iter(), - extra_entropy_values.iter(), - expected_epoch_nonces.iter() - ) { - let nc: Hash<32> = Hash::from(nc_value.as_slice()); - let nh: Hash<32> = Hash::from(nh_value.as_slice()); - let extra_entropy = *extra_entropy_value; - let mut epoch_nonce = EpochNonceGenerator::new(nc, nh, extra_entropy); - let nonce = epoch_nonce.finalize().unwrap(); - assert_eq!(nonce.as_ref(), expected_epoch_nonce.as_slice()); - } - } -} diff --git a/pallas-crypto/src/nonce/mod.rs b/pallas-crypto/src/nonce/mod.rs index 174b7498..8b59d79d 100644 --- a/pallas-crypto/src/nonce/mod.rs +++ b/pallas-crypto/src/nonce/mod.rs @@ -1,17 +1,198 @@ -use thiserror::Error; +use crate::hash::{Hash, Hasher}; -use crate::hash::Hash; - -pub mod epoch_nonce; -pub mod rolling_nonce; +/// A nonce generator function that calculates an epoch nonce from the eta_v value (nc) of the block right before +/// the stability window and the block hash of the first block from the previous epoch (nh). +pub fn generate_epoch_nonce(nc: Hash<32>, nh: Hash<32>, extra_entropy: Option<&[u8]>) -> Hash<32> { + let mut hasher = Hasher::<256>::new(); + hasher.input(nc.as_ref()); + hasher.input(nh.as_ref()); + let epoch_nonce = hasher.finalize(); + if let Some(extra_entropy) = extra_entropy { + let mut hasher = Hasher::<256>::new(); + hasher.input(epoch_nonce.as_ref()); + hasher.input(extra_entropy); + hasher.finalize() + } else { + epoch_nonce + } +} -#[derive(Error, Debug)] -pub enum Error { - #[error("Nonce error: {0}")] - Nonce(String), +/// A nonce generator function that calculates a rolling nonce (eta_v) by applying each cardano block in +/// the shelley era and beyond. These rolling nonce values are used to help calculate the epoch +/// nonce values used in consensus for the Ouroboros protocols (tpraos, praos, cpraos). +/// +/// # Panic +/// +/// This function may panic if the `block_eta_vrf_0` argument is not a slice of +/// either 32 bytes or 64 bytes. +/// +pub fn generate_rolling_nonce(previous_block_eta_v: Hash<32>, block_eta_vrf_0: &[u8]) -> Hash<32> { + assert!( + block_eta_vrf_0.len() == 32 || block_eta_vrf_0.len() == 64, + "Invalid block_eta_vrf_0 length: {}, expected 32 or 64", + block_eta_vrf_0.len() + ); + let mut hasher = Hasher::<256>::new(); + hasher.input(previous_block_eta_v.as_ref()); + hasher.input(Hasher::<256>::hash(block_eta_vrf_0).as_ref()); + hasher.finalize() } -/// A trait for generating nonces. -pub trait NonceGenerator: Sized { - fn finalize(&mut self) -> Result, Error>; +#[cfg(test)] +mod tests { + use itertools::izip; + + use crate::hash::Hash; + + use super::*; + + #[test] + fn test_epoch_nonce() { + let nc_values = vec![ + hex::decode("e86e133bd48ff5e79bec43af1ac3e348b539172f33e502d2c96735e8c51bd04d") + .unwrap(), + hex::decode("d1340a9c1491f0face38d41fd5c82953d0eb48320d65e952414a0c5ebaf87587") + .unwrap(), + ]; + let nh_values = vec![ + hex::decode("d7a1ff2a365abed59c9ae346cba842b6d3df06d055dba79a113e0704b44cc3e9") + .unwrap(), + hex::decode("ee91d679b0a6ce3015b894c575c799e971efac35c7a8cbdc2b3f579005e69abd") + .unwrap(), + ]; + let ee = hex::decode("d982e06fd33e7440b43cefad529b7ecafbaa255e38178ad4189a37e4ce9bf1fa") + .unwrap(); + let extra_entropy_values: Vec> = vec![None, Some(&ee)]; + let expected_epoch_nonces = vec![ + hex::decode("e536a0081ddd6d19786e9d708a85819a5c3492c0da7349f59c8ad3e17e4acd98") + .unwrap(), + hex::decode("0022cfa563a5328c4fb5c8017121329e964c26ade5d167b1bd9b2ec967772b60") + .unwrap(), + ]; + + for (nc_value, nh_value, extra_entropy_value, expected_epoch_nonce) in izip!( + nc_values.iter(), + nh_values.iter(), + extra_entropy_values.iter(), + expected_epoch_nonces.iter() + ) { + let nc: Hash<32> = Hash::from(nc_value.as_slice()); + let nh: Hash<32> = Hash::from(nh_value.as_slice()); + let extra_entropy = *extra_entropy_value; + let epoch_nonce = generate_epoch_nonce(nc, nh, extra_entropy); + assert_eq!(epoch_nonce.as_ref(), expected_epoch_nonce.as_slice()); + } + } + + #[test] + fn test_rolling_nonce() { + let shelley_genesis_hash = + hex::decode("1a3be38bcbb7911969283716ad7aa550250226b76a61fc51cc9a9a35d9276d81") + .unwrap(); + + let eta_vrf_0_values = vec![ + hex::decode("36ec5378d1f5041a59eb8d96e61de96f0950fb41b49ff511f7bc7fd109d4383e1d24be7034e6749c6612700dd5ceb0c66577b88a19ae286b1321d15bce1ab736").unwrap(), + hex::decode("e0bf34a6b73481302f22987cde4c12807cbc2c3fea3f7fcb77261385a50e8ccdda3226db3efff73e9fb15eecf841bbc85ce37550de0435ebcdcb205e0ed08467").unwrap(), + hex::decode("7107ef8c16058b09f4489715297e55d145a45fc0df75dfb419cab079cd28992854a034ad9dc4c764544fb70badd30a9611a942a03523c6f3d8967cf680c4ca6b").unwrap(), + hex::decode("6f561aad83884ee0d7b19fd3d757c6af096bfd085465d1290b13a9dfc817dfcdfb0b59ca06300206c64d1ba75fd222a88ea03c54fbbd5d320b4fbcf1c228ba4e").unwrap(), + hex::decode("3d3ba80724db0a028783afa56a85d684ee778ae45b9aa9af3120f5e1847be1983bd4868caf97fcfd82d5a3b0b7c1a6d53491d75440a75198014eb4e707785cad").unwrap(), + hex::decode("0b07976bc04321c2e7ba0f1acb3c61bd92b5fc780a855632e30e6746ab4ac4081490d816928762debd3e512d22ad512a558612adc569718df1784261f5c26aff").unwrap(), + hex::decode("5e9e001fb1e2ddb0dc7ff40af917ecf4ba9892491d4bcbf2c81db2efc57627d40d7aac509c9bcf5070d4966faaeb84fd76bb285af2e51af21a8c024089f598c1").unwrap(), + hex::decode("182e83f8c67ad2e6bddead128e7108499ebcbc272b50c42783ef08f035aa688fecc7d15be15a90dbfe7fe5d7cd9926987b6ec12b05f2eadfe0eb6cad5130aca4").unwrap(), + hex::decode("275e7404b2385a9d606d67d0e29f5516fb84c1c14aaaf91afa9a9b3dcdfe09075efdadbaf158cfa1e9f250cc7c691ed2db4a29288d2426bd74a371a2a4b91b57").unwrap(), + hex::decode("0f35c7217792f8b0cbb721ae4ae5c9ae7f2869df49a3db256aacc10d23997a09e0273261b44ebbcecd6bf916f2c1cd79cf25b0c2851645d75dd0747a8f6f92f5").unwrap(), + hex::decode("14c28bf9b10421e9f90ffc9ab05df0dc8c8a07ffac1c51725fba7e2b7972d0769baea248f93ed0f2067d11d719c2858c62fc1d8d59927b41d4c0fbc68d805b32").unwrap(), + hex::decode("e4ce96fee9deb9378a107db48587438cddf8e20a69e21e5e4fbd35ef0c56530df77eba666cb152812111ba66bbd333ed44f627c727115f8f4f15b31726049a19").unwrap(), + hex::decode("b38f315e3ce369ea2551bf4f44e723dd15c7d67ba4b3763997909f65e46267d6540b9b00a7a65ae3d1f3a3316e57a821aeaac33e4e42ded415205073134cd185").unwrap(), + hex::decode("4bcbf774af9c8ff24d4d96099001ec06a24802c88fea81680ea2411392d32dbd9b9828a690a462954b894708d511124a2db34ec4179841e07a897169f0f1ac0e").unwrap(), + hex::decode("65247ace6355f978a12235265410c44f3ded02849ec8f8e6db2ac705c3f57d322ea073c13cf698e15d7e1d7f2bc95e7b3533be0dee26f58864f1664df0c1ebba").unwrap(), + hex::decode("d0c2bb451d0a3465a7fef7770718e5e49bf092a85dbf5af66ea26ec9c1b359026905fc1457e2b98b01ede7ba42aedcc525301f747a0ed9a9b61c37f27f9d8812").unwrap(), + hex::decode("250d9ec7ebec73e885798ae9427e1ea47b5ae66059b465b7c0fd132d17a9c2dcae29ba72863c1861cfb776d342812c4e9000981c4a40819430d0e84aa8bfeb0d").unwrap(), + hex::decode("0549cc0a5e5b9920796b88784c49b7d9a04cf2e86ab18d5af7b00780e60fb0fb5a7129945f4f918201dbad5348d4ccface4370f266540f8e072cdb46d3705930").unwrap(), + hex::decode("e543a26031dbdc8597b1beeba48a4f1cf6ab90c0e5b9343936b6e948a791198fc4fa22928e21edec812a04d0c9629772bf78e475d91a323cd8a8a6e005f92b4d").unwrap(), + hex::decode("4e4be69ad170fb8b3b17835913391ee537098d49e4452844a71ab2147ac55e45871c8943271806034ee9450b31c9486db9d26942946f48040ece7eea81424af1").unwrap(), + hex::decode("cb8a528288f902349250f9e8015e8334b0e24c2eeb9bb7d75e73c39024685804577565e62aca35948d2686ea38e9f8de97837ea30d2fb08347768394416e4a38").unwrap(), + hex::decode("fce94c47196a56a5cb94d5151ca429daf1c563ae889d0a42c2d03cfe43c94a636221c7e21b0668de9e5b6b32ee1e78b2c9aabc16537bf79c7b85eb956f433ac7").unwrap(), + hex::decode("fc8a125c9e2418c87907db4437a0ad6a378bba728ac8e0ce0e64f2a2f4b8201315e1b08d7983ce597cb68be2a2400d6d0d59b7359fe3dc9daca73d468da48972").unwrap(), + hex::decode("49290417311420d67f029a80b013b754150dd0097aa64de1c14a2467ab2e26cc2724071c04cb90cb0cf6c6353cf31f63235af7849d6ba023fd0fc0bc79d32f0b").unwrap(), + hex::decode("45c65effdc8007c9f2fc9057af986e94eb5c12b755465058d4b933ee37638452c5eeca4b43b8cbddabc60f29cbe5676b0bc55c0da88f8d0c36068e7d17ee603a").unwrap(), + hex::decode("a51e4e0f28aee3024207d87a5a1965313bdba4df44c6b845f7ca3408e5dabfe873df6b6ba26000e841f83f69e1de7857122ba538b42f255da2d013208af806ba").unwrap(), + hex::decode("5dbd891bf3bcfd5d054274759c13552aeaa187949875d81ee62ed394253ae25182e78b3a4a1976a7674e425bab860931d57f8a1d4fdc81fa4c3e8e8bf9016d5d").unwrap(), + hex::decode("3b5b044026e9066d62ce2f5a1fb01052a8cfe200dea28d421fc70f42c4d2b890b90ffef5675de1e47e4a20c9ca8700ceea23a61338ac759a098d167fa71642cb").unwrap(), + hex::decode("bb4017880cfa1e37f256dfe2a9cdb1349ed5dea8f69de75dc5933540dcf49e69afc33c837ba8a791857e16fad8581c4e9046778c49ca1ecd1fb675983be6d721").unwrap(), + hex::decode("517bbdb6e9e5f4702193064543204e780f5d33a866d0dcd65ada19f05715dea60ca81b842de5dca8f6b84a9cf469c8fb81991369dba21571476cc9c8d4ff2136").unwrap(), + ]; + + let expected_eta_v_values = vec![ + hex::decode("2af15f57076a8ff225746624882a77c8d2736fe41d3db70154a22b50af851246") + .unwrap(), + hex::decode("a815ff978369b57df09b0072485c26920dc0ec8e924a852a42f0715981cf0042") + .unwrap(), + hex::decode("f112d91435b911b6b5acaf27198762905b1cdec8c5a7b712f925ce3c5c76bb5f") + .unwrap(), + hex::decode("5450d95d9be4194a0ded40fbb4036b48d1f1d6da796e933fefd2c5c888794b4b") + .unwrap(), + hex::decode("c5c0f406cb522ad3fead4ecc60bce9c31e80879bc17eb1bb9acaa9b998cdf8bf") + .unwrap(), + hex::decode("5857048c728580549de645e087ba20ef20bb7c51cc84b5bc89df6b8b0ed98c41") + .unwrap(), + hex::decode("d6f40ef403687115db061b2cb9b1ab4ddeb98222075d5a3e03c8d217d4d7c40e") + .unwrap(), + hex::decode("5489d75a9f4971c1824462b5e2338609a91f121241f21fee09811bd5772ae0a8") + .unwrap(), + hex::decode("04716326833ecdb595153adac9566a4b39e5c16e8d02526cb4166e4099a00b1a") + .unwrap(), + hex::decode("39db709f50c8a279f0a94adcefb9360dbda6cdce168aed4288329a9cd53492b6") + .unwrap(), + hex::decode("c784b8c8678e0a04748a3ad851dd7c34ed67141cd9dc0c50ceaff4df804699a7") + .unwrap(), + hex::decode("cc1a5861358c075de93a26a91c5a951d5e71190d569aa2dc786d4ca8fc80cc38") + .unwrap(), + hex::decode("514979c89313c49e8f59fb8445113fa7623e99375cc4917fe79df54f8d4bdfce") + .unwrap(), + hex::decode("6a783e04481b9e04e8f3498a3b74c90c06a1031fb663b6793ce592a6c26f56f4") + .unwrap(), + hex::decode("1190f5254599dcee4f3cf1afdf4181085c36a6db6c30f334bfe6e6f320a6ed91") + .unwrap(), + hex::decode("91c777d6db066fe58edd67cd751fc7240268869b365393f6910e0e8f0fa58af3") + .unwrap(), + hex::decode("c545d83926c011b5c68a72de9a4e2f9da402703f4aab1b967456eae73d9f89b3") + .unwrap(), + hex::decode("ec31d2348bf543482842843a61d5b32691dedf801f198d68126c423ddf391e8b") + .unwrap(), + hex::decode("de223867d5c972895dd99ac0280a3e02947a7fb018ed42ed048266f913d2dfc2") + .unwrap(), + hex::decode("4dd9801752aade9c6e06bf03e9d2ec8a30ef7c6f30106790a23a9599e90ee08a") + .unwrap(), + hex::decode("fcb183abd512271f40408a5872827ce79cc2dda685a986a7dbdc61d842495a91") + .unwrap(), + hex::decode("e834d8ffd6dd042167b13e38512c62afdaf4d635d5b1ab0d513e08e9bef0ef63") + .unwrap(), + hex::decode("270a78257a958cd5fdb26f0b9ab302df2d2196fd04989f7ca1bb703e4dd904f0") + .unwrap(), + hex::decode("7e324f67af787dfddee10354128c60c60bf601bd8147c867d2471749a7b0f334") + .unwrap(), + hex::decode("54521ed42e0e782b5268ec55f80cff582162bc23fdcee5cdaa0f1a2ce7fa1f02") + .unwrap(), + hex::decode("557c296a71d8c9cb3fe7dcd95fbf4d70f6a3974d93c71b450d62a41b9a85d5a1") + .unwrap(), + hex::decode("20e078301ca282857378bbf10ac40965445c4c9fa73a160e0a116b4cf808b4b4") + .unwrap(), + hex::decode("b5a741dd3ff6a5a3d27b4d046dfb7a3901aacd37df7e931ba05e1320ad155c1c") + .unwrap(), + hex::decode("8b445f35f4a7b76e5d279d71fa9e05376a7c4533ca8b2b98fd2dbaf814d3bf8f") + .unwrap(), + hex::decode("08e7b5277abc139deb50f61264375fa091c580f8a85f259be78a002f7023c31f") + .unwrap(), + ]; + + let mut previous_block_eta_v = Hash::<32>::from(shelley_genesis_hash.as_slice()); + + for (eta_vrf_0, expected_eta_v) in eta_vrf_0_values.iter().zip(expected_eta_v_values.iter()) + { + let rolling_nonce = generate_rolling_nonce(previous_block_eta_v, eta_vrf_0); + assert_eq!(rolling_nonce.as_ref(), expected_eta_v.as_slice()); + previous_block_eta_v = rolling_nonce; + } + } } diff --git a/pallas-crypto/src/nonce/rolling_nonce.rs b/pallas-crypto/src/nonce/rolling_nonce.rs deleted file mode 100644 index 84cf8f65..00000000 --- a/pallas-crypto/src/nonce/rolling_nonce.rs +++ /dev/null @@ -1,168 +0,0 @@ -use crate::hash::{Hash, Hasher}; -use crate::nonce::{Error, NonceGenerator}; - -/// A nonce generator that calculates a rolling nonce by applying each cardano block in -/// the shelley era and beyond. These rolling nonce values are used to help calculate the epoch -/// nonce values used in consensus for the Ouroboros protocols (tpraos, praos, cpraos). -#[derive(Debug, Clone)] -pub struct RollingNonceGenerator { - pub nonce: Hash<32>, - block_eta_v: Option>, -} - -impl RollingNonceGenerator { - pub fn new(nonce: Hash<32>) -> Self { - Self { - nonce, - block_eta_v: None, - } - } - - pub fn apply_block(&mut self, eta_vrf_0: &[u8]) -> Result<(), Error> { - let len = eta_vrf_0.len(); - if len != 64 && len != 32 { - return Err(Error::Nonce(format!( - "Invalid eta_vrf_0 length: {}, expected 32 or 64", - eta_vrf_0.len() - ))); - } - self.block_eta_v = Some(Hasher::<256>::hash(eta_vrf_0)); - Ok(()) - } -} - -impl NonceGenerator for RollingNonceGenerator { - fn finalize(&mut self) -> Result, Error> { - if self.block_eta_v.is_none() { - return Err(Error::Nonce( - "Must call apply_block before finalize!".to_string(), - )); - } - let mut hasher = Hasher::<256>::new(); - hasher.input(self.nonce.as_ref()); - hasher.input(self.block_eta_v.unwrap().as_ref()); - Ok(hasher.finalize()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_rolling_nonce() { - let shelley_genesis_hash = - hex::decode("1a3be38bcbb7911969283716ad7aa550250226b76a61fc51cc9a9a35d9276d81") - .unwrap(); - - let eta_vrf_0_values = vec![ - hex::decode("36ec5378d1f5041a59eb8d96e61de96f0950fb41b49ff511f7bc7fd109d4383e1d24be7034e6749c6612700dd5ceb0c66577b88a19ae286b1321d15bce1ab736").unwrap(), - hex::decode("e0bf34a6b73481302f22987cde4c12807cbc2c3fea3f7fcb77261385a50e8ccdda3226db3efff73e9fb15eecf841bbc85ce37550de0435ebcdcb205e0ed08467").unwrap(), - hex::decode("7107ef8c16058b09f4489715297e55d145a45fc0df75dfb419cab079cd28992854a034ad9dc4c764544fb70badd30a9611a942a03523c6f3d8967cf680c4ca6b").unwrap(), - hex::decode("6f561aad83884ee0d7b19fd3d757c6af096bfd085465d1290b13a9dfc817dfcdfb0b59ca06300206c64d1ba75fd222a88ea03c54fbbd5d320b4fbcf1c228ba4e").unwrap(), - hex::decode("3d3ba80724db0a028783afa56a85d684ee778ae45b9aa9af3120f5e1847be1983bd4868caf97fcfd82d5a3b0b7c1a6d53491d75440a75198014eb4e707785cad").unwrap(), - hex::decode("0b07976bc04321c2e7ba0f1acb3c61bd92b5fc780a855632e30e6746ab4ac4081490d816928762debd3e512d22ad512a558612adc569718df1784261f5c26aff").unwrap(), - hex::decode("5e9e001fb1e2ddb0dc7ff40af917ecf4ba9892491d4bcbf2c81db2efc57627d40d7aac509c9bcf5070d4966faaeb84fd76bb285af2e51af21a8c024089f598c1").unwrap(), - hex::decode("182e83f8c67ad2e6bddead128e7108499ebcbc272b50c42783ef08f035aa688fecc7d15be15a90dbfe7fe5d7cd9926987b6ec12b05f2eadfe0eb6cad5130aca4").unwrap(), - hex::decode("275e7404b2385a9d606d67d0e29f5516fb84c1c14aaaf91afa9a9b3dcdfe09075efdadbaf158cfa1e9f250cc7c691ed2db4a29288d2426bd74a371a2a4b91b57").unwrap(), - hex::decode("0f35c7217792f8b0cbb721ae4ae5c9ae7f2869df49a3db256aacc10d23997a09e0273261b44ebbcecd6bf916f2c1cd79cf25b0c2851645d75dd0747a8f6f92f5").unwrap(), - hex::decode("14c28bf9b10421e9f90ffc9ab05df0dc8c8a07ffac1c51725fba7e2b7972d0769baea248f93ed0f2067d11d719c2858c62fc1d8d59927b41d4c0fbc68d805b32").unwrap(), - hex::decode("e4ce96fee9deb9378a107db48587438cddf8e20a69e21e5e4fbd35ef0c56530df77eba666cb152812111ba66bbd333ed44f627c727115f8f4f15b31726049a19").unwrap(), - hex::decode("b38f315e3ce369ea2551bf4f44e723dd15c7d67ba4b3763997909f65e46267d6540b9b00a7a65ae3d1f3a3316e57a821aeaac33e4e42ded415205073134cd185").unwrap(), - hex::decode("4bcbf774af9c8ff24d4d96099001ec06a24802c88fea81680ea2411392d32dbd9b9828a690a462954b894708d511124a2db34ec4179841e07a897169f0f1ac0e").unwrap(), - hex::decode("65247ace6355f978a12235265410c44f3ded02849ec8f8e6db2ac705c3f57d322ea073c13cf698e15d7e1d7f2bc95e7b3533be0dee26f58864f1664df0c1ebba").unwrap(), - hex::decode("d0c2bb451d0a3465a7fef7770718e5e49bf092a85dbf5af66ea26ec9c1b359026905fc1457e2b98b01ede7ba42aedcc525301f747a0ed9a9b61c37f27f9d8812").unwrap(), - hex::decode("250d9ec7ebec73e885798ae9427e1ea47b5ae66059b465b7c0fd132d17a9c2dcae29ba72863c1861cfb776d342812c4e9000981c4a40819430d0e84aa8bfeb0d").unwrap(), - hex::decode("0549cc0a5e5b9920796b88784c49b7d9a04cf2e86ab18d5af7b00780e60fb0fb5a7129945f4f918201dbad5348d4ccface4370f266540f8e072cdb46d3705930").unwrap(), - hex::decode("e543a26031dbdc8597b1beeba48a4f1cf6ab90c0e5b9343936b6e948a791198fc4fa22928e21edec812a04d0c9629772bf78e475d91a323cd8a8a6e005f92b4d").unwrap(), - hex::decode("4e4be69ad170fb8b3b17835913391ee537098d49e4452844a71ab2147ac55e45871c8943271806034ee9450b31c9486db9d26942946f48040ece7eea81424af1").unwrap(), - hex::decode("cb8a528288f902349250f9e8015e8334b0e24c2eeb9bb7d75e73c39024685804577565e62aca35948d2686ea38e9f8de97837ea30d2fb08347768394416e4a38").unwrap(), - hex::decode("fce94c47196a56a5cb94d5151ca429daf1c563ae889d0a42c2d03cfe43c94a636221c7e21b0668de9e5b6b32ee1e78b2c9aabc16537bf79c7b85eb956f433ac7").unwrap(), - hex::decode("fc8a125c9e2418c87907db4437a0ad6a378bba728ac8e0ce0e64f2a2f4b8201315e1b08d7983ce597cb68be2a2400d6d0d59b7359fe3dc9daca73d468da48972").unwrap(), - hex::decode("49290417311420d67f029a80b013b754150dd0097aa64de1c14a2467ab2e26cc2724071c04cb90cb0cf6c6353cf31f63235af7849d6ba023fd0fc0bc79d32f0b").unwrap(), - hex::decode("45c65effdc8007c9f2fc9057af986e94eb5c12b755465058d4b933ee37638452c5eeca4b43b8cbddabc60f29cbe5676b0bc55c0da88f8d0c36068e7d17ee603a").unwrap(), - hex::decode("a51e4e0f28aee3024207d87a5a1965313bdba4df44c6b845f7ca3408e5dabfe873df6b6ba26000e841f83f69e1de7857122ba538b42f255da2d013208af806ba").unwrap(), - hex::decode("5dbd891bf3bcfd5d054274759c13552aeaa187949875d81ee62ed394253ae25182e78b3a4a1976a7674e425bab860931d57f8a1d4fdc81fa4c3e8e8bf9016d5d").unwrap(), - hex::decode("3b5b044026e9066d62ce2f5a1fb01052a8cfe200dea28d421fc70f42c4d2b890b90ffef5675de1e47e4a20c9ca8700ceea23a61338ac759a098d167fa71642cb").unwrap(), - hex::decode("bb4017880cfa1e37f256dfe2a9cdb1349ed5dea8f69de75dc5933540dcf49e69afc33c837ba8a791857e16fad8581c4e9046778c49ca1ecd1fb675983be6d721").unwrap(), - hex::decode("517bbdb6e9e5f4702193064543204e780f5d33a866d0dcd65ada19f05715dea60ca81b842de5dca8f6b84a9cf469c8fb81991369dba21571476cc9c8d4ff2136").unwrap(), - ]; - - let expected_eta_v_values = vec![ - hex::decode("2af15f57076a8ff225746624882a77c8d2736fe41d3db70154a22b50af851246") - .unwrap(), - hex::decode("a815ff978369b57df09b0072485c26920dc0ec8e924a852a42f0715981cf0042") - .unwrap(), - hex::decode("f112d91435b911b6b5acaf27198762905b1cdec8c5a7b712f925ce3c5c76bb5f") - .unwrap(), - hex::decode("5450d95d9be4194a0ded40fbb4036b48d1f1d6da796e933fefd2c5c888794b4b") - .unwrap(), - hex::decode("c5c0f406cb522ad3fead4ecc60bce9c31e80879bc17eb1bb9acaa9b998cdf8bf") - .unwrap(), - hex::decode("5857048c728580549de645e087ba20ef20bb7c51cc84b5bc89df6b8b0ed98c41") - .unwrap(), - hex::decode("d6f40ef403687115db061b2cb9b1ab4ddeb98222075d5a3e03c8d217d4d7c40e") - .unwrap(), - hex::decode("5489d75a9f4971c1824462b5e2338609a91f121241f21fee09811bd5772ae0a8") - .unwrap(), - hex::decode("04716326833ecdb595153adac9566a4b39e5c16e8d02526cb4166e4099a00b1a") - .unwrap(), - hex::decode("39db709f50c8a279f0a94adcefb9360dbda6cdce168aed4288329a9cd53492b6") - .unwrap(), - hex::decode("c784b8c8678e0a04748a3ad851dd7c34ed67141cd9dc0c50ceaff4df804699a7") - .unwrap(), - hex::decode("cc1a5861358c075de93a26a91c5a951d5e71190d569aa2dc786d4ca8fc80cc38") - .unwrap(), - hex::decode("514979c89313c49e8f59fb8445113fa7623e99375cc4917fe79df54f8d4bdfce") - .unwrap(), - hex::decode("6a783e04481b9e04e8f3498a3b74c90c06a1031fb663b6793ce592a6c26f56f4") - .unwrap(), - hex::decode("1190f5254599dcee4f3cf1afdf4181085c36a6db6c30f334bfe6e6f320a6ed91") - .unwrap(), - hex::decode("91c777d6db066fe58edd67cd751fc7240268869b365393f6910e0e8f0fa58af3") - .unwrap(), - hex::decode("c545d83926c011b5c68a72de9a4e2f9da402703f4aab1b967456eae73d9f89b3") - .unwrap(), - hex::decode("ec31d2348bf543482842843a61d5b32691dedf801f198d68126c423ddf391e8b") - .unwrap(), - hex::decode("de223867d5c972895dd99ac0280a3e02947a7fb018ed42ed048266f913d2dfc2") - .unwrap(), - hex::decode("4dd9801752aade9c6e06bf03e9d2ec8a30ef7c6f30106790a23a9599e90ee08a") - .unwrap(), - hex::decode("fcb183abd512271f40408a5872827ce79cc2dda685a986a7dbdc61d842495a91") - .unwrap(), - hex::decode("e834d8ffd6dd042167b13e38512c62afdaf4d635d5b1ab0d513e08e9bef0ef63") - .unwrap(), - hex::decode("270a78257a958cd5fdb26f0b9ab302df2d2196fd04989f7ca1bb703e4dd904f0") - .unwrap(), - hex::decode("7e324f67af787dfddee10354128c60c60bf601bd8147c867d2471749a7b0f334") - .unwrap(), - hex::decode("54521ed42e0e782b5268ec55f80cff582162bc23fdcee5cdaa0f1a2ce7fa1f02") - .unwrap(), - hex::decode("557c296a71d8c9cb3fe7dcd95fbf4d70f6a3974d93c71b450d62a41b9a85d5a1") - .unwrap(), - hex::decode("20e078301ca282857378bbf10ac40965445c4c9fa73a160e0a116b4cf808b4b4") - .unwrap(), - hex::decode("b5a741dd3ff6a5a3d27b4d046dfb7a3901aacd37df7e931ba05e1320ad155c1c") - .unwrap(), - hex::decode("8b445f35f4a7b76e5d279d71fa9e05376a7c4533ca8b2b98fd2dbaf814d3bf8f") - .unwrap(), - hex::decode("08e7b5277abc139deb50f61264375fa091c580f8a85f259be78a002f7023c31f") - .unwrap(), - ]; - - let mut rolling_nonce_generator = - RollingNonceGenerator::new(Hash::from(shelley_genesis_hash.as_slice())); - - for (eta_vrf_0, expected_eta_v) in eta_vrf_0_values.iter().zip(expected_eta_v_values.iter()) - { - rolling_nonce_generator.apply_block(eta_vrf_0).unwrap(); - rolling_nonce_generator = - RollingNonceGenerator::new(rolling_nonce_generator.finalize().unwrap()); - assert_eq!( - rolling_nonce_generator.nonce.as_ref(), - expected_eta_v.as_slice() - ); - } - } -} diff --git a/pallas-crypto/src/vrf/mod.rs b/pallas-crypto/src/vrf/mod.rs index 250c23f6..0dd8eaf8 100644 --- a/pallas-crypto/src/vrf/mod.rs +++ b/pallas-crypto/src/vrf/mod.rs @@ -1,36 +1,107 @@ +use crate::hash::Hash; use thiserror::Error; use vrf_dalek::vrf03::{PublicKey03, SecretKey03, VrfProof03}; +/// error that can be returned if the verification of a [`VrfProof`] fails +/// see [`VrfProof::verify`] +/// #[derive(Error, Debug)] -pub enum Error { - #[error("TryFromSlice {0}")] - TryFromSlice(#[from] std::array::TryFromSliceError), +#[error("VRF Proof Verification failed.")] +pub struct VerificationError( + #[from] + #[source] + vrf_dalek::errors::VrfError, +); - #[error("VrfError {0}")] - VrfError(#[from] vrf_dalek::errors::VrfError), +pub const VRF_SEED_SIZE: usize = 32; +pub const VRF_PROOF_SIZE: usize = 80; +pub const VRF_PUBLIC_KEY_SIZE: usize = 32; +pub const VRF_SECRET_KEY_SIZE: usize = 32; +pub const VRF_PROOF_HASH_SIZE: usize = 64; + +// Wrapper for VRF secret key +pub struct VrfSecretKey { + secret_key_03: SecretKey03, +} + +// Wrapper for VRF public key +pub struct VrfPublicKey { + public_key_03: PublicKey03, +} + +// Wrapper for VRF proof +pub struct VrfProof { + proof_03: VrfProof03, +} + +// Create a VrfSecretKey from a slice +impl From<&[u8; VRF_SECRET_KEY_SIZE]> for VrfSecretKey { + fn from(slice: &[u8; VRF_SECRET_KEY_SIZE]) -> Self { + VrfSecretKey { + secret_key_03: SecretKey03::from_bytes(slice), + } + } +} + +// Create a VrfPublicKey from a slice +impl From<&[u8; VRF_PUBLIC_KEY_SIZE]> for VrfPublicKey { + fn from(slice: &[u8; VRF_PUBLIC_KEY_SIZE]) -> Self { + VrfPublicKey { + public_key_03: PublicKey03::from_bytes(slice), + } + } +} + +// Create a VrfProof from a slice +impl From<&[u8; VRF_PROOF_SIZE]> for VrfProof { + fn from(slice: &[u8; VRF_PROOF_SIZE]) -> Self { + VrfProof { + proof_03: VrfProof03::from_bytes(slice).unwrap(), + } + } } -/// Sign a seed value with a vrf secret key and produce a proof signature -pub fn vrf_prove(secret_key: &[u8], seed: &[u8]) -> Result, Error> { - let sk = SecretKey03::from_bytes(secret_key[..32].try_into()?); - let pk = PublicKey03::from(&sk); - let proof = VrfProof03::generate(&pk, &sk, seed); - Ok(proof.to_bytes().to_vec()) +// Create a VrfPublicKey from a VrfSecretKey +impl From<&VrfSecretKey> for VrfPublicKey { + fn from(secret_key: &VrfSecretKey) -> Self { + VrfPublicKey { + public_key_03: PublicKey03::from(&secret_key.secret_key_03), + } + } } -/// Convert a proof signature to a hash -pub fn vrf_proof_to_hash(proof: &[u8]) -> Result, Error> { - let proof = VrfProof03::from_bytes(proof[..80].try_into()?)?; - Ok(proof.proof_to_hash().to_vec()) +impl VrfSecretKey { + /// Sign a challenge message value with a vrf secret key and produce a proof signature + pub fn prove(&self, challenge: &[u8]) -> VrfProof { + let pk = PublicKey03::from(&self.secret_key_03); + let proof = VrfProof03::generate(&pk, &self.secret_key_03, challenge); + VrfProof { proof_03: proof } + } } -/// Verify a proof signature with a vrf public key. This will return a hash to compare with the original -/// signature hash, but any non-error result is considered a successful verification without needing -/// to do the extra comparison check. -pub fn vrf_verify(public_key: &[u8], signature: &[u8], seed: &[u8]) -> Result, Error> { - let pk = PublicKey03::from_bytes(public_key.try_into()?); - let proof = VrfProof03::from_bytes(signature.try_into()?)?; - Ok(proof.verify(&pk, seed)?.to_vec()) +impl VrfProof { + /// Return the created proof signature + pub fn signature(&self) -> [u8; VRF_PROOF_SIZE] { + self.proof_03.to_bytes() + } + + /// Convert a proof signature to a hash + pub fn to_hash(&self) -> Hash { + Hash::from(self.proof_03.proof_to_hash()) + } + + /// Verify a proof signature with a vrf public key. This will return a hash to compare with the original + /// signature hash, but any non-error result is considered a successful verification without needing + /// to do the extra comparison check. + pub fn verify( + &self, + public_key: &VrfPublicKey, + seed: &[u8], + ) -> Result, VerificationError> { + Ok(Hash::from( + self.proof_03.verify(&public_key.public_key_03, seed)?, + )) + } } #[cfg(test)] @@ -53,22 +124,32 @@ mod tests { // "description": "VRF Signing Key", // "cborHex": "5840adb9c97bec60189aa90d01d113e3ef405f03477d82a94f81da926c90cd46a374e0ff2371508ac339431b50af7d69cde0f120d952bb876806d3136f9a7fda4381" // } - - let vrf_skey = hex::decode("adb9c97bec60189aa90d01d113e3ef405f03477d82a94f81da926c90cd46a374e0ff2371508ac339431b50af7d69cde0f120d952bb876806d3136f9a7fda4381").unwrap(); - let vrf_vkey = + let raw_vrf_skey: Vec = hex::decode("adb9c97bec60189aa90d01d113e3ef405f03477d82a94f81da926c90cd46a374e0ff2371508ac339431b50af7d69cde0f120d952bb876806d3136f9a7fda4381").unwrap(); + let raw_vrf_vkey: Vec = hex::decode("e0ff2371508ac339431b50af7d69cde0f120d952bb876806d3136f9a7fda4381") .unwrap(); - // random seed to sign with vrf_skey - let mut seed = [0u8; 64]; - thread_rng().fill(&mut seed); + let vrf_skey = VrfSecretKey::from(&raw_vrf_skey[..VRF_SECRET_KEY_SIZE].try_into().unwrap()); + let vrf_vkey = + VrfPublicKey::from(&raw_vrf_vkey[..VRF_PUBLIC_KEY_SIZE].try_into().unwrap() + as &[u8; VRF_PUBLIC_KEY_SIZE]); + + let calculated_vrf_vkey = VrfPublicKey::from(&vrf_skey); + assert_eq!( + vrf_vkey.public_key_03.as_bytes(), + calculated_vrf_vkey.public_key_03.as_bytes() + ); + + // random challenge to sign with vrf_skey + let mut challenge = [0u8; 64]; + thread_rng().fill(&mut challenge); // create a proof signature and hash of the seed - let proof_signature = vrf_prove(&vrf_skey, &seed).unwrap(); - let proof_hash = vrf_proof_to_hash(&proof_signature).unwrap(); + let proof = vrf_skey.prove(&challenge); + let proof_hash = proof.to_hash(); // verify the proof signature with the public vrf public key - let verified_hash = vrf_verify(&vrf_vkey, &proof_signature, &seed).unwrap(); + let verified_hash = proof.verify(&vrf_vkey, &challenge).unwrap(); assert_eq!(proof_hash, verified_hash); } } diff --git a/pallas-math/Cargo.toml b/pallas-math/Cargo.toml index fe3a9cd7..e6d14be0 100644 --- a/pallas-math/Cargo.toml +++ b/pallas-math/Cargo.toml @@ -11,17 +11,10 @@ readme = "README.md" authors = ["Andrew Westberg "] exclude = ["tests/data/*"] -[features] -default = ["gmp"] -gmp = ["dep:gmp-mpfr-sys"] -num = ["dep:num-bigint", "dep:num-integer", "dep:num-traits"] - [dependencies] -gmp-mpfr-sys = { version = "1.6.4", features = ["mpc"], default-features = false, optional = true } once_cell = "1.19.0" -num-bigint = { version = "0.4.6", optional = true } -num-integer = { version = "0.1.46", optional = true } -num-traits = { version = "0.2.19", optional = true } +malachite = "0.4.16" +malachite-base = "0.4.16" regex = "1.10.5" thiserror = "1.0.61" diff --git a/pallas-math/src/lib.rs b/pallas-math/src/lib.rs index 485178c8..0726af4f 100644 --- a/pallas-math/src/lib.rs +++ b/pallas-math/src/lib.rs @@ -1,14 +1,2 @@ pub mod math; - -// Ensure only one of `gmp` or `num` is enabled, not both. -#[cfg(all(feature = "gmp", feature = "num"))] -compile_error!("Features `gmp` and `num` are mutually exclusive."); - -#[cfg(all(not(feature = "gmp"), not(feature = "num")))] -compile_error!("One of the features `gmp` or `num` must be enabled."); - -#[cfg(feature = "gmp")] -pub mod math_gmp; - -#[cfg(feature = "num")] -pub mod math_num; +pub mod math_malachite; diff --git a/pallas-math/src/math.rs b/pallas-math/src/math.rs index df9eba81..8f103b82 100644 --- a/pallas-math/src/math.rs +++ b/pallas-math/src/math.rs @@ -7,10 +7,7 @@ use std::ops::{Div, Mul, Neg, Sub}; use thiserror::Error; -#[cfg(feature = "gmp")] -use crate::math_gmp::Decimal; -#[cfg(feature = "num")] -use crate::math_num::Decimal; +pub type FixedDecimal = crate::math_malachite::Decimal; #[derive(Debug, Error)] pub enum Error { @@ -49,6 +46,22 @@ pub trait FixedPrecision: /// Entry point for bounded iterations for comparing two exp values. fn exp_cmp(&self, max_n: u64, bound_self: i64, compare: &Self) -> ExpCmpOrdering; + + /// Round to the nearest integer number + #[must_use] + fn round(&self) -> Self; + + /// Round down to the nearest integer number + #[must_use] + fn floor(&self) -> Self; + + /// Round up to the nearest integer number + #[must_use] + fn ceil(&self) -> Self; + + /// Truncate to the nearest integer number + #[must_use] + fn trunc(&self) -> Self; } #[derive(Debug, Clone, PartialEq)] @@ -72,54 +85,51 @@ impl From<&str> for ExpOrdering { pub struct ExpCmpOrdering { pub iterations: u64, pub estimation: ExpOrdering, - pub approx: Decimal, + pub approx: FixedDecimal, } #[cfg(test)] mod tests { + use super::*; use std::fs::File; use std::io::BufRead; use std::path::PathBuf; - #[cfg(feature = "gmp")] - use crate::math_gmp::Decimal; - #[cfg(feature = "num")] - use crate::math_num::Decimal; - - use super::*; - #[test] fn test_fixed_precision() { - let fp: Decimal = Decimal::new(34); + let fp: FixedDecimal = FixedDecimal::new(34); assert_eq!(fp.precision(), 34); assert_eq!(fp.to_string(), "0.0000000000000000000000000000000000"); } #[test] fn test_fixed_precision_eq() { - let fp1: Decimal = Decimal::new(34); - let fp2: Decimal = Decimal::new(34); + let fp1: FixedDecimal = FixedDecimal::new(34); + let fp2: FixedDecimal = FixedDecimal::new(34); assert_eq!(fp1, fp2); } #[test] fn test_fixed_precision_from_str() { - let fp: Decimal = Decimal::from_str("1234567890123456789012345678901234", 34).unwrap(); + let fp: FixedDecimal = + FixedDecimal::from_str("1234567890123456789012345678901234", 34).unwrap(); assert_eq!(fp.precision(), 34); assert_eq!(fp.to_string(), "0.1234567890123456789012345678901234"); - let fp: Decimal = Decimal::from_str("-1234567890123456789012345678901234", 30).unwrap(); + let fp: FixedDecimal = + FixedDecimal::from_str("-1234567890123456789012345678901234", 30).unwrap(); assert_eq!(fp.precision(), 30); assert_eq!(fp.to_string(), "-1234.567890123456789012345678901234"); - let fp: Decimal = Decimal::from_str("-1234567890123456789012345678901234", 34).unwrap(); + let fp: FixedDecimal = + FixedDecimal::from_str("-1234567890123456789012345678901234", 34).unwrap(); assert_eq!(fp.precision(), 34); assert_eq!(fp.to_string(), "-0.1234567890123456789012345678901234"); } #[test] fn test_fixed_precision_exp() { - let fp: Decimal = Decimal::from(1u64); + let fp: FixedDecimal = FixedDecimal::from(1u64); assert_eq!(fp.to_string(), "1.0000000000000000000000000000000000"); let exp_fp = fp.exp(); assert_eq!(exp_fp.to_string(), "2.7182818284590452353602874043083282"); @@ -127,8 +137,10 @@ mod tests { #[test] fn test_fixed_precision_mul() { - let fp1: Decimal = Decimal::from_str("52500000000000000000000000000000000", 34).unwrap(); - let fp2: Decimal = Decimal::from_str("43000000000000000000000000000000000", 34).unwrap(); + let fp1: FixedDecimal = + FixedDecimal::from_str("52500000000000000000000000000000000", 34).unwrap(); + let fp2: FixedDecimal = + FixedDecimal::from_str("43000000000000000000000000000000000", 34).unwrap(); let fp3 = &fp1 * &fp2; assert_eq!(fp3.to_string(), "22.5750000000000000000000000000000000"); let fp4 = fp1 * fp2; @@ -137,8 +149,8 @@ mod tests { #[test] fn test_fixed_precision_div() { - let fp1: Decimal = Decimal::from_str("1", 34).unwrap(); - let fp2: Decimal = Decimal::from_str("10", 34).unwrap(); + let fp1: FixedDecimal = FixedDecimal::from_str("1", 34).unwrap(); + let fp2: FixedDecimal = FixedDecimal::from_str("10", 34).unwrap(); let fp3 = &fp1 / &fp2; assert_eq!(fp3.to_string(), "0.1000000000000000000000000000000000"); let fp4 = fp1 / fp2; @@ -147,9 +159,9 @@ mod tests { #[test] fn test_fixed_precision_sub() { - let fp1: Decimal = Decimal::from_str("1", 34).unwrap(); + let fp1: FixedDecimal = FixedDecimal::from_str("1", 34).unwrap(); assert_eq!(fp1.to_string(), "0.0000000000000000000000000000000001"); - let fp2: Decimal = Decimal::from_str("10", 34).unwrap(); + let fp2: FixedDecimal = FixedDecimal::from_str("10", 34).unwrap(); assert_eq!(fp2.to_string(), "0.0000000000000000000000000000000010"); let fp3 = &fp1 - &fp2; assert_eq!(fp3.to_string(), "-0.0000000000000000000000000000000009"); @@ -157,6 +169,214 @@ mod tests { assert_eq!(fp4.to_string(), "-0.0000000000000000000000000000000009"); } + #[test] + fn test_fixed_precision_round() { + let fp1: FixedDecimal = + FixedDecimal::from_str("11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp1.round().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.round().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.round().to_string(), + "2.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("1500", 3).unwrap(); + assert_eq!(fp4.round().to_string(), "2.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("1499", 3).unwrap(); + assert_eq!(fp5.round().to_string(), "1.000"); + let fp6: FixedDecimal = + FixedDecimal::from_str("-11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp6.round().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("-14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.round().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("-15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.round().to_string(), + "-2.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("-1500", 3).unwrap(); + assert_eq!(fp4.round().to_string(), "-2.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("-1499", 3).unwrap(); + assert_eq!(fp5.round().to_string(), "-1.000"); + let fp6: FixedDecimal = FixedDecimal::from_str("1000", 3).unwrap(); + assert_eq!(fp6.round().to_string(), "1.000"); + let fp7: FixedDecimal = FixedDecimal::from_str("-1000", 3).unwrap(); + assert_eq!(fp7.round().to_string(), "-1.000"); + } + + #[test] + fn test_fixed_precision_floor() { + let fp1: FixedDecimal = + FixedDecimal::from_str("11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp1.floor().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.floor().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.floor().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("1500", 3).unwrap(); + assert_eq!(fp4.floor().to_string(), "1.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("1499", 3).unwrap(); + assert_eq!(fp5.floor().to_string(), "1.000"); + let fp6: FixedDecimal = + FixedDecimal::from_str("-11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp6.floor().to_string(), + "-2.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("-14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.floor().to_string(), + "-2.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("-15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.floor().to_string(), + "-2.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("-1500", 3).unwrap(); + assert_eq!(fp4.floor().to_string(), "-2.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("-1499", 3).unwrap(); + assert_eq!(fp5.floor().to_string(), "-2.000"); + let fp6: FixedDecimal = FixedDecimal::from_str("1000", 3).unwrap(); + assert_eq!(fp6.floor().to_string(), "1.000"); + let fp7: FixedDecimal = FixedDecimal::from_str("-1000", 3).unwrap(); + assert_eq!(fp7.floor().to_string(), "-1.000"); + } + + #[test] + fn test_fixed_precision_ceil() { + let fp1: FixedDecimal = + FixedDecimal::from_str("11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp1.ceil().to_string(), + "2.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.ceil().to_string(), + "2.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.ceil().to_string(), + "2.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("1500", 3).unwrap(); + assert_eq!(fp4.ceil().to_string(), "2.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("1499", 3).unwrap(); + assert_eq!(fp5.ceil().to_string(), "2.000"); + let fp6: FixedDecimal = + FixedDecimal::from_str("-11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp6.ceil().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("-14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.ceil().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("-15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.ceil().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("-1500", 3).unwrap(); + assert_eq!(fp4.ceil().to_string(), "-1.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("-1499", 3).unwrap(); + assert_eq!(fp5.ceil().to_string(), "-1.000"); + let fp6: FixedDecimal = FixedDecimal::from_str("1000", 3).unwrap(); + assert_eq!(fp6.ceil().to_string(), "1.000"); + let fp7: FixedDecimal = FixedDecimal::from_str("-1000", 3).unwrap(); + assert_eq!(fp7.ceil().to_string(), "-1.000"); + } + + #[test] + fn test_fixed_precision_trunc() { + let fp1: FixedDecimal = + FixedDecimal::from_str("11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp1.trunc().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.trunc().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.trunc().to_string(), + "1.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("1500", 3).unwrap(); + assert_eq!(fp4.trunc().to_string(), "1.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("1499", 3).unwrap(); + assert_eq!(fp5.trunc().to_string(), "1.000"); + let fp6: FixedDecimal = + FixedDecimal::from_str("-11234567890123456789012345678901234", 34).unwrap(); + assert_eq!( + fp6.trunc().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp2: FixedDecimal = + FixedDecimal::from_str("-14999999999999999999999999999999999", 34).unwrap(); + assert_eq!( + fp2.trunc().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp3: FixedDecimal = + FixedDecimal::from_str("-15000000000000000000000000000000000", 34).unwrap(); + assert_eq!( + fp3.trunc().to_string(), + "-1.0000000000000000000000000000000000" + ); + let fp4: FixedDecimal = FixedDecimal::from_str("-1500", 3).unwrap(); + assert_eq!(fp4.trunc().to_string(), "-1.000"); + let fp5: FixedDecimal = FixedDecimal::from_str("-1499", 3).unwrap(); + assert_eq!(fp5.trunc().to_string(), "-1.000"); + let fp6: FixedDecimal = FixedDecimal::from_str("1000", 3).unwrap(); + assert_eq!(fp6.trunc().to_string(), "1.000"); + let fp7: FixedDecimal = FixedDecimal::from_str("-1000", 3).unwrap(); + assert_eq!(fp7.trunc().to_string(), "-1.000"); + } + #[test] fn golden_tests() { let mut data_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); @@ -172,20 +392,20 @@ mod tests { let file = File::open(data_path).expect("golden_tests_result.txt: file not found"); let result_reader = std::io::BufReader::new(file); - let one: Decimal = Decimal::from(1u64); - let ten: Decimal = Decimal::from(10u64); - let f: Decimal = &one / &ten; + let one: FixedDecimal = FixedDecimal::from(1u64); + let ten: FixedDecimal = FixedDecimal::from(10u64); + let f: FixedDecimal = &one / &ten; assert_eq!(f.to_string(), "0.1000000000000000000000000000000000"); for (test_line, result_line) in reader.lines().zip(result_reader.lines()) { let test_line = test_line.expect("failed to read line"); // println!("test_line: {}", test_line); let mut parts = test_line.split_whitespace(); - let x = Decimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) + let x = FixedDecimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) .expect("failed to parse x"); - let a = Decimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) + let a = FixedDecimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) .expect("failed to parse a"); - let b = Decimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) + let b = FixedDecimal::from_str(parts.next().unwrap(), DEFAULT_PRECISION) .expect("failed to parse b"); let result_line = result_line.expect("failed to read line"); // println!("result_line: {}", result_line); @@ -210,7 +430,13 @@ mod tests { let c = &one - &f; assert_eq!(c.to_string(), "0.9000000000000000000000000000000000"); let threshold_b = c.pow(&b); - assert_eq!((&one - &threshold_b).to_string(), expected_threshold_b); + assert_eq!( + (&one - &threshold_b).to_string(), + expected_threshold_b, + "(1 - f) *** b failed to match! - (1 - f)={}, b={}", + &c, + &b + ); // do Taylor approximation for // a < 1 - (1 - f) *** b <=> 1/(1-a) < exp(-b * ln' (1 - f)) diff --git a/pallas-math/src/math_gmp.rs b/pallas-math/src/math_gmp.rs deleted file mode 100644 index 3a3b2e22..00000000 --- a/pallas-math/src/math_gmp.rs +++ /dev/null @@ -1,1062 +0,0 @@ -/*! -# Cardano Math functions using the GNU Multiple Precision Arithmetic Library (GMP) - */ - -use std::cmp::Ordering; -use std::ffi::{CStr, CString}; -use std::fmt::{Display, Formatter}; -use std::mem::MaybeUninit; -use std::ops::{Div, Mul, Neg, Sub}; -use std::ptr::null_mut; - -use gmp_mpfr_sys::gmp::{ - mpz_add, mpz_cdiv_q, mpz_clear, mpz_cmp, mpz_cmpabs, mpz_get_str, mpz_get_ui, mpz_init, - mpz_init_set_ui, mpz_mul, mpz_mul_si, mpz_mul_ui, mpz_neg, mpz_pow_ui, mpz_ptr, mpz_set, - mpz_set_si, mpz_set_str, mpz_set_ui, mpz_srcptr, mpz_sub, mpz_sub_ui, mpz_t, mpz_tdiv_q_ui, - mpz_tdiv_qr, -}; -use gmp_mpfr_sys::mpc::free_str; -use once_cell::sync::Lazy; -use regex::Regex; - -use crate::math::{Error, ExpCmpOrdering, ExpOrdering, FixedPrecision, DEFAULT_PRECISION}; - -#[derive(Debug, Clone)] -pub struct Decimal { - precision: u64, - precision_multiplier: mpz_t, - data: mpz_t, -} - -impl Drop for Decimal { - fn drop(&mut self) { - unsafe { - mpz_clear(&mut self.precision_multiplier); - mpz_clear(&mut self.data); - } - } -} - -impl PartialEq for Decimal { - fn eq(&self, other: &Self) -> bool { - unsafe { - self.precision == other.precision - && mpz_cmp(&self.precision_multiplier, &other.precision_multiplier) == 0 - && mpz_cmp(&self.data, &other.data) == 0 - } - } -} - -impl PartialOrd for Decimal { - fn partial_cmp(&self, other: &Self) -> Option { - unsafe { - if self.precision != other.precision - || mpz_cmp(&self.precision_multiplier, &other.precision_multiplier) != 0 - { - return None; - } - match mpz_cmp(&self.data, &other.data) { - cmp if cmp < 0 => Some(Ordering::Less), - cmp if cmp > 0 => Some(Ordering::Greater), - _ => Some(Ordering::Equal), - } - } - } -} - -impl Display for Decimal { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - unsafe { - write!( - f, - "{}", - print_fixedp( - &self.data, - &self.precision_multiplier, - self.precision as usize, - ) - ) - } - } -} - -impl From for Decimal { - fn from(n: u64) -> Self { - unsafe { - let mut result = Decimal::new(DEFAULT_PRECISION); - mpz_set_ui(&mut result.data, n); - mpz_mul(&mut result.data, &result.data, &result.precision_multiplier); - result - } - } -} - -impl From for Decimal { - fn from(n: i64) -> Self { - unsafe { - let mut result = Decimal::new(DEFAULT_PRECISION); - mpz_set_si(&mut result.data, n); - mpz_mul(&mut result.data, &result.data, &result.precision_multiplier); - result - } - } -} - -impl From<&mpz_t> for Decimal { - fn from(n: &mpz_t) -> Self { - unsafe { - let mut result = Decimal::new(DEFAULT_PRECISION); - mpz_set(&mut result.data, n); - result - } - } -} - -impl Neg for Decimal { - type Output = Self; - - fn neg(self) -> Self::Output { - unsafe { - let mut result = Decimal::new(self.precision); - mpz_neg(&mut result.data, &self.data); - result - } - } -} - -impl Mul for Decimal { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - unsafe { - let mut result = Decimal::new(self.precision); - mpz_mul(&mut result.data, &self.data, &rhs.data); - scale(&mut result.data); - result - } - } -} - -// Implement Mul for a reference to Decimal -impl<'a, 'b> Mul<&'b Decimal> for &'a Decimal { - type Output = Decimal; - - fn mul(self, rhs: &'b Decimal) -> Self::Output { - unsafe { - let mut result = Decimal::new(self.precision); - mpz_mul(&mut result.data, &self.data, &rhs.data); - scale(&mut result.data); - result - } - } -} - -impl Div for Decimal { - type Output = Self; - - fn div(self, rhs: Self) -> Self::Output { - unsafe { - let mut result = Decimal::new(self.precision); - div(&mut result.data, &self.data, &rhs.data); - result - } - } -} - -// Implement Div for a reference to Decimal -impl<'a, 'b> Div<&'b Decimal> for &'a Decimal { - type Output = Decimal; - - fn div(self, rhs: &'b Decimal) -> Self::Output { - unsafe { - let mut result = Decimal::new(self.precision); - div(&mut result.data, &self.data, &rhs.data); - result - } - } -} - -impl Sub for Decimal { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - unsafe { - let mut result = Decimal::new(self.precision); - mpz_sub(&mut result.data, &self.data, &rhs.data); - result - } - } -} - -// Implement Sub for a reference to Decimal -impl<'a, 'b> Sub<&'b Decimal> for &'a Decimal { - type Output = Decimal; - - fn sub(self, rhs: &'b Decimal) -> Self::Output { - unsafe { - let mut result = Decimal::new(self.precision); - mpz_sub(&mut result.data, &self.data, &rhs.data); - result - } - } -} - -impl FixedPrecision for Decimal { - fn new(precision: u64) -> Self { - unsafe { - let precision_multiplier: mpz_t = { - let mut precision_multiplier = MaybeUninit::uninit(); - mpz_init(precision_multiplier.as_mut_ptr()); - mpz_pow_ui(precision_multiplier.as_mut_ptr(), &TEN.value, precision); - precision_multiplier.assume_init() - }; - let data: mpz_t = { - let mut data = MaybeUninit::uninit(); - mpz_init_set_ui(data.as_mut_ptr(), 0); - data.assume_init() - }; - Decimal { - precision, - precision_multiplier, - data, - } - } - } - - fn from_str(s: &str, precision: u64) -> Result { - unsafe { - // assert that s contains only digits using a regex - if !DIGITS_REGEX.is_match(s) { - return Err(Error::RegexFailure(regex::Error::Syntax( - "string contained non-digits".to_string(), - ))); - } - - let mut decimal = Decimal::new(precision); - let c_string = CString::new(s)?; - mpz_set_str(&mut decimal.data, c_string.as_ptr(), 10); - Ok(decimal) - } - } - - fn precision(&self) -> u64 { - self.precision - } - - fn exp(&self) -> Self { - unsafe { - let mut exp_x = Decimal::new(self.precision); - ref_exp(&mut exp_x.data, &self.data); - exp_x - } - } - - fn ln(&self) -> Self { - unsafe { - let mut ln_x = Decimal::new(self.precision); - ref_ln(&mut ln_x.data, &self.data); - ln_x - } - } - - fn pow(&self, rhs: &Self) -> Self { - unsafe { - let mut pow_x = Decimal::new(self.precision); - ref_pow(&mut pow_x.data, &self.data, &rhs.data); - pow_x - } - } - - fn exp_cmp(&self, max_n: u64, bound_self: i64, compare: &Self) -> ExpCmpOrdering { - unsafe { - let mut output = Decimal::new(self.precision); - ref_exp_cmp( - &mut output.data, - max_n, - &self.data, - bound_self, - &compare.data, - ) - } - } -} - -/// # Safety -/// This function is unsafe because it dereferences raw pointers. -unsafe fn print_fixedp(n: &mpz_t, precision: &mpz_t, width: usize) -> String { - let mut temp_r: mpz_t = { - let mut temp_r = MaybeUninit::uninit(); - mpz_init(temp_r.as_mut_ptr()); - temp_r.assume_init() - }; - let mut temp_q: mpz_t = { - let mut temp_q = MaybeUninit::uninit(); - mpz_init(temp_q.as_mut_ptr()); - temp_q.assume_init() - }; - // use truncate rounding here for consistency - mpz_tdiv_qr(&mut temp_q, &mut temp_r, n, precision); - - let is_negative_q = mpz_cmp(&temp_q, &ZERO.value) < 0; - let is_negative_r = mpz_cmp(&temp_r, &ZERO.value) < 0; - - if is_negative_q { - mpz_neg(&mut temp_q, &temp_q); - } - if is_negative_r { - mpz_neg(&mut temp_r, &temp_r); - } - - let mut s = String::new(); - if is_negative_q || is_negative_r { - s.push('-'); - } - let q_char_c = mpz_get_str(null_mut(), 10, &temp_q); - let r_char_c = mpz_get_str(null_mut(), 10, &temp_r); - let q_cstr = CStr::from_ptr(q_char_c); - let r_cstr = CStr::from_ptr(r_char_c); - let r_len = r_cstr.to_bytes().len(); - s.push_str(q_cstr.to_str().unwrap()); - s.push('.'); - // fill with zeroes up to width for the fractional part - if r_len < width { - s.push_str(&"0".repeat(width - r_len)); - } - s.push_str(r_cstr.to_str().unwrap()); - - free_str(q_char_c); - free_str(r_char_c); - - mpz_clear(&mut temp_r); - mpz_clear(&mut temp_q); - - s -} - -struct Constant { - value: mpz_t, -} - -impl Constant { - pub fn new(init: fn() -> mpz_t) -> Constant { - Constant { value: init() } - } -} - -impl Drop for Constant { - fn drop(&mut self) { - unsafe { - mpz_clear(&mut self.value); - } - } -} - -unsafe impl Sync for Constant {} -unsafe impl Send for Constant {} - -static DIGITS_REGEX: Lazy = Lazy::new(|| Regex::new(r"^-?\d+$").unwrap()); - -static TEN: Lazy = Lazy::new(|| { - Constant::new(|| unsafe { - let mut ten: mpz_t = { - let mut ten = MaybeUninit::uninit(); - mpz_init(ten.as_mut_ptr()); - ten.assume_init() - }; - mpz_set_ui(&mut ten, 10); - ten - }) -}); - -static PRECISION: Lazy = Lazy::new(|| { - Constant::new(|| unsafe { - let mut precision: mpz_t = { - let mut precision = MaybeUninit::uninit(); - mpz_init(precision.as_mut_ptr()); - precision.assume_init() - }; - mpz_pow_ui(&mut precision, &TEN.value, 34); - precision - }) -}); - -static EPS: Lazy = Lazy::new(|| { - Constant::new(|| unsafe { - let mut epsilon: mpz_t = { - let mut epsilon = MaybeUninit::uninit(); - mpz_init(epsilon.as_mut_ptr()); - epsilon.assume_init() - }; - mpz_pow_ui(&mut epsilon, &TEN.value, 34 - 24); - epsilon - }) -}); - -static _RESOLUTION: Lazy = Lazy::new(|| { - Constant::new(|| unsafe { - let mut resolution: mpz_t = { - let mut resolution = MaybeUninit::uninit(); - mpz_init(resolution.as_mut_ptr()); - resolution.assume_init() - }; - mpz_pow_ui(&mut resolution, &TEN.value, 17); - resolution - }) -}); - -static ONE: Lazy = Lazy::new(|| { - Constant::new(|| unsafe { - let mut one: mpz_t = { - let mut one = MaybeUninit::uninit(); - mpz_init(one.as_mut_ptr()); - one.assume_init() - }; - mpz_set_ui(&mut one, 1); - mpz_mul(&mut one, &one, &PRECISION.value); - one - }) -}); - -static ZERO: Lazy = Lazy::new(|| { - Constant::new(|| unsafe { - let mut zero: mpz_t = { - let mut zero = MaybeUninit::uninit(); - mpz_init(zero.as_mut_ptr()); - zero.assume_init() - }; - mpz_set_ui(&mut zero, 0); - zero - }) -}); - -static E: Lazy = Lazy::new(|| { - Constant::new(|| unsafe { - let mut e: mpz_t = { - let mut e = MaybeUninit::uninit(); - mpz_init(e.as_mut_ptr()); - e.assume_init() - }; - ref_exp(&mut e, &ONE.value); - e - }) -}); - -/// Entry point for 'exp' approximation. First does the scaling of 'x' to [0,1] -/// and then calls the continued fraction approximation function. -/// -/// # Safety -/// This function is unsafe because it dereferences raw pointers. -unsafe fn ref_exp(rop: mpz_ptr, x: mpz_srcptr) -> i32 { - let mut iterations = 0; - - match mpz_cmp(x, &ZERO.value) { - 0 => mpz_set(rop, &ONE.value), - v if v < 0 => { - let mut x_: mpz_t = { - let mut x_ = MaybeUninit::uninit(); - mpz_init(x_.as_mut_ptr()); - x_.assume_init() - }; - mpz_neg(&mut x_, x); - let mut temp: mpz_t = { - let mut temp = MaybeUninit::uninit(); - mpz_init(temp.as_mut_ptr()); - temp.assume_init() - }; - - iterations = ref_exp(&mut temp, &x_); - - div(rop, &ONE.value, &temp); - - mpz_clear(&mut x_); - mpz_clear(&mut temp); - } - _ => { - let mut n_exponent: mpz_t = { - let mut n_exponent = MaybeUninit::uninit(); - mpz_init(n_exponent.as_mut_ptr()); - n_exponent.assume_init() - }; - let mut x_: mpz_t = { - let mut x_ = MaybeUninit::uninit(); - mpz_init(x_.as_mut_ptr()); - x_.assume_init() - }; - let mut temp_r: mpz_t = { - let mut temp_r = MaybeUninit::uninit(); - mpz_init(temp_r.as_mut_ptr()); - temp_r.assume_init() - }; - let mut temp_q: mpz_t = { - let mut temp_q = MaybeUninit::uninit(); - mpz_init(temp_q.as_mut_ptr()); - temp_q.assume_init() - }; - - mpz_cdiv_q(&mut n_exponent, x, &PRECISION.value); - let n = mpz_get_ui(&n_exponent); - mpz_mul(&mut n_exponent, &n_exponent, &PRECISION.value); /* ceil(x) */ - - mpz_tdiv_q_ui(&mut x_, x, n); - iterations = mp_exp_taylor(rop, 1000, &x_, &EPS.value); - - ipow(rop, &*rop, n as i64); - mpz_clear(&mut n_exponent); - mpz_clear(&mut x_); - mpz_clear(&mut temp_r); - mpz_clear(&mut temp_q); - } - } - - iterations -} - -/// Division with quotent and remainder -/// -/// # Safety -/// This function is unsafe because it dereferences raw pointers. -pub unsafe fn div_qr(q: mpz_ptr, r: mpz_ptr, x: &mpz_t, y: &mpz_t) { - let mut temp_r: mpz_t = { - let mut temp_r = MaybeUninit::uninit(); - mpz_init(temp_r.as_mut_ptr()); - temp_r.assume_init() - }; - let mut temp_q: mpz_t = { - let mut temp_q = MaybeUninit::uninit(); - mpz_init(temp_q.as_mut_ptr()); - temp_q.assume_init() - }; - mpz_tdiv_qr(&mut temp_q, &mut temp_r, x, y); - mpz_set(r, &temp_r); - mpz_set(q, &temp_q); - - mpz_clear(&mut temp_r); - mpz_clear(&mut temp_q); -} - -/// Division -/// -/// # Safety -/// This function is unsafe because it dereferences raw pointers. -pub unsafe fn div(rop: mpz_ptr, x: &mpz_t, y: &mpz_t) { - let mut temp_r: mpz_t = { - let mut temp_r = MaybeUninit::uninit(); - mpz_init(temp_r.as_mut_ptr()); - temp_r.assume_init() - }; - let mut temp_q: mpz_t = { - let mut temp_q = MaybeUninit::uninit(); - mpz_init(temp_q.as_mut_ptr()); - temp_q.assume_init() - }; - let mut temp: mpz_t = { - let mut temp = MaybeUninit::uninit(); - mpz_init(temp.as_mut_ptr()); - temp.assume_init() - }; - - div_qr(&mut temp_q, &mut temp_r, x, y); - - mpz_mul(&mut temp, &temp_q, &PRECISION.value); - mpz_mul(&mut temp_r, &temp_r, &PRECISION.value); - div_qr(&mut temp_q, &mut temp_r, &temp_r, y); - - mpz_add(&mut temp, &temp, &temp_q); - mpz_set(rop, &temp); - - mpz_clear(&mut temp_r); - mpz_clear(&mut temp_q); - mpz_clear(&mut temp); -} -/// Taylor / MacLaurin series approximation -/// -/// # Safety -/// This function is unsafe because it dereferences raw pointers. -pub unsafe fn mp_exp_taylor(rop: mpz_ptr, max_n: i32, x: &mpz_t, epsilon: &mpz_t) -> i32 { - let mut divisor: mpz_t = { - let mut divisor = MaybeUninit::uninit(); - mpz_init(divisor.as_mut_ptr()); - divisor.assume_init() - }; - mpz_set(&mut divisor, &ONE.value); - let mut last_x: mpz_t = { - let mut last_x = MaybeUninit::uninit(); - mpz_init(last_x.as_mut_ptr()); - last_x.assume_init() - }; - mpz_set(&mut last_x, &ONE.value); - let mut next_x: mpz_t = { - let mut next_x = MaybeUninit::uninit(); - mpz_init(next_x.as_mut_ptr()); - next_x.assume_init() - }; - mpz_set(rop, &ONE.value); - let mut n = 0; - while n < max_n { - mpz_mul(&mut next_x, x, &last_x); - scale(&mut next_x); - div(&mut next_x, &next_x, &divisor); - - if mpz_cmpabs(&next_x, epsilon) < 0 { - break; - } - - mpz_add(&mut divisor, &divisor, &ONE.value); - mpz_add(rop, rop, &next_x); - - mpz_set(&mut last_x, &next_x); - n += 1; - } - - mpz_clear(&mut divisor); - mpz_clear(&mut last_x); - mpz_clear(&mut next_x); - n -} - -/// #Safety -/// This function is unsafe because it dereferences raw pointers. -unsafe fn scale(rop: mpz_ptr) { - let mut temp: mpz_t = { - let mut temp = MaybeUninit::uninit(); - mpz_init(temp.as_mut_ptr()); - temp.assume_init() - }; - let mut a: mpz_t = { - let mut a = MaybeUninit::uninit(); - mpz_init(a.as_mut_ptr()); - a.assume_init() - }; - - div_qr(&mut a, &mut temp, &*rop, &PRECISION.value); - if mpz_cmp(rop, &ZERO.value) < 0 && mpz_cmp(&temp, &ZERO.value) != 0 { - mpz_sub_ui(&mut a, &a, 1); - } - - mpz_set(rop, &a); - mpz_clear(&mut temp); - mpz_clear(&mut a); -} - -/// Integer power internal function -/// -/// # Safety -/// This function is unsafe because it dereferences raw pointers. -unsafe fn ipow_(rop: mpz_ptr, x: &mpz_t, n: i64) { - if n == 0 { - mpz_set(rop, &ONE.value); - } else if n % 2 == 0 { - let mut res: mpz_t = { - let mut res = MaybeUninit::uninit(); - mpz_init(res.as_mut_ptr()); - res.assume_init() - }; - ipow_(&mut res, x, n / 2); - mpz_mul(rop, &res, &res); - scale(rop); - mpz_clear(&mut res); - } else { - let mut res: mpz_t = { - let mut res = MaybeUninit::uninit(); - mpz_init(res.as_mut_ptr()); - res.assume_init() - }; - ipow_(&mut res, x, n - 1); - mpz_mul(rop, &res, x); - scale(rop); - mpz_clear(&mut res); - } -} - -/// Integer power -/// -/// # Safety -/// This function is unsafe because it dereferences raw pointers. -pub unsafe fn ipow(rop: mpz_ptr, x: &mpz_t, n: i64) { - if n < 0 { - let mut temp: mpz_t = { - let mut temp = MaybeUninit::uninit(); - mpz_init(temp.as_mut_ptr()); - temp.assume_init() - }; - ipow_(&mut temp, x, -n); - div(rop, &ONE.value, &temp); - mpz_clear(&mut temp); - } else { - ipow_(rop, x, n); - } -} - -/// Compute an approximation of 'ln(1 + x)' via continued fractions. Either for a -/// maximum of 'maxN' iterations or until the absolute difference between two -/// succeeding convergents is smaller than 'eps'. Assumes 'x' to be within -/// [1,e). -unsafe fn mp_ln_n(rop: mpz_ptr, max_n: i32, x: &mpz_t, epsilon: &mpz_t) { - let mut an_m2: mpz_t = { - let mut an_m2 = MaybeUninit::uninit(); - mpz_init(an_m2.as_mut_ptr()); - an_m2.assume_init() - }; - let mut bn_m2: mpz_t = { - let mut bn_m2 = MaybeUninit::uninit(); - mpz_init(bn_m2.as_mut_ptr()); - bn_m2.assume_init() - }; - let mut an_m1: mpz_t = { - let mut an_m1 = MaybeUninit::uninit(); - mpz_init(an_m1.as_mut_ptr()); - an_m1.assume_init() - }; - let mut bn_m1: mpz_t = { - let mut bn_m1 = MaybeUninit::uninit(); - mpz_init(bn_m1.as_mut_ptr()); - bn_m1.assume_init() - }; - let mut ba: mpz_t = { - let mut ba = MaybeUninit::uninit(); - mpz_init(ba.as_mut_ptr()); - ba.assume_init() - }; - let mut aa: mpz_t = { - let mut aa = MaybeUninit::uninit(); - mpz_init(aa.as_mut_ptr()); - aa.assume_init() - }; - let mut a_: mpz_t = { - let mut a_ = MaybeUninit::uninit(); - mpz_init(a_.as_mut_ptr()); - a_.assume_init() - }; - let mut bb: mpz_t = { - let mut bb = MaybeUninit::uninit(); - mpz_init(bb.as_mut_ptr()); - bb.assume_init() - }; - let mut ab: mpz_t = { - let mut ab = MaybeUninit::uninit(); - mpz_init(ab.as_mut_ptr()); - ab.assume_init() - }; - let mut b_: mpz_t = { - let mut b_ = MaybeUninit::uninit(); - mpz_init(b_.as_mut_ptr()); - b_.assume_init() - }; - let mut convergent: mpz_t = { - let mut convergent = MaybeUninit::uninit(); - mpz_init(convergent.as_mut_ptr()); - convergent.assume_init() - }; - let mut last: mpz_t = { - let mut last = MaybeUninit::uninit(); - mpz_init(last.as_mut_ptr()); - last.assume_init() - }; - let mut a: mpz_t = { - let mut a = MaybeUninit::uninit(); - mpz_init(a.as_mut_ptr()); - a.assume_init() - }; - let mut b: mpz_t = { - let mut b = MaybeUninit::uninit(); - mpz_init(b.as_mut_ptr()); - b.assume_init() - }; - let mut diff: mpz_t = { - let mut diff = MaybeUninit::uninit(); - mpz_init(diff.as_mut_ptr()); - diff.assume_init() - }; - - let mut first = true; - let mut n = 1; - - mpz_set(&mut a, x); - mpz_set(&mut b, &ONE.value); - - mpz_set(&mut an_m2, &ONE.value); - mpz_set_ui(&mut bn_m2, 0); - mpz_set_ui(&mut an_m1, 0); - mpz_set(&mut bn_m1, &ONE.value); - - let mut curr_a = 1; - - while n <= max_n + 2 { - let curr_a_2 = curr_a * curr_a; - mpz_mul_ui(&mut a, x, curr_a_2); - if n > 1 && n % 2 == 1 { - curr_a += 1; - } - - mpz_mul(&mut ba, &b, &an_m1); - scale(&mut ba); - mpz_mul(&mut aa, &a, &an_m2); - scale(&mut aa); - mpz_add(&mut a_, &ba, &aa); - - mpz_mul(&mut bb, &b, &bn_m1); - scale(&mut bb); - mpz_mul(&mut ab, &a, &bn_m2); - scale(&mut ab); - mpz_add(&mut b_, &bb, &ab); - - div(&mut convergent, &a_, &b_); - - if first { - first = false; - } else { - mpz_sub(&mut diff, &convergent, &last); - if mpz_cmpabs(&diff, epsilon) < 0 { - break; - } - } - - mpz_set(&mut last, &convergent); - - n += 1; - mpz_set(&mut an_m2, &an_m1); - mpz_set(&mut bn_m2, &bn_m1); - mpz_set(&mut an_m1, &a_); - mpz_set(&mut bn_m1, &b_); - - mpz_add(&mut b, &b, &ONE.value); - } - - mpz_set(rop, &convergent); - - mpz_clear(&mut an_m2); - mpz_clear(&mut bn_m2); - mpz_clear(&mut an_m1); - mpz_clear(&mut bn_m1); - mpz_clear(&mut ba); - mpz_clear(&mut aa); - mpz_clear(&mut bb); - mpz_clear(&mut ab); - mpz_clear(&mut a_); - mpz_clear(&mut b_); - mpz_clear(&mut a); - mpz_clear(&mut b); - mpz_clear(&mut diff); - mpz_clear(&mut convergent); - mpz_clear(&mut last); -} - -unsafe fn find_e(x: &mpz_t) -> i64 { - let mut x_: mpz_t = { - let mut x_ = MaybeUninit::uninit(); - mpz_init(x_.as_mut_ptr()); - x_.assume_init() - }; - let mut x__: mpz_t = { - let mut x__ = MaybeUninit::uninit(); - mpz_init(x__.as_mut_ptr()); - x__.assume_init() - }; - - div(&mut x_, &ONE.value, &E.value); - mpz_set(&mut x__, &E.value); - - let mut l = -1; - let mut u = 1; - while mpz_cmp(&x_, x) > 0 || mpz_cmp(&x__, x) < 0 { - mpz_mul(&mut x_, &x_, &x_); - scale(&mut x_); - - mpz_mul(&mut x__, &x__, &x__); - scale(&mut x__); - - l *= 2; - u *= 2; - } - - while l + 1 != u { - let mid = l + ((u - l) / 2); - - ipow(&mut x_, &E.value, mid); - if mpz_cmp(x, &x_) < 0 { - u = mid; - } else { - l = mid; - } - } - - mpz_clear(&mut x_); - mpz_clear(&mut x__); - l -} - -/// Entry point for 'ln' approximation. First does the necessary scaling, and -/// then calls the continued fraction calculation. For any value outside the -/// domain, i.e., 'x in (-inf,0]', the function returns '-INFINITY'. -unsafe fn ref_ln(rop: mpz_ptr, x: &mpz_t) -> bool { - if mpz_cmp(x, &ZERO.value) <= 0 { - return false; - } - - let n = find_e(x); - - let mut temp_r: mpz_t = { - let mut temp_r = MaybeUninit::uninit(); - mpz_init(temp_r.as_mut_ptr()); - temp_r.assume_init() - }; - let mut temp_q: mpz_t = { - let mut temp_q = MaybeUninit::uninit(); - mpz_init(temp_q.as_mut_ptr()); - temp_q.assume_init() - }; - let mut x_: mpz_t = { - let mut x_ = MaybeUninit::uninit(); - mpz_init(x_.as_mut_ptr()); - x_.assume_init() - }; - let mut factor: mpz_t = { - let mut factor = MaybeUninit::uninit(); - mpz_init(factor.as_mut_ptr()); - factor.assume_init() - }; - - mpz_set_si(rop, n); - mpz_mul(rop, rop, &PRECISION.value); - ref_exp(&mut factor, rop); - - div(&mut x_, x, &factor); - - mpz_sub(&mut x_, &x_, &ONE.value); - - mp_ln_n(&mut x_, 1000, &x_, &EPS.value); - mpz_add(rop, rop, &x_); - - mpz_clear(&mut temp_r); - mpz_clear(&mut temp_q); - mpz_clear(&mut x_); - mpz_clear(&mut factor); - - true -} - -unsafe fn ref_pow(rop: mpz_ptr, base: &mpz_t, exponent: &mpz_t) { - /* x^y = exp(y * ln x) */ - - let mut tmp: mpz_t = { - let mut tmp = MaybeUninit::uninit(); - mpz_init(tmp.as_mut_ptr()); - tmp.assume_init() - }; - - ref_ln(&mut tmp, base); - mpz_mul(&mut tmp, &tmp, exponent); - scale(&mut tmp); - ref_exp(rop, &tmp); - - mpz_clear(&mut tmp); -} - -/// `bound_x` is the bound for exp in the interval x is chosen from -/// `compare` the value to compare to -/// -/// if the result is GT, then the computed value is guaranteed to be greater, if -/// the result is LT, the computed value is guaranteed to be less than -/// `compare`. In the case of `UNKNOWN` no conclusion was possible for the -/// selected precision. -/// -/// Lagrange remainder require knowledge of the maximum value to compute the -/// maximal error of the remainder. -unsafe fn ref_exp_cmp( - rop: mpz_ptr, - max_n: u64, - x: &mpz_t, - bound_x: i64, - compare: &mpz_t, -) -> ExpCmpOrdering { - mpz_set(rop, &ONE.value); - let mut n = 0u64; - let mut divisor: mpz_t = { - let mut divisor = MaybeUninit::uninit(); - mpz_init(divisor.as_mut_ptr()); - divisor.assume_init() - }; - let mut next_x: mpz_t = { - let mut next_x = MaybeUninit::uninit(); - mpz_init(next_x.as_mut_ptr()); - next_x.assume_init() - }; - let mut error: mpz_t = { - let mut error = MaybeUninit::uninit(); - mpz_init(error.as_mut_ptr()); - error.assume_init() - }; - let mut upper: mpz_t = { - let mut upper = MaybeUninit::uninit(); - mpz_init(upper.as_mut_ptr()); - upper.assume_init() - }; - let mut lower: mpz_t = { - let mut lower = MaybeUninit::uninit(); - mpz_init(lower.as_mut_ptr()); - lower.assume_init() - }; - let mut error_term: mpz_t = { - let mut error_term = MaybeUninit::uninit(); - mpz_init(error_term.as_mut_ptr()); - error_term.assume_init() - }; - - mpz_set(&mut divisor, &ONE.value); - mpz_set(&mut error, x); - - let mut estimate = ExpOrdering::UNKNOWN; - while n < max_n { - mpz_set(&mut next_x, &error); - - if mpz_cmpabs(&next_x, &EPS.value) < 0 { - break; - } - - mpz_add(&mut divisor, &divisor, &ONE.value); - - // update error estimation, this is initially bound_x * x and in general - // bound_x * x^(n+1)/(n + 1)! we use `error` to store the x^n part and a - // single integral multiplication with the bound - mpz_mul(&mut error, &error, x); - scale(&mut error); - div(&mut error, &error, &divisor); - - mpz_mul_si(&mut error_term, &error, bound_x); - - mpz_add(rop, rop, &next_x); - - /* compare is guaranteed to be above overall result */ - mpz_add(&mut upper, rop, &error_term); - - if mpz_cmp(compare, &upper) > 0 { - estimate = ExpOrdering::GT; - n += 1; - break; - } - - mpz_sub(&mut lower, rop, &error_term); - - /* compare is guaranteed to be below overall result */ - if mpz_cmp(compare, &lower) < 0 { - estimate = ExpOrdering::LT; - n += 1; - break; - } - - n += 1; - } - - mpz_clear(&mut divisor); - mpz_clear(&mut next_x); - mpz_clear(&mut error); - mpz_clear(&mut upper); - mpz_clear(&mut lower); - mpz_clear(&mut error_term); - - ExpCmpOrdering { - iterations: n, - estimation: estimate, - approx: Decimal::from(&*rop), - } -} diff --git a/pallas-math/src/math_num.rs b/pallas-math/src/math_malachite.rs similarity index 59% rename from pallas-math/src/math_num.rs rename to pallas-math/src/math_malachite.rs index 0ae7ae37..824c46eb 100644 --- a/pallas-math/src/math_num.rs +++ b/pallas-math/src/math_malachite.rs @@ -2,24 +2,25 @@ # Cardano Math functions using the num-bigint crate */ +use crate::math::{Error, ExpCmpOrdering, ExpOrdering, FixedPrecision, DEFAULT_PRECISION}; +use malachite::num::arithmetic::traits::{Abs, DivRem, DivRound, Pow, PowAssign}; +use malachite::num::basic::traits::One; +use malachite::platform_64::Limb; +use malachite::rounding_modes::RoundingMode; +use malachite::{Integer, Natural}; +use malachite_base::num::arithmetic::traits::Sign; +use once_cell::sync::Lazy; +use regex::Regex; use std::cmp::Ordering; use std::fmt::{Display, Formatter}; -use std::ops::{Div, Mul, Neg, Sub}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}; use std::str::FromStr; -use num_bigint::BigInt; -use num_integer::Integer; -use num_traits::{Signed, ToPrimitive}; -use once_cell::sync::Lazy; -use regex::Regex; - -use crate::math::{Error, ExpCmpOrdering, ExpOrdering, FixedPrecision, DEFAULT_PRECISION}; - #[derive(Debug, Clone)] pub struct Decimal { precision: u64, - precision_multiplier: BigInt, - data: BigInt, + precision_multiplier: Integer, + data: Integer, } impl PartialEq for Decimal { @@ -58,7 +59,7 @@ impl Display for Decimal { impl From for Decimal { fn from(n: u64) -> Self { let mut result = Decimal::new(DEFAULT_PRECISION); - result.data = BigInt::from(n) * &result.precision_multiplier; + result.data = Integer::from(n) * &result.precision_multiplier; result } } @@ -66,19 +67,53 @@ impl From for Decimal { impl From for Decimal { fn from(n: i64) -> Self { let mut result = Decimal::new(DEFAULT_PRECISION); - result.data = BigInt::from(n) * &result.precision_multiplier; + result.data = Integer::from(n) * &result.precision_multiplier; + result + } +} + +impl From for Decimal { + fn from(n: Integer) -> Self { + let mut result = Decimal::new(DEFAULT_PRECISION); + result.data = n * &result.precision_multiplier; + result + } +} + +impl From<&Integer> for Decimal { + fn from(n: &Integer) -> Self { + let mut result = Decimal::new(DEFAULT_PRECISION); + result.data = n * &result.precision_multiplier; + result + } +} + +impl From for Decimal { + fn from(n: Natural) -> Self { + let mut result = Decimal::new(DEFAULT_PRECISION); + result.data = Integer::from(n) * &result.precision_multiplier; result } } -impl From<&BigInt> for Decimal { - fn from(n: &BigInt) -> Self { +impl From<&Natural> for Decimal { + fn from(n: &Natural) -> Self { let mut result = Decimal::new(DEFAULT_PRECISION); - result.data.clone_from(n); + result.data = Integer::from(n) * &result.precision_multiplier; result } } +impl From<&[u8]> for Decimal { + fn from(n: &[u8]) -> Self { + let limbs = n + .chunks(size_of::()) + .map(|chunk| Limb::from_be_bytes(chunk.try_into().expect("Infallible"))) + .collect(); + Decimal::from(Natural::from_owned_limbs_desc(limbs)) + } +} + impl Neg for Decimal { type Output = Self; @@ -89,6 +124,17 @@ impl Neg for Decimal { } } +// Implement Neg for a reference to Decimal +impl<'a> Neg for &'a Decimal { + type Output = Decimal; + + fn neg(self) -> Self::Output { + let mut result = Decimal::new(self.precision); + result.data = -&self.data; + result + } +} + impl Mul for Decimal { type Output = Self; @@ -100,6 +146,13 @@ impl Mul for Decimal { } } +impl MulAssign for Decimal { + fn mul_assign(&mut self, rhs: Self) { + self.data *= &rhs.data; + scale(&mut self.data); + } +} + // Implement Mul for a reference to Decimal impl<'a, 'b> Mul<&'b Decimal> for &'a Decimal { type Output = Decimal; @@ -112,6 +165,13 @@ impl<'a, 'b> Mul<&'b Decimal> for &'a Decimal { } } +impl<'a, 'b> MulAssign<&'b Decimal> for &'a mut Decimal { + fn mul_assign(&mut self, rhs: &'b Decimal) { + self.data *= &rhs.data; + scale(&mut self.data); + } +} + impl Div for Decimal { type Output = Self; @@ -122,6 +182,13 @@ impl Div for Decimal { } } +impl DivAssign for Decimal { + fn div_assign(&mut self, rhs: Self) { + let temp = self.data.clone(); + div(&mut self.data, &temp, &rhs.data); + } +} + // Implement Div for a reference to Decimal impl<'a, 'b> Div<&'b Decimal> for &'a Decimal { type Output = Decimal; @@ -133,6 +200,13 @@ impl<'a, 'b> Div<&'b Decimal> for &'a Decimal { } } +impl<'a, 'b> DivAssign<&'b Decimal> for &'a mut Decimal { + fn div_assign(&mut self, rhs: &'b Decimal) { + let temp = self.data.clone(); + div(&mut self.data, &temp, &rhs.data); + } +} + impl Sub for Decimal { type Output = Self; @@ -143,6 +217,12 @@ impl Sub for Decimal { } } +impl SubAssign for Decimal { + fn sub_assign(&mut self, rhs: Self) { + self.data -= &rhs.data; + } +} + // Implement Sub for a reference to Decimal impl<'a, 'b> Sub<&'b Decimal> for &'a Decimal { type Output = Decimal; @@ -154,11 +234,50 @@ impl<'a, 'b> Sub<&'b Decimal> for &'a Decimal { } } +impl<'a, 'b> SubAssign<&'b Decimal> for &'a mut Decimal { + fn sub_assign(&mut self, rhs: &'b Decimal) { + self.data -= &rhs.data; + } +} + +impl Add for Decimal { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + let mut result = Decimal::new(self.precision); + result.data = &self.data + &rhs.data; + result + } +} + +impl AddAssign for Decimal { + fn add_assign(&mut self, rhs: Self) { + self.data += &rhs.data; + } +} + +// Implement Add for a reference to Decimal +impl<'a, 'b> Add<&'b Decimal> for &'a Decimal { + type Output = Decimal; + + fn add(self, rhs: &'b Decimal) -> Self::Output { + let mut result = Decimal::new(self.precision); + result.data = &self.data + &rhs.data; + result + } +} + +impl<'a, 'b> AddAssign<&'b Decimal> for &'a mut Decimal { + fn add_assign(&mut self, rhs: &'b Decimal) { + self.data += &rhs.data; + } +} + impl FixedPrecision for Decimal { fn new(precision: u64) -> Self { - let ten = BigInt::from(10); - let precision_multiplier = ten.pow(precision as u32); - let data = BigInt::from(0); + let mut precision_multiplier = Integer::from(10); + precision_multiplier.pow_assign(precision); + let data = Integer::from(0); Decimal { precision, precision_multiplier, @@ -175,7 +294,7 @@ impl FixedPrecision for Decimal { } let mut decimal = Decimal::new(precision); - decimal.data = BigInt::from_str(s).unwrap(); + decimal.data = Integer::from_str(s).unwrap(); Ok(decimal) } @@ -211,9 +330,51 @@ impl FixedPrecision for Decimal { &compare.data, ) } + + fn round(&self) -> Self { + let mut result = self.clone(); + let half = &self.precision_multiplier / Integer::from(2); + let remainder = &self.data % &self.precision_multiplier; + if (&remainder).abs() >= half { + if self.data.sign() == Ordering::Less { + result.data -= &self.precision_multiplier + remainder; + } else { + result.data += &self.precision_multiplier - remainder; + } + } else { + result.data -= remainder; + } + result + } + + fn floor(&self) -> Self { + let mut result = self.clone(); + let remainder = &self.data % &self.precision_multiplier; + if self.data.sign() == Ordering::Less && remainder != 0 { + result.data -= &self.precision_multiplier; + } + result.data -= remainder; + result + } + + fn ceil(&self) -> Self { + let mut result = self.clone(); + let remainder = &self.data % &self.precision_multiplier; + if self.data.sign() == Ordering::Greater && remainder != 0 { + result.data += &self.precision_multiplier; + } + result.data -= remainder; + result + } + + fn trunc(&self) -> Self { + let mut result = self.clone(); + result.data -= &self.data % &self.precision_multiplier; + result + } } -fn print_fixedp(n: &BigInt, precision: &BigInt, width: usize) -> String { +fn print_fixedp(n: &Integer, precision: &Integer, width: usize) -> String { let (mut temp_q, mut temp_r) = n.div_rem(precision); let is_negative_q = temp_q < ZERO.value; @@ -243,11 +404,11 @@ fn print_fixedp(n: &BigInt, precision: &BigInt, width: usize) -> String { } struct Constant { - value: BigInt, + value: Integer, } impl Constant { - pub fn new(init: fn() -> BigInt) -> Constant { + pub fn new(init: fn() -> Integer) -> Constant { Constant { value: init() } } } @@ -256,14 +417,14 @@ unsafe impl Sync for Constant {} unsafe impl Send for Constant {} static DIGITS_REGEX: Lazy = Lazy::new(|| Regex::new(r"^-?\d+$").unwrap()); -static TEN: Lazy = Lazy::new(|| Constant::new(|| BigInt::from(10))); -static PRECISION: Lazy = Lazy::new(|| Constant::new(|| TEN.value.pow(34))); -static EPS: Lazy = Lazy::new(|| Constant::new(|| TEN.value.pow(34 - 24))); -static ONE: Lazy = Lazy::new(|| Constant::new(|| BigInt::from(1) * &PRECISION.value)); -static ZERO: Lazy = Lazy::new(|| Constant::new(|| BigInt::from(0))); +static TEN: Lazy = Lazy::new(|| Constant::new(|| Integer::from(10))); +static PRECISION: Lazy = Lazy::new(|| Constant::new(|| TEN.value.clone().pow(34))); +static EPS: Lazy = Lazy::new(|| Constant::new(|| TEN.value.clone().pow(34 - 24))); +static ONE: Lazy = Lazy::new(|| Constant::new(|| Integer::from(1) * &PRECISION.value)); +static ZERO: Lazy = Lazy::new(|| Constant::new(|| Integer::from(0))); static E: Lazy = Lazy::new(|| { Constant::new(|| { - let mut e = BigInt::from(0); + let mut e = Integer::from(0); ref_exp(&mut e, &ONE.value); e }) @@ -271,29 +432,28 @@ static E: Lazy = Lazy::new(|| { /// Entry point for 'exp' approximation. First does the scaling of 'x' to [0,1] /// and then calls the continued fraction approximation function. -fn ref_exp(rop: &mut BigInt, x: &BigInt) -> i32 { +fn ref_exp(rop: &mut Integer, x: &Integer) -> i32 { let mut iterations = 0; match x.cmp(&ZERO.value) { - std::cmp::Ordering::Equal => { + Ordering::Equal => { // rop = 1 rop.clone_from(&ONE.value); } - std::cmp::Ordering::Less => { + Ordering::Less => { let x_ = -x; - let mut temp = BigInt::from(0); + let mut temp = Integer::from(0); iterations = ref_exp(&mut temp, &x_); // rop = 1 / temp div(rop, &ONE.value, &temp); } - std::cmp::Ordering::Greater => { - let mut n_exponent = x.div_ceil(&PRECISION.value); - let n = n_exponent.to_u32().expect("n_exponent to_u32 failed"); - n_exponent *= &PRECISION.value; /* ceil(x) */ - let x_ = x / n; + Ordering::Greater => { + let (n_exponent, _) = x.div_round(&PRECISION.value, RoundingMode::Ceiling); + let x_ = x / &n_exponent; iterations = mp_exp_taylor(rop, 1000, &x_, &EPS.value); // rop = rop.pow(n) - ipow(rop, &rop.clone(), n as i64); + let n_exponent_i64: i64 = i64::try_from(&n_exponent).expect("n_exponent to_i64 failed"); + ipow(rop, &rop.clone(), n_exponent_i64); } } @@ -302,15 +462,15 @@ fn ref_exp(rop: &mut BigInt, x: &BigInt) -> i32 { /// Division with quotent and remainder #[inline] -fn div_qr(q: &mut BigInt, r: &mut BigInt, x: &BigInt, y: &BigInt) { +fn div_qr(q: &mut Integer, r: &mut Integer, x: &Integer, y: &Integer) { (*q, *r) = x.div_rem(y); } /// Division -pub fn div(rop: &mut BigInt, x: &BigInt, y: &BigInt) { - let mut temp_q = BigInt::from(0); - let mut temp_r = BigInt::from(0); - let mut temp: BigInt; +pub fn div(rop: &mut Integer, x: &Integer, y: &Integer) { + let mut temp_q = Integer::from(0); + let mut temp_r = Integer::from(0); + let mut temp: Integer; div_qr(&mut temp_q, &mut temp_r, x, y); temp = &temp_q * &PRECISION.value; @@ -322,7 +482,7 @@ pub fn div(rop: &mut BigInt, x: &BigInt, y: &BigInt) { *rop = temp; } /// Taylor / MacLaurin series approximation -fn mp_exp_taylor(rop: &mut BigInt, max_n: i32, x: &BigInt, epsilon: &BigInt) -> i32 { +fn mp_exp_taylor(rop: &mut Integer, max_n: i32, x: &Integer, epsilon: &Integer) -> i32 { let mut divisor = ONE.value.clone(); let mut last_x = ONE.value.clone(); rop.clone_from(&ONE.value); @@ -333,12 +493,12 @@ fn mp_exp_taylor(rop: &mut BigInt, max_n: i32, x: &BigInt, epsilon: &BigInt) -> let next_x2 = next_x.clone(); div(&mut next_x, &next_x2, &divisor); - if next_x.abs() < epsilon.abs() { + if (&next_x).abs() < epsilon.abs() { break; } divisor += &ONE.value; - *rop += &next_x; + *rop = &*rop + &next_x; last_x.clone_from(&next_x); n += 1; } @@ -346,27 +506,27 @@ fn mp_exp_taylor(rop: &mut BigInt, max_n: i32, x: &BigInt, epsilon: &BigInt) -> n } -fn scale(rop: &mut BigInt) { - let mut temp = BigInt::from(0); - let mut a = BigInt::from(0); +pub(crate) fn scale(rop: &mut Integer) { + let mut temp = Integer::from(0); + let mut a = Integer::from(0); div_qr(&mut a, &mut temp, rop, &PRECISION.value); if *rop < ZERO.value && temp != ZERO.value { - a -= 1; + a -= Integer::ONE; } *rop = a; } /// Integer power internal function -fn ipow_(rop: &mut BigInt, x: &BigInt, n: i64) { +fn ipow_(rop: &mut Integer, x: &Integer, n: i64) { if n == 0 { rop.clone_from(&ONE.value); } else if n % 2 == 0 { - let mut res = BigInt::from(0); + let mut res = Integer::from(0); ipow_(&mut res, x, n / 2); *rop = &res * &res; scale(rop); } else { - let mut res = BigInt::from(0); + let mut res = Integer::from(0); ipow_(&mut res, x, n - 1); *rop = res * x; scale(rop); @@ -374,9 +534,9 @@ fn ipow_(rop: &mut BigInt, x: &BigInt, n: i64) { } /// Integer power -fn ipow(rop: &mut BigInt, x: &BigInt, n: i64) { +fn ipow(rop: &mut Integer, x: &Integer, n: i64) { if n < 0 { - let mut temp = BigInt::from(0); + let mut temp = Integer::from(0); ipow_(&mut temp, x, -n); div(rop, &ONE.value, &temp); } else { @@ -388,32 +548,32 @@ fn ipow(rop: &mut BigInt, x: &BigInt, n: i64) { /// maximum of 'maxN' iterations or until the absolute difference between two /// succeeding convergents is smaller than 'eps'. Assumes 'x' to be within /// [1,e). -fn mp_ln_n(rop: &mut BigInt, max_n: i32, x: &BigInt, epsilon: &BigInt) { - let mut ba: BigInt; - let mut aa: BigInt; - let mut ab: BigInt; - let mut bb: BigInt; - let mut a_: BigInt; - let mut b_: BigInt; - let mut diff: BigInt; - let mut convergent: BigInt = BigInt::from(0); - let mut last: BigInt = BigInt::from(0); +fn mp_ln_n(rop: &mut Integer, max_n: i32, x: &Integer, epsilon: &Integer) { + let mut ba: Integer; + let mut aa: Integer; + let mut ab: Integer; + let mut bb: Integer; + let mut a_: Integer; + let mut b_: Integer; + let mut diff: Integer; + let mut convergent: Integer = Integer::from(0); + let mut last: Integer = Integer::from(0); let mut first = true; let mut n = 1; - let mut a: BigInt; + let mut a: Integer; let mut b = ONE.value.clone(); let mut an_m2 = ONE.value.clone(); - let mut bn_m2 = BigInt::from(0); - let mut an_m1 = BigInt::from(0); + let mut bn_m2 = Integer::from(0); + let mut an_m1 = Integer::from(0); let mut bn_m1 = ONE.value.clone(); let mut curr_a = 1; while n <= max_n + 2 { let curr_a_2 = curr_a * curr_a; - a = x * curr_a_2; + a = x * Integer::from(curr_a_2); if n > 1 && n % 2 == 1 { curr_a += 1; } @@ -455,12 +615,11 @@ fn mp_ln_n(rop: &mut BigInt, max_n: i32, x: &BigInt, epsilon: &BigInt) { *rop = convergent; } -fn find_e(x: &BigInt) -> i64 { - let mut x_: BigInt = BigInt::from(0); - let mut x__: BigInt; +fn find_e(x: &Integer) -> i64 { + let mut x_: Integer = Integer::from(0); + let mut x__: Integer = E.value.clone(); div(&mut x_, &ONE.value, &E.value); - x__ = E.value.clone(); let mut l = -1; let mut u = 1; @@ -491,17 +650,17 @@ fn find_e(x: &BigInt) -> i64 { /// Entry point for 'ln' approximation. First does the necessary scaling, and /// then calls the continued fraction calculation. For any value outside the /// domain, i.e., 'x in (-inf,0]', the function returns '-INFINITY'. -fn ref_ln(rop: &mut BigInt, x: &BigInt) -> bool { - let mut factor = BigInt::from(0); - let mut x_ = BigInt::from(0); +fn ref_ln(rop: &mut Integer, x: &Integer) -> bool { + let mut factor = Integer::from(0); + let mut x_ = Integer::from(0); if x <= &ZERO.value { return false; } let n = find_e(x); - *rop = BigInt::from(n); - *rop = rop.clone() * &PRECISION.value; + *rop = Integer::from(n); + *rop = &*rop * &PRECISION.value; ref_exp(&mut factor, rop); div(&mut x_, x, &factor); @@ -510,14 +669,14 @@ fn ref_ln(rop: &mut BigInt, x: &BigInt) -> bool { let x_2 = x_.clone(); mp_ln_n(&mut x_, 1000, &x_2, &EPS.value); - *rop = rop.clone() + &x_; + *rop = &*rop + &x_; true } -fn ref_pow(rop: &mut BigInt, base: &BigInt, exponent: &BigInt) { +fn ref_pow(rop: &mut Integer, base: &Integer, exponent: &Integer) { /* x^y = exp(y * ln x) */ - let mut tmp: BigInt = BigInt::from(0); + let mut tmp: Integer = Integer::from(0); ref_ln(&mut tmp, base); tmp *= exponent; scale(&mut tmp); @@ -535,20 +694,20 @@ fn ref_pow(rop: &mut BigInt, base: &BigInt, exponent: &BigInt) { /// Lagrange remainder require knowledge of the maximum value to compute the /// maximal error of the remainder. fn ref_exp_cmp( - rop: &mut BigInt, + rop: &mut Integer, max_n: u64, - x: &BigInt, + x: &Integer, bound_x: i64, - compare: &BigInt, + compare: &Integer, ) -> ExpCmpOrdering { rop.clone_from(&ONE.value); let mut n = 0u64; - let mut divisor: BigInt; - let mut next_x: BigInt; - let mut error: BigInt; - let mut upper: BigInt; - let mut lower: BigInt; - let mut error_term: BigInt; + let mut divisor: Integer; + let mut next_x: Integer; + let mut error: Integer; + let mut upper: Integer; + let mut lower: Integer; + let mut error_term: Integer; divisor = ONE.value.clone(); error = x.clone(); @@ -556,7 +715,7 @@ fn ref_exp_cmp( let mut estimate = ExpOrdering::UNKNOWN; while n < max_n { next_x = error.clone(); - if next_x.abs() < EPS.value.abs() { + if (&next_x).abs() < (&EPS.value).abs() { break; } divisor += &ONE.value; @@ -568,8 +727,8 @@ fn ref_exp_cmp( scale(&mut error); let e2 = error.clone(); div(&mut error, &e2, &divisor); - error_term = &error * bound_x; - *rop += &next_x; + error_term = &error * Integer::from(bound_x); + *rop = &*rop + &next_x; /* compare is guaranteed to be above overall result */ upper = &*rop + &error_term; @@ -589,9 +748,12 @@ fn ref_exp_cmp( n += 1; } + let mut approx = Decimal::new(DEFAULT_PRECISION); + approx.data = rop.clone(); + ExpCmpOrdering { iterations: n, estimation: estimate, - approx: Decimal::from(&*rop), + approx, } }