From b6eccc6a8036dd9384df8d69f166e753c365d0b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Tue, 27 Feb 2024 13:21:21 +0100 Subject: [PATCH] API: Add fill_value config argument to similar --- src/tensors/fibers.jl | 8 +++- src/tensors/levels/atomiclevels.jl | 3 +- src/tensors/levels/denselevels.jl | 5 ++- src/tensors/levels/denserlelevels.jl | 4 +- src/tensors/levels/elementlevels.jl | 5 ++- src/tensors/levels/patternlevels.jl | 2 +- src/tensors/levels/repeatrlelevels.jl | 3 +- src/tensors/levels/separatelevels.jl | 10 ++++- src/tensors/levels/singlelistlevels.jl | 4 +- src/tensors/levels/singlerlelevels.jl | 6 +-- src/tensors/levels/sparsebandlevels.jl | 4 +- src/tensors/levels/sparsebytemaplevels.jl | 4 +- src/tensors/levels/sparsecoolevels.jl | 4 +- src/tensors/levels/sparsehashlevels.jl | 4 +- src/tensors/levels/sparselevels.jl | 4 +- src/tensors/levels/sparselistlevels.jl | 4 +- src/tensors/levels/sparserlelevels.jl | 4 +- src/tensors/levels/sparsetrianglelevels.jl | 4 +- src/tensors/levels/sparsevbllevels.jl | 4 +- test/test_constructors.jl | 52 ++++++++++++++++++++++ 20 files changed, 101 insertions(+), 37 deletions(-) diff --git a/src/tensors/fibers.jl b/src/tensors/fibers.jl index c6cd05d5c..3c40df9d3 100644 --- a/src/tensors/fibers.jl +++ b/src/tensors/fibers.jl @@ -326,7 +326,11 @@ end Base.summary(fbr::Tensor) = "$(join(size(fbr), "×")) Tensor($(summary(fbr.lvl)))" Base.summary(fbr::SubFiber) = "$(join(size(fbr), "×")) SubFiber($(summary(fbr.lvl)))" -Base.similar(fbr::AbstractFiber) = Tensor(similar_level(fbr.lvl)) -Base.similar(fbr::AbstractFiber, dims::Tuple) = Tensor(similar_level(fbr.lvl, dims...)) +Base.similar(fbr::AbstractFiber) = similar(fbr, default(fbr), eltype(fbr), size(fbr)) +Base.similar(fbr::AbstractFiber, eltype::Type) = similar(fbr, default(fbr), eltype, size(fbr)) +Base.similar(fbr::AbstractFiber, fill_value, eltype::Type) = similar(fbr, fill_value, eltype, size(fbr)) +Base.similar(fbr::AbstractFiber, dims::Tuple) = similar(fbr, default(fbr), eltype(fbr), dims) +Base.similar(fbr::AbstractFiber, eltype::Type, dims::Tuple) = similar(fbr, default(fbr), eltype, dims) +Base.similar(fbr::AbstractFiber, fill_value, eltype::Type, dims::Tuple) = Tensor(similar_level(fbr.lvl, fill_value, eltype, dims...)) moveto(tns::Tensor, device) = Tensor(moveto(tns.lvl, device)) \ No newline at end of file diff --git a/src/tensors/levels/atomiclevels.jl b/src/tensors/levels/atomiclevels.jl index 871e77682..8f704f8f1 100644 --- a/src/tensors/levels/atomiclevels.jl +++ b/src/tensors/levels/atomiclevels.jl @@ -25,7 +25,8 @@ AtomicLevel(lvl::Lvl) where {Lvl} = AtomicLevel{Vector{Base.Threads.SpinLock}, L # AtomicLevel{AVal, Lvl}(atomics::AVal, lvl::Lvl) where {Lvl, AVal} = AtomicLevel{AVal, Lvl}(lvl, atomics) Base.summary(::AtomicLevel{AVal, Lvl}) where {Lvl, AVal} = "AtomicLevel($(AVal), $(Lvl))" -similar_level(lvl::Atomic{AVal, Lvl}) where {Lvl, AVal} = AtomicLevel{AVal, Lvl}(similar_level(lvl.lvl)) +similar_level(lvl::Atomic{AVal, Lvl}, fill_value, eltype::Type, dims...) where {Lvl, AVal} = + AtomicLevel(similar_level(lvl.lvl, fill_value, eltype, dims...)) postype(::Type{<:AtomicLevel{AVal, Lvl}}) where {Lvl, AVal} = postype(Lvl) diff --git a/src/tensors/levels/denselevels.jl b/src/tensors/levels/denselevels.jl index 65f43c224..fd5abbdc9 100644 --- a/src/tensors/levels/denselevels.jl +++ b/src/tensors/levels/denselevels.jl @@ -34,8 +34,9 @@ DenseLevel{Ti}(lvl::Lvl, shape) where {Ti, Lvl} = DenseLevel{Ti, Lvl}(lvl, shape const Dense = DenseLevel Base.summary(lvl::Dense) = "Dense($(summary(lvl.lvl)))" -similar_level(lvl::DenseLevel) = Dense(similar_level(lvl.lvl)) -similar_level(lvl::DenseLevel, dims...) = Dense(similar_level(lvl.lvl, dims[1:end-1]...), dims[end]) + +similar_level(lvl::DenseLevel, fill_value, eltype::Type, dims...) = + Dense(similar_level(lvl.lvl, fill_value, eltype, dims[1:end-1]...), dims[end]) function postype(::Type{DenseLevel{Ti, Lvl}}) where {Ti, Lvl} return postype(Lvl) diff --git a/src/tensors/levels/denserlelevels.jl b/src/tensors/levels/denserlelevels.jl index fd8980f00..b603608aa 100644 --- a/src/tensors/levels/denserlelevels.jl +++ b/src/tensors/levels/denserlelevels.jl @@ -41,8 +41,8 @@ DenseRLELevel{Ti}(lvl::Lvl, shape, ptr::Ptr, right::Right, buf::Lvl) where {Ti, DenseRLELevel{Ti, Ptr, Right, Lvl}(lvl, Ti(shape), ptr, right, buf) Base.summary(lvl::DenseRLELevel) = "DenseRLE($(summary(lvl.lvl)))" -similar_level(lvl::DenseRLELevel) = DenseRLE(similar_level(lvl.lvl)) -similar_level(lvl::DenseRLELevel, dim, tail...) = DenseRLE(similar_level(lvl.lvl, tail...), dim) +similar_level(lvl::DenseRLELevel, fill_value, eltype::Type, dim, tail...) = + DenseRLE(similar_level(lvl.lvl, fill_value, eltype, tail...), dim) function postype(::Type{DenseRLELevel{Ti, Ptr, Right, Lvl}}) where {Ti, Ptr, Right, Lvl} return postype(Lvl) diff --git a/src/tensors/levels/elementlevels.jl b/src/tensors/levels/elementlevels.jl index f2ceb2cad..b0b4fc649 100644 --- a/src/tensors/levels/elementlevels.jl +++ b/src/tensors/levels/elementlevels.jl @@ -5,7 +5,7 @@ A subfiber of an element level is a scalar of type `Tv`, initialized to `D`. `D` may optionally be given as the first argument. The data is stored in a vector -of type `Val` with `eltype(Val) = Tv`. The type `Ti` is the index type used to +of type `Val` with `eltype(Val) = Tv`. The type `Tp` is the index type used to access Val. ```jldoctest @@ -34,7 +34,8 @@ ElementLevel{D, Tv, Tp}(val::Val) where {D, Tv, Tp, Val} = ElementLevel{D, Tv, T Base.summary(::Element{D}) where {D} = "Element($(D))" -similar_level(::ElementLevel{D, Tv, Tp}) where {D, Tv, Tp} = ElementLevel{D, Tv, Tp}() +similar_level(::ElementLevel{D, Tv, Tp}, fill_value, eltype::Type, ::Vararg) where {D, Tv, Tp} = + ElementLevel{fill_value, eltype, Tp}() postype(::Type{<:ElementLevel{D, Tv, Tp}}) where {D, Tv, Tp} = Tp diff --git a/src/tensors/levels/patternlevels.jl b/src/tensors/levels/patternlevels.jl index 9c7519a25..c10bad349 100644 --- a/src/tensors/levels/patternlevels.jl +++ b/src/tensors/levels/patternlevels.jl @@ -19,7 +19,7 @@ const Pattern = PatternLevel PatternLevel() = PatternLevel{Int}() Base.summary(::Pattern) = "Pattern()" -similar_level(::PatternLevel) = PatternLevel() +similar_level(::PatternLevel, ::Any, ::Type, ::Vararg) = PatternLevel() countstored_level(lvl::PatternLevel, pos) = pos diff --git a/src/tensors/levels/repeatrlelevels.jl b/src/tensors/levels/repeatrlelevels.jl index 17869c867..815439b62 100644 --- a/src/tensors/levels/repeatrlelevels.jl +++ b/src/tensors/levels/repeatrlelevels.jl @@ -43,8 +43,7 @@ RepeatRLELevel{D, Ti, Tp, Tv}(shape, ptr::Ptr, idx::Idx, val::Val) where {D, Ti, RepeatRLELevel{D, Ti, Tp, Tv, Ptr, Idx, Val}(shape, ptr, idx, val) Base.summary(::RepeatRLE{D}) where {D} = "RepeatRLE($(D))" -similar_level(::RepeatRLELevel{D}) where {D} = RepeatRLE{D}() -similar_level(::RepeatRLELevel{D}, dim, tail...) where {D} = RepeatRLE{D}(dim) +similar_level(::RepeatRLELevel{D}, ::Any, ::Type, dim, tail...) where {D} = RepeatRLE{D}(dim) data_rep_level(::Type{<:RepeatRLELevel{D, Ti, Tp, Tv, Ptr, Idx, Val}}) where {D, Ti, Tp, Tv, Ptr, Idx, Val} = RepeatData(D, Tv) function postype(::Type{RepeatRLELevel{D, Ti, Tp, Tv, Ptr, Idx, Val}}) where {D, Ti, Tp, Tv, Ptr, Idx, Val} diff --git a/src/tensors/levels/separatelevels.jl b/src/tensors/levels/separatelevels.jl index 908fc7d9f..aea9c25d6 100644 --- a/src/tensors/levels/separatelevels.jl +++ b/src/tensors/levels/separatelevels.jl @@ -27,7 +27,8 @@ SeparateLevel(lvl::Lvl) where {Lvl} = SeparateLevel(lvl, [lvl]) SeparateLevel{Lvl, Val}(lvl::Lvl) where {Lvl, Val} = SeparateLevel{Lvl, Val}(lvl, [lvl]) Base.summary(::Separate{Lvl, Val}) where {Lvl, Val} = "Separate($(Lvl))" -similar_level(lvl::Separate{Lvl, Val}) where {Lvl, Val} = SeparateLevel{Lvl, Val}(similar_level(lvl.lvl)) +similar_level(lvl::Separate{Lvl, Val}, fill_value, eltype::Type, dims...) where {Lvl, Val} = + SeparateLevel(similar_level(lvl.lvl, fill_value, eltype, dims...)) postype(::Type{<:Separate{Lvl, Val}}) where {Lvl, Val} = postype(Lvl) @@ -128,7 +129,12 @@ function assemble_level!(lvl::VirtualSeparateLevel, ctx, pos_start, pos_stop) push!(ctx.code.preamble, quote Finch.resize_if_smaller!($(lvl.ex).val, $(ctx(pos_stop))) for $pos in $(ctx(pos_start)):$(ctx(pos_stop)) - $sym = similar_level($(lvl.ex).lvl) + $sym = similar_level( + $(lvl.ex).lvl, + level_default(typeof($(lvl.ex).lvl)), + level_eltype(typeof($(lvl.ex).lvl)), + level_size($(lvl.ex).lvl)... + ) $(contain(ctx) do ctx_2 lvl_2 = virtualize(sym, lvl.Lvl, ctx_2.code, sym) lvl_2 = declare_level!(lvl_2, ctx_2, literal(0), literal(virtual_level_default(lvl_2))) diff --git a/src/tensors/levels/singlelistlevels.jl b/src/tensors/levels/singlelistlevels.jl index c98afe14d..34aa234c9 100644 --- a/src/tensors/levels/singlelistlevels.jl +++ b/src/tensors/levels/singlelistlevels.jl @@ -51,8 +51,8 @@ SingleListLevel{Ti}(lvl::Lvl, shape, ptr::Ptr, idx::Idx) where {Ti, Lvl, Ptr, Id SingleListLevel{Ti, Ptr, Idx, Lvl}(lvl, shape, ptr, idx) Base.summary(lvl::SingleListLevel) = "SingleList($(summary(lvl.lvl)))" -similar_level(lvl::SingleListLevel) = SingleList(similar_level(lvl.lvl)) -similar_level(lvl::SingleListLevel, dim, tail...) = SingleList(similar_level(lvl.lvl, tail...), dim) +similar_level(lvl::SingleListLevel, fill_value, eltype::Type, dim, tail...) = + SingleList(similar_level(lvl.lvl, fill_value, eltype, tail...), dim) function postype(::Type{SingleListLevel{Ti, Ptr, Idx, Lvl}}) where {Ti, Ptr, Idx, Lvl} return postype(Lvl) diff --git a/src/tensors/levels/singlerlelevels.jl b/src/tensors/levels/singlerlelevels.jl index 68e0b086c..c04b69b20 100644 --- a/src/tensors/levels/singlerlelevels.jl +++ b/src/tensors/levels/singlerlelevels.jl @@ -42,10 +42,10 @@ SingleRLELevel{Ti}(lvl, shape) where {Ti} = SingleRLELevel{Ti}(lvl, shape, posty SingleRLELevel{Ti}(lvl::Lvl, shape, ptr::Ptr, left::Left, right::Right) where {Ti, Lvl, Ptr, Left, Right} = SingleRLELevel{Ti, Ptr, Left, Right, Lvl}(lvl, shape, ptr, left, right) - + Base.summary(lvl::SingleRLELevel) = "SingleRLE($(summary(lvl.lvl)))" -similar_level(lvl::SingleRLELevel) = SingleRLE(similar_level(lvl.lvl)) -similar_level(lvl::SingleRLELevel, dim, tail...) = SingleRLE(similar_level(lvl.lvl, tail...), dim) +similar_level(lvl::SingleRLELevel, fill_value, eltype::Type, dim, tail...) = + SingleRLE(similar_level(lvl.lvl, fill_value, eltype, tail...), dim) function memtype(::Type{SingleRLELevel{Ti, Ptr, Left, Right, Lvl}}) where {Ti, Ptr, Left, Right, Lvl} return Ti diff --git a/src/tensors/levels/sparsebandlevels.jl b/src/tensors/levels/sparsebandlevels.jl index 02e891f35..d8e521aff 100644 --- a/src/tensors/levels/sparsebandlevels.jl +++ b/src/tensors/levels/sparsebandlevels.jl @@ -43,8 +43,8 @@ function moveto(lvl::SparseBandLevel{Ti}, device) where {Ti} end Base.summary(lvl::SparseBandLevel) = "SparseBand($(summary(lvl.lvl)))" -similar_level(lvl::SparseBandLevel) = SparseBand(similar_level(lvl.lvl)) -similar_level(lvl::SparseBandLevel, dim, tail...) = SparseBand(similar_level(lvl.lvl, tail...), dim) +similar_level(lvl::SparseBandLevel, fill_value, eltype::Type, dim, tail...) = + SparseBand(similar_level(lvl.lvl, fill_value, eltype, tail...), dim) pattern!(lvl::SparseBandLevel{Ti}) where {Ti} = SparseBandLevel{Ti}(pattern!(lvl.lvl), lvl.shape, lvl.ptr, lvl.idx, lvl.ofs) diff --git a/src/tensors/levels/sparsebytemaplevels.jl b/src/tensors/levels/sparsebytemaplevels.jl index 1f8250dd7..6b8023adc 100644 --- a/src/tensors/levels/sparsebytemaplevels.jl +++ b/src/tensors/levels/sparsebytemaplevels.jl @@ -43,8 +43,8 @@ SparseByteMapLevel{Ti}(lvl::Lvl, shape, ptr::Ptr, tbl::Tbl, srt::Srt) where {Ti, SparseByteMapLevel{Ti, Ptr, Tbl, Srt, Lvl}(lvl, shape, ptr, tbl, srt) Base.summary(lvl::SparseByteMapLevel) = "SparseByteMap($(summary(lvl.lvl)))" -similar_level(lvl::SparseByteMapLevel) = SparseByteMap(similar_level(lvl.lvl)) -similar_level(lvl::SparseByteMapLevel, dims...) = SparseByteMap(similar_level(lvl.lvl, dims[1:end-1]...), dims[end]) +similar_level(lvl::SparseByteMapLevel, fill_value, eltype::Type, dims...) = + SparseByteMap(similar_level(lvl.lvl, fill_value, eltype, dims[1:end-1]...), dims[end]) function postype(::Type{SparseByteMapLevel{Ti, Ptr, Tbl, Srt, Lvl}}) where {Ti, Ptr, Tbl, Srt, Lvl} return postype(Lvl) diff --git a/src/tensors/levels/sparsecoolevels.jl b/src/tensors/levels/sparsecoolevels.jl index a9bd882bc..fb855bd75 100644 --- a/src/tensors/levels/sparsecoolevels.jl +++ b/src/tensors/levels/sparsecoolevels.jl @@ -56,8 +56,8 @@ SparseCOOLevel{N, TI}(lvl::Lvl, shape, ptr::Ptr, tbl::Tbl) where {N, TI, Lvl, Pt SparseCOOLevel{N, TI, Ptr, Tbl, Lvl}(lvl, TI(shape), ptr, tbl) Base.summary(lvl::SparseCOOLevel{N}) where {N} = "SparseCOO{$N}($(summary(lvl.lvl)))" -similar_level(lvl::SparseCOOLevel{N}) where {N} = SparseCOOLevel{N}(similar_level(lvl.lvl)) -similar_level(lvl::SparseCOOLevel{N}, tail...) where {N} = SparseCOOLevel{N}(similar_level(lvl.lvl, tail[1:end-N]...), (tail[end-N+1:end]...,)) +similar_level(lvl::SparseCOOLevel{N}, fill_value, eltype::Type, tail...) where {N} = + SparseCOOLevel{N}(similar_level(lvl.lvl, fill_value, eltype, tail[1:end-N]...), (tail[end-N+1:end]...,)) function postype(::Type{SparseCOOLevel{N, TI, Ptr, Tbl, Lvl}}) where {N, TI, Ptr, Tbl, Lvl} return postype(Lvl) diff --git a/src/tensors/levels/sparsehashlevels.jl b/src/tensors/levels/sparsehashlevels.jl index 170775ba9..2cbf01a90 100644 --- a/src/tensors/levels/sparsehashlevels.jl +++ b/src/tensors/levels/sparsehashlevels.jl @@ -60,8 +60,8 @@ SparseHashLevel{N, TI}(lvl::Lvl, shape, ptr::Ptr, tbl::Tbl, srt::Srt) where {N, SparseHashLevel{N, TI, Ptr, Tbl, Srt, Lvl}(lvl, shape, ptr, tbl, srt) Base.summary(lvl::SparseHashLevel{N}) where {N} = "SparseHash{$N}($(summary(lvl.lvl)))" -similar_level(lvl::SparseHashLevel{N}) where {N} = SparseHashLevel{N}(similar_level(lvl.lvl)) -similar_level(lvl::SparseHashLevel{N}, tail...) where {N} = SparseHashLevel{N}(similar_level(lvl.lvl, tail[1:end-N]...), (tail[end-N+1:end]...,)) +similar_level(lvl::SparseHashLevel{N}, fill_value, eltype::Type, tail...) where {N} = + SparseHashLevel{N}(similar_level(lvl.lvl, fill_value, eltype, tail[1:end-N]...), (tail[end-N+1:end]...,)) function postype(::Type{SparseHashLevel{N, TI, Ptr, Tbl, Srt, Lvl}}) where {N, TI, Ptr, Tbl, Srt, Lvl} return postype(Lvl) diff --git a/src/tensors/levels/sparselevels.jl b/src/tensors/levels/sparselevels.jl index 38b21a6b0..e8e44affa 100644 --- a/src/tensors/levels/sparselevels.jl +++ b/src/tensors/levels/sparselevels.jl @@ -158,8 +158,8 @@ SparseLevel{Ti}(lvl::Lvl, shape, tbl::Tbl) where {Ti, Lvl, Tbl} = SparseLevel{Ti, Tbl, Lvl}(lvl, shape, tbl) Base.summary(lvl::SparseLevel) = "Sparse($(summary(lvl.lvl)))" -similar_level(lvl::SparseLevel) = Sparse(similar_level(lvl.lvl)) -similar_level(lvl::SparseLevel, dim, tail...) = Sparse(similar_level(lvl.lvl, tail...), dim) +similar_level(lvl::SparseLevel, fill_value, eltype::Type, dim, tail...) = + Sparse(similar_level(lvl.lvl, fill_value, eltype, tail...), dim) function postype(::Type{SparseLevel{Ti, Tbl, Lvl}}) where {Ti, Tbl, Lvl} return postype(Lvl) diff --git a/src/tensors/levels/sparselistlevels.jl b/src/tensors/levels/sparselistlevels.jl index a195f55c5..bd2471db5 100644 --- a/src/tensors/levels/sparselistlevels.jl +++ b/src/tensors/levels/sparselistlevels.jl @@ -48,8 +48,8 @@ SparseListLevel{Ti}(lvl::Lvl, shape, ptr::Ptr, idx::Idx) where {Ti, Lvl, Ptr, Id SparseListLevel{Ti, Ptr, Idx, Lvl}(lvl, shape, ptr, idx) Base.summary(lvl::SparseListLevel) = "SparseList($(summary(lvl.lvl)))" -similar_level(lvl::SparseListLevel) = SparseList(similar_level(lvl.lvl)) -similar_level(lvl::SparseListLevel, dim, tail...) = SparseList(similar_level(lvl.lvl, tail...), dim) +similar_level(lvl::SparseListLevel, fill_value, eltype::Type, dim, tail...) = + SparseList(similar_level(lvl.lvl, fill_value, eltype, tail...), dim) function postype(::Type{SparseListLevel{Ti, Ptr, Idx, Lvl}}) where {Ti, Ptr, Idx, Lvl} return postype(Lvl) diff --git a/src/tensors/levels/sparserlelevels.jl b/src/tensors/levels/sparserlelevels.jl index d3b39c12f..3fe0975ab 100644 --- a/src/tensors/levels/sparserlelevels.jl +++ b/src/tensors/levels/sparserlelevels.jl @@ -39,8 +39,8 @@ SparseRLELevel{Ti}(lvl::Lvl, shape, ptr::Ptr, left::Left, right::Right, buf::Lvl SparseRLELevel{Ti, Ptr, Left, Right, Lvl}(lvl, Ti(shape), ptr, left, right, buf) Base.summary(lvl::SparseRLELevel) = "SparseRLE($(summary(lvl.lvl)))" -similar_level(lvl::SparseRLELevel) = SparseRLE(similar_level(lvl.lvl)) -similar_level(lvl::SparseRLELevel, dim, tail...) = SparseRLE(similar_level(lvl.lvl, tail...), dim) +similar_level(lvl::SparseRLELevel, fill_value, eltype::Type, dim, tail...) = + SparseRLE(similar_level(lvl.lvl, fill_value, eltype, tail...), dim) function postype(::Type{SparseRLELevel{Ti, Ptr, Left, Right, Lvl}}) where {Ti, Ptr, Left, Right, Lvl} return postype(Lvl) diff --git a/src/tensors/levels/sparsetrianglelevels.jl b/src/tensors/levels/sparsetrianglelevels.jl index 1951373de..e395c08c2 100644 --- a/src/tensors/levels/sparsetrianglelevels.jl +++ b/src/tensors/levels/sparsetrianglelevels.jl @@ -34,8 +34,8 @@ SparseTriangleLevel{N, Ti, Lvl}(lvl) where {N, Ti, Lvl} = SparseTriangleLevel{N, const SparseTriangle = SparseTriangleLevel Base.summary(lvl::SparseTriangle{N}) where {N} = "SparseTriangle{$N}($(summary(lvl.lvl)))" -similar_level(lvl::SparseTriangle{N}) where {N} = SparseTriangle(similar_level(lvl.lvl)) -similar_level(lvl::SparseTriangle{N}, dims...) where {N} = SparseTriangle(similar_level(lvl.lvl, dims[1:end-1]...), dims[end]) +similar_level(lvl::SparseTriangle{N}, fill_value, eltype::Type, dims...) where {N} = + SparseTriangle(similar_level(lvl.lvl, fill_value, eltype, dims[1:end-1]...), dims[end]) function postype(::Type{SparseTriangleLevel{N, Ti, Lvl}}) where {N, Ti, Lvl} return postype(Lvl) diff --git a/src/tensors/levels/sparsevbllevels.jl b/src/tensors/levels/sparsevbllevels.jl index 6163f0c0b..68360536b 100644 --- a/src/tensors/levels/sparsevbllevels.jl +++ b/src/tensors/levels/sparsevbllevels.jl @@ -52,8 +52,8 @@ function moveto(lvl::SparseVBLLevel{Ti}, device) where {Ti} end Base.summary(lvl::SparseVBLLevel) = "SparseVBL($(summary(lvl.lvl)))" -similar_level(lvl::SparseVBLLevel) = SparseVBL(similar_level(lvl.lvl)) -similar_level(lvl::SparseVBLLevel, dim, tail...) = SparseVBL(similar_level(lvl.lvl, tail...), dim) +similar_level(lvl::SparseVBLLevel, fill_value, eltype::Type, dim, tail...) = + SparseVBL(similar_level(lvl.lvl, fill_value, eltype, tail...), dim) pattern!(lvl::SparseVBLLevel{Ti}) where {Ti} = SparseVBLLevel{Ti}(pattern!(lvl.lvl), lvl.shape, lvl.ptr, lvl.idx, lvl.ofs) diff --git a/test/test_constructors.jl b/test/test_constructors.jl index f481c0bf4..22f6e8155 100644 --- a/test/test_constructors.jl +++ b/test/test_constructors.jl @@ -62,6 +62,42 @@ @test Structure(fbr) == Structure(Tensor(Lvl{Int16}(Element(0.0)))) @test Structure(fbr) == Structure(Tensor(Lvl(Element(0.0), Int16(0)))) @test Structure(fbr) == Structure(Tensor(Lvl{Int16}(Element(0.0), 0))) + + if key == "SingleList" || key == "SingleRLE" + continue # don't test similar for Single* + end + if key == "SparseBand" + continue # https://github.com/willow-ahrens/Finch.jl/issues/443 + end + + fbr = Tensor(Dense(Lvl(Element(Int64(0)))), Matrix(reshape(1:25, (5, 5)))) + res = similar(fbr) + @test size(res) == size(fbr) + @test default(res) == 0 && eltype(res) == Int64 + + res = similar(fbr, (10, 5)) + @test size(res) == (10, 5) + @test default(res) == 0 && eltype(res) == Int64 + + res = similar(fbr, Float64) + @test size(res) == size(fbr) + @test default(res) == 0 && eltype(res) == Float64 + + res = similar(fbr, 1, Float64) + @test size(res) == size(fbr) + @test default(res) == 1 && eltype(res) == Float64 + + res = similar(fbr, ComplexF32, (10, 5)) + @test size(res) == (10, 5) + @test default(res) == 0 && eltype(res) == ComplexF32 + + res = similar(fbr, 2, ComplexF64, (10, 5)) + @test size(res) == (10, 5) + @test default(res) == 2 && eltype(res) == ComplexF64 + + res = copyto!(similar(fbr, -1, Float64), fbr) + @test res == fbr + @test default(res) == -1 && eltype(res) == Float64 end @test check_output("constructors/format_$key.txt", String(take!(io))) @@ -136,6 +172,11 @@ @test Structure(fbr) == Structure(Tensor(Lvl{N}(Element(0.0), Int16.(zerodim)))) @test Structure(fbr) == Structure(Tensor(Lvl{N, NTuple{N, Int16}}(Element(0.0), Int16.(zerodim)))) + fbr = Tensor(Lvl{2}(Element(0)), Matrix(reshape(1:25, (5, 5)))) + res = copyto!(similar(fbr, -1, Float64), fbr) + @test res == fbr + @test default(res) == -1 && eltype(res) == Float64 + end @test check_output("constructors/format_$(key).txt", String(take!(io))) end @@ -174,6 +215,11 @@ println(io, "empty tensor: ", fbr) @test Structure(fbr) == Structure(Tensor(Dense(Separate(Dense(Element(0)))))) + fbr = Tensor(Dense(Separate(Dense(Element(0)))), Matrix(reshape(1:25, (5, 5)))) + res = copyto!(similar(fbr, -1, Float64), fbr) + @test res == fbr + @test default(res) == -1 && eltype(res) == Float64 + @test check_output("constructors/format_d_p_d_e.txt", String(take!(io))) end @@ -198,6 +244,11 @@ println(io, "empty tensor: ", fbr) @test Structure(fbr) == Structure(Tensor(Dense(Atomic(Dense(Element(0)))))) + fbr = Tensor(Dense(Atomic(Dense(Element(0)))), Matrix(reshape(1:25, (5, 5)))) + res = copyto!(similar(fbr, -1, Float64), fbr) + @test res == fbr + @test default(res) == -1 && eltype(res) == Float64 + @test check_output("constructors/format_d_a_d_e.txt", String(take!(io))) end @@ -230,4 +281,5 @@ @test obov == [val, val, val, 4] && obov.data == [val-1, val-1, val-1, 3] end + end