Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix lookups #899

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
*.jl.*.cov
*.jl.cov
*.jl.mem
*.cov
*.cov
*.mem
.DS_Store

/Manifest.toml
Expand Down
34 changes: 33 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
DiskArrays = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

Expand All @@ -36,6 +37,7 @@ DimensionalDataAlgebraOfGraphicsExt = "AlgebraOfGraphics"
DimensionalDataCategoricalArraysExt = "CategoricalArrays"
DimensionalDataDiskArraysExt = "DiskArrays"
DimensionalDataMakie = "Makie"
DimensionalDataNearestNeighborsExt = "NearestNeighbors"
DimensionalDataPythonCall = "PythonCall"
DimensionalDataStatsBase = "StatsBase"

Expand Down Expand Up @@ -68,6 +70,7 @@ IteratorInterfaceExtensions = "1"
JLArrays = "0.1"
LinearAlgebra = "1"
Makie = "0.20, 0.21, 0.22"
NearestNeighbors = "0.4"
OffsetArrays = "1"
Plots = "1"
PrecompileTools = "1"
Expand Down Expand Up @@ -115,4 +118,33 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["AlgebraOfGraphics", "Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "DiskArrays", "Distributions", "Documenter", "GPUArrays", "ImageFiltering", "ImageTransformations", "JLArrays", "CairoMakie", "OffsetArrays", "Plots", "PythonCall", "Random", "SafeTestsets", "StatsBase", "StatsPlots", "Test", "Unitful"]
test = [
"AlgebraOfGraphics",
"Aqua",
"ArrayInterface",
"BenchmarkTools",
"CairoMakie",
"CategoricalArrays",
"ColorTypes",
"Combinatorics",
"CoordinateTransformations",
"DataFrames",
"DiskArrays",
"Distributions",
"Documenter",
"GPUArrays",
"ImageFiltering",
"ImageTransformations",
"JLArrays",
"NearestNeighbors",
"OffsetArrays",
"Plots",
"PythonCall",
"Random",
"SafeTestsets",
"StatsBase",
"StatsPlots",
"Test",
"Unitful"
]

45 changes: 45 additions & 0 deletions ext/DimensionalDataNearestNeighborsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
module DimensionalDataNearestNeighborsExt

using DimensionalData
using NearestNeighbors
using NearestNeighbors.StaticArrays
using DimensionalData.Lookups
using DimensionalData.Dimensions

using DimensionalData.Lookups: ArrayLookup, matrix, atol

const DD = DimensionalData
const NN = NearestNeighbors

function DD.Lookups.select_array_lookups(
lookups::Tuple{<:ArrayLookup,<:ArrayLookup,Vararg{ArrayLookup}},
selectors::Tuple{<:Union{At,Near},<:Union{Near,At},Vararg{Union{Near,At}}}
)
f1 = first(lookups)
vals = SVector(map(val, selectors))
tree = Lookups.tree(f1)
knn!(f1.idxvec, f1.distvec, tree, vals, 1)
idx = f1.idxvec[1]
found_vals = tree.data[idx]
map(selectors, Tuple(found_vals)) do s, t
s isa At ? Lookups._is_at(s, t) : true
end |> all || throw(ArgumentError("$(selectors) not found in lookup"))
return CartesianIndices(matrix(first(lookups)))[idx] |> Tuple
end

function DD.Dimensions.format_unaligned(
lookups::Tuple{<:ArrayLookup,<:ArrayLookup,Vararg{ArrayLookup}}, dims::DD.DimTuple, axes,
)
points = vec(SVector.(zip(map(matrix, lookups)...)))
idxvec = Vector{Int}(undef, 1)
distvec = Vector{NN.get_T(eltype(points))}(undef, 1)
tree = NN.KDTree(points, NN.Euclidean(); reorder=false)
return map(lookups, dims, axes) do l, d, a
newl = rebuild(l;
data=a, tree, dim=basedims(d), dims=basedims(dims), idxvec, distvec
)
rebuild(d, newl)
end
end

end
24 changes: 19 additions & 5 deletions src/Dimensions/format.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and any fields holding `Auto-` objects are filled with guessed objects.
If a [`Lookup`](@ref) hasn't been specified, a lookup is chosen
based on the type and element type of the values.
"""
format(dims, A::AbstractArray) = format((dims,), A)
format(dims::DimOrDimType, A::AbstractVector) = format((dims,), A)
function format(dims::NamedTuple, A::AbstractArray)
dims = map(keys(dims), values(dims)) do k, v
rebuild(name2dim(k), v)
Expand All @@ -29,17 +29,28 @@ end
# Make a dummy array that assumes the dims are the correct length and don't hold `Colon`s
function format(dims::DimTuple)
ax = map(parent ∘ first ∘ axes, dims)
A = CartesianIndices(ax)
return format(dims, A)
return format(dims, ax)
end
format(dims::Tuple{Vararg{Any,N}}, A::AbstractArray{<:Any,N}) where N = format(dims, axes(A))
format(dims::Tuple{Vararg{Any,N}}, A::AbstractArray{<:Any,N}) where N =
format(dims, axes(A))
@noinline format(dims::Tuple{Vararg{Any,M}}, A::AbstractArray{<:Any,N}) where {N,M} =
throw(DimensionMismatch("Array A has $N axes, while the number of dims is $M: $(map(basetypeof, dims))"))
format(dims::Tuple{Vararg{Any,N}}, axes::Tuple{Vararg{Any,N}}) where N = map(_format, dims, axes)
function format(dims::Tuple{Vararg{Any,N}}, axes::Tuple{Vararg{Any,N}}) where N
# We need to format first
fdims = map(_format, dims, axes)
# Then format any unaligned dims as a group
split_alignments(first ∘ tuple, format_unaligned, fdims, axes)
end
format(d::Dimension{<:AbstractArray}) = _format(d, axes(val(d), 1))
format(d::Dimension, axis::AbstractRange) = _format(d, axis)
format(d::Type{<:Dimension}, axis::AbstractRange) = _format(d, axis)
format(l::Lookup) = parent(format(AnonDim(l)))

# Fallback
function format_unaligned end
format_unaligned(dims::DimTuple, axes) = format_unaligned(val(dims), dims, axes)
format_unaligned(::Tuple, dims::DimTuple, axes) = map(format, dims, axes)

_format(dimname::Symbol, axis::AbstractRange) = Dim{dimname}(NoLookup(axes(axis, 1)))
_format(::Type{D}, axis::AbstractRange) where D<:Dimension = D(NoLookup(axes(axis, 1)))
_format(dim::Dimension{Colon}, axis::AbstractRange) = rebuild(dim, NoLookup(axes(axis, 1)))
Expand All @@ -54,6 +65,9 @@ format(m::Lookup, D::Type, axis::AbstractRange) = format(m, D, parent(m), axis)
format(v::AutoVal, D::Type, axis::AbstractRange) = _valformaterror(val(v), D)
format(v, D::Type, axis::AbstractRange) = _valformaterror(v, D)

format(m::Lookups.ArrayLookup, D::Type, ::AutoValues, axis::AbstractRange) =
rebuild(m; dim=D(), data=axis)

# Format Lookups
# No more identification required for NoLookup
format(m::Lookups.Length1NoLookup, D::Type, values, axis::AbstractRange) = m
Expand Down
38 changes: 23 additions & 15 deletions src/Dimensions/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,29 @@ Convert a `Dimension` or `Selector` `I` to indices of `Int`, `AbstractArray` or
@inline function dims2indices(dims::DimTuple, I::DimTuple)
extradims = otherdims(I, dims)
length(extradims) > 0 && _extradimswarn(extradims)
_dims2indices(lookup(dims), dims, sortdims(I, dims))
Isorted = Dimensions.sortdims(I, dims)
return split_alignments(dims2indices, unalligned_dims2indices, dims, Isorted)
end
@inline dims2indices(dims::Tuple{}, ::Tuple{}) = ()

# Handle tuples with @generated
@inline _dims2indices(::Tuple{}, dims::Tuple{}, ::Tuple{}) = ()
@generated function _dims2indices(lookups::Tuple, dims::Tuple, I::Tuple)
@inline function unalligned_dims2indices(dims::DimTuple, sel::Tuple)
map(sel) do s
s isa Union{Selector,Interval} && _unalligned_all_selector_error(dims)
isnothing(s) ? Colon() : s
end
end
@inline function unalligned_dims2indices(dims::DimTuple, sel::Tuple{Selector,Vararg{Selector}})
Lookups.select_unalligned_indices(lookup(dims), sel)
end

# Run fa on each aligned dimension d[n] and indices i[n],
# and fu on grouped unaligned dimensions and I.
# The result is the updated dimensions, but in the original order
split_alignments(fa, fu, dims::Tuple, I::Tuple) =
split_alignments(fa, fu, val(dims), dims, I)
@generated function split_alignments(
fa, fu, lookups::Tuple, dims::Tuple, I::Tuple
)
# We separate out Aligned and Unaligned lookups as
# Unaligned must be selected in groups e.g. X and Y together.
unalligned = Expr(:tuple)
Expand All @@ -68,7 +85,7 @@ end
push!(dimmerge.args, :(uadims[$ua_count]))
else
a_count += 1
push!(alligned.args, :(_dims2indices(dims[$i], I[$i])))
push!(alligned.args, :(fa(dims[$i], I[$i])))
# Update the merged tuple
push!(dimmerge.args, :(adims[$a_count]))
end
Expand All @@ -80,23 +97,14 @@ end
quote
adims = $alligned
# Unaligned dims have to be run together as a set
uadims = unalligned_dims2indices($unalligned, map(_unwrapdim, $uaI))
uadims = fu($unalligned, map(_unwrapdim, $uaI))
$dimmerge
end
else
alligned
end
end

@inline function unalligned_dims2indices(dims::DimTuple, sel::Tuple)
map(sel) do s
s isa Union{Selector,Interval} && _unalligned_all_selector_error(dims)
isnothing(s) ? Colon() : s
end
end
@inline function unalligned_dims2indices(dims::DimTuple, sel::Tuple{Selector,Vararg{Selector}})
Lookups.select_unalligned_indices(lookup(dims), sel)
end

_unalligned_all_selector_error(dims) =
throw(ArgumentError("Unalligned dims: use selectors for all $(join(map(name, dims), ", ")) dims, or none of them"))
Expand Down
2 changes: 1 addition & 1 deletion src/Lookups/Lookups.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export AutoStep, AutoBounds, AutoValues
export Lookup
export AutoLookup, AbstractNoLookup, NoLookup
export Aligned, AbstractSampled, Sampled, AbstractCyclic, Cyclic, AbstractCategorical, Categorical
export Unaligned, Transformed
export Unaligned, Transformed, ArrayLookup

# Deprecated
export LookupArray
Expand Down
18 changes: 16 additions & 2 deletions src/Lookups/lookup_arrays.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
Lookup

Expand Down Expand Up @@ -612,6 +611,21 @@ transformfunc(lookup::Transformed) = lookup.f
Base.:(==)(l1::Transformed, l2::Transformed) = typeof(l1) == typeof(l2) && f(l1) == f(l2)

# TODO Transformed bounds
struct ArrayLookup{T,A,D,Ds,Ma<:AbstractArray{T},Tr,IV,DV,Me} <: Unaligned{T,1}
data::A
dim::D
dims::Ds
matrix::Ma
tree::Tr
idxvec::IV
distvec::DV
metadata::Me
end
ArrayLookup(matrix; metadata=NoMetadata()) =
ArrayLookup(AutoValues(), AutoDim(), AutoDim(), matrix, nothing, nothing, nothing, metadata)
dim(lookup::ArrayLookup) = lookup.dim
matrix(l::ArrayLookup) = l.matrix
tree(l::ArrayLookup) = l.tree

# Shared methods

Expand Down Expand Up @@ -973,4 +987,4 @@ function promote_first(a1::AbstractArray, as::AbstractArray...)
end

return convert(C, a1)
end
end
30 changes: 25 additions & 5 deletions src/Lookups/selector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,11 @@ end
function at(lookup::NoLookup, sel::At; err=_True(), kw...)
v = val(sel)
r = round(Int, v)
at = atol(sel)
if isnothing(at)
if isnothing(atol(sel))
v == r || _selnotfound_or_nothing(err, lookup, v)
else
at >= 0.5 && error("atol must be small than 0.5 for NoLookup")
isapprox(v, r; atol=at) || _selnotfound_or_nothing(err, lookup, v)
isapprox(v, r; atol=atol(sel)) || _selnotfound_or_nothing(err, lookup, v)
end
if r in lookup
return r
Expand Down Expand Up @@ -220,6 +219,8 @@ function at(::Order, ::Span, lookup::Lookup, selval, atol, rtol::Nothing; err=_T
end
end


_is_at(at::At, v) = _is_at(val(at), v, atol(at))
@inline _is_at(x, y, atol) = x == y
@inline _is_at(x::Dates.AbstractTime, y::Dates.AbstractTime, atol::Dates.Period) =
x >= y - atol && x <= y + atol
Expand Down Expand Up @@ -1103,13 +1104,32 @@ end

# We use the transformation from the first unaligned dim.
# In practice the others could be empty.
function select_unalligned_indices(lookups::LookupTuple, sel::Tuple{IntSelector,Vararg{IntSelector}})
function select_unalligned_indices(
lookups::LookupTuple, sel::Tuple{IntSelector,Vararg{IntSelector}}
)
transformed = transformfunc(lookups[1])(map(val, sel))
map(_transform2int, lookups, sel, transformed)
end
function select_unalligned_indices(lookups::LookupTuple, sel::Tuple{Selector,Vararg{Selector}})
function select_unalligned_indices(
lookups::LookupTuple, sel::Tuple{Selector,Vararg{Selector}}
)
throw(ArgumentError("only `Near`, `At` or `Contains` selectors currently work on `Unalligned` lookups"))
end
function select_unalligned_indices(
lookups::Tuple{<:ArrayLookup,<:ArrayLookup,Vararg{ArrayLookup}},
selectors::Tuple{<:IntSelector,<:IntSelector,Vararg{IntSelector}}
)
select_array_lookups(lookups, selectors)
end

# This implementation is extremely slow,
# it's expected user will use the NearestNeighbors.jl extension
function select_array_lookups(
lookups::Tuple{<:ArrayLookup,<:ArrayLookup,Vararg{ArrayLookup}},
selectors::Tuple
)
throw(ArgumentError("Load NearestNeighbors.jl to use `At` on `ArrayLookup`s"))
end

_transform2int(lookup::AbstractArray, ::Near, x) = min(max(round(Int, x), firstindex(lookup)), lastindex(lookup))
_transform2int(lookup::AbstractArray, ::Contains, x) = round(Int, x)
Expand Down
12 changes: 11 additions & 1 deletion src/Lookups/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,20 @@ function Base.show(io::IO, mime::MIME"text/plain", lookup::Transformed)
show_compact(io, mime, lookup)
show(io, mime, lookup.f)
print(io, " ")
ctx = IOContext(io, :compact=>true)
ctx = IOContext(io, :compact => true)
show(ctx, mime, dim(lookup))
end

function Base.show(io::IO, mime::MIME"text/plain", lookup::ArrayLookup)
show_compact(io, mime, lookup)
if !get(io, :compact, false)
println(io)
ctx = IOContext(io, :compact => true)
show(ctx, mime, lookup.matrix)
show(ctx, mime, dim(lookup))
end
end

function Base.show(io::IO, mime::MIME"text/plain", lookup::Lookup)
show_compact(io, mime, lookup)
get(io, :compact, false) && print_index(io, mime, parent(lookup))
Expand Down
14 changes: 13 additions & 1 deletion src/array/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,8 @@ for f in (:fill, :rand)
end
end
# AbstractRNG rand DimArray creation methods
Base.rand(r::AbstractRNG, x, d1::Dimension, dims::Dimension...; kw...) = rand(r, x, (d1, dims...); kw...)
Base.rand(r::AbstractRNG, x, d1::Dimension, dims::Dimension...; kw...) =
rand(r, x, (d1, dims...); kw...)
function Base.rand(r::AbstractRNG, x, dims::DimTuple; kw...)
C = dimconstructor(dims)
C(rand(r, x, _dimlength(dims)), _maybestripval(dims); kw...)
Expand All @@ -739,6 +740,17 @@ function Base.rand(r::AbstractRNG, ::Type{T}, dims::DimTuple; kw...) where T
C(rand(r, T, _dimlength(dims)), _maybestripval(dims); kw...)
end

function _dimlength(
dims::Tuple{<:Dimension{<:Lookups.ArrayLookup},Vararg{Dimension{<:Lookups.ArrayLookup}}}
)
lookups = lookup(dims)
sz1 = size(first(lookups).matrix)
foreach(lookups) do l
sz = size(l.matrix)
sz1 == sz || throw(ArgumentError("ArrayLookup matrix sizes must match. Got $sz1 and $sz"))
end
return sz1
end
_dimlength(dims::Tuple) = map(_dimlength, dims)
_dimlength(dim::Dimension{<:AbstractArray}) = length(dim)
_dimlength(dim::Dimension{<:Val{Keys}}) where Keys = length(Keys)
Expand Down
Loading
Loading