From 3c7a847328c443b0f5df733bc38da9d7c147fe91 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Mar 2026 14:28:05 +0000 Subject: [PATCH 1/4] Fix custom forward pass where primal return wasn't needed --- .../Enzyme/ForwardMode/customfwd_double.ll | 37 +++++++++++++++++++ .../test/Enzyme/ForwardMode/customfwd_int.ll | 37 +++++++++++++++++++ 2 files changed, 74 insertions(+) create mode 100644 enzyme/test/Enzyme/ForwardMode/customfwd_double.ll create mode 100644 enzyme/test/Enzyme/ForwardMode/customfwd_int.ll diff --git a/enzyme/test/Enzyme/ForwardMode/customfwd_double.ll b/enzyme/test/Enzyme/ForwardMode/customfwd_double.ll new file mode 100644 index 000000000000..4d4df58892fd --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/customfwd_double.ll @@ -0,0 +1,37 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -preserve-nvvm -enzyme -enzyme-preopt=false -early-cse -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="preserve-nvvm,enzyme,function(early-cse)" -enzyme-preopt=false -S | FileCheck %s + +source_filename = "customfwd_double.c" +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__enzyme_register_derivative_square = dso_local local_unnamed_addr global [2 x i8*] [i8* bitcast (double (double)* @square to i8*), i8* bitcast (double (double, double)* @derivative_square to i8*)], align 16 + +; Function Attrs: norecurse nounwind readnone uwtable willreturn +define double @square(double %x) #0 { +entry: + %mul = fmul double %x, %x + ret double %mul +} + +define double @derivative_square(double %x, double %dx) #0 { +entry: + ret double 100.000000e+00 +} + +; Function Attrs: nounwind uwtable +define double @caller(double %x, double %dx) { +entry: + %call = call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double)* @square to i8*), metadata !"enzyme_dup", double %x, double %dx) + ret double %call +} + +declare dso_local double @__enzyme_fwddiff(i8*, ...) + +attributes #0 = { norecurse nounwind readnone } + +; CHECK: define internal double @fwddiffesquare(double %x, double %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call double @derivative_square(double %x, double %"x'") +; CHECK-NEXT: ret double %0 +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/ForwardMode/customfwd_int.ll b/enzyme/test/Enzyme/ForwardMode/customfwd_int.ll new file mode 100644 index 000000000000..e1eed4f9732c --- /dev/null +++ b/enzyme/test/Enzyme/ForwardMode/customfwd_int.ll @@ -0,0 +1,37 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -preserve-nvvm -enzyme -enzyme-preopt=false -early-cse -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="preserve-nvvm,enzyme,function(early-cse)" -enzyme-preopt=false -S | FileCheck %s + +source_filename = "customfwd_int.c" +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__enzyme_register_derivative_square = dso_local local_unnamed_addr global [2 x i8*] [i8* bitcast (i32 (i32)* @square to i8*), i8* bitcast (i32 (i32, i32)* @derivative_square to i8*)], align 16 + +; Function Attrs: norecurse nounwind readnone uwtable willreturn +define i32 @square(i32 %x) #0 { +entry: + %mul = mul i32 %x, %x + ret i32 %mul +} + +define i32 @derivative_square(i32 %x, i32 %dx) #0 { +entry: + ret i32 100 +} + +; Function Attrs: nounwind uwtable +define i32 @caller(i32 %x, i32 %dx) { +entry: + %call = call i32 (i8*, ...) @__enzyme_fwddiff(i8* bitcast (i32 (i32)* @square to i8*), metadata !"enzyme_dup", i32 %x, i32 %dx) + ret i32 %call +} + +declare dso_local i32 @__enzyme_fwddiff(i8*, ...) + +attributes #0 = { norecurse nounwind readnone } + +; CHECK: define internal i32 @fwddiffesquare(i32 %x, i32 %"x'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call i32 @derivative_square(i32 %x, i32 %"x'") +; CHECK-NEXT: ret i32 %0 +; CHECK-NEXT: } From e785abe03062a1c433f43d212f52839cc6436453 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Thu, 12 Mar 2026 15:49:26 +0000 Subject: [PATCH 2/4] Fix EnzymeLogic.cpp CreateForwardDiff to support custom forward scalar derivatives properly --- enzyme/Enzyme/EnzymeLogic.cpp | 74 +++++++++++++++++------------------ 1 file changed, 35 insertions(+), 39 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index ae7839d6dcd8..da64d8d24e3f 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4762,34 +4762,6 @@ Function *EnzymeLogic::CreateForwardDiff( !hasconstant && returnUsed) return foundcalled; - if (!foundcalled->getReturnType()->isVoidTy() && !hasconstant) { - if (returnUsed && retType == DIFFE_TYPE::CONSTANT) { - } - if (!returnUsed && retType != DIFFE_TYPE::CONSTANT && !hasconstant) { - FunctionType *FTy = FunctionType::get( - todiff->getReturnType(), foundcalled->getFunctionType()->params(), - foundcalled->getFunctionType()->isVarArg()); - Function *NewF = Function::Create( - FTy, Function::LinkageTypes::InternalLinkage, - "fixderivative_" + todiff->getName(), todiff->getParent()); - for (auto pair : llvm::zip(NewF->args(), foundcalled->args())) { - std::get<0>(pair).setName(std::get<1>(pair).getName()); - } - - BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF); - IRBuilder<> bb(BB); - SmallVector args; - for (auto &a : NewF->args()) - args.push_back(&a); - auto cal = bb.CreateCall(foundcalled, args); - cal->setCallingConv(foundcalled->getCallingConv()); - - bb.CreateRet(bb.CreateExtractValue(cal, 1)); - return ForwardCachedFunctions[tup] = NewF; - } - assert(returnUsed); - } - SmallVector curTypes; bool legal = true; SmallVector nextConstantArgs; @@ -4819,20 +4791,14 @@ Function *EnzymeLogic::CreateForwardDiff( curTypes.push_back(additionalArg); } if (legal) { - Type *RT = todiff->getReturnType(); - if (returnUsed && retType != DIFFE_TYPE::CONSTANT) { - RT = StructType::get(RT->getContext(), {RT, RT}); - } - if (!returnUsed && retType == DIFFE_TYPE::CONSTANT) { - RT = Type::getVoidTy(RT->getContext()); - } + Type *RT = foundcalled->getReturnType(); FunctionType *FTy = FunctionType::get( RT, curTypes, todiff->getFunctionType()->isVarArg()); Function *NewF = Function::Create( FTy, Function::LinkageTypes::InternalLinkage, - "fixderivative_" + todiff->getName(), todiff->getParent()); + "fwddiffe" + todiff->getName(), todiff->getParent()); auto foundArg = NewF->arg_begin(); SmallVector nextArgs; @@ -4874,12 +4840,42 @@ Function *EnzymeLogic::CreateForwardDiff( auto cal = bb.CreateCall(foundcalled, nextArgs); cal->setCallingConv(foundcalled->getCallingConv()); - if (returnUsed && retType != DIFFE_TYPE::CONSTANT) { + if (RT->isVoidTy()) { + bb.CreateRetVoid(); + } else if (cal->getType() == RT) { bb.CreateRet(cal); + } else if (returnUsed && retType != DIFFE_TYPE::CONSTANT) { + if (cal->getType()->isStructTy()) { + bb.CreateRet(cal); + } else { + SmallVector primalArgs; + auto argIt = NewF->arg_begin(); + for (auto tup : llvm::zip(todiff->args(), constant_args)) { + primalArgs.push_back(argIt); + argIt++; + if (std::get<1>(tup) != DIFFE_TYPE::CONSTANT) { + argIt++; + } + } + auto primalCal = bb.CreateCall(todiff, primalArgs); + primalCal->setCallingConv(todiff->getCallingConv()); + Value *str = UndefValue::get(RT); + str = bb.CreateInsertValue(str, primalCal, 0); + str = bb.CreateInsertValue(str, cal, 1); + bb.CreateRet(str); + } } else if (returnUsed) { - bb.CreateRet(bb.CreateExtractValue(cal, 0)); + if (cal->getType()->isStructTy()) { + bb.CreateRet(bb.CreateExtractValue(cal, 0)); + } else { + bb.CreateRet(cal); + } } else if (retType != DIFFE_TYPE::CONSTANT) { - bb.CreateRet(bb.CreateExtractValue(cal, 1)); + if (cal->getType()->isStructTy()) { + bb.CreateRet(bb.CreateExtractValue(cal, 1)); + } else { + bb.CreateRet(cal); + } } else { bb.CreateRetVoid(); } From 209486436b2439860b45ac50a250156044bf3193 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sat, 14 Mar 2026 16:54:46 +0000 Subject: [PATCH 3/4] fix --- enzyme/Enzyme/EnzymeLogic.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index da64d8d24e3f..7df498d92995 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4792,13 +4792,22 @@ Function *EnzymeLogic::CreateForwardDiff( } if (legal) { Type *RT = foundcalled->getReturnType(); + if (!returnUsed && retType != DIFFE_TYPE::CONSTANT) { + if (RT->isStructTy()) { + RT = RT->getStructElementType(1); + } + } else if (returnUsed && retType == DIFFE_TYPE::CONSTANT) { + if (RT->isStructTy()) { + RT = RT->getStructElementType(0); + } + } FunctionType *FTy = FunctionType::get( RT, curTypes, todiff->getFunctionType()->isVarArg()); Function *NewF = Function::Create( FTy, Function::LinkageTypes::InternalLinkage, - "fwddiffe" + todiff->getName(), todiff->getParent()); + "fixderivative_" + todiff->getName(), todiff->getParent()); auto foundArg = NewF->arg_begin(); SmallVector nextArgs; From 263d4e70ee378f19d01622fdec1f28baed567d77 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sun, 15 Mar 2026 01:27:47 +0000 Subject: [PATCH 4/4] fix --- enzyme/Enzyme/EnzymeLogic.cpp | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 7df498d92995..e78d308b9058 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -30,8 +30,12 @@ #include "EnzymeLogic.h" #include "ActivityAnalysis.h" #include "AdjointGenerator.h" -#include "EnzymeLogic.h" #include "TypeAnalysis/TypeAnalysis.h" +#include +#include +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallString.h" +#include #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/GlobalValue.h" @@ -4810,14 +4814,25 @@ Function *EnzymeLogic::CreateForwardDiff( "fixderivative_" + todiff->getName(), todiff->getParent()); auto foundArg = NewF->arg_begin(); + auto fcArg = foundcalled->arg_begin(); SmallVector nextArgs; for (auto tup : llvm::zip(todiff->args(), constant_args)) { nextArgs.push_back(foundArg); auto &arg = std::get<0>(tup); - foundArg->setName(arg.getName()); + if (fcArg != foundcalled->arg_end()) { + foundArg->setName(fcArg->getName()); + fcArg++; + } else { + foundArg->setName(arg.getName()); + } foundArg++; if (std::get<1>(tup) != DIFFE_TYPE::CONSTANT) { - foundArg->setName(arg.getName() + "'"); + if (fcArg != foundcalled->arg_end()) { + foundArg->setName(fcArg->getName()); + fcArg++; + } else { + foundArg->setName(arg.getName() + "'"); + } nextConstantArgs.push_back(std::get<1>(tup)); nextArgs.push_back(foundArg); foundArg++;