Skip to content

Commit

Permalink
Make flatten work on tables too
Browse files Browse the repository at this point in the history
Add a test for flatten

Fix eltype from table
  • Loading branch information
asinghvi17 committed Sep 22, 2024
1 parent c4c3a29 commit ffc87f1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/primitives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,13 @@ flatten(f, ::Type{Target}, geom) where {Target<:GI.AbstractTrait} = _flatten(f,

_flatten(f, ::Type{Target}, geom) where Target = _flatten(f, Target, GI.trait(geom), geom)
# Try to flatten over iterables
_flatten(f, ::Type{Target}, ::Nothing, iterable) where Target =
Iterators.flatten(Iterators.map(x -> _flatten(f, Target, x), iterable))
function _flatten(f, ::Type{Target}, ::Nothing, iterable) where Target
if Tables.istable(iterable)
Iterators.flatten(Iterators.map(x -> _flatten(f, Target, x), Tables.getcolumn(iterable, first(GI.geometrycolumns(iterable)))))

Check warning on line 478 in src/primitives.jl

View check run for this annotation

Codecov / codecov/patch

src/primitives.jl#L478

Added line #L478 was not covered by tests
else
Iterators.flatten(Iterators.map(x -> _flatten(f, Target, x), iterable))
end
end
# Flatten feature collections
function _flatten(f, ::Type{Target}, ::GI.FeatureCollectionTrait, fc) where Target
Iterators.map(GI.getfeature(fc)) do feature
Expand Down
19 changes: 19 additions & 0 deletions test/primitives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,25 @@ end
@test GO._tuple_point.(GO.flatten(GI.PointTrait, very_wrapped)) == vcat(pv1, pv2)
@test collect(GO.flatten(GI.AbstractCurveTrait, [poly])) == [lr1, lr2]
@test collect(GO.flatten(GI.x, GI.PointTrait, very_wrapped)) == first.(vcat(pv1, pv2))
@testset "flatten with tables" begin
# Construct a simple table with a geometry column
geom_column = [GI.Point(1.0,1.0), GI.Point(2.0,2.0), GI.Point(3.0,3.0)]
table = (geometry = geom_column, id = [1, 2, 3])

# Test flatten on the table
flattened = collect(GO.flatten(GI.PointTrait, table))

@test length(flattened) == 3
@test all(p isa GI.Point for p in flattened)
@test flattened == geom_column

# Test flatten with a function
flattened_coords = collect(GO.flatten(p -> (GI.x(p), GI.y(p)), GI.PointTrait, table))

@test length(flattened_coords) == 3
@test all(c isa Tuple{Float64,Float64} for c in flattened_coords)
@test flattened_coords == [(1.0,1.0), (2.0,2.0), (3.0,3.0)]
end
end

@testset "reconstruct" begin
Expand Down

0 comments on commit ffc87f1

Please sign in to comment.