Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using Zygote: Zygote
using ForwardDiff: ForwardDiff
using ReverseDiff: ReverseDiff
using FiniteDifferences: FiniteDifferences
using Enzyme: Enzyme
using Compat: only

using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils
Expand Down
25 changes: 22 additions & 3 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ gradient(f, s::Symbol, args) = gradient(f, Val(s), args)
function gradient(f, ::Val{:Zygote}, args)
g = only(Zygote.gradient(f, args))
if isnothing(g)
# To respect the same output as other ADs
if args isa AbstractArray{<:Real}
return zeros(size(args)) # To respect the same output as other ADs
return zeros(size(args))
else
return zeros.(size.(args))
end
Expand All @@ -57,6 +58,24 @@ function gradient(f, ::Val{:Zygote}, args)
end
end

function gradient(f, ::Val{:EnzymeForward}, args)
# shape = size(args)
# f_prime(flatargs) = f(reshape(flatargs, shape...))
# return Enzyme.gradient(Enzyme.Forward, f_prime, reshape(args, prod(shape)))
d_args = zero(args)
Enzyme.autodiff(Enzyme.Forward, f, Enzyme.Active, Enzyme.Duplicated(args, d_args))
return d_args
end

function gradient(f, ::Val{:EnzymeReverse}, args)
# shape = size(args)
# f_prime(flatargs) = f(reshape(flatargs, shape...))
# return Enzyme.gradient(Enzyme.Reverse, f_prime, reshape(args, prod(shape)))
d_args = zero(args)
Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active, Enzyme.Duplicated(args, d_args))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need the Enzyme.Active but not sure...

return d_args
end

function gradient(f, ::Val{:ForwardDiff}, args)
return ForwardDiff.gradient(f, args)
end
Expand Down Expand Up @@ -90,7 +109,7 @@ testdiagfunction(k::MOKernel, A) = sum(kernelmatrix_diag(k, A))
testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B))

function test_ADs(
kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3]
kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse, :EnzymeForward], dims=[3, 3]
)
test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims)
if !test_fd.anynonpass
Expand All @@ -108,7 +127,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context())
end

function test_ADs(
k::MOKernel; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=(in=3, out=2, obs=3)
k::MOKernel; ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse, :EnzymeForward], dims=(in=3, out=2, obs=3)
)
test_fd = test_FiniteDiff(k, dims)
if !test_fd.anynonpass
Expand Down