Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,16 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
return B.CreateLoad(rankTy, alloc);
}

void handleNVSincos(llvm::CallInst &call);
void visitInstruction(llvm::Instruction &inst) {
if (auto *CI = llvm::dyn_cast<llvm::CallInst>(&inst)) {
if (auto *F = CI->getCalledFunction()) {
if (F->getName() == "__nv_sincos" || F->getName() == "__nv_sincosf") {
handleNVSincos(*CI);
return;
}
}
}
using namespace llvm;

// TODO explicitly handle all instructions rather than using the catch all
Expand Down
71 changes: 71 additions & 0 deletions enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4276,3 +4276,74 @@ bool AdjointGenerator::handleKnownCallDerivatives(

return false;
}

void AdjointGenerator::handleNVSincos(llvm::CallInst &call) {
if (gutils->isConstantInstruction(&call))
return;

if (Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeSplit) {
IRBuilder<> Builder2(&call);
getForwardBuilder(Builder2);

Value *op = call.getArgOperand(0);
Value *s_ptr = call.getArgOperand(1);
Value *c_ptr = call.getArgOperand(2);

Value *d_x = diffe(op, Builder2);
Value *X = lookup(op, Builder2);

Value *SinX = Builder2.CreateUnaryIntrinsic(Intrinsic::sin, X);
Value *CosX = Builder2.CreateUnaryIntrinsic(Intrinsic::cos, X);

Value *d_s = Builder2.CreateFMul(d_x, CosX);
Value *d_c = Builder2.CreateFMul(d_x, Builder2.CreateFNeg(SinX));

Value *s_shadow = gutils->invertPointerM(s_ptr, Builder2);
Value *c_shadow = gutils->invertPointerM(c_ptr, Builder2);

if (s_shadow)
Builder2.CreateStore(d_s, s_shadow);
if (c_shadow)
Builder2.CreateStore(d_c, c_shadow);
return;
}

if (Mode != DerivativeMode::ReverseModeGradient &&
Mode != DerivativeMode::ReverseModeCombined)
return;

IRBuilder<> Builder2(&call);
getReverseBuilder(Builder2);

Value *op = call.getArgOperand(0);
Value *s_ptr = call.getArgOperand(1);
Value *c_ptr = call.getArgOperand(2);

Value *X = lookup(op, Builder2);

Value *s_shadow = gutils->invertPointerM(s_ptr, Builder2);
Value *c_shadow = gutils->invertPointerM(c_ptr, Builder2);

if (!s_shadow || !c_shadow)
return;

Type *ElTy = op->getType();

Value *d_s = Builder2.CreateLoad(ElTy, s_shadow);
Value *d_c = Builder2.CreateLoad(ElTy, c_shadow);

Builder2.CreateStore(Constant::getNullValue(ElTy), s_shadow);
Builder2.CreateStore(Constant::getNullValue(ElTy), c_shadow);

Value *SinX = Builder2.CreateUnaryIntrinsic(Intrinsic::sin, X);
Value *CosX = Builder2.CreateUnaryIntrinsic(Intrinsic::cos, X);

Value *Term1 = Builder2.CreateFMul(d_s, CosX);
Value *Term2 = Builder2.CreateFMul(d_c, SinX);
Value *Diff = Builder2.CreateFSub(Term1, Term2);

if (!gutils->isConstantValue(op)) {
gutils->addToDiffe(op, Diff, Builder2, ElTy);
}
}
76 changes: 76 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/nv_sincos.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -sroa -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,sroa,instsimplify,%simplifycfg)" -S | FileCheck %s

; Function Attrs: nounwind
declare void @__nv_sincos(double, double*, double*)
declare void @__nv_sincosf(float, float*, float*)

define void @tester(double %x, double* %s_out, double* %c_out) {
entry:
call void @__nv_sincos(double %x, double* %s_out, double* %c_out)
ret void
}

define double @test_derivative(double %x) {
entry:
%s = alloca double
%c = alloca double
%ds = alloca double
%dc = alloca double

store double 1.0, double* %ds
store double 2.0, double* %dc

%0 = call double (...) @__enzyme_autodiff(void (double, double*, double*)* @tester, double %x, double* %s, double* %ds, double* %c, double* %dc)

ret double %0
}

define float @tester_f(float %x, float* %s_out, float* %c_out) {
entry:
call void @__nv_sincosf(float %x, float* %s_out, float* %c_out)
ret float %x
}

define float @test_derivative_f(float %x) {
entry:
%s = alloca float
%c = alloca float
%ds = alloca float
%dc = alloca float

store float 1.0, float* %ds
store float 2.0, float* %dc

%0 = call float (...) @__enzyme_autodiff(float (float, float*, float*)* @tester_f, float %x, float* %s, float* %ds, float* %c, float* %dc)

ret float %0
}

declare double @__enzyme_autodiff(...)

; CHECK-LABEL: define internal { float } @diffetester_f(float %x, float* %s_out, float* %"s_out'", float* %c_out, float* %"c_out'")
; CHECK: call void @__nv_sincosf(float %x, float* %s_out, float* %c_out)
; CHECK: %[[dsf:.+]] = load float, float* %"s_out'"
; CHECK: %[[dcf:.+]] = load float, float* %"c_out'"
; CHECK: store float 0.000000e+00, float* %"s_out'"
; CHECK: store float 0.000000e+00, float* %"c_out'"
; CHECK: %[[sinf:.+]] = call fast float @llvm.sin.f32(float %x)
; CHECK: %[[cosf:.+]] = call fast float @llvm.cos.f32(float %x)
; CHECK: %[[term1f:.+]] = fmul fast float %[[dsf]], %[[cosf]]
; CHECK: %[[term2f:.+]] = fmul fast float %[[dcf]], %[[sinf]]
; CHECK: %[[difff:.+]] = fsub fast float %[[term1f]], %[[term2f]]
; CHECK: ret { float }

; CHECK-LABEL: define internal { double } @diffetester(double %x, double* %s_out, double* %"s_out'", double* %c_out, double* %"c_out'")
; CHECK: call void @__nv_sincos(double %x, double* %s_out, double* %c_out)
; CHECK: %[[ds:.+]] = load double, double* %"s_out'"
; CHECK: %[[dc:.+]] = load double, double* %"c_out'"
; CHECK: store double 0.000000e+00, double* %"s_out'"
; CHECK: store double 0.000000e+00, double* %"c_out'"
; CHECK: %[[sin:.+]] = call fast double @llvm.sin.f64(double %x)
; CHECK: %[[cos:.+]] = call fast double @llvm.cos.f64(double %x)
; CHECK: %[[term1:.+]] = fmul fast double %[[ds]], %[[cos]]
; CHECK: %[[term2:.+]] = fmul fast double %[[dc]], %[[sin]]
; CHECK: %[[diff:.+]] = fsub fast double %[[term1]], %[[term2]]
; CHECK: ret { double }