Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Extend MPI dialect #123255

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions mlir/include/mlir/Dialect/MPI/IR/MPI.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,43 @@ def MPI_ErrorClassAttr : EnumAttr<MPI_Dialect, MPI_ErrorClassEnum, "errclass"> {
let assemblyFormat = "`<` $value `>`";
}

def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">;
def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">;
def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">;
def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">;
def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">;
def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">;
def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">;
def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">;
def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">;
def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">;
def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">;
def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">;
def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">;
def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">;

def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [
MPI_OpNull,
MPI_OpMax,
MPI_OpMin,
MPI_OpSum,
MPI_OpProd,
MPI_OpLand,
MPI_OpBand,
MPI_OpLor,
MPI_OpBor,
MPI_OpLxor,
MPI_OpBxor,
MPI_OpMinloc,
MPI_OpMaxloc,
MPI_OpReplace
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::mpi";
}

def MPI_OpClassAttr : EnumAttr<MPI_Dialect, MPI_OpClassEnum, "opclass"> {
let assemblyFormat = "`<` $value `>`";
}

#endif // MLIR_DIALECT_MPI_IR_MPI_TD
242 changes: 225 additions & 17 deletions mlir/include/mlir/Dialect/MPI/IR/MPIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,73 @@ def MPI_InitOp : MPI_Op<"init", []> {

def MPI_CommRankOp : MPI_Op<"comm_rank", []> {
let summary = "Get the current rank, equivalent to "
"`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`";
"`MPI_Comm_rank(comm, &rank)`";
let description = [{
Communicators other than `MPI_COMM_WORLD` are not supported for now.
If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins Optional<MPI_Comm> : $comm);

let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $rank
);

let assemblyFormat = "attr-dict `:` type(results)";
let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
// CommSizeOp
//===----------------------------------------------------------------------===//

def MPI_CommSizeOp : MPI_Op<"comm_size", []> {
let summary = "Get the size of the group associated to the communicator, "
"equivalent to `MPI_Comm_size(comm, &size)`";
let description = [{
If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins Optional<MPI_Comm> : $comm);

let results = (
outs Optional<MPI_Retval> : $retval,
I32 : $size
);

let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
// CommSplitOp
//===----------------------------------------------------------------------===//

def MPI_CommSplit : MPI_Op<"comm_split", []> {
let summary = "Partition the group associated to the given communicator into "
"disjoint subgroups";
let description = [{
This operation splits the communicator into multiple sub-communicators.
The color value determines the group of processes that will be part of the
new communicator. The key value determines the rank of the calling process
in the new communicator.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key);

let results = (
outs Optional<MPI_Retval> : $retval,
MPI_Comm : $newcomm
);

let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` type(results)";
}

//===----------------------------------------------------------------------===//
Expand All @@ -65,59 +118,217 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> {

def MPI_SendOp : MPI_Op<"send", []> {
let summary =
"Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`";
"Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
`dest`. The `tag` value and communicator enables the library to determine
the matching of multiple sends and receives between the same ranks.

Communicators other than `MPI_COMM_WORLD` are not supprted for now.
If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank,
Optional<MPI_Comm> : $comm
);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)? `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($rank)"
"(`->` type($retval)^)?";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// ISendOp
//===----------------------------------------------------------------------===//

def MPI_ISendOp : MPI_Op<"isend", []> {
let summary =
"Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`";
let description = [{
MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to
rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank,
Optional<MPI_Comm> : $comm
);

let results = (outs Optional<MPI_Retval>:$retval, MPI_Request : $req);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank) "
"(`,` type($comm))? `->` (type($retval) `,` ^)? type($req)";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// RecvOp
//===----------------------------------------------------------------------===//

def MPI_RecvOp : MPI_Op<"recv", []> {
let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, "
"MPI_COMM_WORLD, MPI_STATUS_IGNORE)`";
"comm, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Recv performs a blocking receive of `size` elements of type `dtype`
from rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

Communicators other than `MPI_COMM_WORLD` are not supprted for now.
If communicator is not specified, `MPI_COMM_WORLD` is used by default.
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank);
let arguments = (
ins AnyMemRef : $ref,
I32 : $tag, I32 : $rank,
Optional<MPI_Comm> : $comm
);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` "
"type($ref) `,` type($tag) `,` type($rank)"
"(`->` type($retval)^)?";
let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank) "
"(`,` type($comm))? (`->` type($retval)^)?";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// IRecvOp
//===----------------------------------------------------------------------===//

def MPI_IRecvOp : MPI_Op<"irecv", []> {
let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, "
"comm, &req)`";
let description = [{
MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype`
from rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $ref,
I32 : $tag,
I32 : $rank,
Optional<MPI_Comm> : $comm
);

let results = (outs Optional<MPI_Retval>:$retval, MPI_Request : $req);

let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict "
"`:` type($ref) `,` type($tag) `,` type($rank)"
"(`,` type($comm))? `->` (type($retval) `,` ^)? type($req)";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// AllReduceOp
//===----------------------------------------------------------------------===//

def MPI_AllReduceOp : MPI_Op<"allreduce", []> {
let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`";
let description = [{
MPI_Allreduce performs a reduction operation on the values in the sendbuf
array and stores the result in the recvbuf array. The operation is
performed across all processes in the communicator.

The `op` attribute specifies the reduction operation to be performed.
Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are
supported.

If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (
ins AnyMemRef : $sendbuf,
AnyMemRef : $recvbuf,
MPI_OpClassAttr : $op,
Optional<MPI_Comm> : $comm
);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm)?`)` "
"attr-dict `:` type($sendbuf) `,` type($recvbuf) `,` "
"type($op) (`,` type($comm))? (`->` type($retval)^)?";
let hasCanonicalizer = 1;
}

//===----------------------------------------------------------------------===//
// BarrierOp
//===----------------------------------------------------------------------===//

def MPI_Barrier : MPI_Op<"barrier", []> {
let summary = "Equivalent to `MPI_Barrier(comm)`";
let description = [{
MPI_Barrier blocks execution until all processes in the communicator have
reached this routine.

If communicator is not specified, `MPI_COMM_WORLD` is used by default.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins Optional<MPI_Comm> : $comm);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type($retval)^";
}

//===----------------------------------------------------------------------===//
// WaitOp
//===----------------------------------------------------------------------===//

def MPI_Wait : MPI_Op<"wait", []> {
let summary = "Equivalent to `MPI_Wait(req, MPI_STATUS_IGNORE)`";
let description = [{
MPI_Wait blocks execution until the request has completed.

The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
}];

let arguments = (ins MPI_Request : $req);

let results = (outs Optional<MPI_Retval>:$retval);

let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) `->` type($retval)^";
}

//===----------------------------------------------------------------------===//
// FinalizeOp
Expand All @@ -139,7 +350,6 @@ def MPI_FinalizeOp : MPI_Op<"finalize", []> {
let assemblyFormat = "attr-dict (`:` type($retval)^)?";
}


//===----------------------------------------------------------------------===//
// RetvalCheckOp
//===----------------------------------------------------------------------===//
Expand All @@ -163,10 +373,8 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> {
let assemblyFormat = "$val `=` $errclass attr-dict `:` type($res)";
}



//===----------------------------------------------------------------------===//
// RetvalCheckOp
// ErrorClassOp
//===----------------------------------------------------------------------===//

def MPI_ErrorClassOp : MPI_Op<"error_class", []> {
Expand Down
Loading