Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,11 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
llvm::SmallDenseSet<unsigned> operandPositionsToShadow;
llvm::SmallDenseSet<unsigned> resultPositionsToShadow;

// while these operands are inactive in the op region(s), we may still need to
// create placeholder shadows for them to ensure syntactic correctness for the
// IR
llvm::SmallDenseSet<unsigned> constOperandPositionsToShadow;

SmallVector<RegionSuccessor> entrySuccessors;
regionBranchOp.getEntrySuccessorRegions(
SmallVector<Attribute>(op->getNumOperands(), Attribute()),
Expand All @@ -352,8 +357,44 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
// operands.
for (auto &&[i, regionValue, operand] :
llvm::enumerate(targetValues, operandRange)) {
if (gutils->isConstantValue(regionValue))
continue;

// check if we need to create a shadow for an inactive region value
if (gutils->isConstantValue(regionValue)) {

// if all the possible predecessors for this value are also const, then
// we can skip creating a shadow. Else we need to create a shadow for
// syntactic correctness

SmallVector<Value> possibleActivePreds;
SmallVector<RegionBranchPoint> predecessors;
regionBranchOp.getPredecessors(successor, predecessors);
for (RegionBranchPoint predecessor : predecessors) {
if (predecessor.isParent()) {
// if the predecessor is the parent itself, then it's just
// operand!
possibleActivePreds.push_back(operand);
continue;
}
auto terminator = predecessor.getTerminatorPredecessorOrNull();
auto predecessorOperands = terminator.getSuccessorOperands(successor);
if (i < predecessorOperands.size())
possibleActivePreds.push_back(predecessorOperands[i]);
}

bool skipOpShadow = true;
for (auto pv : possibleActivePreds) {
if (!skipOpShadow)
break;
skipOpShadow = skipOpShadow && gutils->isConstantValue(pv);
};
// if there's any possible active predecessor, we create a shadow for it
if (!skipOpShadow)
constOperandPositionsToShadow.insert(
operandRange.getBeginOperandIndex() + i);
if (skipOpShadow)
continue;
}

operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i);
if (successor.isParent())
resultPositionsToShadow.insert(i);
Expand All @@ -365,13 +406,16 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
resultPositionsToShadow.insert(res.getResultNumber());

return controlFlowForwardHandler(
op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow);
op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow,
constOperandPositionsToShadow);
}

LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
Operation *op, OpBuilder &builder, MGradientUtils *gutils,
const llvm::SmallDenseSet<unsigned> &operandPositionsToShadow,
const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow) {
const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow,
const llvm::SmallDenseSet<unsigned> &constOperandPositionToShadow) {

// For all active results, add shadow types.
// For now, assuming all results are relevant.
Operation *newOp = gutils->getNewFromOriginal(op);
Expand Down Expand Up @@ -423,6 +467,74 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
replacementRegion.takeBody(region);
}

// Re-fix block args for all successor regions
// Even though createWithShadows properly creates the differentiated control
// flow op(accounting for any const args which might have shadows),
// takeBody(...) replaces the successor regions entirely, including the block
// arguments. We fix the block arguments here for the entry successor regions.
if (!constOperandPositionToShadow.empty()) {
auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op);
auto iface = dyn_cast<ControlFlowAutoDiffOpInterface>(op);

SmallVector<RegionSuccessor> entrySuccessors;
regionBranchOp.getEntrySuccessorRegions(
SmallVector<Attribute>(op->getNumOperands(), Attribute()),
entrySuccessors);

for (const RegionSuccessor &successor : entrySuccessors) {

if (successor.isParent())
continue;

OperandRange oldRegionOperands =
iface.getSuccessorOperands(regionBranchOp, successor);
ValueRange oldRegionInputs = regionBranchOp.getSuccessorInputs(successor);

// the new region corresponding to this successor(we want to modify the
// arguments of this region in-place)
auto &newRegion =
replacement->getRegion(successor.getSuccessor()->getRegionNumber());

for (int i = oldRegionInputs.size() - 1; i >= 0; --i) {
unsigned operandPosition = oldRegionOperands.getBeginOperandIndex() + i;
if (!constOperandPositionToShadow.contains(operandPosition))
continue;

auto oldRegionInput = dyn_cast<BlockArgument>(oldRegionInputs[i]);

if (!oldRegionInput ||
gutils->invertedPointers.contains(oldRegionInput))
continue;

auto newRegionInput =
cast<BlockArgument>(gutils->getNewFromOriginal(oldRegionInput));

auto typeIface =
dyn_cast<AutoDiffTypeInterface>(oldRegionInput.getType());

if (!typeIface) {
op->emitError() << " AutoDiffTypeInterface not implemented for "
<< oldRegionInput.getType() << "\n";
return failure();
}

Value newRegionShadow;
if (newRegionInput.getArgNumber() == newRegion.getNumArguments() - 1) {
newRegionShadow = newRegion.addArgument(
typeIface.getShadowType(gutils->width), newRegionInput.getLoc());
} else {
// insert at position i+1
newRegionShadow = newRegion.insertArgument(
newRegion.args_begin() + newRegionInput.getArgNumber() + 1,
typeIface.getShadowType(gutils->width), newRegionInput.getLoc());
}

// update the inverted pointer map
gutils->invertedPointers.map(oldRegionInput, newRegionShadow);
}
}
}

// Inject the mapping for the new results into GradientUtil's shadow
// table.
SmallVector<Value> reps;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder,
LogicalResult controlFlowForwardHandler(
Operation *op, OpBuilder &builder, MGradientUtils *gutils,
const llvm::SmallDenseSet<unsigned> &operandPositionsToShadow,
const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow);
const llvm::SmallDenseSet<unsigned> &resultPositionsToShadow,
const llvm::SmallDenseSet<unsigned> &constOperandPositionToShadow);

// Implements forward-mode differentiation of branching operations.
// Assumes that successive shadows are legal
Expand Down
51 changes: 51 additions & 0 deletions enzyme/test/MLIR/ForwardMode/for3.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// RUN: %eopt --enzyme %s | FileCheck %s

module {
func.func @carry_mismatch_scf(%x : f64) -> f64 {
%zero = arith.constant 0.0 : f64
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
%r = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %zero) -> (f64) {
scf.yield %x : f64
}
return %r : f64
}

func.func @dcarry_mismatch_scf(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @carry_mismatch_scf(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
return %r : f64
}

func.func @carry_mismatch_affine(%x : f64) -> f64 {
%zero = arith.constant 0.0 : f64
%r = affine.for %i = 0 to 10 iter_args(%acc = %zero) -> (f64) {
affine.yield %x : f64
}
return %r : f64
}

func.func @dcarry_mismatch_affine(%x : f64, %dx : f64) -> f64 {
%r = enzyme.fwddiff @carry_mismatch_affine(%x, %dx) { activity=[#enzyme<activity enzyme_dup>], ret_activity=[#enzyme<activity enzyme_dupnoneed>] } : (f64, f64) -> (f64)
return %r : f64
}
}

// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_scf(
// CHECK-DAG: %[[ZERO0:.+]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[ZERO1:.+]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
// CHECK: %[[LOOP:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[DACC:.+]] = %[[ZERO1]], %[[ACC:.+]] = %[[ZERO0]]) -> (f64, f64) {
// CHECK-NEXT: scf.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64
// CHECK-NEXT: }
// CHECK-NEXT: return %[[LOOP]]#1 : f64

// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_affine(
// CHECK-DAG: %[[AZERO0:.+]] = arith.constant 0.000000e+00 : f64
// CHECK-DAG: %[[AZERO1:.+]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[ALOOP:.+]]:2 = affine.for %[[AIV:.+]] = 0 to 10 iter_args(%[[ADACC:.+]] = %[[AZERO1]], %[[AACC:.+]] = %[[AZERO0]]) -> (f64, f64) {
// CHECK-NEXT: affine.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64
// CHECK-NEXT: }
// CHECK-NEXT: return %[[ALOOP]]#1 : f64
Loading