diff --git a/Cargo.toml b/Cargo.toml index 7970c8e..108d39b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,10 +13,12 @@ edition = '2021' [dependencies] serde = "1.0" -bnum = {version = "0.5.0", features = ["num-traits"]} +bnum = {version = "0.5.0", features = ["numtraits"]} num-traits = "0.2" +num-integer = "0.1" [dev-dependencies] serde_json = "1.0" serde_derive = "1.0" lazy_static = "1.4" +rand = "0" \ No newline at end of file diff --git a/src/int256.rs b/src/int256.rs index 030e0ae..9fa59ff 100644 --- a/src/int256.rs +++ b/src/int256.rs @@ -1,5 +1,6 @@ use bnum::types::I256; use bnum::BInt; +use num_integer::Roots; use num_traits::{ Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, Num, One, Pow, Signed, ToPrimitive, Zero, @@ -57,7 +58,7 @@ impl Int256 { /// Square root pub fn sqrt(&self) -> Uint256 { - self.0.ilog(self.0).into() + Uint256(self.0.sqrt().unsigned_abs()) } /// Checked conversion to Uint256 @@ -321,7 +322,7 @@ macro_rules! forward_checked_op { impl $trait_ for $type_ { fn $method(&self, $type_(b): &$type_) -> Option<$type_> { let $type_(a) = self; - let value = a.$method(*b); + let value = a.$method(b); match value { Some(value) => Some(Int256(value)), None => None, @@ -381,7 +382,7 @@ fn check_from_str_radix_overflow_underflow() { #[test] fn test_to_primitive_64() { use num_traits::ToPrimitive; - let u32_max: u64 = std::u32::MAX.into(); + let u32_max: u64 = u32::MAX.into(); let mut i = 0u64; while i < 100_000 { let i_int256: Int256 = i.into(); @@ -410,7 +411,7 @@ fn test_to_primitive_64() { /// Check the ToPrimitive impl for +-[0, 100k] and +-[2^64-100, ~2^64+100k] #[test] fn test_to_primitive_128() { - let u64_max: u128 = std::u64::MAX.into(); + let u64_max: u128 = u64::MAX.into(); use num_traits::ToPrimitive; let mut i = 0u128; while i < 100_000 { @@ -444,9 +445,9 @@ fn test_to_from_bytes() { let n1 = Int256::from(-1i8); let z = Int256::zero(); let p1 = Int256::from(1i8); - let p_u64max = Int256::from(std::u64::MAX); + let p_u64max = Int256::from(u64::MAX); let n_u64max = p_u64max.neg(); - let p_u128max = Int256::from(std::u128::MAX); + let p_u128max = Int256::from(u128::MAX); let n_u128max = p_u128max.neg(); let int256_max = Int256::max_value(); let int256_min = Int256::min_value(); @@ -459,3 +460,15 @@ fn test_to_from_bytes() { assert_eq!(tc, to_from_le(tc)); } } + +#[test] +fn test_sqrt() { + use rand::prelude::*; + + for _ in 1..100000 { + let r: i128 = random(); + let n = Int256::from(r.abs()); + let sqrt = (n.mul(n)).sqrt(); + assert!(sqrt == n.to_uint256().unwrap()); + } +} diff --git a/src/uint256.rs b/src/uint256.rs index 495550b..2aafd45 100644 --- a/src/uint256.rs +++ b/src/uint256.rs @@ -1,6 +1,7 @@ pub use super::Int256; use bnum::types::U256; use bnum::BUint; +use num_integer::Roots; use num_traits::{ Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive, Num, One, Pow, ToPrimitive, Zero, @@ -66,7 +67,7 @@ impl Uint256 { /// Square root pub fn sqrt(&self) -> Uint256 { - self.0.ilog(self.0).into() + Self(self.0.sqrt()) } } @@ -357,7 +358,7 @@ macro_rules! forward_checked_op { impl $trait_ for $type_ { fn $method(&self, $type_(b): &$type_) -> Option<$type_> { let $type_(a) = self; - let value = a.$method(*b); + let value = a.$method(b); match value { Some(value) => Some(Uint256(value)), None => None, @@ -503,7 +504,7 @@ fn check_from_str_radix_overflow() { /// Check the ToPrimitive impl for [0, 100k] + [2^32-100, ~2^32+100k] #[test] fn test_to_primitive_64() { - let u32_max: u64 = std::u32::MAX.into(); + let u32_max: u64 = u32::MAX.into(); use num_traits::ToPrimitive; let mut i = 0u64; while i < 100_000 { @@ -529,7 +530,7 @@ fn test_to_primitive_64() { /// The default ToPrimitive impl breaks on values above 2^64 #[test] fn test_to_primitive_128() { - let u64_max: u128 = std::u64::MAX.into(); + let u64_max: u128 = u64::MAX.into(); use num_traits::ToPrimitive; let mut i = 0u128; while i < 100_000 { @@ -549,3 +550,15 @@ fn test_to_primitive_128() { i += 1 } } + +#[test] +fn test_sqrt() { + use rand::prelude::*; + + for _ in 1..100000 { + let r: u128 = random(); + let n = Uint256::from(r); + let sqrt = (n.mul(n)).sqrt(); + assert!(sqrt == n); + } +}