diff --git a/src/batchview.jl b/src/batchview.jl index a26bd5b..2529206 100644 --- a/src/batchview.jl +++ b/src/batchview.jl @@ -95,7 +95,7 @@ struct BatchView{TElem,TData,TCollate} <: AbstractDataContainer batchsize::Int count::Int partial::Bool - collate::TCollate + collate::TCollate # either Val(nothing), Val(false), or a function end function BatchView(data::T; batchsize::Int=1, partial::Bool=true, collate=Val(nothing)) where {T} diff --git a/src/eachobs.jl b/src/eachobs.jl index b8e908d..3c582e2 100644 --- a/src/eachobs.jl +++ b/src/eachobs.jl @@ -60,7 +60,7 @@ The original data is preserved in the `data` field of the DataLoader. containing `batchsize` observations. Default `1`. - **`buffer`**: If `buffer=true` and supported by the type of `data`, a buffer will be allocated and reused for memory efficiency. - You can also pass a preallocated object to `buffer`. Default `false`. + May want to set `partial=false` to avoid size mismatch. Default `false`. - **`collate`**: Defines the batching behavior. Default `nothing`. - If `nothing` , a batch is `getobs(data, indices)`. - If `false`, each batch is `[getobs(data, i) for i in indices]`. @@ -147,15 +147,14 @@ end function DataLoader( data; - buffer = false, - parallel = false, - shuffle = false, + buffer::Bool = false, + parallel::Bool = false, + shuffle::Bool = false, batchsize::Int = 1, partial::Bool = true, collate = Val(nothing), rng::AbstractRNG = Random.default_rng()) - buffer = buffer isa Bool ? buffer : true if collate isa Bool || collate === nothing collate = Val(collate) end