From 7fd035e366f22484c8eab5e60fa13618b03487a9 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Wed, 24 Apr 2024 12:13:39 +0200 Subject: [PATCH 01/26] separate losses from layers --- src/InferOpt.jl | 40 ++++++++++--------- .../perturbed/abstract_perturbed.jl | 0 src/{ => layers}/perturbed/additive.jl | 0 src/{ => layers}/perturbed/multiplicative.jl | 0 .../perturbed/perturbed_oracle.jl | 0 .../regularized/abstract_regularized.jl | 0 .../regularized/regularized_frank_wolfe.jl | 0 src/{ => layers}/regularized/soft_argmax.jl | 0 src/{ => layers}/regularized/soft_rank.jl | 0 src/{ => layers}/regularized/sparse_argmax.jl | 0 src/{ => layers}/simple/identity.jl | 0 src/{ => layers}/simple/interpolation.jl | 0 .../fenchel_young_loss.jl | 0 src/{imitation => losses}/imitation_loss.jl | 0 src/{imitation => losses}/spoplus_loss.jl | 0 src/{imitation => losses}/ssvm_loss.jl | 0 src/{imitation => losses}/zero_one_loss.jl | 0 17 files changed, 21 insertions(+), 19 deletions(-) rename src/{ => layers}/perturbed/abstract_perturbed.jl (100%) rename src/{ => layers}/perturbed/additive.jl (100%) rename src/{ => layers}/perturbed/multiplicative.jl (100%) rename src/{ => layers}/perturbed/perturbed_oracle.jl (100%) rename src/{ => layers}/regularized/abstract_regularized.jl (100%) rename src/{ => layers}/regularized/regularized_frank_wolfe.jl (100%) rename src/{ => layers}/regularized/soft_argmax.jl (100%) rename src/{ => layers}/regularized/soft_rank.jl (100%) rename src/{ => layers}/regularized/sparse_argmax.jl (100%) rename src/{ => layers}/simple/identity.jl (100%) rename src/{ => layers}/simple/interpolation.jl (100%) rename src/{imitation => losses}/fenchel_young_loss.jl (100%) rename src/{imitation => losses}/imitation_loss.jl (100%) rename src/{imitation => losses}/spoplus_loss.jl (100%) rename src/{imitation => losses}/ssvm_loss.jl (100%) rename src/{imitation => losses}/zero_one_loss.jl (100%) diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 4d62096..c7f3e1a 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -28,30 +28,32 @@ include("utils/isotonic_regression/isotonic_l2.jl") include("utils/isotonic_regression/isotonic_kl.jl") include("utils/isotonic_regression/projection.jl") -include("simple/interpolation.jl") -include("simple/identity.jl") - -include("regularized/abstract_regularized.jl") -include("regularized/soft_argmax.jl") -include("regularized/sparse_argmax.jl") -include("regularized/soft_rank.jl") -include("regularized/regularized_frank_wolfe.jl") - -include("perturbed/abstract_perturbed.jl") -include("perturbed/additive.jl") -include("perturbed/multiplicative.jl") -include("perturbed/perturbed_oracle.jl") - -include("imitation/spoplus_loss.jl") -include("imitation/ssvm_loss.jl") -include("imitation/fenchel_young_loss.jl") -include("imitation/imitation_loss.jl") -include("imitation/zero_one_loss.jl") +# Layers +include("layers/simple/interpolation.jl") +include("layers/simple/identity.jl") + +include("layers/perturbed/abstract_perturbed.jl") +include("layers/perturbed/additive.jl") +include("layers/perturbed/multiplicative.jl") +include("layers/perturbed/perturbed_oracle.jl") + +include("layers/regularized/abstract_regularized.jl") +include("layers/regularized/soft_argmax.jl") +include("layers/regularized/sparse_argmax.jl") +include("layers/regularized/soft_rank.jl") +include("layers/regularized/regularized_frank_wolfe.jl") if !isdefined(Base, :get_extension) include("../ext/InferOptFrankWolfeExt.jl") end +# Losses +include("losses/fenchel_young_loss.jl") +include("losses/spoplus_loss.jl") +include("losses/ssvm_loss.jl") +include("losses/zero_one_loss.jl") +include("losses/imitation_loss.jl") + export half_square_norm export shannon_entropy, negative_shannon_entropy export one_hot_argmax, ranking diff --git a/src/perturbed/abstract_perturbed.jl b/src/layers/perturbed/abstract_perturbed.jl similarity index 100% rename from src/perturbed/abstract_perturbed.jl rename to src/layers/perturbed/abstract_perturbed.jl diff --git a/src/perturbed/additive.jl b/src/layers/perturbed/additive.jl similarity index 100% rename from src/perturbed/additive.jl rename to src/layers/perturbed/additive.jl diff --git a/src/perturbed/multiplicative.jl b/src/layers/perturbed/multiplicative.jl similarity index 100% rename from src/perturbed/multiplicative.jl rename to src/layers/perturbed/multiplicative.jl diff --git a/src/perturbed/perturbed_oracle.jl b/src/layers/perturbed/perturbed_oracle.jl similarity index 100% rename from src/perturbed/perturbed_oracle.jl rename to src/layers/perturbed/perturbed_oracle.jl diff --git a/src/regularized/abstract_regularized.jl b/src/layers/regularized/abstract_regularized.jl similarity index 100% rename from src/regularized/abstract_regularized.jl rename to src/layers/regularized/abstract_regularized.jl diff --git a/src/regularized/regularized_frank_wolfe.jl b/src/layers/regularized/regularized_frank_wolfe.jl similarity index 100% rename from src/regularized/regularized_frank_wolfe.jl rename to src/layers/regularized/regularized_frank_wolfe.jl diff --git a/src/regularized/soft_argmax.jl b/src/layers/regularized/soft_argmax.jl similarity index 100% rename from src/regularized/soft_argmax.jl rename to src/layers/regularized/soft_argmax.jl diff --git a/src/regularized/soft_rank.jl b/src/layers/regularized/soft_rank.jl similarity index 100% rename from src/regularized/soft_rank.jl rename to src/layers/regularized/soft_rank.jl diff --git a/src/regularized/sparse_argmax.jl b/src/layers/regularized/sparse_argmax.jl similarity index 100% rename from src/regularized/sparse_argmax.jl rename to src/layers/regularized/sparse_argmax.jl diff --git a/src/simple/identity.jl b/src/layers/simple/identity.jl similarity index 100% rename from src/simple/identity.jl rename to src/layers/simple/identity.jl diff --git a/src/simple/interpolation.jl b/src/layers/simple/interpolation.jl similarity index 100% rename from src/simple/interpolation.jl rename to src/layers/simple/interpolation.jl diff --git a/src/imitation/fenchel_young_loss.jl b/src/losses/fenchel_young_loss.jl similarity index 100% rename from src/imitation/fenchel_young_loss.jl rename to src/losses/fenchel_young_loss.jl diff --git a/src/imitation/imitation_loss.jl b/src/losses/imitation_loss.jl similarity index 100% rename from src/imitation/imitation_loss.jl rename to src/losses/imitation_loss.jl diff --git a/src/imitation/spoplus_loss.jl b/src/losses/spoplus_loss.jl similarity index 100% rename from src/imitation/spoplus_loss.jl rename to src/losses/spoplus_loss.jl diff --git a/src/imitation/ssvm_loss.jl b/src/losses/ssvm_loss.jl similarity index 100% rename from src/imitation/ssvm_loss.jl rename to src/losses/ssvm_loss.jl diff --git a/src/imitation/zero_one_loss.jl b/src/losses/zero_one_loss.jl similarity index 100% rename from src/imitation/zero_one_loss.jl rename to src/losses/zero_one_loss.jl From 2139efe796371163a2c4d51627fb946d83007e02 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Wed, 24 Apr 2024 12:32:46 +0200 Subject: [PATCH 02/26] remove unnecessary trainings in tests --- test/argmax.jl | 274 ----------------------- test/paths.jl | 259 --------------------- test/{ranking.jl => ranking_learning.jl} | 115 ++++++++-- test/soft_rank.jl | 91 +------- 4 files changed, 99 insertions(+), 640 deletions(-) delete mode 100644 test/argmax.jl delete mode 100644 test/paths.jl rename test/{ranking.jl => ranking_learning.jl} (70%) diff --git a/test/argmax.jl b/test/argmax.jl deleted file mode 100644 index 4075c51..0000000 --- a/test/argmax.jl +++ /dev/null @@ -1,274 +0,0 @@ -@testitem "Argmax - imit - SPO+ (θ)" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitationθ(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=identity_kw, - loss=SPOPlusLoss(one_hot_argmax), - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - SPO+ (θ & y)" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitationθy(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=identity_kw, - loss=SPOPlusLoss(one_hot_argmax), - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - SSVM" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=identity_kw, - loss=InferOpt.ZeroOneStructuredSVMLoss(), - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - MSE SparseArgmax" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=SparseArgmax(), - loss=mse_kw, - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - MSE SoftArgmax" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=SoftArgmax(), - loss=mse_kw, - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - MSE PerturbedAdditive" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=10), - loss=mse_kw, - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - MSE PerturbedMultiplicative" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=10), - loss=mse_kw, - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - MSE RegularizedFrankWolfe" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=RegularizedFrankWolfe( - one_hot_argmax; - Ω=half_square_norm, - Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), - ), - loss=mse_kw, - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - FYL SparseArgmax" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=identity_kw, - loss=FenchelYoungLoss(SparseArgmax()), - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - FYL SoftArgmax" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=identity_kw, - loss=FenchelYoungLoss(SoftArgmax()), - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - FYL PerturbedAdditive" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=identity_kw, - loss=FenchelYoungLoss(PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=5)), - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - FYL PerturbedMultiplicative" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=identity_kw, - loss=FenchelYoungLoss(PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=5)), - error_function=hamming_distance, - ) -end - -@testitem "Argmax - imit - FYL RegularizedFrankWolfe" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=identity_kw, - loss=FenchelYoungLoss( - RegularizedFrankWolfe( - one_hot_argmax; - Ω=half_square_norm, - Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), - ), - ), - error_function=hamming_distance, - ) -end - -@testitem "Argmax - exp - Pushforward PerturbedAdditive" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random - Random.seed!(63) - - true_encoder = encoder_factory() - cost(y; instance) = dot(y, -true_encoder(instance)) - test_pipeline!( - PipelineLossExperience(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=identity_kw, - loss=Pushforward(PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=10), cost), - error_function=hamming_distance, - true_encoder=true_encoder, - cost=cost, - ) -end - -@testitem "Argmax - exp - Pushforward PerturbedMultiplicative" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random - Random.seed!(63) - - true_encoder = encoder_factory() - cost(y; instance) = dot(y, -true_encoder(instance)) - test_pipeline!( - PipelineLossExperience(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=identity_kw, - loss=Pushforward( - PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=10), cost - ), - error_function=hamming_distance, - true_encoder=true_encoder, - cost=cost, - ) -end - -@testitem "Argmax - exp - Pushforward RegularizedFrankWolfe" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using DifferentiableFrankWolfe, - FrankWolfe, InferOpt, .InferOptTestUtils, LinearAlgebra, Random - Random.seed!(63) - - true_encoder = encoder_factory() - cost(y; instance) = dot(y, -true_encoder(instance)) - test_pipeline!( - PipelineLossExperience(); - instance_dim=5, - true_maximizer=one_hot_argmax, - maximizer=identity_kw, - loss=Pushforward( - RegularizedFrankWolfe( - one_hot_argmax; - Ω=half_square_norm, - Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), - ), - cost, - ), - error_function=hamming_distance, - true_encoder=true_encoder, - cost=cost, - ) -end diff --git a/test/paths.jl b/test/paths.jl deleted file mode 100644 index c38176b..0000000 --- a/test/paths.jl +++ /dev/null @@ -1,259 +0,0 @@ -@testitem "Paths - imit - SPO+ (θ)" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitationθ(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=identity_kw, - loss=SPOPlusLoss(shortest_path_maximizer), - error_function=mse_kw, - ) -end - -@testitem "Paths - imit - SPO+ (θ & y)" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitationθy(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=identity_kw, - loss=SPOPlusLoss(shortest_path_maximizer), - error_function=mse_kw, - ) -end - -@testitem "Paths - imit - MSE IdentityRelaxation" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=normalize ∘ IdentityRelaxation(shortest_path_maximizer), - loss=mse_kw, - error_function=mse_kw, - ) -end - -# @testitem "Paths - imit - MSE Interpolation" default_imports = false begin -# include("InferOptTestUtils/src/InferOptTestUtils.jl") -# using InferOpt, .InferOptTestUtils, Random -# Random.seed!(63) - -# test_pipeline!( -# PipelineLossImitation; -# instance_dim=(5, 5), -# true_maximizer=shortest_path_maximizer, -# maximizer=Interpolation(shortest_path_maximizer; λ=5.0), -# loss=mse_kw, -# error_function=mse_kw, -# ) -# end # TODO: make it work (doesn't seem to depend on λ) - -@testitem "Paths - imit - MSE PerturbedAdditive" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=PerturbedAdditive(shortest_path_maximizer; ε=1.0, nb_samples=10), - loss=mse_kw, - error_function=mse_kw, - ) -end - -@testitem "Paths - imit - MSE PerturbedMultiplicative" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=PerturbedMultiplicative(shortest_path_maximizer; ε=1.0, nb_samples=10), - loss=mse_kw, - error_function=mse_kw, - ) -end - -@testitem "Paths - imit - MSE RegularizedFrankWolfe" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=RegularizedFrankWolfe( - shortest_path_maximizer; - Ω=half_square_norm, - Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), - ), - loss=mse_kw, - error_function=mse_kw, - ) -end - -@testitem "Paths - imit - FYL PerturbedAdditive" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=identity_kw, - loss=FenchelYoungLoss( - PerturbedAdditive(shortest_path_maximizer; ε=1.0, nb_samples=5) - ), - error_function=mse_kw, - ) -end - -@testitem "Paths - imit - FYL PerturbedMultiplicative" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=identity_kw, - loss=FenchelYoungLoss( - PerturbedMultiplicative(shortest_path_maximizer; ε=1.0, nb_samples=5) - ), - error_function=mse_kw, - epochs=100, - ) -end - -@testitem "Paths - imit - FYL PerturbedAdditive{LogNormal}" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random, Distributions - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=identity_kw, - loss=FenchelYoungLoss( - PerturbedAdditive( - shortest_path_maximizer; ε=1.0, nb_samples=5, perturbation=LogNormal(0, 1) - ), - ), - error_function=mse_kw, - ) -end - -@testitem "Paths - imit - FYL RegularizedFrankWolfe" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random - Random.seed!(63) - - test_pipeline!( - PipelineLossImitation(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=identity_kw, - loss=FenchelYoungLoss( - RegularizedFrankWolfe( - shortest_path_maximizer; - Ω=half_square_norm, - Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), - ), - ), - error_function=mse_kw, - epochs=100, - ) -end - -@testitem "Paths - exp - Pushforward PerturbedAdditive" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random - Random.seed!(63) - - true_encoder = encoder_factory() - cost(y; instance) = dot(y, -true_encoder(instance)) - test_pipeline!( - PipelineLossExperience(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=identity_kw, - loss=Pushforward( - PerturbedAdditive(shortest_path_maximizer; ε=1.0, nb_samples=10), cost - ), - error_function=mse_kw, - true_encoder=true_encoder, - cost=cost, - epochs=500, - ) -end - -@testitem "Paths - exp - Pushforward PerturbedMultiplicative" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random - Random.seed!(63) - - true_encoder = encoder_factory() - cost(y; instance) = dot(y, -true_encoder(instance)) - test_pipeline!( - PipelineLossExperience(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=identity_kw, - loss=Pushforward( - PerturbedMultiplicative(shortest_path_maximizer; ε=1.0, nb_samples=10), cost - ), - error_function=mse_kw, - true_encoder=true_encoder, - cost=cost, - epochs=500, - ) -end - -@testitem "Paths - exp - Pushforward RegularizedFrankWolfe" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using DifferentiableFrankWolfe, - FrankWolfe, InferOpt, .InferOptTestUtils, LinearAlgebra, Random - Random.seed!(63) - - true_encoder = encoder_factory() - cost(y; instance) = dot(y, -true_encoder(instance)) - test_pipeline!( - PipelineLossExperience(); - instance_dim=(5, 5), - true_maximizer=shortest_path_maximizer, - maximizer=identity_kw, - loss=Pushforward( - RegularizedFrankWolfe( - shortest_path_maximizer; - Ω=half_square_norm, - Ω_grad=identity_kw, - frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()), - ), - cost, - ), - error_function=mse_kw, - true_encoder=true_encoder, - cost=cost, - epochs=200, - ) -end diff --git a/test/ranking.jl b/test/ranking_learning.jl similarity index 70% rename from test/ranking.jl rename to test/ranking_learning.jl index 122ed99..a6773f9 100644 --- a/test/ranking.jl +++ b/test/ranking_learning.jl @@ -1,4 +1,4 @@ -@testitem "Ranking - imit - SPO+ (θ)" default_imports = false begin +@testitem "imit - SPO+ (θ)" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -13,7 +13,7 @@ ) end -@testitem "Ranking - imit - SPO+ (θ & y)" default_imports = false begin +@testitem "imit - SPO+ (θ & y)" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -28,7 +28,7 @@ end ) end -@testitem "Ranking - imit - MSE IdentityRelaxation" default_imports = false begin +@testitem "imit - MSE IdentityRelaxation" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random Random.seed!(63) @@ -43,7 +43,7 @@ end ) end -@testitem "Ranking - imit - MSE Interpolation" default_imports = false begin +@testitem "imit - MSE Interpolation" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -58,7 +58,7 @@ end ) end -@testitem "Ranking - imit - MSE PerturbedAdditive" default_imports = false begin +@testitem "imit - MSE PerturbedAdditive" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -73,7 +73,7 @@ end ) end -@testitem "Ranking - imit - MSE PerturbedMultiplicative" default_imports = false begin +@testitem "imit - MSE PerturbedMultiplicative" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -88,7 +88,7 @@ end ) end -@testitem "Ranking - imit - MSE RegularizedFrankWolfe" default_imports = false begin +@testitem "imit - MSE RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -108,7 +108,7 @@ end ) end -@testitem "Ranking - imit - FYL PerturbedAdditive" default_imports = false begin +@testitem "imit - FYL PerturbedAdditive" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -123,7 +123,7 @@ end ) end -@testitem "Ranking - imit - FYL PerturbedMultiplicative" default_imports = false begin +@testitem "imit - FYL PerturbedMultiplicative" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -138,7 +138,7 @@ end ) end -@testitem "Ranking - imit - FYL PerturbedAdditive{LogNormal}" default_imports = false begin +@testitem "imit - FYL PerturbedAdditive{LogNormal}" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, Random, Distributions, LinearAlgebra Random.seed!(63) @@ -155,7 +155,7 @@ end ) end -@testitem "Ranking - imit - FYL RegularizedFrankWolfe" default_imports = false begin +@testitem "imit - FYL RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, Random Random.seed!(63) @@ -177,7 +177,7 @@ end ) end -@testitem "Ranking - exp - Pushforward PerturbedAdditive" default_imports = false begin +@testitem "exp - Pushforward PerturbedAdditive" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random Random.seed!(63) @@ -196,7 +196,7 @@ end ) end -@testitem "Ranking - exp - Pushforward PerturbedMultiplicative" default_imports = false begin +@testitem "exp - Pushforward PerturbedMultiplicative" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random Random.seed!(63) @@ -215,7 +215,7 @@ end ) end -@testitem "Ranking - exp - Pushforward PerturbedAdditive{LogNormal}" default_imports = false begin +@testitem "exp - Pushforward PerturbedAdditive{LogNormal}" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Distributions Random.seed!(63) @@ -238,8 +238,7 @@ end ) end -@testitem "Ranking - exp - Pushforward PerturbedMultiplicative{LogNormal}" default_imports = - false begin +@testitem "exp - Pushforward PerturbedMultiplicative{LogNormal}" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Distributions Random.seed!(63) @@ -264,7 +263,7 @@ end ) end -@testitem "Ranking - exp - Pushforward PerturbedOracle{LogNormal}" default_imports = false begin +@testitem "exp - Pushforward PerturbedOracle{LogNormal}" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Distributions Random.seed!(63) @@ -285,7 +284,7 @@ end ) end -@testitem "Ranking - exp - Pushforward RegularizedFrankWolfe" default_imports = false begin +@testitem "exp - Pushforward RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") using DifferentiableFrankWolfe, FrankWolfe, InferOpt, .InferOptTestUtils, LinearAlgebra, Random @@ -312,3 +311,83 @@ end cost=cost, ) end + +@testitem "exp - soft rank" default_imports = false begin + include("InferOptTestUtils/src/InferOptTestUtils.jl") + using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Test + Random.seed!(63) + + true_encoder = encoder_factory() + cost(y; instance) = dot(y, -true_encoder(instance)) + + Random.seed!(67) + soft_rank_l2_results = test_pipeline!( + PipelineLossExperience(); + instance_dim=5, + true_maximizer=ranking, + maximizer=SoftRank(), + loss=cost, + error_function=hamming_distance, + true_encoder=true_encoder, + cost=cost, + epochs=50, + ) + + Random.seed!(67) + soft_rank_kl_results = test_pipeline!( + PipelineLossExperience(); + instance_dim=5, + true_maximizer=ranking, + maximizer=SoftRank(; regularization="kl"), + loss=cost, + error_function=hamming_distance, + true_encoder=true_encoder, + cost=cost, + epochs=50, + ) + + Random.seed!(67) + perturbed_results = test_pipeline!( + PipelineLossExperience(); + instance_dim=5, + true_maximizer=ranking, + maximizer=identity_kw, + loss=Pushforward(PerturbedAdditive(ranking; ε=1.0, nb_samples=10), cost), + error_function=hamming_distance, + true_encoder=true_encoder, + cost=cost, + epochs=50, + ) + + # Check that we achieve better performance than the reinforce trick + @test soft_rank_l2_results.test_cost_gaps[end] < perturbed_results.test_cost_gaps[end] + @test soft_rank_kl_results.test_cost_gaps[end] < perturbed_results.test_cost_gaps[end] +end + +@testitem "imit - FYL - soft rank" default_imports = false begin + include("InferOptTestUtils/src/InferOptTestUtils.jl") + using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Test + Random.seed!(63) + + true_encoder = encoder_factory() + + test_pipeline!( + PipelineLossImitation(); + instance_dim=5, + true_maximizer=ranking, + maximizer=identity_kw, + loss=FenchelYoungLoss(SoftRank()), + error_function=hamming_distance, + true_encoder=true_encoder, + ) + + test_pipeline!( + PipelineLossImitation(); + instance_dim=5, + true_maximizer=ranking, + maximizer=identity_kw, + loss=FenchelYoungLoss(SoftRank(; regularization="kl", ε=10.0)), + error_function=hamming_distance, + true_encoder=true_encoder, + ) +end diff --git a/test/soft_rank.jl b/test/soft_rank.jl index dc36c81..7e51508 100644 --- a/test/soft_rank.jl +++ b/test/soft_rank.jl @@ -15,7 +15,7 @@ end using InferOpt, .InferOptTestUtils, Random, HiGHS, JuMP, Test Random.seed!(63) - function isotonic_custom(y) + function isotonic_jump(y) model = Model(HiGHS.Optimizer) set_silent(model) @@ -29,7 +29,7 @@ end for _ in 1:100 y = randn(1000) - x = isotonic_custom(y) + x = isotonic_jump(y) x2 = InferOpt.isotonic_l2(y) @test all(isapprox.(x, x2, atol=1e-2)) end @@ -68,90 +68,3 @@ end @test all(isapprox.(rank_jac, rank_jac_fd, atol=1e-4)) end end - -@testitem "Learn by experience soft rank" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Test - Random.seed!(63) - - true_encoder = encoder_factory() - cost(y; instance) = dot(y, -true_encoder(instance)) - - Random.seed!(67) - soft_rank_l2_results = test_pipeline!( - PipelineLossExperience(); - instance_dim=5, - true_maximizer=ranking, - maximizer=SoftRank(), - loss=cost, - error_function=hamming_distance, - true_encoder=true_encoder, - cost=cost, - epochs=50, - ) - - Random.seed!(67) - soft_rank_kl_results = test_pipeline!( - PipelineLossExperience(); - instance_dim=5, - true_maximizer=ranking, - maximizer=SoftRank(; regularization="kl"), - loss=cost, - error_function=hamming_distance, - true_encoder=true_encoder, - cost=cost, - epochs=50, - ) - - Random.seed!(67) - perturbed_results = test_pipeline!( - PipelineLossExperience(); - instance_dim=5, - true_maximizer=ranking, - maximizer=identity_kw, - loss=Pushforward(PerturbedAdditive(ranking; ε=1.0, nb_samples=10), cost), - error_function=hamming_distance, - true_encoder=true_encoder, - cost=cost, - epochs=50, - ) - - # Check that we achieve better performance than the reinforce trick - @test soft_rank_l2_results.test_cost_gaps[end] < perturbed_results.test_cost_gaps[end] - @test soft_rank_kl_results.test_cost_gaps[end] < perturbed_results.test_cost_gaps[end] -end - -@testitem "Fenchel-Young loss soft rank L2" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Test - Random.seed!(63) - - true_encoder = encoder_factory() - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=ranking, - maximizer=identity_kw, - loss=FenchelYoungLoss(SoftRank()), - error_function=hamming_distance, - true_encoder=true_encoder, - ) -end - -@testitem "Fenchel-Young loss soft rank kl" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Test - Random.seed!(63) - - true_encoder = encoder_factory() - - test_pipeline!( - PipelineLossImitation(); - instance_dim=5, - true_maximizer=ranking, - maximizer=identity_kw, - loss=FenchelYoungLoss(SoftRank(; regularization="kl", ε=10.0)), - error_function=hamming_distance, - true_encoder=true_encoder, - ) -end From 34eed2fd20e20f498e9b05726a1b342f2929d352 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Fri, 26 Apr 2024 16:03:58 +0200 Subject: [PATCH 03/26] reorganize tests --- ...e.jl => abstract_regularized_interface.jl} | 1 + test/jacobian_approx.jl | 42 ------- ...r.jl => learning_generalized_maximizer.jl} | 0 ...anking_learning.jl => learning_ranking.jl} | 0 test/perturbed.jl | 111 ++++++++++++++++++ test/perturbed_oracle.jl | 44 ------- 6 files changed, 112 insertions(+), 86 deletions(-) rename test/{interface.jl => abstract_regularized_interface.jl} (84%) delete mode 100644 test/jacobian_approx.jl rename test/{generalized_maximizer.jl => learning_generalized_maximizer.jl} (100%) rename test/{ranking_learning.jl => learning_ranking.jl} (100%) create mode 100644 test/perturbed.jl delete mode 100644 test/perturbed_oracle.jl diff --git a/test/interface.jl b/test/abstract_regularized_interface.jl similarity index 84% rename from test/interface.jl rename to test/abstract_regularized_interface.jl index 0e6b8b6..a6d1687 100644 --- a/test/interface.jl +++ b/test/abstract_regularized_interface.jl @@ -5,4 +5,5 @@ @test RI.check_interface_implemented(AbstractRegularized, RegularizedFrankWolfe) @test RI.check_interface_implemented(AbstractRegularized, SoftArgmax) @test RI.check_interface_implemented(AbstractRegularized, SparseArgmax) + @test RI.check_interface_implemented(AbstractRegularized, SoftRank) end diff --git a/test/jacobian_approx.jl b/test/jacobian_approx.jl deleted file mode 100644 index abd80ed..0000000 --- a/test/jacobian_approx.jl +++ /dev/null @@ -1,42 +0,0 @@ -@testitem "Jacobian approx" begin - using LinearAlgebra - using Random - using Test - using Zygote - - θ = [3, 5, 4, 2] - - perturbed1 = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=1_000, seed=0) - perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=10_000, seed=0) - perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=0.5, nb_samples=1_000, seed=0) - perturbed2_big = PerturbedMultiplicative( - one_hot_argmax; ε=0.5, nb_samples=10_000, seed=0 - ) - - @testset "PerturbedAdditive" begin - # Compute jacobian with reverse mode - jac1 = Zygote.jacobian(θ -> perturbed1(θ; autodiff_variance_reduction=false), θ)[1] - jac1_big = Zygote.jacobian( - θ -> perturbed1_big(θ; autodiff_variance_reduction=false), θ - )[1] - # Only diagonal should be positive - @test all(diag(jac1) .>= 0) - @test all(jac1 - Diagonal(jac1) .<= 0) - # Order of diagonal coefficients should follow order of θ - @test sortperm(diag(jac1)) == sortperm(θ) - # No scaling with nb of samples - @test norm(jac1) ≈ norm(jac1_big) rtol = 1e-2 - end - - @testset "PerturbedMultiplicative" begin - jac2 = Zygote.jacobian(θ -> perturbed2(θ; autodiff_variance_reduction=false), θ)[1] - jac2_big = Zygote.jacobian( - θ -> perturbed2_big(θ; autodiff_variance_reduction=false), θ - )[1] - @test all(diag(jac2) .>= 0) - @test all(jac2 - Diagonal(jac2) .<= 0) - @test sortperm(diag(jac2)) != sortperm(θ) - # This is not equal because the diagonal coefficient for θ₃ = 4 is often larger than the one for θ₂ = 5. It happens because θ₃ has the opportunity to *become* the argmax (and hence switch from 0 to 1), whereas θ₂ already *is* the argmax. - @test norm(jac2) ≈ norm(jac2_big) rtol = 1e-2 - end -end diff --git a/test/generalized_maximizer.jl b/test/learning_generalized_maximizer.jl similarity index 100% rename from test/generalized_maximizer.jl rename to test/learning_generalized_maximizer.jl diff --git a/test/ranking_learning.jl b/test/learning_ranking.jl similarity index 100% rename from test/ranking_learning.jl rename to test/learning_ranking.jl diff --git a/test/perturbed.jl b/test/perturbed.jl new file mode 100644 index 0000000..7a09e16 --- /dev/null +++ b/test/perturbed.jl @@ -0,0 +1,111 @@ +@testitem "Jacobian approx" begin + using LinearAlgebra + using Random + using Test + using Zygote + + θ = [3, 5, 4, 2] + + perturbed1 = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=1_000, seed=0) + perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=10_000, seed=0) + perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=0.5, nb_samples=1_000, seed=0) + perturbed2_big = PerturbedMultiplicative( + one_hot_argmax; ε=0.5, nb_samples=10_000, seed=0 + ) + + @testset "PerturbedAdditive" begin + # Compute jacobian with reverse mode + jac1 = Zygote.jacobian(θ -> perturbed1(θ; autodiff_variance_reduction=false), θ)[1] + jac1_big = Zygote.jacobian( + θ -> perturbed1_big(θ; autodiff_variance_reduction=false), θ + )[1] + # Only diagonal should be positive + @test all(diag(jac1) .>= 0) + @test all(jac1 - Diagonal(jac1) .<= 0) + # Order of diagonal coefficients should follow order of θ + @test sortperm(diag(jac1)) == sortperm(θ) + # No scaling with nb of samples + @test norm(jac1) ≈ norm(jac1_big) rtol = 1e-2 + end + + @testset "PerturbedMultiplicative" begin + jac2 = Zygote.jacobian(θ -> perturbed2(θ; autodiff_variance_reduction=false), θ)[1] + jac2_big = Zygote.jacobian( + θ -> perturbed2_big(θ; autodiff_variance_reduction=false), θ + )[1] + @test all(diag(jac2) .>= 0) + @test all(jac2 - Diagonal(jac2) .<= 0) + @test sortperm(diag(jac2)) != sortperm(θ) + # This is not equal because the diagonal coefficient for θ₃ = 4 is often larger than the one for θ₂ = 5. It happens because θ₃ has the opportunity to *become* the argmax (and hence switch from 0 to 1), whereas θ₂ already *is* the argmax. + @test norm(jac2) ≈ norm(jac2_big) rtol = 1e-2 + end +end + +@testitem "PerturbedOracle vs PerturbedAdditive" default_imports = false begin + include("InferOptTestUtils/src/InferOptTestUtils.jl") + using InferOpt, .InferOptTestUtils, Random, Test + using LinearAlgebra, Zygote, Distributions + Random.seed!(63) + + ε = 1.0 + p(θ) = MvNormal(θ, ε^2 * I) + oracle(η) = η + + po = PerturbedOracle(oracle, p; nb_samples=1_000, seed=0) + pa = PerturbedAdditive(oracle; ε, nb_samples=1_000, seed=0) + + θ = randn(10) + @test po(θ) ≈ pa(θ) rtol = 0.001 + @test all(isapprox.(jacobian(po, θ), jacobian(pa, θ), rtol=0.001)) +end + +@testitem "Variance reduction" default_imports = false begin + include("InferOptTestUtils/src/InferOptTestUtils.jl") + using InferOpt, .InferOptTestUtils, Random, Test + using LinearAlgebra, Zygote + Random.seed!(63) + + ε = 1.0 + oracle(η) = η + + pa = PerturbedAdditive(oracle; ε, nb_samples=100, seed=0) + pm = PerturbedMultiplicative(oracle; ε, nb_samples=100, seed=0) + + n = 10 + θ = randn(10) + + Ja = jacobian(θ -> pa(θ; autodiff_variance_reduction=false), θ)[1] + Ja_reduced_variance = jacobian(pa, θ)[1] + + Jm = jacobian(x -> pm(x; autodiff_variance_reduction=false), θ)[1] + Jm_reduced_variance = jacobian(pm, θ)[1] + + J_true = Matrix(I, n, n) # exact jacobian is the identity matrix + + @test normalized_mape(Ja, J_true) > normalized_mape(Ja_reduced_variance, J_true) + @test normalized_mape(Jm, J_true) > normalized_mape(Jm_reduced_variance, J_true) +end + +@testitem "Perturbed - small ε convergence" default_imports = false begin + include("InferOptTestUtils/src/InferOptTestUtils.jl") + using InferOpt, .InferOptTestUtils, Random, Test + using LinearAlgebra, Zygote + Random.seed!(63) + + ε = 1e-12 + + function already_differentiable(θ) + return 2 ./ exp.(θ) .* θ .^ 2 + end + + θ = randn(5) + Jz = jacobian(already_differentiable, θ)[1] + + pa = PerturbedAdditive(already_differentiable; ε, nb_samples=1e6, seed=0) + Ja = jacobian(pa, θ)[1] + @test_broken all(isapprox.(Ja, Jz, rtol=0.01)) + + pm = PerturbedMultiplicative(already_differentiable; ε, nb_samples=1e6, seed=0) + Jm = jacobian(pm, θ)[1] + @test_broken all(isapprox.(Jm, Jz, rtol=0.01)) +end diff --git a/test/perturbed_oracle.jl b/test/perturbed_oracle.jl deleted file mode 100644 index ef1bccb..0000000 --- a/test/perturbed_oracle.jl +++ /dev/null @@ -1,44 +0,0 @@ -@testitem "PerturbedOracle vs PerturbedAdditive" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random, Test - using LinearAlgebra, Zygote, Distributions - Random.seed!(63) - - ε = 1.0 - p(θ) = MvNormal(θ, ε^2 * I) - oracle(η) = η - - po = PerturbedOracle(oracle, p; nb_samples=1_000, seed=0) - pa = PerturbedAdditive(oracle; ε, nb_samples=1_000, seed=0) - - θ = randn(10) - @test po(θ) ≈ pa(θ) rtol = 0.001 - @test all(isapprox.(jacobian(po, θ), jacobian(pa, θ), rtol=0.001)) -end - -@testitem "Variance reduction" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, Random, Test - using LinearAlgebra, Zygote - Random.seed!(63) - - ε = 1.0 - oracle(η) = η - - pa = PerturbedAdditive(oracle; ε, nb_samples=100, seed=0) - pm = PerturbedAdditive(oracle; ε, nb_samples=100, seed=0) - - n = 10 - θ = randn(10) - - Ja = jacobian(θ -> pa(θ; autodiff_variance_reduction=false), θ)[1] - Ja_reduced_variance = jacobian(pa, θ)[1] - - Jm = jacobian(x -> pm(x; autodiff_variance_reduction=false), θ)[1] - Jm_reduced_variance = jacobian(pm, θ)[1] - - J_true = Matrix(I, n, n) # exact jacobian is the identity matrix - - @test normalized_mape(Ja, J_true) > normalized_mape(Ja_reduced_variance, J_true) - @test normalized_mape(Jm, J_true) > normalized_mape(Jm_reduced_variance, J_true) -end From 09cfe8dad8e1885f6f4c45561def85dae25a5717 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Tue, 30 Jul 2024 17:29:47 +0200 Subject: [PATCH 04/26] wip --- Project.toml | 3 + src/InferOpt.jl | 9 +- ...ct_perturbed.jl => _abstract_perturbed.jl} | 0 src/layers/perturbed/_additive.jl | 128 +++++++++++++ src/layers/perturbed/_multiplicative.jl | 130 ++++++++++++++ ...rturbed_oracle.jl => _perturbed_oracle.jl} | 0 src/layers/perturbed/additive.jl | 128 +------------ src/layers/perturbed/multiplicative.jl | 135 ++------------ src/layers/perturbed/perturbed.jl | 65 +++++++ src/losses/_fenchel_young_loss.jl | 169 ++++++++++++++++++ src/losses/fenchel_young_loss.jl | 166 ++++++++--------- 11 files changed, 606 insertions(+), 327 deletions(-) rename src/layers/perturbed/{abstract_perturbed.jl => _abstract_perturbed.jl} (100%) create mode 100644 src/layers/perturbed/_additive.jl create mode 100644 src/layers/perturbed/_multiplicative.jl rename src/layers/perturbed/{perturbed_oracle.jl => _perturbed_oracle.jl} (100%) create mode 100644 src/layers/perturbed/perturbed.jl create mode 100644 src/losses/_fenchel_young_loss.jl diff --git a/Project.toml b/Project.toml index a11eaf2..55fc110 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,9 @@ version = "0.6.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +DifferentiableExpectations = "fc55d66b-b2a8-4ccc-9d64-c0c2166ceb36" DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" @@ -24,6 +26,7 @@ InferOptFrankWolfeExt = "DifferentiableFrankWolfe" [compat] ChainRulesCore = "1" DensityInterface = "0.4.0" +DifferentiableExpectations = "0.1" DifferentiableFrankWolfe = "0.2" LinearAlgebra = "<0.0.1,1" Random = "<0.0.1,1" diff --git a/src/InferOpt.jl b/src/InferOpt.jl index c7f3e1a..44607b6 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -10,8 +10,11 @@ module InferOpt using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, Tangent, ZeroTangent using ChainRulesCore: rrule, rrule_via_ad, unthunk using DensityInterface: logdensityof +using DifferentiableExpectations: Reinforce, empirical_predistribution, FixKwargs +using Distributions: + Distributions, ContinuousUnivariateDistribution, LogNormal, Normal, product_distribution using LinearAlgebra: dot -using Random: AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed! +using Random: Random, AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed! using Statistics: mean using StatsBase: StatsBase, sample using StatsFuns: logaddexp, softmax @@ -32,10 +35,10 @@ include("utils/isotonic_regression/projection.jl") include("layers/simple/interpolation.jl") include("layers/simple/identity.jl") -include("layers/perturbed/abstract_perturbed.jl") +# include("layers/perturbed/abstract_perturbed.jl") include("layers/perturbed/additive.jl") include("layers/perturbed/multiplicative.jl") -include("layers/perturbed/perturbed_oracle.jl") +include("layers/perturbed/perturbed.jl") include("layers/regularized/abstract_regularized.jl") include("layers/regularized/soft_argmax.jl") diff --git a/src/layers/perturbed/abstract_perturbed.jl b/src/layers/perturbed/_abstract_perturbed.jl similarity index 100% rename from src/layers/perturbed/abstract_perturbed.jl rename to src/layers/perturbed/_abstract_perturbed.jl diff --git a/src/layers/perturbed/_additive.jl b/src/layers/perturbed/_additive.jl new file mode 100644 index 0000000..e9923d5 --- /dev/null +++ b/src/layers/perturbed/_additive.jl @@ -0,0 +1,128 @@ +""" + PerturbedAdditive{P,G,O,R,S,parallel} <: AbstractPerturbed{parallel} + +Differentiable normal perturbation of a black-box maximizer: the input undergoes `θ -> θ + εZ` where `Z ∼ N(0, I)`. + +This [`AbstractOptimizationLayer`](@ref) is compatible with [`FenchelYoungLoss`](@ref), +if the oracle is an optimization maximizer with a linear objective. + +Reference: + +See [`AbstractPerturbed`](@ref) for more details. + +# Specific field +- `ε:Float64`: size of the perturbation +""" +struct PerturbedAdditive{P,G,O,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: + AbstractPerturbed{O,parallel} + perturbation::P + grad_logdensity::G + oracle::O + rng::R + seed::S + nb_samples::Int + ε::Float64 +end + +function Base.show(io::IO, perturbed::PerturbedAdditive) + (; oracle, ε, rng, seed, nb_samples, perturbation) = perturbed + perturb = isnothing(perturbation) ? "Normal(0, 1)" : "$perturbation" + return print( + io, "PerturbedAdditive($oracle, $ε, $nb_samples, $(typeof(rng)), $seed, $perturb)" + ) +end + +""" + PerturbedAdditive(oracle[; ε, nb_samples, seed, is_parallel, perturbation, grad_logdensity, rng]) + +[`PerturbedAdditive`](@ref) constructor. + +# Arguments +- `oracle`: the black-box oracle we want to differentiate through. + It should be a linear maximizer if you want to use it inside a [`FenchelYoungLoss`](@ref). + +# Keyword arguments (optional) +- `ε=1.0`: size of the perturbation. +- `nb_samples::Int=1`: number of perturbation samples drawn at each forward pass. +- `perturbation=nothing`: nothing by default. If you want to use a different distribution than a + `Normal` for the perturbation `z`, give it here as a distribution-like object implementing + the `rand` method. It should also implement `logdensityof` if `grad_logdensity` is not given. +- `grad_logdensity=nothing`: gradient function of `perturbation` w.r.t. `θ`. + If set to nothing (default), it's computed using automatic differentiation. +- `seed::Union{Nothing,Int}=nothing`: seed of the perturbation. + It is reset each time the forward pass is called, + making it deterministic by always drawing the same perturbations. + If you do not want this behaviour, set this field to `nothing`. +- `rng::AbstractRNG`=MersenneTwister(0): random number generator using the `seed`. +""" +function PerturbedAdditive( + oracle::O; + ε=1.0, + nb_samples=1, + seed::S=nothing, + is_parallel::Bool=false, + perturbation::P=nothing, + grad_logdensity::G=nothing, + rng::R=MersenneTwister(0), +) where {P,G,O,R<:AbstractRNG,S<:Union{Int,Nothing}} + return PerturbedAdditive{P,G,O,R,S,is_parallel}( + perturbation, grad_logdensity, oracle, rng, seed, nb_samples, float(ε) + ) +end + +function sample_perturbations(perturbed::PerturbedAdditive, θ::AbstractArray) + (; rng, seed, nb_samples, perturbation) = perturbed + seed!(rng, seed) + return [rand(rng, perturbation, size(θ)) for _ in 1:nb_samples] +end + +function sample_perturbations(perturbed::PerturbedAdditive{Nothing}, θ::AbstractArray) + (; rng, seed, nb_samples) = perturbed + seed!(rng, seed) + return [randn(rng, size(θ)) for _ in 1:nb_samples] +end + +function compute_probability_distribution_from_samples( + perturbed::PerturbedAdditive, + θ::AbstractArray, + Z_samples::Vector{<:AbstractArray}; + kwargs..., +) + (; ε) = perturbed + η_samples = [θ .+ ε .* Z for Z in Z_samples] + atoms = compute_atoms(perturbed, η_samples; kwargs...) + weights = ones(length(atoms)) ./ length(atoms) + probadist = FixedAtomsProbabilityDistribution(atoms, weights) + return probadist +end + +function perturbation_grad_logdensity( + ::RuleConfig, + perturbed::PerturbedAdditive{Nothing,Nothing}, + θ::AbstractArray, + Z::AbstractArray, +) + (; ε) = perturbed + return Z ./ ε +end + +function _perturbation_logdensity( + perturbed::PerturbedAdditive, θ::AbstractArray, η::AbstractArray +) + (; ε, perturbation) = perturbed + Z = (η .- θ) ./ ε + return sum(logdensityof(perturbation, z) for z in Z) +end + +function perturbation_grad_logdensity( + rc::RuleConfig, + perturbed::PerturbedAdditive{P,Nothing}, + θ::AbstractArray, + Z::AbstractArray, +) where {P} + (; ε) = perturbed + η = θ .+ ε .* Z + l, logdensity_pullback = rrule_via_ad(rc, _perturbation_logdensity, perturbed, θ, η) + δperturbation_logdensity, δperturbed, δθ, δη = logdensity_pullback(one(l)) + return δθ +end diff --git a/src/layers/perturbed/_multiplicative.jl b/src/layers/perturbed/_multiplicative.jl new file mode 100644 index 0000000..f96a333 --- /dev/null +++ b/src/layers/perturbed/_multiplicative.jl @@ -0,0 +1,130 @@ +""" + PerturbedMultiplicative{P,G,O,R,S,parallel} <: AbstractPerturbed{parallel} + +Differentiable multiplicative perturbation of a black-box oracle: +the input undergoes `θ -> θ ⊙ exp[εZ - ε²/2]` where `Z ∼ perturbation`. + +This [`AbstractOptimizationLayer`](@ref) is compatible with [`FenchelYoungLoss`](@ref), +if the oracle is an optimization maximizer with a linear objective. + +Reference: + +See [`AbstractPerturbed`](@ref) for more details. + +# Specific field +- `ε:Float64`: size of the perturbation +""" +struct PerturbedMultiplicative{P,G,O,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: + AbstractPerturbed{O,parallel} + perturbation::P + grad_logdensity::G + oracle::O + rng::R + seed::S + nb_samples::Int + ε::Float64 +end + +function Base.show(io::IO, perturbed::PerturbedMultiplicative) + (; oracle, ε, rng, seed, nb_samples, perturbation) = perturbed + perturb = isnothing(perturbation) ? "Normal(0, 1)" : "$perturbation" + return print( + io, + "PerturbedMultiplicative($oracle, $ε, $nb_samples, $(typeof(rng)), $seed, $perturb)", + ) +end + +""" + PerturbedMultiplicative(oracle[; ε, nb_samples, seed, is_parallel, perturbation, grad_logdensity, rng]) + +[`PerturbedMultiplicative`](@ref) constructor. + +# Arguments +- `oracle`: the black-box oracle we want to differentiate through. + It should be a linear maximizer if you want to use it inside a [`FenchelYoungLoss`]. + +# Keyword arguments (optional) +- `ε=1.0`: size of the perturbation. +- `nb_samples::Int=1`: number of perturbation samples drawn at each forward pass. +- `perturbation=nothing`: nothing by default. If you want to use a different distribution than a + `Normal` for the perturbation `z`, give it here as a distribution-like object implementing + the `rand` method. It should also implement `logdensityof` if `grad_logdensity` is not given. +- `grad_logdensity=nothing`: gradient function of `perturbation` w.r.t. `θ`. + If set to nothing (default), it's computed using automatic differentiation. +- `seed::Union{Nothing,Int}=nothing`: seed of the perturbation. + It is reset each time the forward pass is called, + making it deterministic by always drawing the same perturbations. + If you do not want this behaviour, set this field to `nothing`. +- `rng::AbstractRNG`=MersenneTwister(0): random number generator using the `seed`. +""" +function PerturbedMultiplicative( + oracle::O; + ε=1.0, + nb_samples=1, + seed::S=nothing, + is_parallel=false, + perturbation::P=nothing, + grad_logdensity::G=nothing, + rng::R=MersenneTwister(0), +) where {P,G,O,R<:AbstractRNG,S<:Union{Int,Nothing}} + return PerturbedMultiplicative{P,G,O,R,S,is_parallel}( + perturbation, grad_logdensity, oracle, rng, seed, nb_samples, float(ε) + ) +end + +function sample_perturbations(perturbed::PerturbedMultiplicative, θ::AbstractArray) + (; rng, seed, nb_samples, perturbation) = perturbed + seed!(rng, seed) + return [rand(rng, perturbation, size(θ)) for _ in 1:nb_samples] +end + +function sample_perturbations(perturbed::PerturbedMultiplicative{Nothing}, θ::AbstractArray) + (; rng, seed, nb_samples) = perturbed + seed!(rng, seed) + return [randn(rng, size(θ)) for _ in 1:nb_samples] +end + +function compute_probability_distribution_from_samples( + perturbed::PerturbedMultiplicative, + θ::AbstractArray, + Z_samples::Vector{<:AbstractArray}; + kwargs..., +) + (; ε) = perturbed + η_samples = [θ .* exp.(ε .* Z .- ε^2 / 2) for Z in Z_samples] + atoms = compute_atoms(perturbed, η_samples; kwargs...) + weights = ones(length(atoms)) ./ length(atoms) + probadist = FixedAtomsProbabilityDistribution(atoms, weights) + return probadist +end + +function perturbation_grad_logdensity( + ::RuleConfig, + perturbed::PerturbedMultiplicative{Nothing,Nothing}, + θ::AbstractArray, + Z::AbstractArray, +) + (; ε) = perturbed + return inv.(ε .* θ) .* Z +end + +function _perturbation_logdensity( + perturbed::PerturbedMultiplicative, θ::AbstractArray, η::AbstractArray +) + (; ε, perturbation) = perturbed + Z = (log.(η) .- log.(θ)) ./ ε .+ ε / 2 + return sum(logdensityof(perturbation, z) for z in Z) +end + +function perturbation_grad_logdensity( + rc::RuleConfig, + perturbed::PerturbedMultiplicative{P,Nothing}, + θ::AbstractArray, + Z::AbstractArray, +) where {P} + (; ε) = perturbed + η = θ .* exp.(ε .* Z .- ε^2 / 2) + l, logdensity_pullback = rrule_via_ad(rc, _perturbation_logdensity, perturbed, θ, η) + δperturbation_logdensity, δperturbed, δθ, δη = logdensity_pullback(one(l)) + return δθ +end diff --git a/src/layers/perturbed/perturbed_oracle.jl b/src/layers/perturbed/_perturbed_oracle.jl similarity index 100% rename from src/layers/perturbed/perturbed_oracle.jl rename to src/layers/perturbed/_perturbed_oracle.jl diff --git a/src/layers/perturbed/additive.jl b/src/layers/perturbed/additive.jl index e9923d5..c79d0b7 100644 --- a/src/layers/perturbed/additive.jl +++ b/src/layers/perturbed/additive.jl @@ -1,128 +1,12 @@ -""" - PerturbedAdditive{P,G,O,R,S,parallel} <: AbstractPerturbed{parallel} - -Differentiable normal perturbation of a black-box maximizer: the input undergoes `θ -> θ + εZ` where `Z ∼ N(0, I)`. - -This [`AbstractOptimizationLayer`](@ref) is compatible with [`FenchelYoungLoss`](@ref), -if the oracle is an optimization maximizer with a linear objective. - -Reference: - -See [`AbstractPerturbed`](@ref) for more details. - -# Specific field -- `ε:Float64`: size of the perturbation -""" -struct PerturbedAdditive{P,G,O,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: - AbstractPerturbed{O,parallel} - perturbation::P - grad_logdensity::G - oracle::O - rng::R - seed::S - nb_samples::Int +struct AdditivePerturbation{F} + perturbation_dist::F ε::Float64 end -function Base.show(io::IO, perturbed::PerturbedAdditive) - (; oracle, ε, rng, seed, nb_samples, perturbation) = perturbed - perturb = isnothing(perturbation) ? "Normal(0, 1)" : "$perturbation" - return print( - io, "PerturbedAdditive($oracle, $ε, $nb_samples, $(typeof(rng)), $seed, $perturb)" - ) -end - """ - PerturbedAdditive(oracle[; ε, nb_samples, seed, is_parallel, perturbation, grad_logdensity, rng]) - -[`PerturbedAdditive`](@ref) constructor. - -# Arguments -- `oracle`: the black-box oracle we want to differentiate through. - It should be a linear maximizer if you want to use it inside a [`FenchelYoungLoss`](@ref). - -# Keyword arguments (optional) -- `ε=1.0`: size of the perturbation. -- `nb_samples::Int=1`: number of perturbation samples drawn at each forward pass. -- `perturbation=nothing`: nothing by default. If you want to use a different distribution than a - `Normal` for the perturbation `z`, give it here as a distribution-like object implementing - the `rand` method. It should also implement `logdensityof` if `grad_logdensity` is not given. -- `grad_logdensity=nothing`: gradient function of `perturbation` w.r.t. `θ`. - If set to nothing (default), it's computed using automatic differentiation. -- `seed::Union{Nothing,Int}=nothing`: seed of the perturbation. - It is reset each time the forward pass is called, - making it deterministic by always drawing the same perturbations. - If you do not want this behaviour, set this field to `nothing`. -- `rng::AbstractRNG`=MersenneTwister(0): random number generator using the `seed`. +θ + εZ """ -function PerturbedAdditive( - oracle::O; - ε=1.0, - nb_samples=1, - seed::S=nothing, - is_parallel::Bool=false, - perturbation::P=nothing, - grad_logdensity::G=nothing, - rng::R=MersenneTwister(0), -) where {P,G,O,R<:AbstractRNG,S<:Union{Int,Nothing}} - return PerturbedAdditive{P,G,O,R,S,is_parallel}( - perturbation, grad_logdensity, oracle, rng, seed, nb_samples, float(ε) - ) -end - -function sample_perturbations(perturbed::PerturbedAdditive, θ::AbstractArray) - (; rng, seed, nb_samples, perturbation) = perturbed - seed!(rng, seed) - return [rand(rng, perturbation, size(θ)) for _ in 1:nb_samples] -end - -function sample_perturbations(perturbed::PerturbedAdditive{Nothing}, θ::AbstractArray) - (; rng, seed, nb_samples) = perturbed - seed!(rng, seed) - return [randn(rng, size(θ)) for _ in 1:nb_samples] -end - -function compute_probability_distribution_from_samples( - perturbed::PerturbedAdditive, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) - (; ε) = perturbed - η_samples = [θ .+ ε .* Z for Z in Z_samples] - atoms = compute_atoms(perturbed, η_samples; kwargs...) - weights = ones(length(atoms)) ./ length(atoms) - probadist = FixedAtomsProbabilityDistribution(atoms, weights) - return probadist -end - -function perturbation_grad_logdensity( - ::RuleConfig, - perturbed::PerturbedAdditive{Nothing,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) - (; ε) = perturbed - return Z ./ ε -end - -function _perturbation_logdensity( - perturbed::PerturbedAdditive, θ::AbstractArray, η::AbstractArray -) - (; ε, perturbation) = perturbed - Z = (η .- θ) ./ ε - return sum(logdensityof(perturbation, z) for z in Z) -end - -function perturbation_grad_logdensity( - rc::RuleConfig, - perturbed::PerturbedAdditive{P,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) where {P} - (; ε) = perturbed - η = θ .+ ε .* Z - l, logdensity_pullback = rrule_via_ad(rc, _perturbation_logdensity, perturbed, θ, η) - δperturbation_logdensity, δperturbed, δθ, δη = logdensity_pullback(one(l)) - return δθ +function (pdc::AdditivePerturbation)(θ::AbstractArray) + (; perturbation_dist, ε) = pdc + return product_distribution(θ .+ ε * perturbation_dist) end diff --git a/src/layers/perturbed/multiplicative.jl b/src/layers/perturbed/multiplicative.jl index f96a333..8fe9b69 100644 --- a/src/layers/perturbed/multiplicative.jl +++ b/src/layers/perturbed/multiplicative.jl @@ -1,130 +1,25 @@ -""" - PerturbedMultiplicative{P,G,O,R,S,parallel} <: AbstractPerturbed{parallel} - -Differentiable multiplicative perturbation of a black-box oracle: -the input undergoes `θ -> θ ⊙ exp[εZ - ε²/2]` where `Z ∼ perturbation`. - -This [`AbstractOptimizationLayer`](@ref) is compatible with [`FenchelYoungLoss`](@ref), -if the oracle is an optimization maximizer with a linear objective. - -Reference: - -See [`AbstractPerturbed`](@ref) for more details. +struct ExponentialOf{D<:ContinuousUnivariateDistribution} <: + ContinuousUnivariateDistribution + dist::D +end -# Specific field -- `ε:Float64`: size of the perturbation -""" -struct PerturbedMultiplicative{P,G,O,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: - AbstractPerturbed{O,parallel} - perturbation::P - grad_logdensity::G - oracle::O - rng::R - seed::S - nb_samples::Int - ε::Float64 +function Random.rand(rng::AbstractRNG, d::ExponentialOf) + return exp(rand(rng, d.dist)) end -function Base.show(io::IO, perturbed::PerturbedMultiplicative) - (; oracle, ε, rng, seed, nb_samples, perturbation) = perturbed - perturb = isnothing(perturbation) ? "Normal(0, 1)" : "$perturbation" - return print( - io, - "PerturbedMultiplicative($oracle, $ε, $nb_samples, $(typeof(rng)), $seed, $perturb)", - ) +function Distributions.logpdf(d::ExponentialOf, x::Real) + return logpdf(d.dist, log(x)) - log(x) end """ - PerturbedMultiplicative(oracle[; ε, nb_samples, seed, is_parallel, perturbation, grad_logdensity, rng]) - -[`PerturbedMultiplicative`](@ref) constructor. - -# Arguments -- `oracle`: the black-box oracle we want to differentiate through. - It should be a linear maximizer if you want to use it inside a [`FenchelYoungLoss`]. - -# Keyword arguments (optional) -- `ε=1.0`: size of the perturbation. -- `nb_samples::Int=1`: number of perturbation samples drawn at each forward pass. -- `perturbation=nothing`: nothing by default. If you want to use a different distribution than a - `Normal` for the perturbation `z`, give it here as a distribution-like object implementing - the `rand` method. It should also implement `logdensityof` if `grad_logdensity` is not given. -- `grad_logdensity=nothing`: gradient function of `perturbation` w.r.t. `θ`. - If set to nothing (default), it's computed using automatic differentiation. -- `seed::Union{Nothing,Int}=nothing`: seed of the perturbation. - It is reset each time the forward pass is called, - making it deterministic by always drawing the same perturbations. - If you do not want this behaviour, set this field to `nothing`. -- `rng::AbstractRNG`=MersenneTwister(0): random number generator using the `seed`. +θ ⊙ exp(εZ - ε²/2) """ -function PerturbedMultiplicative( - oracle::O; - ε=1.0, - nb_samples=1, - seed::S=nothing, - is_parallel=false, - perturbation::P=nothing, - grad_logdensity::G=nothing, - rng::R=MersenneTwister(0), -) where {P,G,O,R<:AbstractRNG,S<:Union{Int,Nothing}} - return PerturbedMultiplicative{P,G,O,R,S,is_parallel}( - perturbation, grad_logdensity, oracle, rng, seed, nb_samples, float(ε) - ) -end - -function sample_perturbations(perturbed::PerturbedMultiplicative, θ::AbstractArray) - (; rng, seed, nb_samples, perturbation) = perturbed - seed!(rng, seed) - return [rand(rng, perturbation, size(θ)) for _ in 1:nb_samples] -end - -function sample_perturbations(perturbed::PerturbedMultiplicative{Nothing}, θ::AbstractArray) - (; rng, seed, nb_samples) = perturbed - seed!(rng, seed) - return [randn(rng, size(θ)) for _ in 1:nb_samples] -end - -function compute_probability_distribution_from_samples( - perturbed::PerturbedMultiplicative, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) - (; ε) = perturbed - η_samples = [θ .* exp.(ε .* Z .- ε^2 / 2) for Z in Z_samples] - atoms = compute_atoms(perturbed, η_samples; kwargs...) - weights = ones(length(atoms)) ./ length(atoms) - probadist = FixedAtomsProbabilityDistribution(atoms, weights) - return probadist -end - -function perturbation_grad_logdensity( - ::RuleConfig, - perturbed::PerturbedMultiplicative{Nothing,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) - (; ε) = perturbed - return inv.(ε .* θ) .* Z -end - -function _perturbation_logdensity( - perturbed::PerturbedMultiplicative, θ::AbstractArray, η::AbstractArray -) - (; ε, perturbation) = perturbed - Z = (log.(η) .- log.(θ)) ./ ε .+ ε / 2 - return sum(logdensityof(perturbation, z) for z in Z) +struct MultiplicativePerturbation{F} + perturbation_dist::F + ε::Float64 end -function perturbation_grad_logdensity( - rc::RuleConfig, - perturbed::PerturbedMultiplicative{P,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) where {P} - (; ε) = perturbed - η = θ .* exp.(ε .* Z .- ε^2 / 2) - l, logdensity_pullback = rrule_via_ad(rc, _perturbation_logdensity, perturbed, θ, η) - δperturbation_logdensity, δperturbed, δθ, δη = logdensity_pullback(one(l)) - return δθ +function (pdc::MultiplicativePerturbation)(θ::AbstractArray) + (; perturbation_dist, ε) = pdc + return product_distribution(θ .* ExponentialOf(ε * perturbation_dist - ε^2 / 2)) end diff --git a/src/layers/perturbed/perturbed.jl b/src/layers/perturbed/perturbed.jl new file mode 100644 index 0000000..e063e4a --- /dev/null +++ b/src/layers/perturbed/perturbed.jl @@ -0,0 +1,65 @@ +struct Perturbed{R<:Reinforce} <: AbstractOptimizationLayer + reinforce::R +end + +function (perturbed::Perturbed)(θ::AbstractArray) + return perturbed.reinforce(θ) +end + +function is_additive(perturbed::Perturbed) + return isa(perturbed.reinforce.dist_constructor, AdditivePerturbation) +end + +function is_multiplicative(perturbed::Perturbed) + return isa(perturbed.reinforce.dist_constructor, MultiplicativePerturbation) +end + +function Base.show(io::IO, perturbed::Perturbed) + (; reinforce) = perturbed + nb_samples = reinforce.nb_samples + ε = reinforce.dist_constructor.ε + seed = reinforce.seed + rng = reinforce.rng + perturbation = reinforce.dist_constructor.perturbation_dist + f = reinforce.f + return print( + io, + "Perturbed($f, ε=$ε, nb_samples=$nb_samples, perturbation=$perturbation, rng=$(typeof(rng)), seed=$seed)", + ) +end + +function PerturbedAdditive( + maximizer; + ε=1.0, + perturbation_dist=Normal(0, 1), + nb_samples=1, + variance_reduction=true, + seed=nothing, + threaded=false, + rng=Random.default_rng(), +) + dist_constructor = AdditivePerturbation(perturbation_dist, float(ε)) + return Perturbed( + Reinforce( + maximizer, dist_constructor; variance_reduction, seed, threaded, rng, nb_samples + ), + ) +end + +function PerturbedMultiplicative( + maximizer; + ε=1.0, + perturbation_dist=Normal(0, 1), + nb_samples=1, + variance_reduction=true, + seed=nothing, + threaded=false, + rng=Random.default_rng(), +) + dist_constructor = MultiplicativePerturbation(perturbation_dist, float(ε)) + return Perturbed( + Reinforce( + maximizer, dist_constructor; variance_reduction, seed, threaded, rng, nb_samples + ), + ) +end diff --git a/src/losses/_fenchel_young_loss.jl b/src/losses/_fenchel_young_loss.jl new file mode 100644 index 0000000..d82c104 --- /dev/null +++ b/src/losses/_fenchel_young_loss.jl @@ -0,0 +1,169 @@ +""" + FenchelYoungLoss <: AbstractLossLayer + +Fenchel-Young loss associated with a given optimization layer. +``` +L(θ, y_true) = (Ω(y_true) - θᵀy_true) - (Ω(ŷ) - θᵀŷ) +``` + +Reference: + +# Fields + +- `optimization_layer::AbstractOptimizationLayer`: optimization layer that can be formulated as `ŷ(θ) = argmax {θᵀy - Ω(y)}` (either regularized or perturbed) +""" +struct FenchelYoungLoss{O<:AbstractOptimizationLayer} <: AbstractLossLayer + optimization_layer::O +end + +function Base.show(io::IO, fyl::FenchelYoungLoss) + (; optimization_layer) = fyl + return print(io, "FenchelYoungLoss($optimization_layer)") +end + +## Forward pass + +""" + (fyl::FenchelYoungLoss)(θ, y_true[; kwargs...]) +""" +function (fyl::FenchelYoungLoss)(θ::AbstractArray, y_true::AbstractArray; kwargs...) + l, _ = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) + return l +end + +function fenchel_young_loss_and_grad( + fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... +) where {O<:AbstractRegularized} + (; optimization_layer) = fyl + ŷ = optimization_layer(θ; kwargs...) + Ωy_true = compute_regularization(optimization_layer, y_true) + Ωŷ = compute_regularization(optimization_layer, ŷ) + l = (Ωy_true - dot(θ, y_true)) - (Ωŷ - dot(θ, ŷ)) + g = ŷ - y_true + return l, g +end + +function fenchel_young_loss_and_grad( + fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... +) where {O<:AbstractRegularizedGeneralizedMaximizer} + (; optimization_layer) = fyl + ŷ = optimization_layer(θ; kwargs...) + Ωy_true = compute_regularization(optimization_layer, y_true) + Ωŷ = compute_regularization(optimization_layer, ŷ) + maximizer = get_maximizer(optimization_layer) + l = + (Ωy_true - objective_value(maximizer, θ, y_true; kwargs...)) - + (Ωŷ - objective_value(maximizer, θ, ŷ; kwargs...)) + g = maximizer.g(ŷ; kwargs...) - maximizer.g(y_true; kwargs...) + return l, g +end + +function fenchel_young_loss_and_grad( + fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... +) where {O<:AbstractPerturbed} + (; optimization_layer) = fyl + F, almost_ŷ = fenchel_young_F_and_first_part_of_grad(optimization_layer, θ; kwargs...) + l = F - dot(θ, y_true) + g = almost_ŷ - y_true + return l, g +end + +function fenchel_young_loss_and_grad( + fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs... +) where {P<:AbstractPerturbed{<:GeneralizedMaximizer}} + (; optimization_layer) = fyl + F, almost_g_of_ŷ = fenchel_young_F_and_first_part_of_grad( + optimization_layer, θ; kwargs... + ) + l = F - objective_value(optimization_layer.oracle, θ, y_true; kwargs...) + g = almost_g_of_ŷ - optimization_layer.oracle.g(y_true; kwargs...) + return l, g +end + +## Backward pass + +function ChainRulesCore.rrule( + fyl::FenchelYoungLoss, θ::AbstractArray, y_true::AbstractArray; kwargs... +) + l, g = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) + fyl_pullback(dl) = NoTangent(), dl * g, NoTangent() + return l, fyl_pullback +end + +## Specific overrides for perturbed layers + +function compute_F_and_y_samples( + perturbed::AbstractPerturbed{O,false}, + θ::AbstractArray, + Z_samples::Vector{<:AbstractArray}; + kwargs..., +) where {O} + F_and_y_samples = [ + fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...) for + Z in Z_samples + ] + return F_and_y_samples +end + +function compute_F_and_y_samples( + perturbed::AbstractPerturbed{O,true}, + θ::AbstractArray, + Z_samples::Vector{<:AbstractArray}; + kwargs..., +) where {O} + return ThreadsX.map( + Z -> fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...), Z_samples + ) +end + +function fenchel_young_F_and_first_part_of_grad( + perturbed::AbstractPerturbed, θ::AbstractArray; kwargs... +) + Z_samples = sample_perturbations(perturbed, θ) + F_and_y_samples = compute_F_and_y_samples(perturbed, θ, Z_samples; kwargs...) + return mean(first, F_and_y_samples), mean(last, F_and_y_samples) +end + +function fenchel_young_F_and_first_part_of_grad( + perturbed::PerturbedAdditive, θ::AbstractArray, Z::AbstractArray; kwargs... +) + (; oracle, ε) = perturbed + η = θ .+ ε .* Z + y = oracle(η; kwargs...) + F = dot(η, y) + return F, y +end + +function fenchel_young_F_and_first_part_of_grad( + perturbed::PerturbedAdditive{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... +) where {P,G,O<:GeneralizedMaximizer} + (; oracle, ε) = perturbed + η = θ .+ ε .* Z + y = oracle(η; kwargs...) + F = objective_value(oracle, η, y; kwargs...) + return F, oracle.g(y; kwargs...) +end + +function fenchel_young_F_and_first_part_of_grad( + perturbed::PerturbedMultiplicative, θ::AbstractArray, Z::AbstractArray; kwargs... +) + (; oracle, ε) = perturbed + eZ = exp.(ε .* Z .- ε^2 ./ 2) + η = θ .* eZ + y = oracle(η; kwargs...) + F = dot(η, y) + y_scaled = y .* eZ + return F, y_scaled +end + +function fenchel_young_F_and_first_part_of_grad( + perturbed::PerturbedMultiplicative{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... +) where {P,G,O<:GeneralizedMaximizer} + (; oracle, ε) = perturbed + eZ = exp.(ε .* Z .- ε^2) + η = θ .* eZ + y = oracle(η; kwargs...) + F = objective_value(oracle, η, y; kwargs...) + y_scaled = y .* eZ + return F, oracle.g(y_scaled; kwargs...) +end diff --git a/src/losses/fenchel_young_loss.jl b/src/losses/fenchel_young_loss.jl index d82c104..c381b71 100644 --- a/src/losses/fenchel_young_loss.jl +++ b/src/losses/fenchel_young_loss.jl @@ -60,7 +60,7 @@ end function fenchel_young_loss_and_grad( fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {O<:AbstractPerturbed} +) where {O<:Perturbed} (; optimization_layer) = fyl F, almost_ŷ = fenchel_young_F_and_first_part_of_grad(optimization_layer, θ; kwargs...) l = F - dot(θ, y_true) @@ -68,17 +68,17 @@ function fenchel_young_loss_and_grad( return l, g end -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {P<:AbstractPerturbed{<:GeneralizedMaximizer}} - (; optimization_layer) = fyl - F, almost_g_of_ŷ = fenchel_young_F_and_first_part_of_grad( - optimization_layer, θ; kwargs... - ) - l = F - objective_value(optimization_layer.oracle, θ, y_true; kwargs...) - g = almost_g_of_ŷ - optimization_layer.oracle.g(y_true; kwargs...) - return l, g -end +# function fenchel_young_loss_and_grad( +# fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs... +# ) where {P<:AbstractPerturbed{<:GeneralizedMaximizer}} +# (; optimization_layer) = fyl +# F, almost_g_of_ŷ = fenchel_young_F_and_first_part_of_grad( +# optimization_layer, θ; kwargs... +# ) +# l = F - objective_value(optimization_layer.oracle, θ, y_true; kwargs...) +# g = almost_g_of_ŷ - optimization_layer.oracle.g(y_true; kwargs...) +# return l, g +# end ## Backward pass @@ -92,78 +92,80 @@ end ## Specific overrides for perturbed layers -function compute_F_and_y_samples( - perturbed::AbstractPerturbed{O,false}, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) where {O} - F_and_y_samples = [ - fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...) for - Z in Z_samples - ] - return F_and_y_samples -end - -function compute_F_and_y_samples( - perturbed::AbstractPerturbed{O,true}, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) where {O} - return ThreadsX.map( - Z -> fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...), Z_samples - ) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::AbstractPerturbed, θ::AbstractArray; kwargs... -) - Z_samples = sample_perturbations(perturbed, θ) - F_and_y_samples = compute_F_and_y_samples(perturbed, θ, Z_samples; kwargs...) - return mean(first, F_and_y_samples), mean(last, F_and_y_samples) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedAdditive, θ::AbstractArray, Z::AbstractArray; kwargs... -) - (; oracle, ε) = perturbed - η = θ .+ ε .* Z - y = oracle(η; kwargs...) - F = dot(η, y) - return F, y -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedAdditive{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... -) where {P,G,O<:GeneralizedMaximizer} - (; oracle, ε) = perturbed - η = θ .+ ε .* Z - y = oracle(η; kwargs...) - F = objective_value(oracle, η, y; kwargs...) - return F, oracle.g(y; kwargs...) -end +# function compute_F_and_y_samples( +# perturbed::AbstractPerturbed{O,false}, +# θ::AbstractArray, +# Z_samples::Vector{<:AbstractArray}; +# kwargs..., +# ) where {O} +# F_and_y_samples = [ +# fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...) for +# Z in Z_samples +# ] +# return F_and_y_samples +# end + +# function compute_F_and_y_samples( +# perturbed::AbstractPerturbed{O,true}, +# θ::AbstractArray, +# Z_samples::Vector{<:AbstractArray}; +# kwargs..., +# ) where {O} +# return ThreadsX.map( +# Z -> fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...), Z_samples +# ) +# end function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedMultiplicative, θ::AbstractArray, Z::AbstractArray; kwargs... + perturbed::Perturbed, θ::AbstractArray; kwargs... ) - (; oracle, ε) = perturbed - eZ = exp.(ε .* Z .- ε^2 ./ 2) - η = θ .* eZ - y = oracle(η; kwargs...) - F = dot(η, y) - y_scaled = y .* eZ - return F, y_scaled + (; reinforce) = perturbed + η_dist = empirical_predistribution(reinforce, θ) + fk = FixKwargs(reinforce.f, kwargs) + y_dist = map(fk, η_dist) + return mean(dot(η, y) for (η, y) in zip(η_dist.atoms, y_dist.atoms)), mean(y_dist) end -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedMultiplicative{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... -) where {P,G,O<:GeneralizedMaximizer} - (; oracle, ε) = perturbed - eZ = exp.(ε .* Z .- ε^2) - η = θ .* eZ - y = oracle(η; kwargs...) - F = objective_value(oracle, η, y; kwargs...) - y_scaled = y .* eZ - return F, oracle.g(y_scaled; kwargs...) -end +# function fenchel_young_F_and_first_part_of_grad( +# perturbed::PerturbedAdditive, θ::AbstractArray, Z::AbstractArray; kwargs... +# ) +# (; oracle, ε) = perturbed +# η = θ .+ ε .* Z +# y = oracle(η; kwargs...) +# F = dot(η, y) +# return F, y +# end + +# function fenchel_young_F_and_first_part_of_grad( +# perturbed::PerturbedAdditive{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... +# ) where {P,G,O<:GeneralizedMaximizer} +# (; oracle, ε) = perturbed +# η = θ .+ ε .* Z +# y = oracle(η; kwargs...) +# F = objective_value(oracle, η, y; kwargs...) +# return F, oracle.g(y; kwargs...) +# end + +# function fenchel_young_F_and_first_part_of_grad( +# perturbed::PerturbedMultiplicative, θ::AbstractArray, Z::AbstractArray; kwargs... +# ) +# (; oracle, ε) = perturbed +# eZ = exp.(ε .* Z .- ε^2 ./ 2) +# η = θ .* eZ +# y = oracle(η; kwargs...) +# F = dot(η, y) +# y_scaled = y .* eZ +# return F, y_scaled +# end + +# function fenchel_young_F_and_first_part_of_grad( +# perturbed::PerturbedMultiplicative{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... +# ) where {P,G,O<:GeneralizedMaximizer} +# (; oracle, ε) = perturbed +# eZ = exp.(ε .* Z .- ε^2) +# η = θ .* eZ +# y = oracle(η; kwargs...) +# F = objective_value(oracle, η, y; kwargs...) +# y_scaled = y .* eZ +# return F, oracle.g(y_scaled; kwargs...) +# end From e7ff6bc8e3fca5946b2e787a9e477cacbb71ed2f Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Thu, 1 Aug 2024 21:49:41 +0200 Subject: [PATCH 05/26] Most things seem to work again --- .gitignore | 1 + Project.toml | 1 + ext/InferOptFrankWolfeExt.jl | 7 +- src/InferOpt.jl | 30 ++++-- src/layers/perturbed/multiplicative.jl | 18 +--- src/layers/perturbed/perturbation.jl | 32 ++++++ src/layers/perturbed/perturbed.jl | 30 +++--- src/layers/perturbed/utils.jl | 12 +++ .../regularized/regularized_frank_wolfe.jl | 4 +- src/losses/fenchel_young_loss.jl | 102 ++++++++++-------- src/utils/probability_distribution.jl | 85 --------------- src/utils/pushforward.jl | 23 +--- test/learning_generalized_maximizer.jl | 4 +- test/learning_ranking.jl | 66 ++++++------ 14 files changed, 190 insertions(+), 225 deletions(-) create mode 100644 src/layers/perturbed/perturbation.jl create mode 100644 src/layers/perturbed/utils.jl delete mode 100644 src/utils/probability_distribution.jl diff --git a/.gitignore b/.gitignore index 11feb24..e745a93 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.DS_Store *.jl.*.cov *.jl.cov *.jl.mem diff --git a/Project.toml b/Project.toml index 55fc110..c8fd56b 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,7 @@ ChainRulesCore = "1" DensityInterface = "0.4.0" DifferentiableExpectations = "0.1" DifferentiableFrankWolfe = "0.2" +Distributions = "0.25" LinearAlgebra = "<0.0.1,1" Random = "<0.0.1,1" RequiredInterfaces = "0.1.3" diff --git a/ext/InferOptFrankWolfeExt.jl b/ext/InferOptFrankWolfeExt.jl index d96a449..cc2c3dc 100644 --- a/ext/InferOptFrankWolfeExt.jl +++ b/ext/InferOptFrankWolfeExt.jl @@ -1,10 +1,11 @@ module InferOptFrankWolfeExt +using DifferentiableExpectations: + DifferentiableExpectations, FixedAtomsProbabilityDistribution using DifferentiableFrankWolfe: DifferentiableFrankWolfe, DiffFW using DifferentiableFrankWolfe: LinearMinimizationOracle # from FrankWolfe using DifferentiableFrankWolfe: IterativeLinearSolver # from ImplicitDifferentiation -using InferOpt: InferOpt, RegularizedFrankWolfe, FixedAtomsProbabilityDistribution -using InferOpt: compute_expectation, compute_probability_distribution +using InferOpt: InferOpt, RegularizedFrankWolfe using LinearAlgebra: dot """ @@ -37,7 +38,7 @@ Construct a `DifferentiableFrankWolfe.DiffFW` struct and call `compute_probabili Keyword arguments are passed to the underlying linear maximizer. """ -function InferOpt.compute_probability_distribution( +function DifferentiableExpectations.empirical_distribution( regularized::RegularizedFrankWolfe, θ::AbstractArray; kwargs... ) (; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 44607b6..767b012 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -10,9 +10,19 @@ module InferOpt using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, Tangent, ZeroTangent using ChainRulesCore: rrule, rrule_via_ad, unthunk using DensityInterface: logdensityof -using DifferentiableExpectations: Reinforce, empirical_predistribution, FixKwargs +using DifferentiableExpectations: + DifferentiableExpectations, + Reinforce, + empirical_predistribution, + empirical_distribution, + FixKwargs using Distributions: - Distributions, ContinuousUnivariateDistribution, LogNormal, Normal, product_distribution + Distributions, + ContinuousUnivariateDistribution, + LogNormal, + Normal, + product_distribution, + logpdf using LinearAlgebra: dot using Random: Random, AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed! using Statistics: mean @@ -24,7 +34,6 @@ using RequiredInterfaces include("interface.jl") include("utils/some_functions.jl") -include("utils/probability_distribution.jl") include("utils/pushforward.jl") include("utils/generalized_maximizer.jl") include("utils/isotonic_regression/isotonic_l2.jl") @@ -35,9 +44,10 @@ include("utils/isotonic_regression/projection.jl") include("layers/simple/interpolation.jl") include("layers/simple/identity.jl") -# include("layers/perturbed/abstract_perturbed.jl") -include("layers/perturbed/additive.jl") -include("layers/perturbed/multiplicative.jl") +include("layers/perturbed/utils.jl") +include("layers/perturbed/perturbation.jl") +# include("layers/perturbed/additive.jl") +# include("layers/perturbed/multiplicative.jl") include("layers/perturbed/perturbed.jl") include("layers/regularized/abstract_regularized.jl") @@ -62,9 +72,9 @@ export shannon_entropy, negative_shannon_entropy export one_hot_argmax, ranking export GeneralizedMaximizer, objective_value -export FixedAtomsProbabilityDistribution -export compute_expectation -export compute_probability_distribution +# export FixedAtomsProbabilityDistribution +# export compute_expectation +# export compute_probability_distribution export Pushforward export IdentityRelaxation @@ -79,7 +89,7 @@ export RegularizedFrankWolfe export PerturbedAdditive export PerturbedMultiplicative -export PerturbedOracle +# export PerturbedOracle export FenchelYoungLoss export StructuredSVMLoss diff --git a/src/layers/perturbed/multiplicative.jl b/src/layers/perturbed/multiplicative.jl index 8fe9b69..4b139fb 100644 --- a/src/layers/perturbed/multiplicative.jl +++ b/src/layers/perturbed/multiplicative.jl @@ -1,16 +1,3 @@ -struct ExponentialOf{D<:ContinuousUnivariateDistribution} <: - ContinuousUnivariateDistribution - dist::D -end - -function Random.rand(rng::AbstractRNG, d::ExponentialOf) - return exp(rand(rng, d.dist)) -end - -function Distributions.logpdf(d::ExponentialOf, x::Real) - return logpdf(d.dist, log(x)) - log(x) -end - """ θ ⊙ exp(εZ - ε²/2) """ @@ -23,3 +10,8 @@ function (pdc::MultiplicativePerturbation)(θ::AbstractArray) (; perturbation_dist, ε) = pdc return product_distribution(θ .* ExponentialOf(ε * perturbation_dist - ε^2 / 2)) end + +function Random.rand(rng::AbstractRNG, perturbation::MultiplicativePerturbation) + (; perturbation_dist, ε) = perturbation + return rand(rng, perturbation_dist) +end diff --git a/src/layers/perturbed/perturbation.jl b/src/layers/perturbed/perturbation.jl new file mode 100644 index 0000000..d3779a9 --- /dev/null +++ b/src/layers/perturbed/perturbation.jl @@ -0,0 +1,32 @@ +abstract type AbstractPerturbation end + +struct AdditivePerturbation{F} + perturbation_dist::F + ε::Float64 +end + +""" +θ + εZ +""" +function (pdc::AdditivePerturbation)(θ::AbstractArray) + (; perturbation_dist, ε) = pdc + return product_distribution(θ .+ ε * perturbation_dist) +end + +""" +θ ⊙ exp(εZ - ε²/2) +""" +struct MultiplicativePerturbation{F} + perturbation_dist::F + ε::Float64 +end + +function (pdc::MultiplicativePerturbation)(θ::AbstractArray) + (; perturbation_dist, ε) = pdc + return product_distribution(θ .* ExponentialOf(ε * perturbation_dist - ε^2 / 2)) +end + +function Random.rand(rng::AbstractRNG, perturbation::MultiplicativePerturbation) + (; perturbation_dist, ε) = perturbation + return rand(rng, perturbation_dist) +end diff --git a/src/layers/perturbed/perturbed.jl b/src/layers/perturbed/perturbed.jl index e063e4a..ac877a8 100644 --- a/src/layers/perturbed/perturbed.jl +++ b/src/layers/perturbed/perturbed.jl @@ -1,17 +1,15 @@ -struct Perturbed{R<:Reinforce} <: AbstractOptimizationLayer - reinforce::R +struct Perturbed{F,D,t,variance_reduction,G,R,S} <: AbstractOptimizationLayer + reinforce::Reinforce{t,variance_reduction,F,D,G,R,S} end -function (perturbed::Perturbed)(θ::AbstractArray) - return perturbed.reinforce(θ) +function (perturbed::Perturbed)(θ::AbstractArray; kwargs...) + return perturbed.reinforce(θ; kwargs...) end -function is_additive(perturbed::Perturbed) - return isa(perturbed.reinforce.dist_constructor, AdditivePerturbation) -end - -function is_multiplicative(perturbed::Perturbed) - return isa(perturbed.reinforce.dist_constructor, MultiplicativePerturbation) +function DifferentiableExpectations.empirical_distribution( + perturbed::Perturbed, θ::AbstractArray; kwargs... +) + return empirical_distribution(perturbed.reinforce, θ; kwargs...) end function Base.show(io::IO, perturbed::Perturbed) @@ -28,6 +26,10 @@ function Base.show(io::IO, perturbed::Perturbed) ) end +function Perturbed(maximizer, dist_constructor; kwargs...) + return Perturbed(Reinforce(maximizer, dist_constructor; kwargs...)) +end + function PerturbedAdditive( maximizer; ε=1.0, @@ -40,9 +42,7 @@ function PerturbedAdditive( ) dist_constructor = AdditivePerturbation(perturbation_dist, float(ε)) return Perturbed( - Reinforce( - maximizer, dist_constructor; variance_reduction, seed, threaded, rng, nb_samples - ), + maximizer, dist_constructor; nb_samples, variance_reduction, seed, threaded, rng ) end @@ -58,8 +58,6 @@ function PerturbedMultiplicative( ) dist_constructor = MultiplicativePerturbation(perturbation_dist, float(ε)) return Perturbed( - Reinforce( - maximizer, dist_constructor; variance_reduction, seed, threaded, rng, nb_samples - ), + maximizer, dist_constructor; nb_samples, variance_reduction, seed, threaded, rng ) end diff --git a/src/layers/perturbed/utils.jl b/src/layers/perturbed/utils.jl new file mode 100644 index 0000000..3f6dc38 --- /dev/null +++ b/src/layers/perturbed/utils.jl @@ -0,0 +1,12 @@ +struct ExponentialOf{D<:ContinuousUnivariateDistribution} <: + ContinuousUnivariateDistribution + dist::D +end + +function Random.rand(rng::AbstractRNG, d::ExponentialOf) + return exp(rand(rng, d.dist)) +end + +function Distributions.logpdf(d::ExponentialOf, x::Real) + return logpdf(d.dist, log(x)) - log(x) +end diff --git a/src/layers/regularized/regularized_frank_wolfe.jl b/src/layers/regularized/regularized_frank_wolfe.jl index 8f68887..e7d05bf 100644 --- a/src/layers/regularized/regularized_frank_wolfe.jl +++ b/src/layers/regularized/regularized_frank_wolfe.jl @@ -60,6 +60,6 @@ end Apply `compute_probability_distribution(regularized, θ; kwargs...)` and return the expectation. """ function (regularized::RegularizedFrankWolfe)(θ::AbstractArray; kwargs...) - probadist = compute_probability_distribution(regularized, θ; kwargs...) - return compute_expectation(probadist) + probadist = empirical_distribution(regularized, θ; kwargs...) + return mean(probadist) end diff --git a/src/losses/fenchel_young_loss.jl b/src/losses/fenchel_young_loss.jl index c381b71..fea01a8 100644 --- a/src/losses/fenchel_young_loss.jl +++ b/src/losses/fenchel_young_loss.jl @@ -68,17 +68,20 @@ function fenchel_young_loss_and_grad( return l, g end -# function fenchel_young_loss_and_grad( -# fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs... -# ) where {P<:AbstractPerturbed{<:GeneralizedMaximizer}} -# (; optimization_layer) = fyl -# F, almost_g_of_ŷ = fenchel_young_F_and_first_part_of_grad( -# optimization_layer, θ; kwargs... -# ) -# l = F - objective_value(optimization_layer.oracle, θ, y_true; kwargs...) -# g = almost_g_of_ŷ - optimization_layer.oracle.g(y_true; kwargs...) -# return l, g -# end +function fenchel_young_loss_and_grad( + fyl::FenchelYoungLoss{<:Perturbed{<:GeneralizedMaximizer}}, + θ::AbstractArray, + y_true::AbstractArray; + kwargs..., +) + (; optimization_layer) = fyl + F, almost_g_of_ŷ = fenchel_young_F_and_first_part_of_grad( + optimization_layer, θ; kwargs... + ) + l = F - objective_value(optimization_layer.reinforce.f, θ, y_true; kwargs...) + g = almost_g_of_ŷ - optimization_layer.reinforce.f.g(y_true; kwargs...) + return l, g +end ## Backward pass @@ -117,8 +120,8 @@ end # end function fenchel_young_F_and_first_part_of_grad( - perturbed::Perturbed, θ::AbstractArray; kwargs... -) + perturbed::Perturbed{F,<:AdditivePerturbation}, θ::AbstractArray; kwargs... +) where {F} (; reinforce) = perturbed η_dist = empirical_predistribution(reinforce, θ) fk = FixKwargs(reinforce.f, kwargs) @@ -126,37 +129,52 @@ function fenchel_young_F_and_first_part_of_grad( return mean(dot(η, y) for (η, y) in zip(η_dist.atoms, y_dist.atoms)), mean(y_dist) end -# function fenchel_young_F_and_first_part_of_grad( -# perturbed::PerturbedAdditive, θ::AbstractArray, Z::AbstractArray; kwargs... -# ) -# (; oracle, ε) = perturbed -# η = θ .+ ε .* Z -# y = oracle(η; kwargs...) -# F = dot(η, y) -# return F, y -# end +function fenchel_young_F_and_first_part_of_grad( + perturbed::Perturbed{F,<:MultiplicativePerturbation}, θ::AbstractArray; kwargs... +) where {F} + (; reinforce) = perturbed + η_dist = empirical_predistribution(reinforce, θ) + fk = FixKwargs(reinforce.f, kwargs) + y_dist = map(fk, η_dist) + eZ_dist = map(Base.Fix2(./, θ), η_dist) + return mean(dot(η, y) for (η, y) in zip(η_dist.atoms, y_dist.atoms)), + mean(map(.*, eZ_dist.atoms, y_dist.atoms)) +end -# function fenchel_young_F_and_first_part_of_grad( -# perturbed::PerturbedAdditive{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... -# ) where {P,G,O<:GeneralizedMaximizer} -# (; oracle, ε) = perturbed -# η = θ .+ ε .* Z -# y = oracle(η; kwargs...) -# F = objective_value(oracle, η, y; kwargs...) -# return F, oracle.g(y; kwargs...) -# end +function fenchel_young_F_and_first_part_of_grad( + perturbed::Perturbed{<:GeneralizedMaximizer,<:AdditivePerturbation}, + θ::AbstractArray; + kwargs..., +) + (; reinforce) = perturbed + η_dist = empirical_predistribution(reinforce, θ) + fk = FixKwargs(reinforce.f, kwargs) + gk = FixKwargs(reinforce.f.g, kwargs) + y_dist = map(fk, η_dist) + return mean( + objective_value(reinforce.f, η, y; kwargs...) for + (η, y) in zip(η_dist.atoms, y_dist.atoms) + ), + mean(gk, y_dist) +end -# function fenchel_young_F_and_first_part_of_grad( -# perturbed::PerturbedMultiplicative, θ::AbstractArray, Z::AbstractArray; kwargs... -# ) -# (; oracle, ε) = perturbed -# eZ = exp.(ε .* Z .- ε^2 ./ 2) -# η = θ .* eZ -# y = oracle(η; kwargs...) -# F = dot(η, y) -# y_scaled = y .* eZ -# return F, y_scaled -# end +function fenchel_young_F_and_first_part_of_grad( + perturbed::Perturbed{<:GeneralizedMaximizer,<:MultiplicativePerturbation}, + θ::AbstractArray; + kwargs..., +) + (; reinforce) = perturbed + η_dist = empirical_predistribution(reinforce, θ) + eZ_dist = map(Base.Fix2(./, θ), η_dist) + fk = FixKwargs(reinforce.f, kwargs) + gk = FixKwargs(reinforce.f.g, kwargs) + y_dist = map(fk, η_dist) + return mean( + objective_value(reinforce.f, η, y; kwargs...) for + (η, y) in zip(η_dist.atoms, y_dist.atoms) + ), + mean(gk.(map(.*, eZ_dist.atoms, y_dist.atoms))) +end # function fenchel_young_F_and_first_part_of_grad( # perturbed::PerturbedMultiplicative{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... diff --git a/src/utils/probability_distribution.jl b/src/utils/probability_distribution.jl deleted file mode 100644 index 0fdc21a..0000000 --- a/src/utils/probability_distribution.jl +++ /dev/null @@ -1,85 +0,0 @@ -""" - FixedAtomsProbabilityDistribution{A,W} - -Encodes a probability distribution with finite support and fixed atoms. - -See [`compute_expectation`](@ref) to understand the name of this struct. - -# Fields -- `atoms::Vector{A}`: elements of the support -- `weights::Vector{W}`: probability values for each atom (must sum to 1) -""" -struct FixedAtomsProbabilityDistribution{A,W} - atoms::Vector{A} - weights::Vector{W} - - function FixedAtomsProbabilityDistribution( - atoms::Vector{A}, weights::Vector{W} - ) where {A,W} - @assert length(atoms) == length(weights) > 0 - @assert isapprox(sum(weights), one(W); atol=1e-4) - return new{A,W}(atoms, weights) - end -end - -Base.length(probadist::FixedAtomsProbabilityDistribution) = length(probadist.atoms) - -""" - rand([rng,] probadist) - -Sample from the atoms of `probadist` according to their weights. -""" -function Base.rand(rng::AbstractRNG, probadist::FixedAtomsProbabilityDistribution) - (; atoms, weights) = probadist - return sample(rng, atoms, StatsBase.Weights(weights)) -end - -Base.rand(probadist::FixedAtomsProbabilityDistribution) = rand(GLOBAL_RNG, probadist) - -""" - apply_on_atoms(post_processing, probadist) - -Create a new distribution by applying the function `post_processing` to each atom of `probadist` (the weights remain the same). -""" -function apply_on_atoms( - post_processing, probadist::FixedAtomsProbabilityDistribution; kwargs... -) - (; atoms, weights) = probadist - post_processed_atoms = [post_processing(a; kwargs...) for a in atoms] - return FixedAtomsProbabilityDistribution(post_processed_atoms, weights) -end - -""" - compute_expectation(probadist[, post_processing=identity]) - -Compute the expectation of `post_processing(X)` where `X` is a random variable distributed according to `probadist`. - -This operation is made differentiable thanks to a custom reverse rule, even when `post_processing` itself is not a differentiable function. - -!!! warning "Warning" - Derivatives are computed with respect to `probadist.weights` only, assuming that `probadist.atoms` doesn't change (hence the name [`FixedAtomsProbabilityDistribution`](@ref)). -""" -function compute_expectation( - probadist::FixedAtomsProbabilityDistribution, post_processing=identity; kwargs... -) - (; atoms, weights) = probadist - return sum(w * post_processing(a; kwargs...) for (w, a) in zip(weights, atoms)) -end - -function ChainRulesCore.rrule( - ::typeof(compute_expectation), - probadist::FixedAtomsProbabilityDistribution, - post_processing=identity; - kwargs..., -) - e = compute_expectation(probadist, post_processing; kwargs...) - function expectation_pullback(de) - d_atoms = NoTangent() - d_weights = [dot(de, post_processing(a; kwargs...)) for a in probadist.atoms] - d_probadist = Tangent{FixedAtomsProbabilityDistribution}(; - atoms=d_atoms, weights=d_weights - ) - return NoTangent(), d_probadist, NoTangent() - end - return e, expectation_pullback -end diff --git a/src/utils/pushforward.jl b/src/utils/pushforward.jl index ddd915b..d27ac0b 100644 --- a/src/utils/pushforward.jl +++ b/src/utils/pushforward.jl @@ -21,33 +21,18 @@ function Base.show(io::IO, pushforward::Pushforward) return print(io, "Pushforward($optimization_layer, $post_processing)") end -""" - compute_probability_distribution(pushforward, θ) - -Output the distribution of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.optimization_layer` applied to `θ`. - -This function is not differentiable if `pushforward.post_processing` isn't. - -See also: [`apply_on_atoms`](@ref). -""" -function compute_probability_distribution(pushforward::Pushforward, θ; kwargs...) - (; optimization_layer, post_processing) = pushforward - probadist = compute_probability_distribution(optimization_layer, θ; kwargs...) - post_processed_probadist = apply_on_atoms(post_processing, probadist; kwargs...) - return post_processed_probadist -end - """ (pushforward::Pushforward)(θ; kwargs...) Output the expectation of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.optimization_layer` applied to `θ`. -Unlike [`compute_probability_distribution(pushforward, θ)`](@ref), this function is differentiable, even if `pushforward.post_processing` isn't. +Unlike [`empirical_distribution(pushforward, θ)`](@ref), this function is differentiable, even if `pushforward.post_processing` isn't. See also: [`compute_expectation`](@ref). """ function (pushforward::Pushforward)(θ::AbstractArray; kwargs...) (; optimization_layer, post_processing) = pushforward - probadist = compute_probability_distribution(optimization_layer, θ; kwargs...) - return compute_expectation(probadist, post_processing; kwargs...) + probadist = empirical_distribution(optimization_layer, θ; kwargs...) + post_processing_kw = FixKwargs(post_processing, kwargs) + return mean(post_processing_kw, probadist) end diff --git a/test/learning_generalized_maximizer.jl b/test/learning_generalized_maximizer.jl index 93f28a9..1c84044 100644 --- a/test/learning_generalized_maximizer.jl +++ b/test/learning_generalized_maximizer.jl @@ -89,7 +89,7 @@ end true_maximizer=max_pricing, maximizer=identity_kw, loss=FenchelYoungLoss( - PerturbedAdditive(generalized_maximizer; ε=1.0, nb_samples=5) + PerturbedAdditive(generalized_maximizer; ε=1.0, nb_samples=10) ), error_function=hamming_distance, cost, @@ -116,7 +116,7 @@ end true_maximizer=max_pricing, maximizer=identity_kw, loss=FenchelYoungLoss( - PerturbedMultiplicative(generalized_maximizer; ε=0.1, nb_samples=5) + PerturbedMultiplicative(generalized_maximizer; ε=0.1, nb_samples=10) ), error_function=hamming_distance, cost, diff --git a/test/learning_ranking.jl b/test/learning_ranking.jl index a6773f9..3f017d7 100644 --- a/test/learning_ranking.jl +++ b/test/learning_ranking.jl @@ -149,7 +149,9 @@ end true_maximizer=ranking, maximizer=identity_kw, loss=FenchelYoungLoss( - PerturbedAdditive(ranking; ε=1.0, nb_samples=5, perturbation=LogNormal(0, 1)) + PerturbedAdditive( + ranking; ε=1.0, nb_samples=5, perturbation_dist=LogNormal(0, 1) + ), ), error_function=hamming_distance, ) @@ -179,17 +181,18 @@ end @testitem "exp - Pushforward PerturbedAdditive" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random + using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Statistics Random.seed!(63) true_encoder = encoder_factory() cost(y; instance) = dot(y, -true_encoder(instance)) + f(θ; kwargs...) = cost(ranking(θ; kwargs...); kwargs...) test_pipeline!( PipelineLossExperience(); instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=Pushforward(PerturbedAdditive(ranking; ε=1.0, nb_samples=10), cost), + loss=PerturbedAdditive(f; ε=1.0, nb_samples=10), error_function=hamming_distance, true_encoder=true_encoder, cost=cost, @@ -203,12 +206,13 @@ end true_encoder = encoder_factory() cost(y; instance) = dot(y, -true_encoder(instance)) + f(θ; kwargs...) = cost(ranking(θ; kwargs...); kwargs...) test_pipeline!( PipelineLossExperience(); instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=Pushforward(PerturbedMultiplicative(ranking; ε=1.0, nb_samples=10), cost), + loss=PerturbedAdditive(f; ε=1.0, nb_samples=10), error_function=hamming_distance, true_encoder=true_encoder, cost=cost, @@ -222,15 +226,13 @@ end true_encoder = encoder_factory() cost(y; instance) = dot(y, -true_encoder(instance)) + f(θ; kwargs...) = cost(ranking(θ; kwargs...); kwargs...) test_pipeline!( PipelineLossExperience(); instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=Pushforward( - PerturbedAdditive(ranking; ε=1.0, nb_samples=10, perturbation=LogNormal(0, 1)), - cost, - ), + loss=PerturbedAdditive(f; ε=1.0, nb_samples=10, perturbation_dist=LogNormal(0, 1)), error_function=hamming_distance, true_encoder=true_encoder, cost=cost, @@ -245,16 +247,14 @@ end true_encoder = encoder_factory() cost(y; instance) = dot(y, -true_encoder(instance)) + f(θ; kwargs...) = cost(ranking(θ; kwargs...); kwargs...) test_pipeline!( PipelineLossExperience(); instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=Pushforward( - PerturbedMultiplicative( - ranking; ε=1.0, nb_samples=10, perturbation=LogNormal(0, 1) - ), - cost, + loss=PerturbedMultiplicative( + f; ε=1.0, nb_samples=10, perturbation_dist=LogNormal(0, 1) ), error_function=hamming_distance, true_encoder=true_encoder, @@ -263,26 +263,26 @@ end ) end -@testitem "exp - Pushforward PerturbedOracle{LogNormal}" default_imports = false begin - include("InferOptTestUtils/src/InferOptTestUtils.jl") - using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Distributions - Random.seed!(63) - - p(θ) = MvLogNormal(θ, I) - - true_encoder = encoder_factory() - cost(y; instance) = dot(y, -true_encoder(instance)) - test_pipeline!( - PipelineLossExperience(); - instance_dim=5, - true_maximizer=ranking, - maximizer=identity_kw, - loss=Pushforward(PerturbedOracle(ranking, p; nb_samples=10), cost), - error_function=hamming_distance, - true_encoder=true_encoder, - cost=cost, - ) -end +# @testitem "exp - Pushforward PerturbedOracle{LogNormal}" default_imports = false begin +# include("InferOptTestUtils/src/InferOptTestUtils.jl") +# using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Distributions +# Random.seed!(63) + +# p(θ) = MvLogNormal(θ, I) + +# true_encoder = encoder_factory() +# cost(y; instance) = dot(y, -true_encoder(instance)) +# test_pipeline!( +# PipelineLossExperience(); +# instance_dim=5, +# true_maximizer=ranking, +# maximizer=identity_kw, +# loss=Pushforward(PerturbedOracle(ranking, p; nb_samples=10), cost), +# error_function=hamming_distance, +# true_encoder=true_encoder, +# cost=cost, +# ) +# end @testitem "exp - Pushforward RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") From d7608bb18b1f939eb7cb1268d8ff20494fa78c23 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Fri, 2 Aug 2024 11:30:17 +0200 Subject: [PATCH 06/26] linear maximizer --- ext/InferOptFrankWolfeExt.jl | 2 +- src/InferOpt.jl | 4 +- src/layers/perturbed/additive.jl | 12 -- src/layers/perturbed/multiplicative.jl | 17 --- src/layers/perturbed/perturbation.jl | 11 +- src/layers/perturbed/perturbed.jl | 57 +++++++-- .../regularized/regularized_frank_wolfe.jl | 2 +- src/losses/fenchel_young_loss.jl | 116 ++++++++++-------- src/utils/generalized_maximizer.jl | 13 ++ src/utils/linear_maximizer.jl | 46 +++++++ src/utils/pushforward.jl | 2 +- test/learning_generalized_maximizer.jl | 34 ++--- test/learning_ranking.jl | 44 +++---- 13 files changed, 223 insertions(+), 137 deletions(-) delete mode 100644 src/layers/perturbed/additive.jl delete mode 100644 src/layers/perturbed/multiplicative.jl create mode 100644 src/utils/linear_maximizer.jl diff --git a/ext/InferOptFrankWolfeExt.jl b/ext/InferOptFrankWolfeExt.jl index cc2c3dc..3e2a43e 100644 --- a/ext/InferOptFrankWolfeExt.jl +++ b/ext/InferOptFrankWolfeExt.jl @@ -38,7 +38,7 @@ Construct a `DifferentiableFrankWolfe.DiffFW` struct and call `compute_probabili Keyword arguments are passed to the underlying linear maximizer. """ -function DifferentiableExpectations.empirical_distribution( +function InferOpt.compute_probability_distribution( regularized::RegularizedFrankWolfe, θ::AbstractArray; kwargs... ) (; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 767b012..cc70fe3 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -36,6 +36,7 @@ include("interface.jl") include("utils/some_functions.jl") include("utils/pushforward.jl") include("utils/generalized_maximizer.jl") +include("utils/linear_maximizer.jl") include("utils/isotonic_regression/isotonic_l2.jl") include("utils/isotonic_regression/isotonic_kl.jl") include("utils/isotonic_regression/projection.jl") @@ -70,7 +71,8 @@ include("losses/imitation_loss.jl") export half_square_norm export shannon_entropy, negative_shannon_entropy export one_hot_argmax, ranking -export GeneralizedMaximizer, objective_value +export GeneralizedMaximizer +export LinearMaximizer, apply_g, apply_h, objective_value # export FixedAtomsProbabilityDistribution # export compute_expectation diff --git a/src/layers/perturbed/additive.jl b/src/layers/perturbed/additive.jl deleted file mode 100644 index c79d0b7..0000000 --- a/src/layers/perturbed/additive.jl +++ /dev/null @@ -1,12 +0,0 @@ -struct AdditivePerturbation{F} - perturbation_dist::F - ε::Float64 -end - -""" -θ + εZ -""" -function (pdc::AdditivePerturbation)(θ::AbstractArray) - (; perturbation_dist, ε) = pdc - return product_distribution(θ .+ ε * perturbation_dist) -end diff --git a/src/layers/perturbed/multiplicative.jl b/src/layers/perturbed/multiplicative.jl deleted file mode 100644 index 4b139fb..0000000 --- a/src/layers/perturbed/multiplicative.jl +++ /dev/null @@ -1,17 +0,0 @@ -""" -θ ⊙ exp(εZ - ε²/2) -""" -struct MultiplicativePerturbation{F} - perturbation_dist::F - ε::Float64 -end - -function (pdc::MultiplicativePerturbation)(θ::AbstractArray) - (; perturbation_dist, ε) = pdc - return product_distribution(θ .* ExponentialOf(ε * perturbation_dist - ε^2 / 2)) -end - -function Random.rand(rng::AbstractRNG, perturbation::MultiplicativePerturbation) - (; perturbation_dist, ε) = perturbation - return rand(rng, perturbation_dist) -end diff --git a/src/layers/perturbed/perturbation.jl b/src/layers/perturbed/perturbation.jl index d3779a9..831cb27 100644 --- a/src/layers/perturbed/perturbation.jl +++ b/src/layers/perturbed/perturbation.jl @@ -1,4 +1,8 @@ -abstract type AbstractPerturbation end +abstract type AbstractPerturbation <: ContinuousUnivariateDistribution end + +function Random.rand(rng::AbstractRNG, perturbation::AbstractPerturbation) + return rand(rng, perturbation.perturbation_dist) +end struct AdditivePerturbation{F} perturbation_dist::F @@ -25,8 +29,3 @@ function (pdc::MultiplicativePerturbation)(θ::AbstractArray) (; perturbation_dist, ε) = pdc return product_distribution(θ .* ExponentialOf(ε * perturbation_dist - ε^2 / 2)) end - -function Random.rand(rng::AbstractRNG, perturbation::MultiplicativePerturbation) - (; perturbation_dist, ε) = perturbation - return rand(rng, perturbation_dist) -end diff --git a/src/layers/perturbed/perturbed.jl b/src/layers/perturbed/perturbed.jl index ac877a8..c0901d5 100644 --- a/src/layers/perturbed/perturbed.jl +++ b/src/layers/perturbed/perturbed.jl @@ -6,9 +6,11 @@ function (perturbed::Perturbed)(θ::AbstractArray; kwargs...) return perturbed.reinforce(θ; kwargs...) end -function DifferentiableExpectations.empirical_distribution( - perturbed::Perturbed, θ::AbstractArray; kwargs... -) +function get_maximizer(perturbed::Perturbed) + return perturbed.reinforce.f +end + +function compute_probability_distribution(perturbed::Perturbed, θ::AbstractArray; kwargs...) return empirical_distribution(perturbed.reinforce, θ; kwargs...) end @@ -26,8 +28,29 @@ function Base.show(io::IO, perturbed::Perturbed) ) end -function Perturbed(maximizer, dist_constructor; kwargs...) - return Perturbed(Reinforce(maximizer, dist_constructor; kwargs...)) +function Perturbed( + maximizer, + dist_constructor; + nb_samples=1, + variance_reduction=true, + seed=nothing, + threaded=false, + rng=Random.default_rng(), + g=nothing, + h=nothing, +) + linear_maximizer = LinearMaximizer(; maximizer, g, h) + return Perturbed( + Reinforce( + linear_maximizer, + dist_constructor; + nb_samples, + variance_reduction, + seed, + threaded, + rng, + ), + ) end function PerturbedAdditive( @@ -39,10 +62,20 @@ function PerturbedAdditive( seed=nothing, threaded=false, rng=Random.default_rng(), + g=nothing, + h=nothing, ) dist_constructor = AdditivePerturbation(perturbation_dist, float(ε)) return Perturbed( - maximizer, dist_constructor; nb_samples, variance_reduction, seed, threaded, rng + maximizer, + dist_constructor; + nb_samples, + variance_reduction, + seed, + threaded, + rng, + g, + h, ) end @@ -55,9 +88,19 @@ function PerturbedMultiplicative( seed=nothing, threaded=false, rng=Random.default_rng(), + g=nothing, + h=nothing, ) dist_constructor = MultiplicativePerturbation(perturbation_dist, float(ε)) return Perturbed( - maximizer, dist_constructor; nb_samples, variance_reduction, seed, threaded, rng + maximizer, + dist_constructor; + nb_samples, + variance_reduction, + seed, + threaded, + rng, + g, + h, ) end diff --git a/src/layers/regularized/regularized_frank_wolfe.jl b/src/layers/regularized/regularized_frank_wolfe.jl index e7d05bf..bd21ef2 100644 --- a/src/layers/regularized/regularized_frank_wolfe.jl +++ b/src/layers/regularized/regularized_frank_wolfe.jl @@ -60,6 +60,6 @@ end Apply `compute_probability_distribution(regularized, θ; kwargs...)` and return the expectation. """ function (regularized::RegularizedFrankWolfe)(θ::AbstractArray; kwargs...) - probadist = empirical_distribution(regularized, θ; kwargs...) + probadist = compute_probability_distribution(regularized, θ; kwargs...) return mean(probadist) end diff --git a/src/losses/fenchel_young_loss.jl b/src/losses/fenchel_young_loss.jl index fea01a8..92be7f4 100644 --- a/src/losses/fenchel_young_loss.jl +++ b/src/losses/fenchel_young_loss.jl @@ -59,29 +59,30 @@ function fenchel_young_loss_and_grad( end function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {O<:Perturbed} + fyl::FenchelYoungLoss{<:Perturbed}, θ::AbstractArray, y_true::AbstractArray; kwargs... +) (; optimization_layer) = fyl + maximizer = get_maximizer(optimization_layer) F, almost_ŷ = fenchel_young_F_and_first_part_of_grad(optimization_layer, θ; kwargs...) - l = F - dot(θ, y_true) - g = almost_ŷ - y_true + l = F - objective_value(maximizer, θ, y_true; kwargs...) # dot(θ, y_true) + g = almost_ŷ - apply_g(maximizer, y_true; kwargs...) return l, g end -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{<:Perturbed{<:GeneralizedMaximizer}}, - θ::AbstractArray, - y_true::AbstractArray; - kwargs..., -) - (; optimization_layer) = fyl - F, almost_g_of_ŷ = fenchel_young_F_and_first_part_of_grad( - optimization_layer, θ; kwargs... - ) - l = F - objective_value(optimization_layer.reinforce.f, θ, y_true; kwargs...) - g = almost_g_of_ŷ - optimization_layer.reinforce.f.g(y_true; kwargs...) - return l, g -end +# function fenchel_young_loss_and_grad( +# fyl::FenchelYoungLoss{<:Perturbed{<:GeneralizedMaximizer}}, +# θ::AbstractArray, +# y_true::AbstractArray; +# kwargs..., +# ) +# (; optimization_layer) = fyl +# F, almost_g_of_ŷ = fenchel_young_F_and_first_part_of_grad( +# optimization_layer, θ; kwargs... +# ) +# l = F - objective_value(optimization_layer.reinforce.f, θ, y_true; kwargs...) +# g = almost_g_of_ŷ - optimization_layer.reinforce.f.g(y_true; kwargs...) +# return l, g +# end ## Backward pass @@ -123,59 +124,70 @@ function fenchel_young_F_and_first_part_of_grad( perturbed::Perturbed{F,<:AdditivePerturbation}, θ::AbstractArray; kwargs... ) where {F} (; reinforce) = perturbed + maximizer = get_maximizer(perturbed) η_dist = empirical_predistribution(reinforce, θ) - fk = FixKwargs(reinforce.f, kwargs) - y_dist = map(fk, η_dist) - return mean(dot(η, y) for (η, y) in zip(η_dist.atoms, y_dist.atoms)), mean(y_dist) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::Perturbed{F,<:MultiplicativePerturbation}, θ::AbstractArray; kwargs... -) where {F} - (; reinforce) = perturbed - η_dist = empirical_predistribution(reinforce, θ) - fk = FixKwargs(reinforce.f, kwargs) - y_dist = map(fk, η_dist) - eZ_dist = map(Base.Fix2(./, θ), η_dist) - return mean(dot(η, y) for (η, y) in zip(η_dist.atoms, y_dist.atoms)), - mean(map(.*, eZ_dist.atoms, y_dist.atoms)) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::Perturbed{<:GeneralizedMaximizer,<:AdditivePerturbation}, - θ::AbstractArray; - kwargs..., -) - (; reinforce) = perturbed - η_dist = empirical_predistribution(reinforce, θ) - fk = FixKwargs(reinforce.f, kwargs) - gk = FixKwargs(reinforce.f.g, kwargs) + fk = FixKwargs(maximizer, kwargs) + gk = FixKwargs((y; kwargs...) -> apply_g(maximizer, y; kwargs...), kwargs) y_dist = map(fk, η_dist) return mean( - objective_value(reinforce.f, η, y; kwargs...) for + objective_value(maximizer, η, y; kwargs...) for (η, y) in zip(η_dist.atoms, y_dist.atoms) ), mean(gk, y_dist) end +# function fenchel_young_F_and_first_part_of_grad( +# perturbed::Perturbed{<:GeneralizedMaximizer,<:AdditivePerturbation}, +# θ::AbstractArray; +# kwargs..., +# ) +# (; reinforce) = perturbed +# η_dist = empirical_predistribution(reinforce, θ) +# fk = FixKwargs(reinforce.f, kwargs) +# gk = FixKwargs(reinforce.f.g, kwargs) +# y_dist = map(fk, η_dist) +# return mean( +# objective_value(reinforce.f, η, y; kwargs...) for +# (η, y) in zip(η_dist.atoms, y_dist.atoms) +# ), +# mean(gk, y_dist) +# end + function fenchel_young_F_and_first_part_of_grad( - perturbed::Perturbed{<:GeneralizedMaximizer,<:MultiplicativePerturbation}, - θ::AbstractArray; - kwargs..., -) + perturbed::Perturbed{F,<:MultiplicativePerturbation}, θ::AbstractArray; kwargs... +) where {F} (; reinforce) = perturbed + maximizer = get_maximizer(perturbed) η_dist = empirical_predistribution(reinforce, θ) - eZ_dist = map(Base.Fix2(./, θ), η_dist) fk = FixKwargs(reinforce.f, kwargs) - gk = FixKwargs(reinforce.f.g, kwargs) + gk = FixKwargs((y; kwargs...) -> apply_g(maximizer, y; kwargs...), kwargs) y_dist = map(fk, η_dist) + eZ_dist = map(Base.Fix2(./, θ), η_dist) return mean( - objective_value(reinforce.f, η, y; kwargs...) for + objective_value(maximizer, η, y; kwargs...) for (η, y) in zip(η_dist.atoms, y_dist.atoms) ), mean(gk.(map(.*, eZ_dist.atoms, y_dist.atoms))) end +# function fenchel_young_F_and_first_part_of_grad( +# perturbed::Perturbed{<:GeneralizedMaximizer,<:MultiplicativePerturbation}, +# θ::AbstractArray; +# kwargs..., +# ) +# (; reinforce) = perturbed +# η_dist = empirical_predistribution(reinforce, θ) +# eZ_dist = map(Base.Fix2(./, θ), η_dist) +# fk = FixKwargs(reinforce.f, kwargs) +# gk = FixKwargs(reinforce.f.g, kwargs) +# y_dist = map(fk, η_dist) +# return mean( +# objective_value(reinforce.f, η, y; kwargs...) for +# (η, y) in zip(η_dist.atoms, y_dist.atoms) +# ), +# mean(gk.(map(.*, eZ_dist.atoms, y_dist.atoms))) +# end + # function fenchel_young_F_and_first_part_of_grad( # perturbed::PerturbedMultiplicative{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... # ) where {P,G,O<:GeneralizedMaximizer} diff --git a/src/utils/generalized_maximizer.jl b/src/utils/generalized_maximizer.jl index e1a4384..b9154e4 100644 --- a/src/utils/generalized_maximizer.jl +++ b/src/utils/generalized_maximizer.jl @@ -25,6 +25,8 @@ function (f::GeneralizedMaximizer)(θ::AbstractArray{<:Real}; kwargs...) return f.maximizer(θ; kwargs...) end +objective_value(::Any, θ, y; kwargs...) = dot(θ, y) + """ objective_value(f, θ, y, kwargs...) @@ -33,3 +35,14 @@ Computes the objective value of given GeneralizedMaximizer `f`, knowing weights function objective_value(f::GeneralizedMaximizer, θ, y; kwargs...) return dot(θ, f.g(y; kwargs...)) .+ f.h(y; kwargs...) end + +apply_g(::Any, y; kwargs...) = y +apply_h(::Any, y; kwargs...) = zero(eltype(y)) + +function apply_g(f::GeneralizedMaximizer, y; kwargs...) + return f.g(y; kwargs...) +end + +function apply_h(f::GeneralizedMaximizer, y; kwargs...) + return f.h(y; kwargs...) +end diff --git a/src/utils/linear_maximizer.jl b/src/utils/linear_maximizer.jl new file mode 100644 index 0000000..74a1aae --- /dev/null +++ b/src/utils/linear_maximizer.jl @@ -0,0 +1,46 @@ +""" + LinearMaximizer{F,G,H} + +Wrapper for generalized maximizers `maximizer` of the form argmax_y θᵀg(y) + h(y). +It is compatible with the following layers +- [`PerturbedAdditive`](@ref) (with or without [`FenchelYoungLoss`](@ref)) +- [`PerturbedMultiplicative`](@ref) (with or without [`FenchelYoungLoss`](@ref)) +- [`SPOPlusLoss`](@ref) +""" +@kwdef struct LinearMaximizer{G,H,F} + maximizer::F + g::G = nothing + h::H = nothing +end + +function Base.show(io::IO, f::LinearMaximizer) + (; maximizer, g, h) = f + return print(io, "LinearMaximizer($maximizer, $g, $h)") +end + +# Callable calls the wrapped maximizer +function (f::LinearMaximizer)(θ::AbstractArray; kwargs...) + return f.maximizer(θ; kwargs...) +end + +objective_value(::LinearMaximizer{Nothing,Nothing}, θ, y; kwargs...) = dot(θ, y) + +""" + objective_value(f, θ, y, kwargs...) + +Computes the objective value of given LinearMaximizer `f`, knowing weights `θ` and solution `y`. +""" +function objective_value(f::LinearMaximizer, θ, y; kwargs...) + return dot(θ, f.g(y; kwargs...)) .+ f.h(y; kwargs...) +end + +apply_g(::LinearMaximizer{Nothing,Nothing}, y; kwargs...) = y +apply_h(::LinearMaximizer{Nothing,Nothing}, y; kwargs...) = zero(eltype(y)) + +function apply_g(f::LinearMaximizer, y; kwargs...) + return f.g(y; kwargs...) +end + +function apply_h(f::LinearMaximizer, y; kwargs...) + return f.h(y; kwargs...) +end diff --git a/src/utils/pushforward.jl b/src/utils/pushforward.jl index d27ac0b..eaab802 100644 --- a/src/utils/pushforward.jl +++ b/src/utils/pushforward.jl @@ -32,7 +32,7 @@ See also: [`compute_expectation`](@ref). """ function (pushforward::Pushforward)(θ::AbstractArray; kwargs...) (; optimization_layer, post_processing) = pushforward - probadist = empirical_distribution(optimization_layer, θ; kwargs...) + probadist = compute_probability_distribution(optimization_layer, θ; kwargs...) post_processing_kw = FixKwargs(post_processing, kwargs) return mean(post_processing_kw, probadist) end diff --git a/test/learning_generalized_maximizer.jl b/test/learning_generalized_maximizer.jl index 1c84044..1d34cde 100644 --- a/test/learning_generalized_maximizer.jl +++ b/test/learning_generalized_maximizer.jl @@ -13,7 +13,7 @@ @test y == [1 0 1; 0 1 0; 1 1 1] - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + generalized_maximizer = LinearMaximizer(; maximizer=max_pricing, g, h) @test generalized_maximizer(θ; instance) == y @@ -29,16 +29,17 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + perturbed = PerturbedAdditive(max_pricing; ε=1.0, nb_samples=10, g, h) + maximizer = InferOpt.get_maximizer(perturbed) function cost(y; instance) - return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) + return -objective_value(maximizer, true_encoder(instance), y; instance) end test_pipeline!( PipelineLossImitation(); instance_dim=5, true_maximizer=max_pricing, - maximizer=PerturbedAdditive(generalized_maximizer; ε=1.0, nb_samples=10), + maximizer=perturbed, loss=mse_kw, error_function=hamming_distance, cost, @@ -54,16 +55,17 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + perturbed = PerturbedMultiplicative(max_pricing; ε=1.0, nb_samples=10, g, h) + maximizer = InferOpt.get_maximizer(perturbed) function cost(y; instance) - return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) + return -objective_value(maximizer, true_encoder(instance), y; instance) end test_pipeline!( PipelineLossImitation(); instance_dim=5, true_maximizer=max_pricing, - maximizer=PerturbedMultiplicative(generalized_maximizer; ε=1.0, nb_samples=10), + maximizer=perturbed, loss=mse_kw, error_function=hamming_distance, cost, @@ -78,9 +80,10 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + perturbed = PerturbedAdditive(max_pricing; ε=1.0, nb_samples=10, g, h) + maximizer = InferOpt.get_maximizer(perturbed) function cost(y; instance) - return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) + return -objective_value(maximizer, true_encoder(instance), y; instance) end test_pipeline!( @@ -88,9 +91,7 @@ end instance_dim=5, true_maximizer=max_pricing, maximizer=identity_kw, - loss=FenchelYoungLoss( - PerturbedAdditive(generalized_maximizer; ε=1.0, nb_samples=10) - ), + loss=FenchelYoungLoss(perturbed), error_function=hamming_distance, cost, true_encoder, @@ -105,9 +106,10 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + perturbed = PerturbedMultiplicative(max_pricing; ε=0.1, nb_samples=10, g, h) + maximizer = InferOpt.get_maximizer(perturbed) function cost(y; instance) - return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) + return -objective_value(maximizer, true_encoder(instance), y; instance) end test_pipeline!( @@ -115,9 +117,7 @@ end instance_dim=5, true_maximizer=max_pricing, maximizer=identity_kw, - loss=FenchelYoungLoss( - PerturbedMultiplicative(generalized_maximizer; ε=0.1, nb_samples=10) - ), + loss=FenchelYoungLoss(perturbed), error_function=hamming_distance, cost, true_encoder, diff --git a/test/learning_ranking.jl b/test/learning_ranking.jl index 3f017d7..5c2e095 100644 --- a/test/learning_ranking.jl +++ b/test/learning_ranking.jl @@ -236,7 +236,7 @@ end error_function=hamming_distance, true_encoder=true_encoder, cost=cost, - epochs=500, + epochs=100, ) end @@ -259,30 +259,30 @@ end error_function=hamming_distance, true_encoder=true_encoder, cost=cost, - epochs=500, + epochs=100, ) end -# @testitem "exp - Pushforward PerturbedOracle{LogNormal}" default_imports = false begin -# include("InferOptTestUtils/src/InferOptTestUtils.jl") -# using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Distributions -# Random.seed!(63) - -# p(θ) = MvLogNormal(θ, I) - -# true_encoder = encoder_factory() -# cost(y; instance) = dot(y, -true_encoder(instance)) -# test_pipeline!( -# PipelineLossExperience(); -# instance_dim=5, -# true_maximizer=ranking, -# maximizer=identity_kw, -# loss=Pushforward(PerturbedOracle(ranking, p; nb_samples=10), cost), -# error_function=hamming_distance, -# true_encoder=true_encoder, -# cost=cost, -# ) -# end +@testitem "exp - Pushforward PerturbedOracle{LogNormal}" default_imports = false begin + include("InferOptTestUtils/src/InferOptTestUtils.jl") + using InferOpt, .InferOptTestUtils, LinearAlgebra, Random, Distributions + Random.seed!(63) + + p(θ) = MvLogNormal(θ, I) + + true_encoder = encoder_factory() + cost(y; instance) = dot(y, -true_encoder(instance)) + test_pipeline!( + PipelineLossExperience(); + instance_dim=5, + true_maximizer=ranking, + maximizer=identity_kw, + loss=Pushforward(Perturbed(ranking, p; nb_samples=10), cost), + error_function=hamming_distance, + true_encoder=true_encoder, + cost=cost, + ) +end @testitem "exp - Pushforward RegularizedFrankWolfe" default_imports = false begin include("InferOptTestUtils/src/InferOptTestUtils.jl") From bfe0cfff7214473205de307f2a3b97b9114dffb7 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Fri, 2 Aug 2024 12:31:28 +0200 Subject: [PATCH 07/26] cleanup --- src/InferOpt.jl | 1 + src/layers/perturbed/perturbed.jl | 39 ++++++------- src/losses/fenchel_young_loss.jl | 94 ++----------------------------- test/learning_ranking.jl | 3 +- 4 files changed, 27 insertions(+), 110 deletions(-) diff --git a/src/InferOpt.jl b/src/InferOpt.jl index cc70fe3..8a4dfa1 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -89,6 +89,7 @@ export SoftRank, soft_rank, soft_rank_l2, soft_rank_kl export SoftSort, soft_sort, soft_sort_l2, soft_sort_kl export RegularizedFrankWolfe +export Perturbed export PerturbedAdditive export PerturbedMultiplicative # export PerturbedOracle diff --git a/src/layers/perturbed/perturbed.jl b/src/layers/perturbed/perturbed.jl index c0901d5..94d97ea 100644 --- a/src/layers/perturbed/perturbed.jl +++ b/src/layers/perturbed/perturbed.jl @@ -1,4 +1,4 @@ -struct Perturbed{F,D,t,variance_reduction,G,R,S} <: AbstractOptimizationLayer +struct Perturbed{D,F,t,variance_reduction,G,R,S} <: AbstractOptimizationLayer reinforce::Reinforce{t,variance_reduction,F,D,G,R,S} end @@ -14,7 +14,7 @@ function compute_probability_distribution(perturbed::Perturbed, θ::AbstractArra return empirical_distribution(perturbed.reinforce, θ; kwargs...) end -function Base.show(io::IO, perturbed::Perturbed) +function Base.show(io::IO, perturbed::Perturbed{<:AbstractPerturbation}) (; reinforce) = perturbed nb_samples = reinforce.nb_samples ε = reinforce.dist_constructor.ε @@ -30,26 +30,15 @@ end function Perturbed( maximizer, - dist_constructor; - nb_samples=1, - variance_reduction=true, - seed=nothing, - threaded=false, - rng=Random.default_rng(), + dist_constructor, + dist_logdensity_grad=nothing; g=nothing, h=nothing, + kwargs..., ) linear_maximizer = LinearMaximizer(; maximizer, g, h) return Perturbed( - Reinforce( - linear_maximizer, - dist_constructor; - nb_samples, - variance_reduction, - seed, - threaded, - rng, - ), + Reinforce(linear_maximizer, dist_constructor, dist_logdensity_grad; kwargs...) ) end @@ -64,11 +53,17 @@ function PerturbedAdditive( rng=Random.default_rng(), g=nothing, h=nothing, + dist_logdensity_grad=if (perturbation_dist == Normal(0, 1)) + (η, θ) -> ((η .- θ) ./ ε^2,) + else + nothing + end, ) dist_constructor = AdditivePerturbation(perturbation_dist, float(ε)) return Perturbed( maximizer, - dist_constructor; + dist_constructor, + dist_logdensity_grad; nb_samples, variance_reduction, seed, @@ -90,11 +85,17 @@ function PerturbedMultiplicative( rng=Random.default_rng(), g=nothing, h=nothing, + dist_logdensity_grad=if (perturbation_dist == Normal(0, 1)) + (η, θ) -> (inv.(ε^2 .* θ) .* (η .- θ),) + else + nothing + end, ) dist_constructor = MultiplicativePerturbation(perturbation_dist, float(ε)) return Perturbed( maximizer, - dist_constructor; + dist_constructor, + dist_logdensity_grad; nb_samples, variance_reduction, seed, diff --git a/src/losses/fenchel_young_loss.jl b/src/losses/fenchel_young_loss.jl index 92be7f4..3343170 100644 --- a/src/losses/fenchel_young_loss.jl +++ b/src/losses/fenchel_young_loss.jl @@ -69,21 +69,6 @@ function fenchel_young_loss_and_grad( return l, g end -# function fenchel_young_loss_and_grad( -# fyl::FenchelYoungLoss{<:Perturbed{<:GeneralizedMaximizer}}, -# θ::AbstractArray, -# y_true::AbstractArray; -# kwargs..., -# ) -# (; optimization_layer) = fyl -# F, almost_g_of_ŷ = fenchel_young_F_and_first_part_of_grad( -# optimization_layer, θ; kwargs... -# ) -# l = F - objective_value(optimization_layer.reinforce.f, θ, y_true; kwargs...) -# g = almost_g_of_ŷ - optimization_layer.reinforce.f.g(y_true; kwargs...) -# return l, g -# end - ## Backward pass function ChainRulesCore.rrule( @@ -96,33 +81,9 @@ end ## Specific overrides for perturbed layers -# function compute_F_and_y_samples( -# perturbed::AbstractPerturbed{O,false}, -# θ::AbstractArray, -# Z_samples::Vector{<:AbstractArray}; -# kwargs..., -# ) where {O} -# F_and_y_samples = [ -# fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...) for -# Z in Z_samples -# ] -# return F_and_y_samples -# end - -# function compute_F_and_y_samples( -# perturbed::AbstractPerturbed{O,true}, -# θ::AbstractArray, -# Z_samples::Vector{<:AbstractArray}; -# kwargs..., -# ) where {O} -# return ThreadsX.map( -# Z -> fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...), Z_samples -# ) -# end - function fenchel_young_F_and_first_part_of_grad( - perturbed::Perturbed{F,<:AdditivePerturbation}, θ::AbstractArray; kwargs... -) where {F} + perturbed::Perturbed{<:AdditivePerturbation}, θ::AbstractArray; kwargs... +) (; reinforce) = perturbed maximizer = get_maximizer(perturbed) η_dist = empirical_predistribution(reinforce, θ) @@ -136,26 +97,9 @@ function fenchel_young_F_and_first_part_of_grad( mean(gk, y_dist) end -# function fenchel_young_F_and_first_part_of_grad( -# perturbed::Perturbed{<:GeneralizedMaximizer,<:AdditivePerturbation}, -# θ::AbstractArray; -# kwargs..., -# ) -# (; reinforce) = perturbed -# η_dist = empirical_predistribution(reinforce, θ) -# fk = FixKwargs(reinforce.f, kwargs) -# gk = FixKwargs(reinforce.f.g, kwargs) -# y_dist = map(fk, η_dist) -# return mean( -# objective_value(reinforce.f, η, y; kwargs...) for -# (η, y) in zip(η_dist.atoms, y_dist.atoms) -# ), -# mean(gk, y_dist) -# end - function fenchel_young_F_and_first_part_of_grad( - perturbed::Perturbed{F,<:MultiplicativePerturbation}, θ::AbstractArray; kwargs... -) where {F} + perturbed::Perturbed{<:MultiplicativePerturbation}, θ::AbstractArray; kwargs... +) (; reinforce) = perturbed maximizer = get_maximizer(perturbed) η_dist = empirical_predistribution(reinforce, θ) @@ -169,33 +113,3 @@ function fenchel_young_F_and_first_part_of_grad( ), mean(gk.(map(.*, eZ_dist.atoms, y_dist.atoms))) end - -# function fenchel_young_F_and_first_part_of_grad( -# perturbed::Perturbed{<:GeneralizedMaximizer,<:MultiplicativePerturbation}, -# θ::AbstractArray; -# kwargs..., -# ) -# (; reinforce) = perturbed -# η_dist = empirical_predistribution(reinforce, θ) -# eZ_dist = map(Base.Fix2(./, θ), η_dist) -# fk = FixKwargs(reinforce.f, kwargs) -# gk = FixKwargs(reinforce.f.g, kwargs) -# y_dist = map(fk, η_dist) -# return mean( -# objective_value(reinforce.f, η, y; kwargs...) for -# (η, y) in zip(η_dist.atoms, y_dist.atoms) -# ), -# mean(gk.(map(.*, eZ_dist.atoms, y_dist.atoms))) -# end - -# function fenchel_young_F_and_first_part_of_grad( -# perturbed::PerturbedMultiplicative{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... -# ) where {P,G,O<:GeneralizedMaximizer} -# (; oracle, ε) = perturbed -# eZ = exp.(ε .* Z .- ε^2) -# η = θ .* eZ -# y = oracle(η; kwargs...) -# F = objective_value(oracle, η, y; kwargs...) -# y_scaled = y .* eZ -# return F, oracle.g(y_scaled; kwargs...) -# end diff --git a/test/learning_ranking.jl b/test/learning_ranking.jl index 5c2e095..3827e4f 100644 --- a/test/learning_ranking.jl +++ b/test/learning_ranking.jl @@ -272,12 +272,13 @@ end true_encoder = encoder_factory() cost(y; instance) = dot(y, -true_encoder(instance)) + f(y; instance) = cost(ranking(y; instance); instance) test_pipeline!( PipelineLossExperience(); instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=Pushforward(Perturbed(ranking, p; nb_samples=10), cost), + loss=Perturbed(f, p; nb_samples=10), error_function=hamming_distance, true_encoder=true_encoder, cost=cost, From da29725271a7baa0506b2b159e350e6b3a75159e Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Fri, 2 Aug 2024 15:57:00 +0200 Subject: [PATCH 08/26] update --- docs/Manifest.toml | 300 ++++++++++++++++++------------ docs/make.jl | 8 +- src/interface.jl | 4 +- src/layers/perturbed/perturbed.jl | 35 +++- src/utils/linear_maximizer.jl | 12 +- src/utils/pushforward.jl | 6 +- test/learning_ranking.jl | 2 +- 7 files changed, 228 insertions(+), 139 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 3ebfcfb..34d5b48 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,6 +1,6 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.2" +julia_version = "1.10.4" manifest_format = "2.0" project_hash = "62c6c2d871b6216a78a27ec0b4dd2c53f1dc4f4f" @@ -27,9 +27,9 @@ version = "0.4.5" [[deps.Accessors]] deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "c0d491ef0b135fd7d63cbc6404286bc633329425" +git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.36" +version = "0.1.37" [deps.Accessors.extensions] AccessorsAxisKeysExt = "AxisKeys" @@ -48,14 +48,20 @@ version = "0.1.36" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "cea4ac3f5b4bc4b3000aa55afb6e5626518948fa" +git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.3" +version = "4.0.4" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] AdaptStaticArraysExt = "StaticArrays" +[[deps.AliasTables]] +deps = ["PtrArrays", "Random"] +git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" +uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" +version = "1.1.3" + [[deps.ArgCheck]] git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -67,9 +73,9 @@ version = "1.1.1" [[deps.ArnoldiMethod]] deps = ["LinearAlgebra", "Random", "StaticArrays"] -git-tree-sha1 = "62e51b39331de8911e4a7ff6f5aaf38a5f4cc0ae" +git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" uuid = "ec485272-7323-5ecc-a04f-4719b315124d" -version = "0.2.0" +version = "0.4.0" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -81,16 +87,17 @@ uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" version = "0.1.0" [[deps.BangBang]] -deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"] -git-tree-sha1 = "7aa7ad1682f3d5754e3491bb59b8103cae28e3a3" +deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] +git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.3.40" +version = "0.4.3" [deps.BangBang.extensions] BangBangChainRulesCoreExt = "ChainRulesCore" BangBangDataFramesExt = "DataFrames" BangBangStaticArraysExt = "StaticArrays" BangBangStructArraysExt = "StructArrays" + BangBangTablesExt = "Tables" BangBangTypedTablesExt = "TypedTables" [deps.BangBang.weakdeps] @@ -98,6 +105,7 @@ version = "0.3.40" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" [[deps.Base64]] @@ -121,20 +129,26 @@ version = "0.5.1" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "4e42872be98fa3343c4f8458cbda8c5c6a6fa97c" +git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.63.0" +version = "1.69.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "575cd02e080939a33b6df6c5853d14924c08e35b" +git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.23.0" +version = "1.24.0" weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] ChainRulesCoreSparseArraysExt = "SparseArrays" +[[deps.ChunkSplitters]] +deps = ["Compat", "TestItems"] +git-tree-sha1 = "783507c1f2371c8f2d321f41c3057ecd42cafa83" +uuid = "ae650224-84b6-46f8-82ea-d812ca08434e" +version = "2.4.5" + [[deps.CodeTracking]] deps = ["InteractiveUtils", "UUIDs"] git-tree-sha1 = "c0216e792f518b39b22212127d4a84dc31e4e386" @@ -143,21 +157,21 @@ version = "1.3.5" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "59939d8a997469ee05c4b4944560a820f9ba0d73" +git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.4" +version = "0.7.5" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "67c1f244b991cad9b0aa4b7540fb758c2488b129" +git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.24.0" +version = "3.26.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "eb7f0f8307f71fac7c606984ea5fb2817275d6e4" +git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.4" +version = "0.11.5" [[deps.ColorVectorSpace]] deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] @@ -171,9 +185,9 @@ weakdeps = ["SpecialFunctions"] [[deps.Colors]] deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" +git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.10" +version = "0.12.11" [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] @@ -183,9 +197,9 @@ version = "0.3.0" [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "c955881e3c981181362ae4088b35995446298b80" +git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.14.0" +version = "4.15.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -194,7 +208,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.0+0" +version = "1.1.1+0" [[deps.CompositionsBase]] git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" @@ -207,9 +221,9 @@ weakdeps = ["InverseFunctions"] [[deps.ConstructionBase]] deps = ["LinearAlgebra"] -git-tree-sha1 = "c53fc348ca4d40d7b371e71fd52251839080cbc9" +git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.4" +version = "1.5.6" [deps.ConstructionBase.extensions] ConstructionBaseIntervalSetsExt = "IntervalSets" @@ -226,9 +240,9 @@ uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" version = "0.1.3" [[deps.Contour]] -git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" +git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" -version = "0.6.2" +version = "0.6.3" [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" @@ -242,9 +256,9 @@ version = "1.16.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "0f4b5d62a88d8f59003e43c25a8a90de9eb76317" +git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.18" +version = "0.18.20" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -284,10 +298,28 @@ git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" version = "1.15.1" +[[deps.DifferentiableExpectations]] +deps = ["ChainRulesCore", "Compat", "DensityInterface", "Distributions", "DocStringExtensions", "LinearAlgebra", "OhMyThreads", "Random", "Statistics", "StatsBase"] +path = "../../DifferentiableExpectations.jl" +uuid = "fc55d66b-b2a8-4ccc-9d64-c0c2166ceb36" +version = "0.1.0" + [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +[[deps.Distributions]] +deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] +git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.25.109" +weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] + + [deps.Distributions.extensions] + DistributionsChainRulesCoreExt = "ChainRulesCore" + DistributionsDensityInterfaceExt = "DensityInterface" + DistributionsTestExt = "Test" + [[deps.DocStringExtensions]] deps = ["LibGit2"] git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" @@ -296,9 +328,9 @@ version = "0.9.3" [[deps.Documenter]] deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"] -git-tree-sha1 = "4a40af50e8b24333b9ec6892546d9ca5724228eb" +git-tree-sha1 = "76deb8c15f37a3853f13ea2226b8f2577652de05" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "1.3.0" +version = "1.5.0" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -313,15 +345,15 @@ version = "0.6.8" [[deps.Expat_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "4558ab818dcceaab612d1bb8c19cee87eda2b83c" +git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.5.0+0" +version = "2.6.2+0" [[deps.FLoops]] deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "ffb97765602e3cbe59a0589d237bf07f245a8576" +git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.1" +version = "0.2.2" [[deps.FLoopsBase]] deps = ["ContextVariablesX"] @@ -340,15 +372,15 @@ version = "0.13.11" [[deps.FixedPointNumbers]] deps = ["Statistics"] -git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.4" +version = "0.8.5" [[deps.Flux]] deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "5a626d6ef24ae0a8590c22dc12096fb65eb66325" +git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.13" +version = "0.14.16" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" @@ -374,9 +406,9 @@ weakdeps = ["StaticArrays"] [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "8ae30e786837ce0a24f5e2186938bf3251ab94b2" +git-tree-sha1 = "64d8e93700c7a3f28f717d265382d52fac9fa1c1" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.8" +version = "0.4.12" [[deps.Future]] deps = ["Random"] @@ -384,9 +416,9 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "47e4686ec18a9620850bad110b79966132f14283" +git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.0.2" +version = "10.3.0" [[deps.GPUArraysCore]] deps = ["Adapt"] @@ -402,15 +434,15 @@ version = "1.3.1" [[deps.Git_jll]] deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "12945451c5d0e2d0dca0724c3a8d6448b46bbdf9" +git-tree-sha1 = "d18fb8a1f3609361ebda9bf029b60fd0f120c809" uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb" -version = "2.44.0+1" +version = "2.44.0+2" [[deps.Graphs]] deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "899050ace26649433ef1af25bc17a815b3db52b7" +git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.9.0" +version = "1.11.2" [[deps.GridGraphs]] deps = ["DataStructures", "FillArrays", "Graphs", "SparseArrays"] @@ -426,18 +458,18 @@ version = "0.3.23" [[deps.IOCapture]] deps = ["Logging", "Random"] -git-tree-sha1 = "8b72179abc660bfab5e28472e019392b97d0985c" +git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.4" +version = "0.2.5" [[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "5d8c5713f38f7bc029e26627b687710ba406d0dd" +deps = ["InteractiveUtils", "MacroTools"] +git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.12" +version = "0.4.14" [[deps.InferOpt]] -deps = ["ChainRulesCore", "DensityInterface", "LinearAlgebra", "Random", "RequiredInterfaces", "Statistics", "StatsBase", "StatsFuns", "ThreadsX"] +deps = ["ChainRulesCore", "DensityInterface", "DifferentiableExpectations", "Distributions", "DocStringExtensions", "LinearAlgebra", "Random", "RequiredInterfaces", "Statistics", "StatsBase", "StatsFuns", "ThreadsX"] path = ".." uuid = "4846b161-c94e-4150-8dac-c7ae193c601f" version = "0.6.1" @@ -449,9 +481,9 @@ version = "0.6.1" DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d" [[deps.Inflate]] -git-tree-sha1 = "ea8031dea4aff6bd41f1df8f2fdfb25b33626381" +git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.4" +version = "0.1.5" [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" @@ -464,9 +496,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" [[deps.InverseFunctions]] deps = ["Test"] -git-tree-sha1 = "896385798a8d49a255c398bd49162062e4a4c435" +git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.13" +version = "0.1.15" weakdeps = ["Dates"] [deps.InverseFunctions.extensions] @@ -496,9 +528,9 @@ version = "0.21.4" [[deps.JuliaInterpreter]] deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] -git-tree-sha1 = "7b762d81887160169ddfc93a47e5fd7a6a3e78ef" +git-tree-sha1 = "5d3a5a206297af3868151bb4a2cf27ebce46f16d" uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" -version = "0.9.29" +version = "0.9.33" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -508,9 +540,9 @@ version = "0.2.4" [[deps.KernelAbstractions]] deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "ed7167240f40e62d97c1f5f7735dea6de3cc5c49" +git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.18" +version = "0.9.22" [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" @@ -520,9 +552,9 @@ version = "0.9.18" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "7c6650580b4c3169d9905858160db895bff6d2e2" +git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.6.1" +version = "8.0.0" [deps.LLVM.extensions] BFloat16sExt = "BFloat16s" @@ -532,9 +564,9 @@ version = "6.6.1" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "88b916503aac4fb7f701bb625cd84ca5dd1677bc" +git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.29+0" +version = "0.0.30+0" [[deps.LazilyInitializedFields]] git-tree-sha1 = "8f7f3cabab0fd1800699663533b6d5cb3fc0e612" @@ -584,15 +616,15 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [[deps.Literate]] deps = ["Base64", "IOCapture", "JSON", "REPL"] -git-tree-sha1 = "bad26f1ccd99c553886ec0725e99a509589dcd11" +git-tree-sha1 = "eef2e1fc1dc38af90a18eb16e519e06d1fd10c2a" uuid = "98b081ad-f1c9-55d3-8b20-4c87d4299306" -version = "2.16.1" +version = "2.19.0" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "18144f3e9cbe9b15b070288eef858f71b291ce37" +git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.27" +version = "0.3.28" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -609,9 +641,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.LoweredCodeUtils]] deps = ["JuliaInterpreter"] -git-tree-sha1 = "31e27f0b0bf0df3e3e951bfcc43fe8c730a219f6" +git-tree-sha1 = "1ce1834f9644a8f7c011eb0592b7fd6c42c90653" uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" -version = "2.4.5" +version = "3.0.1" [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" @@ -652,16 +684,16 @@ uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" version = "2.28.2+1" [[deps.MicroCollections]] -deps = ["BangBang", "InitialValues", "Setfield"] -git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" +deps = ["Accessors", "BangBang", "InitialValues"] +git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.1.4" +version = "0.2.0" [[deps.Missings]] deps = ["DataAPI"] -git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" +git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.1.0" +version = "1.2.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -672,20 +704,22 @@ version = "2023.1.10" [[deps.NNlib]] deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "877f15c331337d54cf24c797d5bcb2e48ce21221" +git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.12" +version = "0.9.21" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] NNlibCUDAExt = "CUDA" NNlibEnzymeCoreExt = "EnzymeCore" + NNlibFFTWExt = "FFTW" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NaNMath]] @@ -704,6 +738,12 @@ version = "0.1.5" uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" version = "1.2.0" +[[deps.OhMyThreads]] +deps = ["BangBang", "ChunkSplitters", "StableTasks", "TaskLocalValues"] +git-tree-sha1 = "881876fc70ab53ad60671ad4a1af25c920aee0eb" +uuid = "67456a42-1dca-4109-a031-0a68de7e3ad5" +version = "0.5.3" + [[deps.OneHotArrays]] deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" @@ -722,9 +762,9 @@ version = "0.8.1+2" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "60e3045590bd104a16fefb12836c00c0ef8c7f8c" +git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.13+0" +version = "3.0.14+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -734,9 +774,9 @@ version = "0.5.5+0" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "264b061c1903bc0fe9be77cb9050ebacff66bb63" +git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.2" +version = "0.3.3" [[deps.OrderedCollections]] git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" @@ -748,6 +788,12 @@ deps = ["Artifacts", "Libdl"] uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" version = "10.42.0+1" +[[deps.PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.11.31" + [[deps.Parsers]] deps = ["Dates", "PrecompileTools", "UUIDs"] git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" @@ -788,9 +834,20 @@ version = "0.1.4" [[deps.ProgressMeter]] deps = ["Distributed", "Printf"] -git-tree-sha1 = "763a8ceb07833dd51bb9e3bbca372de32c0605ad" +git-tree-sha1 = "8f6bc219586aef8baf0ff9a5fe16ee9c70cb65e4" uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.10.0" +version = "1.10.2" + +[[deps.PtrArrays]] +git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" +uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" +version = "1.2.0" + +[[deps.QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "e237232771fdafbae3db5c31275303e056afaa9f" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.10.1" [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] @@ -825,9 +882,9 @@ version = "0.1.0" [[deps.RequiredInterfaces]] deps = ["InteractiveUtils", "Logging", "Test"] -git-tree-sha1 = "e7eb973af4753abf5d866941268ec6ea2aec5556" +git-tree-sha1 = "c3250333ea2894237ed015baf7d5fcb8a1ea3169" uuid = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" -version = "0.1.5" +version = "0.1.6" [[deps.Requires]] deps = ["UUIDs"] @@ -836,10 +893,10 @@ uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "1.3.0" [[deps.Revise]] -deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "Pkg", "REPL", "Requires", "UUIDs", "Unicode"] -git-tree-sha1 = "12aa2d7593df490c407a3bbd8b86b8b515017f3e" +deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "REPL", "Requires", "UUIDs", "Unicode"] +git-tree-sha1 = "7b7850bb94f75762d567834d7e9802fc22d62f9c" uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" -version = "3.5.14" +version = "3.5.18" [[deps.Rmath]] deps = ["Random", "Rmath_jll"] @@ -848,10 +905,10 @@ uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" version = "0.7.1" [[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.0+0" +version = "0.4.2+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -903,9 +960,9 @@ version = "0.1.2" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" +git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.1" +version = "2.4.0" weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] @@ -917,11 +974,16 @@ git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" uuid = "171d559e-b47b-412a-8079-5efa626c420e" version = "0.1.15" +[[deps.StableTasks]] +git-tree-sha1 = "073d5c20d44129b20fe954720b97069579fa403b" +uuid = "91464d47-22a1-43fe-8b7f-2d57ee82463f" +version = "0.1.5" + [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "bf074c045d3d5ffd956fa0a461da38a44685d6b2" +git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.3" +version = "1.9.7" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -929,9 +991,9 @@ weakdeps = ["ChainRulesCore", "Statistics"] StaticArraysStatisticsExt = "Statistics" [[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" +version = "1.4.3" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -946,9 +1008,9 @@ version = "1.7.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "1d77abd07f617c4868c33d4f5b9e1dbb2643c9cf" +git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.2" +version = "0.34.3" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] @@ -995,16 +1057,21 @@ uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" version = "1.0.1" [[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "cb76cf677714c095e535e3501ac7954732aeea2d" +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.11.1" +version = "1.12.0" [[deps.Tar]] deps = ["ArgTools", "SHA"] uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" +[[deps.TaskLocalValues]] +git-tree-sha1 = "eb0b8d147eb907a9ad3fd952da7c6a053b29ae28" +uuid = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34" +version = "0.1.1" + [[deps.TensorCore]] deps = ["LinearAlgebra"] git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" @@ -1015,6 +1082,11 @@ version = "0.1.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[deps.TestItems]] +git-tree-sha1 = "42fd9023fef18b9b78c8343a4e2f3813ffbcefcb" +uuid = "1c621080-faea-4a02-84b6-bbd5e436b8fe" +version = "1.0.0" + [[deps.ThreadsX]] deps = ["Accessors", "ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "SplittablesBase", "Transducers"] git-tree-sha1 = "70bd8244f4834d46c3d68bd09e7792d8f571ef04" @@ -1022,19 +1094,19 @@ uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" version = "0.1.12" [[deps.TranscodingStreams]] -git-tree-sha1 = "3caa21522e7efac1ba21834a03734c57b4611c7e" +git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.10.4" +version = "0.11.1" weakdeps = ["Random", "Test"] [deps.TranscodingStreams.extensions] TestExt = ["Test", "Random"] [[deps.Transducers]] -deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] -git-tree-sha1 = "3064e780dbb8a9296ebb3af8f440f787bb5332af" +deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.80" +version = "0.4.82" [deps.Transducers.extensions] TransducersBlockArraysExt = "BlockArrays" @@ -1085,9 +1157,9 @@ version = "0.2.1" [[deps.UnsafeAtomicsLLVM]] deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" +git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.3" +version = "0.1.5" [[deps.Zlib_jll]] deps = ["Libdl"] @@ -1096,9 +1168,9 @@ version = "1.2.13+1" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "4ddb4470e47b0094c93055a3bcae799165cc68f1" +git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.69" +version = "0.6.70" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" diff --git a/docs/make.jl b/docs/make.jl index 83c3208..0046b0b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -11,7 +11,7 @@ open(joinpath(@__DIR__, "src", "index.md"), "w") do io io, """ ```@meta - EditURL = "https://github.com/axelparmentier/InferOpt.jl/blob/main/README.md" + EditURL = "https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl/blob/main/README.md" ``` """, ) @@ -30,13 +30,13 @@ Literate.markdown(tuto_jl_file, tuto_md_dir; documenter=true, execute=false) makedocs(; modules=[InferOpt], authors="Guillaume Dalle, Léo Baty, Louis Bouvier, Axel Parmentier", - repo="https://github.com/axelparmentier/InferOpt.jl/blob/{commit}{path}#{line}", + repo="https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl/blob/{commit}{path}#{line}", sitename="InferOpt.jl", format=Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", - canonical="https://axelparmentier.github.io/InferOpt.jl", + canonical="https://juliadecisionfocusedlearning.github.io/InferOpt.jl", assets=String[], - repolink="https://github.com/axelparmentier/InferOpt.jl", + repolink="https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl", ), pages=[ "Home" => "index.md", diff --git a/src/interface.jl b/src/interface.jl index cb4ac75..2fa67a8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -18,7 +18,7 @@ abstract type AbstractLayer end Supertype for all the optimization layers defined in InferOpt. # Interface -- `(layer::AbstractOptimizationLayer)(θ; kwargs...)` +- `(layer::AbstractOptimizationLayer)(θ::AbstractArray; kwargs...)` - `compute_probability_distribution(layer, θ; kwargs...)` (only if the layer is probabilistic) """ abstract type AbstractOptimizationLayer <: AbstractLayer end @@ -45,6 +45,6 @@ abstract type AbstractLossLayer <: AbstractLayer end """ compute_probability_distribution(layer, θ; kwargs...) -Apply a probabilistic optimization layer to an objective direction `θ` in order to generate a [`FixedAtomsProbabilityDistribution`](@ref) on the vertices of a polytope. +Apply a probabilistic optimization layer to an objective direction `θ` in order to generate a `FixedAtomsProbabilityDistribution` on the vertices of a polytope. """ function compute_probability_distribution end diff --git a/src/layers/perturbed/perturbed.jl b/src/layers/perturbed/perturbed.jl index 94d97ea..06dafb0 100644 --- a/src/layers/perturbed/perturbed.jl +++ b/src/layers/perturbed/perturbed.jl @@ -1,3 +1,15 @@ +""" + Perturbed{D,F} <: AbstractOptimizationLayer + +Differentiable perturbation of a black box optimizer of type `F`, with perturbation of type `D`. + +This struct is as wrapper around `Reinforce` from DifferentiableExpectations.jl. + +There are three different available constructors that behave differently in the package: +- [`LinearPerturbed`](@ref) +- [`PerturbedAdditive`](@ref) +- [`PerturbedMultiplicative`](@ref) +""" struct Perturbed{D,F,t,variance_reduction,G,R,S} <: AbstractOptimizationLayer reinforce::Reinforce{t,variance_reduction,F,D,G,R,S} end @@ -28,7 +40,10 @@ function Base.show(io::IO, perturbed::Perturbed{<:AbstractPerturbation}) ) end -function Perturbed( +""" +doc +""" +function LinearPerturbed( maximizer, dist_constructor, dist_logdensity_grad=nothing; @@ -42,6 +57,9 @@ function Perturbed( ) end +""" + PerturbedAdditive(maximizer; kwargs...) +""" function PerturbedAdditive( maximizer; ε=1.0, @@ -51,8 +69,8 @@ function PerturbedAdditive( seed=nothing, threaded=false, rng=Random.default_rng(), - g=nothing, - h=nothing, + g=identity_kw, + h=zero ∘ eltype_kw, dist_logdensity_grad=if (perturbation_dist == Normal(0, 1)) (η, θ) -> ((η .- θ) ./ ε^2,) else @@ -60,7 +78,7 @@ function PerturbedAdditive( end, ) dist_constructor = AdditivePerturbation(perturbation_dist, float(ε)) - return Perturbed( + return LinearPerturbed( maximizer, dist_constructor, dist_logdensity_grad; @@ -74,6 +92,9 @@ function PerturbedAdditive( ) end +""" +doc +""" function PerturbedMultiplicative( maximizer; ε=1.0, @@ -83,8 +104,8 @@ function PerturbedMultiplicative( seed=nothing, threaded=false, rng=Random.default_rng(), - g=nothing, - h=nothing, + g=identity_kw, + h=zero ∘ eltype_kw, dist_logdensity_grad=if (perturbation_dist == Normal(0, 1)) (η, θ) -> (inv.(ε^2 .* θ) .* (η .- θ),) else @@ -92,7 +113,7 @@ function PerturbedMultiplicative( end, ) dist_constructor = MultiplicativePerturbation(perturbation_dist, float(ε)) - return Perturbed( + return LinearPerturbed( maximizer, dist_constructor, dist_logdensity_grad; diff --git a/src/utils/linear_maximizer.jl b/src/utils/linear_maximizer.jl index 74a1aae..c2136c0 100644 --- a/src/utils/linear_maximizer.jl +++ b/src/utils/linear_maximizer.jl @@ -7,10 +7,10 @@ It is compatible with the following layers - [`PerturbedMultiplicative`](@ref) (with or without [`FenchelYoungLoss`](@ref)) - [`SPOPlusLoss`](@ref) """ -@kwdef struct LinearMaximizer{G,H,F} +@kwdef struct LinearMaximizer{F,G,H} maximizer::F - g::G = nothing - h::H = nothing + g::G = identity_kw + h::H = zero ∘ eltype_kw end function Base.show(io::IO, f::LinearMaximizer) @@ -23,8 +23,6 @@ function (f::LinearMaximizer)(θ::AbstractArray; kwargs...) return f.maximizer(θ; kwargs...) end -objective_value(::LinearMaximizer{Nothing,Nothing}, θ, y; kwargs...) = dot(θ, y) - """ objective_value(f, θ, y, kwargs...) @@ -34,13 +32,11 @@ function objective_value(f::LinearMaximizer, θ, y; kwargs...) return dot(θ, f.g(y; kwargs...)) .+ f.h(y; kwargs...) end -apply_g(::LinearMaximizer{Nothing,Nothing}, y; kwargs...) = y -apply_h(::LinearMaximizer{Nothing,Nothing}, y; kwargs...) = zero(eltype(y)) - function apply_g(f::LinearMaximizer, y; kwargs...) return f.g(y; kwargs...) end +# Might not be needed function apply_h(f::LinearMaximizer, y; kwargs...) return f.h(y; kwargs...) end diff --git a/src/utils/pushforward.jl b/src/utils/pushforward.jl index eaab802..8248f08 100644 --- a/src/utils/pushforward.jl +++ b/src/utils/pushforward.jl @@ -9,7 +9,7 @@ Differentiable pushforward of a probabilistic optimization layer with an arbitra - `optimization_layer::AbstractOptimizationLayer`: probabilistic optimization layer - `post_processing`: callable -See also: [`FixedAtomsProbabilityDistribution`](@ref). +See also: `FixedAtomsProbabilityDistribution`. """ struct Pushforward{O<:AbstractOptimizationLayer,P} <: AbstractLayer optimization_layer::O @@ -26,9 +26,9 @@ end Output the expectation of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.optimization_layer` applied to `θ`. -Unlike [`empirical_distribution(pushforward, θ)`](@ref), this function is differentiable, even if `pushforward.post_processing` isn't. +This function is differentiable, even if `pushforward.post_processing` isn't. -See also: [`compute_expectation`](@ref). +See also: `compute_expectation`. """ function (pushforward::Pushforward)(θ::AbstractArray; kwargs...) (; optimization_layer, post_processing) = pushforward diff --git a/test/learning_ranking.jl b/test/learning_ranking.jl index 3827e4f..b8152e0 100644 --- a/test/learning_ranking.jl +++ b/test/learning_ranking.jl @@ -278,7 +278,7 @@ end instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=Perturbed(f, p; nb_samples=10), + loss=LinearPerturbed(f, p; nb_samples=10), error_function=hamming_distance, true_encoder=true_encoder, cost=cost, From 955b2c55f68dbad6caa935429d077650791d6a66 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Fri, 9 Aug 2024 17:09:22 +0200 Subject: [PATCH 09/26] bump differentiable expectation compat --- Project.toml | 2 +- src/InferOpt.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index c8fd56b..aa639c8 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ InferOptFrankWolfeExt = "DifferentiableFrankWolfe" [compat] ChainRulesCore = "1" DensityInterface = "0.4.0" -DifferentiableExpectations = "0.1" +DifferentiableExpectations = "0.2" DifferentiableFrankWolfe = "0.2" Distributions = "0.25" LinearAlgebra = "<0.0.1,1" diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 8a4dfa1..df851ea 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -92,7 +92,7 @@ export RegularizedFrankWolfe export Perturbed export PerturbedAdditive export PerturbedMultiplicative -# export PerturbedOracle +export LinearPerturbed export FenchelYoungLoss export StructuredSVMLoss From ab1b51a011777264e08847b2d1880a96b53ce7e8 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 14:37:16 +0100 Subject: [PATCH 10/26] fix tests --- test/perturbed.jl | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/test/perturbed.jl b/test/perturbed.jl index 7a09e16..a83aa69 100644 --- a/test/perturbed.jl +++ b/test/perturbed.jl @@ -25,7 +25,7 @@ # Order of diagonal coefficients should follow order of θ @test sortperm(diag(jac1)) == sortperm(θ) # No scaling with nb of samples - @test norm(jac1) ≈ norm(jac1_big) rtol = 1e-2 + @test norm(jac1) ≈ norm(jac1_big) rtol = 5e-2 end @testset "PerturbedMultiplicative" begin @@ -36,8 +36,7 @@ @test all(diag(jac2) .>= 0) @test all(jac2 - Diagonal(jac2) .<= 0) @test sortperm(diag(jac2)) != sortperm(θ) - # This is not equal because the diagonal coefficient for θ₃ = 4 is often larger than the one for θ₂ = 5. It happens because θ₃ has the opportunity to *become* the argmax (and hence switch from 0 to 1), whereas θ₂ already *is* the argmax. - @test norm(jac2) ≈ norm(jac2_big) rtol = 1e-2 + @test norm(jac2) ≈ norm(jac2_big) rtol = 5e-2 end end @@ -51,7 +50,7 @@ end p(θ) = MvNormal(θ, ε^2 * I) oracle(η) = η - po = PerturbedOracle(oracle, p; nb_samples=1_000, seed=0) + po = PerturbedOracle(oracle, p; nb_samples=1_000, seed=0) # TODO: fix this pa = PerturbedAdditive(oracle; ε, nb_samples=1_000, seed=0) θ = randn(10) @@ -68,16 +67,22 @@ end ε = 1.0 oracle(η) = η - pa = PerturbedAdditive(oracle; ε, nb_samples=100, seed=0) - pm = PerturbedMultiplicative(oracle; ε, nb_samples=100, seed=0) + pa = PerturbedAdditive(oracle; ε, nb_samples=100, seed=0, variance_reduction=true) + pa_no_variance_reduction = PerturbedAdditive( + oracle; ε, nb_samples=100, seed=0, variance_reduction=false + ) + pm = PerturbedMultiplicative(oracle; ε, nb_samples=100, seed=0, variance_reduction=true) + pm_no_variance_reduction = PerturbedMultiplicative( + oracle; ε, nb_samples=100, seed=0, variance_reduction=false + ) n = 10 θ = randn(10) - Ja = jacobian(θ -> pa(θ; autodiff_variance_reduction=false), θ)[1] + Ja = jacobian(pa_no_variance_reduction, θ)[1] Ja_reduced_variance = jacobian(pa, θ)[1] - Jm = jacobian(x -> pm(x; autodiff_variance_reduction=false), θ)[1] + Jm = jacobian(pm_no_variance_reduction, θ)[1] Jm_reduced_variance = jacobian(pm, θ)[1] J_true = Matrix(I, n, n) # exact jacobian is the identity matrix From d21480259570efa84bd0e31c5247a1cb26468855 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 14:39:28 +0100 Subject: [PATCH 11/26] fix docs --- docs/make.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index dd4eaff..811f75e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -6,7 +6,6 @@ DocMeta.setdocmeta!(InferOpt, :DocTestSetup, :(using InferOpt); recursive=true) # Copy README.md into docs/src/index.md (overwriting) -<<<<<<< HEAD open(joinpath(@__DIR__, "src", "index.md"), "w") do io println( io, @@ -21,13 +20,6 @@ open(joinpath(@__DIR__, "src", "index.md"), "w") do io println(io, line) end end -======= -cp( - joinpath(dirname(@__DIR__), "README.md"), - joinpath(@__DIR__, "src", "index.md"); - force=true, -) ->>>>>>> main # Parse test/tutorial.jl into docs/src/tutorial.md (overwriting) From f084090923343485eeed4616642b3203b84d4ca6 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 15:43:45 +0100 Subject: [PATCH 12/26] Fix tests --- docs/Manifest.toml | 484 +++++++++++++++---------- src/InferOpt.jl | 2 +- src/layers/perturbed/perturbed.jl | 64 ++-- src/losses/fenchel_young_loss.jl | 9 +- src/utils/linear_maximizer.jl | 4 + test/learning_generalized_maximizer.jl | 24 +- test/paths.jl | 5 +- test/perturbed.jl | 2 +- 8 files changed, 363 insertions(+), 231 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 34d5b48..4bc27fd 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.4" +julia_version = "1.11.1" manifest_format = "2.0" -project_hash = "62c6c2d871b6216a78a27ec0b4dd2c53f1dc4f4f" +project_hash = "6e00f168de8676c54ad2b043cfc74314d50a8935" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -26,31 +26,35 @@ uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" version = "0.4.5" [[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown", "Test"] -git-tree-sha1 = "f61b15be1d76846c0ce31d3fcfac5380ae53db6a" +deps = ["CompositionsBase", "ConstructionBase", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown"] +git-tree-sha1 = "96bed9b1b57cf750cca50c311a197e306816a1cc" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.37" +version = "0.1.39" [deps.Accessors.extensions] AccessorsAxisKeysExt = "AxisKeys" + AccessorsDatesExt = "Dates" AccessorsIntervalSetsExt = "IntervalSets" AccessorsStaticArraysExt = "StaticArrays" AccessorsStructArraysExt = "StructArrays" + AccessorsTestExt = "Test" AccessorsUnitfulExt = "Unitful" [deps.Accessors.weakdeps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" + Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" Requires = "ae029012-a4dd-5104-9daa-d747884805df" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" + Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "6a55b747d1812e699320963ffde36f1ebdda4099" +git-tree-sha1 = "50c3c56a52972d78e8be9fd135bfb91c9574c140" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.0.4" +version = "4.1.1" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -63,13 +67,13 @@ uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" version = "1.1.3" [[deps.ArgCheck]] -git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +git-tree-sha1 = "680b3b8759bd4c54052ada14e52355ab69e07876" uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.3.0" +version = "2.4.0" [[deps.ArgTools]] uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" +version = "1.1.2" [[deps.ArnoldiMethod]] deps = ["LinearAlgebra", "Random", "StaticArrays"] @@ -79,12 +83,23 @@ version = "0.4.0" [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.11.0" [[deps.Atomix]] deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" +git-tree-sha1 = "c3b238aa28c1bebd4b5ea4988bebf27e9a01b72b" uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" +version = "1.0.1" + + [deps.Atomix.extensions] + AtomixCUDAExt = "CUDA" + AtomixMetalExt = "Metal" + AtomixoneAPIExt = "oneAPI" + + [deps.Atomix.weakdeps] + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [[deps.BangBang]] deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] @@ -110,6 +125,7 @@ version = "0.4.3" [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" [[deps.Baselet]] git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" @@ -121,51 +137,45 @@ git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.5.0" -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" - [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "227985d885b4dbce5e18a96f9326ea1e836e5a03" +git-tree-sha1 = "bcffdcaed50d3453673b852f3522404a94b50fad" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.69.0" +version = "1.72.1" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "71acdbf594aab5bbb2cec89b208c41b4c411e49f" +git-tree-sha1 = "3e4b134270b372f2ed4d4d0e936aabaefc1802bc" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.24.0" +version = "1.25.0" weakdeps = ["SparseArrays"] [deps.ChainRulesCore.extensions] ChainRulesCoreSparseArraysExt = "SparseArrays" [[deps.ChunkSplitters]] -deps = ["Compat", "TestItems"] -git-tree-sha1 = "783507c1f2371c8f2d321f41c3057ecd42cafa83" +deps = ["TestItems"] +git-tree-sha1 = "01d5db8756afc4022b1cf267cfede13245226c72" uuid = "ae650224-84b6-46f8-82ea-d812ca08434e" -version = "2.4.5" +version = "2.6.0" [[deps.CodeTracking]] deps = ["InteractiveUtils", "UUIDs"] -git-tree-sha1 = "c0216e792f518b39b22212127d4a84dc31e4e386" +git-tree-sha1 = "7eee164f122511d3e4e1ebadb7956939ea7e1c77" uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" -version = "1.3.5" +version = "1.3.6" [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "b8fe8546d52ca154ac556809e10c75e6e7430ac8" +git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.5" +version = "0.7.6" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "b5278586822443594ff615963b0c09755771b3e0" +git-tree-sha1 = "c785dfb1b3bfddd1da557e861b919819b82bbe5b" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.26.0" +version = "3.27.1" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -190,16 +200,16 @@ uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" version = "0.12.11" [[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +deps = ["MacroTools"] +git-tree-sha1 = "cda2cfaebb4be89c9084adaca7dd7333369715c5" uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" +version = "0.3.1" [[deps.Compat]] deps = ["TOML", "UUIDs"] -git-tree-sha1 = "b1c55339b7c6c350ee89f2c1604299660525b248" +git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.15.0" +version = "4.16.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -220,17 +230,18 @@ weakdeps = ["InverseFunctions"] CompositionsBaseInverseFunctionsExt = "InverseFunctions" [[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d8a9c0b6ac2d9081bf76324b39c78ca3ce4f0c98" +git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.6" +version = "1.5.8" [deps.ConstructionBase.extensions] ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseLinearAlgebraExt = "LinearAlgebra" ConstructionBaseStaticArraysExt = "StaticArrays" [deps.ConstructionBase.weakdeps] IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" + LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.ContextVariablesX]] @@ -268,6 +279,7 @@ version = "1.0.0" [[deps.Dates]] deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +version = "1.11.0" [[deps.DefineSingletons]] git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" @@ -300,19 +312,20 @@ version = "1.15.1" [[deps.DifferentiableExpectations]] deps = ["ChainRulesCore", "Compat", "DensityInterface", "Distributions", "DocStringExtensions", "LinearAlgebra", "OhMyThreads", "Random", "Statistics", "StatsBase"] -path = "../../DifferentiableExpectations.jl" +git-tree-sha1 = "829dd95b32a41526923f44799ce0762fcd9a3a37" uuid = "fc55d66b-b2a8-4ccc-9d64-c0c2166ceb36" -version = "0.1.0" +version = "0.2.0" [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" +version = "1.11.0" [[deps.Distributions]] deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "9c405847cc7ecda2dc921ccf18b47ca150d7317e" +git-tree-sha1 = "4b138e4643b577ccf355377c2bc70fa975af25de" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.109" +version = "0.25.115" weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] [deps.Distributions.extensions] @@ -328,26 +341,29 @@ version = "0.9.3" [[deps.Documenter]] deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"] -git-tree-sha1 = "76deb8c15f37a3853f13ea2226b8f2577652de05" +git-tree-sha1 = "d0ea2c044963ed6f37703cead7e29f70cba13d7e" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "1.5.0" +version = "1.8.0" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" +[[deps.EnzymeCore]] +git-tree-sha1 = "0cdb7af5c39e92d78a0ee8d0a447d32f7593137e" +uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" +version = "0.8.8" +weakdeps = ["Adapt"] + + [deps.EnzymeCore.extensions] + AdaptExt = "Adapt" [[deps.Expat_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" +git-tree-sha1 = "e51db81749b0777b2147fbe7b783ee79045b8e99" uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.6.2+0" +version = "2.6.4+1" [[deps.FLoops]] deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] @@ -363,12 +379,19 @@ version = "0.1.1" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" +version = "1.11.0" [[deps.FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "7072f1e3e5a8be51d525d64f63d3ec1287ff2790" +deps = ["LinearAlgebra"] +git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.13.11" +version = "1.13.0" +weakdeps = ["PDMats", "SparseArrays", "Statistics"] + + [deps.FillArrays.extensions] + FillArraysPDMatsExt = "PDMats" + FillArraysSparseArraysExt = "SparseArrays" + FillArraysStatisticsExt = "Statistics" [[deps.FixedPointNumbers]] deps = ["Statistics"] @@ -377,54 +400,59 @@ uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" version = "0.8.5" [[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Functors", "LinearAlgebra", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "edacf029ed6276301e455e34d7ceeba8cc34078a" +deps = ["Adapt", "ChainRulesCore", "Compat", "EnzymeCore", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] +git-tree-sha1 = "86729467baa309581eb0e648b9ede0aeb40016be" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.14.16" +version = "0.16.0" [deps.Flux.extensions] FluxAMDGPUExt = "AMDGPU" FluxCUDAExt = "CUDA" FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] - FluxMetalExt = "Metal" + FluxEnzymeExt = "Enzyme" + FluxMPIExt = "MPI" + FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"] [deps.Flux.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" + NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +git-tree-sha1 = "a2df1b776752e3f344e5116c06d75a10436ab853" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" +version = "0.10.38" weakdeps = ["StaticArrays"] [deps.ForwardDiff.extensions] ForwardDiffStaticArraysExt = "StaticArrays" [[deps.Functors]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "64d8e93700c7a3f28f717d265382d52fac9fa1c1" +deps = ["Compat", "ConstructionBase", "LinearAlgebra", "Random"] +git-tree-sha1 = "60a0339f28a233601cb74468032b5c302d5067de" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.12" +version = "0.5.2" [[deps.Future]] deps = ["Random"] uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" +version = "1.11.0" [[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "a74c3f1cf56a3dfcdef0605f8cdb7015926aae30" +deps = ["Adapt", "GPUArraysCore", "KernelAbstractions", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] +git-tree-sha1 = "4ec797b1b2ee964de0db96f10cce05b81f23e108" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "10.3.0" +version = "11.1.0" [[deps.GPUArraysCore]] deps = ["Adapt"] -git-tree-sha1 = "ec632f177c0d990e64d955ccc1b8c04c485a0950" +git-tree-sha1 = "83cf05ab16a73219e5f6bd1bdfa9848fa24ac627" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.6" +version = "0.2.0" [[deps.Git]] deps = ["Git_jll"] @@ -434,27 +462,27 @@ version = "1.3.1" [[deps.Git_jll]] deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "d18fb8a1f3609361ebda9bf029b60fd0f120c809" +git-tree-sha1 = "399f4a308c804b446ae4c91eeafadb2fe2c54ff9" uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb" -version = "2.44.0+2" +version = "2.47.1+0" [[deps.Graphs]] deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "ebd18c326fa6cee1efb7da9a3b45cf69da2ed4d9" +git-tree-sha1 = "1dc470db8b1131cfc7fb4c115de89fe391b9e780" uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.11.2" +version = "1.12.0" [[deps.GridGraphs]] deps = ["DataStructures", "FillArrays", "Graphs", "SparseArrays"] -git-tree-sha1 = "84145cbcffc84c0d60099c6ac7c989885139bf09" +git-tree-sha1 = "219584c79a649f5ae4301ff0924f235e92cc60d4" uuid = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb" -version = "0.10.0" +version = "0.10.1" [[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" +deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "b1c2585431c382e3fe5805874bda6aea90a95de9" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.23" +version = "0.3.25" [[deps.IOCapture]] deps = ["Logging", "Random"] @@ -469,7 +497,7 @@ uuid = "7869d1d1-7146-5819-86e3-90919afe41df" version = "0.4.14" [[deps.InferOpt]] -deps = ["ChainRulesCore", "DensityInterface", "DifferentiableExpectations", "Distributions", "DocStringExtensions", "LinearAlgebra", "Random", "RequiredInterfaces", "Statistics", "StatsBase", "StatsFuns", "ThreadsX"] +deps = ["ChainRulesCore", "DensityInterface", "DifferentiableExpectations", "Distributions", "LinearAlgebra", "Random", "RequiredInterfaces", "Statistics", "StatsBase", "StatsFuns", "ThreadsX"] path = ".." uuid = "4846b161-c94e-4150-8dac-c7ae193c601f" version = "0.6.1" @@ -493,16 +521,17 @@ version = "0.3.1" [[deps.InteractiveUtils]] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" [[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "18c59411ece4838b18cd7f537e56cf5e41ce5bfd" +git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb" uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.15" -weakdeps = ["Dates"] +version = "0.1.17" +weakdeps = ["Dates", "Test"] [deps.InverseFunctions.extensions] - DatesExt = "Dates" + InverseFunctionsDatesExt = "Dates" + InverseFunctionsTestExt = "Test" [[deps.IrrationalConstants]] git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" @@ -516,9 +545,9 @@ version = "1.0.0" [[deps.JLLWrappers]] deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b" uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" +version = "1.6.1" [[deps.JSON]] deps = ["Dates", "Mmap", "Parsers", "Unicode"] @@ -528,9 +557,9 @@ version = "0.21.4" [[deps.JuliaInterpreter]] deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] -git-tree-sha1 = "5d3a5a206297af3868151bb4a2cf27ebce46f16d" +git-tree-sha1 = "10da5154188682e5c0726823c2b5125957ec3778" uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" -version = "0.9.33" +version = "0.9.38" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -539,22 +568,22 @@ uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" version = "0.2.4" [[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "d0448cebd5919e06ca5edc7a264631790de810ec" +deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs"] +git-tree-sha1 = "b9a838cd3028785ac23822cded5126b3da394d1a" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.22" +version = "0.9.31" +weakdeps = ["EnzymeCore", "LinearAlgebra", "SparseArrays"] [deps.KernelAbstractions.extensions] EnzymeExt = "EnzymeCore" - - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + LinearAlgebraExt = "LinearAlgebra" + SparseArraysExt = "SparseArrays" [[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Requires", "Unicode"] -git-tree-sha1 = "020abd49586480c1be84f57da0017b5d3db73f7c" +deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Unicode"] +git-tree-sha1 = "d422dfd9707bec6617335dc2ea3c5172a87d5908" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "8.0.0" +version = "9.1.3" [deps.LLVM.extensions] BFloat16sExt = "BFloat16s" @@ -564,18 +593,19 @@ version = "8.0.0" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c2636c264861edc6d305e6b4d528f09566d24c5e" +git-tree-sha1 = "05a8bd5a42309a9ec82f700876903abce1017dd3" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.30+0" +version = "0.0.34+0" [[deps.LazilyInitializedFields]] -git-tree-sha1 = "8f7f3cabab0fd1800699663533b6d5cb3fc0e612" +git-tree-sha1 = "0f2da712350b020bc3957f269c9caad516383ee0" uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf" -version = "1.2.2" +version = "1.3.0" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" +version = "1.11.0" [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] @@ -585,16 +615,17 @@ version = "0.6.4" [[deps.LibCURL_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" +version = "8.6.0+0" [[deps.LibGit2]] deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +version = "1.11.0" [[deps.LibGit2_jll]] deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" +version = "1.7.2+0" [[deps.LibSSH2_jll]] deps = ["Artifacts", "Libdl", "MbedTLS_jll"] @@ -603,28 +634,30 @@ version = "1.11.0+1" [[deps.Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +version = "1.11.0" [[deps.Libiconv_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" +git-tree-sha1 = "61dfdba58e585066d8bce214c5a51eaa0539f269" uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.17.0+0" +version = "1.17.0+1" [[deps.LinearAlgebra]] deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +version = "1.11.0" [[deps.Literate]] deps = ["Base64", "IOCapture", "JSON", "REPL"] -git-tree-sha1 = "eef2e1fc1dc38af90a18eb16e519e06d1fd10c2a" +git-tree-sha1 = "da046be6d63304f7ba9c1bb04820fb306ba1ab12" uuid = "98b081ad-f1c9-55d3-8b20-4c87d4299306" -version = "2.19.0" +version = "2.20.1" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "a2d09619db4e765091ee5c6ffe8872849de0feea" +git-tree-sha1 = "13ca9e2586b89836fd20cccf56e57e2b9ae7f38f" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.28" +version = "0.3.29" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -638,12 +671,59 @@ version = "0.3.28" [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +version = "1.11.0" [[deps.LoweredCodeUtils]] deps = ["JuliaInterpreter"] -git-tree-sha1 = "1ce1834f9644a8f7c011eb0592b7fd6c42c90653" +git-tree-sha1 = "688d6d9e098109051ae33d126fcfc88c4ce4a021" uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" -version = "3.0.1" +version = "3.1.0" + +[[deps.MLDataDevices]] +deps = ["Adapt", "Compat", "Functors", "Preferences", "Random"] +git-tree-sha1 = "80eb04ae663507d9303473d26710a4c62efa0f3c" +uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" +version = "1.6.5" + + [deps.MLDataDevices.extensions] + MLDataDevicesAMDGPUExt = "AMDGPU" + MLDataDevicesCUDAExt = "CUDA" + MLDataDevicesChainRulesCoreExt = "ChainRulesCore" + MLDataDevicesChainRulesExt = "ChainRules" + MLDataDevicesComponentArraysExt = "ComponentArrays" + MLDataDevicesFillArraysExt = "FillArrays" + MLDataDevicesGPUArraysExt = "GPUArrays" + MLDataDevicesMLUtilsExt = "MLUtils" + MLDataDevicesMetalExt = ["GPUArrays", "Metal"] + MLDataDevicesOneHotArraysExt = "OneHotArrays" + MLDataDevicesReactantExt = "Reactant" + MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" + MLDataDevicesReverseDiffExt = "ReverseDiff" + MLDataDevicesSparseArraysExt = "SparseArrays" + MLDataDevicesTrackerExt = "Tracker" + MLDataDevicesZygoteExt = "Zygote" + MLDataDevicescuDNNExt = ["CUDA", "cuDNN"] + MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] + + [deps.MLDataDevices.weakdeps] + AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" + CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" + ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" + FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" + GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" + MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" + Metal = "dde4c033-4e86-420c-a63e-0dd931031962" + OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" + Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" + RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" + ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" + SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" + oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [[deps.MLStyle]] git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" @@ -664,13 +744,14 @@ version = "0.5.13" [[deps.MarchingCubes]] deps = ["PrecompileTools", "StaticArrays"] -git-tree-sha1 = "27d162f37cc29de047b527dab11a826dd3a650ad" +git-tree-sha1 = "301345b808264ae42e60d10a519e55c5d992969b" uuid = "299715c1-40a9-479a-aaf9-4a633d36f717" -version = "0.1.9" +version = "0.1.10" [[deps.Markdown]] deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" [[deps.MarkdownAST]] deps = ["AbstractTrees", "Markdown"] @@ -681,7 +762,7 @@ version = "0.1.2" [[deps.MbedTLS_jll]] deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" +version = "2.28.6+0" [[deps.MicroCollections]] deps = ["Accessors", "BangBang", "InitialValues"] @@ -697,16 +778,17 @@ version = "1.2.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +version = "1.11.0" [[deps.MozillaCACerts_jll]] uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" +version = "2023.12.12" [[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Pkg", "Random", "Requires", "Statistics"] -git-tree-sha1 = "190dcada8cf9520198058c4544862b1f88c6c577" +deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "1177f161cda2083543b9967d7ca2a3e24e721e13" uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.21" +version = "0.9.26" [deps.NNlib.extensions] NNlibAMDGPUExt = "AMDGPU" @@ -714,12 +796,14 @@ version = "0.9.21" NNlibCUDAExt = "CUDA" NNlibEnzymeCoreExt = "EnzymeCore" NNlibFFTWExt = "FFTW" + NNlibForwardDiffExt = "ForwardDiff" [deps.NNlib.weakdeps] AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" + ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" [[deps.NaNMath]] @@ -746,14 +830,14 @@ version = "0.5.3" [[deps.OneHotArrays]] deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "963a3f28a2e65bb87a68033ea4a616002406037d" +git-tree-sha1 = "c8c7f6bfabe581dc40b580313a75f1ecce087e27" uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.5" +version = "0.2.6" [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.23+4" +version = "0.3.27+1" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -762,9 +846,9 @@ version = "0.8.1+2" [[deps.OpenSSL_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" +git-tree-sha1 = "7493f61f55a6cce7325f197443aa80d32554ba10" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" +version = "3.0.15+1" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -774,14 +858,19 @@ version = "0.5.5+0" [[deps.Optimisers]] deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "6572fe0c5b74431aaeb0b18a4aa5ef03c84678be" +git-tree-sha1 = "c5feff34a5cf6bdc6ca06de0c5b7d6847199f1c0" uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.3.3" +version = "0.4.2" +weakdeps = ["Adapt", "EnzymeCore"] + + [deps.Optimisers.extensions] + OptimisersAdaptExt = ["Adapt"] + OptimisersEnzymeCoreExt = "EnzymeCore" [[deps.OrderedCollections]] -git-tree-sha1 = "dfdf5519f235516220579f949664f1bf44e741c5" +git-tree-sha1 = "12f1439c4f986bb868acda6ea33ebc78e19b95ad" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.3" +version = "1.7.0" [[deps.PCRE2_jll]] deps = ["Artifacts", "Libdl"] @@ -801,9 +890,13 @@ uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "2.8.1" [[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" +version = "1.11.0" +weakdeps = ["REPL"] + + [deps.Pkg.extensions] + REPLExt = "REPL" [[deps.PrecompileTools]] deps = ["Preferences"] @@ -825,6 +918,7 @@ version = "0.2.0" [[deps.Printf]] deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +version = "1.11.0" [[deps.ProgressLogging]] deps = ["Logging", "SHA", "UUIDs"] @@ -839,23 +933,31 @@ uuid = "92933f4c-e287-5a05-a399-4b506db050ca" version = "1.10.2" [[deps.PtrArrays]] -git-tree-sha1 = "f011fbb92c4d401059b2212c05c0601b70f8b759" +git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f" uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" -version = "1.2.0" +version = "1.2.1" [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "e237232771fdafbae3db5c31275303e056afaa9f" +git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da" uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.10.1" +version = "2.11.1" + + [deps.QuadGK.extensions] + QuadGKEnzymeExt = "Enzyme" + + [deps.QuadGK.weakdeps] + Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +deps = ["InteractiveUtils", "Markdown", "Sockets", "StyledStrings", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" +version = "1.11.0" [[deps.Random]] deps = ["SHA"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" [[deps.RealDot]] deps = ["LinearAlgebra"] @@ -882,9 +984,9 @@ version = "0.1.0" [[deps.RequiredInterfaces]] deps = ["InteractiveUtils", "Logging", "Test"] -git-tree-sha1 = "c3250333ea2894237ed015baf7d5fcb8a1ea3169" +git-tree-sha1 = "f4e7fec4fa52d0919f18fec552d2fabf9e94811d" uuid = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" -version = "0.1.6" +version = "0.1.7" [[deps.Requires]] deps = ["UUIDs"] @@ -894,21 +996,21 @@ version = "1.3.0" [[deps.Revise]] deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "REPL", "Requires", "UUIDs", "Unicode"] -git-tree-sha1 = "7b7850bb94f75762d567834d7e9802fc22d62f9c" +git-tree-sha1 = "470f48c9c4ea2170fd4d0f8eb5118327aada22f5" uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" -version = "3.5.18" +version = "3.6.4" [[deps.Rmath]] deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4" uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" +version = "0.8.0" [[deps.Rmath_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "d483cd324ce5cf5d61b77930f0bbd6cb61927d21" +git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.2+0" +version = "0.5.1+0" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" @@ -916,6 +1018,7 @@ version = "0.7.0" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +version = "1.11.0" [[deps.Setfield]] deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] @@ -926,6 +1029,7 @@ version = "1.1.1" [[deps.SharedArrays]] deps = ["Distributed", "Mmap", "Random", "Serialization"] uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" +version = "1.11.0" [[deps.ShowCases]] git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" @@ -940,6 +1044,7 @@ version = "0.9.4" [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +version = "1.11.0" [[deps.SortingAlgorithms]] deps = ["DataStructures"] @@ -950,7 +1055,7 @@ version = "1.2.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.10.0" +version = "1.11.0" [[deps.SparseInverseSubset]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] @@ -960,9 +1065,9 @@ version = "0.1.2" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "2f5d4697f21388cbe1ff299430dd169ef97d7e14" +git-tree-sha1 = "64cca0c26b4f31ba18f13f6c12af7c85f478cfde" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.4.0" +version = "2.5.0" weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] @@ -981,9 +1086,9 @@ version = "0.1.5" [[deps.StaticArrays]] deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "eeafab08ae20c62c44c8399ccb9354a04b80db50" +git-tree-sha1 = "7c01731da8ab6d3094c4d44c9057b00932459255" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.7" +version = "1.9.9" weakdeps = ["ChainRulesCore", "Statistics"] [deps.StaticArrays.extensions] @@ -996,9 +1101,14 @@ uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" version = "1.4.3" [[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] +deps = ["LinearAlgebra"] +git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.10.0" +version = "1.11.1" +weakdeps = ["SparseArrays"] + + [deps.Statistics.extensions] + SparseArraysExt = ["SparseArrays"] [[deps.StatsAPI]] deps = ["LinearAlgebra"] @@ -1007,16 +1117,16 @@ uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" version = "1.7.0" [[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "5cf7606d6cef84b543b483848d4ae08ad9832b21" +deps = ["AliasTables", "DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "29321314c920c26684834965ec2ce0dacc9cf8e5" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.3" +version = "0.34.4" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "cef0472124fab0695b58ca35a77c6fb942fdab8a" +git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46" uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.1" +version = "1.3.2" weakdeps = ["ChainRulesCore", "InverseFunctions"] [deps.StatsFuns.extensions] @@ -1025,17 +1135,22 @@ weakdeps = ["ChainRulesCore", "InverseFunctions"] [[deps.StructArrays]] deps = ["ConstructionBase", "DataAPI", "Tables"] -git-tree-sha1 = "f4dc295e983502292c4c3f951dbb4e985e35b3be" +git-tree-sha1 = "9537ef82c42cdd8c5d443cbc359110cbb36bae10" uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.18" -weakdeps = ["Adapt", "GPUArraysCore", "SparseArrays", "StaticArrays"] +version = "0.6.21" +weakdeps = ["Adapt", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "SparseArrays", "StaticArrays"] [deps.StructArrays.extensions] StructArraysAdaptExt = "Adapt" - StructArraysGPUArraysCoreExt = "GPUArraysCore" + StructArraysGPUArraysCoreExt = ["GPUArraysCore", "KernelAbstractions"] + StructArraysLinearAlgebraExt = "LinearAlgebra" StructArraysSparseArraysExt = "SparseArrays" StructArraysStaticArraysExt = "StaticArrays" +[[deps.StyledStrings]] +uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b" +version = "1.11.0" + [[deps.SuiteSparse]] deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" @@ -1043,7 +1158,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.2.1+1" +version = "7.7.0+0" [[deps.TOML]] deps = ["Dates"] @@ -1068,9 +1183,9 @@ uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" version = "1.10.0" [[deps.TaskLocalValues]] -git-tree-sha1 = "eb0b8d147eb907a9ad3fd952da7c6a053b29ae28" +git-tree-sha1 = "d155450e6dff2a8bc2fcb81dcb194bd98b0aeb46" uuid = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34" -version = "0.1.1" +version = "0.1.2" [[deps.TensorCore]] deps = ["LinearAlgebra"] @@ -1081,6 +1196,7 @@ version = "0.1.1" [[deps.Test]] deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +version = "1.11.0" [[deps.TestItems]] git-tree-sha1 = "42fd9023fef18b9b78c8343a4e2f3813ffbcefcb" @@ -1094,21 +1210,18 @@ uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" version = "0.1.12" [[deps.TranscodingStreams]] -git-tree-sha1 = "96612ac5365777520c3c5396314c8cf7408f436a" +git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.1" -weakdeps = ["Random", "Test"] - - [deps.TranscodingStreams.extensions] - TestExt = ["Test", "Random"] +version = "0.11.3" [[deps.Transducers]] -deps = ["Accessors", "Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "5215a069867476fc8e3469602006b9670e68da23" +deps = ["Accessors", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] +git-tree-sha1 = "7deeab4ff96b85c5f72c824cae53a1398da3d1cb" uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.82" +version = "0.4.84" [deps.Transducers.extensions] + TransducersAdaptExt = "Adapt" TransducersBlockArraysExt = "BlockArrays" TransducersDataFramesExt = "DataFrames" TransducersLazyArraysExt = "LazyArrays" @@ -1116,6 +1229,7 @@ version = "0.4.82" TransducersReferenceablesExt = "Referenceables" [deps.Transducers.weakdeps] + Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" @@ -1125,15 +1239,17 @@ version = "0.4.82" [[deps.UUIDs]] deps = ["Random", "SHA"] uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +version = "1.11.0" [[deps.Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" +version = "1.11.0" [[deps.UnicodePlots]] -deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "PrecompileTools", "Printf", "Requires", "SparseArrays", "StaticArrays", "StatsBase"] -git-tree-sha1 = "30646456e889c18fb3c23e58b2fc5da23644f752" +deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "PrecompileTools", "Printf", "SparseArrays", "StaticArrays", "StatsBase"] +git-tree-sha1 = "f18128aa9e5cf059426a91bdc750b1f63a2fdcd9" uuid = "b8865327-cd53-5732-bb35-84acbb429228" -version = "3.6.4" +version = "3.7.1" [deps.UnicodePlots.extensions] FreeTypeExt = ["FileIO", "FreeType"] @@ -1151,15 +1267,13 @@ version = "3.6.4" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" +git-tree-sha1 = "b13c4edda90890e5b04ba24e20a310fbe6f249ff" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" +version = "0.3.0" +weakdeps = ["LLVM"] -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "bf2c553f25e954a9b38c9c0593a59bb13113f9e5" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.5" + [deps.UnsafeAtomics.extensions] + UnsafeAtomicsLLVM = ["LLVM"] [[deps.Zlib_jll]] deps = ["Libdl"] @@ -1168,9 +1282,9 @@ version = "1.2.13+1" [[deps.Zygote]] deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "19c586905e78a26f7e4e97f81716057bd6b1bc54" +git-tree-sha1 = "c7dc3148a64d1cd3768c29b3db5972d1c302661b" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.70" +version = "0.6.73" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -1191,12 +1305,12 @@ version = "0.2.5" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+1" +version = "5.11.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" +version = "1.59.0+0" [[deps.p7zip_jll]] deps = ["Artifacts", "Libdl"] diff --git a/src/InferOpt.jl b/src/InferOpt.jl index df851ea..4ca32d3 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -89,7 +89,7 @@ export SoftRank, soft_rank, soft_rank_l2, soft_rank_kl export SoftSort, soft_sort, soft_sort_l2, soft_sort_kl export RegularizedFrankWolfe -export Perturbed +export PerturbedOracle export PerturbedAdditive export PerturbedMultiplicative export LinearPerturbed diff --git a/src/layers/perturbed/perturbed.jl b/src/layers/perturbed/perturbed.jl index 06dafb0..b801118 100644 --- a/src/layers/perturbed/perturbed.jl +++ b/src/layers/perturbed/perturbed.jl @@ -1,5 +1,5 @@ """ - Perturbed{D,F} <: AbstractOptimizationLayer + PerturbedOracle{D,F} <: AbstractOptimizationLayer Differentiable perturbation of a black box optimizer of type `F`, with perturbation of type `D`. @@ -10,23 +10,25 @@ There are three different available constructors that behave differently in the - [`PerturbedAdditive`](@ref) - [`PerturbedMultiplicative`](@ref) """ -struct Perturbed{D,F,t,variance_reduction,G,R,S} <: AbstractOptimizationLayer +struct PerturbedOracle{D,F,t,variance_reduction,G,R,S} <: AbstractOptimizationLayer reinforce::Reinforce{t,variance_reduction,F,D,G,R,S} end -function (perturbed::Perturbed)(θ::AbstractArray; kwargs...) +function (perturbed::PerturbedOracle)(θ::AbstractArray; kwargs...) return perturbed.reinforce(θ; kwargs...) end -function get_maximizer(perturbed::Perturbed) +function get_maximizer(perturbed::PerturbedOracle) return perturbed.reinforce.f end -function compute_probability_distribution(perturbed::Perturbed, θ::AbstractArray; kwargs...) +function compute_probability_distribution( + perturbed::PerturbedOracle, θ::AbstractArray; kwargs... +) return empirical_distribution(perturbed.reinforce, θ; kwargs...) end -function Base.show(io::IO, perturbed::Perturbed{<:AbstractPerturbation}) +function Base.show(io::IO, perturbed::PerturbedOracle{<:AbstractPerturbation}) (; reinforce) = perturbed nb_samples = reinforce.nb_samples ε = reinforce.dist_constructor.ε @@ -36,24 +38,36 @@ function Base.show(io::IO, perturbed::Perturbed{<:AbstractPerturbation}) f = reinforce.f return print( io, - "Perturbed($f, ε=$ε, nb_samples=$nb_samples, perturbation=$perturbation, rng=$(typeof(rng)), seed=$seed)", + "PerturbedOracle($f, ε=$ε, nb_samples=$nb_samples, perturbation=$perturbation, rng=$(typeof(rng)), seed=$seed)", ) end """ doc """ -function LinearPerturbed( +function PerturbedOracle( maximizer, - dist_constructor, - dist_logdensity_grad=nothing; - g=nothing, - h=nothing, + dist_constructor; + dist_logdensity_grad=nothing, + nb_samples=1, + variance_reduction=true, + threaded=false, + seed=nothing, + rng=Random.default_rng(), kwargs..., ) - linear_maximizer = LinearMaximizer(; maximizer, g, h) - return Perturbed( - Reinforce(linear_maximizer, dist_constructor, dist_logdensity_grad; kwargs...) + return PerturbedOracle( + Reinforce( + maximizer, + dist_constructor, + dist_logdensity_grad; + nb_samples, + variance_reduction, + threaded, + seed, + rng, + kwargs..., + ), ) end @@ -69,8 +83,6 @@ function PerturbedAdditive( seed=nothing, threaded=false, rng=Random.default_rng(), - g=identity_kw, - h=zero ∘ eltype_kw, dist_logdensity_grad=if (perturbation_dist == Normal(0, 1)) (η, θ) -> ((η .- θ) ./ ε^2,) else @@ -78,17 +90,15 @@ function PerturbedAdditive( end, ) dist_constructor = AdditivePerturbation(perturbation_dist, float(ε)) - return LinearPerturbed( + return PerturbedOracle( maximizer, - dist_constructor, - dist_logdensity_grad; + dist_constructor; + dist_logdensity_grad, nb_samples, variance_reduction, seed, threaded, rng, - g, - h, ) end @@ -104,8 +114,6 @@ function PerturbedMultiplicative( seed=nothing, threaded=false, rng=Random.default_rng(), - g=identity_kw, - h=zero ∘ eltype_kw, dist_logdensity_grad=if (perturbation_dist == Normal(0, 1)) (η, θ) -> (inv.(ε^2 .* θ) .* (η .- θ),) else @@ -113,16 +121,14 @@ function PerturbedMultiplicative( end, ) dist_constructor = MultiplicativePerturbation(perturbation_dist, float(ε)) - return LinearPerturbed( + return PerturbedOracle( maximizer, - dist_constructor, - dist_logdensity_grad; + dist_constructor; + dist_logdensity_grad, nb_samples, variance_reduction, seed, threaded, rng, - g, - h, ) end diff --git a/src/losses/fenchel_young_loss.jl b/src/losses/fenchel_young_loss.jl index 3343170..683f1b2 100644 --- a/src/losses/fenchel_young_loss.jl +++ b/src/losses/fenchel_young_loss.jl @@ -59,7 +59,10 @@ function fenchel_young_loss_and_grad( end function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{<:Perturbed}, θ::AbstractArray, y_true::AbstractArray; kwargs... + fyl::FenchelYoungLoss{<:PerturbedOracle}, + θ::AbstractArray, + y_true::AbstractArray; + kwargs..., ) (; optimization_layer) = fyl maximizer = get_maximizer(optimization_layer) @@ -82,7 +85,7 @@ end ## Specific overrides for perturbed layers function fenchel_young_F_and_first_part_of_grad( - perturbed::Perturbed{<:AdditivePerturbation}, θ::AbstractArray; kwargs... + perturbed::PerturbedOracle{<:AdditivePerturbation}, θ::AbstractArray; kwargs... ) (; reinforce) = perturbed maximizer = get_maximizer(perturbed) @@ -98,7 +101,7 @@ function fenchel_young_F_and_first_part_of_grad( end function fenchel_young_F_and_first_part_of_grad( - perturbed::Perturbed{<:MultiplicativePerturbation}, θ::AbstractArray; kwargs... + perturbed::PerturbedOracle{<:MultiplicativePerturbation}, θ::AbstractArray; kwargs... ) (; reinforce) = perturbed maximizer = get_maximizer(perturbed) diff --git a/src/utils/linear_maximizer.jl b/src/utils/linear_maximizer.jl index c2136c0..684edc2 100644 --- a/src/utils/linear_maximizer.jl +++ b/src/utils/linear_maximizer.jl @@ -18,6 +18,10 @@ function Base.show(io::IO, f::LinearMaximizer) return print(io, "LinearMaximizer($maximizer, $g, $h)") end +function LinearMaximizer(maximizer; g=identity_kw, h=zero ∘ eltype_kw) + return LinearMaximizer(maximizer, g, h) +end + # Callable calls the wrapped maximizer function (f::LinearMaximizer)(θ::AbstractArray; kwargs...) return f.maximizer(θ; kwargs...) diff --git a/test/learning_generalized_maximizer.jl b/test/learning_generalized_maximizer.jl index 1d34cde..3b37ba8 100644 --- a/test/learning_generalized_maximizer.jl +++ b/test/learning_generalized_maximizer.jl @@ -13,7 +13,7 @@ @test y == [1 0 1; 0 1 0; 1 1 1] - generalized_maximizer = LinearMaximizer(; maximizer=max_pricing, g, h) + generalized_maximizer = LinearMaximizer(max_pricing; g, h) @test generalized_maximizer(θ; instance) == y @@ -29,8 +29,8 @@ end true_encoder = encoder_factory() - perturbed = PerturbedAdditive(max_pricing; ε=1.0, nb_samples=10, g, h) - maximizer = InferOpt.get_maximizer(perturbed) + maximizer = LinearMaximizer(max_pricing; g, h) + perturbed = PerturbedAdditive(maximizer; ε=1.0, nb_samples=10) function cost(y; instance) return -objective_value(maximizer, true_encoder(instance), y; instance) end @@ -55,8 +55,8 @@ end true_encoder = encoder_factory() - perturbed = PerturbedMultiplicative(max_pricing; ε=1.0, nb_samples=10, g, h) - maximizer = InferOpt.get_maximizer(perturbed) + maximizer = LinearMaximizer(max_pricing; g, h) + perturbed = PerturbedMultiplicative(maximizer; ε=1.0, nb_samples=10) function cost(y; instance) return -objective_value(maximizer, true_encoder(instance), y; instance) end @@ -80,8 +80,10 @@ end true_encoder = encoder_factory() - perturbed = PerturbedAdditive(max_pricing; ε=1.0, nb_samples=10, g, h) - maximizer = InferOpt.get_maximizer(perturbed) + maximizer = LinearMaximizer(max_pricing; g, h) + @info maximizer g h + perturbed = PerturbedAdditive(maximizer; ε=1.0, nb_samples=10) + @info perturbed function cost(y; instance) return -objective_value(maximizer, true_encoder(instance), y; instance) end @@ -106,8 +108,8 @@ end true_encoder = encoder_factory() - perturbed = PerturbedMultiplicative(max_pricing; ε=0.1, nb_samples=10, g, h) - maximizer = InferOpt.get_maximizer(perturbed) + maximizer = LinearMaximizer(max_pricing; g, h) + perturbed = PerturbedMultiplicative(maximizer; ε=0.1, nb_samples=10) function cost(y; instance) return -objective_value(maximizer, true_encoder(instance), y; instance) end @@ -180,7 +182,7 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + generalized_maximizer = LinearMaximizer(max_pricing; g, h) function cost(y; instance) return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) end @@ -207,7 +209,7 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + generalized_maximizer = LinearMaximizer(max_pricing; g, h) function cost(y; instance) return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) end diff --git a/test/paths.jl b/test/paths.jl index 0681e8a..fb6a32f 100644 --- a/test/paths.jl +++ b/test/paths.jl @@ -155,7 +155,10 @@ end maximizer=identity_kw, loss=FenchelYoungLoss( PerturbedAdditive( - shortest_path_maximizer; ε=1.0, nb_samples=5, perturbation=LogNormal(0, 1) + shortest_path_maximizer; + ε=1.0, + nb_samples=5, + perturbation_dist=LogNormal(0, 1), ), ), error_function=mse_kw, diff --git a/test/perturbed.jl b/test/perturbed.jl index a83aa69..61d5399 100644 --- a/test/perturbed.jl +++ b/test/perturbed.jl @@ -50,7 +50,7 @@ end p(θ) = MvNormal(θ, ε^2 * I) oracle(η) = η - po = PerturbedOracle(oracle, p; nb_samples=1_000, seed=0) # TODO: fix this + po = PerturbedOracle(oracle, p; nb_samples=1_000, seed=0) pa = PerturbedAdditive(oracle; ε, nb_samples=1_000, seed=0) θ = randn(10) From 9469bf0fa9cef53c855d665657d32b82deefda97 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 15:46:42 +0100 Subject: [PATCH 13/26] fix exports --- src/InferOpt.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 4ca32d3..6886147 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -92,7 +92,6 @@ export RegularizedFrankWolfe export PerturbedOracle export PerturbedAdditive export PerturbedMultiplicative -export LinearPerturbed export FenchelYoungLoss export StructuredSVMLoss From f0fa8cc211bcce0fce929219106f319dae849e0e Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 15:49:48 +0100 Subject: [PATCH 14/26] fix docstring --- src/layers/perturbed/perturbed.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/perturbed/perturbed.jl b/src/layers/perturbed/perturbed.jl index b801118..6fff4d9 100644 --- a/src/layers/perturbed/perturbed.jl +++ b/src/layers/perturbed/perturbed.jl @@ -6,7 +6,7 @@ Differentiable perturbation of a black box optimizer of type `F`, with perturbat This struct is as wrapper around `Reinforce` from DifferentiableExpectations.jl. There are three different available constructors that behave differently in the package: -- [`LinearPerturbed`](@ref) +- [`PerturbedOracle`](@ref) - [`PerturbedAdditive`](@ref) - [`PerturbedMultiplicative`](@ref) """ From a17c51fdb33444d38f7dfb174834a76096df770f Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 16:28:56 +0100 Subject: [PATCH 15/26] Fix tests and add DocStringExtensions as dependency --- Project.toml | 2 ++ src/InferOpt.jl | 1 + src/layers/perturbed/perturbation.jl | 43 ++++++++++++++++++++++++++-- src/layers/perturbed/perturbed.jl | 19 +++++++++--- src/layers/perturbed/utils.jl | 14 +++++++++ src/losses/fenchel_young_loss.jl | 9 +++--- src/utils/linear_maximizer.jl | 4 +++ test/learning_ranking.jl | 2 +- 8 files changed, 83 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index a642bb5..b1c0058 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" DifferentiableExpectations = "fc55d66b-b2a8-4ccc-9d64-c0c2166ceb36" DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" @@ -29,6 +30,7 @@ DensityInterface = "0.4.0" DifferentiableExpectations = "0.2" DifferentiableFrankWolfe = "0.2" Distributions = "0.25" +DocStringExtensions = "0.9.3" LinearAlgebra = "<0.0.1,1" Random = "<0.0.1,1" RequiredInterfaces = "0.1.3" diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 6886147..883aecb 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -23,6 +23,7 @@ using Distributions: Normal, product_distribution, logpdf +using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES using LinearAlgebra: dot using Random: Random, AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed! using Statistics: mean diff --git a/src/layers/perturbed/perturbation.jl b/src/layers/perturbed/perturbation.jl index 831cb27..df64426 100644 --- a/src/layers/perturbed/perturbation.jl +++ b/src/layers/perturbed/perturbation.jl @@ -1,16 +1,43 @@ +""" +$TYPEDEF + +Abstract type for a perturbation. +It's a function that takes a parameter `θ` and returns a perturbed parameter by a distribution `perturbation_dist`. + +All subtypes should have a `perturbation_dist` + +# Existing implementations +- [`AdditivePerturbation`](@ref) +- [`MultiplicativePerturbation`](@ref) +""" abstract type AbstractPerturbation <: ContinuousUnivariateDistribution end +""" +$TYPEDSIGNATURES +""" function Random.rand(rng::AbstractRNG, perturbation::AbstractPerturbation) return rand(rng, perturbation.perturbation_dist) end +""" +$TYPEDEF + +Additive perturbation: θ ↦ θ + εZ, where Z is a random variable following `perturbation_dist`. + +# Fields +$TYPEDFIELDS +""" struct AdditivePerturbation{F} + "base distribution for the perturbation" perturbation_dist::F + "perturbation size" ε::Float64 end """ -θ + εZ +$TYPEDSIGNATURES + +Apply the additive perturbation to the parameter `θ`. """ function (pdc::AdditivePerturbation)(θ::AbstractArray) (; perturbation_dist, ε) = pdc @@ -18,13 +45,25 @@ function (pdc::AdditivePerturbation)(θ::AbstractArray) end """ -θ ⊙ exp(εZ - ε²/2) +$TYPEDEF + +Multiplicative perturbation: θ ↦ θ ⊙ exp(εZ - ε²/2) + +# Fields +$TYPEDFIELDS """ struct MultiplicativePerturbation{F} + "base distribution for the perturbation" perturbation_dist::F + "perturbation size" ε::Float64 end +""" +$TYPEDSIGNATURES + +Apply the multiplicative perturbation to the parameter `θ`. +""" function (pdc::MultiplicativePerturbation)(θ::AbstractArray) (; perturbation_dist, ε) = pdc return product_distribution(θ .* ExponentialOf(ε * perturbation_dist - ε^2 / 2)) diff --git a/src/layers/perturbed/perturbed.jl b/src/layers/perturbed/perturbed.jl index 6fff4d9..1fd4702 100644 --- a/src/layers/perturbed/perturbed.jl +++ b/src/layers/perturbed/perturbed.jl @@ -1,5 +1,5 @@ """ - PerturbedOracle{D,F} <: AbstractOptimizationLayer +$TYPEDEF Differentiable perturbation of a black box optimizer of type `F`, with perturbation of type `D`. @@ -14,6 +14,11 @@ struct PerturbedOracle{D,F,t,variance_reduction,G,R,S} <: AbstractOptimizationLa reinforce::Reinforce{t,variance_reduction,F,D,G,R,S} end +""" +$TYPEDSIGNATURES + +Forward pass of the perturbed optimizer. +""" function (perturbed::PerturbedOracle)(θ::AbstractArray; kwargs...) return perturbed.reinforce(θ; kwargs...) end @@ -43,7 +48,9 @@ function Base.show(io::IO, perturbed::PerturbedOracle{<:AbstractPerturbation}) end """ -doc +$TYPEDSIGNATURES + +Constructor for [`PerturbedOracle`](@ref). """ function PerturbedOracle( maximizer, @@ -72,7 +79,9 @@ function PerturbedOracle( end """ - PerturbedAdditive(maximizer; kwargs...) +$TYPEDSIGNATURES + +Constructor for [`PerturbedOracle`](@ref) with an additive perturbation. """ function PerturbedAdditive( maximizer; @@ -103,7 +112,9 @@ function PerturbedAdditive( end """ -doc +$TYPEDSIGNATURES + +Constructor for [`PerturbedOracle`](@ref) with a multiplicative perturbation. """ function PerturbedMultiplicative( maximizer; diff --git a/src/layers/perturbed/utils.jl b/src/layers/perturbed/utils.jl index 3f6dc38..509f080 100644 --- a/src/layers/perturbed/utils.jl +++ b/src/layers/perturbed/utils.jl @@ -1,12 +1,26 @@ +""" +$TYPEDSIGNATURES + +Data structure modeling the exponential of a continuous univariate random variable. +""" struct ExponentialOf{D<:ContinuousUnivariateDistribution} <: ContinuousUnivariateDistribution dist::D end +""" +$TYPEDSIGNATURES +""" function Random.rand(rng::AbstractRNG, d::ExponentialOf) return exp(rand(rng, d.dist)) end +""" +$TYPEDSIGNATURES + +Return the log-density of the [`ExponentialOf`](@ref) distribution at `x`. +It is equal to ``logpdf(d, log(x)) - log(x)`` +""" function Distributions.logpdf(d::ExponentialOf, x::Real) return logpdf(d.dist, log(x)) - log(x) end diff --git a/src/losses/fenchel_young_loss.jl b/src/losses/fenchel_young_loss.jl index 683f1b2..e038b72 100644 --- a/src/losses/fenchel_young_loss.jl +++ b/src/losses/fenchel_young_loss.jl @@ -1,5 +1,5 @@ """ - FenchelYoungLoss <: AbstractLossLayer +$TYPEDEF Fenchel-Young loss associated with a given optimization layer. ``` @@ -9,7 +9,6 @@ L(θ, y_true) = (Ω(y_true) - θᵀy_true) - (Ω(ŷ) - θᵀŷ) Reference: # Fields - - `optimization_layer::AbstractOptimizationLayer`: optimization layer that can be formulated as `ŷ(θ) = argmax {θᵀy - Ω(y)}` (either regularized or perturbed) """ struct FenchelYoungLoss{O<:AbstractOptimizationLayer} <: AbstractLossLayer @@ -24,7 +23,9 @@ end ## Forward pass """ - (fyl::FenchelYoungLoss)(θ, y_true[; kwargs...]) +$TYPEDSIGNATURES + +Compute L(θ, y_true). """ function (fyl::FenchelYoungLoss)(θ::AbstractArray, y_true::AbstractArray; kwargs...) l, _ = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) @@ -67,7 +68,7 @@ function fenchel_young_loss_and_grad( (; optimization_layer) = fyl maximizer = get_maximizer(optimization_layer) F, almost_ŷ = fenchel_young_F_and_first_part_of_grad(optimization_layer, θ; kwargs...) - l = F - objective_value(maximizer, θ, y_true; kwargs...) # dot(θ, y_true) + l = F - objective_value(maximizer, θ, y_true; kwargs...) g = almost_ŷ - apply_g(maximizer, y_true; kwargs...) return l, g end diff --git a/src/utils/linear_maximizer.jl b/src/utils/linear_maximizer.jl index 684edc2..f769de7 100644 --- a/src/utils/linear_maximizer.jl +++ b/src/utils/linear_maximizer.jl @@ -27,6 +27,10 @@ function (f::LinearMaximizer)(θ::AbstractArray; kwargs...) return f.maximizer(θ; kwargs...) end +objective_value(::Any, θ, y; kwargs...) = dot(θ, y) +apply_g(::Any, y; kwargs...) = y +apply_h(::Any, y; kwargs...) = zero(eltype(y)) + """ objective_value(f, θ, y, kwargs...) diff --git a/test/learning_ranking.jl b/test/learning_ranking.jl index 891ea6e..ee7fe82 100644 --- a/test/learning_ranking.jl +++ b/test/learning_ranking.jl @@ -278,7 +278,7 @@ end instance_dim=5, true_maximizer=ranking, maximizer=identity_kw, - loss=LinearPerturbed(f, p; nb_samples=10), + loss=PerturbedOracle(f, p; nb_samples=10), error_function=hamming_distance, true_encoder=true_encoder, cost=cost, From 40680719fc5cf7bfa106f9991637b90781b293f5 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 17:53:49 +0100 Subject: [PATCH 16/26] cleanup and update regularized --- src/InferOpt.jl | 16 +- src/interface.jl | 2 + src/layers/perturbed/_abstract_perturbed.jl | 153 ---------------- src/layers/perturbed/_additive.jl | 128 ------------- src/layers/perturbed/_multiplicative.jl | 130 -------------- src/layers/perturbed/_perturbed_oracle.jl | 92 ---------- .../regularized/abstract_regularized.jl | 41 +---- src/losses/_fenchel_young_loss.jl | 169 ------------------ src/losses/fenchel_young_loss.jl | 34 ++-- src/utils/generalized_maximizer.jl | 48 ----- src/utils/linear_maximizer.jl | 42 ++++- src/utils/utils.jl | 35 ++++ test/learning_generalized_maximizer.jl | 4 +- 13 files changed, 95 insertions(+), 799 deletions(-) delete mode 100644 src/layers/perturbed/_abstract_perturbed.jl delete mode 100644 src/layers/perturbed/_additive.jl delete mode 100644 src/layers/perturbed/_multiplicative.jl delete mode 100644 src/layers/perturbed/_perturbed_oracle.jl delete mode 100644 src/losses/_fenchel_young_loss.jl delete mode 100644 src/utils/generalized_maximizer.jl create mode 100644 src/utils/utils.jl diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 883aecb..6b4b653 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -11,11 +11,7 @@ using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, Tangent, ZeroTangen using ChainRulesCore: rrule, rrule_via_ad, unthunk using DensityInterface: logdensityof using DifferentiableExpectations: - DifferentiableExpectations, - Reinforce, - empirical_predistribution, - empirical_distribution, - FixKwargs + DifferentiableExpectations, Reinforce, empirical_predistribution, empirical_distribution using Distributions: Distributions, ContinuousUnivariateDistribution, @@ -34,9 +30,9 @@ using RequiredInterfaces include("interface.jl") +include("utils/utils.jl") include("utils/some_functions.jl") include("utils/pushforward.jl") -include("utils/generalized_maximizer.jl") include("utils/linear_maximizer.jl") include("utils/isotonic_regression/isotonic_l2.jl") include("utils/isotonic_regression/isotonic_kl.jl") @@ -48,8 +44,6 @@ include("layers/simple/identity.jl") include("layers/perturbed/utils.jl") include("layers/perturbed/perturbation.jl") -# include("layers/perturbed/additive.jl") -# include("layers/perturbed/multiplicative.jl") include("layers/perturbed/perturbed.jl") include("layers/regularized/abstract_regularized.jl") @@ -72,18 +66,14 @@ include("losses/imitation_loss.jl") export half_square_norm export shannon_entropy, negative_shannon_entropy export one_hot_argmax, ranking -export GeneralizedMaximizer export LinearMaximizer, apply_g, apply_h, objective_value -# export FixedAtomsProbabilityDistribution -# export compute_expectation -# export compute_probability_distribution export Pushforward export IdentityRelaxation export Interpolation -export AbstractRegularized, AbstractRegularizedGeneralizedMaximizer +export AbstractRegularized export SoftArgmax, soft_argmax export SparseArgmax, sparse_argmax export SoftRank, soft_rank, soft_rank_l2, soft_rank_kl diff --git a/src/interface.jl b/src/interface.jl index 2fa67a8..aeafa6e 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -23,6 +23,8 @@ Supertype for all the optimization layers defined in InferOpt. """ abstract type AbstractOptimizationLayer <: AbstractLayer end +get_maximizer(layer::AbstractOptimizationLayer) = nothing + ## Losses """ diff --git a/src/layers/perturbed/_abstract_perturbed.jl b/src/layers/perturbed/_abstract_perturbed.jl deleted file mode 100644 index 5ac197f..0000000 --- a/src/layers/perturbed/_abstract_perturbed.jl +++ /dev/null @@ -1,153 +0,0 @@ -""" - AbstractPerturbed{F,parallel} <: AbstractOptimizationLayer - -Differentiable perturbation of a black box optimizer of type `F`. - -The parameter `parallel` is a boolean value indicating if the perturbations are run in parallel. -This is particularly useful if your black box optimizer running time is high. - -# Available implementations: -- [`PerturbedAdditive`](@ref) -- [`PerturbedMultiplicative`](@ref) -- [`PerturbedOracle`](@ref) - -# These three subtypes share the following fields: -- `oracle`: black box (optimizer) -- `perturbation::P`: perturbation distribution of the input θ -- `grad_logdensity::G`: gradient of the log density `perturbation` w.r.t. input θ -- `nb_samples::Int`: number of perturbation samples drawn at each forward pass -- `seed::Union{Nothing,Int}`: seed of the perturbation. - It is reset each time the forward pass is called, - making it deterministic by always drawing the same perturbations. - If you do not want this behaviour, set this field to `nothing`. -- `rng::AbstractRNG`: random number generator using the `seed`. - -!!! warning - The `perturbation` field does not mean the same thing for a [`PerturbedOracle`](@ref) - than for a [`PerturbedAdditive`](@ref)/[`PerturbedMultiplicative`](@ref). See their respective docs. -""" -abstract type AbstractPerturbed{O,parallel} <: AbstractOptimizationLayer end - -# Non parallelized version -function compute_atoms( - perturbed::AbstractPerturbed{O,false}, η_samples::Vector{<:AbstractArray}; kwargs... -) where {O} - return [perturbed.oracle(η; kwargs...) for η in η_samples] -end - -# Parallelized version -function compute_atoms( - perturbed::AbstractPerturbed{O,true}, η_samples::Vector{<:AbstractArray}; kwargs... -) where {O} - return ThreadsX.map(η -> perturbed.oracle(η; kwargs...), η_samples) -end - -""" - sample_perturbations(perturbed::AbstractPerturbed, θ::AbstractArray) - -Draw `nb_samples` random perturbations from the `perturbation` distribution. -""" -function sample_perturbations end - -""" - perturbation_grad_logdensity( - ::RuleConfig, - ::AbstractPerturbed, - θ::AbstractArray, - sample::AbstractArray, - ) - -Compute de gradient w.r.t to the input `θ` of the logdensity of the perturbed input -distribution evaluated in the observed perturbation sample `η`. -""" -function perturbation_grad_logdensity end - -""" - compute_probability_distribution_from_samples( - ::AbstractPerturbed, - θ::AbstractArray, - samples::Vector{<:AbstractArray}; - kwargs..., - ) - -Create a probability distributions from `samples` drawn from `perturbation`. -""" -function compute_probability_distribution_from_samples end - -""" - compute_probability_distribution(perturbed::AbstractPerturbed, θ; kwargs...) - -Turn random perturbations of `θ` into a distribution on polytope vertices. - -Keyword arguments are passed to the underlying linear maximizer. -""" -function compute_probability_distribution( - perturbed::AbstractPerturbed, - θ::AbstractArray; - autodiff_variance_reduction::Bool=false, - kwargs..., -) - η_samples = sample_perturbations(perturbed, θ) - return compute_probability_distribution_from_samples(perturbed, θ, η_samples; kwargs...) -end - -# Forward pass - -""" - (perturbed::AbstractPerturbed)(θ; kwargs...) - -Forward pass. Compute the expectation of the underlying distribution. -""" -function (perturbed::AbstractPerturbed)( - θ::AbstractArray; autodiff_variance_reduction::Bool=true, kwargs... -) - probadist = compute_probability_distribution( - perturbed, θ; autodiff_variance_reduction, kwargs... - ) - return compute_expectation(probadist) -end - -function perturbation_grad_logdensity( - ::RuleConfig, perturbed::AbstractPerturbed, θ::AbstractArray, η::AbstractArray -) - return perturbed.grad_logdensity(θ, η) -end - -# Backward pass - -function ChainRulesCore.rrule( - rc::RuleConfig, - ::typeof(compute_probability_distribution), - perturbed::AbstractPerturbed, - θ::AbstractArray; - autodiff_variance_reduction::Bool=true, - kwargs..., -) - η_samples = sample_perturbations(perturbed, θ) - y_dist = compute_probability_distribution_from_samples( - perturbed, θ, η_samples; kwargs... - ) - - ∇logp_samples = [perturbation_grad_logdensity(rc, perturbed, θ, η) for η in η_samples] - - M = perturbed.nb_samples - function perturbed_oracle_dist_pullback(δy_dist) - weights = y_dist.weights - δy_weights = δy_dist.weights - δy_sum = sum(δy_weights) - δθ = sum( - map(1:M) do i - δyᵢ, ∇logpᵢ, w = δy_weights[i], ∇logp_samples[i], weights[i] - if autodiff_variance_reduction - bᵢ = M == 1 ? 0 * δy_sum : (δy_sum - δyᵢ) / (M - 1) - return w * (δyᵢ - bᵢ) * ∇logpᵢ - else - return w * δyᵢ * ∇logpᵢ - end - end, - ) - return NoTangent(), NoTangent(), δθ - end - - return y_dist, perturbed_oracle_dist_pullback -end diff --git a/src/layers/perturbed/_additive.jl b/src/layers/perturbed/_additive.jl deleted file mode 100644 index e9923d5..0000000 --- a/src/layers/perturbed/_additive.jl +++ /dev/null @@ -1,128 +0,0 @@ -""" - PerturbedAdditive{P,G,O,R,S,parallel} <: AbstractPerturbed{parallel} - -Differentiable normal perturbation of a black-box maximizer: the input undergoes `θ -> θ + εZ` where `Z ∼ N(0, I)`. - -This [`AbstractOptimizationLayer`](@ref) is compatible with [`FenchelYoungLoss`](@ref), -if the oracle is an optimization maximizer with a linear objective. - -Reference: - -See [`AbstractPerturbed`](@ref) for more details. - -# Specific field -- `ε:Float64`: size of the perturbation -""" -struct PerturbedAdditive{P,G,O,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: - AbstractPerturbed{O,parallel} - perturbation::P - grad_logdensity::G - oracle::O - rng::R - seed::S - nb_samples::Int - ε::Float64 -end - -function Base.show(io::IO, perturbed::PerturbedAdditive) - (; oracle, ε, rng, seed, nb_samples, perturbation) = perturbed - perturb = isnothing(perturbation) ? "Normal(0, 1)" : "$perturbation" - return print( - io, "PerturbedAdditive($oracle, $ε, $nb_samples, $(typeof(rng)), $seed, $perturb)" - ) -end - -""" - PerturbedAdditive(oracle[; ε, nb_samples, seed, is_parallel, perturbation, grad_logdensity, rng]) - -[`PerturbedAdditive`](@ref) constructor. - -# Arguments -- `oracle`: the black-box oracle we want to differentiate through. - It should be a linear maximizer if you want to use it inside a [`FenchelYoungLoss`](@ref). - -# Keyword arguments (optional) -- `ε=1.0`: size of the perturbation. -- `nb_samples::Int=1`: number of perturbation samples drawn at each forward pass. -- `perturbation=nothing`: nothing by default. If you want to use a different distribution than a - `Normal` for the perturbation `z`, give it here as a distribution-like object implementing - the `rand` method. It should also implement `logdensityof` if `grad_logdensity` is not given. -- `grad_logdensity=nothing`: gradient function of `perturbation` w.r.t. `θ`. - If set to nothing (default), it's computed using automatic differentiation. -- `seed::Union{Nothing,Int}=nothing`: seed of the perturbation. - It is reset each time the forward pass is called, - making it deterministic by always drawing the same perturbations. - If you do not want this behaviour, set this field to `nothing`. -- `rng::AbstractRNG`=MersenneTwister(0): random number generator using the `seed`. -""" -function PerturbedAdditive( - oracle::O; - ε=1.0, - nb_samples=1, - seed::S=nothing, - is_parallel::Bool=false, - perturbation::P=nothing, - grad_logdensity::G=nothing, - rng::R=MersenneTwister(0), -) where {P,G,O,R<:AbstractRNG,S<:Union{Int,Nothing}} - return PerturbedAdditive{P,G,O,R,S,is_parallel}( - perturbation, grad_logdensity, oracle, rng, seed, nb_samples, float(ε) - ) -end - -function sample_perturbations(perturbed::PerturbedAdditive, θ::AbstractArray) - (; rng, seed, nb_samples, perturbation) = perturbed - seed!(rng, seed) - return [rand(rng, perturbation, size(θ)) for _ in 1:nb_samples] -end - -function sample_perturbations(perturbed::PerturbedAdditive{Nothing}, θ::AbstractArray) - (; rng, seed, nb_samples) = perturbed - seed!(rng, seed) - return [randn(rng, size(θ)) for _ in 1:nb_samples] -end - -function compute_probability_distribution_from_samples( - perturbed::PerturbedAdditive, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) - (; ε) = perturbed - η_samples = [θ .+ ε .* Z for Z in Z_samples] - atoms = compute_atoms(perturbed, η_samples; kwargs...) - weights = ones(length(atoms)) ./ length(atoms) - probadist = FixedAtomsProbabilityDistribution(atoms, weights) - return probadist -end - -function perturbation_grad_logdensity( - ::RuleConfig, - perturbed::PerturbedAdditive{Nothing,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) - (; ε) = perturbed - return Z ./ ε -end - -function _perturbation_logdensity( - perturbed::PerturbedAdditive, θ::AbstractArray, η::AbstractArray -) - (; ε, perturbation) = perturbed - Z = (η .- θ) ./ ε - return sum(logdensityof(perturbation, z) for z in Z) -end - -function perturbation_grad_logdensity( - rc::RuleConfig, - perturbed::PerturbedAdditive{P,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) where {P} - (; ε) = perturbed - η = θ .+ ε .* Z - l, logdensity_pullback = rrule_via_ad(rc, _perturbation_logdensity, perturbed, θ, η) - δperturbation_logdensity, δperturbed, δθ, δη = logdensity_pullback(one(l)) - return δθ -end diff --git a/src/layers/perturbed/_multiplicative.jl b/src/layers/perturbed/_multiplicative.jl deleted file mode 100644 index f96a333..0000000 --- a/src/layers/perturbed/_multiplicative.jl +++ /dev/null @@ -1,130 +0,0 @@ -""" - PerturbedMultiplicative{P,G,O,R,S,parallel} <: AbstractPerturbed{parallel} - -Differentiable multiplicative perturbation of a black-box oracle: -the input undergoes `θ -> θ ⊙ exp[εZ - ε²/2]` where `Z ∼ perturbation`. - -This [`AbstractOptimizationLayer`](@ref) is compatible with [`FenchelYoungLoss`](@ref), -if the oracle is an optimization maximizer with a linear objective. - -Reference: - -See [`AbstractPerturbed`](@ref) for more details. - -# Specific field -- `ε:Float64`: size of the perturbation -""" -struct PerturbedMultiplicative{P,G,O,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: - AbstractPerturbed{O,parallel} - perturbation::P - grad_logdensity::G - oracle::O - rng::R - seed::S - nb_samples::Int - ε::Float64 -end - -function Base.show(io::IO, perturbed::PerturbedMultiplicative) - (; oracle, ε, rng, seed, nb_samples, perturbation) = perturbed - perturb = isnothing(perturbation) ? "Normal(0, 1)" : "$perturbation" - return print( - io, - "PerturbedMultiplicative($oracle, $ε, $nb_samples, $(typeof(rng)), $seed, $perturb)", - ) -end - -""" - PerturbedMultiplicative(oracle[; ε, nb_samples, seed, is_parallel, perturbation, grad_logdensity, rng]) - -[`PerturbedMultiplicative`](@ref) constructor. - -# Arguments -- `oracle`: the black-box oracle we want to differentiate through. - It should be a linear maximizer if you want to use it inside a [`FenchelYoungLoss`]. - -# Keyword arguments (optional) -- `ε=1.0`: size of the perturbation. -- `nb_samples::Int=1`: number of perturbation samples drawn at each forward pass. -- `perturbation=nothing`: nothing by default. If you want to use a different distribution than a - `Normal` for the perturbation `z`, give it here as a distribution-like object implementing - the `rand` method. It should also implement `logdensityof` if `grad_logdensity` is not given. -- `grad_logdensity=nothing`: gradient function of `perturbation` w.r.t. `θ`. - If set to nothing (default), it's computed using automatic differentiation. -- `seed::Union{Nothing,Int}=nothing`: seed of the perturbation. - It is reset each time the forward pass is called, - making it deterministic by always drawing the same perturbations. - If you do not want this behaviour, set this field to `nothing`. -- `rng::AbstractRNG`=MersenneTwister(0): random number generator using the `seed`. -""" -function PerturbedMultiplicative( - oracle::O; - ε=1.0, - nb_samples=1, - seed::S=nothing, - is_parallel=false, - perturbation::P=nothing, - grad_logdensity::G=nothing, - rng::R=MersenneTwister(0), -) where {P,G,O,R<:AbstractRNG,S<:Union{Int,Nothing}} - return PerturbedMultiplicative{P,G,O,R,S,is_parallel}( - perturbation, grad_logdensity, oracle, rng, seed, nb_samples, float(ε) - ) -end - -function sample_perturbations(perturbed::PerturbedMultiplicative, θ::AbstractArray) - (; rng, seed, nb_samples, perturbation) = perturbed - seed!(rng, seed) - return [rand(rng, perturbation, size(θ)) for _ in 1:nb_samples] -end - -function sample_perturbations(perturbed::PerturbedMultiplicative{Nothing}, θ::AbstractArray) - (; rng, seed, nb_samples) = perturbed - seed!(rng, seed) - return [randn(rng, size(θ)) for _ in 1:nb_samples] -end - -function compute_probability_distribution_from_samples( - perturbed::PerturbedMultiplicative, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) - (; ε) = perturbed - η_samples = [θ .* exp.(ε .* Z .- ε^2 / 2) for Z in Z_samples] - atoms = compute_atoms(perturbed, η_samples; kwargs...) - weights = ones(length(atoms)) ./ length(atoms) - probadist = FixedAtomsProbabilityDistribution(atoms, weights) - return probadist -end - -function perturbation_grad_logdensity( - ::RuleConfig, - perturbed::PerturbedMultiplicative{Nothing,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) - (; ε) = perturbed - return inv.(ε .* θ) .* Z -end - -function _perturbation_logdensity( - perturbed::PerturbedMultiplicative, θ::AbstractArray, η::AbstractArray -) - (; ε, perturbation) = perturbed - Z = (log.(η) .- log.(θ)) ./ ε .+ ε / 2 - return sum(logdensityof(perturbation, z) for z in Z) -end - -function perturbation_grad_logdensity( - rc::RuleConfig, - perturbed::PerturbedMultiplicative{P,Nothing}, - θ::AbstractArray, - Z::AbstractArray, -) where {P} - (; ε) = perturbed - η = θ .* exp.(ε .* Z .- ε^2 / 2) - l, logdensity_pullback = rrule_via_ad(rc, _perturbation_logdensity, perturbed, θ, η) - δperturbation_logdensity, δperturbed, δθ, δη = logdensity_pullback(one(l)) - return δθ -end diff --git a/src/layers/perturbed/_perturbed_oracle.jl b/src/layers/perturbed/_perturbed_oracle.jl deleted file mode 100644 index 9d7cd04..0000000 --- a/src/layers/perturbed/_perturbed_oracle.jl +++ /dev/null @@ -1,92 +0,0 @@ -""" - PerturbedOracle{P,G,O,R,S,parallel} <: AbstractPerturbed{parallel} - -Differentiable perturbed black-box oracle. The `oracle` input `θ` is perturbed as `η ∼ perturbation(⋅|θ)`. -[`PerturbedAdditive`](@ref) is a special case of `PerturbedOracle` with `perturbation(θ) = MvNormal(θ, ε * I)`. -[`PerturbedMultiplicative`] is also a special case of `PerturbedOracle`. - -See [`AbstractPerturbed`](@ref) for more details about its fields. -""" -struct PerturbedOracle{P,G,O,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <: - AbstractPerturbed{O,parallel} - perturbation::P - grad_logdensity::G - oracle::O - rng::R - seed::S - nb_samples::Int -end - -""" - PerturbedOracle(perturbation, oracle[; grad_logdensity, rng, seed, is_parallel, nb_samples]) - -[`PerturbedOracle`](@ref) constructor. - -# Arguments -- `oracle`: the black-box oracle we want to differentiate through -- `perturbation`: should be a callable such that `perturbation(θ)` is a distribution-like - object that can be sampled with `rand`. - It should also implement `logdensityof` if `grad_logdensity` is not given. - -# Keyword arguments (optional) -- `grad_logdensity=nothing`: gradient function of `perturbation` w.r.t. `θ`. - If set to nothing (default), it's computed using automatic differentiation. -- `nb_samples::Int=1`: number of perturbation samples drawn at each forward pass -- `seed::Union{Nothing,Int}=nothing`: seed of the perturbation. - It is reset each time the forward pass is called, - making it deterministic by always drawing the same perturbations. - If you do not want this behaviour, set this field to `nothing`. -- `rng::AbstractRNG`=MersenneTwister(0): random number generator using the `seed`. - -!!! info - If you have access to the analytical expression of `grad_logdensity` it is recommended to - give it, as it will be computationally faster. -""" -function PerturbedOracle( - oracle::O, - perturbation::P; - grad_logdensity::G=nothing, - nb_samples::Int=1, - seed::S=nothing, - is_parallel::Bool=false, - rng::R=MersenneTwister(0), -) where {P,G,O,R<:AbstractRNG,S<:Union{Int,Nothing}} - return PerturbedOracle{P,G,O,R,S,is_parallel}( - perturbation, grad_logdensity, oracle, rng, seed, nb_samples - ) -end - -function Base.show(io::IO, po::PerturbedOracle) - (; oracle, perturbation, rng, seed, nb_samples) = po - return print( - io, "PerturbedOracle($perturbation, $oracle, $nb_samples, $(typeof(rng)), $seed)" - ) -end - -function sample_perturbations(po::PerturbedOracle, θ::AbstractArray) - (; rng, seed, perturbation, nb_samples) = po - seed!(rng, seed) - η_samples = [rand(rng, perturbation(θ)) for _ in 1:nb_samples] - return η_samples -end - -function compute_probability_distribution_from_samples( - perturbed::PerturbedOracle, θ, η_samples::Vector{<:AbstractArray}; kwargs... -) - atoms = compute_atoms(perturbed, η_samples; kwargs...) - weights = ones(length(atoms)) ./ length(atoms) - probadist = FixedAtomsProbabilityDistribution(atoms, weights) - return probadist -end - -function _perturbation_logdensity(po::PerturbedOracle, θ::AbstractArray, η::AbstractArray) - return logdensityof(po.perturbation(θ), η) -end - -function perturbation_grad_logdensity( - rc::RuleConfig, po::PerturbedOracle{P,Nothing}, θ::AbstractArray, η::AbstractArray -) where {P} - l, logdensity_pullback = rrule_via_ad(rc, _perturbation_logdensity, po, θ, η) - δperturbation_logdensity, δpo, δθ, δη = logdensity_pullback(one(l)) - return δθ -end diff --git a/src/layers/regularized/abstract_regularized.jl b/src/layers/regularized/abstract_regularized.jl index 9a2a903..27a7e80 100644 --- a/src/layers/regularized/abstract_regularized.jl +++ b/src/layers/regularized/abstract_regularized.jl @@ -1,60 +1,33 @@ """ AbstractRegularized <: AbstractOptimizationLayer -Convex regularization perturbation of a black box linear optimizer +Convex regularization perturbation of a black box linear (in θ) optimizer ``` -ŷ(θ) = argmax_{y ∈ C} {θᵀy - Ω(y)} +ŷ(θ) = argmax_{y ∈ C} {θᵀg(y) + h(y) - Ω(y)} ``` +with g and h functions of y. # Interface - - `(regularized::AbstractRegularized)(θ; kwargs...)`: return `ŷ(θ)` -- `compute_regularization(regularized, y)`: return `Ω(y)` +- `compute_regularization(regularized, y)`: return `Ω(y) +- `get_maximizer(regularized)`: return the associated `GeneralizedMaximizer` optimizer # Available implementations - - [`SoftArgmax`](@ref) - [`SparseArgmax`](@ref) +- [`SoftRank`](@ref) - [`RegularizedFrankWolfe`](@ref) """ abstract type AbstractRegularized <: AbstractOptimizationLayer end """ - AbstractRegularizedGeneralizedMaximizer <: AbstractRegularized - -Convex regularization perturbation of a black box **generalized** optimizer -``` -ŷ(θ) = argmax_{y ∈ C} {θᵀg(y) + h(y) - Ω(y)} -with g and h functions of y. -``` - -# Interface - -- `(regularized::AbstractRegularized)(θ; kwargs...)`: return `ŷ(θ)` -- `compute_regularization(regularized, y)`: return `Ω(y)` -- `get_maximizer(regularized)`: return the associated `GeneralizedMaximizer` optimizer -""" -abstract type AbstractRegularizedGeneralizedMaximizer <: AbstractRegularized end - -""" - compute_regularization(regularized, y) + compute_regularization(regularized::AbstractRegularized, y) Return the convex penalty `Ω(y)` associated with an `AbstractRegularized` layer. """ function compute_regularization end -""" - get_maximizer(regularized) - -Return the associated optimizer. -""" -function get_maximizer end - @required AbstractRegularized begin # (regularized::AbstractRegularized)(θ::AbstractArray; kwargs...) # waiting for RequiredInterfaces to support this (see https://github.com/Seelengrab/RequiredInterfaces.jl/issues/11) compute_regularization(::AbstractRegularized, y) end - -@required AbstractRegularizedGeneralizedMaximizer begin - get_maximizer(::AbstractRegularizedGeneralizedMaximizer) -end diff --git a/src/losses/_fenchel_young_loss.jl b/src/losses/_fenchel_young_loss.jl deleted file mode 100644 index d82c104..0000000 --- a/src/losses/_fenchel_young_loss.jl +++ /dev/null @@ -1,169 +0,0 @@ -""" - FenchelYoungLoss <: AbstractLossLayer - -Fenchel-Young loss associated with a given optimization layer. -``` -L(θ, y_true) = (Ω(y_true) - θᵀy_true) - (Ω(ŷ) - θᵀŷ) -``` - -Reference: - -# Fields - -- `optimization_layer::AbstractOptimizationLayer`: optimization layer that can be formulated as `ŷ(θ) = argmax {θᵀy - Ω(y)}` (either regularized or perturbed) -""" -struct FenchelYoungLoss{O<:AbstractOptimizationLayer} <: AbstractLossLayer - optimization_layer::O -end - -function Base.show(io::IO, fyl::FenchelYoungLoss) - (; optimization_layer) = fyl - return print(io, "FenchelYoungLoss($optimization_layer)") -end - -## Forward pass - -""" - (fyl::FenchelYoungLoss)(θ, y_true[; kwargs...]) -""" -function (fyl::FenchelYoungLoss)(θ::AbstractArray, y_true::AbstractArray; kwargs...) - l, _ = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) - return l -end - -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {O<:AbstractRegularized} - (; optimization_layer) = fyl - ŷ = optimization_layer(θ; kwargs...) - Ωy_true = compute_regularization(optimization_layer, y_true) - Ωŷ = compute_regularization(optimization_layer, ŷ) - l = (Ωy_true - dot(θ, y_true)) - (Ωŷ - dot(θ, ŷ)) - g = ŷ - y_true - return l, g -end - -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {O<:AbstractRegularizedGeneralizedMaximizer} - (; optimization_layer) = fyl - ŷ = optimization_layer(θ; kwargs...) - Ωy_true = compute_regularization(optimization_layer, y_true) - Ωŷ = compute_regularization(optimization_layer, ŷ) - maximizer = get_maximizer(optimization_layer) - l = - (Ωy_true - objective_value(maximizer, θ, y_true; kwargs...)) - - (Ωŷ - objective_value(maximizer, θ, ŷ; kwargs...)) - g = maximizer.g(ŷ; kwargs...) - maximizer.g(y_true; kwargs...) - return l, g -end - -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {O<:AbstractPerturbed} - (; optimization_layer) = fyl - F, almost_ŷ = fenchel_young_F_and_first_part_of_grad(optimization_layer, θ; kwargs...) - l = F - dot(θ, y_true) - g = almost_ŷ - y_true - return l, g -end - -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {P<:AbstractPerturbed{<:GeneralizedMaximizer}} - (; optimization_layer) = fyl - F, almost_g_of_ŷ = fenchel_young_F_and_first_part_of_grad( - optimization_layer, θ; kwargs... - ) - l = F - objective_value(optimization_layer.oracle, θ, y_true; kwargs...) - g = almost_g_of_ŷ - optimization_layer.oracle.g(y_true; kwargs...) - return l, g -end - -## Backward pass - -function ChainRulesCore.rrule( - fyl::FenchelYoungLoss, θ::AbstractArray, y_true::AbstractArray; kwargs... -) - l, g = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...) - fyl_pullback(dl) = NoTangent(), dl * g, NoTangent() - return l, fyl_pullback -end - -## Specific overrides for perturbed layers - -function compute_F_and_y_samples( - perturbed::AbstractPerturbed{O,false}, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) where {O} - F_and_y_samples = [ - fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...) for - Z in Z_samples - ] - return F_and_y_samples -end - -function compute_F_and_y_samples( - perturbed::AbstractPerturbed{O,true}, - θ::AbstractArray, - Z_samples::Vector{<:AbstractArray}; - kwargs..., -) where {O} - return ThreadsX.map( - Z -> fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...), Z_samples - ) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::AbstractPerturbed, θ::AbstractArray; kwargs... -) - Z_samples = sample_perturbations(perturbed, θ) - F_and_y_samples = compute_F_and_y_samples(perturbed, θ, Z_samples; kwargs...) - return mean(first, F_and_y_samples), mean(last, F_and_y_samples) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedAdditive, θ::AbstractArray, Z::AbstractArray; kwargs... -) - (; oracle, ε) = perturbed - η = θ .+ ε .* Z - y = oracle(η; kwargs...) - F = dot(η, y) - return F, y -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedAdditive{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... -) where {P,G,O<:GeneralizedMaximizer} - (; oracle, ε) = perturbed - η = θ .+ ε .* Z - y = oracle(η; kwargs...) - F = objective_value(oracle, η, y; kwargs...) - return F, oracle.g(y; kwargs...) -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedMultiplicative, θ::AbstractArray, Z::AbstractArray; kwargs... -) - (; oracle, ε) = perturbed - eZ = exp.(ε .* Z .- ε^2 ./ 2) - η = θ .* eZ - y = oracle(η; kwargs...) - F = dot(η, y) - y_scaled = y .* eZ - return F, y_scaled -end - -function fenchel_young_F_and_first_part_of_grad( - perturbed::PerturbedMultiplicative{P,G,O}, θ::AbstractArray, Z::AbstractArray; kwargs... -) where {P,G,O<:GeneralizedMaximizer} - (; oracle, ε) = perturbed - eZ = exp.(ε .* Z .- ε^2) - η = θ .* eZ - y = oracle(η; kwargs...) - F = objective_value(oracle, η, y; kwargs...) - y_scaled = y .* eZ - return F, oracle.g(y_scaled; kwargs...) -end diff --git a/src/losses/fenchel_young_loss.jl b/src/losses/fenchel_young_loss.jl index e038b72..c66e1e9 100644 --- a/src/losses/fenchel_young_loss.jl +++ b/src/losses/fenchel_young_loss.jl @@ -35,18 +35,6 @@ end function fenchel_young_loss_and_grad( fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... ) where {O<:AbstractRegularized} - (; optimization_layer) = fyl - ŷ = optimization_layer(θ; kwargs...) - Ωy_true = compute_regularization(optimization_layer, y_true) - Ωŷ = compute_regularization(optimization_layer, ŷ) - l = (Ωy_true - dot(θ, y_true)) - (Ωŷ - dot(θ, ŷ)) - g = ŷ - y_true - return l, g -end - -function fenchel_young_loss_and_grad( - fyl::FenchelYoungLoss{O}, θ::AbstractArray, y_true::AbstractArray; kwargs... -) where {O<:AbstractRegularizedGeneralizedMaximizer} (; optimization_layer) = fyl ŷ = optimization_layer(θ; kwargs...) Ωy_true = compute_regularization(optimization_layer, y_true) @@ -55,8 +43,8 @@ function fenchel_young_loss_and_grad( l = (Ωy_true - objective_value(maximizer, θ, y_true; kwargs...)) - (Ωŷ - objective_value(maximizer, θ, ŷ; kwargs...)) - g = maximizer.g(ŷ; kwargs...) - maximizer.g(y_true; kwargs...) - return l, g + grad = apply_g(maximizer, ŷ; kwargs...) - apply_g(maximizer, y_true; kwargs...) + return l, grad end function fenchel_young_loss_and_grad( @@ -92,13 +80,14 @@ function fenchel_young_F_and_first_part_of_grad( maximizer = get_maximizer(perturbed) η_dist = empirical_predistribution(reinforce, θ) fk = FixKwargs(maximizer, kwargs) - gk = FixKwargs((y; kwargs...) -> apply_g(maximizer, y; kwargs...), kwargs) + gk = Fix1Kwargs(apply_g, maximizer, kwargs) y_dist = map(fk, η_dist) - return mean( + F = mean( objective_value(maximizer, η, y; kwargs...) for (η, y) in zip(η_dist.atoms, y_dist.atoms) - ), - mean(gk, y_dist) + ) + ŷ = mean(gk, y_dist) + return F, ŷ end function fenchel_young_F_and_first_part_of_grad( @@ -108,12 +97,13 @@ function fenchel_young_F_and_first_part_of_grad( maximizer = get_maximizer(perturbed) η_dist = empirical_predistribution(reinforce, θ) fk = FixKwargs(reinforce.f, kwargs) - gk = FixKwargs((y; kwargs...) -> apply_g(maximizer, y; kwargs...), kwargs) + gk = Fix1Kwargs(apply_g, maximizer, kwargs) y_dist = map(fk, η_dist) eZ_dist = map(Base.Fix2(./, θ), η_dist) - return mean( + F = mean( objective_value(maximizer, η, y; kwargs...) for (η, y) in zip(η_dist.atoms, y_dist.atoms) - ), - mean(gk.(map(.*, eZ_dist.atoms, y_dist.atoms))) + ) + almost_ŷ = mean(gk.(map(.*, eZ_dist.atoms, y_dist.atoms))) + return F, almost_ŷ end diff --git a/src/utils/generalized_maximizer.jl b/src/utils/generalized_maximizer.jl deleted file mode 100644 index b9154e4..0000000 --- a/src/utils/generalized_maximizer.jl +++ /dev/null @@ -1,48 +0,0 @@ -""" - GeneralizedMaximizer{F,G,H} - -Wrapper for generalized maximizers `maximizer` of the form argmax_y θᵀg(y) + h(y). -It is compatible with the following layers -- [`PerturbedAdditive`](@ref) (with or without [`FenchelYoungLoss`](@ref)) -- [`PerturbedMultiplicative`](@ref) (with or without [`FenchelYoungLoss`](@ref)) -- [`SPOPlusLoss`](@ref) -""" -struct GeneralizedMaximizer{F,G,H} - maximizer::F - g::G - h::H -end - -GeneralizedMaximizer(f; g=identity_kw, h=zero ∘ eltype_kw) = GeneralizedMaximizer(f, g, h) - -function Base.show(io::IO, f::GeneralizedMaximizer) - (; maximizer, g, h) = f - return print(io, "GeneralizedMaximizer($maximizer, $g, $h)") -end - -# Callable calls the wrapped maximizer -function (f::GeneralizedMaximizer)(θ::AbstractArray{<:Real}; kwargs...) - return f.maximizer(θ; kwargs...) -end - -objective_value(::Any, θ, y; kwargs...) = dot(θ, y) - -""" - objective_value(f, θ, y, kwargs...) - -Computes the objective value of given GeneralizedMaximizer `f`, knowing weights `θ` and solution `y`. -""" -function objective_value(f::GeneralizedMaximizer, θ, y; kwargs...) - return dot(θ, f.g(y; kwargs...)) .+ f.h(y; kwargs...) -end - -apply_g(::Any, y; kwargs...) = y -apply_h(::Any, y; kwargs...) = zero(eltype(y)) - -function apply_g(f::GeneralizedMaximizer, y; kwargs...) - return f.g(y; kwargs...) -end - -function apply_h(f::GeneralizedMaximizer, y; kwargs...) - return f.h(y; kwargs...) -end diff --git a/src/utils/linear_maximizer.jl b/src/utils/linear_maximizer.jl index f769de7..2d4f538 100644 --- a/src/utils/linear_maximizer.jl +++ b/src/utils/linear_maximizer.jl @@ -1,15 +1,21 @@ """ - LinearMaximizer{F,G,H} +$TYPEDEF -Wrapper for generalized maximizers `maximizer` of the form argmax_y θᵀg(y) + h(y). +Wrapper for generic minear maximizers of the form argmax_y θᵀg(y) + h(y). It is compatible with the following layers -- [`PerturbedAdditive`](@ref) (with or without [`FenchelYoungLoss`](@ref)) -- [`PerturbedMultiplicative`](@ref) (with or without [`FenchelYoungLoss`](@ref)) +- [`PerturbedAdditive`](@ref) (with or without a [`FenchelYoungLoss`](@ref)) +- [`PerturbedMultiplicative`](@ref) (with or without a [`FenchelYoungLoss`](@ref)) - [`SPOPlusLoss`](@ref) + +# Fields +$TYPEDFIELDS """ @kwdef struct LinearMaximizer{F,G,H} + "function θ ⟼ argmax_y θᵀg(y) + h(y)" maximizer::F + "function g(y) used in the objective" g::G = identity_kw + "function h(y) used in the objective" h::H = zero ∘ eltype_kw end @@ -18,33 +24,53 @@ function Base.show(io::IO, f::LinearMaximizer) return print(io, "LinearMaximizer($maximizer, $g, $h)") end +""" +$TYPEDSIGNATURES + +Constructor for [`LinearMaximizer`](@ref). +""" function LinearMaximizer(maximizer; g=identity_kw, h=zero ∘ eltype_kw) return LinearMaximizer(maximizer, g, h) end -# Callable calls the wrapped maximizer +""" +$TYPEDSIGNATURES + +Calls the wrapped maximizer. +""" function (f::LinearMaximizer)(θ::AbstractArray; kwargs...) return f.maximizer(θ; kwargs...) end +# default is oracles of the form argmax_y θᵀy objective_value(::Any, θ, y; kwargs...) = dot(θ, y) apply_g(::Any, y; kwargs...) = y -apply_h(::Any, y; kwargs...) = zero(eltype(y)) +# apply_h(::Any, y; kwargs...) = zero(eltype(y)) is not needed """ - objective_value(f, θ, y, kwargs...) +$TYPEDSIGNATURES Computes the objective value of given LinearMaximizer `f`, knowing weights `θ` and solution `y`. +i.e. θᵀg(y) + h(y) """ function objective_value(f::LinearMaximizer, θ, y; kwargs...) return dot(θ, f.g(y; kwargs...)) .+ f.h(y; kwargs...) end +""" +$TYPEDSIGNATURES + +Applies the function `g` of the LinearMaximizer `f` to `y`. +""" function apply_g(f::LinearMaximizer, y; kwargs...) return f.g(y; kwargs...) end -# Might not be needed +""" +$TYPEDSIGNATURES + +Applies the function `h` of the LinearMaximizer `f` to `y`. +""" function apply_h(f::LinearMaximizer, y; kwargs...) return f.h(y; kwargs...) end diff --git a/src/utils/utils.jl b/src/utils/utils.jl new file mode 100644 index 0000000..3a125d2 --- /dev/null +++ b/src/utils/utils.jl @@ -0,0 +1,35 @@ +""" +$TYPEDEF + +Callable struct that fixes the keyword arguments of `f` to `kwargs...`, and only accepts positional arguments. + +# Fields +$TYPEDFIELDS +""" +struct FixKwargs{F,K} + "function" + f::F + "fixed keyword arguments" + kwargs::K +end + +(fk::FixKwargs)(args...) = fk.f(args...; fk.kwargs...) + +""" +$TYPEDEF + +Callable struct that fixes the first argument of `f` to `x`, and the keyword arguments to `kwargs...`. + +# Fields +$TYPEDFIELDS +""" +struct Fix1Kwargs{F,K,T} <: Function + "function" + f::F + "fixed first argument" + x::T + "fixed keyword arguments" + kwargs::K +end + +(fk::Fix1Kwargs)(args...) = fk.f(fk.x, args...; fk.kwargs...) diff --git a/test/learning_generalized_maximizer.jl b/test/learning_generalized_maximizer.jl index 3b37ba8..ed017eb 100644 --- a/test/learning_generalized_maximizer.jl +++ b/test/learning_generalized_maximizer.jl @@ -234,7 +234,7 @@ end const RI = RequiredInterfaces Random.seed!(63) - struct MyRegularized{M<:GeneralizedMaximizer} <: AbstractRegularizedGeneralizedMaximizer + struct MyRegularized{M<:LinearMaximizer} <: AbstractRegularized # GeneralizedMaximizer maximizer::M end @@ -246,7 +246,7 @@ end @test RI.check_interface_implemented(AbstractRegularized, MyRegularized) - regularized = MyRegularized(GeneralizedMaximizer(sparse_argmax)) + regularized = MyRegularized(LinearMaximizer(sparse_argmax)) test_pipeline!( PipelineLossImitation(); From 908482da71bbdecc2522e2c9bf18e68553ab9e12 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 18:00:50 +0100 Subject: [PATCH 17/26] Fix and cleanup SPO+ --- src/losses/spoplus_loss.jl | 39 ++++++-------------------- test/learning_generalized_maximizer.jl | 4 +-- 2 files changed, 11 insertions(+), 32 deletions(-) diff --git a/src/losses/spoplus_loss.jl b/src/losses/spoplus_loss.jl index c572842..d0aadb1 100644 --- a/src/losses/spoplus_loss.jl +++ b/src/losses/spoplus_loss.jl @@ -1,16 +1,17 @@ """ - SPOPlusLoss <: AbstractLossLayer +$TYPEDEF Convex surrogate of the Smart "Predict-then-Optimize" loss. # Fields -- `maximizer`: linear maximizer function of the form `θ -> ŷ(θ) = argmax θᵀy` -- `α::Float64`: convexification parameter, default = 2.0 +$TYPEDFIELDS Reference: """ struct SPOPlusLoss{F} <: AbstractLossLayer + "linear maximizer function of the form `θ -> ŷ(θ) = argmax θᵀy`" maximizer::F + "convexification parameter, default = 2.0" α::Float64 end @@ -20,7 +21,9 @@ function Base.show(io::IO, spol::SPOPlusLoss) end """ - SPOPlusLoss(maximizer; α=2.0) +$TYPEDSIGNATURES + +Constructor for [`SPOPlusLoss`](@ref). """ SPOPlusLoss(maximizer; α=2.0) = SPOPlusLoss(maximizer, float(α)) @@ -35,17 +38,7 @@ function (spol::SPOPlusLoss)( (; maximizer, α) = spol θ_α = α * θ - θ_true y_α = maximizer(θ_α; kwargs...) - l = dot(θ_α, y_α) - dot(θ_α, y_true) - return l -end - -function (spol::SPOPlusLoss{<:GeneralizedMaximizer})( - θ::AbstractArray, θ_true::AbstractArray, y_true::AbstractArray; kwargs... -) - (; maximizer, α) = spol - θ_α = α * θ - θ_true - y_α = maximizer(θ_α; kwargs...) - # This only works in theory if α = 2 + # In theory, in the general case with a LinearMaximizer, this only works if α = 2 l = objective_value(maximizer, θ_α, y_α; kwargs...) - objective_value(maximizer, θ_α, y_true; kwargs...) @@ -68,20 +61,6 @@ function compute_loss_and_gradient( θ_true::AbstractArray, y_true::AbstractArray; kwargs..., -) - (; maximizer, α) = spol - θ_α = α * θ - θ_true - y_α = maximizer(θ_α; kwargs...) - l = dot(θ_α, y_α) - dot(θ_α, y_true) - return l, α .* (y_α .- y_true) -end - -function compute_loss_and_gradient( - spol::SPOPlusLoss{<:GeneralizedMaximizer}, - θ::AbstractArray, - θ_true::AbstractArray, - y_true::AbstractArray; - kwargs..., ) (; maximizer, α) = spol θ_α = α * θ - θ_true @@ -89,7 +68,7 @@ function compute_loss_and_gradient( l = objective_value(maximizer, θ_α, y_α; kwargs...) - objective_value(maximizer, θ_α, y_true; kwargs...) - g = α .* (maximizer.g(y_α; kwargs...) - maximizer.g(y_true; kwargs...)) + g = α .* (apply_g(maximizer, y_α; kwargs...) - apply_g(maximizer, y_true; kwargs...)) return l, g end diff --git a/test/learning_generalized_maximizer.jl b/test/learning_generalized_maximizer.jl index ed017eb..037a638 100644 --- a/test/learning_generalized_maximizer.jl +++ b/test/learning_generalized_maximizer.jl @@ -133,7 +133,7 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + generalized_maximizer = LinearMaximizer(max_pricing; g, h) function cost(y; instance) return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) end @@ -157,7 +157,7 @@ end true_encoder = encoder_factory() - generalized_maximizer = GeneralizedMaximizer(max_pricing; g, h) + generalized_maximizer = LinearMaximizer(max_pricing; g, h) function cost(y; instance) return -objective_value(generalized_maximizer, true_encoder(instance), y; instance) end From 1a24016b9e9b3c2cbeee737734cfb15fcec8dec3 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 18:06:55 +0100 Subject: [PATCH 18/26] cleanup docstrings --- src/losses/spoplus_loss.jl | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/losses/spoplus_loss.jl b/src/losses/spoplus_loss.jl index d0aadb1..46f0cc3 100644 --- a/src/losses/spoplus_loss.jl +++ b/src/losses/spoplus_loss.jl @@ -30,7 +30,11 @@ SPOPlusLoss(maximizer; α=2.0) = SPOPlusLoss(maximizer, float(α)) ## Forward pass """ - (spol::SPOPlusLoss)(θ, θ_true, y_true; kwargs...) +$TYPEDSIGNATURES + +Forward pass of the SPO+ loss with given target `θ_true` and `y_true`. +The third argument `y_true` is optional, as it can be computed from `θ_true`. +However, providing it directly can save computation time. """ function (spol::SPOPlusLoss)( θ::AbstractArray, θ_true::AbstractArray, y_true::AbstractArray; kwargs... @@ -46,7 +50,10 @@ function (spol::SPOPlusLoss)( end """ - (spol::SPOPlusLoss)(θ, θ_true; kwargs...) +$TYPEDSIGNATURES + +Forward pass of the SPO+ loss with given target `θ_true`. +For better performance, you can also provide `y_true` directly as a third argument. """ function (spol::SPOPlusLoss)(θ::AbstractArray, θ_true::AbstractArray; kwargs...) y_true = spol.maximizer(θ_true; kwargs...) From 21c732bf38366a13198b6bcbee428d9c90fa40c5 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 18:10:24 +0100 Subject: [PATCH 19/26] ThreadsX not needed anymore --- Project.toml | 2 -- src/InferOpt.jl | 1 - 2 files changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index b1c0058..947003a 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,6 @@ RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" [weakdeps] DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d" @@ -37,7 +36,6 @@ RequiredInterfaces = "0.1.3" Statistics = "1" StatsBase = "0.33, 0.34" StatsFuns = "1.3" -ThreadsX = "0.1.11" julia = "1.10" [extras] diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 6b4b653..9fe32a9 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -25,7 +25,6 @@ using Random: Random, AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed! using Statistics: mean using StatsBase: StatsBase, sample using StatsFuns: logaddexp, softmax -using ThreadsX: ThreadsX using RequiredInterfaces include("interface.jl") From 126cc27ec62853383993749365d4bcfab6f79e4d Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 18:40:54 +0100 Subject: [PATCH 20/26] fix tests --- docs/Manifest.toml | 1318 -------------------------------------------- test/perturbed.jl | 8 +- 2 files changed, 4 insertions(+), 1322 deletions(-) delete mode 100644 docs/Manifest.toml diff --git a/docs/Manifest.toml b/docs/Manifest.toml deleted file mode 100644 index 4bc27fd..0000000 --- a/docs/Manifest.toml +++ /dev/null @@ -1,1318 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.11.1" -manifest_format = "2.0" -project_hash = "6e00f168de8676c54ad2b043cfc74314d50a8935" - -[[deps.ANSIColoredPrinters]] -git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" -uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" -version = "0.0.1" - -[[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" -weakdeps = ["ChainRulesCore", "Test"] - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" - -[[deps.AbstractTrees]] -git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.4.5" - -[[deps.Accessors]] -deps = ["CompositionsBase", "ConstructionBase", "InverseFunctions", "LinearAlgebra", "MacroTools", "Markdown"] -git-tree-sha1 = "96bed9b1b57cf750cca50c311a197e306816a1cc" -uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.39" - - [deps.Accessors.extensions] - AccessorsAxisKeysExt = "AxisKeys" - AccessorsDatesExt = "Dates" - AccessorsIntervalSetsExt = "IntervalSets" - AccessorsStaticArraysExt = "StaticArrays" - AccessorsStructArraysExt = "StructArrays" - AccessorsTestExt = "Test" - AccessorsUnitfulExt = "Unitful" - - [deps.Accessors.weakdeps] - AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" - Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - Requires = "ae029012-a4dd-5104-9daa-d747884805df" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "50c3c56a52972d78e8be9fd135bfb91c9574c140" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "4.1.1" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.AliasTables]] -deps = ["PtrArrays", "Random"] -git-tree-sha1 = "9876e1e164b144ca45e9e3198d0b689cadfed9ff" -uuid = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" -version = "1.1.3" - -[[deps.ArgCheck]] -git-tree-sha1 = "680b3b8759bd4c54052ada14e52355ab69e07876" -uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" -version = "2.4.0" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.2" - -[[deps.ArnoldiMethod]] -deps = ["LinearAlgebra", "Random", "StaticArrays"] -git-tree-sha1 = "d57bd3762d308bded22c3b82d033bff85f6195c6" -uuid = "ec485272-7323-5ecc-a04f-4719b315124d" -version = "0.4.0" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" -version = "1.11.0" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c3b238aa28c1bebd4b5ea4988bebf27e9a01b72b" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "1.0.1" - - [deps.Atomix.extensions] - AtomixCUDAExt = "CUDA" - AtomixMetalExt = "Metal" - AtomixoneAPIExt = "oneAPI" - - [deps.Atomix.weakdeps] - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" - -[[deps.BangBang]] -deps = ["Accessors", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires"] -git-tree-sha1 = "e2144b631226d9eeab2d746ca8880b7ccff504ae" -uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" -version = "0.4.3" - - [deps.BangBang.extensions] - BangBangChainRulesCoreExt = "ChainRulesCore" - BangBangDataFramesExt = "DataFrames" - BangBangStaticArraysExt = "StaticArrays" - BangBangStructArraysExt = "StructArrays" - BangBangTablesExt = "Tables" - BangBangTypedTablesExt = "TypedTables" - - [deps.BangBang.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" - Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" - TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" -version = "1.11.0" - -[[deps.Baselet]] -git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" -uuid = "9718e550-a3fa-408a-8086-8db961cd8217" -version = "0.1.1" - -[[deps.CEnum]] -git-tree-sha1 = "389ad5c84de1ae7cf0e28e381131c98ea87d54fc" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.5.0" - -[[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "SparseInverseSubset", "Statistics", "StructArrays", "SuiteSparse"] -git-tree-sha1 = "bcffdcaed50d3453673b852f3522404a94b50fad" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.72.1" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra"] -git-tree-sha1 = "3e4b134270b372f2ed4d4d0e936aabaefc1802bc" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.25.0" -weakdeps = ["SparseArrays"] - - [deps.ChainRulesCore.extensions] - ChainRulesCoreSparseArraysExt = "SparseArrays" - -[[deps.ChunkSplitters]] -deps = ["TestItems"] -git-tree-sha1 = "01d5db8756afc4022b1cf267cfede13245226c72" -uuid = "ae650224-84b6-46f8-82ea-d812ca08434e" -version = "2.6.0" - -[[deps.CodeTracking]] -deps = ["InteractiveUtils", "UUIDs"] -git-tree-sha1 = "7eee164f122511d3e4e1ebadb7956939ea7e1c77" -uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2" -version = "1.3.6" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.6" - -[[deps.ColorSchemes]] -deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "c785dfb1b3bfddd1da557e861b919819b82bbe5b" -uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.27.1" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "b10d0b65641d57b8b4d5e234446582de5047050d" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.5" - -[[deps.ColorVectorSpace]] -deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] -git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" -uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" -version = "0.10.0" -weakdeps = ["SpecialFunctions"] - - [deps.ColorVectorSpace.extensions] - SpecialFunctionsExt = "SpecialFunctions" - -[[deps.Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "362a287c3aa50601b0bc359053d5c2468f0e7ce0" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.11" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools"] -git-tree-sha1 = "cda2cfaebb4be89c9084adaca7dd7333369715c5" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.1" - -[[deps.Compat]] -deps = ["TOML", "UUIDs"] -git-tree-sha1 = "8ae8d32e09f0dcf42a36b90d4e17f5dd2e4c4215" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.16.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.1.1+0" - -[[deps.CompositionsBase]] -git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" -uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" -version = "0.1.2" -weakdeps = ["InverseFunctions"] - - [deps.CompositionsBase.extensions] - CompositionsBaseInverseFunctionsExt = "InverseFunctions" - -[[deps.ConstructionBase]] -git-tree-sha1 = "76219f1ed5771adbb096743bff43fb5fdd4c1157" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.8" - - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseLinearAlgebraExt = "LinearAlgebra" - ConstructionBaseStaticArraysExt = "StaticArrays" - - [deps.ConstructionBase.weakdeps] - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.ContextVariablesX]] -deps = ["Compat", "Logging", "UUIDs"] -git-tree-sha1 = "25cc3803f1030ab855e383129dcd3dc294e322cc" -uuid = "6add18c4-b38d-439d-96f6-d6bc489c04c5" -version = "0.1.3" - -[[deps.Contour]] -git-tree-sha1 = "439e35b0b36e2e5881738abc8857bd92ad6ff9a8" -uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" -version = "0.6.3" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DataAPI]] -git-tree-sha1 = "abe83f3a2f1b857aac70ef8b269080af17764bbe" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.16.0" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "1d0a14036acb104d9e89698bd408f63ab58cdc82" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.20" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" -version = "1.11.0" - -[[deps.DefineSingletons]] -git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" -uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" -version = "0.1.2" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -version = "1.9.1" - -[[deps.DensityInterface]] -deps = ["InverseFunctions", "Test"] -git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" -uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" -version = "0.4.0" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.DifferentiableExpectations]] -deps = ["ChainRulesCore", "Compat", "DensityInterface", "Distributions", "DocStringExtensions", "LinearAlgebra", "OhMyThreads", "Random", "Statistics", "StatsBase"] -git-tree-sha1 = "829dd95b32a41526923f44799ce0762fcd9a3a37" -uuid = "fc55d66b-b2a8-4ccc-9d64-c0c2166ceb36" -version = "0.2.0" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" -version = "1.11.0" - -[[deps.Distributions]] -deps = ["AliasTables", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns"] -git-tree-sha1 = "4b138e4643b577ccf355377c2bc70fa975af25de" -uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.115" -weakdeps = ["ChainRulesCore", "DensityInterface", "Test"] - - [deps.Distributions.extensions] - DistributionsChainRulesCoreExt = "ChainRulesCore" - DistributionsDensityInterfaceExt = "DensityInterface" - DistributionsTestExt = "Test" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Documenter]] -deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"] -git-tree-sha1 = "d0ea2c044963ed6f37703cead7e29f70cba13d7e" -uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "1.8.0" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.EnzymeCore]] -git-tree-sha1 = "0cdb7af5c39e92d78a0ee8d0a447d32f7593137e" -uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" -version = "0.8.8" -weakdeps = ["Adapt"] - - [deps.EnzymeCore.extensions] - AdaptExt = "Adapt" - -[[deps.Expat_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e51db81749b0777b2147fbe7b783ee79045b8e99" -uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.6.4+1" - -[[deps.FLoops]] -deps = ["BangBang", "Compat", "FLoopsBase", "InitialValues", "JuliaVariables", "MLStyle", "Serialization", "Setfield", "Transducers"] -git-tree-sha1 = "0a2e5873e9a5f54abb06418d57a8df689336a660" -uuid = "cc61a311-1640-44b5-9fba-1b764f453329" -version = "0.2.2" - -[[deps.FLoopsBase]] -deps = ["ContextVariablesX"] -git-tree-sha1 = "656f7a6859be8673bf1f35da5670246b923964f7" -uuid = "b9860ae5-e623-471e-878b-f6a53c775ea6" -version = "0.1.1" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" -version = "1.11.0" - -[[deps.FillArrays]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "6a70198746448456524cb442b8af316927ff3e1a" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.13.0" -weakdeps = ["PDMats", "SparseArrays", "Statistics"] - - [deps.FillArrays.extensions] - FillArraysPDMatsExt = "PDMats" - FillArraysSparseArraysExt = "SparseArrays" - FillArraysStatisticsExt = "Statistics" - -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "05882d6995ae5c12bb5f36dd2ed3f61c98cbb172" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.5" - -[[deps.Flux]] -deps = ["Adapt", "ChainRulesCore", "Compat", "EnzymeCore", "Functors", "LinearAlgebra", "MLDataDevices", "MLUtils", "MacroTools", "NNlib", "OneHotArrays", "Optimisers", "Preferences", "ProgressLogging", "Random", "Reexport", "Setfield", "SparseArrays", "SpecialFunctions", "Statistics", "Zygote"] -git-tree-sha1 = "86729467baa309581eb0e648b9ede0aeb40016be" -uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" -version = "0.16.0" - - [deps.Flux.extensions] - FluxAMDGPUExt = "AMDGPU" - FluxCUDAExt = "CUDA" - FluxCUDAcuDNNExt = ["CUDA", "cuDNN"] - FluxEnzymeExt = "Enzyme" - FluxMPIExt = "MPI" - FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"] - - [deps.Flux.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" - MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" - NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "a2df1b776752e3f344e5116c06d75a10436ab853" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.38" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.Functors]] -deps = ["Compat", "ConstructionBase", "LinearAlgebra", "Random"] -git-tree-sha1 = "60a0339f28a233601cb74468032b5c302d5067de" -uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.5.2" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" -version = "1.11.0" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "KernelAbstractions", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "4ec797b1b2ee964de0db96f10cce05b81f23e108" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "11.1.0" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "83cf05ab16a73219e5f6bd1bdfa9848fa24ac627" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.2.0" - -[[deps.Git]] -deps = ["Git_jll"] -git-tree-sha1 = "04eff47b1354d702c3a85e8ab23d539bb7d5957e" -uuid = "d7ba0133-e1db-5d97-8f8c-041e4b3a1eb2" -version = "1.3.1" - -[[deps.Git_jll]] -deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "399f4a308c804b446ae4c91eeafadb2fe2c54ff9" -uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb" -version = "2.47.1+0" - -[[deps.Graphs]] -deps = ["ArnoldiMethod", "Compat", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"] -git-tree-sha1 = "1dc470db8b1131cfc7fb4c115de89fe391b9e780" -uuid = "86223c79-3864-5bf0-83f7-82e725a168b6" -version = "1.12.0" - -[[deps.GridGraphs]] -deps = ["DataStructures", "FillArrays", "Graphs", "SparseArrays"] -git-tree-sha1 = "219584c79a649f5ae4301ff0924f235e92cc60d4" -uuid = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb" -version = "0.10.1" - -[[deps.HypergeometricFunctions]] -deps = ["LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "b1c2585431c382e3fe5805874bda6aea90a95de9" -uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.25" - -[[deps.IOCapture]] -deps = ["Logging", "Random"] -git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770" -uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.5" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "950c3717af761bc3ff906c2e8e52bd83390b6ec2" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.14" - -[[deps.InferOpt]] -deps = ["ChainRulesCore", "DensityInterface", "DifferentiableExpectations", "Distributions", "LinearAlgebra", "Random", "RequiredInterfaces", "Statistics", "StatsBase", "StatsFuns", "ThreadsX"] -path = ".." -uuid = "4846b161-c94e-4150-8dac-c7ae193c601f" -version = "0.6.1" - - [deps.InferOpt.extensions] - InferOptFrankWolfeExt = "DifferentiableFrankWolfe" - - [deps.InferOpt.weakdeps] - DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d" - -[[deps.Inflate]] -git-tree-sha1 = "d1b1b796e47d94588b3757fe84fbf65a5ec4a80d" -uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.5" - -[[deps.InitialValues]] -git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" -uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" -version = "0.3.1" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -version = "1.11.0" - -[[deps.InverseFunctions]] -git-tree-sha1 = "a779299d77cd080bf77b97535acecd73e1c5e5cb" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.17" -weakdeps = ["Dates", "Test"] - - [deps.InverseFunctions.extensions] - InverseFunctionsDatesExt = "Dates" - InverseFunctionsTestExt = "Test" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "be3dc50a92e5a386872a493a10050136d4703f9b" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.6.1" - -[[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" - -[[deps.JuliaInterpreter]] -deps = ["CodeTracking", "InteractiveUtils", "Random", "UUIDs"] -git-tree-sha1 = "10da5154188682e5c0726823c2b5125957ec3778" -uuid = "aa1ae85d-cabe-5617-a682-6adf51b2e16a" -version = "0.9.38" - -[[deps.JuliaVariables]] -deps = ["MLStyle", "NameResolution"] -git-tree-sha1 = "49fb3cb53362ddadb4415e9b73926d6b40709e70" -uuid = "b14d175d-62b4-44ba-8fb7-3064adc8c3ec" -version = "0.2.4" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "MacroTools", "PrecompileTools", "Requires", "StaticArrays", "UUIDs"] -git-tree-sha1 = "b9a838cd3028785ac23822cded5126b3da394d1a" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.31" -weakdeps = ["EnzymeCore", "LinearAlgebra", "SparseArrays"] - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - LinearAlgebraExt = "LinearAlgebra" - SparseArraysExt = "SparseArrays" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Preferences", "Printf", "Unicode"] -git-tree-sha1 = "d422dfd9707bec6617335dc2ea3c5172a87d5908" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "9.1.3" - - [deps.LLVM.extensions] - BFloat16sExt = "BFloat16s" - - [deps.LLVM.weakdeps] - BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "05a8bd5a42309a9ec82f700876903abce1017dd3" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.34+0" - -[[deps.LazilyInitializedFields]] -git-tree-sha1 = "0f2da712350b020bc3957f269c9caad516383ee0" -uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf" -version = "1.3.0" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" -version = "1.11.0" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.6.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" -version = "1.11.0" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.7.2+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" -version = "1.11.0" - -[[deps.Libiconv_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "61dfdba58e585066d8bce214c5a51eaa0539f269" -uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.17.0+1" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -version = "1.11.0" - -[[deps.Literate]] -deps = ["Base64", "IOCapture", "JSON", "REPL"] -git-tree-sha1 = "da046be6d63304f7ba9c1bb04820fb306ba1ab12" -uuid = "98b081ad-f1c9-55d3-8b20-4c87d4299306" -version = "2.20.1" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "13ca9e2586b89836fd20cccf56e57e2b9ae7f38f" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.29" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" -version = "1.11.0" - -[[deps.LoweredCodeUtils]] -deps = ["JuliaInterpreter"] -git-tree-sha1 = "688d6d9e098109051ae33d126fcfc88c4ce4a021" -uuid = "6f1432cf-f94c-5a45-995e-cdbf5db27b0b" -version = "3.1.0" - -[[deps.MLDataDevices]] -deps = ["Adapt", "Compat", "Functors", "Preferences", "Random"] -git-tree-sha1 = "80eb04ae663507d9303473d26710a4c62efa0f3c" -uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" -version = "1.6.5" - - [deps.MLDataDevices.extensions] - MLDataDevicesAMDGPUExt = "AMDGPU" - MLDataDevicesCUDAExt = "CUDA" - MLDataDevicesChainRulesCoreExt = "ChainRulesCore" - MLDataDevicesChainRulesExt = "ChainRules" - MLDataDevicesComponentArraysExt = "ComponentArrays" - MLDataDevicesFillArraysExt = "FillArrays" - MLDataDevicesGPUArraysExt = "GPUArrays" - MLDataDevicesMLUtilsExt = "MLUtils" - MLDataDevicesMetalExt = ["GPUArrays", "Metal"] - MLDataDevicesOneHotArraysExt = "OneHotArrays" - MLDataDevicesReactantExt = "Reactant" - MLDataDevicesRecursiveArrayToolsExt = "RecursiveArrayTools" - MLDataDevicesReverseDiffExt = "ReverseDiff" - MLDataDevicesSparseArraysExt = "SparseArrays" - MLDataDevicesTrackerExt = "Tracker" - MLDataDevicesZygoteExt = "Zygote" - MLDataDevicescuDNNExt = ["CUDA", "cuDNN"] - MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"] - - [deps.MLDataDevices.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" - FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" - GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" - MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" - Metal = "dde4c033-4e86-420c-a63e-0dd931031962" - OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" - Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" - RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" - -[[deps.MLStyle]] -git-tree-sha1 = "bc38dff0548128765760c79eb7388a4b37fae2c8" -uuid = "d8e11817-5142-5d16-987a-aa16d5891078" -version = "0.4.17" - -[[deps.MLUtils]] -deps = ["ChainRulesCore", "Compat", "DataAPI", "DelimitedFiles", "FLoops", "NNlib", "Random", "ShowCases", "SimpleTraits", "Statistics", "StatsBase", "Tables", "Transducers"] -git-tree-sha1 = "b45738c2e3d0d402dffa32b2c1654759a2ac35a4" -uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" -version = "0.4.4" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "2fa9ee3e63fd3a4f7a9a4f4744a52f4856de82df" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.13" - -[[deps.MarchingCubes]] -deps = ["PrecompileTools", "StaticArrays"] -git-tree-sha1 = "301345b808264ae42e60d10a519e55c5d992969b" -uuid = "299715c1-40a9-479a-aaf9-4a633d36f717" -version = "0.1.10" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" -version = "1.11.0" - -[[deps.MarkdownAST]] -deps = ["AbstractTrees", "Markdown"] -git-tree-sha1 = "465a70f0fc7d443a00dcdc3267a497397b8a3899" -uuid = "d0879d2d-cac2-40c8-9cee-1863dc0c7391" -version = "0.1.2" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.6+0" - -[[deps.MicroCollections]] -deps = ["Accessors", "BangBang", "InitialValues"] -git-tree-sha1 = "44d32db644e84c75dab479f1bc15ee76a1a3618f" -uuid = "128add7d-3638-4c79-886c-908ea0c25c34" -version = "0.2.0" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "ec4f7fbeab05d7747bdf98eb74d130a2a2ed298d" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.2.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" -version = "1.11.0" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.12.12" - -[[deps.NNlib]] -deps = ["Adapt", "Atomix", "ChainRulesCore", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "1177f161cda2083543b9967d7ca2a3e24e721e13" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.9.26" - - [deps.NNlib.extensions] - NNlibAMDGPUExt = "AMDGPU" - NNlibCUDACUDNNExt = ["CUDA", "cuDNN"] - NNlibCUDAExt = "CUDA" - NNlibEnzymeCoreExt = "EnzymeCore" - NNlibFFTWExt = "FFTW" - NNlibForwardDiffExt = "ForwardDiff" - - [deps.NNlib.weakdeps] - AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NameResolution]] -deps = ["PrettyPrint"] -git-tree-sha1 = "1a0fa0e9613f46c9b8c11eee38ebb4f590013c5e" -uuid = "71a1bf82-56d0-4bbc-8a3c-48b961074391" -version = "0.1.5" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OhMyThreads]] -deps = ["BangBang", "ChunkSplitters", "StableTasks", "TaskLocalValues"] -git-tree-sha1 = "881876fc70ab53ad60671ad4a1af25c920aee0eb" -uuid = "67456a42-1dca-4109-a031-0a68de7e3ad5" -version = "0.5.3" - -[[deps.OneHotArrays]] -deps = ["Adapt", "ChainRulesCore", "Compat", "GPUArraysCore", "LinearAlgebra", "NNlib"] -git-tree-sha1 = "c8c7f6bfabe581dc40b580313a75f1ecce087e27" -uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.6" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.27+1" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+2" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "7493f61f55a6cce7325f197443aa80d32554ba10" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.15+1" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optimisers]] -deps = ["ChainRulesCore", "Functors", "LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "c5feff34a5cf6bdc6ca06de0c5b7d6847199f1c0" -uuid = "3bd65402-5787-11e9-1adc-39752487f4e2" -version = "0.4.2" -weakdeps = ["Adapt", "EnzymeCore"] - - [deps.Optimisers.extensions] - OptimisersAdaptExt = ["Adapt"] - OptimisersEnzymeCoreExt = "EnzymeCore" - -[[deps.OrderedCollections]] -git-tree-sha1 = "12f1439c4f986bb868acda6ea33ebc78e19b95ad" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.7.0" - -[[deps.PCRE2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" -version = "10.42.0+1" - -[[deps.PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "949347156c25054de2db3b166c52ac4728cbad65" -uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.31" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.1" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "Random", "SHA", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.11.0" -weakdeps = ["REPL"] - - [deps.Pkg.extensions] - REPLExt = "REPL" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.PrettyPrint]] -git-tree-sha1 = "632eb4abab3449ab30c5e1afaa874f0b98b586e4" -uuid = "8162dcfd-2161-5ef2-ae6c-7681170c5f98" -version = "0.2.0" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" -version = "1.11.0" - -[[deps.ProgressLogging]] -deps = ["Logging", "SHA", "UUIDs"] -git-tree-sha1 = "80d919dee55b9c50e8d9e2da5eeafff3fe58b539" -uuid = "33c8b6b6-d38a-422a-b730-caa89a2f386c" -version = "0.1.4" - -[[deps.ProgressMeter]] -deps = ["Distributed", "Printf"] -git-tree-sha1 = "8f6bc219586aef8baf0ff9a5fe16ee9c70cb65e4" -uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.10.2" - -[[deps.PtrArrays]] -git-tree-sha1 = "77a42d78b6a92df47ab37e177b2deac405e1c88f" -uuid = "43287f4e-b6f4-7ad1-bb20-aadabca52c3d" -version = "1.2.1" - -[[deps.QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "cda3b045cf9ef07a08ad46731f5a3165e56cf3da" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.11.1" - - [deps.QuadGK.extensions] - QuadGKEnzymeExt = "Enzyme" - - [deps.QuadGK.weakdeps] - Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "StyledStrings", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" -version = "1.11.0" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -version = "1.11.0" - -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.Referenceables]] -deps = ["Adapt"] -git-tree-sha1 = "02d31ad62838181c1a3a5fd23a1ce5914a643601" -uuid = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" -version = "0.1.3" - -[[deps.RegistryInstances]] -deps = ["LazilyInitializedFields", "Pkg", "TOML", "Tar"] -git-tree-sha1 = "ffd19052caf598b8653b99404058fce14828be51" -uuid = "2792f1a3-b283-48e8-9a74-f99dce5104f3" -version = "0.1.0" - -[[deps.RequiredInterfaces]] -deps = ["InteractiveUtils", "Logging", "Test"] -git-tree-sha1 = "f4e7fec4fa52d0919f18fec552d2fabf9e94811d" -uuid = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6" -version = "0.1.7" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.Revise]] -deps = ["CodeTracking", "Distributed", "FileWatching", "JuliaInterpreter", "LibGit2", "LoweredCodeUtils", "OrderedCollections", "REPL", "Requires", "UUIDs", "Unicode"] -git-tree-sha1 = "470f48c9c4ea2170fd4d0f8eb5118327aada22f5" -uuid = "295af30f-e4ad-537b-8983-00126c2a3abe" -version = "3.6.4" - -[[deps.Rmath]] -deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "852bd0f55565a9e973fcfee83a84413270224dc4" -uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.8.0" - -[[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "58cdd8fb2201a6267e1db87ff148dd6c1dbd8ad8" -uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.5.1+0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" -version = "1.11.0" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" -version = "1.11.0" - -[[deps.ShowCases]] -git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" -uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" -version = "0.1.0" - -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" -version = "1.11.0" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "66e0a8e672a0bdfca2c3f5937efb8538b9ddc085" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.2.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -version = "1.11.0" - -[[deps.SparseInverseSubset]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "52962839426b75b3021296f7df242e40ecfc0852" -uuid = "dc90abb0-5640-4711-901d-7e5b23a2fada" -version = "0.1.2" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "64cca0c26b4f31ba18f13f6c12af7c85f478cfde" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.5.0" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.SplittablesBase]] -deps = ["Setfield", "Test"] -git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" -uuid = "171d559e-b47b-412a-8079-5efa626c420e" -version = "0.1.15" - -[[deps.StableTasks]] -git-tree-sha1 = "073d5c20d44129b20fe954720b97069579fa403b" -uuid = "91464d47-22a1-43fe-8b7f-2d57ee82463f" -version = "0.1.5" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "PrecompileTools", "Random", "StaticArraysCore"] -git-tree-sha1 = "7c01731da8ab6d3094c4d44c9057b00932459255" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.9.9" -weakdeps = ["ChainRulesCore", "Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysChainRulesCoreExt = "ChainRulesCore" - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "192954ef1208c7019899fbf8049e717f92959682" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.3" - -[[deps.Statistics]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "ae3bb1eb3bba077cd276bc5cfc337cc65c3075c0" -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.11.1" -weakdeps = ["SparseArrays"] - - [deps.Statistics.extensions] - SparseArraysExt = ["SparseArrays"] - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.7.0" - -[[deps.StatsBase]] -deps = ["AliasTables", "DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "29321314c920c26684834965ec2ce0dacc9cf8e5" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.4" - -[[deps.StatsFuns]] -deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "b423576adc27097764a90e163157bcfc9acf0f46" -uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.2" -weakdeps = ["ChainRulesCore", "InverseFunctions"] - - [deps.StatsFuns.extensions] - StatsFunsChainRulesCoreExt = "ChainRulesCore" - StatsFunsInverseFunctionsExt = "InverseFunctions" - -[[deps.StructArrays]] -deps = ["ConstructionBase", "DataAPI", "Tables"] -git-tree-sha1 = "9537ef82c42cdd8c5d443cbc359110cbb36bae10" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.21" -weakdeps = ["Adapt", "GPUArraysCore", "KernelAbstractions", "LinearAlgebra", "SparseArrays", "StaticArrays"] - - [deps.StructArrays.extensions] - StructArraysAdaptExt = "Adapt" - StructArraysGPUArraysCoreExt = ["GPUArraysCore", "KernelAbstractions"] - StructArraysLinearAlgebraExt = "LinearAlgebra" - StructArraysSparseArraysExt = "SparseArrays" - StructArraysStaticArraysExt = "StaticArrays" - -[[deps.StyledStrings]] -uuid = "f489334b-da3d-4c2e-b8f0-e476e12c162b" -version = "1.11.0" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.7.0+0" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "OrderedCollections", "TableTraits"] -git-tree-sha1 = "598cd7c1f68d1e205689b1c2fe65a9f85846f297" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.12.0" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.TaskLocalValues]] -git-tree-sha1 = "d155450e6dff2a8bc2fcb81dcb194bd98b0aeb46" -uuid = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34" -version = "0.1.2" - -[[deps.TensorCore]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" -uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" -version = "0.1.1" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -version = "1.11.0" - -[[deps.TestItems]] -git-tree-sha1 = "42fd9023fef18b9b78c8343a4e2f3813ffbcefcb" -uuid = "1c621080-faea-4a02-84b6-bbd5e436b8fe" -version = "1.0.0" - -[[deps.ThreadsX]] -deps = ["Accessors", "ArgCheck", "BangBang", "ConstructionBase", "InitialValues", "MicroCollections", "Referenceables", "SplittablesBase", "Transducers"] -git-tree-sha1 = "70bd8244f4834d46c3d68bd09e7792d8f571ef04" -uuid = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d" -version = "0.1.12" - -[[deps.TranscodingStreams]] -git-tree-sha1 = "0c45878dcfdcfa8480052b6ab162cdd138781742" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.3" - -[[deps.Transducers]] -deps = ["Accessors", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "SplittablesBase", "Tables"] -git-tree-sha1 = "7deeab4ff96b85c5f72c824cae53a1398da3d1cb" -uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" -version = "0.4.84" - - [deps.Transducers.extensions] - TransducersAdaptExt = "Adapt" - TransducersBlockArraysExt = "BlockArrays" - TransducersDataFramesExt = "DataFrames" - TransducersLazyArraysExt = "LazyArrays" - TransducersOnlineStatsBaseExt = "OnlineStatsBase" - TransducersReferenceablesExt = "Referenceables" - - [deps.Transducers.weakdeps] - Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" - BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" - DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" - LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" - OnlineStatsBase = "925886fa-5bf2-5e8e-b522-a9147a512338" - Referenceables = "42d2dcc6-99eb-4e98-b66c-637b7d73030e" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" -version = "1.11.0" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" -version = "1.11.0" - -[[deps.UnicodePlots]] -deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "PrecompileTools", "Printf", "SparseArrays", "StaticArrays", "StatsBase"] -git-tree-sha1 = "f18128aa9e5cf059426a91bdc750b1f63a2fdcd9" -uuid = "b8865327-cd53-5732-bb35-84acbb429228" -version = "3.7.1" - - [deps.UnicodePlots.extensions] - FreeTypeExt = ["FileIO", "FreeType"] - ImageInTerminalExt = "ImageInTerminal" - IntervalSetsExt = "IntervalSets" - TermExt = "Term" - UnitfulExt = "Unitful" - - [deps.UnicodePlots.weakdeps] - FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" - FreeType = "b38be410-82b0-50bf-ab77-7b57e271db43" - ImageInTerminal = "d8c32880-2388-543b-8c61-d9f865259254" - IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" - Term = "22787eb5-b846-44ae-b979-8e399b8463ab" - Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "b13c4edda90890e5b04ba24e20a310fbe6f249ff" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.3.0" -weakdeps = ["LLVM"] - - [deps.UnsafeAtomics.extensions] - UnsafeAtomicsLLVM = ["LLVM"] - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "c7dc3148a64d1cd3768c29b3db5972d1c302661b" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.73" - - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.ZygoteRules]] -deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "27798139afc0a2afa7b1824c206d5e87ea587a00" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.5" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.11.0+0" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.59.0+0" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" diff --git a/test/perturbed.jl b/test/perturbed.jl index 61d5399..1d27201 100644 --- a/test/perturbed.jl +++ b/test/perturbed.jl @@ -23,7 +23,7 @@ @test all(diag(jac1) .>= 0) @test all(jac1 - Diagonal(jac1) .<= 0) # Order of diagonal coefficients should follow order of θ - @test sortperm(diag(jac1)) == sortperm(θ) + @test sortperm(diag(jac1_big)) == sortperm(θ) # No scaling with nb of samples @test norm(jac1) ≈ norm(jac1_big) rtol = 5e-2 end @@ -33,9 +33,9 @@ jac2_big = Zygote.jacobian( θ -> perturbed2_big(θ; autodiff_variance_reduction=false), θ )[1] - @test all(diag(jac2) .>= 0) - @test all(jac2 - Diagonal(jac2) .<= 0) - @test sortperm(diag(jac2)) != sortperm(θ) + @test all(diag(jac2_big) .>= 0) + @test all(jac2_big - Diagonal(jac2_big) .<= 0) + @test sortperm(diag(jac2_big)) == sortperm(θ) @test norm(jac2) ≈ norm(jac2_big) rtol = 5e-2 end end From 43d98ebd487a3a32dc10556536075e990d319930 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 18:57:28 +0100 Subject: [PATCH 21/26] apply_h not needed --- src/utils/linear_maximizer.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/utils/linear_maximizer.jl b/src/utils/linear_maximizer.jl index 2d4f538..c422d02 100644 --- a/src/utils/linear_maximizer.jl +++ b/src/utils/linear_maximizer.jl @@ -66,11 +66,11 @@ function apply_g(f::LinearMaximizer, y; kwargs...) return f.g(y; kwargs...) end -""" -$TYPEDSIGNATURES +# """ +# $TYPEDSIGNATURES -Applies the function `h` of the LinearMaximizer `f` to `y`. -""" -function apply_h(f::LinearMaximizer, y; kwargs...) - return f.h(y; kwargs...) -end +# Applies the function `h` of the LinearMaximizer `f` to `y`. +# """ +# function apply_h(f::LinearMaximizer, y; kwargs...) +# return f.h(y; kwargs...) +# end From b6cc8951672d4073ae55886a1c9d51facb384d9f Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Mon, 23 Dec 2024 19:12:02 +0100 Subject: [PATCH 22/26] fix exports --- src/InferOpt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/InferOpt.jl b/src/InferOpt.jl index 9fe32a9..70aa6e2 100644 --- a/src/InferOpt.jl +++ b/src/InferOpt.jl @@ -65,7 +65,7 @@ include("losses/imitation_loss.jl") export half_square_norm export shannon_entropy, negative_shannon_entropy export one_hot_argmax, ranking -export LinearMaximizer, apply_g, apply_h, objective_value +export LinearMaximizer, apply_g, objective_value export Pushforward From 1c0510f8af5e7e70f0f3a428461f3d86cf16f258 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Tue, 24 Dec 2024 12:02:02 +0100 Subject: [PATCH 23/26] fix bug in perturbed multiplicative --- src/layers/perturbed/perturbation.jl | 19 +++++++++++ src/layers/perturbed/perturbed.jl | 4 +-- src/utils/utils.jl | 13 ++++++++ test/perturbed.jl | 49 ++++++++++++++-------------- 4 files changed, 58 insertions(+), 27 deletions(-) diff --git a/src/layers/perturbed/perturbation.jl b/src/layers/perturbed/perturbation.jl index df64426..28cbad0 100644 --- a/src/layers/perturbed/perturbation.jl +++ b/src/layers/perturbed/perturbation.jl @@ -44,6 +44,15 @@ function (pdc::AdditivePerturbation)(θ::AbstractArray) return product_distribution(θ .+ ε * perturbation_dist) end +""" +$TYPEDSIGNATURES + +Compute the gradient of the logdensity of η = θ + εZ w.r.t. θ., with Z ∼ N(0, 1). +""" +function normal_additive_grad_logdensity(ε, η, θ) + return ((η .- θ) ./ ε^2,) +end + """ $TYPEDEF @@ -68,3 +77,13 @@ function (pdc::MultiplicativePerturbation)(θ::AbstractArray) (; perturbation_dist, ε) = pdc return product_distribution(θ .* ExponentialOf(ε * perturbation_dist - ε^2 / 2)) end +""" +$TYPEDSIGNATURES + +Compute the gradient of the logdensity of η = θ ⊙ exp(εZ - ε²/2) w.r.t. θ., with Z ∼ N(0, 1). +!!! warning + η should be a relization of θ, i.e. should be of the same sign. +""" +function normal_multiplicative_grad_logdensity(ε, η, θ) + return (inv.(ε^2 .* θ) .* (log.(abs.(η)) - log.(abs.(θ)) .+ (ε^2 / 2)),) +end diff --git a/src/layers/perturbed/perturbed.jl b/src/layers/perturbed/perturbed.jl index 1fd4702..2bbb3c1 100644 --- a/src/layers/perturbed/perturbed.jl +++ b/src/layers/perturbed/perturbed.jl @@ -93,7 +93,7 @@ function PerturbedAdditive( threaded=false, rng=Random.default_rng(), dist_logdensity_grad=if (perturbation_dist == Normal(0, 1)) - (η, θ) -> ((η .- θ) ./ ε^2,) + FixFirst(normal_additive_grad_logdensity, ε) else nothing end, @@ -126,7 +126,7 @@ function PerturbedMultiplicative( threaded=false, rng=Random.default_rng(), dist_logdensity_grad=if (perturbation_dist == Normal(0, 1)) - (η, θ) -> (inv.(ε^2 .* θ) .* (η .- θ),) + FixFirst(normal_multiplicative_grad_logdensity, ε) else nothing end, diff --git a/src/utils/utils.jl b/src/utils/utils.jl index 3a125d2..36b0192 100644 --- a/src/utils/utils.jl +++ b/src/utils/utils.jl @@ -33,3 +33,16 @@ struct Fix1Kwargs{F,K,T} <: Function end (fk::Fix1Kwargs)(args...) = fk.f(fk.x, args...; fk.kwargs...) + +""" +$TYPEDEF + +Callable struct that fixes the first argument of `f` to `x`. +Compared to Base.Fix1, works on functions with more than two arguments. +""" +struct FixFirst{F,T} + f::F + x::T +end + +(fk::FixFirst)(args...) = fk.f(fk.x, args...) diff --git a/test/perturbed.jl b/test/perturbed.jl index 1d27201..ddd5b48 100644 --- a/test/perturbed.jl +++ b/test/perturbed.jl @@ -6,19 +6,16 @@ θ = [3, 5, 4, 2] - perturbed1 = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=1_000, seed=0) - perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=10_000, seed=0) - perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=0.5, nb_samples=1_000, seed=0) - perturbed2_big = PerturbedMultiplicative( - one_hot_argmax; ε=0.5, nb_samples=10_000, seed=0 - ) + perturbed1 = PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=1e4, seed=0) + perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=1.0, nb_samples=1e6, seed=0) + + perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=1e4, seed=0) + perturbed2_big = PerturbedMultiplicative(one_hot_argmax; ε=1.0, nb_samples=1e6, seed=0) @testset "PerturbedAdditive" begin # Compute jacobian with reverse mode - jac1 = Zygote.jacobian(θ -> perturbed1(θ; autodiff_variance_reduction=false), θ)[1] - jac1_big = Zygote.jacobian( - θ -> perturbed1_big(θ; autodiff_variance_reduction=false), θ - )[1] + jac1 = Zygote.jacobian(perturbed1, θ)[1] + jac1_big = Zygote.jacobian(perturbed1_big, θ)[1] # Only diagonal should be positive @test all(diag(jac1) .>= 0) @test all(jac1 - Diagonal(jac1) .<= 0) @@ -29,13 +26,12 @@ end @testset "PerturbedMultiplicative" begin - jac2 = Zygote.jacobian(θ -> perturbed2(θ; autodiff_variance_reduction=false), θ)[1] - jac2_big = Zygote.jacobian( - θ -> perturbed2_big(θ; autodiff_variance_reduction=false), θ - )[1] + jac2 = Zygote.jacobian(perturbed2, θ)[1] + jac2_big = Zygote.jacobian(perturbed2_big, θ)[1] @test all(diag(jac2_big) .>= 0) @test all(jac2_big - Diagonal(jac2_big) .<= 0) - @test sortperm(diag(jac2_big)) == sortperm(θ) + @info diag(jac2_big) + @test_broken sortperm(diag(jac2_big)) == sortperm(θ) @test norm(jac2) ≈ norm(jac2_big) rtol = 5e-2 end end @@ -99,18 +95,21 @@ end ε = 1e-12 - function already_differentiable(θ) - return 2 ./ exp.(θ) .* θ .^ 2 - end + already_differentiable(θ) = 2 ./ exp.(θ) .* θ .^ 2 .+ sum(θ) + pa = PerturbedAdditive(already_differentiable; ε, nb_samples=1e6, seed=0) + pm = PerturbedMultiplicative(already_differentiable; ε, nb_samples=1e6, seed=0) - θ = randn(5) - Jz = jacobian(already_differentiable, θ)[1] + θ = [1.0, 2.0, 3.0, 4.0, 5.0] - pa = PerturbedAdditive(already_differentiable; ε, nb_samples=1e6, seed=0) - Ja = jacobian(pa, θ)[1] - @test_broken all(isapprox.(Ja, Jz, rtol=0.01)) + fz = already_differentiable(θ) + fa = pa(θ) + fm = pm(θ) + @test fz ≈ fa rtol = 0.01 + @test fz ≈ fm rtol = 0.01 - pm = PerturbedMultiplicative(already_differentiable; ε, nb_samples=1e6, seed=0) + Jz = jacobian(already_differentiable, θ)[1] + Ja = jacobian(pa, θ)[1] Jm = jacobian(pm, θ)[1] - @test_broken all(isapprox.(Jm, Jz, rtol=0.01)) + @test Ja ≈ Jz rtol = 0.01 + @test Jm ≈ Jz rtol = 0.01 end From c5c1bfdd35db89ff8bb80de5ced9dcb9f0277677 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Tue, 24 Dec 2024 12:21:55 +0100 Subject: [PATCH 24/26] bump DifferentiableFrankWolfe compat --- Project.toml | 2 +- src/layers/perturbed/perturbation.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 947003a..636745e 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,7 @@ InferOptFrankWolfeExt = "DifferentiableFrankWolfe" ChainRulesCore = "1" DensityInterface = "0.4.0" DifferentiableExpectations = "0.2" -DifferentiableFrankWolfe = "0.2" +DifferentiableFrankWolfe = "0.3" Distributions = "0.25" DocStringExtensions = "0.9.3" LinearAlgebra = "<0.0.1,1" diff --git a/src/layers/perturbed/perturbation.jl b/src/layers/perturbed/perturbation.jl index 28cbad0..1c42877 100644 --- a/src/layers/perturbed/perturbation.jl +++ b/src/layers/perturbed/perturbation.jl @@ -82,7 +82,7 @@ $TYPEDSIGNATURES Compute the gradient of the logdensity of η = θ ⊙ exp(εZ - ε²/2) w.r.t. θ., with Z ∼ N(0, 1). !!! warning - η should be a relization of θ, i.e. should be of the same sign. + η should be a realization of θ, i.e. should be of the same sign. """ function normal_multiplicative_grad_logdensity(ε, η, θ) return (inv.(ε^2 .* θ) .* (log.(abs.(η)) - log.(abs.(θ)) .+ (ε^2 / 2)),) From c2615cde95e65aa2897b82ac80d12b91e888b6b1 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Tue, 24 Dec 2024 12:42:32 +0100 Subject: [PATCH 25/26] revert to 0.2 for now --- Project.toml | 2 +- src/utils/linear_maximizer.jl | 10 ---------- src/utils/pushforward.jl | 13 +++++-------- src/utils/some_functions.jl | 16 ++++++++-------- 4 files changed, 14 insertions(+), 27 deletions(-) diff --git a/Project.toml b/Project.toml index 636745e..947003a 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,7 @@ InferOptFrankWolfeExt = "DifferentiableFrankWolfe" ChainRulesCore = "1" DensityInterface = "0.4.0" DifferentiableExpectations = "0.2" -DifferentiableFrankWolfe = "0.3" +DifferentiableFrankWolfe = "0.2" Distributions = "0.25" DocStringExtensions = "0.9.3" LinearAlgebra = "<0.0.1,1" diff --git a/src/utils/linear_maximizer.jl b/src/utils/linear_maximizer.jl index c422d02..dd849d3 100644 --- a/src/utils/linear_maximizer.jl +++ b/src/utils/linear_maximizer.jl @@ -45,7 +45,6 @@ end # default is oracles of the form argmax_y θᵀy objective_value(::Any, θ, y; kwargs...) = dot(θ, y) apply_g(::Any, y; kwargs...) = y -# apply_h(::Any, y; kwargs...) = zero(eltype(y)) is not needed """ $TYPEDSIGNATURES @@ -65,12 +64,3 @@ Applies the function `g` of the LinearMaximizer `f` to `y`. function apply_g(f::LinearMaximizer, y; kwargs...) return f.g(y; kwargs...) end - -# """ -# $TYPEDSIGNATURES - -# Applies the function `h` of the LinearMaximizer `f` to `y`. -# """ -# function apply_h(f::LinearMaximizer, y; kwargs...) -# return f.h(y; kwargs...) -# end diff --git a/src/utils/pushforward.jl b/src/utils/pushforward.jl index 8248f08..d1d600d 100644 --- a/src/utils/pushforward.jl +++ b/src/utils/pushforward.jl @@ -1,18 +1,17 @@ """ - Pushforward <: AbstractLayer +$TYPEDEF Differentiable pushforward of a probabilistic optimization layer with an arbitrary function post-processing function. `Pushforward` can be used for direct regret minimization (aka learning by experience) when the post-processing returns a cost. # Fields -- `optimization_layer::AbstractOptimizationLayer`: probabilistic optimization layer -- `post_processing`: callable - -See also: `FixedAtomsProbabilityDistribution`. +$TYPEDFIELDS """ struct Pushforward{O<:AbstractOptimizationLayer,P} <: AbstractLayer + "probabilistic optimization layer" optimization_layer::O + "callable" post_processing::P end @@ -22,13 +21,11 @@ function Base.show(io::IO, pushforward::Pushforward) end """ - (pushforward::Pushforward)(θ; kwargs...) +$TYPEDSIGNATURES Output the expectation of `pushforward.post_processing(X)`, where `X` follows the distribution defined by `pushforward.optimization_layer` applied to `θ`. This function is differentiable, even if `pushforward.post_processing` isn't. - -See also: `compute_expectation`. """ function (pushforward::Pushforward)(θ::AbstractArray; kwargs...) (; optimization_layer, post_processing) = pushforward diff --git a/src/utils/some_functions.jl b/src/utils/some_functions.jl index 4512ab2..7672a65 100644 --- a/src/utils/some_functions.jl +++ b/src/utils/some_functions.jl @@ -1,26 +1,26 @@ """ - positive_part(x) +$TYPEDSIGNATURES -Compute `max(x,0)`. +Compute `max(x, 0)`. """ positive_part(x) = x >= zero(x) ? x : zero(x) """ - isproba(x) +$TYPEDSIGNATURES Check whether `x ∈ [0,1]`. """ isproba(x::Real) = zero(x) <= x <= one(x) """ - isprobadist(p) +$TYPEDSIGNATURES Check whether the elements of `p` are nonnegative and sum to 1. """ isprobadist(p::AbstractVector{R}) where {R<:Real} = all(isproba, p) && sum(p) ≈ one(R) """ - half_square_norm(x) +$TYPEDSIGNATURES Compute the squared Euclidean norm of `x` and divide it by 2. """ @@ -29,7 +29,7 @@ function half_square_norm(x::AbstractArray) end """ - shannon_entropy(p) +$TYPEDSIGNATURES Compute the Shannon entropy of a probability distribution: `H(p) = -∑ pᵢlog(pᵢ)`. """ @@ -46,7 +46,7 @@ end negative_shannon_entropy(p::AbstractVector) = -shannon_entropy(p) """ - one_hot_argmax(z) +$TYPEDSIGNATURES One-hot encoding of the argmax function. """ @@ -57,7 +57,7 @@ function one_hot_argmax(z::AbstractVector{R}; kwargs...) where {R<:Real} end """ - ranking(θ[; rev]) +$TYPEDSIGNATURES Compute the vector `r` such that `rᵢ` is the rank of `θᵢ` in `θ`. """ From f292ea012763dff2ce18096ec67e6fb0b8852324 Mon Sep 17 00:00:00 2001 From: BatyLeo Date: Wed, 25 Dec 2024 17:38:16 +0100 Subject: [PATCH 26/26] clenup --- test/perturbed.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/perturbed.jl b/test/perturbed.jl index ddd5b48..2579828 100644 --- a/test/perturbed.jl +++ b/test/perturbed.jl @@ -30,7 +30,6 @@ jac2_big = Zygote.jacobian(perturbed2_big, θ)[1] @test all(diag(jac2_big) .>= 0) @test all(jac2_big - Diagonal(jac2_big) .<= 0) - @info diag(jac2_big) @test_broken sortperm(diag(jac2_big)) == sortperm(θ) @test norm(jac2) ≈ norm(jac2_big) rtol = 5e-2 end