Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: use sleefpirates for activation functions on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jul 21, 2024
1 parent 8d68102 commit a041e53
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 15 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"

Expand Down Expand Up @@ -61,6 +62,7 @@ Random = "1.10"
ReTestItems = "1.23.1"
Reexport = "1"
ReverseDiff = "1.15"
SLEEFPirates = "0.6.43"
StableRNGs = "1"
Statistics = "1.10"
Test = "1.10"
Expand Down
1 change: 1 addition & 0 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ,
using Random: Random, AbstractRNG, rand!
using Reexport: @reexport
using Statistics: Statistics, mean, var
using SLEEFPirates: SLEEFPirates
using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter

@reexport using NNlib
Expand Down
7 changes: 7 additions & 0 deletions src/api/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ generic implementation.
This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be
done by the user if needed.
!!! tip
Certain activation functions are replaced with specialized implementations from
[SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl). This might lead to
faster performance but can cause slight decrease in accuracy (in the floating point
limit).
## Arguments
- `σ`: Activation function
Expand Down
4 changes: 4 additions & 0 deletions src/api/bias_activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ single last dimension.
- `σ`: Activation function
- `x`: Input to be transformed
- `bias`: Bias to be added. Can be `nothing`.
See also [`bias_activation!!`](@ref), [`fast_activation!!`](@ref).
"""
function bias_activation::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F}
_bias_act_check(x, bias)
Expand All @@ -22,6 +24,8 @@ end
Same as [`bias_activation`](@ref) but might update `x` in-place if possible. Users should
not rely on `x` being mutated, it is recommended to use it like
`y = bias_activation!!(σ, x, bias)`. If `x` is updated in-place, `y` aliases `x`.
See also [`bias_activation`](@ref), [`fast_activation!!`](@ref).
"""
function bias_activation!!(
σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F}
Expand Down
33 changes: 30 additions & 3 deletions src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ _fast_activation(::typeof(identity), x::AbstractArray) = x

@stable default_mode="disable" function _fast_activation::F, x::AbstractArray) where {F}
if internal_operation_mode(x) isa LoopedArrayOp
RT = Core.Compiler._return_type(σ, Tuple{eltype(x)})
σ_sleef = __sleefpirates_activation(σ)
RT = Core.Compiler._return_type(σ_sleef, Tuple{eltype(x)})
y = similar(x, RT)
@simd ivdep for I in eachindex(y, x)
@inbounds y[I] = σ(x[I])
@inbounds y[I] = σ_sleef(x[I])
end
return y
end
Expand All @@ -43,8 +44,9 @@ _fast_activation!(::typeof(identity), x::AbstractArray) = x

@stable default_mode="disable" function _fast_activation!::F, x::AbstractArray) where {F}
if internal_operation_mode(x) isa LoopedArrayOp
σ_sleef = __sleefpirates_activation(σ)
@simd ivdep for I in eachindex(x)
@inbounds x[I] = σ(x[I])
@inbounds x[I] = σ_sleef(x[I])
end
return x
end
Expand Down Expand Up @@ -81,3 +83,28 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!!

return CRC.rrule_via_ad(cfg, _fast_activation, σ, x)
end

# Specialized functions that use SLEEFPirates.jl to speed up the activation functions
sigmoid_fast_sleefpirates(x) = SLEEFPirates.sigmoid_fast(x)
softplus_sleefpirates(x) = SLEEFPirates.softplus(x)
logsigmoid_sleefpirates(x) = -softplus_sleefpirates(-x)
elu_sleefpirates(x, α=1) = SLEEFPirates.Elu(α)(x)
gelu_sleefpirates(x) = SLEEFPirates.gelu(x)
swish_sleefpirates(x) = Base.FastMath.mul_fast(x, sigmoid_fast_sleefpirates(x))
lisht_sleefpirates(x) = Base.FastMath.mul_fast(x, tanh_fast_sleefpirates(x))
tanh_sleefpirates(x) = SLEEFPirates.tanh(x)
tanh_fast_sleefpirates(x) = SLEEFPirates.tanh_fast(x)

# Convert to SLEEFPirates.jl
__sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f
__sleefpirates_activation(f::F, ::Type{Float32}) where {F} = __sleefpirates_activation(f)
__sleefpirates_activation(f::F, ::Type{Float64}) where {F} = __sleefpirates_activation(f)

for (fbase, ffast) in ((NNlib.sigmoid_fast, sigmoid_fast_sleefpirates),
(NNlib.softplus, softplus_sleefpirates), (NNlib.logsigmoid, logsigmoid_sleefpirates),
(NNlib.elu, elu_sleefpirates), (NNlib.gelu, gelu_sleefpirates),
(NNlib.swish, swish_sleefpirates), (NNlib.lisht, lisht_sleefpirates),
(NNlib.tanh, tanh_sleefpirates), (NNlib.tanh_fast, tanh_fast_sleefpirates))
@eval __sleefpirates_activation(::typeof($fbase)) = $ffast
end
__sleefpirates_activation(f::F) where {F} = f
11 changes: 5 additions & 6 deletions src/impl/bias_activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@ end

function __generic_bias_activation(
::typeof(identity), x::AbstractArray{<:Number}, bias::AbstractVector{<:Number})
bias_ = __reshape_bias_into_xdims(x, bias)
return broadcast(+, x, bias_)
return broadcast(+, x, __reshape_bias_into_xdims(x, bias))
end
__generic_bias_activation(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x
__generic_bias_activation::F, x::AbstractArray{<:Number}, ::Nothing) where {F} = σ.(x)
function __generic_bias_activation(
σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N}
bias_ = __reshape_bias_into_xdims(x, bias)
return broadcast +, x, bias_)
return broadcast +, x, __reshape_bias_into_xdims(x, bias))
end

# Entry Points to the implementation
Expand Down Expand Up @@ -121,7 +119,8 @@ function __bias_activation_impl!(
opmode = internal_operation_mode((y, x, bias))
bias_ = __reshape_bias_into_xdims(x, bias)
if opmode isa LoopedArrayOp
bc = Broadcast.instantiate(Broadcast.broadcasted +, x, bias_))
σ_sleef = __sleefpirates_activation(σ)
bc = Broadcast.instantiate(Broadcast.broadcasted(σ_sleef +, x, bias_))
@simd ivdep for I in eachindex(bc)
@inbounds y[I] = bc[I]
end
Expand All @@ -131,7 +130,7 @@ function __bias_activation_impl!(
broadcast!(+, y, x, bias_)
return y
end
broadcast! +, y, x, bias)
broadcast! +, y, x, bias_)
return y
end

Expand Down
9 changes: 3 additions & 6 deletions src/impl/normalization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2)
return rμ2, rσ²2
end
function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3)
@inbounds @simd ivdep for I in eachindex(rμ2, rσ²2)
rμ2[I] = m3 * rμ[I] + m1 * μ[I]
rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I]
@simd ivdep for I in eachindex(rμ2, rσ²2)
@inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I]
@inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I]
end
end
function __update_statistics!(::GPUBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3)
Expand All @@ -38,7 +38,6 @@ end
end

CRC.@non_differentiable __update_statistics(::Any...)
# EnzymeRules.inactive_noinl(::typeof(__update_statistics), ::Any...) = nothing

function _update_normalization_statistics(
x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N},
Expand All @@ -54,8 +53,6 @@ function _update_normalization_statistics(
end

CRC.@non_differentiable _update_normalization_statistics(::Any...)
# NOTE: The following leads to mixed activity not sure why
# EnzymeRules.inactive_noinl(::typeof(_update_normalization_statistics), ::Any...) = nothing

__accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims)

Expand Down

0 comments on commit a041e53

Please sign in to comment.