Skip to content

Commit

Permalink
refactor: replace custom pow with recently added native candle pow
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Jan 15, 2024
1 parent 8f6b8be commit 413a127
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 9 deletions.
2 changes: 1 addition & 1 deletion native/candlex/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 0 additions & 4 deletions native/candlex/src/kernels/custom_binary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ DEVICE_FN_FLOAT_WRAPPER(atan2)
DEVICE_FN_DOUBLE_WRAPPER(atan2)
DEVICE_FN_FLOAT_WRAPPER(fmod)
DEVICE_FN_DOUBLE_WRAPPER(fmod)
DEVICE_FN_FLOAT_WRAPPER(pow)
DEVICE_FN_DOUBLE_WRAPPER(pow)

#define CUSTOM_BINARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \
extern "C" __global__ void FN_NAME( \
Expand Down Expand Up @@ -90,8 +88,6 @@ CUSTOM_BINARY_OP(uint32_t, bit_or_u32, x | y)
CUSTOM_BINARY_OP(int64_t, bit_or_i64, x | y)
CUSTOM_BINARY_OP(uint32_t, bit_xor_u32, x ^ y)
CUSTOM_BINARY_OP(int64_t, bit_xor_i64, x ^ y)
CUSTOM_BINARY_OP(float, pow_f32, powg(x, y))
CUSTOM_BINARY_OP(double, pow_f64, powg(x, y))
CUSTOM_BINARY_OP(uint8_t, remainder_u8, x % y)
CUSTOM_BINARY_OP(int64_t, remainder_i64, x % y)
CUSTOM_BINARY_OP(float, remainder_f32, fmodg(x, y))
Expand Down
1 change: 0 additions & 1 deletion native/candlex/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,6 @@ custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64));
custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64));
custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2, (U32, I64));
custom_binary_op!(Atan2, "atan2", |v1, v2| v1.atan2(v2), (F32, F64));
custom_binary_op!(Pow, "pow", |v1, v2| v1.powf(v2), (F32, F64));
custom_binary_op!(
Remainder,
"remainder",
Expand Down
6 changes: 3 additions & 3 deletions native/candlex/src/tensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::atoms;
use crate::error::CandlexError;
use crate::ops::{
Acos, Acosh, Argsort, Asin, Asinh, Atan, Atan2, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt,
Cosh, ErfInv, Erfc, Expm1, IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow,
Remainder, Shl, Shr, Sigmoid, Sign, Sinh, Tan,
Cosh, ErfInv, Erfc, Expm1, IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Remainder,
Shl, Shr, Sigmoid, Sign, Sinh, Tan,
};
use candle_core::{DType, Device, Tensor};
use half::{bf16, f16};
Expand Down Expand Up @@ -549,6 +549,7 @@ binary_nif!(add, broadcast_add);
binary_nif!(subtract, broadcast_sub);
binary_nif!(multiply, broadcast_mul);
binary_nif!(quotient, broadcast_div);
binary_nif!(pow, broadcast_pow);
binary_nif!(max, broadcast_maximum);
binary_nif!(min, broadcast_minimum);
binary_nif!(equal, eq);
Expand All @@ -567,7 +568,6 @@ custom_binary_nif!(left_shift, Shl);
custom_binary_nif!(logical_and, LogicalAnd);
custom_binary_nif!(logical_or, LogicalOr);
custom_binary_nif!(logical_xor, LogicalXor);
custom_binary_nif!(pow, Pow);
custom_binary_nif!(right_shift, Shr);
custom_binary_nif!(remainder, Remainder);

Expand Down

0 comments on commit 413a127

Please sign in to comment.