diff --git a/.gitignore b/.gitignore index f687069..963aa6c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.DS_Store *.jl.*.cov *.jl.cov *.jl.mem diff --git a/Project.toml b/Project.toml index ca624cd..947003a 100644 --- a/Project.toml +++ b/Project.toml @@ -6,14 +6,16 @@ version = "0.6.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +DifferentiableExpectations = "fc55d66b-b2a8-4ccc-9d64-c0c2166ceb36" DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" [weakdeps] DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d" @@ -24,14 +26,16 @@ InferOptFrankWolfeExt = "DifferentiableFrankWolfe" [compat] ChainRulesCore = "1" DensityInterface = "0.4.0" +DifferentiableExpectations = "0.2" DifferentiableFrankWolfe = "0.2" -LinearAlgebra = "1" -Random = "1" +Distributions = "0.25" +DocStringExtensions = "0.9.3" +LinearAlgebra = "<0.0.1,1" +Random = "<0.0.1,1" RequiredInterfaces = "0.1.3" Statistics = "1" StatsBase = "0.33, 0.34" StatsFuns = "1.3" -ThreadsX = "0.1.11" julia = "1.10" [extras] diff --git a/docs/make.jl b/docs/make.jl index 7d81abb..811f75e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -6,11 +6,20 @@ DocMeta.setdocmeta!(InferOpt, :DocTestSetup, :(using InferOpt); recursive=true) # Copy README.md into docs/src/index.md (overwriting) -cp( - joinpath(dirname(@__DIR__), "README.md"), - joinpath(@__DIR__, "src", "index.md"); - force=true, -) +open(joinpath(@__DIR__, "src", "index.md"), "w") do io + println( + io, + """ + ```@meta + EditURL = "https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl/blob/main/README.md" + ``` + """, + ) + # Write the contents out below the meta bloc + for line in eachline(joinpath(dirname(@__DIR__), "README.md")) + println(io, line) + end +end # Parse test/tutorial.jl into docs/src/tutorial.md (overwriting) @@ -21,8 +30,14 @@ Literate.markdown(tuto_jl_file, tuto_md_dir; documenter=true, execute=false) makedocs(; modules=[InferOpt], authors="Guillaume Dalle, Léo Baty, Louis Bouvier, Axel Parmentier", + repo="https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl/blob/{commit}{path}#{line}", sitename="InferOpt.jl", - format=Documenter.HTML(), + format=Documenter.HTML(; + prettyurls=get(ENV, "CI", "false") == "true", + canonical="https://juliadecisionfocusedlearning.github.io/InferOpt.jl", + assets=String[], + repolink="https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl", + ), pages=[ "Home" => "index.md", "Background" => "background.md", diff --git a/ext/InferOptFrankWolfeExt.jl b/ext/InferOptFrankWolfeExt.jl index d96a449..3e2a43e 100644 --- a/ext/InferOptFrankWolfeExt.jl +++ b/ext/InferOptFrankWolfeExt.jl @@ -1,10 +1,11 @@ module InferOptFrankWolfeExt +using DifferentiableExpectations: + DifferentiableExpectations, FixedAtomsProbabilityDistribution using DifferentiableFrankWolfe: DifferentiableFrankWolfe, DiffFW using DifferentiableFrankWolfe: LinearMinimizationOracle # from FrankWolfe using DifferentiableFrankWolfe: IterativeLinearSolver # from ImplicitDifferentiation -using InferOpt: InferOpt, RegularizedFrankWolfe, FixedAtomsProbabilityDistribution -using InferOpt: compute_expectation, compute_probability_distribution +using InferOpt: InferOpt, RegularizedFrankWolfe using LinearAlgebra: dot """ diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 4d62096..70aa6e2 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -10,71 +10,78 @@ module InferOpt using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, Tangent, ZeroTangent using ChainRulesCore: rrule, rrule_via_ad, unthunk using DensityInterface: logdensityof +using DifferentiableExpectations: + DifferentiableExpectations, Reinforce, empirical_predistribution, empirical_distribution +using Distributions: + Distributions, + ContinuousUnivariateDistribution, + LogNormal, + Normal, + product_distribution, + logpdf +using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES using LinearAlgebra: dot -using Random: AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed! +using Random: Random, AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed! using Statistics: mean using StatsBase: StatsBase, sample using StatsFuns: logaddexp, softmax -using ThreadsX: ThreadsX using RequiredInterfaces include("interface.jl") +include("utils/utils.jl") include("utils/some_functions.jl") -include("utils/probability_distribution.jl") include("utils/pushforward.jl") -include("utils/generalized_maximizer.jl") +include("utils/linear_maximizer.jl") include("utils/isotonic_regression/isotonic_l2.jl") include("utils/isotonic_regression/isotonic_kl.jl") include("utils/isotonic_regression/projection.jl") -include("simple/interpolation.jl") -include("simple/identity.jl") +# Layers +include("layers/simple/interpolation.jl") +include("layers/simple/identity.jl") -include("regularized/abstract_regularized.jl") -include("regularized/soft_argmax.jl") -include("regularized/sparse_argmax.jl") -include("regularized/soft_rank.jl") -include("regularized/regularized_frank_wolfe.jl") +include("layers/perturbed/utils.jl") +include("layers/perturbed/perturbation.jl") +include("layers/perturbed/perturbed.jl") -include("perturbed/abstract_perturbed.jl") -include("perturbed/additive.jl") -include("perturbed/multiplicative.jl") -include("perturbed/perturbed_oracle.jl") - -include("imitation/spoplus_loss.jl") -include("imitation/ssvm_loss.jl") -include("imitation/fenchel_young_loss.jl") -include("imitation/imitation_loss.jl") -include("imitation/zero_one_loss.jl") +include("layers/regularized/abstract_regularized.jl") +include("layers/regularized/soft_argmax.jl") +include("layers/regularized/sparse_argmax.jl") +include("layers/regularized/soft_rank.jl") +include("layers/regularized/regularized_frank_wolfe.jl") if !isdefined(Base, :get_extension) include("../ext/InferOptFrankWolfeExt.jl") end +# Losses +include("losses/fenchel_young_loss.jl") +include("losses/spoplus_loss.jl") +include("losses/ssvm_loss.jl") +include("losses/zero_one_loss.jl") +include("losses/imitation_loss.jl") + export half_square_norm export shannon_entropy, negative_shannon_entropy export one_hot_argmax, ranking -export GeneralizedMaximizer, objective_value +export LinearMaximizer, apply_g, objective_value -export FixedAtomsProbabilityDistribution -export compute_expectation -export compute_probability_distribution export Pushforward export IdentityRelaxation export Interpolation -export AbstractRegularized, AbstractRegularizedGeneralizedMaximizer +export AbstractRegularized export SoftArgmax, soft_argmax export SparseArgmax, sparse_argmax export SoftRank, soft_rank, soft_rank_l2, soft_rank_kl export SoftSort, soft_sort, soft_sort_l2, soft_sort_kl export RegularizedFrankWolfe +export PerturbedOracle export PerturbedAdditive export PerturbedMultiplicative -export PerturbedOracle export FenchelYoungLoss export StructuredSVMLoss diff --git a/src/imitation/fenchel_young_loss.jl b/src/imitation/fenchel_young_loss.jl deleted file mode 100644 index d82c104..0000000 --- a/src/imitation/fenchel_young_loss.jl +++ /dev/null @@ -1,169 +0,0 @@ -""" - FenchelYoungLoss <: AbstractLossLayer - -Fenchel-Young loss associated with a given optimization layer. -``` -L(θ, y_true) = (Ω(y_true) - θᵀy_true) - (Ω(ŷ) - θᵀŷ) -``` - -Reference: - -# Fields - -- `optimization_layer::AbstractOptimizationLayer`: optimization layer that can be formulated as `ŷ(θ) = argmax {θᵀy - Ω(y)}` (either regularized or perturbed) -""" -struct FenchelYoungLoss{O<:AbstractOptimizationLayer} <: AbstractLossLayer - optimization_layer::O -end - -function Base.show(io::IO, fyl::FenchelYoungLoss) - (; optimization_layer) = fyl - return print(io, "FenchelYoungLoss($optimization_layer)") -end - -## Forward pass - -""" - (fyl::FenchelYoungLoss)(θ, y_true[; kwargs...]) -""" -function (fyl::FenchelYoungLoss)(θ::AbstractArray, y_true::AbstractArray; kwargs...) - l, _ = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) - return l -end - -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {O<:AbstractRegularized} - (; optimization_layer) = fyl - ŷ = optimization_layer(θ; kwargs...) - Ωy_true = compute_regularization(optimization_layer, y_true) - Ωŷ = compute_regularization(optimization_layer, ŷ) - l = (Ωy_true - dot(θ, y_true)) - (Ωŷ - dot(θ, ŷ)) - g = ŷ - y_true - return l, g -end - -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {O<:AbstractRegularizedGeneralizedMaximizer} - (; optimization_layer) = fyl - ŷ = optimization_layer(θ; kwargs...) - Ωy_true = compute_regularization(optimization_layer, y_true) - Ωŷ = compute_regularization(optimization_layer, ŷ) - maximizer = get_maximizer(optimization_layer) - l = - (Ωy_true - objective_value(maximizer, θ, y_true; kwargs...)) - - (Ωŷ - objective_value(maximizer, θ, ŷ; kwargs...)) - g = maximizer.g(ŷ; kwargs...) - maximizer.g(y_true; kwargs...) - return l, g -end - -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {O<:AbstractPerturbed} - (; optimization_layer) = fyl - F, almost_ŷ = fenchel_young_F_and_first_part_of_grad(optimization_layer, θ; kwargs...) - l = F - dot(θ, y_true) - g = almost_ŷ - y_true - return l, g -end - -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {P<:AbstractPerturbed{<:GeneralizedMaximizer}} - (; optimization_layer) = fyl - F, almost_g_of_ŷ = fenchel_young_F_and_first_part_of_grad( - optimization_layer, θ; kwargs... - ) - l = F - objective_value(optimization_layer.oracle, θ, y_true; kwargs...) - g = almost_g_of_ŷ - optimization_layer.oracle.g(y_true; kwargs...) - return l, g -end - -## Backward pass - -function ChainRulesCore.rrule( - fyl::FenchelYoungLoss, θ::AbstractArray, y_true::AbstractArray; kwargs... -) - l, g = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) - fyl_pullback(dl) = NoTangent(), dl * g, NoTangent() - return l, fyl_pullback -end - -## Specific overrides for perturbed layers - -function compute_F_and_y_samples( - perturbed::AbstractPerturbed{O,false}, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) where {O} - F_and_y_samples = [ - fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...) for - Z in Z_samples - ] - return F_and_y_samples -end - -function compute_F_and_y_samples( - perturbed::AbstractPerturbed{O,true}, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) where {O} - return ThreadsX.map( - Z -> fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...), Z_samples - ) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::AbstractPerturbed, θ::AbstractArray; kwargs... -) - Z_samples = sample_perturbations(perturbed, θ) - F_and_y_samples = compute_F_and_y_samples(perturbed, θ, Z_samples; kwargs...) - return mean(first, F_and_y_samples), mean(last, F_and_y_samples) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedAdditive, θ::AbstractArray, Z::AbstractArray; kwargs... -) - (; oracle, ε) = perturbed - η = θ .+ ε .* Z - y = oracle(η; kwargs...) - F = dot(η, y) - return F, y -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedAdditive{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... -) where {P,G,O<:GeneralizedMaximizer} - (; oracle, ε) = perturbed - η = θ .+ ε .* Z - y = oracle(η; kwargs...) - F = objective_value(oracle, η, y; kwargs...) - return F, oracle.g(y; kwargs...) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedMultiplicative, θ::AbstractArray, Z::AbstractArray; kwargs... -) - (; oracle, ε) = perturbed - eZ = exp.(ε .* Z .- ε^2 ./ 2) - η = θ .* eZ - y = oracle(η; kwargs...) - F = dot(η, y) - y_scaled = y .* eZ - return F, y_scaled -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedMultiplicative{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... -) where {P,G,O<:GeneralizedMaximizer} - (; oracle, ε) = perturbed - eZ = exp.(ε .* Z .- ε^2) - η = θ .* eZ - y = oracle(η; kwargs...) - F = objective_value(oracle, η, y; kwargs...) - y_scaled = y .* eZ - return F, oracle.g(y_scaled; kwargs...) -end diff --git a/src/interface.jl b/src/interface.jl index cb4ac75..aeafa6e 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -18,11 +18,13 @@ abstract type AbstractLayer end Supertype for all the optimization layers defined in InferOpt. # Interface -- `(layer::AbstractOptimizationLayer)(θ; kwargs...)` +- `(layer::AbstractOptimizationLayer)(θ::AbstractArray; kwargs...)` - `compute_probability_distribution(layer, θ; kwargs...)` (only if the layer is probabilistic) """ abstract type AbstractOptimizationLayer <: AbstractLayer end +get_maximizer(layer::AbstractOptimizationLayer) = nothing + ## Losses """ @@ -45,6 +47,6 @@ abstract type AbstractLossLayer <: AbstractLayer end """ compute_probability_distribution(layer, θ; kwargs...) -Apply a probabilistic optimization layer to an objective direction `θ` in order to generate a [`FixedAtomsProbabilityDistribution`](@ref) on the vertices of a polytope. +Apply a probabilistic optimization layer to an objective direction `θ` in order to generate a `FixedAtomsProbabilityDistribution` on the vertices of a polytope. """ function compute_probability_distribution end diff --git a/src/layers/perturbed/perturbation.jl b/src/layers/perturbed/perturbation.jl new file mode 100644 index 0000000..1c42877 --- /dev/null +++ b/src/layers/perturbed/perturbation.jl @@ -0,0 +1,89 @@ +""" +$TYPEDEF + +Abstract type for a perturbation. +It's a function that takes a parameter `θ` and returns a perturbed parameter by a distribution `perturbation_dist`. + +All subtypes should have a `perturbation_dist` + +# Existing implementations +- [`AdditivePerturbation`](@ref) +- [`MultiplicativePerturbation`](@ref) +""" +abstract type AbstractPerturbation <: ContinuousUnivariateDistribution end + +""" +$TYPEDSIGNATURES +""" +function Random.rand(rng::AbstractRNG, perturbation::AbstractPerturbation) + return rand(rng, perturbation.perturbation_dist) +end + +""" +$TYPEDEF + +Additive perturbation: θ ↦ θ + εZ, where Z is a random variable following `perturbation_dist`. + +# Fields +$TYPEDFIELDS +""" +struct AdditivePerturbation{F} + "base distribution for the perturbation" + perturbation_dist::F + "perturbation size" + ε::Float64 +end + +""" +$TYPEDSIGNATURES + +Apply the additive perturbation to the parameter `θ`. +""" +function (pdc::AdditivePerturbation)(θ::AbstractArray) + (; perturbation_dist, ε) = pdc + return product_distribution(θ .+ ε * perturbation_dist) +end + +""" +$TYPEDSIGNATURES + +Compute the gradient of the logdensity of η = θ + εZ w.r.t. θ., with Z ∼ N(0, 1). +""" +function normal_additive_grad_logdensity(ε, η, θ) + return ((η .- θ) ./ ε^2,) +end + +""" +$TYPEDEF + +Multiplicative perturbation: θ ↦ θ ⊙ exp(εZ - ε²/2) + +# Fields +$TYPEDFIELDS +""" +struct MultiplicativePerturbation{F} + "base distribution for the perturbation" + perturbation_dist::F + "perturbation size" + ε::Float64 +end + +""" +$TYPEDSIGNATURES + +Apply the multiplicative perturbation to the parameter `θ`. +""" +function (pdc::MultiplicativePerturbation)(θ::AbstractArray) + (; perturbation_dist, ε) = pdc + return product_distribution(θ .* ExponentialOf(ε * perturbation_dist - ε^2 / 2)) +end +""" +$TYPEDSIGNATURES + +Compute the gradient of the logdensity of η = θ ⊙ exp(εZ - ε²/2) w.r.t. θ., with Z ∼ N(0, 1). +!!! warning + η should be a realization of θ, i.e. should be of the same sign. +""" +function normal_multiplicative_grad_logdensity(ε, η, θ) + return (inv.(ε^2 .* θ) .* (log.(abs.(η)) - log.(abs.(θ)) .+ (ε^2 / 2)),) +end diff --git a/src/layers/perturbed/perturbed.jl b/src/layers/perturbed/perturbed.jl new file mode 100644 index 0000000..2bbb3c1 --- /dev/null +++ b/src/layers/perturbed/perturbed.jl @@ -0,0 +1,145 @@ +""" +$TYPEDEF + +Differentiable perturbation of a black box optimizer of type `F`, with perturbation of type `D`. + +This struct is as wrapper around `Reinforce` from DifferentiableExpectations.jl. + +There are three different available constructors that behave differently in the package: +- [`PerturbedOracle`](@ref) +- [`PerturbedAdditive`](@ref) +- [`PerturbedMultiplicative`](@ref) +""" +struct PerturbedOracle{D,F,t,variance_reduction,G,R,S} <: AbstractOptimizationLayer + reinforce::Reinforce{t,variance_reduction,F,D,G,R,S} +end + +""" +$TYPEDSIGNATURES + +Forward pass of the perturbed optimizer. +""" +function (perturbed::PerturbedOracle)(θ::AbstractArray; kwargs...) + return perturbed.reinforce(θ; kwargs...) +end + +function get_maximizer(perturbed::PerturbedOracle) + return perturbed.reinforce.f +end + +function compute_probability_distribution( + perturbed::PerturbedOracle, θ::AbstractArray; kwargs... +) + return empirical_distribution(perturbed.reinforce, θ; kwargs...) +end + +function Base.show(io::IO, perturbed::PerturbedOracle{<:AbstractPerturbation}) + (; reinforce) = perturbed + nb_samples = reinforce.nb_samples + ε = reinforce.dist_constructor.ε + seed = reinforce.seed + rng = reinforce.rng + perturbation = reinforce.dist_constructor.perturbation_dist + f = reinforce.f + return print( + io, + "PerturbedOracle($f, ε=$ε, nb_samples=$nb_samples, perturbation=$perturbation, rng=$(typeof(rng)), seed=$seed)", + ) +end + +""" +$TYPEDSIGNATURES + +Constructor for [`PerturbedOracle`](@ref). +""" +function PerturbedOracle( + maximizer, + dist_constructor; + dist_logdensity_grad=nothing, + nb_samples=1, + variance_reduction=true, + threaded=false, + seed=nothing, + rng=Random.default_rng(), + kwargs..., +) + return PerturbedOracle( + Reinforce( + maximizer, + dist_constructor, + dist_logdensity_grad; + nb_samples, + variance_reduction, + threaded, + seed, + rng, + kwargs..., + ), + ) +end + +""" +$TYPEDSIGNATURES + +Constructor for [`PerturbedOracle`](@ref) with an additive perturbation. +""" +function PerturbedAdditive( + maximizer; + ε=1.0, + perturbation_dist=Normal(0, 1), + nb_samples=1, + variance_reduction=true, + seed=nothing, + threaded=false, + rng=Random.default_rng(), + dist_logdensity_grad=if (perturbation_dist == Normal(0, 1)) + FixFirst(normal_additive_grad_logdensity, ε) + else + nothing + end, +) + dist_constructor = AdditivePerturbation(perturbation_dist, float(ε)) + return PerturbedOracle( + maximizer, + dist_constructor; + dist_logdensity_grad, + nb_samples, + variance_reduction, + seed, + threaded, + rng, + ) +end + +""" +$TYPEDSIGNATURES + +Constructor for [`PerturbedOracle`](@ref) with a multiplicative perturbation. +""" +function PerturbedMultiplicative( + maximizer; + ε=1.0, + perturbation_dist=Normal(0, 1), + nb_samples=1, + variance_reduction=true, + seed=nothing, + threaded=false, + rng=Random.default_rng(), + dist_logdensity_grad=if (perturbation_dist == Normal(0, 1)) + FixFirst(normal_multiplicative_grad_logdensity, ε) + else + nothing + end, +) + dist_constructor = MultiplicativePerturbation(perturbation_dist, float(ε)) + return PerturbedOracle( + maximizer, + dist_constructor; + dist_logdensity_grad, + nb_samples, + variance_reduction, + seed, + threaded, + rng, + ) +end diff --git a/src/layers/perturbed/utils.jl b/src/layers/perturbed/utils.jl new file mode 100644 index 0000000..509f080 --- /dev/null +++ b/src/layers/perturbed/utils.jl @@ -0,0 +1,26 @@ +""" +$TYPEDSIGNATURES + +Data structure modeling the exponential of a continuous univariate random variable. +""" +struct ExponentialOf{D<:ContinuousUnivariateDistribution} <: + ContinuousUnivariateDistribution + dist::D +end + +""" +$TYPEDSIGNATURES +""" +function Random.rand(rng::AbstractRNG, d::ExponentialOf) + return exp(rand(rng, d.dist)) +end + +""" +$TYPEDSIGNATURES + +Return the log-density of the [`ExponentialOf`](@ref) distribution at `x`. +It is equal to ``logpdf(d, log(x)) - log(x)`` +""" +function Distributions.logpdf(d::ExponentialOf, x::Real) + return logpdf(d.dist, log(x)) - log(x) +end diff --git a/src/regularized/abstract_regularized.jl b/src/layers/regularized/abstract_regularized.jl similarity index 52% rename from src/regularized/abstract_regularized.jl rename to src/layers/regularized/abstract_regularized.jl index 9a2a903..27a7e80 100644 --- a/src/regularized/abstract_regularized.jl +++ b/src/layers/regularized/abstract_regularized.jl @@ -1,60 +1,33 @@ """ AbstractRegularized <: AbstractOptimizationLayer -Convex regularization perturbation of a black box linear optimizer +Convex regularization perturbation of a black box linear (in θ) optimizer ``` -ŷ(θ) = argmax_{y ∈ C} {θᵀy - Ω(y)} +ŷ(θ) = argmax_{y ∈ C} {θᵀg(y) + h(y) - Ω(y)} ``` +with g and h functions of y. # Interface - - `(regularized::AbstractRegularized)(θ; kwargs...)`: return `ŷ(θ)` -- `compute_regularization(regularized, y)`: return `Ω(y)` +- `compute_regularization(regularized, y)`: return `Ω(y) +- `get_maximizer(regularized)`: return the associated `GeneralizedMaximizer` optimizer # Available implementations - - [`SoftArgmax`](@ref) - [`SparseArgmax`](@ref) +- [`SoftRank`](@ref) - [`RegularizedFrankWolfe`](@ref) """ abstract type AbstractRegularized <: AbstractOptimizationLayer end """ - AbstractRegularizedGeneralizedMaximizer <: AbstractRegularized - -Convex regularization perturbation of a black box **generalized** optimizer -``` -ŷ(θ) = argmax_{y ∈ C} {θᵀg(y) + h(y) - Ω(y)} -with g and h functions of y. -``` - -# Interface - -- `(regularized::AbstractRegularized)(θ; kwargs...)`: return `ŷ(θ)` -- `compute_regularization(regularized, y)`: return `Ω(y)` -- `get_maximizer(regularized)`: return the associated `GeneralizedMaximizer` optimizer -""" -abstract type AbstractRegularizedGeneralizedMaximizer <: AbstractRegularized end - -""" - compute_regularization(regularized, y) + compute_regularization(regularized::AbstractRegularized, y) Return the convex penalty `Ω(y)` associated with an `AbstractRegularized` layer. """ function compute_regularization end -""" - get_maximizer(regularized) - -Return the associated optimizer. -""" -function get_maximizer end - @required AbstractRegularized begin # (regularized::AbstractRegularized)(θ::AbstractArray; kwargs...) # waiting for RequiredInterfaces to support this (see https://github.com/Seelengrab/RequiredInterfaces.jl/issues/11) compute_regularization(::AbstractRegularized, y) end - -@required AbstractRegularizedGeneralizedMaximizer begin - get_maximizer(::AbstractRegularizedGeneralizedMaximizer) -end diff --git a/src/regularized/regularized_frank_wolfe.jl b/src/layers/regularized/regularized_frank_wolfe.jl similarity index 98% rename from src/regularized/regularized_frank_wolfe.jl rename to src/layers/regularized/regularized_frank_wolfe.jl index 8f68887..bd21ef2 100644 --- a/src/regularized/regularized_frank_wolfe.jl +++ b/src/layers/regularized/regularized_frank_wolfe.jl @@ -61,5 +61,5 @@ Apply `compute_probability_distribution(regularized, θ; kwargs...)` and return """ function (regularized::RegularizedFrankWolfe)(θ::AbstractArray; kwargs...) probadist = compute_probability_distribution(regularized, θ; kwargs...) - return compute_expectation(probadist) + return mean(probadist) end diff --git a/src/regularized/soft_argmax.jl b/src/layers/regularized/soft_argmax.jl similarity index 100% rename from src/regularized/soft_argmax.jl rename to src/layers/regularized/soft_argmax.jl diff --git a/src/regularized/soft_rank.jl b/src/layers/regularized/soft_rank.jl similarity index 100% rename from src/regularized/soft_rank.jl rename to src/layers/regularized/soft_rank.jl diff --git a/src/regularized/sparse_argmax.jl b/src/layers/regularized/sparse_argmax.jl similarity index 100% rename from src/regularized/sparse_argmax.jl rename to src/layers/regularized/sparse_argmax.jl diff --git a/src/simple/identity.jl b/src/layers/simple/identity.jl similarity index 100% rename from src/simple/identity.jl rename to src/layers/simple/identity.jl diff --git a/src/simple/interpolation.jl b/src/layers/simple/interpolation.jl similarity index 100% rename from src/simple/interpolation.jl rename to src/layers/simple/interpolation.jl diff --git a/src/losses/fenchel_young_loss.jl b/src/losses/fenchel_young_loss.jl new file mode 100644 index 0000000..c66e1e9 --- /dev/null +++ b/src/losses/fenchel_young_loss.jl @@ -0,0 +1,109 @@ +""" +$TYPEDEF + +Fenchel-Young loss associated with a given optimization layer. +``` +L(θ, y_true) = (Ω(y_true) - θᵀy_true) - (Ω(ŷ) - θᵀŷ) +``` + +Reference: + +# Fields +- `optimization_layer::AbstractOptimizationLayer`: optimization layer that can be formulated as `ŷ(θ) = argmax {θᵀy - Ω(y)}` (either regularized or perturbed) +""" +struct FenchelYoungLoss{O<:AbstractOptimizationLayer} <: AbstractLossLayer + optimization_layer::O +end + +function Base.show(io::IO, fyl::FenchelYoungLoss) + (; optimization_layer) = fyl + return print(io, "FenchelYoungLoss($optimization_layer)") +end + +## Forward pass + +""" +$TYPEDSIGNATURES + +Compute L(θ, y_true). +""" +function (fyl::FenchelYoungLoss)(θ::AbstractArray, y_true::AbstractArray; kwargs...) + l, _ = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) + return l +end + +function fenchel_young_loss_and_grad( + fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... +) where {O<:AbstractRegularized} + (; optimization_layer) = fyl + ŷ = optimization_layer(θ; kwargs...) + Ωy_true = compute_regularization(optimization_layer, y_true) + Ωŷ = compute_regularization(optimization_layer, ŷ) + maximizer = get_maximizer(optimization_layer) + l = + (Ωy_true - objective_value(maximizer, θ, y_true; kwargs...)) - + (Ωŷ - objective_value(maximizer, θ, ŷ; kwargs...)) + grad = apply_g(maximizer, ŷ; kwargs...) - apply_g(maximizer, y_true; kwargs...) + return l, grad +end + +function fenchel_young_loss_and_grad( + fyl::FenchelYoungLoss{<:PerturbedOracle}, + θ::AbstractArray, + y_true::AbstractArray; + kwargs..., +) + (; optimization_layer) = fyl + maximizer = get_maximizer(optimization_layer) + F, almost_ŷ = fenchel_young_F_and_first_part_of_grad(optimization_layer, θ; kwargs...) + l = F - objective_value(maximizer, θ, y_true; kwargs...) + g = almost_ŷ - apply_g(maximizer, y_true; kwargs...) + return l, g +end + +## Backward pass + +function ChainRulesCore.rrule( + fyl::FenchelYoungLoss, θ::AbstractArray, y_true::AbstractArray; kwargs... +) + l, g = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) + fyl_pullback(dl) = NoTangent(), dl * g, NoTangent() + return l, fyl_pullback +end + +## Specific overrides for perturbed layers + +function fenchel_young_F_and_first_part_of_grad( + perturbed::PerturbedOracle{<:AdditivePerturbation}, θ::AbstractArray; kwargs... +) + (; reinforce) = perturbed + maximizer = get_maximizer(perturbed) + η_dist = empirical_predistribution(reinforce, θ) + fk = FixKwargs(maximizer, kwargs) + gk = Fix1Kwargs(apply_g, maximizer, kwargs) + y_dist = map(fk, η_dist) + F = mean( + objective_value(maximizer, η, y; kwargs...) for + (η, y) in zip(η_dist.atoms, y_dist.atoms) + ) + ŷ = mean(gk, y_dist) + return F, ŷ +end + +function fenchel_young_F_and_first_part_of_grad( + perturbed::PerturbedOracle{<:MultiplicativePerturbation}, θ::AbstractArray; kwargs... +) + (; reinforce) = perturbed + maximizer = get_maximizer(perturbed) + η_dist = empirical_predistribution(reinforce, θ) + fk = FixKwargs(reinforce.f, kwargs) + gk = Fix1Kwargs(apply_g, maximizer, kwargs) + y_dist = map(fk, η_dist) + eZ_dist = map(Base.Fix2(./, θ), η_dist) + F = mean( + objective_value(maximizer, η, y; kwargs...) for + (η, y) in zip(η_dist.atoms, y_dist.atoms) + ) + almost_ŷ = mean(gk.(map(.*, eZ_dist.atoms, y_dist.atoms))) + return F, almost_ŷ +end diff --git a/src/imitation/imitation_loss.jl b/src/losses/imitation_loss.jl similarity index 100% rename from src/imitation/imitation_loss.jl rename to src/losses/imitation_loss.jl diff --git a/src/imitation/spoplus_loss.jl b/src/losses/spoplus_loss.jl similarity index 65% rename from src/imitation/spoplus_loss.jl rename to src/losses/spoplus_loss.jl index c572842..46f0cc3 100644 --- a/src/imitation/spoplus_loss.jl +++ b/src/losses/spoplus_loss.jl @@ -1,16 +1,17 @@ """ - SPOPlusLoss <: AbstractLossLayer +$TYPEDEF Convex surrogate of the Smart "Predict-then-Optimize" loss. # Fields -- `maximizer`: linear maximizer function of the form `θ -> ŷ(θ) = argmax θᵀy` -- `α::Float64`: convexification parameter, default = 2.0 +$TYPEDFIELDS Reference: """ struct SPOPlusLoss{F} <: AbstractLossLayer + "linear maximizer function of the form `θ -> ŷ(θ) = argmax θᵀy`" maximizer::F + "convexification parameter, default = 2.0" α::Float64 end @@ -20,14 +21,20 @@ function Base.show(io::IO, spol::SPOPlusLoss) end """ - SPOPlusLoss(maximizer; α=2.0) +$TYPEDSIGNATURES + +Constructor for [`SPOPlusLoss`](@ref). """ SPOPlusLoss(maximizer; α=2.0) = SPOPlusLoss(maximizer, float(α)) ## Forward pass """ - (spol::SPOPlusLoss)(θ, θ_true, y_true; kwargs...) +$TYPEDSIGNATURES + +Forward pass of the SPO+ loss with given target `θ_true` and `y_true`. +The third argument `y_true` is optional, as it can be computed from `θ_true`. +However, providing it directly can save computation time. """ function (spol::SPOPlusLoss)( θ::AbstractArray, θ_true::AbstractArray, y_true::AbstractArray; kwargs... @@ -35,17 +42,7 @@ function (spol::SPOPlusLoss)( (; maximizer, α) = spol θ_α = α * θ - θ_true y_α = maximizer(θ_α; kwargs...) - l = dot(θ_α, y_α) - dot(θ_α, y_true) - return l -end - -function (spol::SPOPlusLoss{<:GeneralizedMaximizer})( - θ::AbstractArray, θ_true::AbstractArray, y_true::AbstractArray; kwargs... -) - (; maximizer, α) = spol - θ_α = α * θ - θ_true - y_α = maximizer(θ_α; kwargs...) - # This only works in theory if α = 2 + # In theory, in the general case with a LinearMaximizer, this only works if α = 2 l = objective_value(maximizer, θ_α, y_α; kwargs...) - objective_value(maximizer, θ_α, y_true; kwargs...) @@ -53,7 +50,10 @@ function (spol::SPOPlusLoss{<:GeneralizedMaximizer})( end """ - (spol::SPOPlusLoss)(θ, θ_true; kwargs...) +$TYPEDSIGNATURES + +Forward pass of the SPO+ loss with given target `θ_true`. +For better performance, you can also provide `y_true` directly as a third argument. """ function (spol::SPOPlusLoss)(θ::AbstractArray, θ_true::AbstractArray; kwargs...) y_true = spol.maximizer(θ_true; kwargs...) @@ -68,20 +68,6 @@ function compute_loss_and_gradient( θ_true::AbstractArray, y_true::AbstractArray; kwargs..., -) - (; maximizer, α) = spol - θ_α = α * θ - θ_true - y_α = maximizer(θ_α; kwargs...) - l = dot(θ_α, y_α) - dot(θ_α, y_true) - return l, α .* (y_α .- y_true) -end - -function compute_loss_and_gradient( - spol::SPOPlusLoss{<:GeneralizedMaximizer}, - θ::AbstractArray, - θ_true::AbstractArray, - y_true::AbstractArray; - kwargs..., ) (; maximizer, α) = spol θ_α = α * θ - θ_true @@ -89,7 +75,7 @@ function compute_loss_and_gradient( l = objective_value(maximizer, θ_α, y_α; kwargs...) - objective_value(maximizer, θ_α, y_true; kwargs...) - g = α .* (maximizer.g(y_α; kwargs...) - maximizer.g(y_true; kwargs...)) + g = α .* (apply_g(maximizer, y_α; kwargs...) - apply_g(maximizer, y_true; kwargs...)) return l, g end diff --git a/src/imitation/ssvm_loss.jl b/src/losses/ssvm_loss.jl similarity index 100% rename from src/imitation/ssvm_loss.jl rename to src/losses/ssvm_loss.jl diff --git a/src/imitation/zero_one_loss.jl b/src/losses/zero_one_loss.jl similarity index 100% rename from src/imitation/zero_one_loss.jl rename to src/losses/zero_one_loss.jl diff --git a/src/perturbed/abstract_perturbed.jl b/src/perturbed/abstract_perturbed.jl deleted file mode 100644 index 5ac197f..0000000 --- a/src/perturbed/abstract_perturbed.jl +++ /dev/null @@ -1,153 +0,0 @@ -""" - AbstractPerturbed{F,parallel} <: AbstractOptimizationLayer - -Differentiable perturbation of a black box optimizer of type `F`. - -The parameter `parallel` is a boolean value indicating if the perturbations are run in parallel. -This is particularly useful if your black box optimizer running time is high. - -# Available implementations: -- [`PerturbedAdditive`](@ref) -- [`PerturbedMultiplicative`](@ref) -- [`PerturbedOracle`](@ref) - -# These three subtypes share the following fields: -- `oracle`: black box (optimizer) -- `perturbation::P`: perturbation distribution of the input θ -- `grad_logdensity::G`: gradient of the log density `perturbation` w.r.t. input θ -- `nb_samples::Int`: number of perturbation samples drawn at each forward pass -- `seed::Union{Nothing,Int}`: seed of the perturbation. - It is reset each time the forward pass is called, - making it deterministic by always drawing the same perturbations. - If you do not want this behaviour, set this field to `nothing`. -- `rng::AbstractRNG`: random number generator using the `seed`. - -!!! warning - The `perturbation` field does not mean the same thing for a [`PerturbedOracle`](@ref) - than for a [`PerturbedAdditive`](@ref)/[`PerturbedMultiplicative`](@ref). See their respective docs. -""" -abstract type AbstractPerturbed{O,parallel} <: AbstractOptimizationLayer end - -# Non parallelized version -function compute_atoms( - perturbed::AbstractPerturbed{O,false}, η_samples::Vector{<:AbstractArray}; kwargs... -) where {O} - return [perturbed.oracle(η; kwargs...) for η in η_samples] -end - -# Parallelized version -function compute_atoms( - perturbed::AbstractPerturbed{O,true}, η_samples::Vector{<:AbstractArray}; kwargs... -) where {O} - return ThreadsX.map(η -> perturbed.oracle(η; kwargs...), η_samples) -end - -""" - sample_perturbations(perturbed::AbstractPerturbed, θ::AbstractArray) - -Draw `nb_samples` random perturbations from the `perturbation` distribution. -""" -function sample_perturbations end - -""" - perturbation_grad_logdensity( - ::RuleConfig, - ::AbstractPerturbed, - θ::AbstractArray, - sample::AbstractArray, - ) - -Compute de gradient w.r.t to the input `θ` of the logdensity of the perturbed input -distribution evaluated in the observed perturbation sample `η`. -""" -function perturbation_grad_logdensity end - -""" - compute_probability_distribution_from_samples( - ::AbstractPerturbed, - θ::AbstractArray, - samples::Vector{<:AbstractArray}; - kwargs..., - ) - -Create a probability distributions from `samples` drawn from `perturbation`. -""" -function compute_probability_distribution_from_samples end - -""" - compute_probability_distribution(perturbed::AbstractPerturbed, θ; kwargs...) - -Turn random perturbations of `θ` into a distribution on polytope vertices. - -Keyword arguments are passed to the underlying linear maximizer. -""" -function compute_probability_distribution( - perturbed::AbstractPerturbed, - θ::AbstractArray; - autodiff_variance_reduction::Bool=false, - kwargs..., -) - η_samples = sample_perturbations(perturbed, θ) - return compute_probability_distribution_from_samples(perturbed, θ, η_samples; kwargs...) -end - -# Forward pass - -""" - (perturbed::AbstractPerturbed)(θ; kwargs...) - -Forward pass. Compute the expectation of the underlying distribution. -""" -function (perturbed::AbstractPerturbed)( - θ::AbstractArray; autodiff_variance_reduction::Bool=true, kwargs... -) - probadist = compute_probability_distribution( - perturbed, θ; autodiff_variance_reduction, kwargs... - ) - return compute_expectation(probadist) -end - -function perturbation_grad_logdensity( - ::RuleConfig, perturbed::AbstractPerturbed, θ::AbstractArray, η::AbstractArray -) - return perturbed.grad_logdensity(θ, η) -end - -# Backward pass - -function ChainRulesCore.rrule( - rc::RuleConfig, - ::typeof(compute_probability_distribution), - perturbed::AbstractPerturbed, - θ::AbstractArray; - autodiff_variance_reduction::Bool=true, - kwargs..., -) - η_samples = sample_perturbations(perturbed, θ) - y_dist = compute_probability_distribution_from_samples( - perturbed, θ, η_samples; kwargs... - ) - - ∇logp_samples = [perturbation_grad_logdensity(rc, perturbed, θ, η) for η in η_samples] - - M = perturbed.nb_samples - function perturbed_oracle_dist_pullback(δy_dist) - weights = y_dist.weights - δy_weights = δy_dist.weights - δy_sum = sum(δy_weights) - δθ = sum( - map(1:M) do i - δyᵢ, ∇logpᵢ, w = δy_weights[i], ∇logp_samples[i], weights[i] - if autodiff_variance_reduction - bᵢ = M == 1 ? 0 * δy_sum : (δy_sum - δyᵢ) / (M - 1) - return w * (δyᵢ - bᵢ) * ∇logpᵢ - else - return w * δyᵢ * ∇logpᵢ - end - end, - ) - return NoTangent(), NoTangent(), δθ - end - - return y_dist, perturbed_oracle_dist_pullback -end diff --git a/src/perturbed/additive.jl b/src/perturbed/additive.jl deleted file mode 100644 index e9923d5..0000000 --- a/src/perturbed/additive.jl +++ /dev/null @@ -1,128 +0,0 @@ -""" - PerturbedAdditive{P,G,O,R,S,parallel} <: AbstractPerturbed{parallel} - -Differentiable normal perturbation of a black-box maximizer: the input undergoes `θ -> θ + εZ` where `Z ∼ N(0, I)`. - -This [`AbstractOptimizationLayer`](@ref) is compatible with [`FenchelYoungLoss`](@ref), -if the oracle is an optimization maximizer with a linear objective. - -Reference: - -See [`AbstractPerturbed`](@ref) for more details. - -# Specific field -- `ε:Float64`: size of the perturbation -""" -struct PerturbedAdditive{P,G,O,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: - AbstractPerturbed{O,parallel} - perturbation::P - grad_logdensity::G - oracle::O - rng::R - seed::S - nb_samples::Int - ε::Float64 -end - -function Base.show(io::IO, perturbed::PerturbedAdditive) - (; oracle, ε, rng, seed, nb_samples, perturbation) = perturbed - perturb = isnothing(perturbation) ? "Normal(0, 1)" : "$perturbation" - return print( - io, "PerturbedAdditive($oracle, $ε, $nb_samples, $(typeof(rng)), $seed, $perturb)" - ) -end - -""" - PerturbedAdditive(oracle[; ε, nb_samples, seed, is_parallel, perturbation, grad_logdensity, rng]) - -[`PerturbedAdditive`](@ref) constructor. - -# Arguments -- `oracle`: the black-box oracle we want to differentiate through. - It should be a linear maximizer if you want to use it inside a [`FenchelYoungLoss`](@ref). - -# Keyword arguments (optional) -- `ε=1.0`: size of the perturbation. -- `nb_samples::Int=1`: number of perturbation samples drawn at each forward pass. -- `perturbation=nothing`: nothing by default. If you want to use a different distribution than a - `Normal` for the perturbation `z`, give it here as a distribution-like object implementing - the `rand` method. It should also implement `logdensityof` if `grad_logdensity` is not given. -- `grad_logdensity=nothing`: gradient function of `perturbation` w.r.t. `θ`. - If set to nothing (default), it's computed using automatic differentiation. -- `seed::Union{Nothing,Int}=nothing`: seed of the perturbation. - It is reset each time the forward pass is called, - making it deterministic by always drawing the same perturbations. - If you do not want this behaviour, set this field to `nothing`. -- `rng::AbstractRNG`=MersenneTwister(0): random number generator using the `seed`. -""" -function PerturbedAdditive( - oracle::O; - ε=1.0, - nb_samples=1, - seed::S=nothing, - is_parallel::Bool=false, - perturbation::P=nothing, - grad_logdensity::G=nothing, - rng::R=MersenneTwister(0), -) where {P,G,O,R<:AbstractRNG,S<:Union{Int,Nothing}} - return PerturbedAdditive{P,G,O,R,S,is_parallel}( - perturbation, grad_logdensity, oracle, rng, seed, nb_samples, float(ε) - ) -end - -function sample_perturbations(perturbed::PerturbedAdditive, θ::AbstractArray) - (; rng, seed, nb_samples, perturbation) = perturbed - seed!(rng, seed) - return [rand(rng, perturbation, size(θ)) for _ in 1:nb_samples] -end - -function sample_perturbations(perturbed::PerturbedAdditive{Nothing}, θ::AbstractArray) - (; rng, seed, nb_samples) = perturbed - seed!(rng, seed) - return [randn(rng, size(θ)) for _ in 1:nb_samples] -end - -function compute_probability_distribution_from_samples( - perturbed::PerturbedAdditive, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) - (; ε) = perturbed - η_samples = [θ .+ ε .* Z for Z in Z_samples] - atoms = compute_atoms(perturbed, η_samples; kwargs...) - weights = ones(length(atoms)) ./ length(atoms) - probadist = FixedAtomsProbabilityDistribution(atoms, weights) - return probadist -end - -function perturbation_grad_logdensity( - ::RuleConfig, - perturbed::PerturbedAdditive{Nothing,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) - (; ε) = perturbed - return Z ./ ε -end - -function _perturbation_logdensity( - perturbed::PerturbedAdditive, θ::AbstractArray, η::AbstractArray -) - (; ε, perturbation) = perturbed - Z = (η .- θ) ./ ε - return sum(logdensityof(perturbation, z) for z in Z) -end - -function perturbation_grad_logdensity( - rc::RuleConfig, - perturbed::PerturbedAdditive{P,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) where {P} - (; ε) = perturbed - η = θ .+ ε .* Z - l, logdensity_pullback = rrule_via_ad(rc, _perturbation_logdensity, perturbed, θ, η) - δperturbation_logdensity, δperturbed, δθ, δη = logdensity_pullback(one(l)) - return δθ -end diff --git a/src/perturbed/multiplicative.jl b/src/perturbed/multiplicative.jl deleted file mode 100644 index f96a333..0000000 --- a/src/perturbed/multiplicative.jl +++ /dev/null @@ -1,130 +0,0 @@ -""" - PerturbedMultiplicative{P,G,O,R,S,parallel} <: AbstractPerturbed{parallel} - -Differentiable multiplicative perturbation of a black-box oracle: -the input undergoes `θ -> θ ⊙ exp[εZ - ε²/2]` where `Z ∼ perturbation`. - -This [`AbstractOptimizationLayer`](@ref) is compatible with [`FenchelYoungLoss`](@ref), -if the oracle is an optimization maximizer with a linear objective. - -Reference: - -See [`AbstractPerturbed`](@ref) for more details. - -# Specific field -- `ε:Float64`: size of the perturbation -""" -struct PerturbedMultiplicative{P,G,O,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: - AbstractPerturbed{O,parallel} - perturbation::P - grad_logdensity::G - oracle::O - rng::R - seed::S - nb_samples::Int - ε::Float64 -end - -function Base.show(io::IO, perturbed::PerturbedMultiplicative) - (; oracle, ε, rng, seed, nb_samples, perturbation) = perturbed - perturb = isnothing(perturbation) ? "Normal(0, 1)" : "$perturbation" - return print( - io, - "PerturbedMultiplicative($oracle, $ε, $nb_samples, $(typeof(rng)), $seed, $perturb)", - ) -end - -""" - PerturbedMultiplicative(oracle[; ε, nb_samples, seed, is_parallel, perturbation, grad_logdensity, rng]) - -[`PerturbedMultiplicative`](@ref) constructor. - -# Arguments -- `oracle`: the black-box oracle we want to differentiate through. - It should be a linear maximizer if you want to use it inside a [`FenchelYoungLoss`]. - -# Keyword arguments (optional) -- `ε=1.0`: size of the perturbation. -- `nb_samples::Int=1`: number of perturbation samples drawn at each forward pass. -- `perturbation=nothing`: nothing by default. If you want to use a different distribution than a - `Normal` for the perturbation `z`, give it here as a distribution-like object implementing - the `rand` method. It should also implement `logdensityof` if `grad_logdensity` is not given. -- `grad_logdensity=nothing`: gradient function of `perturbation` w.r.t. `θ`. - If set to nothing (default), it's computed using automatic differentiation. -- `seed::Union{Nothing,Int}=nothing`: seed of the perturbation. - It is reset each time the forward pass is called, - making it deterministic by always drawing the same perturbations. - If you do not want this behaviour, set this field to `nothing`. -- `rng::AbstractRNG`=MersenneTwister(0): random number generator using the `seed`. -""" -function PerturbedMultiplicative( - oracle::O; - ε=1.0, - nb_samples=1, - seed::S=nothing, - is_parallel=false, - perturbation::P=nothing, - grad_logdensity::G=nothing, - rng::R=MersenneTwister(0), -) where {P,G,O,R<:AbstractRNG,S<:Union{Int,Nothing}} - return PerturbedMultiplicative{P,G,O,R,S,is_parallel}( - perturbation, grad_logdensity, oracle, rng, seed, nb_samples, float(ε) - ) -end - -function sample_perturbations(perturbed::PerturbedMultiplicative, θ::AbstractArray) - (; rng, seed, nb_samples, perturbation) = perturbed - seed!(rng, seed) - return [rand(rng, perturbation, size(θ)) for _ in 1:nb_samples] -end - -function sample_perturbations(perturbed::PerturbedMultiplicative{Nothing}, θ::AbstractArray) - (; rng, seed, nb_samples) = perturbed - seed!(rng, seed) - return [randn(rng, size(θ)) for _ in 1:nb_samples] -end - -function compute_probability_distribution_from_samples( - perturbed::PerturbedMultiplicative, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) - (; ε) = perturbed - η_samples = [θ .* exp.(ε .* Z .- ε^2 / 2) for Z in Z_samples] - atoms = compute_atoms(perturbed, η_samples; kwargs...) - weights = ones(length(atoms)) ./ length(atoms) - probadist = FixedAtomsProbabilityDistribution(atoms, weights) - return probadist -end - -function perturbation_grad_logdensity( - ::RuleConfig, - perturbed::PerturbedMultiplicative{Nothing,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) - (; ε) = perturbed - return inv.(ε .* θ) .* Z -end - -function _perturbation_logdensity( - perturbed::PerturbedMultiplicative, θ::AbstractArray, η::AbstractArray -) - (; ε, perturbation) = perturbed - Z = (log.(η) .- log.(θ)) ./ ε .+ ε / 2 - return sum(logdensityof(perturbation, z) for z in Z) -end - -function perturbation_grad_logdensity( - rc::RuleConfig, - perturbed::PerturbedMultiplicative{P,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) where {P} - (; ε) = perturbed - η = θ .* exp.(ε .* Z .- ε^2 / 2) - l, logdensity_pullback = rrule_via_ad(rc, _perturbation_logdensity, perturbed, θ, η) - δperturbation_logdensity, δperturbed, δθ, δη = logdensity_pullback(one(l)) - return δθ -end diff --git a/src/perturbed/perturbed_oracle.jl b/src/perturbed/perturbed_oracle.jl deleted file mode 100644 index 9d7cd04..0000000 --- a/src/perturbed/perturbed_oracle.jl +++ /dev/null @@ -1,92 +0,0 @@ -""" - PerturbedOracle{P,G,O,R,S,parallel} <: AbstractPerturbed{parallel} - -Differentiable perturbed black-box oracle. The `oracle` input `θ` is perturbed as `η ∼ perturbation(⋅|θ)`. -[`PerturbedAdditive`](@ref) is a special case of `PerturbedOracle` with `perturbation(θ) = MvNormal(θ, ε * I)`. -[`PerturbedMultiplicative`] is also a special case of `PerturbedOracle`. - -See [`AbstractPerturbed`](@ref) for more details about its fields. -""" -struct PerturbedOracle{P,G,O,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: - AbstractPerturbed{O,parallel} - perturbation::P - grad_logdensity::G - oracle::O - rng::R - seed::S - nb_samples::Int -end - -""" - PerturbedOracle(perturbation, oracle[; grad_logdensity, rng, seed, is_parallel, nb_samples]) - -[`PerturbedOracle`](@ref) constructor. - -# Arguments -- `oracle`: the black-box oracle we want to differentiate through -- `perturbation`: should be a callable such that `perturbation(θ)` is a distribution-like - object that can be sampled with `rand`. - It should also implement `logdensityof` if `grad_logdensity` is not given. - -# Keyword arguments (optional) -- `grad_logdensity=nothing`: gradient function of `perturbation` w.r.t. `θ`. - If set to nothing (default), it's computed using automatic differentiation. -- `nb_samples::Int=1`: number of perturbation samples drawn at each forward pass -- `seed::Union{Nothing,Int}=nothing`: seed of the perturbation. - It is reset each time the forward pass is called, - making it deterministic by always drawing the same perturbations. - If you do not want this behaviour, set this field to `nothing`. -- `rng::AbstractRNG`=MersenneTwister(0): random number generator using the `seed`. - -!!! info - If you have access to the analytical expression of `grad_logdensity` it is recommended to - give it, as it will be computationally faster. -""" -function PerturbedOracle( - oracle::O, - perturbation::P; - grad_logdensity::G=nothing, - nb_samples::Int=1, - seed::S=nothing, - is_parallel::Bool=false, - rng::R=MersenneTwister(0), -) where {P,G,O,R<:AbstractRNG,S<:Union{Int,Nothing}} - return PerturbedOracle{P,G,O,R,S,is_parallel}( - perturbation, grad_logdensity, oracle, rng, seed, nb_samples - ) -end - -function Base.show(io::IO, po::PerturbedOracle) - (; oracle, perturbation, rng, seed, nb_samples) = po - return print( - io, "PerturbedOracle($perturbation, $oracle, $nb_samples, $(typeof(rng)), $seed)" - ) -end - -function sample_perturbations(po::PerturbedOracle, θ::AbstractArray) - (; rng, seed, perturbation, nb_samples) = po - seed!(rng, seed) - η_samples = [rand(rng, perturbation(θ)) for _ in 1:nb_samples] - return η_samples -end - -function compute_probability_distribution_from_samples( - perturbed::PerturbedOracle, θ, η_samples::Vector{<:AbstractArray}; kwargs... -) - atoms = compute_atoms(perturbed, η_samples; kwargs...) - weights = ones(length(atoms)) ./ length(atoms) - probadist = FixedAtomsProbabilityDistribution(atoms, weights) - return probadist -end - -function _perturbation_logdensity(po::PerturbedOracle, θ::AbstractArray, η::AbstractArray) - return logdensityof(po.perturbation(θ), η) -end - -function perturbation_grad_logdensity( - rc::RuleConfig, po::PerturbedOracle{P,Nothing}, θ::AbstractArray, η::AbstractArray -) where {P} - l, logdensity_pullback = rrule_via_ad(rc, _perturbation_logdensity, po, θ, η) - δperturbation_logdensity, δpo, δθ, δη = logdensity_pullback(one(l)) - return δθ -end diff --git a/src/utils/generalized_maximizer.jl b/src/utils/generalized_maximizer.jl deleted file mode 100644 index e1a4384..0000000 --- a/src/utils/generalized_maximizer.jl +++ /dev/null @@ -1,35 +0,0 @@ -""" - GeneralizedMaximizer{F,G,H} - -Wrapper for generalized maximizers `maximizer` of the form argmax_y θᵀg(y) + h(y). -It is compatible with the following layers -- [`PerturbedAdditive`](@ref) (with or without [`FenchelYoungLoss`](@ref)) -- [`PerturbedMultiplicative`](@ref) (with or without [`FenchelYoungLoss`](@ref)) -- [`SPOPlusLoss`](@ref) -""" -struct GeneralizedMaximizer{F,G,H} - maximizer::F - g::G - h::H -end - -GeneralizedMaximizer(f; g=identity_kw, h=zero ∘ eltype_kw) = GeneralizedMaximizer(f, g, h) - -function Base.show(io::IO, f::GeneralizedMaximizer) - (; maximizer, g, h) = f - return print(io, "GeneralizedMaximizer($maximizer, $g, $h)") -end - -# Callable calls the wrapped maximizer -function (f::GeneralizedMaximizer)(θ::AbstractArray{<:Real}; kwargs...) - return f.maximizer(θ; kwargs...) -end - -""" - objective_value(f, θ, y, kwargs...) - -Computes the objective value of given GeneralizedMaximizer `f`, knowing weights `θ` and solution `y`. -""" -function objective_value(f::GeneralizedMaximizer, θ, y; kwargs...) - return dot(θ, f.g(y; kwargs...)) .+ f.h(y; kwargs...) -end diff --git a/src/utils/linear_maximizer.jl b/src/utils/linear_maximizer.jl new file mode 100644 index 0000000..dd849d3 --- /dev/null +++ b/src/utils/linear_maximizer.jl @@ -0,0 +1,66 @@ +""" +$TYPEDEF + +Wrapper for generic minear maximizers of the form argmax_y θᵀg(y) + h(y). +It is compatible with the following layers +- [`PerturbedAdditive`](@ref) (with or without a [`FenchelYoungLoss`](@ref)) +- [`PerturbedMultiplicative`](@ref) (with or without a [`FenchelYoungLoss`](@ref)) +- [`SPOPlusLoss`](@ref) + +# Fields +$TYPEDFIELDS +""" +@kwdef struct LinearMaximizer{F,G,H} + "function θ ⟼ argmax_y θᵀg(y) + h(y)" + maximizer::F + "function g(y) used in the objective" + g::G = identity_kw + "function h(y) used in the objective" + h::H = zero ∘ eltype_kw +end + +function Base.show(io::IO, f::LinearMaximizer) + (; maximizer, g, h) = f + return print(io, "LinearMaximizer($maximizer, $g, $h)") +end + +""" +$TYPEDSIGNATURES + +Constructor for [`LinearMaximizer`](@ref). +""" +function LinearMaximizer(maximizer; g=identity_kw, h=zero ∘ eltype_kw) + return LinearMaximizer(maximizer, g, h) +end + +""" +$TYPEDSIGNATURES + +Calls the wrapped maximizer. +""" +function (f::LinearMaximizer)(θ::AbstractArray; kwargs...) + return f.maximizer(θ; kwargs...) +end + +# default is oracles of the form argmax_y θᵀy +objective_value(::Any, θ, y; kwargs...) = dot(θ, y) +apply_g(::Any, y; kwargs...) = y + +""" +$TYPEDSIGNATURES + +Computes the objective value of given LinearMaximizer `f`, knowing weights `θ` and solution `y`. +i.e. θᵀg(y) + h(y) +""" +function objective_value(f::LinearMaximizer, θ, y; kwargs...) + return dot(θ, f.g(y; kwargs...)) .+ f.h(y; kwargs...) +end + +""" +$TYPEDSIGNATURES + +Applies the function `g` of the LinearMaximizer `f` to `y`. +""" +function apply_g(f::LinearMaximizer, y; kwargs...) + return f.g(y; kwargs...) +end diff --git a/src/utils/probability_distribution.jl b/src/utils/probability_distribution.jl deleted file mode 100644 index 0fdc21a..0000000 --- a/src/utils/probability_distribution.jl +++ /dev/null @@ -1,85 +0,0 @@ -""" - FixedAtomsProbabilityDistribution{A,W} - -Encodes a probability distribution with finite support and fixed atoms. - -See [`compute_expectation`](@ref) to understand the name of this struct. - -# Fields -- `atoms::Vector{A}`: elements of the support -- `weights::Vector{W}`: probability values for each atom (must sum to 1) -""" -struct FixedAtomsProbabilityDistribution{A,W} - atoms::Vector{A} - weights::Vector{W} - - function FixedAtomsProbabilityDistribution( - atoms::Vector{A}, weights::Vector{W} - ) where {A,W} - @assert length(atoms) == length(weights) > 0 - @assert isapprox(sum(weights), one(W); atol=1e-4) - return new{A,W}(atoms, weights) - end -end - -Base.length(probadist::FixedAtomsProbabilityDistribution) = length(probadist.atoms) - -""" - rand([rng,] probadist) - -Sample from the atoms of `probadist` according to their weights. -""" -function Base.rand(rng::AbstractRNG, probadist::FixedAtomsProbabilityDistribution) - (; atoms, weights) = probadist - return sample(rng, atoms, StatsBase.Weights(weights)) -end - -Base.rand(probadist::FixedAtomsProbabilityDistribution) = rand(GLOBAL_RNG, probadist) - -""" - apply_on_atoms(post_processing, probadist) - -Create a new distribution by applying the function `post_processing` to each atom of `probadist` (the weights remain the same). -""" -function apply_on_atoms( - post_processing, probadist::FixedAtomsProbabilityDistribution; kwargs... -) - (; atoms, weights) = probadist - post_processed_atoms = [post_processing(a; kwargs...) for a in atoms] - return FixedAtomsProbabilityDistribution(post_processed_atoms, weights) -end - -""" - compute_expectation(probadist[, post_processing=identity]) - -Compute the expectation of `post_processing(X)` where `X` is a random variable distributed according to `probadist`. - -This operation is made differentiable thanks to a custom reverse rule, even when `post_processing` itself is not a differentiable function. - -!!! warning "Warning" - Derivatives are computed with respect to `probadist.weights` only, assuming that `probadist.atoms` doesn't change (hence the name [`FixedAtomsProbabilityDistribution`](@ref)). -""" -function compute_expectation( - probadist::FixedAtomsProbabilityDistribution, post_processing=identity; kwargs... -) - (; atoms, weights) = probadist - return sum(w * post_processing(a; kwargs...) for (w, a) in zip(weights, atoms)) -end - -function ChainRulesCore.rrule( - ::typeof(compute_expectation), - probadist::FixedAtomsProbabilityDistribution, - post_processing=identity; - kwargs..., -) - e = compute_expectation(probadist, post_processing; kwargs...) - function expectation_pullback(de) - d_atoms = NoTangent() - d_weights = [dot(de, post_processing(a; kwargs...)) for a in probadist.atoms] - d_probadist = Tangent{FixedAtomsProbabilityDistribution}(; - atoms=d_atoms, weights=d_weights - ) - return NoTangent(), d_probadist, NoTangent() - end - return e, expectation_pullback -end diff --git a/src/utils/pushforward.jl b/src/utils/pushforward.jl index ddd915b..d1d600d 100644 --- a/src/utils/pushforward.jl +++ b/src/utils/pushforward.jl @@ -1,18 +1,17 @@ """ - Pushforward <: AbstractLayer +$TYPEDEF Differentiable pushforward of a probabilistic optimization layer with an arbitrary function post-processing function. `Pushforward` can be used for direct regret minimization (aka learning by experience) when the post-processing returns a cost. # Fields -- `optimization_layer::AbstractOptimizationLayer`: probabilistic optimization layer -- `post_processing`: callable - -See also: [`FixedAtomsProbabilityDistribution`](@ref). +$TYPEDFIELDS """ struct Pushforward{O<:AbstractOptimizationLayer,P} <: AbstractLayer + "probabilistic optimization layer" optimization_layer::O + "callable" post_processing::P end @@ -22,32 +21,15 @@ function Base.show(io::IO, pushforward::Pushforward) end """ - compute_probability_distribution(pushforward, θ) - -Output the distribution of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.optimization_layer` applied to `θ`. - -This function is not differentiable if `pushforward.post_processing` isn't. - -See also: [`apply_on_atoms`](@ref). -""" -function compute_probability_distribution(pushforward::Pushforward, θ; kwargs...) - (; optimization_layer, post_processing) = pushforward - probadist = compute_probability_distribution(optimization_layer, θ; kwargs...) - post_processed_probadist = apply_on_atoms(post_processing, probadist; kwargs...) - return post_processed_probadist -end - -""" - (pushforward::Pushforward)(θ; kwargs...) +$TYPEDSIGNATURES Output the expectation of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.optimization_layer` applied to `θ`. -Unlike [`compute_probability_distribution(pushforward, θ)`](@ref), this function is differentiable, even if `pushforward.post_processing` isn't. - -See also: [`compute_expectation`](@ref). +This function is differentiable, even if `pushforward.post_processing` isn't. """ function (pushforward::Pushforward)(θ::AbstractArray; kwargs...) (; optimization_layer, post_processing) = pushforward probadist = compute_probability_distribution(optimization_layer, θ; kwargs...) - return compute_expectation(probadist, post_processing; kwargs...) + post_processing_kw = FixKwargs(post_processing, kwargs) + return mean(post_processing_kw, probadist) end diff --git a/src/utils/some_functions.jl b/src/utils/some_functions.jl index 4512ab2..7672a65 100644 --- a/src/utils/some_functions.jl +++ b/src/utils/some_functions.jl @@ -1,26 +1,26 @@ """ - positive_part(x) +$TYPEDSIGNATURES -Compute `max(x,0)`. +Compute `max(x, 0)`. """ positive_part(x) = x >= zero(x) ? x : zero(x) """ - isproba(x) +$TYPEDSIGNATURES Check whether `x ∈ [0,1]`. """ isproba(x::Real) = zero(x) <= x <= one(x) """ - isprobadist(p) +$TYPEDSIGNATURES Check whether the elements of `p` are nonnegative and sum to 1. """ isprobadist(p::AbstractVector{R}) where {R<:Real} = all(isproba, p) && sum(p) ≈ one(R) """ - half_square_norm(x) +$TYPEDSIGNATURES Compute the squared Euclidean norm of `x` and divide it by 2. """ @@ -29,7 +29,7 @@ function half_square_norm(x::AbstractArray) end """ - shannon_entropy(p) +$TYPEDSIGNATURES Compute the Shannon entropy of a probability distribution: `H(p) = -∑ pᵢlog(pᵢ)`. """ @@ -46,7 +46,7 @@ end negative_shannon_entropy(p::AbstractVector) = -shannon_entropy(p) """ - one_hot_argmax(z) +$TYPEDSIGNATURES One-hot encoding of the argmax function. """ @@ -57,7 +57,7 @@ function one_hot_argmax(z::AbstractVector{R}; kwargs...) where {R<:Real} end """ - ranking(θ[; rev]) +$TYPEDSIGNATURES Compute the vector `r` such that `rᵢ` is the rank of `θᵢ` in `θ`. """ diff --git a/src/utils/utils.jl b/src/utils/utils.jl new file mode 100644 index 0000000..36b0192 --- /dev/null +++ b/src/utils/utils.jl @@ -0,0 +1,48 @@ +""" +$TYPEDEF + +Callable struct that fixes the keyword arguments of `f` to `kwargs...`, and only accepts positional arguments. + +# Fields +$TYPEDFIELDS +""" +struct FixKwargs{F,K} + "function" + f::F + "fixed keyword arguments" + kwargs::K +end + +(fk::FixKwargs)(args...) = fk.f(args...; fk.kwargs...) + +""" +$TYPEDEF + +Callable struct that fixes the first argument of `f` to `x`, and the keyword arguments to `kwargs...`. + +# Fields +$TYPEDFIELDS +""" +struct Fix1Kwargs{F,K,T} <: Function + "function" + f::F + "fixed first argument" + x::T + "fixed keyword arguments" + kwargs::K +end + +(fk::Fix1Kwargs)(args...) = fk.f(fk.x, args...; fk.kwargs...) + +""" +$TYPEDEF + +Callable struct that fixes the first argument of `f` to `x`. +Compared to Base.Fix1, works on functions with more than two arguments. +""" +struct FixFirst{F,T} + f::F + x::T +end + +(fk::FixFirst)(args...) = fk.f(fk.x, args...) diff --git a/test/interface.jl b/test/abstract_regularized_interface.jl similarity index 84% rename from test/interface.jl rename to test/abstract_regularized_interface.jl index 0e6b8b6..a6d1687 100644 --- a/test/interface.jl +++ b/test/abstract_regularized_interface.jl @@ -5,4 +5,5 @@ @test RI.check_interface_implemented(AbstractRegularized, RegularizedFrankWolfe) @test RI.check_interface_implemented(AbstractRegularized, SoftArgmax) @test RI.check_interface_implemented(AbstractRegularized, SparseArgmax) + @test RI.check_interface_implemented(AbstractRegularized, SoftRank) end diff --git a/test/jacobian_approx.jl b/test/jacobian_approx.jl deleted file mode 100644 index abd80ed..0000000 --- a/test/jacobian_approx.jl +++ /dev/null @@ -1,42 +0,0 @@ -@testitem "Jacobian approx" begin - using LinearAlgebra - using Random - using Test - using Zygote - - θ = [3, 5, 4, 2] - - perturbed1 = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=1_000, seed=0) - perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=10_000, seed=0) - perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=0.5, nb_samples=1_000, seed=0) - perturbed2_big = PerturbedMultiplicative( - one_hot_argmax; ε=0.5, nb_samples=10_000, seed=0 - ) - - @testset "PerturbedAdditive" begin - # Compute jacobian with reverse mode - jac1 = Zygote.jacobian(θ -> perturbed1(θ; autodiff_variance_reduction=false), θ)[1] - jac1_big = Zygote.jacobian( - θ -> perturbed1_big(θ; autodiff_variance_reduction=false), θ - )[1] - # Only diagonal should be positive - @test all(diag(jac1) .>= 0) - @test all(jac1 - Diagonal(jac1) .<= 0) - # Order of diagonal coefficients should follow order of θ - @test sortperm(diag(jac1)) == sortperm(θ) - # No scaling with nb of samples - @test norm(jac1) ≈ norm(jac1_big) rtol = 1e-2 - end - - @testset "PerturbedMultiplicative" begin - jac2 = Zygote.jacobian(θ -> perturbed2(θ; autodiff_variance_reduction=false), θ)[1] - jac2_big = Zygote.jacobian( - θ -> perturbed2_big(θ; autodiff_variance_reduction=false), θ - )[1] - @test all(diag(jac2) .>= 0) - @test all(jac2 - Diagonal(jac2) .<= 0) - @test sortperm(diag(jac2)) != sortperm(θ) - # This is not equal because the diagonal coefficient for θ₃ = 4 is often larger than the one for θ₂ = 5. It happens because θ₃ has the opportunity to *become* the argmax (and hence switch from 0 to 1), whereas θ₂ already *is* the argmax. - @test norm(jac2) ≈ norm(jac2_big) rtol = 1e-2 - end -end diff --git a/test/generalized_maximizer.jl b/test/learning_generalized_maximizer.jl similarity index 80% rename from test/generalized_maximizer.jl rename to test/learning_generalized_maximizer.jl index 93f28a9..037a638 100644 --- a/test/generalized_maximizer.jl +++ b/test/learning_generalized_maximizer.jl @@ -13,7 +13,7 @@ @test y == [1 0 1; 0 1 0; 1 1 1] - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + generalized_maximizer = LinearMaximizer(max_pricing; g, h) @test generalized_maximizer(θ; instance) == y @@ -29,16 +29,17 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + maximizer = LinearMaximizer(max_pricing; g, h) + perturbed = PerturbedAdditive(maximizer; ε=1.0, nb_samples=10) function cost(y; instance) - return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) + return -objective_value(maximizer, true_encoder(instance), y; instance) end test_pipeline!( PipelineLossImitation(); instance_dim=5, true_maximizer=max_pricing, - maximizer=PerturbedAdditive(generalized_maximizer; ε=1.0, nb_samples=10), + maximizer=perturbed, loss=mse_kw, error_function=hamming_distance, cost, @@ -54,16 +55,17 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + maximizer = LinearMaximizer(max_pricing; g, h) + perturbed = PerturbedMultiplicative(maximizer; ε=1.0, nb_samples=10) function cost(y; instance) - return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) + return -objective_value(maximizer, true_encoder(instance), y; instance) end test_pipeline!( PipelineLossImitation(); instance_dim=5, true_maximizer=max_pricing, - maximizer=PerturbedMultiplicative(generalized_maximizer; ε=1.0, nb_samples=10), + maximizer=perturbed, loss=mse_kw, error_function=hamming_distance, cost, @@ -78,9 +80,12 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + maximizer = LinearMaximizer(max_pricing; g, h) + @info maximizer g h + perturbed = PerturbedAdditive(maximizer; ε=1.0, nb_samples=10) + @info perturbed function cost(y; instance) - return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) + return -objective_value(maximizer, true_encoder(instance), y; instance) end test_pipeline!( @@ -88,9 +93,7 @@ end instance_dim=5, true_maximizer=max_pricing, maximizer=identity_kw, - loss=FenchelYoungLoss( - PerturbedAdditive(generalized_maximizer; ε=1.0, nb_samples=5) - ), + loss=FenchelYoungLoss(perturbed), error_function=hamming_distance, cost, true_encoder, @@ -105,9 +108,10 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + maximizer = LinearMaximizer(max_pricing; g, h) + perturbed = PerturbedMultiplicative(maximizer; ε=0.1, nb_samples=10) function cost(y; instance) - return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) + return -objective_value(maximizer, true_encoder(instance), y; instance) end test_pipeline!( @@ -115,9 +119,7 @@ end instance_dim=5, true_maximizer=max_pricing, maximizer=identity_kw, - loss=FenchelYoungLoss( - PerturbedMultiplicative(generalized_maximizer; ε=0.1, nb_samples=5) - ), + loss=FenchelYoungLoss(perturbed), error_function=hamming_distance, cost, true_encoder, @@ -131,7 +133,7 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + generalized_maximizer = LinearMaximizer(max_pricing; g, h) function cost(y; instance) return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) end @@ -155,7 +157,7 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + generalized_maximizer = LinearMaximizer(max_pricing; g, h) function cost(y; instance) return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) end @@ -180,7 +182,7 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + generalized_maximizer = LinearMaximizer(max_pricing; g, h) function cost(y; instance) return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) end @@ -207,7 +209,7 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + generalized_maximizer = LinearMaximizer(max_pricing; g, h) function cost(y; instance) return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) end @@ -232,7 +234,7 @@ end const RI = RequiredInterfaces Random.seed!(63) - struct MyRegularized{M<:GeneralizedMaximizer} <: AbstractRegularizedGeneralizedMaximizer + struct MyRegularized{M<:LinearMaximizer} <: AbstractRegularized # GeneralizedMaximizer maximizer::M end @@ -244,7 +246,7 @@ end @test RI.check_interface_implemented(AbstractRegularized, MyRegularized) - regularized = MyRegularized(GeneralizedMaximizer(sparse_argmax)) + regularized = MyRegularized(LinearMaximizer(sparse_argmax)) test_pipeline!( PipelineLossImitation(); diff --git a/test/ranking.jl b/test/learning_ranking.jl similarity index 64% rename from test/ranking.jl rename to test/learning_ranking.jl index 63954a5..ee7fe82 100644 --- a/test/ranking.jl +++ b/test/learning_ranking.jl @@ -1,4 +1,4 @@ -@testitem "Ranking - imit - SPO+ (θ)" default_imports = false begin +@testitem "imit - SPO+ (θ)" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -13,7 +13,7 @@ ) end -@testitem "Ranking - imit - SPO+ (θ & y)" default_imports = false begin +@testitem "imit - SPO+ (θ & y)" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -28,7 +28,7 @@ end ) end -@testitem "Ranking - imit - MSE IdentityRelaxation" default_imports = false begin +@testitem "imit - MSE IdentityRelaxation" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random Random.seed!(63) @@ -43,7 +43,7 @@ end ) end -@testitem "Ranking - imit - MSE Interpolation" default_imports = false begin +@testitem "imit - MSE Interpolation" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -58,7 +58,7 @@ end ) end -@testitem "Ranking - imit - MSE PerturbedAdditive" default_imports = false begin +@testitem "imit - MSE PerturbedAdditive" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -73,7 +73,7 @@ end ) end -@testitem "Ranking - imit - MSE PerturbedMultiplicative" default_imports = false begin +@testitem "imit - MSE PerturbedMultiplicative" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -88,7 +88,7 @@ end ) end -@testitem "Ranking - imit - MSE RegularizedFrankWolfe" default_imports = false begin +@testitem "imit - MSE RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -108,7 +108,7 @@ end ) end -@testitem "Ranking - imit - FYL PerturbedAdditive" default_imports = false begin +@testitem "imit - FYL PerturbedAdditive" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -123,7 +123,7 @@ end ) end -@testitem "Ranking - imit - FYL PerturbedMultiplicative" default_imports = false begin +@testitem "imit - FYL PerturbedMultiplicative" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -138,7 +138,7 @@ end ) end -@testitem "Ranking - imit - FYL PerturbedAdditive{LogNormal}" default_imports = false begin +@testitem "imit - FYL PerturbedAdditive{LogNormal}" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random, Distributions, LinearAlgebra Random.seed!(63) @@ -149,13 +149,15 @@ end true_maximizer=ranking, maximizer=identity_kw, loss=FenchelYoungLoss( - PerturbedAdditive(ranking; ε=1.0, nb_samples=5, perturbation=LogNormal(0, 1)) + PerturbedAdditive( + ranking; ε=1.0, nb_samples=5, perturbation_dist=LogNormal(0, 1) + ), ), error_function=hamming_distance, ) end -@testitem "Ranking - imit - FYL RegularizedFrankWolfe" default_imports = false begin +@testitem "imit - FYL RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -177,94 +179,91 @@ end ) end -@testitem "Ranking - exp - Pushforward PerturbedAdditive" default_imports = false begin +@testitem "exp - Pushforward PerturbedAdditive" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random + using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Statistics Random.seed!(63) true_encoder = encoder_factory() cost(y; instance) = dot(y, -true_encoder(instance)) + f(θ; kwargs...) = cost(ranking(θ; kwargs...); kwargs...) test_pipeline!( PipelineLossExperience(); instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=Pushforward(PerturbedAdditive(ranking; ε=1.0, nb_samples=10), cost), + loss=PerturbedAdditive(f; ε=1.0, nb_samples=10), error_function=hamming_distance, true_encoder=true_encoder, cost=cost, ) end -@testitem "Ranking - exp - Pushforward PerturbedMultiplicative" default_imports = false begin +@testitem "exp - Pushforward PerturbedMultiplicative" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random Random.seed!(63) true_encoder = encoder_factory() cost(y; instance) = dot(y, -true_encoder(instance)) + f(θ; kwargs...) = cost(ranking(θ; kwargs...); kwargs...) test_pipeline!( PipelineLossExperience(); instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=Pushforward(PerturbedMultiplicative(ranking; ε=1.0, nb_samples=10), cost), + loss=PerturbedAdditive(f; ε=1.0, nb_samples=10), error_function=hamming_distance, true_encoder=true_encoder, cost=cost, ) end -@testitem "Ranking - exp - Pushforward PerturbedAdditive{LogNormal}" default_imports = false begin +@testitem "exp - Pushforward PerturbedAdditive{LogNormal}" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Distributions Random.seed!(63) true_encoder = encoder_factory() cost(y; instance) = dot(y, -true_encoder(instance)) + f(θ; kwargs...) = cost(ranking(θ; kwargs...); kwargs...) test_pipeline!( PipelineLossExperience(); instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=Pushforward( - PerturbedAdditive(ranking; ε=1.0, nb_samples=10, perturbation=LogNormal(0, 1)), - cost, - ), + loss=PerturbedAdditive(f; ε=1.0, nb_samples=10, perturbation_dist=LogNormal(0, 1)), error_function=hamming_distance, true_encoder=true_encoder, cost=cost, - epochs=500, + epochs=100, ) end -@testitem "Ranking - exp - Pushforward PerturbedMultiplicative{LogNormal}" default_imports = - false begin +@testitem "exp - Pushforward PerturbedMultiplicative{LogNormal}" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Distributions Random.seed!(63) true_encoder = encoder_factory() cost(y; instance) = dot(y, -true_encoder(instance)) + f(θ; kwargs...) = cost(ranking(θ; kwargs...); kwargs...) test_pipeline!( PipelineLossExperience(); instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=Pushforward( - PerturbedMultiplicative( - ranking; ε=1.0, nb_samples=10, perturbation=LogNormal(0, 1) - ), - cost, + loss=PerturbedMultiplicative( + f; ε=1.0, nb_samples=10, perturbation_dist=LogNormal(0, 1) ), error_function=hamming_distance, true_encoder=true_encoder, cost=cost, - epochs=500, + epochs=100, ) end -@testitem "Ranking - exp - Pushforward PerturbedOracle{LogNormal}" default_imports = false begin +@testitem "exp - Pushforward PerturbedOracle{LogNormal}" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Distributions Random.seed!(63) @@ -273,19 +272,20 @@ end true_encoder = encoder_factory() cost(y; instance) = dot(y, -true_encoder(instance)) + f(y; instance) = cost(ranking(y; instance); instance) test_pipeline!( PipelineLossExperience(); instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=Pushforward(PerturbedOracle(ranking, p; nb_samples=10), cost), + loss=PerturbedOracle(f, p; nb_samples=10), error_function=hamming_distance, true_encoder=true_encoder, cost=cost, ) end -@testitem "Ranking - exp - Pushforward RegularizedFrankWolfe" default_imports = false begin +@testitem "exp - Pushforward RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, LinearAlgebra, Random @@ -312,3 +312,83 @@ end cost=cost, ) end + +@testitem "exp - soft rank" default_imports = false begin + include("InferOptTestUtils/src/InferOptTestUtils.jl") + using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Test + Random.seed!(63) + + true_encoder = encoder_factory() + cost(y; instance) = dot(y, -true_encoder(instance)) + + Random.seed!(67) + soft_rank_l2_results = test_pipeline!( + PipelineLossExperience(); + instance_dim=5, + true_maximizer=ranking, + maximizer=SoftRank(), + loss=cost, + error_function=hamming_distance, + true_encoder=true_encoder, + cost=cost, + epochs=50, + ) + + Random.seed!(67) + soft_rank_kl_results = test_pipeline!( + PipelineLossExperience(); + instance_dim=5, + true_maximizer=ranking, + maximizer=SoftRank(; regularization="kl"), + loss=cost, + error_function=hamming_distance, + true_encoder=true_encoder, + cost=cost, + epochs=50, + ) + + Random.seed!(67) + perturbed_results = test_pipeline!( + PipelineLossExperience(); + instance_dim=5, + true_maximizer=ranking, + maximizer=identity_kw, + loss=Pushforward(PerturbedAdditive(ranking; ε=1.0, nb_samples=10), cost), + error_function=hamming_distance, + true_encoder=true_encoder, + cost=cost, + epochs=50, + ) + + # Check that we achieve better performance than the reinforce trick + @test soft_rank_l2_results.test_cost_gaps[end] < perturbed_results.test_cost_gaps[end] + @test soft_rank_kl_results.test_cost_gaps[end] < perturbed_results.test_cost_gaps[end] +end + +@testitem "imit - FYL - soft rank" default_imports = false begin + include("InferOptTestUtils/src/InferOptTestUtils.jl") + using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Test + Random.seed!(63) + + true_encoder = encoder_factory() + + test_pipeline!( + PipelineLossImitation(); + instance_dim=5, + true_maximizer=ranking, + maximizer=identity_kw, + loss=FenchelYoungLoss(SoftRank()), + error_function=hamming_distance, + true_encoder=true_encoder, + ) + + test_pipeline!( + PipelineLossImitation(); + instance_dim=5, + true_maximizer=ranking, + maximizer=identity_kw, + loss=FenchelYoungLoss(SoftRank(; regularization="kl", ε=10.0)), + error_function=hamming_distance, + true_encoder=true_encoder, + ) +end diff --git a/test/paths.jl b/test/paths.jl index 0681e8a..fb6a32f 100644 --- a/test/paths.jl +++ b/test/paths.jl @@ -155,7 +155,10 @@ end maximizer=identity_kw, loss=FenchelYoungLoss( PerturbedAdditive( - shortest_path_maximizer; ε=1.0, nb_samples=5, perturbation=LogNormal(0, 1) + shortest_path_maximizer; + ε=1.0, + nb_samples=5, + perturbation_dist=LogNormal(0, 1), ), ), error_function=mse_kw, diff --git a/test/perturbed.jl b/test/perturbed.jl new file mode 100644 index 0000000..2579828 --- /dev/null +++ b/test/perturbed.jl @@ -0,0 +1,114 @@ +@testitem "Jacobian approx" begin + using LinearAlgebra + using Random + using Test + using Zygote + + θ = [3, 5, 4, 2] + + perturbed1 = PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=1e4, seed=0) + perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=1e6, seed=0) + + perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=1e4, seed=0) + perturbed2_big = PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=1e6, seed=0) + + @testset "PerturbedAdditive" begin + # Compute jacobian with reverse mode + jac1 = Zygote.jacobian(perturbed1, θ)[1] + jac1_big = Zygote.jacobian(perturbed1_big, θ)[1] + # Only diagonal should be positive + @test all(diag(jac1) .>= 0) + @test all(jac1 - Diagonal(jac1) .<= 0) + # Order of diagonal coefficients should follow order of θ + @test sortperm(diag(jac1_big)) == sortperm(θ) + # No scaling with nb of samples + @test norm(jac1) ≈ norm(jac1_big) rtol = 5e-2 + end + + @testset "PerturbedMultiplicative" begin + jac2 = Zygote.jacobian(perturbed2, θ)[1] + jac2_big = Zygote.jacobian(perturbed2_big, θ)[1] + @test all(diag(jac2_big) .>= 0) + @test all(jac2_big - Diagonal(jac2_big) .<= 0) + @test_broken sortperm(diag(jac2_big)) == sortperm(θ) + @test norm(jac2) ≈ norm(jac2_big) rtol = 5e-2 + end +end + +@testitem "PerturbedOracle vs PerturbedAdditive" default_imports = false begin + include("InferOptTestUtils/src/InferOptTestUtils.jl") + using InferOpt, .InferOptTestUtils, Random, Test + using LinearAlgebra, Zygote, Distributions + Random.seed!(63) + + ε = 1.0 + p(θ) = MvNormal(θ, ε^2 * I) + oracle(η) = η + + po = PerturbedOracle(oracle, p; nb_samples=1_000, seed=0) + pa = PerturbedAdditive(oracle; ε, nb_samples=1_000, seed=0) + + θ = randn(10) + @test po(θ) ≈ pa(θ) rtol = 0.001 + @test all(isapprox.(jacobian(po, θ), jacobian(pa, θ), rtol=0.001)) +end + +@testitem "Variance reduction" default_imports = false begin + include("InferOptTestUtils/src/InferOptTestUtils.jl") + using InferOpt, .InferOptTestUtils, Random, Test + using LinearAlgebra, Zygote + Random.seed!(63) + + ε = 1.0 + oracle(η) = η + + pa = PerturbedAdditive(oracle; ε, nb_samples=100, seed=0, variance_reduction=true) + pa_no_variance_reduction = PerturbedAdditive( + oracle; ε, nb_samples=100, seed=0, variance_reduction=false + ) + pm = PerturbedMultiplicative(oracle; ε, nb_samples=100, seed=0, variance_reduction=true) + pm_no_variance_reduction = PerturbedMultiplicative( + oracle; ε, nb_samples=100, seed=0, variance_reduction=false + ) + + n = 10 + θ = randn(10) + + Ja = jacobian(pa_no_variance_reduction, θ)[1] + Ja_reduced_variance = jacobian(pa, θ)[1] + + Jm = jacobian(pm_no_variance_reduction, θ)[1] + Jm_reduced_variance = jacobian(pm, θ)[1] + + J_true = Matrix(I, n, n) # exact jacobian is the identity matrix + + @test normalized_mape(Ja, J_true) > normalized_mape(Ja_reduced_variance, J_true) + @test normalized_mape(Jm, J_true) > normalized_mape(Jm_reduced_variance, J_true) +end + +@testitem "Perturbed - small ε convergence" default_imports = false begin + include("InferOptTestUtils/src/InferOptTestUtils.jl") + using InferOpt, .InferOptTestUtils, Random, Test + using LinearAlgebra, Zygote + Random.seed!(63) + + ε = 1e-12 + + already_differentiable(θ) = 2 ./ exp.(θ) .* θ .^ 2 .+ sum(θ) + pa = PerturbedAdditive(already_differentiable; ε, nb_samples=1e6, seed=0) + pm = PerturbedMultiplicative(already_differentiable; ε, nb_samples=1e6, seed=0) + + θ = [1.0, 2.0, 3.0, 4.0, 5.0] + + fz = already_differentiable(θ) + fa = pa(θ) + fm = pm(θ) + @test fz ≈ fa rtol = 0.01 + @test fz ≈ fm rtol = 0.01 + + Jz = jacobian(already_differentiable, θ)[1] + Ja = jacobian(pa, θ)[1] + Jm = jacobian(pm, θ)[1] + @test Ja ≈ Jz rtol = 0.01 + @test Jm ≈ Jz rtol = 0.01 +end diff --git a/test/perturbed_oracle.jl b/test/perturbed_oracle.jl deleted file mode 100644 index ef1bccb..0000000 --- a/test/perturbed_oracle.jl +++ /dev/null @@ -1,44 +0,0 @@ -@testitem "PerturbedOracle vs PerturbedAdditive" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random, Test - using LinearAlgebra, Zygote, Distributions - Random.seed!(63) - - ε = 1.0 - p(θ) = MvNormal(θ, ε^2 * I) - oracle(η) = η - - po = PerturbedOracle(oracle, p; nb_samples=1_000, seed=0) - pa = PerturbedAdditive(oracle; ε, nb_samples=1_000, seed=0) - - θ = randn(10) - @test po(θ) ≈ pa(θ) rtol = 0.001 - @test all(isapprox.(jacobian(po, θ), jacobian(pa, θ), rtol=0.001)) -end - -@testitem "Variance reduction" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random, Test - using LinearAlgebra, Zygote - Random.seed!(63) - - ε = 1.0 - oracle(η) = η - - pa = PerturbedAdditive(oracle; ε, nb_samples=100, seed=0) - pm = PerturbedAdditive(oracle; ε, nb_samples=100, seed=0) - - n = 10 - θ = randn(10) - - Ja = jacobian(θ -> pa(θ; autodiff_variance_reduction=false), θ)[1] - Ja_reduced_variance = jacobian(pa, θ)[1] - - Jm = jacobian(x -> pm(x; autodiff_variance_reduction=false), θ)[1] - Jm_reduced_variance = jacobian(pm, θ)[1] - - J_true = Matrix(I, n, n) # exact jacobian is the identity matrix - - @test normalized_mape(Ja, J_true) > normalized_mape(Ja_reduced_variance, J_true) - @test normalized_mape(Jm, J_true) > normalized_mape(Jm_reduced_variance, J_true) -end diff --git a/test/soft_rank.jl b/test/soft_rank.jl index dc36c81..7e51508 100644 --- a/test/soft_rank.jl +++ b/test/soft_rank.jl @@ -15,7 +15,7 @@ end using InferOpt, .InferOptTestUtils, Random, HiGHS, JuMP, Test Random.seed!(63) - function isotonic_custom(y) + function isotonic_jump(y) model = Model(HiGHS.Optimizer) set_silent(model) @@ -29,7 +29,7 @@ end for _ in 1:100 y = randn(1000) - x = isotonic_custom(y) + x = isotonic_jump(y) x2 = InferOpt.isotonic_l2(y) @test all(isapprox.(x, x2, atol=1e-2)) end @@ -68,90 +68,3 @@ end @test all(isapprox.(rank_jac, rank_jac_fd, atol=1e-4)) end end - -@testitem "Learn by experience soft rank" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Test - Random.seed!(63) - - true_encoder = encoder_factory() - cost(y; instance) = dot(y, -true_encoder(instance)) - - Random.seed!(67) - soft_rank_l2_results = test_pipeline!( - PipelineLossExperience(); - instance_dim=5, - true_maximizer=ranking, - maximizer=SoftRank(), - loss=cost, - error_function=hamming_distance, - true_encoder=true_encoder, - cost=cost, - epochs=50, - ) - - Random.seed!(67) - soft_rank_kl_results = test_pipeline!( - PipelineLossExperience(); - instance_dim=5, - true_maximizer=ranking, - maximizer=SoftRank(; regularization="kl"), - loss=cost, - error_function=hamming_distance, - true_encoder=true_encoder, - cost=cost, - epochs=50, - ) - - Random.seed!(67) - perturbed_results = test_pipeline!( - PipelineLossExperience(); - instance_dim=5, - true_maximizer=ranking, - maximizer=identity_kw, - loss=Pushforward(PerturbedAdditive(ranking; ε=1.0, nb_samples=10), cost), - error_function=hamming_distance, - true_encoder=true_encoder, - cost=cost, - epochs=50, - ) - - # Check that we achieve better performance than the reinforce trick - @test soft_rank_l2_results.test_cost_gaps[end] < perturbed_results.test_cost_gaps[end] - @test soft_rank_kl_results.test_cost_gaps[end] < perturbed_results.test_cost_gaps[end] -end - -@testitem "Fenchel-Young loss soft rank L2" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Test - Random.seed!(63) - - true_encoder = encoder_factory() - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=ranking, - maximizer=identity_kw, - loss=FenchelYoungLoss(SoftRank()), - error_function=hamming_distance, - true_encoder=true_encoder, - ) -end - -@testitem "Fenchel-Young loss soft rank kl" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Test - Random.seed!(63) - - true_encoder = encoder_factory() - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=ranking, - maximizer=identity_kw, - loss=FenchelYoungLoss(SoftRank(; regularization="kl", ε=10.0)), - error_function=hamming_distance, - true_encoder=true_encoder, - ) -end