Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add transaction support to write. #72

Merged
merged 4 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Options for the JuliaFormatter auto syntax formatting tool.
# https://domluna.github.io/JuliaFormatter.jl/stable/

whitespace_ops_in_indices = true
remove_extra_newlines = true
always_for_in = true
whitespace_typedefs = true

# And add other options we like:
separate_kwargs_with_semicolon = true
149 changes: 114 additions & 35 deletions src/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function find_driver(fn::AbstractString)
AG.extensiondriver(fn)
end

const lookup_type = Dict{Tuple{DataType,Int},AG.OGRwkbGeometryType}(
const lookup_type = Dict{Tuple{DataType, Int}, AG.OGRwkbGeometryType}(
(AG.GeoInterface.PointTrait, 2) => AG.wkbPoint,
(AG.GeoInterface.PointTrait, 3) => AG.wkbPoint25D,
(AG.GeoInterface.PointTrait, 4) => AG.wkbPointZM,
Expand All @@ -42,7 +42,6 @@ const lookup_type = Dict{Tuple{DataType,Int},AG.OGRwkbGeometryType}(
(AG.GeoInterface.MultiPolygonTrait, 4) => AG.wkbMultiPolygonZM,
)


"""
read(fn::AbstractString; kwargs...)
read(fn::AbstractString, layer::Union{Integer,AbstractString}; kwargs...)
Expand All @@ -61,7 +60,7 @@ function read(fn::AbstractString; kwargs...)
return t
end

function read(fn::AbstractString, layer::Union{Integer,AbstractString}; kwargs...)
function read(fn::AbstractString, layer::Union{Integer, AbstractString}; kwargs...)
startswith(fn, "/vsi") || occursin(":", fn) || isfile(fn) || error("File not found.")
t = AG.read(fn; kwargs...) do ds
return read(ds, layer)
Expand All @@ -72,7 +71,11 @@ end
function read(ds, layer)
df, gnames, sr = AG.getlayer(ds, layer) do table
if table.ptr == C_NULL
throw(ArgumentError("Given layer id/name doesn't exist. For reference this is the dataset:\n$ds"))
throw(
ArgumentError(
"Given layer id/name doesn't exist. For reference this is the dataset:\n$ds",
),
)
end
names, _ = AG.schema_names(AG.getfeaturedefn(first(table)))
sr = AG.getspatialref(table)
Expand All @@ -84,12 +87,12 @@ function read(ds, layer)
end
crs = sr.ptr == C_NULL ? nothing : GFT.WellKnownText(GFT.CRS(), AG.toWKT(sr))
geometrycolumns = Tuple(gnames)
metadata!(df, "crs", crs, style=:default)
metadata!(df, "geometrycolumns", geometrycolumns, style=:default)
metadata!(df, "crs", crs; style = :default)
metadata!(df, "geometrycolumns", geometrycolumns; style = :default)

# Also add the GEOINTERFACE:property as a namespaced thing
metadata!(df, "GEOINTERFACE:crs", crs, style=:default)
metadata!(df, "GEOINTERFACE:geometrycolumns", geometrycolumns, style=:default)
metadata!(df, "GEOINTERFACE:crs", crs; style = :default)
metadata!(df, "GEOINTERFACE:geometrycolumns", geometrycolumns; style = :default)
return df
end

Expand All @@ -98,12 +101,24 @@ end

Write the provided `table` to `fn`. The `geom_column` is expected to hold ArchGDAL geometries.
"""
function write(fn::AbstractString, table; layer_name::AbstractString="data", crs::Union{GFT.GeoFormat,Nothing}=getcrs(table), driver::Union{Nothing,AbstractString}=nothing, options::Dict{String,String}=Dict{String,String}(), geom_columns=getgeometrycolumns(table), kwargs...)
function write(
fn::AbstractString,
table;
layer_name::AbstractString = "data",
crs::Union{GFT.GeoFormat, Nothing} = getcrs(table),
driver::Union{Nothing, AbstractString} = nothing,
options::Dict{String, String} = Dict{String, String}(),
geom_columns = getgeometrycolumns(table),
chunksize = 20_000,
kwargs...,
)
rows = Tables.rows(table)
sch = Tables.schema(rows)

# Determine geometry columns
isnothing(geom_columns) && error("Please set `geom_columns` kw or define `GeoInterface.geometrycolumns` for $(typeof(table))")
isnothing(geom_columns) && error(
"Please set `geom_columns` kw or define `GeoInterface.geometrycolumns` for $(typeof(table))",
)
if :geom_column in keys(kwargs) # backwards compatible
geom_columns = (kwargs[:geom_column],)
end
Expand All @@ -113,7 +128,11 @@ function write(fn::AbstractString, table; layer_name::AbstractString="data", crs
trait = AG.GeoInterface.geomtrait(getproperty(first(rows), geom_column))
ndim = AG.GeoInterface.ncoord(getproperty(first(rows), geom_column))
geom_type = get(lookup_type, (typeof(trait), ndim), nothing)
isnothing(geom_type) && throw(ArgumentError("Can't convert $trait with $ndim dimensions of column $geom_column to ArchGDAL yet."))
isnothing(geom_type) && throw(
ArgumentError(
"Can't convert $trait with $ndim dimensions of column $geom_column to ArchGDAL yet.",
),
)
push!(geom_types, geom_type)
end

Expand All @@ -130,10 +149,12 @@ function write(fn::AbstractString, table; layer_name::AbstractString="data", crs
end

# Figure out attributes
fields = Vector{Tuple{Symbol,DataType}}()
fields = Vector{Tuple{Symbol, DataType}}()
for (name, type) in zip(sch.names, sch.types)
if !(name in geom_columns)
AG.GeoInterface.isgeometry(type) && error("Did you mean to use the `geom_columns` argument to specify $name is a geometry?")
AG.GeoInterface.isgeometry(type) && error(
"Did you mean to use the `geom_columns` argument to specify $name is a geometry?",
)
types = Base.uniontypes(type)
if length(types) == 1
push!(fields, (Symbol(name), type))
Expand All @@ -144,47 +165,105 @@ function write(fn::AbstractString, table; layer_name::AbstractString="data", crs
end
end
end
AG.create(
fn,
driver=driver
) do ds
AG.create(fn; driver = driver) do ds
AG.newspatialref() do spatialref
crs !== nothing && AG.importCRS!(spatialref, crs)
AG.createlayer(
name=layer_name,
geom=first(geom_types), # how to set the name though?
spatialref=spatialref,
options=stringlist(options)

can_create_layer = AG.testcapability(ds, "CreateLayer")
can_use_transaction = AG.testcapability(ds, "Transactions")

AG.createlayer(;
name = layer_name,
dataset = can_create_layer ? ds : AG.create(AG.getdriver("Memory")),
geom = first(geom_types), # how to set the name though?
spatialref = spatialref,
options = stringlist(options),
) do layer
for (i, (geom_column, geom_type)) in enumerate(zip(geom_columns, geom_types))
for (i, (geom_column, geom_type)) in
enumerate(zip(geom_columns, geom_types))
if i > 1
AG.writegeomdefn!(layer, string(geom_column), geom_type)
end
end
fieldindices = Int[]
for (name, type) in fields
AG.createfielddefn(String(name), convert(AG.OGRFieldType, type)) do fd
AG.setsubtype!(fd, convert(AG.OGRFieldSubType, type))
AG.addfielddefn!(layer, fd)
end
push!(fieldindices, AG.findfieldindex(layer, name, false))
end
for row in rows
AG.createfeature(layer) do feature
for (i, (geom_column)) in enumerate(geom_columns)
AG.setgeom!(feature, i - 1, GeoInterface.convert(AG.IGeometry, getproperty(row, geom_column)))
end
for (name, _) in fields
field = getproperty(row, name)
if !ismissing(field)
AG.setfield!(feature, AG.findfieldindex(feature, name), getproperty(row, name))
else
AG.GDAL.ogr_f_setfieldnull(feature.ptr, AG.findfieldindex(feature, name))

for chunk in Iterators.partition(rows, chunksize)
can_use_transaction &&
AG.GDAL.gdaldatasetstarttransaction(ds.ptr, false)

for row in chunk
AG.addfeature(layer) do feature
for (i, (geom_column)) in enumerate(geom_columns)
AG.GDAL.ogr_f_setgeomfielddirectly(
feature.ptr,
i - 1,
_convert(
AG.Geometry,
Tables.getcolumn(row, geom_column),
),
)
end
for (i, (name, _)) in zip(fieldindices, fields)
field = Tables.getcolumn(row, name)
if !ismissing(field)
AG.setfield!(feature, i, field)
else
AG.GDAL.ogr_f_setfieldnull(feature.ptr, i)
end
end
end
end
if can_use_transaction
try
AG.GDAL.gdaldatasetcommittransaction(ds.ptr)
catch e
e isa AG.GDAL.GDALError &&
AG.GDAL.gdaldatasetrollbacktransaction(ds.ptr)
rethrow(e)
end
end
end
if !can_create_layer
@warn "Can't create layers in this format, copying from memory instead."
AG.copy(
layer;
dataset = ds,
name = layer_name,
options = stringlist(options),
)
end
AG.copy(layer, dataset=ds, name=layer_name, options=stringlist(options))
end
end
end
fn
end

# This should be upstreamed to ArchGDAL
const lookup_method = Dict{DataType, Function}(
GeoInterface.PointTrait => AG.unsafe_createpoint,
GeoInterface.MultiPointTrait => AG.unsafe_createmultipoint,
GeoInterface.LineStringTrait => AG.unsafe_createlinestring,
GeoInterface.LinearRingTrait => AG.unsafe_createlinearring,
GeoInterface.MultiLineStringTrait => AG.unsafe_createmultilinestring,
GeoInterface.PolygonTrait => AG.unsafe_createpolygon,
GeoInterface.MultiPolygonTrait => AG.unsafe_createmultipolygon,
)

function _convert(::Type{T}, geom) where {T <: AG.Geometry}
f = get(lookup_method, typeof(GeoInterface.geomtrait(geom)), nothing)
isnothing(f) && error(
"Cannot convert an object of $(typeof(geom)) with the $(typeof(type)) trait (yet). Please report an issue.",
)
return f(GeoInterface.coordinates(geom))
end

function _convert(::Type{T}, geom::AG.IGeometry) where {T <: AG.Geometry}
return AG.unsafe_clone(geom)
end
5 changes: 3 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function stringlist(dict::Dict{String,String})
function stringlist(dict::Dict{String, String})
sv = Vector{String}()
for (k, v) in pairs(dict)
push!(sv, uppercase(string(k)) * "=" * string(v))
Expand Down Expand Up @@ -81,5 +81,6 @@ end

# Since `DataFrameRow` is simply a view of a DataFrame, we can reach back
# to the original DataFrame to get the metadata.
GeoInterface.geometrycolumns(row::DataFrameRow) = GeoInterface.geometrycolumns(getfield(row, :df)) # get the parent of the row view
GeoInterface.geometrycolumns(row::DataFrameRow) =
GeoInterface.geometrycolumns(getfield(row, :df)) # get the parent of the row view
GeoInterface.crs(row::DataFrameRow) = GeoInterface.crs(getfield(row, :df)) # get the parent of the row view
Loading
Loading