-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[flang][openmp] Changes for invoking scan Op #123254
base: main
Are you sure you want to change the base?
Conversation
fccf0c6
to
a66c8cb
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-openmp Author: Anchu Rajendran S (anchuraj) ChangesFrontend changes to invoke scan Op. It will build successfully after #114737 is merged. Full diff: https://github.com/llvm/llvm-project/pull/123254.diff 11 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 299d9d438f1156..8bec29c74a1542 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -344,6 +344,19 @@ bool ClauseProcessor::processDistSchedule(
return false;
}
+bool ClauseProcessor::processExclusive(
+ mlir::Location currentLocation,
+ mlir::omp::ExclusiveClauseOps &result) const {
+ return findRepeatableClause<omp::clause::Exclusive>(
+ [&](const omp::clause::Exclusive &clause, const parser::CharBlock &) {
+ for (const Object &object : clause.v) {
+ const semantics::Symbol *symbol = object.sym();
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ result.exclusiveVars.push_back(symVal);
+ }
+ });
+}
+
bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
@@ -380,6 +393,19 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
return false;
}
+bool ClauseProcessor::processInclusive(
+ mlir::Location currentLocation,
+ mlir::omp::InclusiveClauseOps &result) const {
+ return findRepeatableClause<omp::clause::Inclusive>(
+ [&](const omp::clause::Inclusive &clause, const parser::CharBlock &) {
+ for (const Object &object : clause.v) {
+ const semantics::Symbol *symbol = object.sym();
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ result.inclusiveVars.push_back(symVal);
+ }
+ });
+}
+
bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
@@ -1135,10 +1161,9 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
- rp.addDeclareReduction(currentLocation, converter, clause,
- reductionVars, reduceVarByRef,
- reductionDeclSymbols, reductionSyms);
-
+ rp.addDeclareReduction(
+ currentLocation, converter, clause, reductionVars, reduceVarByRef,
+ reductionDeclSymbols, reductionSyms, &result.reductionMod);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7b047d4a7567ad..e05f66c7666844 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -64,6 +64,8 @@ class ClauseProcessor {
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processDistSchedule(lower::StatementContext &stmtCtx,
mlir::omp::DistScheduleClauseOps &result) const;
+ bool processExclusive(mlir::Location currentLocation,
+ mlir::omp::ExclusiveClauseOps &result) const;
bool processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const;
bool processFinal(lower::StatementContext &stmtCtx,
@@ -72,6 +74,8 @@ class ClauseProcessor {
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
+ bool processInclusive(mlir::Location currentLocation,
+ mlir::omp::InclusiveClauseOps &result) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(lower::StatementContext &stmtCtx,
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b424e209d56da9..a26bdcdf343e13 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -728,8 +728,8 @@ Enter make(const parser::OmpClause::Enter &inp,
Exclusive make(const parser::OmpClause::Exclusive &inp,
semantics::SemanticsContext &semaCtx) {
- // inp -> empty
- llvm_unreachable("Empty: exclusive");
+ // inp.v -> parser::OmpObjectList
+ return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}
Fail make(const parser::OmpClause::Fail &inp,
@@ -838,8 +838,8 @@ If make(const parser::OmpClause::If &inp,
Inclusive make(const parser::OmpClause::Inclusive &inp,
semantics::SemanticsContext &semaCtx) {
- // inp -> empty
- llvm_unreachable("Empty: inclusive");
+ // inp.v -> parser::OmpObjectList
+ return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}
Indirect make(const parser::OmpClause::Indirect &inp,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 1434bcd6330e02..0e0e9a57287c3d 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1578,6 +1578,15 @@ static void genParallelClauses(
cp.processReduction(loc, clauseOps, reductionSyms);
}
+static void genScanClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses, mlir::Location loc,
+ mlir::omp::ScanOperands &clauseOps) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processInclusive(loc, clauseOps);
+ cp.processExclusive(loc, clauseOps);
+}
+
static void genSectionsClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
@@ -1975,6 +1984,17 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return parallelOp;
}
+static mlir::omp::ScanOp
+genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, mlir::Location loc,
+ const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+
+ mlir::omp::ScanOperands clauseOps;
+ genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
+ return converter.getFirOpBuilder().create<mlir::omp::ScanOp>(
+ converter.getCurrentLocation(), clauseOps);
+}
+
/// This breaks the normal prototype of the gen*Op functions: adding the
/// sectionBlocks argument so that the enclosed section constructs can be
/// lowered here with correct reduction symbol remapping.
@@ -2978,7 +2998,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_scan:
- TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
+ genScanOp(converter, symTable, semaCtx, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_section:
llvm_unreachable("genOMPDispatch: OMPD_section");
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index 2cd21107a916e4..e6ca5d5073c336 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -25,6 +25,7 @@
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
+#include <string>
static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
@@ -514,18 +515,36 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}
+mlir::omp::ReductionModifier
+translateReductionModifier(const ReductionModifier &m) {
+ switch (m) {
+ case ReductionModifier::Default:
+ return mlir::omp::ReductionModifier::defaultmod;
+ case ReductionModifier::Inscan:
+ return mlir::omp::ReductionModifier::inscan;
+ case ReductionModifier::Task:
+ return mlir::omp::ReductionModifier::task;
+ }
+ return mlir::omp::ReductionModifier::defaultmod;
+}
+
void ReductionProcessor::addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
+ mlir::omp::ReductionModifierAttr *reductionMod) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
- reduction.t))
- TODO(currentLocation, "Reduction modifiers are not supported");
+ auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
+ if (mod.has_value() && (mod.value() != ReductionModifier::Inscan)) {
+ std::string modStr = "default";
+ if (mod.value() == ReductionModifier::Task)
+ modStr = "task";
+ TODO(currentLocation, "Reduction modifier " + modStr + " is not supported");
+ }
mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{
@@ -649,6 +668,11 @@ void ReductionProcessor::addDeclareReduction(
currentLocation, isByRef);
reductionDeclSymbols.push_back(
mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
+ auto redMod = std::get<std::optional<ReductionModifier>>(reduction.t);
+ if (redMod.has_value())
+ *reductionMod = mlir::omp::ReductionModifierAttr::get(
+ firOpBuilder.getContext(),
+ translateReductionModifier(redMod.value()));
}
}
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 5f4d742b62cb10..44ab67979d5db9 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -19,6 +19,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
@@ -126,7 +127,8 @@ class ReductionProcessor {
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
+ mlir::omp::ReductionModifierAttr *reductionMod);
};
template <typename FloatOp, typename IntegerOp>
@@ -156,6 +158,8 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
return builder.create<ComplexOp>(loc, op1, op2);
}
+using ReductionModifier = omp::clause::Reduction::ReductionModifier;
+
} // namespace omp
} // namespace lower
} // namespace Fortran
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90 b/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
deleted file mode 100644
index 152d91a16f80fe..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
+++ /dev/null
@@ -1,15 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction modifiers are not supported
-subroutine reduction_inscan()
- integer :: i,j
- i = 0
-
- !$omp do reduction(inscan, +:i)
- do j=1,10
- !$omp scan inclusive(i)
- i = i + 1
- end do
- !$omp end do
-end subroutine reduction_inscan
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90 b/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
deleted file mode 100644
index 82625ed8c5f31c..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
+++ /dev/null
@@ -1,14 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction modifiers are not supported
-
-subroutine foo()
- integer :: i, j
- j = 0
- !$omp do reduction (inscan, *: j)
- do i = 1, 10
- !$omp scan inclusive(j)
- j = j + 1
- end do
-end subroutine
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-task.f90 b/flang/test/Lower/OpenMP/Todo/reduction-task.f90
index 6707f65e1a4cc3..b746872e9e7edf 100644
--- a/flang/test/Lower/OpenMP/Todo/reduction-task.f90
+++ b/flang/test/Lower/OpenMP/Todo/reduction-task.f90
@@ -1,7 +1,7 @@
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! CHECK: not yet implemented: Reduction modifiers are not supported
+! CHECK: not yet implemented: Reduction modifier task is not supported
subroutine reduction_task()
integer :: i
i = 0
diff --git a/flang/test/Lower/OpenMP/scan.f90 b/flang/test/Lower/OpenMP/scan.f90
new file mode 100644
index 00000000000000..9cf2174a7f3314
--- /dev/null
+++ b/flang/test/Lower/OpenMP/scan.f90
@@ -0,0 +1,34 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+subroutine inclusive_scan
+ implicit none
+ integer, parameter :: n = 100
+ integer a(n), b(n)
+ integer x, k
+
+ !CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
+ !$omp parallel do reduction(inscan, +: x)
+ do k = 1, n
+ x = x + a(k)
+ !CHECK: omp.scan inclusive({{.*}})
+ !$omp scan inclusive(x)
+ b(k) = x
+ end do
+end subroutine inclusive_scan
+
+
+subroutine exclusive_scan
+ implicit none
+ integer, parameter :: n = 100
+ integer a(n), b(n)
+ integer x, k
+
+ !CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
+ !$omp parallel do reduction(inscan, +: x)
+ do k = 1, n
+ x = x + a(k)
+ !CHECK: omp.scan exclusive({{.*}})
+ !$omp scan exclusive(x)
+ b(k) = x
+ end do
+end subroutine exclusive_scan
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 5d0003911bca87..056b7f989128a4 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality(
target.addDynamicallyLegalOp<
omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
- omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
+ omp::MapInfoOp, omp::ScanOp, omp::OrderedOp, omp::TargetEnterDataOp,
omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
omp::YieldOp>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
@@ -264,6 +264,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
RegionLessOpConversion<omp::CancelOp>,
RegionLessOpConversion<omp::CriticalDeclareOp>,
RegionLessOpConversion<omp::OrderedOp>,
+ RegionLessOpConversion<omp::ScanOp>,
RegionLessOpConversion<omp::TargetEnterDataOp>,
RegionLessOpConversion<omp::TargetExitDataOp>,
RegionLessOpConversion<omp::TargetUpdateOp>,
|
@llvm/pr-subscribers-flang-openmp Author: Anchu Rajendran S (anchuraj) ChangesFrontend changes to invoke scan Op. It will build successfully after #114737 is merged. Full diff: https://github.com/llvm/llvm-project/pull/123254.diff 11 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 299d9d438f1156..8bec29c74a1542 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -344,6 +344,19 @@ bool ClauseProcessor::processDistSchedule(
return false;
}
+bool ClauseProcessor::processExclusive(
+ mlir::Location currentLocation,
+ mlir::omp::ExclusiveClauseOps &result) const {
+ return findRepeatableClause<omp::clause::Exclusive>(
+ [&](const omp::clause::Exclusive &clause, const parser::CharBlock &) {
+ for (const Object &object : clause.v) {
+ const semantics::Symbol *symbol = object.sym();
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ result.exclusiveVars.push_back(symVal);
+ }
+ });
+}
+
bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
@@ -380,6 +393,19 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
return false;
}
+bool ClauseProcessor::processInclusive(
+ mlir::Location currentLocation,
+ mlir::omp::InclusiveClauseOps &result) const {
+ return findRepeatableClause<omp::clause::Inclusive>(
+ [&](const omp::clause::Inclusive &clause, const parser::CharBlock &) {
+ for (const Object &object : clause.v) {
+ const semantics::Symbol *symbol = object.sym();
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ result.inclusiveVars.push_back(symVal);
+ }
+ });
+}
+
bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
@@ -1135,10 +1161,9 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
- rp.addDeclareReduction(currentLocation, converter, clause,
- reductionVars, reduceVarByRef,
- reductionDeclSymbols, reductionSyms);
-
+ rp.addDeclareReduction(
+ currentLocation, converter, clause, reductionVars, reduceVarByRef,
+ reductionDeclSymbols, reductionSyms, &result.reductionMod);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7b047d4a7567ad..e05f66c7666844 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -64,6 +64,8 @@ class ClauseProcessor {
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processDistSchedule(lower::StatementContext &stmtCtx,
mlir::omp::DistScheduleClauseOps &result) const;
+ bool processExclusive(mlir::Location currentLocation,
+ mlir::omp::ExclusiveClauseOps &result) const;
bool processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const;
bool processFinal(lower::StatementContext &stmtCtx,
@@ -72,6 +74,8 @@ class ClauseProcessor {
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
+ bool processInclusive(mlir::Location currentLocation,
+ mlir::omp::InclusiveClauseOps &result) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(lower::StatementContext &stmtCtx,
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b424e209d56da9..a26bdcdf343e13 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -728,8 +728,8 @@ Enter make(const parser::OmpClause::Enter &inp,
Exclusive make(const parser::OmpClause::Exclusive &inp,
semantics::SemanticsContext &semaCtx) {
- // inp -> empty
- llvm_unreachable("Empty: exclusive");
+ // inp.v -> parser::OmpObjectList
+ return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}
Fail make(const parser::OmpClause::Fail &inp,
@@ -838,8 +838,8 @@ If make(const parser::OmpClause::If &inp,
Inclusive make(const parser::OmpClause::Inclusive &inp,
semantics::SemanticsContext &semaCtx) {
- // inp -> empty
- llvm_unreachable("Empty: inclusive");
+ // inp.v -> parser::OmpObjectList
+ return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}
Indirect make(const parser::OmpClause::Indirect &inp,
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 1434bcd6330e02..0e0e9a57287c3d 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1578,6 +1578,15 @@ static void genParallelClauses(
cp.processReduction(loc, clauseOps, reductionSyms);
}
+static void genScanClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ const List<Clause> &clauses, mlir::Location loc,
+ mlir::omp::ScanOperands &clauseOps) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processInclusive(loc, clauseOps);
+ cp.processExclusive(loc, clauseOps);
+}
+
static void genSectionsClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
@@ -1975,6 +1984,17 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return parallelOp;
}
+static mlir::omp::ScanOp
+genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, mlir::Location loc,
+ const ConstructQueue &queue, ConstructQueue::const_iterator item) {
+
+ mlir::omp::ScanOperands clauseOps;
+ genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
+ return converter.getFirOpBuilder().create<mlir::omp::ScanOp>(
+ converter.getCurrentLocation(), clauseOps);
+}
+
/// This breaks the normal prototype of the gen*Op functions: adding the
/// sectionBlocks argument so that the enclosed section constructs can be
/// lowered here with correct reduction symbol remapping.
@@ -2978,7 +2998,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_scan:
- TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
+ genScanOp(converter, symTable, semaCtx, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_section:
llvm_unreachable("genOMPDispatch: OMPD_section");
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index 2cd21107a916e4..e6ca5d5073c336 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -25,6 +25,7 @@
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
+#include <string>
static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
@@ -514,18 +515,36 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}
+mlir::omp::ReductionModifier
+translateReductionModifier(const ReductionModifier &m) {
+ switch (m) {
+ case ReductionModifier::Default:
+ return mlir::omp::ReductionModifier::defaultmod;
+ case ReductionModifier::Inscan:
+ return mlir::omp::ReductionModifier::inscan;
+ case ReductionModifier::Task:
+ return mlir::omp::ReductionModifier::task;
+ }
+ return mlir::omp::ReductionModifier::defaultmod;
+}
+
void ReductionProcessor::addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
+ mlir::omp::ReductionModifierAttr *reductionMod) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
- reduction.t))
- TODO(currentLocation, "Reduction modifiers are not supported");
+ auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
+ if (mod.has_value() && (mod.value() != ReductionModifier::Inscan)) {
+ std::string modStr = "default";
+ if (mod.value() == ReductionModifier::Task)
+ modStr = "task";
+ TODO(currentLocation, "Reduction modifier " + modStr + " is not supported");
+ }
mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{
@@ -649,6 +668,11 @@ void ReductionProcessor::addDeclareReduction(
currentLocation, isByRef);
reductionDeclSymbols.push_back(
mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
+ auto redMod = std::get<std::optional<ReductionModifier>>(reduction.t);
+ if (redMod.has_value())
+ *reductionMod = mlir::omp::ReductionModifierAttr::get(
+ firOpBuilder.getContext(),
+ translateReductionModifier(redMod.value()));
}
}
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h
index 5f4d742b62cb10..44ab67979d5db9 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.h
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h
@@ -19,6 +19,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
@@ -126,7 +127,8 @@ class ReductionProcessor {
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
- llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
+ llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
+ mlir::omp::ReductionModifierAttr *reductionMod);
};
template <typename FloatOp, typename IntegerOp>
@@ -156,6 +158,8 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
return builder.create<ComplexOp>(loc, op1, op2);
}
+using ReductionModifier = omp::clause::Reduction::ReductionModifier;
+
} // namespace omp
} // namespace lower
} // namespace Fortran
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90 b/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
deleted file mode 100644
index 152d91a16f80fe..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90
+++ /dev/null
@@ -1,15 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction modifiers are not supported
-subroutine reduction_inscan()
- integer :: i,j
- i = 0
-
- !$omp do reduction(inscan, +:i)
- do j=1,10
- !$omp scan inclusive(i)
- i = i + 1
- end do
- !$omp end do
-end subroutine reduction_inscan
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90 b/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
deleted file mode 100644
index 82625ed8c5f31c..00000000000000
--- a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90
+++ /dev/null
@@ -1,14 +0,0 @@
-! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-
-! CHECK: not yet implemented: Reduction modifiers are not supported
-
-subroutine foo()
- integer :: i, j
- j = 0
- !$omp do reduction (inscan, *: j)
- do i = 1, 10
- !$omp scan inclusive(j)
- j = j + 1
- end do
-end subroutine
diff --git a/flang/test/Lower/OpenMP/Todo/reduction-task.f90 b/flang/test/Lower/OpenMP/Todo/reduction-task.f90
index 6707f65e1a4cc3..b746872e9e7edf 100644
--- a/flang/test/Lower/OpenMP/Todo/reduction-task.f90
+++ b/flang/test/Lower/OpenMP/Todo/reduction-task.f90
@@ -1,7 +1,7 @@
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
-! CHECK: not yet implemented: Reduction modifiers are not supported
+! CHECK: not yet implemented: Reduction modifier task is not supported
subroutine reduction_task()
integer :: i
i = 0
diff --git a/flang/test/Lower/OpenMP/scan.f90 b/flang/test/Lower/OpenMP/scan.f90
new file mode 100644
index 00000000000000..9cf2174a7f3314
--- /dev/null
+++ b/flang/test/Lower/OpenMP/scan.f90
@@ -0,0 +1,34 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+subroutine inclusive_scan
+ implicit none
+ integer, parameter :: n = 100
+ integer a(n), b(n)
+ integer x, k
+
+ !CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
+ !$omp parallel do reduction(inscan, +: x)
+ do k = 1, n
+ x = x + a(k)
+ !CHECK: omp.scan inclusive({{.*}})
+ !$omp scan inclusive(x)
+ b(k) = x
+ end do
+end subroutine inclusive_scan
+
+
+subroutine exclusive_scan
+ implicit none
+ integer, parameter :: n = 100
+ integer a(n), b(n)
+ integer x, k
+
+ !CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
+ !$omp parallel do reduction(inscan, +: x)
+ do k = 1, n
+ x = x + a(k)
+ !CHECK: omp.scan exclusive({{.*}})
+ !$omp scan exclusive(x)
+ b(k) = x
+ end do
+end subroutine exclusive_scan
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 5d0003911bca87..056b7f989128a4 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality(
target.addDynamicallyLegalOp<
omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
- omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
+ omp::MapInfoOp, omp::ScanOp, omp::OrderedOp, omp::TargetEnterDataOp,
omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
omp::YieldOp>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
@@ -264,6 +264,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
RegionLessOpConversion<omp::CancelOp>,
RegionLessOpConversion<omp::CriticalDeclareOp>,
RegionLessOpConversion<omp::OrderedOp>,
+ RegionLessOpConversion<omp::ScanOp>,
RegionLessOpConversion<omp::TargetEnterDataOp>,
RegionLessOpConversion<omp::TargetExitDataOp>,
RegionLessOpConversion<omp::TargetUpdateOp>,
|
a66c8cb
to
08a8bd1
Compare
08a8bd1
to
d8e27cd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Anchu, just a few comments from me.
flang/test/Lower/OpenMP/scan.f90
Outdated
integer a(n), b(n) | ||
integer x, k | ||
|
||
!CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be good to check that the block argument defined by the omp.wsloop
operation for the reduction variable is what's eventually passed to omp.scan
. Something like:
! CHECK: omp.wsloop reduction(mod: inscan, %{{.*}} -> %[[RED_ARG:.*]] : {{.*}}) {
! CHECK: %[[RED_DECL:.*]]:2 = hlfir.declare %[[RED_ARG]]
! CHECK: omp.scan inclusive(%[[RED_DECL]]#0 : {{.*}})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to use the autogenerated checks so these gets checked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd personally prefer not doing that, because it makes the test easily break for unrelated changes in representation and lowering. But feel free to leave it in if you prefer, I don't feel that strongly about it and there are a few other tests like this in the same directory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand the issue you are pointing to. I have updated the testcase.
47b8887
to
f1c408b
Compare
f1c408b
to
bb6a40c
Compare
@@ -35,6 +35,8 @@ namespace Fortran { | |||
namespace lower { | |||
namespace omp { | |||
|
|||
using ReductionModifier = omp::clause::Reduction::ReductionModifier; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd place this below the #include
s.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think then it would be using ReductionModifier = Fortran::lower::omp::clause::Reduction::ReductionModifier;
I modified it.
flang/test/Lower/OpenMP/scan.f90
Outdated
integer a(n), b(n) | ||
integer x, k | ||
|
||
!CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd personally prefer not doing that, because it makes the test easily break for unrelated changes in representation and lowering. But feel free to leave it in if you prefer, I don't feel that strongly about it and there are a few other tests like this in the same directory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for working on my feedback, LGTM.
87ecee3
to
e52b03b
Compare
e52b03b
to
656a381
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for taking so long to review this. Just one minor point.
auto mod = std::get<std::optional<ReductionModifier>>(reduction.t); | ||
if (mod.has_value()) { | ||
if (mod.value() == ReductionModifier::Task) | ||
TODO(currentLocation, "Reduction modifier `task` is not supported"); | ||
else | ||
reductionMod = mlir::omp::ReductionModifierAttr::get( | ||
firOpBuilder.getContext(), translateReductionModifier(mod.value())); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you decide to put this in addDeclareReduction
? If I understand correctly, this is not used as part of the declare reduction operation. I wonder if it would be better placed in processReduction
?
Scan
directive specifies the scan computations. This PR adds the Frontend changes to invoke scan Op.##Related PRs
-#102792
-#114737