From 1fbae54306aff0c55ec544675bedb961574c2837 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 16 Jan 2025 22:31:59 +0100 Subject: [PATCH 01/42] Add `MPI_Comm`, `MPI_Request`, `MPI_Status`, `MPI_Op` type definitions --- mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 30 ++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td index 87eefa719d45c0..1d96b49d16585b 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -40,4 +40,34 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> { }]; } +// TODO +def MPI_Comm : MPI_Type<"Comm", "comm"> { + let summary = "..." + let description = [{ + This type represents a handler to the MPI communicator. + }] +} + +// TODO +def MPI_Request : MPI_Type<"Request", "request"> { + let summary = "..." + let description = [{ + This type represents a handler to an asynchronous requests. + }] +} + +// TODO +def MPI_Status : MPI_Type<"Status", "status"> { + let summary = ""; + let description = [{ + }]; +} + +// TODO +def MPI_Op : MPI_Type<"Op", "op"> { + let summary = ""; + let description = [{ + }]; +} + #endif // MLIR_DIALECT_MPI_IR_MPITYPES_TD From dc84ca4b87412dcfa9c83c48ae9916663c7ca38e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 16 Jan 2025 23:02:23 +0100 Subject: [PATCH 02/42] Add `MPI_CommSize`, `MPI_ISend`, `MPI_IRecv` ops --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 85 ++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 240fac5104c34f..8719b67cd7f5f0 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -59,6 +59,28 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { let assemblyFormat = "attr-dict `:` type(results)"; } +//===----------------------------------------------------------------------===// +// CommSizeOp +//===----------------------------------------------------------------------===// + +def MPI_CommSizeOp : MPI_Op<"comm_size", []> { + let summary = "Get the size of the group associated to the communicator, equivalent to " + "`MPI_Comm_size(MPI_COMM_WORLD, &size)`"; + let description = [{ + Communicators other than `MPI_COMM_WORLD` are not supported for now. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let results = ( + outs Optional : $retval, + I32 : $size + ); + + let assemblyFormat = "attr-dict `:` type(results)"; +} + //===----------------------------------------------------------------------===// // SendOp //===----------------------------------------------------------------------===// @@ -87,6 +109,37 @@ def MPI_SendOp : MPI_Op<"send", []> { let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// ISendOp +//===----------------------------------------------------------------------===// + +// TODO what about request handler? +// NOTE datatype & count args are implicit by the type of the first argument (i.e. memref of eltype) +// NOTE other communicators not yet supported by the `mpi` dialect +def MPI_ISendOp : MPI_Op<"isend", []> { + let summary = + "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`"; + let description = [{ + MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to rank + `dest`. The `tag` value and communicator enables the library to determine + the matching of multiple sends and receives between the same ranks. + + Communicators other than `MPI_COMM_WORLD` are not supprted for now. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank); + + let results = (outs Optional:$retval); + + let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " + "type($ref) `,` type($tag) `,` type($rank)" + "(`->` type($retval)^)?"; + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // RecvOp //===----------------------------------------------------------------------===// @@ -118,6 +171,38 @@ def MPI_RecvOp : MPI_Op<"recv", []> { let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// IRecvOp +//===----------------------------------------------------------------------===// + +// TODO same as MPI_ISendOp +def MPI_IRecvOp : MPI_Op<"irecv", []> { + let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, " + "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`"; + let description = [{ + MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` + from rank `dest`. The `tag` value and communicator enables the library to + determine the matching of multiple sends and receives between the same + ranks. + + Communicators other than `MPI_COMM_WORLD` are not supprted for now. + The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object + is not yet ported to MLIR. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank); + + let results = (outs Optional:$retval); + + let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " + "type($ref) `,` type($tag) `,` type($rank)" + "(`->` type($retval)^)?"; + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // FinalizeOp From 2ee10ab60bb793b9164b0795e6657ceccc4704ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 16 Jan 2025 23:02:34 +0100 Subject: [PATCH 03/42] Fix typo --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 8719b67cd7f5f0..4be5a6dfea7777 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -251,7 +251,7 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> { //===----------------------------------------------------------------------===// -// RetvalCheckOp +// ErrorClassOp //===----------------------------------------------------------------------===// def MPI_ErrorClassOp : MPI_Op<"error_class", []> { From 539bf43b5cf705e64183d7d84f08e7132dac3872 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 25 Jan 2025 20:59:08 +0100 Subject: [PATCH 04/42] Finish types --- mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 24 ++++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td index 1d96b49d16585b..20cde07d9a4b98 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -40,7 +40,10 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> { }]; } -// TODO +//===----------------------------------------------------------------------===// +// mpi::CommType +//===----------------------------------------------------------------------===// + def MPI_Comm : MPI_Type<"Comm", "comm"> { let summary = "..." let description = [{ @@ -48,25 +51,36 @@ def MPI_Comm : MPI_Type<"Comm", "comm"> { }] } -// TODO +//===----------------------------------------------------------------------===// +// mpi::RequestType +//===----------------------------------------------------------------------===// + def MPI_Request : MPI_Type<"Request", "request"> { let summary = "..." let description = [{ - This type represents a handler to an asynchronous requests. + This type represents a handler to an asynchronous request. }] } -// TODO +//===----------------------------------------------------------------------===// +// mpi::StatusType +//===----------------------------------------------------------------------===// + def MPI_Status : MPI_Type<"Status", "status"> { let summary = ""; let description = [{ + This type represents the status of a reception operation. }]; } -// TODO +//===----------------------------------------------------------------------===// +// mpi::OpType +//===----------------------------------------------------------------------===// + def MPI_Op : MPI_Type<"Op", "op"> { let summary = ""; let description = [{ + This type represents a handle to a operation that can be used in MPI reduce and scan routines. }]; } From 662998d610c56eb83dd9897a94b8bd4131e95191 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 26 Jan 2025 11:35:36 +0100 Subject: [PATCH 05/42] Define `MPI_Op` enum & attr --- mlir/include/mlir/Dialect/MPI/IR/MPI.td | 40 +++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td index 643612e1e2ee89..182de03a5a8057 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td @@ -215,4 +215,44 @@ def MPI_ErrorClassAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } +// TODO is it ok to have them as I32? +def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">; +def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">; +def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">; +def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">; +def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">; +def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">; +def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">; +def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">; +def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">; +def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">; +def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">; +def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">; +def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">; +def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">; + +def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [ + MPI_OpNull, + MPI_OpMax, + MPI_OpMin, + MPI_OpSum, + MPI_OpProd, + MPI_OpLand, + MPI_OpBand, + MPI_OpLor, + MPI_OpBor, + MPI_OpLxor, + MPI_OpBxor, + MPI_OpMinloc, + MPI_OpMaxloc, + MPI_OpReplace + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mpi"; +} + +def MPI_OpClassAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + #endif // MLIR_DIALECT_MPI_IR_MPI_TD From c1ec63c24ff9b1af1a4dde393b1e90767605044b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 26 Jan 2025 20:18:36 +0100 Subject: [PATCH 06/42] Add communicator argument to mpi ops as optional input argument --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 38 ++++++++++++---------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 4be5a6dfea7777..1330313c41a8c3 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -43,14 +43,16 @@ def MPI_InitOp : MPI_Op<"init", []> { def MPI_CommRankOp : MPI_Op<"comm_rank", []> { let summary = "Get the current rank, equivalent to " - "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`"; + "`MPI_Comm_rank(comm, &rank)`"; let description = [{ - Communicators other than `MPI_COMM_WORLD` are not supported for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; + let arguments = (ins Optional : $comm); + let results = ( outs Optional : $retval, I32 : $rank @@ -65,14 +67,16 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { def MPI_CommSizeOp : MPI_Op<"comm_size", []> { let summary = "Get the size of the group associated to the communicator, equivalent to " - "`MPI_Comm_size(MPI_COMM_WORLD, &size)`"; + "`MPI_Comm_size(comm, &size)`"; let description = [{ - Communicators other than `MPI_COMM_WORLD` are not supported for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; + let arguments = (ins Optional : $comm); + let results = ( outs Optional : $retval, I32 : $size @@ -87,19 +91,19 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> { def MPI_SendOp : MPI_Op<"send", []> { let summary = - "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`"; + "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`"; let description = [{ MPI_Send performs a blocking send of `size` elements of type `dtype` to rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supprted for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; - let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank); + let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional : $comm); let results = (outs Optional:$retval); @@ -115,22 +119,21 @@ def MPI_SendOp : MPI_Op<"send", []> { // TODO what about request handler? // NOTE datatype & count args are implicit by the type of the first argument (i.e. memref of eltype) -// NOTE other communicators not yet supported by the `mpi` dialect def MPI_ISendOp : MPI_Op<"isend", []> { let summary = - "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`"; + "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`"; let description = [{ MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supprted for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; - let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank); + let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional : $comm); let results = (outs Optional:$retval); @@ -146,14 +149,14 @@ def MPI_ISendOp : MPI_Op<"isend", []> { def MPI_RecvOp : MPI_Op<"recv", []> { let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, " - "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`"; + "comm, MPI_STATUS_IGNORE)`"; let description = [{ MPI_Recv performs a blocking receive of `size` elements of type `dtype` from rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supprted for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object is not yet ported to MLIR. @@ -161,7 +164,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> { to check for errors. }]; - let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank); + let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional : $comm); let results = (outs Optional:$retval); @@ -175,17 +178,16 @@ def MPI_RecvOp : MPI_Op<"recv", []> { // IRecvOp //===----------------------------------------------------------------------===// -// TODO same as MPI_ISendOp def MPI_IRecvOp : MPI_Op<"irecv", []> { let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, " - "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`"; + "comm, MPI_STATUS_IGNORE)`"; let description = [{ MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` from rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - Communicators other than `MPI_COMM_WORLD` are not supprted for now. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object is not yet ported to MLIR. @@ -193,7 +195,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { to check for errors. }]; - let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank); + let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional : $comm); let results = (outs Optional:$retval); From 7eda7915fd48b8b52117a2a8075cfc466b25c905 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 26 Jan 2025 20:21:40 +0100 Subject: [PATCH 07/42] Add summary of new mpi types --- mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td index 20cde07d9a4b98..a7f96c0530883b 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -45,10 +45,10 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> { //===----------------------------------------------------------------------===// def MPI_Comm : MPI_Type<"Comm", "comm"> { - let summary = "..." + let summary = "MPI communicator handler"; let description = [{ This type represents a handler to the MPI communicator. - }] + }]; } //===----------------------------------------------------------------------===// @@ -56,10 +56,10 @@ def MPI_Comm : MPI_Type<"Comm", "comm"> { //===----------------------------------------------------------------------===// def MPI_Request : MPI_Type<"Request", "request"> { - let summary = "..." + let summary = "MPI asynchronous request handler"; let description = [{ This type represents a handler to an asynchronous request. - }] + }]; } //===----------------------------------------------------------------------===// @@ -67,7 +67,7 @@ def MPI_Request : MPI_Type<"Request", "request"> { //===----------------------------------------------------------------------===// def MPI_Status : MPI_Type<"Status", "status"> { - let summary = ""; + let summary = "MPI reception operation status type"; let description = [{ This type represents the status of a reception operation. }]; @@ -78,7 +78,7 @@ def MPI_Status : MPI_Type<"Status", "status"> { //===----------------------------------------------------------------------===// def MPI_Op : MPI_Type<"Op", "op"> { - let summary = ""; + let summary = "MPI operation handler"; let description = [{ This type represents a handle to a operation that can be used in MPI reduce and scan routines. }]; From b97a541bb700df7eaea365f4221d351433bfc3bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 26 Jan 2025 20:26:37 +0100 Subject: [PATCH 08/42] format code --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 1330313c41a8c3..de62536bf80d5f 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -205,7 +205,6 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let hasCanonicalizer = 1; } - //===----------------------------------------------------------------------===// // FinalizeOp //===----------------------------------------------------------------------===// @@ -226,7 +225,6 @@ def MPI_FinalizeOp : MPI_Op<"finalize", []> { let assemblyFormat = "attr-dict (`:` type($retval)^)?"; } - //===----------------------------------------------------------------------===// // RetvalCheckOp //===----------------------------------------------------------------------===// @@ -250,8 +248,6 @@ def MPI_RetvalCheckOp : MPI_Op<"retval_check", []> { let assemblyFormat = "$val `=` $errclass attr-dict `:` type($res)"; } - - //===----------------------------------------------------------------------===// // ErrorClassOp //===----------------------------------------------------------------------===// From d5725a8994ab1f913377abcbe56cbab9e52f0cd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 26 Jan 2025 20:50:49 +0100 Subject: [PATCH 09/42] Add `mpi.comm_split` op --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index de62536bf80d5f..913e7ab9986eff 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -85,6 +85,33 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> { let assemblyFormat = "attr-dict `:` type(results)"; } +//===----------------------------------------------------------------------===// +// CommSplitOp +//===----------------------------------------------------------------------===// + +def MPI_CommSplit : MPI_Op<"comm_split", []> { + let summary = "Partition the group associated to the given communicator into " + "disjoint subgroups"; + let description = [{ + This operation splits the communicator into multiple sub-communicators. + The color value determines the group of processes that will be part of the + new communicator. The key value determines the rank of the calling process + in the new communicator. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = (MPI_Comm : $comm, I32 : $color, I32 : $key); + + let results = ( + outs Optional : $retval, + MPI_Comm : $newcomm + ); + + let assemblyFormat = "`(` $color `,` $key `)` attr-dict `:` type(results)"; +} + //===----------------------------------------------------------------------===// // SendOp //===----------------------------------------------------------------------===// From 1a68b34dc147adabb04507d3ad80628424ea01bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 26 Jan 2025 20:52:10 +0100 Subject: [PATCH 10/42] Add `mpi.barrier` op --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 913e7ab9986eff..78d624bc19f783 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -232,6 +232,25 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let hasCanonicalizer = 1; } +def MPI_Barrier : MPI_Op<"barrier", []> { + let summary = "Equivalent to `MPI_Barrier(comm)`"; + let description = [{ + MPI_Barrier blocks execution until all processes in the communicator have + reached this routine. + + If communicator is not specified, `MPI_COMM_WORLD` is used by default. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = (ins Optional : $comm); + + let results = (outs Optional:$retval); + + let assemblyFormat = "attr-dict `:` type($retval)^"; +} + //===----------------------------------------------------------------------===// // FinalizeOp //===----------------------------------------------------------------------===// From 80a42592bb2dfff31915a33d8b4ec8d7d01527ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 27 Jan 2025 00:05:02 +0100 Subject: [PATCH 11/42] Format code --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 45 +++++++++++++++------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 78d624bc19f783..43128fc1c552c6 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -66,8 +66,8 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { //===----------------------------------------------------------------------===// def MPI_CommSizeOp : MPI_Op<"comm_size", []> { - let summary = "Get the size of the group associated to the communicator, equivalent to " - "`MPI_Comm_size(comm, &size)`"; + let summary = "Get the size of the group associated to the communicator, " + "equivalent to `MPI_Comm_size(comm, &size)`"; let description = [{ If communicator is not specified, `MPI_COMM_WORLD` is used by default. @@ -130,7 +130,12 @@ def MPI_SendOp : MPI_Op<"send", []> { to check for errors. }]; - let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional : $comm); + let arguments = ( + ins AnyMemRef : $ref, + I32 : $tag, + I32 : $rank, + Optional : $comm + ); let results = (outs Optional:$retval); @@ -145,14 +150,14 @@ def MPI_SendOp : MPI_Op<"send", []> { //===----------------------------------------------------------------------===// // TODO what about request handler? -// NOTE datatype & count args are implicit by the type of the first argument (i.e. memref of eltype) def MPI_ISendOp : MPI_Op<"isend", []> { let summary = "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`"; let description = [{ - MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to rank - `dest`. The `tag` value and communicator enables the library to determine - the matching of multiple sends and receives between the same ranks. + MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to + rank `dest`. The `tag` value and communicator enables the library to + determine the matching of multiple sends and receives between the same + ranks. If communicator is not specified, `MPI_COMM_WORLD` is used by default. @@ -160,7 +165,12 @@ def MPI_ISendOp : MPI_Op<"isend", []> { to check for errors. }]; - let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional : $comm); + let arguments = ( + ins AnyMemRef : $ref, + I32 : $tag, + I32 : $rank, + Optional : $comm + ); let results = (outs Optional:$retval); @@ -191,7 +201,11 @@ def MPI_RecvOp : MPI_Op<"recv", []> { to check for errors. }]; - let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional : $comm); + let arguments = ( + ins AnyMemRef : $ref, + I32 : $tag, I32 : $rank, + Optional : $comm + ); let results = (outs Optional:$retval); @@ -207,7 +221,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> { def MPI_IRecvOp : MPI_Op<"irecv", []> { let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, " - "comm, MPI_STATUS_IGNORE)`"; + "comm, &req)`"; let description = [{ MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` from rank `dest`. The `tag` value and communicator enables the library to @@ -215,16 +229,19 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { ranks. If communicator is not specified, `MPI_COMM_WORLD` is used by default. - The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object - is not yet ported to MLIR. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; - let arguments = (ins AnyMemRef : $ref, I32 : $tag, I32 : $rank, Optional : $comm); + let arguments = ( + ins AnyMemRef : $ref, + I32 : $tag, + I32 : $rank, + Optional : $comm + ); - let results = (outs Optional:$retval); + let results = (outs Optional:$retval, MPI_Request : $req); let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" From cfb81af015f25f99542885c9eb024040edc37cb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 27 Jan 2025 00:31:26 +0100 Subject: [PATCH 12/42] Fix ops returning `MPI_Request` --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 43128fc1c552c6..f1e7f94dcbb5d1 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -149,7 +149,6 @@ def MPI_SendOp : MPI_Op<"send", []> { // ISendOp //===----------------------------------------------------------------------===// -// TODO what about request handler? def MPI_ISendOp : MPI_Op<"isend", []> { let summary = "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`"; @@ -172,11 +171,11 @@ def MPI_ISendOp : MPI_Op<"isend", []> { Optional : $comm ); - let results = (outs Optional:$retval); + let results = (outs Optional:$retval, MPI_Request : $req); let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" - "(`->` type($retval)^)?"; + "`->` (type($retval) `,` ^)? type($req)"; let hasCanonicalizer = 1; } @@ -245,7 +244,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" - "(`->` type($retval)^)?"; + "`->` (type($retval) `,` ^)? type($req)"; let hasCanonicalizer = 1; } From 740cf0b6dcaedc8da0f77cc029502ec7302edc2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 27 Jan 2025 00:44:14 +0100 Subject: [PATCH 13/42] Add `mpi.wait` op --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index f1e7f94dcbb5d1..9c5954aff2879d 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -248,6 +248,10 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// BarrierOp +//===----------------------------------------------------------------------===// + def MPI_Barrier : MPI_Op<"barrier", []> { let summary = "Equivalent to `MPI_Barrier(comm)`"; let description = [{ @@ -267,6 +271,29 @@ def MPI_Barrier : MPI_Op<"barrier", []> { let assemblyFormat = "attr-dict `:` type($retval)^"; } +//===----------------------------------------------------------------------===// +// WaitOp +//===----------------------------------------------------------------------===// + +def MPI_Wait : MPI_Op<"wait", []> { + let summary = "Equivalent to `MPI_Wait(req, MPI_STATUS_IGNORE)`"; + let description = [{ + MPI_Wait blocks execution until the request has completed. + + The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object + is not yet ported to MLIR. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = (MPI_Request : $req); + + let results = (outs Optional:$retval); + + let assemblyFormat = "attr-dict `:` type($retval)^"; +} + //===----------------------------------------------------------------------===// // FinalizeOp //===----------------------------------------------------------------------===// From 1af142547f6ee19bc22207f8891c590c1aa78bd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 27 Jan 2025 00:55:10 +0100 Subject: [PATCH 14/42] Add `mpi.allreduce` op --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 32 ++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 9c5954aff2879d..ff48f323fe8c16 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -248,6 +248,38 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// AllReduceOp +//===----------------------------------------------------------------------===// + +def MPI_AllReduceOp : MPI_Op<"allreduce", []> { + let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`"; + let description = [{ + MPI_Allreduce performs a reduction operation on the values in the sendbuf + array and stores the result in the recvbuf array. The operation is + performed across all processes in the communicator. + + If communicator is not specified, `MPI_COMM_WORLD` is used by default. + + This operation can optionally return an `!mpi.retval` value that can be used + to check for errors. + }]; + + let arguments = ( + ins AnyMemRef : $sendbuf, + AnyMemRef : $recvbuf, + MPI_Op : $op, + Optional : $comm + ); + + let results = (outs Optional:$retval); + + let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm)?`)` " + "attr-dict `:` type($sendbuf) `,` type($recvbuf) `,` " + "type($op) (`->` type($retval)^)?"; + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // BarrierOp //===----------------------------------------------------------------------===// From c11a60f8419e446f127adafc282eccfe378553dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 27 Jan 2025 01:08:34 +0100 Subject: [PATCH 15/42] Fix assembly formats --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 32 +++++++++++----------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index ff48f323fe8c16..5f4a39966350ea 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -58,7 +58,7 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { I32 : $rank ); - let assemblyFormat = "attr-dict `:` type(results)"; + let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type(results)"; } //===----------------------------------------------------------------------===// @@ -82,7 +82,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> { I32 : $size ); - let assemblyFormat = "attr-dict `:` type(results)"; + let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type(results)"; } //===----------------------------------------------------------------------===// @@ -109,7 +109,7 @@ def MPI_CommSplit : MPI_Op<"comm_split", []> { MPI_Comm : $newcomm ); - let assemblyFormat = "`(` $color `,` $key `)` attr-dict `:` type(results)"; + let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` type(results)"; } //===----------------------------------------------------------------------===// @@ -139,7 +139,7 @@ def MPI_SendOp : MPI_Op<"send", []> { let results = (outs Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)? `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" "(`->` type($retval)^)?"; let hasCanonicalizer = 1; @@ -173,9 +173,9 @@ def MPI_ISendOp : MPI_Op<"isend", []> { let results = (outs Optional:$retval, MPI_Request : $req); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " - "type($ref) `,` type($tag) `,` type($rank)" - "`->` (type($retval) `,` ^)? type($req)"; + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict " + "`:` type($ref) `,` type($tag) `,` type($rank) " + "(`,` type($comm))? `->` (type($retval) `,` ^)? type($req)"; let hasCanonicalizer = 1; } @@ -208,9 +208,9 @@ def MPI_RecvOp : MPI_Op<"recv", []> { let results = (outs Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " - "type($ref) `,` type($tag) `,` type($rank)" - "(`->` type($retval)^)?"; + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict " + "`:` type($ref) `,` type($tag) `,` type($rank) " + "(`,` type($comm))? (`->` type($retval)^)?"; let hasCanonicalizer = 1; } @@ -242,9 +242,9 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let results = (outs Optional:$retval, MPI_Request : $req); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " - "type($ref) `,` type($tag) `,` type($rank)" - "`->` (type($retval) `,` ^)? type($req)"; + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict " + "`:` type($ref) `,` type($tag) `,` type($rank)" + "(`,` type($comm))? `->` (type($retval) `,` ^)? type($req)"; let hasCanonicalizer = 1; } @@ -276,7 +276,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm)?`)` " "attr-dict `:` type($sendbuf) `,` type($recvbuf) `,` " - "type($op) (`->` type($retval)^)?"; + "type($op) (`,` type($comm))? (`->` type($retval)^)?"; let hasCanonicalizer = 1; } @@ -300,7 +300,7 @@ def MPI_Barrier : MPI_Op<"barrier", []> { let results = (outs Optional:$retval); - let assemblyFormat = "attr-dict `:` type($retval)^"; + let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type($retval)^"; } //===----------------------------------------------------------------------===// @@ -323,7 +323,7 @@ def MPI_Wait : MPI_Op<"wait", []> { let results = (outs Optional:$retval); - let assemblyFormat = "attr-dict `:` type($retval)^"; + let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) `->` type($retval)^"; } //===----------------------------------------------------------------------===// From d971d8335a7bd43e7652fd5876489a392e9fd6cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 27 Jan 2025 01:13:57 +0100 Subject: [PATCH 16/42] add some tests --- mlir/test/Dialect/MPI/ops.mlir | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir index 8f2421a73396c2..17f2fc9453e464 100644 --- a/mlir/test/Dialect/MPI/ops.mlir +++ b/mlir/test/Dialect/MPI/ops.mlir @@ -9,6 +9,9 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32 %retval, %rank = mpi.comm_rank : !mpi.retval, i32 + // CHECK-NEXT: %retval2, %size = mpi.comm_size : !mpi.retval, i32 + %retval2, %size = mpi.comm_size : !mpi.retval, i32 + // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 @@ -21,6 +24,27 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + // CHECK-NEXT: mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request + %req1 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request + + // CHECK-NEXT: %1 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request + %err2, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request + + // CHECK-NEXT: mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request + %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 + + // CHECK-NEXT: %2 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request + %err3, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + + // CHECK-NEXT: mpi.wait(%req1) : mpi.request + mpi.wait(%req1) : mpi.request + + // CHECK-NEXT: %3 = mpi.wait(%req1) : mpi.request -> !mpi.retval + %err4 = mpi.wait(%req1) : mpi.request -> !mpi.retval + + // CHECK-NEXT: mpi.barrier : !mpi.retval + mpi.barrier : !mpi.retval + // CHECK-NEXT: %3 = mpi.finalize : !mpi.retval %rval = mpi.finalize : !mpi.retval From beb57646d002e049d12c870ee75d50361177d032 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 27 Jan 2025 11:13:09 +0100 Subject: [PATCH 17/42] Fix input specifier --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 5f4a39966350ea..534849f960ca54 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -102,7 +102,7 @@ def MPI_CommSplit : MPI_Op<"comm_split", []> { to check for errors. }]; - let arguments = (MPI_Comm : $comm, I32 : $color, I32 : $key); + let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key); let results = ( outs Optional : $retval, @@ -319,7 +319,7 @@ def MPI_Wait : MPI_Op<"wait", []> { to check for errors. }]; - let arguments = (MPI_Request : $req); + let arguments = (ins MPI_Request : $req); let results = (outs Optional:$retval); From 231799487a5d48da1a04ac40e132d8eed85b732c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 27 Jan 2025 11:27:33 +0100 Subject: [PATCH 18/42] Comment predefined constant MPI_Ops --- mlir/include/mlir/Dialect/MPI/IR/MPI.td | 72 ++++++++++++------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td index 182de03a5a8057..8b45b98a24a866 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td @@ -216,43 +216,43 @@ def MPI_ErrorClassAttr : EnumAttr { } // TODO is it ok to have them as I32? -def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">; -def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">; -def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">; -def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">; -def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">; -def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">; -def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">; -def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">; -def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">; -def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">; -def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">; -def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">; -def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">; -def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">; +// def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">; +// def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">; +// def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">; +// def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">; +// def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">; +// def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">; +// def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">; +// def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">; +// def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">; +// def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">; +// def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">; +// def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">; +// def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">; +// def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">; -def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [ - MPI_OpNull, - MPI_OpMax, - MPI_OpMin, - MPI_OpSum, - MPI_OpProd, - MPI_OpLand, - MPI_OpBand, - MPI_OpLor, - MPI_OpBor, - MPI_OpLxor, - MPI_OpBxor, - MPI_OpMinloc, - MPI_OpMaxloc, - MPI_OpReplace - ]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::mpi"; -} +// def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [ +// MPI_OpNull, +// MPI_OpMax, +// MPI_OpMin, +// MPI_OpSum, +// MPI_OpProd, +// MPI_OpLand, +// MPI_OpBand, +// MPI_OpLor, +// MPI_OpBor, +// MPI_OpLxor, +// MPI_OpBxor, +// MPI_OpMinloc, +// MPI_OpMaxloc, +// MPI_OpReplace +// ]> { +// let genSpecializedAttr = 0; +// let cppNamespace = "::mlir::mpi"; +// } -def MPI_OpClassAttr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} +// def MPI_OpClassAttr : EnumAttr { +// let assemblyFormat = "`<` $value `>`"; +// } #endif // MLIR_DIALECT_MPI_IR_MPI_TD From 63ccc331d86d7a35bb374c32fda58573aeccec7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 27 Jan 2025 12:05:45 +0100 Subject: [PATCH 19/42] Replace `MPI_Op` new type for region --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 3 ++- mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 12 ++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 534849f960ca54..faab2196895539 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -268,10 +268,11 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let arguments = ( ins AnyMemRef : $sendbuf, AnyMemRef : $recvbuf, - MPI_Op : $op, Optional : $comm ); + let regions = (region SizedRegion<1>:$op); + let results = (outs Optional:$retval); let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm)?`)` " diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td index a7f96c0530883b..3a07790a9ebad9 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -77,11 +77,11 @@ def MPI_Status : MPI_Type<"Status", "status"> { // mpi::OpType //===----------------------------------------------------------------------===// -def MPI_Op : MPI_Type<"Op", "op"> { - let summary = "MPI operation handler"; - let description = [{ - This type represents a handle to a operation that can be used in MPI reduce and scan routines. - }]; -} +// def MPI_Op : MPI_Type<"Op", "op"> { +// let summary = "MPI operation handler"; +// let description = [{ +// This type represents a handle to a operation that can be used in MPI reduce and scan routines. +// }]; +// } #endif // MLIR_DIALECT_MPI_IR_MPITYPES_TD From d318c60729cc6ae7f91325c6c1df5528a68dc87a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 28 Jan 2025 13:05:57 +0100 Subject: [PATCH 20/42] Go back to only use predefined MPI_Ops --- mlir/include/mlir/Dialect/MPI/IR/MPI.td | 73 ++++++++++---------- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 7 +- mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 9 ++- 3 files changed, 48 insertions(+), 41 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPI.td b/mlir/include/mlir/Dialect/MPI/IR/MPI.td index 8b45b98a24a866..7c84443e5520d9 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPI.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPI.td @@ -215,44 +215,43 @@ def MPI_ErrorClassAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } -// TODO is it ok to have them as I32? -// def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">; -// def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">; -// def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">; -// def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">; -// def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">; -// def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">; -// def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">; -// def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">; -// def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">; -// def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">; -// def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">; -// def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">; -// def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">; -// def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">; +def MPI_OpNull : I32EnumAttrCase<"MPI_OP_NULL", 0, "MPI_OP_NULL">; +def MPI_OpMax : I32EnumAttrCase<"MPI_MAX", 1, "MPI_MAX">; +def MPI_OpMin : I32EnumAttrCase<"MPI_MIN", 2, "MPI_MIN">; +def MPI_OpSum : I32EnumAttrCase<"MPI_SUM", 3, "MPI_SUM">; +def MPI_OpProd : I32EnumAttrCase<"MPI_PROD", 4, "MPI_PROD">; +def MPI_OpLand : I32EnumAttrCase<"MPI_LAND", 5, "MPI_LAND">; +def MPI_OpBand : I32EnumAttrCase<"MPI_BAND", 6, "MPI_BAND">; +def MPI_OpLor : I32EnumAttrCase<"MPI_LOR", 7, "MPI_LOR">; +def MPI_OpBor : I32EnumAttrCase<"MPI_BOR", 8, "MPI_BOR">; +def MPI_OpLxor : I32EnumAttrCase<"MPI_LXOR", 9, "MPI_LXOR">; +def MPI_OpBxor : I32EnumAttrCase<"MPI_BXOR", 10, "MPI_BXOR">; +def MPI_OpMinloc : I32EnumAttrCase<"MPI_MINLOC", 11, "MPI_MINLOC">; +def MPI_OpMaxloc : I32EnumAttrCase<"MPI_MAXLOC", 12, "MPI_MAXLOC">; +def MPI_OpReplace : I32EnumAttrCase<"MPI_REPLACE", 13, "MPI_REPLACE">; -// def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [ -// MPI_OpNull, -// MPI_OpMax, -// MPI_OpMin, -// MPI_OpSum, -// MPI_OpProd, -// MPI_OpLand, -// MPI_OpBand, -// MPI_OpLor, -// MPI_OpBor, -// MPI_OpLxor, -// MPI_OpBxor, -// MPI_OpMinloc, -// MPI_OpMaxloc, -// MPI_OpReplace -// ]> { -// let genSpecializedAttr = 0; -// let cppNamespace = "::mlir::mpi"; -// } +def MPI_OpClassEnum : I32EnumAttr<"MPI_OpClassEnum", "MPI operation class", [ + MPI_OpNull, + MPI_OpMax, + MPI_OpMin, + MPI_OpSum, + MPI_OpProd, + MPI_OpLand, + MPI_OpBand, + MPI_OpLor, + MPI_OpBor, + MPI_OpLxor, + MPI_OpBxor, + MPI_OpMinloc, + MPI_OpMaxloc, + MPI_OpReplace + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::mpi"; +} -// def MPI_OpClassAttr : EnumAttr { -// let assemblyFormat = "`<` $value `>`"; -// } +def MPI_OpClassAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} #endif // MLIR_DIALECT_MPI_IR_MPI_TD diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index faab2196895539..981c30b5afb2d6 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -259,6 +259,10 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { array and stores the result in the recvbuf array. The operation is performed across all processes in the communicator. + The `op` attribute specifies the reduction operation to be performed. + Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are + supported. + If communicator is not specified, `MPI_COMM_WORLD` is used by default. This operation can optionally return an `!mpi.retval` value that can be used @@ -268,11 +272,10 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let arguments = ( ins AnyMemRef : $sendbuf, AnyMemRef : $recvbuf, + MPI_OpClassAttr : $op, Optional : $comm ); - let regions = (region SizedRegion<1>:$op); - let results = (outs Optional:$retval); let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm)?`)` " diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td index 3a07790a9ebad9..c47081ce77fba1 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -77,10 +77,15 @@ def MPI_Status : MPI_Type<"Status", "status"> { // mpi::OpType //===----------------------------------------------------------------------===// -// def MPI_Op : MPI_Type<"Op", "op"> { +// def MPI_Operation : MPI_Type<"Op", "op"> { // let summary = "MPI operation handler"; // let description = [{ -// This type represents a handle to a operation that can be used in MPI reduce and scan routines. +// This type represents a handle to a operation that can be used in MPI reduce +// and scan routines. + +// The TableGen definition is named `MPI_Operation` instead of the `MPI_Op` +// name of the MPI standard to avoid conflict with the `MPI_Op` definition to +// define MLIR ops. // }]; // } From 8e3aa18dc5162250f0c7ea89b4de97c9d6250c0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 09:16:21 +0100 Subject: [PATCH 21/42] Remove `MPI_Operation` type --- mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td index c47081ce77fba1..868132a62abc4b 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -73,20 +73,4 @@ def MPI_Status : MPI_Type<"Status", "status"> { }]; } -//===----------------------------------------------------------------------===// -// mpi::OpType -//===----------------------------------------------------------------------===// - -// def MPI_Operation : MPI_Type<"Op", "op"> { -// let summary = "MPI operation handler"; -// let description = [{ -// This type represents a handle to a operation that can be used in MPI reduce -// and scan routines. - -// The TableGen definition is named `MPI_Operation` instead of the `MPI_Op` -// name of the MPI standard to avoid conflict with the `MPI_Op` definition to -// define MLIR ops. -// }]; -// } - #endif // MLIR_DIALECT_MPI_IR_MPITYPES_TD From 9c708d466010848cb197a54511882807117fbb31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 09:16:36 +0100 Subject: [PATCH 22/42] Add `mpi.comm_world` op to return `MPI_COMM_WORLD` --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 981c30b5afb2d6..97e36a2e01b0cc 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -37,6 +37,21 @@ def MPI_InitOp : MPI_Op<"init", []> { let assemblyFormat = "attr-dict (`:` type($retval)^)?"; } +//===----------------------------------------------------------------------===// +// CommWorldOp +//===----------------------------------------------------------------------===// + +def MPI_CommWorldOp : MPI_Op<"comm_world", []> { + let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`"; + let description = [{ + This operation returns the predefined MPI_COMM_WORLD communicator. + }]; + + let results = (outs MPI_Comm : $comm); + + let assemblyFormat = "attr-dict `:` type(results)"; +} + //===----------------------------------------------------------------------===// // CommRankOp //===----------------------------------------------------------------------===// From 326b13fa15d6a55bd94b85982ca8f33d6145bb10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 09:27:23 +0100 Subject: [PATCH 23/42] Add tests --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 4 +- mlir/test/Dialect/MPI/ops.mlir | 55 +++++++++++++++++----- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 97e36a2e01b0cc..af324845893c5a 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -294,8 +294,8 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let results = (outs Optional:$retval); let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm)?`)` " - "attr-dict `:` type($sendbuf) `,` type($recvbuf) `,` " - "type($op) (`,` type($comm))? (`->` type($retval)^)?"; + "attr-dict `:` type($sendbuf) `,` type($recvbuf) " + "(`,` type($comm))? (`->` type($retval)^)?"; let hasCanonicalizer = 1; } diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir index 17f2fc9453e464..9fee6e83ceb4c2 100644 --- a/mlir/test/Dialect/MPI/ops.mlir +++ b/mlir/test/Dialect/MPI/ops.mlir @@ -12,46 +12,79 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK-NEXT: %retval2, %size = mpi.comm_size : !mpi.retval, i32 %retval2, %size = mpi.comm_size : !mpi.retval, i32 + // CHECK-NEXT: %comm = mpi.comm_world : mpi.comm + %comm = mpi.comm_world : mpi.comm + + // CHECK-NEXT: %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.retval, i32 + %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : mpi.comm, i32, i32 + // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 // CHECK-NEXT: %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval %err2 = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + // CHECK-NEXT: mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm + mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm + // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval - // CHECK-NEXT: mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request + // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm + mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm + + // CHECK-NEXT: %3 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request %req1 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request - // CHECK-NEXT: %1 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request - %err2, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request + // CHECK-NEXT: %4, %5 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request + %err4, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request + + // CHECK-NEXT: %3 = mpi.isend(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request + %req1 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request - // CHECK-NEXT: mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request + // CHECK-NEXT: %6 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 - // CHECK-NEXT: %2 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request - %err3, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + // CHECK-NEXT: %7, %8 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request + %err5, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + + // CHECK-NEXT: %6 = mpi.irecv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request + %req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm // CHECK-NEXT: mpi.wait(%req1) : mpi.request mpi.wait(%req1) : mpi.request - // CHECK-NEXT: %3 = mpi.wait(%req1) : mpi.request -> !mpi.retval - %err4 = mpi.wait(%req1) : mpi.request -> !mpi.retval + // CHECK-NEXT: %9 = mpi.wait(%req1) : mpi.request -> !mpi.retval + %err6 = mpi.wait(%req2) : mpi.request -> !mpi.retval // CHECK-NEXT: mpi.barrier : !mpi.retval mpi.barrier : !mpi.retval - // CHECK-NEXT: %3 = mpi.finalize : !mpi.retval + // CHECK-NEXT: %10 = mpi.barrier : !mpi.retval + %err7 = mpi.barrier : !mpi.retval + + // CHECK-NEXT: mpi.barrier(%comm) : !mpi.retval + mpi.barrier(%comm) : !mpi.retval + + // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> + mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> + + // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval + %err8 = mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> + + // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, mpi.comm + mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, mpi.comm + + // CHECK-NEXT: %11 = mpi.finalize : !mpi.retval %rval = mpi.finalize : !mpi.retval - // CHECK-NEXT: %4 = mpi.retval_check %retval = : i1 + // CHECK-NEXT: %12 = mpi.retval_check %retval = : i1 %res = mpi.retval_check %retval = : i1 - // CHECK-NEXT: %5 = mpi.error_class %0 : !mpi.retval + // CHECK-NEXT: %13 = mpi.error_class %0 : !mpi.retval %errclass = mpi.error_class %err : !mpi.retval // CHECK-NEXT: return From 1fd5578547a5cbed7892c00246424a4c29da3674 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 11:08:15 +0100 Subject: [PATCH 24/42] Fix anchor of assembly format --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index af324845893c5a..f3e80387295db2 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -190,7 +190,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> { let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank) " - "(`,` type($comm))? `->` (type($retval) `,` ^)? type($req)"; + "(`,` ^ type($comm))? `->` (type($retval) `,` ^)? type($req)"; let hasCanonicalizer = 1; } @@ -225,7 +225,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> { let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank) " - "(`,` type($comm))? (`->` type($retval)^)?"; + "(`,` ^ type($comm))? (`->` type($retval)^)?"; let hasCanonicalizer = 1; } @@ -259,7 +259,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank)" - "(`,` type($comm))? `->` (type($retval) `,` ^)? type($req)"; + "(`,` ^ type($comm))? `->` (type($retval) `,` ^)? type($req)"; let hasCanonicalizer = 1; } @@ -295,7 +295,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm)?`)` " "attr-dict `:` type($sendbuf) `,` type($recvbuf) " - "(`,` type($comm))? (`->` type($retval)^)?"; + "(`,` ^ type($comm))? (`->` type($retval)^)?"; let hasCanonicalizer = 1; } From 016b856bfc6829097859ba09ee9c1fa8025b4fd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 11:43:01 +0100 Subject: [PATCH 25/42] Fix more anchors --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index f3e80387295db2..6501d35f1f0955 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -154,7 +154,7 @@ def MPI_SendOp : MPI_Op<"send", []> { let results = (outs Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)? `)` attr-dict `:` " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` ^ $comm)? `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" "(`->` type($retval)^)?"; let hasCanonicalizer = 1; @@ -188,7 +188,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> { let results = (outs Optional:$retval, MPI_Request : $req); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` ^ $comm)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank) " "(`,` ^ type($comm))? `->` (type($retval) `,` ^)? type($req)"; let hasCanonicalizer = 1; @@ -223,7 +223,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> { let results = (outs Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` ^ $comm)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank) " "(`,` ^ type($comm))? (`->` type($retval)^)?"; let hasCanonicalizer = 1; @@ -257,7 +257,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let results = (outs Optional:$retval, MPI_Request : $req); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm)?`)` attr-dict " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` ^ $comm)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank)" "(`,` ^ type($comm))? `->` (type($retval) `,` ^)? type($req)"; let hasCanonicalizer = 1; @@ -293,7 +293,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let results = (outs Optional:$retval); - let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm)?`)` " + let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` ^ $comm)?`)` " "attr-dict `:` type($sendbuf) `,` type($recvbuf) " "(`,` ^ type($comm))? (`->` type($retval)^)?"; let hasCanonicalizer = 1; From 1931b8eb6d8be445704df29d5cfd18486e0caa33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 12:05:58 +0100 Subject: [PATCH 26/42] Fix anchors again --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 6501d35f1f0955..fef4475ad2d2db 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -154,7 +154,7 @@ def MPI_SendOp : MPI_Op<"send", []> { let results = (outs Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` ^ $comm)? `)` attr-dict `:` " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)? `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" "(`->` type($retval)^)?"; let hasCanonicalizer = 1; @@ -188,9 +188,9 @@ def MPI_ISendOp : MPI_Op<"isend", []> { let results = (outs Optional:$retval, MPI_Request : $req); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` ^ $comm)?`)` attr-dict " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank) " - "(`,` ^ type($comm))? `->` (type($retval) `,` ^)? type($req)"; + "(`,` type($comm) ^)? `->` (type($retval) `,` ^)? type($req)"; let hasCanonicalizer = 1; } @@ -223,9 +223,9 @@ def MPI_RecvOp : MPI_Op<"recv", []> { let results = (outs Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` ^ $comm)?`)` attr-dict " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank) " - "(`,` ^ type($comm))? (`->` type($retval)^)?"; + "(`,` type($comm) ^)? (`->` type($retval)^)?"; let hasCanonicalizer = 1; } @@ -257,9 +257,9 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let results = (outs Optional:$retval, MPI_Request : $req); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` ^ $comm)?`)` attr-dict " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank)" - "(`,` ^ type($comm))? `->` (type($retval) `,` ^)? type($req)"; + "(`,` type($comm) ^)? `->` (type($retval) `,` ^)? type($req)"; let hasCanonicalizer = 1; } @@ -293,9 +293,9 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let results = (outs Optional:$retval); - let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` ^ $comm)?`)` " + let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm ^)?`)` " "attr-dict `:` type($sendbuf) `,` type($recvbuf) " - "(`,` ^ type($comm))? (`->` type($retval)^)?"; + "(`,` type($comm) ^)? (`->` type($retval)^)?"; let hasCanonicalizer = 1; } From aec9fbd20d6f8f0b5239f5a8a22b2d365d228f6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 12:11:52 +0100 Subject: [PATCH 27/42] fix another anchor --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index fef4475ad2d2db..0a35b5d9c8ea59 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -319,7 +319,7 @@ def MPI_Barrier : MPI_Op<"barrier", []> { let results = (outs Optional:$retval); - let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type($retval)^"; + let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type($retval)^"; } //===----------------------------------------------------------------------===// From d4684fbe60673aa0eeeea3dea33ead2703d5415e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 14:09:13 +0100 Subject: [PATCH 28/42] fix optional format of `MPI_BarrierOp` --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 0a35b5d9c8ea59..3e250694b82874 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -319,7 +319,7 @@ def MPI_Barrier : MPI_Op<"barrier", []> { let results = (outs Optional:$retval); - let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type($retval)^"; + let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($retval) ^)?"; } //===----------------------------------------------------------------------===// From 794fa25de3865fd7ae364772aef3079f03b1d6f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 14:23:09 +0100 Subject: [PATCH 29/42] fix more anchors --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 3e250694b82874..3b7fef674f18fb 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -73,7 +73,7 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { I32 : $rank ); - let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type(results)"; + let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)"; } //===----------------------------------------------------------------------===// @@ -97,7 +97,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> { I32 : $size ); - let assemblyFormat = "(`(` $comm `)`)? attr-dict `:` type(results)"; + let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)"; } //===----------------------------------------------------------------------===// @@ -342,7 +342,7 @@ def MPI_Wait : MPI_Op<"wait", []> { let results = (outs Optional:$retval); - let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) `->` type($retval)^"; + let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) (`->` type($retval) ^)?"; } //===----------------------------------------------------------------------===// From f0d0f44a1a9675ef007cf9dcf405c330af270c7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 14:32:43 +0100 Subject: [PATCH 30/42] fix anchors in `MPI_ISendOp` and `MPI_IRecvOp` --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 3b7fef674f18fb..aa88d56dbe7b6c 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -190,7 +190,7 @@ def MPI_ISendOp : MPI_Op<"isend", []> { let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank) " - "(`,` type($comm) ^)? `->` (type($retval) `,` ^)? type($req)"; + "(`,` type($comm) ^)? `->` (type($retval) ^ `,`)? type($req)"; let hasCanonicalizer = 1; } @@ -259,7 +259,7 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank)" - "(`,` type($comm) ^)? `->` (type($retval) `,` ^)? type($req)"; + "(`,` type($comm) ^)? `->` (type($retval) ^ `,`)? type($req)"; let hasCanonicalizer = 1; } From 92f2ccab2fac24af18dcf2dc7f2986cd90fa2788 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 14:51:26 +0100 Subject: [PATCH 31/42] fix format --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index aa88d56dbe7b6c..3f872a221df131 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -186,11 +186,11 @@ def MPI_ISendOp : MPI_Op<"isend", []> { Optional : $comm ); - let results = (outs Optional:$retval, MPI_Request : $req); + let results = (outs MPI_Request : $req, Optional:$retval); let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank) " - "(`,` type($comm) ^)? `->` (type($retval) ^ `,`)? type($req)"; + "(`,` type($comm) ^)? `->` type($req) (`,` type($retval) ^)?"; let hasCanonicalizer = 1; } @@ -255,11 +255,11 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { Optional : $comm ); - let results = (outs Optional:$retval, MPI_Request : $req); + let results = (outs MPI_Request : $req, Optional:$retval); let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank)" - "(`,` type($comm) ^)? `->` (type($retval) ^ `,`)? type($req)"; + "(`,` type($comm) ^)? `->` type($req) (`,` type($retval) ^)?"; let hasCanonicalizer = 1; } From 3688915dfc38419808cb93e2dfb468af5867806a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 15:09:37 +0100 Subject: [PATCH 32/42] Define `getCanonicalizationPatterns` for `ISendOp`, `IRecvOp`, `AllReduceOp` --- mlir/lib/Dialect/MPI/IR/MPIOps.cpp | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp index dcb55d8921364f..320a6800bc9041 100644 --- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp +++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp @@ -53,6 +53,26 @@ void mlir::mpi::RecvOp::getCanonicalizationPatterns( results.add>(context); } +void mlir::mpi::ISendOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results.add>(context); +} + +void mlir::mpi::IRecvOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results.add>(context); +} + +void mlir::mpi::SendOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results.add>(context); +} + +void mlir::mpi::AllReduceOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &results, mlir::MLIRContext *context) { + results.add>(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// From 3abe925dd3fd6251363040e3e004d8c583758268 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 15:19:55 +0100 Subject: [PATCH 33/42] remove duplicated `getCanonicalizationPatterns` --- mlir/lib/Dialect/MPI/IR/MPIOps.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp index 320a6800bc9041..86d9d311cadb8c 100644 --- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp +++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp @@ -63,11 +63,6 @@ void mlir::mpi::IRecvOp::getCanonicalizationPatterns( results.add>(context); } -void mlir::mpi::SendOp::getCanonicalizationPatterns( - mlir::RewritePatternSet &results, mlir::MLIRContext *context) { - results.add>(context); -} - void mlir::mpi::AllReduceOp::getCanonicalizationPatterns( mlir::RewritePatternSet &results, mlir::MLIRContext *context) { results.add>(context); From 7a9fa9c12ec2db70bac37b3600bd457e8a20c15e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 16:04:33 +0100 Subject: [PATCH 34/42] Remove canonicalization for `AllReduceOp` --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 1 - mlir/lib/Dialect/MPI/IR/MPIOps.cpp | 5 ----- 2 files changed, 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 3f872a221df131..06f01a3d747964 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -296,7 +296,6 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm ^)?`)` " "attr-dict `:` type($sendbuf) `,` type($recvbuf) " "(`,` type($comm) ^)? (`->` type($retval)^)?"; - let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp index 86d9d311cadb8c..56d8edfbcc0255 100644 --- a/mlir/lib/Dialect/MPI/IR/MPIOps.cpp +++ b/mlir/lib/Dialect/MPI/IR/MPIOps.cpp @@ -63,11 +63,6 @@ void mlir::mpi::IRecvOp::getCanonicalizationPatterns( results.add>(context); } -void mlir::mpi::AllReduceOp::getCanonicalizationPatterns( - mlir::RewritePatternSet &results, mlir::MLIRContext *context) { - results.add>(context); -} - //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// From 1926bda364f32507a56701b6d0695da263904b3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 16:26:33 +0100 Subject: [PATCH 35/42] fix test --- mlir/test/Dialect/MPI/ops.mlir | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir index 9fee6e83ceb4c2..c955f233769d3b 100644 --- a/mlir/test/Dialect/MPI/ops.mlir +++ b/mlir/test/Dialect/MPI/ops.mlir @@ -39,20 +39,20 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK-NEXT: %3 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request %req1 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request - // CHECK-NEXT: %4, %5 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request - %err4, %req2 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request + // CHECK-NEXT: %4, %5 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request, !mpi.retval + %req2, %err4, = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request, !mpi.retval // CHECK-NEXT: %3 = mpi.isend(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request %req1 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request // CHECK-NEXT: %6 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request - %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 + %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request - // CHECK-NEXT: %7, %8 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval, mpi.request - %err5, %req4 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval + // CHECK-NEXT: %7, %8 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request, !mpi.retval + %req4, %err5 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request, !mpi.retval // CHECK-NEXT: %6 = mpi.irecv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request - %req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm + %req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request // CHECK-NEXT: mpi.wait(%req1) : mpi.request mpi.wait(%req1) : mpi.request From 89ec1112a09666df09f122d5e2cda35e4230a00d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 16:40:49 +0100 Subject: [PATCH 36/42] fix some assembly formats --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 18 ++++++++++-------- mlir/test/Dialect/MPI/ops.mlir | 12 ++++++------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 06f01a3d747964..fd7c7a30e92891 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -69,8 +69,8 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { let arguments = (ins Optional : $comm); let results = ( - outs Optional : $retval, - I32 : $rank + outs I32 : $rank, + Optional : $retval, ); let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)"; @@ -93,8 +93,8 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> { let arguments = (ins Optional : $comm); let results = ( - outs Optional : $retval, - I32 : $size + outs I32 : $size, + Optional : $retval, ); let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)"; @@ -120,11 +120,12 @@ def MPI_CommSplit : MPI_Op<"comm_split", []> { let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key); let results = ( - outs Optional : $retval, - MPI_Comm : $newcomm + outs MPI_Comm : $newcomm, + Optional : $retval, ); - let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` type(results)"; + let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` " + "type(results)"; } //===----------------------------------------------------------------------===// @@ -190,7 +191,8 @@ def MPI_ISendOp : MPI_Op<"isend", []> { let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank) " - "(`,` type($comm) ^)? `->` type($req) (`,` type($retval) ^)?"; + "(`,` type($comm) ^)? `->` type($req)" + "(`,` type($retval) ^)?"; let hasCanonicalizer = 1; } diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir index c955f233769d3b..753dcd3cfdf2d8 100644 --- a/mlir/test/Dialect/MPI/ops.mlir +++ b/mlir/test/Dialect/MPI/ops.mlir @@ -6,17 +6,17 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK: %0 = mpi.init : !mpi.retval %err = mpi.init : !mpi.retval - // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32 - %retval, %rank = mpi.comm_rank : !mpi.retval, i32 + // CHECK-NEXT: %retval, %rank = mpi.comm_rank : i32, !mpi.retval + %retval, %rank = mpi.comm_rank : i32, !mpi.retval - // CHECK-NEXT: %retval2, %size = mpi.comm_size : !mpi.retval, i32 - %retval2, %size = mpi.comm_size : !mpi.retval, i32 + // CHECK-NEXT: %retval2, %size = mpi.comm_size : i32, !mpi.retval + %retval2, %size = mpi.comm_size : i32, !mpi.retval // CHECK-NEXT: %comm = mpi.comm_world : mpi.comm %comm = mpi.comm_world : mpi.comm - // CHECK-NEXT: %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : !mpi.retval, i32 - %retval3, %new_comm = mpi.comm_split(%comm, %rank, %rank) : mpi.comm, i32, i32 + // CHECK-NEXT: %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : i32, !mpi.retval + %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : mpi.comm, i32, i32 // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 From 30fb67349ee02b5ba33f7e76abd451b076420e71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 29 Jan 2025 18:08:38 +0100 Subject: [PATCH 37/42] fix syntax --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index fd7c7a30e92891..4bc32cd258f690 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -70,7 +70,7 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { let results = ( outs I32 : $rank, - Optional : $retval, + Optional : $retval ); let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)"; @@ -94,7 +94,7 @@ def MPI_CommSizeOp : MPI_Op<"comm_size", []> { let results = ( outs I32 : $size, - Optional : $retval, + Optional : $retval ); let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)"; @@ -121,7 +121,7 @@ def MPI_CommSplit : MPI_Op<"comm_split", []> { let results = ( outs MPI_Comm : $newcomm, - Optional : $retval, + Optional : $retval ); let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` " From 6abba5a37d5ea73c2b177581db9d476da4a26c91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 30 Jan 2025 00:27:43 +0100 Subject: [PATCH 38/42] Remove MPI_Comm type --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 133 ++++++------------- mlir/include/mlir/Dialect/MPI/IR/MPITypes.td | 11 -- mlir/test/Dialect/MPI/ops.mlir | 24 ---- 3 files changed, 40 insertions(+), 128 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 4bc32cd258f690..f3d21466795282 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -37,43 +37,26 @@ def MPI_InitOp : MPI_Op<"init", []> { let assemblyFormat = "attr-dict (`:` type($retval)^)?"; } -//===----------------------------------------------------------------------===// -// CommWorldOp -//===----------------------------------------------------------------------===// - -def MPI_CommWorldOp : MPI_Op<"comm_world", []> { - let summary = "Get the World communicator, equivalent to `MPI_COMM_WORLD`"; - let description = [{ - This operation returns the predefined MPI_COMM_WORLD communicator. - }]; - - let results = (outs MPI_Comm : $comm); - - let assemblyFormat = "attr-dict `:` type(results)"; -} - //===----------------------------------------------------------------------===// // CommRankOp //===----------------------------------------------------------------------===// def MPI_CommRankOp : MPI_Op<"comm_rank", []> { let summary = "Get the current rank, equivalent to " - "`MPI_Comm_rank(comm, &rank)`"; + "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`"; let description = [{ - If communicator is not specified, `MPI_COMM_WORLD` is used by default. - + Communicators other than `MPI_COMM_WORLD` are not supported for now. + This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; - let arguments = (ins Optional : $comm); - let results = ( outs I32 : $rank, Optional : $retval ); - let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)"; + let assemblyFormat = "attr-dict `:` type(results)"; } //===----------------------------------------------------------------------===// @@ -82,50 +65,20 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { def MPI_CommSizeOp : MPI_Op<"comm_size", []> { let summary = "Get the size of the group associated to the communicator, " - "equivalent to `MPI_Comm_size(comm, &size)`"; + "equivalent to `MPI_Comm_size(MPI_COMM_WORLD, &size)`"; let description = [{ - If communicator is not specified, `MPI_COMM_WORLD` is used by default. + Communicators other than `MPI_COMM_WORLD` are not supported for now. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; - let arguments = (ins Optional : $comm); - let results = ( outs I32 : $size, Optional : $retval ); - let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict `:` type(results)"; -} - -//===----------------------------------------------------------------------===// -// CommSplitOp -//===----------------------------------------------------------------------===// - -def MPI_CommSplit : MPI_Op<"comm_split", []> { - let summary = "Partition the group associated to the given communicator into " - "disjoint subgroups"; - let description = [{ - This operation splits the communicator into multiple sub-communicators. - The color value determines the group of processes that will be part of the - new communicator. The key value determines the rank of the calling process - in the new communicator. - - This operation can optionally return an `!mpi.retval` value that can be used - to check for errors. - }]; - - let arguments = (ins MPI_Comm : $comm, I32 : $color, I32 : $key); - - let results = ( - outs MPI_Comm : $newcomm, - Optional : $retval - ); - - let assemblyFormat = "`(` $comm `,` $color `,` $key `)` attr-dict `:` " - "type(results)"; + let assemblyFormat = "attr-dict `:` type(results)"; } //===----------------------------------------------------------------------===// @@ -134,13 +87,13 @@ def MPI_CommSplit : MPI_Op<"comm_split", []> { def MPI_SendOp : MPI_Op<"send", []> { let summary = - "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, comm)`"; + "Equivalent to `MPI_Send(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`"; let description = [{ MPI_Send performs a blocking send of `size` elements of type `dtype` to rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - If communicator is not specified, `MPI_COMM_WORLD` is used by default. + Communicators other than `MPI_COMM_WORLD` are not supported for now. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. @@ -149,13 +102,12 @@ def MPI_SendOp : MPI_Op<"send", []> { let arguments = ( ins AnyMemRef : $ref, I32 : $tag, - I32 : $rank, - Optional : $comm + I32 : $rank ); let results = (outs Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)? `)` attr-dict `:` " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:` " "type($ref) `,` type($tag) `,` type($rank)" "(`->` type($retval)^)?"; let hasCanonicalizer = 1; @@ -167,14 +119,14 @@ def MPI_SendOp : MPI_Op<"send", []> { def MPI_ISendOp : MPI_Op<"isend", []> { let summary = - "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, comm)`"; + "Equivalent to `MPI_Isend(ptr, size, dtype, dest, tag, MPI_COMM_WORLD)`"; let description = [{ MPI_Isend begins a non-blocking send of `size` elements of type `dtype` to rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - If communicator is not specified, `MPI_COMM_WORLD` is used by default. + Communicators other than `MPI_COMM_WORLD` are not supported for now. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. @@ -183,16 +135,14 @@ def MPI_ISendOp : MPI_Op<"isend", []> { let arguments = ( ins AnyMemRef : $ref, I32 : $tag, - I32 : $rank, - Optional : $comm + I32 : $rank ); let results = (outs MPI_Request : $req, Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " + let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict " "`:` type($ref) `,` type($tag) `,` type($rank) " - "(`,` type($comm) ^)? `->` type($req)" - "(`,` type($retval) ^)?"; + "`->` type($req) (`,` type($retval) ^)?"; let hasCanonicalizer = 1; } @@ -202,14 +152,14 @@ def MPI_ISendOp : MPI_Op<"isend", []> { def MPI_RecvOp : MPI_Op<"recv", []> { let summary = "Equivalent to `MPI_Recv(ptr, size, dtype, dest, tag, " - "comm, MPI_STATUS_IGNORE)`"; + "MPI_COMM_WORLD, MPI_STATUS_IGNORE)`"; let description = [{ MPI_Recv performs a blocking receive of `size` elements of type `dtype` from rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - If communicator is not specified, `MPI_COMM_WORLD` is used by default. + Communicators other than `MPI_COMM_WORLD` are not supported for now. The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object is not yet ported to MLIR. @@ -219,15 +169,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> { let arguments = ( ins AnyMemRef : $ref, - I32 : $tag, I32 : $rank, - Optional : $comm + I32 : $tag, I32 : $rank ); let results = (outs Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " - "`:` type($ref) `,` type($tag) `,` type($rank) " - "(`,` type($comm) ^)? (`->` type($retval)^)?"; + let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`" + "type($ref) `,` type($tag) `,` type($rank) " + "(`->` type($retval)^)?"; let hasCanonicalizer = 1; } @@ -237,14 +186,14 @@ def MPI_RecvOp : MPI_Op<"recv", []> { def MPI_IRecvOp : MPI_Op<"irecv", []> { let summary = "Equivalent to `MPI_Irecv(ptr, size, dtype, dest, tag, " - "comm, &req)`"; + "MPI_COMM_WORLD, &req)`"; let description = [{ MPI_Irecv begins a non-blocking receive of `size` elements of type `dtype` from rank `dest`. The `tag` value and communicator enables the library to determine the matching of multiple sends and receives between the same ranks. - If communicator is not specified, `MPI_COMM_WORLD` is used by default. + Communicators other than `MPI_COMM_WORLD` are not supported for now. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. @@ -253,15 +202,14 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { let arguments = ( ins AnyMemRef : $ref, I32 : $tag, - I32 : $rank, - Optional : $comm + I32 : $rank ); let results = (outs MPI_Request : $req, Optional:$retval); - let assemblyFormat = "`(` $ref `,` $tag `,` $rank (`,` $comm ^)?`)` attr-dict " - "`:` type($ref) `,` type($tag) `,` type($rank)" - "(`,` type($comm) ^)? `->` type($req) (`,` type($retval) ^)?"; + let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`" + "type($ref) `,` type($tag) `,` type($rank) `->`" + "type($req) (`,` type($retval) ^)?"; let hasCanonicalizer = 1; } @@ -270,7 +218,8 @@ def MPI_IRecvOp : MPI_Op<"irecv", []> { //===----------------------------------------------------------------------===// def MPI_AllReduceOp : MPI_Op<"allreduce", []> { - let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, comm)`"; + let summary = "Equivalent to `MPI_Allreduce(sendbuf, recvbuf, op, " + "MPI_COMM_WORLD)`"; let description = [{ MPI_Allreduce performs a reduction operation on the values in the sendbuf array and stores the result in the recvbuf array. The operation is @@ -280,7 +229,7 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { Currently only the `MPI_Op` predefined in the standard (e.g. `MPI_SUM`) are supported. - If communicator is not specified, `MPI_COMM_WORLD` is used by default. + Communicators other than `MPI_COMM_WORLD` are not supported for now. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. @@ -289,15 +238,14 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { let arguments = ( ins AnyMemRef : $sendbuf, AnyMemRef : $recvbuf, - MPI_OpClassAttr : $op, - Optional : $comm + MPI_OpClassAttr : $op ); let results = (outs Optional:$retval); - let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op (`,` $comm ^)?`)` " - "attr-dict `:` type($sendbuf) `,` type($recvbuf) " - "(`,` type($comm) ^)? (`->` type($retval)^)?"; + let assemblyFormat = "`(` $sendbuf `,` $recvbuf `,` $op `)` attr-dict `:`" + "type($sendbuf) `,` type($recvbuf)" + "(`->` type($retval)^)?"; } //===----------------------------------------------------------------------===// @@ -305,22 +253,20 @@ def MPI_AllReduceOp : MPI_Op<"allreduce", []> { //===----------------------------------------------------------------------===// def MPI_Barrier : MPI_Op<"barrier", []> { - let summary = "Equivalent to `MPI_Barrier(comm)`"; + let summary = "Equivalent to `MPI_Barrier(MPI_COMM_WORLD)`"; let description = [{ MPI_Barrier blocks execution until all processes in the communicator have reached this routine. - If communicator is not specified, `MPI_COMM_WORLD` is used by default. + Communicators other than `MPI_COMM_WORLD` are not supported for now. This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }]; - let arguments = (ins Optional : $comm); - let results = (outs Optional:$retval); - let assemblyFormat = "(`(` $comm ^ `)`)? attr-dict (`:` type($retval) ^)?"; + let assemblyFormat = "attr-dict (`:` type($retval) ^)?"; } //===----------------------------------------------------------------------===// @@ -343,7 +289,8 @@ def MPI_Wait : MPI_Op<"wait", []> { let results = (outs Optional:$retval); - let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) (`->` type($retval) ^)?"; + let assemblyFormat = "`(` $req `)` attr-dict `:` type($req) " + "(`->` type($retval) ^)?"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td index 868132a62abc4b..fafea0eac8bb74 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPITypes.td @@ -40,17 +40,6 @@ def MPI_Retval : MPI_Type<"Retval", "retval"> { }]; } -//===----------------------------------------------------------------------===// -// mpi::CommType -//===----------------------------------------------------------------------===// - -def MPI_Comm : MPI_Type<"Comm", "comm"> { - let summary = "MPI communicator handler"; - let description = [{ - This type represents a handler to the MPI communicator. - }]; -} - //===----------------------------------------------------------------------===// // mpi::RequestType //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir index 753dcd3cfdf2d8..37b84e88c7d4cf 100644 --- a/mlir/test/Dialect/MPI/ops.mlir +++ b/mlir/test/Dialect/MPI/ops.mlir @@ -12,48 +12,30 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK-NEXT: %retval2, %size = mpi.comm_size : i32, !mpi.retval %retval2, %size = mpi.comm_size : i32, !mpi.retval - // CHECK-NEXT: %comm = mpi.comm_world : mpi.comm - %comm = mpi.comm_world : mpi.comm - - // CHECK-NEXT: %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : i32, !mpi.retval - %new_comm, %retval3 = mpi.comm_split(%comm, %rank, %rank) : mpi.comm, i32, i32 - // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 // CHECK-NEXT: %1 = mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval %err2 = mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval - // CHECK-NEXT: mpi.send(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm - mpi.send(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm - // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 // CHECK-NEXT: %2 = mpi.recv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval %err3 = mpi.recv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> !mpi.retval - // CHECK-NEXT: mpi.recv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm - mpi.recv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm - // CHECK-NEXT: %3 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request %req1 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request // CHECK-NEXT: %4, %5 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request, !mpi.retval %req2, %err4, = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request, !mpi.retval - // CHECK-NEXT: %3 = mpi.isend(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request - %req1 = mpi.isend(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request - // CHECK-NEXT: %6 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request // CHECK-NEXT: %7, %8 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request, !mpi.retval %req4, %err5 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request, !mpi.retval - // CHECK-NEXT: %6 = mpi.irecv(%arg0, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request - %req3 = mpi.irecv(%ref, %rank, %rank, %comm) : memref<100xf32>, i32, i32, mpi.comm -> mpi.request - // CHECK-NEXT: mpi.wait(%req1) : mpi.request mpi.wait(%req1) : mpi.request @@ -66,18 +48,12 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK-NEXT: %10 = mpi.barrier : !mpi.retval %err7 = mpi.barrier : !mpi.retval - // CHECK-NEXT: mpi.barrier(%comm) : !mpi.retval - mpi.barrier(%comm) : !mpi.retval - // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM) : memref<100xf32>, memref<100xf32> -> !mpi.retval %err8 = mpi.allreduce(%ref, %ref, MPI_SUM) : memref<100xf32>, memref<100xf32> - // CHECK-NEXT: mpi.allreduce(%arg0, %arg0, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, mpi.comm - mpi.allreduce(%ref, %ref, MPI_SUM, %comm) : memref<100xf32>, memref<100xf32>, mpi.comm - // CHECK-NEXT: %11 = mpi.finalize : !mpi.retval %rval = mpi.finalize : !mpi.retval From 452f7602e9d56ce6aae957c4cb8f97b548ecc5b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 30 Jan 2025 00:42:18 +0100 Subject: [PATCH 39/42] fix tests --- mlir/test/Dialect/MPI/ops.mlir | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir index 37b84e88c7d4cf..05305c5afc7316 100644 --- a/mlir/test/Dialect/MPI/ops.mlir +++ b/mlir/test/Dialect/MPI/ops.mlir @@ -6,11 +6,11 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK: %0 = mpi.init : !mpi.retval %err = mpi.init : !mpi.retval - // CHECK-NEXT: %retval, %rank = mpi.comm_rank : i32, !mpi.retval - %retval, %rank = mpi.comm_rank : i32, !mpi.retval + // CHECK-NEXT: %rank, %retval = mpi.comm_rank : i32, !mpi.retval + %rank, %retval = mpi.comm_rank : i32, !mpi.retval - // CHECK-NEXT: %retval2, %size = mpi.comm_size : i32, !mpi.retval - %retval2, %size = mpi.comm_size : i32, !mpi.retval + // CHECK-NEXT: %size, %retval2 = mpi.comm_size : i32, !mpi.retval + %size, %retval2 = mpi.comm_size : i32, !mpi.retval // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 @@ -28,7 +28,7 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { %req1 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request // CHECK-NEXT: %4, %5 = mpi.isend(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request, !mpi.retval - %req2, %err4, = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request, !mpi.retval + %req2, %err4 = mpi.isend(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request, !mpi.retval // CHECK-NEXT: %6 = mpi.irecv(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request %req3 = mpi.irecv(%ref, %rank, %rank) : memref<100xf32>, i32, i32 -> mpi.request From 56868e84adc7fbe5ed3098c8b8fdad8683ab8dc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 30 Jan 2025 00:48:47 +0100 Subject: [PATCH 40/42] change order of results of `MPI_CommRankOp` --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 4 ++-- mlir/test/Dialect/MPI/ops.mlir | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index f3d21466795282..5a4bab982efa8f 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -52,8 +52,8 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { }]; let results = ( - outs I32 : $rank, - Optional : $retval + outs Optional : $retval, + I32 : $rank ); let assemblyFormat = "attr-dict `:` type(results)"; diff --git a/mlir/test/Dialect/MPI/ops.mlir b/mlir/test/Dialect/MPI/ops.mlir index 05305c5afc7316..bfe409da5e9c67 100644 --- a/mlir/test/Dialect/MPI/ops.mlir +++ b/mlir/test/Dialect/MPI/ops.mlir @@ -6,11 +6,11 @@ func.func @mpi_test(%ref : memref<100xf32>) -> () { // CHECK: %0 = mpi.init : !mpi.retval %err = mpi.init : !mpi.retval - // CHECK-NEXT: %rank, %retval = mpi.comm_rank : i32, !mpi.retval - %rank, %retval = mpi.comm_rank : i32, !mpi.retval + // CHECK-NEXT: %retval, %rank = mpi.comm_rank : !mpi.retval, i32 + %retval, %rank = mpi.comm_rank : !mpi.retval, i32 - // CHECK-NEXT: %size, %retval2 = mpi.comm_size : i32, !mpi.retval - %size, %retval2 = mpi.comm_size : i32, !mpi.retval + // CHECK-NEXT: %retval2, %size = mpi.comm_size : !mpi.retval, i32 + %retval2, %size = mpi.comm_size : !mpi.retval, i32 // CHECK-NEXT: mpi.send(%arg0, %rank, %rank) : memref<100xf32>, i32, i32 mpi.send(%ref, %rank, %rank) : memref<100xf32>, i32, i32 From b9988b392c987c4b6cc89e2eef5a9cec96299640 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 30 Jan 2025 00:48:55 +0100 Subject: [PATCH 41/42] format code --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 5a4bab982efa8f..93f9069ca70778 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -175,7 +175,7 @@ def MPI_RecvOp : MPI_Op<"recv", []> { let results = (outs Optional:$retval); let assemblyFormat = "`(` $ref `,` $tag `,` $rank `)` attr-dict `:`" - "type($ref) `,` type($tag) `,` type($rank) " + "type($ref) `,` type($tag) `,` type($rank)" "(`->` type($retval)^)?"; let hasCanonicalizer = 1; } From 2075c02ffa097e183013d7cdf486770f38f0543a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 30 Jan 2025 00:52:34 +0100 Subject: [PATCH 42/42] format code --- mlir/include/mlir/Dialect/MPI/IR/MPIOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td index 93f9069ca70778..8a0d21475f7b25 100644 --- a/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td +++ b/mlir/include/mlir/Dialect/MPI/IR/MPIOps.td @@ -46,7 +46,7 @@ def MPI_CommRankOp : MPI_Op<"comm_rank", []> { "`MPI_Comm_rank(MPI_COMM_WORLD, &rank)`"; let description = [{ Communicators other than `MPI_COMM_WORLD` are not supported for now. - + This operation can optionally return an `!mpi.retval` value that can be used to check for errors. }];