Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions GNNlib/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
GNNlibAMDGPUExt = "AMDGPU"
GNNlibCUDAExt = "CUDA"
GNNlibMetalExt = "Metal"
GNNlibMooncakeExt = "Mooncake"

[compat]
AMDGPU = "1"
Expand All @@ -35,6 +37,7 @@ GNNGraphs = "1.4"
LinearAlgebra = "1"
Metal = "1.0"
MLUtils = "0.4"
Mooncake = "0.5"
NNlib = "0.9"
Random = "1"
Statistics = "1"
Expand Down
50 changes: 50 additions & 0 deletions GNNlib/ext/GNNlibMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
module GNNlibMooncakeExt

using GNNlib: GNNlib, propagate, copy_xj
using GNNGraphs: GNNGraph, adjacency_matrix
using LinearAlgebra: adjoint
using Base: IEEEFloat
import Mooncake
using Mooncake: CoDual, DefaultCtx, NoRData, @is_primitive

# Mooncake reverse rule for the sparse message-passing fast path
# propagate(copy_xj, g, +, xi, xj, e) == xj * adjacency_matrix(g)
# `A` is constant w.r.t. the inputs, so the pullback is just `dxj = dy * A'`.
# Without it Mooncake differentiates the generic sparse matmul, which is far
# slower than Zygote.

@is_primitive DefaultCtx Tuple{
typeof(propagate),
typeof(copy_xj),
GNNGraph,
typeof(+),
Nothing,
AbstractMatrix{P},
Nothing,
} where {P <: IEEEFloat}

function Mooncake.rrule!!(
::CoDual{typeof(propagate)},
::CoDual{typeof(copy_xj)},
g::CoDual{<:GNNGraph},
::CoDual{typeof(+)},
::CoDual{Nothing},
xj::CoDual{<:AbstractMatrix{P}},
::CoDual{Nothing},
) where {P <: IEEEFloat}
pg = Mooncake.primal(g)
pxj = Mooncake.primal(xj)
A = adjacency_matrix(pg, P; weighted = false)
y = pxj * A
res = Mooncake.zero_fcodual(y)
function propagate_copy_xj_add_pullback!!(::NoRData)
dy = Mooncake.tangent(res)
dxj = Mooncake.tangent(xj)
dxj .+= dy * adjoint(A)
return NoRData(), NoRData(), Mooncake.zero_rdata(pg),
NoRData(), NoRData(), NoRData(), NoRData()
end
return res, propagate_copy_xj_add_pullback!!
end

end # module
1 change: 1 addition & 0 deletions GNNlib/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
8 changes: 4 additions & 4 deletions GNNlib/test/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,22 @@ end
@testset "copy_xj +" begin
for g in TEST_GRAPHS
f(g, x) = propagate(copy_xj, g, +, xj = x)
test_gradients(f, g, g.x; test_grad_f=false)
test_gradients(f, g, g.x; test_grad_f=false, test_mooncake=TEST_MOONCAKE)
end
end

@testset "copy_xj mean" begin
for g in TEST_GRAPHS
f(g, x) = propagate(copy_xj, g, mean, xj = x)
test_gradients(f, g, g.x; test_grad_f=false)
test_gradients(f, g, g.x; test_grad_f=false, test_mooncake=TEST_MOONCAKE)
end
end

@testset "e_mul_xj +" begin
for g in TEST_GRAPHS
e = rand(Float32, size(g.x, 1), g.num_edges)
f(g, x, e) = propagate(e_mul_xj, g, +; xj = x, e)
test_gradients(f, g, g.x, e; test_grad_f=false)
test_gradients(f, g, g.x, e; test_grad_f=false, test_mooncake=TEST_MOONCAKE)
end
end

Expand All @@ -158,7 +158,7 @@ end
g = set_edge_weight(g, w)
return propagate(w_mul_xj, g, +, xj = x)
end
test_gradients(f, g, g.x, w; test_grad_f=false)
test_gradients(f, g, g.x, w; test_grad_f=false, test_mooncake=TEST_MOONCAKE)
end
end
end
Expand Down
30 changes: 26 additions & 4 deletions GNNlib/test/test_module.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,16 @@ using FiniteDifferences: FiniteDifferences
using Zygote: Zygote
using Flux: Flux

# Mooncake.jl requires Julia >= 1.12
const TEST_MOONCAKE = VERSION >= v"1.12"
if TEST_MOONCAKE
import Mooncake
end

# from this module
export D_IN, D_OUT, GRAPH_TYPES, TEST_GRAPHS,
test_gradients, finitediff_withgradient,
check_equal_leaves, gpu_backend
check_equal_leaves, gpu_backend, TEST_MOONCAKE


const D_IN = 3
Expand Down Expand Up @@ -81,12 +87,13 @@ function test_gradients(
test_grad_f = true,
test_grad_x = true,
compare_finite_diff = true,
test_mooncake = false,
loss = (f, g, xs...) -> mean(f(g, xs...)),
)

if !test_gpu && !compare_finite_diff
error("You should either compare finite diff vs CPU AD \
or CPU AD vs GPU AD.")
if !test_gpu && !compare_finite_diff && !test_mooncake
error("You should either compare finite diff vs CPU AD, \
CPU AD vs GPU AD, or test Mooncake AD.")
end

## Let's make sure first that the forward pass works.
Expand Down Expand Up @@ -115,6 +122,14 @@ function test_gradients(
check_equal_leaves(g, g_fd; rtol, atol)
end

if test_mooncake && !(graph.graph isa AbstractSparseMatrix) # Mooncake friendly tangents currently error on sparse graph internals
# Mooncake gradient with respect to input via Flux integration, compared against Zygote.
loss_mc_x = (xs...) -> loss(f, graph, xs...)
y_mc, g_mc = Flux.withgradient(loss_mc_x, Flux.AutoMooncake(), xs...)
@assert isapprox(y, y_mc; rtol, atol)
check_equal_leaves(g, g_mc; rtol, atol)
end

if test_gpu
# Zygote gradient with respect to input on GPU.
y_gpu, g_gpu = Zygote.withgradient((xs...) -> loss(f_gpu, graph_gpu, xs...), xs_gpu...)
Expand All @@ -138,6 +153,13 @@ function test_gradients(
check_equal_leaves(g, g_fd; rtol, atol)
end

if test_mooncake && !(graph.graph isa AbstractSparseMatrix) # Mooncake friendly tangents currently error on sparse graph internals
# Mooncake gradient with respect to f via Flux integration, compared against Zygote.
y_mc, g_mc = Flux.withgradient(f -> loss(f, graph, xs...), Flux.AutoMooncake(), f)
@assert isapprox(y, y_mc; rtol, atol)
check_equal_leaves(g, g_mc; rtol, atol)
end

if test_gpu
# Zygote gradient with respect to f on GPU.
y_gpu, g_gpu = Zygote.withgradient(f -> loss(f,graph_gpu, xs_gpu...), f_gpu)
Expand Down
Loading