Skip to content

Commit

Permalink
fix for buffer DataLoader
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jan 25, 2025
1 parent 76b29d4 commit 32e2063
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/batchview.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct BatchView{TElem,TData,TCollate} <: AbstractDataContainer
batchsize::Int
count::Int
partial::Bool
collate::TCollate
collate::TCollate # either Val(nothing), Val(false), or a function
end

function BatchView(data::T; batchsize::Int=1, partial::Bool=true, collate=Val(nothing)) where {T}
Expand Down
9 changes: 4 additions & 5 deletions src/eachobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ The original data is preserved in the `data` field of the DataLoader.
containing `batchsize` observations. Default `1`.
- **`buffer`**: If `buffer=true` and supported by the type of `data`,
a buffer will be allocated and reused for memory efficiency.
You can also pass a preallocated object to `buffer`. Default `false`.
May want to set `partial=false` to avoid size mismatch. Default `false`.
- **`collate`**: Defines the batching behavior. Default `nothing`.
- If `nothing` , a batch is `getobs(data, indices)`.
- If `false`, each batch is `[getobs(data, i) for i in indices]`.
Expand Down Expand Up @@ -147,15 +147,14 @@ end

function DataLoader(
data;
buffer = false,
parallel = false,
shuffle = false,
buffer::Bool = false,
parallel::Bool = false,
shuffle::Bool = false,
batchsize::Int = 1,
partial::Bool = true,
collate = Val(nothing),
rng::AbstractRNG = Random.default_rng())

buffer = buffer isa Bool ? buffer : true
if collate isa Bool || collate === nothing
collate = Val(collate)
end
Expand Down

0 comments on commit 32e2063

Please sign in to comment.