From bc9546935085f1db414e792d21f29e85e11fe223 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Drvo=C5=A1t=C4=9Bp?= Date: Mon, 4 Sep 2023 13:56:07 +0200 Subject: [PATCH] Try harder not to outgrow preallocated IOBuffer.data (#235) * Try harder not to outgrow preallocated IOBuffer.data --- src/codec/encode.jl | 91 ++++++++++++++++++++------------------------- test/test_encode.jl | 60 +++++++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 52 deletions(-) diff --git a/src/codec/encode.jl b/src/codec/encode.jl index 2505971..d1faa50 100644 --- a/src/codec/encode.jl +++ b/src/codec/encode.jl @@ -4,67 +4,56 @@ function encode_tag(io::IO, field_number, wire_type::WireType) end encode_tag(e::ProtoEncoder, field_number, wire_type::WireType) = encode_tag(e.io, field_number, wire_type) # TODO: audit usage and composability of maybe_ensure_room -maybe_ensure_room(io::IOBuffer, n) = Base.ensureroom(io, n) -maybe_ensure_room(::IO, n) = nothing +maybe_ensure_room(io::IOBuffer, n) = Base.ensureroom(io, min(io.maxsize, n)) -# When we don't know the lenght beforehand we -# 1. Allocate 5 bytes for the length -# 2. encode data -# 3. come back to beginning and encode the length -# 4. shift the encoded data in case we didn't use all 5 bytes allocated for length -@inline function _with_size(f, io::IOBuffer, sink, x) - if io.seekable - MAX_LENGTH_VARINT_BYTES = 5 # max size of a UInt32 as vbyte - initpos = position(io) - truncate(io, initpos + MAX_LENGTH_VARINT_BYTES) # 1. - seek(io, initpos + MAX_LENGTH_VARINT_BYTES) - f(sink, x) # e.g. _encode(io, x) # 2. - endpos = position(io) - data_len = endpos - initpos - MAX_LENGTH_VARINT_BYTES - seek(io, initpos) # 3. - vbyte_encode(io, UInt32(data_len)) # --||-- - lenght_len = position(io) - initpos - unsafe_copyto!(io.data, initpos + lenght_len + 1, io.data, initpos + MAX_LENGTH_VARINT_BYTES + 1, data_len) # 4. - seek(io, initpos + lenght_len + data_len) - truncate(io, initpos + lenght_len + data_len) - else - vbyte_encode(io, UInt32(_encoded_size(x))) - f(sink, x) - end - return nothing -end +maybe_ensure_room(::IO, n) = nothing -@inline function _with_size(f, io::IOBuffer, sink, x, V) +@inline function _with_size(f, io::IOBuffer, sink, x, V...) if io.seekable - MAX_LENGTH_VARINT_BYTES = 5 # max size of a UInt32 as vbyte initpos = position(io) - truncate(io, initpos + MAX_LENGTH_VARINT_BYTES) # 1. - seek(io, initpos + MAX_LENGTH_VARINT_BYTES) - f(sink, x, V) # e.g. _encode(io, x, Val{:zigzag}) # 2. + # We need to encode the encoded size of x before we know it. We first preallocate 1 + # byte as that is the mininum size of the encoded size. + # If our guess is right, it will save us a copy, but we never want to preallocate too much + # space for the size, because then we risk outgrowing the buffer that was allocated with exact size + # needed to contain the message. + # TODO: make the guess better (e.g. by incorporating maxsize) + encoded_size_len_guess = 1 + truncate(io, initpos + encoded_size_len_guess) + seek(io, initpos + encoded_size_len_guess) + # Now we can encode the object itself + f(sink, x, V...) # e.g. _encode(io, x) or _encode(io, x, Val{:zigzag}) endpos = position(io) - data_len = endpos - initpos - MAX_LENGTH_VARINT_BYTES - seek(io, initpos) # 3. - vbyte_encode(io, UInt32(data_len)) # --||-- - lenght_len = position(io) - initpos - unsafe_copyto!(io.data, initpos + lenght_len + 1, io.data, initpos + MAX_LENGTH_VARINT_BYTES + 1, data_len) # 4. - seek(io, initpos + lenght_len + data_len) - truncate(io, initpos + lenght_len + data_len) + encoded_size = endpos - initpos - encoded_size_len_guess + encoded_size_len = _encoded_size(UInt32(encoded_size)) + @assert (initpos + encoded_size_len + encoded_size) <= io.maxsize + # If our initial guess on encoded size of the size was wrong, then we need to move the encoded data + if encoded_size_len_guess < encoded_size_len + truncate(io, initpos + encoded_size_len + encoded_size) + # Move the data right after the correct size + unsafe_copyto!( + io.data, + initpos + encoded_size_len + 1, + io.data, + initpos + encoded_size_len_guess + 1, + encoded_size + ) + end + # Now we can encode the size + seek(io, initpos) + vbyte_encode(io, UInt32(encoded_size)) + seek(io, initpos + encoded_size_len + encoded_size) else - vbyte_encode(io, UInt32(_encoded_size(x, V))) - f(sink, x, V) + # TODO: avoid quadratic behavior when estimating encoded size by providing a scratch buffer + vbyte_encode(io, UInt32(_encoded_size(x, V...))) + f(sink, x, V...) end return nothing end -@inline function _with_size(f, io::IO, sink, x) - vbyte_encode(io, UInt32(_encoded_size(x))) - f(sink, x) - return nothing -end - -@inline function _with_size(f, io::IO, sink, x, V) - vbyte_encode(io, UInt32(_encoded_size(x, V))) - f(sink, x, V) +@inline function _with_size(f, io::IO, sink, x, V...) + # TODO: avoid quadratic behavior when estimating encoded size by providing a scratch buffer + vbyte_encode(io, UInt32(_encoded_size(x, V...))) + f(sink, x, V...) return nothing end diff --git a/test/test_encode.jl b/test/test_encode.jl index 939395d..edf4f08 100644 --- a/test/test_encode.jl +++ b/test/test_encode.jl @@ -349,6 +349,30 @@ end Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6], Val{:zigzag}) @test take!(io) == UInt8[6, 2, 4, 6, 8, 10, 12] + io = IOBuffer(zeros(UInt8, 7), maxsize=7, read=false, write=true) + Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6]) + @test io.data == UInt8[6, 1, 2, 3, 4, 5, 6] + + io = IOBuffer(zeros(UInt8, 7), maxsize=7, read=false, write=true) + Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6], Val{:zigzag}) + @test io.data == UInt8[6, 2, 4, 6, 8, 10, 12] + + io = IOBuffer(;maxsize=2^14 + 1) + Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6]) + @test take!(io) == UInt8[6, 1, 2, 3, 4, 5, 6] + + io = IOBuffer(;maxsize=2^14 + 1) + Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6], Val{:zigzag}) + @test take!(io) == UInt8[6, 2, 4, 6, 8, 10, 12] + + io = IOBuffer(;maxsize=2^21 + 1) + Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6]) + @test take!(io) == UInt8[6, 1, 2, 3, 4, 5, 6] + + io = IOBuffer(;maxsize=2^21 + 1) + Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6], Val{:zigzag}) + @test take!(io) == UInt8[6, 2, 4, 6, 8, 10, 12] + io = PipeBuffer() Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6]) @test take!(io) == UInt8[6, 1, 2, 3, 4, 5, 6] @@ -356,12 +380,46 @@ end io = PipeBuffer() Codecs._with_size(Codecs._encode, io, io, [1, 2, 3, 4, 5, 6], Val{:zigzag}) @test take!(io) == UInt8[6, 2, 4, 6, 8, 10, 12] + + io = IOBuffer() + Codecs._with_size(Codecs._encode, io, io, collect(1:128)) + @test take!(io) == vcat(UInt8(129), UInt8(1), UInt8.(collect(1:127)), UInt8(128), UInt8(1)) + + io = IOBuffer() + Codecs._with_size(Codecs._encode, io, io, fill(2, 129), Val{:zigzag}) + @test take!(io) == vcat(UInt8(129), UInt8(1), fill(UInt8(4), 129)) + + io = IOBuffer(;maxsize=2^14 + 1) + Codecs._with_size(Codecs._encode, io, io, collect(1:128)) + @test take!(io) == vcat(UInt8(129), UInt8(1), UInt8.(collect(1:127)), UInt8(128), UInt8(1)) + + io = IOBuffer(;maxsize=2^14 + 1) + Codecs._with_size(Codecs._encode, io, io, fill(2, 129), Val{:zigzag}) + @test take!(io) == vcat(UInt8(129), UInt8(1), fill(UInt8(4), 129)) + + io = IOBuffer(;maxsize=2^21 + 1) + Codecs._with_size(Codecs._encode, io, io, collect(1:128)) + @test take!(io) == vcat(UInt8(129), UInt8(1), UInt8.(collect(1:127)), UInt8(128), UInt8(1)) + + io = IOBuffer(;maxsize=2^21 + 1) + Codecs._with_size(Codecs._encode, io, io, fill(2, 129), Val{:zigzag}) + @test take!(io) == vcat(UInt8(129), UInt8(1), fill(UInt8(4), 129)) + + io = PipeBuffer() + Codecs._with_size(Codecs._encode, io, io, collect(1:128)) + @test take!(io) == vcat(UInt8(129), UInt8(1), UInt8.(collect(1:127)), UInt8(128), UInt8(1)) + + io = PipeBuffer() + Codecs._with_size(Codecs._encode, io, io, fill(2, 129), Val{:zigzag}) + @test take!(io) == vcat(UInt8(129), UInt8(1), fill(UInt8(4), 129)) end @testset "_encoded_size" begin @test _encoded_size(nothing) == 0 @test _encoded_size(UInt8[0xff]) == 1 + @test _encoded_size(UInt8[]) == 0 @test _encoded_size("S") == 1 + @test _encoded_size("") == 0 @test _encoded_size(typemax(UInt32)) == 5 @test _encoded_size(typemax(UInt64)) == 10 @test _encoded_size(typemax(Int32)) == 5 @@ -506,4 +564,4 @@ end @test _encoded_size(Dict("K" => typemax(Float64))) == ((1 + 1 + 1) + (1 + 8)) + 1 @test _encoded_size(Dict("K" => typemax(Float64)), 1) == ((1 + 1 + 1) + (1 + 8)) + 2 end -end # module \ No newline at end of file +end # module