From 1da6b314e8fd9babf832e2015f9b6d4285af3d79 Mon Sep 17 00:00:00 2001 From: minhthuc Date: Wed, 10 Jan 2024 12:24:39 +0100 Subject: [PATCH 1/4] tensor parallel support --- CMakeLists.txt | 36 ++- cmake/FindNCCL.cmake | 28 ++ docker/Dockerfile | 20 +- include/ctranslate2/devices.h | 30 +++ include/ctranslate2/layers/attention.h | 4 +- include/ctranslate2/layers/common.h | 4 +- include/ctranslate2/layers/transformer.h | 3 + include/ctranslate2/models/model.h | 12 +- include/ctranslate2/ops/nccl_ops.h | 35 +++ include/ctranslate2/ops/ops.h | 1 + include/ctranslate2/replica_pool.h | 2 + include/ctranslate2/utils.h | 3 + python/cpp/encoder.cc | 6 +- python/cpp/generator.cc | 6 +- python/cpp/module.cc | 1 + python/cpp/module.h | 1 + python/cpp/mpi.cc | 30 +++ python/cpp/replica_pool.h | 6 + python/cpp/translator.cc | 8 +- python/cpp/wav2vec2.cc | 6 +- python/cpp/whisper.cc | 6 +- python/ctranslate2/__init__.py | 1 + python/ctranslate2/specs/model_spec.py | 3 + python/ctranslate2/specs/transformer_spec.py | 13 + .../tools/prepare_build_environment_linux.sh | 14 +- src/cuda/mpi_stub.cc | 94 +++++++ src/cuda/mpi_stub.h | 18 ++ src/cuda/nccl_stub.cc | 93 +++++++ src/cuda/utils.h | 22 ++ src/devices.cc | 99 +++++++ src/layers/attention.cc | 31 ++- src/layers/common.cc | 39 ++- src/layers/transformer.cc | 42 ++- src/models/model.cc | 254 +++++++++++++++++- src/ops/nccl_ops.cc | 23 ++ src/ops/nccl_ops_cpu.cc | 23 ++ src/ops/nccl_ops_gpu.cu | 93 +++++++ src/utils.cc | 1 - tools/benchmark_tensor_parallel/README.md | 18 ++ tools/benchmark_tensor_parallel/benchmark.py | 172 ++++++++++++ .../requirements.txt | 3 + 41 files changed, 1263 insertions(+), 41 deletions(-) create mode 100644 cmake/FindNCCL.cmake create mode 100644 include/ctranslate2/ops/nccl_ops.h create mode 100644 python/cpp/mpi.cc create mode 100644 src/cuda/mpi_stub.cc create mode 100644 src/cuda/mpi_stub.h create mode 100644 src/cuda/nccl_stub.cc create mode 100644 src/ops/nccl_ops.cc create mode 100644 src/ops/nccl_ops_cpu.cc create mode 100644 src/ops/nccl_ops_gpu.cu create mode 100644 tools/benchmark_tensor_parallel/README.md create mode 100644 tools/benchmark_tensor_parallel/benchmark.py create mode 100644 tools/benchmark_tensor_parallel/requirements.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 1089106cc..a32a45fe7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,6 +20,7 @@ option(ENABLE_PROFILING "Compile with profiling support" OFF) option(BUILD_CLI "Compile the clients" ON) option(BUILD_TESTS "Compile the tests" OFF) option(BUILD_SHARED_LIBS "Build shared libraries" ON) +option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF) if(ENABLE_PROFILING) message(STATUS "Enable profiling support") @@ -179,6 +180,8 @@ set(SOURCES src/ops/topp_mask.cc src/ops/topp_mask_cpu.cc src/ops/transpose.cc + src/ops/nccl_ops.cc + src/ops/nccl_ops_cpu.cc src/padder.cc src/profiler.cc src/random.cc @@ -191,7 +194,7 @@ set(SOURCES src/utils.cc src/vocabulary.cc src/vocabulary_map.cc - ) +) set(LIBRARIES ${CMAKE_THREAD_LIBS_INIT} spdlog::spdlog_header_only @@ -419,6 +422,24 @@ endif() if (WITH_CUDA) find_package(CUDA 11.0 REQUIRED) + list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) + if (WITH_TENSOR_PARALLEL) + find_package(MPI REQUIRED) + find_package(NCCL REQUIRED) + include_directories(${NCCL_INCLUDE_DIR}) + include_directories(${MPI_INCLUDE_PATH}) + if(CUDA_DYNAMIC_LOADING) + list(APPEND SOURCES src/cuda/mpi_stub.cc) + list(APPEND SOURCES src/cuda/nccl_stub.cc) + add_definitions(-DCT2_WITH_CUDA_DYNAMIC_LOADING) + else () + list(APPEND LIBRARIES ${NCCL_LIBRARY}) + list(APPEND LIBRARIES ${MPI_LIBRARIES}) + endif () + add_definitions(-DCT2_WITH_TENSOR_PARALLEL) + endif () + include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include) + add_definitions(-DCT2_WITH_CUDA) if(MSVC) if(BUILD_SHARED_LIBS) @@ -522,7 +543,8 @@ if (WITH_CUDA) src/ops/topk_gpu.cu src/ops/topp_mask_gpu.cu src/ops/quantize_gpu.cu - ) + src/ops/nccl_ops_gpu.cu + ) elseif(WITH_CUDNN) message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON") else() @@ -546,6 +568,10 @@ target_include_directories(${PROJECT_NAME} BEFORE PRIVATE ${PRIVATE_INCLUDE_DIRECTORIES} ) +if (WITH_TENSOR_PARALLEL AND CUDA_DYNAMIC_LOADING) + target_compile_options(${PROJECT_NAME} PRIVATE -DOMPI_SKIP_MPICXX) +endif() + if(BUILD_TESTS) add_subdirectory(tests) endif() @@ -587,6 +613,11 @@ configure_file(cmake/${PROJECT_NAME}Config.cmake COPYONLY ) +configure_file(cmake/FindNCCL.cmake + "${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}/FindNCCL.cmake" + COPYONLY +) + set(ConfigPackageLocation ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}) if(BUILD_SHARED_LIBS) @@ -603,6 +634,7 @@ endif() install( FILES cmake/${PROJECT_NAME}Config.cmake + cmake/FindNCCL.cmake "${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}/${PROJECT_NAME}ConfigVersion.cmake" DESTINATION ${ConfigPackageLocation} diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake new file mode 100644 index 000000000..c5f0e31e8 --- /dev/null +++ b/cmake/FindNCCL.cmake @@ -0,0 +1,28 @@ +# Find the NCCL libraries +# +# The following variables are optionally searched for defaults +# NCCL_ROOT_DIR: Base directory where all NCCL components are found +# +# The following are set after configuration is done: +# NCCL_FOUND +# NCCL_INCLUDE_DIR +# NCCL_LIBRARY + +find_path(NCCL_INCLUDE_DIR NAMES nccl.h + PATHS ${NCCL_ROOT_DIR}/include +) + +find_library(NCCL_LIBRARY NAMES nccl + PATHS ${NCCL_ROOT_DIR}/lib ${NCCL_ROOT_DIR}/lib64) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR + NCCL_LIBRARY) + +if (NCCL_FOUND) + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIR}, library: + ${NCCL_LIBRARY})") + mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY) + set(NCCL_VERSION "${NCCL_MAJOR}.${NCCL_MINOR}.${NCCL_PATCH}") + +endif () diff --git a/docker/Dockerfile b/docker/Dockerfile index bfc7dfcbf..c1d5a47fb 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -35,6 +35,16 @@ RUN wget -q https://github.com/oneapi-src/oneDNN/archive/refs/tags/v${ONEDNN_VER cd .. && \ rm -r oneDNN-* +ENV OPENMPI_VERSION=4.1.6 +RUN wget -q https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-${OPENMPI_VERSION}.tar.bz2 && \ + tar xf *.tar.bz2 && \ + rm *.tar.bz2 && \ + cd openmpi-* && \ + ./configure && \ + make -j$(nproc) install && \ + cd .. && \ + rm -r openmpi-* + COPY third_party third_party COPY cli cli COPY include include @@ -50,13 +60,14 @@ ENV CUDA_NVCC_FLAGS=${CUDA_NVCC_FLAGS:-"-Xfatbin=-compress-all"} ARG CUDA_ARCH_LIST ENV CUDA_ARCH_LIST=${CUDA_ARCH_LIST:-"Common"} ENV CTRANSLATE2_ROOT=/opt/ctranslate2 +ENV LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH} -RUN mkdir build && \ - cd build && \ +RUN mkdir build_tmp && \ + cd build_tmp && \ cmake -DCMAKE_INSTALL_PREFIX=${CTRANSLATE2_ROOT} \ -DWITH_CUDA=ON -DWITH_CUDNN=ON -DWITH_MKL=ON -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP \ -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \ - -DCUDA_NVCC_FLAGS="${CUDA_NVCC_FLAGS}" -DCUDA_ARCH_LIST="${CUDA_ARCH_LIST}" .. && \ + -DCUDA_NVCC_FLAGS="${CUDA_NVCC_FLAGS}" -DCUDA_ARCH_LIST="${CUDA_ARCH_LIST}" -DWITH_TENSOR_PARALLEL=ON .. && \ VERBOSE=1 make -j$(nproc) install ENV LANG=en_US.UTF-8 @@ -74,6 +85,9 @@ RUN apt-get update && \ apt-get install -y --no-install-recommends \ libcublas-12-2 \ libcudnn8=8.9.7.29-1+cuda12.2 \ + libnccl2=2.19.3-1+cuda12.2 \ + libopenmpi3=4.0.3-0ubuntu1 \ + openmpi-bin \ libgomp1 \ python3-pip \ && \ diff --git a/include/ctranslate2/devices.h b/include/ctranslate2/devices.h index 2691efc3a..674713b8f 100644 --- a/include/ctranslate2/devices.h +++ b/include/ctranslate2/devices.h @@ -2,6 +2,10 @@ #include #include +#include +#ifdef CT2_WITH_TENSOR_PARALLEL +# include +#endif namespace ctranslate2 { @@ -45,4 +49,30 @@ namespace ctranslate2 { int _new_index; }; + extern int my_rank; + extern int local_rank; + extern int n_ranks; + + class ScopedMPISetter { + public: + ScopedMPISetter(); + ~ScopedMPISetter(); + + static int getNRanks(); + static int getCurRank(); + static int getLocalRank(); + +#ifdef CT2_WITH_TENSOR_PARALLEL + static ncclComm_t getNcclComm(); +#endif + + static void finalize(); + + private: +#ifdef CT2_WITH_TENSOR_PARALLEL + static uint64_t getHostHash(const char *string); + static void getHostName(char *hostname, int maxlen); + static std::vector _nccl_comms; +#endif + }; } diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index b342f4faa..d2deb5e03 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -43,7 +43,7 @@ namespace ctranslate2 { } bool multi_query() const { - return _num_heads_kv == 1; + return _multi_query; } static StorageView prepare_length_mask(const StorageView& lengths, @@ -53,6 +53,7 @@ namespace ctranslate2 { const bool multi_query = false); private: + const bool _tensor_parallel; const dim_t _num_heads; const bool _self_attention; const bool _is_decoder; @@ -68,6 +69,7 @@ namespace ctranslate2 { const StorageView* _relative_position_values; dim_t _maximum_relative_position; const float _queries_scale; + const bool _multi_query; const dim_t _num_heads_kv; const bool _merge_time_and_head_dims; const dim_t _cache_time_dim; diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index cb8586b78..d06c6da6c 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -127,7 +127,8 @@ namespace ctranslate2 { public: Dense(const models::Model& model, const std::string& scope, - const ops::ActivationType* activation_type = nullptr); + const ops::ActivationType* activation_type = nullptr, + const bool affected_by_tp = false); DataType output_type() const override; dim_t output_size() const override; void operator()(const StorageView& input, StorageView& output) const; @@ -147,6 +148,7 @@ namespace ctranslate2 { const ops::Gemm _gemm_op; const ops::Quantize _quantize_op; const ops::Dequantize _dequantize_op; + const bool _affected_by_tp; }; class LayerNorm : public Layer diff --git a/include/ctranslate2/layers/transformer.h b/include/ctranslate2/layers/transformer.h index 61b9fae47..a7183a30d 100644 --- a/include/ctranslate2/layers/transformer.h +++ b/include/ctranslate2/layers/transformer.h @@ -34,6 +34,7 @@ namespace ctranslate2 { const Dense _ff1; const std::unique_ptr _ff1_noact; const Dense _ff2; + const bool _tensor_parallel; }; class TransformerEncoderLayer : public Layer @@ -149,6 +150,7 @@ namespace ctranslate2 { const std::unique_ptr _output_norm; const std::vector> _layers; const std::unique_ptr _position_encoder; + const bool _tensor_parallel; }; class TransformerDecoder : public Decoder @@ -211,6 +213,7 @@ namespace ctranslate2 { bool _average_alignment_heads; Dense _proj; const dim_t _sliding_window; + const bool _tensor_parallel; }; } diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h index 43a4ea5b9..1bd7a4c14 100644 --- a/include/ctranslate2/models/model.h +++ b/include/ctranslate2/models/model.h @@ -26,11 +26,13 @@ namespace ctranslate2 { static std::shared_ptr load(const std::string& path, Device device = Device::CPU, int device_index = 0, - ComputeType compute_type = ComputeType::DEFAULT); + ComputeType compute_type = ComputeType::DEFAULT, + bool tensor_parallel = false); static std::shared_ptr load(ModelReader& model_reader, Device device = Device::CPU, int device_index = 0, - ComputeType compute_type = ComputeType::DEFAULT); + ComputeType compute_type = ComputeType::DEFAULT, + bool tensor_parallel = false); virtual std::unique_ptr as_sequence_to_sequence() const; virtual std::unique_ptr as_sequence_generator() const; @@ -78,6 +80,10 @@ namespace ctranslate2 { return _binary_version >= 5; } + bool tensor_parallel() const { + return _tensor_parallel; + } + virtual bool use_global_int16_scale() const { return true; } @@ -163,6 +169,7 @@ namespace ctranslate2 { ComputeType _effective_compute_type = ComputeType::DEFAULT; dim_t _preferred_size_multiple = 1; std::unordered_map> _variable_index; + bool _tensor_parallel = false; }; template<> @@ -191,6 +198,7 @@ namespace ctranslate2 { std::vector device_indices = {0}; size_t num_replicas_per_device = 1; ComputeType compute_type = ComputeType::DEFAULT; + bool tensor_parallel = false; }; // Base class for replicas. diff --git a/include/ctranslate2/ops/nccl_ops.h b/include/ctranslate2/ops/nccl_ops.h new file mode 100644 index 000000000..d610d972f --- /dev/null +++ b/include/ctranslate2/ops/nccl_ops.h @@ -0,0 +1,35 @@ +#pragma once + +#include "op.h" + +namespace ctranslate2 { + namespace ops { + class ReduceAll : public Op { + public: + enum class RED_OP { + SUM, + PROD, + MIN, + MAX, + AVG + }; + + explicit ReduceAll(RED_OP op = RED_OP::SUM); + void operator()(const StorageView& input, StorageView& output) const; + private: + RED_OP _reduce_op; + + template + void compute(const StorageView& input, StorageView& output) const; + }; + + class GatherAll : public Op { + public: + explicit GatherAll(); + void operator()(const StorageView& input, StorageView& output) const; + private: + template + void compute(const StorageView& input, StorageView& output) const; + }; + } +} \ No newline at end of file diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h index 051c81acc..f03d0211a 100644 --- a/include/ctranslate2/ops/ops.h +++ b/include/ctranslate2/ops/ops.h @@ -37,3 +37,4 @@ #include "rotary.h" #include "alibi_add.h" #include "slide.h" +#include "nccl_ops.h" diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h index efc9824d1..8c8e15d8e 100644 --- a/include/ctranslate2/replica_pool.h +++ b/include/ctranslate2/replica_pool.h @@ -34,11 +34,13 @@ namespace ctranslate2 { const Device device, const ComputeType compute_type = ComputeType::DEFAULT, const std::vector& device_indices = {0}, + const bool tensor_parallel = false, const ReplicaPoolConfig& config = {}) { models::ModelLoader model_loader(model_path); model_loader.device = device; model_loader.device_indices = device_indices; model_loader.compute_type = compute_type; + model_loader.tensor_parallel = tensor_parallel; initialize_pool(model_loader, config); } diff --git a/include/ctranslate2/utils.h b/include/ctranslate2/utils.h index c8e7ef78b..23c58cb82 100644 --- a/include/ctranslate2/utils.h +++ b/include/ctranslate2/utils.h @@ -4,6 +4,7 @@ #include #include #include +#include "ctranslate2/types.h" namespace ctranslate2 { @@ -92,5 +93,7 @@ namespace ctranslate2 { #endif #define THROW_RUNTIME_ERROR(MESSAGE) THROW_EXCEPTION(std::runtime_error, MESSAGE) #define THROW_INVALID_ARGUMENT(MESSAGE) THROW_EXCEPTION(std::invalid_argument, MESSAGE) +#define SAFE_DIVIDE(x, y) ((y != 0 && (x % y == 0)) ? (x / y) : (throw std::runtime_error("Division has a remainder," \ + "Model can't be ran with the tensor parallel mode in " + std::to_string(y) + " nodes"))) } diff --git a/python/cpp/encoder.cc b/python/cpp/encoder.cc index ea8b1a430..9a50923ac 100644 --- a/python/cpp/encoder.cc +++ b/python/cpp/encoder.cc @@ -71,7 +71,7 @@ namespace ctranslate2 { >>> encoder.forward_batch([["▁Hello", "▁world", "!"]]) )pbdoc") - .def(py::init>&, const StringOrMap&, size_t, size_t, long, py::object>(), + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(), py::arg("model_path"), py::arg("device")="cpu", py::kw_only(), @@ -80,6 +80,7 @@ namespace ctranslate2 { py::arg("inter_threads")=1, py::arg("intra_threads")=0, py::arg("max_queued_batches")=0, + py::arg("tensor_parallel")=false, py::arg("files")=py::none(), R"pbdoc( Initializes the encoder. @@ -96,6 +97,7 @@ namespace ctranslate2 { max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, 0 for an automatic value). When the queue is full, future requests will block until a free slot is available. + tensor_parallel: run model with tensor parallel mode files: Load model files from the memory. This argument is a dictionary mapping file names to file contents as file-like or bytes objects. If this is set, :obj:`model_path` acts as an identifier for this model. @@ -111,6 +113,8 @@ namespace ctranslate2 { "Number of encoders backing this instance.") .def_property_readonly("num_queued_batches", &EncoderWrapper::num_queued_batches, "Number of batches waiting to be processed.") + .def_property_readonly("tensor_parallel", &EncoderWrapper::tensor_parallel, + "Run model with tensor parallel mode.") .def_property_readonly("num_active_batches", &EncoderWrapper::num_active_batches, "Number of batches waiting to be processed or currently processed.") diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc index 981c6da68..93b1a229a 100644 --- a/python/cpp/generator.cc +++ b/python/cpp/generator.cc @@ -128,7 +128,7 @@ namespace ctranslate2 { >>> generator.generate_batch([[""]], max_length=50, sampling_topk=20) )pbdoc") - .def(py::init>&, const StringOrMap&, size_t, size_t, long, py::object>(), + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(), py::arg("model_path"), py::arg("device")="cpu", py::kw_only(), @@ -137,6 +137,7 @@ namespace ctranslate2 { py::arg("inter_threads")=1, py::arg("intra_threads")=0, py::arg("max_queued_batches")=0, + py::arg("tensor_parallel")=false, py::arg("files")=py::none(), R"pbdoc( Initializes the generator. @@ -153,6 +154,7 @@ namespace ctranslate2 { max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, 0 for an automatic value). When the queue is full, future requests will block until a free slot is available. + tensor_parallel: run model with tensor parallel mode. files: Load model files from the memory. This argument is a dictionary mapping file names to file contents as file-like or bytes objects. If this is set, :obj:`model_path` acts as an identifier for this model. @@ -168,6 +170,8 @@ namespace ctranslate2 { "Number of generators backing this instance.") .def_property_readonly("num_queued_batches", &GeneratorWrapper::num_queued_batches, "Number of batches waiting to be processed.") + .def_property_readonly("tensor_parallel", &GeneratorWrapper::tensor_parallel, + "Run model with tensor parallel mode.") .def_property_readonly("num_active_batches", &GeneratorWrapper::num_active_batches, "Number of batches waiting to be processed or currently processed.") diff --git a/python/cpp/module.cc b/python/cpp/module.cc index 4a9e47561..4489d5314 100644 --- a/python/cpp/module.cc +++ b/python/cpp/module.cc @@ -87,4 +87,5 @@ PYBIND11_MODULE(_ext, m) ctranslate2::python::register_encoder(m); ctranslate2::python::register_whisper(m); ctranslate2::python::register_wav2vec2(m); + ctranslate2::python::register_mpi(m); } diff --git a/python/cpp/module.h b/python/cpp/module.h index 01fdbdf59..9c9a9a2ff 100644 --- a/python/cpp/module.h +++ b/python/cpp/module.h @@ -18,6 +18,7 @@ namespace ctranslate2 { void register_translator(py::module& m); void register_whisper(py::module& m); void register_wav2vec2(py::module& m); + void register_mpi(py::module& m); } } diff --git a/python/cpp/mpi.cc b/python/cpp/mpi.cc new file mode 100644 index 000000000..01abf1157 --- /dev/null +++ b/python/cpp/mpi.cc @@ -0,0 +1,30 @@ +#include "module.h" + +#include + +#include "utils.h" + +namespace ctranslate2 { + namespace python { + + void register_mpi(py::module& m) { + py::class_( + m, "MpiInfo", + R"pbdoc( + An object to manage the MPI communication between processes. + It provides information about MPI connexion. + )pbdoc") + + .def_static("getNRanks", &ScopedMPISetter::getNRanks, + "Get the number of gpus running for the current model.") + + .def_static("getCurRank", &ScopedMPISetter::getCurRank, + "Get the current rank of process.") + + .def_static("getLocalRank", &ScopedMPISetter::getLocalRank, + "Get the current GPU id used by process.") + ; + } + + } +} diff --git a/python/cpp/replica_pool.h b/python/cpp/replica_pool.h index a735ea363..d71bf6b96 100644 --- a/python/cpp/replica_pool.h +++ b/python/cpp/replica_pool.h @@ -44,6 +44,7 @@ namespace ctranslate2 { size_t inter_threads, size_t intra_threads, long max_queued_batches, + bool tensor_parallel, py::object files) : _model_loader(create_model_reader(model_path, files)) { @@ -53,6 +54,7 @@ namespace ctranslate2 { _model_loader.device_indices = std::visit(DeviceIndexResolver(), device_index); _model_loader.compute_type = std::visit(ComputeTypeResolver(device), compute_type); _model_loader.num_replicas_per_device = inter_threads; + _model_loader.tensor_parallel = tensor_parallel; _pool_config.num_threads_per_replica = intra_threads; _pool_config.max_queued_batches = max_queued_batches; @@ -77,6 +79,10 @@ namespace ctranslate2 { return compute_type_to_str(model()->effective_compute_type()); } + bool tensor_parallel() const { + return _model_loader.tensor_parallel; + } + size_t num_replicas() const { return _pool->num_replicas(); } diff --git a/python/cpp/translator.cc b/python/cpp/translator.cc index b46d7ab9e..8e4a8a4be 100644 --- a/python/cpp/translator.cc +++ b/python/cpp/translator.cc @@ -33,6 +33,7 @@ namespace ctranslate2 { size_t inter_threads, size_t intra_threads, long max_queued_batches, + bool tensor_parallel, py::object files) : ReplicaPoolHelper(model_path, device, @@ -41,6 +42,7 @@ namespace ctranslate2 { inter_threads, intra_threads, max_queued_batches, + tensor_parallel, files) , _device(_model_loader.device) , _device_index(_model_loader.device_indices) @@ -378,7 +380,7 @@ namespace ctranslate2 { >>> translator.translate_batch([["▁Hello", "▁world", "!"]]) )pbdoc") - .def(py::init>&, const StringOrMap&, size_t, size_t, long, py::object>(), + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(), py::arg("model_path"), py::arg("device")="cpu", py::kw_only(), @@ -387,6 +389,7 @@ namespace ctranslate2 { py::arg("inter_threads")=1, py::arg("intra_threads")=0, py::arg("max_queued_batches")=0, + py::arg("tensor_parallel")=false, py::arg("files")=py::none(), R"pbdoc( Initializes the translator. @@ -403,6 +406,7 @@ namespace ctranslate2 { max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, 0 for an automatic value). When the queue is full, future requests will block until a free slot is available. + tensor_parallel: run model with tensor parallel mode files: Load model files from the memory. This argument is a dictionary mapping file names to file contents as file-like or bytes objects. If this is set, :obj:`model_path` acts as an identifier for this model. @@ -418,6 +422,8 @@ namespace ctranslate2 { "Number of translators backing this instance.") .def_property_readonly("num_queued_batches", &TranslatorWrapper::num_queued_batches, "Number of batches waiting to be processed.") + .def_property_readonly("tensor_parallel", &TranslatorWrapper::tensor_parallel, + "Run model with tensor parallel mode.") .def_property_readonly("num_active_batches", &TranslatorWrapper::num_active_batches, "Number of batches waiting to be processed or currently processed.") diff --git a/python/cpp/wav2vec2.cc b/python/cpp/wav2vec2.cc index ced116cb4..343caa158 100644 --- a/python/cpp/wav2vec2.cc +++ b/python/cpp/wav2vec2.cc @@ -27,7 +27,7 @@ namespace ctranslate2 { https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec )pbdoc") - .def(py::init>&, const StringOrMap&, size_t, size_t, long, py::object>(), + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(), py::arg("model_path"), py::arg("device")="cpu", py::kw_only(), @@ -36,6 +36,7 @@ namespace ctranslate2 { py::arg("inter_threads")=1, py::arg("intra_threads")=0, py::arg("max_queued_batches")=0, + py::arg("tensor_parallel")=false, py::arg("files")=py::none(), R"pbdoc( Initializes a Wav2Vec2 model from a converted model. @@ -52,6 +53,7 @@ namespace ctranslate2 { max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, 0 for an automatic value). When the queue is full, future requests will block until a free slot is available. + tensor_parallel: run model with tensor parallel mode files: Load model files from the memory. This argument is a dictionary mapping file names to file contents as file-like or bytes objects. If this is set, :obj:`model_path` acts as an identifier for this model. @@ -67,6 +69,8 @@ namespace ctranslate2 { "Number of model workers backing this instance.") .def_property_readonly("num_queued_batches", &Wav2Vec2Wrapper::num_queued_batches, "Number of batches waiting to be processed.") + .def_property_readonly("tensor_parallel", &Wav2Vec2Wrapper::tensor_parallel, + "Run model with tensor parallel mode.") .def_property_readonly("num_active_batches", &Wav2Vec2Wrapper::num_active_batches, "Number of batches waiting to be processed or currently processed.") diff --git a/python/cpp/whisper.cc b/python/cpp/whisper.cc index cb1b45a7d..47be8ece7 100644 --- a/python/cpp/whisper.cc +++ b/python/cpp/whisper.cc @@ -163,7 +163,7 @@ namespace ctranslate2 { .def_property_readonly("num_languages", &WhisperWrapper::num_languages, "Returns the number of languages supported.") - .def(py::init>&, const StringOrMap&, size_t, size_t, long, py::object>(), + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(), py::arg("model_path"), py::arg("device")="cpu", py::kw_only(), @@ -172,6 +172,7 @@ namespace ctranslate2 { py::arg("inter_threads")=1, py::arg("intra_threads")=0, py::arg("max_queued_batches")=0, + py::arg("tensor_parallel")=false, py::arg("files")=py::none(), R"pbdoc( Initializes a Whisper model from a converted model. @@ -188,6 +189,7 @@ namespace ctranslate2 { max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, 0 for an automatic value). When the queue is full, future requests will block until a free slot is available. + tensor_parallel: run model with tensor parallel mode files: Load model files from the memory. This argument is a dictionary mapping file names to file contents as file-like or bytes objects. If this is set, :obj:`model_path` acts as an identifier for this model. @@ -203,6 +205,8 @@ namespace ctranslate2 { "Number of model workers backing this instance.") .def_property_readonly("num_queued_batches", &WhisperWrapper::num_queued_batches, "Number of batches waiting to be processed.") + .def_property_readonly("tensor_parallel", &WhisperWrapper::tensor_parallel, + "Run model with tensor parallel mode.") .def_property_readonly("num_active_batches", &WhisperWrapper::num_active_batches, "Number of batches waiting to be processed or currently processed.") diff --git a/python/ctranslate2/__init__.py b/python/ctranslate2/__init__.py index a80997645..88da68aec 100644 --- a/python/ctranslate2/__init__.py +++ b/python/ctranslate2/__init__.py @@ -30,6 +30,7 @@ GenerationResult, GenerationStepResult, Generator, + MpiInfo, ScoringResult, StorageView, TranslationResult, diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py index 4cb765636..28b2e4f21 100644 --- a/python/ctranslate2/specs/model_spec.py +++ b/python/ctranslate2/specs/model_spec.py @@ -291,6 +291,9 @@ def to_dict(self): if not key.startswith("_") } + def add_attribute(self, key, value): + self.__dict__[key] = value + def save_as_json(self, path): """Saves the configuration as a JSON file.""" with open(path, "w", encoding="utf-8") as config_file: diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index 9de261c58..c3f8d91be 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -45,6 +45,7 @@ def __init__( rms_norm: Use the root mean square layer normalization. multi_query_attention: Use multi-query attention. """ + self.multi_query_attention = multi_query_attention self.num_heads = np.dtype("int16").type(num_heads) self.pre_norm = pre_norm self.activation = np.dtype("int8").type(activation) @@ -207,6 +208,9 @@ def __init__( for _ in range(num_layers) ] self.start_from_zero_embedding = False + self.multi_query_attention = multi_query_attention or ( + num_heads_kv != num_heads + ) if project_in_out: self.project_in = common_spec.LinearSpec() @@ -339,6 +343,9 @@ def __init__( super().__init__() self.encoder = encoder self.decoder = decoder + self._config.add_attribute( + "multi_query_attention", self.encoder.multi_query_attention + ) @classmethod def from_config( @@ -467,6 +474,9 @@ def __init__(self, decoder: TransformerDecoderSpec): super().__init__() self.decoder = decoder + self._config.add_attribute( + "multi_query_attention", self.decoder.multi_query_attention + ) @classmethod def from_config( @@ -608,6 +618,9 @@ def __init__( super().__init__() self.encoder = encoder + self._config.add_attribute( + "multi_query_attention", self.encoder.multi_query_attention + ) if pooling_layer: self.pooler_dense = common_spec.LinearSpec() diff --git a/python/tools/prepare_build_environment_linux.sh b/python/tools/prepare_build_environment_linux.sh index 0350e01e7..89f8293f6 100755 --- a/python/tools/prepare_build_environment_linux.sh +++ b/python/tools/prepare_build_environment_linux.sh @@ -27,7 +27,8 @@ else cuda-cudart-devel-12-2-12.2.140-1 \ libcurand-devel-12-2-10.3.3.141-1 \ libcudnn8-devel-8.9.7.29-1.cuda12.2 \ - libcublas-devel-12-2-12.2.5.6-1 + libcublas-devel-12-2-12.2.5.6-1 \ + libnccl-devel-2.19.3-1+cuda12.2 ln -s cuda-12.2 /usr/local/cuda ONEAPI_VERSION=2023.2.0 @@ -44,6 +45,15 @@ else cd .. rm -r oneDNN-* + OPENMPI_VERSION=4.1.6 + curl -L -O https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-${OPENMPI_VERSION}.tar.bz2 + tar xf *.tar.bz2 && rm *.tar.bz2 + cd openmpi-* + ./configure + make -j$(nproc) install + cd .. + rm -r openmpi-* + export LD_LIBRARY_PATH="/usr/local/lib/:$LD_LIBRARY_PATH" fi mkdir build-release && cd build-release @@ -51,7 +61,7 @@ mkdir build-release && cd build-release if [ "$CIBW_ARCHS" == "aarch64" ]; then cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_CLI=OFF -DWITH_MKL=OFF -DOPENMP_RUNTIME=COMP -DCMAKE_PREFIX_PATH="/opt/OpenBLAS" -DWITH_OPENBLAS=ON -DWITH_RUY=ON .. else - cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-msse4.1" -DBUILD_CLI=OFF -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP -DWITH_CUDA=ON -DWITH_CUDNN=ON -DCUDA_DYNAMIC_LOADING=ON -DCUDA_NVCC_FLAGS="-Xfatbin=-compress-all" -DCUDA_ARCH_LIST="Common" .. + cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-msse4.1" -DBUILD_CLI=OFF -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP -DWITH_CUDA=ON -DWITH_CUDNN=ON -DCUDA_DYNAMIC_LOADING=ON -DCUDA_NVCC_FLAGS="-Xfatbin=-compress-all" -DCUDA_ARCH_LIST="Common" -DWITH_TENSOR_PARALLEL=ON .. fi VERBOSE=1 make -j$(nproc) install diff --git a/src/cuda/mpi_stub.cc b/src/cuda/mpi_stub.cc new file mode 100644 index 000000000..a2a69c4da --- /dev/null +++ b/src/cuda/mpi_stub.cc @@ -0,0 +1,94 @@ +#include +#include + +#define STR_HELPER(x) #x +#define STR(x) STR_HELPER(x) + +#include + +#define OPENMPI_LIBNAME "libmpi.so." STR(OMPI_MAJOR_VERSION) STR(0) + +namespace ctranslate2 { + + template + static Signature load_symbol(void* handle, const char* name, const char* library_name) { + void* symbol = dlsym(handle, name); + if (!symbol) + throw std::runtime_error("Cannot load symbol " + std::string(name) + + " from library " + std::string(library_name)); + return reinterpret_cast(symbol); + } + + static void* get_so_handle() { + static auto so_handle = []() { + void* handle = dlopen(OPENMPI_LIBNAME, RTLD_LAZY); + return handle; + }(); + return so_handle; + } + + template + static Signature load_symbol(const char* name) { + void* handle = get_so_handle(); + if (!handle) + throw std::runtime_error("Library " + std::string(OPENMPI_LIBNAME) + + " is not found or cannot be loaded"); + return load_symbol(handle, name, OPENMPI_LIBNAME); + } + + template + static Signature load_symbol_global(const char* name) { + void* handle = get_so_handle(); + if (!handle) + return nullptr; + return load_symbol(handle, name, OPENMPI_LIBNAME); + } +} + +extern "C" { + + int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype, + void *recvbuf, int recvcount, + MPI_Datatype recvtype, MPI_Comm comm) { + using Signature = int(*)(const void *sendbuf, int sendcount, MPI_Datatype sendtype, + void *recvbuf, int recvcount, + MPI_Datatype recvtype, MPI_Comm comm); + static auto func = ctranslate2::load_symbol("MPI_Allgather"); + return func(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm); + } + + int MPI_Bcast(void *buffer, int count, MPI_Datatype datatype, + int root, MPI_Comm comm) { + using Signature = int(*)(void *buffer, int count, MPI_Datatype datatype, + int root, MPI_Comm comm); + static auto func = ctranslate2::load_symbol("MPI_Bcast"); + return func(buffer, count, datatype, root, comm); + } + + int MPI_Init(int *argc, char ***argv) { + using Signature = int(*)(int *argc, char ***argv); + static auto func = ctranslate2::load_symbol("MPI_Init"); + return func(argc, argv); + } + + int MPI_Finalize(void) { + using Signature = int(*)(void); + static auto func = ctranslate2::load_symbol("MPI_Finalize"); + return func(); + } + + int MPI_Comm_rank(MPI_Comm comm, int *rank) { + using Signature = int(*)(MPI_Comm comm, int *size); + static auto func = ctranslate2::load_symbol("MPI_Comm_rank"); + return func(comm, rank); + } + + int MPI_Comm_size(MPI_Comm comm, int *size) { + using Signature = int(*)(MPI_Comm comm, int *size); + static auto func = ctranslate2::load_symbol("MPI_Comm_size"); + return func(comm, size); + } +} +struct ompi_predefined_datatype_t* stub_mpi_datatype_null = ctranslate2::load_symbol_global("ompi_mpi_datatype_null"); +struct ompi_predefined_datatype_t* stub_ompi_mpi_byte = ctranslate2::load_symbol_global("ompi_mpi_byte"); +struct ompi_predefined_communicator_t* stub_ompi_mpi_comm_world = ctranslate2::load_symbol_global("ompi_mpi_comm_world"); \ No newline at end of file diff --git a/src/cuda/mpi_stub.h b/src/cuda/mpi_stub.h new file mode 100644 index 000000000..83803900a --- /dev/null +++ b/src/cuda/mpi_stub.h @@ -0,0 +1,18 @@ +#pragma once + +#include + +#ifdef CT2_WITH_CUDA_DYNAMIC_LOADING +extern struct ompi_predefined_datatype_t* stub_mpi_datatype_null; +#define STUB_MPI_DATATYPE_NULL OMPI_PREDEFINED_GLOBAL(MPI_Datatype, *stub_mpi_datatype_null) + +extern struct ompi_predefined_datatype_t* stub_ompi_mpi_byte; +#define STUB_MPI_BYTE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, *stub_ompi_mpi_byte) + +extern struct ompi_predefined_communicator_t* stub_ompi_mpi_comm_world; +#define STUB_MPI_COMM_WORLD OMPI_PREDEFINED_GLOBAL(MPI_Comm, *stub_ompi_mpi_comm_world) +#else +#define STUB_MPI_DATATYPE_NULL MPI_DATATYPE_NULL +#define STUB_MPI_BYTE MPI_BYTE +#define STUB_MPI_COMM_WORLD MPI_COMM_WORLD +#endif \ No newline at end of file diff --git a/src/cuda/nccl_stub.cc b/src/cuda/nccl_stub.cc new file mode 100644 index 000000000..669518cb2 --- /dev/null +++ b/src/cuda/nccl_stub.cc @@ -0,0 +1,93 @@ +#include + +#include + +#define STR_HELPER(x) #x +#define STR(x) STR_HELPER(x) + +#include +#define NCCL_LIBNAME "libnccl.so." STR(NCCL_MAJOR) + +#include + +namespace ctranslate2 { + + template + static Signature load_symbol(void* handle, const char* name, const char* library_name) { + void* symbol = dlsym(handle, name); + if (!symbol) + throw std::runtime_error("Cannot load symbol " + std::string(name) + + " from library " + std::string(library_name)); + return reinterpret_cast(symbol); + } + static inline void log_nccl_version(void* handle) { + using Signature = ncclResult_t(*)(int*); + const auto nccl_get_version = load_symbol(handle, + "ncclGetVersion", + NCCL_LIBNAME); + int version = 0; + nccl_get_version(&version); + spdlog::info("Loaded nccl library version {}", version); + } + + static void* get_so_handle() { + static auto so_handle = []() { + void* handle = dlopen(NCCL_LIBNAME, RTLD_LAZY); + if (!handle) + throw std::runtime_error("Library " + std::string(NCCL_LIBNAME) + + " is not found or cannot be loaded"); + log_nccl_version(handle); + return handle; + }(); + return so_handle; + } + + template + static Signature load_symbol(const char* name) { + void* handle = get_so_handle(); + return load_symbol(handle, name, NCCL_LIBNAME); + } + +} + +extern "C" { + ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId) { + using Signature = ncclResult_t(*)(ncclUniqueId* uniqueId); + static auto func = ctranslate2::load_symbol("ncclGetUniqueId"); + return func(uniqueId); + } + + ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) { + using Signature = ncclResult_t(*)(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + static auto func = ctranslate2::load_symbol("ncclCommInitRank"); + return func(comm, nranks, commId, rank); + } + + ncclResult_t ncclCommDestroy(ncclComm_t comm) { + using Signature = ncclResult_t(*)(ncclComm_t comm); + static auto func = ctranslate2::load_symbol("ncclCommDestroy"); + return func(comm); + } + + ncclResult_t ncclCommAbort(ncclComm_t comm) { + using Signature = ncclResult_t(*)(ncclComm_t comm); + static auto func = ctranslate2::load_symbol("ncclCommAbort"); + return func(comm); + } + + ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) { + using Signature = ncclResult_t(*)(const void* sendbuff, void* recvbuff, size_t count, + ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream); + static auto func = ctranslate2::load_symbol("ncclAllReduce"); + return func(sendbuff, recvbuff, count, datatype, op, comm, stream); + } + + ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, + ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) { + using Signature = ncclResult_t(*)(const void* sendbuff, void* recvbuff, size_t sendcount, + ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); + static auto func = ctranslate2::load_symbol("ncclAllGather"); + return func(sendbuff, recvbuff, sendcount, datatype, comm, stream); + } +} diff --git a/src/cuda/utils.h b/src/cuda/utils.h index 29bc99a39..8c1c134fe 100644 --- a/src/cuda/utils.h +++ b/src/cuda/utils.h @@ -6,6 +6,10 @@ #include #include +#ifdef CT2_WITH_TENSOR_PARALLEL +# include +# include +#endif #ifdef CT2_WITH_CUDNN # include #endif @@ -16,6 +20,24 @@ namespace ctranslate2 { namespace cuda { +#ifdef CT2_WITH_TENSOR_PARALLEL +#define MPI_CHECK(ans) \ + { \ + int e = ans; \ + if( e != MPI_SUCCESS ) \ + THROW_RUNTIME_ERROR("MPI failed with error " \ + + std::to_string(e)); \ + } + +#define NCCL_CHECK(ans) \ + { \ + ncclResult_t r = ans; \ + if( r != ncclSuccess ) \ + THROW_RUNTIME_ERROR("NCCL failed with error " \ + + std::to_string(r)); \ + } +#endif + #define CUDA_CHECK(ans) \ { \ cudaError_t code = (ans); \ diff --git a/src/devices.cc b/src/devices.cc index 3822cc3c3..47582f8be 100644 --- a/src/devices.cc +++ b/src/devices.cc @@ -3,6 +3,9 @@ #ifdef CT2_WITH_CUDA # include "cuda/utils.h" #endif +#ifdef CT2_WITH_TENSOR_PARALLEL +# include +#endif #include "device_dispatch.h" @@ -115,5 +118,101 @@ namespace ctranslate2 { (void)device; #endif } + // Initialize the static member variable +#ifdef CT2_WITH_TENSOR_PARALLEL + std::vector ScopedMPISetter::_nccl_comms; +#endif + int my_rank = 0; + int local_rank = 0; + int n_ranks = 1; + + ScopedMPISetter::ScopedMPISetter() { +#ifdef CT2_WITH_TENSOR_PARALLEL + // initializing MPI + MPI_CHECK(MPI_Init(nullptr, nullptr)); + MPI_CHECK(MPI_Comm_rank(STUB_MPI_COMM_WORLD, &my_rank)); + MPI_CHECK(MPI_Comm_size(STUB_MPI_COMM_WORLD, &n_ranks)); + + uint64_t hostHashs[n_ranks]; + char hostname[1024]; + getHostName(hostname, 1024); + hostHashs[my_rank] = getHostHash(hostname); + MPI_CHECK(MPI_Allgather(MPI_IN_PLACE, 0, STUB_MPI_DATATYPE_NULL, + hostHashs, sizeof(uint64_t), STUB_MPI_BYTE, STUB_MPI_COMM_WORLD)); + for (int p = 0; p < n_ranks; p++) { + if (p == my_rank) { + break; + } + if (hostHashs[p] == hostHashs[my_rank]) { + local_rank++; + } + } + atexit(finalize); +#endif + } + + ScopedMPISetter::~ScopedMPISetter() = default; +#ifdef CT2_WITH_TENSOR_PARALLEL + uint64_t ScopedMPISetter::getHostHash(const char *string) { + // Based on DJB2, result = result * 33 + char + uint64_t result = 5381; + for (int c = 0; string[c] != '\0'; c++) { + result = ((result << 5) + result) + string[c]; + } + return result; + } + + void ScopedMPISetter::getHostName(char *hostname, int maxlen) { + gethostname(hostname, maxlen); + for (int i = 0; i < maxlen; i++) { + if (hostname[i] == '.') { + hostname[i] = '\0'; + return; + } + } + } + + ncclComm_t ScopedMPISetter::getNcclComm() { + static thread_local ncclComm_t comm; + static thread_local ncclUniqueId id; + + if (comm == nullptr) { + int nRanks = ScopedMPISetter::getNRanks(); + int myRank = ScopedMPISetter::getCurRank(); + if (myRank == 0) { + ncclGetUniqueId(&id); + } + MPI_CHECK(MPI_Bcast((void *) &id, sizeof(id), STUB_MPI_BYTE, 0, STUB_MPI_COMM_WORLD)); + NCCL_CHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); + _nccl_comms.push_back(&comm); + } + return comm; + } +#endif + + void ScopedMPISetter::finalize() { +#ifdef CT2_WITH_TENSOR_PARALLEL + for (auto* comm : _nccl_comms) { + //finalizing NCCL + if (*comm) { + NCCL_CHECK(ncclCommAbort(*comm)); + NCCL_CHECK(ncclCommDestroy(*comm)); + } + } + MPI_CHECK(MPI_Finalize()); +#endif + } + + int ScopedMPISetter::getNRanks() { + return n_ranks; + } + + int ScopedMPISetter::getCurRank() { + return my_rank; + } + + int ScopedMPISetter::getLocalRank() { + return local_rank; + } } diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 2a066ba86..cf6074b2a 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -343,7 +343,10 @@ namespace ctranslate2 { std::vector layers; layers.reserve(num_linear_layers); for (dim_t i = 0; i < num_linear_layers; ++i) - layers.emplace_back(model, scope + "/linear_" + std::to_string(i)); + if (i == (num_linear_layers - 1)) { + layers.emplace_back(model, scope + "/linear_" + std::to_string(i), nullptr, true); + } else + layers.emplace_back(model, scope + "/linear_" + std::to_string(i)); return layers; } @@ -376,11 +379,12 @@ namespace ctranslate2 { bool pre_norm, bool is_decoder, Alibi* alibi) - : _num_heads(num_heads) + : _tensor_parallel(model.tensor_parallel()) + , _num_heads(_tensor_parallel ? SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()) : num_heads) , _self_attention(self_attention) , _is_decoder(is_decoder) , _linear(make_linear_layers(model, scope, self_attention)) - , _d_model(_linear.back().output_size()) + , _d_model(_tensor_parallel ? SAFE_DIVIDE(_linear.back().output_size(), ScopedMPISetter::getNRanks()) : _linear.back().output_size()) , _d_head(model.get_attribute_with_default(scope + "/head_dim", _d_model / _num_heads)) , _pre_norm(pre_norm) , _layer_norm(build_optional_layer(model, scope + "/layer_norm")) @@ -392,11 +396,13 @@ namespace ctranslate2 { , _queries_scale(model.get_attribute_with_default( scope + "/queries_scale", 1.f / std::sqrt(static_cast(_d_head)))) - , _num_heads_kv(model.get_flag_with_default(scope + "/multi_query", false) + , _multi_query(model.get_flag_with_default(scope + "/multi_query", false)) + , _num_heads_kv(_multi_query ? 1 - : model.get_attribute_with_default(scope + "/num_heads_kv", - _num_heads)) - , _merge_time_and_head_dims(_num_heads_kv == 1 + : (_tensor_parallel ? model.get_attribute_with_default(scope + "/num_heads_kv", + _num_heads * ScopedMPISetter::getNRanks()) / ScopedMPISetter::getNRanks() + : model.get_attribute_with_default(scope + "/num_heads_kv", _num_heads))) + , _merge_time_and_head_dims(_multi_query && !_relative_attention_bias && !_relative_position_keys && !_relative_position_values) @@ -458,7 +464,7 @@ namespace ctranslate2 { if (cached_keys == nullptr || cached_keys->empty()) { _linear[1](values, fused_proj); - if (_num_heads_kv == 1) { + if (_multi_query) { if (values_padder) values_padder->add_padding(fused_proj); ops::Split(2, {_d_head, _d_head})(fused_proj, keys_proj, values_proj); @@ -476,7 +482,7 @@ namespace ctranslate2 { if (queries_proj.dim(1) == 1 && cached_keys) beam_size = queries_proj.dim(0) / cached_keys->dim(0); - if (_num_heads_kv == 1) { + if (_multi_query) { if (queries_padder) queries_padder->add_padding(queries_proj); queries_proj.reshape({queries_proj.dim(0) / beam_size, -1, _d_head}); @@ -592,6 +598,13 @@ namespace ctranslate2 { _linear.back()(context, output); + if (_tensor_parallel) { + Shape shape = output.shape(); + StorageView tmp(std::move(shape), output.dtype(), output.device()); + ops::ReduceAll ops_reduce_all(ops::ReduceAll::RED_OP::SUM); + ops_reduce_all(output, tmp); + output = std::move(tmp); + } if (_layer_norm) { ops::Add()(queries, output, output); diff --git a/src/layers/common.cc b/src/layers/common.cc index 5f70c4336..92b7e9cf1 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -265,7 +265,8 @@ namespace ctranslate2 { Dense::Dense(const models::Model& model, const std::string& scope, - const ops::ActivationType* activation_type) + const ops::ActivationType* activation_type, + const bool affected_by_tp) : _packed_weight(false) , _weight(get_linear_weight(model, scope, &_packed_weight)) , _bias(model.get_variable_if_exists(scope + "/bias")) @@ -294,6 +295,7 @@ namespace ctranslate2 { /*shift_to_uint8=*/bool(_u8_shift_compensation), /*round_before_cast=*/model.round_before_cast_in_quantization()) , _dequantize_op(activation_type) + , _affected_by_tp(affected_by_tp) { } @@ -344,7 +346,40 @@ namespace ctranslate2 { StorageView qinput(_weight.dtype(), device); StorageView qinput_scale(_qscale->dtype(), device); StorageView qoutput(DataType::INT32, device); - _quantize_op(input, qinput, qinput_scale); + const StorageView* pinput = &input; + + if (ScopedMPISetter::getNRanks() > 1 && _affected_by_tp) { + StorageView input_reshaped(input.shape(), input.dtype(), input.device()); + Shape shape = input.shape(); + dim_t batch_size = shape[0]; + dim_t depth = shape[shape.size() - 1]; + dim_t length = shape[shape.size() - 2]; + StorageView input_gather_all({1, depth * ScopedMPISetter::getNRanks(), batch_size * length}, input.dtype(), input.device()); + ops::Transpose transpose_op({0, 2, 1}); + // Transpose input B x L x D -> B x D x L + if (batch_size > 1) { + input_reshaped.shallow_copy(const_cast(input)); + input_reshaped.reshape({1, batch_size * length, depth}); + pinput = &input_reshaped; + } + StorageView input_t(input.dtype(), input.device()); + transpose_op(*pinput, input_t); + ops::GatherAll gather_ops; + gather_ops(input_t, input_gather_all); + input_t.resize({1, batch_size * length, depth * ScopedMPISetter::getNRanks()}); + transpose_op(input_gather_all, input_t); + StorageView qinput_tmp(_weight.dtype(), device); + _quantize_op(input_t, qinput_tmp, qinput_scale); + dim_t index = _weight.dim(-1) * ScopedMPISetter::getCurRank(); + dim_t size = _weight.dim(-1); + ops::Slide(-1, index, size)(qinput_tmp, qinput); + if (batch_size > 1) + qinput.reshape({batch_size, length, depth}); + } + else { + _quantize_op(input, qinput, qinput_scale); + } + _gemm_op(qinput, *weight, qoutput, compensation); _dequantize_op(qoutput, qinput_scale, diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc index 056a01f99..97b5669c1 100644 --- a/src/layers/transformer.cc +++ b/src/layers/transformer.cc @@ -14,7 +14,8 @@ namespace ctranslate2 { , _activation_type(activation_type) , _ff1(model, scope + "/linear_0", &_activation_type) , _ff1_noact(build_optional_layer(model, scope + "/linear_0_noact")) - , _ff2(model, scope + "/linear_1") { + , _ff2(model, scope + "/linear_1", nullptr, true) + , _tensor_parallel(model.tensor_parallel()) { } void FeedForwardNetwork::operator()(const StorageView& input, StorageView& output) const { @@ -29,7 +30,6 @@ namespace ctranslate2 { StorageView inner(dtype, device); _ff1(*x, inner); - if (_ff1_noact) { StorageView linear(dtype, device); (*_ff1_noact)(*x, linear); @@ -38,6 +38,14 @@ namespace ctranslate2 { _ff2(inner, output); + if (_tensor_parallel) { + Shape shape = output.shape(); + StorageView tmp(std::move(shape), output.dtype(), output.device()); + ops::ReduceAll red_op(ops::ReduceAll::RED_OP::SUM); + red_op(output, tmp); + output = std::move(tmp); + } + if (_layer_norm) { ops::Add()(input, output, output); @@ -250,6 +258,7 @@ namespace ctranslate2 { , _position_encoder(_layers.front()->get_self_attention().has_positional_embeddings() ? nullptr : build_position_encoder(model, scope + "/position_encodings", _embeddings)) + , _tensor_parallel(model.tensor_parallel()) { } @@ -278,8 +287,12 @@ namespace ctranslate2 { padder->remove_padding(input); } + int num_heads = _num_heads; + if (_tensor_parallel) { + num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()); + } lengths_mask = std::make_unique( - layers::MultiHeadAttention::prepare_length_mask(*lengths, _num_heads, max_time)); + layers::MultiHeadAttention::prepare_length_mask(*lengths, num_heads, max_time)); } StorageView position_bias(output.dtype(), output.device()); @@ -334,7 +347,8 @@ namespace ctranslate2 { : build_position_encoder(model, scope + "/position_encodings", _embeddings)) , _with_encoder_attention(_layers.front()->has_cross_attention()) , _proj(model, scope + "/projection") - , _sliding_window(model.get_attribute_with_default(scope + "/sliding_window", 0)) { + , _sliding_window(model.get_attribute_with_default(scope + "/sliding_window", 0)) + , _tensor_parallel(model.tensor_parallel()) { dim_t alignment_layer = ( model.get_attribute_with_default(scope + "/alignment_layer", -1)); @@ -497,13 +511,19 @@ namespace ctranslate2 { input_padder->remove_padding(layer_in); } + dim_t num_heads = _num_heads; + if (_tensor_parallel) { + num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()); + } + StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask( *lengths, - _num_heads, + num_heads, max_time, /*mask_future=*/true, multi_query); + if (step > 0) ops::Add()(lengths_mask, StorageView(int32_t(step)), lengths_mask); @@ -527,10 +547,14 @@ namespace ctranslate2 { } if (memory_lengths) { + dim_t num_heads = _num_heads; + if (_tensor_parallel) { + num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()); + } const dim_t beam_size = batch_size / memory_lengths->dim(0); memory_lengths_mask = std::make_unique( layers::MultiHeadAttention::prepare_length_mask(*memory_lengths, - _num_heads, + num_heads, beam_size > 1 ? beam_size : max_time)); } } @@ -585,9 +609,13 @@ namespace ctranslate2 { if (i > 0) { auto max_tokens = _sliding_window + layer_in_chunk->dim(1); StorageView tmp_lengths = StorageView(Shape{layer_in_chunk->dim(0)}, int32_t(max_tokens), device); + int num_heads = _num_heads; + if (_tensor_parallel) { + num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()); + } StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask( tmp_lengths, - _num_heads, + num_heads, max_tokens, /*mask_future=*/true, multi_query); diff --git a/src/models/model.cc b/src/models/model.cc index 0672494ff..65679f53d 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -5,6 +5,7 @@ #include "ctranslate2/models/model_factory.h" #include "ctranslate2/ops/ops.h" #include "ctranslate2/utils.h" +#include #ifdef CT2_WITH_CUDA # include "cuda/utils.h" @@ -17,6 +18,27 @@ namespace ctranslate2 { static const std::string binary_file = "model.bin"; static const std::string config_file = "config.json"; + enum class VARIABLE_TYPE { + ATTN_LINEAR_0_WEIGHT, + ATTN_LINEAR_0_WEIGHT_SCALE, + ATTN_LINEAR_0_BIAS, + ATTN_LINEAR_1_WEIGHT, + ATTN_LINEAR_1_WEIGHT_SCALE, + ATTN_LINEAR_1_BIAS, + ATTN_LINEAR_2_WEIGHT, + SELF_ATTN_LINEAR_0_WEIGHT, + SELF_ATTN_LINEAR_0_WEIGHT_SCALE, + SELF_ATTN_LINEAR_0_BIAS, + SELF_ATTN_LINEAR_1_WEIGHT, + FFN_LINEAR_0_WEIGHT, + FFN_LINEAR_0_BIAS, + FFN_LINEAR_0_WEIGHT_SCALE, + FFN_LINEAR_0_NOACT_WEIGHT, + FFN_LINEAR_0_NOACT_WEIGHT_SCALE, + FFN_LINEAR_0_NOACT_BIAS, + FFN_LINEAR_1_WEIGHT, + OTHERS, + }; static inline void report_stream_error(const std::streampos position, const size_t read_size, @@ -84,13 +106,13 @@ namespace ctranslate2 { return; // Move variables back to the CPU device. - if (src_device != Device::CPU) { + if (src_device != Device::CPU && dst_device == Device::CPU) { ScopedDeviceSetter scoped_device_setter(src_device, src_device_index); move_variables_to_device(variables, Device::CPU); } // Move variables to the destination device. - if (dst_device != Device::CPU) { + if (src_device == Device::CPU && dst_device != Device::CPU) { ScopedDeviceSetter scoped_device_setter(dst_device, dst_device_index); move_variables_to_device(variables, dst_device); } @@ -389,6 +411,31 @@ namespace ctranslate2 { } } + static void split_variables(StorageView variable, int dim, std::vector& partitions_size, std::vector& outputs) + { + if (variable.rank() < 1 || variable.rank() > 2) + throw std::runtime_error("Unsupported split variables which has the rank of matrix more than 2." + "Current variable has the rank " + std::to_string(variable.rank())); + + //std::vector outputs(num, StorageView(variable.dtype(), variable.device())); + + size_t num = partitions_size.size(); + std::vector p_outputs(num); + + for (int i = 0; i < num; ++i) { + p_outputs[i] = &outputs[i]; + } + ops::Split(dim, partitions_size)(variable, p_outputs); + } + + static bool replace(std::string& str, const std::string& from, const std::string& to) { + size_t start_pos = str.find(from); + if (start_pos == std::string::npos) + return false; + str.replace(start_pos, from.length(), to); + return true; + } + static void check_version(const size_t saved_version, const size_t current_version, const std::string& version_type) { @@ -403,24 +450,114 @@ namespace ctranslate2 { + "(Forward compatibility is not guaranteed.)"); } + static VARIABLE_TYPE classify_variable(const std::string& name) { + std::regex pattern_self_attn("/self_attention/linear_(\\d+)/(\\w+)"); + std::regex pattern_attn("/attention/linear_(\\d+)/(\\w+)"); + std::regex pattern_ffn("/ffn/linear_(\\d+)(\\w*)/(\\w+)"); + + std::smatch match; + + if (std::regex_search(name, match, pattern_self_attn)) { + int layer_number = std::stoi(match[1]); + std::string parameterName = match[2]; + + switch (layer_number) { + case 0: + if (parameterName == "bias") + return VARIABLE_TYPE::SELF_ATTN_LINEAR_0_BIAS; + if (parameterName == "weight") + return VARIABLE_TYPE::SELF_ATTN_LINEAR_0_WEIGHT; + else + return VARIABLE_TYPE::SELF_ATTN_LINEAR_0_WEIGHT_SCALE; + case 1: + if (parameterName == "weight") + return VARIABLE_TYPE::SELF_ATTN_LINEAR_1_WEIGHT; + default: + return VARIABLE_TYPE::OTHERS; + }; + } + else if (std::regex_search(name, match, pattern_attn)) { + int layer_number = std::stoi(match[1]); + std::string parameterName = match[2]; + + switch (layer_number) { + case 0: + if (parameterName == "bias") + return VARIABLE_TYPE::ATTN_LINEAR_0_BIAS; + if (parameterName == "weight") + return VARIABLE_TYPE::ATTN_LINEAR_0_WEIGHT; + return VARIABLE_TYPE::ATTN_LINEAR_0_WEIGHT_SCALE; + case 1: + if (parameterName == "bias") + return VARIABLE_TYPE::ATTN_LINEAR_1_BIAS; + if (parameterName == "weight") + return VARIABLE_TYPE::ATTN_LINEAR_1_WEIGHT; + return VARIABLE_TYPE::ATTN_LINEAR_1_WEIGHT_SCALE; + case 2: + if (parameterName == "weight") + return VARIABLE_TYPE::ATTN_LINEAR_2_WEIGHT; + default: + return VARIABLE_TYPE::OTHERS; + }; + } + else if (std::regex_search(name, match, pattern_ffn)) { + int layer_number = std::stoi(match[1]); + std::string noact = match[2]; + std::string parameterName = match[3]; + + switch (layer_number) { + case 0: + if (noact == "noact" && parameterName == "bias") + return VARIABLE_TYPE::FFN_LINEAR_0_NOACT_BIAS; + if (noact == "noact" && parameterName == "weight") + return VARIABLE_TYPE::FFN_LINEAR_0_NOACT_WEIGHT; + if (noact == "noact") + return VARIABLE_TYPE::FFN_LINEAR_0_NOACT_WEIGHT_SCALE; + if (parameterName == "bias") + return VARIABLE_TYPE::FFN_LINEAR_0_BIAS; + if (parameterName == "weight") + return VARIABLE_TYPE::FFN_LINEAR_0_WEIGHT; + return VARIABLE_TYPE::FFN_LINEAR_0_WEIGHT_SCALE; + case 1: + if (parameterName == "weight") + return VARIABLE_TYPE::FFN_LINEAR_1_WEIGHT; + default: + return VARIABLE_TYPE::OTHERS; + }; + } + + return VARIABLE_TYPE::OTHERS; + } + std::shared_ptr Model::load(const std::string& path, Device device, int device_index, - ComputeType compute_type) { + ComputeType compute_type, + bool tensor_parallel) { ModelFileReader model_reader(path); - return load(model_reader, device, device_index, compute_type); + return load(model_reader, device, device_index, compute_type, tensor_parallel); } std::shared_ptr Model::load(ModelReader& model_reader, Device device, int device_index, - ComputeType compute_type) { + ComputeType compute_type, + bool tensor_parallel) { { // Log the system configuration the first time a model is loaded. static std::once_flag log_once; std::call_once(log_once, log_system_config); } + int world_size; + int current_index; + if (tensor_parallel) { + ScopedMPISetter mpi_setter = ScopedMPISetter(); + device_index = ScopedMPISetter::getLocalRank(); + current_index = ScopedMPISetter::getCurRank(); + world_size = ScopedMPISetter::getNRanks(); + } + { // Check that the device and device index are valid. ScopedDeviceSetter(device, device_index); @@ -448,6 +585,7 @@ namespace ctranslate2 { auto model = create_model(spec); model->_binary_version = binary_version; model->_spec_revision = spec_revision; + model->_tensor_parallel = tensor_parallel; check_version(spec_revision, model->current_spec_revision(), "revision"); @@ -460,6 +598,19 @@ namespace ctranslate2 { // Load the variables. const auto num_variables = consume(model_file); model->_variable_index.reserve(num_variables); + + // check config for tensor parallel + bool multi_query_attention = false; + if (tensor_parallel) + { + + if (model->config.contains("multi_query_attention")) + multi_query_attention = model->config["multi_query_attention"]; + else + spdlog::warn("Running model in mode tensor parallel but missing multi_query_attention option in" + " the config.json could lead to error! Try using the latest version of converters"); + } + for (uint32_t i = 0; i < num_variables; ++i) { auto name = consume(model_file); const size_t rank = consume(model_file); @@ -481,6 +632,89 @@ namespace ctranslate2 { StorageView variable(std::move(shape), dtype); consume(model_file, num_bytes, static_cast(variable.buffer())); + if (tensor_parallel) { + int outer_dim = 0; + int inner_dim = 1; + static dim_t model_dim = 0; + static dim_t total_dim = 0; + + auto variable_type = classify_variable(name); + if (variable_type != VARIABLE_TYPE::OTHERS) { + std::vector outputs(world_size, StorageView(variable.dtype(), variable.device())); + switch (variable_type) { + case VARIABLE_TYPE::SELF_ATTN_LINEAR_1_WEIGHT: + case VARIABLE_TYPE::ATTN_LINEAR_2_WEIGHT: + case VARIABLE_TYPE::FFN_LINEAR_1_WEIGHT: + { + dim_t output_per_partition_dim = SAFE_DIVIDE(variable.dim(inner_dim), world_size); + std::vector partitions_size(world_size, output_per_partition_dim); + split_variables(std::move(variable), inner_dim, partitions_size, outputs); + break; + } + case VARIABLE_TYPE::SELF_ATTN_LINEAR_0_WEIGHT: + case VARIABLE_TYPE::SELF_ATTN_LINEAR_0_WEIGHT_SCALE: + case VARIABLE_TYPE::SELF_ATTN_LINEAR_0_BIAS: + { + std::vector partitions_size; + if (multi_query_attention) { + if (model_dim == 0) { + model_dim = variable.dim(-1); + total_dim = variable.dim(outer_dim); + } + dim_t q_dim = SAFE_DIVIDE(model_dim, world_size); + dim_t kv_dim = SAFE_DIVIDE((total_dim - model_dim), (2 * world_size)); + partitions_size = std::vector(world_size, q_dim); + std::vector kv_part(world_size * 2, kv_dim); + partitions_size.insert(partitions_size.end(), kv_part.begin(), kv_part.end()); + } + else { + dim_t dim_per_kqv_per_partition = SAFE_DIVIDE(variable.dim(outer_dim) / 3, world_size); + partitions_size = std::vector(3 * world_size, dim_per_kqv_per_partition); + } + std::vector outputs_tmp = std::vector(partitions_size.size(), + StorageView(variable.dtype(), + variable.device())); + split_variables(std::move(variable), outer_dim, partitions_size, outputs_tmp); + for (int i = 0; i < world_size; i++) { + std::vector output_linear = {&outputs_tmp[i], &outputs_tmp[i + world_size], + &outputs_tmp[i + world_size * 2]}; + StorageView tmp(variable.dtype(), variable.device()); + ops::Concat(static_cast(outer_dim))(output_linear, tmp); + outputs[i] = std::move(tmp); + } + break; + } + case VARIABLE_TYPE::ATTN_LINEAR_1_WEIGHT: + case VARIABLE_TYPE::ATTN_LINEAR_1_WEIGHT_SCALE: + case VARIABLE_TYPE::ATTN_LINEAR_1_BIAS: + { + std::vector partitions_size; + dim_t dim_per_kqv_per_partition = SAFE_DIVIDE(variable.dim(outer_dim) / 2, world_size); + partitions_size = std::vector(2 * world_size, dim_per_kqv_per_partition); + std::vector outputs_tmp = std::vector(partitions_size.size(), + StorageView(variable.dtype(), + variable.device())); + split_variables(std::move(variable), outer_dim, partitions_size, outputs_tmp); + for (int i = 0; i < world_size; i++) { + std::vector output_linear = {&outputs_tmp[i], &outputs_tmp[i + world_size]}; + StorageView tmp(variable.dtype(), variable.device()); + ops::Concat(static_cast(outer_dim))(output_linear, tmp); + outputs[i] = std::move(tmp); + } + break; + } + default: + { + dim_t output_per_partition_dim = SAFE_DIVIDE(variable.dim(outer_dim), world_size); + std::vector partitions_size(world_size, output_per_partition_dim); + split_variables(std::move(variable), outer_dim, partitions_size, outputs); + } + }; + if (outputs.size() > current_index && !outputs[current_index].empty()) + variable = std::move(outputs[current_index]); + } + } + model->register_variable(std::move(name), std::move(variable)); } @@ -558,16 +792,24 @@ namespace ctranslate2 { if (device == Device::CUDA && !cuda::have_same_compute_capability(device_indices)) throw std::invalid_argument("Cannot use multiple GPUs with different Compute Capabilities " "for the same model"); + if (tensor_parallel && device != Device::CUDA) { + throw std::invalid_argument("Tensor Parallel mode can run only on cuda"); + } #endif std::vector> models; + if (tensor_parallel && (device_indices.size() > 1)) { + spdlog::warn("Running model in mode tensor parallel does not support" + " running independently a model in each device"); + } + models.reserve(device_indices.size() * num_replicas_per_device); for (const size_t device_index : device_indices) { std::shared_ptr model; if (models.empty()) - model = Model::load(*model_reader, device, device_index, compute_type); + model = Model::load(*model_reader, device, device_index, compute_type, tensor_parallel); else model = models.back()->copy_to(device, device_index); diff --git a/src/ops/nccl_ops.cc b/src/ops/nccl_ops.cc new file mode 100644 index 000000000..756ce0332 --- /dev/null +++ b/src/ops/nccl_ops.cc @@ -0,0 +1,23 @@ +#include "ctranslate2/ops/nccl_ops.h" +#include "dispatch.h" + +namespace ctranslate2 { + namespace ops { + + ReduceAll::ReduceAll(ReduceAll::RED_OP op) + : _reduce_op(op) { + } + + void ReduceAll::operator()(const StorageView& input, StorageView& output) const { + PROFILE("ReduceAll"); + DEVICE_AND_TYPE_DISPATCH(input.device(), input.dtype(), (compute(input, output))); + } + + GatherAll::GatherAll() = default; + + void GatherAll::operator()(const StorageView& input, StorageView& output) const { + PROFILE("ReduceAll"); + DEVICE_AND_TYPE_DISPATCH(input.device(), input.dtype(), (compute(input, output))); + } + } +} diff --git a/src/ops/nccl_ops_cpu.cc b/src/ops/nccl_ops_cpu.cc new file mode 100644 index 000000000..d0f63750b --- /dev/null +++ b/src/ops/nccl_ops_cpu.cc @@ -0,0 +1,23 @@ +#include "ctranslate2/ops/nccl_ops.h" +#include "dispatch.h" + +namespace ctranslate2 { + namespace ops { + + template + void ReduceAll::compute(const StorageView& /*input*/, StorageView& /*output*/) const { + throw std::runtime_error("reduce all is not applied for the cpu"); + } + + template + void GatherAll::compute(const StorageView& /*input*/, StorageView& /*output*/) const { + throw std::runtime_error("gather all is not applied for the cpu"); + } + #define DECLARE_IMPL(T) \ + template void ReduceAll::compute(const StorageView&, \ + StorageView&) const; \ + template void GatherAll::compute(const StorageView&, \ + StorageView&) const; + DECLARE_ALL_TYPES(DECLARE_IMPL) + } +} diff --git a/src/ops/nccl_ops_gpu.cu b/src/ops/nccl_ops_gpu.cu new file mode 100644 index 000000000..1b607ef6a --- /dev/null +++ b/src/ops/nccl_ops_gpu.cu @@ -0,0 +1,93 @@ +#include "ctranslate2/ops/nccl_ops.h" +#ifdef CT2_WITH_TENSOR_PARALLEL + #include + #include "cuda/utils.h" +#endif +#include "type_dispatch.h" + +namespace ctranslate2 { + namespace ops { + +#ifdef CT2_WITH_TENSOR_PARALLEL + ncclDataType_t getNcclDataTypeFromDataType(DataType type) { + switch (type) { +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) + case DataType::BFLOAT16: + return ncclBfloat16; +#endif + case DataType::FLOAT16: + return ncclFloat16; + case DataType::FLOAT32: + return ncclFloat32; + case DataType::INT32: + return ncclInt32; + case DataType::INT8: + return ncclInt8; + default: + throw std::invalid_argument("The current datatype " + std::to_string(static_cast(type)) + + " is not supported for the mode tensor parallel "); + } + } + + ncclRedOp_t redop_to_nccl_op(ReduceAll::RED_OP op) { + switch (op) { + case ReduceAll::RED_OP::SUM: + return ncclSum; + case ReduceAll::RED_OP::PROD: + return ncclProd; + case ReduceAll::RED_OP::MAX: + return ncclMax; + case ReduceAll::RED_OP::MIN: + return ncclMin; +#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0) + case ReduceAll::RED_OP::AVG: + return ncclAvg; +#endif + default: + throw std::runtime_error("the current reduce operation " + std::to_string(static_cast(op)) + " is not supported"); + } + } +#endif + + template + void ReduceAll::compute(const StorageView& input, StorageView& output) const { +#ifdef CT2_WITH_TENSOR_PARALLEL + // initializing NCCL + dim_t data_size = input.size(); + ncclComm_t comm = ScopedMPISetter::getNcclComm(); + ncclDataType_t ncclDataType = getNcclDataTypeFromDataType(input.dtype()); + ncclRedOp_t ncclOp = redop_to_nccl_op(_reduce_op); + NCCL_CHECK(ncclAllReduce(input.data(), output.data(), + data_size, ncclDataType, ncclOp, + comm, cuda::get_cuda_stream())); + + cudaStreamSynchronize(cuda::get_cuda_stream()); +#endif + (void)input; + (void)output; + } + + template + void GatherAll::compute(const StorageView& input, StorageView& output) const { +#ifdef CT2_WITH_TENSOR_PARALLEL + // initializing NCCL + dim_t data_size = input.size(); + ncclComm_t comm = ScopedMPISetter::getNcclComm(); + ncclDataType_t ncclDataType = getNcclDataTypeFromDataType(input.dtype()); + NCCL_CHECK(ncclAllGather(input.data(), output.data(), + data_size, ncclDataType, + comm, cuda::get_cuda_stream())); + + cudaStreamSynchronize(cuda::get_cuda_stream()); +#endif + (void)input; + (void)output; + } +#define DECLARE_IMPL(T) \ + template void GatherAll::compute(const StorageView&, \ + StorageView&) const; \ + template void ReduceAll::compute(const StorageView&, \ + StorageView&) const; + DECLARE_ALL_TYPES(DECLARE_IMPL) + } +} diff --git a/src/utils.cc b/src/utils.cc index f0eb29509..4f8bde57c 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -189,5 +189,4 @@ namespace ctranslate2 { return features; } - } diff --git a/tools/benchmark_tensor_parallel/README.md b/tools/benchmark_tensor_parallel/README.md new file mode 100644 index 000000000..3ee92f2ca --- /dev/null +++ b/tools/benchmark_tensor_parallel/README.md @@ -0,0 +1,18 @@ +## Benchmark tools + +This directory contains script to test the tensor parallelism mode. + +### Requirements + +* Python 3 +* Following this [doc](../../docs/parallel.md#model-and-tensor-parallelism) to configure the environment. + +```bash +python3 -m pip install -r requirements.txt +``` + +### Usage + +```text +mpirun -np 2 -hostfile hostfile python3 benchmark.py --mode --model_path --src --target --batch_size +``` \ No newline at end of file diff --git a/tools/benchmark_tensor_parallel/benchmark.py b/tools/benchmark_tensor_parallel/benchmark.py new file mode 100644 index 000000000..12a3e35fd --- /dev/null +++ b/tools/benchmark_tensor_parallel/benchmark.py @@ -0,0 +1,172 @@ +import ctranslate2 +import argparse +import os +import collections +import time +import GPUtil +import sentencepiece as spm +import concurrent.futures + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + + +class BenchmarkResult( + collections.namedtuple( + "BenchmarkResult", + ( + "generation_time", + "num_tokens", + "max_gpu_mem", + ), + ) +): + pass + + +def build_prompt(sp, inputs): + prompt_tokens = [] + for question in inputs: + input_tokens = [""] + sp.encode_as_pieces( + f"{B_INST} {question.strip()} {E_INST}" + ) + prompt_tokens.append(input_tokens) + return prompt_tokens + + +def count_tokens(generated_token): + count = 0 + for output in generated_token: + count += len(output) + return count + + +def avg_tokens(generated_token): + return count_tokens(generated_token) / len(generated_token) + + +def process_prompt(generator, max_generation_length, generated_token, prompt): + step_results = generator.generate_tokens( + prompt, + max_length=max_generation_length, + sampling_temperature=0.6, + sampling_topk=20, + sampling_topp=1, + ) + for step_result in step_results: + batch_id = step_result.batch_id + generated_token[batch_id].append(step_result.token) + + +def benchmark_generation(generator, + sp, + prompt_tokens, + generated_file, + mode, + batch_size): + max_generation_length = 512 + generated_token = [[] for _ in range(len(prompt_tokens))] + generated_text = ["" for _ in range(len(prompt_tokens))] + tokens_buffer = [] + elapsed_time = None + num_tokens = 0 + + if mode == "sequence": + start_all = time.time() + for i in range(0, len(prompt_tokens), batch_size): + step_results = generator.generate_tokens( + prompt_tokens[i:i + batch_size], + max_length=max_generation_length, + sampling_temperature=0.6, + sampling_topk=20, + sampling_topp=1, + ) + for step_result in step_results: + batch_id = step_result.batch_id + generated_token[batch_id].append(step_result.token) + end_all = time.time() + elapsed_time = end_all - start_all + num_tokens = count_tokens(generated_token) + elif mode == "parallel": + nb_process = len(prompt_tokens) / batch_size + 1 + start_all = time.time() + with concurrent.futures.ThreadPoolExecutor(max_workers=nb_process) as executor: + futures = [executor.submit(process_prompt, generator, max_generation_length, generated_token, + prompt_tokens[index:index + batch_size]) + for index in range(0, len(prompt_tokens), batch_size)] + num_tokens = count_tokens(generated_token) + end_all = time.time() + elapsed_time = end_all - start_all + + memory_gpus = float(GPUtil.getGPUs()[0].memoryUsed) + + # save answer to file + for index in range(0, len(generated_token)): + for token in generated_token[index]: + is_new_word = token.startswith("▁") + if is_new_word and tokens_buffer: + word = sp.decode(tokens_buffer) + if word: + if generated_text[index]: + word = ' ' + word + generated_text[index] += word + tokens_buffer = [] + tokens_buffer.append(token) + if tokens_buffer: + word = sp.decode(tokens_buffer) + if generated_text[index]: + word = ' ' + word + generated_text[index] += word + tokens_buffer = [] + + # write result to target file + target_file = os.path.abspath(generated_file) + if ctranslate2.MpiInfo.getCurRank() == 0: + with open(target_file, 'w') as file: + for index in range(len(generated_text)): + file.write(f"answer{index}: ") + file.write(generated_text[index]) + file.write(f"\n\n") + + return BenchmarkResult(elapsed_time, num_tokens, memory_gpus) + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--mode", + choices=["sequence", "parallel"], + default="sequence", + help="benchmark in parallel or sequence mode", + ) + parser.add_argument("--model_path", type=str, help="model path") + parser.add_argument("--src", type=str, help="source file") + parser.add_argument("--target", type=str, help="target file") + parser.add_argument("--batch_size", type=int, help="batch size") + args = parser.parse_args() + + print("Loading the model...") + generator = ctranslate2.Generator(args.model_path, device="cuda", tensor_parallel=True, inter_threads=2) + sp = spm.SentencePieceProcessor(os.path.join(args.model_path, "tokenizer.model")) + + if not os.path.exists(args.src): + raise Exception("No source file found: " + args.src) + # Open the file in read mode + with open(args.src, 'r') as file: + # Read all lines from the file and create a list + inputs = file.readlines() + + prompt_tokens = build_prompt(sp, inputs) + result = benchmark_generation(generator, sp, prompt_tokens, args.target, args.mode, args.batch_size) + if ctranslate2.MpiInfo.getCurRank() == 0: + print("Benchmark result (%d sample(s)):" % len(prompt_tokens)) + print("- Generation time: %.2f s" % result.generation_time) + print("- Number of tokens: %d" % result.num_tokens) + print("- Throughput: %.1f" % (result.num_tokens / result.generation_time)) + print("- max. GPU memory usage: %dMB" % int(result.max_gpu_mem)) + + +if __name__ == "__main__": + main() diff --git a/tools/benchmark_tensor_parallel/requirements.txt b/tools/benchmark_tensor_parallel/requirements.txt new file mode 100644 index 000000000..533257d10 --- /dev/null +++ b/tools/benchmark_tensor_parallel/requirements.txt @@ -0,0 +1,3 @@ +ctranslate2>=4.1.0 +sentencepiece +GPUtil \ No newline at end of file From ac8f7ae9ced73db74b6d161744056ee70a71b86f Mon Sep 17 00:00:00 2001 From: thucpham Date: Fri, 1 Mar 2024 16:28:40 +0100 Subject: [PATCH 2/4] add docs --- README.md | 1 + docs/parallel.md | 44 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5a9f51abd..53fd07430 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,7 @@ The project is production-oriented and comes with [backward compatibility guaran * **Lightweight on disk**
Quantization can make the models 4 times smaller on disk with minimal accuracy loss. * **Simple integration**
The project has few dependencies and exposes simple APIs in [Python](https://opennmt.net/CTranslate2/python/overview.html) and C++ to cover most integration needs. * **Configurable and interactive decoding**
[Advanced decoding features](https://opennmt.net/CTranslate2/decoding.html) allow autocompleting a partial sequence and returning alternatives at a specific location in the sequence. +* **Support tensor parallelism for distributed inference. Some of these features are difficult to achieve with standard deep learning frameworks and are the motivation for this project. diff --git a/docs/parallel.md b/docs/parallel.md index 604fea122..887a1d744 100644 --- a/docs/parallel.md +++ b/docs/parallel.md @@ -42,8 +42,50 @@ Parallelization with multiple Python threads is possible because all computation ``` ## Model and tensor parallelism +Models as the [`Translator`](python/ctranslate2.Translator.rst) and [`Generator`](python/ctranslate2.Generator.rst) can be split into multiple GPUs different. +This is very helpful when the model is too big to be load in only 1 GPU. -These types of parallelism are not yet implemented in CTranslate2. +```python +translator = ctranslate2.Translator(model_path, device="cuda", tensor_parallel=True) +``` + +Setup environment: +* Install [open-mpi](https://www.open-mpi.org/) +* Configure open-mpi by creating the config file like ``hostfile``: +```bash +[ipaddress or dns] slots=nbGPU1 +[other ipaddress or dns] slots=NbGPU2 +``` + +Run: +* Run the application in multiprocess to using tensor parallel: +```bash +mpirun -np nbGPUExpected -hostfile hostfile python3 script +``` + +If you're trying to run the tensor parallelism in multiple machine, there are additional configuration is needed: +* Make sure Master and Slave can connect to each other as a pair with ssh + pubkey +* Export all necessary environment variables from Master to Slave like the example below: +```bash +mpirun -x VIRTUAL_ENV_PROMPT -x PATH -x VIRTUAL_ENV -x _ -x LD_LIBRARY_PATH -np nbGPUExpected -hostfile hostfile python3 script +``` +Read more [open-mpi docs](https://www.open-mpi.org/doc/) for more information. + +* In this mode, the application will be run in multiprocess. We can filter out the master process by using: +```python +if ctranslate2.MpiInfo.getCurRank() == 0: + print(...) +``` + +```{note} +Running model in tensor parallel mode in one machine can boost the performance but if running the model shared between multiple +machine could be slower because of the latency in the connectivity. +``` + +```{note} +In mode tensor parallel, `inter_threads` is always supported to run multiple workers. Otherwise, `device_index` no longer has any effect +because tensor parallel mode will check only available gpus on the system and number of gpu that you want to use. +``` ## Asynchronous execution From 05a2702c7ee456013723b86c82d6700af56cbc26 Mon Sep 17 00:00:00 2001 From: thucpham Date: Mon, 4 Mar 2024 12:23:35 +0100 Subject: [PATCH 3/4] small fix --- docs/parallel.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/parallel.md b/docs/parallel.md index 887a1d744..ba827d7b2 100644 --- a/docs/parallel.md +++ b/docs/parallel.md @@ -42,8 +42,8 @@ Parallelization with multiple Python threads is possible because all computation ``` ## Model and tensor parallelism -Models as the [`Translator`](python/ctranslate2.Translator.rst) and [`Generator`](python/ctranslate2.Generator.rst) can be split into multiple GPUs different. -This is very helpful when the model is too big to be load in only 1 GPU. +Models used with [`Translator`](python/ctranslate2.Translator.rst) and [`Generator`](python/ctranslate2.Generator.rst) can be split into multiple GPUs. +This is very useful when the model is too big to be loaded in only 1 GPU. ```python translator = ctranslate2.Translator(model_path, device="cuda", tensor_parallel=True) @@ -58,12 +58,12 @@ Setup environment: ``` Run: -* Run the application in multiprocess to using tensor parallel: +* Run the application in multiprocess to use tensor parallel: ```bash mpirun -np nbGPUExpected -hostfile hostfile python3 script ``` -If you're trying to run the tensor parallelism in multiple machine, there are additional configuration is needed: +If you're trying to use tensor parallelism in multiple machines, some additional configuration is needed: * Make sure Master and Slave can connect to each other as a pair with ssh + pubkey * Export all necessary environment variables from Master to Slave like the example below: ```bash @@ -71,20 +71,20 @@ mpirun -x VIRTUAL_ENV_PROMPT -x PATH -x VIRTUAL_ENV -x _ -x LD_LIBRARY_PATH -np ``` Read more [open-mpi docs](https://www.open-mpi.org/doc/) for more information. -* In this mode, the application will be run in multiprocess. We can filter out the master process by using: +* In this mode, the application will run in multiprocess. We can filter out the master process by using: ```python if ctranslate2.MpiInfo.getCurRank() == 0: print(...) ``` ```{note} -Running model in tensor parallel mode in one machine can boost the performance but if running the model shared between multiple -machine could be slower because of the latency in the connectivity. +Running model in tensor parallel mode in one machine can boost the performance but if the model shared between multiple machines +could be slower because of the latency in the connectivity. ``` ```{note} In mode tensor parallel, `inter_threads` is always supported to run multiple workers. Otherwise, `device_index` no longer has any effect -because tensor parallel mode will check only available gpus on the system and number of gpu that you want to use. +because tensor parallel mode will check only for available gpus on the system and the number of gpus you want to use. ``` ## Asynchronous execution From 4ee128f30482c976970c5694d8d16a69e7a08c1f Mon Sep 17 00:00:00 2001 From: minhthuc Date: Mon, 4 Mar 2024 17:40:47 +0100 Subject: [PATCH 4/4] fix adding bias multiple times in layer output. --- include/ctranslate2/layers/common.h | 4 ++-- src/layers/common.cc | 12 +++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h index d06c6da6c..6c69275b8 100644 --- a/include/ctranslate2/layers/common.h +++ b/include/ctranslate2/layers/common.h @@ -128,7 +128,7 @@ namespace ctranslate2 { Dense(const models::Model& model, const std::string& scope, const ops::ActivationType* activation_type = nullptr, - const bool affected_by_tp = false); + const bool is_layer_out = false); DataType output_type() const override; dim_t output_size() const override; void operator()(const StorageView& input, StorageView& output) const; @@ -148,7 +148,7 @@ namespace ctranslate2 { const ops::Gemm _gemm_op; const ops::Quantize _quantize_op; const ops::Dequantize _dequantize_op; - const bool _affected_by_tp; + const bool _is_layer_out; }; class LayerNorm : public Layer diff --git a/src/layers/common.cc b/src/layers/common.cc index 92b7e9cf1..22b2a55bd 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -266,7 +266,7 @@ namespace ctranslate2 { Dense::Dense(const models::Model& model, const std::string& scope, const ops::ActivationType* activation_type, - const bool affected_by_tp) + const bool is_layer_out) : _packed_weight(false) , _weight(get_linear_weight(model, scope, &_packed_weight)) , _bias(model.get_variable_if_exists(scope + "/bias")) @@ -295,7 +295,7 @@ namespace ctranslate2 { /*shift_to_uint8=*/bool(_u8_shift_compensation), /*round_before_cast=*/model.round_before_cast_in_quantization()) , _dequantize_op(activation_type) - , _affected_by_tp(affected_by_tp) + , _is_layer_out(is_layer_out) { } @@ -341,6 +341,8 @@ namespace ctranslate2 { const StorageView* compensation = (_partial_u8_shift_compensation.empty() ? _u8_shift_compensation : &_partial_u8_shift_compensation); + + bool affected_by_tp = ScopedMPISetter::getNRanks() > 1 && _is_layer_out; if (_quantized_gemm) { const auto device = input.device(); StorageView qinput(_weight.dtype(), device); @@ -348,7 +350,7 @@ namespace ctranslate2 { StorageView qoutput(DataType::INT32, device); const StorageView* pinput = &input; - if (ScopedMPISetter::getNRanks() > 1 && _affected_by_tp) { + if (affected_by_tp) { StorageView input_reshaped(input.shape(), input.dtype(), input.device()); Shape shape = input.shape(); dim_t batch_size = shape[0]; @@ -381,6 +383,8 @@ namespace ctranslate2 { } _gemm_op(qinput, *weight, qoutput, compensation); + if (affected_by_tp && ScopedMPISetter::getCurRank() == 0) + bias = nullptr; _dequantize_op(qoutput, qinput_scale, *qscale, @@ -389,6 +393,8 @@ namespace ctranslate2 { output, bias); } else { + if (affected_by_tp && ScopedMPISetter::getCurRank() == 0) + bias = nullptr; _gemm_op(input, *weight, output, nullptr, bias); } }