From a041e536bff1a917ab0da992f0e48c18709c988a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 15:43:39 -0700 Subject: [PATCH] feat: use sleefpirates for activation functions on CPU --- Project.toml | 2 ++ src/LuxLib.jl | 1 + src/api/activation.jl | 7 +++++++ src/api/bias_activation.jl | 4 ++++ src/impl/activation.jl | 33 ++++++++++++++++++++++++++++++--- src/impl/bias_activation.jl | 11 +++++------ src/impl/normalization.jl | 9 +++------ 7 files changed, 52 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 95d604c7..7297d338 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b" @@ -61,6 +62,7 @@ Random = "1.10" ReTestItems = "1.23.1" Reexport = "1" ReverseDiff = "1.15" +SLEEFPirates = "0.6.43" StableRNGs = "1" Statistics = "1.10" Test = "1.10" diff --git a/src/LuxLib.jl b/src/LuxLib.jl index d15fcce6..78f3bc76 100644 --- a/src/LuxLib.jl +++ b/src/LuxLib.jl @@ -17,6 +17,7 @@ using NNlib: NNlib, ConvDims, conv, conv!, relu, gelu, sigmoid_fast, swish, σ, using Random: Random, AbstractRNG, rand! using Reexport: @reexport using Statistics: Statistics, mean, var +using SLEEFPirates: SLEEFPirates using UnrolledUtilities: unrolled_any, unrolled_all, unrolled_filter @reexport using NNlib diff --git a/src/api/activation.jl b/src/api/activation.jl index 5bb791d2..6b06bda0 100644 --- a/src/api/activation.jl +++ b/src/api/activation.jl @@ -10,6 +10,13 @@ generic implementation. This function doesn't replace `σ` with `NNlib.fast_act(σ, ...)`, that needs to be done by the user if needed. +!!! tip + + Certain activation functions are replaced with specialized implementations from + [SLEEFPirates.jl](https://github.com/JuliaSIMD/SLEEFPirates.jl). This might lead to + faster performance but can cause slight decrease in accuracy (in the floating point + limit). + ## Arguments - `σ`: Activation function diff --git a/src/api/bias_activation.jl b/src/api/bias_activation.jl index 68bb5372..73b74c2b 100644 --- a/src/api/bias_activation.jl +++ b/src/api/bias_activation.jl @@ -10,6 +10,8 @@ single last dimension. - `σ`: Activation function - `x`: Input to be transformed - `bias`: Bias to be added. Can be `nothing`. + +See also [`bias_activation!!`](@ref), [`fast_activation!!`](@ref). """ function bias_activation(σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} _bias_act_check(x, bias) @@ -22,6 +24,8 @@ end Same as [`bias_activation`](@ref) but might update `x` in-place if possible. Users should not rely on `x` being mutated, it is recommended to use it like `y = bias_activation!!(σ, x, bias)`. If `x` is updated in-place, `y` aliases `x`. + +See also [`bias_activation`](@ref), [`fast_activation!!`](@ref). """ function bias_activation!!( σ::F, x::AbstractArray, bias::Optional{<:AbstractVector}) where {F} diff --git a/src/impl/activation.jl b/src/impl/activation.jl index 264e30f5..ab966dad 100644 --- a/src/impl/activation.jl +++ b/src/impl/activation.jl @@ -24,10 +24,11 @@ _fast_activation(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp - RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) + σ_sleef = __sleefpirates_activation(σ) + RT = Core.Compiler._return_type(σ_sleef, Tuple{eltype(x)}) y = similar(x, RT) @simd ivdep for I in eachindex(y, x) - @inbounds y[I] = σ(x[I]) + @inbounds y[I] = σ_sleef(x[I]) end return y end @@ -43,8 +44,9 @@ _fast_activation!(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation!(σ::F, x::AbstractArray) where {F} if internal_operation_mode(x) isa LoopedArrayOp + σ_sleef = __sleefpirates_activation(σ) @simd ivdep for I in eachindex(x) - @inbounds x[I] = σ(x[I]) + @inbounds x[I] = σ_sleef(x[I]) end return x end @@ -81,3 +83,28 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fast_activation!! return CRC.rrule_via_ad(cfg, _fast_activation, σ, x) end + +# Specialized functions that use SLEEFPirates.jl to speed up the activation functions +sigmoid_fast_sleefpirates(x) = SLEEFPirates.sigmoid_fast(x) +softplus_sleefpirates(x) = SLEEFPirates.softplus(x) +logsigmoid_sleefpirates(x) = -softplus_sleefpirates(-x) +elu_sleefpirates(x, α=1) = SLEEFPirates.Elu(α)(x) +gelu_sleefpirates(x) = SLEEFPirates.gelu(x) +swish_sleefpirates(x) = Base.FastMath.mul_fast(x, sigmoid_fast_sleefpirates(x)) +lisht_sleefpirates(x) = Base.FastMath.mul_fast(x, tanh_fast_sleefpirates(x)) +tanh_sleefpirates(x) = SLEEFPirates.tanh(x) +tanh_fast_sleefpirates(x) = SLEEFPirates.tanh_fast(x) + +# Convert to SLEEFPirates.jl +__sleefpirates_activation(f::F, ::Type{T}) where {F, T} = f +__sleefpirates_activation(f::F, ::Type{Float32}) where {F} = __sleefpirates_activation(f) +__sleefpirates_activation(f::F, ::Type{Float64}) where {F} = __sleefpirates_activation(f) + +for (fbase, ffast) in ((NNlib.sigmoid_fast, sigmoid_fast_sleefpirates), + (NNlib.softplus, softplus_sleefpirates), (NNlib.logsigmoid, logsigmoid_sleefpirates), + (NNlib.elu, elu_sleefpirates), (NNlib.gelu, gelu_sleefpirates), + (NNlib.swish, swish_sleefpirates), (NNlib.lisht, lisht_sleefpirates), + (NNlib.tanh, tanh_sleefpirates), (NNlib.tanh_fast, tanh_fast_sleefpirates)) + @eval __sleefpirates_activation(::typeof($fbase)) = $ffast +end +__sleefpirates_activation(f::F) where {F} = f diff --git a/src/impl/bias_activation.jl b/src/impl/bias_activation.jl index 7009bdac..b711d558 100644 --- a/src/impl/bias_activation.jl +++ b/src/impl/bias_activation.jl @@ -15,15 +15,13 @@ end function __generic_bias_activation( ::typeof(identity), x::AbstractArray{<:Number}, bias::AbstractVector{<:Number}) - bias_ = __reshape_bias_into_xdims(x, bias) - return broadcast(+, x, bias_) + return broadcast(+, x, __reshape_bias_into_xdims(x, bias)) end __generic_bias_activation(::typeof(identity), x::AbstractArray{<:Number}, ::Nothing) = x __generic_bias_activation(σ::F, x::AbstractArray{<:Number}, ::Nothing) where {F} = σ.(x) function __generic_bias_activation( σ::F, x::AbstractArray{<:Number, N}, bias::AbstractVector{<:Number}) where {F, N} - bias_ = __reshape_bias_into_xdims(x, bias) - return broadcast(σ ∘ +, x, bias_) + return broadcast(σ ∘ +, x, __reshape_bias_into_xdims(x, bias)) end # Entry Points to the implementation @@ -121,7 +119,8 @@ function __bias_activation_impl!( opmode = internal_operation_mode((y, x, bias)) bias_ = __reshape_bias_into_xdims(x, bias) if opmode isa LoopedArrayOp - bc = Broadcast.instantiate(Broadcast.broadcasted(σ ∘ +, x, bias_)) + σ_sleef = __sleefpirates_activation(σ) + bc = Broadcast.instantiate(Broadcast.broadcasted(σ_sleef ∘ +, x, bias_)) @simd ivdep for I in eachindex(bc) @inbounds y[I] = bc[I] end @@ -131,7 +130,7 @@ function __bias_activation_impl!( broadcast!(+, y, x, bias_) return y end - broadcast!(σ ∘ +, y, x, bias) + broadcast!(σ ∘ +, y, x, bias_) return y end diff --git a/src/impl/normalization.jl b/src/impl/normalization.jl index 0e34cb83..a603cbed 100644 --- a/src/impl/normalization.jl +++ b/src/impl/normalization.jl @@ -18,9 +18,9 @@ function __update_statistics(opmode, rμ, rσ², μ, σ², m1, m2) return rμ2, rσ²2 end function __update_statistics!(::LoopedArrayOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) - @inbounds @simd ivdep for I in eachindex(rμ2, rσ²2) - rμ2[I] = m3 * rμ[I] + m1 * μ[I] - rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] + @simd ivdep for I in eachindex(rμ2, rσ²2) + @inbounds rμ2[I] = m3 * rμ[I] + m1 * μ[I] + @inbounds rσ²2[I] = m3 * rσ²[I] + m2 * σ²[I] end end function __update_statistics!(::GPUBroadcastOp, rμ2, rσ²2, rμ, rσ², μ, σ², m1, m2, m3) @@ -38,7 +38,6 @@ end end CRC.@non_differentiable __update_statistics(::Any...) -# EnzymeRules.inactive_noinl(::typeof(__update_statistics), ::Any...) = nothing function _update_normalization_statistics( x::AbstractArray{T, N}, rμ::AbstractArray{<:Number, N}, @@ -54,8 +53,6 @@ function _update_normalization_statistics( end CRC.@non_differentiable _update_normalization_statistics(::Any...) -# NOTE: The following leads to mixed activity not sure why -# EnzymeRules.inactive_noinl(::typeof(_update_normalization_statistics), ::Any...) = nothing __accum_size(x, ::Val{dims}) where {dims} = prod(Base.Fix1(size, x), dims)