Skip to content

Commit

Permalink
Add jlarrays test (EnzymeAD#2222)
Browse files Browse the repository at this point in the history
* Add jlarrays test

* Update Project.toml

* Update EnzymeGPUArraysCoreExt.jl

* Update sugar.jl

* Update EnzymeGPUArraysCoreExt.jl

* Update EnzymeGPUArraysCoreExt.jl

* Update jlarrays.jl

* Update jlarrays.jl
  • Loading branch information
wsmoses authored Dec 24, 2024
1 parent d014f66 commit dacc208
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 9 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ EnzymeCore = "0.8.8"
Enzyme_jll = "0.0.170"
GPUArraysCore = "0.1.6, 0.2"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1"
GPUArraysCore = "0.1.6, 0.2"
LLVM = "6.1, 7, 8, 9"
LogExpFunctions = "0.3"
ObjectFile = "0.4"
Expand Down
29 changes: 21 additions & 8 deletions ext/EnzymeGPUArraysCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@ module EnzymeGPUArraysCoreExt
using GPUArraysCore
using Enzyme

@inline function Enzyme.onehot(x::AbstractGPUArray)
Enzyme.onehot_internal(zerosetfn, x, 0, length(x))
end

@inline function Enzyme.onehot(x::AbstractGPUArray, start::Int, endl::Int)
Enzyme.onehot_internal(zerosetfn, x, start-1, endl-start+1)
end

function Enzyme.zerosetfn(x::AbstractGPUArray, i::Int)
res = zero(x)
@allowscalar @inbounds res[i] = 1
Expand All @@ -22,4 +14,25 @@ function Enzyme.zerosetfn!(x::AbstractGPUArray, i::Int, val)
return
end

@inline function Enzyme.onehot(x::AbstractGPUArray)
# Enzyme.onehot_internal(Enzyme.zerosetfn, x, 0, length(x))
N = length(x)
ntuple(Val(N)) do i
Base.@_inline_meta
res = zero(x)
@allowscalar @inbounds res[i] = 1
return res
end
end

@inline function onehot(x::AbstractArray, start::Int, endl::Int)
# Enzyme.onehot_internal(Enzyme.zerosetfn, x, start-1, endl-start+1)
ntuple(Val(endl - start + 1)) do i
Base.@_inline_meta
res = zero(x)
@allowscalar @inbounds res[i + start - 1] = 1
return res
end
end

end # module
2 changes: 1 addition & 1 deletion src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function zerosetfn!(x, i::Int, val)
nothing
end

@generated function onehot_internal(fn::F, x::T, startv::Int, lengthv::Int) where {F, T<:Array}
@generated function onehot_internal(fn::F, x::T, startv::Int, lengthv::Int) where {F, T<:AbstractArray}
ir = GPUCompiler.JuliaContext() do ctx
Base.@_inline_meta

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
11 changes: 11 additions & 0 deletions test/ext/jlarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using Enzyme, Test, JLArrays

function jlres(x)
2 * collect(x)
end

@testset "JLArrays" begin
# TODO fix activity of jlarray
# Enzyme.jacobian(Forward, jlres, JLArray([3.0, 5.0]))
# Enzyme.jacobian(Reverse, jlres, JLArray([3.0, 5.0]))
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3783,5 +3783,6 @@ include("ext/logexpfunctions.jl")
include("ext/bfloat16s.jl")
end

include("ext/jlarrays.jl")
include("ext/sparsearrays.jl")
include("ext/staticarrays.jl")

0 comments on commit dacc208

Please sign in to comment.