Skip to content

Commit

Permalink
Decomposition updates for quake.control types
Browse files Browse the repository at this point in the history
  • Loading branch information
bmhowe23 committed Sep 3, 2024
1 parent 995e111 commit 2bb9369
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 19 deletions.
94 changes: 75 additions & 19 deletions lib/Optimizer/Transforms/DecompositionPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "cudaq/Optimizer/Dialect/CC/CCOps.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"

Expand Down Expand Up @@ -40,10 +41,12 @@ inline Value createDivF(Location loc, Value numerator, double denominator,
return rewriter.create<arith::DivFOp>(loc, numerator, denominatorValue);
}

/// @brief Returns true if \p op contains any `ControlType` operands.
inline bool containsControlTypes(quake::OperatorInterface op) {
return llvm::any_of(op.getControls(), [](const Value &v) {
return v.getType().isa<quake::ControlType>();
/// @brief Returns true if \p op contains any `ControlType` operands that are
/// used outside of the block where \p op resides.
inline bool containsControlsUsedOutsideBlock(quake::OperatorInterface op) {
return llvm::any_of(op.getControls(), [op](Value v) {
return v.getType().isa<quake::ControlType>() &&
v.isUsedOutsideOfBlock(op->getBlock());
});
}

Expand Down Expand Up @@ -82,6 +85,37 @@ class QuakeOperatorCreator {
quake::WireType::get(rewriter.getContext()));
}

/// @brief Promote all `quake.control` types in \p controls to `quake.wire`
/// types.
void promoteControls(Location loc, MutableArrayRef<Value> controls) {
origControls.assign(controls.begin(), controls.end());
for (auto &c : controls)
if (c.getType().isa<quake::ControlType>())
c = rewriter.create<quake::FromControlOp>(
loc, quake::WireType::get(rewriter.getContext()), c);
}

/// @brief Perform necessary conversion of `quake.wire` values to
/// `quake.control` values (complementing `promoteControls`). This also
/// replaces downstream uses of the original control with the new control in
/// case the decomposition modified the control.
void demoteWiresToControlsAndReplaceUses(Operation *op,
MutableArrayRef<Value> controls) {
auto ctrlTy = quake::ControlType::get(rewriter.getContext());
auto loc = op->getLoc();
DominanceInfo domInfo(op->getParentOfType<func::FuncOp>());
for (auto &&[oc, c] : llvm::zip_equal(origControls, controls))
if (oc.getType().isa<quake::ControlType>()) {
c = rewriter.create<quake::ToControlOp>(loc, ctrlTy, c);
// This is like oc.replaceAllUsesWith(c) except it checks for
// proper dominance before doing the replacement. This is important
// because the rewriter tends to work from the bottom up.
for (auto &use : llvm::make_early_inc_range(oc.getUses()))
if (domInfo.properlyDominates(c, use.getOwner()))
use.set(c);
}
}

/// Pluck out the values from \p newValues whose type is `WireType` and
/// replace all the \p op uses with those values.
void selectWiresAndReplaceUses(Operation *op, ValueRange newValues) {
Expand Down Expand Up @@ -224,6 +258,9 @@ class QuakeOperatorCreator {

private:
PatternRewriter &rewriter;
/// The original control values before some may have been promoted from
/// quake.control to quake.wire.
SmallVector<Value> origControls;
};

/// Check whether the operation has the correct number of controls.
Expand Down Expand Up @@ -661,11 +698,6 @@ struct CXToCZ : public OpRewritePattern<quake::XOp> {
PatternRewriter &rewriter) const override {
if (failed(checkNumControls(op, 1)))
return failure();
// This decomposition does not support `quake.control` types because the
// input controls are used as targets during this transformation.
if (containsControlTypes(op))
return failure();

// Op info
Location loc = op->getLoc();
Value target = op.getTarget();
Expand All @@ -675,7 +707,15 @@ struct CXToCZ : public OpRewritePattern<quake::XOp> {
if (negatedControls)
negControl = (*negatedControls)[0];

// TODO - Update this pattern to support threading modified controls
// throughout multiple blocks. qRewriter.demoteWiresToControlsAndReplaceUses
// does not currently handle that.
if (negControl && containsControlsUsedOutsideBlock(op))
return failure();

QuakeOperatorCreator qRewriter(rewriter);
if (negControl)
qRewriter.promoteControls(loc, controls);
qRewriter.create<quake::HOp>(loc, target);
if (negControl)
qRewriter.create<quake::XOp>(loc, controls);
Expand All @@ -684,6 +724,8 @@ struct CXToCZ : public OpRewritePattern<quake::XOp> {
qRewriter.create<quake::XOp>(loc, controls);
qRewriter.create<quake::HOp>(loc, target);

if (negControl)
qRewriter.demoteWiresToControlsAndReplaceUses(op, controls);
qRewriter.selectWiresAndReplaceUses(op, controls, target);
rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -813,11 +855,6 @@ struct CCZToCX : public OpRewritePattern<quake::ZOp> {

LogicalResult matchAndRewrite(quake::ZOp op,
PatternRewriter &rewriter) const override {
// This decomposition does not support `quake.control` types because the
// input controls are used as targets during this transformation.
if (containsControlTypes(op))
return failure();

SmallVector<Value, 2> controls(2);
if (failed(checkAndExtractControls(op, controls, rewriter)))
return failure();
Expand All @@ -843,7 +880,14 @@ struct CCZToCX : public OpRewritePattern<quake::ZOp> {
}
}

// TODO - Update this pattern to support threading modified controls
// throughout multiple blocks. qRewriter.demoteWiresToControlsAndReplaceUses
// does not currently handle that.
if (containsControlsUsedOutsideBlock(op))
return failure();

QuakeOperatorCreator qRewriter(rewriter);
qRewriter.promoteControls(loc, controls);
qRewriter.create<quake::XOp>(loc, controls[1], target);
qRewriter.create<quake::TOp>(loc, /*isAdj=*/!negC0, target);
qRewriter.create<quake::XOp>(loc, controls[0], target);
Expand All @@ -860,6 +904,7 @@ struct CCZToCX : public OpRewritePattern<quake::ZOp> {

qRewriter.create<quake::TOp>(loc, /*isAdj=*/negC1, controls[0]);

qRewriter.demoteWiresToControlsAndReplaceUses(op, controls);
qRewriter.selectWiresAndReplaceUses(op, controls, target);
rewriter.eraseOp(op);
return success();
Expand All @@ -878,10 +923,6 @@ struct CZToCX : public OpRewritePattern<quake::ZOp> {

LogicalResult matchAndRewrite(quake::ZOp op,
PatternRewriter &rewriter) const override {
// This decomposition does not support `quake.control` types because the
// input controls are used as targets during this transformation.
if (containsControlTypes(op))
return failure();
if (failed(checkNumControls(op, 1)))
return failure();

Expand All @@ -894,7 +935,15 @@ struct CZToCX : public OpRewritePattern<quake::ZOp> {
if (negatedControls)
negControl = (*negatedControls)[0];

// TODO - Update this pattern to support threading modified controls
// throughout multiple blocks. qRewriter.demoteWiresToControlsAndReplaceUses
// does not currently handle that.
if (negControl && containsControlsUsedOutsideBlock(op))
return failure();

QuakeOperatorCreator qRewriter(rewriter);
if (negControl)
qRewriter.promoteControls(loc, controls);
qRewriter.create<quake::HOp>(loc, target);
if (negControl)
qRewriter.create<quake::XOp>(loc, controls);
Expand All @@ -903,6 +952,8 @@ struct CZToCX : public OpRewritePattern<quake::ZOp> {
qRewriter.create<quake::XOp>(loc, controls);
qRewriter.create<quake::HOp>(loc, target);

if (negControl)
qRewriter.demoteWiresToControlsAndReplaceUses(op, controls);
qRewriter.selectWiresAndReplaceUses(op, controls, target);
rewriter.eraseOp(op);
return success();
Expand Down Expand Up @@ -969,7 +1020,10 @@ struct CR1ToCX : public OpRewritePattern<quake::R1Op> {

LogicalResult matchAndRewrite(quake::R1Op op,
PatternRewriter &rewriter) const override {
if (containsControlTypes(op))
// TODO - Update this pattern to support threading modified controls
// throughout multiple blocks. qRewriter.demoteWiresToControlsAndReplaceUses
// does not currently handle that.
if (containsControlsUsedOutsideBlock(op))
return failure();

Value control;
Expand All @@ -994,6 +1048,7 @@ struct CR1ToCX : public OpRewritePattern<quake::R1Op> {
Value negHalfAngle = rewriter.create<arith::NegFOp>(loc, halfAngle);

QuakeOperatorCreator qRewriter(rewriter);
qRewriter.promoteControls(loc, control);
qRewriter.create<quake::R1Op>(loc, /*isAdj*/ negControl, halfAngle,
noControls, control);
qRewriter.create<quake::XOp>(loc, control, target);
Expand All @@ -1002,6 +1057,7 @@ struct CR1ToCX : public OpRewritePattern<quake::R1Op> {
qRewriter.create<quake::XOp>(loc, control, target);
qRewriter.create<quake::R1Op>(loc, halfAngle, noControls, target);

qRewriter.demoteWiresToControlsAndReplaceUses(op, control);
qRewriter.selectWiresAndReplaceUses(op, ValueRange{control, target});
rewriter.eraseOp(op);
return success();
Expand Down
1 change: 1 addition & 0 deletions test/Transforms/DecompositionPatterns/CCZToCX.qke
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// RUN: cudaq-opt -pass-pipeline='builtin.module(decomposition{enable-patterns=CCZToCX})' %s | FileCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(decomposition{enable-patterns=CCZToCX})' %s | CircuitCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CCZToCX})' %s | FileCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg,pruned-ctrl-form),decomposition{enable-patterns=CCZToCX})' %s | FileCheck %s

// Test the decomposition pattern with different control types. The FileCheck
// part of this test only cares about the sequence of operations. Correcteness
Expand Down
1 change: 1 addition & 0 deletions test/Transforms/DecompositionPatterns/CR1ToCX.qke
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// RUN: cudaq-opt -pass-pipeline='builtin.module(decomposition{enable-patterns=CR1ToCX})' %s | CircuitCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CR1ToCX})' %s | FileCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CR1ToCX})' %s | CircuitCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg,pruned-ctrl-form),decomposition{enable-patterns=CR1ToCX})' %s | FileCheck %s

// Test the decomposition pattern with different control types. The FileCheck
// part of this test only cares about the sequence of operations. Correcteness
Expand Down
1 change: 1 addition & 0 deletions test/Transforms/DecompositionPatterns/CXToCZ.qke
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// RUN: cudaq-opt -pass-pipeline='builtin.module(decomposition{enable-patterns=CXToCZ})' %s | CircuitCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CXToCZ})' %s | FileCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CXToCZ})' %s | CircuitCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg,pruned-ctrl-form),decomposition{enable-patterns=CXToCZ})' %s | FileCheck %s

// Test the decomposition pattern with different control types. The FileCheck
// part of this test only cares about the sequence of operations. Correcteness
Expand Down
1 change: 1 addition & 0 deletions test/Transforms/DecompositionPatterns/CZToCX.qke
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// RUN: cudaq-opt -pass-pipeline='builtin.module(decomposition{enable-patterns=CZToCX})' %s | CircuitCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CZToCX})' %s | FileCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg),decomposition{enable-patterns=CZToCX})' %s | CircuitCheck %s
// RUN: cudaq-opt -pass-pipeline='builtin.module(func.func(expand-control-veqs,memtoreg,pruned-ctrl-form),decomposition{enable-patterns=CZToCX})' %s | FileCheck %s

// Test the decomposition pattern with different control types. The FileCheck
// part of this test only cares about the sequence of operations. Correcteness
Expand Down

0 comments on commit 2bb9369

Please sign in to comment.