From e3a2f2cf2516723916174958b4dd1d473ffe8ee2 Mon Sep 17 00:00:00 2001 From: Parvm1102 Date: Sat, 27 Jun 2026 14:38:00 +0530 Subject: [PATCH] Add Mooncake extension for propagate copy_xj fast path and enabled tests Signed-off-by: Parvm1102 --- GNNlib/Project.toml | 3 ++ GNNlib/ext/GNNlibMooncakeExt.jl | 50 +++++++++++++++++++++++++++++++++ GNNlib/test/Project.toml | 1 + GNNlib/test/msgpass.jl | 8 +++--- GNNlib/test/test_module.jl | 30 +++++++++++++++++--- 5 files changed, 84 insertions(+), 8 deletions(-) create mode 100644 GNNlib/ext/GNNlibMooncakeExt.jl diff --git a/GNNlib/Project.toml b/GNNlib/Project.toml index f4a1e9d1f..a0e378881 100644 --- a/GNNlib/Project.toml +++ b/GNNlib/Project.toml @@ -20,11 +20,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" [extensions] GNNlibAMDGPUExt = "AMDGPU" GNNlibCUDAExt = "CUDA" GNNlibMetalExt = "Metal" +GNNlibMooncakeExt = "Mooncake" [compat] AMDGPU = "1" @@ -35,6 +37,7 @@ GNNGraphs = "1.4" LinearAlgebra = "1" Metal = "1.0" MLUtils = "0.4" +Mooncake = "0.5" NNlib = "0.9" Random = "1" Statistics = "1" diff --git a/GNNlib/ext/GNNlibMooncakeExt.jl b/GNNlib/ext/GNNlibMooncakeExt.jl new file mode 100644 index 000000000..37aae08aa --- /dev/null +++ b/GNNlib/ext/GNNlibMooncakeExt.jl @@ -0,0 +1,50 @@ +module GNNlibMooncakeExt + +using GNNlib: GNNlib, propagate, copy_xj +using GNNGraphs: GNNGraph, adjacency_matrix +using LinearAlgebra: adjoint +using Base: IEEEFloat +import Mooncake +using Mooncake: CoDual, DefaultCtx, NoRData, @is_primitive + +# Mooncake reverse rule for the sparse message-passing fast path +# propagate(copy_xj, g, +, xi, xj, e) == xj * adjacency_matrix(g) +# `A` is constant w.r.t. the inputs, so the pullback is just `dxj = dy * A'`. +# Without it Mooncake differentiates the generic sparse matmul, which is far +# slower than Zygote. + +@is_primitive DefaultCtx Tuple{ + typeof(propagate), + typeof(copy_xj), + GNNGraph, + typeof(+), + Nothing, + AbstractMatrix{P}, + Nothing, +} where {P <: IEEEFloat} + +function Mooncake.rrule!!( + ::CoDual{typeof(propagate)}, + ::CoDual{typeof(copy_xj)}, + g::CoDual{<:GNNGraph}, + ::CoDual{typeof(+)}, + ::CoDual{Nothing}, + xj::CoDual{<:AbstractMatrix{P}}, + ::CoDual{Nothing}, +) where {P <: IEEEFloat} + pg = Mooncake.primal(g) + pxj = Mooncake.primal(xj) + A = adjacency_matrix(pg, P; weighted = false) + y = pxj * A + res = Mooncake.zero_fcodual(y) + function propagate_copy_xj_add_pullback!!(::NoRData) + dy = Mooncake.tangent(res) + dxj = Mooncake.tangent(xj) + dxj .+= dy * adjoint(A) + return NoRData(), NoRData(), Mooncake.zero_rdata(pg), + NoRData(), NoRData(), NoRData(), NoRData() + end + return res, propagate_copy_xj_add_pullback!! +end + +end # module diff --git a/GNNlib/test/Project.toml b/GNNlib/test/Project.toml index 2d87c8e99..5ebbf1f27 100644 --- a/GNNlib/test/Project.toml +++ b/GNNlib/test/Project.toml @@ -7,6 +7,7 @@ GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/GNNlib/test/msgpass.jl b/GNNlib/test/msgpass.jl index 44bda260b..31de7f1cc 100644 --- a/GNNlib/test/msgpass.jl +++ b/GNNlib/test/msgpass.jl @@ -132,14 +132,14 @@ end @testset "copy_xj +" begin for g in TEST_GRAPHS f(g, x) = propagate(copy_xj, g, +, xj = x) - test_gradients(f, g, g.x; test_grad_f=false) + test_gradients(f, g, g.x; test_grad_f=false, test_mooncake=TEST_MOONCAKE) end end @testset "copy_xj mean" begin for g in TEST_GRAPHS f(g, x) = propagate(copy_xj, g, mean, xj = x) - test_gradients(f, g, g.x; test_grad_f=false) + test_gradients(f, g, g.x; test_grad_f=false, test_mooncake=TEST_MOONCAKE) end end @@ -147,7 +147,7 @@ end for g in TEST_GRAPHS e = rand(Float32, size(g.x, 1), g.num_edges) f(g, x, e) = propagate(e_mul_xj, g, +; xj = x, e) - test_gradients(f, g, g.x, e; test_grad_f=false) + test_gradients(f, g, g.x, e; test_grad_f=false, test_mooncake=TEST_MOONCAKE) end end @@ -158,7 +158,7 @@ end g = set_edge_weight(g, w) return propagate(w_mul_xj, g, +, xj = x) end - test_gradients(f, g, g.x, w; test_grad_f=false) + test_gradients(f, g, g.x, w; test_grad_f=false, test_mooncake=TEST_MOONCAKE) end end end diff --git a/GNNlib/test/test_module.jl b/GNNlib/test/test_module.jl index 075881af8..18fe42dca 100644 --- a/GNNlib/test/test_module.jl +++ b/GNNlib/test/test_module.jl @@ -42,10 +42,16 @@ using FiniteDifferences: FiniteDifferences using Zygote: Zygote using Flux: Flux +# Mooncake.jl requires Julia >= 1.12 +const TEST_MOONCAKE = VERSION >= v"1.12" +if TEST_MOONCAKE + import Mooncake +end + # from this module export D_IN, D_OUT, GRAPH_TYPES, TEST_GRAPHS, test_gradients, finitediff_withgradient, - check_equal_leaves, gpu_backend + check_equal_leaves, gpu_backend, TEST_MOONCAKE const D_IN = 3 @@ -81,12 +87,13 @@ function test_gradients( test_grad_f = true, test_grad_x = true, compare_finite_diff = true, + test_mooncake = false, loss = (f, g, xs...) -> mean(f(g, xs...)), ) - if !test_gpu && !compare_finite_diff - error("You should either compare finite diff vs CPU AD \ - or CPU AD vs GPU AD.") + if !test_gpu && !compare_finite_diff && !test_mooncake + error("You should either compare finite diff vs CPU AD, \ + CPU AD vs GPU AD, or test Mooncake AD.") end ## Let's make sure first that the forward pass works. @@ -115,6 +122,14 @@ function test_gradients( check_equal_leaves(g, g_fd; rtol, atol) end + if test_mooncake && !(graph.graph isa AbstractSparseMatrix) # Mooncake friendly tangents currently error on sparse graph internals + # Mooncake gradient with respect to input via Flux integration, compared against Zygote. + loss_mc_x = (xs...) -> loss(f, graph, xs...) + y_mc, g_mc = Flux.withgradient(loss_mc_x, Flux.AutoMooncake(), xs...) + @assert isapprox(y, y_mc; rtol, atol) + check_equal_leaves(g, g_mc; rtol, atol) + end + if test_gpu # Zygote gradient with respect to input on GPU. y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, graph_gpu, xs...), xs_gpu...) @@ -138,6 +153,13 @@ function test_gradients( check_equal_leaves(g, g_fd; rtol, atol) end + if test_mooncake && !(graph.graph isa AbstractSparseMatrix) # Mooncake friendly tangents currently error on sparse graph internals + # Mooncake gradient with respect to f via Flux integration, compared against Zygote. + y_mc, g_mc = Flux.withgradient(f -> loss(f, graph, xs...), Flux.AutoMooncake(), f) + @assert isapprox(y, y_mc; rtol, atol) + check_equal_leaves(g, g_mc; rtol, atol) + end + if test_gpu # Zygote gradient with respect to f on GPU. y_gpu, g_gpu = Zygote.withgradient(f -> loss(f,graph_gpu, xs_gpu...), f_gpu)