Skip to content

Commit

Permalink
Changes for invoking scan Op
Browse files Browse the repository at this point in the history
  • Loading branch information
anchuraj committed Jan 22, 2025
1 parent afcbcae commit fccf0c6
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 45 deletions.
39 changes: 35 additions & 4 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,22 @@ bool ClauseProcessor::processDistSchedule(
return false;
}

bool ClauseProcessor::processExclusive(
mlir::Location currentLocation,
mlir::omp::ExclusiveClauseOps &result) const {
return findRepeatableClause<omp::clause::Exclusive>(
[&](const omp::clause::Exclusive &clause, const parser::CharBlock &) {
for (const Object &object : clause.v) {
semantics::Symbol *sym = object.sym();
mlir::Value symVal = converter.getSymbolAddress(*sym);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
symVal = declOp.getBase();
}
result.exclusiveVars.push_back(symVal);
}
});
}

bool ClauseProcessor::processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Filter>()) {
Expand Down Expand Up @@ -380,6 +396,22 @@ bool ClauseProcessor::processHint(mlir::omp::HintClauseOps &result) const {
return false;
}

bool ClauseProcessor::processInclusive(
mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const {
return findRepeatableClause<omp::clause::Inclusive>(
[&](const omp::clause::Inclusive &clause, const parser::CharBlock &) {
for (const Object &object : clause.v) {
const semantics::Symbol *symbol = object.sym();
mlir::Value symVal = converter.getSymbolAddress(*symbol);
// if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
// symVal = declOp.getBase();
// }
result.inclusiveVars.push_back(symVal);
}
});
}

bool ClauseProcessor::processMergeable(
mlir::omp::MergeableClauseOps &result) const {
return markClauseOccurrence<omp::clause::Mergeable>(result.mergeable);
Expand Down Expand Up @@ -1135,10 +1167,9 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(currentLocation, converter, clause,
reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms);

rp.addDeclareReduction(
currentLocation, converter, clause, reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms, &result.reductionMod);
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
Expand Down
4 changes: 4 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class ClauseProcessor {
bool processDeviceType(mlir::omp::DeviceTypeClauseOps &result) const;
bool processDistSchedule(lower::StatementContext &stmtCtx,
mlir::omp::DistScheduleClauseOps &result) const;
bool processExclusive(mlir::Location currentLocation,
mlir::omp::ExclusiveClauseOps &result) const;
bool processFilter(lower::StatementContext &stmtCtx,
mlir::omp::FilterClauseOps &result) const;
bool processFinal(lower::StatementContext &stmtCtx,
Expand All @@ -72,6 +74,8 @@ class ClauseProcessor {
mlir::omp::HasDeviceAddrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processHint(mlir::omp::HintClauseOps &result) const;
bool processInclusive(mlir::Location currentLocation,
mlir::omp::InclusiveClauseOps &result) const;
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
bool processNumTeams(lower::StatementContext &stmtCtx,
Expand Down
8 changes: 4 additions & 4 deletions flang/lib/Lower/OpenMP/Clauses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -728,8 +728,8 @@ Enter make(const parser::OmpClause::Enter &inp,

Exclusive make(const parser::OmpClause::Exclusive &inp,
semantics::SemanticsContext &semaCtx) {
// inp -> empty
llvm_unreachable("Empty: exclusive");
// inp.v -> parser::OmpObjectList
return Exclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}

Fail make(const parser::OmpClause::Fail &inp,
Expand Down Expand Up @@ -838,8 +838,8 @@ If make(const parser::OmpClause::If &inp,

Inclusive make(const parser::OmpClause::Inclusive &inp,
semantics::SemanticsContext &semaCtx) {
// inp -> empty
llvm_unreachable("Empty: inclusive");
// inp.v -> parser::OmpObjectList
return Inclusive{makeObjects(/*List=*/inp.v, semaCtx)};
}

Indirect make(const parser::OmpClause::Indirect &inp,
Expand Down
22 changes: 21 additions & 1 deletion flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,15 @@ static void genParallelClauses(
cp.processReduction(loc, clauseOps, reductionSyms);
}

static void genScanClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::ScanOperands &clauseOps) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processInclusive(loc, clauseOps);
cp.processExclusive(loc, clauseOps);
}

static void genSectionsClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
Expand Down Expand Up @@ -1975,6 +1984,17 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return parallelOp;
}

static mlir::omp::ScanOp
genScanOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, mlir::Location loc,
const ConstructQueue &queue, ConstructQueue::const_iterator item) {

mlir::omp::ScanOperands clauseOps;
genScanClauses(converter, semaCtx, item->clauses, loc, clauseOps);
return converter.getFirOpBuilder().create<mlir::omp::ScanOp>(
converter.getCurrentLocation(), clauseOps);
}

/// This breaks the normal prototype of the gen*Op functions: adding the
/// sectionBlocks argument so that the enclosed section constructs can be
/// lowered here with correct reduction symbol remapping.
Expand Down Expand Up @@ -2978,7 +2998,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
genStandaloneParallel(converter, symTable, semaCtx, eval, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_scan:
TODO(loc, "Unhandled directive " + llvm::omp::getOpenMPDirectiveName(dir));
genScanOp(converter, symTable, semaCtx, loc, queue, item);
break;
case llvm::omp::Directive::OMPD_section:
llvm_unreachable("genOMPDispatch: OMPD_section");
Expand Down
32 changes: 28 additions & 4 deletions flang/lib/Lower/OpenMP/ReductionProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
#include <string>

static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
Expand Down Expand Up @@ -514,18 +515,36 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}

mlir::omp::ReductionModifier
translateReductionModifier(const ReductionModifier &m) {
switch (m) {
case ReductionModifier::Default:
return mlir::omp::ReductionModifier::defaultmod;
case ReductionModifier::Inscan:
return mlir::omp::ReductionModifier::inscan;
case ReductionModifier::Task:
return mlir::omp::ReductionModifier::task;
}
return mlir::omp::ReductionModifier::defaultmod;
}

void ReductionProcessor::addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr *reductionMod) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
reduction.t))
TODO(currentLocation, "Reduction modifiers are not supported");
auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
if (mod.has_value() && (mod.value() != ReductionModifier::Inscan)) {
std::string modStr = "default";
if (mod.value() == ReductionModifier::Task)
modStr = "task";
TODO(currentLocation, "Reduction modifier " + modStr + " is not supported");
}

mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{
Expand Down Expand Up @@ -649,6 +668,11 @@ void ReductionProcessor::addDeclareReduction(
currentLocation, isByRef);
reductionDeclSymbols.push_back(
mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
auto redMod = std::get<std::optional<ReductionModifier>>(reduction.t);
if (redMod.has_value())
*reductionMod = mlir::omp::ReductionModifierAttr::get(
firOpBuilder.getContext(),
translateReductionModifier(redMod.value()));
}
}

Expand Down
6 changes: 5 additions & 1 deletion flang/lib/Lower/OpenMP/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "flang/Parser/parse-tree.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/type.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"

Expand Down Expand Up @@ -126,7 +127,8 @@ class ReductionProcessor {
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
mlir::omp::ReductionModifierAttr *reductionMod);
};

template <typename FloatOp, typename IntegerOp>
Expand Down Expand Up @@ -156,6 +158,8 @@ ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
return builder.create<ComplexOp>(loc, op1, op2);
}

using ReductionModifier = omp::clause::Reduction::ReductionModifier;

} // namespace omp
} // namespace lower
} // namespace Fortran
Expand Down
15 changes: 0 additions & 15 deletions flang/test/Lower/OpenMP/Todo/reduction-inscan.f90

This file was deleted.

14 changes: 0 additions & 14 deletions flang/test/Lower/OpenMP/Todo/reduction-modifiers.f90

This file was deleted.

2 changes: 1 addition & 1 deletion flang/test/Lower/OpenMP/Todo/reduction-task.f90
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
! RUN: %not_todo_cmd bbc -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s
! RUN: %not_todo_cmd %flang_fc1 -emit-fir -fopenmp -o - %s 2>&1 | FileCheck %s

! CHECK: not yet implemented: Reduction modifiers are not supported
! CHECK: not yet implemented: Reduction modifier task is not supported
subroutine reduction_task()
integer :: i
i = 0
Expand Down
34 changes: 34 additions & 0 deletions flang/test/Lower/OpenMP/scan.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s

subroutine inclusive_scan
implicit none
integer, parameter :: n = 100
integer a(n), b(n)
integer x, k

!CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
!$omp parallel do reduction(inscan, +: x)
do k = 1, n
x = x + a(k)
!CHECK: omp.scan inclusive({{.*}})
!$omp scan inclusive(x)
b(k) = x
end do
end subroutine inclusive_scan


subroutine exclusive_scan
implicit none
integer, parameter :: n = 100
integer a(n), b(n)
integer x, k

!CHECK: omp.wsloop reduction(mod: inscan, {{.*}}) {
!$omp parallel do reduction(inscan, +: x)
do k = 1, n
x = x + a(k)
!CHECK: omp.scan exclusive({{.*}})
!$omp scan exclusive(x)
b(k) = x
end do
end subroutine exclusive_scan
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void mlir::configureOpenMPToLLVMConversionLegality(
target.addDynamicallyLegalOp<
omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
omp::CancelOp, omp::CriticalDeclareOp, omp::FlushOp, omp::MapBoundsOp,
omp::MapInfoOp, omp::OrderedOp, omp::TargetEnterDataOp,
omp::MapInfoOp, omp::ScanOp, omp::OrderedOp, omp::TargetEnterDataOp,
omp::TargetExitDataOp, omp::TargetUpdateOp, omp::ThreadprivateOp,
omp::YieldOp>([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
Expand Down Expand Up @@ -264,6 +264,7 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
RegionLessOpConversion<omp::CancelOp>,
RegionLessOpConversion<omp::CriticalDeclareOp>,
RegionLessOpConversion<omp::OrderedOp>,
RegionLessOpConversion<omp::ScanOp>,
RegionLessOpConversion<omp::TargetEnterDataOp>,
RegionLessOpConversion<omp::TargetExitDataOp>,
RegionLessOpConversion<omp::TargetUpdateOp>,
Expand Down

0 comments on commit fccf0c6

Please sign in to comment.