Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][openmp] Changes for invoking scan Op #123254

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,19 @@ bool ClauseProcessor::processDistSchedule(
return false;
}

bool ClauseProcessor::processExclusive(
mlir::Location currentLocation,
mlir::omp::ExclusiveClauseOps &result) const {
return findRepeatableClause<omp::clause::Exclusive>(
skatrak marked this conversation as resolved.
Show resolved Hide resolved
[&](const omp::clause::Exclusive &clause, const parser::CharBlock &) {
for (const Object &object : clause.v) {
const semantics::Symbol *symbol = object.sym();
mlir::Value symVal = converter.getSymbolAddress(*symbol);
result.exclusiveVars.push_back(symVal);
}
});
}

bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
Expand Down Expand Up @@ -380,6 +393,19 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
return false;
}

bool ClauseProcessor::processInclusive(
mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const {
return findRepeatableClause<omp::clause::Inclusive>(
[&](const omp::clause::Inclusive &clause, const parser::CharBlock &) {
for (const Object &object : clause.v) {
const semantics::Symbol *symbol = object.sym();
mlir::Value symVal = converter.getSymbolAddress(*symbol);
result.inclusiveVars.push_back(symVal);
}
});
}

bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
Expand Down Expand Up @@ -1135,10 +1161,9 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(currentLocation, converter, clause,
reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms);

rp.addDeclareReduction(
currentLocation, converter, clause, reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms, &result.reductionMod);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
Expand Down
4 changes: 4 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class ClauseProcessor {
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processDistSchedule(lower::StatementContext &stmtCtx,
mlir::omp::DistScheduleClauseOps &result) const;
bool processExclusive(mlir::Location currentLocation,
mlir::omp::ExclusiveClauseOps &result) const;
bool processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const;
bool processFinal(lower::StatementContext &stmtCtx,
Expand All @@ -72,6 +74,8 @@ class ClauseProcessor {
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processInclusive(mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(lower::StatementContext &stmtCtx,
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,8 +728,8 @@ Enter make(const parser::OmpClause::Enter &inp,

Exclusive make(const parser::OmpClause::Exclusive &inp,
semantics::SemanticsContext &semaCtx) {
// inp -> empty
llvm_unreachable("Empty: exclusive");
// inp.v -> parser::OmpObjectList
return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}

Fail make(const parser::OmpClause::Fail &inp,
Expand Down Expand Up @@ -838,8 +838,8 @@ If make(const parser::OmpClause::If &inp,

Inclusive make(const parser::OmpClause::Inclusive &inp,
semantics::SemanticsContext &semaCtx) {
// inp -> empty
llvm_unreachable("Empty: inclusive");
// inp.v -> parser::OmpObjectList
return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}

Indirect make(const parser::OmpClause::Indirect &inp,
Expand Down
21 changes: 20 additions & 1 deletion flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,15 @@ static void genParallelClauses(
cp.processReduction(loc, clauseOps, reductionSyms);
}

static void genScanClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::ScanOperands &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processInclusive(loc, clauseOps);
cp.processExclusive(loc, clauseOps);
}

static void genSectionsClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
Expand Down Expand Up @@ -1975,6 +1984,16 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return parallelOp;
}

static mlir::omp::ScanOp
genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, mlir::Location loc,
const ConstructQueue &queue, ConstructQueue::const_iterator item) {
mlir::omp::ScanOperands clauseOps;
genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
return converter.getFirOpBuilder().create<mlir::omp::ScanOp>(
converter.getCurrentLocation(), clauseOps);
}

/// This breaks the normal prototype of the gen*Op functions: adding the
/// sectionBlocks argument so that the enclosed section constructs can be
/// lowered here with correct reduction symbol remapping.
Expand Down Expand Up @@ -2978,7 +2997,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_scan:
TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
genScanOp(converter, symTable, semaCtx, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_section:
llvm_unreachable("genOMPDispatch: OMPD_section");
Expand Down
32 changes: 28 additions & 4 deletions flang/lib/Lower/OpenMP/ReductionProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
#include <string>

static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
Expand Down Expand Up @@ -514,18 +515,36 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}

mlir::omp::ReductionModifier
translateReductionModifier(const ReductionModifier &m) {
skatrak marked this conversation as resolved.
Show resolved Hide resolved
switch (m) {
case ReductionModifier::Default:
return mlir::omp::ReductionModifier::defaultmod;
case ReductionModifier::Inscan:
return mlir::omp::ReductionModifier::inscan;
case ReductionModifier::Task:
return mlir::omp::ReductionModifier::task;
}
return mlir::omp::ReductionModifier::defaultmod;
}

void ReductionProcessor::addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr *reductionMod) {
skatrak marked this conversation as resolved.
Show resolved Hide resolved
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
reduction.t))
TODO(currentLocation, "Reduction modifiers are not supported");
auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
if (mod.has_value() && (mod.value() != ReductionModifier::Inscan)) {
std::string modStr = "default";
if (mod.value() == ReductionModifier::Task)
modStr = "task";
TODO(currentLocation, "Reduction modifier " + modStr + " is not supported");
skatrak marked this conversation as resolved.
Show resolved Hide resolved
}
skatrak marked this conversation as resolved.
Show resolved Hide resolved

mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{
Expand Down Expand Up @@ -649,6 +668,11 @@ void ReductionProcessor::addDeclareReduction(
currentLocation, isByRef);
reductionDeclSymbols.push_back(
mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
auto redMod = std::get<std::optional<ReductionModifier>>(reduction.t);
if (redMod.has_value())
*reductionMod = mlir::omp::ReductionModifierAttr::get(
firOpBuilder.getContext(),
translateReductionModifier(redMod.value()));
skatrak marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
6 changes: 5 additions & 1 deletion flang/lib/Lower/OpenMP/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"

Expand Down Expand Up @@ -126,7 +127,8 @@ class ReductionProcessor {
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr *reductionMod);
};

template <typename FloatOp, typename IntegerOp>
Expand Down Expand Up @@ -156,6 +158,8 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
return builder.create<ComplexOp>(loc, op1, op2);
}

using ReductionModifier = omp::clause::Reduction::ReductionModifier;
skatrak marked this conversation as resolved.
Show resolved Hide resolved

} // namespace omp
} // namespace lower
} // namespace Fortran
Expand Down
15 changes: 0 additions & 15 deletions flang/test/Lower/OpenMP/Todo/reduction-inscan.f90

This file was deleted.

14 changes: 0 additions & 14 deletions flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90

This file was deleted.

2 changes: 1 addition & 1 deletion flang/test/Lower/OpenMP/Todo/reduction-task.f90
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s

! CHECK: not yet implemented: Reduction modifiers are not supported
! CHECK: not yet implemented: Reduction modifier task is not supported
subroutine reduction_task()
integer :: i
i = 0
Expand Down
34 changes: 34 additions & 0 deletions flang/test/Lower/OpenMP/scan.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s

subroutine inclusive_scan
implicit none
integer, parameter :: n = 100
integer a(n), b(n)
integer x, k

!CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be good to check that the block argument defined by the omp.wsloop operation for the reduction variable is what's eventually passed to omp.scan. Something like:

! CHECK: omp.wsloop reduction(mod: inscan, %{{.*}} -> %[[RED_ARG:.*]] : {{.*}}) {
! CHECK: %[[RED_DECL:.*]]:2 = hlfir.declare %[[RED_ARG]]
! CHECK: omp.scan inclusive(%[[RED_DECL]]#0 : {{.*}})

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to use the autogenerated checks so these gets checked

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd personally prefer not doing that, because it makes the test easily break for unrelated changes in representation and lowering. But feel free to leave it in if you prefer, I don't feel that strongly about it and there are a few other tests like this in the same directory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the issue you are pointing to. I have updated the testcase.

!$omp parallel do reduction(inscan, +: x)
do k = 1, n
x = x + a(k)
!CHECK: omp.scan inclusive({{.*}})
!$omp scan inclusive(x)
b(k) = x
end do
end subroutine inclusive_scan


subroutine exclusive_scan
implicit none
integer, parameter :: n = 100
integer a(n), b(n)
integer x, k

!CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
!$omp parallel do reduction(inscan, +: x)
do k = 1, n
x = x + a(k)
!CHECK: omp.scan exclusive({{.*}})
!$omp scan exclusive(x)
b(k) = x
end do
end subroutine exclusive_scan
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality(
target.addDynamicallyLegalOp<
omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
omp::MapInfoOp, omp::OrderedOp, omp::ScanOp, omp::TargetEnterDataOp,
omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
omp::YieldOp>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
Expand Down Expand Up @@ -264,6 +264,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
RegionLessOpConversion<omp::CancelOp>,
RegionLessOpConversion<omp::CriticalDeclareOp>,
RegionLessOpConversion<omp::OrderedOp>,
RegionLessOpConversion<omp::ScanOp>,
RegionLessOpConversion<omp::TargetEnterDataOp>,
RegionLessOpConversion<omp::TargetExitDataOp>,
RegionLessOpConversion<omp::TargetUpdateOp>,
Expand Down
Loading