Skip to content

Commit

Permalink
[MPS][BE] Delete MacOS-12.3 specific checks (pytorch#133141)
Browse files Browse the repository at this point in the history
And make MPS device unavailable on Sonoma releases As lots of those checks 2 years old, are no longer validated in CI and probably much more such checks are missing

Pull Request resolved: pytorch#133141
Approved by: https://github.com/kulinseth, https://github.com/clee2000, https://github.com/atalman
  • Loading branch information
malfet authored and pytorchmergebot committed Aug 14, 2024
1 parent 7b269cc commit 07c73a9
Show file tree
Hide file tree
Showing 22 changed files with 135 additions and 619 deletions.
95 changes: 0 additions & 95 deletions aten/src/ATen/mps/IndexKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,13 @@ static const char * indexing_metal_shaders = R"INDEX_METAL(
using namespace metal;
#if __METAL_VERSION__ < 300
struct IndexAB {
// Allow up to 16 indices
metal::array<constant void *, 16> indexArray [[ id(0) ]];
};
#else
struct IndexAB {
constant int64_t* indexArray;
};
#endif
template<typename T, typename OffsetsT>
kernel void index_select(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
Expand All @@ -38,11 +26,7 @@ kernel void index_select(
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
Expand All @@ -56,11 +40,7 @@ kernel void index_select(
template<typename T, typename OffsetsT>
void index_put_impl(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB,
#else
constant IndexAB & indexAB,
#endif
constant int64_t * index_sizes,
constant int64_t * index_strides,
constant OffsetsT * offsets,
Expand All @@ -70,11 +50,7 @@ void index_put_impl(
uint thread_index) {
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
Expand All @@ -89,11 +65,7 @@ void index_put_impl(
template<typename T, typename OffsetsT>
kernel void index_put_serial(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
Expand All @@ -113,11 +85,7 @@ kernel void index_put_serial(
template<typename T, typename OffsetsT>
kernel void index_put(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
Expand All @@ -131,20 +99,6 @@ kernel void index_put(
index_put_impl<T>(indexAB, index_sizes, index_strides, offsets, inputData, outputData, num_indices, thread_index);
}
#if __METAL_VERSION__ < 300
#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
constant IndexAB & indexAB [[buffer(0)]], \
constant void * indexSizes [[buffer(1)]], \
constant void * indexStrides [[buffer(2)]], \
constant IDX_DTYPE * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#else
#define REGISTER_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
Expand All @@ -157,7 +111,6 @@ kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>(
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#endif
#define REGISTER_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
REGISTER_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
Expand All @@ -172,21 +125,6 @@ kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>(
REGISTER_INDEX_OP_ALL_DTYPES(select);
REGISTER_INDEX_OP_ALL_DTYPES(put);
#if __METAL_VERSION__ < 300
#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>( \
constant IndexAB & indexAB [[buffer(0)]], \
constant void * indexSizes [[buffer(1)]], \
constant void * indexStrides [[buffer(2)]], \
constant IDX_DTYPE * offsets [[buffer(3)]], \
constant void * inputData [[buffer(4)]], \
device void * outputData [[buffer(5)]], \
constant uint32_t & num_indices [[buffer(6)]], \
constant uint * numIters [[buffer(7)]], \
uint thread_index [[thread_position_in_grid]]);
#else
#define REGISTER_SINGLE_THREADED_INDEX_OP(DTYPE_SIZE, IDX_SIZE, DTYPE, INDEX_OP_TYPE, IDX_DTYPE) \
template \
[[host_name("index_" #INDEX_OP_TYPE "_" #DTYPE_SIZE "_" #IDX_SIZE)]] \
Expand All @@ -200,7 +138,6 @@ kernel void index_ ## INDEX_OP_TYPE<DTYPE, IDX_DTYPE>(
constant uint32_t & num_indices [[buffer(6)]], \
constant uint * numIters [[buffer(7)]], \
uint thread_index [[thread_position_in_grid]]);
#endif
#define REGISTER_SINGLE_THREADED_INDEX_OP_ALL_DTYPES(INDEX_OP_TYPE) \
REGISTER_SINGLE_THREADED_INDEX_OP(8bit, idx32, char, INDEX_OP_TYPE, uint3); \
Expand Down Expand Up @@ -250,11 +187,7 @@ kernel void kernel_index_offsets<packed_uint3, ulong3>(
template<typename T, typename E, typename OffsetsT>
kernel void index_put_accumulate_native_dtypes(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
Expand All @@ -266,11 +199,7 @@ kernel void index_put_accumulate_native_dtypes(
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
Expand All @@ -294,11 +223,7 @@ __attribute__((__always_inline__)) void atomic_fetch_add_relaxed(device void * a
template<typename T, typename OffsetsT>
kernel void atomic_index_put_accumulate(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant OffsetsT * offsets [[buffer(3)]],
Expand All @@ -310,11 +235,7 @@ kernel void atomic_index_put_accumulate(
constant int64_t * index_strides = (constant int64_t *)indexStrides;
int64_t offset = 0;
for (uint32_t i = 0; i < num_indices; i++) {
#if __METAL_VERSION__ >= 300
constant int64_t* indexArray = indexAB[i].indexArray;
#else
constant int64_t* indexArray = (constant int64_t*)indexAB.indexArray[i];
#endif
int64_t index = indexArray[offsets[thread_index].z / sizeof(int64_t)];
if (index < 0) {
index += index_sizes[i];
Expand All @@ -329,11 +250,7 @@ kernel void atomic_index_put_accumulate(
template
[[host_name("index_put_accumulate_32bit_float_idx32")]]
kernel void atomic_index_put_accumulate<float, uint3>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
Expand All @@ -345,11 +262,7 @@ kernel void atomic_index_put_accumulate<float, uint3>(
template
[[host_name("index_put_accumulate_32bit_float_idx64")]]
kernel void atomic_index_put_accumulate<float, ulong3>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant ulong3 * offsets [[buffer(3)]],
Expand All @@ -361,11 +274,7 @@ kernel void atomic_index_put_accumulate<float, ulong3>(
template
[[host_name("index_put_accumulate_32bit_int_idx32")]]
kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
Expand All @@ -377,11 +286,7 @@ kernel void index_put_accumulate_native_dtypes<atomic_int, int, uint3>(
template
[[host_name("index_put_accumulate_32bit_int_idx64")]]
kernel void index_put_accumulate_native_dtypes<atomic_int, int, ulong3>(
#if __METAL_VERSION__ >= 300
constant IndexAB * indexAB [[buffer(0)]],
#else
constant IndexAB & indexAB [[buffer(0)]],
#endif
constant void * indexSizes [[buffer(1)]],
constant void * indexStrides [[buffer(2)]],
constant ulong3 * offsets [[buffer(3)]],
Expand Down
5 changes: 2 additions & 3 deletions aten/src/ATen/mps/MPSDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ namespace at::mps {

// Helper enum to check if a MPSGraph op is supported in a given macOS version
enum class MacOSVersion : uint32_t {
MACOS_VER_13_0_PLUS = 0,
MACOS_VER_13_1_PLUS,
MACOS_VER_13_1_PLUS = 0,
MACOS_VER_13_2_PLUS,
MACOS_VER_13_3_PLUS,
MACOS_VER_14_0_PLUS,
Expand Down Expand Up @@ -79,7 +78,7 @@ class TORCH_API MPSDevice {
};

TORCH_API bool is_available();
TORCH_API bool is_macos_13_or_newer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
TORCH_API bool is_macos_13_or_newer(MacOSVersion version);
TORCH_API at::Allocator* GetMPSAllocator(bool useSharedAllocator = false);

} // namespace at::mps
23 changes: 6 additions & 17 deletions aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,11 @@
static std::unique_ptr<MPSDevice> mps_device;
static c10::once_flag mpsdev_init;

static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device, bool macOS13Plus) {
static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) {
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
// host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+
MTLLanguageVersion languageVersion = MTLLanguageVersion2_3;
#if defined(__MAC_13_0)
if (macOS13Plus) {
languageVersion = MTLLanguageVersion3_0;
}
#endif

TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
return languageVersion;
return MTLLanguageVersion3_0;
}

MPSDevice* MPSDevice::getInstance() {
Expand All @@ -36,7 +29,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
NSError* error = nil;
if (!_mtl_indexing_library) {
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device)];
[options setFastMathEnabled:YES];
_mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders
encoding:NSASCIIStringEncoding]
Expand Down Expand Up @@ -75,13 +68,12 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
}

MPSDevice::MPSDevice() : _mtl_device(nil), _mtl_indexing_library(nil) {
// Check that MacOS 12.3+ version of MPS framework is available
// Create the MPSGraph and check method introduced in 12.3+
// Check that MacOS 13.0+ version of MPS framework is available
// Create the MPSGraph and check method introduced in 13.0
// which is used by MPS backend.
id mpsCD = NSClassFromString(@"MPSGraph");

if ([mpsCD instancesRespondToSelector:@selector
(LSTMWithSourceTensor:recurrentWeight:inputWeight:bias:initState:initCell:descriptor:name:)] == NO) {
if ([mpsCD instancesRespondToSelector:@selector(cumulativeSumWithTensor:axis:name:)] == NO) {
return;
}

Expand Down Expand Up @@ -112,7 +104,6 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
isOperatingSystemAtLeastVersion:{.majorVersion = major, .minorVersion = minor, .patchVersion = 0}];
}
};
static bool _macos_13_0_plus = is_os_version_at_least(13, 0);
static bool _macos_13_1_plus = is_os_version_at_least(13, 1);
static bool _macos_13_2_plus = is_os_version_at_least(13, 2);
static bool _macos_13_3_plus = is_os_version_at_least(13, 3);
Expand All @@ -121,8 +112,6 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
static bool _macos_15_0_plus = is_os_version_at_least(15, 0);

switch (version) {
case MacOSVersion::MACOS_VER_13_0_PLUS:
return _macos_13_0_plus;
case MacOSVersion::MACOS_VER_13_1_PLUS:
return _macos_13_1_plus;
case MacOSVersion::MACOS_VER_13_2_PLUS:
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSHooks.mm
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
TORCH_CHECK(major == 13, "Trying to check for unexpected MacOS major ", major);
switch (minor) {
case 0:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS);
return true;
case 1:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_1_PLUS);
case 2:
Expand Down
6 changes: 1 addition & 5 deletions aten/src/ATen/native/RNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1433,15 +1433,11 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
}
#ifdef USE_MPS
if (_input.is_mps() && (mps::is_macos_13_or_newer() || num_layers == 1)) {
if (_input.is_mps()) {
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
num_layers, dropout_p, train, bidirectional, batch_first);
std::tuple<Tensor, Tensor, Tensor> return_values = std::make_tuple(std::get<0>(output), std::get<1>(output), std::get<2>(output));
return return_values;
} else if (_input.is_mps()) {
TORCH_WARN_ONCE("Native multi-layer LSTM support in MPS available only on MacOS 13 onwards.",
" Falling back to LSTMCell iteration.",
" This may have performance implications.");
}
#endif
// if cells are of different size, that means projections are used
Expand Down
15 changes: 0 additions & 15 deletions aten/src/ATen/native/mps/operations/BinaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ static void binaryOpTensor(const Tensor& self,
const Tensor& output_,
std::string op_name,
BinaryOpBlock binaryBlock) {
TORCH_CHECK(!(!is_macos_13_or_newer() && self.scalar_type() == ScalarType::Byte),
"MPS support binary op with uint8 natively starting from macOS 13.0");
TORCH_CHECK(!(op_name == "power" && !is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS) &&
(self.scalar_type() == ScalarType::Long ||
(other.scalar_type() == ScalarType::Long &&
Expand Down Expand Up @@ -105,19 +103,6 @@ static void binaryOpTensor(const Tensor& self,
auto inputDataType = self.scalar_type();
auto otherDataType = other.scalar_type();
auto outputDataType = output_.scalar_type();
if (!is_macos_13_or_newer()) {
// workaround for signed vs. unsigned comparison issue in MacOS 12
if (outputDataType == kBool && (inputDataType == kByte || otherDataType == kByte)) {
inputDataType = otherDataType = kByte;
} else {
if (inputDataType == kBool || inputDataType == kByte) {
inputDataType = kChar;
}
if (otherDataType == kBool || otherDataType == kByte) {
otherDataType = kChar;
}
}
}

@autoreleasepool {
string key = op_name + getTensorsStringKey({self, other, output_});
Expand Down
11 changes: 0 additions & 11 deletions aten/src/ATen/native/mps/operations/Distributions.mm
Original file line number Diff line number Diff line change
Expand Up @@ -406,17 +406,6 @@ Tensor normal_mps(const Tensor& mean, const Tensor& std, std::optional<Generator
}

Tensor& randperm_out_mps(int64_t n, std::optional<Generator> generator, Tensor& result) {
if (!is_macos_13_or_newer()) {
TORCH_WARN_ONCE("MPS: randperm op is supported natively starting from macOS 13.0. ",
"Falling back on CPU. This may have performance implications.");

auto result_cpu = result.to("cpu");
at::randperm_out(result_cpu, n);
result.resize_as_(result_cpu);
result.copy_(result_cpu);
return result;
}

TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()),
"Expected a '",
Expand Down
Loading

0 comments on commit 07c73a9

Please sign in to comment.