Skip to content

Commit

Permalink
standardise chunking with chunks keyword
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz committed Apr 13, 2024
1 parent e2fa25b commit 4d7363f
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 68 deletions.
48 changes: 28 additions & 20 deletions ext/RastersArchGDALExt/gdal_source.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ const AG = ArchGDAL

const GDAL_LOCUS = Start()

const GDAL_DIM_ORDER = (X(), Y(), Band())

# drivers supporting the gdal Create() method to directly write to disk
const GDAL_DRIVERS_SUPPORTING_CREATE = ("GTiff", "HDF4", "KEA", "netCDF", "PCIDSK", "Zarr", "MEM"#=...=#)

Expand Down Expand Up @@ -350,7 +352,7 @@ function _create_with_driver(f, filename, dims::Tuple, T, missingval;
options=Dict{String,String}(),
driver="",
_block_template=nothing,
chunks=true,
chunks=nokw,
kw...
)
_gdal_validate(dims)
Expand All @@ -363,7 +365,7 @@ function _create_with_driver(f, filename, dims::Tuple, T, missingval;
nbands = hasdim(dims, Band) ? length(DD.dims(dims, Band())) : 1

driver = _check_driver(filename, driver)
options_vec = _process_options(driver, options; _block_template)
options_vec = _process_options(driver, options; _block_template, chunks)
gdaldriver = driver isa String ? AG.getdriver(driver) : driver

create_kw = (; width=length(x), height=length(y), nbands, dtype=T,)
Expand Down Expand Up @@ -400,7 +402,10 @@ end
end

# Convert a Dict of options to a Vector{String} for GDAL
function _process_options(driver::String, options::Dict; _block_template=nothing)
function _process_options(driver::String, options::Dict;
chunks=nokw,
_block_template=nothing
)
options_str = Dict(string(k)=>string(v) for (k,v) in options)
# Get the GDAL driver object
gdaldriver = AG.getdriver(driver)
Expand All @@ -413,19 +418,21 @@ function _process_options(driver::String, options::Dict; _block_template=nothing
# the goal is to set write block sizes that correspond to eventually blocked reads
# creation options are driver dependent

if !isnothing(_block_template) && DA.haschunks(_block_template) == DA.Chunked()
block_x, block_y = DA.max_chunksize(DA.eachchunk(_block_template))

# GDAL default is line-by-line compression without tiling.
# Here, tiling is enabled if the source chunk size is viable for GTiff,
# i.e. when the chunk size is divisible by 16.
if (block_x % 16 == 0) && (block_y % 16 == 0)
options_str["TILED"] = "YES"
end
chunk_pattern = RA._chunks_to_tuple(_block_template, (X(), Y(), Band()), chunks)
if !isnothing(chunk_pattern)
xchunksize, ychunksize = chunk_pattern

block_x, block_y = string.((block_x, block_y))
block_x, block_y = string.((xchunksize, ychunksize))

if driver == "GTiff"
# GDAL default is line-by-line compression without tiling.
# Here, tiling is enabled if the source chunk size is viable for GTiff,
# i.e. when the chunk size is divisible by 16.
if (xchunksize % 16 == 0) && (ychunksize % 16 == 0)
options_str["TILED"] = "YES"
else
xchunksize == 1 || @warn "X and Y chunk size do not match. Columns are used and X size $xchunksize is ignored"
end
# dont overwrite user specified values
if !("BLOCKXSIZE" in keys(options_str))
options_str["BLOCKXSIZE"] = block_x
Expand All @@ -435,10 +442,11 @@ function _process_options(driver::String, options::Dict; _block_template=nothing
end
elseif driver == "COG"
if !("BLOCKSIZE" in keys(options_str))
# cog only supports square blocks
# if the source already has square blocks, use them
# otherwise use the driver default
options_str["BLOCKSIZE"] = block_x == block_y ? block_x : "512"
if xchunksize == ychunksize
options_str["BLOCKSIZE"] = block_x
else
@warn "Writing COG X and Y chunks do not match: $block_x, $block_y. Default of 512, 512 used."
end
end
end
end
Expand Down Expand Up @@ -470,9 +478,9 @@ function _bandnames(rds::AG.RasterDataset, nbands=AG.nraster(rds))
end
end

function _gdalmetadata(dataset::AG.Dataset, key)
function _gdalmetadata(dataset::AG.Dataset, name)
meta = AG.metadata(dataset)
regex = Regex("$key=(.*)")
regex = Regex("$name=(.*)")
i = findfirst(f -> occursin(regex, f), meta)
if i isa Nothing
return ""
Expand Down Expand Up @@ -560,7 +568,7 @@ function _extensiondriver(filename::AbstractString)
end

# Permute dims unless they match the normal GDAL dimension order
_maybe_permute_to_gdal(A) = _maybe_permute_to_gdal(A, DD.dims(A, (X, Y, Band)))
_maybe_permute_to_gdal(A) = _maybe_permute_to_gdal(A, DD.dims(A, GDAL_DIM_ORDER))
_maybe_permute_to_gdal(A, dims::Tuple) = A
_maybe_permute_to_gdal(A, dims::Tuple{<:XDim,<:YDim,<:Band}) = permutedims(A, dims)
_maybe_permute_to_gdal(A, dims::Tuple{<:XDim,<:YDim}) = permutedims(A, dims)
Expand Down
8 changes: 2 additions & 6 deletions ext/RastersNCDatasetsExt/ncdatasets_source.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@ const UNNAMED_NCD_FILE_KEY = "unnamed"

const NCDAllowedType = Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64,Float32,Float64,Char,String}

## Keywords

$NCD_WRITE_KEYWORDS

function Base.write(filename::AbstractString, ::NCDsource, A::AbstractRaster;
append=false,
force=false,
Expand Down Expand Up @@ -71,7 +67,7 @@ function _writevar!(ds::AbstractDataset, A::AbstractRaster{T,N};
verbose=true,
missingval=nokw,
chunks=nokw,
chunksizes=chunks,
chunksizes=RA._chunks_to_tuple(A, dims(A), chunks),
kw...
) where {T,N}
missingval = missingval isa NoKW ? Rasters.missingval(A) : missingval
Expand Down Expand Up @@ -105,7 +101,7 @@ function _writevar!(ds::AbstractDataset, A::AbstractRaster{T,N};
end

dimnames = lowercase.(string.(map(RA.name, dims(A))))
var = NCD.defVar(ds, key, eltyp, dimnames; attrib=attrib, kw...) |> RA.CFDiskArray
var = NCD.defVar(ds, key, eltyp, dimnames; attrib=attrib, chunksizes, kw...) |> RA.CFDiskArray

# Write with a DiskArays.jl broadcast
var .= A
Expand Down
6 changes: 3 additions & 3 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,13 @@ function Raster(ds, filename::AbstractString;
)::Raster
name1 = filekey(ds, name)
source = _sourcetrait(filename, source)
data1, dims1, metadata1, missingval1 = _open(source, ds; key=name1) do var
data1, dims1, metadata1, missingval1 = _open(source, ds; name=name1) do var
metadata1 = metadata isa NoKW ? _metadata(var) : metadata
missingval1 = _check_missingval(var, missingval)
replace_missing1 = replace_missing && !isnothing(missingval1)
missingval2 = replace_missing1 ? missing : missingval1
data = if lazy
A = FileArray{typeof(source)}(var, filename; key=name1, write)
A = FileArray{typeof(source)}(var, filename; name=name1, write)
replace_missing1 ? _replace_missing(A, missingval1) : A
else
_checkmem(var)
Expand Down Expand Up @@ -354,7 +354,7 @@ function _replace_missing(A::AbstractArray{T}, missingval) where T
return repmissing.(A)
end

filekey(ds, key) = key
filekey(ds, name) = name
filekey(filename::String) = Symbol(splitext(basename(filename))[1])

DD.dimconstructor(::Tuple{<:Dimension{<:AbstractProjected},Vararg{<:Dimension}}) = Raster
Expand Down
10 changes: 5 additions & 5 deletions src/sources/commondatamodel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,23 +101,23 @@ end

function FileStack{source}(
ds::AbstractDataset, filename::AbstractString;
write::Bool=false, keys::NTuple{N,Symbol}, vars
write::Bool=false, name::NTuple{N,Symbol}, vars
) where {source<:CDMsource,N}
layertypes = map(var -> Union{Missing,eltype(var)}, vars)
layersizes = map(size, vars)
eachchunk = map(_get_eachchunk, vars)
haschunks = map(_get_haschunks, vars)
return FileStack{source,keys}(filename, layertypes, layersizes, eachchunk, haschunks, write)
return FileStack{source,name}(filename, layertypes, layersizes, eachchunk, haschunks, write)
end

function Base.open(f::Function, A::FileArray{source}; write=A.write, kw...) where source<:CDMsource
_open(source(), filename(A); key=key(A), write, kw...) do var
_open(source(), filename(A); name=name(A), write, kw...) do var
f(var)
end
end

function _open(f, ::CDMsource, ds::AbstractDataset; key=nokw, kw...)
x = key isa NoKW ? ds : CFDiskArray(ds[_firstname(ds, key)])
function _open(f, ::CDMsource, ds::AbstractDataset; name=nokw, kw...)
x = name isa NoKW ? ds : CFDiskArray(ds[_firstname(ds, name)])
cleanreturn(f(x))
end
_open(f, ::CDMsource, var::CFDiskArray; kw...) = cleanreturn(f(var))
Expand Down
8 changes: 4 additions & 4 deletions src/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -464,13 +464,13 @@ function _layer_stack(filename;
dims = _sort_by_layerdims(dims isa NoKW ? _dims(ds, dimdict) : dims, layerdims)
layermetadata = layermetadata isa NoKW ? _layermetadata(ds; layers) : layermetadata
missingval = missingval isa NoKW ? Rasters.missingval(ds) : missingval
tuplekeys = Tuple(map(Symbol, layers.keys))
names = Tuple(map(Symbol, layers.names))
data = if lazy
FileStack{typeof(source)}(ds, filename; keys=tuplekeys, vars=Tuple(layers.vars))
FileStack{typeof(source)}(ds, filename; name=names, vars=Tuple(layers.vars))
else
NamedTuple{tuplekeys}(map(Array, layers.vars))
NamedTuple{names}(map(Array, layers.vars))
end
data, (; dims, refdims, layerdims=NamedTuple{tuplekeys}(layerdims), metadata, layermetadata=NamedTuple{tuplekeys}(layermetadata), missingval)
data, (; dims, refdims, layerdims=NamedTuple{names}(layerdims), metadata, layermetadata=NamedTuple{names}(layermetadata), missingval)
end
return RasterStack(data; field_kw..., kw...)
end
Expand Down
41 changes: 41 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,44 @@ function _run(f, range::OrdinalRange, threaded::Bool, progress::Bool, desc::Stri
end
end
end

# NoKW means true
@inline function _chunks_to_tuple(template, dims, chunks::Bool)
if chunks == true
if template isa AbstractArray && DA.haschunks(template) == DA.Chunked()
# Get chunks from the template
DA.max_chunksize(DA.eachchunk(template))
else
# Use defaults
_chunks_to_tuple(template, dims, (X(512), Y(512)))
end
else
nothing
end
end
@inline function _chunks_to_tuple(template, dimorder, chunks::NTuple{N,Integer}) where N
n = length(dimorder)
if n < N
throw(ArgumentError("Length $n tuple needed for `chunks`, got $N"))
elseif n > N
(chunks..., ntuple(_ -> 1, Val{n-N}())...)
else
chunks
end
end
@inline function _chunks_to_tuple(template, dimorder, chunks::DimTuple)
size_one_chunk_axes = map(d -> rebuild(d, 1), otherdims(dimorder, chunks))
alldims = (chunks..., size_one_chunk_axes...)
int_chunks = map(val, dims(alldims, dimorder))
if !isnothing(template)
if !all(map(>=, size(template), int_chunks))
@warn "Chunks $int_chunks larger than array size $(size(template)). Using defaults."
return nothing
end
end
return int_chunks
end
@inline _chunks_to_tuple(template, dimorder, chunks::NamedTuple) =
_chunks_to_tuple(template, dimorder, DD.kw2dims(chunks))
@inline _chunks_to_tuple(template, dimorder, chunks::Nothing) = nothing
@inline _chunks_to_tuple(template, dims, chunks::NoKW) = nothing
5 changes: 3 additions & 2 deletions src/write.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const CHUNS_KEYWORD = """

const CHUNKS_KEYWORD = """
- `chunks`: a `NTuple{N,Int}` specifying the chunk size for each dimension.
To specify only specific dimensions, a Tuple of `Dimension` wrapping `Int`
or a `NamedTuple` of `Int` can be used. Other dimensions will have a chunk
Expand All @@ -25,7 +26,7 @@ Other keyword arguments are passed to the `write` method for the backend.
## GDAL Keywords
$(RA.FORCE_KEYWORD)
$FORCE_KEYWORD
- `driver`: A GDAL driver name `String` or a GDAL driver retrieved via `ArchGDAL.getdriver(drivername)`.
By default `driver` is guessed from the filename extension.
- `options::Dict{String,String}`: A dictionary containing the dataset creation options passed to the driver.
Expand Down
11 changes: 11 additions & 0 deletions test/sources/gdal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,17 @@ gdalpath = maybedownload(url)
@test_throws ArgumentError write(filename_gtiff2, gdalarray; driver="GTiff", options=Dict("COMPRESS"=>"FOOBAR"))
end

@testset "chunks" begin
filename = tempname() * ".tiff"
write(filename, gdalarray; chunks=(128, 128, 1))
gdalarray2 = Raster(filename; lazy=true)
@test DiskArrays.eachchunk(gdalarray2)[1] == (1:128, 1:128)
filename = tempname() * ".tiff"
@test_warn "X and Y chunks do not match" write(filename, gdalarray; chunks=(128, 256, 1), driver="COG")
gdalarray2 = Raster(filename; lazy=true)
@test DiskArrays.eachchunk(gdalarray2)[1] == (1:512, 1:512)
end

@testset "resave current" begin
filename = tempname() * ".rst"
write(filename, gdalarray)
Expand Down
67 changes: 40 additions & 27 deletions test/sources/ncdatasets.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Rasters, DimensionalData, Test, Statistics, Dates, CFTime, Plots

using Rasters.Lookups, Rasters.Dimensions
using Rasters.DiskArrays
import ArchGDAL, NCDatasets
using Rasters: FileArray, FileStack, NCDsource, crs, bounds, name, trim
testdir = realpath(joinpath(dirname(pathof(Rasters)), "../test"))
Expand Down Expand Up @@ -51,11 +52,11 @@ end
@test_throws ArgumentError Raster("notafile.nc")

@testset "lazyness" begin
@time read(Raster(ncsingle));
# Eager is the default
@test parent(ncarray) isa Array
@test parent(lazyarray) isa FileArray
@test parent(eagerarray) isa Array
@time read(lazyarray);
end

@testset "from url" begin
Expand Down Expand Up @@ -237,17 +238,14 @@ end

@testset "write" begin
@testset "to netcdf" begin
# TODO save and load subset
geoA = read(ncarray)
@test size(geoA) == size(ncarray)
filename = tempname() * ".nc"
write(filename, geoA)
write(filename, ncarray)
@testset "CF attributes" begin
@test NCDatasets.Dataset(filename)[:x].attrib["axis"] == "X"
@test NCDatasets.Dataset(filename)[:x].attrib["bounds"] == "x_bnds"
# TODO better units and standard name handling
end
saved = read(Raster(filename))
saved = Raster(filename)
@test size(saved) == size(geoA)
@test refdims(saved) == refdims(geoA)
@test missingval(saved) === missingval(geoA)
Expand All @@ -269,27 +267,42 @@ end
@test saved isa typeof(geoA)
# TODO test crs

# test for nc `kw...`
geoA = read(ncarray)
write("tos.nc", geoA; force=true) # default `deflatelevel = 0`
write("tos_small.nc", geoA; deflatelevel=2)
@test filesize("tos_small.nc") * 1.5 < filesize("tos.nc") # compress ratio >= 1.5
isfile("tos.nc") && rm("tos.nc")
isfile("tos_small.nc") && rm("tos_small.nc")

# test for nc `append`
n = 100
x = rand(n, n)
r1 = Raster(x, (X, Y); name = "v1")
r2 = Raster(x, (X, Y); name = "v2")
fn = "test.nc"
isfile(fn) && rm(fn)
write(fn, r1, append=false)
size1 = filesize(fn)
write(fn, r2; append=true)
size2 = filesize(fn)
@test size2 > size1*1.8 # two variable
isfile(fn) && rm(fn)
@testset "chunks" begin
filename = tempname() * ".nc"
write(filename, ncarray; chunks=(64, 64))
@test DiskArrays.eachchunk(Raster(filename; lazy=true))[1] == (1:64, 1:64, 1:1)
filename = tempname() * ".nc"
write(filename, ncarray; chunks=(X=16, Y=10, Ti=8))
@test DiskArrays.eachchunk(Raster(filename; lazy=true))[1] == (1:16, 1:10, 1:8)
filename = tempname() * ".nc"
@test_warn "larger than array size" write(filename, ncarray; chunks=(X=1000, Y=10, Ti=8))
# No chunks
@test DiskArrays.haschunks(Raster(filename; lazy=true)) isa DiskArrays.Unchunked
@test DiskArrays.eachchunk(Raster(filename; lazy=true))[1] == map(Base.OneTo, size(ncarray))
end

@testset "deflatelevel" begin
write("tos.nc", ncarray; force=true) # default `deflatelevel = 0`
write("tos_small.nc", ncarray; deflatelevel=2)
@test filesize("tos_small.nc") * 1.5 < filesize("tos.nc") # compress ratio >= 1.5
isfile("tos.nc") && rm("tos.nc")
isfile("tos_small.nc") && rm("tos_small.nc")
end

@testset "append" begin
n = 100
x = rand(n, n)
r1 = Raster(x, (X, Y); name = "v1")
r2 = Raster(x, (X, Y); name = "v2")
fn = "test.nc"
isfile(fn) && rm(fn)
write(fn, r1, append=false)
size1 = filesize(fn)
write(fn, r2; append=true)
size2 = filesize(fn)
@test size2 > size1*1.8 # two variable
isfile(fn) && rm(fn)
end

@testset "non allowed values" begin
# TODO return this test when the changes in NCDatasets.jl settle
Expand Down
Loading

0 comments on commit 4d7363f

Please sign in to comment.