Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion enzyme/Enzyme/CacheUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,13 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
alloc->setAlignment(Align(align));
}
if (sublimits.size() == 0) {
auto val = getUndefinedValueForType(*newFunc->getParent(), types.back());
// For i1 predicate caches, "not executed" must behave like false.
// Otherwise reverse CFG reconstruction may branch on undef/poison (issue
// #2629).
bool forceZero =
T->isIntegerTy() && cast<IntegerType>(T)->getBitWidth() == 1;
auto val = getUndefinedValueForType(*newFunc->getParent(), types.back(),
/*forceZero*/ forceZero);
if (!isa<UndefValue>(val))
scopeInstructions[alloc].push_back(entryBuilder.CreateStore(val, alloc));
}
Expand Down
103 changes: 39 additions & 64 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7868,7 +7868,12 @@ void GradientUtils::branchToCorrespondingTarget(
}
// llvm::errs() << "</DONE>\n";

if (targetToPreds.size() == 3) {
// NOTE: the 3-target "reuse two branch predicates" optimization constructs
// a synthetic staging block to evaluate the second predicate only under the
// first split. That requires actual control-flow. When replacePHIs != nullptr
// we must not use this fast-path (it would otherwise eagerly evaluate the
// inner predicate unconditionally and reproduce #2629 at -O0/-O1).
if (replacePHIs == nullptr && targetToPreds.size() == 3) {
// Try `block` as a potential first split point.
for (auto block : blocks) {
{
Expand Down Expand Up @@ -7986,73 +7991,43 @@ void GradientUtils::branchToCorrespondingTarget(
// the remainder of foundTargets.
auto cond1 = lookupM(bi1->getCondition(), BuilderM);

// Condition cond2 splits off the two blocks in
// (foundTargets-uniqueTargets) from each other.
auto cond2 = lookupM(bi2->getCondition(), BuilderM);
// Create a staging block so the second predicate is only evaluated
// on the path where the first split is taken (fixes #2629).
BasicBlock *staging =
BasicBlock::Create(oldFunc->getContext(), "staging", newFunc);

// `lookupM` requires reverse-only blocks to have an entry in
// reverseBlockToPrimal. Since `staging` is synthetic, map it to the
// same forward/primal block as the current insertion block.
BasicBlock *stagingFwd = BuilderM.GetInsertBlock();
if (!isOriginalBlock(*stagingFwd)) {
auto it = reverseBlockToPrimal.find(stagingFwd);
assert(it != reverseBlockToPrimal.end());
stagingFwd = it->second;
}
reverseBlockToPrimal[staging] = stagingFwd;

if (replacePHIs == nullptr) {
BasicBlock *staging =
BasicBlock::Create(oldFunc->getContext(), "staging", newFunc);
auto stagingIfNeeded = [&](BasicBlock *B) {
auto edge = std::make_pair(block, B);
if (done[edge].size() == 1) {
return *done[edge].begin();
} else {
assert(done[edge].size() == 2);
return staging;
}
};
BuilderM.CreateCondBr(cond1, stagingIfNeeded(bi1->getSuccessor(0)),
stagingIfNeeded(bi1->getSuccessor(1)));
BuilderM.SetInsertPoint(staging);
BuilderM.CreateCondBr(
cond2,
*done[std::make_pair(subblock, bi2->getSuccessor(0))].begin(),
*done[std::make_pair(subblock, bi2->getSuccessor(1))].begin());
} else {
Value *otherBranch = nullptr;
for (unsigned i = 0; i < 2; ++i) {
Value *val = cond1;
if (i == 1)
val = BuilderM.CreateNot(val, "anot1_");
auto edge = std::make_pair(block, bi1->getSuccessor(i));
if (done[edge].size() == 1) {
auto found = replacePHIs->find(*done[edge].begin());
if (found == replacePHIs->end())
continue;
if (&*BuilderM.GetInsertPoint() == found->second) {
if (found->second->getNextNode())
BuilderM.SetInsertPoint(found->second->getNextNode());
else
BuilderM.SetInsertPoint(found->second->getParent());
}
found->second->replaceAllUsesWith(val);
found->second->eraseFromParent();
} else {
otherBranch = val;
}
auto stagingIfNeeded = [&](BasicBlock *B) {
auto edge = std::make_pair(block, B);
if (done[edge].size() == 1) {
return *done[edge].begin();
} else {
assert(done[edge].size() == 2);
return staging;
}
};

for (unsigned i = 0; i < 2; ++i) {
auto edge = std::make_pair(subblock, bi2->getSuccessor(i));
auto found = replacePHIs->find(*done[edge].begin());
if (found == replacePHIs->end())
continue;
BuilderM.CreateCondBr(cond1, stagingIfNeeded(bi1->getSuccessor(0)),
stagingIfNeeded(bi1->getSuccessor(1)));
BuilderM.SetInsertPoint(staging);

Value *val = cond2;
if (i == 1)
val = BuilderM.CreateNot(val, "bnot1_");
val = BuilderM.CreateAnd(val, otherBranch, "andVal" + Twine(i));
if (&*BuilderM.GetInsertPoint() == found->second) {
if (found->second->getNextNode())
BuilderM.SetInsertPoint(found->second->getNextNode());
else
BuilderM.SetInsertPoint(found->second->getParent());
}
found->second->replaceAllUsesWith(val);
found->second->eraseFromParent();
}
}
// IMPORTANT: materialize cond2 *in staging* (so it is not executed
// when the outer guard path wasn't taken).
auto cond2 = lookupM(bi2->getCondition(), BuilderM);
BuilderM.CreateCondBr(
cond2,
*done[std::make_pair(subblock, bi2->getSuccessor(0))].begin(),
*done[std::make_pair(subblock, bi2->getSuccessor(1))].begin());

return;
}
Expand Down