Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General overhaul #120

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.DS_Store
*.jl.*.cov
*.jl.cov
*.jl.mem
Expand Down
12 changes: 8 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@ version = "0.6.1"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
DifferentiableExpectations = "fc55d66b-b2a8-4ccc-9d64-c0c2166ceb36"
DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"

[weakdeps]
DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d"
Expand All @@ -24,14 +26,16 @@ InferOptFrankWolfeExt = "DifferentiableFrankWolfe"
[compat]
ChainRulesCore = "1"
DensityInterface = "0.4.0"
DifferentiableExpectations = "0.2"
DifferentiableFrankWolfe = "0.2"
LinearAlgebra = "1"
Random = "1"
Distributions = "0.25"
DocStringExtensions = "0.9.3"
LinearAlgebra = "<0.0.1,1"
Random = "<0.0.1,1"
RequiredInterfaces = "0.1.3"
Statistics = "1"
StatsBase = "0.33, 0.34"
StatsFuns = "1.3"
ThreadsX = "0.1.11"
julia = "1.10"

[extras]
Expand Down
27 changes: 21 additions & 6 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,20 @@ DocMeta.setdocmeta!(InferOpt, :DocTestSetup, :(using InferOpt); recursive=true)

# Copy README.md into docs/src/index.md (overwriting)

cp(
joinpath(dirname(@__DIR__), "README.md"),
joinpath(@__DIR__, "src", "index.md");
force=true,
)
open(joinpath(@__DIR__, "src", "index.md"), "w") do io
println(
io,
"""
```@meta
EditURL = "https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl/blob/main/README.md"
```
""",
)
# Write the contents out below the meta bloc
for line in eachline(joinpath(dirname(@__DIR__), "README.md"))
println(io, line)
end
end

# Parse test/tutorial.jl into docs/src/tutorial.md (overwriting)

Expand All @@ -21,8 +30,14 @@ Literate.markdown(tuto_jl_file, tuto_md_dir; documenter=true, execute=false)
makedocs(;
modules=[InferOpt],
authors="Guillaume Dalle, Léo Baty, Louis Bouvier, Axel Parmentier",
repo="https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl/blob/{commit}{path}#{line}",
sitename="InferOpt.jl",
format=Documenter.HTML(),
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://juliadecisionfocusedlearning.github.io/InferOpt.jl",
assets=String[],
repolink="https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl",
),
pages=[
"Home" => "index.md",
"Background" => "background.md",
Expand Down
5 changes: 3 additions & 2 deletions ext/InferOptFrankWolfeExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module InferOptFrankWolfeExt

using DifferentiableExpectations:
DifferentiableExpectations, FixedAtomsProbabilityDistribution
using DifferentiableFrankWolfe: DifferentiableFrankWolfe, DiffFW
using DifferentiableFrankWolfe: LinearMinimizationOracle # from FrankWolfe
using DifferentiableFrankWolfe: IterativeLinearSolver # from ImplicitDifferentiation
using InferOpt: InferOpt, RegularizedFrankWolfe, FixedAtomsProbabilityDistribution
using InferOpt: compute_expectation, compute_probability_distribution
using InferOpt: InferOpt, RegularizedFrankWolfe
using LinearAlgebra: dot

"""
Expand Down
61 changes: 34 additions & 27 deletions src/InferOpt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,71 +10,78 @@ module InferOpt
using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, Tangent, ZeroTangent
using ChainRulesCore: rrule, rrule_via_ad, unthunk
using DensityInterface: logdensityof
using DifferentiableExpectations:
DifferentiableExpectations, Reinforce, empirical_predistribution, empirical_distribution
using Distributions:
Distributions,
ContinuousUnivariateDistribution,
LogNormal,
Normal,
product_distribution,
logpdf
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
using LinearAlgebra: dot
using Random: AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed!
using Random: Random, AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed!
using Statistics: mean
using StatsBase: StatsBase, sample
using StatsFuns: logaddexp, softmax
using ThreadsX: ThreadsX
using RequiredInterfaces

include("interface.jl")

include("utils/utils.jl")
include("utils/some_functions.jl")
include("utils/probability_distribution.jl")
include("utils/pushforward.jl")
include("utils/generalized_maximizer.jl")
include("utils/linear_maximizer.jl")
include("utils/isotonic_regression/isotonic_l2.jl")
include("utils/isotonic_regression/isotonic_kl.jl")
include("utils/isotonic_regression/projection.jl")

include("simple/interpolation.jl")
include("simple/identity.jl")
# Layers
include("layers/simple/interpolation.jl")
include("layers/simple/identity.jl")

include("regularized/abstract_regularized.jl")
include("regularized/soft_argmax.jl")
include("regularized/sparse_argmax.jl")
include("regularized/soft_rank.jl")
include("regularized/regularized_frank_wolfe.jl")
include("layers/perturbed/utils.jl")
include("layers/perturbed/perturbation.jl")
include("layers/perturbed/perturbed.jl")

include("perturbed/abstract_perturbed.jl")
include("perturbed/additive.jl")
include("perturbed/multiplicative.jl")
include("perturbed/perturbed_oracle.jl")

include("imitation/spoplus_loss.jl")
include("imitation/ssvm_loss.jl")
include("imitation/fenchel_young_loss.jl")
include("imitation/imitation_loss.jl")
include("imitation/zero_one_loss.jl")
include("layers/regularized/abstract_regularized.jl")
include("layers/regularized/soft_argmax.jl")
include("layers/regularized/sparse_argmax.jl")
include("layers/regularized/soft_rank.jl")
include("layers/regularized/regularized_frank_wolfe.jl")

if !isdefined(Base, :get_extension)
include("../ext/InferOptFrankWolfeExt.jl")
end

# Losses
include("losses/fenchel_young_loss.jl")
include("losses/spoplus_loss.jl")
include("losses/ssvm_loss.jl")
include("losses/zero_one_loss.jl")
include("losses/imitation_loss.jl")

export half_square_norm
export shannon_entropy, negative_shannon_entropy
export one_hot_argmax, ranking
export GeneralizedMaximizer, objective_value
export LinearMaximizer, apply_g, objective_value

export FixedAtomsProbabilityDistribution
export compute_expectation
export compute_probability_distribution
export Pushforward

export IdentityRelaxation
export Interpolation

export AbstractRegularized, AbstractRegularizedGeneralizedMaximizer
export AbstractRegularized
export SoftArgmax, soft_argmax
export SparseArgmax, sparse_argmax
export SoftRank, soft_rank, soft_rank_l2, soft_rank_kl
export SoftSort, soft_sort, soft_sort_l2, soft_sort_kl
export RegularizedFrankWolfe

export PerturbedOracle
export PerturbedAdditive
export PerturbedMultiplicative
export PerturbedOracle

export FenchelYoungLoss
export StructuredSVMLoss
Expand Down
169 changes: 0 additions & 169 deletions src/imitation/fenchel_young_loss.jl

This file was deleted.

6 changes: 4 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ abstract type AbstractLayer end
Supertype for all the optimization layers defined in InferOpt.

# Interface
- `(layer::AbstractOptimizationLayer)(θ; kwargs...)`
- `(layer::AbstractOptimizationLayer)(θ::AbstractArray; kwargs...)`
- `compute_probability_distribution(layer, θ; kwargs...)` (only if the layer is probabilistic)
"""
abstract type AbstractOptimizationLayer <: AbstractLayer end

get_maximizer(layer::AbstractOptimizationLayer) = nothing

## Losses

"""
Expand All @@ -45,6 +47,6 @@ abstract type AbstractLossLayer <: AbstractLayer end
"""
compute_probability_distribution(layer, θ; kwargs...)

Apply a probabilistic optimization layer to an objective direction `θ` in order to generate a [`FixedAtomsProbabilityDistribution`](@ref) on the vertices of a polytope.
Apply a probabilistic optimization layer to an objective direction `θ` in order to generate a `FixedAtomsProbabilityDistribution` on the vertices of a polytope.
"""
function compute_probability_distribution end
Loading
Loading