diff --git a/.github/workflows/build_test_release_eudsl.yml b/.github/workflows/build_test_release_eudsl.yml index c1090e8e..1b6cb4e6 100644 --- a/.github/workflows/build_test_release_eudsl.yml +++ b/.github/workflows/build_test_release_eudsl.yml @@ -128,7 +128,8 @@ jobs: # prevent OOM on free GHA echo "EUDSLPY_DISABLE_COMPILE_OPT=${{ inputs.eudslpy_disable_compile_opt == 'true' && 'ON' || 'OFF' }}" >> $GITHUB_ENV - pip install cibuildwheel + $python3_command -m pip install cibuildwheel + $python3_command -m pip install -r requirements-dev.txt - name: "Build eudsl-tblgen" run: | @@ -141,7 +142,7 @@ jobs: else export CMAKE_PREFIX_PATH=$PWD/llvm-install export PIP_FIND_LINKS=$PWD/wheelhouse - $python3_command -m pip wheel "$PWD/projects/eudsl-tblgen" -w wheelhouse -v + $python3_command -m pip wheel "$PWD/projects/eudsl-tblgen" -w wheelhouse -v --no-build-isolation fi - name: "Build eudsl-llvmpy" @@ -156,7 +157,8 @@ jobs: else export CMAKE_PREFIX_PATH=$PWD/llvm-install export PIP_FIND_LINKS=$PWD/wheelhouse - $python3_command -m pip wheel "$PWD/projects/eudsl-llvmpy" -w wheelhouse -v + $python3_command -m pip install eudsl-tblgen -f wheelhouse + $python3_command -m pip wheel "$PWD/projects/eudsl-llvmpy" -w wheelhouse -v --no-build-isolation fi - name: "Build eudsl-nbgen" @@ -170,7 +172,7 @@ jobs: else export CMAKE_PREFIX_PATH=$PWD/llvm-install export PIP_FIND_LINKS=$PWD/wheelhouse - $python3_command -m pip wheel "$PWD/projects/eudsl-nbgen" -w wheelhouse -v + $python3_command -m pip wheel "$PWD/projects/eudsl-nbgen" -w wheelhouse -v --no-build-isolation fi - name: "Build eudsl-py" @@ -185,14 +187,15 @@ jobs: else export CMAKE_PREFIX_PATH=$PWD/llvm-install export PIP_FIND_LINKS=$PWD/wheelhouse - $python3_command -m pip wheel "$PWD/projects/eudsl-py" -w wheelhouse -v + $python3_command -m pip install eudsl-nbgen -f wheelhouse + $python3_command -m pip wheel "$PWD/projects/eudsl-py" -w wheelhouse -v --no-build-isolation fi # just to/make sure total build continues to work - name: "Build all of eudsl" + if: ${{ github.event_name == 'pull_request' }} run: | - $python3_command -m pip install -r requirements.txt $python3_command -m pip install eudsl-tblgen -f wheelhouse cmake -B $PWD/eudsl-build -S $PWD \ -DCMAKE_PREFIX_PATH=$PWD/llvm-install \ diff --git a/CMakeLists.txt b/CMakeLists.txt index 4f423581..8dc815a2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,7 @@ if(EUDSL_STANDALONE_BUILD) set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + set(MLIR_INCLUDE_DIR ${MLIR_INCLUDE_DIRS}) list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") @@ -40,8 +41,6 @@ if(EUDSL_STANDALONE_BUILD) include(AddLLVM) include(AddMLIR) include(AddClang) - - set(MLIR_INCLUDE_DIR ${MLIR_INCLUDE_DIRS}) else() # turning LLVM -DLLVM_OPTIMIZED_TABLEGEN=ON builds some stuff in the NATIVE dir # but not everything so LLVM_BINARY_DIR isn't correct diff --git a/build_tools/cmake/llvm_cache.cmake b/build_tools/cmake/llvm_cache.cmake index 2e64b58e..882ef937 100644 --- a/build_tools/cmake/llvm_cache.cmake +++ b/build_tools/cmake/llvm_cache.cmake @@ -15,6 +15,7 @@ set(LLVM_BUILD_TOOLS ON CACHE BOOL "") set(LLVM_BUILD_UTILS ON CACHE BOOL "") set(LLVM_INCLUDE_TOOLS ON CACHE BOOL "") set(LLVM_INSTALL_UTILS ON CACHE BOOL "") +set(LLVM_ENABLE_DUMP ON CACHE BOOL "") set(LLVM_BUILD_LLVM_DYLIB ON CACHE BOOL "") # All the tools will use libllvm shared library @@ -75,26 +76,7 @@ set(LLVM_INSTALL_TOOLCHAIN_ONLY OFF CACHE BOOL "") set(LLVM_DISTRIBUTIONS MlirDevelopment CACHE STRING "") set(LLVM_MlirDevelopment_DISTRIBUTION_COMPONENTS - clangAPINotes - clangAST - clangASTMatchers - clangAnalysis - clangBasic - clangDriver - clangDriver - clangEdit - clangFormat - clangFrontend - clangLex - clangParse - clangRewrite - clangSema - clangSerialization - clangSupport - clangTooling - clangToolingCore - clangToolingInclusions - + clang-libraries clang-headers # triggers ClangConfig.cmake and etc clang-cmake-exports diff --git a/projects/CMakeLists.txt b/projects/CMakeLists.txt index 102c8801..53f951f1 100644 --- a/projects/CMakeLists.txt +++ b/projects/CMakeLists.txt @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Copyright (c) 2024. +include_directories(common) + if(NOT WIN32) add_subdirectory(eudsl-py) add_subdirectory(eudsl-llvmpy) diff --git a/projects/eudsl-py/src/bind_vec_like.h b/projects/common/eudsl/bind_vec_like.h similarity index 94% rename from projects/eudsl-py/src/bind_vec_like.h rename to projects/common/eudsl/bind_vec_like.h index 5a733f69..46f2e226 100644 --- a/projects/eudsl-py/src/bind_vec_like.h +++ b/projects/common/eudsl/bind_vec_like.h @@ -1,11 +1,12 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Copyright (c) 2024. +// Copyright (c) 2024-2025. #pragma once #include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/TypeName.h" #include #include @@ -13,7 +14,9 @@ #include #include #include +#include +namespace eudsl { struct _ArrayRef {}; struct _MutableArrayRef {}; struct _SmallVector {}; @@ -283,3 +286,18 @@ nanobind::class_ bind_iter_range(nanobind::handle scope, return cl; } + +inline void bind_array_ref_smallvector(nanobind::handle scope) { + scope.attr("T") = nanobind::type_var("T"); + arrayRef = + nanobind::class_<_ArrayRef>(scope, "ArrayRef", nanobind::is_generic(), + nanobind::sig("class ArrayRef[T]")); + mutableArrayRef = nanobind::class_<_MutableArrayRef>( + scope, "MutableArrayRef", nanobind::is_generic(), + nanobind::sig("class MutableArrayRef[T]")); + smallVector = nanobind::class_<_SmallVector>( + scope, "SmallVector", nanobind::is_generic(), + nanobind::sig("class SmallVector[T]")); +} + +} // namespace eudsl diff --git a/projects/eudsl-py/src/type_casters.h b/projects/common/eudsl/type_casters.h similarity index 89% rename from projects/eudsl-py/src/type_casters.h rename to projects/common/eudsl/type_casters.h index 66cc31d4..4b8e7d91 100644 --- a/projects/eudsl-py/src/type_casters.h +++ b/projects/common/eudsl/type_casters.h @@ -1,14 +1,20 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Copyright (c) 2024. +// Copyright (c) 2024-2025. #pragma once #include -#include +// ReSharper disable once CppUnusedIncludeDirective #include +#include +// ReSharper disable once CppUnusedIncludeDirective #include +// ReSharper disable once CppUnusedIncludeDirective +#include +// ReSharper disable once CppUnusedIncludeDirective +#include "eudsl/bind_vec_like.h" template <> struct nanobind::detail::type_caster { diff --git a/projects/common/eudsl/util.h b/projects/common/eudsl/util.h new file mode 100644 index 00000000..2b32dba2 --- /dev/null +++ b/projects/common/eudsl/util.h @@ -0,0 +1,103 @@ +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Copyright (c) 2025. + +#pragma once + +#include + +namespace eudsl { +template +struct non_copying_non_moving_class_ : nanobind::class_ { + template + NB_INLINE non_copying_non_moving_class_(nanobind::handle scope, + const char *name, + const Extra &...extra) { + nanobind::detail::type_init_data d; + + d.flags = 0; + d.align = (uint8_t)alignof(typename nanobind::class_::Alias); + d.size = (uint32_t)sizeof(typename nanobind::class_::Alias); + d.name = name; + d.scope = scope.ptr(); + d.type = &typeid(T); + + if constexpr (!std::is_same_v::Base, + T>) { + d.base = &typeid(typename nanobind::class_::Base); + d.flags |= (uint32_t)nanobind::detail::type_init_flags::has_base; + } + + if constexpr (std::is_destructible_v) { + d.flags |= (uint32_t)nanobind::detail::type_flags::is_destructible; + + if constexpr (!std::is_trivially_destructible_v) { + d.flags |= (uint32_t)nanobind::detail::type_flags::has_destruct; + d.destruct = nanobind::detail::wrap_destruct; + } + } + + if constexpr (nanobind::detail::has_shared_from_this_v) { + d.flags |= (uint32_t)nanobind::detail::type_flags::has_shared_from_this; + d.keep_shared_from_this_alive = [](PyObject *self) noexcept { + if (auto sp = nanobind::inst_ptr(self)->weak_from_this().lock()) { + nanobind::detail::keep_alive( + self, new auto(std::move(sp)), + [](void *p) noexcept { delete (decltype(sp) *)p; }); + return true; + } + return false; + }; + } + + (nanobind::detail::type_extra_apply(d, extra), ...); + + this->m_ptr = nanobind::detail::nb_type_new(&d); + } +}; + +template +constexpr auto coerceReturn(Return (*pf)(Args...)) noexcept { + return [&pf](Args &&...args) -> NewReturn { + return pf(std::forward(args)...); + }; +} + +template +constexpr auto coerceReturn(Return (Class::*pmf)(Args...), + std::false_type = {}) noexcept { + return [&pmf](Class *cls, Args &&...args) -> NewReturn { + return (cls->*pmf)(std::forward(args)...); + }; +} + +/* + * If you get + * ``` + * Called object type 'void(MyClass::*)(vector&,int)' is not a function or + * function pointer + * ``` + * it's because you're calling a member function without + * passing the `this` pointer as the first arg + */ +template +constexpr auto coerceReturn(Return (Class::*pmf)(Args...) const, + std::true_type) noexcept { + // copy the *pmf, not capture by ref + return [pmf](const Class &cls, Args &&...args) -> NewReturn { + return (cls.*pmf)(std::forward(args)...); + }; +} + +inline size_t wrap(Py_ssize_t i, size_t n) { + if (i < 0) + i += (Py_ssize_t)n; + + if (i < 0 || (size_t)i >= n) + throw nanobind::index_error(); + + return (size_t)i; +} + +} // namespace eudsl diff --git a/projects/eudsl-llvmpy/eudsl-llvmpy-generate.py b/projects/eudsl-llvmpy/eudsl-llvmpy-generate.py index 47f8b720..bfa7b9cb 100644 --- a/projects/eudsl-llvmpy/eudsl-llvmpy-generate.py +++ b/projects/eudsl-llvmpy/eudsl-llvmpy-generate.py @@ -278,19 +278,24 @@ class LLVMMatchType(Generic[_T]): int_regex = re.compile(r"_i(\d+)") fp_regex = re.compile(r"_f(\d+)") - for d in intrins.defs: - intr = intrins.defs[d] - if intr.name.startswith("int_amdgcn") and intr.type.as_string != "ClangBuiltin": + defs = intrins.get_defs() + for d in defs: + intr = defs[d] + if ( + intr.get_name().startswith("int_amdgcn") + and intr.get_type().get_as_string() != "ClangBuiltin" + ): arg_types = [] ret_types = [] - for p in intr.values.ParamTypes.value: - p_s = p.as_string + for p in intr.get_values().ParamTypes.get_value(): + p_s = p.get_as_string() if p_s.startswith("anon"): - p_s = p.type.as_string + p_s = p.get_type().get_as_string() + pdv = p.get_def().get_values() if p_s == "LLVMMatchType": - p_s += f"[{p.def_.values.Number.value.value}]" + p_s += f"[{pdv.Number.get_value()}]" elif p_s == "LLVMQualPointerType": - _, addr_space = p.def_.values.Sig.value.values + kind, addr_space = pdv.Sig.get_value() p_s += f"[{addr_space}]" else: raise NotImplemented(f"unsupported {p_s=}") @@ -303,8 +308,8 @@ class LLVMMatchType(Generic[_T]): p_s = "pointer" arg_types.append(p_s) - for p in intr.values.RetTypes.value: - ret_types.append(p.as_string) + for p in intr.get_values().RetTypes.get_value(): + ret_types.append(p.get_as_string()) ret_str = "" if len(ret_types): diff --git a/projects/eudsl-nbgen/CMakeLists.txt b/projects/eudsl-nbgen/CMakeLists.txt index a63b9287..0e1d0201 100644 --- a/projects/eudsl-nbgen/CMakeLists.txt +++ b/projects/eudsl-nbgen/CMakeLists.txt @@ -31,6 +31,8 @@ if(EUDSL_NBGEN_STANDALONE_BUILD) include(AddLLVM) include(AddClang) include(HandleLLVMOptions) + + include_directories(${CMAKE_CURRENT_LIST_DIR}/../common) endif() include_directories(${LLVM_INCLUDE_DIRS}) diff --git a/projects/eudsl-nbgen/cmake/eudsl_nbgen-config.cmake b/projects/eudsl-nbgen/cmake/eudsl_nbgen-config.cmake index e5d78b4f..0e1622ee 100644 --- a/projects/eudsl-nbgen/cmake/eudsl_nbgen-config.cmake +++ b/projects/eudsl-nbgen/cmake/eudsl_nbgen-config.cmake @@ -5,6 +5,19 @@ # copy-pasta from AddMLIR.cmake/AddLLVM.cmake/TableGen.cmake +set(EUDSL_NBGEN_NANOBIND_OPTIONS + -Wno-cast-qual + -Wno-deprecated-literal-operator + -Wno-covered-switch-default + -Wno-nested-anon-types + -Wno-zero-length-array + -Wno-c++98-compat-extra-semi + -Wno-c++20-extensions + $<$:-fexceptions -frtti> + $<$:-fexceptions -frtti> + $<$:/EHsc /GR> +) + function(eudsl_nbgen target input_file) set(EUDSL_NBGEN_TARGET_DEFINITIONS ${input_file}) cmake_parse_arguments(ARG "" "" "LINK_LIBS;EXTRA_INCLUDES;NAMESPACES" ${ARGN}) @@ -89,6 +102,7 @@ function(eudsl_nbgen target input_file) WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS ${EUDSL_NBGEN_EXE} ${global_tds} DEPFILE ${_depfile} + DEPENDS ${EUDSL_NBGEN_EXE} COMMENT "eudsl-nbgen: Generating ${_full_gen_file}..." ) # epic hack to specify all shards that will be generated even though we don't know them before hand @@ -137,6 +151,7 @@ function(eudsl_nbgen target input_file) endif() add_library(${target} STATIC "${_full_gen_file}.sharded.cpp" ${_shards}) + target_compile_options(${target} PUBLIC ${EUDSL_NBGEN_NANOBIND_OPTIONS}) execute_process( COMMAND "${Python_EXECUTABLE}" -m nanobind --include_dir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_include_dir diff --git a/projects/eudsl-nbgen/src/eudsl-nbgen.cpp b/projects/eudsl-nbgen/src/eudsl-nbgen.cpp index f8ba7d65..2c357921 100644 --- a/projects/eudsl-nbgen/src/eudsl-nbgen.cpp +++ b/projects/eudsl-nbgen/src/eudsl-nbgen.cpp @@ -93,6 +93,7 @@ static std::string getPyClassName(const std::string &qualifiedNameAsString) { s = std::regex_replace(s, std::regex(R"(\*)"), ""); s = std::regex_replace(s, std::regex("<"), "["); s = std::regex_replace(s, std::regex(">"), "]"); + s = std::regex_replace(s, std::regex("::"), "."); return s; } @@ -100,7 +101,6 @@ static std::string snakeCase(const std::string &name) { std::string s = name; s = std::regex_replace(s, std::regex(R"(([A-Z]+)([A-Z][a-z]))"), "$1_$2"); s = std::regex_replace(s, std::regex(R"(([a-z\d])([A-Z]))"), "$1_$2"); - s = std::regex_replace(s, std::regex("-"), "_"); std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); return s; @@ -143,20 +143,26 @@ static llvm::SmallPtrSet findOverloads(clang::FunctionDecl *decl, return results; } -// TODO(max): split this into two functions (one for names and one for types) -static std::string sanitizeNameOrType(std::string nameOrType, - int emptyIdx = 0) { - if (nameOrType == "from") - nameOrType = "from_"; - else if (nameOrType == "except") - nameOrType = "except_"; - else if (nameOrType == "") - nameOrType = std::string(emptyIdx + 1, '_'); - else if (nameOrType.rfind("ArrayRef", 0) == 0) - nameOrType = "llvm::" + nameOrType; - if (std::regex_search(nameOrType, std::regex(R"(std::__1)"))) - nameOrType = std::regex_replace(nameOrType, std::regex("std::__1"), "std"); - return nameOrType; +static std::string sanitizeName(std::string name, int emptyIdx = 0) { + if (name == "def") + name = "def_"; + if (name == "from") + name = "from_"; + else if (name == "except") + name = "except_"; + else if (name == "") + name = std::string(emptyIdx + 1, '_'); + return name; +} + +static std::string sanitizeType(std::string type) { + if (type.rfind("ArrayRef", 0) == 0) + type = "llvm::" + type; + if (std::regex_search(type, std::regex(R"(std::__1)"))) + type = std::regex_replace(type, std::regex("std::__1"), "std"); + if (std::regex_search(type, std::regex(R"(std::__cxx11)"))) + type = std::regex_replace(type, std::regex("std::__cxx11"), "std"); + return type; } // emit a lambda body to disambiguate/break ties amongst overloads @@ -187,12 +193,21 @@ std::string emitNBLambdaBody(clang::FunctionDecl *decl, n = llvm::formatv("std::move({0})", n); } std::string newParamNamesStr = llvm::join(newParamNames, ", "); + std::string return_; + auto returnTypeT = decl->getReturnType(); + if (!returnTypeT->isVoidType()) + return_ = "return"; + + bool canonical = true; + if (std::regex_search(returnTypeT.getAsString(), std::regex(R"(_t\b)"))) + canonical = false; + std::string returnType = + sanitizeType(returnTypeT.getAsString(getPrintingPolicy(canonical))); + std::string funcRef; - std::string returnType = sanitizeNameOrType( - decl->getReturnType().getAsString(getPrintingPolicy())); if (decl->isStatic() || !decl->isCXXClassMember()) { - funcRef = llvm::formatv("\n []({0}) -> {1} {{\n return {2}({3});\n }", - typedParamsStr, returnType, + funcRef = llvm::formatv("\n []({0}) -> {1} {{\n {2} {3}({4});\n }", + typedParamsStr, returnType, return_, decl->getQualifiedNameAsString(), newParamNamesStr); } else { assert(decl->isCXXClassMember() && "expected class member"); @@ -200,16 +215,76 @@ std::string emitNBLambdaBody(clang::FunctionDecl *decl, typedParamsStr = llvm::formatv("self, {0}", typedParamsStr); else typedParamsStr = "self"; + std::string methName = decl->getNameAsString(); + if (llvm::isa(decl)) + methName = "operator " + returnType; const clang::CXXRecordDecl *parentRecord = llvm::cast(decl->getParent()); - funcRef = llvm::formatv( - "\n []({0}& {1}) -> {2} {{\n return self.{3}({4});\n }", - parentRecord->getQualifiedNameAsString(), typedParamsStr, returnType, - decl->getNameAsString(), newParamNamesStr); + funcRef = + llvm::formatv("\n []({0}& {1}) -> {2} {{\n {3} self.{4}({5});\n }", + parentRecord->getQualifiedNameAsString(), typedParamsStr, + returnType, return_, methName, newParamNamesStr); } return funcRef; } +static std::string +processOperator(const std::string &fnName, + const llvm::SmallVector &argNames) { + std::string newFnName = fnName; + if (fnName == "operator!=") { + newFnName = "__ne__"; + } else if (fnName == "operator==") { + newFnName = "__eq__"; + } else if (fnName == "operator-") { + newFnName = "__neg__"; + } else if (fnName == "operator[]") { + newFnName = "__getitem__"; + } else if (fnName == "operator<") { + newFnName = "__lt__"; + } else if (fnName == "operator<=") { + newFnName = "__le__"; + } else if (fnName == "operator>") { + newFnName = "__gt__"; + } else if (fnName == "operator>=") { + newFnName = "__ge__"; + } else if (fnName == "operator%") { + newFnName = "__mod__"; + } else if (fnName == "operator*") { + if (!argNames.empty()) { + newFnName = "__mul__"; + } else { + // operator* not supported + } + } else if (fnName == "operator+" && !argNames.empty()) { + newFnName = "__add__"; + } else if (fnName == "operator->") { + // operator-> not supported + } else if (fnName == "operator!") { + // operator! not supported + } else if (fnName == "operator<<") { + // operator<< not supported + } + + return newFnName; +} + +static std::string getCppClass(const clang::CXXRecordDecl *decl) { + std::string className; + if (const clang::ClassTemplateSpecializationDecl *t = + llvm::dyn_cast(decl)) { + // TODO(max): this emits unnecessary default template args, like + // mlir::detail::TypeIDResolver + // auto td = t->getTypeForDecl(); + className = t->getTypeForDecl()->getCanonicalTypeInternal().getAsString( + getPrintingPolicy()); + } else { + className = decl->getQualifiedNameAsString(); + } + + return sanitizeType(className); +} + static bool emitClassMethodOrFunction(clang::FunctionDecl *decl, clang::CompilerInstance &ci, @@ -227,8 +302,8 @@ emitClassMethodOrFunction(clang::FunctionDecl *decl, if (std::regex_search(t.getAsString(), std::regex(R"(_t\b)"))) canonical = false; std::string paramType = t.getAsString(getPrintingPolicy(canonical)); - paramTypes.push_back(sanitizeNameOrType(paramType)); - paramNames.push_back(sanitizeNameOrType(name, i)); + paramTypes.push_back(sanitizeType(paramType)); + paramNames.push_back(sanitizeName(name, i)); } llvm::SmallPtrSet funcOverloads = @@ -237,21 +312,20 @@ emitClassMethodOrFunction(clang::FunctionDecl *decl, findOverloads(decl, ci.getSema()); std::string funcRef, nbFnName; - if (auto ctor = llvm::dyn_cast(decl)) { + if (clang::CXXConstructorDecl *ctor = + llvm::dyn_cast(decl)) { if (ctor->isDeleted()) return false; funcRef = llvm::formatv("nb::init<{0}>()", llvm::join(paramTypes, ", ")); } else { - if (funcOverloads.size() == 1 && funcTemplOverloads.empty()) { + if (funcOverloads.size() == 1 && funcTemplOverloads.empty()) funcRef = llvm::formatv("&{0}", decl->getQualifiedNameAsString()); - } else { + else funcRef = emitNBLambdaBody(decl, paramNames, paramTypes); - } - nbFnName = snakeCase(decl->getNameAsString()); - if (decl->isOverloadedOperator()) { - // TODO(max): handle overloaded operators - // nbFnName = nbFnName; + if (decl->isOverloadedOperator() || + llvm::isa(decl)) { + nbFnName = processOperator(nbFnName, paramNames); } else if (decl->isStatic() && funcOverloads.size() > 1 && llvm::any_of(funcOverloads, [](clang::FunctionDecl *m) { return !m->isStatic(); @@ -299,7 +373,7 @@ emitClassMethodOrFunction(clang::FunctionDecl *decl, if (decl->isCXXClassMember()) { const clang::CXXRecordDecl *parentRecord = llvm::cast(decl->getParent()); - scope = getNBBindClassName(parentRecord->getQualifiedNameAsString()); + scope = getNBBindClassName(getCppClass(parentRecord)); } outputFile->os() << llvm::formatv("{0}.{1}({2}{3}{4}{5}{6});\n", scope, @@ -309,13 +383,12 @@ emitClassMethodOrFunction(clang::FunctionDecl *decl, return true; } -std::string getNBScope(clang::TagDecl *decl) { +static std::string getNBScope(clang::TagDecl *decl) { std::string scope = "m"; const clang::DeclContext *declContext = decl->getDeclContext(); - if (declContext->isRecord()) { - const clang::CXXRecordDecl *ctx = - llvm::cast(declContext); - scope = getNBBindClassName(ctx->getQualifiedNameAsString()); + if (const clang::CXXRecordDecl *ctx = + llvm::dyn_cast(declContext)) { + scope = getNBBindClassName(getCppClass(ctx)); } return scope; } @@ -330,9 +403,9 @@ static bool emitClass(clang::CXXRecordDecl *decl, clang::CompilerInstance &ci, return false; } - std::string scope = getNBScope(decl); std::string additional = ""; - std::string className = decl->getQualifiedNameAsString(); + std::string cppClass = getCppClass(decl); + std::string autoVar = llvm::formatv("auto {0}", getNBBindClassName(cppClass)); if (decl->getNumBases() > 1) { clang::DiagnosticBuilder builder = ci.getDiagnostics().Report( decl->getLocation(), ci.getDiagnostics().getCustomDiagID( @@ -341,25 +414,27 @@ static bool emitClass(clang::CXXRecordDecl *decl, clang::CompilerInstance &ci, } else if (decl->getNumBases() == 1) { // handle some known bases that we've already found a wap to bind clang::CXXBaseSpecifier baseClass = *decl->bases_begin(); - std::string baseName = baseClass.getType().getAsString(getPrintingPolicy()); + clang::QualType baseType = baseClass.getType(); + std::string baseName = getCppClass(baseType->getAsCXXRecordDecl()); // TODO(max): these could be lookups on the corresponding recorddecls using // sema... if (baseName.rfind("mlir::Op<", 0) == 0) { - className = llvm::formatv("{0}, mlir::OpState", className); + cppClass = llvm::formatv("{0}, mlir::OpState", cppClass); } else if (baseName.rfind("mlir::detail::StorageUserBase<", 0) == 0) { llvm::SmallVector templParams; llvm::StringRef{baseName}.split(templParams, ","); - className = llvm::formatv("{0}, {1}", className, templParams[1]); + // TODO(max): this needs to use getCppClass not templParams[1], which is a + // string + cppClass = llvm::formatv("{0}, {1}", cppClass, templParams[1]); } else if (baseName.rfind("mlir::Dialect", 0) == 0 && - className.rfind("mlir::ExtensibleDialect") == - std::string::npos) { + cppClass.rfind("mlir::ExtensibleDialect") == std::string::npos) { // clang-format off - additional += llvm::formatv("\n .def_static(\"insert_into_registry\", [](mlir::DialectRegistry ®istry) {{ registry.insert<{0}>(); })", className); - additional += llvm::formatv("\n .def_static(\"load_into_context\", [](mlir::MLIRContext &context) {{ return context.getOrLoadDialect<{0}>(); })", className); + additional += llvm::formatv("\n .def_static(\"insert_into_registry\", [](mlir::DialectRegistry ®istry) {{ registry.insert<{0}>(); })", cppClass); + additional += llvm::formatv("\n .def_static(\"load_into_context\", [](mlir::MLIRContext &context) {{ return context.getOrLoadDialect<{0}>(); })", cppClass); // clang-format on } else if (!llvm::isa( baseClass.getType()->getAsCXXRecordDecl())) { - className = llvm::formatv("{0}, {1}", className, baseName); + cppClass = llvm::formatv("{0}, {1}", cppClass, baseName); } else { assert(llvm::isa( baseClass.getType()->getAsCXXRecordDecl()) && @@ -372,12 +447,11 @@ static bool emitClass(clang::CXXRecordDecl *decl, clang::CompilerInstance &ci, } } - std::string autoVar = llvm::formatv( - "auto {0}", getNBBindClassName(decl->getQualifiedNameAsString())); - + std::string scope = getNBScope(decl); + std::string pyClassName = getPyClassName(decl->getNameAsString()); outputFile->os() << llvm::formatv( - "\n{0} = nb::class_<{1}>({2}, \"{3}\"){4};\n", autoVar, className, scope, - getPyClassName(decl->getNameAsString()), additional); + "\n{0} = nb::class_<{1}>({2}, \"{3}\"){4};\n", autoVar, cppClass, scope, + pyClassName, additional); return true; } @@ -396,10 +470,9 @@ static bool emitEnum(clang::EnumDecl *decl, clang::CompilerInstance &ci, cstDecl->getQualifiedNameAsString()); if (i++ < nDecls - 1) outputFile->os() << "\n"; - else - outputFile->os() << ";\n"; } - outputFile->os() << "\n"; + + outputFile->os() << ";\n"; return true; } @@ -422,8 +495,7 @@ static bool emitField(clang::DeclaratorDecl *field, clang::CompilerInstance &ci, if (field->getType()->hasPointerRepresentation()) refInternal = ", nb::rv_policy::reference_internal"; - std::string scope = - getNBBindClassName(parentRecord->getQualifiedNameAsString()); + std::string scope = getNBBindClassName(getCppClass(parentRecord)); std::string nbFnName = llvm::formatv("\"{0}\"", snakeCase(field->getNameAsString())); outputFile->os() << llvm::formatv("{0}.{1}({2}, &{3}{4});\n", scope, defStr, @@ -433,7 +505,7 @@ static bool emitField(clang::DeclaratorDecl *field, clang::CompilerInstance &ci, } template -static bool shouldSkip(T *decl) { +static bool shouldSkip(T *decl, clang::CompilerInstance &ci) { auto *encl = llvm::dyn_cast( decl->getEnclosingNamespaceContext()); if (!encl) @@ -442,6 +514,8 @@ static bool shouldSkip(T *decl) { // bind std:: if (encl->isStdNamespace() || encl->isInStdNamespace()) return true; + if (ci.getSema().getSourceManager().isInSystemHeader(decl->getLocation())) + return true; if (!filterInNamespace(encl->getQualifiedNameAsString())) return true; if constexpr (std::is_same_v || @@ -474,7 +548,9 @@ struct BindingsVisitor ci.getDiagnostics())) {} bool VisitCXXRecordDecl(clang::CXXRecordDecl *decl) { - if (shouldSkip(decl)) + if (shouldSkip(decl, ci)) + return true; + if (decl->isAbstract()) return true; if (decl->isClass() || decl->isStruct()) { if (emitClass(decl, ci, outputFile)) @@ -537,8 +613,10 @@ struct BindingsVisitor return true; } + // TODO(max): skip definitions somehow? like FloatType::getFloat4E2M1FN which + // has both the decl and the impl in a header? bool VisitCXXMethodDecl(clang::CXXMethodDecl *decl) { - if (shouldSkip(decl) || llvm::isa(decl) || + if (shouldSkip(decl, ci) || llvm::isa(decl) || !visitedRecords.contains(decl->getParent())) return true; if (decl->isOverloadedOperator() && @@ -560,12 +638,18 @@ struct BindingsVisitor "friend functions not supported")); return true; } + if (decl->isCopyAssignmentOperator() || decl->isMoveAssignmentOperator()) + return true; + if (decl->isDeleted()) + return true; + emitClassMethodOrFunction(decl, ci, outputFile); + return true; } bool VisitFunctionDecl(clang::FunctionDecl *decl) { - if (shouldSkip(decl) || decl->isCXXClassMember()) + if (shouldSkip(decl, ci) || decl->isCXXClassMember()) return true; // clang-format off // this @@ -583,12 +667,19 @@ struct BindingsVisitor "template functions not supported yet")); return true; } + if (decl->getFriendObjectKind()) { + clang::DiagnosticBuilder builder = ci.getDiagnostics().Report( + decl->getLocation(), ci.getDiagnostics().getCustomDiagID( + clang::DiagnosticsEngine::Note, + "friend functions not supported")); + return true; + } emitClassMethodOrFunction(decl, ci, outputFile); return true; } bool VisitEnumDecl(clang::EnumDecl *decl) { - if (shouldSkip(decl)) + if (shouldSkip(decl, ci)) return true; if (decl->getQualifiedNameAsString().rfind("unnamed enum") != std::string::npos) @@ -613,6 +704,8 @@ struct BindingsVisitor // TODO(max): this is a hack and not stable bool VisitDecl(clang::Decl *decl) { + if (ci.getSema().getSourceManager().isInSystemHeader(decl->getLocation())) + return true; const clang::DeclContext *declContext = decl->getDeclContext(); HackDeclContext *ctx = static_cast(decl->getDeclContext()); @@ -802,13 +895,15 @@ namespace nb = nanobind; using namespace nb::literals; using namespace mlir; using namespace llvm; -#include "type_casters.h" +#include "eudsl/type_casters.h" +namespace eudsl { void populate)" << finalTarget << i << R"(Module(nb::module_ &m) { )"; // clang-format on shardFile << shards[i] << std::endl; shardFile << "}" << std::endl; + shardFile << "}" << std::endl; shardFile.flush(); shardFile.close(); } @@ -850,6 +945,7 @@ void populate)" << finalTarget << i << R"(Module(nb::module_ &m) { namespace nb = nanobind; using namespace nb::literals; +namespace eudsl { void populate)" << finalTarget << R"(Module(nb::module_ &m) { )"; // clang-format on @@ -862,6 +958,7 @@ void populate)" << finalTarget << R"(Module(nb::module_ &m) { finalShardedFile << "populate" << finalTarget << i << "Module(m);" << std::endl; + finalShardedFile << "}" << std::endl; finalShardedFile << "}" << std::endl; finalShardedFile.flush(); finalShardedFile.close(); diff --git a/projects/eudsl-py/CMakeLists.txt b/projects/eudsl-py/CMakeLists.txt index 7444726c..4fad14a1 100644 --- a/projects/eudsl-py/CMakeLists.txt +++ b/projects/eudsl-py/CMakeLists.txt @@ -53,6 +53,8 @@ if(EUDSLPY_STANDALONE_BUILD) # for out-of-tree MLIR_INCLUDE_DIR points to the build dir by default # and MLIR_INCLUDE_DIRS points to the correct place set(MLIR_INCLUDE_DIR ${MLIR_INCLUDE_DIRS}) + + include_directories(${CMAKE_CURRENT_LIST_DIR}/../common) endif() include_directories(${LLVM_INCLUDE_DIRS}) diff --git a/projects/eudsl-py/src/eudslpy_ext.cpp b/projects/eudsl-py/src/eudslpy_ext.cpp index 0b3fc7dc..c66fff48 100644 --- a/projects/eudsl-py/src/eudslpy_ext.cpp +++ b/projects/eudsl-py/src/eudslpy_ext.cpp @@ -3,12 +3,6 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // Copyright (c) 2024. -#include -#include -#include -#include -#include - #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" @@ -57,12 +51,16 @@ #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ThreadPool.h" -#include "bind_vec_like.h" -#include "type_casters.h" +#include + +// ReSharper disable once CppUnusedIncludeDirective +#include "eudsl/type_casters.h" +#include "eudsl/util.h" namespace nb = nanobind; using namespace nb::literals; +namespace eudsl { class FakeDialect : public mlir::Dialect { public: FakeDialect(llvm::StringRef name, mlir::MLIRContext *context, mlir::TypeID id) @@ -73,63 +71,7 @@ nb::class_<_SmallVector> smallVector; nb::class_<_ArrayRef> arrayRef; nb::class_<_MutableArrayRef> mutableArrayRef; -void bind_array_ref_smallvector(nb::handle scope) { - scope.attr("T") = nb::type_var("T"); - arrayRef = nb::class_<_ArrayRef>(scope, "ArrayRef", nb::is_generic(), - nb::sig("class ArrayRef[T]")); - mutableArrayRef = - nb::class_<_MutableArrayRef>(scope, "MutableArrayRef", nb::is_generic(), - nb::sig("class MutableArrayRef[T]")); - smallVector = nb::class_<_SmallVector>(scope, "SmallVector", nb::is_generic(), - nb::sig("class SmallVector[T]")); -} - -template -struct non_copying_non_moving_class_ : nb::class_ { - template - NB_INLINE non_copying_non_moving_class_(nb::handle scope, const char *name, - const Extra &...extra) { - nb::detail::type_init_data d; - - d.flags = 0; - d.align = (uint8_t)alignof(typename nb::class_::Alias); - d.size = (uint32_t)sizeof(typename nb::class_::Alias); - d.name = name; - d.scope = scope.ptr(); - d.type = &typeid(T); - - if constexpr (!std::is_same_v::Base, T>) { - d.base = &typeid(typename nb::class_::Base); - d.flags |= (uint32_t)nb::detail::type_init_flags::has_base; - } - - if constexpr (std::is_destructible_v) { - d.flags |= (uint32_t)nb::detail::type_flags::is_destructible; - - if constexpr (!std::is_trivially_destructible_v) { - d.flags |= (uint32_t)nb::detail::type_flags::has_destruct; - d.destruct = nb::detail::wrap_destruct; - } - } - - if constexpr (nb::detail::has_shared_from_this_v) { - d.flags |= (uint32_t)nb::detail::type_flags::has_shared_from_this; - d.keep_shared_from_this_alive = [](PyObject *self) noexcept { - if (auto sp = nb::inst_ptr(self)->weak_from_this().lock()) { - nb::detail::keep_alive( - self, new auto(std::move(sp)), - [](void *p) noexcept { delete (decltype(sp) *)p; }); - return true; - } - return false; - }; - } - - (nb::detail::type_extra_apply(d, extra), ...); - - this->m_ptr = nb::detail::nb_type_new(&d); - } -}; +extern void populateEUDSLGen_IR0Module(nb::module_ &m); void populateIRModule(nb::module_ &m) { using namespace mlir; @@ -358,8 +300,10 @@ extern void populateEUDSLGen_x86vectorModule(nb::module_ &m); // extern void populateEUDSLGen_xegpuModule(nb::module_ &m); +} // namespace eudsl + NB_MODULE(eudslpy_ext, m) { - bind_array_ref_smallvector(m); + eudsl::bind_array_ref_smallvector(m); nb::class_(m, "APFloat"); nb::class_(m, "APInt"); @@ -381,9 +325,6 @@ NB_MODULE(eudslpy_ext, m) { nb::class_(m, "TypeID"); nb::class_(m, "InterfaceMap"); - auto irModule = m.def_submodule("ir"); - populateIRModule(irModule); - nb::class_>(m, "FailureOr[bool]"); nb::class_>(m, "FailureOr[StringAttr]"); nb::class_>( @@ -429,64 +370,64 @@ NB_MODULE(eudslpy_ext, m) { nb::class_(m, "BitVector"); auto [smallVectorOfBool, arrayRefOfBool, mutableArrayRefOfBool] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfFloat, arrayRefOfFloat, mutableArrayRefOfFloat] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfInt, arrayRefOfInt, mutableArrayRefOfInt] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfChar, arrayRefOfChar, mutableArrayRefOfChar] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfDouble, arrayRefOfDouble, mutableArrayRefOfDouble] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfInt16, arrayRefOfInt16, mutableArrayRefOfInt16] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfInt32, arrayRefOfInt32, mutableArrayRefOfInt32] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfInt64, arrayRefOfInt64, mutableArrayRefOfInt64] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfUInt16, arrayRefOfUInt16, mutableArrayRefOfUInt16] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfUInt32, arrayRefOfUInt32, mutableArrayRefOfUInt32] = - bind_array_ref(m); + eudsl::bind_array_ref(m); auto [smallVectorOfUInt64, arrayRefOfUInt64, mutableArrayRefOfUInt64] = - bind_array_ref(m); + eudsl::bind_array_ref(m); // these have to precede... - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - - bind_array_ref(m); - - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - - bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - - bind_array_ref(m); - bind_array_ref(m); - // bind_array_ref(m); - bind_array_ref(m); - bind_array_ref(m); - - smallVector.def_static( + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + + eudsl::bind_array_ref(m); + + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + // eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + eudsl::bind_array_ref(m); + + eudsl::smallVector.def_static( "__class_getitem__", [smallVectorOfBool, smallVectorOfInt, smallVectorOfFloat, smallVectorOfInt16, smallVectorOfInt32, smallVectorOfInt64, @@ -533,7 +474,7 @@ NB_MODULE(eudslpy_ext, m) { throw std::runtime_error(errMsg); }); - smallVector.def_static( + eudsl::smallVector.def_static( "__class_getitem__", [smallVectorOfFloat, smallVectorOfInt16, smallVectorOfInt32, smallVectorOfInt64, smallVectorOfUInt16, smallVectorOfUInt32, @@ -574,17 +515,21 @@ NB_MODULE(eudslpy_ext, m) { nb::class_>( m, "iterator_range[ResultRange.UseIterator]"); - bind_iter_range, mlir::Type>( + eudsl::bind_iter_range, mlir::Type>( m, "ValueTypeRange[ValueRange]"); - bind_iter_range, mlir::Type>( + eudsl::bind_iter_range, mlir::Type>( m, "ValueTypeRange[OperandRange]"); - bind_iter_range, mlir::Type>( + eudsl::bind_iter_range, mlir::Type>( m, "ValueTypeRange[ResultRange]"); - bind_iter_like, nb::rv_policy::reference_internal>( - m, "iplist[Block]"); - bind_iter_like, - nb::rv_policy::reference_internal>(m, "iplist[Operation]"); + eudsl::bind_iter_like, + nb::rv_policy::reference_internal>(m, "iplist[Block]"); + eudsl::bind_iter_like, + nb::rv_policy::reference_internal>(m, + "iplist[Operation]"); + + auto irModule = m.def_submodule("ir"); + eudsl::populateIRModule(irModule); auto dialectsModule = m.def_submodule("dialects"); @@ -592,134 +537,134 @@ NB_MODULE(eudslpy_ext, m) { // populateEUDSLGen_accModule(accModule); auto affineModule = dialectsModule.def_submodule("affine"); - populateEUDSLGen_affineModule(affineModule); + eudsl::populateEUDSLGen_affineModule(affineModule); auto amdgpuModule = dialectsModule.def_submodule("amdgpu"); - populateEUDSLGen_amdgpuModule(amdgpuModule); + eudsl::populateEUDSLGen_amdgpuModule(amdgpuModule); // auto amxModule = dialectsModule.def_submodule("amx"); - // populateEUDSLGen_amxModule(amxModule); + // eudsl::populateEUDSLGen_amxModule(amxModule); auto arithModule = dialectsModule.def_submodule("arith"); - populateEUDSLGen_arithModule(arithModule); + eudsl::populateEUDSLGen_arithModule(arithModule); // auto arm_neonModule = dialectsModule.def_submodule("arm_neon"); - // populateEUDSLGen_arm_neonModule(arm_neonModule); + // eudsl::populateEUDSLGen_arm_neonModule(arm_neonModule); // auto arm_smeModule = dialectsModule.def_submodule("arm_sme"); - // populateEUDSLGen_arm_smeModule(arm_smeModule); + // eudsl::populateEUDSLGen_arm_smeModule(arm_smeModule); // auto arm_sveModule = dialectsModule.def_submodule("arm_sve"); - // populateEUDSLGen_arm_sveModule(arm_sveModule); + // eudsl::populateEUDSLGen_arm_sveModule(arm_sveModule); auto asyncModule = dialectsModule.def_submodule("async"); - populateEUDSLGen_asyncModule(asyncModule); + eudsl::populateEUDSLGen_asyncModule(asyncModule); auto bufferizationModule = dialectsModule.def_submodule("bufferization"); - populateEUDSLGen_bufferizationModule(bufferizationModule); + eudsl::populateEUDSLGen_bufferizationModule(bufferizationModule); auto cfModule = dialectsModule.def_submodule("cf"); - populateEUDSLGen_cfModule(cfModule); + eudsl::populateEUDSLGen_cfModule(cfModule); auto complexModule = dialectsModule.def_submodule("complex"); - populateEUDSLGen_complexModule(complexModule); + eudsl::populateEUDSLGen_complexModule(complexModule); // auto DLTIDialectModule = dialectsModule.def_submodule("DLTIDialect"); - // populateEUDSLGen_DLTIDialectModule(DLTIDialectModule); + // eudsl::populateEUDSLGen_DLTIDialectModule(DLTIDialectModule); auto emitcModule = dialectsModule.def_submodule("emitc"); - populateEUDSLGen_emitcModule(emitcModule); + eudsl::populateEUDSLGen_emitcModule(emitcModule); auto funcModule = dialectsModule.def_submodule("func"); - populateEUDSLGen_funcModule(funcModule); + eudsl::populateEUDSLGen_funcModule(funcModule); auto gpuModule = dialectsModule.def_submodule("gpu"); - populateEUDSLGen_gpuModule(gpuModule); + eudsl::populateEUDSLGen_gpuModule(gpuModule); auto indexModule = dialectsModule.def_submodule("index"); - populateEUDSLGen_indexModule(indexModule); + eudsl::populateEUDSLGen_indexModule(indexModule); // auto irdlModule = dialectsModule.def_submodule("irdl"); - // populateEUDSLGen_irdlModule(irdlModule); + // eudsl::populateEUDSLGen_irdlModule(irdlModule); auto linalgModule = dialectsModule.def_submodule("linalg"); - populateEUDSLGen_linalgModule(linalgModule); + eudsl::populateEUDSLGen_linalgModule(linalgModule); auto LLVMModule = dialectsModule.def_submodule("LLVM"); - populateEUDSLGen_LLVMModule(LLVMModule); + eudsl::populateEUDSLGen_LLVMModule(LLVMModule); auto mathModule = dialectsModule.def_submodule("math"); - populateEUDSLGen_mathModule(mathModule); + eudsl::populateEUDSLGen_mathModule(mathModule); auto memrefModule = dialectsModule.def_submodule("memref"); - populateEUDSLGen_memrefModule(memrefModule); + eudsl::populateEUDSLGen_memrefModule(memrefModule); // auto meshModule = dialectsModule.def_submodule("mesh"); - // populateEUDSLGen_meshModule(meshModule); + // eudsl::populateEUDSLGen_meshModule(meshModule); // auto ml_programModule = dialectsModule.def_submodule("ml_program"); - // populateEUDSLGen_ml_programModule(ml_programModule); + // eudsl::populateEUDSLGen_ml_programModule(ml_programModule); // auto mpiModule = dialectsModule.def_submodule("mpi"); - // populateEUDSLGen_mpiModule(mpiModule); + // eudsl::populateEUDSLGen_mpiModule(mpiModule); auto nvgpuModule = dialectsModule.def_submodule("nvgpu"); - populateEUDSLGen_nvgpuModule(nvgpuModule); + eudsl::populateEUDSLGen_nvgpuModule(nvgpuModule); auto NVVMModule = dialectsModule.def_submodule("NVVM"); - populateEUDSLGen_NVVMModule(NVVMModule); + eudsl::populateEUDSLGen_NVVMModule(NVVMModule); // auto ompModule = dialectsModule.def_submodule("omp"); - // populateEUDSLGen_ompModule(ompModule); + // eudsl::populateEUDSLGen_ompModule(ompModule); auto pdlModule = dialectsModule.def_submodule("pdl"); - populateEUDSLGen_pdlModule(pdlModule); + eudsl::populateEUDSLGen_pdlModule(pdlModule); auto pdl_interpModule = dialectsModule.def_submodule("pdl_interp"); - populateEUDSLGen_pdl_interpModule(pdl_interpModule); + eudsl::populateEUDSLGen_pdl_interpModule(pdl_interpModule); auto polynomialModule = dialectsModule.def_submodule("polynomial"); - populateEUDSLGen_polynomialModule(polynomialModule); + eudsl::populateEUDSLGen_polynomialModule(polynomialModule); // auto ptrModule = dialectsModule.def_submodule("ptr"); - // populateEUDSLGen_ptrModule(ptrModule); + // eudsl::populateEUDSLGen_ptrModule(ptrModule); // auto quantModule = dialectsModule.def_submodule("quant"); - // populateEUDSLGen_quantModule(quantModule); + // eudsl::populateEUDSLGen_quantModule(quantModule); auto ROCDLModule = dialectsModule.def_submodule("ROCDL"); - populateEUDSLGen_ROCDLModule(ROCDLModule); + eudsl::populateEUDSLGen_ROCDLModule(ROCDLModule); auto scfModule = dialectsModule.def_submodule("scf"); - populateEUDSLGen_scfModule(scfModule); + eudsl::populateEUDSLGen_scfModule(scfModule); auto shapeModule = dialectsModule.def_submodule("shape"); - populateEUDSLGen_shapeModule(shapeModule); + eudsl::populateEUDSLGen_shapeModule(shapeModule); // auto sparse_tensorModule = dialectsModule.def_submodule("sparse_tensor"); - // populateEUDSLGen_sparse_tensorModule(sparse_tensorModule); + // eudsl::populateEUDSLGen_sparse_tensorModule(sparse_tensorModule); // auto spirvModule = dialectsModule.def_submodule("spirv"); - // populateEUDSLGen_spirvModule(spirvModule); + // eudsl::populateEUDSLGen_spirvModule(spirvModule); auto tensorModule = dialectsModule.def_submodule("tensor"); - populateEUDSLGen_tensorModule(tensorModule); + eudsl::populateEUDSLGen_tensorModule(tensorModule); auto tosaModule = dialectsModule.def_submodule("tosa"); - populateEUDSLGen_tosaModule(tosaModule); + eudsl::populateEUDSLGen_tosaModule(tosaModule); // auto transformModule = dialectsModule.def_submodule("transform"); - // populateEUDSLGen_transformModule(transformModule); + // eudsl::populateEUDSLGen_transformModule(transformModule); // auto ubModule = dialectsModule.def_submodule("ub"); - // populateEUDSLGen_ubModule(ubModule); + // eudsl::populateEUDSLGen_ubModule(ubModule); // auto vectorModule = dialectsModule.def_submodule("vector"); - // populateEUDSLGen_vectorModule(vectorModule); + // eudsl::populateEUDSLGen_vectorModule(vectorModule); // auto x86vectorModule = dialectsModule.def_submodule("x86vector"); - // populateEUDSLGen_x86vectorModule(x86vectorModule); + // eudsl::populateEUDSLGen_x86vectorModule(x86vectorModule); // auto xegpuModule = dialectsModule.def_submodule("xegpu"); - // populateEUDSLGen_xegpuModule(xegpuModule); + // eudsl::populateEUDSLGen_xegpuModule(xegpuModule); } diff --git a/projects/eudsl-tblgen/CMakeLists.txt b/projects/eudsl-tblgen/CMakeLists.txt index 94e8a5bc..d05fbc24 100644 --- a/projects/eudsl-tblgen/CMakeLists.txt +++ b/projects/eudsl-tblgen/CMakeLists.txt @@ -15,24 +15,26 @@ if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_LIST_DIR) message("Building ${LLVM_SUBPROJECT_TITLE} as a standalone project.") project(${LLVM_SUBPROJECT_TITLE} CXX C) find_package(LLVM REQUIRED CONFIG) + find_package(MLIR REQUIRED CONFIG) message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") + list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") include(TableGen) include(AddLLVM) - # TODO(max): probably don't need this anymore after landing the nanobind fix? - # technically we need this on windows too but our LLVM is compiled without exception handling - # and that breaks windows - if(NOT WIN32) - include(HandleLLVMOptions) - endif() + include(AddMLIR) + include(HandleLLVMOptions) + + include_directories(${CMAKE_CURRENT_LIST_DIR}/../common) endif() +include_directories(${MLIR_INCLUDE_DIRS}) include_directories(${LLVM_INCLUDE_DIRS}) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) @@ -53,19 +55,27 @@ find_package(nanobind CONFIG REQUIRED) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${EUDSL_TBLGEN_SRC_DIR}/eudsl_tblgen) set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) -nanobind_add_module(eudsl_tblgen_ext NB_STATIC STABLE_ABI +nanobind_add_module(eudsl_tblgen_ext NB_STATIC src/eudsl_tblgen_ext.cpp src/TGParser.cpp src/TGLexer.cpp ) -target_link_libraries(eudsl_tblgen_ext - PRIVATE LLVMTableGenCommon LLVMTableGen LLVMCore) -target_compile_options(eudsl_tblgen_ext - PUBLIC +set_property(TARGET eudsl_tblgen_ext PROPERTY POSITION_INDEPENDENT_CODE ON) +target_link_libraries(eudsl_tblgen_ext PRIVATE + LLVMTableGenCommon LLVMTableGen LLVMCore MLIRTableGen) +set(nanobind_options -Wno-cast-qual + -Wno-deprecated-literal-operator + -Wno-covered-switch-default + -Wno-nested-anon-types + -Wno-zero-length-array + -Wno-c++98-compat-extra-semi $<$:-fexceptions -frtti> $<$:-fexceptions -frtti> - $<$:/EHsc /GR>) + $<$:/EHsc /GR> +) +target_compile_options(eudsl_tblgen_ext PRIVATE ${nanobind_options}) +target_compile_options(nanobind-static PRIVATE ${nanobind_options}) # note WORKING_DIRECTORY set(NB_STUBGEN_CMD "${Python_EXECUTABLE}" "-m" "nanobind.stubgen" diff --git a/projects/eudsl-tblgen/pyproject.toml b/projects/eudsl-tblgen/pyproject.toml index 48245c7e..ed4171ed 100644 --- a/projects/eudsl-tblgen/pyproject.toml +++ b/projects/eudsl-tblgen/pyproject.toml @@ -23,6 +23,7 @@ wheel.packages = ["src/eudsl_tblgen"] [tool.scikit-build.cmake.define] LLVM_DIR = { env = "LLVM_DIR", default = "EMPTY" } +MLIR_DIR = { env = "MLIR_DIR", default = "EMPTY" } CMAKE_CXX_VISIBILITY_PRESET = "hidden" CMAKE_C_COMPILER_LAUNCHER = { env = "CMAKE_C_COMPILER_LAUNCHER", default = "" } CMAKE_CXX_COMPILER_LAUNCHER = { env = "CMAKE_CXX_COMPILER_LAUNCHER", default = "" } @@ -35,6 +36,7 @@ archs = ["auto64"] manylinux-x86_64-image = "manylinux_2_28" environment-pass = [ "LLVM_DIR", + "MLIR_DIR", "CMAKE_GENERATOR", "CMAKE_PREFIX_PATH", "CC", diff --git a/projects/eudsl-tblgen/src/TGLexer.cpp b/projects/eudsl-tblgen/src/TGLexer.cpp index ff1c73a4..961831f3 100644 --- a/projects/eudsl-tblgen/src/TGLexer.cpp +++ b/projects/eudsl-tblgen/src/TGLexer.cpp @@ -15,14 +15,11 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Twine.h" -#include "llvm/Config/llvm-config.h" // for strtoull()/strtoll() define #include "llvm/Support/Compiler.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" #include "llvm/TableGen/Error.h" -#include #include -#include #include #include #include diff --git a/projects/eudsl-tblgen/src/eudsl_tblgen/__init__.py b/projects/eudsl-tblgen/src/eudsl_tblgen/__init__.py index e21275a5..c8aa205c 100644 --- a/projects/eudsl-tblgen/src/eudsl_tblgen/__init__.py +++ b/projects/eudsl-tblgen/src/eudsl_tblgen/__init__.py @@ -2,5 +2,94 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Copyright (c) 2024. +from typing import List, Optional from .eudsl_tblgen_ext import * + + +import re + + +def get_operation_name(def_record): + prefix = def_record.get_value_as_def("opDialect").get_value_as_string("name") + op_name = def_record.get_value_as_string("opName") + + if not prefix: + return op_name + return f"{prefix}.{op_name}" + + +def get_requested_op_definitions(records, op_inc_filter=None, op_exc_filter=None): + class_def = records.get_class("Op") + if not class_def: + raise RuntimeError("ERROR: Couldn't find the 'Op' class!") + + if op_inc_filter: + include_regex = re.compile(op_inc_filter) + if op_exc_filter: + exclude_regex = re.compile(op_exc_filter) + defs = [] + + for def_name in records.get_defs(): + def_record = records.get_defs()[def_name] + if not def_record.is_sub_class_of(class_def): + continue + # Include if no include filter or include filter matches. + if op_inc_filter and not include_regex.match(get_operation_name(def_record)): + continue + # Unless there is an exclude filter and it matches. + if op_exc_filter and exclude_regex.match(get_operation_name(def_record)): + continue + defs.append(def_record) + + return defs + + +def collect_all_defs( + record_keeper: RecordKeeper, + selected_dialect: Optional[str] = None, +) -> List[AttrOrTypeDef]: + records = record_keeper.get_defs() + records = [records[d] for d in records] + # Nothing to do if no defs were found. + if not records: + return [] + + defs = [ + AttrOrTypeDef(rec) + for rec in records + if rec.get_value("builders") and rec.get_value("parameters") + ] + result_defs = [] + + if not selected_dialect: + # If a dialect was not specified, ensure that all found defs belong to the same dialect. + dialects = {definition.get_dialect().get_name() for definition in defs} + if len(dialects) > 1: + raise RuntimeError( + "Defs belong to more than one dialect. Must select one via '--(attr|type)defs-dialect'" + ) + result_defs.extend(defs) + else: + # Otherwise, generate the defs that belong to the selected dialect. + dialect_defs = [ + definition + for definition in defs + if definition.get_dialect().get_name() == selected_dialect + ] + result_defs.extend(dialect_defs) + + return result_defs + + +def get_all_type_constraints(records: RecordKeeper) -> List[Constraint]: + result = [] + for record in records.get_all_derived_definitions_if_defined("TypeConstraint"): + # Ignore constraints defined outside of the top-level file. + constr = Constraint(record) + # Generate C++ function only if "cppFunctionName" is set. + if not constr.get_cpp_function_name(): + continue + result.append(constr) + return result + diff --git a/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp b/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp index 07b4d3b4..9e5354d9 100644 --- a/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp +++ b/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp @@ -4,6 +4,29 @@ // Copyright (c) 2024. #include "TGParser.h" +#include "mlir/TableGen/Argument.h" +#include "mlir/TableGen/AttrOrTypeDef.h" +#include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Builder.h" +#include "mlir/TableGen/Class.h" +#include "mlir/TableGen/CodeGenHelpers.h" +#include "mlir/TableGen/Constraint.h" +#include "mlir/TableGen/Dialect.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/GenNameParser.h" +#include "mlir/TableGen/Interfaces.h" +#include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/Pass.h" +#include "mlir/TableGen/Pattern.h" +#include "mlir/TableGen/Predicate.h" +#include "mlir/TableGen/Property.h" +#include "mlir/TableGen/Region.h" +#include "mlir/TableGen/SideEffects.h" +#include "mlir/TableGen/Successor.h" +#include "mlir/TableGen/Trait.h" +#include "mlir/TableGen/Type.h" +#include "llvm/ADT/Hashing.h" #include "llvm/IR/Intrinsics.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SourceMgr.h" @@ -11,97 +34,125 @@ #include #include -#include +// ReSharper disable once CppUnusedIncludeDirective #include #include +// ReSharper disable once CppUnusedIncludeDirective #include -using namespace llvm; +#include "eudsl/util.h" +// ReSharper disable once CppUnusedIncludeDirective +#include "eudsl/type_casters.h" namespace nb = nanobind; using namespace nb::literals; -template -constexpr auto coerceReturn(Return (*pf)(Args...)) noexcept { - return [&pf](Args &&...args) -> NewReturn { - return pf(std::forward(args)...); - }; -} - -template -constexpr auto coerceReturn(Return (Class::*pmf)(Args...), - std::false_type = {}) noexcept { - return [&pmf](Class *cls, Args &&...args) -> NewReturn { - return (cls->*pmf)(std::forward(args)...); - }; -} - -/* - * If you get - * ``` - * Called object type 'void(MyClass::*)(vector&,int)' is not a function or - * function pointer - * ``` - * it's because you're calling a member function without - * passing the `this` pointer as the first arg - */ -template -constexpr auto coerceReturn(Return (Class::*pmf)(Args...) const, - std::true_type) noexcept { - // copy the *pmf, not capture by ref - return [pmf](const Class &cls, Args &&...args) -> NewReturn { - return (cls.*pmf)(std::forward(args)...); - }; -} - -template <> -struct nb::detail::type_caster { - NB_TYPE_CASTER(StringRef, const_name("str")) - - bool from_python(handle src, uint8_t, cleanup_list *) noexcept { - Py_ssize_t size; - const char *str = PyUnicode_AsUTF8AndSize(src.ptr(), &size); - if (!str) { - PyErr_Clear(); - return false; - } - value = StringRef(str, (size_t)size); - return true; - } - - static handle from_cpp(StringRef value, rv_policy, cleanup_list *) noexcept { - return PyUnicode_FromStringAndSize(value.data(), value.size()); - } -}; - // hack to expose protected Init::InitKind -struct HackInit : public Init { +struct HackInit : public llvm::Init { using InitKind = Init::InitKind; }; NB_MODULE(eudsl_tblgen_ext, m) { - auto recty = nb::class_(m, "RecTy"); - - nb::enum_(m, "RecTyKind") - .value("BitRecTyKind", RecTy::RecTyKind::BitRecTyKind) - .value("BitsRecTyKind", RecTy::RecTyKind::BitsRecTyKind) - .value("IntRecTyKind", RecTy::RecTyKind::IntRecTyKind) - .value("StringRecTyKind", RecTy::RecTyKind::StringRecTyKind) - .value("ListRecTyKind", RecTy::RecTyKind::ListRecTyKind) - .value("DagRecTyKind", RecTy::RecTyKind::DagRecTyKind) - .value("RecordRecTyKind", RecTy::RecTyKind::RecordRecTyKind); - - recty.def_prop_ro("rec_ty_kind", &RecTy::getRecTyKind) - .def_prop_ro("record_keeper", &RecTy::getRecordKeeper) - .def_prop_ro("as_string", &RecTy::getAsString) - .def("__str__", &RecTy::getAsString) - .def("type_is_a", &RecTy::typeIsA, "rhs"_a) - .def("type_is_convertible_to", &RecTy::typeIsConvertibleTo, "rhs"_a); - - nb::class_(m, "RecordRecTy") - .def_prop_ro("classes", coerceReturn>( - &RecordRecTy::getClasses, nb::const_)) - .def("is_sub_class_of", &RecordRecTy::isSubClassOf, "class_"_a); + auto recty = nb::class_(m, "RecTy"); + + nb::enum_(m, "RecTyKind") + .value("BitRecTyKind", llvm::RecTy::RecTyKind::BitRecTyKind) + .value("BitsRecTyKind", llvm::RecTy::RecTyKind::BitsRecTyKind) + .value("IntRecTyKind", llvm::RecTy::RecTyKind::IntRecTyKind) + .value("StringRecTyKind", llvm::RecTy::RecTyKind::StringRecTyKind) + .value("ListRecTyKind", llvm::RecTy::RecTyKind::ListRecTyKind) + .value("DagRecTyKind", llvm::RecTy::RecTyKind::DagRecTyKind) + .value("RecordRecTyKind", llvm::RecTy::RecTyKind::RecordRecTyKind); + + recty.def("get_rec_ty_kind", &llvm::RecTy::getRecTyKind) + .def("get_record_keeper", &llvm::RecTy::getRecordKeeper, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::RecTy::getAsString) + .def("__str__", &llvm::RecTy::getAsString) + .def("print", &llvm::RecTy::print, "os"_a) + .def("dump", &llvm::RecTy::dump) + .def("type_is_convertible_to", &llvm::RecTy::typeIsConvertibleTo, "rhs"_a) + .def("type_is_a", &llvm::RecTy::typeIsA, "rhs"_a) + .def("get_list_ty", &llvm::RecTy::getListTy, + nb::rv_policy::reference_internal); + + nb::class_(m, "BitRecTy") + .def_static("classof", &llvm::BitRecTy::classof, "rt"_a) + .def_static("get", &llvm::BitRecTy::get, "rk"_a, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::BitRecTy::getAsString) + .def("__str__", &llvm::BitRecTy::getAsString) + .def("type_is_convertible_to", &llvm::BitRecTy::typeIsConvertibleTo, + "rhs"_a); + + nb::class_(m, "IntRecTy") + .def_static("classof", &llvm::IntRecTy::classof, "rt"_a) + .def_static("get", &llvm::IntRecTy::get, "rk"_a, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::IntRecTy::getAsString) + .def("__str__", &llvm::IntRecTy::getAsString) + .def("type_is_convertible_to", &llvm::IntRecTy::typeIsConvertibleTo, + "rhs"_a); + + nb::class_(m, "StringRecTy") + .def_static("classof", &llvm::StringRecTy::classof, "rt"_a) + .def_static("get", &llvm::StringRecTy::get, "rk"_a, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::StringRecTy::getAsString) + .def("__init__", &llvm::StringRecTy::getAsString) + .def("type_is_convertible_to", &llvm::StringRecTy::typeIsConvertibleTo, + "rhs"_a); + + nb::class_(m, "ListRecTy") + .def_static("classof", &llvm::ListRecTy::classof, "rt"_a) + .def_static("get", &llvm::ListRecTy::get, "t"_a, + nb::rv_policy::reference_internal) + .def("get_element_type", &llvm::ListRecTy::getElementType, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::ListRecTy::getAsString) + .def("__init__", &llvm::ListRecTy::getAsString) + .def("type_is_convertible_to", &llvm::ListRecTy::typeIsConvertibleTo, + "rhs"_a) + .def("type_is_a", &llvm::ListRecTy::typeIsA, "rhs"_a); + + nb::class_(m, "DagRecTy") + .def_static("classof", &llvm::DagRecTy::classof, "rt"_a) + .def_static("get", &llvm::DagRecTy::get, "rk"_a, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::DagRecTy::getAsString) + .def("__init__", &llvm::DagRecTy::getAsString); + + nb::class_(m, "RecordRecTy") + .def_static("classof", &llvm::RecordRecTy::classof, "rt"_a) + .def_static( + "get", + [](llvm::RecordKeeper &RK, + llvm::ArrayRef Classes) + -> const llvm::RecordRecTy * { + return llvm::RecordRecTy::get(RK, Classes); + }, + "rk"_a, "classes"_a, nb::rv_policy::reference_internal) + .def_static( + "get", + [](const llvm::Record *Class) -> const llvm::RecordRecTy * { + return llvm::RecordRecTy::get(Class); + }, + "class"_a, nb::rv_policy::reference_internal) + .def("profile", &llvm::RecordRecTy::Profile, "id"_a) + .def("get_classes", + eudsl::coerceReturn>( + &llvm::RecordRecTy::getClasses, nb::const_), + nb::rv_policy::reference_internal) + .def("classes_begin", &llvm::RecordRecTy::classes_begin, + nb::rv_policy::reference_internal) + .def("classes_end", &llvm::RecordRecTy::classes_end, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::RecordRecTy::getAsString) + .def("__str__", &llvm::RecordRecTy::getAsString) + .def("is_sub_class_of", &llvm::RecordRecTy::isSubClassOf, "class"_a) + .def("type_is_convertible_to", &llvm::RecordRecTy::typeIsConvertibleTo, + "rhs"_a) + .def("type_is_a", &llvm::RecordRecTy::typeIsA, "rhs"_a); nb::enum_(m, "InitKind") .value("IK_FirstTypedInit", HackInit::InitKind::IK_FirstTypedInit) @@ -130,303 +181,715 @@ NB_MODULE(eudsl_tblgen_ext, m) { .value("IK_UnsetInit", HackInit::InitKind::IK_UnsetInit) .value("IK_ArgumentInit", HackInit::InitKind::IK_ArgumentInit); - nb::class_(m, "Init") - .def_prop_ro("kind", &Init::getKind) - .def_prop_ro("as_string", &Init::getAsUnquotedString) - .def("__str__", &Init::getAsUnquotedString) - .def("is_complete", &Init::isComplete) - .def("is_concrete", &Init::isConcrete) - .def("get_field_type", &Init::getFieldType, "field_name"_a, + nb::class_(m, "Init") + .def("get_kind", &llvm::Init::getKind) + .def("get_record_keeper", &llvm::Init::getRecordKeeper, + nb::rv_policy::reference_internal) + .def("is_complete", &llvm::Init::isComplete) + .def("is_concrete", &llvm::Init::isConcrete) + .def("print", &llvm::Init::print, "os"_a) + .def("get_as_string", &llvm::Init::getAsString) + .def("__str__", &llvm::Init::getAsUnquotedString) + .def("get_as_unquoted_string", &llvm::Init::getAsUnquotedString) + .def("dump", &llvm::Init::dump) + .def("get_cast_to", &llvm::Init::getCastTo, "ty"_a, + nb::rv_policy::reference_internal) + .def("convert_initializer_to", &llvm::Init::convertInitializerTo, "ty"_a, + nb::rv_policy::reference_internal) + .def("convert_initializer_bit_range", + &llvm::Init::convertInitializerBitRange, "bits"_a, nb::rv_policy::reference_internal) - .def("get_bit", &Init::getBit, "bit"_a, + .def("get_field_type", &llvm::Init::getFieldType, "field_name"_a, + nb::rv_policy::reference_internal) + .def("resolve_references", &llvm::Init::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("get_bit", &llvm::Init::getBit, "bit"_a, nb::rv_policy::reference_internal); - nb::class_(m, "TypedInit") - .def_prop_ro("record_keeper", &TypedInit::getRecordKeeper) - .def_prop_ro("type", &TypedInit::getType); + nb::class_(m, "TypedInit") + .def_static("classof", &llvm::TypedInit::classof, "i"_a) + .def("get_type", &llvm::TypedInit::getType, + nb::rv_policy::reference_internal) + .def("get_record_keeper", &llvm::TypedInit::getRecordKeeper, + nb::rv_policy::reference_internal) + .def("get_cast_to", &llvm::TypedInit::getCastTo, "ty"_a, + nb::rv_policy::reference_internal) + .def("convert_initializer_to", &llvm::TypedInit::convertInitializerTo, + "ty"_a, nb::rv_policy::reference_internal) + .def("convert_initializer_bit_range", + &llvm::TypedInit::convertInitializerBitRange, "bits"_a, + nb::rv_policy::reference_internal) + .def("get_field_type", &llvm::TypedInit::getFieldType, "field_name"_a, + nb::rv_policy::reference_internal); + + nb::class_(m, "UnsetInit") + .def_static("classof", &llvm::UnsetInit::classof, "i"_a) + .def_static("get", &llvm::UnsetInit::get, "rk"_a, + nb::rv_policy::reference_internal) + .def("get_record_keeper", &llvm::UnsetInit::getRecordKeeper, + nb::rv_policy::reference_internal) + .def("get_cast_to", &llvm::UnsetInit::getCastTo, "ty"_a, + nb::rv_policy::reference_internal) + .def("convert_initializer_to", &llvm::UnsetInit::convertInitializerTo, + "ty"_a, nb::rv_policy::reference_internal) + .def("get_bit", &llvm::UnsetInit::getBit, "bit"_a, + nb::rv_policy::reference_internal) + .def("is_complete", &llvm::UnsetInit::isComplete) + .def("is_concrete", &llvm::UnsetInit::isConcrete) + .def("get_as_string", &llvm::UnsetInit::getAsString) + .def("__str__", &llvm::UnsetInit::getAsString); + + auto llvm_ArgumentInit = nb::class_(m, "ArgumentInit"); + + nb::enum_(llvm_ArgumentInit, "Kind") + .value("Positional", llvm::ArgumentInit::Positional) + .value("Named", llvm::ArgumentInit::Named); + + llvm_ArgumentInit.def_static("classof", &llvm::ArgumentInit::classof, "i"_a) + .def("get_record_keeper", &llvm::ArgumentInit::getRecordKeeper, + nb::rv_policy::reference_internal) + .def_static("get", &llvm::ArgumentInit::get, "value"_a, "aux"_a, + nb::rv_policy::reference_internal) + .def("is_positional", &llvm::ArgumentInit::isPositional) + .def("is_named", &llvm::ArgumentInit::isNamed) + .def("get_value", &llvm::ArgumentInit::getValue, + nb::rv_policy::reference_internal) + .def("get_index", &llvm::ArgumentInit::getIndex) + .def("get_name", &llvm::ArgumentInit::getName, + nb::rv_policy::reference_internal) + .def("clone_with_value", &llvm::ArgumentInit::cloneWithValue, "value"_a, + nb::rv_policy::reference_internal) + .def("profile", &llvm::ArgumentInit::Profile, "id"_a) + .def("resolve_references", &llvm::ArgumentInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::ArgumentInit::getAsString) + .def("__str__", &llvm::ArgumentInit::getAsString) + .def("is_complete", &llvm::ArgumentInit::isComplete) + .def("is_concrete", &llvm::ArgumentInit::isConcrete) + .def("get_bit", &llvm::ArgumentInit::getBit, "bit"_a, + nb::rv_policy::reference_internal) + .def("get_cast_to", &llvm::ArgumentInit::getCastTo, "ty"_a, + nb::rv_policy::reference_internal) + .def("convert_initializer_to", &llvm::ArgumentInit::convertInitializerTo, + "ty"_a, nb::rv_policy::reference_internal); - nb::class_(m, "UnsetInit"); + nb::class_(m, "BitInit") + .def_static("classof", &llvm::BitInit::classof, "i"_a) + .def_static("get", &llvm::BitInit::get, "rk"_a, "v"_a, + nb::rv_policy::reference_internal) + .def("get_value", &llvm::BitInit::getValue) + .def("__bool__", &llvm::BitInit::getValue) + .def("convert_initializer_to", &llvm::BitInit::convertInitializerTo, + "ty"_a, nb::rv_policy::reference_internal) + .def("get_bit", &llvm::BitInit::getBit, "bit"_a, + nb::rv_policy::reference_internal) + .def("is_concrete", &llvm::BitInit::isConcrete) + .def("get_as_string", &llvm::BitInit::getAsString) + .def("__str__", &llvm::BitInit::getAsString); - nb::class_(m, "ArgumentInit") - .def("is_positional", &ArgumentInit::isPositional) - .def("is_named", &ArgumentInit::isNamed) - .def_prop_ro("value", &ArgumentInit::getValue) - .def_prop_ro("index", &ArgumentInit::getIndex) - .def_prop_ro("name", &ArgumentInit::getName); + nb::class_(m, "BitsInit") + .def_static("classof", &llvm::BitsInit::classof, "i"_a) + .def_static("get", &llvm::BitsInit::get, "rk"_a, "range"_a, + nb::rv_policy::reference_internal) + .def("profile", &llvm::BitsInit::Profile, "id"_a) + .def("get_num_bits", &llvm::BitsInit::getNumBits) + .def("convert_initializer_to", &llvm::BitsInit::convertInitializerTo, + "ty"_a, nb::rv_policy::reference_internal) + .def("convert_initializer_bit_range", + &llvm::BitsInit::convertInitializerBitRange, "bits"_a, + nb::rv_policy::reference_internal) + .def("convert_initializer_to_int", + &llvm::BitsInit::convertInitializerToInt) + .def("is_complete", &llvm::BitsInit::isComplete) + .def("all_in_complete", &llvm::BitsInit::allInComplete) + .def("is_concrete", &llvm::BitsInit::isConcrete) + .def("get_as_string", &llvm::BitsInit::getAsString) + .def("__str__", &llvm::BitsInit::getAsString) + .def("resolve_references", &llvm::BitsInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("get_bit", &llvm::BitsInit::getBit, "bit"_a, + nb::rv_policy::reference_internal); - nb::class_(m, "BitInit") - .def_prop_ro("value", &BitInit::getValue) - .def("__bool__", &BitInit::getValue); + nb::class_(m, "IntInit") + .def_static("classof", &llvm::IntInit::classof, "i"_a) + .def_static("get", &llvm::IntInit::get, "rk"_a, "v"_a, + nb::rv_policy::reference_internal) + .def("get_value", &llvm::IntInit::getValue) + .def("__int__", &llvm::IntInit::getValue) + .def("convert_initializer_to", &llvm::IntInit::convertInitializerTo, + "ty"_a, nb::rv_policy::reference_internal) + .def("convert_initializer_bit_range", + &llvm::IntInit::convertInitializerBitRange, "bits"_a, + nb::rv_policy::reference_internal) + .def("is_concrete", &llvm::IntInit::isConcrete) + .def("get_as_string", &llvm::IntInit::getAsString) + .def("__str__", &llvm::IntInit::getAsString) + .def("get_bit", &llvm::IntInit::getBit, "bit"_a, + nb::rv_policy::reference_internal); - nb::class_(m, "BitsInit") - .def_prop_ro("num_bits", &BitsInit::getNumBits) - .def("all_incomplete", &BitsInit::allInComplete); + nb::class_(m, "AnonymousNameInit") + .def_static("classof", &llvm::AnonymousNameInit::classof, "i"_a) + .def_static("get", &llvm::AnonymousNameInit::get, "rk"_a, "__"_a, + nb::rv_policy::reference_internal) + .def("get_value", &llvm::AnonymousNameInit::getValue) + .def("get_name_init", &llvm::AnonymousNameInit::getNameInit, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::AnonymousNameInit::getAsString) + .def("__str__", &llvm::AnonymousNameInit::getAsString) + .def("resolve_references", &llvm::AnonymousNameInit::resolveReferences, + "r"_a, nb::rv_policy::reference_internal) + .def("get_bit", &llvm::AnonymousNameInit::getBit, "bit"_a, + nb::rv_policy::reference_internal); - nb::class_(m, "IntInit") - .def_prop_ro("value", &IntInit::getValue); + auto llvm_StringInit = + nb::class_(m, "StringInit"); + nb::enum_(llvm_StringInit, "StringFormat") + .value("SF_String", llvm::StringInit::SF_String) + .value("SF_Code", llvm::StringInit::SF_Code); - nb::class_(m, "AnonymousNameInit") - .def_prop_ro("value", &AnonymousNameInit::getValue) - .def_prop_ro("name_init", &AnonymousNameInit::getNameInit); + llvm_StringInit.def_static("classof", &llvm::StringInit::classof, "i"_a) + .def_static("get", &llvm::StringInit::get, "rk"_a, "__"_a, "fmt"_a, + nb::rv_policy::reference_internal) + .def_static("determine_format", &llvm::StringInit::determineFormat, + "fmt1"_a, "fmt2"_a) + .def("get_value", &llvm::StringInit::getValue) + .def("get_format", &llvm::StringInit::getFormat) + .def("has_code_format", &llvm::StringInit::hasCodeFormat) + .def("convert_initializer_to", &llvm::StringInit::convertInitializerTo, + "ty"_a, nb::rv_policy::reference_internal) + .def("is_concrete", &llvm::StringInit::isConcrete) + .def("get_as_string", &llvm::StringInit::getAsString) + .def("__str__", &llvm::StringInit::getAsUnquotedString) + .def("get_as_unquoted_string", &llvm::StringInit::getAsUnquotedString) + .def("get_bit", &llvm::StringInit::getBit, "bit"_a, + nb::rv_policy::reference_internal); - auto stringInit = nb::class_(m, "StringInit"); + auto llvm_ListInit = + nb::class_(m, "ListInit") + .def_static("classof", &llvm::ListInit::classof, "i"_a) + .def_static("get", &llvm::ListInit::get, "range"_a, "elt_ty"_a, + nb::rv_policy::reference_internal) + .def("profile", &llvm::ListInit::Profile, "id"_a) + .def("get_element", &llvm::ListInit::getElement, "i"_a, + nb::rv_policy::reference_internal) + .def("get_element_type", &llvm::ListInit::getElementType, + nb::rv_policy::reference_internal) + .def("get_element_as_record", &llvm::ListInit::getElementAsRecord, + "i"_a, nb::rv_policy::reference_internal) + .def("convert_initializer_to", &llvm::ListInit::convertInitializerTo, + "ty"_a, nb::rv_policy::reference_internal) + .def("resolve_references", &llvm::ListInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("is_complete", &llvm::ListInit::isComplete) + .def("is_concrete", &llvm::ListInit::isConcrete) + .def("get_as_string", &llvm::ListInit::getAsString) + .def("begin", &llvm::ListInit::begin, + nb::rv_policy::reference_internal) + .def("end", &llvm::ListInit::end, nb::rv_policy::reference_internal) + .def("size", &llvm::ListInit::size) + .def("empty", &llvm::ListInit::empty) + .def("get_bit", &llvm::ListInit::getBit, "bit"_a, + nb::rv_policy::reference_internal) + .def("__len__", [](const llvm::ListInit &v) { return v.size(); }) + .def("__bool__", [](const llvm::ListInit &v) { return !v.empty(); }) + .def( + "__iter__", + [](llvm::ListInit &v) { + return nb::make_iterator( + nb::type(), "Iterator", v.begin(), v.end()); + }, + nb::rv_policy::reference_internal) + .def( + "__getitem__", + [](llvm::ListInit &v, Py_ssize_t i) { + return v.getElement(eudsl::wrap(i, v.size())); + }, + nb::rv_policy::reference_internal) + .def("get_values", + eudsl::coerceReturn>( + &llvm::ListInit::getValues, nb::const_)); - nb::enum_(m, "StringFormat") - .value("SF_String", StringInit::StringFormat::SF_String) - .value("SF_Code", StringInit::StringFormat::SF_Code); + auto llvm_OpInit = nb::class_(m, "OpInit") + .def_static("classof", &llvm::OpInit::classof, "i"_a) + .def("get_bit", &llvm::OpInit::getBit, "bit"_a, + nb::rv_policy::reference_internal); - stringInit.def_prop_ro("value", &StringInit::getValue) - .def_prop_ro("format", &StringInit::getFormat) - .def("has_code_format", &StringInit::hasCodeFormat); + auto unaryOpInit = nb::class_(m, "UnOpInit"); + nb::enum_(m, "UnaryOp") + .value("TOLOWER", llvm::UnOpInit::UnaryOp::TOLOWER) + .value("TOUPPER", llvm::UnOpInit::UnaryOp::TOUPPER) + .value("CAST", llvm::UnOpInit::UnaryOp::CAST) + .value("NOT", llvm::UnOpInit::UnaryOp::NOT) + .value("HEAD", llvm::UnOpInit::UnaryOp::HEAD) + .value("TAIL", llvm::UnOpInit::UnaryOp::TAIL) + .value("SIZE", llvm::UnOpInit::UnaryOp::SIZE) + .value("EMPTY", llvm::UnOpInit::UnaryOp::EMPTY) + .value("GETDAGOP", llvm::UnOpInit::UnaryOp::GETDAGOP) + .value("LOG2", llvm::UnOpInit::UnaryOp::LOG2) + .value("REPR", llvm::UnOpInit::UnaryOp::REPR) + .value("LISTFLATTEN", llvm::UnOpInit::UnaryOp::LISTFLATTEN); - nb::class_(m, "ListInit") - .def("__len__", [](const ListInit &v) { return v.size(); }) - .def("__bool__", [](const ListInit &v) { return !v.empty(); }) + unaryOpInit.def_static("classof", &llvm::UnOpInit::classof, "i"_a) + .def_static("get", &llvm::UnOpInit::get, "opc"_a, "lhs"_a, "type"_a, + nb::rv_policy::reference_internal) + .def("profile", &llvm::UnOpInit::Profile, "id"_a) .def( - "__iter__", - [](ListInit &v) { - return nb::make_iterator( - nb::type(), "Iterator", v.begin(), v.end()); + "get_operand", + [](llvm::UnOpInit &self) -> const llvm::Init * { + return self.getOperand(); }, - nb::keep_alive<0, 1>()) + nb::rv_policy::reference_internal) + .def("get_opcode", &llvm::UnOpInit::getOpcode) .def( - "__getitem__", - [](ListInit &v, Py_ssize_t i) { - return v.getElement(nb::detail::wrap(i, v.size())); + "get_operand", + [](llvm::UnOpInit &self) -> const llvm::Init * { + return self.getOperand(); }, nb::rv_policy::reference_internal) - .def_prop_ro("element_type", &ListInit::getElementType) - .def("get_element_as_record", &ListInit::getElementAsRecord, "i"_a, - nb::rv_policy::reference_internal) - .def_prop_ro("values", coerceReturn>( - &ListInit::getValues, nb::const_)); - - nb::class_(m, "OpInit") - .def("bit", &OpInit::getBit, "bit"_a, nb::rv_policy::reference_internal); - - auto unaryOpInit = nb::class_(m, "UnOpInit"); - nb::enum_(m, "UnaryOp") - .value("TOLOWER", UnOpInit::UnaryOp::TOLOWER) - .value("TOUPPER", UnOpInit::UnaryOp::TOUPPER) - .value("CAST", UnOpInit::UnaryOp::CAST) - .value("NOT", UnOpInit::UnaryOp::NOT) - .value("HEAD", UnOpInit::UnaryOp::HEAD) - .value("TAIL", UnOpInit::UnaryOp::TAIL) - .value("SIZE", UnOpInit::UnaryOp::SIZE) - .value("EMPTY", UnOpInit::UnaryOp::EMPTY) - .value("GETDAGOP", UnOpInit::UnaryOp::GETDAGOP) - .value("LOG2", UnOpInit::UnaryOp::LOG2) - .value("REPR", UnOpInit::UnaryOp::REPR) - .value("LISTFLATTEN", UnOpInit::UnaryOp::LISTFLATTEN); - unaryOpInit.def_prop_ro("opcode", &UnOpInit::getOpcode); - - auto binaryOpInit = nb::class_(m, "BinOpInit"); - nb::enum_(m, "BinaryOp") - .value("ADD", BinOpInit::BinaryOp::ADD) - .value("SUB", BinOpInit::BinaryOp::SUB) - .value("MUL", BinOpInit::BinaryOp::MUL) - .value("DIV", BinOpInit::BinaryOp::DIV) - .value("AND", BinOpInit::BinaryOp::AND) - .value("OR", BinOpInit::BinaryOp::OR) - .value("XOR", BinOpInit::BinaryOp::XOR) - .value("SHL", BinOpInit::BinaryOp::SHL) - .value("SRA", BinOpInit::BinaryOp::SRA) - .value("SRL", BinOpInit::BinaryOp::SRL) - .value("LISTCONCAT", BinOpInit::BinaryOp::LISTCONCAT) - .value("LISTSPLAT", BinOpInit::BinaryOp::LISTSPLAT) - .value("LISTREMOVE", BinOpInit::BinaryOp::LISTREMOVE) - .value("LISTELEM", BinOpInit::BinaryOp::LISTELEM) - .value("LISTSLICE", BinOpInit::BinaryOp::LISTSLICE) - .value("RANGEC", BinOpInit::BinaryOp::RANGEC) - .value("STRCONCAT", BinOpInit::BinaryOp::STRCONCAT) - .value("INTERLEAVE", BinOpInit::BinaryOp::INTERLEAVE) - .value("CONCAT", BinOpInit::BinaryOp::CONCAT) - .value("EQ", BinOpInit::BinaryOp::EQ) - .value("NE", BinOpInit::BinaryOp::NE) - .value("LE", BinOpInit::BinaryOp::LE) - .value("LT", BinOpInit::BinaryOp::LT) - .value("GE", BinOpInit::BinaryOp::GE) - .value("GT", BinOpInit::BinaryOp::GT) - .value("GETDAGARG", BinOpInit::BinaryOp::GETDAGARG) - .value("GETDAGNAME", BinOpInit::BinaryOp::GETDAGNAME) - .value("SETDAGOP", BinOpInit::BinaryOp::SETDAGOP); - binaryOpInit.def_prop_ro("opcode", &BinOpInit::getOpcode) - .def_prop_ro("lhs", &BinOpInit::getLHS) - .def_prop_ro("rhs", &BinOpInit::getRHS); - - auto ternaryOpInit = nb::class_(m, "TernOpInit"); - nb::enum_(m, "TernaryOp") - .value("SUBST", TernOpInit::TernaryOp::SUBST) - .value("FOREACH", TernOpInit::TernaryOp::FOREACH) - .value("FILTER", TernOpInit::TernaryOp::FILTER) - .value("IF", TernOpInit::TernaryOp::IF) - .value("DAG", TernOpInit::TernaryOp::DAG) - .value("RANGE", TernOpInit::TernaryOp::RANGE) - .value("SUBSTR", TernOpInit::TernaryOp::SUBSTR) - .value("FIND", TernOpInit::TernaryOp::FIND) - .value("SETDAGARG", TernOpInit::TernaryOp::SETDAGARG) - .value("SETDAGNAME", TernOpInit::TernaryOp::SETDAGNAME); - ternaryOpInit.def_prop_ro("opcode", &TernOpInit::getOpcode) - .def_prop_ro("lhs", &TernOpInit::getLHS) - .def_prop_ro("mhs", &TernOpInit::getMHS) - .def_prop_ro("rhs", &TernOpInit::getRHS); - - nb::class_(m, "CondOpInit"); - nb::class_(m, "FoldOpInit"); - nb::class_(m, "IsAOpInit"); - nb::class_(m, "ExistsOpInit"); - - nb::class_(m, "VarInit") - .def_prop_ro("name", &VarInit::getName) - .def_prop_ro("name_init", &VarInit::getNameInit) - .def_prop_ro("name_init_as_string", &VarInit::getNameInitAsString); - - nb::class_(m, "VarBitInit") - .def_prop_ro("bit_var", &VarBitInit::getBitVar) - .def_prop_ro("bit_num", &VarBitInit::getBitNum); - - nb::class_(m, "DefInit") - .def_prop_ro("def_", &DefInit::getDef); - - nb::class_(m, "VarDefInit") - .def("get_arg", &VarDefInit::getArg, "i"_a, - nb::rv_policy::reference_internal) - .def_prop_ro("args", coerceReturn>( - &VarDefInit::args, nb::const_)) - .def("__len__", [](const VarDefInit &v) { return v.args_size(); }) - .def("__bool__", [](const VarDefInit &v) { return !v.args_empty(); }) + .def("fold", &llvm::UnOpInit::Fold, "cur_rec"_a, "is_final"_a, + nb::rv_policy::reference_internal) + .def("resolve_references", &llvm::UnOpInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::UnOpInit::getAsString) + .def("__str__", &llvm::UnOpInit::getAsUnquotedString); + + auto binaryOpInit = nb::class_(m, "BinOpInit"); + nb::enum_(m, "BinaryOp") + .value("ADD", llvm::BinOpInit::BinaryOp::ADD) + .value("SUB", llvm::BinOpInit::BinaryOp::SUB) + .value("MUL", llvm::BinOpInit::BinaryOp::MUL) + .value("DIV", llvm::BinOpInit::BinaryOp::DIV) + .value("AND", llvm::BinOpInit::BinaryOp::AND) + .value("OR", llvm::BinOpInit::BinaryOp::OR) + .value("XOR", llvm::BinOpInit::BinaryOp::XOR) + .value("SHL", llvm::BinOpInit::BinaryOp::SHL) + .value("SRA", llvm::BinOpInit::BinaryOp::SRA) + .value("SRL", llvm::BinOpInit::BinaryOp::SRL) + .value("LISTCONCAT", llvm::BinOpInit::BinaryOp::LISTCONCAT) + .value("LISTSPLAT", llvm::BinOpInit::BinaryOp::LISTSPLAT) + .value("LISTREMOVE", llvm::BinOpInit::BinaryOp::LISTREMOVE) + .value("LISTELEM", llvm::BinOpInit::BinaryOp::LISTELEM) + .value("LISTSLICE", llvm::BinOpInit::BinaryOp::LISTSLICE) + .value("RANGEC", llvm::BinOpInit::BinaryOp::RANGEC) + .value("STRCONCAT", llvm::BinOpInit::BinaryOp::STRCONCAT) + .value("INTERLEAVE", llvm::BinOpInit::BinaryOp::INTERLEAVE) + .value("CONCAT", llvm::BinOpInit::BinaryOp::CONCAT) + .value("EQ", llvm::BinOpInit::BinaryOp::EQ) + .value("NE", llvm::BinOpInit::BinaryOp::NE) + .value("LE", llvm::BinOpInit::BinaryOp::LE) + .value("LT", llvm::BinOpInit::BinaryOp::LT) + .value("GE", llvm::BinOpInit::BinaryOp::GE) + .value("GT", llvm::BinOpInit::BinaryOp::GT) + .value("GETDAGARG", llvm::BinOpInit::BinaryOp::GETDAGARG) + .value("GETDAGNAME", llvm::BinOpInit::BinaryOp::GETDAGNAME) + .value("SETDAGOP", llvm::BinOpInit::BinaryOp::SETDAGOP); + + binaryOpInit.def_static("classof", &llvm::BinOpInit::classof, "i"_a) + .def_static("get", &llvm::BinOpInit::get, "opc"_a, "lhs"_a, "rhs"_a, + "type"_a, nb::rv_policy::reference_internal) + .def_static("get_str_concat", &llvm::BinOpInit::getStrConcat, "lhs"_a, + "rhs"_a, nb::rv_policy::reference_internal) + .def_static("get_list_concat", &llvm::BinOpInit::getListConcat, "lhs"_a, + "rhs"_a, nb::rv_policy::reference_internal) + .def("profile", &llvm::BinOpInit::Profile, "id"_a) + .def("get_opcode", &llvm::BinOpInit::getOpcode) + .def("get_lhs", &llvm::BinOpInit::getLHS, + nb::rv_policy::reference_internal) + .def("get_rhs", &llvm::BinOpInit::getRHS, + nb::rv_policy::reference_internal) + .def("compare_init", &llvm::BinOpInit::CompareInit, "opc"_a, "lhs"_a, + "rhs"_a) + .def("fold", &llvm::BinOpInit::Fold, "cur_rec"_a, + nb::rv_policy::reference_internal) + .def("resolve_references", &llvm::BinOpInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::BinOpInit::getAsString) + .def("__str__", &llvm::BinOpInit::getAsUnquotedString); + + auto ternaryOpInit = + nb::class_(m, "TernOpInit"); + nb::enum_(m, "TernaryOp") + .value("SUBST", llvm::TernOpInit::TernaryOp::SUBST) + .value("FOREACH", llvm::TernOpInit::TernaryOp::FOREACH) + .value("FILTER", llvm::TernOpInit::TernaryOp::FILTER) + .value("IF", llvm::TernOpInit::TernaryOp::IF) + .value("DAG", llvm::TernOpInit::TernaryOp::DAG) + .value("RANGE", llvm::TernOpInit::TernaryOp::RANGE) + .value("SUBSTR", llvm::TernOpInit::TernaryOp::SUBSTR) + .value("FIND", llvm::TernOpInit::TernaryOp::FIND) + .value("SETDAGARG", llvm::TernOpInit::TernaryOp::SETDAGARG) + .value("SETDAGNAME", llvm::TernOpInit::TernaryOp::SETDAGNAME); + + ternaryOpInit.def_static("classof", &llvm::TernOpInit::classof, "i"_a) + .def_static("get", &llvm::TernOpInit::get, "opc"_a, "lhs"_a, "mhs"_a, + "rhs"_a, "type"_a, nb::rv_policy::reference_internal) + .def("profile", &llvm::TernOpInit::Profile, "id"_a) + .def("get_opcode", &llvm::TernOpInit::getOpcode) + .def("get_lhs", &llvm::TernOpInit::getLHS, + nb::rv_policy::reference_internal) + .def("get_mhs", &llvm::TernOpInit::getMHS, + nb::rv_policy::reference_internal) + .def("get_rhs", &llvm::TernOpInit::getRHS, + nb::rv_policy::reference_internal) + .def("fold", &llvm::TernOpInit::Fold, "cur_rec"_a, + nb::rv_policy::reference_internal) + .def("is_complete", &llvm::TernOpInit::isComplete) + .def("resolve_references", &llvm::TernOpInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::TernOpInit::getAsString) + .def("__str__", &llvm::TernOpInit::getAsUnquotedString); + + nb::class_(m, "CondOpInit") + .def_static("classof", &llvm::CondOpInit::classof, "i"_a) + .def_static("get", &llvm::CondOpInit::get, "c"_a, "v"_a, "type"_a, + nb::rv_policy::reference_internal) + .def("profile", &llvm::CondOpInit::Profile, "id"_a) + .def("get_val_type", &llvm::CondOpInit::getValType, + nb::rv_policy::reference_internal) + .def("get_num_conds", &llvm::CondOpInit::getNumConds) + .def("get_cond", &llvm::CondOpInit::getCond, "num"_a, + nb::rv_policy::reference_internal) + .def("get_val", &llvm::CondOpInit::getVal, "num"_a, + nb::rv_policy::reference_internal) + .def("get_conds", &llvm::CondOpInit::getConds) + .def("get_vals", &llvm::CondOpInit::getVals) + .def("fold", &llvm::CondOpInit::Fold, "cur_rec"_a, + nb::rv_policy::reference_internal) + .def("resolve_references", &llvm::CondOpInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("is_concrete", &llvm::CondOpInit::isConcrete) + .def("is_complete", &llvm::CondOpInit::isComplete) + .def("get_as_string", &llvm::CondOpInit::getAsString) + .def("__str__", &llvm::CondOpInit::getAsUnquotedString) + .def("arg_begin", &llvm::CondOpInit::arg_begin, + nb::rv_policy::reference_internal) + .def("arg_end", &llvm::CondOpInit::arg_end, + nb::rv_policy::reference_internal) + .def("case_size", &llvm::CondOpInit::case_size) + .def("case_empty", &llvm::CondOpInit::case_empty) + .def("name_begin", &llvm::CondOpInit::name_begin, + nb::rv_policy::reference_internal) + .def("name_end", &llvm::CondOpInit::name_end, + nb::rv_policy::reference_internal) + .def("val_size", &llvm::CondOpInit::val_size) + .def("val_empty", &llvm::CondOpInit::val_empty) + .def("get_bit", &llvm::CondOpInit::getBit, "bit"_a, + nb::rv_policy::reference_internal); + + nb::class_(m, "FoldOpInit") + .def_static("classof", &llvm::FoldOpInit::classof, "i"_a) + .def_static("get", &llvm::FoldOpInit::get, "start"_a, "list"_a, "a"_a, + "b"_a, "expr"_a, "type"_a, nb::rv_policy::reference_internal) + .def("profile", &llvm::FoldOpInit::Profile, "id"_a) + .def("fold", &llvm::FoldOpInit::Fold, "cur_rec"_a, + nb::rv_policy::reference_internal) + .def("is_complete", &llvm::FoldOpInit::isComplete) + .def("resolve_references", &llvm::FoldOpInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("get_bit", &llvm::FoldOpInit::getBit, "bit"_a, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::FoldOpInit::getAsString) + .def("__str__", &llvm::FoldOpInit::getAsString); + + nb::class_(m, "IsAOpInit") + .def_static("classof", &llvm::IsAOpInit::classof, "i"_a) + .def_static("get", &llvm::IsAOpInit::get, "check_type"_a, "expr"_a, + nb::rv_policy::reference_internal) + .def("profile", &llvm::IsAOpInit::Profile, "id"_a) + .def("fold", &llvm::IsAOpInit::Fold, nb::rv_policy::reference_internal) + .def("is_complete", &llvm::IsAOpInit::isComplete) + .def("resolve_references", &llvm::IsAOpInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("get_bit", &llvm::IsAOpInit::getBit, "bit"_a, + nb::rv_policy::reference_internal) + .def("__str__", &llvm::IsAOpInit::getAsString) + .def("get_as_string", &llvm::IsAOpInit::getAsString); + + nb::class_(m, "ExistsOpInit") + .def_static("classof", &llvm::ExistsOpInit::classof, "i"_a) + .def_static("get", &llvm::ExistsOpInit::get, "check_type"_a, "expr"_a, + nb::rv_policy::reference_internal) + .def("profile", &llvm::ExistsOpInit::Profile, "id"_a) + .def("fold", &llvm::ExistsOpInit::Fold, "cur_rec"_a, "is_final"_a, + nb::rv_policy::reference_internal) + .def("is_complete", &llvm::ExistsOpInit::isComplete) + .def("resolve_references", &llvm::ExistsOpInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("get_bit", &llvm::ExistsOpInit::getBit, "bit"_a, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::ExistsOpInit::getAsString) + .def("__str__", &llvm::ExistsOpInit::getAsUnquotedString); + + nb::class_(m, "VarInit") + .def_static("classof", &llvm::VarInit::classof, "i"_a) + .def_static( + "get", + [](llvm::StringRef VN, const llvm::RecTy *T) + -> const llvm::VarInit * { return llvm::VarInit::get(VN, T); }, + "vn"_a, "t"_a, nb::rv_policy::reference_internal) + .def_static( + "get", + [](const llvm::Init *VN, const llvm::RecTy *T) + -> const llvm::VarInit * { return llvm::VarInit::get(VN, T); }, + "vn"_a, "t"_a, nb::rv_policy::reference_internal) + .def("get_name", &llvm::VarInit::getName) + .def("get_name_init", &llvm::VarInit::getNameInit, + nb::rv_policy::reference_internal) + .def("get_name_init_as_string", &llvm::VarInit::getNameInitAsString) + .def("resolve_references", &llvm::VarInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("get_bit", &llvm::VarInit::getBit, "bit"_a, + nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::VarInit::getAsString) + .def("__str__", &llvm::VarInit::getAsUnquotedString); + + nb::class_(m, "VarBitInit") + .def_static("classof", &llvm::VarBitInit::classof, "i"_a) + .def_static("get", &llvm::VarBitInit::get, "t"_a, "b"_a, + nb::rv_policy::reference_internal) + .def("get_bit_var", &llvm::VarBitInit::getBitVar, + nb::rv_policy::reference_internal) + .def("get_bit_num", &llvm::VarBitInit::getBitNum) + .def("get_as_string", &llvm::VarBitInit::getAsString) + .def("__str__", &llvm::VarBitInit::getAsUnquotedString) + .def("resolve_references", &llvm::VarBitInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("get_bit", &llvm::VarBitInit::getBit, "b"_a, + nb::rv_policy::reference_internal); + + nb::class_(m, "DefInit") + .def_static("classof", &llvm::DefInit::classof, "i"_a) + .def("convert_initializer_to", &llvm::DefInit::convertInitializerTo, + "ty"_a, nb::rv_policy::reference_internal) + .def("get_def", &llvm::DefInit::getDef, nb::rv_policy::reference_internal) + .def("get_field_type", &llvm::DefInit::getFieldType, "field_name"_a, + nb::rv_policy::reference_internal) + .def("is_concrete", &llvm::DefInit::isConcrete) + .def("get_as_string", &llvm::DefInit::getAsString) + .def("__str__", &llvm::DefInit::getAsUnquotedString) + .def("get_bit", &llvm::DefInit::getBit, "bit"_a, + nb::rv_policy::reference_internal); + + nb::class_(m, "VarDefInit") + .def_static("classof", &llvm::VarDefInit::classof, "i"_a) + .def_static("get", &llvm::VarDefInit::get, "loc"_a, "class"_a, "args"_a, + nb::rv_policy::reference_internal) + .def("profile", &llvm::VarDefInit::Profile, "id"_a) + .def("resolve_references", &llvm::VarDefInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("fold", &llvm::VarDefInit::Fold, nb::rv_policy::reference_internal) + .def("get_as_string", &llvm::VarDefInit::getAsString) + .def("__str__", &llvm::VarDefInit::getAsUnquotedString) + .def("get_arg", &llvm::VarDefInit::getArg, "i"_a, + nb::rv_policy::reference_internal) + .def("args_begin", &llvm::VarDefInit::args_begin, + nb::rv_policy::reference_internal) + .def("args_end", &llvm::VarDefInit::args_end, + nb::rv_policy::reference_internal) + .def("args_size", &llvm::VarDefInit::args_size) + .def("args_empty", &llvm::VarDefInit::args_empty) + .def("get_bit", &llvm::VarDefInit::getBit, "bit"_a, + nb::rv_policy::reference_internal) + .def("args", + eudsl::coerceReturn>( + &llvm::VarDefInit::args, nb::const_), + nb::rv_policy::reference_internal) + .def("__len__", [](const llvm::VarDefInit &v) { return v.args_size(); }) + .def("__bool__", + [](const llvm::VarDefInit &v) { return !v.args_empty(); }) .def( "__iter__", - [](VarDefInit &v) { + [](llvm::VarDefInit &v) { return nb::make_iterator( - nb::type(), "Iterator", v.args_begin(), + nb::type(), "Iterator", v.args_begin(), v.args_end()); }, - nb::keep_alive<0, 1>()) + nb::rv_policy::reference_internal) .def( "__getitem__", - [](VarDefInit &v, Py_ssize_t i) { - return v.getArg(nb::detail::wrap(i, v.args_size())); + [](llvm::VarDefInit &v, Py_ssize_t i) { + return v.getArg(eudsl::wrap(i, v.args_size())); }, nb::rv_policy::reference_internal); - nb::class_(m, "FieldInit") - .def_prop_ro("record", &FieldInit::getRecord) - .def_prop_ro("field_name", &FieldInit::getFieldName); + nb::class_(m, "FieldInit") + .def_static("classof", &llvm::FieldInit::classof, "i"_a) + .def_static("get", &llvm::FieldInit::get, "r"_a, "fn"_a, + nb::rv_policy::reference_internal) + .def("get_record", &llvm::FieldInit::getRecord, + nb::rv_policy::reference_internal) + .def("get_field_name", &llvm::FieldInit::getFieldName, + nb::rv_policy::reference_internal) + .def("get_bit", &llvm::FieldInit::getBit, "bit"_a, + nb::rv_policy::reference_internal) + .def("resolve_references", &llvm::FieldInit::resolveReferences, "r"_a, + nb::rv_policy::reference_internal) + .def("fold", &llvm::FieldInit::Fold, "cur_rec"_a, + nb::rv_policy::reference_internal) + .def("is_concrete", &llvm::FieldInit::isConcrete) + .def("get_as_string", &llvm::FieldInit::getAsString) + .def("__str__", &llvm::FieldInit::getAsUnquotedString); - nb::class_(m, "DagInit") - .def_prop_ro("operator", &DagInit::getOperator) - .def_prop_ro("name_init", &DagInit::getName) - .def_prop_ro("name_str", &DagInit::getNameStr) - .def_prop_ro("num_args", &DagInit::getNumArgs) - .def("get_arg", &DagInit::getArg, "num"_a, + nb::class_(m, "DagInit") + .def("profile", &llvm::DagInit::Profile, "id"_a) + .def("get_operator", &llvm::DagInit::getOperator, + nb::rv_policy::reference_internal) + .def("get_operator_as_def", &llvm::DagInit::getOperatorAsDef, "loc"_a, + nb::rv_policy::reference_internal) + .def("get_name", &llvm::DagInit::getName, + nb::rv_policy::reference_internal) + .def("get_name_str", &llvm::DagInit::getNameStr) + .def("get_num_args", &llvm::DagInit::getNumArgs) + .def("get_arg", &llvm::DagInit::getArg, "num"_a, + nb::rv_policy::reference_internal) + .def("get_arg_no", &llvm::DagInit::getArgNo, "name"_a) + .def("get_arg_name", &llvm::DagInit::getArgName, "num"_a, nb::rv_policy::reference_internal) - .def("get_arg_no", &DagInit::getArgNo, "name"_a) - .def("get_arg_name_init", &DagInit::getArgName, "num"_a, + .def("get_arg_name_str", &llvm::DagInit::getArgNameStr, "num"_a) + .def("resolve_references", &llvm::DagInit::resolveReferences, "r"_a, nb::rv_policy::reference_internal) - .def("get_arg_name_str", &DagInit::getArgNameStr, "num"_a) - .def("get_arg_name_inits", - coerceReturn>(&DagInit::getArgNames, - nb::const_), + .def("is_concrete", &llvm::DagInit::isConcrete) + .def("get_as_string", &llvm::DagInit::getAsString) + .def("arg_begin", &llvm::DagInit::arg_begin, + nb::rv_policy::reference_internal) + .def("arg_end", &llvm::DagInit::arg_end, + nb::rv_policy::reference_internal) + .def("arg_size", &llvm::DagInit::arg_size) + .def("arg_empty", &llvm::DagInit::arg_empty) + .def("name_begin", &llvm::DagInit::name_begin, + nb::rv_policy::reference_internal) + .def("name_end", &llvm::DagInit::name_end, + nb::rv_policy::reference_internal) + .def("name_size", &llvm::DagInit::name_size) + .def("name_empty", &llvm::DagInit::name_empty) + .def("get_bit", &llvm::DagInit::getBit, "bit"_a, + nb::rv_policy::reference_internal) + .def("get_arg_names", + eudsl::coerceReturn>( + &llvm::DagInit::getArgNames, nb::const_), nb::rv_policy::reference_internal) .def("get_args", - coerceReturn>(&DagInit::getArgs, - nb::const_), + eudsl::coerceReturn>( + &llvm::DagInit::getArgs, nb::const_), nb::rv_policy::reference_internal) - .def("__len__", [](const DagInit &v) { return v.arg_size(); }) - .def("__bool__", [](const DagInit &v) { return !v.arg_empty(); }) + .def("__len__", [](const llvm::DagInit &v) { return v.arg_size(); }) + .def("__bool__", [](const llvm::DagInit &v) { return !v.arg_empty(); }) .def( "__iter__", - [](DagInit &v) { + [](llvm::DagInit &v) { return nb::make_iterator( - nb::type(), "Iterator", v.arg_begin(), v.arg_end()); + nb::type(), "Iterator", v.arg_begin(), + v.arg_end()); }, - nb::keep_alive<0, 1>()) + nb::rv_policy::reference_internal) .def( "__getitem__", - [](DagInit &v, Py_ssize_t i) { - return v.getArg(nb::detail::wrap(i, v.arg_size())); + [](llvm::DagInit &v, Py_ssize_t i) { + return v.getArg(eudsl::wrap(i, v.arg_size())); }, nb::rv_policy::reference_internal); - nb::class_(m, "RecordVal") - .def("dump", &RecordVal::dump) - .def_prop_ro("name", &RecordVal::getName) - .def_prop_ro("name_init_as_string", &RecordVal::getNameInitAsString) - .def_prop_ro("print_type", &RecordVal::getPrintType) - .def_prop_ro("record_keeper", &RecordVal::getRecordKeeper) - .def_prop_ro("type", &RecordVal::getType) - .def_prop_ro("is_nonconcrete_ok", &RecordVal::isNonconcreteOK) - .def_prop_ro("is_template_arg", &RecordVal::isTemplateArg) - .def_prop_ro("value", &RecordVal::getValue) + auto llvm_RecordVal = nb::class_(m, "RecordVal"); + nb::enum_(llvm_RecordVal, "FieldKind") + .value("FK_Normal", llvm::RecordVal::FK_Normal) + .value("FK_NonconcreteOK", llvm::RecordVal::FK_NonconcreteOK) + .value("FK_TemplateArg", llvm::RecordVal::FK_TemplateArg); + + llvm_RecordVal + .def("get_record_keeper", &llvm::RecordVal::getRecordKeeper, + nb::rv_policy::reference_internal) + .def("get_name", &llvm::RecordVal::getName) + .def("get_name_init", &llvm::RecordVal::getNameInit, + nb::rv_policy::reference_internal) + .def("get_name_init_as_string", &llvm::RecordVal::getNameInitAsString) + .def("get_loc", &llvm::RecordVal::getLoc, + nb::rv_policy::reference_internal) + .def("is_nonconcrete_ok", &llvm::RecordVal::isNonconcreteOK) + .def("is_template_arg", &llvm::RecordVal::isTemplateArg) + .def("get_type", &llvm::RecordVal::getType, + nb::rv_policy::reference_internal) + .def("get_print_type", &llvm::RecordVal::getPrintType) + .def("get_value", &llvm::RecordVal::getValue, + nb::rv_policy::reference_internal) + .def( + "set_value", + [](llvm::RecordVal &self, const llvm::Init *V) -> bool { + return self.setValue(V); + }, + "v"_a) + .def( + "set_value", + [](llvm::RecordVal &self, const llvm::Init *V, + llvm::SMLoc NewLoc) -> bool { return self.setValue(V, NewLoc); }, + "v"_a, "new_loc"_a) + .def("add_reference_loc", &llvm::RecordVal::addReferenceLoc, "loc"_a) + .def("get_reference_locs", &llvm::RecordVal::getReferenceLocs) + .def("set_used", &llvm::RecordVal::setUsed, "used"_a) + .def("is_used", &llvm::RecordVal::isUsed) + .def("dump", &llvm::RecordVal::dump) + .def("print", &llvm::RecordVal::print, "os"_a, "print_sem"_a) .def("__str__", - [](const RecordVal &self) { - return self.getValue()->getAsUnquotedString(); + [](const llvm::RecordVal &self) { + return self.getValue() ? self.getValue()->getAsUnquotedString() + : "<>"; }) - .def_prop_ro("is_used", &RecordVal::isUsed); + .def("is_used", &llvm::RecordVal::isUsed); struct RecordValues {}; nb::class_(m, "RecordValues", nb::dynamic_attr()) - .def("__repr__", [](const nb::object &self) { - nb::str s{"RecordValues("}; - auto dic = nb::cast(nb::getattr(self, "__dict__")); - int i = 0; - for (auto [key, value] : dic) { - s += key + nb::str("=") + - nb::str(nb::cast(value) - .getValue() - ->getAsUnquotedString() - .c_str()); - if (i < dic.size() - 1) - s += nb::str(", "); - ++i; - } - s += nb::str(")"); - return s; - }); - - nb::class_(m, "Record") - .def_prop_ro("direct_super_classes", - [](const Record &self) -> std::vector { - SmallVector Classes; - self.getDirectSuperClasses(Classes); - return {Classes.begin(), Classes.end()}; - }) - .def_prop_ro("id", &Record::getID) - .def_prop_ro("name", &Record::getName) - .def_prop_ro("name_init_as_string", &Record::getNameInitAsString) - .def_prop_ro("records", &Record::getRecords) - .def_prop_ro("type", &Record::getType) - .def("get_value", nb::overload_cast(&Record::getValue), - "name"_a, nb::rv_policy::reference_internal) - .def("get_value_as_bit", &Record::getValueAsBit, "field_name"_a) - .def("get_value_as_def", &Record::getValueAsDef, "field_name"_a) - .def("get_value_as_int", &Record::getValueAsInt, "field_name"_a) - .def("get_value_as_list_of_defs", &Record::getValueAsListOfDefs, - "field_name"_a, nb::rv_policy::reference_internal) - .def("get_value_as_list_of_ints", &Record::getValueAsListOfInts, - "field_name"_a) - .def("get_value_as_list_of_strings", &Record::getValueAsListOfStrings, - "field_name"_a) - .def("get_value_as_optional_def", &Record::getValueAsOptionalDef, - "field_name"_a, nb::rv_policy::reference_internal) - .def("get_value_as_optional_string", &Record::getValueAsOptionalString, - nb::sig("def get_value_as_optional_string(self, field_name: str, /) " - "-> Optional[str]")) - .def("get_value_as_string", &Record::getValueAsString, "field_name"_a) - .def("get_value_as_bit_or_unset", &Record::getValueAsBitOrUnset, - "field_name"_a, "unset"_a) - .def("get_value_as_bits_init", &Record::getValueAsBitsInit, - "field_name"_a, nb::rv_policy::reference_internal) - .def("get_value_as_dag", &Record::getValueAsDag, "field_name"_a, - nb::rv_policy::reference_internal) - .def("get_value_as_list_init", &Record::getValueAsListInit, - "field_name"_a, nb::rv_policy::reference_internal) - .def("get_value_init", &Record::getValueInit, "field_name"_a, - nb::rv_policy::reference_internal) - .def_prop_ro( + .def("__repr__", + [](const nb::object &self) { + nb::str s{"RecordValues("}; + auto dic = nb::cast(nb::getattr(self, "__dict__")); + int i = 0; + for (auto [key, value] : dic) { + s += key + nb::str("=") + + nb::str(nb::cast(value) + .getValue() + ->getAsUnquotedString() + .c_str()); + if (i < dic.size() - 1) + s += nb::str(", "); + ++i; + } + s += nb::str(")"); + return s; + }) + .def( + "__iter__", + [](const nb::object &self) { + return nb::iter(getattr(self, "__dict__")); + }, + nb::rv_policy::reference_internal) + .def( + "keys", + [](const nb::object &self) { + return getattr(getattr(self, "__dict__"), "keys")(); + }, + nb::rv_policy::reference_internal) + .def( "values", - [](Record &self) { + [](const nb::object &self) { + return getattr(getattr(self, "__dict__"), "values")(); + }, + nb::rv_policy::reference_internal) + .def( + "items", + [](const nb::object &self) { + return getattr(getattr(self, "__dict__"), "items")(); + }, + nb::rv_policy::reference_internal); + + nb::class_(m, "Record") + .def("get_direct_super_classes", + [](const llvm::Record &self) -> std::vector { + llvm::SmallVector Classes; + self.getDirectSuperClasses(Classes); + return {Classes.begin(), Classes.end()}; + }) + .def( + "get_values", + [](llvm::Record &self) { // you can't just call the class_->operator() nb::handle recordValsInstTy = nb::type(); assert(recordValsInstTy.is_valid() && @@ -436,33 +899,152 @@ NB_MODULE(eudsl_tblgen_ext, m) { recordValsInst.type().is(recordValsInstTy) && !nb::inst_ready(recordValsInst)); - std::vector values = self.getValues(); - for (const RecordVal &recordVal : values) { + std::vector values = self.getValues(); + for (const llvm::RecordVal &recordVal : values) { nb::setattr(recordValsInst, recordVal.getName().str().c_str(), nb::borrow(nb::cast(recordVal))); } return recordValsInst; - }) - .def("has_direct_super_class", &Record::hasDirectSuperClass, + }, + nb::rv_policy::reference_internal) + .def("get_template_args", + eudsl::coerceReturn>( + &llvm::Record::getTemplateArgs, nb::const_), + nb::rv_policy::reference_internal) + .def_static("get_new_uid", &llvm::Record::getNewUID, "rk"_a) + .def("get_id", &llvm::Record::getID) + .def("get_name", &llvm::Record::getName) + .def("get_name_init", &llvm::Record::getNameInit, + nb::rv_policy::reference_internal) + .def("get_name_init_as_string", &llvm::Record::getNameInitAsString) + .def("set_name", &llvm::Record::setName, "name"_a) + .def("get_loc", eudsl::coerceReturn>( + &llvm::Record::getLoc, nb::const_)) + .def("append_loc", &llvm::Record::appendLoc, "loc"_a) + .def("get_forward_declaration_locs", + &llvm::Record::getForwardDeclarationLocs) + .def("append_reference_loc", &llvm::Record::appendReferenceLoc, "loc"_a) + .def("get_reference_locs", &llvm::Record::getReferenceLocs) + .def("update_class_loc", &llvm::Record::updateClassLoc, "loc"_a) + .def("get_type", &llvm::Record::getType, + nb::rv_policy::reference_internal) + .def("get_def_init", &llvm::Record::getDefInit, + nb::rv_policy::reference_internal) + .def("is_class", &llvm::Record::isClass) + .def("is_multi_class", &llvm::Record::isMultiClass) + .def("is_anonymous", &llvm::Record::isAnonymous) + .def("get_template_args", &llvm::Record::getTemplateArgs) + .def("get_assertions", &llvm::Record::getAssertions) + .def("get_dumps", &llvm::Record::getDumps) + .def("get_super_classes", &llvm::Record::getSuperClasses) + .def("has_direct_super_class", &llvm::Record::hasDirectSuperClass, "super_class"_a) - .def_prop_ro("is_anonymous", &Record::isAnonymous) - .def_prop_ro("is_class", &Record::isClass) - .def_prop_ro("is_multi_class", &Record::isMultiClass) - .def("is_sub_class_of", - nb::overload_cast(&Record::isSubClassOf, nb::const_), - "r"_a) - .def("is_sub_class_of", - nb::overload_cast(&Record::isSubClassOf, nb::const_), - "name"_a) - .def("is_value_unset", &Record::isValueUnset, "field_name"_a) - .def_prop_ro("def_init", &Record::getDefInit) - .def_prop_ro("name_init", &Record::getNameInit) - .def_prop_ro("template_args", coerceReturn>( - &Record::getTemplateArgs, nb::const_)) - .def("is_template_arg", &Record::isTemplateArg, "name"_a); - - using RecordMap = std::map, std::less<>>; - using GlobalMap = std::map>; + .def("is_template_arg", &llvm::Record::isTemplateArg, "name"_a) + .def( + "get_value", + [](llvm::Record &self, const llvm::Init *Name) + -> const llvm::RecordVal * { return self.getValue(Name); }, + "name"_a, nb::rv_policy::reference_internal) + .def( + "get_value", + [](llvm::Record &self, llvm::StringRef Name) + -> const llvm::RecordVal * { return self.getValue(Name); }, + "name"_a, nb::rv_policy::reference_internal) + .def( + "get_value", + [](llvm::Record &self, const llvm::Init *Name) -> llvm::RecordVal * { + return self.getValue(Name); + }, + "name"_a, nb::rv_policy::reference_internal) + .def( + "get_value", + [](llvm::Record &self, llvm::StringRef Name) -> llvm::RecordVal * { + return self.getValue(Name); + }, + "name"_a, nb::rv_policy::reference_internal) + .def("add_template_arg", &llvm::Record::addTemplateArg, "name"_a) + .def("add_value", &llvm::Record::addValue, "rv"_a) + .def( + "remove_value", + [](llvm::Record &self, const llvm::Init *Name) -> void { + return self.removeValue(Name); + }, + "name"_a) + .def( + "remove_value", + [](llvm::Record &self, llvm::StringRef Name) -> void { + return self.removeValue(Name); + }, + "name"_a) + .def("add_assertion", &llvm::Record::addAssertion, "loc"_a, "condition"_a, + "message"_a) + .def("add_dump", &llvm::Record::addDump, "loc"_a, "message"_a) + .def("append_assertions", &llvm::Record::appendAssertions, "rec"_a) + .def("append_dumps", &llvm::Record::appendDumps, "rec"_a) + .def("check_record_assertions", &llvm::Record::checkRecordAssertions) + .def("emit_record_dumps", &llvm::Record::emitRecordDumps) + .def("check_unused_template_args", &llvm::Record::checkUnusedTemplateArgs) + .def( + "is_sub_class_of", + [](llvm::Record &self, const llvm::Record *R) -> bool { + return self.isSubClassOf(R); + }, + "r"_a) + .def( + "is_sub_class_of", + [](llvm::Record &self, llvm::StringRef Name) -> bool { + return self.isSubClassOf(Name); + }, + "name"_a) + .def("add_super_class", &llvm::Record::addSuperClass, "r"_a, "range"_a) + .def( + "resolve_references", + [](llvm::Record &self, const llvm::Init *NewName) -> void { + return self.resolveReferences(NewName); + }, + "new_name"_a) + .def( + "resolve_references", + [](llvm::Record &self, llvm::Resolver &R, + const llvm::RecordVal *SkipVal) -> void { + return self.resolveReferences(R, SkipVal); + }, + "r"_a, "skip_val"_a) + .def("get_records", &llvm::Record::getRecords, + nb::rv_policy::reference_internal) + .def("dump", [](llvm::Record &self) { self.dump(); }) + .def("get_field_loc", &llvm::Record::getFieldLoc, "field_name"_a) + .def("get_value_init", &llvm::Record::getValueInit, "field_name"_a, + nb::rv_policy::reference_internal) + .def("is_value_unset", &llvm::Record::isValueUnset, "field_name"_a) + .def("get_value_as_string", &llvm::Record::getValueAsString, + "field_name"_a) + .def("get_value_as_optional_string", + &llvm::Record::getValueAsOptionalString, "field_name"_a) + .def("get_value_as_bits_init", &llvm::Record::getValueAsBitsInit, + "field_name"_a, nb::rv_policy::reference_internal) + .def("get_value_as_list_init", &llvm::Record::getValueAsListInit, + "field_name"_a, nb::rv_policy::reference_internal) + .def("get_value_as_list_of_defs", &llvm::Record::getValueAsListOfDefs, + "field_name"_a) + .def("get_value_as_list_of_ints", &llvm::Record::getValueAsListOfInts, + "field_name"_a) + .def("get_value_as_list_of_strings", + &llvm::Record::getValueAsListOfStrings, "field_name"_a) + .def("get_value_as_def", &llvm::Record::getValueAsDef, "field_name"_a, + nb::rv_policy::reference_internal) + .def("get_value_as_optional_def", &llvm::Record::getValueAsOptionalDef, + "field_name"_a, nb::rv_policy::reference_internal) + .def("get_value_as_bit", &llvm::Record::getValueAsBit, "field_name"_a) + .def("get_value_as_bit_or_unset", &llvm::Record::getValueAsBitOrUnset, + "field_name"_a, "unset"_a) + .def("get_value_as_int", &llvm::Record::getValueAsInt, "field_name"_a) + .def("get_value_as_dag", &llvm::Record::getValueAsDag, "field_name"_a, + nb::rv_policy::reference_internal); + + using RecordMap = + std::map, std::less<>>; + using GlobalMap = std::map>; nb::bind_map(m, "GlobalMap"); nb::class_(m, "RecordMap") @@ -479,14 +1061,14 @@ NB_MODULE(eudsl_tblgen_ext, m) { return nb::make_key_iterator( nb::type(), "KeyIterator", m.begin(), m.end()); }, - nb::keep_alive<0, 1>()) + nb::rv_policy::reference_internal) .def( "keys", [](RecordMap &m) { return nb::make_key_iterator( nb::type(), "KeyIterator", m.begin(), m.end()); }, - nb::keep_alive<0, 1>()) + nb::rv_policy::reference_internal) .def( "__getitem__", [](RecordMap &m, const std::string &k) { @@ -498,26 +1080,27 @@ NB_MODULE(eudsl_tblgen_ext, m) { }, nb::rv_policy::reference_internal); - nb::class_(m, "RecordKeeper") + nb::class_(m, "RecordKeeper") .def(nb::init<>()) .def( "parse_td", - [](RecordKeeper &self, const std::string &inputFilename, + [](llvm::RecordKeeper &self, const std::string &inputFilename, const std::vector &includeDirs, const std::vector ¯oNames, bool noWarnOnUnusedTemplateArgs) { - ErrorOr> fileOrErr = - MemoryBuffer::getFileOrSTDIN(inputFilename, /*IsText=*/true); + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFile(inputFilename, + /*IsText=*/true); if (std::error_code EC = fileOrErr.getError()) throw std::runtime_error("Could not open input file '" + inputFilename + "': " + EC.message() + "\n"); self.saveInputFilename(inputFilename); - SourceMgr srcMgr; + llvm::SourceMgr srcMgr; srcMgr.setIncludeDirs(includeDirs); - srcMgr.AddNewSourceBuffer(std::move(*fileOrErr), SMLoc()); - TGParser tgParser(srcMgr, macroNames, self, - noWarnOnUnusedTemplateArgs); + srcMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); + llvm::TGParser tgParser(srcMgr, macroNames, self, + noWarnOnUnusedTemplateArgs); if (tgParser.ParseFile()) throw std::runtime_error("Could not parse file '" + inputFilename); @@ -526,15 +1109,1663 @@ NB_MODULE(eudsl_tblgen_ext, m) { "input_filename"_a, "include_dirs"_a = nb::list(), "macro_names"_a = nb::list(), "no_warn_on_unused_template_args"_a = true) - .def_prop_ro("input_filename", &RecordKeeper::getInputFilename) - .def_prop_ro("classes", &RecordKeeper::getClasses) - .def_prop_ro("defs", &RecordKeeper::getDefs) - .def_prop_ro("globals", &RecordKeeper::getGlobals) - .def("get_all_derived_definitions", - coerceReturn, ArrayRef>( - &RecordKeeper::getAllDerivedDefinitions, nb::const_), - "class_name"_a, nb::rv_policy::reference_internal); - - m.def("lookup_intrinsic_id", Intrinsic::lookupIntrinsicID, nb::arg("name")); - m.def("intrinsic_is_overloaded", Intrinsic::isOverloaded, nb::arg("id")); + .def("get_input_filename", &llvm::RecordKeeper::getInputFilename) + .def("get_classes", &llvm::RecordKeeper::getClasses, + nb::rv_policy::reference_internal) + .def("get_defs", &llvm::RecordKeeper::getDefs, + nb::rv_policy::reference_internal) + .def("get_globals", &llvm::RecordKeeper::getGlobals, + nb::rv_policy::reference_internal) + .def("get_class", &llvm::RecordKeeper::getClass, "name"_a, + nb::rv_policy::reference_internal) + .def("get_def", &llvm::RecordKeeper::getDef, "name"_a, + nb::rv_policy::reference_internal) + .def("get_global", &llvm::RecordKeeper::getGlobal, "name"_a, + nb::rv_policy::reference_internal) + .def("save_input_filename", &llvm::RecordKeeper::saveInputFilename, + "filename"_a) + .def("add_class", &llvm::RecordKeeper::addClass, "r"_a) + .def("add_def", &llvm::RecordKeeper::addDef, "r"_a) + .def("add_extra_global", &llvm::RecordKeeper::addExtraGlobal, "name"_a, + "i"_a) + .def("get_new_anonymous_name", &llvm::RecordKeeper::getNewAnonymousName, + nb::rv_policy::reference_internal) + .def( + "get_all_derived_definitions", + [](llvm::RecordKeeper &self, + const llvm::ArrayRef ClassNames) + -> std::vector> { + return self.getAllDerivedDefinitions(ClassNames); + }, + "class_names"_a) + .def("dump", &llvm::RecordKeeper::dump) + .def( + "get_all_derived_definitions", + [](llvm::RecordKeeper &self, const std::string &className) + -> std::vector { + return self.getAllDerivedDefinitions(className); + }, + "class_name"_a, nb::rv_policy::reference_internal) + .def( + "get_all_derived_definitions_if_defined", + [](llvm::RecordKeeper &self, const std::string &className) + -> std::vector { + return self.getAllDerivedDefinitionsIfDefined(className); + }, + "class_name"_a, nb::rv_policy::reference_internal); + + nb::class_(m, "raw_ostream"); + + auto mlir_tblgen_Pred = + nb::class_(m, "Pred") + .def(nb::init<>()) + .def(nb::init(), "record"_a) + .def(nb::init(), "init"_a) + .def("is_null", &mlir::tblgen::Pred::isNull) + .def("get_condition", &mlir::tblgen::Pred::getCondition) + .def("is_combined", &mlir::tblgen::Pred::isCombined) + .def("get_loc", &mlir::tblgen::Pred::getLoc) + .def( + "operator==", + [](mlir::tblgen::Pred &self, const mlir::tblgen::Pred &other) + -> bool { return self.operator==(other); }, + "other"_a) + .def("operator bool", + [](mlir::tblgen::Pred &self) -> bool { + return self.operator bool(); + }) + .def("get_def", &mlir::tblgen::Pred::getDef, + nb::rv_policy::reference_internal); + + auto mlir_tblgen_CPred = + nb::class_(m, "CPred") + .def(nb::init(), "record"_a) + .def(nb::init(), "init"_a); + + mlir_tblgen_CPred.def("get_condition_impl", + &mlir::tblgen::CPred::getConditionImpl); + + auto mlir_tblgen_CombinedPred = + nb::class_(m, + "CombinedPred") + .def(nb::init(), "record"_a) + .def(nb::init(), "init"_a) + .def("get_condition_impl", + &mlir::tblgen::CombinedPred::getConditionImpl) + .def("get_combiner_def", &mlir::tblgen::CombinedPred::getCombinerDef, + nb::rv_policy::reference_internal); + + mlir_tblgen_CombinedPred.def("get_children", + &mlir::tblgen::CombinedPred::getChildren); + + auto mlir_tblgen_SubstLeavesPred = + nb::class_( + m, "SubstLeavesPred") + .def("get_pattern", &mlir::tblgen::SubstLeavesPred::getPattern); + + mlir_tblgen_SubstLeavesPred.def( + "get_replacement", &mlir::tblgen::SubstLeavesPred::getReplacement); + + auto mlir_tblgen_ConcatPred = + nb::class_( + m, "ConcatPred") + .def("get_prefix", &mlir::tblgen::ConcatPred::getPrefix); + + mlir_tblgen_ConcatPred.def("get_suffix", + &mlir::tblgen::ConcatPred::getSuffix); + + auto mlir_tblgen_Constraint = + nb::class_(m, "Constraint") + .def(nb::init(), + "record"_a, "kind"_a) + .def(nb::init(), "record"_a) + .def( + "operator==", + [](mlir::tblgen::Constraint &self, + const mlir::tblgen::Constraint &that) -> bool { + return self.operator==(that); + }, + "that"_a) + .def( + "operator!=", + [](mlir::tblgen::Constraint &self, + const mlir::tblgen::Constraint &that) -> bool { + return self.operator!=(that); + }, + "that"_a) + .def("get_predicate", &mlir::tblgen::Constraint::getPredicate) + .def("get_condition_template", + &mlir::tblgen::Constraint::getConditionTemplate) + .def("get_summary", &mlir::tblgen::Constraint::getSummary) + .def("get_description", &mlir::tblgen::Constraint::getDescription) + .def("get_def_name", &mlir::tblgen::Constraint::getDefName) + .def("get_unique_def_name", + &mlir::tblgen::Constraint::getUniqueDefName) + .def("get_cpp_function_name", + &mlir::tblgen::Constraint::getCppFunctionName) + .def("get_kind", &mlir::tblgen::Constraint::getKind) + .def("get_def", &mlir::tblgen::Constraint::getDef, + nb::rv_policy::reference_internal); + + nb::enum_(mlir_tblgen_Constraint, "Kind") + .value("CK_Attr", mlir::tblgen::Constraint::CK_Attr) + .value("CK_Region", mlir::tblgen::Constraint::CK_Region) + .value("CK_Successor", mlir::tblgen::Constraint::CK_Successor) + .value("CK_Type", mlir::tblgen::Constraint::CK_Type) + .value("CK_Uncategorized", mlir::tblgen::Constraint::CK_Uncategorized); + + auto mlir_tblgen_AppliedConstraint = + nb::class_(m, "AppliedConstraint") + .def(nb::init< + mlir::tblgen::Constraint &&, llvm::StringRef, + std::vector> &&>(), + "constraint"_a, "self"_a, "entities"_a) + .def_rw("constraint", &mlir::tblgen::AppliedConstraint::constraint) + .def_rw("self", &mlir::tblgen::AppliedConstraint::self) + .def_rw("entities", &mlir::tblgen::AppliedConstraint::entities); + + auto mlir_tblgen_AttrConstraint = + nb::class_( + m, "AttrConstraint") + .def_static("classof", &mlir::tblgen::AttrConstraint::classof, "c"_a) + .def("is_sub_class_of", &mlir::tblgen::AttrConstraint::isSubClassOf, + "class_name"_a); + + auto mlir_tblgen_Attribute = + nb::class_( + m, "Attribute") + .def(nb::init(), "record"_a) + .def(nb::init(), "init"_a) + .def("get_storage_type", &mlir::tblgen::Attribute::getStorageType) + .def("get_return_type", &mlir::tblgen::Attribute::getReturnType) + .def("get_value_type", &mlir::tblgen::Attribute::getValueType) + .def("get_convert_from_storage_call", + &mlir::tblgen::Attribute::getConvertFromStorageCall) + .def("is_const_buildable", &mlir::tblgen::Attribute::isConstBuildable) + .def("get_const_builder_template", + &mlir::tblgen::Attribute::getConstBuilderTemplate) + .def("get_base_attr", &mlir::tblgen::Attribute::getBaseAttr) + .def("has_default_value", &mlir::tblgen::Attribute::hasDefaultValue) + .def("get_default_value", &mlir::tblgen::Attribute::getDefaultValue) + .def("is_optional", &mlir::tblgen::Attribute::isOptional) + .def("is_derived_attr", &mlir::tblgen::Attribute::isDerivedAttr) + .def("is_type_attr", &mlir::tblgen::Attribute::isTypeAttr) + .def("is_symbol_ref_attr", &mlir::tblgen::Attribute::isSymbolRefAttr) + .def("is_enum_attr", &mlir::tblgen::Attribute::isEnumAttr) + .def("get_attr_def_name", &mlir::tblgen::Attribute::getAttrDefName) + .def("get_derived_code_body", + &mlir::tblgen::Attribute::getDerivedCodeBody) + .def("get_dialect", &mlir::tblgen::Attribute::getDialect) + .def("get_def", &mlir::tblgen::Attribute::getDef, + nb::rv_policy::reference_internal); + + auto mlir_tblgen_ConstantAttr = + nb::class_(m, "ConstantAttr") + .def(nb::init(), "init"_a) + .def("get_attribute", &mlir::tblgen::ConstantAttr::getAttribute) + .def("get_constant_value", + &mlir::tblgen::ConstantAttr::getConstantValue); + + auto mlir_tblgen_EnumAttrCase = + nb::class_( + m, "EnumAttrCase") + .def(nb::init(), "record"_a) + .def(nb::init(), "init"_a) + .def("get_symbol", &mlir::tblgen::EnumAttrCase::getSymbol) + .def("get_str", &mlir::tblgen::EnumAttrCase::getStr) + .def("get_value", &mlir::tblgen::EnumAttrCase::getValue); + + mlir_tblgen_EnumAttrCase.def("get_def", &mlir::tblgen::EnumAttrCase::getDef, + nb::rv_policy::reference_internal); + + auto mlir_tblgen_EnumAttr = + nb::class_(m, "EnumAttr") + .def(nb::init(), "record"_a) + .def(nb::init(), "record"_a) + .def(nb::init(), "init"_a) + .def_static("classof", &mlir::tblgen::EnumAttr::classof, "attr"_a) + .def("is_bit_enum", &mlir::tblgen::EnumAttr::isBitEnum) + .def("get_enum_class_name", &mlir::tblgen::EnumAttr::getEnumClassName) + .def("get_cpp_namespace", &mlir::tblgen::EnumAttr::getCppNamespace) + .def("get_underlying_type", + &mlir::tblgen::EnumAttr::getUnderlyingType) + .def("get_underlying_to_symbol_fn_name", + &mlir::tblgen::EnumAttr::getUnderlyingToSymbolFnName) + .def("get_string_to_symbol_fn_name", + &mlir::tblgen::EnumAttr::getStringToSymbolFnName) + .def("get_symbol_to_string_fn_name", + &mlir::tblgen::EnumAttr::getSymbolToStringFnName) + .def("get_symbol_to_string_fn_ret_type", + &mlir::tblgen::EnumAttr::getSymbolToStringFnRetType) + .def("get_max_enum_val_fn_name", + &mlir::tblgen::EnumAttr::getMaxEnumValFnName) + .def("get_all_cases", &mlir::tblgen::EnumAttr::getAllCases) + .def("gen_specialized_attr", + &mlir::tblgen::EnumAttr::genSpecializedAttr) + .def("get_base_attr_class", &mlir::tblgen::EnumAttr::getBaseAttrClass, + nb::rv_policy::reference_internal) + .def("get_specialized_attr_class_name", + &mlir::tblgen::EnumAttr::getSpecializedAttrClassName); + + mlir_tblgen_EnumAttr.def("print_bit_enum_primary_groups", + &mlir::tblgen::EnumAttr::printBitEnumPrimaryGroups); + + auto mlir_tblgen_Property = + nb::class_(m, "Property") + .def(nb::init(), "record"_a) + .def(nb::init(), "init"_a) + .def(nb::init(), + "summary"_a, "description"_a, "storage_type"_a, + "interface_type"_a, "convert_from_storage_call"_a, + "assign_to_storage_call"_a, "convert_to_attribute_call"_a, + "convert_from_attribute_call"_a, "parser_call"_a, + "optional_parser_call"_a, "printer_call"_a, + "read_from_mlir_bytecode_call"_a, + "write_to_mlir_bytecode_call"_a, "hash_property_call"_a, + "default_value"_a, "storage_type_value_override"_a) + .def("get_summary", &mlir::tblgen::Property::getSummary) + .def("get_description", &mlir::tblgen::Property::getDescription) + .def("get_storage_type", &mlir::tblgen::Property::getStorageType) + .def("get_interface_type", &mlir::tblgen::Property::getInterfaceType) + .def("get_convert_from_storage_call", + &mlir::tblgen::Property::getConvertFromStorageCall) + .def("get_assign_to_storage_call", + &mlir::tblgen::Property::getAssignToStorageCall) + .def("get_convert_to_attribute_call", + &mlir::tblgen::Property::getConvertToAttributeCall) + .def("get_convert_from_attribute_call", + &mlir::tblgen::Property::getConvertFromAttributeCall) + .def("get_predicate", &mlir::tblgen::Property::getPredicate) + .def("get_parser_call", &mlir::tblgen::Property::getParserCall) + .def("has_optional_parser", + &mlir::tblgen::Property::hasOptionalParser) + .def("get_optional_parser_call", + &mlir::tblgen::Property::getOptionalParserCall) + .def("get_printer_call", &mlir::tblgen::Property::getPrinterCall) + .def("get_read_from_mlir_bytecode_call", + &mlir::tblgen::Property::getReadFromMlirBytecodeCall) + .def("get_write_to_mlir_bytecode_call", + &mlir::tblgen::Property::getWriteToMlirBytecodeCall) + .def("get_hash_property_call", + &mlir::tblgen::Property::getHashPropertyCall) + .def("has_default_value", &mlir::tblgen::Property::hasDefaultValue) + .def("get_default_value", &mlir::tblgen::Property::getDefaultValue) + .def("has_storage_type_value_override", + &mlir::tblgen::Property::hasStorageTypeValueOverride) + .def("get_storage_type_value_override", + &mlir::tblgen::Property::getStorageTypeValueOverride) + .def("get_property_def_name", + &mlir::tblgen::Property::getPropertyDefName) + .def("get_base_property", &mlir::tblgen::Property::getBaseProperty) + .def("get_def", &mlir::tblgen::Property::getDef, + nb::rv_policy::reference_internal); + + auto mlir_tblgen_NamedProperty = + nb::class_(m, "NamedProperty") + .def_rw("name", &mlir::tblgen::NamedProperty::name) + .def_rw("prop", &mlir::tblgen::NamedProperty::prop); + + auto mlir_tblgen_Dialect = + nb::class_(m, "Dialect") + .def(nb::init(), "def_"_a) + .def("get_name", &mlir::tblgen::Dialect::getName) + .def("get_cpp_namespace", &mlir::tblgen::Dialect::getCppNamespace) + .def("get_cpp_class_name", &mlir::tblgen::Dialect::getCppClassName) + .def("get_summary", &mlir::tblgen::Dialect::getSummary) + .def("get_description", &mlir::tblgen::Dialect::getDescription) + .def("get_dependent_dialects", + &mlir::tblgen::Dialect::getDependentDialects) + .def("get_extra_class_declaration", + &mlir::tblgen::Dialect::getExtraClassDeclaration) + .def("has_canonicalizer", &mlir::tblgen::Dialect::hasCanonicalizer) + .def("has_constant_materializer", + &mlir::tblgen::Dialect::hasConstantMaterializer) + .def("has_non_default_destructor", + &mlir::tblgen::Dialect::hasNonDefaultDestructor) + .def("has_operation_attr_verify", + &mlir::tblgen::Dialect::hasOperationAttrVerify) + .def("has_region_arg_attr_verify", + &mlir::tblgen::Dialect::hasRegionArgAttrVerify) + .def("has_region_result_attr_verify", + &mlir::tblgen::Dialect::hasRegionResultAttrVerify) + .def("has_operation_interface_fallback", + &mlir::tblgen::Dialect::hasOperationInterfaceFallback) + .def("use_default_attribute_printer_parser", + &mlir::tblgen::Dialect::useDefaultAttributePrinterParser) + .def("use_default_type_printer_parser", + &mlir::tblgen::Dialect::useDefaultTypePrinterParser) + .def("is_extensible", &mlir::tblgen::Dialect::isExtensible) + .def("use_properties_for_attributes", + &mlir::tblgen::Dialect::usePropertiesForAttributes) + .def("get_discardable_attributes", + &mlir::tblgen::Dialect::getDiscardableAttributes, + nb::rv_policy::reference_internal) + .def("get_def", &mlir::tblgen::Dialect::getDef, + nb::rv_policy::reference_internal) + .def( + "operator==", + [](mlir::tblgen::Dialect &self, + const mlir::tblgen::Dialect &other) -> bool { + return self.operator==(other); + }, + "other"_a) + .def( + "operator!=", + [](mlir::tblgen::Dialect &self, + const mlir::tblgen::Dialect &other) -> bool { + return self.operator!=(other); + }, + "other"_a) + .def( + "operator<", + [](mlir::tblgen::Dialect &self, + const mlir::tblgen::Dialect &other) -> bool { + return self.operator<(other); + }, + "other"_a) + .def("operator bool", [](mlir::tblgen::Dialect &self) -> bool { + return self.operator bool(); + }); + + auto mlir_tblgen_TypeConstraint = + nb::class_( + m, "TypeConstraint") + .def(nb::init(), "record"_a) + .def_static("classof", &mlir::tblgen::TypeConstraint::classof, "c"_a) + .def("is_optional", &mlir::tblgen::TypeConstraint::isOptional) + .def("is_variadic", &mlir::tblgen::TypeConstraint::isVariadic) + .def("is_variadic_of_variadic", + &mlir::tblgen::TypeConstraint::isVariadicOfVariadic) + .def("get_variadic_of_variadic_segment_size_attr", + &mlir::tblgen::TypeConstraint:: + getVariadicOfVariadicSegmentSizeAttr) + .def("is_variable_length", + &mlir::tblgen::TypeConstraint::isVariableLength) + .def("get_builder_call", + &mlir::tblgen::TypeConstraint::getBuilderCall) + .def("get_cpp_type", &mlir::tblgen::TypeConstraint::getCppType); + + auto mlir_tblgen_Type = + nb::class_(m, "Type") + .def(nb::init(), "record"_a); + + mlir_tblgen_Type.def("get_dialect", &mlir::tblgen::Type::getDialect); + + auto mlir_tblgen_NamedAttribute = + nb::class_(m, "NamedAttribute") + .def_rw("name", &mlir::tblgen::NamedAttribute::name) + .def_rw("attr", &mlir::tblgen::NamedAttribute::attr); + + auto mlir_tblgen_NamedTypeConstraint = + nb::class_(m, "NamedTypeConstraint") + .def("has_predicate", + &mlir::tblgen::NamedTypeConstraint::hasPredicate) + .def("is_optional", &mlir::tblgen::NamedTypeConstraint::isOptional) + .def("is_variadic", &mlir::tblgen::NamedTypeConstraint::isVariadic) + .def("is_variadic_of_variadic", + &mlir::tblgen::NamedTypeConstraint::isVariadicOfVariadic) + .def("is_variable_length", + &mlir::tblgen::NamedTypeConstraint::isVariableLength) + .def_rw("name", &mlir::tblgen::NamedTypeConstraint::name) + .def_rw("constraint", &mlir::tblgen::NamedTypeConstraint::constraint); + + auto mlir_tblgen_Builder = nb::class_(m, "Builder"); + + auto mlir_tblgen_Builder_Parameter = + nb::class_(mlir_tblgen_Builder, + "Parameter") + .def("get_cpp_type", &mlir::tblgen::Builder::Parameter::getCppType) + .def("get_name", &mlir::tblgen::Builder::Parameter::getName) + .def("get_default_value", + &mlir::tblgen::Builder::Parameter::getDefaultValue); + + mlir_tblgen_Builder + .def(nb::init>(), + "record"_a, "loc"_a) + .def("get_parameters", &mlir::tblgen::Builder::getParameters) + .def("get_body", &mlir::tblgen::Builder::getBody) + .def("get_deprecated_message", + &mlir::tblgen::Builder::getDeprecatedMessage); + + auto mlir_tblgen_Trait = + nb::class_(m, "Trait") + .def(nb::init(), + "kind"_a, "def_"_a) + .def_static("create", &mlir::tblgen::Trait::create, "init"_a) + .def("get_kind", &mlir::tblgen::Trait::getKind) + .def("get_def", &mlir::tblgen::Trait::getDef, + nb::rv_policy::reference_internal); + nb::enum_(mlir_tblgen_Trait, "Kind") + .value("Native", mlir::tblgen::Trait::Kind::Native) + .value("Pred", mlir::tblgen::Trait::Kind::Pred) + .value("Internal", mlir::tblgen::Trait::Kind::Internal) + .value("Interface", mlir::tblgen::Trait::Kind::Interface); + + auto mlir_tblgen_NativeTrait = + nb::class_(m, + "NativeTrait") + .def("get_fully_qualified_trait_name", + &mlir::tblgen::NativeTrait::getFullyQualifiedTraitName) + .def("is_structural_op_trait", + &mlir::tblgen::NativeTrait::isStructuralOpTrait) + .def("get_extra_concrete_class_declaration", + &mlir::tblgen::NativeTrait::getExtraConcreteClassDeclaration) + .def("get_extra_concrete_class_definition", + &mlir::tblgen::NativeTrait::getExtraConcreteClassDefinition) + .def_static("classof", &mlir::tblgen::NativeTrait::classof, "t"_a); + + auto mlir_tblgen_PredTrait = + nb::class_(m, "PredTrait") + .def("get_pred_template", &mlir::tblgen::PredTrait::getPredTemplate) + .def("get_summary", &mlir::tblgen::PredTrait::getSummary) + .def_static("classof", &mlir::tblgen::PredTrait::classof, "t"_a); + + auto mlir_tblgen_InternalTrait = + nb::class_( + m, "InternalTrait") + .def("get_fully_qualified_trait_name", + &mlir::tblgen::InternalTrait::getFullyQualifiedTraitName) + .def_static("classof", &mlir::tblgen::InternalTrait::classof, "t"_a); + + auto mlir_tblgen_InterfaceTrait = + nb::class_( + m, "InterfaceTrait") + .def("get_interface", &mlir::tblgen::InterfaceTrait::getInterface) + .def("get_fully_qualified_trait_name", + &mlir::tblgen::InterfaceTrait::getFullyQualifiedTraitName) + .def_static("classof", &mlir::tblgen::InterfaceTrait::classof, "t"_a) + .def("should_declare_methods", + &mlir::tblgen::InterfaceTrait::shouldDeclareMethods) + .def("get_always_declared_methods", + &mlir::tblgen::InterfaceTrait::getAlwaysDeclaredMethods); + + auto mlir_tblgen_AttrOrTypeBuilder = + nb::class_( + m, "AttrOrTypeBuilder") + .def("get_return_type", + &mlir::tblgen::AttrOrTypeBuilder::getReturnType) + .def("has_inferred_context_parameter", + &mlir::tblgen::AttrOrTypeBuilder::hasInferredContextParameter); + + auto mlir_tblgen_AttrOrTypeParameter = + nb::class_(m, "AttrOrTypeParameter") + .def(nb::init(), "def_"_a, + "index"_a) + .def("is_anonymous", &mlir::tblgen::AttrOrTypeParameter::isAnonymous) + .def("get_name", &mlir::tblgen::AttrOrTypeParameter::getName) + .def("get_accessor_name", + &mlir::tblgen::AttrOrTypeParameter::getAccessorName) + .def("get_allocator", + &mlir::tblgen::AttrOrTypeParameter::getAllocator) + .def("get_comparator", + &mlir::tblgen::AttrOrTypeParameter::getComparator) + .def("get_cpp_type", &mlir::tblgen::AttrOrTypeParameter::getCppType) + .def("get_cpp_accessor_type", + &mlir::tblgen::AttrOrTypeParameter::getCppAccessorType) + .def("get_cpp_storage_type", + &mlir::tblgen::AttrOrTypeParameter::getCppStorageType) + .def("get_convert_from_storage", + &mlir::tblgen::AttrOrTypeParameter::getConvertFromStorage) + .def("get_parser", &mlir::tblgen::AttrOrTypeParameter::getParser) + .def("get_constraint", + &mlir::tblgen::AttrOrTypeParameter::getConstraint) + .def("get_printer", &mlir::tblgen::AttrOrTypeParameter::getPrinter) + .def("get_summary", &mlir::tblgen::AttrOrTypeParameter::getSummary) + .def("get_syntax", &mlir::tblgen::AttrOrTypeParameter::getSyntax) + .def("is_optional", &mlir::tblgen::AttrOrTypeParameter::isOptional) + .def("get_default_value", + &mlir::tblgen::AttrOrTypeParameter::getDefaultValue) + .def("get_def", &mlir::tblgen::AttrOrTypeParameter::getDef, + nb::rv_policy::reference_internal) + .def( + "operator==", + [](mlir::tblgen::AttrOrTypeParameter &self, + const mlir::tblgen::AttrOrTypeParameter &other) -> bool { + return self.operator==(other); + }, + "other"_a) + .def( + "operator!=", + [](mlir::tblgen::AttrOrTypeParameter &self, + const mlir::tblgen::AttrOrTypeParameter &other) -> bool { + return self.operator!=(other); + }, + "other"_a); + + auto mlir_tblgen_AttributeSelfTypeParameter = + nb::class_( + m, "AttributeSelfTypeParameter") + .def_static("classof", + &mlir::tblgen::AttributeSelfTypeParameter::classof, + "param"_a); + + auto mlir_tblgen_AttrOrTypeDef = + nb::class_(m, "AttrOrTypeDef") + .def(nb::init(), "def_"_a) + .def("get_dialect", &mlir::tblgen::AttrOrTypeDef::getDialect) + .def("get_name", &mlir::tblgen::AttrOrTypeDef::getName) + .def("has_description", &mlir::tblgen::AttrOrTypeDef::hasDescription) + .def("get_description", &mlir::tblgen::AttrOrTypeDef::getDescription) + .def("has_summary", &mlir::tblgen::AttrOrTypeDef::hasSummary) + .def("get_summary", &mlir::tblgen::AttrOrTypeDef::getSummary) + .def("get_cpp_class_name", + &mlir::tblgen::AttrOrTypeDef::getCppClassName) + .def("get_cpp_base_class_name", + &mlir::tblgen::AttrOrTypeDef::getCppBaseClassName) + .def("get_storage_class_name", + &mlir::tblgen::AttrOrTypeDef::getStorageClassName) + .def("get_storage_namespace", + &mlir::tblgen::AttrOrTypeDef::getStorageNamespace) + .def("gen_storage_class", + &mlir::tblgen::AttrOrTypeDef::genStorageClass) + .def("has_storage_custom_constructor", + &mlir::tblgen::AttrOrTypeDef::hasStorageCustomConstructor) + .def("get_parameters", &mlir::tblgen::AttrOrTypeDef::getParameters) + .def("get_num_parameters", + &mlir::tblgen::AttrOrTypeDef::getNumParameters) + .def("get_mnemonic", &mlir::tblgen::AttrOrTypeDef::getMnemonic) + .def("has_custom_assembly_format", + &mlir::tblgen::AttrOrTypeDef::hasCustomAssemblyFormat) + .def("get_assembly_format", + &mlir::tblgen::AttrOrTypeDef::getAssemblyFormat) + .def("gen_accessors", &mlir::tblgen::AttrOrTypeDef::genAccessors) + .def("gen_verify_decl", &mlir::tblgen::AttrOrTypeDef::genVerifyDecl) + .def("gen_verify_invariants_impl", + &mlir::tblgen::AttrOrTypeDef::genVerifyInvariantsImpl) + .def("get_extra_decls", &mlir::tblgen::AttrOrTypeDef::getExtraDecls) + .def("get_extra_defs", &mlir::tblgen::AttrOrTypeDef::getExtraDefs) + .def("get_loc", &mlir::tblgen::AttrOrTypeDef::getLoc) + .def("skip_default_builders", + &mlir::tblgen::AttrOrTypeDef::skipDefaultBuilders) + .def("get_builders", &mlir::tblgen::AttrOrTypeDef::getBuilders) + .def("get_traits", &mlir::tblgen::AttrOrTypeDef::getTraits) + .def( + "operator==", + [](mlir::tblgen::AttrOrTypeDef &self, + const mlir::tblgen::AttrOrTypeDef &other) -> bool { + return self.operator==(other); + }, + "other"_a) + .def( + "operator<", + [](mlir::tblgen::AttrOrTypeDef &self, + const mlir::tblgen::AttrOrTypeDef &other) -> bool { + return self.operator<(other); + }, + "other"_a) + .def("operator bool", + [](mlir::tblgen::AttrOrTypeDef &self) -> bool { + return self.operator bool(); + }) + .def("get_def", &mlir::tblgen::AttrOrTypeDef::getDef, + nb::rv_policy::reference_internal); + + auto mlir_tblgen_AttrDef = + nb::class_(m, + "AttrDef") + .def("get_type_builder", &mlir::tblgen::AttrDef::getTypeBuilder) + .def_static("classof", &mlir::tblgen::AttrDef::classof, "def_"_a) + .def("get_attr_name", &mlir::tblgen::AttrDef::getAttrName); + + auto mlir_tblgen_TypeDef = + nb::class_(m, + "TypeDef") + .def_static("classof", &mlir::tblgen::TypeDef::classof, "def_"_a) + .def("get_type_name", &mlir::tblgen::TypeDef::getTypeName); + + auto mlir_raw_indented_ostream = + nb::class_( + m, "raw_indented_ostream") + .def(nb::init(), "os"_a) + .def("get_o_stream", &mlir::raw_indented_ostream::getOStream, + nb::rv_policy::reference_internal) + .def("scope", &mlir::raw_indented_ostream::scope, "open"_a, "close"_a, + "indent"_a) + .def("print_reindented", &mlir::raw_indented_ostream::printReindented, + "str"_a, "extra_prefix"_a, nb::rv_policy::reference_internal) + .def( + "indent", + [](mlir::raw_indented_ostream &self) + -> mlir::raw_indented_ostream & { return self.indent(); }, + nb::rv_policy::reference_internal) + .def("unindent", &mlir::raw_indented_ostream::unindent, + nb::rv_policy::reference_internal) + .def( + "indent", + [](mlir::raw_indented_ostream &self, int with) + -> mlir::raw_indented_ostream & { return self.indent(with); }, + "with"_a, nb::rv_policy::reference_internal) + .def("print_reindented", &mlir::raw_indented_ostream::printReindented, + "str"_a, "extra_prefix"_a, nb::rv_policy::reference_internal); + + auto mlir_raw_indented_ostream_DelimitedScope = + nb::class_( + mlir_raw_indented_ostream, "DelimitedScope"); + mlir_raw_indented_ostream_DelimitedScope.def( + nb::init(), + + "os"_a, "open"_a, "close"_a, "indent"_a); + auto mlir_tblgen_FmtContext = + nb::class_(m, "FmtContext") + .def(nb::init<>()) + .def(nb::init>>(), + "subs"_a) + .def("add_subst", &mlir::tblgen::FmtContext::addSubst, + "placeholder"_a, "subst"_a, nb::rv_policy::reference_internal) + .def("with_builder", &mlir::tblgen::FmtContext::withBuilder, + "subst"_a, nb::rv_policy::reference_internal) + .def("with_self", &mlir::tblgen::FmtContext::withSelf, "subst"_a, + nb::rv_policy::reference_internal) + .def( + "get_subst_for", + [](mlir::tblgen::FmtContext &self, + mlir::tblgen::FmtContext::PHKind placeholder) + -> std::optional { + return self.getSubstFor(placeholder); + }, + "placeholder"_a) + .def( + "get_subst_for", + [](mlir::tblgen::FmtContext &self, llvm::StringRef placeholder) + -> std::optional { + return self.getSubstFor(placeholder); + }, + "placeholder"_a) + .def_static("get_place_holder_kind", + &mlir::tblgen::FmtContext::getPlaceHolderKind, "str"_a); + + nb::enum_(mlir_tblgen_FmtContext, "PHKind") + .value("None", mlir::tblgen::FmtContext::PHKind::None) + .value("Custom", mlir::tblgen::FmtContext::PHKind::Custom) + .value("Builder", mlir::tblgen::FmtContext::PHKind::Builder) + .value("Self", mlir::tblgen::FmtContext::PHKind::Self); + + auto mlir_tblgen_FmtReplacement = + nb::class_(m, "FmtReplacement") + .def(nb::init<>()) + .def(nb::init(), "literal"_a) + .def(nb::init(), "spec"_a, "index"_a) + .def(nb::init(), "spec"_a, "index"_a, + "end"_a) + .def(nb::init(), + "spec"_a, "placeholder"_a) + .def_rw("type", &mlir::tblgen::FmtReplacement::type) + .def_rw("spec", &mlir::tblgen::FmtReplacement::spec) + .def_rw("index", &mlir::tblgen::FmtReplacement::index) + .def_rw("end", &mlir::tblgen::FmtReplacement::end) + .def_rw("placeholder", &mlir::tblgen::FmtReplacement::placeholder); + nb::enum_(mlir_tblgen_FmtReplacement, + "Type") + .value("Empty", mlir::tblgen::FmtReplacement::Type::Empty) + .value("Literal", mlir::tblgen::FmtReplacement::Type::Literal) + .value("PositionalPH", mlir::tblgen::FmtReplacement::Type::PositionalPH) + .value("PositionalRangePH", + mlir::tblgen::FmtReplacement::Type::PositionalRangePH) + .value("SpecialPH", mlir::tblgen::FmtReplacement::Type::SpecialPH); + + mlir_tblgen_FmtReplacement.def_ro_static( + "k_unset", &mlir::tblgen::FmtReplacement::kUnset); + + auto mlir_tblgen_FmtObjectBase = + nb::class_(m, "FmtObjectBase") + .def(nb::init(), + "fmt"_a, "ctx"_a, "num_params"_a) + .def(nb::init(), "that"_a) + .def("format", &mlir::tblgen::FmtObjectBase::format, "s"_a) + .def("str", &mlir::tblgen::FmtObjectBase::str) + .def("__str__", [](mlir::tblgen::FmtObjectBase &self) -> std::string { + return self.str(); + }); + + auto mlir_tblgen_FmtStrVecObject = + nb::class_( + m, "FmtStrVecObject") + .def(nb::init>(), + "fmt"_a, "ctx"_a, "params"_a) + .def(nb::init(), "that"_a); + + m.def( + "tgfmt", + [](llvm::StringRef fmt, const mlir::tblgen::FmtContext *ctx, + llvm::ArrayRef params) -> mlir::tblgen::FmtStrVecObject { + return mlir::tblgen::tgfmt(fmt, ctx, params); + }, + "fmt"_a, "ctx"_a, "params"_a); + + auto mlir_tblgen_IfDefScope = + nb::class_(m, "IfDefScope") + .def(nb::init(), "name"_a, + "os"_a); + + auto mlir_tblgen_NamespaceEmitter = + nb::class_(m, "NamespaceEmitter") + .def(nb::init(), + "os"_a, "dialect"_a) + .def(nb::init(), "os"_a, + "cpp_namespace"_a); + + auto mlir_tblgen_StaticVerifierFunctionEmitter = + nb::class_( + m, "StaticVerifierFunctionEmitter") + .def(nb::init(), + "os"_a, "records"_a, "tag"_a) + .def("collect_op_constraints", + &mlir::tblgen::StaticVerifierFunctionEmitter:: + collectOpConstraints, + "op_defs"_a) + .def("emit_op_constraints", + &mlir::tblgen::StaticVerifierFunctionEmitter::emitOpConstraints, + "op_defs"_a) + .def( + "emit_pattern_constraints", + [](mlir::tblgen::StaticVerifierFunctionEmitter &self, + const llvm::ArrayRef constraints) + -> void { return self.emitPatternConstraints(constraints); }, + "constraints"_a) + .def( + "get_type_constraint_fn", + &mlir::tblgen::StaticVerifierFunctionEmitter::getTypeConstraintFn, + "constraint"_a) + .def( + "get_attr_constraint_fn", + &mlir::tblgen::StaticVerifierFunctionEmitter::getAttrConstraintFn, + "constraint"_a) + .def("get_successor_constraint_fn", + &mlir::tblgen::StaticVerifierFunctionEmitter:: + getSuccessorConstraintFn, + "constraint"_a) + .def("get_region_constraint_fn", + &mlir::tblgen::StaticVerifierFunctionEmitter:: + getRegionConstraintFn, + "constraint"_a); + + m.def("escape_string", &mlir::tblgen::escapeString, "value"_a); + + auto mlir_tblgen_MethodParameter = + nb::class_(m, "MethodParameter") + .def("write_decl_to", &mlir::tblgen::MethodParameter::writeDeclTo, + "os"_a) + .def("write_def_to", &mlir::tblgen::MethodParameter::writeDefTo, + "os"_a) + .def("get_type", &mlir::tblgen::MethodParameter::getType) + .def("get_name", &mlir::tblgen::MethodParameter::getName) + .def("has_default_value", + &mlir::tblgen::MethodParameter::hasDefaultValue); + + auto mlir_tblgen_MethodParameters = + nb::class_(m, "MethodParameters") + .def(nb::init>(), + "parameters"_a) + .def(nb::init>(), + "parameters"_a) + .def("write_decl_to", &mlir::tblgen::MethodParameters::writeDeclTo, + "os"_a) + .def("write_def_to", &mlir::tblgen::MethodParameters::writeDefTo, + "os"_a) + .def("subsumes", &mlir::tblgen::MethodParameters::subsumes, "other"_a) + .def("get_num_parameters", + &mlir::tblgen::MethodParameters::getNumParameters); + + auto mlir_tblgen_MethodSignature = + nb::class_(m, "MethodSignature") + .def("makes_redundant", + &mlir::tblgen::MethodSignature::makesRedundant, "other"_a) + .def("get_name", &mlir::tblgen::MethodSignature::getName) + .def("get_return_type", &mlir::tblgen::MethodSignature::getReturnType) + .def("get_num_parameters", + &mlir::tblgen::MethodSignature::getNumParameters) + .def("write_decl_to", &mlir::tblgen::MethodSignature::writeDeclTo, + "os"_a) + .def("write_def_to", &mlir::tblgen::MethodSignature::writeDefTo, + "os"_a, "name_prefix"_a) + .def("write_template_params_to", + &mlir::tblgen::MethodSignature::writeTemplateParamsTo, "os"_a); + + auto mlir_tblgen_MethodBody = + nb::class_(m, "MethodBody") + .def(nb::init(), "decl_only"_a) + .def(nb::init(), "other"_a) + .def( + "operator=", + [](mlir::tblgen::MethodBody &self, + mlir::tblgen::MethodBody &&body) + -> mlir::tblgen::MethodBody & { + return self.operator=(std::move(body)); + }, + "body"_a, nb::rv_policy::reference_internal) + .def("write_to", &mlir::tblgen::MethodBody::writeTo, "os"_a) + .def("indent", &mlir::tblgen::MethodBody::indent, + nb::rv_policy::reference_internal) + .def("unindent", &mlir::tblgen::MethodBody::unindent, + nb::rv_policy::reference_internal) + .def("scope", &mlir::tblgen::MethodBody::scope, "open"_a, "close"_a, + "indent"_a) + .def("get_stream", &mlir::tblgen::MethodBody::getStream, + nb::rv_policy::reference_internal); + + auto mlir_tblgen_ClassDeclaration = + nb::class_(m, "ClassDeclaration"); + nb::enum_(mlir_tblgen_ClassDeclaration, + "Kind") + .value("Method", mlir::tblgen::ClassDeclaration::Method) + .value("UsingDeclaration", + mlir::tblgen::ClassDeclaration::UsingDeclaration) + .value("VisibilityDeclaration", + mlir::tblgen::ClassDeclaration::VisibilityDeclaration) + .value("Field", mlir::tblgen::ClassDeclaration::Field) + .value("ExtraClassDeclaration", + mlir::tblgen::ClassDeclaration::ExtraClassDeclaration); + + auto mlir_tblgen_Method = + nb::class_(m, "Method") + .def(nb::init>(), + "ret_type"_a, "name"_a, "properties"_a, "params"_a) + .def(nb::init(), "_"_a) + .def( + "operator=", + [](mlir::tblgen::Method &self, + mlir::tblgen::Method &&_) -> mlir::tblgen::Method & { + return self.operator=(std::move(_)); + }, + "_"_a, nb::rv_policy::reference_internal) + .def("body", &mlir::tblgen::Method::body, + nb::rv_policy::reference_internal) + .def("set_deprecated", &mlir::tblgen::Method::setDeprecated, + "message"_a) + .def("is_static", &mlir::tblgen::Method::isStatic) + .def("is_private", &mlir::tblgen::Method::isPrivate) + .def("is_inline", &mlir::tblgen::Method::isInline) + .def("is_constructor", &mlir::tblgen::Method::isConstructor) + .def("is_const", &mlir::tblgen::Method::isConst) + .def("get_name", &mlir::tblgen::Method::getName) + .def("get_return_type", &mlir::tblgen::Method::getReturnType) + .def("makes_redundant", &mlir::tblgen::Method::makesRedundant, + "other"_a) + .def("write_decl_to", &mlir::tblgen::Method::writeDeclTo, "os"_a) + .def("write_def_to", &mlir::tblgen::Method::writeDefTo, "os"_a, + "name_prefix"_a); + + nb::enum_(mlir_tblgen_Method, "Properties") + .value("None", mlir::tblgen::Method::None) + .value("Static", mlir::tblgen::Method::Static) + .value("Constructor", mlir::tblgen::Method::Constructor) + .value("Private", mlir::tblgen::Method::Private) + .value("Declaration", mlir::tblgen::Method::Declaration) + .value("Inline", mlir::tblgen::Method::Inline) + .value("ConstexprValue", mlir::tblgen::Method::ConstexprValue) + .value("Const", mlir::tblgen::Method::Const) + .value("Constexpr", mlir::tblgen::Method::Constexpr) + .value("StaticDeclaration", mlir::tblgen::Method::StaticDeclaration) + .value("StaticInline", mlir::tblgen::Method::StaticInline) + .value("ConstInline", mlir::tblgen::Method::ConstInline) + .value("ConstDeclaration", mlir::tblgen::Method::ConstDeclaration); + + nb::enum_(m, "Visibility") + .value("Public", mlir::tblgen::Visibility::Public) + .value("Protected", mlir::tblgen::Visibility::Protected) + .value("Private", mlir::tblgen::Visibility::Private); + + m.def( + "operator<<", + [](llvm::raw_ostream &os, + mlir::tblgen::Visibility visibility) -> llvm::raw_ostream & { + return mlir::tblgen::operator<<(os, visibility); + }, + "os"_a, "visibility"_a); + + auto mlir_tblgen_Constructor = + nb::class_(m, + "Constructor") + .def("write_decl_to", &mlir::tblgen::Constructor::writeDeclTo, "os"_a) + .def("write_def_to", &mlir::tblgen::Constructor::writeDefTo, "os"_a, + "name_prefix"_a) + .def_static("classof", &mlir::tblgen::Constructor::classof, + "other"_a); + + auto mlir_tblgen_Constructor_MemberInitializer = + nb::class_( + mlir_tblgen_Constructor, "MemberInitializer") + .def(nb::init(), "name"_a, "value"_a) + .def("write_to", + &mlir::tblgen::Constructor::MemberInitializer::writeTo, "os"_a); + + auto mlir_tblgen_ParentClass = + nb::class_(m, "ParentClass") + .def("write_to", &mlir::tblgen::ParentClass::writeTo, "os"_a); + + auto mlir_tblgen_UsingDeclaration = + nb::class_(m, "UsingDeclaration") + .def("write_decl_to", &mlir::tblgen::UsingDeclaration::writeDeclTo, + "os"_a); + + auto mlir_tblgen_Field = + nb::class_(m, "Field") + .def("write_decl_to", &mlir::tblgen::Field::writeDeclTo, "os"_a); + + auto mlir_tblgen_VisibilityDeclaration = + nb::class_(m, + "VisibilityDeclaration") + .def(nb::init(), "visibility"_a) + .def("get_visibility", + &mlir::tblgen::VisibilityDeclaration::getVisibility) + .def("write_decl_to", + &mlir::tblgen::VisibilityDeclaration::writeDeclTo, "os"_a); + + auto mlir_tblgen_ExtraClassDeclaration = + nb::class_(m, + "ExtraClassDeclaration") + .def(nb::init(), + "extra_class_declaration"_a, "extra_class_definition"_a) + .def(nb::init(), + "extra_class_declaration"_a, "extra_class_definition"_a) + .def("write_decl_to", + &mlir::tblgen::ExtraClassDeclaration::writeDeclTo, "os"_a) + .def("write_def_to", &mlir::tblgen::ExtraClassDeclaration::writeDefTo, + "os"_a, "name_prefix"_a); + + // "add_parent", + // [](mlir::tblgen::Class &self, mlir::tblgen::ParentClass parent) + // -> mlir::tblgen::ParentClass & { return self.addParent(parent); }, + // "parent"_a, nb::rv_policy::reference_internal) + auto mlir_tblgen_Class = + nb::class_(m, "Class") + .def("get_class_name", &mlir::tblgen::Class::getClassName) + .def( + "write_decl_to", + [](mlir::tblgen::Class &self, llvm::raw_ostream &rawOs) -> void { + return self.writeDeclTo(rawOs); + }, + "raw_os"_a) + .def( + "write_def_to", + [](mlir::tblgen::Class &self, llvm::raw_ostream &rawOs) -> void { + return self.writeDefTo(rawOs); + }, + "raw_os"_a) + .def( + "write_decl_to", + [](mlir::tblgen::Class &self, mlir::raw_indented_ostream &os) + -> void { return self.writeDeclTo(os); }, + "os"_a) + .def( + "write_def_to", + [](mlir::tblgen::Class &self, mlir::raw_indented_ostream &os) + -> void { return self.writeDefTo(os); }, + "os"_a) + .def("finalize", &mlir::tblgen::Class::finalize); + + auto mlir_GenInfo = + nb::class_(m, "GenInfo") + .def(nb::init>(), + "arg"_a, "description"_a, "generator"_a) + .def("invoke", &mlir::GenInfo::invoke, "records"_a, "os"_a) + .def("get_gen_argument", &mlir::GenInfo::getGenArgument) + .def("get_gen_description", &mlir::GenInfo::getGenDescription); + + auto mlir_GenRegistration = + nb::class_(m, "GenRegistration"); + + mlir_GenRegistration.def( + nb::init &>(), + "arg"_a, "description"_a, "function"_a); + + auto mlir_GenNameParser = + nb::class_(m, "GenNameParser") + .def(nb::init(), "opt"_a) + .def("print_option_info", &mlir::GenNameParser::printOptionInfo, + "o"_a, "global_width"_a); + + auto mlir_tblgen_InterfaceMethod = + nb::class_(m, "InterfaceMethod"); + + auto mlir_tblgen_InterfaceMethod_Argument = + nb::class_( + mlir_tblgen_InterfaceMethod, "Argument") + .def_rw("type", &mlir::tblgen::InterfaceMethod::Argument::type) + .def_rw("name", &mlir::tblgen::InterfaceMethod::Argument::name); + + mlir_tblgen_InterfaceMethod.def(nb::init(), "def_"_a) + .def("get_return_type", &mlir::tblgen::InterfaceMethod::getReturnType) + .def("get_name", &mlir::tblgen::InterfaceMethod::getName) + .def("is_static", &mlir::tblgen::InterfaceMethod::isStatic) + .def("get_body", &mlir::tblgen::InterfaceMethod::getBody) + .def("get_default_implementation", + &mlir::tblgen::InterfaceMethod::getDefaultImplementation) + .def("get_description", &mlir::tblgen::InterfaceMethod::getDescription) + .def("get_arguments", &mlir::tblgen::InterfaceMethod::getArguments) + .def("arg_empty", &mlir::tblgen::InterfaceMethod::arg_empty); + + auto mlir_tblgen_Interface = + nb::class_(m, "Interface") + .def(nb::init(), "def_"_a) + .def(nb::init(), "rhs"_a) + .def("get_name", &mlir::tblgen::Interface::getName) + .def("get_fully_qualified_name", + &mlir::tblgen::Interface::getFullyQualifiedName) + .def("get_cpp_namespace", &mlir::tblgen::Interface::getCppNamespace) + .def("get_methods", &mlir::tblgen::Interface::getMethods) + .def("get_description", &mlir::tblgen::Interface::getDescription) + .def("get_extra_class_declaration", + &mlir::tblgen::Interface::getExtraClassDeclaration) + .def("get_extra_trait_class_declaration", + &mlir::tblgen::Interface::getExtraTraitClassDeclaration) + .def("get_extra_shared_class_declaration", + &mlir::tblgen::Interface::getExtraSharedClassDeclaration) + .def("get_extra_class_of", &mlir::tblgen::Interface::getExtraClassOf) + .def("get_verify", &mlir::tblgen::Interface::getVerify) + .def("get_base_interfaces", + &mlir::tblgen::Interface::getBaseInterfaces) + .def("verify_with_regions", + &mlir::tblgen::Interface::verifyWithRegions) + .def("get_def", &mlir::tblgen::Interface::getDef, + nb::rv_policy::reference_internal); + + auto mlir_tblgen_AttrInterface = + nb::class_( + m, "AttrInterface") + .def_static("classof", &mlir::tblgen::AttrInterface::classof, + "interface"_a); + + auto mlir_tblgen_OpInterface = + nb::class_( + m, "OpInterface") + .def_static("classof", &mlir::tblgen::OpInterface::classof, + "interface"_a); + + auto mlir_tblgen_TypeInterface = + nb::class_( + m, "TypeInterface") + .def_static("classof", &mlir::tblgen::TypeInterface::classof, + "interface"_a); + + auto mlir_tblgen_Region = + nb::class_(m, "Region") + .def_static("classof", &mlir::tblgen::Region::classof, "c"_a) + .def("is_variadic", &mlir::tblgen::Region::isVariadic); + + auto mlir_tblgen_NamedRegion = + nb::class_(m, "NamedRegion") + .def("is_variadic", &mlir::tblgen::NamedRegion::isVariadic) + .def_rw("name", &mlir::tblgen::NamedRegion::name) + .def_rw("constraint", &mlir::tblgen::NamedRegion::constraint); + + auto mlir_tblgen_Successor = + nb::class_(m, + "Successor") + .def_static("classof", &mlir::tblgen::Successor::classof, "c"_a) + .def("is_variadic", &mlir::tblgen::Successor::isVariadic); + + auto mlir_tblgen_NamedSuccessor = + nb::class_(m, "NamedSuccessor") + .def("is_variadic", &mlir::tblgen::NamedSuccessor::isVariadic) + .def_rw("name", &mlir::tblgen::NamedSuccessor::name) + .def_rw("constraint", &mlir::tblgen::NamedSuccessor::constraint); + + auto mlir_tblgen_InferredResultType = + nb::class_(m, "InferredResultType") + .def(nb::init(), "index"_a, "transformer"_a) + .def("is_arg", &mlir::tblgen::InferredResultType::isArg) + .def("get_index", &mlir::tblgen::InferredResultType::getIndex) + .def("get_result_index", + &mlir::tblgen::InferredResultType::getResultIndex) + .def_static("map_result_index", + &mlir::tblgen::InferredResultType::mapResultIndex, "i"_a) + .def_static("unmap_result_index", + &mlir::tblgen::InferredResultType::unmapResultIndex, + "i"_a) + .def_static("is_result_index", + &mlir::tblgen::InferredResultType::isResultIndex, "i"_a) + .def_static("is_arg_index", + &mlir::tblgen::InferredResultType::isArgIndex, "i"_a) + .def("get_transformer", + &mlir::tblgen::InferredResultType::getTransformer); + + auto mlir_tblgen_Operator = + nb::class_(m, "Operator") + .def(nb::init(), "def_"_a) + .def(nb::init(), "def_"_a) + .def("get_dialect_name", &mlir::tblgen::Operator::getDialectName) + .def("get_operation_name", &mlir::tblgen::Operator::getOperationName) + .def("get_cpp_class_name", &mlir::tblgen::Operator::getCppClassName) + .def("get_qual_cpp_class_name", + &mlir::tblgen::Operator::getQualCppClassName) + .def("get_cpp_namespace", &mlir::tblgen::Operator::getCppNamespace) + .def("get_adaptor_name", &mlir::tblgen::Operator::getAdaptorName) + .def("get_generic_adaptor_name", + &mlir::tblgen::Operator::getGenericAdaptorName) + .def("assert_invariants", &mlir::tblgen::Operator::assertInvariants) + + .def("is_variadic", &mlir::tblgen::Operator::isVariadic) + .def("skip_default_builders", + &mlir::tblgen::Operator::skipDefaultBuilders) + .def("result_begin", &mlir::tblgen::Operator::result_begin, + nb::rv_policy::reference_internal) + .def("result_end", &mlir::tblgen::Operator::result_end, + nb::rv_policy::reference_internal) + .def("get_results", &mlir::tblgen::Operator::getResults) + .def("get_num_results", &mlir::tblgen::Operator::getNumResults) + .def( + "get_result", + [](mlir::tblgen::Operator &self, + int index) -> mlir::tblgen::NamedTypeConstraint & { + return self.getResult(index); + }, + "index"_a, nb::rv_policy::reference_internal) + .def( + "get_result", + [](mlir::tblgen::Operator &self, + int index) -> const mlir::tblgen::NamedTypeConstraint & { + return self.getResult(index); + }, + "index"_a, nb::rv_policy::reference_internal) + .def("get_result_type_constraint", + &mlir::tblgen::Operator::getResultTypeConstraint, "index"_a) + .def("get_result_name", &mlir::tblgen::Operator::getResultName, + "index"_a) + .def("get_result_decorators", + &mlir::tblgen::Operator::getResultDecorators, "index"_a) + .def("get_num_variable_length_results", + &mlir::tblgen::Operator::getNumVariableLengthResults) + .def( + "attribute_begin", + [](mlir::tblgen::Operator &self) + -> const mlir::tblgen::NamedAttribute * { + return self.attribute_begin(); + }, + nb::rv_policy::reference_internal) + .def( + "attribute_end", + [](mlir::tblgen::Operator &self) + -> const mlir::tblgen::NamedAttribute * { + return self.attribute_end(); + }, + nb::rv_policy::reference_internal) + .def("get_attributes", + [](mlir::tblgen::Operator &self) + -> llvm::iterator_range< + const mlir::tblgen::NamedAttribute *> { + return self.getAttributes(); + }) + .def( + "attribute_begin", + [](mlir::tblgen::Operator &self) + -> mlir::tblgen::NamedAttribute * { + return self.attribute_begin(); + }, + nb::rv_policy::reference_internal) + .def( + "attribute_end", + [](mlir::tblgen::Operator &self) + -> mlir::tblgen::NamedAttribute * { + return self.attribute_end(); + }, + nb::rv_policy::reference_internal) + .def("get_attributes", + [](mlir::tblgen::Operator &self) + -> llvm::iterator_range { + return self.getAttributes(); + }) + .def("get_num_attributes", &mlir::tblgen::Operator::getNumAttributes) + .def("get_num_native_attributes", + &mlir::tblgen::Operator::getNumNativeAttributes) + .def( + "get_attribute", + [](mlir::tblgen::Operator &self, + int index) -> mlir::tblgen::NamedAttribute & { + return self.getAttribute(index); + }, + "index"_a, nb::rv_policy::reference_internal) + .def( + "get_attribute", + [](mlir::tblgen::Operator &self, + int index) -> const mlir::tblgen::NamedAttribute & { + return self.getAttribute(index); + }, + "index"_a, nb::rv_policy::reference_internal) + .def("operand_begin", &mlir::tblgen::Operator::operand_begin, + nb::rv_policy::reference_internal) + .def("operand_end", &mlir::tblgen::Operator::operand_end, + nb::rv_policy::reference_internal) + .def("get_operands", &mlir::tblgen::Operator::getOperands) + .def( + "properties_begin", + [](mlir::tblgen::Operator &self) + -> const mlir::tblgen::NamedProperty * { + return self.properties_begin(); + }, + nb::rv_policy::reference_internal) + .def( + "properties_end", + [](mlir::tblgen::Operator &self) + -> const mlir::tblgen::NamedProperty * { + return self.properties_end(); + }, + nb::rv_policy::reference_internal) + .def( + "get_properties", + [](mlir::tblgen::Operator &self) + -> llvm::iterator_range { + return self.getProperties(); + }) + .def( + "properties_begin", + [](mlir::tblgen::Operator &self) + -> mlir::tblgen::NamedProperty * { + return self.properties_begin(); + }, + nb::rv_policy::reference_internal) + .def( + "properties_end", + [](mlir::tblgen::Operator &self) + -> mlir::tblgen::NamedProperty * { + return self.properties_end(); + }, + nb::rv_policy::reference_internal) + .def("get_properties", + [](mlir::tblgen::Operator &self) + -> llvm::iterator_range { + return self.getProperties(); + }) + .def("get_num_core_attributes", + &mlir::tblgen::Operator::getNumCoreAttributes) + .def( + "get_property", + [](mlir::tblgen::Operator &self, + int index) -> mlir::tblgen::NamedProperty & { + return self.getProperty(index); + }, + "index"_a, nb::rv_policy::reference_internal) + .def( + "get_property", + [](mlir::tblgen::Operator &self, + int index) -> const mlir::tblgen::NamedProperty & { + return self.getProperty(index); + }, + "index"_a, nb::rv_policy::reference_internal) + .def("get_num_operands", &mlir::tblgen::Operator::getNumOperands) + .def( + "get_operand", + [](mlir::tblgen::Operator &self, + int index) -> mlir::tblgen::NamedTypeConstraint & { + return self.getOperand(index); + }, + "index"_a, nb::rv_policy::reference_internal) + .def( + "get_operand", + [](mlir::tblgen::Operator &self, + int index) -> const mlir::tblgen::NamedTypeConstraint & { + return self.getOperand(index); + }, + "index"_a, nb::rv_policy::reference_internal) + .def("get_num_variable_length_operands", + &mlir::tblgen::Operator::getNumVariableLengthOperands) + .def("get_num_args", &mlir::tblgen::Operator::getNumArgs) + .def("has_single_variadic_arg", + &mlir::tblgen::Operator::hasSingleVariadicArg) + .def("has_single_variadic_result", + &mlir::tblgen::Operator::hasSingleVariadicResult) + .def("has_no_variadic_regions", + &mlir::tblgen::Operator::hasNoVariadicRegions) + .def("arg_begin", &mlir::tblgen::Operator::arg_begin, + nb::rv_policy::reference_internal) + .def("arg_end", &mlir::tblgen::Operator::arg_end, + nb::rv_policy::reference_internal) + .def("get_args", &mlir::tblgen::Operator::getArgs) + .def("get_arg", &mlir::tblgen::Operator::getArg, "index"_a) + .def("get_arg_name", &mlir::tblgen::Operator::getArgName, "index"_a) + .def("get_arg_decorators", &mlir::tblgen::Operator::getArgDecorators, + "index"_a) + .def("get_trait", &mlir::tblgen::Operator::getTrait, "trait"_a, + nb::rv_policy::reference_internal) + .def("region_begin", &mlir::tblgen::Operator::region_begin, + nb::rv_policy::reference_internal) + .def("region_end", &mlir::tblgen::Operator::region_end, + nb::rv_policy::reference_internal) + .def("get_regions", &mlir::tblgen::Operator::getRegions) + .def("get_num_regions", &mlir::tblgen::Operator::getNumRegions) + .def("get_region", &mlir::tblgen::Operator::getRegion, "index"_a, + nb::rv_policy::reference_internal) + .def("get_num_variadic_regions", + &mlir::tblgen::Operator::getNumVariadicRegions) + .def("successor_begin", &mlir::tblgen::Operator::successor_begin, + nb::rv_policy::reference_internal) + .def("successor_end", &mlir::tblgen::Operator::successor_end, + nb::rv_policy::reference_internal) + .def("get_successors", &mlir::tblgen::Operator::getSuccessors) + .def("get_num_successors", &mlir::tblgen::Operator::getNumSuccessors) + .def("get_successor", &mlir::tblgen::Operator::getSuccessor, + "index"_a, nb::rv_policy::reference_internal) + .def("get_num_variadic_successors", + &mlir::tblgen::Operator::getNumVariadicSuccessors) + .def("trait_begin", &mlir::tblgen::Operator::trait_begin, + nb::rv_policy::reference_internal) + .def("trait_end", &mlir::tblgen::Operator::trait_end, + nb::rv_policy::reference_internal) + .def("get_traits", &mlir::tblgen::Operator::getTraits) + .def("get_loc", &mlir::tblgen::Operator::getLoc) + .def("has_description", &mlir::tblgen::Operator::hasDescription) + .def("get_description", &mlir::tblgen::Operator::getDescription) + .def("has_summary", &mlir::tblgen::Operator::hasSummary) + .def("get_summary", &mlir::tblgen::Operator::getSummary) + .def("has_assembly_format", + &mlir::tblgen::Operator::hasAssemblyFormat) + .def("get_assembly_format", + &mlir::tblgen::Operator::getAssemblyFormat) + .def("get_extra_class_declaration", + &mlir::tblgen::Operator::getExtraClassDeclaration) + .def("get_extra_class_definition", + &mlir::tblgen::Operator::getExtraClassDefinition) + .def("get_def", &mlir::tblgen::Operator::getDef, + nb::rv_policy::reference_internal) + .def("get_dialect", &mlir::tblgen::Operator::getDialect, + nb::rv_policy::reference_internal) + .def("print", &mlir::tblgen::Operator::print, "os"_a) + .def("all_result_types_known", + &mlir::tblgen::Operator::allResultTypesKnown) + .def("get_inferred_result_type", + &mlir::tblgen::Operator::getInferredResultType, "index"_a, + nb::rv_policy::reference_internal); + + auto mlir_tblgen_Operator_VariableDecorator = + nb::class_( + mlir_tblgen_Operator, "VariableDecorator") + .def(nb::init(), "def_"_a) + .def("get_def", &mlir::tblgen::Operator::VariableDecorator::getDef, + nb::rv_policy::reference_internal); + + auto mlir_tblgen_Operator_OperandOrAttribute = + nb::class_( + mlir_tblgen_Operator, "OperandOrAttribute") + .def( + nb::init(), + "kind"_a, "index"_a) + .def("operand_or_attribute_index", + &mlir::tblgen::Operator::OperandOrAttribute:: + operandOrAttributeIndex) + .def("kind", &mlir::tblgen::Operator::OperandOrAttribute::kind); + + nb::enum_( + mlir_tblgen_Operator_OperandOrAttribute, "Kind") + .value("Operand", + mlir::tblgen::Operator::OperandOrAttribute::Kind::Operand) + .value("Attribute", + mlir::tblgen::Operator::OperandOrAttribute::Kind::Attribute); + + mlir_tblgen_Operator + .def("get_arg_to_operand_or_attribute", + &mlir::tblgen::Operator::getArgToOperandOrAttribute, "index"_a) + .def("get_builders", &mlir::tblgen::Operator::getBuilders) + .def("get_getter_name", &mlir::tblgen::Operator::getGetterName, "name"_a) + .def("get_setter_name", &mlir::tblgen::Operator::getSetterName, "name"_a) + .def("get_remover_name", &mlir::tblgen::Operator::getRemoverName, + "name"_a) + .def("has_folder", &mlir::tblgen::Operator::hasFolder) + .def("use_custom_properties_encoding", + &mlir::tblgen::Operator::useCustomPropertiesEncoding); + + auto mlir_tblgen_PassOption = + nb::class_(m, "PassOption") + .def(nb::init(), "def_"_a) + .def("get_cpp_variable_name", + &mlir::tblgen::PassOption::getCppVariableName) + .def("get_argument", &mlir::tblgen::PassOption::getArgument) + .def("get_type", &mlir::tblgen::PassOption::getType) + .def("get_default_value", &mlir::tblgen::PassOption::getDefaultValue) + .def("get_description", &mlir::tblgen::PassOption::getDescription) + .def("get_additional_flags", + &mlir::tblgen::PassOption::getAdditionalFlags) + .def("is_list_option", &mlir::tblgen::PassOption::isListOption); + + auto mlir_tblgen_PassStatistic = + nb::class_(m, "PassStatistic") + .def(nb::init(), "def_"_a) + .def("get_cpp_variable_name", + &mlir::tblgen::PassStatistic::getCppVariableName) + .def("get_name", &mlir::tblgen::PassStatistic::getName) + .def("get_description", &mlir::tblgen::PassStatistic::getDescription); + + auto mlir_tblgen_Pass = + nb::class_(m, "Pass") + .def(nb::init(), "def_"_a) + .def("get_argument", &mlir::tblgen::Pass::getArgument) + .def("get_base_class", &mlir::tblgen::Pass::getBaseClass) + .def("get_summary", &mlir::tblgen::Pass::getSummary) + .def("get_description", &mlir::tblgen::Pass::getDescription) + .def("get_constructor", &mlir::tblgen::Pass::getConstructor) + .def("get_dependent_dialects", + &mlir::tblgen::Pass::getDependentDialects) + .def("get_options", &mlir::tblgen::Pass::getOptions) + .def("get_statistics", &mlir::tblgen::Pass::getStatistics) + .def("get_def", &mlir::tblgen::Pass::getDef, + nb::rv_policy::reference_internal); + + auto mlir_tblgen_DagLeaf = + nb::class_(m, "DagLeaf") + .def(nb::init(), "def_"_a) + .def("is_unspecified", &mlir::tblgen::DagLeaf::isUnspecified) + .def("is_operand_matcher", &mlir::tblgen::DagLeaf::isOperandMatcher) + .def("is_attr_matcher", &mlir::tblgen::DagLeaf::isAttrMatcher) + .def("is_native_code_call", &mlir::tblgen::DagLeaf::isNativeCodeCall) + .def("is_constant_attr", &mlir::tblgen::DagLeaf::isConstantAttr) + .def("is_enum_attr_case", &mlir::tblgen::DagLeaf::isEnumAttrCase) + .def("is_string_attr", &mlir::tblgen::DagLeaf::isStringAttr) + .def("get_as_constraint", &mlir::tblgen::DagLeaf::getAsConstraint) + .def("get_as_constant_attr", + &mlir::tblgen::DagLeaf::getAsConstantAttr) + .def("get_as_enum_attr_case", + &mlir::tblgen::DagLeaf::getAsEnumAttrCase) + .def("get_condition_template", + &mlir::tblgen::DagLeaf::getConditionTemplate) + .def("get_native_code_template", + &mlir::tblgen::DagLeaf::getNativeCodeTemplate) + .def("get_num_returns_of_native_code", + &mlir::tblgen::DagLeaf::getNumReturnsOfNativeCode) + .def("get_string_attr", &mlir::tblgen::DagLeaf::getStringAttr) + .def("print", &mlir::tblgen::DagLeaf::print, "os"_a); + + auto mlir_tblgen_DagNode = + nb::class_(m, "DagNode") + .def(nb::init(), "node"_a) + .def("operator bool", + [](mlir::tblgen::DagNode &self) -> bool { + return self.operator bool(); + }) + .def("get_symbol", &mlir::tblgen::DagNode::getSymbol) + .def("get_dialect_op", &mlir::tblgen::DagNode::getDialectOp, + "mapper"_a, nb::rv_policy::reference_internal) + .def("get_num_ops", &mlir::tblgen::DagNode::getNumOps) + .def("get_num_args", &mlir::tblgen::DagNode::getNumArgs) + .def("is_nested_dag_arg", &mlir::tblgen::DagNode::isNestedDagArg, + "index"_a) + .def("get_arg_as_nested_dag", + &mlir::tblgen::DagNode::getArgAsNestedDag, "index"_a) + .def("get_arg_as_leaf", &mlir::tblgen::DagNode::getArgAsLeaf, + "index"_a) + .def("get_arg_name", &mlir::tblgen::DagNode::getArgName, "index"_a) + .def("is_replace_with_value", + &mlir::tblgen::DagNode::isReplaceWithValue) + .def("is_location_directive", + &mlir::tblgen::DagNode::isLocationDirective) + .def("is_return_type_directive", + &mlir::tblgen::DagNode::isReturnTypeDirective) + .def("is_native_code_call", &mlir::tblgen::DagNode::isNativeCodeCall) + .def("is_either", &mlir::tblgen::DagNode::isEither) + .def("is_variadic", &mlir::tblgen::DagNode::isVariadic) + .def("is_operation", &mlir::tblgen::DagNode::isOperation) + .def("get_native_code_template", + &mlir::tblgen::DagNode::getNativeCodeTemplate) + .def("get_num_returns_of_native_code", + &mlir::tblgen::DagNode::getNumReturnsOfNativeCode) + .def("print", &mlir::tblgen::DagNode::print, "os"_a); + + auto mlir_tblgen_SymbolInfoMap = + nb::class_(m, "SymbolInfoMap") + .def(nb::init>(), "loc"_a); + + auto mlir_tblgen_SymbolInfoMap_SymbolInfo = + nb::class_( + mlir_tblgen_SymbolInfoMap, "SymbolInfo") + .def("get_var_type_str", + &mlir::tblgen::SymbolInfoMap::SymbolInfo::getVarTypeStr, + "name"_a) + .def("get_var_decl", + &mlir::tblgen::SymbolInfoMap::SymbolInfo::getVarDecl, "name"_a) + .def("get_arg_decl", + &mlir::tblgen::SymbolInfoMap::SymbolInfo::getArgDecl, "name"_a) + .def("get_var_name", + &mlir::tblgen::SymbolInfoMap::SymbolInfo::getVarName, "name"_a); + + using SymbolInfoMapBaseT = + std::unordered_multimap; + mlir_tblgen_SymbolInfoMap + .def("begin", + [](mlir::tblgen::SymbolInfoMap &self) + -> SymbolInfoMapBaseT::iterator { return self.begin(); }) + .def("end", + [](mlir::tblgen::SymbolInfoMap &self) + -> SymbolInfoMapBaseT::iterator { return self.end(); }) + .def("cbegin", + [](mlir::tblgen::SymbolInfoMap &self) + -> SymbolInfoMapBaseT::const_iterator { return self.begin(); }) + .def("cend", + [](mlir::tblgen::SymbolInfoMap &self) + -> SymbolInfoMapBaseT::const_iterator { return self.end(); }) + .def("bind_op_argument", &mlir::tblgen::SymbolInfoMap::bindOpArgument, + "node"_a, "symbol"_a, "op"_a, "arg_index"_a, "variadic_sub_index"_a) + .def("bind_op_result", &mlir::tblgen::SymbolInfoMap::bindOpResult, + "symbol"_a, "op"_a) + .def("bind_values", &mlir::tblgen::SymbolInfoMap::bindValues, "symbol"_a, + "num_values"_a) + .def("bind_value", &mlir::tblgen::SymbolInfoMap::bindValue, "symbol"_a) + .def("bind_multiple_values", + &mlir::tblgen::SymbolInfoMap::bindMultipleValues, "symbol"_a, + "num_values"_a) + .def("bind_attr", &mlir::tblgen::SymbolInfoMap::bindAttr, "symbol"_a) + .def("contains", &mlir::tblgen::SymbolInfoMap::contains, "symbol"_a) + .def("find", &mlir::tblgen::SymbolInfoMap::find, "key"_a) + .def( + "find_bound_symbol", + [](mlir::tblgen::SymbolInfoMap &self, llvm::StringRef key, + mlir::tblgen::DagNode node, const mlir::tblgen::Operator &op, + int argIndex, std::optional variadicSubIndex) + -> SymbolInfoMapBaseT::const_iterator { + return self.findBoundSymbol(key, node, op, argIndex, + variadicSubIndex); + }, + "key"_a, "node"_a, "op"_a, "arg_index"_a, "variadic_sub_index"_a) + .def( + "find_bound_symbol", + [](mlir::tblgen::SymbolInfoMap &self, llvm::StringRef key, + const mlir::tblgen::SymbolInfoMap::SymbolInfo &symbolInfo) + -> SymbolInfoMapBaseT::const_iterator { + return self.findBoundSymbol(key, symbolInfo); + }, + "key"_a, "symbol_info"_a) + .def("get_range_of_equal_elements", + &mlir::tblgen::SymbolInfoMap::getRangeOfEqualElements, "key"_a) + .def("count", &mlir::tblgen::SymbolInfoMap::count, "key"_a) + .def("get_static_value_count", + &mlir::tblgen::SymbolInfoMap::getStaticValueCount, "symbol"_a) + .def("get_value_and_range_use", + &mlir::tblgen::SymbolInfoMap::getValueAndRangeUse, "symbol"_a, + "fmt"_a, "separator"_a) + .def("get_all_range_use", &mlir::tblgen::SymbolInfoMap::getAllRangeUse, + "symbol"_a, "fmt"_a, "separator"_a) + .def("assign_unique_alternative_names", + &mlir::tblgen::SymbolInfoMap::assignUniqueAlternativeNames) + .def_static("get_value_pack_name", + &mlir::tblgen::SymbolInfoMap::getValuePackName, "symbol"_a, + "index"_a); + + auto mlir_tblgen_Pattern = + nb::class_(m, "Pattern") + .def(nb::init< + const llvm::Record *, + llvm::DenseMap< + const llvm::Record *, + std::unique_ptr< + mlir::tblgen::Operator, + std::default_delete>, + llvm::DenseMapInfo, + llvm::detail::DenseMapPair< + const llvm::Record *, + std::unique_ptr>>> *>(), + "def_"_a, "mapper"_a) + .def("get_source_pattern", &mlir::tblgen::Pattern::getSourcePattern) + .def("get_num_result_patterns", + &mlir::tblgen::Pattern::getNumResultPatterns) + .def("get_result_pattern", &mlir::tblgen::Pattern::getResultPattern, + "index"_a) + .def("collect_source_pattern_bound_symbols", + &mlir::tblgen::Pattern::collectSourcePatternBoundSymbols, + "info_map"_a) + .def("collect_result_pattern_bound_symbols", + &mlir::tblgen::Pattern::collectResultPatternBoundSymbols, + "info_map"_a) + .def("get_source_root_op", &mlir::tblgen::Pattern::getSourceRootOp, + nb::rv_policy::reference_internal) + .def("get_dialect_op", &mlir::tblgen::Pattern::getDialectOp, "node"_a, + nb::rv_policy::reference_internal) + .def("get_constraints", &mlir::tblgen::Pattern::getConstraints) + .def("get_num_supplemental_patterns", + &mlir::tblgen::Pattern::getNumSupplementalPatterns) + .def("get_supplemental_pattern", + &mlir::tblgen::Pattern::getSupplementalPattern, "index"_a) + .def("get_benefit", &mlir::tblgen::Pattern::getBenefit) + .def("get_location", &mlir::tblgen::Pattern::getLocation) + .def("collect_bound_symbols", + &mlir::tblgen::Pattern::collectBoundSymbols, "tree"_a, + "info_map"_a, "is_src_pattern"_a); + + auto mlir_tblgen_SideEffect = + nb::class_(m, "SideEffect") + .def("get_name", &mlir::tblgen::SideEffect::getName) + .def("get_base_effect_name", + &mlir::tblgen::SideEffect::getBaseEffectName) + .def("get_interface_trait", + &mlir::tblgen::SideEffect::getInterfaceTrait) + .def("get_resource", &mlir::tblgen::SideEffect::getResource) + .def("get_stage", &mlir::tblgen::SideEffect::getStage) + .def("get_effect_onfull_region", + &mlir::tblgen::SideEffect::getEffectOnfullRegion); + + mlir_tblgen_SideEffect.def_static( + "classof", &mlir::tblgen::SideEffect::classof, "var"_a); + + auto mlir_tblgen_SideEffectTrait = + nb::class_( + m, "SideEffectTrait") + .def("get_effects", &mlir::tblgen::SideEffectTrait::getEffects) + .def("get_base_effect_name", + &mlir::tblgen::SideEffectTrait::getBaseEffectName); + + mlir_tblgen_SideEffectTrait.def_static( + "classof", &mlir::tblgen::SideEffectTrait::classof, "t"_a); + + m.def("lookup_intrinsic_id", llvm::Intrinsic::lookupIntrinsicID, + nb::arg("name")); + m.def("intrinsic_is_overloaded", llvm::Intrinsic::isOverloaded, + nb::arg("id")); } diff --git a/projects/eudsl-tblgen/tests/td/CommonTypeConstraints.td b/projects/eudsl-tblgen/tests/td/CommonTypeConstraints.td index cd90e377..64546088 100644 --- a/projects/eudsl-tblgen/tests/td/CommonTypeConstraints.td +++ b/projects/eudsl-tblgen/tests/td/CommonTypeConstraints.td @@ -918,4 +918,8 @@ def SignlessIntegerOrFloatLike : TypeConstraint, "signless-integer-like or floating-point-like">; +def DummyConstraint : AnyTypeOf<[AnyInteger, Index, AnyFloat]> { + let cppFunctionName = "isValidDummy"; +} + #endif // COMMON_TYPE_CONSTRAINTS_TD diff --git a/projects/eudsl-tblgen/tests/test_bindings.py b/projects/eudsl-tblgen/tests/test_bindings.py index 8874b45a..7f3c688a 100644 --- a/projects/eudsl-tblgen/tests/test_bindings.py +++ b/projects/eudsl-tblgen/tests/test_bindings.py @@ -6,7 +6,12 @@ from pathlib import Path import pytest -from eudsl_tblgen import RecordKeeper +from eudsl_tblgen import ( + RecordKeeper, + get_requested_op_definitions, + get_all_type_constraints, + collect_all_defs, +) @pytest.fixture(scope="function") @@ -15,18 +20,18 @@ def json_record_keeper(): def test_json_record_keeper(json_record_keeper): - assert json_record_keeper.input_filename == str( + assert json_record_keeper.get_input_filename() == str( Path(__file__).parent / "td" / "JSON.td" ) - assert set(json_record_keeper.classes) == { + assert set(json_record_keeper.get_classes()) == { "Base", "Derived", "Intermediate", "Variables", } - assert set(json_record_keeper.defs.keys()) == { + assert set(json_record_keeper.get_defs().keys()) == { "D", "ExampleDagOp", "FieldKeywordTest", @@ -41,46 +46,49 @@ def test_json_record_keeper(json_record_keeper): assert len(json_record_keeper.get_all_derived_definitions("Intermediate")) == 1 assert len(json_record_keeper.get_all_derived_definitions("Derived")) == 0 - assert json_record_keeper.get_all_derived_definitions("Base")[0].name == "D" - assert json_record_keeper.get_all_derived_definitions("Intermediate")[0].name == "D" + assert json_record_keeper.get_all_derived_definitions("Base")[0].get_name() == "D" + assert ( + json_record_keeper.get_all_derived_definitions("Intermediate")[0].get_name() + == "D" + ) def test_record(json_record_keeper): - assert json_record_keeper.classes["Base"] - assert json_record_keeper.classes["Intermediate"] - assert json_record_keeper.classes["Derived"] - assert json_record_keeper.classes["Variables"] - - base_cl = json_record_keeper.classes["Base"] - interm_cl = json_record_keeper.classes["Intermediate"] - deriv_cl = json_record_keeper.classes["Derived"] - variab_cl = json_record_keeper.classes["Variables"] - - assert len(base_cl.direct_super_classes) == 0 - assert len(interm_cl.direct_super_classes) == 1 - assert len(deriv_cl.direct_super_classes) == 1 - assert len(variab_cl.direct_super_classes) == 0 - - assert interm_cl.direct_super_classes[0].name == "Base" - assert deriv_cl.direct_super_classes[0].name == "Intermediate" - - assert base_cl.name == "Base" - assert base_cl.name_init_as_string == "Base" - assert base_cl.records is json_record_keeper - assert base_cl.type - - assert repr(base_cl.values) == "RecordValues()" + assert json_record_keeper.get_classes()["Base"] + assert json_record_keeper.get_classes()["Intermediate"] + assert json_record_keeper.get_classes()["Derived"] + assert json_record_keeper.get_classes()["Variables"] + + base_cl = json_record_keeper.get_classes()["Base"] + interm_cl = json_record_keeper.get_classes()["Intermediate"] + deriv_cl = json_record_keeper.get_classes()["Derived"] + variab_cl = json_record_keeper.get_classes()["Variables"] + + assert len(base_cl.get_direct_super_classes()) == 0 + assert len(interm_cl.get_direct_super_classes()) == 1 + assert len(deriv_cl.get_direct_super_classes()) == 1 + assert len(variab_cl.get_direct_super_classes()) == 0 + + assert interm_cl.get_direct_super_classes()[0].get_name() == "Base" + assert deriv_cl.get_direct_super_classes()[0].get_name() == "Intermediate" + + assert base_cl.get_name() == "Base" + assert base_cl.get_name_init_as_string() == "Base" + assert base_cl.get_records() is json_record_keeper + assert base_cl.get_type() + + assert repr(base_cl.get_values()) == "RecordValues()" assert ( - repr(variab_cl.values) + repr(variab_cl.get_values()) == "RecordValues(i=?, s=?, b=?, bs={ ?, ?, ?, ?, ?, ?, ?, ? }, c=?, li=?, base=?, d=?)" ) - assert interm_cl.has_direct_super_class(interm_cl.direct_super_classes[0]) + assert interm_cl.has_direct_super_class(interm_cl.get_direct_super_classes()[0]) assert interm_cl.has_direct_super_class(base_cl) - assert base_cl.is_anonymous is False - assert base_cl.is_class is True - assert base_cl.is_multi_class is False + assert base_cl.is_anonymous() is False + assert base_cl.is_class() is True + assert base_cl.is_multi_class() is False assert interm_cl.is_sub_class_of(base_cl) assert not interm_cl.is_sub_class_of(variab_cl) @@ -88,20 +96,20 @@ def test_record(json_record_keeper): def test_record_val_classes(json_record_keeper): - variab_cl = json_record_keeper.classes["Variables"] + variab_cl = json_record_keeper.get_classes()["Variables"] assert variab_cl.get_value("i") i_val = variab_cl.get_value("i") - assert i_val.name == "i" - assert i_val.name_init_as_string == "i" - assert i_val.print_type == "int" - assert i_val.record_keeper is json_record_keeper - assert i_val.is_nonconcrete_ok is False - assert i_val.is_template_arg is False - assert i_val.is_used is False + assert i_val.get_name() == "i" + assert i_val.get_name_init_as_string() == "i" + assert i_val.get_print_type() == "int" + assert i_val.get_record_keeper() is json_record_keeper + assert i_val.is_nonconcrete_ok() is False + assert i_val.is_template_arg() is False + assert i_val.is_used() is False def test_record_val_defs(json_record_keeper): - var_prim_def = json_record_keeper.defs["VarPrim"] + var_prim_def = json_record_keeper.get_defs()["VarPrim"] assert var_prim_def.get_value_as_int("i") == 3 assert var_prim_def.get_value_as_int("enormous_pos") == 9123456789123456789 assert var_prim_def.get_value_as_int("enormous_neg") == -9123456789123456789 @@ -112,31 +120,31 @@ def test_record_val_defs(json_record_keeper): def test_init(json_record_keeper): - variab_cl = json_record_keeper.classes["Variables"] + variab_cl = json_record_keeper.get_classes()["Variables"] assert variab_cl.get_value("i") - assert variab_cl.get_value("i").value - i_val_init = variab_cl.get_value("i").value + assert variab_cl.get_value("i").get_value() + i_val_init = variab_cl.get_value("i").get_value() assert str(i_val_init) == "?" - assert i_val_init.as_string == "?" + assert i_val_init.get_as_string() == "?" assert i_val_init.is_complete() is False assert i_val_init.is_concrete() is True def test_record_rec_ty(json_record_keeper): - base_cl = json_record_keeper.classes["Base"] - interm_cl = json_record_keeper.classes["Intermediate"] - deriv_cl = json_record_keeper.classes["Derived"] + base_cl = json_record_keeper.get_classes()["Base"] + interm_cl = json_record_keeper.get_classes()["Intermediate"] + deriv_cl = json_record_keeper.get_classes()["Derived"] - assert not base_cl.type.classes - assert interm_cl.type.classes - assert deriv_cl.type.classes - assert len(interm_cl.type.classes) == 1 - assert len(deriv_cl.type.classes) == 1 - assert interm_cl.type.classes[0].name == "Base" - assert deriv_cl.type.classes[0].name == "Intermediate" + assert not base_cl.get_type().get_classes() + assert interm_cl.get_type().get_classes() + assert deriv_cl.get_type().get_classes() + assert len(interm_cl.get_type().get_classes()) == 1 + assert len(deriv_cl.get_type().get_classes()) == 1 + assert interm_cl.get_type().get_classes()[0].get_name() == "Base" + assert deriv_cl.get_type().get_classes()[0].get_name() == "Intermediate" - assert interm_cl.type.is_sub_class_of(base_cl) - assert deriv_cl.type.is_sub_class_of(interm_cl) + assert interm_cl.get_type().is_sub_class_of(base_cl) + assert deriv_cl.get_type().is_sub_class_of(interm_cl) @pytest.fixture(scope="function") @@ -148,76 +156,92 @@ def record_keeper_test_dialect(): def test_init_complex(record_keeper_test_dialect): - op = record_keeper_test_dialect.defs["Test_TypesOp"] - assert str(op.values.opName) == "types" - assert str(op.values.cppNamespace) == "test" - assert str(op.values.opDocGroup) == "?" - assert str(op.values.results) == "(outs)" - assert str(op.values.regions) == "(region)" - assert str(op.values.successors) == "(successor)" - assert str(op.values.builders) == "?" - assert bool(op.values.skipDefaultBuilders.value) is False - assert str(op.values.assemblyFormat) == "?" - assert bool(op.values.hasCustomAssemblyFormat.value) is False - assert bool(op.values.hasVerifier.value) is False - assert bool(op.values.hasRegionVerifier.value) is False - assert bool(op.values.hasCanonicalizer.value) is False - assert bool(op.values.hasCanonicalizeMethod.value) is False - assert bool(op.values.hasFolder.value) is False - assert bool(op.values.useCustomPropertiesEncoding.value) is False - assert len(op.values.traits.value) == 0 - assert str(op.values.extraClassDeclaration) == "?" - assert str(op.values.extraClassDefinition) == "?" + op = record_keeper_test_dialect.get_defs()["Test_TypesOp"] + assert str(op.get_values().opName) == "types" + assert str(op.get_values().cppNamespace) == "test" + assert str(op.get_values().opDocGroup) == "?" + assert str(op.get_values().results) == "(outs)" + assert str(op.get_values().regions) == "(region)" + assert str(op.get_values().successors) == "(successor)" + assert str(op.get_values().builders) == "?" + assert bool(op.get_values().skipDefaultBuilders.get_value()) is False + assert str(op.get_values().assemblyFormat) == "?" + assert bool(op.get_values().hasCustomAssemblyFormat.get_value()) is False + assert bool(op.get_values().hasVerifier.get_value()) is False + assert bool(op.get_values().hasRegionVerifier.get_value()) is False + assert bool(op.get_values().hasCanonicalizer.get_value()) is False + assert bool(op.get_values().hasCanonicalizeMethod.get_value()) is False + assert bool(op.get_values().hasFolder.get_value()) is False + assert bool(op.get_values().useCustomPropertiesEncoding.get_value()) is False + assert len(op.get_values().traits.get_value()) == 0 + assert str(op.get_values().extraClassDeclaration) == "?" + assert str(op.get_values().extraClassDefinition) == "?" assert ( - repr(op.values) - == "RecordValues(opDialect=Test_Dialect, opName=types, cppNamespace=test, summary=, description=, opDocGroup=?, arguments=(ins I32:$a, SI64:$b, UI8:$c, Index:$d, F32:$e, NoneType:$f, anonymous_347), results=(outs), regions=(region), successors=(successor), builders=?, skipDefaultBuilders=0, assemblyFormat=?, hasCustomAssemblyFormat=0, hasVerifier=0, hasRegionVerifier=0, hasCanonicalizer=0, hasCanonicalizeMethod=0, hasFolder=0, useCustomPropertiesEncoding=0, traits=[], extraClassDeclaration=?, extraClassDefinition=?)" + repr(op.get_values()) + == "RecordValues(opDialect=Test_Dialect, opName=types, cppNamespace=test, summary=, description=, opDocGroup=?, arguments=(ins I32:$a, SI64:$b, UI8:$c, Index:$d, F32:$e, NoneType:$f, anonymous_348), results=(outs), regions=(region), successors=(successor), builders=?, skipDefaultBuilders=0, assemblyFormat=?, hasCustomAssemblyFormat=0, hasVerifier=0, hasRegionVerifier=0, hasCanonicalizer=0, hasCanonicalizeMethod=0, hasFolder=0, useCustomPropertiesEncoding=0, traits=[], extraClassDeclaration=?, extraClassDefinition=?)" ) - arguments = op.values.arguments - assert arguments.value.get_arg_name_str(0) == "a" - assert arguments.value.get_arg_name_str(1) == "b" - assert arguments.value.get_arg_name_str(2) == "c" - assert arguments.value.get_arg_name_str(3) == "d" - assert arguments.value.get_arg_name_str(4) == "e" - assert arguments.value.get_arg_name_str(5) == "f" - - assert str(arguments.value[0]) == "I32" - assert str(arguments.value[1]) == "SI64" - assert str(arguments.value[2]) == "UI8" - assert str(arguments.value[3]) == "Index" - assert str(arguments.value[4]) == "F32" - assert str(arguments.value[5]) == "NoneType" - - attr = record_keeper_test_dialect.defs["Test_TestAttr"] - assert str(attr.values.predicate) == "anonymous_334" - assert str(attr.values.storageType) == "test::TestAttr" - assert str(attr.values.returnType) == "test::TestAttr" + arguments = op.get_values().arguments + assert arguments.get_value().get_arg_name_str(0) == "a" + assert arguments.get_value().get_arg_name_str(1) == "b" + assert arguments.get_value().get_arg_name_str(2) == "c" + assert arguments.get_value().get_arg_name_str(3) == "d" + assert arguments.get_value().get_arg_name_str(4) == "e" + assert arguments.get_value().get_arg_name_str(5) == "f" + + assert str(arguments.get_value()[0]) == "I32" + assert str(arguments.get_value()[1]) == "SI64" + assert str(arguments.get_value()[2]) == "UI8" + assert str(arguments.get_value()[3]) == "Index" + assert str(arguments.get_value()[4]) == "F32" + assert str(arguments.get_value()[5]) == "NoneType" + + attr = record_keeper_test_dialect.get_defs()["Test_TestAttr"] + assert str(attr.get_values().predicate) == "anonymous_335" + assert str(attr.get_values().storageType) == "test::TestAttr" + assert str(attr.get_values().returnType) == "test::TestAttr" assert ( - str(attr.values.convertFromStorage.value) + str(attr.get_values().convertFromStorage.get_value()) == "::llvm::cast($_self)" ) - assert str(attr.values.constBuilderCall) == "?" - assert str(attr.values.defaultValue) == "?" - assert str(attr.values.valueType) == "?" - assert bool(attr.values.isOptional.value) is False - assert str(attr.values.baseAttr) == "?" - assert str(attr.values.cppNamespace) == "test" - assert str(attr.values.dialect) == "Test_Dialect" - assert str(attr.values.cppBaseClassName.value) == "::mlir::Attribute" - assert str(attr.values.storageClass) == "TestAttrStorage" - assert str(attr.values.storageNamespace) == "detail" - assert bool(attr.values.genStorageClass.value) is True - assert bool(attr.values.hasStorageCustomConstructor.value) is False - assert str(attr.values.parameters) == "(ins)" - assert str(attr.values.builders) == "?" - assert len(attr.values.traits.value) == 0 - assert str(attr.values.mnemonic) == "test" - assert str(attr.values.assemblyFormat) == "?" - assert bool(attr.values.hasCustomAssemblyFormat.value) is False - assert bool(attr.values.genAccessors.value) is True - assert bool(attr.values.skipDefaultBuilders.value) is False - assert bool(attr.values.genVerifyDecl.value) is False - assert str(attr.values.cppClassName) == "TestAttr" - assert str(attr.values.cppType) == "test::TestAttr" - assert str(attr.values.attrName) == "test.test" + assert str(attr.get_values().constBuilderCall) == "?" + assert str(attr.get_values().defaultValue) == "?" + assert str(attr.get_values().valueType) == "?" + assert bool(attr.get_values().isOptional.get_value()) is False + assert str(attr.get_values().baseAttr) == "?" + assert str(attr.get_values().cppNamespace) == "test" + assert str(attr.get_values().dialect) == "Test_Dialect" + assert str(attr.get_values().cppBaseClassName.get_value()) == "::mlir::Attribute" + assert str(attr.get_values().storageClass) == "TestAttrStorage" + assert str(attr.get_values().storageNamespace) == "detail" + assert bool(attr.get_values().genStorageClass.get_value()) is True + assert bool(attr.get_values().hasStorageCustomConstructor.get_value()) is False + assert str(attr.get_values().parameters) == "(ins)" + assert str(attr.get_values().builders) == "?" + assert len(attr.get_values().traits.get_value()) == 0 + assert str(attr.get_values().mnemonic) == "test" + assert str(attr.get_values().assemblyFormat) == "?" + assert bool(attr.get_values().hasCustomAssemblyFormat.get_value()) is False + assert bool(attr.get_values().genAccessors.get_value()) is True + assert bool(attr.get_values().skipDefaultBuilders.get_value()) is False + assert bool(attr.get_values().genVerifyDecl.get_value()) is False + assert str(attr.get_values().cppClassName) == "TestAttr" + assert str(attr.get_values().cppType) == "test::TestAttr" + assert str(attr.get_values().attrName) == "test.test" + + +def test_mlir_tblgen(record_keeper_test_dialect): + for op in get_requested_op_definitions(record_keeper_test_dialect): + print(op.get_name()) + for constraint in get_all_type_constraints(record_keeper_test_dialect): + print(constraint.get_def_name()) + print(constraint.get_summary()) + + all_defs = collect_all_defs(record_keeper_test_dialect) + for d in all_defs: + print(d.get_name()) + + all_defs = collect_all_defs(record_keeper_test_dialect, "test") + for d in all_defs: + print(d.get_name()) diff --git a/requirements.txt b/requirements-dev.txt similarity index 81% rename from requirements.txt rename to requirements-dev.txt index bf532c32..de1d0061 100644 --- a/requirements.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ +litgen @ git+https://github.com/pthom/litgen@f5d154c6f7679e755baa1047563d7c340309bc00 nanobind==2.4.0 numpy==2.0.2 -litgen @ git+https://github.com/pthom/litgen@f5d154c6f7679e755baa1047563d7c340309bc00 +scikit-build-core==0.10.7