From 4dfe67f64a32934a9f04907597bc3fd8af994b5f Mon Sep 17 00:00:00 2001 From: Min Xu Date: Mon, 23 Feb 2026 20:28:28 -0500 Subject: [PATCH] fix the nested predicate bug --- enzyme/Enzyme/CacheUtility.cpp | 8 ++- enzyme/Enzyme/GradientUtils.cpp | 103 ++++++++++++-------------------- 2 files changed, 46 insertions(+), 65 deletions(-) diff --git a/enzyme/Enzyme/CacheUtility.cpp b/enzyme/Enzyme/CacheUtility.cpp index caa065f9c233..67af390a64b3 100644 --- a/enzyme/Enzyme/CacheUtility.cpp +++ b/enzyme/Enzyme/CacheUtility.cpp @@ -838,7 +838,13 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T, alloc->setAlignment(Align(align)); } if (sublimits.size() == 0) { - auto val = getUndefinedValueForType(*newFunc->getParent(), types.back()); + // For i1 predicate caches, "not executed" must behave like false. + // Otherwise reverse CFG reconstruction may branch on undef/poison (issue + // #2629). + bool forceZero = + T->isIntegerTy() && cast(T)->getBitWidth() == 1; + auto val = getUndefinedValueForType(*newFunc->getParent(), types.back(), + /*forceZero*/ forceZero); if (!isa(val)) scopeInstructions[alloc].push_back(entryBuilder.CreateStore(val, alloc)); } diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 7a98e2fb02d1..8662eee4ec91 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -7868,7 +7868,12 @@ void GradientUtils::branchToCorrespondingTarget( } // llvm::errs() << "\n"; - if (targetToPreds.size() == 3) { + // NOTE: the 3-target "reuse two branch predicates" optimization constructs + // a synthetic staging block to evaluate the second predicate only under the + // first split. That requires actual control-flow. When replacePHIs != nullptr + // we must not use this fast-path (it would otherwise eagerly evaluate the + // inner predicate unconditionally and reproduce #2629 at -O0/-O1). + if (replacePHIs == nullptr && targetToPreds.size() == 3) { // Try `block` as a potential first split point. for (auto block : blocks) { { @@ -7986,73 +7991,43 @@ void GradientUtils::branchToCorrespondingTarget( // the remainder of foundTargets. auto cond1 = lookupM(bi1->getCondition(), BuilderM); - // Condition cond2 splits off the two blocks in - // (foundTargets-uniqueTargets) from each other. - auto cond2 = lookupM(bi2->getCondition(), BuilderM); + // Create a staging block so the second predicate is only evaluated + // on the path where the first split is taken (fixes #2629). + BasicBlock *staging = + BasicBlock::Create(oldFunc->getContext(), "staging", newFunc); + + // `lookupM` requires reverse-only blocks to have an entry in + // reverseBlockToPrimal. Since `staging` is synthetic, map it to the + // same forward/primal block as the current insertion block. + BasicBlock *stagingFwd = BuilderM.GetInsertBlock(); + if (!isOriginalBlock(*stagingFwd)) { + auto it = reverseBlockToPrimal.find(stagingFwd); + assert(it != reverseBlockToPrimal.end()); + stagingFwd = it->second; + } + reverseBlockToPrimal[staging] = stagingFwd; - if (replacePHIs == nullptr) { - BasicBlock *staging = - BasicBlock::Create(oldFunc->getContext(), "staging", newFunc); - auto stagingIfNeeded = [&](BasicBlock *B) { - auto edge = std::make_pair(block, B); - if (done[edge].size() == 1) { - return *done[edge].begin(); - } else { - assert(done[edge].size() == 2); - return staging; - } - }; - BuilderM.CreateCondBr(cond1, stagingIfNeeded(bi1->getSuccessor(0)), - stagingIfNeeded(bi1->getSuccessor(1))); - BuilderM.SetInsertPoint(staging); - BuilderM.CreateCondBr( - cond2, - *done[std::make_pair(subblock, bi2->getSuccessor(0))].begin(), - *done[std::make_pair(subblock, bi2->getSuccessor(1))].begin()); - } else { - Value *otherBranch = nullptr; - for (unsigned i = 0; i < 2; ++i) { - Value *val = cond1; - if (i == 1) - val = BuilderM.CreateNot(val, "anot1_"); - auto edge = std::make_pair(block, bi1->getSuccessor(i)); - if (done[edge].size() == 1) { - auto found = replacePHIs->find(*done[edge].begin()); - if (found == replacePHIs->end()) - continue; - if (&*BuilderM.GetInsertPoint() == found->second) { - if (found->second->getNextNode()) - BuilderM.SetInsertPoint(found->second->getNextNode()); - else - BuilderM.SetInsertPoint(found->second->getParent()); - } - found->second->replaceAllUsesWith(val); - found->second->eraseFromParent(); - } else { - otherBranch = val; - } + auto stagingIfNeeded = [&](BasicBlock *B) { + auto edge = std::make_pair(block, B); + if (done[edge].size() == 1) { + return *done[edge].begin(); + } else { + assert(done[edge].size() == 2); + return staging; } + }; - for (unsigned i = 0; i < 2; ++i) { - auto edge = std::make_pair(subblock, bi2->getSuccessor(i)); - auto found = replacePHIs->find(*done[edge].begin()); - if (found == replacePHIs->end()) - continue; + BuilderM.CreateCondBr(cond1, stagingIfNeeded(bi1->getSuccessor(0)), + stagingIfNeeded(bi1->getSuccessor(1))); + BuilderM.SetInsertPoint(staging); - Value *val = cond2; - if (i == 1) - val = BuilderM.CreateNot(val, "bnot1_"); - val = BuilderM.CreateAnd(val, otherBranch, "andVal" + Twine(i)); - if (&*BuilderM.GetInsertPoint() == found->second) { - if (found->second->getNextNode()) - BuilderM.SetInsertPoint(found->second->getNextNode()); - else - BuilderM.SetInsertPoint(found->second->getParent()); - } - found->second->replaceAllUsesWith(val); - found->second->eraseFromParent(); - } - } + // IMPORTANT: materialize cond2 *in staging* (so it is not executed + // when the outer guard path wasn't taken). + auto cond2 = lookupM(bi2->getCondition(), BuilderM); + BuilderM.CreateCondBr( + cond2, + *done[std::make_pair(subblock, bi2->getSuccessor(0))].begin(), + *done[std::make_pair(subblock, bi2->getSuccessor(1))].begin()); return; }