Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rework DataLoader #192

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ export mapobs,
shuffleobs

include("batchview.jl")
export batchsize,
BatchView
export batchsize, BatchView

include("obsview.jl")
export obsview, ObsView

include("dataloader.jl")
export eachobs, DataLoader
Expand All @@ -48,10 +50,6 @@ include("folds.jl")
export kfolds,
leavepout

include("obsview.jl")
export obsview,
ObsView

include("randobs.jl")
export randobs

Expand Down
235 changes: 123 additions & 112 deletions src/dataloader.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,9 @@ julia> first(DataLoader(["a", "b", "c", "d"], batchsize=2, collate=collate_fn))
"ab"
```
"""
struct DataLoader{T,B,C,R<:AbstractRNG}
data::T
struct DataLoader{T<:Union{ObsView,BatchView},B,P,C,O,R<:AbstractRNG}
data::O # original data
_data::T # data wrapped in ObsView / BatchView
batchsize::Int
buffer::B # boolean, or external buffer
partial::Bool
Expand All @@ -157,74 +158,94 @@ function DataLoader(
collate = Val(nothing),
rng::AbstractRNG = Random.default_rng())

if !(buffer isa Bool) && parallel
throw(ArgumentError("If `parallel=true`, `buffer` must be a boolean."))
end

if collate isa Bool || collate === nothing
collate = Val(collate)
end
return DataLoader(data, batchsize, buffer, partial, shuffle, parallel, collate, rng)

# Wrapping with ObsView in order to work around
# issue https://github.com/FluxML/Flux.jl/issues/1935
data = ObsView(data)
if batchsize > 0
data = BatchView(data; batchsize, partial, collate)
end

if buffer == true
buffer = _create_buffer(data)
end
P = parallel ? :parallel : :serial
# for buffer == false and external buffer, we keep as is

T, B, C, R = typeof(data), typeof(buffer), typeof(collate), typeof(rng)
return DataLoader{T,B,P,C,R}(data, batchsize, buffer,
partial, shuffle, parallel, collate, rng)
end

function Base.iterate(d::DataLoader)
# TODO move ObsView and BatchWView wrapping to the constructor, so that
# we can parametrize the DataLoader with ObsView and BatchView and define specialized methods.

# Wrapping with ObsView in order to work around
# issue https://github.com/FluxML/Flux.jl/issues/1935
data = ObsView(d.data)

data = d.shuffle ? shuffleobs(d.rng, data) : data
data = d.batchsize > 0 ? BatchView(data; d.batchsize, d.partial, d.collate) : data
# buffered - serial case
function Base.iterate(d::DataLoader{T,B,:serial}) where {T,B}
@assert d.buffer != false
data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data
iter = (getobs!(d.buffer, data, i) for i in 1:numobs(data))
obs, state = iterate(iter)
return obs, (iter, state)
end

if d.parallel
iter = eachobsparallel(data; d.buffer)
else
if d.buffer == false
iter = (getobs(data, i) for i in 1:numobs(data))
elseif d.buffer == true
buf = create_buffer(data)
iter = (getobs!(buf, data, i) for i in 1:numobs(data))
else # external buffer
buf = d.buffer
iter = (getobs!(buf, data, i) for i in 1:numobs(data))
end
end
# buffered - parallel case
function Base.iterate(d::DataLoader{T,B,:parallel}) where {T,B}
@assert d.buffer != false
data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data
iter = _eachobsparallel_buffered(d.buffer, data)
obs, state = iterate(iter)
return obs, (iter, state)
end

create_buffer(x) = getobs(x, 1)
function create_buffer(x::BatchView)
obsindices = _batchrange(x, 1)
return [getobs(A.data, idx) for idx in enumerate(obsindices)]
# unbuffered - serial case
function Base.iterate(d::DataLoader{T,Bool,:serial}) where {T}
@assert d.buffer == false
data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data
iter = (getobs(data, i) for i in 1:numobs(data))
obs, state = iterate(iter)
return obs, (iter, state)
end
function create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData}
obsindices = _batchrange(x, 1)
return getobs(x.data, obsindices)

# unbuffered - parallel case
function Base.iterate(d::DataLoader{T,Bool,:parallel}) where {T}
@assert d.buffer == false
data = d.shuffle ? _shuffledata(d.rng, d._data) : d._data
iter = _eachobsparallel_unbuffered(data)
obs, state = iterate(iter)
return obs, (iter, state)
end

## next iterations
function Base.iterate(::DataLoader, (iter, state))
ret = iterate(iter, state)
isnothing(ret) && return
obs, state = ret
return obs, (iter, state)
end

_shuffledata(rng, data::ObsView) = shuffleobs(rng, data)

function Base.length(d::DataLoader)
if d.batchsize > 0
return numobs(BatchView(d.data; d.batchsize, d.partial))
else
return numobs(d.data)
end
end
_shuffledata(rng, data::BatchView) =
BatchView(shuffleobs(rng, data.data); data.batchsize, data.partial, data.collate)

_create_buffer(x) = getobs(x, 1)

Base.size(e::DataLoader) = (length(e),)
function _create_buffer(x::BatchView)
obsindices = _batchrange(x, 1)
return [getobs(A.data, idx) for idx in enumerate(obsindices)]
end

function _create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData}
obsindices = _batchrange(x, 1)
return getobs(x.data, obsindices)
end

Base.IteratorEltype(::DataLoader) = Base.EltypeUnknown()
Base.length(d::DataLoader) = numobs(d._data)
Base.size(d::DataLoader) = (length(d),)
Base.IteratorEltype(d::DataLoader) = Base.EltypeUnknown()

## This causes error in some cases of `collect(loader)`
# function Base.eltype(e::DataLoader)
Expand Down Expand Up @@ -288,100 +309,90 @@ function mapobs(f, d::DataLoader)
collate = f ∘ d.collate
end

DataLoader(d.data,
batchsize=d.batchsize,
buffer=d.buffer,
partial=d.partial,
shuffle=d.shuffle,
parallel=d.parallel,
collate=collate,
rng=d.rng)
return DataLoader(d.data;
batchsize=d.batchsize,
buffer=d.buffer,
partial=d.partial,
shuffle=d.shuffle,
parallel=d.parallel,
collate=collate,
rng=d.rng)
end


@inline function _dataloader_foldl1(rf, val, e::DataLoader, data)
if e.shuffle
_dataloader_foldl2(rf, val, e, shuffleobs(e.rng, data))
else
_dataloader_foldl2(rf, val, e, data)
end
# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix)))
function Base.showarg(io::IO, d::DataLoader, toplevel)
print(io, "DataLoader(")
Base.showarg(io, d.data, false)
d.buffer == false || print(io, ", buffer=", d.buffer)
d.parallel == false || print(io, ", parallel=", d.parallel)
d.shuffle == false || print(io, ", shuffle=", d.shuffle)
d.batchsize == 1 || print(io, ", batchsize=", d.batchsize)
d.partial == true || print(io, ", partial=", d.partial)
d.collate === Val(nothing) || print(io, ", collate=", d.collate)
d.rng == Random.default_rng() || print(io, ", rng=", d.rng)
print(io, ")")
end

@inline function _dataloader_foldl2(rf, val, e::DataLoader, data)
if e.batchsize > 0
_dataloader_foldl3(rf, val, e, BatchView(data; e.batchsize, e.partial))
Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false)

function Base.show(io::IO, m::MIME"text/plain", d::DataLoader)
print(io, length(d), "-element ")
Base.showarg(io, d, false)
print(io, "\n with first element:")
print(io, "\n ", _expanded_summary(first(d)))
end

_expanded_summary(x) = summary(x)
function _expanded_summary(xs::Tuple)
parts = [_expanded_summary(x) for x in xs]
"(" * join(parts, ", ") * ",)"
end
function _expanded_summary(xs::NamedTuple)
parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)]
"(; " * join(parts, ", ") * ")"
end


### TRANSDUCERS IMPLEMENTATION #############################


@inline function _dataloader_foldl1(rf, val, d::DataLoader, data)
if d.shuffle
return _dataloader_foldl2(rf, val, d, _shuffledata(d.rng, data))
else
_dataloader_foldl3(rf, val, e, data)
return _dataloader_foldl2(rf, val, d, data)
end
end

@inline function _dataloader_foldl3(rf, val, e::DataLoader, data)
if e.buffer > 0
_dataloader_foldl4_buffered(rf, val, data)
@inline function _dataloader_foldl2(rf, val, d::DataLoader, data)
if d.buffer == false
return _dataloader_foldl3(rf, val, data)
else
_dataloader_foldl4(rf, val, data)
return _dataloader_foldl3_buffered(rf, val, data, d.buffer)
end
end

@inline function _dataloader_foldl4(rf, val, data)
@inline function _dataloader_foldl3(rf, val, data)
for i in 1:numobs(data)
@inbounds x = getobs(data, i)
# TODO: in 1.8 we could @inline this at the callsite,
# optimizer seems to be very sensitive to inlining and
# quite brittle in its capacity to keep this type stable
val = Transducers.@next(rf, val, x)
end
Transducers.complete(rf, val)
return Transducers.complete(rf, val)
end

@inline function _dataloader_foldl4_buffered(rf, val, data)
buf = getobs(data, 1)
@inline function _dataloader_foldl3_buffered(rf, val, data, buf)
for i in 1:numobs(data)
@inbounds x = getobs!(buf, data, i)
val = Transducers.@next(rf, val, x)
end
Transducers.complete(rf, val)
return Transducers.complete(rf, val)
end

@inline function Transducers.__foldl__(rf, val, e::DataLoader)
e.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads"))
_dataloader_foldl1(rf, val, e, ObsView(e.data))
@inline function Transducers.__foldl__(rf, val, d::DataLoader)
d.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads"))
return _dataloader_foldl1(rf, val, d, d._data)
end

# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix)))
function Base.showarg(io::IO, e::DataLoader, toplevel)
print(io, "DataLoader(")
Base.showarg(io, e.data, false)
e.buffer == false || print(io, ", buffer=", e.buffer)
e.parallel == false || print(io, ", parallel=", e.parallel)
e.shuffle == false || print(io, ", shuffle=", e.shuffle)
e.batchsize == 1 || print(io, ", batchsize=", e.batchsize)
e.partial == true || print(io, ", partial=", e.partial)
e.collate === Val(nothing) || print(io, ", collate=", e.collate)
e.rng == Random.default_rng() || print(io, ", rng=", e.rng)
print(io, ")")
end

Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false)

function Base.show(io::IO, m::MIME"text/plain", e::DataLoader)
if Base.haslength(e)
print(io, length(e), "-element ")
else
print(io, "Unknown-length ")
end
Base.showarg(io, e, false)
print(io, "\n with first element:")
print(io, "\n ", _expanded_summary(first(e)))
end

_expanded_summary(x) = summary(x)
function _expanded_summary(xs::Tuple)
parts = [_expanded_summary(x) for x in xs]
"(" * join(parts, ", ") * ",)"
end
function _expanded_summary(xs::NamedTuple)
parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)]
"(; " * join(parts, ", ") * ")"
end

29 changes: 17 additions & 12 deletions src/obstransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,28 @@ joinobs(datas...) = JoinedData(datas)
"""
shuffleobs([rng], data)

Return a "subset" of `data` that spans all observations, but
has the order of the observations shuffled.
Return a version of the dataset `data` that contains all the
origin observations in a random reordering.

The values of `data` itself are not copied. Instead only the
indices are shuffled. This function calls [`obsview`](@ref) to
accomplish that, which means that the return value is likely of a
different type than `data`.

Optionally, a random number generator `rng` can be passed as the
first argument.

The optional parameter `rng` allows one to specify the
random number generator used for shuffling. This is useful when
reproducible results are desired.

For this function to work, the type of `data` must implement
[`numobs`](@ref) and [`getobs`](@ref).

See also [`obsview`](@ref).

# Examples

```julia
# For Arrays the subset will be of type SubArray
@assert typeof(shuffleobs(rand(4,10))) <: SubArray
Expand All @@ -216,18 +230,9 @@ for x in eachobs(shuffleobs(X))
...
end
```

The optional parameter `rng` allows one to specify the
random number generator used for shuffling. This is useful when
reproducible results are desired. By default, uses the global RNG.
See `Random` in Julia's standard library for more info.

For this function to work, the type of `data` must implement
[`numobs`](@ref) and [`getobs`](@ref). See [`ObsView`](@ref)
for more information.
"""
shuffleobs(data) = shuffleobs(Random.default_rng(), data)

function shuffleobs(rng::AbstractRNG, data)
obsview(data, randperm(rng, numobs(data)))
return obsview(data, randperm(rng, numobs(data)))
end
Loading
Loading