Skip to content

Commit

Permalink
Try harder not to outgrow preallocated IOBuffer.data (#235)
Browse files Browse the repository at this point in the history
* Try harder not to outgrow preallocated IOBuffer.data
  • Loading branch information
Drvi authored Sep 4, 2023
1 parent 8c751c1 commit bc95469
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 52 deletions.
91 changes: 40 additions & 51 deletions src/codec/encode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,67 +4,56 @@ function encode_tag(io::IO, field_number, wire_type::WireType)
end
encode_tag(e::ProtoEncoder, field_number, wire_type::WireType) = encode_tag(e.io, field_number, wire_type)
# TODO: audit usage and composability of maybe_ensure_room
maybe_ensure_room(io::IOBuffer, n) = Base.ensureroom(io, n)
maybe_ensure_room(::IO, n) = nothing
maybe_ensure_room(io::IOBuffer, n) = Base.ensureroom(io, min(io.maxsize, n))

# When we don't know the lenght beforehand we
# 1. Allocate 5 bytes for the length
# 2. encode data
# 3. come back to beginning and encode the length
# 4. shift the encoded data in case we didn't use all 5 bytes allocated for length
@inline function _with_size(f, io::IOBuffer, sink, x)
if io.seekable
MAX_LENGTH_VARINT_BYTES = 5 # max size of a UInt32 as vbyte
initpos = position(io)
truncate(io, initpos + MAX_LENGTH_VARINT_BYTES) # 1.
seek(io, initpos + MAX_LENGTH_VARINT_BYTES)
f(sink, x) # e.g. _encode(io, x) # 2.
endpos = position(io)
data_len = endpos - initpos - MAX_LENGTH_VARINT_BYTES
seek(io, initpos) # 3.
vbyte_encode(io, UInt32(data_len)) # --||--
lenght_len = position(io) - initpos
unsafe_copyto!(io.data, initpos + lenght_len + 1, io.data, initpos + MAX_LENGTH_VARINT_BYTES + 1, data_len) # 4.
seek(io, initpos + lenght_len + data_len)
truncate(io, initpos + lenght_len + data_len)
else
vbyte_encode(io, UInt32(_encoded_size(x)))
f(sink, x)
end
return nothing
end
maybe_ensure_room(::IO, n) = nothing

@inline function _with_size(f, io::IOBuffer, sink, x, V)
@inline function _with_size(f, io::IOBuffer, sink, x, V...)
if io.seekable
MAX_LENGTH_VARINT_BYTES = 5 # max size of a UInt32 as vbyte
initpos = position(io)
truncate(io, initpos + MAX_LENGTH_VARINT_BYTES) # 1.
seek(io, initpos + MAX_LENGTH_VARINT_BYTES)
f(sink, x, V) # e.g. _encode(io, x, Val{:zigzag}) # 2.
# We need to encode the encoded size of x before we know it. We first preallocate 1
# byte as that is the mininum size of the encoded size.
# If our guess is right, it will save us a copy, but we never want to preallocate too much
# space for the size, because then we risk outgrowing the buffer that was allocated with exact size
# needed to contain the message.
# TODO: make the guess better (e.g. by incorporating maxsize)
encoded_size_len_guess = 1
truncate(io, initpos + encoded_size_len_guess)
seek(io, initpos + encoded_size_len_guess)
# Now we can encode the object itself
f(sink, x, V...) # e.g. _encode(io, x) or _encode(io, x, Val{:zigzag})
endpos = position(io)
data_len = endpos - initpos - MAX_LENGTH_VARINT_BYTES
seek(io, initpos) # 3.
vbyte_encode(io, UInt32(data_len)) # --||--
lenght_len = position(io) - initpos
unsafe_copyto!(io.data, initpos + lenght_len + 1, io.data, initpos + MAX_LENGTH_VARINT_BYTES + 1, data_len) # 4.
seek(io, initpos + lenght_len + data_len)
truncate(io, initpos + lenght_len + data_len)
encoded_size = endpos - initpos - encoded_size_len_guess
encoded_size_len = _encoded_size(UInt32(encoded_size))
@assert (initpos + encoded_size_len + encoded_size) <= io.maxsize
# If our initial guess on encoded size of the size was wrong, then we need to move the encoded data
if encoded_size_len_guess < encoded_size_len
truncate(io, initpos + encoded_size_len + encoded_size)
# Move the data right after the correct size
unsafe_copyto!(
io.data,
initpos + encoded_size_len + 1,
io.data,
initpos + encoded_size_len_guess + 1,
encoded_size
)
end
# Now we can encode the size
seek(io, initpos)
vbyte_encode(io, UInt32(encoded_size))
seek(io, initpos + encoded_size_len + encoded_size)
else
vbyte_encode(io, UInt32(_encoded_size(x, V)))
f(sink, x, V)
# TODO: avoid quadratic behavior when estimating encoded size by providing a scratch buffer
vbyte_encode(io, UInt32(_encoded_size(x, V...)))
f(sink, x, V...)
end
return nothing
end

@inline function _with_size(f, io::IO, sink, x)
vbyte_encode(io, UInt32(_encoded_size(x)))
f(sink, x)
return nothing
end

@inline function _with_size(f, io::IO, sink, x, V)
vbyte_encode(io, UInt32(_encoded_size(x, V)))
f(sink, x, V)
@inline function _with_size(f, io::IO, sink, x, V...)
# TODO: avoid quadratic behavior when estimating encoded size by providing a scratch buffer
vbyte_encode(io, UInt32(_encoded_size(x, V...)))
f(sink, x, V...)
return nothing
end

Expand Down
60 changes: 59 additions & 1 deletion test/test_encode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,19 +349,77 @@ end
Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6], Val{:zigzag})
@test take!(io) == UInt8[6, 2, 4, 6, 8, 10, 12]

io = IOBuffer(zeros(UInt8, 7), maxsize=7, read=false, write=true)
Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6])
@test io.data == UInt8[6, 1, 2, 3, 4, 5, 6]

io = IOBuffer(zeros(UInt8, 7), maxsize=7, read=false, write=true)
Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6], Val{:zigzag})
@test io.data == UInt8[6, 2, 4, 6, 8, 10, 12]

io = IOBuffer(;maxsize=2^14 + 1)
Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6])
@test take!(io) == UInt8[6, 1, 2, 3, 4, 5, 6]

io = IOBuffer(;maxsize=2^14 + 1)
Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6], Val{:zigzag})
@test take!(io) == UInt8[6, 2, 4, 6, 8, 10, 12]

io = IOBuffer(;maxsize=2^21 + 1)
Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6])
@test take!(io) == UInt8[6, 1, 2, 3, 4, 5, 6]

io = IOBuffer(;maxsize=2^21 + 1)
Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6], Val{:zigzag})
@test take!(io) == UInt8[6, 2, 4, 6, 8, 10, 12]

io = PipeBuffer()
Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6])
@test take!(io) == UInt8[6, 1, 2, 3, 4, 5, 6]

io = PipeBuffer()
Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6], Val{:zigzag})
@test take!(io) == UInt8[6, 2, 4, 6, 8, 10, 12]

io = IOBuffer()
Codecs._with_size(Codecs._encode, io, io, collect(1:128))
@test take!(io) == vcat(UInt8(129), UInt8(1), UInt8.(collect(1:127)), UInt8(128), UInt8(1))

io = IOBuffer()
Codecs._with_size(Codecs._encode, io, io, fill(2, 129), Val{:zigzag})
@test take!(io) == vcat(UInt8(129), UInt8(1), fill(UInt8(4), 129))

io = IOBuffer(;maxsize=2^14 + 1)
Codecs._with_size(Codecs._encode, io, io, collect(1:128))
@test take!(io) == vcat(UInt8(129), UInt8(1), UInt8.(collect(1:127)), UInt8(128), UInt8(1))

io = IOBuffer(;maxsize=2^14 + 1)
Codecs._with_size(Codecs._encode, io, io, fill(2, 129), Val{:zigzag})
@test take!(io) == vcat(UInt8(129), UInt8(1), fill(UInt8(4), 129))

io = IOBuffer(;maxsize=2^21 + 1)
Codecs._with_size(Codecs._encode, io, io, collect(1:128))
@test take!(io) == vcat(UInt8(129), UInt8(1), UInt8.(collect(1:127)), UInt8(128), UInt8(1))

io = IOBuffer(;maxsize=2^21 + 1)
Codecs._with_size(Codecs._encode, io, io, fill(2, 129), Val{:zigzag})
@test take!(io) == vcat(UInt8(129), UInt8(1), fill(UInt8(4), 129))

io = PipeBuffer()
Codecs._with_size(Codecs._encode, io, io, collect(1:128))
@test take!(io) == vcat(UInt8(129), UInt8(1), UInt8.(collect(1:127)), UInt8(128), UInt8(1))

io = PipeBuffer()
Codecs._with_size(Codecs._encode, io, io, fill(2, 129), Val{:zigzag})
@test take!(io) == vcat(UInt8(129), UInt8(1), fill(UInt8(4), 129))
end

@testset "_encoded_size" begin
@test _encoded_size(nothing) == 0
@test _encoded_size(UInt8[0xff]) == 1
@test _encoded_size(UInt8[]) == 0
@test _encoded_size("S") == 1
@test _encoded_size("") == 0
@test _encoded_size(typemax(UInt32)) == 5
@test _encoded_size(typemax(UInt64)) == 10
@test _encoded_size(typemax(Int32)) == 5
Expand Down Expand Up @@ -506,4 +564,4 @@ end
@test _encoded_size(Dict("K" => typemax(Float64))) == ((1 + 1 + 1) + (1 + 8)) + 1
@test _encoded_size(Dict("K" => typemax(Float64)), 1) == ((1 + 1 + 1) + (1 + 8)) + 2
end
end # module
end # module

0 comments on commit bc95469

Please sign in to comment.