Skip to content

Commit

Permalink
Random fixes (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Jan 16, 2025
1 parent 4605dfc commit 555adcb
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 47 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name = "NamedDimsArrays"
uuid = "60cbd0c0-df58-4cb7-918c-6f5607b73fde"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.5"
version = "0.3.6"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -25,8 +26,9 @@ Adapt = "4.1.1"
ArrayLayouts = "1.11.0"
BlockArrays = "1.3.0"
DerivableInterfaces = "0.3.7"
FillArrays = "1.13.0"
LinearAlgebra = "1.10"
MapBroadcast = "0.1.5"
MapBroadcast = "0.1.6"
Random = "1.10"
SimpleTraits = "0.9.4"
TensorAlgebra = "0.1"
Expand Down
16 changes: 6 additions & 10 deletions ext/NamedDimsArraysBlockArraysExt/NamedDimsArraysBlockArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,33 @@ module NamedDimsArraysBlockArraysExt
using ArrayLayouts: ArrayLayouts
using BlockArrays: Block, BlockRange
using NamedDimsArrays:
AbstractNamedDimsArray,
AbstractNamedUnitRange,
named_getindex,
nameddims_getindex,
nameddims_view
AbstractNamedDimsArray, AbstractNamedUnitRange, getindex_named, view_nameddims

function Base.getindex(r::AbstractNamedUnitRange{<:Integer}, I::Block{1})
# TODO: Use `Derive.@interface NamedArrayInterface() r[I]` instead.
return named_getindex(r, I)
return getindex_named(r, I)
end

function Base.getindex(r::AbstractNamedUnitRange{<:Integer}, I::BlockRange{1})
# TODO: Use `Derive.@interface NamedArrayInterface() r[I]` instead.
return named_getindex(r, I)
return getindex_named(r, I)
end

const BlockIndex{N} = Union{Block{N},BlockRange{N},AbstractVector{<:Block{N}}}

function Base.view(a::AbstractNamedDimsArray, I1::Block{1}, Irest::BlockIndex{1}...)
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
return nameddims_view(a, I1, Irest...)
return view_nameddims(a, I1, Irest...)
end

function Base.view(a::AbstractNamedDimsArray, I::Block)
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
return nameddims_view(a, Tuple(I)...)
return view_nameddims(a, Tuple(I)...)
end

function Base.view(a::AbstractNamedDimsArray, I1::BlockIndex{1}, Irest::BlockIndex{1}...)
# TODO: Use `Derive.@interface NamedDimsArrayInterface() r[I]` instead.
return nameddims_view(a, I1, Irest...)
return view_nameddims(a, I1, Irest...)
end

# Fix ambiguity error.
Expand Down
6 changes: 3 additions & 3 deletions src/abstractnamedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,17 @@ function Base.hash(a::AbstractNamedArray, h::UInt)
return hash(name(a), h)
end

named_getindex(a::AbstractArray, I...) = named(getindex(dename(a), I...), name(a))
getindex_named(a::AbstractArray, I...) = named(getindex(dename(a), I...), name(a))

# Array funcionality.
Base.size(a::AbstractNamedArray) = map(s -> named(s, name(a)), size(dename(a)))
Base.axes(a::AbstractNamedArray) = map(s -> named(s, name(a)), axes(dename(a)))
Base.eachindex(a::AbstractNamedArray) = eachindex(dename(a))
function Base.getindex(a::AbstractNamedArray{<:Any,N}, I::Vararg{Int,N}) where {N}
return named_getindex(a, I...)
return getindex_named(a, I...)
end
function Base.getindex(a::AbstractNamedArray, I::Int)
return named_getindex(a, I)
return getindex_named(a, I)
end
Base.isempty(a::AbstractNamedArray) = isempty(dename(a))

Expand Down
107 changes: 81 additions & 26 deletions src/abstractnameddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ function checked_indexin(x::AbstractUnitRange, y::AbstractUnitRange)
end

function Base.copy(a::AbstractNamedDimsArray)
return nameddimsarraytype(a)(copy(dename(a)), nameddimsindices(a))
return constructorof(typeof(a))(copy(dename(a)), nameddimsindices(a))
end

const NamedDimsIndices = Union{
Expand Down Expand Up @@ -181,9 +181,11 @@ Base.values(s::NaiveOrderedSet) = s.values
Base.Tuple(s::NaiveOrderedSet) = Tuple(values(s))
Base.length(s::NaiveOrderedSet) = length(values(s))
Base.axes(s::NaiveOrderedSet) = axes(values(s))
Base.keys(s::NaiveOrderedSet) = Base.OneTo(length(s))
Base.:(==)(s1::NaiveOrderedSet, s2::NaiveOrderedSet) = issetequal(values(s1), values(s2))
Base.iterate(s::NaiveOrderedSet, args...) = iterate(values(s), args...)
Base.getindex(s::NaiveOrderedSet, I::Int) = values(s)[I]
Base.get(s::NaiveOrderedSet, I::Integer, default) = get(values(s), I, default)
Base.invperm(s::NaiveOrderedSet) = NaiveOrderedSet(invperm(values(s)))
Base.Broadcast._axes(::Broadcasted, axes::NaiveOrderedSet) = axes
Base.Broadcast.BroadcastStyle(::Type{<:NaiveOrderedSet}) = Style{NaiveOrderedSet}()
Expand All @@ -210,6 +212,10 @@ function Base.size(a::AbstractNamedDimsArray)
return NaiveOrderedSet(map(named, size(dename(a)), nameddimsindices(a)))
end

function Base.length(a::AbstractNamedDimsArray)
return prod(size(a); init=1)
end

# Circumvent issue when ndims isn't known at compile time.
function Base.axes(a::AbstractNamedDimsArray, d)
return d <= ndims(a) ? axes(a)[d] : OneTo(1)
Expand All @@ -233,17 +239,20 @@ to_nameddimsaxes(dims) = map(to_nameddimsaxis, dims)
to_nameddimsaxis(ax::NamedDimsAxis) = ax
to_nameddimsaxis(I::NamedDimsIndices) = named(dename(only(axes(I))), I)

nameddimsarraytype(a::AbstractNamedDimsArray) = nameddimsarraytype(typeof(a))
nameddimsarraytype(a::Type{<:AbstractNamedDimsArray}) = unspecify_type_parameters(a)
# Interface inspired by [ConstructionBase.constructorof](https://github.com/JuliaObjects/ConstructionBase.jl).
constructorof(type::Type{<:AbstractArray}) = unspecify_type_parameters(type)

constructorof_nameddims(type::Type{<:AbstractNamedDimsArray}) = constructorof(type)
constructorof_nameddims(type::Type{<:AbstractArray}) = NamedDimsArray

function similar_nameddims(a::AbstractNamedDimsArray, elt::Type, inds)
ax = to_nameddimsaxes(inds)
return nameddimsarraytype(a)(similar(dename(a), elt, dename.(Tuple(ax))), name.(ax))
return constructorof(typeof(a))(similar(dename(a), elt, dename.(Tuple(ax))), name.(ax))
end

function similar_nameddims(a::AbstractArray, elt::Type, inds)
ax = to_nameddimsaxes(inds)
return nameddims(similar(a, elt, dename.(Tuple(ax))), name.(ax))
return constructorof_nameddims(typeof(a))(similar(a, elt, dename.(Tuple(ax))), name.(ax))
end

# Base.similar gets the eltype at compile time.
Expand All @@ -262,7 +271,7 @@ function Base.similar(a::AbstractArray, elt::Type, inds::NaiveOrderedSet)
end

function setnameddimsindices(a::AbstractNamedDimsArray, nameddimsindices)
return nameddimsarraytype(a)(dename(a), nameddimsindices)
return constructorof(typeof(a))(dename(a), nameddimsindices)
end
function replacenameddimsindices(f, a::AbstractNamedDimsArray)
return setnameddimsindices(a, replace(f, nameddimsindices(a)))
Expand Down Expand Up @@ -419,10 +428,18 @@ function Base.setindex!(a::AbstractNamedDimsArray, value, I::CartesianIndex)
setindex!(a, value, to_indices(a, (I,))...)
return a
end

function flatten_namedinteger(i::AbstractNamedInteger)
if name(i) isa Union{AbstractNamedUnitRange,AbstractNamedArray}
return name(i)[dename(i)]
end
return i
end

function Base.setindex!(
a::AbstractNamedDimsArray, value, I1::AbstractNamedInteger, Irest::AbstractNamedInteger...
)
I = (I1, Irest...)
I = flatten_namedinteger.((I1, Irest...))
# TODO: Check if this permuation should be inverted.
perm = getperm(name.(nameddimsindices(a)), name.(I))
# TODO: Throw a `NameMismatch` error.
Expand Down Expand Up @@ -510,7 +527,9 @@ function Base.view(a::AbstractNamedDimsArray, I1::NamedViewIndex, Irest::NamedVi
subinds = map(nameddimsindices(a), I) do dimname, i
return checked_indexin(dename(i), dename(dimname))
end
return nameddims(view(dename(a), subinds...), sub_nameddimsindices)
return constructorof_nameddims(typeof(a))(
view(dename(a), subinds...), sub_nameddimsindices
)
end

function Base.getindex(
Expand All @@ -522,22 +541,22 @@ end
# Repeated definition of `Base.ViewIndex`.
const ViewIndex = Union{Real,AbstractArray}

function nameddims_view(a::AbstractArray, I...)
function view_nameddims(a::AbstractArray, I...)
sub_dims = filter(dim -> !(I[dim] isa Real), ntuple(identity, ndims(a)))
sub_nameddimsindices = map(dim -> nameddimsindices(a, dim)[I[dim]], sub_dims)
return nameddims(view(dename(a), I...), sub_nameddimsindices)
return constructorof(typeof(a))(view(dename(a), I...), sub_nameddimsindices)
end

function Base.view(a::AbstractNamedDimsArray, I::ViewIndex...)
return nameddims_view(a, I...)
return view_nameddims(a, I...)
end

function nameddims_getindex(a::AbstractArray, I...)
function getindex_nameddims(a::AbstractArray, I...)
return copy(view(a, I...))
end

function Base.getindex(a::AbstractNamedDimsArray, I::ViewIndex...)
return nameddims_getindex(a, I...)
return getindex_nameddims(a, I...)
end

function Base.setindex!(
Expand All @@ -556,7 +575,7 @@ function Base.setindex!(
Irest::NamedViewIndex...,
)
I = (I1, Irest...)
setindex!(a, nameddimsarraytype(a)(value, I), I...)
setindex!(a, constructorof(typeof(a))(value, I), I...)
return a
end
function Base.setindex!(
Expand All @@ -580,13 +599,13 @@ end
function aligndims(a::AbstractArray, dims)
new_nameddimsindices = to_nameddimsindices(a, dims)
# TODO: Check this permutation is correct (it may be the inverse of what we want).
perm = getperm(nameddimsindices(a), new_nameddimsindices)
perm = Tuple(getperm(nameddimsindices(a), new_nameddimsindices))
isperm(perm) || throw(
NameMismatch(
"Dimension name mismatch $(nameddimsindices(a)), $(new_nameddimsindices)."
),
)
return nameddimsarraytype(a)(permutedims(dename(a), perm), new_nameddimsindices)
return constructorof(typeof(a))(permutedims(dename(a), perm), new_nameddimsindices)
end

function aligneddims(a::AbstractArray, dims)
Expand All @@ -598,7 +617,9 @@ function aligneddims(a::AbstractArray, dims)
"Dimension name mismatch $(nameddimsindices(a)), $(new_nameddimsindices)."
),
)
return nameddimsarraytype(a)(PermutedDimsArray(dename(a), perm), new_nameddimsindices)
return constructorof_nameddims(typeof(a))(
PermutedDimsArray(dename(a), perm), new_nameddimsindices
)
end

# Convenient constructors
Expand Down Expand Up @@ -711,16 +732,17 @@ using Base.Broadcast:
broadcasted,
check_broadcast_shape,
combine_axes
using MapBroadcast: Mapped, mapped
using MapBroadcast: MapBroadcast, Mapped, mapped, tile

abstract type AbstractNamedDimsArrayStyle{N} <: AbstractArrayStyle{N} end

struct NamedDimsArrayStyle{N} <: AbstractNamedDimsArrayStyle{N} end
NamedDimsArrayStyle(::Val{N}) where {N} = NamedDimsArrayStyle{N}()
NamedDimsArrayStyle{M}(::Val{N}) where {M,N} = NamedDimsArrayStyle{N}()
struct NamedDimsArrayStyle{N,NDA} <: AbstractNamedDimsArrayStyle{N} end
NamedDimsArrayStyle(::Val{N}) where {N} = NamedDimsArrayStyle{N,NamedDimsArray}()
NamedDimsArrayStyle{M}(::Val{N}) where {M,N} = NamedDimsArrayStyle{N,NamedDimsArray}()
NamedDimsArrayStyle{M,NDA}(::Val{N}) where {M,N,NDA} = NamedDimsArrayStyle{N,NDA}()

function Broadcast.BroadcastStyle(arraytype::Type{<:AbstractNamedDimsArray})
return NamedDimsArrayStyle{ndims(arraytype)}()
return NamedDimsArrayStyle{ndims(arraytype),constructorof(arraytype)}()
end

function Broadcast.combine_axes(
Expand Down Expand Up @@ -762,6 +784,24 @@ function set_promote_shape(
return named.(ax_promoted, name.(ax1))
end

# Handle operations like `ITensor() + ITensor(i, j)`.
# TODO: Decide if this should be a general definition for `AbstractNamedDimsArray`,
# or just for `AbstractITensor`.
function set_promote_shape(
ax1::Tuple{}, ax2::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange}}
)
return ax2
end

# Handle operations like `ITensor(i, j) + ITensor()`.
# TODO: Decide if this should be a general definition for `AbstractNamedDimsArray`,
# or just for `AbstractITensor`.
function set_promote_shape(
ax1::Tuple{AbstractNamedUnitRange,Vararg{AbstractNamedUnitRange}}, ax2::Tuple{}
)
return ax1
end

function Broadcast.check_broadcast_shape(ax1::NaiveOrderedSet, ax2::NaiveOrderedSet)
return set_check_broadcast_shape(Tuple(ax1), Tuple(ax2))
end
Expand All @@ -775,6 +815,7 @@ function set_check_broadcast_shape(
check_broadcast_shape(dename.(ax1), dename.(ax2_aligned))
return nothing
end
set_check_broadcast_shape(ax1::Tuple{}, ax2::Tuple{}) = nothing

# Dename and lazily permute the arguments using the reference
# dimension names.
Expand All @@ -783,19 +824,33 @@ function denamed(m::Mapped, nameddimsindices)
return mapped(m.f, map(arg -> denamed(arg, nameddimsindices), m.args)...)
end

function nameddimsarraytype(style::NamedDimsArrayStyle{<:Any,NDA}) where {NDA}
return NDA
end

using FillArrays: Fill

function MapBroadcast.tile(a::AbstractNamedDimsArray, ax)
axes(a) == ax && return a
if iszero(ndims(a))
return constructorof(typeof(a))(Fill(a[], dename.(Tuple(ax))), name.(ax))
end
return error("Not implemented.")
end

function Base.similar(bc::Broadcasted{<:AbstractNamedDimsArrayStyle}, elt::Type, ax)
nameddimsindices = name.(ax)
m′ = denamed(Mapped(bc), nameddimsindices)
# TODO: Store the wrapper type in `AbstractNamedDimsArrayStyle` and use that
# wrapper type rather than the generic `nameddims` constructor, which
# can lose information.
# Call it as `nameddimsarraytype(bc.style)`.
return nameddims(similar(m′, elt, dename.(Tuple(ax))), nameddimsindices)
return nameddimsarraytype(bc.style)(
similar(m′, elt, dename.(Tuple(ax))), nameddimsindices
)
end

function Base.copyto!(
dest::AbstractArray{<:Any,N}, bc::Broadcasted{<:AbstractNamedDimsArrayStyle{N}}
) where {N}
function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractNamedDimsArrayStyle})
return copyto!(dest, Mapped(bc))
end

Expand Down
6 changes: 5 additions & 1 deletion src/abstractnamedinteger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ struct FusedNames{Names} <: AbstractName
names::Names
end
fusednames(name1, name2) = FusedNames((name1, name2))
fusednames(name1::FusedNames, name2::FusedNames) = FusedNames(generic_vcat(name1, name2))
function fusednames(name1::FusedNames, name2::FusedNames)
return FusedNames(generic_vcat(name1.names, name2.names))
end
fusednames(name1, name2::FusedNames) = fusednames(FusedNames((name1,)), name2)
fusednames(name1::FusedNames, name2) = fusednames(name1, FusedNames((name2,)))

Expand Down Expand Up @@ -86,6 +88,8 @@ Base.:-(i::AbstractNamedInteger) = setvalue(i, -dename(i))
# TODO: See if we can delete this.
Base.:+(i1::Int, i2::AbstractNamedInteger) = i1 + dename(i2)

Base.:*(i1::Int, i2::AbstractNamedInteger) = named(i1 * dename(i2), name(i2))

Base.zero(i::AbstractNamedInteger) = setvalue(i, zero(dename(i)))
Base.one(i::AbstractNamedInteger) = setvalue(i, one(dename(i)))
Base.signbit(i::AbstractNamedInteger) = signbit(dename(i))
Expand Down
10 changes: 5 additions & 5 deletions src/abstractnamedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ named(r::AbstractUnitRange, name) = namedunitrange(r, name)

# Derived interface.
# TODO: Use `Accessors.@set`?
setname(r::AbstractNamedUnitRange, name) = namedunitrange(dename(r), name)
setname(r::AbstractNamedUnitRange, name) = named(dename(r), name)

# TODO: Use `TypeParameterAccessors`.
denametype(::Type{<:AbstractNamedUnitRange{<:Any,Value}}) where {Value} = Value
Expand All @@ -43,17 +43,17 @@ Base.length(r::AbstractNamedUnitRange) = named(length(dename(r)), name(r))
Base.size(r::AbstractNamedUnitRange) = (named(length(dename(r)), name(r)),)
Base.axes(r::AbstractNamedUnitRange) = (named(only(axes(dename(r))), name(r)),)
Base.step(r::AbstractNamedUnitRange) = named(step(dename(r)), name(r))
Base.getindex(r::AbstractNamedUnitRange, I::Int) = named_getindex(r, I)
Base.getindex(r::AbstractNamedUnitRange, I::Int) = getindex_named(r, I)
# Fix ambiguity error.
function Base.getindex(r::AbstractNamedUnitRange, I::AbstractUnitRange{<:Integer})
return named_getindex(r, I)
return getindex_named(r, I)
end
# Fix ambiguity error.
function Base.getindex(r::AbstractNamedUnitRange, I::Colon)
return named_getindex(r, I)
return getindex_named(r, I)
end
function Base.getindex(r::AbstractNamedUnitRange, I)
return named_getindex(r, I)
return getindex_named(r, I)
end
Base.isempty(r::AbstractNamedUnitRange) = isempty(dename(r))

Expand Down

0 comments on commit 555adcb

Please sign in to comment.