Skip to content

Commit

Permalink
parametrization
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jan 28, 2025
1 parent 56f29dc commit d3c9cb8
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 19 deletions.
53 changes: 40 additions & 13 deletions src/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ julia> first(DataLoader(["a", "b", "c", "d"], batchsize=2, collate=collate_fn))
"ab"
```
"""
struct DataLoader{T<:Union{ObsView,BatchView},B,C,R<:AbstractRNG}
struct DataLoader{T<:Union{ObsView,BatchView},B,P,C,R<:AbstractRNG}
data::T
batchsize::Int
buffer::B # boolean, or external buffer
Expand Down Expand Up @@ -170,27 +170,54 @@ function DataLoader(

if buffer == true
buffer = _create_buffer(data)
end
end
P = parallel ? :parallel : :serial
# for buffer == false and external buffer, we keep as is

return DataLoader(data, batchsize, buffer, partial, shuffle, parallel, collate, rng)
T, B, C, R = typeof(data), typeof(buffer), typeof(collate), typeof(rng)
return DataLoader{T,B,P,C,R}(data, batchsize, buffer,
partial, shuffle, parallel, collate, rng)
end

function Base.iterate(d::DataLoader)


# buffered - serial case
function Base.iterate(d::DataLoader{T,B,:serial}) where {T,B}
@assert d.buffer != false
data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data
if d.parallel
iter = eachobsparallel(data; d.buffer)
else
if d.buffer == false
iter = (getobs(data, i) for i in 1:numobs(data))
else
iter = (getobs!(d.buffer, data, i) for i in 1:numobs(data))
end
end
iter = (getobs!(d.buffer, data, i) for i in 1:numobs(data))
obs, state = iterate(iter)
return obs, (iter, state)
end

# buffered - parallel case
function Base.iterate(d::DataLoader{T,B,:parallel}) where {T,B}
@assert d.buffer != false
data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data
iter = _eachobsparallel_buffered(d.buffer, data)
obs, state = iterate(iter)
return obs, (iter, state)
end

# unbuffered - serial case
function Base.iterate(d::DataLoader{T,Bool,:serial}) where {T}
@assert d.buffer == false
data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data
iter = (getobs(data, i) for i in 1:numobs(data))
obs, state = iterate(iter)
return obs, (iter, state)
end

# unbuffered - parallel case
function Base.iterate(d::DataLoader{T,Bool,:parallel}) where {T}
@assert d.buffer == false
data = d.shuffle ? _shuffledata(d.rng, d.data) : d.data
iter = _eachobsparallel_unbuffered(data)
obs, state = iterate(iter)
return obs, (iter, state)
end

## next iterations
function Base.iterate(::DataLoader, (iter, state))
ret = iterate(iter, state)
isnothing(ret) && return
Expand Down
11 changes: 7 additions & 4 deletions src/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
function eachobsparallel(
data;
executor::Executor = _default_executor(),
buffer = false,
buffer::Bool = false,
channelsize = Threads.nthreads())
if buffer == false
return _eachobsparallel_unbuffered(data, executor; channelsize)
Expand All @@ -38,11 +38,10 @@ function eachobsparallel(
end
end


function _eachobsparallel_buffered(
buffer,
data,
executor;
executor = _default_executor();
channelsize=Threads.nthreads())
buffers = [buffer]
foreach(_ -> push!(buffers, deepcopy(buffer)), 1:channelsize)
Expand All @@ -61,7 +60,11 @@ function _eachobsparallel_buffered(
end
end

function _eachobsparallel_unbuffered(data, executor; channelsize=Threads.nthreads())
function _eachobsparallel_unbuffered(data,
executor = _default_executor();
channelsize=Threads.nthreads()
)

return Loader(1:numobs(data); executor, channelsize) do ch, i
obs = getobs(data, i)
put!(ch, obs)
Expand Down
4 changes: 2 additions & 2 deletions test/parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ end
end
end

@testset "`eachobsparallel(buffer = true)`" begin
iter = eachobsparallel(collect(1:10), buffer=true)
@testset "`DataLoader(buffer = true, parallel=true)`" begin
iter = DataLoader(collect(1:10), buffer=true, batchsize=-1, parallel=true)
@test_nowarn for i in iter end
X_ = collect(iter)
@test all(x 1:10 for x in X_)
Expand Down

0 comments on commit d3c9cb8

Please sign in to comment.