diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index f88d212e8c76..a1b046f0e8ab 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -334,6 +334,11 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( llvm::SmallDenseSet operandPositionsToShadow; llvm::SmallDenseSet resultPositionsToShadow; + // while these operands are inactive in the op region(s), we may still need to + // create placeholder shadows for them to ensure syntactic correctness for the + // IR + llvm::SmallDenseSet constOperandPositionsToShadow; + SmallVector entrySuccessors; regionBranchOp.getEntrySuccessorRegions( SmallVector(op->getNumOperands(), Attribute()), @@ -352,8 +357,44 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( // operands. for (auto &&[i, regionValue, operand] : llvm::enumerate(targetValues, operandRange)) { - if (gutils->isConstantValue(regionValue)) - continue; + + // check if we need to create a shadow for an inactive region value + if (gutils->isConstantValue(regionValue)) { + + // if all the possible predecessors for this value are also const, then + // we can skip creating a shadow. Else we need to create a shadow for + // syntactic correctness + + SmallVector possibleActivePreds; + SmallVector predecessors; + regionBranchOp.getPredecessors(successor, predecessors); + for (RegionBranchPoint predecessor : predecessors) { + if (predecessor.isParent()) { + // if the predecessor is the parent itself, then it's just + // operand! + possibleActivePreds.push_back(operand); + continue; + } + auto terminator = predecessor.getTerminatorPredecessorOrNull(); + auto predecessorOperands = terminator.getSuccessorOperands(successor); + if (i < predecessorOperands.size()) + possibleActivePreds.push_back(predecessorOperands[i]); + } + + bool skipOpShadow = true; + for (auto pv : possibleActivePreds) { + if (!skipOpShadow) + break; + skipOpShadow = skipOpShadow && gutils->isConstantValue(pv); + }; + // if there's any possible active predecessor, we create a shadow for it + if (!skipOpShadow) + constOperandPositionsToShadow.insert( + operandRange.getBeginOperandIndex() + i); + if (skipOpShadow) + continue; + } + operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i); if (successor.isParent()) resultPositionsToShadow.insert(i); @@ -365,13 +406,16 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( resultPositionsToShadow.insert(res.getResultNumber()); return controlFlowForwardHandler( - op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow); + op, builder, gutils, operandPositionsToShadow, resultPositionsToShadow, + constOperandPositionsToShadow); } LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( Operation *op, OpBuilder &builder, MGradientUtils *gutils, const llvm::SmallDenseSet &operandPositionsToShadow, - const llvm::SmallDenseSet &resultPositionsToShadow) { + const llvm::SmallDenseSet &resultPositionsToShadow, + const llvm::SmallDenseSet &constOperandPositionToShadow) { + // For all active results, add shadow types. // For now, assuming all results are relevant. Operation *newOp = gutils->getNewFromOriginal(op); @@ -423,6 +467,74 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( replacementRegion.takeBody(region); } + // Re-fix block args for all successor regions + // Even though createWithShadows properly creates the differentiated control + // flow op(accounting for any const args which might have shadows), + // takeBody(...) replaces the successor regions entirely, including the block + // arguments. We fix the block arguments here for the entry successor regions. + if (!constOperandPositionToShadow.empty()) { + auto regionBranchOp = dyn_cast(op); + auto iface = dyn_cast(op); + + SmallVector entrySuccessors; + regionBranchOp.getEntrySuccessorRegions( + SmallVector(op->getNumOperands(), Attribute()), + entrySuccessors); + + for (const RegionSuccessor &successor : entrySuccessors) { + + if (successor.isParent()) + continue; + + OperandRange oldRegionOperands = + iface.getSuccessorOperands(regionBranchOp, successor); + ValueRange oldRegionInputs = regionBranchOp.getSuccessorInputs(successor); + + // the new region corresponding to this successor(we want to modify the + // arguments of this region in-place) + auto &newRegion = + replacement->getRegion(successor.getSuccessor()->getRegionNumber()); + + for (int i = oldRegionInputs.size() - 1; i >= 0; --i) { + unsigned operandPosition = oldRegionOperands.getBeginOperandIndex() + i; + if (!constOperandPositionToShadow.contains(operandPosition)) + continue; + + auto oldRegionInput = dyn_cast(oldRegionInputs[i]); + + if (!oldRegionInput || + gutils->invertedPointers.contains(oldRegionInput)) + continue; + + auto newRegionInput = + cast(gutils->getNewFromOriginal(oldRegionInput)); + + auto typeIface = + dyn_cast(oldRegionInput.getType()); + + if (!typeIface) { + op->emitError() << " AutoDiffTypeInterface not implemented for " + << oldRegionInput.getType() << "\n"; + return failure(); + } + + Value newRegionShadow; + if (newRegionInput.getArgNumber() == newRegion.getNumArguments() - 1) { + newRegionShadow = newRegion.addArgument( + typeIface.getShadowType(gutils->width), newRegionInput.getLoc()); + } else { + // insert at position i+1 + newRegionShadow = newRegion.insertArgument( + newRegion.args_begin() + newRegionInput.getArgNumber() + 1, + typeIface.getShadowType(gutils->width), newRegionInput.getLoc()); + } + + // update the inverted pointer map + gutils->invertedPointers.map(oldRegionInput, newRegionShadow); + } + } + } + // Inject the mapping for the new results into GradientUtil's shadow // table. SmallVector reps; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index dabce9ecd520..b4df88118bd2 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -40,7 +40,8 @@ LogicalResult controlFlowForwardHandler(Operation *op, OpBuilder &builder, LogicalResult controlFlowForwardHandler( Operation *op, OpBuilder &builder, MGradientUtils *gutils, const llvm::SmallDenseSet &operandPositionsToShadow, - const llvm::SmallDenseSet &resultPositionsToShadow); + const llvm::SmallDenseSet &resultPositionsToShadow, + const llvm::SmallDenseSet &constOperandPositionToShadow); // Implements forward-mode differentiation of branching operations. // Assumes that successive shadows are legal diff --git a/enzyme/test/MLIR/ForwardMode/for3.mlir b/enzyme/test/MLIR/ForwardMode/for3.mlir new file mode 100644 index 000000000000..f67775a80c4c --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/for3.mlir @@ -0,0 +1,51 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @carry_mismatch_scf(%x : f64) -> f64 { + %zero = arith.constant 0.0 : f64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %zero) -> (f64) { + scf.yield %x : f64 + } + return %r : f64 + } + + func.func @dcarry_mismatch_scf(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @carry_mismatch_scf(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } + + func.func @carry_mismatch_affine(%x : f64) -> f64 { + %zero = arith.constant 0.0 : f64 + %r = affine.for %i = 0 to 10 iter_args(%acc = %zero) -> (f64) { + affine.yield %x : f64 + } + return %r : f64 + } + + func.func @dcarry_mismatch_affine(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @carry_mismatch_affine(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } +} + +// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_scf( +// CHECK-DAG: %[[ZERO0:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[ZERO1:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index +// CHECK: %[[LOOP:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[C10]] step %[[C1]] iter_args(%[[DACC:.+]] = %[[ZERO1]], %[[ACC:.+]] = %[[ZERO0]]) -> (f64, f64) { +// CHECK-NEXT: scf.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[LOOP]]#1 : f64 + +// CHECK-LABEL: func.func private @fwddiffecarry_mismatch_affine( +// CHECK-DAG: %[[AZERO0:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[AZERO1:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[ALOOP:.+]]:2 = affine.for %[[AIV:.+]] = 0 to 10 iter_args(%[[ADACC:.+]] = %[[AZERO1]], %[[AACC:.+]] = %[[AZERO0]]) -> (f64, f64) { +// CHECK-NEXT: affine.yield %[[ARG0:.+]], %[[ARG1:.+]] : f64, f64 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[ALOOP]]#1 : f64