Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/internal_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
without deprecation.

```@autodocs
Modules = [Enzyme.Compiler]
Modules = [Enzyme.Compiler, Enzyme.Compiler.RecursiveMaps]
Order = [:module, :type, :constant, :macro, :function]
Filter = t -> !(t === Enzyme.Compiler.CheckNan)
```
49 changes: 5 additions & 44 deletions ext/EnzymeStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,50 +32,11 @@ end
end
end

@inline function Enzyme.EnzymeCore.make_zero(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T}}
return Base.zero(prev)::FT
end
@inline function Enzyme.EnzymeCore.make_zero(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
return Base.zero(prev)::FT
end

@inline function Enzyme.EnzymeCore.make_zero(
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false)
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:SArray{S,T},copy_if_inactive}
return Base.zero(prev)::FT
end
@inline function Enzyme.EnzymeCore.make_zero(
::Type{FT}, seen::IdDict, prev::FT, ::Val{copy_if_inactive} = Val(false)
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T},copy_if_inactive}
if haskey(seen, prev)
return seen[prev]
end
new = Base.zero(prev)::FT
seen[prev] = new
return new
end

@inline function Enzyme.EnzymeCore.make_zero!(
prev::FT, seen
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
if !isnothing(seen)
if prev in seen
return nothing
end
push!(seen, prev)
end
fill!(prev, zero(T))
return nothing
end
@inline function Enzyme.EnzymeCore.make_zero!(
prev::FT
) where {S,T<:Union{AbstractFloat,Complex{<:AbstractFloat}},FT<:MArray{S,T}}
Enzyme.EnzymeCore.make_zero!(prev, nothing)
return nothing
# SArrays and MArrays don't need special treatment for `make_zero(!)` to work or be correct,
# but in case their dedicated `zero` and `fill!` methods are more efficient than
# `make_zero(!)`s recursion, we opt into treating them as leaves.
@inline function Enzyme.EnzymeCore.isvectortype(::Type{<:StaticArray{S, T}}) where {S, T}
return isbitstype(T) && Enzyme.EnzymeCore.isscalartype(T)
end

end
125 changes: 114 additions & 11 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -506,28 +506,131 @@ function autodiff_thunk end
function autodiff_deferred_thunk end

"""
make_zero(::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}=Val(false))::T

Recursively make a zero'd copy of the value `prev` of type `T`. The argument `copy_if_inactive` specifies
what to do if the type `T` is guaranteed to be inactive, use the primal (the default) or still copy the value.
make_zero(prev::T; copy_if_inactive = Val(false), runtime_inactive = Val(false))::T
make_zero(prev::T, ::Val{copy_if_inactive}[, ::Val{runtime_inactive}])::T
make_zero(
::Type{T}, seen::IdDict, prev::T;
copy_if_inactive = Val(false), runtime_inactive = Val(false),
)::T
make_zero(
::Type{T}, seen::IdDict, prev::T, ::Val{copy_if_inactive}[, ::Val{runtime_inactive}]
)::T

Recursively make a copy of the value `prev::T` in which all differentiable values are zeroed.

The argument `copy_if_inactive` specifies what to do if the type `T` or any
of its constituent parts is guaranteed to be inactive (non-differentiable): reuse `prev`s
instance (if `Val(false)`, the default) or make a copy (if `Val(true)`).

The argument `runtime_inactive` specifies whether this function should respect runtime
semantics when determining if a type is guaranteed inactive. If `Val(false)`, only the
methods of `EnzymeRules.inactive_type` that were defined at the time of precompiling
`Enzyme` will be taken into account when determining a type's activity. If `Val(true)`, new
or changed methods of `EnzymeRules.inactive_type` will be taken into account as per usual
Julia semantics.

`copy_if_inactive` and `runtime_inactive` may be provided as either positional or keywords
arguments, but not a combination.

Extending this method for custom types is rarely needed. If you implement a new type, such
as a GPU array type, for which `make_zero` should directly invoke `zero` for scalar eltypes,
it is sufficient to implement `Base.zero` and make sure your type subtypes `DenseArray`. (If
subtyping `DenseArray` is not appropriate, extend [`EnzymeCore.isvectortype`](@ref)
instead.)
"""
function make_zero end

"""
make_zero!(val::T, seen::IdSet{Any}=IdSet())::Nothing
make_zero!(val::T, [seen::IdDict]; runtime_inactive = Val(false))::Nothing
make_zero!(val::T, [seen::IdDict], ::Val{runtime_inactive})::Nothing

Recursively set a variable's differentiable values to zero. Only applicable for types `T`
that are mutable or hold all differentiable values in mutable storage (e.g.,
`Tuple{Vector{Float64}}` qualifies but `Tuple{Float64}` does not). The recursion skips over
parts of `val` that are guaranteed to be inactive.

The argument `runtime_inactive` specifies whether this function should respect runtime
semantics when determining if a type is guaranteed inactive. If `Val(false)`, only the
methods of `EnzymeRules.inactive_type` that were defined at the time of precompiling
`Enzyme` will be taken into account when determining a type's activity. If `Val(true)`, new
or changed methods of `EnzymeRules.inactive_type` will be taken into account as per usual
Julia semantics.

`runtime_inactive` may be given as either a positional or a keyword argument.

Recursively set a variables differentiable fields to zero. Only applicable for mutable types `T`.
Extending this method for custom types is rarely needed. If you implement a new mutable
type, such as a GPU array type, for which `make_zero!` should directly invoke
`fill!(x, false)` for scalar eltypes, it is sufficient to implement `Base.zero`,
`Base.fill!`, and make sure your type subtypes `DenseArray`. (If subtyping `DenseArray` is
not appropriate, extend [`EnzymeCore.isvectortype`](@ref) instead.)
"""
function make_zero! end

"""
make_zero(prev::T)
isvectortype(::Type{T})::Bool

Helper function to recursively make zero.
"""
@inline function make_zero(prev::T, ::Val{copy_if_inactive}=Val(false)) where {T, copy_if_inactive}
make_zero(Core.Typeof(prev), IdDict(), prev, Val(copy_if_inactive))
Trait defining types whose values should be considered leaf nodes when [`make_zero`](@ref)
and [`make_zero!`](@ref) recurse through an object.

By default, `isvectortype(T) == true` when `isscalartype(T) == true` or when
`T <: DenseArray{U}` where `U` is a bitstype and `isscalartype(U) == true`.

A new vector type, such as a GPU array type, should normally subtype `DenseArray` and
inherit `isvectortype` that way. However if this is not appropariate, `isvectortype` may be
extended directly as follows:

```julia
@inline function EnzymeCore.isvectortype(::Type{T}) where {T<:NewArray}
U = eltype(T)
return isbitstype(U) && EnzymeCore.isscalartype(U)
end
```

In either case, the type should implement `Base.zero` and, if mutable, `Base.fill!`.

Extending `isvectortype` is mostly relevant for the lowest-level of abstraction of memory at
which vector space operations like addition and scalar multiplication are supported, the
prototypical case being `Array`. Regular Julia structs with vector space-like semantics
should normally not extend `isvectorspace`; `make_zero(!)` will recurse into them and act
directly on their backing arrays, just like how Enzyme treats them when differentiating. For
example, structured matrix wrappers and sparse array types that are backed by `Array` should
not extend `isvectortype`.

See also [`isscalartype`](@ref).
"""
function isvectortype end

"""
isscalartype(::Type{T})::Bool

Trait defining a subset of [`isvectortype`](@ref) types that should not be considered
composite, such that even if the type is mutable, [`make_zero!`](@ref) will not try to zero
values of the type in-place. For example, `BigFloat` is a mutable type but does not support
in-place mutation through any Julia API, and `isscalartype(BigFloat) == true` ensures that
`make_zero!` will not try to mutate `BigFloat` values.[^BigFloat]

By default, `isscalartype(T) == true` and `isscalartype(Complex{T}) == true` for concrete
types where `T <: AbstractFloat`.

A hypothetical new real number type with Enzyme support should usually subtype
`AbstractFloat` and inherit the `isscalartype` trait that way. If this is not appropriate,
the function can be extended as follows:

```julia
@inline EnzymeCore.isscalartype(::Type{NewReal}) = true
@inline EnzymeCore.isscalartype(::Type{Complex{NewReal}}) = true
```

In either case, the type should implement `Base.zero`.

See also [`isvectortype`](@ref).

[^BigFloat]: Enzyme does not support differentiating `BigFloat` as of this writing; it is
mentioned here only to demonstrate that it would be inappropriate to use traits like
`ismutable` or `isbitstype` to choose between in-place and out-of-place zeroing, showing the
need for a dedicated `isscalartype` trait.
"""
function isscalartype end

function tape_type end

Expand Down
8 changes: 2 additions & 6 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -463,12 +463,8 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
# compute the correct complex derivative in reverse mode by propagating the conjugate return values
# then subtracting twice the imaginary component to get the correct result

for (k, v) in seen
Compiler.recursive_accumulate(k, v, refn_seed)
end
for (k, v) in seen2
Compiler.recursive_accumulate(k, v, imfn_seed)
end
Compiler.accumulate_seen!(refn_seed, seen)
Compiler.accumulate_seen!(imfn_seed, seen2)

fused = fuse_complex_results(results, args...)

Expand Down
5 changes: 5 additions & 0 deletions src/analyses/activity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,11 @@
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState
end

Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world)::Bool where {T}
rt = Enzyme.Compiler.active_reg_inner(T, (), world)
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState

Check warning on line 432 in src/analyses/activity.jl

View check run for this annotation

Codecov / codecov/patch

src/analyses/activity.jl#L430-L432

Added lines #L430 - L432 were not covered by tests
end

"""
Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode)

Expand Down
2 changes: 1 addition & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ const JuliaGlobalNameMap = Dict{String,Any}(
include("absint.jl")
include("llvm/transforms.jl")
include("llvm/passes.jl")
include("typeutils/make_zero.jl")
include("typeutils/recursive_maps.jl")

function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt)
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world)
Expand Down
50 changes: 1 addition & 49 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,47 +253,6 @@
return EnzymeRules.AugmentedReturn(primal, shadow, shadow)
end


@inline function accumulate_into(
into::RT,
seen::IdDict,
from::RT,
)::Tuple{RT,RT} where {RT<:Array}
if Enzyme.Compiler.guaranteed_const(RT)
return (into, from)
end
if !haskey(seen, into)
seen[into] = (into, from)
for i in eachindex(from)
tup = accumulate_into(into[i], seen, from[i])
@inbounds into[i] = tup[1]
@inbounds from[i] = tup[2]
end
end
return seen[into]
end

@inline function accumulate_into(
into::RT,
seen::IdDict,
from::RT,
)::Tuple{RT,RT} where {RT<:AbstractFloat}
if !haskey(seen, into)
seen[into] = (into + from, RT(0))
end
return seen[into]
end

@inline function accumulate_into(into::RT, seen::IdDict, from::RT)::Tuple{RT,RT} where {RT}
if Enzyme.Compiler.guaranteed_const(RT)
return (into, from)
end
if !haskey(seen, into)
throw(AssertionError("Unknown type to accumulate into: $RT"))
end
return seen[into]
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::Const{typeof(Base.deepcopy)},
Expand All @@ -302,15 +261,8 @@
x::Annotation{Ty},
) where {RT,Ty}
if EnzymeRules.needs_shadow(config)
if EnzymeRules.width(config) == 1
accumulate_into(x.dval, IdDict(), shadow)
else
for i = 1:EnzymeRules.width(config)
accumulate_into(x.dval[i], IdDict(), shadow[i])
end
end
Compiler.accumulate_into!(x.dval, shadow)

Check warning on line 264 in src/internal_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/internal_rules.jl#L264

Added line #L264 was not covered by tests
end

return (nothing,)
end

Expand Down
Loading
Loading