Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/MLX5: a2a various optimizations #1067

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 207 additions & 128 deletions src/components/tl/mlx5/alltoall/alltoall_coll.c

Large diffs are not rendered by default.

25 changes: 16 additions & 9 deletions src/components/tl/mlx5/alltoall/alltoall_mkeys.c
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,15 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
int nbc = req->alltoall.num_of_blocks_columns;
int seq_index = req->alltoall.seq_index;
int repeat_count = nbc ? a2a->net.sbgp->group_size
: UCC_TL_TEAM_SIZE(team) / req->alltoall.block_size;
int n_mkeys = nbc ? nbc : 1;
int repeat_count;
int i;
ucc_status_t status;

if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag &
UCC_MLX5_NEED_SEND_MKEY_UPDATE) {
repeat_count = nbc ? a2a->net.sbgp->group_size
: UCC_TL_TEAM_SIZE(team) / req->alltoall.block_width;
for (i = 0; i < n_mkeys; i++) {
status = populate_strided_mkey(a2a, send_mem_access_flags,
node->ops[seq_index].send_mkeys[i],
Expand All @@ -313,6 +314,9 @@ ucc_status_t ucc_tl_mlx5_populate_send_recv_mkeys(ucc_tl_mlx5_team_t * team,
}
if (ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag &
UCC_MLX5_NEED_RECV_MKEY_UPDATE) {
repeat_count =
nbc ? a2a->net.sbgp->group_size
: UCC_TL_TEAM_SIZE(team) / req->alltoall.block_height;
for (i = 0; i < n_mkeys; i++) {
status = populate_strided_mkey(a2a, recv_mem_access_flags,
node->ops[seq_index].recv_mkeys[i],
Expand All @@ -332,7 +336,8 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a,
ucc_tl_mlx5_schedule_t *req, int direction_send)
{
ucc_tl_mlx5_alltoall_node_t *node = &a2a->node;
int block_size = req->alltoall.block_size;
int block_height = req->alltoall.block_height;
int block_width = req->alltoall.block_width;
size_t msg_size = req->alltoall.msg_size;
int nbc = req->alltoall.num_of_blocks_columns;
struct ibv_mr *buff = direction_send
Expand All @@ -345,26 +350,28 @@ static void update_mkey_entry(ucc_tl_mlx5_alltoall_t *a2a,
mkey_entry = (umr_t *)(direction_send ? MY_SEND_UMR_DATA(req, a2a, 0)
: MY_RECV_UMR_DATA(req, a2a, 0));
mkey_entry->addr = (uintptr_t)buff->addr;
mkey_entry->bytes_count = block_size * msg_size;
mkey_entry->bytes_count =
(direction_send ? block_width : block_height) * msg_size;
mkey_entry->bytes_skip = 0;
mkey_entry->lkey = direction_send ? buff->lkey : buff->rkey;
} else {
for (i = 0; i < nbc; i++) {
ucc_assert(block_height == block_width);
mkey_entry =
(umr_t *)(direction_send ? MY_SEND_UMR_DATA(req, a2a, i)
: MY_RECV_UMR_DATA(req, a2a, i));
mkey_entry->addr =
(uintptr_t)buff->addr + i * (block_size * msg_size);
(uintptr_t)buff->addr + i * (block_height * msg_size);
mkey_entry->bytes_count =
(i == (nbc - 1))
? ((node->sbgp->group_size % block_size) * msg_size)
: (block_size * msg_size);
? ((node->sbgp->group_size % block_height) * msg_size)
: (block_height * msg_size);
mkey_entry->bytes_skip =
(i == (nbc - 1))
? ((node->sbgp->group_size -
(node->sbgp->group_size % block_size)) *
(node->sbgp->group_size % block_height)) *
msg_size)
: ((node->sbgp->group_size - block_size) * msg_size);
: ((node->sbgp->group_size - block_height) * msg_size);
mkey_entry->lkey = direction_send ? buff->lkey : buff->rkey;
}
}
Expand Down
31 changes: 31 additions & 0 deletions src/components/tl/mlx5/tl_mlx5.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_lib_config_t, dm_buf_num),
UCC_CONFIG_TYPE_ULUNITS},

{"FORCE_REGULAR", "y",
"Force the regular case where the block dimensions "
"divide ppn. Requires BLOCK_SIZE=0",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_regular),
UCC_CONFIG_TYPE_BOOL},

{"FORCE_LONGER", "y", "Force the blocks to have more height than width",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_longer),
UCC_CONFIG_TYPE_BOOL},

{"FORCE_WIDER", "n", "Force the blocks to have more width than height",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, force_wider), UCC_CONFIG_TYPE_BOOL},

{"BLOCK_SIZE", "0",
"Size of the blocks that are sent using blocked AlltoAll Algorithm",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_size), UCC_CONFIG_TYPE_UINT},
Expand Down Expand Up @@ -104,6 +117,24 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable),
UCC_CONFIG_TYPE_BOOL},

{"SEND_BATCH_SIZE", "2",
"number of blocks that are transposed "
"on the NIC before being sent as a batch to a remote peer",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, block_batch_size),
UCC_CONFIG_TYPE_UINT},

{"NBR_SERIALIZED_BATCHES", "4",
"number of block batches "
"(within the set of blocks to be sent to a given remote peer) "
"serialized on the same device memory chunk",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_serialized_batches),
UCC_CONFIG_TYPE_UINT},

{"NBR_BATCHES_PER_PASSAGE", "1",
"number of batches of blocks sent to one remote node before enqueing",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, nbr_batches_per_passage),
UCC_CONFIG_TYPE_UINT},

{NULL}};

static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {
Expand Down
15 changes: 12 additions & 3 deletions src/components/tl/mlx5/tl_mlx5.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ typedef struct ucc_tl_mlx5_lib_config {
int dm_host;
ucc_tl_mlx5_ib_qp_conf_t qp_conf;
ucc_tl_mlx5_mcast_coll_comm_init_spec_t mcast_conf;
int nbr_serialized_batches;
int nbr_batches_per_passage;
int block_batch_size;
int force_regular;
int force_longer;
int force_wider;
} ucc_tl_mlx5_lib_config_t;

typedef struct ucc_tl_mlx5_context_config {
Expand Down Expand Up @@ -93,10 +99,13 @@ UCC_CLASS_DECLARE(ucc_tl_mlx5_context_t, const ucc_base_context_params_t*,

typedef struct ucc_tl_mlx5_task ucc_tl_mlx5_task_t;
typedef struct ucc_tl_mlx5_schedule ucc_tl_mlx5_schedule_t;
typedef struct ucc_tl_mlx5_dm_chunk {
ptrdiff_t offset; /* 0 based offset from the beginning of
memic_mr (obtained with ibv_reg_dm_mr) */
typedef struct ucc_tl_mlx5_dm_chunk_t {
uintptr_t addr; // 0 based offset from the beginning of
// memic_mr (obtained with ibv_reg_dm_mr)
ucc_tl_mlx5_schedule_t *task;
int posted_sends;
int posted_all;
int completed_sends;
} ucc_tl_mlx5_dm_chunk_t;

typedef struct ucc_tl_mlx5_alltoall ucc_tl_mlx5_alltoall_t;
Expand Down
3 changes: 2 additions & 1 deletion src/components/tl/mlx5/tl_mlx5_coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ typedef struct ucc_tl_mlx5_schedule {
int seq_num;
int seq_index;
int num_of_blocks_columns;
int block_size;
int block_height;
int block_width;
int started;
int send_blocks_enqueued;
int blocks_sent;
Expand Down
22 changes: 15 additions & 7 deletions src/components/tl/mlx5/tl_mlx5_dm.c
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,15 @@ static void ucc_tl_mlx5_dm_chunk_init(ucc_mpool_t *mp, //NOLINT
ucc_tl_mlx5_team_t *team =
ucc_container_of(mp, ucc_tl_mlx5_team_t, dm_pool);

c->offset = (ptrdiff_t)team->dm_offset;
team->dm_offset = PTR_OFFSET(team->dm_offset,
UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size);
c->addr = (uintptr_t)PTR_OFFSET(
(UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_host) ? team->dm_ptr : NULL,
team->dm_offset);
c->posted_sends = 0;
c->posted_all = 0;
c->completed_sends = 0;
team->dm_offset = PTR_OFFSET(
team->dm_offset, UCC_TL_MLX5_TEAM_LIB(team)->cfg.dm_buf_size *
UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size);
}

static ucc_mpool_ops_t ucc_tl_mlx5_dm_ops = {
Expand Down Expand Up @@ -219,13 +225,15 @@ ucc_status_t ucc_tl_mlx5_dm_init(ucc_tl_mlx5_team_t *team)
}

status = ucc_tl_mlx5_dm_alloc_reg(
ctx->shared_ctx, ctx->shared_pd, cfg->dm_host, cfg->dm_buf_size,
&cfg->dm_buf_num, &team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team));
ctx->shared_ctx, ctx->shared_pd, cfg->dm_host,
cfg->dm_buf_size * cfg->block_batch_size, &cfg->dm_buf_num,
&team->dm_ptr, &team->dm_mr, UCC_TL_TEAM_LIB(team));
if (status != UCC_OK) {
goto err_dm_alloc;
}
team->dm_offset = NULL;

team->dm_offset = 0;
// TODO: fix/check the case dm_host=true
ucc_assert(!cfg->dm_host);
status = ucc_mpool_init(
&team->dm_pool, 0, sizeof(ucc_tl_mlx5_dm_chunk_t), 0,
UCC_CACHE_LINE_SIZE, 1, cfg->dm_buf_num, &ucc_tl_mlx5_dm_ops,
Expand Down
3 changes: 2 additions & 1 deletion src/components/tl/mlx5/tl_mlx5_wqe.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ static inline uint8_t get_umr_mr_flags(uint32_t acc)

typedef struct transpose_seg {
__be32 element_size; /* 8 bit value */
__be16 num_rows; /* 7 bit value */
//From PRM we should have the rows first and then the colls. This is probably a naming error
__be16 num_cols; /* 7 bit value */
__be16 num_rows; /* 7 bit value */
__be64 padding;
} transpose_seg_t;

Expand Down
2 changes: 1 addition & 1 deletion test/gtest/tl/mlx5/test_tl_mlx5_wqe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ UCC_TEST_P(test_tl_mlx5_transpose, transposeWqe)

ibv_wr_start(qp.qp_ex);
post_transpose(qp.qp, src_mr->lkey, dst_mr->rkey, (uintptr_t)src,
(uintptr_t)dst, elem_size, nrows, ncols, IBV_SEND_SIGNALED);
(uintptr_t)dst, elem_size, ncols, nrows, IBV_SEND_SIGNALED);
GTEST_ASSERT_EQ(ibv_wr_complete(qp.qp_ex), 0);

while (!completions_num) {
Expand Down
Loading