Skip to content

Commit

Permalink
Merge pull request #450 from willow-ahrens/wma/fix_countstored
Browse files Browse the repository at this point in the history
Wma/fix countstored
  • Loading branch information
willow-ahrens authored Mar 6, 2024
2 parents 08dfbba + be68a34 commit e6544d8
Show file tree
Hide file tree
Showing 131 changed files with 1,258 additions and 356 deletions.
14 changes: 12 additions & 2 deletions docs/examples/bfs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ function bfs(edges, source=5)
F = Tensor(SparseByteMap(Pattern()), n)
_F = Tensor(SparseByteMap(Pattern()), n)
@finch F[source] = true
F_nnz = 1

V = Tensor(Dense(Element(false)), n)
@finch V[source] = true

P = Tensor(Dense(Element(0)), n)
@finch P[source] = source

while countstored(F) > 0
while F_nnz > 0
@finch begin
_F .= false
for j=_, k=_
Expand All @@ -29,8 +30,17 @@ function bfs(edges, source=5)
end
end
end
@finch for k=_; V[k] |= _F[k] end
c = Scalar(0)
@finch begin
for k=_
let _f = _F[k]
V[k] |= _f
c[] += _f
end
end
end
(F, _F) = (_F, F)
F_nnz = c[]
end
return P
end
2 changes: 1 addition & 1 deletion src/looplets/steppers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function lower(root::FinchNode, ctx::AbstractCompiler, style::StepperStyle)
push!(ctx.code.preamble, stepper_seek(node.val, ctx, root.ext))
end

if style.count == 1
if style.count == 1 && !query(call(==, measure(root.ext.val), get_smallest_measure(root.ext.val)), ctx)
body_2 = contain(ctx) do ctx_2
push!(ctx_2.code.preamble, :($i0 = $i))
i1 = freshen(ctx_2.code, i)
Expand Down
4 changes: 2 additions & 2 deletions src/symbolic/analyze_bounds.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ function get_bounds_rules(alg, shash)
end),

(@rule call(max, ~a1..., call(min, ~a2...), ~a3..., call(min, ~a4...), ~a5...) => if !(isdisjoint(a2, a4))
call(max, a1..., call(min, intersect(a2, a4)..., call(max, call(min, setdiff(a2, a4)...), call(min, setdiff(a4, a2)...)), a3..., a5...))
call(max, a1..., call(min, intersect(a2, a4)..., call(max, call(min, setdiff(a2, a4)...), call(min, setdiff(a4, a2)...))), a3..., a5...)
end),

(@rule call(min, ~a1..., call(max, ~a2...), ~a3..., call(max, ~a4...), ~a5...) => if !(isdisjoint(a2, a4))
call(min, a1..., call(max, intersect(a2, a4)..., call(min, call(max, setdiff(a2, a4)...), call(max, setdiff(a4, a2)...)), a3..., a5...))
call(min, a1..., call(max, intersect(a2, a4)..., call(min, call(max, setdiff(a2, a4)...), call(max, setdiff(a4, a2)...))), a3..., a5...)
end),

(@rule call(min, ~a1..., call(max), ~a2...) => call(min, a1..., a2...)),
Expand Down
2 changes: 1 addition & 1 deletion src/tensors/levels/sparsebytemaplevels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Base.resize!(lvl::SparseByteMapLevel{Ti}, dims...) where {Ti} =
SparseByteMapLevel{Ti}(resize!(lvl.lvl, dims[1:end-1]...), dims[end], lvl.ptr, lvl.tbl, lvl.srt)

function countstored_level(lvl::SparseByteMapLevel, pos)
countstored_level(lvl.lvl, lvl.ptr[pos + 1] - 1)
countstored_level(lvl.lvl, pos * lvl.shape)
end

function Base.show(io::IO, lvl::SparseByteMapLevel{Ti, Ptr, Tbl, Srt, Lvl},) where {Ti, Ptr, Tbl, Srt, Lvl}
Expand Down
16 changes: 14 additions & 2 deletions src/tensors/levels/sparselevels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function subtable_seek(tbl, subtbl, state, i, j)
end

function subtable_seek(tbl::DictTable, (p, start, stop), q, i, j)
q = Finch.scansearch(tbl.idx, j, q, stop)
q = Finch.scansearch(tbl.idx, j, q, stop - 1)
return (tbl.idx[q], q)
end

Expand Down Expand Up @@ -175,7 +175,19 @@ function moveto(lvl::SparseLevel{Ti, Tbl, Lvl}, Tm) where {Ti, Tbl, Lvl}
end

function countstored_level(lvl::SparseLevel, pos)
countstored_level(lvl.lvl, lvl.ptr[pos + 1] - 1)
pos == 0 && return countstored_level(lvl.lvl, pos)
subtbl = table_query(lvl.tbl, pos)
start, stop, state = subtable_init(lvl.tbl, subtbl)
if start <= stop
i, qos = subtable_get(lvl.tbl, subtbl, state)
if i < stop
i, state = subtable_seek(lvl.tbl, subtbl, state, start, stop)
i, qos = subtable_get(lvl.tbl, subtbl, state)
end
countstored_level(lvl.lvl, qos)
else
0
end
end

pattern!(lvl::SparseLevel{Ti}) where {Ti} =
Expand Down
2 changes: 1 addition & 1 deletion src/tensors/levels/sparserlelevels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pattern!(lvl::SparseRLELevel{Ti}) where {Ti} =
SparseRLELevel{Ti}(pattern!(lvl.lvl), lvl.shape, lvl.ptr, lvl.left, lvl.right, pattern!(lvl.buf); merge = getmerge(lvl))

function countstored_level(lvl::SparseRLELevel, pos)
countstored_level(lvl.lvl, lvl.left[lvl.ptr[pos + 1]]-1)
countstored_level(lvl.lvl, lvl.ptr[pos + 1]-1)
end

redefault!(lvl::SparseRLELevel{Ti}, init) where {Ti} =
Expand Down
4 changes: 3 additions & 1 deletion src/transforms/scopes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ function (ctx::ScopeVisitor)(node::FinchNode)
if node.lhs.kind != variable
throw(ScopeError("cannot define a non-variable $node.lhs"))
end
#TODO why not just freshen variables?
rhs = ctx(node.rhs)
var = node.lhs
haskey(ctx.vars, var) && throw(ScopeError("In node $(node) variable $(var) is already bound."))
ctx.vars[var] = node.rhs
define(node.lhs, node.rhs, open_scope(node.body, ctx))
define(node.lhs, rhs, open_scope(node.body, ctx))
elseif istree(node)
return similarterm(node, operation(node), map(ctx, arguments(node)))
else
Expand Down
12 changes: 12 additions & 0 deletions test/reference32/representation/DenseRLELazy_representation.txt

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@ DenseRLELazy{Dense} representation:

5x5_falses: Bool[0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0]
tensor: Tensor(DenseRLE{Int32}(Dense{Int32}(Element{false, Bool, Int32}(Bool[0, 0, 0, 0, 0]), 5), 5, [1, 2], [5], Dense{Int32}(Element{false, Bool, Int32}(Bool[]), 5); merge = false))
countstored: 5
5x5_trues: Bool[1 1 1 1 1; 1 1 1 1 1; 1 1 1 1 1; 1 1 1 1 1; 1 1 1 1 1]
tensor: Tensor(DenseRLE{Int32}(Dense{Int32}(Element{false, Bool, Int32}(Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 5), 5, [1, 6], [1, 2, 3, 4, 5], Dense{Int32}(Element{false, Bool, Int32}(Bool[]), 5); merge = false))
countstored: 25
4x4_one_bool: Bool[0 0 0 1; 0 0 0 0; 1 0 0 0; 0 1 0 0]
tensor: Tensor(DenseRLE{Int32}(Dense{Int32}(Element{false, Bool, Int32}(Bool[0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]), 4), 4, [1, 5], [1, 2, 3, 4], Dense{Int32}(Element{false, Bool, Int32}(Bool[]), 4); merge = false))
countstored: 16
4x4_bool_mix: Bool[0 1 0 1; 0 0 0 0; 1 1 1 1; 0 1 0 1]
tensor: Tensor(DenseRLE{Int32}(Dense{Int32}(Element{false, Bool, Int32}(Bool[0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1]), 4), 4, [1, 5], [1, 2, 3, 4], Dense{Int32}(Element{false, Bool, Int32}(Bool[]), 4); merge = false))
countstored: 16
5x5_zeros: [0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0]
tensor: Tensor(DenseRLE{Int32}(Dense{Int32}(Element{0.0, Float64, Int32}([0.0, 0.0, 0.0, 0.0, 0.0]), 5), 5, [1, 2], [5], Dense{Int32}(Element{0.0, Float64, Int32}(Float64[]), 5); merge = false))
countstored: 5
5x5_ones: [1.0 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0 1.0]
tensor: Tensor(DenseRLE{Int32}(Dense{Int32}(Element{0.0, Float64, Int32}([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]), 5), 5, [1, 6], [1, 2, 3, 4, 5], Dense{Int32}(Element{0.0, Float64, Int32}(Float64[]), 5); merge = false))
countstored: 25
5x5_float_mix: [0.0 1.0 2.0 2.0 3.0; 0.0 0.0 0.0 0.0 0.0; 1.0 1.0 2.0 0.0 0.0; 0.0 0.0 0.0 3.0 0.0; 0.0 0.0 0.0 0.0 0.0]
tensor: Tensor(DenseRLE{Int32}(Dense{Int32}(Element{0.0, Float64, Int32}([0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 2.0, 0.0, 0.0, 2.0, 0.0, 0.0, 3.0, 0.0, 3.0, 0.0, 0.0, 0.0, 0.0]), 5), 5, [1, 6], [1, 2, 3, 4, 5], Dense{Int32}(Element{0.0, Float64, Int32}(Float64[]), 5); merge = false))
countstored: 25

Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@ DenseRLELazy{SparseList} representation:

5x5_falses: Bool[0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0]
tensor: Tensor(DenseRLE{Int32}(SparseList{Int32}(Element{false, Bool, Int32}(Bool[]), 5, [1, 1], Int32[]), 5, [1, 2], [5], SparseList{Int32}(Element{false, Bool, Int32}(Bool[]), 5, [1], Int32[]); merge = false))
countstored: 0
5x5_trues: Bool[1 1 1 1 1; 1 1 1 1 1; 1 1 1 1 1; 1 1 1 1 1; 1 1 1 1 1]
tensor: Tensor(DenseRLE{Int32}(SparseList{Int32}(Element{false, Bool, Int32}(Bool[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]), 5, [1, 6, 11, 16, 21, 26], [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5]), 5, [1, 6], [1, 2, 3, 4, 5], SparseList{Int32}(Element{false, Bool, Int32}(Bool[]), 5, [1], Int32[]); merge = false))
countstored: 25
4x4_one_bool: Bool[0 0 0 1; 0 0 0 0; 1 0 0 0; 0 1 0 0]
tensor: Tensor(DenseRLE{Int32}(SparseList{Int32}(Element{false, Bool, Int32}(Bool[1, 1, 1]), 4, [1, 2, 3, 3, 4], [3, 4, 1]), 4, [1, 5], [1, 2, 3, 4], SparseList{Int32}(Element{false, Bool, Int32}(Bool[]), 4, [1], Int32[]); merge = false))
countstored: 3
4x4_bool_mix: Bool[0 1 0 1; 0 0 0 0; 1 1 1 1; 0 1 0 1]
tensor: Tensor(DenseRLE{Int32}(SparseList{Int32}(Element{false, Bool, Int32}(Bool[1, 1, 1, 1, 1, 1, 1, 1]), 4, [1, 2, 5, 6, 9], [3, 1, 3, 4, 3, 1, 3, 4]), 4, [1, 5], [1, 2, 3, 4], SparseList{Int32}(Element{false, Bool, Int32}(Bool[]), 4, [1], Int32[]); merge = false))
countstored: 8
5x5_zeros: [0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0]
tensor: Tensor(DenseRLE{Int32}(SparseList{Int32}(Element{0.0, Float64, Int32}(Float64[]), 5, [1, 1], Int32[]), 5, [1, 2], [5], SparseList{Int32}(Element{0.0, Float64, Int32}(Float64[]), 5, [1], Int32[]); merge = false))
countstored: 0
5x5_ones: [1.0 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0 1.0; 1.0 1.0 1.0 1.0 1.0]
tensor: Tensor(DenseRLE{Int32}(SparseList{Int32}(Element{0.0, Float64, Int32}([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]), 5, [1, 6, 11, 16, 21, 26], [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5]), 5, [1, 6], [1, 2, 3, 4, 5], SparseList{Int32}(Element{0.0, Float64, Int32}(Float64[]), 5, [1], Int32[]); merge = false))
countstored: 25
5x5_float_mix: [0.0 1.0 2.0 2.0 3.0; 0.0 0.0 0.0 0.0 0.0; 1.0 1.0 2.0 0.0 0.0; 0.0 0.0 0.0 3.0 0.0; 0.0 0.0 0.0 0.0 0.0]
tensor: Tensor(DenseRLE{Int32}(SparseList{Int32}(Element{0.0, Float64, Int32}([1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0]), 5, [1, 2, 4, 6, 8, 9], [3, 1, 3, 1, 3, 1, 4, 1]), 5, [1, 6], [1, 2, 3, 4, 5], SparseList{Int32}(Element{0.0, Float64, Int32}(Float64[]), 5, [1], Int32[]); merge = false))
countstored: 8

Loading

0 comments on commit e6544d8

Please sign in to comment.