diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 17735c7867..2a1af759b5 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -96,17 +96,14 @@ static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team, { ucc_tl_mlx5_alltoall_t *a2a = team->a2a; int seq_index = task->alltoall.seq_index; - int npolls = - UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls; - int radix = - UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; - int vrank = - a2a->node.sbgp->group_rank - a2a->node.asr_rank; - int *dist = &a2a->node.fanin_dist; - int size = a2a->node.sbgp->group_size; - int seq_num = task->alltoall.seq_num; - int c_flag = 0; - int polls, peer, vpeer, pos, i; + int npolls = UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls; + int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix; + int vrank = a2a->node.sbgp->group_rank - a2a->node.asr_rank; + int *dist = &a2a->node.fanin_dist; + int size = a2a->node.sbgp->group_size; + int seq_num = task->alltoall.seq_num; + int c_flag = 0; + int polls, peer, vpeer, pos, i; ucc_tl_mlx5_alltoall_ctrl_t *ctrl_v; while (*dist <= a2a->node.fanin_max_dist) { @@ -455,31 +452,34 @@ ucc_tl_mlx5_a2a_wait_for_dm_chunk(ucc_tl_mlx5_schedule_t *task) // add polling mechanism for blocks in order to maintain const qp tx rx static ucc_status_t ucc_tl_mlx5_send_blocks_start(ucc_coll_task_t *coll_task) { - ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); - ucc_base_lib_t *lib = UCC_TASK_LIB(task); - ucc_tl_mlx5_team_t *team = TASK_TEAM(&task->super); - ucc_tl_mlx5_alltoall_t *a2a = team->a2a; - int node_size = a2a->node.sbgp->group_size; - int net_size = a2a->net.sbgp->group_size; - int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * a2a->max_num_of_columns; - int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; - int block_h = task->alltoall.block_height; - int block_w = task->alltoall.block_width; - int col_msgsize = task->alltoall.msg_size * block_w * node_size; - int line_msgsize = task->alltoall.msg_size * block_h * node_size; - int block_msgsize = block_h * block_w * task->alltoall.msg_size; - ucc_status_t status = UCC_OK; - int node_grid_w = node_size / block_w; - int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); - int seq_index = task->alltoall.seq_index; - int block_row = 0; - int block_col = 0; - uint64_t remote_addr = 0; - ucc_tl_mlx5_dm_chunk_t *dm = NULL; - int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; - int nbr_serialized_batches = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; - int nbr_batches_per_passage = UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; - int i, j, k, send_to_self, block_idx, rank, dest_rank, cyc_rank, node_idx; + ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task); + ucc_base_lib_t * lib = UCC_TASK_LIB(task); + ucc_tl_mlx5_team_t * team = TASK_TEAM(&task->super); + ucc_tl_mlx5_alltoall_t *a2a = team->a2a; + int node_size = a2a->node.sbgp->group_size; + int net_size = a2a->net.sbgp->group_size; + int op_msgsize = node_size * a2a->max_msg_size * UCC_TL_TEAM_SIZE(team) * + a2a->max_num_of_columns; + int node_msgsize = SQUARED(node_size) * task->alltoall.msg_size; + int block_h = task->alltoall.block_height; + int block_w = task->alltoall.block_width; + int col_msgsize = task->alltoall.msg_size * block_w * node_size; + int line_msgsize = task->alltoall.msg_size * block_h * node_size; + int block_msgsize = block_h * block_w * task->alltoall.msg_size; + ucc_status_t status = UCC_OK; + int node_grid_w = node_size / block_w; + int node_nbr_blocks = (node_size * node_size) / (block_h * block_w); + int seq_index = task->alltoall.seq_index; + int block_row = 0; + int block_col = 0; + uint64_t remote_addr = 0; + ucc_tl_mlx5_dm_chunk_t *dm = NULL; + int batch_size = UCC_TL_MLX5_TEAM_LIB(team)->cfg.block_batch_size; + int nbr_serialized_batches = + UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_serialized_batches; + int nbr_batches_per_passage = + UCC_TL_MLX5_TEAM_LIB(team)->cfg.nbr_batches_per_passage; + int i, j, k, send_to_self, block_idx, rank, dest_rank, cyc_rank, node_idx; uint64_t src_addr; coll_task->status = UCC_INPROGRESS;