From 56f29dc0da829bd7d0eb5cb1f81da7a6eba82fff Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 28 Jan 2025 15:35:38 +0100 Subject: [PATCH 1/4] rework dataloader --- src/MLUtils.jl | 10 +-- src/dataloader.jl | 204 +++++++++++++++++++++----------------------- src/obstransform.jl | 29 ++++--- src/parallel.jl | 10 +-- 4 files changed, 122 insertions(+), 131 deletions(-) diff --git a/src/MLUtils.jl b/src/MLUtils.jl index 4c344b5..ec7bba4 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -36,8 +36,10 @@ export mapobs, shuffleobs include("batchview.jl") -export batchsize, - BatchView +export batchsize, BatchView + +include("obsview.jl") +export obsview, ObsView include("dataloader.jl") export eachobs, DataLoader @@ -48,10 +50,6 @@ include("folds.jl") export kfolds, leavepout -include("obsview.jl") -export obsview, - ObsView - include("randobs.jl") export randobs diff --git a/src/dataloader.jl b/src/dataloader.jl index 0107835..b4a6a60 100644 --- a/src/dataloader.jl +++ b/src/dataloader.jl @@ -136,7 +136,7 @@ julia> first(DataLoader(["a", "b", "c", "d"], batchsize=2, collate=collate_fn)) "ab" ``` """ -struct DataLoader{T,B,C,R<:AbstractRNG} +struct DataLoader{T<:Union{ObsView,BatchView},B,C,R<:AbstractRNG} data::T batchsize::Int buffer::B # boolean, or external buffer @@ -157,54 +157,40 @@ function DataLoader( collate = Val(nothing), rng::AbstractRNG = Random.default_rng()) - if !(buffer isa Bool) && parallel - throw(ArgumentError("If `parallel=true`, `buffer` must be a boolean.")) - end - if collate isa Bool || collate === nothing collate = Val(collate) end - return DataLoader(data, batchsize, buffer, partial, shuffle, parallel, collate, rng) -end - -function Base.iterate(d::DataLoader) - # TODO move ObsView and BatchWView wrapping to the constructor, so that - # we can parametrize the DataLoader with ObsView and BatchView and define specialized methods. # Wrapping with ObsView in order to work around - # issue https://github.com/FluxML/Flux.jl/issues/1935 - data = ObsView(d.data) + # issue https://github.com/FluxML/Flux.jl/issues/1935 + data = ObsView(data) + if batchsize > 0 + data = BatchView(data; batchsize, partial, collate) + end + + if buffer == true + buffer = _create_buffer(data) + end + # for buffer == false and external buffer, we keep as is - data = d.shuffle ? shuffleobs(d.rng, data) : data - data = d.batchsize > 0 ? BatchView(data; d.batchsize, d.partial, d.collate) : data + return DataLoader(data, batchsize, buffer, partial, shuffle, parallel, collate, rng) +end +function Base.iterate(d::DataLoader) + data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data if d.parallel iter = eachobsparallel(data; d.buffer) else if d.buffer == false iter = (getobs(data, i) for i in 1:numobs(data)) - elseif d.buffer == true - buf = create_buffer(data) - iter = (getobs!(buf, data, i) for i in 1:numobs(data)) - else # external buffer - buf = d.buffer - iter = (getobs!(buf, data, i) for i in 1:numobs(data)) + else + iter = (getobs!(d.buffer, data, i) for i in 1:numobs(data)) end end obs, state = iterate(iter) return obs, (iter, state) end -create_buffer(x) = getobs(x, 1) -function create_buffer(x::BatchView) - obsindices = _batchrange(x, 1) - return [getobs(A.data, idx) for idx in enumerate(obsindices)] -end -function create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData} - obsindices = _batchrange(x, 1) - return getobs(x.data, obsindices) -end - function Base.iterate(::DataLoader, (iter, state)) ret = iterate(iter, state) isnothing(ret) && return @@ -212,19 +198,31 @@ function Base.iterate(::DataLoader, (iter, state)) return obs, (iter, state) end +# recursively unwraps ObsView and BatchView +_unwrapdata(data::BatchView) = _unwrapdata(data.data) +_unwrapdata(data::ObsView) = _unwrapdata(data.data) +_unwrapdata(data) = data -function Base.length(d::DataLoader) - if d.batchsize > 0 - return numobs(BatchView(d.data; d.batchsize, d.partial)) - else - return numobs(d.data) - end -end +_shuffledata(rng, data::ObsView) = shuffleobs(rng, data) + +_shuffledata(rng, data::BatchView) = + BatchView(shuffleobs(rng, data.data); data.batchsize, data.partial, data.collate) -Base.size(e::DataLoader) = (length(e),) +_create_buffer(x) = getobs(x, 1) + +function _create_buffer(x::BatchView) + obsindices = _batchrange(x, 1) + return [getobs(A.data, idx) for idx in enumerate(obsindices)] +end +function _create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData} + obsindices = _batchrange(x, 1) + return getobs(x.data, obsindices) +end -Base.IteratorEltype(::DataLoader) = Base.EltypeUnknown() +Base.length(d::DataLoader) = numobs(d.data) +Base.size(d::DataLoader) = (length(d),) +Base.IteratorEltype(d::DataLoader) = Base.EltypeUnknown() ## This causes error in some cases of `collect(loader)` # function Base.eltype(e::DataLoader) @@ -288,42 +286,71 @@ function mapobs(f, d::DataLoader) collate = f ∘ d.collate end - DataLoader(d.data, - batchsize=d.batchsize, - buffer=d.buffer, - partial=d.partial, - shuffle=d.shuffle, - parallel=d.parallel, - collate=collate, - rng=d.rng) + return DataLoader(_unwrapdata(d.data); + batchsize=d.batchsize, + buffer=d.buffer, + partial=d.partial, + shuffle=d.shuffle, + parallel=d.parallel, + collate=collate, + rng=d.rng) end -@inline function _dataloader_foldl1(rf, val, e::DataLoader, data) - if e.shuffle - _dataloader_foldl2(rf, val, e, shuffleobs(e.rng, data)) - else - _dataloader_foldl2(rf, val, e, data) - end +# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix))) +function Base.showarg(io::IO, d::DataLoader, toplevel) + print(io, "DataLoader(") + Base.showarg(io, _unwrapdata(d.data), false) + d.buffer == false || print(io, ", buffer=", d.buffer) + d.parallel == false || print(io, ", parallel=", d.parallel) + d.shuffle == false || print(io, ", shuffle=", d.shuffle) + d.batchsize == 1 || print(io, ", batchsize=", d.batchsize) + d.partial == true || print(io, ", partial=", d.partial) + d.collate === Val(nothing) || print(io, ", collate=", d.collate) + d.rng == Random.default_rng() || print(io, ", rng=", d.rng) + print(io, ")") end -@inline function _dataloader_foldl2(rf, val, e::DataLoader, data) - if e.batchsize > 0 - _dataloader_foldl3(rf, val, e, BatchView(data; e.batchsize, e.partial)) +Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false) + +function Base.show(io::IO, m::MIME"text/plain", d::DataLoader) + print(io, length(d), "-element ") + Base.showarg(io, d, false) + print(io, "\n with first element:") + print(io, "\n ", _expanded_summary(first(d))) +end + +_expanded_summary(x) = summary(x) +function _expanded_summary(xs::Tuple) + parts = [_expanded_summary(x) for x in xs] + "(" * join(parts, ", ") * ",)" +end +function _expanded_summary(xs::NamedTuple) + parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)] + "(; " * join(parts, ", ") * ")" +end + + +### TRANSDUCERS IMPLEMENTATION ############################# + + +@inline function _dataloader_foldl1(rf, val, d::DataLoader, data) + if d.shuffle + return _dataloader_foldl2(rf, val, d, _shuffledata(d.rng, data)) else - _dataloader_foldl3(rf, val, e, data) + return _dataloader_foldl2(rf, val, d, data) end end -@inline function _dataloader_foldl3(rf, val, e::DataLoader, data) - if e.buffer > 0 - _dataloader_foldl4_buffered(rf, val, data) +@inline function _dataloader_foldl2(rf, val, d::DataLoader, data) + if d.buffer == false + return _dataloader_foldl3(rf, val, data) else - _dataloader_foldl4(rf, val, data) + return _dataloader_foldl3_buffered(rf, val, data, d.buffer) end end -@inline function _dataloader_foldl4(rf, val, data) +@inline function _dataloader_foldl3(rf, val, data) for i in 1:numobs(data) @inbounds x = getobs(data, i) # TODO: in 1.8 we could @inline this at the callsite, @@ -331,57 +358,18 @@ end # quite brittle in its capacity to keep this type stable val = Transducers.@next(rf, val, x) end - Transducers.complete(rf, val) + return Transducers.complete(rf, val) end -@inline function _dataloader_foldl4_buffered(rf, val, data) - buf = getobs(data, 1) +@inline function _dataloader_foldl3_buffered(rf, val, data, buf) for i in 1:numobs(data) @inbounds x = getobs!(buf, data, i) val = Transducers.@next(rf, val, x) end - Transducers.complete(rf, val) + return Transducers.complete(rf, val) end -@inline function Transducers.__foldl__(rf, val, e::DataLoader) - e.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads")) - _dataloader_foldl1(rf, val, e, ObsView(e.data)) -end - -# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix))) -function Base.showarg(io::IO, e::DataLoader, toplevel) - print(io, "DataLoader(") - Base.showarg(io, e.data, false) - e.buffer == false || print(io, ", buffer=", e.buffer) - e.parallel == false || print(io, ", parallel=", e.parallel) - e.shuffle == false || print(io, ", shuffle=", e.shuffle) - e.batchsize == 1 || print(io, ", batchsize=", e.batchsize) - e.partial == true || print(io, ", partial=", e.partial) - e.collate === Val(nothing) || print(io, ", collate=", e.collate) - e.rng == Random.default_rng() || print(io, ", rng=", e.rng) - print(io, ")") +@inline function Transducers.__foldl__(rf, val, d::DataLoader) + d.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads")) + return _dataloader_foldl1(rf, val, d, d.data) end - -Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false) - -function Base.show(io::IO, m::MIME"text/plain", e::DataLoader) - if Base.haslength(e) - print(io, length(e), "-element ") - else - print(io, "Unknown-length ") - end - Base.showarg(io, e, false) - print(io, "\n with first element:") - print(io, "\n ", _expanded_summary(first(e))) -end - -_expanded_summary(x) = summary(x) -function _expanded_summary(xs::Tuple) - parts = [_expanded_summary(x) for x in xs] - "(" * join(parts, ", ") * ",)" -end -function _expanded_summary(xs::NamedTuple) - parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)] - "(; " * join(parts, ", ") * ")" -end - diff --git a/src/obstransform.jl b/src/obstransform.jl index 8698240..e5c6168 100644 --- a/src/obstransform.jl +++ b/src/obstransform.jl @@ -199,14 +199,28 @@ joinobs(datas...) = JoinedData(datas) """ shuffleobs([rng], data) -Return a "subset" of `data` that spans all observations, but -has the order of the observations shuffled. +Return a version of the dataset `data` that contains all the +origin observations in a random reordering. The values of `data` itself are not copied. Instead only the indices are shuffled. This function calls [`obsview`](@ref) to accomplish that, which means that the return value is likely of a different type than `data`. +Optionally, a random number generator `rng` can be passed as the +first argument. + +The optional parameter `rng` allows one to specify the +random number generator used for shuffling. This is useful when +reproducible results are desired. + +For this function to work, the type of `data` must implement +[`numobs`](@ref) and [`getobs`](@ref). + +See also [`obsview`](@ref). + +# Examples + ```julia # For Arrays the subset will be of type SubArray @assert typeof(shuffleobs(rand(4,10))) <: SubArray @@ -216,18 +230,9 @@ for x in eachobs(shuffleobs(X)) ... end ``` - -The optional parameter `rng` allows one to specify the -random number generator used for shuffling. This is useful when -reproducible results are desired. By default, uses the global RNG. -See `Random` in Julia's standard library for more info. - -For this function to work, the type of `data` must implement -[`numobs`](@ref) and [`getobs`](@ref). See [`ObsView`](@ref) -for more information. """ shuffleobs(data) = shuffleobs(Random.default_rng(), data) function shuffleobs(rng::AbstractRNG, data) - obsview(data, randperm(rng, numobs(data))) + return obsview(data, randperm(rng, numobs(data))) end diff --git a/src/parallel.jl b/src/parallel.jl index d68b834..f1a7549 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -29,21 +29,21 @@ function eachobsparallel( data; executor::Executor = _default_executor(), - buffer::Bool = false, + buffer = false, channelsize = Threads.nthreads()) - if buffer - return _eachobsparallel_buffered(data, executor; channelsize) - else + if buffer == false return _eachobsparallel_unbuffered(data, executor; channelsize) + else + return _eachobsparallel_buffered(buffer, data, executor; channelsize) end end function _eachobsparallel_buffered( + buffer, data, executor; channelsize=Threads.nthreads()) - buffer = getobs(data, 1) buffers = [buffer] foreach(_ -> push!(buffers, deepcopy(buffer)), 1:channelsize) From d3c9cb8e071d9f297ef9de543e49f1a5a30131b3 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 28 Jan 2025 16:49:02 +0100 Subject: [PATCH 2/4] parametrization --- src/dataloader.jl | 53 +++++++++++++++++++++++++++++++++++------------ src/parallel.jl | 11 ++++++---- test/parallel.jl | 4 ++-- 3 files changed, 49 insertions(+), 19 deletions(-) diff --git a/src/dataloader.jl b/src/dataloader.jl index b4a6a60..601d35d 100644 --- a/src/dataloader.jl +++ b/src/dataloader.jl @@ -136,7 +136,7 @@ julia> first(DataLoader(["a", "b", "c", "d"], batchsize=2, collate=collate_fn)) "ab" ``` """ -struct DataLoader{T<:Union{ObsView,BatchView},B,C,R<:AbstractRNG} +struct DataLoader{T<:Union{ObsView,BatchView},B,P,C,R<:AbstractRNG} data::T batchsize::Int buffer::B # boolean, or external buffer @@ -170,27 +170,54 @@ function DataLoader( if buffer == true buffer = _create_buffer(data) - end + end + P = parallel ? :parallel : :serial # for buffer == false and external buffer, we keep as is - return DataLoader(data, batchsize, buffer, partial, shuffle, parallel, collate, rng) + T, B, C, R = typeof(data), typeof(buffer), typeof(collate), typeof(rng) + return DataLoader{T,B,P,C,R}(data, batchsize, buffer, + partial, shuffle, parallel, collate, rng) end -function Base.iterate(d::DataLoader) + + +# buffered - serial case +function Base.iterate(d::DataLoader{T,B,:serial}) where {T,B} + @assert d.buffer != false data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data - if d.parallel - iter = eachobsparallel(data; d.buffer) - else - if d.buffer == false - iter = (getobs(data, i) for i in 1:numobs(data)) - else - iter = (getobs!(d.buffer, data, i) for i in 1:numobs(data)) - end - end + iter = (getobs!(d.buffer, data, i) for i in 1:numobs(data)) + obs, state = iterate(iter) + return obs, (iter, state) +end + +# buffered - parallel case +function Base.iterate(d::DataLoader{T,B,:parallel}) where {T,B} + @assert d.buffer != false + data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data + iter = _eachobsparallel_buffered(d.buffer, data) + obs, state = iterate(iter) + return obs, (iter, state) +end + +# unbuffered - serial case +function Base.iterate(d::DataLoader{T,Bool,:serial}) where {T} + @assert d.buffer == false + data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data + iter = (getobs(data, i) for i in 1:numobs(data)) + obs, state = iterate(iter) + return obs, (iter, state) +end + +# unbuffered - parallel case +function Base.iterate(d::DataLoader{T,Bool,:parallel}) where {T} + @assert d.buffer == false + data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data + iter = _eachobsparallel_unbuffered(data) obs, state = iterate(iter) return obs, (iter, state) end +## next iterations function Base.iterate(::DataLoader, (iter, state)) ret = iterate(iter, state) isnothing(ret) && return diff --git a/src/parallel.jl b/src/parallel.jl index f1a7549..90706f5 100644 --- a/src/parallel.jl +++ b/src/parallel.jl @@ -29,7 +29,7 @@ function eachobsparallel( data; executor::Executor = _default_executor(), - buffer = false, + buffer::Bool = false, channelsize = Threads.nthreads()) if buffer == false return _eachobsparallel_unbuffered(data, executor; channelsize) @@ -38,11 +38,10 @@ function eachobsparallel( end end - function _eachobsparallel_buffered( buffer, data, - executor; + executor = _default_executor(); channelsize=Threads.nthreads()) buffers = [buffer] foreach(_ -> push!(buffers, deepcopy(buffer)), 1:channelsize) @@ -61,7 +60,11 @@ function _eachobsparallel_buffered( end end -function _eachobsparallel_unbuffered(data, executor; channelsize=Threads.nthreads()) +function _eachobsparallel_unbuffered(data, + executor = _default_executor(); + channelsize=Threads.nthreads() + ) + return Loader(1:numobs(data); executor, channelsize) do ch, i obs = getobs(data, i) put!(ch, obs) diff --git a/test/parallel.jl b/test/parallel.jl index 928b45f..84a808f 100644 --- a/test/parallel.jl +++ b/test/parallel.jl @@ -47,8 +47,8 @@ end end end -@testset "`eachobsparallel(buffer = true)`" begin - iter = eachobsparallel(collect(1:10), buffer=true) +@testset "`DataLoader(buffer = true, parallel=true)`" begin + iter = DataLoader(collect(1:10), buffer=true, batchsize=-1, parallel=true) @test_nowarn for i in iter end X_ = collect(iter) @test all(x ∈ 1:10 for x in X_) From dd30ce8d1cc8b9a977d78e8282beda3028a08b0e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 28 Jan 2025 17:08:44 +0100 Subject: [PATCH 3/4] cleanup --- src/dataloader.jl | 26 +++++++++++--------------- test/Project.toml | 2 ++ 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/dataloader.jl b/src/dataloader.jl index 601d35d..556d0a0 100644 --- a/src/dataloader.jl +++ b/src/dataloader.jl @@ -136,8 +136,9 @@ julia> first(DataLoader(["a", "b", "c", "d"], batchsize=2, collate=collate_fn)) "ab" ``` """ -struct DataLoader{T<:Union{ObsView,BatchView},B,P,C,R<:AbstractRNG} - data::T +struct DataLoader{T<:Union{ObsView,BatchView},B,P,C,O,R<:AbstractRNG} + data::O # original data + _data::T # data wrapped in ObsView / BatchView batchsize::Int buffer::B # boolean, or external buffer partial::Bool @@ -184,7 +185,7 @@ end # buffered - serial case function Base.iterate(d::DataLoader{T,B,:serial}) where {T,B} @assert d.buffer != false - data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data + data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data iter = (getobs!(d.buffer, data, i) for i in 1:numobs(data)) obs, state = iterate(iter) return obs, (iter, state) @@ -193,7 +194,7 @@ end # buffered - parallel case function Base.iterate(d::DataLoader{T,B,:parallel}) where {T,B} @assert d.buffer != false - data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data + data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data iter = _eachobsparallel_buffered(d.buffer, data) obs, state = iterate(iter) return obs, (iter, state) @@ -202,7 +203,7 @@ end # unbuffered - serial case function Base.iterate(d::DataLoader{T,Bool,:serial}) where {T} @assert d.buffer == false - data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data + data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data iter = (getobs(data, i) for i in 1:numobs(data)) obs, state = iterate(iter) return obs, (iter, state) @@ -211,7 +212,7 @@ end # unbuffered - parallel case function Base.iterate(d::DataLoader{T,Bool,:parallel}) where {T} @assert d.buffer == false - data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data + data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data iter = _eachobsparallel_unbuffered(data) obs, state = iterate(iter) return obs, (iter, state) @@ -225,11 +226,6 @@ function Base.iterate(::DataLoader, (iter, state)) return obs, (iter, state) end -# recursively unwraps ObsView and BatchView -_unwrapdata(data::BatchView) = _unwrapdata(data.data) -_unwrapdata(data::ObsView) = _unwrapdata(data.data) -_unwrapdata(data) = data - _shuffledata(rng, data::ObsView) = shuffleobs(rng, data) _shuffledata(rng, data::BatchView) = @@ -247,7 +243,7 @@ function _create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TDa return getobs(x.data, obsindices) end -Base.length(d::DataLoader) = numobs(d.data) +Base.length(d::DataLoader) = numobs(d._data) Base.size(d::DataLoader) = (length(d),) Base.IteratorEltype(d::DataLoader) = Base.EltypeUnknown() @@ -313,7 +309,7 @@ function mapobs(f, d::DataLoader) collate = f ∘ d.collate end - return DataLoader(_unwrapdata(d.data); + return DataLoader(d.data; batchsize=d.batchsize, buffer=d.buffer, partial=d.partial, @@ -327,7 +323,7 @@ end # Base uses this function for composable array printing, e.g. adjoint(view(::Matrix))) function Base.showarg(io::IO, d::DataLoader, toplevel) print(io, "DataLoader(") - Base.showarg(io, _unwrapdata(d.data), false) + Base.showarg(io, d.data, false) d.buffer == false || print(io, ", buffer=", d.buffer) d.parallel == false || print(io, ", parallel=", d.parallel) d.shuffle == false || print(io, ", shuffle=", d.shuffle) @@ -398,5 +394,5 @@ end @inline function Transducers.__foldl__(rf, val, d::DataLoader) d.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads")) - return _dataloader_foldl1(rf, val, d, d.data) + return _dataloader_foldl1(rf, val, d, d._data) end diff --git a/test/Project.toml b/test/Project.toml index b83c121..72b1e61 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,10 @@ [deps] +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" From 0dfddb83f908aaca61c5ba550d0bb94d40a9e29d Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 1 Feb 2025 09:50:21 +0100 Subject: [PATCH 4/4] fix --- src/dataloader.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/dataloader.jl b/src/dataloader.jl index 556d0a0..26f6f9a 100644 --- a/src/dataloader.jl +++ b/src/dataloader.jl @@ -164,19 +164,19 @@ function DataLoader( # Wrapping with ObsView in order to work around # issue https://github.com/FluxML/Flux.jl/issues/1935 - data = ObsView(data) + _data = ObsView(data) if batchsize > 0 - data = BatchView(data; batchsize, partial, collate) + _data = BatchView(_data; batchsize, partial, collate) end if buffer == true - buffer = _create_buffer(data) + buffer = _create_buffer(_data) end P = parallel ? :parallel : :serial # for buffer == false and external buffer, we keep as is - T, B, C, R = typeof(data), typeof(buffer), typeof(collate), typeof(rng) - return DataLoader{T,B,P,C,R}(data, batchsize, buffer, + T, O, B, C, R = typeof(_data), typeof(data), typeof(buffer), typeof(collate), typeof(rng) + return DataLoader{T,B,P,C,O,R}(data, _data, batchsize, buffer, partial, shuffle, parallel, collate, rng) end