Skip to content

Commit

Permalink
test: remove Zygote type inference testing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 24, 2025
1 parent 27b0cba commit 06fe321
Show file tree
Hide file tree
Showing 16 changed files with 8 additions and 118 deletions.
2 changes: 1 addition & 1 deletion lib/LuxLib/src/impl/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(fused_conv),
z, ∇bias_activation = CRC.rrule_via_ad(cfg, bias_activation, act, y, bias)
∇fused_conv_cached = @closure Δ -> begin
old_threads = maybe_reduce_BLAS_threads(weight)
Δ = NNlib.colmajor(Δ)
Δ = NNlib.colmajor(recursive_unthunk(Δ))
_, _, ∂y, ∂b = ∇bias_activation(Δ)
∂w, ∂x, _ = ∇conv_bias(∂y, ∂b, weight, x, bias, cdims)
reset_BLAS_threads(old_threads)
Expand Down
11 changes: 3 additions & 8 deletions lib/LuxLib/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Functors: Functors
using ForwardDiff: ForwardDiff
using KernelAbstractions: KernelAbstractions
using LinearAlgebra: LinearAlgebra, BLAS
using MLDataDevices: get_device_type, CPUDevice
using MLDataDevices: MLDataDevices, get_device_type, CPUDevice
using NNlib: NNlib
using Static: Static, StaticBool, False, True, static
using StaticArraysCore: SVector, SMatrix
Expand Down Expand Up @@ -207,9 +207,8 @@ expand_batchdim(x::AbstractVector) = reshape(x, :, 1)
expand_batchdim(x::SVector{L, T}) where {L, T} = SMatrix{L, 1, T}(x)

function CRC.rrule(::typeof(expand_batchdim), x::AbstractMatrix)
proj_x = CRC.ProjectTo(x)
∇expand_batchdim = @closure Δ -> begin
return ∂∅, proj_x(view(Δ, :, :, 1))
return ∂∅, CRC.@thunk(CRC.ProjectTo(x)(proj_x(view(recursive_unthunk(Δ), :, :, 1))))
end
return expand_batchdim(x), ∇expand_batchdim
end
Expand Down Expand Up @@ -347,10 +346,6 @@ CRC.@non_differentiable can_loopvec_args_check(::Any...)

EnzymeRules.inactive_noinl(::typeof(can_loopvec_args_check), ::Any...) = nothing

unthunk_leaf(x) = Functors.isleaf(x)
unthunk_leaf(x::CRC.AbstractThunk) = true
unthunk_leaf(::CRC.AbstractTangent) = true

recursive_unthunk(x) = Functors.fmap(CRC.unthunk, x; exclude=unthunk_leaf)
recursive_unthunk(x) = Functors.fmap(CRC.unthunk, x; exclude=MLDataDevices.isleaf)

end
6 changes: 0 additions & 6 deletions lib/LuxLib/test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@
@jet apply_act_fast(f, x)
@jet apply_act_fast2(f, x)

@test @inferred(Zygote.gradient(apply_act, f, x)) isa Any
if f !== lisht
@test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any
end
@test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any

@test_gradients(apply_act, f, x; atol, rtol)
@test_gradients(apply_act_fast, f, x; atol, rtol, skip_backends=[AutoEnzyme()])
@test_gradients(apply_act_fast2, f, x; atol, rtol)
Expand Down
5 changes: 0 additions & 5 deletions lib/LuxLib/test/common_ops/bias_act_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@
@jet bias_act_loss2(act, x, b)
@jet bias_act_loss3(act, x, b)

if act !== lisht && T != Float16
@test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any
@test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any
end

@test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol,
soft_fail=fp16 ? [AutoFiniteDiff()] : [])
@test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol,
Expand Down
14 changes: 0 additions & 14 deletions lib/LuxLib/test/common_ops/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,6 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,
@test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any
@jet fused_conv_bias_activation(activation, weight, x, bias, cdims)

if mode != "amdgpu" && activation !== anonact
@test @inferred(Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims
)) isa Any
else
try
@inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims))
@test true
catch e
e isa ErrorException || rethrow()
@test_broken false
end
end

skip_backends = []
mp = Tx != Tw
mp && push!(skip_backends, AutoReverseDiff())
Expand Down
4 changes: 0 additions & 4 deletions lib/LuxLib/test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu
atol = 1.0f-3
rtol = 1.0f-3

if activation !== anonact
@test @inferred(Zygote.gradient(sumabs2dense, activation, w, x, bias)) isa Any
end

skip_backends = []
Tw != Tx && push!(skip_backends, AutoReverseDiff())

Expand Down
14 changes: 0 additions & 14 deletions lib/LuxLib/test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
@jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims)))
@test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any

__f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims)))
@test @inferred(Zygote.gradient(__f, x)) isa Any

@test_gradients(sumabs2first,
dropout, rng, x, T(0.5), Val(true), T(2), dims; atol=1.0f-3, rtol=1.0f-3)

Expand Down Expand Up @@ -67,10 +64,6 @@ end
@test rng != rng_
@test mask != mask_

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

@test_gradients(sumabs2first,
dropout, rng, x, LuxTestUtils.Constant(mask), T(0.5), Val(true), Val(true),
T(2), :; atol=1.0f-3, rtol=1.0f-3)
Expand All @@ -92,10 +85,6 @@ end
@test rng == rng_
@test mask == mask_

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any

@test_gradients(sumabs2first,
dropout, rng, x, LuxTestUtils.Constant(mask),
T(0.5), Val(true), Val(false), T(2), :;
Expand Down Expand Up @@ -145,9 +134,6 @@ end

@test_broken std(y)std(x) atol=1.0f-2 rtol=1.0f-2

__f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true))))
@test @inferred(Zygote.gradient(__f, x)) isa Any

@test_gradients(sumabs2first,
alpha_dropout, rng, x, T(0.5), Val(true); atol=1.0f-3, rtol=1.0f-3)

Expand Down
7 changes: 0 additions & 7 deletions lib/LuxLib/test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,6 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act,
soft_fail=[AutoFiniteDiff()],
skip_backends=[AutoEnzyme()], enzyme_set_runtime_activity=true)
end

if anonact !== act
lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm(
x, sc, b, rm, rv, tr, act, ϵ)))
@test @inferred(Zygote.gradient(
lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any
end
end

const ALL_TEST_CONFIGS = Iterators.product(
Expand Down
5 changes: 0 additions & 5 deletions lib/LuxLib/test/normalization/groupnorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu)
@test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any
@jet groupnorm(x, scale, bias, groups, act, epsilon)

if anonact !== act
lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ))
@test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any
end

@test y isa aType{T, length(sz)}
@test size(y) == sz

Expand Down
12 changes: 0 additions & 12 deletions lib/LuxLib/test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)
@test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any
@jet instancenorm(x, scale, bias, training, act, epsilon)

if anonact !== act && is_training(training)
lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ)))
@test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any
end

@test y isa aType{T, length(sz)}
@test size(y) == sz

Expand All @@ -50,13 +45,6 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)
x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa Any
@jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon)

if anonact !== act && is_training(training)
lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm(
x, sc, b, rm, rv, Val(true), act, m, ϵ)))
@test @inferred(Zygote.gradient(
lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)) isa Any
end

@test y isa aType{T, length(sz)}
@test size(y) == sz

Expand Down
5 changes: 0 additions & 5 deletions lib/LuxLib/test/normalization/layernorm_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ function run_layernorm_testing_core(

@test_gradients(sumabs2layernorm, x, scale, bias, act, dims, epsilon; atol, rtol,
soft_fail=[AutoFiniteDiff()])

if anonact !== act
lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ))
@test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any
end
end

anonact = x -> x^3
Expand Down
4 changes: 2 additions & 2 deletions lib/MLDataDevices/ext/MLDataDevicesChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function ChainRulesCore.rrule(
end
end

MLDataDevices.isleaf(::ChainRulesCore.ZeroTangent) = true
MLDataDevices.isleaf(::ChainRulesCore.NoTangent) = true
MLDataDevices.isleaf(::ChainRulesCore.AbstractTangent) = true
MLDataDevices.isleaf(::ChainRulesCore.AbstractThunk) = true

end
7 changes: 2 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using Static: Static, StaticBool, StaticInteger, StaticSymbol
using StaticArraysCore: SMatrix, SVector

using LuxCore: LuxCore, AbstractLuxLayer
using MLDataDevices: MLDataDevices
using NNlib: NNlib

const CRC = ChainRulesCore
Expand Down Expand Up @@ -231,11 +232,7 @@ end
calculate_gain(::typeof(NNlib.leakyrelu), x) = typeof(x)((2 / (1 + x^2)))
calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4

unthunk_leaf(x) = Functors.isleaf(x)
unthunk_leaf(x::CRC.AbstractThunk) = true
unthunk_leaf(::CRC.AbstractTangent) = true

recursive_unthunk(x) = Functors.fmap(CRC.unthunk, x; exclude=unthunk_leaf)
recursive_unthunk(x) = Functors.fmap(CRC.unthunk, x; exclude=MLDataDevices.isleaf)

end

Expand Down
3 changes: 0 additions & 3 deletions test/helpers/compact_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,6 @@
@test st_new.incr == 100

@test @inferred(model(x, ps, st)) isa Any

__f = (m, x, ps, st) -> sum(abs2, first(m(x, ps, st)))
@test @inferred(Zygote.gradient(__f, model, x, ps, st)) isa Any
end

@testset "Multiple @return" begin
Expand Down
21 changes: 0 additions & 21 deletions test/helpers/loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ end
@test loss_sum(ŷ, y) loss_res * 4
@test loss_sum2(ŷ, y) loss_res * 4

@test @inferred(Zygote.gradient(loss_mean, ŷ, y)) isa Any

@jet loss_mean(ŷ, y)
@jet loss_sum(ŷ, y)

Expand All @@ -90,8 +88,6 @@ end

@jet MSLELoss()(ŷ, y)

@test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any broken=ongpu

@test_gradients(Base.Fix2(MSLELoss(), y), ŷ; atol=1.0f-3, rtol=1.0f-3)
end
end
Expand Down Expand Up @@ -148,8 +144,6 @@ end
@jet celoss(ŷ, y)
@jet celoss_smooth(ŷ, y)

@test @inferred(Zygote.gradient(celoss, ŷ, y)) isa Any

@test_gradients(Base.Fix2(celoss, y), ŷ; atol=1.0f-3,
rtol=1.0f-3, skip_backends=VERSION v"1.11-" ? [AutoEnzyme()] : [])
end
Expand All @@ -171,8 +165,6 @@ end
@jet logitceloss(logŷ, y)
@jet logitceloss_smooth(logŷ, y)

@test @inferred(Zygote.gradient(logitceloss, logŷ, y)) isa Any

@test_gradients(Base.Fix2(logitceloss, y), logŷ; atol=1.0f-3,
rtol=1.0f-3, skip_backends=VERSION v"1.11-" ? [AutoEnzyme()] : [])
end
Expand All @@ -199,8 +191,6 @@ end
@jet bceloss(σ.(logŷ), y)
@jet bceloss_smooth(σ.(logŷ), y)

@test @inferred(Zygote.gradient(bceloss, σ.(logŷ), y)) isa Any

@test_gradients(Base.Fix2(bceloss, y), σ.(logŷ); atol=1.0f-3, rtol=1.0f-3,
enzyme_set_runtime_activity=true)
end
Expand All @@ -220,8 +210,6 @@ end
@jet logitbceloss(logŷ, y)
@jet logitbceloss_smooth(logŷ, y)

@test @inferred(Zygote.gradient(logitbceloss, logŷ, y)) isa Any

@test_gradients(Base.Fix2(logitbceloss, y), logŷ; atol=1.0f-3, rtol=1.0f-3,
enzyme_set_runtime_activity=true)
end
Expand All @@ -243,8 +231,6 @@ end

@jet BinaryFocalLoss()(ŷ, y)

@test @inferred(Zygote.gradient(BinaryFocalLoss(), ŷ, y)) isa Any broken=ongpu

@test_gradients(Base.Fix2(BinaryFocalLoss(), y), ŷ; atol=1.0f-3, rtol=1.0f-3)
end

Expand All @@ -266,8 +252,6 @@ end

@jet FocalLoss()(ŷ, y)

@test @inferred(Zygote.gradient(FocalLoss(), ŷ, y)) isa Any broken=ongpu

__f = Base.Fix2(FocalLoss(), y)
# FD will lead to out of domain errors
broken_backends = if VERSION v"1.11-"
Expand Down Expand Up @@ -297,7 +281,6 @@ end
@test KLDivergenceLoss()(y, y) 0

@jet KLDivergenceLoss()(ŷ, y)
@test @inferred(Zygote.gradient(KLDivergenceLoss(), ŷ, y)) isa Any

@test_gradients(Base.Fix2(KLDivergenceLoss(), y), ŷ; atol=1.0f-3,
rtol=1.0f-3, skip_backends=VERSION v"1.11-" ? [AutoEnzyme()] : [])
Expand All @@ -311,7 +294,6 @@ end
@test Lux.HingeLoss()(y, 0.5 .* y) 0.125

@jet Lux.HingeLoss()(ŷ, y)
@test @inferred(Zygote.gradient(Lux.HingeLoss(), ŷ, y)) isa Any

__f = Base.Fix2(Lux.HingeLoss(), y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
Expand All @@ -325,7 +307,6 @@ end
@test SquaredHingeLoss()(y, 0.5 .* y) 0.0625

@jet SquaredHingeLoss()(ŷ, y)
@inferred Zygote.gradient(SquaredHingeLoss(), ŷ, y)

__f = Base.Fix2(SquaredHingeLoss(), y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
Expand All @@ -339,7 +320,6 @@ end
@test Lux.PoissonLoss()(y, y) 0.5044459776946685

@jet Lux.PoissonLoss()(ŷ, y)
@test @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) isa Any

__f = Base.Fix2(Lux.PoissonLoss(), y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3)
Expand All @@ -353,7 +333,6 @@ end
@test DiceCoeffLoss()(y, y) 0.0

@jet DiceCoeffLoss()(ŷ, y)
@test @inferred(Zygote.gradient(DiceCoeffLoss(), ŷ, y)) isa Any broken=true

__f = Base.Fix2(DiceCoeffLoss(), y)
@test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3,
Expand Down
6 changes: 0 additions & 6 deletions test/zygote_type_stability.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,6 @@ include("setup_modes.jl")

@test @inferred(model(x, ps, Lux.testmode(st))) isa Any
@test @inferred(loss_function(model, x, ps, Lux.testmode(st))) isa Number
if mode == "amdgpu" && model isa Conv
@test_broken @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa
Any
else
@test @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa Any
end
end
end
end

0 comments on commit 06fe321

Please sign in to comment.