Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
revert: activations from SLEEFPirates
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 22, 2024
1 parent f2df920 commit 232b48e
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 63 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"

Expand Down Expand Up @@ -62,7 +61,6 @@ Random = "1.10"
ReTestItems = "1.23.1"
Reexport = "1"
ReverseDiff = "1.15"
SLEEFPirates = "0.6.43"
StableRNGs = "1"
Statistics = "1.10"
Test = "1.10"
Expand Down
1 change: 0 additions & 1 deletion src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ,
using Random: Random, AbstractRNG, rand!
using Reexport: @reexport
using Statistics: Statistics, mean, var
using SLEEFPirates: SLEEFPirates
using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter

@reexport using NNlib
Expand Down
7 changes: 0 additions & 7 deletions src/api/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@ generic implementation.
This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be
done by the user if needed.
!!! tip
Certain activation functions are replaced with specialized implementations from
[SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl). This might lead to
faster performance but can cause slight decrease in accuracy (in the floating point
limit).
## Arguments
- `σ`: Activation function
Expand Down
33 changes: 3 additions & 30 deletions src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@ _fast_activation(::typeof(identity), x::AbstractArray) = x

@stable default_mode="disable" function _fast_activation::F, x::AbstractArray) where {F}
if internal_operation_mode(x) isa LoopedArrayOp
σ_sleef = __sleefpirates_activation(σ)
RT = Core.Compiler._return_type(σ_sleef, Tuple{eltype(x)})
RT = Core.Compiler._return_type(σ, Tuple{eltype(x)})
y = similar(x, RT)
@simd ivdep for I in eachindex(y, x)
@inbounds y[I] = σ_sleef(x[I])
@inbounds y[I] = σ(x[I])
end
return y
end
Expand All @@ -44,9 +43,8 @@ _fast_activation!(::typeof(identity), x::AbstractArray) = x

@stable default_mode="disable" function _fast_activation!::F, x::AbstractArray) where {F}
if internal_operation_mode(x) isa LoopedArrayOp
σ_sleef = __sleefpirates_activation(σ)
@simd ivdep for I in eachindex(x)
@inbounds x[I] = σ_sleef(x[I])
@inbounds x[I] = σ(x[I])
end
return x
end
Expand Down Expand Up @@ -83,28 +81,3 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!

return CRC.rrule_via_ad(cfg, _fast_activation, σ, x)
end

# Specialized functions that use SLEEFPirates.jl to speed up the activation functions
sigmoid_fast_sleefpirates(x) = SLEEFPirates.sigmoid_fast(x)
softplus_sleefpirates(x) = SLEEFPirates.softplus(x)
logsigmoid_sleefpirates(x) = -softplus_sleefpirates(-x)
elu_sleefpirates(x, α=1) = SLEEFPirates.Elu(α)(x)
gelu_sleefpirates(x) = SLEEFPirates.gelu(x)
swish_sleefpirates(x) = Base.FastMath.mul_fast(x, sigmoid_fast_sleefpirates(x))
lisht_sleefpirates(x) = Base.FastMath.mul_fast(x, tanh_fast_sleefpirates(x))
tanh_sleefpirates(x) = SLEEFPirates.tanh(x)
tanh_fast_sleefpirates(x) = SLEEFPirates.tanh_fast(x)

# Convert to SLEEFPirates.jl
__sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f
__sleefpirates_activation(f::F, ::Type{Float32}) where {F} = __sleefpirates_activation(f)
__sleefpirates_activation(f::F, ::Type{Float64}) where {F} = __sleefpirates_activation(f)

for (fbase, ffast) in ((NNlib.sigmoid_fast, sigmoid_fast_sleefpirates),
(NNlib.softplus, softplus_sleefpirates), (NNlib.logsigmoid, logsigmoid_sleefpirates),
(NNlib.elu, elu_sleefpirates), (NNlib.gelu, gelu_sleefpirates),
(NNlib.swish, swish_sleefpirates), (NNlib.lisht, lisht_sleefpirates),
(Base.tanh, tanh_sleefpirates), (NNlib.tanh_fast, tanh_fast_sleefpirates))
@eval __sleefpirates_activation(::typeof($fbase)) = $ffast
end
__sleefpirates_activation(f::F) where {F} = f
44 changes: 24 additions & 20 deletions src/impl/affine_normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ end

function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F,
x::AbstractArray{<:Number, 4}, μ, σ², ::Nothing, ::Nothing, ϵ::Real) where {F}
@simd ivdep for L in axes(y, 4)
for K in axes(y, 3), J in axes(y, 2)
@inbounds _sc = inv(sqrt(σ²[1, 1, K, L] + ϵ))
@inbounds _bc = -μ[1, 1, K, L] * _sc
for I in axes(y, 1)
for L in axes(y, 4), K in axes(y, 3)
@inbounds _sc = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ))
@inbounds _bc = -μ[1, 1, K, L] * _sc
for J in axes(y, 2)
@simd ivdep for I in axes(y, 1)
@inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc)
end
end
Expand All @@ -73,11 +73,12 @@ end
function __affine_normalize_gn_impl!(::LoopedArrayOp, y::AbstractArray{<:Number, 4}, f::F,
x::AbstractArray{<:Number, 4}, μ, σ², scale::AbstractArray{<:Number, 4},
bias::AbstractArray{<:Number, 4}, ϵ::Real) where {F}
@simd ivdep for L in axes(y, 4)
for K in axes(y, 3), J in axes(y, 2)
@inbounds _sc = scale[1, J, K, 1] / sqrt(σ²[1, 1, K, L] + ϵ)
for L in axes(y, 4), K in axes(y, 3)
@inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ))
for J in axes(y, 2)
@inbounds _sc = scale[1, J, K, 1] * idenom
@inbounds _bc = muladd(-μ[1, 1, K, L], _sc, bias[1, J, K, 1])
for I in axes(y, 1)
@simd ivdep for I in axes(y, 1)
@inbounds y[I, J, K, L] = muladd(x[I, J, K, L], _sc, _bc)
end
end
Expand Down Expand Up @@ -181,11 +182,11 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², ::Nothi
∂x, ∂μ, ∂σ² = similar(x), zero.(μ), zero.(σ²)
half = eltype(∂σ²)(0.5)

@simd ivdep for L in axes(∂y, 4)
for K in axes(∂y, 3)
@inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ))
idenom² = idenom^2
for J in axes(∂y, 2), I in axes(∂y, 1)
for L in axes(∂y, 4), K in axes(∂y, 3)
@inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ))
idenom² = idenom^2
for J in axes(∂y, 2)
@simd for I in axes(∂y, 1)
@inbounds= x[I, J, K, L] - μ[1, 1, K, L]

@inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * idenom
Expand All @@ -194,20 +195,23 @@ function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², ::Nothi
end
end
end

return ∂x, ∂μ, ∂σ², ∂∅, ∂∅
end

function ∇affine_normalize_gn_impl(::LoopedArrayOp, ∂y, x, μ, σ², scale, bias, ϵ)
∂x, ∂μ, ∂σ², ∂sc, ∂b = similar(x), zero.(μ), zero.(σ²), zero.(scale), zero.(bias)
half = eltype(∂σ²)(0.5)

@simd ivdep for L in axes(∂y, 4)
for K in axes(∂y, 3)
@inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ))
idenom² = idenom^2
for J in axes(∂y, 2), I in axes(∂y, 1)
for L in axes(∂y, 4), K in axes(∂y, 3)
@inbounds idenom = @fastmath inv(sqrt(σ²[1, 1, K, L] + ϵ))
idenom² = idenom^2
for J in axes(∂y, 2)
@inbounds _sc = scale[1, J, K, 1] * idenom
@simd for I in axes(∂y, 1)
@inbounds= x[I, J, K, L] - μ[1, 1, K, L]

@inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * scale[1, J, K, 1] * idenom
@inbounds ∂x[I, J, K, L] = ∂y[I, J, K, L] * _sc
@inbounds ∂μ[1, 1, K, L] -= ∂x[I, J, K, L]
@inbounds ∂σ²[1, 1, K, L] -= ∂x[I, J, K, L] ** half * idenom²
@inbounds ∂sc[1, J, K, 1] += ∂y[I, J, K, L] ** idenom
Expand Down
8 changes: 5 additions & 3 deletions src/impl/bias_activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ function __bias_activation_impl!(
opmode = internal_operation_mode((y, x, bias))
bias_ = __reshape_bias_into_xdims(x, bias)
if opmode isa LoopedArrayOp
σ_sleef = __sleefpirates_activation(σ)
bc = Broadcast.instantiate(Broadcast.broadcasted(σ_sleef ∘ +, x, bias_))
bc = Broadcast.instantiate(Broadcast.broadcasted(σ ∘ +, x, bias_))
@simd ivdep for I in eachindex(bc)
@inbounds y[I] = bc[I]
end
Expand All @@ -143,7 +142,10 @@ function __apply_bias_activation_cached!!(
if can_setindex(x)
opmode = internal_operation_mode((x, bias))
if opmode isa LoopedArrayOp
y = broadcast(+, x, bias_)
bc = Broadcast.instantiate(Broadcast.broadcasted(σ ∘ +, x, bias_))
@simd ivdep for I in eachindex(bc)
@inbounds x[I] = bc[I]
end
return _fast_activation(σ, x), x
end
broadcast!(+, x, x, bias_)
Expand Down

0 comments on commit 232b48e

Please sign in to comment.