Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support tensor parallel #1599

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ option(ENABLE_PROFILING "Compile with profiling support" OFF)
option(BUILD_CLI "Compile the clients" ON)
option(BUILD_TESTS "Compile the tests" OFF)
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF)

if(ENABLE_PROFILING)
message(STATUS "Enable profiling support")
Expand Down Expand Up @@ -179,6 +180,8 @@ set(SOURCES
src/ops/topp_mask.cc
src/ops/topp_mask_cpu.cc
src/ops/transpose.cc
src/ops/nccl_ops.cc
src/ops/nccl_ops_cpu.cc
src/padder.cc
src/profiler.cc
src/random.cc
Expand All @@ -191,7 +194,7 @@ set(SOURCES
src/utils.cc
src/vocabulary.cc
src/vocabulary_map.cc
)
)
set(LIBRARIES
${CMAKE_THREAD_LIBS_INIT}
spdlog::spdlog_header_only
Expand Down Expand Up @@ -419,6 +422,24 @@ endif()

if (WITH_CUDA)
find_package(CUDA 11.0 REQUIRED)
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
if (WITH_TENSOR_PARALLEL)
find_package(MPI REQUIRED)
find_package(NCCL REQUIRED)
include_directories(${NCCL_INCLUDE_DIR})
include_directories(${MPI_INCLUDE_PATH})
if(CUDA_DYNAMIC_LOADING)
list(APPEND SOURCES src/cuda/mpi_stub.cc)
list(APPEND SOURCES src/cuda/nccl_stub.cc)
add_definitions(-DCT2_WITH_CUDA_DYNAMIC_LOADING)
else ()
list(APPEND LIBRARIES ${NCCL_LIBRARY})
list(APPEND LIBRARIES ${MPI_LIBRARIES})
endif ()
add_definitions(-DCT2_WITH_TENSOR_PARALLEL)
endif ()
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include)

add_definitions(-DCT2_WITH_CUDA)
if(MSVC)
if(BUILD_SHARED_LIBS)
Expand Down Expand Up @@ -522,7 +543,8 @@ if (WITH_CUDA)
src/ops/topk_gpu.cu
src/ops/topp_mask_gpu.cu
src/ops/quantize_gpu.cu
)
src/ops/nccl_ops_gpu.cu
)
elseif(WITH_CUDNN)
message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON")
else()
Expand All @@ -546,6 +568,10 @@ target_include_directories(${PROJECT_NAME} BEFORE
PRIVATE ${PRIVATE_INCLUDE_DIRECTORIES}
)

if (WITH_TENSOR_PARALLEL AND CUDA_DYNAMIC_LOADING)
target_compile_options(${PROJECT_NAME} PRIVATE -DOMPI_SKIP_MPICXX)
endif()

if(BUILD_TESTS)
add_subdirectory(tests)
endif()
Expand Down Expand Up @@ -587,6 +613,11 @@ configure_file(cmake/${PROJECT_NAME}Config.cmake
COPYONLY
)

configure_file(cmake/FindNCCL.cmake
"${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}/FindNCCL.cmake"
COPYONLY
)

set(ConfigPackageLocation ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME})

if(BUILD_SHARED_LIBS)
Expand All @@ -603,6 +634,7 @@ endif()
install(
FILES
cmake/${PROJECT_NAME}Config.cmake
cmake/FindNCCL.cmake
"${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}/${PROJECT_NAME}ConfigVersion.cmake"
DESTINATION
${ConfigPackageLocation}
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ The project is production-oriented and comes with [backward compatibility guaran
* **Lightweight on disk**<br/>Quantization can make the models 4 times smaller on disk with minimal accuracy loss.
* **Simple integration**<br/>The project has few dependencies and exposes simple APIs in [Python](https://opennmt.net/CTranslate2/python/overview.html) and C++ to cover most integration needs.
* **Configurable and interactive decoding**<br/>[Advanced decoding features](https://opennmt.net/CTranslate2/decoding.html) allow autocompleting a partial sequence and returning alternatives at a specific location in the sequence.
* **Support tensor parallelism for distributed inference.

Some of these features are difficult to achieve with standard deep learning frameworks and are the motivation for this project.

Expand Down
28 changes: 28 additions & 0 deletions cmake/FindNCCL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Find the NCCL libraries
#
# The following variables are optionally searched for defaults
# NCCL_ROOT_DIR: Base directory where all NCCL components are found
#
# The following are set after configuration is done:
# NCCL_FOUND
# NCCL_INCLUDE_DIR
# NCCL_LIBRARY

find_path(NCCL_INCLUDE_DIR NAMES nccl.h
PATHS ${NCCL_ROOT_DIR}/include
)

find_library(NCCL_LIBRARY NAMES nccl
PATHS ${NCCL_ROOT_DIR}/lib ${NCCL_ROOT_DIR}/lib64)

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR
NCCL_LIBRARY)

if (NCCL_FOUND)
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIR}, library:
${NCCL_LIBRARY})")
mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY)
set(NCCL_VERSION "${NCCL_MAJOR}.${NCCL_MINOR}.${NCCL_PATCH}")

endif ()
20 changes: 17 additions & 3 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ RUN wget -q https://github.com/oneapi-src/oneDNN/archive/refs/tags/v${ONEDNN_VER
cd .. && \
rm -r oneDNN-*

ENV OPENMPI_VERSION=4.1.6
RUN wget -q https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-${OPENMPI_VERSION}.tar.bz2 && \
tar xf *.tar.bz2 && \
rm *.tar.bz2 && \
cd openmpi-* && \
./configure && \
make -j$(nproc) install && \
cd .. && \
rm -r openmpi-*

COPY third_party third_party
COPY cli cli
COPY include include
Expand All @@ -50,13 +60,14 @@ ENV CUDA_NVCC_FLAGS=${CUDA_NVCC_FLAGS:-"-Xfatbin=-compress-all"}
ARG CUDA_ARCH_LIST
ENV CUDA_ARCH_LIST=${CUDA_ARCH_LIST:-"Common"}
ENV CTRANSLATE2_ROOT=/opt/ctranslate2
ENV LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH}

RUN mkdir build && \
cd build && \
RUN mkdir build_tmp && \
cd build_tmp && \
cmake -DCMAKE_INSTALL_PREFIX=${CTRANSLATE2_ROOT} \
-DWITH_CUDA=ON -DWITH_CUDNN=ON -DWITH_MKL=ON -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP \
-DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \
-DCUDA_NVCC_FLAGS="${CUDA_NVCC_FLAGS}" -DCUDA_ARCH_LIST="${CUDA_ARCH_LIST}" .. && \
-DCUDA_NVCC_FLAGS="${CUDA_NVCC_FLAGS}" -DCUDA_ARCH_LIST="${CUDA_ARCH_LIST}" -DWITH_TENSOR_PARALLEL=ON .. && \
VERBOSE=1 make -j$(nproc) install

ENV LANG=en_US.UTF-8
Expand All @@ -74,6 +85,9 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends \
libcublas-12-2 \
libcudnn8=8.9.7.29-1+cuda12.2 \
libnccl2=2.19.3-1+cuda12.2 \
libopenmpi3=4.0.3-0ubuntu1 \
openmpi-bin \
libgomp1 \
python3-pip \
&& \
Expand Down
44 changes: 43 additions & 1 deletion docs/parallel.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,50 @@ Parallelization with multiple Python threads is possible because all computation
```

## Model and tensor parallelism
Models used with [`Translator`](python/ctranslate2.Translator.rst) and [`Generator`](python/ctranslate2.Generator.rst) can be split into multiple GPUs.
This is very useful when the model is too big to be loaded in only 1 GPU.

These types of parallelism are not yet implemented in CTranslate2.
```python
translator = ctranslate2.Translator(model_path, device="cuda", tensor_parallel=True)
```

Setup environment:
* Install [open-mpi](https://www.open-mpi.org/)
* Configure open-mpi by creating the config file like ``hostfile``:
```bash
[ipaddress or dns] slots=nbGPU1
[other ipaddress or dns] slots=NbGPU2
```

Run:
* Run the application in multiprocess to use tensor parallel:
```bash
mpirun -np nbGPUExpected -hostfile hostfile python3 script
```

If you're trying to use tensor parallelism in multiple machines, some additional configuration is needed:
* Make sure Master and Slave can connect to each other as a pair with ssh + pubkey
* Export all necessary environment variables from Master to Slave like the example below:
```bash
mpirun -x VIRTUAL_ENV_PROMPT -x PATH -x VIRTUAL_ENV -x _ -x LD_LIBRARY_PATH -np nbGPUExpected -hostfile hostfile python3 script
```
Read more [open-mpi docs](https://www.open-mpi.org/doc/) for more information.

* In this mode, the application will run in multiprocess. We can filter out the master process by using:
```python
if ctranslate2.MpiInfo.getCurRank() == 0:
print(...)
```

```{note}
Running model in tensor parallel mode in one machine can boost the performance but if the model shared between multiple machines
could be slower because of the latency in the connectivity.
```

```{note}
In mode tensor parallel, `inter_threads` is always supported to run multiple workers. Otherwise, `device_index` no longer has any effect
because tensor parallel mode will check only for available gpus on the system and the number of gpus you want to use.
```

## Asynchronous execution

Expand Down
30 changes: 30 additions & 0 deletions include/ctranslate2/devices.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

#include <stdexcept>
#include <string>
#include <vector>
#ifdef CT2_WITH_TENSOR_PARALLEL
# include <nccl.h>
#endif

namespace ctranslate2 {

Expand Down Expand Up @@ -45,4 +49,30 @@ namespace ctranslate2 {
int _new_index;
};

extern int my_rank;
extern int local_rank;
extern int n_ranks;

class ScopedMPISetter {
public:
ScopedMPISetter();
~ScopedMPISetter();

static int getNRanks();
static int getCurRank();
static int getLocalRank();

#ifdef CT2_WITH_TENSOR_PARALLEL
static ncclComm_t getNcclComm();
#endif

static void finalize();

private:
#ifdef CT2_WITH_TENSOR_PARALLEL
static uint64_t getHostHash(const char *string);
static void getHostName(char *hostname, int maxlen);
static std::vector<ncclComm_t*> _nccl_comms;
#endif
};
}
4 changes: 3 additions & 1 deletion include/ctranslate2/layers/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace ctranslate2 {
}

bool multi_query() const {
return _num_heads_kv == 1;
return _multi_query;
}

static StorageView prepare_length_mask(const StorageView& lengths,
Expand All @@ -53,6 +53,7 @@ namespace ctranslate2 {
const bool multi_query = false);

private:
const bool _tensor_parallel;
const dim_t _num_heads;
const bool _self_attention;
const bool _is_decoder;
Expand All @@ -68,6 +69,7 @@ namespace ctranslate2 {
const StorageView* _relative_position_values;
dim_t _maximum_relative_position;
const float _queries_scale;
const bool _multi_query;
const dim_t _num_heads_kv;
const bool _merge_time_and_head_dims;
const dim_t _cache_time_dim;
Expand Down
4 changes: 3 additions & 1 deletion include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ namespace ctranslate2 {
public:
Dense(const models::Model& model,
const std::string& scope,
const ops::ActivationType* activation_type = nullptr);
const ops::ActivationType* activation_type = nullptr,
const bool is_layer_out = false);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& input, StorageView& output) const;
Expand All @@ -147,6 +148,7 @@ namespace ctranslate2 {
const ops::Gemm _gemm_op;
const ops::Quantize _quantize_op;
const ops::Dequantize _dequantize_op;
const bool _is_layer_out;
};

class LayerNorm : public Layer
Expand Down
3 changes: 3 additions & 0 deletions include/ctranslate2/layers/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ namespace ctranslate2 {
const Dense _ff1;
const std::unique_ptr<const Dense> _ff1_noact;
const Dense _ff2;
const bool _tensor_parallel;
};

class TransformerEncoderLayer : public Layer
Expand Down Expand Up @@ -149,6 +150,7 @@ namespace ctranslate2 {
const std::unique_ptr<const LayerNorm> _output_norm;
const std::vector<std::unique_ptr<const TransformerEncoderLayer>> _layers;
const std::unique_ptr<PositionEncoder> _position_encoder;
const bool _tensor_parallel;
};

class TransformerDecoder : public Decoder
Expand Down Expand Up @@ -211,6 +213,7 @@ namespace ctranslate2 {
bool _average_alignment_heads;
Dense _proj;
const dim_t _sliding_window;
const bool _tensor_parallel;
};

}
Expand Down
12 changes: 10 additions & 2 deletions include/ctranslate2/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ namespace ctranslate2 {
static std::shared_ptr<const Model> load(const std::string& path,
Device device = Device::CPU,
int device_index = 0,
ComputeType compute_type = ComputeType::DEFAULT);
ComputeType compute_type = ComputeType::DEFAULT,
bool tensor_parallel = false);
static std::shared_ptr<const Model> load(ModelReader& model_reader,
Device device = Device::CPU,
int device_index = 0,
ComputeType compute_type = ComputeType::DEFAULT);
ComputeType compute_type = ComputeType::DEFAULT,
bool tensor_parallel = false);

virtual std::unique_ptr<SequenceToSequenceReplica> as_sequence_to_sequence() const;
virtual std::unique_ptr<SequenceGeneratorReplica> as_sequence_generator() const;
Expand Down Expand Up @@ -78,6 +80,10 @@ namespace ctranslate2 {
return _binary_version >= 5;
}

bool tensor_parallel() const {
return _tensor_parallel;
}

virtual bool use_global_int16_scale() const {
return true;
}
Expand Down Expand Up @@ -163,6 +169,7 @@ namespace ctranslate2 {
ComputeType _effective_compute_type = ComputeType::DEFAULT;
dim_t _preferred_size_multiple = 1;
std::unordered_map<std::string, std::shared_ptr<StorageView>> _variable_index;
bool _tensor_parallel = false;
};

template<>
Expand Down Expand Up @@ -191,6 +198,7 @@ namespace ctranslate2 {
std::vector<int> device_indices = {0};
size_t num_replicas_per_device = 1;
ComputeType compute_type = ComputeType::DEFAULT;
bool tensor_parallel = false;
};

// Base class for replicas.
Expand Down
Loading
Loading