From ac0517012a99f7c1da16094ea4c3ec6c6bc9cfc1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 21 Jul 2024 22:55:10 -0700 Subject: [PATCH] refactor: move turbo into single function --- src/api/activation.jl | 5 ++++- src/impl/activation.jl | 35 +++++++++++++++++------------------ 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/api/activation.jl b/src/api/activation.jl index 5bb791d2..0e05e74a 100644 --- a/src/api/activation.jl +++ b/src/api/activation.jl @@ -27,4 +27,7 @@ function _fast_activation!!(::Val{true}, σ::F, x::AbstractArray) where {F} return _fast_activation(σ, x) end -_fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} = _fast_activation!(σ, x) +function _fast_activation!!(::Val{false}, σ::F, x::AbstractArray) where {F} + _fast_activation!(σ, x) + return x +end diff --git a/src/impl/activation.jl b/src/impl/activation.jl index 65a2eb76..0b83e03f 100644 --- a/src/impl/activation.jl +++ b/src/impl/activation.jl @@ -19,19 +19,24 @@ function __activation_gradient(Δ, out, act::F, x) where {F} return broadcast(only_deriv, Δ, out, x) end +function _fast_activation!( + ::LoopedArrayOp, y::AbstractArray, σ::F, x::AbstractArray) where {F} + @turbo for I in eachindex(y, x) + @inbounds y[I] = σ(x[I]) + end +end +function _fast_activation!(opmode, y::AbstractArray, σ::F, x::AbstractArray) where {F} + broadcast!(σ, y, x) + return +end + # Entry Points to the implementation _fast_activation(::typeof(identity), x::AbstractArray) = x @stable default_mode="disable" function _fast_activation(σ::F, x::AbstractArray) where {F} - if internal_operation_mode(x) isa LoopedArrayOp - RT = Core.Compiler._return_type(σ, Tuple{eltype(x)}) - y = similar(x, RT) - @turbo for I in eachindex(y, x) - @inbounds y[I] = σ(x[I]) - end - return y - end - return broadcast(σ, x) + y = similar(x, Core.Compiler._return_type(σ, Tuple{eltype(x)})) + _fast_activation!(internal_operation_mode(x), y, σ, x) + return y end function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_activation), @@ -39,17 +44,11 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_activation) return CRC.rrule_via_ad(cfg, broadcast, σ, x) end -_fast_activation!(::typeof(identity), x::AbstractArray) = x +_fast_activation!(::typeof(identity), x::AbstractArray) = nothing @stable default_mode="disable" function _fast_activation!(σ::F, x::AbstractArray) where {F} - if internal_operation_mode(x) isa LoopedArrayOp - @turbo for I in eachindex(x) - @inbounds x[I] = σ(x[I]) - end - return x - end - broadcast!(σ, x, x) - return x + _fast_activation!(internal_operation_mode(x), x, σ, x) + return nothing end # Define rrule for `fast_activation!!`