diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 299d9d438f1156d..98028f581e38716 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -344,6 +344,22 @@ bool ClauseProcessor::processDistSchedule( return false; } +bool ClauseProcessor::processExclusive( + mlir::Location currentLocation, + mlir::omp::ExclusiveClauseOps &result) const { + return findRepeatableClause( + [&](const omp::clause::Exclusive &clause, const parser::CharBlock &) { + for (const Object &object : clause.v) { + semantics::Symbol *sym = object.sym(); + mlir::Value symVal = converter.getSymbolAddress(*sym); + if (auto declOp = symVal.getDefiningOp()) { + symVal = declOp.getBase(); + } + result.exclusiveVars.push_back(symVal); + } + }); +} + bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx, mlir::omp::FilterClauseOps &result) const { if (auto *clause = findUniqueClause()) { @@ -380,6 +396,22 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const { return false; } +bool ClauseProcessor::processInclusive( + mlir::Location currentLocation, + mlir::omp::InclusiveClauseOps &result) const { + return findRepeatableClause( + [&](const omp::clause::Inclusive &clause, const parser::CharBlock &) { + for (const Object &object : clause.v) { + const semantics::Symbol *symbol = object.sym(); + mlir::Value symVal = converter.getSymbolAddress(*symbol); + // if (auto declOp = symVal.getDefiningOp()) { + // symVal = declOp.getBase(); + // } + result.inclusiveVars.push_back(symVal); + } + }); +} + bool ClauseProcessor::processMergeable( mlir::omp::MergeableClauseOps &result) const { return markClauseOccurrence(result.mergeable); @@ -1135,10 +1167,9 @@ bool ClauseProcessor::processReduction( llvm::SmallVector reductionDeclSymbols; llvm::SmallVector reductionSyms; ReductionProcessor rp; - rp.addDeclareReduction(currentLocation, converter, clause, - reductionVars, reduceVarByRef, - reductionDeclSymbols, reductionSyms); - + rp.addDeclareReduction( + currentLocation, converter, clause, reductionVars, reduceVarByRef, + reductionDeclSymbols, reductionSyms, &result.reductionMod); // Copy local lists into the output. llvm::copy(reductionVars, std::back_inserter(result.reductionVars)); llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref)); diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 7b047d4a7567adf..e05f66c7666844f 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -64,6 +64,8 @@ class ClauseProcessor { bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const; bool processDistSchedule(lower::StatementContext &stmtCtx, mlir::omp::DistScheduleClauseOps &result) const; + bool processExclusive(mlir::Location currentLocation, + mlir::omp::ExclusiveClauseOps &result) const; bool processFilter(lower::StatementContext &stmtCtx, mlir::omp::FilterClauseOps &result) const; bool processFinal(lower::StatementContext &stmtCtx, @@ -72,6 +74,8 @@ class ClauseProcessor { mlir::omp::HasDeviceAddrClauseOps &result, llvm::SmallVectorImpl &isDeviceSyms) const; bool processHint(mlir::omp::HintClauseOps &result) const; + bool processInclusive(mlir::Location currentLocation, + mlir::omp::InclusiveClauseOps &result) const; bool processMergeable(mlir::omp::MergeableClauseOps &result) const; bool processNowait(mlir::omp::NowaitClauseOps &result) const; bool processNumTeams(lower::StatementContext &stmtCtx, diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp index b424e209d56da93..a26bdcdf343e137 100644 --- a/flang/lib/Lower/OpenMP/Clauses.cpp +++ b/flang/lib/Lower/OpenMP/Clauses.cpp @@ -728,8 +728,8 @@ Enter make(const parser::OmpClause::Enter &inp, Exclusive make(const parser::OmpClause::Exclusive &inp, semantics::SemanticsContext &semaCtx) { - // inp -> empty - llvm_unreachable("Empty: exclusive"); + // inp.v -> parser::OmpObjectList + return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)}; } Fail make(const parser::OmpClause::Fail &inp, @@ -838,8 +838,8 @@ If make(const parser::OmpClause::If &inp, Inclusive make(const parser::OmpClause::Inclusive &inp, semantics::SemanticsContext &semaCtx) { - // inp -> empty - llvm_unreachable("Empty: inclusive"); + // inp.v -> parser::OmpObjectList + return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)}; } Indirect make(const parser::OmpClause::Indirect &inp, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 1434bcd6330e023..0e0e9a57287c3da 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1578,6 +1578,15 @@ static void genParallelClauses( cp.processReduction(loc, clauseOps, reductionSyms); } +static void genScanClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + const List &clauses, mlir::Location loc, + mlir::omp::ScanOperands &clauseOps) { + ClauseProcessor cp(converter, semaCtx, clauses); + cp.processInclusive(loc, clauseOps); + cp.processExclusive(loc, clauseOps); +} + static void genSectionsClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, const List &clauses, mlir::Location loc, @@ -1975,6 +1984,17 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable, return parallelOp; } +static mlir::omp::ScanOp +genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, mlir::Location loc, + const ConstructQueue &queue, ConstructQueue::const_iterator item) { + + mlir::omp::ScanOperands clauseOps; + genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps); + return converter.getFirOpBuilder().create( + converter.getCurrentLocation(), clauseOps); +} + /// This breaks the normal prototype of the gen*Op functions: adding the /// sectionBlocks argument so that the enclosed section constructs can be /// lowered here with correct reduction symbol remapping. @@ -2978,7 +2998,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter, genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item); break; case llvm::omp::Directive::OMPD_scan: - TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir)); + genScanOp(converter, symTable, semaCtx, loc, queue, item); break; case llvm::omp::Directive::OMPD_section: llvm_unreachable("genOMPDispatch: OMPD_section"); diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp index 2cd21107a916e4e..e6ca5d5073c336a 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp @@ -25,6 +25,7 @@ #include "flang/Parser/tools.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "llvm/Support/CommandLine.h" +#include static llvm::cl::opt forceByrefReduction( "force-byref-reduction", @@ -514,18 +515,36 @@ static bool doReductionByRef(mlir::Value reductionVar) { return false; } +mlir::omp::ReductionModifier +translateReductionModifier(const ReductionModifier &m) { + switch (m) { + case ReductionModifier::Default: + return mlir::omp::ReductionModifier::defaultmod; + case ReductionModifier::Inscan: + return mlir::omp::ReductionModifier::inscan; + case ReductionModifier::Task: + return mlir::omp::ReductionModifier::task; + } + return mlir::omp::ReductionModifier::defaultmod; +} + void ReductionProcessor::addDeclareReduction( mlir::Location currentLocation, lower::AbstractConverter &converter, const omp::clause::Reduction &reduction, llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reduceVarByRef, llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl &reductionSymbols) { + llvm::SmallVectorImpl &reductionSymbols, + mlir::omp::ReductionModifierAttr *reductionMod) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - if (std::get>( - reduction.t)) - TODO(currentLocation, "Reduction modifiers are not supported"); + auto mod = std::get>(reduction.t); + if (mod.has_value() && (mod.value() != ReductionModifier::Inscan)) { + std::string modStr = "default"; + if (mod.value() == ReductionModifier::Task) + modStr = "task"; + TODO(currentLocation, "Reduction modifier " + modStr + " is not supported"); + } mlir::omp::DeclareReductionOp decl; const auto &redOperatorList{ @@ -649,6 +668,11 @@ void ReductionProcessor::addDeclareReduction( currentLocation, isByRef); reductionDeclSymbols.push_back( mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName())); + auto redMod = std::get>(reduction.t); + if (redMod.has_value()) + *reductionMod = mlir::omp::ReductionModifierAttr::get( + firOpBuilder.getContext(), + translateReductionModifier(redMod.value())); } } diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.h b/flang/lib/Lower/OpenMP/ReductionProcessor.h index 5f4d742b62cb107..44ab67979d5db9f 100644 --- a/flang/lib/Lower/OpenMP/ReductionProcessor.h +++ b/flang/lib/Lower/OpenMP/ReductionProcessor.h @@ -19,6 +19,7 @@ #include "flang/Parser/parse-tree.h" #include "flang/Semantics/symbol.h" #include "flang/Semantics/type.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/Types.h" @@ -126,7 +127,8 @@ class ReductionProcessor { llvm::SmallVectorImpl &reductionVars, llvm::SmallVectorImpl &reduceVarByRef, llvm::SmallVectorImpl &reductionDeclSymbols, - llvm::SmallVectorImpl &reductionSymbols); + llvm::SmallVectorImpl &reductionSymbols, + mlir::omp::ReductionModifierAttr *reductionMod); }; template @@ -156,6 +158,8 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, return builder.create(loc, op1, op2); } +using ReductionModifier = omp::clause::Reduction::ReductionModifier; + } // namespace omp } // namespace lower } // namespace Fortran diff --git a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90 b/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90 deleted file mode 100644 index 152d91a16f80fe8..000000000000000 --- a/flang/test/Lower/OpenMP/Todo/reduction-inscan.f90 +++ /dev/null @@ -1,15 +0,0 @@ -! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s - -! CHECK: not yet implemented: Reduction modifiers are not supported -subroutine reduction_inscan() - integer :: i,j - i = 0 - - !$omp do reduction(inscan, +:i) - do j=1,10 - !$omp scan inclusive(i) - i = i + 1 - end do - !$omp end do -end subroutine reduction_inscan diff --git a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90 b/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90 deleted file mode 100644 index 82625ed8c5f31c1..000000000000000 --- a/flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90 +++ /dev/null @@ -1,14 +0,0 @@ -! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s - -! CHECK: not yet implemented: Reduction modifiers are not supported - -subroutine foo() - integer :: i, j - j = 0 - !$omp do reduction (inscan, *: j) - do i = 1, 10 - !$omp scan inclusive(j) - j = j + 1 - end do -end subroutine diff --git a/flang/test/Lower/OpenMP/Todo/reduction-task.f90 b/flang/test/Lower/OpenMP/Todo/reduction-task.f90 index 6707f65e1a4cc31..b746872e9e7edf5 100644 --- a/flang/test/Lower/OpenMP/Todo/reduction-task.f90 +++ b/flang/test/Lower/OpenMP/Todo/reduction-task.f90 @@ -1,7 +1,7 @@ ! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s ! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s -! CHECK: not yet implemented: Reduction modifiers are not supported +! CHECK: not yet implemented: Reduction modifier task is not supported subroutine reduction_task() integer :: i i = 0 diff --git a/flang/test/Lower/OpenMP/scan.f90 b/flang/test/Lower/OpenMP/scan.f90 new file mode 100644 index 000000000000000..9cf2174a7f3314c --- /dev/null +++ b/flang/test/Lower/OpenMP/scan.f90 @@ -0,0 +1,34 @@ +!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +subroutine inclusive_scan + implicit none + integer, parameter :: n = 100 + integer a(n), b(n) + integer x, k + + !CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) { + !$omp parallel do reduction(inscan, +: x) + do k = 1, n + x = x + a(k) + !CHECK: omp.scan inclusive({{.*}}) + !$omp scan inclusive(x) + b(k) = x + end do +end subroutine inclusive_scan + + +subroutine exclusive_scan + implicit none + integer, parameter :: n = 100 + integer a(n), b(n) + integer x, k + + !CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) { + !$omp parallel do reduction(inscan, +: x) + do k = 1, n + x = x + a(k) + !CHECK: omp.scan exclusive({{.*}}) + !$omp scan exclusive(x) + b(k) = x + end do +end subroutine exclusive_scan diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 5d0003911bca879..056b7f989128a44 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality( target.addDynamicallyLegalOp< omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp, omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp, - omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp, + omp::MapInfoOp, omp::ScanOp, omp::OrderedOp, omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp, omp::YieldOp>([&](Operation *op) { return typeConverter.isLegal(op->getOperandTypes()) && @@ -264,6 +264,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, + RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion, RegionLessOpConversion,