Skip to content

Commit

Permalink
Add stream VByte encoding for shorts, integers and longs
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Jan 20, 2025
1 parent bcbc236 commit 21ffa7c
Show file tree
Hide file tree
Showing 6 changed files with 649 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution.buffer.vstream;

import com.google.common.annotations.VisibleForTesting;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;

import static com.google.common.base.Verify.verify;
import static io.airlift.slice.SizeOf.SIZE_OF_INT;

public class IntStreamVByte
{
private static final int GROUP_SIZE = 4;

private IntStreamVByte() {}

public static Slice encode(int[] values, Slice slice)
{
verify(slice.length() >= maxEncodedSize(values.length), "Slice is smaller than maximum encoded size");
SliceOutput output = slice.getOutput();

byte[] controlBytes = new byte[controlBytesTableSize(values.length)];
output.writeBytes(controlBytes);

int dataIndex = 0;
while (dataIndex < values.length) {
int groupSize = Math.min(GROUP_SIZE, values.length - dataIndex);
byte controlByte = 0;

for (int i = 0; i < groupSize; i++) {
byte byteSize = getRequiredBytes(values[dataIndex + i]);
controlByte |= (byte) ((byteSize - 1) << (i * 2));
writeValue(output, values[dataIndex + i], byteSize);
}

controlBytes[dataIndex / GROUP_SIZE] = controlByte;
dataIndex += groupSize;
}

slice.setBytes(0, controlBytes); // write control bytes in a single pass
return output.slice();
}

public static int maxEncodedSize(int size)
{
return size * SIZE_OF_INT + controlBytesTableSize(size);
}

@VisibleForTesting
static int controlBytesTableSize(int size)
{
return (size + GROUP_SIZE - 1) / GROUP_SIZE;
}

public static int[] decode(Slice slice, int size)
{
byte[] controlBytes = new byte[controlBytesTableSize(size)];

SliceInput input = slice.getInput();
input.readBytes(controlBytes);
int[] decoded = new int[size];
int index = 0;
for (byte controlByte : controlBytes) {
for (int i = 0; i < GROUP_SIZE && index < size; i++) {
byte valueSize = (byte) (((controlByte >> (i * 2)) & 0x03) + 1);
decoded[index++] = readValue(input, valueSize);
}
}
return decoded;
}

@VisibleForTesting
static byte getRequiredBytes(int value)
{
if ((value & 0xFFFFFF80) == 0) {
return 1;
}
if ((value & 0xFFFF8000) == 0) {
return 2;
}
if ((value & 0xFF800000) == 0) {
return 3;
}
return 4;
}

private static void writeValue(SliceOutput buffer, int value, byte byteSize)
{
switch (byteSize) {
case 1:
buffer.writeByte(value);
return;
case 2:
buffer.writeShort(value);
return;
case 3:
buffer.writeByte(value);
buffer.writeShort(value >>> 8);
return;
case 4:
buffer.writeInt(value);
return;
}
throw new IllegalArgumentException("Invalid byte size: " + byteSize);
}

private static int readValue(SliceInput input, byte byteSize)
{
return switch (byteSize) {
case 1 -> input.readByte();
case 2 -> input.readShort();
case 3 -> input.readByte() | (input.readShort() << 8);
case 4 -> input.readInt();
default -> throw new IllegalArgumentException("Invalid byte size: " + byteSize);
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.execution.buffer.vstream;

import com.google.common.annotations.VisibleForTesting;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;

import static com.google.common.base.Verify.verify;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;

public class LongStreamVByte
{
private static final int GROUP_SIZE = 16 / SIZE_OF_LONG;

private LongStreamVByte() {}

public static Slice encode(long[] values, Slice slice)
{
verify(slice.length() >= maxEncodedSize(values.length), "Slice is smaller than maximum encoded size");
SliceOutput output = slice.getOutput();

byte[] controlBytes = new byte[controlBytesTableSize(values.length)];
output.writeBytes(controlBytes);

int dataIndex = 0;
while (dataIndex < values.length) {
int groupSize = Math.min(GROUP_SIZE, values.length - dataIndex);
byte controlByte = 0;

for (int i = 0; i < groupSize; i++) {
int byteSize = getRequiredBytes(values[dataIndex + i]);
controlByte |= (byte) ((byteSize - 1) << (i * 4));
writeValue(output, values[dataIndex + i], byteSize);
}
controlBytes[dataIndex / GROUP_SIZE] = controlByte;
dataIndex += groupSize;
}

slice.setBytes(0, controlBytes);
return output.slice();
}

public static int maxEncodedSize(int size)
{
return (size * SIZE_OF_LONG) + controlBytesTableSize(size);
}

@VisibleForTesting
static int controlBytesTableSize(int size)
{
return (size + GROUP_SIZE - 1) / GROUP_SIZE;
}

public static long[] decode(Slice slice, int size)
{
byte[] controlBytes = new byte[controlBytesTableSize(size)];
SliceInput input = slice.getInput();
input.readBytes(controlBytes);

long[] decoded = new long[size];
int index = 0;
for (byte controlByte : controlBytes) {
for (int i = 0; i < GROUP_SIZE && index < size; i++) {
int valueSize = ((controlByte >> (i * 4)) & 0x0F) + 1;
decoded[index++] = readValue(input, valueSize);
}
}
return decoded;
}

@VisibleForTesting
static int getRequiredBytes(long value)
{
if ((value & 0xFFFFFFFFFFFFFF80L) == 0) {
return 1;
}
if ((value & 0xFFFFFFFFFFFF8000L) == 0) {
return 2;
}
if ((value & 0xFFFFFFFFFF800000L) == 0) {
return 3;
}
if ((value & 0xFFFFFFFF80000000L) == 0) {
return 4;
}
if ((value & 0xFFFFFF8000000000L) == 0) {
return 5;
}
if ((value & 0xFFFF800000000000L) == 0) {
return 6;
}
if ((value & 0xFF80000000000000L) == 0) {
return 7;
}
return 8;
}

private static void writeValue(SliceOutput buffer, long value, int byteSize)
{
switch (byteSize) {
case 1:
buffer.writeByte((byte) value);
return;
case 2:
buffer.writeShort((short) value);
return;
case 3:
buffer.writeByte((byte) (value & 0xFF));
buffer.writeShort((short) (value >>> 8));
return;
case 4:
buffer.writeInt((int) value);
return;
case 5:
buffer.writeByte((byte) (value & 0xFF));
buffer.writeInt((int) (value >>> 8));
return;
case 6:
buffer.writeShort((short) value);
buffer.writeInt((int) (value >>> 16));
return;
case 7:
buffer.writeByte((byte) value);
buffer.writeShort((short) (value >>> 8));
buffer.writeInt((int) (value >>> 24));
return;
case 8:
buffer.writeLong(value);
return;
}

throw new IllegalArgumentException("Invalid byte size: " + byteSize);
}

private static long readValue(SliceInput input, int byteSize)
{
return switch (byteSize) {
case 1 -> input.readByte();
case 2 -> input.readShort();
case 3 -> (input.readByte() & 0xFF) | (input.readShort() << 8);
case 4 -> input.readInt();
case 5 -> (input.readByte() & 0xFF) | ((long) input.readInt() << 8);
case 6 -> input.readShort() | ((long) input.readInt() << 16);
case 7 -> (input.readByte() & 0xFF) | (input.readShort() << 8) | ((long) input.readInt() << 24);
case 8 -> input.readLong();
default -> throw new IllegalArgumentException("Invalid byte size: " + byteSize);
};
}
}
Loading

0 comments on commit 21ffa7c

Please sign in to comment.