Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang][openmp] Changes for invoking scan Op #123254

Merged
merged 3 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,20 @@ bool ClauseProcessor::processDistSchedule(
return false;
}

bool ClauseProcessor::processExclusive(
mlir::Location currentLocation,
mlir::omp::ExclusiveClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Exclusive>()) {
for (const Object &object : clause->v) {
const semantics::Symbol *symbol = object.sym();
mlir::Value symVal = converter.getSymbolAddress(*symbol);
result.exclusiveVars.push_back(symVal);
}
return true;
}
return false;
}

bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
Expand Down Expand Up @@ -380,6 +394,20 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
return false;
}

bool ClauseProcessor::processInclusive(
mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Inclusive>()) {
for (const Object &object : clause->v) {
const semantics::Symbol *symbol = object.sym();
mlir::Value symVal = converter.getSymbolAddress(*symbol);
result.inclusiveVars.push_back(symVal);
}
return true;
}
return false;
}

bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
Expand Down Expand Up @@ -1135,10 +1163,9 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(currentLocation, converter, clause,
reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms);

rp.processReductionArguments(
currentLocation, converter, clause, reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms, result.reductionMod);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
Expand Down
4 changes: 4 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class ClauseProcessor {
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processDistSchedule(lower::StatementContext &stmtCtx,
mlir::omp::DistScheduleClauseOps &result) const;
bool processExclusive(mlir::Location currentLocation,
mlir::omp::ExclusiveClauseOps &result) const;
bool processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const;
bool processFinal(lower::StatementContext &stmtCtx,
Expand All @@ -72,6 +74,8 @@ class ClauseProcessor {
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processInclusive(mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(lower::StatementContext &stmtCtx,
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,8 +728,8 @@ Enter make(const parser::OmpClause::Enter &inp,

Exclusive make(const parser::OmpClause::Exclusive &inp,
semantics::SemanticsContext &semaCtx) {
// inp -> empty
llvm_unreachable("Empty: exclusive");
// inp.v -> parser::OmpObjectList
return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}

Fail make(const parser::OmpClause::Fail &inp,
Expand Down Expand Up @@ -838,8 +838,8 @@ If make(const parser::OmpClause::If &inp,

Inclusive make(const parser::OmpClause::Inclusive &inp,
semantics::SemanticsContext &semaCtx) {
// inp -> empty
llvm_unreachable("Empty: inclusive");
// inp.v -> parser::OmpObjectList
return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}

Indirect make(const parser::OmpClause::Indirect &inp,
Expand Down
21 changes: 20 additions & 1 deletion flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,15 @@ static void genParallelClauses(
cp.processReduction(loc, clauseOps, reductionSyms);
}

static void genScanClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::ScanOperands &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processInclusive(loc, clauseOps);
cp.processExclusive(loc, clauseOps);
}

static void genSectionsClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
Expand Down Expand Up @@ -1975,6 +1984,16 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return parallelOp;
}

static mlir::omp::ScanOp
genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, mlir::Location loc,
const ConstructQueue &queue, ConstructQueue::const_iterator item) {
mlir::omp::ScanOperands clauseOps;
genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
return converter.getFirOpBuilder().create<mlir::omp::ScanOp>(
converter.getCurrentLocation(), clauseOps);
}

/// This breaks the normal prototype of the gen*Op functions: adding the
/// sectionBlocks argument so that the enclosed section constructs can be
/// lowered here with correct reduction symbol remapping.
Expand Down Expand Up @@ -2978,7 +2997,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_scan:
TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
genScanOp(converter, symTable, semaCtx, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_section:
llvm_unreachable("genOMPDispatch: OMPD_section");
Expand Down
31 changes: 26 additions & 5 deletions flang/lib/Lower/OpenMP/ReductionProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ static llvm::cl::opt<bool> forceByrefReduction(
llvm::cl::desc("Pass all reduction arguments by reference"),
llvm::cl::Hidden);

using ReductionModifier =
Fortran::lower::omp::clause::Reduction::ReductionModifier;

namespace Fortran {
namespace lower {
namespace omp {
Expand Down Expand Up @@ -514,18 +517,36 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}

void ReductionProcessor::addDeclareReduction(
mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
switch (mod) {
case ReductionModifier::Default:
return mlir::omp::ReductionModifier::defaultmod;
case ReductionModifier::Inscan:
return mlir::omp::ReductionModifier::inscan;
case ReductionModifier::Task:
return mlir::omp::ReductionModifier::task;
}
return mlir::omp::ReductionModifier::defaultmod;
}

void ReductionProcessor::processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr &reductionMod) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
reduction.t))
TODO(currentLocation, "Reduction modifiers are not supported");
auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
if (mod.has_value()) {
if (mod.value() == ReductionModifier::Task)
TODO(currentLocation, "Reduction modifier `task` is not supported");
else
reductionMod = mlir::omp::ReductionModifierAttr::get(
firOpBuilder.getContext(), translateReductionModifier(mod.value()));
}
tblah marked this conversation as resolved.
Show resolved Hide resolved

mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{
Expand Down
6 changes: 4 additions & 2 deletions flang/lib/Lower/OpenMP/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"

Expand Down Expand Up @@ -120,13 +121,14 @@ class ReductionProcessor {

/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
static void addDeclareReduction(
static void processReductionArguments(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr &reductionMod);
};

template <typename FloatOp, typename IntegerOp>
Expand Down
15 changes: 0 additions & 15 deletions flang/test/Lower/OpenMP/Todo/reduction-inscan.f90

This file was deleted.

14 changes: 0 additions & 14 deletions flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90

This file was deleted.

2 changes: 1 addition & 1 deletion flang/test/Lower/OpenMP/Todo/reduction-task.f90
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s

! CHECK: not yet implemented: Reduction modifiers are not supported
! CHECK: not yet implemented: Reduction modifier `task` is not supported
subroutine reduction_task()
integer :: i
i = 0
Expand Down
36 changes: 36 additions & 0 deletions flang/test/Lower/OpenMP/scan.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
! RUN: bbc -emit-hlfir -fopenmp %s -o - | FileCheck %s
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s

! CHECK: omp.wsloop reduction(mod: inscan, @add_reduction_i32 %{{.*}} -> %[[RED_ARG_1:.*]] : {{.*}}) {
! CHECK: %[[RED_DECL_1:.*]]:2 = hlfir.declare %[[RED_ARG_1]]
! CHECK: omp.scan inclusive(%[[RED_DECL_1]]#1 : {{.*}})

subroutine inclusive_scan(a, b, n)
implicit none
integer a(:), b(:)
integer x, k, n

!$omp parallel do reduction(inscan, +: x)
do k = 1, n
x = x + a(k)
!$omp scan inclusive(x)
b(k) = x
end do
end subroutine inclusive_scan


! CHECK: omp.wsloop reduction(mod: inscan, @add_reduction_i32 %{{.*}} -> %[[RED_ARG_2:.*]] : {{.*}}) {
! CHECK: %[[RED_DECL_2:.*]]:2 = hlfir.declare %[[RED_ARG_2]]
! CHECK: omp.scan exclusive(%[[RED_DECL_2]]#1 : {{.*}})
subroutine exclusive_scan(a, b, n)
implicit none
integer a(:), b(:)
integer x, k, n

!$omp parallel do reduction(inscan, +: x)
do k = 1, n
x = x + a(k)
!$omp scan exclusive(x)
b(k) = x
end do
end subroutine exclusive_scan
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality(
target.addDynamicallyLegalOp<
omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
omp::MapInfoOp, omp::OrderedOp, omp::ScanOp, omp::TargetEnterDataOp,
omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
omp::YieldOp>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
Expand Down Expand Up @@ -264,6 +264,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
RegionLessOpConversion<omp::CancelOp>,
RegionLessOpConversion<omp::CriticalDeclareOp>,
RegionLessOpConversion<omp::OrderedOp>,
RegionLessOpConversion<omp::ScanOp>,
RegionLessOpConversion<omp::TargetEnterDataOp>,
RegionLessOpConversion<omp::TargetExitDataOp>,
RegionLessOpConversion<omp::TargetUpdateOp>,
Expand Down