diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 19e3c947b04b..d36a2c9ec802 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -1032,31 +1032,19 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, SmallVector scopeMD = { getDerivativeAliasScope(origptr, idx)}; - if (auto origValI = dyn_cast_or_null(origVal)) - if (auto MD = origValI->getMetadata(LLVMContext::MD_alias_scope)) { - auto MDN = cast(MD); - for (auto &o : MDN->operands()) - scopeMD.push_back(o); - } auto scope = MDNode::get(LI->getContext(), scopeMD); LI->setMetadata(LLVMContext::MD_alias_scope, scope); st->setMetadata(LLVMContext::MD_alias_scope, scope); SmallVector MDs; - for (ssize_t j = -1; j < getWidth(); j++) { + for (ssize_t j = -1; j < (ssize_t)getWidth(); j++) { if (j != (ssize_t)idx) MDs.push_back(getDerivativeAliasScope(origptr, j)); } - if (auto origValI = dyn_cast_or_null(origVal)) - if (auto MD = origValI->getMetadata(LLVMContext::MD_noalias)) { - auto MDN = cast(MD); - for (auto &o : MDN->operands()) - MDs.push_back(o); - } - idx++; - auto noscope = MDNode::get(ptr->getContext(), MDs); + auto noscope = MDNode::get(LI->getContext(), MDs); LI->setMetadata(LLVMContext::MD_noalias, noscope); st->setMetadata(LLVMContext::MD_noalias, noscope); + idx++; if (origVal && isa(origVal) && start == 0 && size == (DL.getTypeSizeInBits(origVal->getType()) + 7) / 8) { diff --git a/enzyme/test/Enzyme/ReverseMode/alias-sret-correctness.ll b/enzyme/test/Enzyme/ReverseMode/alias-sret-correctness.ll new file mode 100644 index 000000000000..5bdc5b7f17f8 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/alias-sret-correctness.ll @@ -0,0 +1,41 @@ +; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,gvn,dse,simplifycfg)" -S | FileCheck %s; fi + +define { double, double } @subcall(double %x) { +entry: + %m1 = fmul double %x, 2.0 + %m2 = fmul double %x, 3.0 + %res1 = insertvalue { double, double } undef, double %m1, 0 + %res2 = insertvalue { double, double } %res1, double %m2, 1 + ret { double, double } %res2 +} + +define double @caller(double %x, double* %p) { +entry: + %call = call { double, double } @subcall(double %x) + %val1 = extractvalue { double, double } %call, 0 + %val2 = extractvalue { double, double } %call, 1 + + ; Load carrying alias.scope + %l = load double, double* %p, align 8, !alias.scope !1 + + %res = fadd double %val1, %val2 + %res2 = fadd double %res, %l + ret double %res2 +} + +!1 = !{!2} +!2 = distinct !{!2, !3, !"domain"} +!3 = distinct !{!3, !"domain"} + +define void @test_diff(double %x, double* %p, double* %dp) { + call void (...) @__enzyme_autodiff(i8* bitcast (double (double, double*)* @caller to i8*), double %x, double* %p, double* %dp) + ret void +} + +declare void @__enzyme_autodiff(...) + +; CHECK: define internal { double } @diffecaller +; CHECK: %[[gep:.+]] = getelementptr inbounds { double, double }, { double, double }* %"call'de", i32 0, i32 1 +; CHECK: %[[load:.+]] = load double, double* %[[gep]], align 8 +; CHECK: %[[add:.+]] = fadd fast double %[[load]], %{{.+}} +; CHECK-NEXT: store double %[[add]], double* %[[gep]], align 8{{$}} diff --git a/enzyme/test/Enzyme/ReverseMode/alias-sret-ptr.ll b/enzyme/test/Enzyme/ReverseMode/alias-sret-ptr.ll new file mode 100644 index 000000000000..26619323efbf --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/alias-sret-ptr.ll @@ -0,0 +1,53 @@ +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,gvn,dse)" -S | FileCheck %s + +; Subcall takes a pointer output to mimic Array return in memory +define void @subcall(double* %out, double %x) { +entry: + %m1 = fmul double %x, 2.0 + %m2 = fmul double %x, 3.0 + store double %m1, double* %out, align 8 + %gep = getelementptr inbounds double, double* %out, i64 1 + store double %m2, double* %gep, align 8 + ret void +} + +define double @caller(double %x, double* %p) { +entry: + %call_out = alloca double, i64 2, align 8 + call void @subcall(double* %call_out, double %x) + + %val1 = load double, double* %call_out, align 8 + %gep = getelementptr inbounds double, double* %call_out, i64 1 + %val2 = load double, double* %gep, align 8 + + ; Load carrying alias.scope + %l = load double, double* %p, align 8, !alias.scope !1 + + %res = fadd double %val1, %val2 + %res2 = fadd double %res, %l + ret double %res2 +} + +!1 = !{!2} +!2 = distinct !{!2, !3, !"domain"} +!3 = distinct !{!3, !"domain"} + +define void @test_diff(double %x, double* %p, double* %dp) { + call void (...) @__enzyme_autodiff(i8* bitcast (double (double, double*)* @caller to i8*), double %x, double* %p, double* %dp) + ret void +} + +declare void @__enzyme_autodiff(...) + +; CHECK: define internal { double } @diffecaller(double %x, double* nocapture readonly %p, double* nocapture %"p'", double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"call_out'ipa" = alloca double, i64 2, align 8 +; CHECK-NEXT: %"gep'ipg" = getelementptr inbounds double, double* %"call_out'ipa", i64 1 +; CHECK-NEXT: %[[loadP:.+]] = load double, double* %"p'", align 8, !alias.scope ![[scopeP:[0-9]+]], !noalias ![[noaliasP:[0-9]+]] +; CHECK-NEXT: %[[addP:.+]] = fadd fast double %[[loadP]], %differeturn +; CHECK-NEXT: store double %[[addP]], double* %"p'", align 8, !alias.scope ![[scopeP]], !noalias ![[noaliasP]] +; CHECK-NEXT: store double %differeturn, double* %"gep'ipg", align 8, !alias.scope ![[scopeS:[0-9]+]], !noalias ![[noaliasS:[0-9]+]] +; CHECK-NEXT: store double %differeturn, double* %"call_out'ipa", align 8, !alias.scope ![[scopeS]], !noalias ![[noaliasS]] + +; CHECK: ![[scopeS]] = !{![[shadow_node:[0-9]+]]} +; CHECK: ![[shadow_node]] = distinct !{![[shadow_node]], !{{[0-9]+}}, !"shadow_0"} diff --git a/enzyme/test/Enzyme/ReverseMode/alias-sret.ll b/enzyme/test/Enzyme/ReverseMode/alias-sret.ll new file mode 100644 index 000000000000..0f4d5884e187 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/alias-sret.ll @@ -0,0 +1,51 @@ +; RUN: if [ %llvmver -ge 12 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg)" -S | FileCheck %s; fi + +define { double, double } @subcall(double %x) { +entry: + %m1 = fmul double %x, 2.0 + %m2 = fmul double %x, 3.0 + %res1 = insertvalue { double, double } undef, double %m1, 0 + %res2 = insertvalue { double, double } %res1, double %m2, 1 + ret { double, double } %res2 +} + +define double @caller(double %x, double* %p) { +entry: + %call = call { double, double } @subcall(double %x) + %val1 = extractvalue { double, double } %call, 0 + %val2 = extractvalue { double, double } %call, 1 + %l = load double, double* %p, !alias.scope !1 + %res = fadd double %val1, %val2 + %res2 = fadd double %res, %l + ret double %res2 +} + +!1 = !{!2} +!2 = distinct !{!2, !3, !"domain"} +!3 = distinct !{!3, !"domain"} + +define void @test_diff(double %x, double* %p, double* %dp) { + call void (...) @__enzyme_autodiff(i8* bitcast (double (double, double*)* @caller to i8*), double %x, double* %p, double* %dp) + ret void +} + +declare void @__enzyme_autodiff(...) + +; CHECK: define internal { double } @diffecaller(double %x, double* nocapture readonly %p, double* nocapture %"p'", double %differeturn) +; CHECK-NEXT: entry: +; CHECK-NEXT: %"call'de" = alloca { double, double }, align 8 +; CHECK-NEXT: store { double, double } zeroinitializer, { double, double }* %"call'de", align 8 +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: +; CHECK-NEXT: %[[add0:.+]] = fadd fast double 0.000000e+00, %differeturn +; CHECK-NEXT: %[[add1:.+]] = fadd fast double 0.000000e+00, %differeturn +; CHECK-NEXT: %[[add2:.+]] = fadd fast double 0.000000e+00, %[[add0]] +; CHECK-NEXT: %[[add3:.+]] = fadd fast double 0.000000e+00, %[[add0]] +; CHECK-NEXT: %[[loadP:.+]] = load double, double* %"p'", align 8, !alias.scope ![[scopeP:[0-9]+]], !noalias ![[noaliasP:[0-9]+]] +; CHECK-NEXT: %[[addP:.+]] = fadd fast double %[[loadP]], %[[add1]] +; CHECK-NEXT: store double %[[addP]], double* %"p'", align 8, !alias.scope ![[scopeP]], !noalias ![[noaliasP]] +; CHECK-NEXT: %[[gepS:.+]] = getelementptr inbounds { double, double }, { double, double }* %"call'de", i32 0, i32 1 +; CHECK-NEXT: %[[loadS:.+]] = load double, double* %[[gepS]], align 8 +; CHECK-NEXT: %[[addS:.+]] = fadd fast double %[[loadS]], %[[add3]] +; CHECK-NEXT: store double %[[addS]], double* %[[gepS]], align 8{{$}}