Skip to content

Commit

Permalink
Implement the 'where' primitive for conditional selection (ml-explore…
Browse files Browse the repository at this point in the history
  • Loading branch information
Rifur13 authored Feb 22, 2024
1 parent ad4a45e commit 126c986
Show file tree
Hide file tree
Showing 23 changed files with 991 additions and 56 deletions.
6 changes: 6 additions & 0 deletions benchmarks/cpp/single_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void time_unary_ops() {

void time_binary_ops() {
int M = 1000, N = 100, K = 10;
auto condition = random::randint(0, 2, {M, N, K});
auto a = random::uniform({M, N, K});
auto b = random::uniform({M, N, K});
auto device = default_device();
Expand All @@ -84,7 +85,9 @@ void time_binary_ops() {
TIME(divide, a, b, device);
TIME(maximum, a, b, device);
TIME(minimum, a, b, device);
TIME(where, condition, a, b, device);

condition = array({true});
b = random::uniform({1});
eval(b);
TIMEM("scalar", add, a, b, device);
Expand All @@ -93,14 +96,17 @@ void time_binary_ops() {
TIMEM("scalar", multiply, a, b, device);
TIMEM("vector-scalar", divide, a, b, device);
TIMEM("scalar-vector", divide, b, a, device);
TIMEM("scalar-vector", where, condition, a, b, device);

condition = broadcast_to(array({true}), {1000, 100});
a = broadcast_to(random::uniform({1}), {1000, 100});
b = broadcast_to(random::uniform({1}), {1000, 100});
eval(a, b);
TIMEM("scalar-scalar broadcast", add, a, b, device);
TIMEM("scalar-scalar broadcast", subtract, a, b, device);
TIMEM("scalar-scalar broadcast", multiply, a, b, device);
TIMEM("scalar-scalar broadcast", divide, a, b, device);
TIMEM("scalar-scalar broadcast", where, condition, a, b, device);
}

void time_strided_ops() {
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/accelerate/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ DEFAULT(Reshape)
DEFAULT(Remainder)
DEFAULT(Round)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Slice)
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp
${CMAKE_CURRENT_SOURCE_DIR}/rope.cpp
${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp
${CMAKE_CURRENT_SOURCE_DIR}/select.cpp
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp
Expand Down
44 changes: 22 additions & 22 deletions mlx/backend/common/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace mlx::core {

namespace {

enum BinaryOpType {
enum class BinaryOpType {
ScalarScalar,
ScalarVector,
VectorScalar,
Expand All @@ -20,17 +20,17 @@ enum BinaryOpType {
BinaryOpType get_binary_op_type(const array& a, const array& b) {
BinaryOpType bopt;
if (a.data_size() == 1 && b.data_size() == 1) {
bopt = ScalarScalar;
bopt = BinaryOpType::ScalarScalar;
} else if (a.data_size() == 1 && b.flags().contiguous) {
bopt = ScalarVector;
bopt = BinaryOpType::ScalarVector;
} else if (b.data_size() == 1 && a.flags().contiguous) {
bopt = VectorScalar;
bopt = BinaryOpType::VectorScalar;
} else if (
a.flags().row_contiguous && b.flags().row_contiguous ||
a.flags().col_contiguous && b.flags().col_contiguous) {
bopt = VectorVector;
bopt = BinaryOpType::VectorVector;
} else {
bopt = General;
bopt = BinaryOpType::General;
}
return bopt;
}
Expand All @@ -42,11 +42,11 @@ void set_binary_op_output_data(
BinaryOpType bopt,
bool donate_with_move = false) {
switch (bopt) {
case ScalarScalar:
case BinaryOpType::ScalarScalar:
out.set_data(
allocator::malloc_or_wait(out.itemsize()), 1, a.strides(), a.flags());
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
if (b.is_donatable() && b.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(b);
Expand All @@ -61,7 +61,7 @@ void set_binary_op_output_data(
b.flags());
}
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
Expand All @@ -76,7 +76,7 @@ void set_binary_op_output_data(
a.flags());
}
break;
case VectorVector:
case BinaryOpType::VectorVector:
if (a.is_donatable() && a.itemsize() == out.itemsize()) {
if (donate_with_move) {
out.move_shared_buffer(a);
Expand All @@ -97,7 +97,7 @@ void set_binary_op_output_data(
a.flags());
}
break;
case General:
case BinaryOpType::General:
if (a.is_donatable() && a.flags().row_contiguous &&
a.itemsize() == out.itemsize() && a.size() == out.size()) {
if (donate_with_move) {
Expand Down Expand Up @@ -424,25 +424,25 @@ void binary_op(
set_binary_op_output_data(a, b, out, bopt);

// The full computation is scalar scalar so call the base op once
if (bopt == ScalarScalar) {
if (bopt == BinaryOpType::ScalarScalar) {
*(out.data<U>()) = op(*a.data<T>(), *b.data<T>());
return;
}

// The full computation is scalar vector so delegate to the op
if (bopt == ScalarVector) {
if (bopt == BinaryOpType::ScalarVector) {
opsv(a.data<T>(), b.data<T>(), out.data<U>(), b.data_size());
return;
}

// The full computation is vector scalar so delegate to the op
if (bopt == VectorScalar) {
if (bopt == BinaryOpType::VectorScalar) {
opvs(a.data<T>(), b.data<T>(), out.data<U>(), a.data_size());
return;
}

// The full computation is vector vector so delegate to the op
if (bopt == VectorVector) {
if (bopt == BinaryOpType::VectorVector) {
opvv(a.data<T>(), b.data<T>(), out.data<U>(), out.size());
return;
}
Expand Down Expand Up @@ -475,17 +475,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = VectorVector;
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = VectorScalar;
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = ScalarVector;
bopt = BinaryOpType::ScalarVector;
dim = d;
}

Expand All @@ -495,20 +495,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = General;
bopt = BinaryOpType::General;
dim = ndim;
} else {
stride = strides[dim - 1];
}

switch (bopt) {
case VectorVector:
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out, opvv, dim, stride);
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out, opvs, dim, stride);
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out, opsv, dim, stride);
break;
default:
Expand Down
22 changes: 11 additions & 11 deletions mlx/backend/common/binary_two.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,14 +260,14 @@ void binary_op(
set_binary_op_output_data(a, b, out_b, bopt);

// The full computation is scalar scalar so call the base op once
if (bopt == ScalarScalar) {
if (bopt == BinaryOpType::ScalarScalar) {
std::tie(*(out_a.data<U>()), *(out_b.data<U>())) =
op(*a.data<T>(), *b.data<T>());
return;
}

// The full computation is scalar vector so delegate to the op
if (bopt == ScalarVector) {
if (bopt == BinaryOpType::ScalarVector) {
opsv(
a.data<T>(),
b.data<T>(),
Expand All @@ -278,7 +278,7 @@ void binary_op(
}

// The full computation is vector scalar so delegate to the op
if (bopt == VectorScalar) {
if (bopt == BinaryOpType::VectorScalar) {
opvs(
a.data<T>(),
b.data<T>(),
Expand All @@ -289,7 +289,7 @@ void binary_op(
}

// The full computation is vector vector so delegate to the op
if (bopt == VectorVector) {
if (bopt == BinaryOpType::VectorVector) {
opvv(
a.data<T>(),
b.data<T>(),
Expand Down Expand Up @@ -327,17 +327,17 @@ void binary_op(
// Case 1: LxM and FxM where L and F are broadcastable and M is row contiguous
int dim = ndim;
if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) {
bopt = VectorVector;
bopt = BinaryOpType::VectorVector;
dim = d;
// Case 2: LxM and Fx1 where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) {
bopt = VectorScalar;
bopt = BinaryOpType::VectorScalar;
dim = d;
// Case 3: Lx1 and FxM where L and F are broadcastable and M is row
// contiguous
} else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) {
bopt = ScalarVector;
bopt = BinaryOpType::ScalarVector;
dim = d;
}

Expand All @@ -347,20 +347,20 @@ void binary_op(
size_t stride;
if (dim == 0 || strides[dim - 1] < 16) {
stride = 1;
bopt = General;
bopt = BinaryOpType::General;
dim = ndim;
} else {
stride = strides[dim - 1];
}

switch (bopt) {
case VectorVector:
case BinaryOpType::VectorVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvv, dim, stride);
break;
case VectorScalar:
case BinaryOpType::VectorScalar:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opvs, dim, stride);
break;
case ScalarVector:
case BinaryOpType::ScalarVector:
binary_op_dispatch_dims<T, U>(a, b, out_a, out_b, opsv, dim, stride);
break;
default:
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/default_primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ DEFAULT(Reshape)
DEFAULT(Round)
DEFAULT(Scan)
DEFAULT(Scatter)
DEFAULT(Select)
DEFAULT(Sigmoid)
DEFAULT(Sign)
DEFAULT(Sin)
Expand Down
7 changes: 7 additions & 0 deletions mlx/backend/common/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -588,4 +588,11 @@ struct LogicalOr {
};
};

struct Select {
template <typename T>
T operator()(bool condition, T x, T y) {
return condition ? x : y;
}
};

} // namespace mlx::core::detail
72 changes: 72 additions & 0 deletions mlx/backend/common/select.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright © 2023 Apple Inc.

#include <cassert>

#include "mlx/backend/common/ternary.h"
#include "mlx/primitives.h"

namespace mlx::core {

namespace {

template <typename Op>
void select_op(
const array& a,
const array& b,
const array& c,
array& out,
Op op) {
switch (out.dtype()) {
case bool_:
ternary_op<bool, bool, bool, bool>(a, b, c, out, op);
break;
case uint8:
ternary_op<bool, uint8_t, uint8_t, uint8_t>(a, b, c, out, op);
break;
case uint16:
ternary_op<bool, uint16_t, uint16_t, uint16_t>(a, b, c, out, op);
break;
case uint32:
ternary_op<bool, uint32_t, uint32_t, uint32_t>(a, b, c, out, op);
break;
case uint64:
ternary_op<bool, uint64_t, uint64_t, uint64_t>(a, b, c, out, op);
break;
case int8:
ternary_op<bool, int8_t, int8_t, int8_t>(a, b, c, out, op);
break;
case int16:
ternary_op<bool, int16_t, int16_t, int16_t>(a, b, c, out, op);
break;
case int32:
ternary_op<bool, int32_t, int32_t, int32_t>(a, b, c, out, op);
break;
case int64:
ternary_op<bool, int64_t, int64_t, int64_t>(a, b, c, out, op);
break;
case float16:
ternary_op<bool, float16_t, float16_t, float16_t>(a, b, c, out, op);
break;
case float32:
ternary_op<bool, float, float, float>(a, b, c, out, op);
break;
case bfloat16:
ternary_op<bool, bfloat16_t, bfloat16_t, bfloat16_t>(a, b, c, out, op);
break;
case complex64:
ternary_op<bool, complex64_t, complex64_t, complex64_t>(a, b, c, out, op);
break;
}
}

} // namespace

void Select::eval(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);
const auto& condition = inputs[0];
const auto& a = inputs[1];
const auto& b = inputs[2];
select_op(condition, a, b, out, detail::Select());
}

} // namespace mlx::core
Loading

0 comments on commit 126c986

Please sign in to comment.