From 987b20af8b85d68c03163b7bee0a91f32d9ec535 Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Wed, 27 Dec 2023 04:51:20 -0500 Subject: [PATCH 1/3] support Elements() for Dicts and Sets --- src/getsetall.jl | 4 ++++ src/optics.jl | 7 ++++--- test/test_getsetall.jl | 4 ++++ test/test_optics.jl | 6 ++++-- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/getsetall.jl b/src/getsetall.jl index 3ddb2f46..ef7a565d 100644 --- a/src/getsetall.jl +++ b/src/getsetall.jl @@ -54,6 +54,8 @@ function setall end getall(obj::Union{Tuple, AbstractVector}, ::Elements) = obj getall(obj::Union{NamedTuple}, ::Elements) = values(obj) getall(obj::AbstractArray, ::Elements) = vec(obj) +getall(obj::AbstractSet, ::Elements) = collect(obj) +getall(obj::AbstractDict, ::Elements) = collect(obj) getall(obj::Number, ::Elements) = (obj,) getall(obj::AbstractString, ::Elements) = collect(obj) getall(obj, ::Elements) = error("Elements() not supported for $(typeof(obj))") @@ -70,6 +72,8 @@ setall(obj::NamedTuple{NS}, ::Elements, vs) where {NS} = NamedTuple{NS}(NTuple{l setall(obj::NTuple{N, Any}, ::Elements, vs) where {N} = (@assert length(vs) == N; NTuple{N}(vs)) setall(obj::AbstractArray, ::Elements, vs::AbstractArray) = (@assert length(obj) == length(vs); reshape(vs, size(obj))) setall(obj::AbstractArray, ::Elements, vs) = setall(obj, Elements(), collect(vs)) +setall(obj::Set, ::Elements, vs) = Set(vs) +setall(obj::Dict, ::Elements, vs) = Dict(vs) setall(obj, ::Elements, vs) = error("Elements() not supported for $(typeof(obj))") function setall(obj, o::If, vs) if o.modify_condition(obj) diff --git a/src/optics.jl b/src/optics.jl index 18ddec7f..2f8f6d82 100644 --- a/src/optics.jl +++ b/src/optics.jl @@ -227,9 +227,10 @@ $EXPERIMENTAL struct Elements end OpticStyle(::Type{<:Elements}) = ModifyBased() -function modify(f, obj, ::Elements) - map(f, obj) -end +modify(f, obj, ::Elements) = map(f, obj) +# sets and dicts don't support map(), but still have the concept of elements: +modify(f, obj::Set, ::Elements) = Set(f(p) for p in obj) +modify(f, obj::Dict, ::Elements) = Dict(f(p)::Pair for p in obj) """ If(modify_condition) diff --git a/test/test_getsetall.jl b/test/test_getsetall.jl index 6518d122..ba99fad8 100644 --- a/test/test_getsetall.jl +++ b/test/test_getsetall.jl @@ -41,6 +41,8 @@ end @test (2, 5, 10, 17, 26, 37) === @inferred getall(obj, @optic _ |> _[:] |> Elements() |> Elements() |> _[:] |> Elements() |> Elements() |> _[1]^2 + 1 |> only) # trickier types for Elements(): + @test issetequal(["x", "y"], @inferred getall(Set(["x", "y"]), Elements())) + @test issetequal([1 => "x", 2 => "y"], @inferred getall(Dict(1 => "x", 2 => "y"), Elements())) obj = (a=("ab", "c"), b=([1 2; 3 4],), c=(SVector(1.), SVector(2, 3))) @test ['b', 'c', 'd'] == @inferred getall(obj, @optic _.a |> Elements() |> Elements() |> _ + 1) @test [2, 4, 3, 5] == @inferred getall(obj, @optic _.b |> Elements() |> Elements() |> _ + 1) @@ -90,6 +92,8 @@ end @test [2, 3] == @inferred setall([1, "2"], Elements(), (2, 3)) @test [2, "3"] == @inferred setall([1, "2"], Elements(), (2, "3")) @test [2, 3] == @inferred setall([1, "2"], Elements(), [2, 3]) + @test Set([2, 3]) == @inferred setall(Set(["1", "2"]), Elements(), (2, 3)) + @test Dict(1 => 2, 3 => 4) == @inferred setall(Dict(:a => :b, :c => :d), Elements(), [1 => 2, 3 => 4]) @test_throws ErrorException setall("abc", Elements(), [2, 3]) @test 2 === @inferred setall(1, If(>(0)), (2,)) diff --git a/test/test_optics.jl b/test/test_optics.jl index 5dcda4b7..1364030a 100644 --- a/test/test_optics.jl +++ b/test/test_optics.jl @@ -55,18 +55,20 @@ end end @testset "Elements" begin - @test [0,0,0] == @set 1:3 |> Elements() = 0 arr = 1:3 @test 2:4 == (@set arr |> Elements() += 1) @test map(cos, arr) == modify(cos, arr, Elements()) - @test modify(cos, (), Elements()) === () @inferred modify(cos, arr, Elements()) @inferred modify(cos, (), Elements()) + + @test Set([2,3,4]) == @inferred modify(x->x+1, Set([1,2,3]), Elements()) + # not @inferred because Tuple(::Pair) is type unstable: + @test Dict(1 => 2, 2 => 3, 3 => 4) == modify(x->x+1, Dict(1 => 1, 2 => 2, 3 => 3), last ∘ Elements()) end @testset "Recursive" begin From 48debb8a8d1eee1df0595985bc08bb018095b839 Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Thu, 4 Jan 2024 12:32:39 -0500 Subject: [PATCH 2/3] set(diag) --- src/functionlenses.jl | 4 +++- test/test_functionlenses.jl | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/functionlenses.jl b/src/functionlenses.jl index 1f5911b7..424cd13f 100644 --- a/src/functionlenses.jl +++ b/src/functionlenses.jl @@ -1,4 +1,4 @@ -using LinearAlgebra: norm, normalize +using LinearAlgebra: norm, normalize, diag, diagind using Dates # first and last on general indexable collections @@ -100,6 +100,8 @@ modify(f, obj, o::typeof(skipmissing)) = @modify(f, obj |> filter(!ismissing, _) set(obj, ::typeof(sort), val) = @set obj[sortperm(obj)] = val modify(f, obj, ::typeof(sort)) = @modify(f, obj[sortperm(obj)]) +set(A, ::typeof(diag), val) = @set A[diagind(A)] = val + ################################################################################ ##### os ################################################################################ diff --git a/test/test_functionlenses.jl b/test/test_functionlenses.jl index abaa06fd..b4af6882 100644 --- a/test/test_functionlenses.jl +++ b/test/test_functionlenses.jl @@ -2,7 +2,7 @@ module TestFunctionLenses using Test using Dates using Unitful -using LinearAlgebra: norm +using LinearAlgebra: norm, diag using InverseFunctions: inverse using Accessors: test_getset_laws, test_modify_law using Accessors @@ -183,6 +183,9 @@ end test_modify_law(reverse, @optic(filter(>(0), _)), [1, -2, 3, -4, 5, -6]) test_getset_laws(skipmissing, [1, missing, 3], [0, 1], [5, 6]; cmp=(x,y) -> isequal(collect(x), collect(y))) test_modify_law(cumsum, sort, [1, -2, 3, -4, 5, -6]) + + test_getset_laws(diag, [1 2; 3 4], [1., 2.5], [0, 1]) + test_getset_laws(diag, [1 2 3; 4 5 6], [1., 2.5], [0, 1]) end @testset "math" begin From 5517fb4c2e3438a116e59fb776642eac1f3d71ea Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Thu, 4 Jan 2024 13:04:13 -0500 Subject: [PATCH 3/3] add optic negation same as how functions work --- src/optics.jl | 2 ++ test/test_functionlenses.jl | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/optics.jl b/src/optics.jl index 2f8f6d82..1c517c2d 100644 --- a/src/optics.jl +++ b/src/optics.jl @@ -461,6 +461,8 @@ Broadcast.broadcastable( o::Union{PropertyLens,IndexLens,DynamicIndexLens,Elements,Properties,If,Recursive} ) = Ref(o) +Base.:(!)(f::Union{PropertyLens,IndexLens,DynamicIndexLens}) = (!) ∘ f + function make_salt(s64::UInt64)::UInt # used for faster hashes. See https://github.com/jw3126/Setfield.jl/pull/162 diff --git a/test/test_functionlenses.jl b/test/test_functionlenses.jl index b4af6882..69ab750c 100644 --- a/test/test_functionlenses.jl +++ b/test/test_functionlenses.jl @@ -254,6 +254,10 @@ end @test set((a=3, b=4), norm, 10) === (a=6., b=8.) test_getset_laws(norm, (3, 4), 10, 12) test_getset_laws(Base.splat(hypot), (3, 4), 10, 12) + + test_getset_laws(!(@optic _.a), (a=true,), false, true) + test_getset_laws(!(@optic _[1]), (a=true,), false, true) + test_getset_laws(!(@optic _[end]), (a=true,), false, true) end @testset "dates" begin