Skip to content

Commit

Permalink
CODESTYLE: fix alignments and minor comments
Browse files Browse the repository at this point in the history
  • Loading branch information
samnordmann committed Dec 30, 2024
1 parent 0bf47f8 commit 6cdea8d
Showing 1 changed file with 38 additions and 43 deletions.
81 changes: 38 additions & 43 deletions src/components/tl/mlx5/alltoall/alltoall_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,17 @@ static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team,
{
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
int seq_index = task->alltoall.seq_index;
int npolls = UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls;
int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix;
int vrank = a2a->node.sbgp->group_rank - a2a->node.asr_rank;
int *dist = &a2a->node.fanin_dist;
int size = a2a->node.sbgp->group_size;
int seq_num = task->alltoall.seq_num;
int c_flag = 0;
int polls;
int peer, vpeer, pos, i;
int npolls =
UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls;
int radix =
UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix;
int vrank =
a2a->node.sbgp->group_rank - a2a->node.asr_rank;
int *dist = &a2a->node.fanin_dist;
int size = a2a->node.sbgp->group_size;
int seq_num = task->alltoall.seq_num;
int c_flag = 0;
int polls, peer, vpeer, pos, i;
ucc_tl_mlx5_alltoall_ctrl_t *ctrl_v;

while (*dist <= a2a->node.fanin_max_dist) {
Expand Down Expand Up @@ -282,12 +284,10 @@ static ucc_status_t ucc_tl_mlx5_fanout_start(ucc_coll_task_t *coll_task)

tl_debug(UCC_TASK_LIB(task), "fanout start");
/* start task if completion event received */
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0);
if (team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank) {
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(
task, "mlx5_alltoall_wait-on-data_start", 0);
} else {
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start",
0);
}
/* Start fanout */
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task);
Expand Down Expand Up @@ -455,34 +455,32 @@ ucc_tl_mlx5_a2a_wait_for_dm_chunk(ucc_tl_mlx5_schedule_t *task)
// add polling mechanism for blocks in order to maintain const qp tx rx
static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task);
ucc_base_lib_t *lib = UCC_TASK_LIB(task);
ucc_tl_mlx5_team_t *team = TASK_TEAM(&task->super);
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
int node_size = a2a->node.sbgp->group_size;
int net_size = a2a->net.sbgp->group_size;
int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) *
a2a->max_num_of_columns;
int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size;
int block_h = task->alltoall.block_height;
int block_w = task->alltoall.block_width;
int col_msgsize = task->alltoall.msg_size * block_w * node_size;
int line_msgsize = task->alltoall.msg_size * block_h * node_size;
int block_msgsize = block_h * block_w * task->alltoall.msg_size;
ucc_status_t status = UCC_OK;
int node_grid_w = node_size / block_w;
int node_nbr_blocks = (node_size * node_size) / (block_h * block_w);
int seq_index = task->alltoall.seq_index;
int block_row = 0, block_col = 0;
uint64_t remote_addr = 0;
ucc_tl_mlx5_dm_chunk_t *dm = NULL;
int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size;
int nbr_serialized_batches =
UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches;
int nbr_batches_per_passage =
UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage;
int i, j, k, send_to_self, block_idx, rank, dest_rank, cyc_rank, node_idx;
uint64_t src_addr;
ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task);
ucc_base_lib_t *lib = UCC_TASK_LIB(task);
ucc_tl_mlx5_team_t *team = TASK_TEAM(&task->super);
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
int node_size = a2a->node.sbgp->group_size;
int net_size = a2a->net.sbgp->group_size;
int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * a2a->max_num_of_columns;
int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size;
int block_h = task->alltoall.block_height;
int block_w = task->alltoall.block_width;
int col_msgsize = task->alltoall.msg_size * block_w * node_size;
int line_msgsize = task->alltoall.msg_size * block_h * node_size;
int block_msgsize = block_h * block_w * task->alltoall.msg_size;
ucc_status_t status = UCC_OK;
int node_grid_w = node_size / block_w;
int node_nbr_blocks = (node_size * node_size) / (block_h * block_w);
int seq_index = task->alltoall.seq_index;
int block_row = 0;
int block_col = 0;
uint64_t remote_addr = 0;
ucc_tl_mlx5_dm_chunk_t *dm = NULL;
int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size;
int nbr_serialized_batches = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches;
int nbr_batches_per_passage = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage;
int i, j, k, send_to_self, block_idx, rank, dest_rank, cyc_rank, node_idx;
uint64_t src_addr;

coll_task->status = UCC_INPROGRESS;
coll_task->super.status = UCC_INPROGRESS;
Expand All @@ -507,9 +505,6 @@ static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task)
if (!send_to_self &&
task->alltoall.op->blocks_sent[cyc_rank] < node_nbr_blocks) {
dm = ucc_tl_mlx5_a2a_wait_for_dm_chunk(task);
if (status != UCC_OK) {
return status;
}
}
send_start(team, cyc_rank);
for (i = 0; i < nbr_serialized_batches; i++) {
Expand Down

0 comments on commit 6cdea8d

Please sign in to comment.