From dd2c031d954112ed15114ba582827cf630d385de Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 22 Nov 2022 13:29:24 -0500 Subject: [PATCH] [java] Sparse tensor support (#10653) **Description**: Adds support for creating and receiving sparse tensors in the ORT Java API. CSRC and COO tensors as inputs are tested, but there is no op which accepts a block sparse tensor to test. COO tensors are tested as outputs, but there is no op which emits a CSRC or block sparse tensor to test. **Motivation and Context** - Why is this change required? What problem does it solve? Request to expose ORT sparse tensor support in Java. cc @yuslepukhin --- .../java/ai/onnxruntime/OnnxSparseTensor.java | 920 ++++++++++++++++++ .../main/java/ai/onnxruntime/OnnxTensor.java | 99 +- .../java/ai/onnxruntime/OnnxTensorLike.java | 59 ++ .../main/java/ai/onnxruntime/OnnxValue.java | 10 +- .../main/java/ai/onnxruntime/OrtSession.java | 17 +- .../src/main/java/ai/onnxruntime/OrtUtil.java | 93 +- .../main/java/ai/onnxruntime/TensorInfo.java | 33 + java/src/main/native/OrtJniUtil.c | 83 +- java/src/main/native/OrtJniUtil.h | 6 + .../native/ai_onnxruntime_OnnxSparseTensor.c | 534 ++++++++++ .../java/ai/onnxruntime/InferenceTest.java | 4 + .../java/ai/onnxruntime/SparseTensorTest.java | 450 +++++++++ .../generic_sparse_to_dense_matmul.onnx | 16 + 13 files changed, 2218 insertions(+), 106 deletions(-) create mode 100644 java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java create mode 100644 java/src/main/java/ai/onnxruntime/OnnxTensorLike.java create mode 100644 java/src/main/native/ai_onnxruntime_OnnxSparseTensor.c create mode 100644 java/src/test/java/ai/onnxruntime/SparseTensorTest.java create mode 100644 java/testdata/generic_sparse_to_dense_matmul.onnx diff --git a/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java new file mode 100644 index 0000000000000..668e6e07ceccd --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java @@ -0,0 +1,920 @@ +/* + * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import static ai.onnxruntime.OnnxTensor.fp16ToFloat; + +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.nio.ShortBuffer; +import java.util.Arrays; + +/** + * A Java object wrapping an OnnxSparseTensor. + * + *

Sparse tensors support a variety of formats, and the {@link #getValue} method returns a + * different static inner class representing each type. + */ +public final class OnnxSparseTensor extends OnnxTensorLike { + private final SparseTensorType sparseTensorType; + + // Held to prevent deallocation while used in native code. + private final Buffer indices; + private final LongBuffer innerIndices; + private final Buffer values; + + /** + * Construct a sparse tensor from JNI. + * + * @param nativeHandle The tensor native handle. + * @param allocatorHandle The allocator handle. + * @param sparseType The sparsity type. + * @param info The tensor info. + */ + OnnxSparseTensor(long nativeHandle, long allocatorHandle, int sparseType, TensorInfo info) { + this( + nativeHandle, + allocatorHandle, + SparseTensorType.mapFromInt(sparseType), + info, + null, + null, + null); + } + + /** + * Construct a COO or block sparse tensor. + * + * @param nativeHandle The tensor native handle. + * @param allocatorHandle The allocator handle. + * @param sparseType The sparsity type. + * @param info The tensor info. + * @param indices The indices buffer. + * @param values The data buffer. + */ + OnnxSparseTensor( + long nativeHandle, + long allocatorHandle, + SparseTensorType sparseType, + TensorInfo info, + Buffer indices, + Buffer values) { + this(nativeHandle, allocatorHandle, sparseType, info, indices, null, values); + } + + /** + * Construct a sparse tensor. + * + *

If the tensor is COO or block sparse then innerIndices may be null. + * + * @param nativeHandle The tensor native handle. + * @param allocatorHandle The allocator handle. + * @param sparseType The sparsity type. + * @param info The tensor info. + * @param indices The indices buffer. + * @param innerIndices The inner indices buffer. + * @param values The data buffer. + */ + OnnxSparseTensor( + long nativeHandle, + long allocatorHandle, + SparseTensorType sparseType, + TensorInfo info, + Buffer indices, + LongBuffer innerIndices, + Buffer values) { + super(nativeHandle, allocatorHandle, info); + this.sparseTensorType = sparseType; + this.indices = indices; + this.innerIndices = innerIndices; + this.values = values; + } + + /** + * Creates a Sparse Tensor in ORT from the Java side representation. + * + * @param env The OrtEnvironment. + * @param tensor The Java side representation. + * @param The buffer type. + * @return The sparse tensor in ORT. + * @throws OrtException If the tensor could not be created or was invalid. + */ + public static OnnxSparseTensor createSparseTensor( + OrtEnvironment env, SparseTensor tensor) throws OrtException { + return createSparseTensor(env, env.defaultAllocator, tensor); + } + + static OnnxSparseTensor createSparseTensor( + OrtEnvironment env, OrtAllocator allocator, SparseTensor tensor) throws OrtException { + if (!allocator.isClosed()) { + TensorInfo info = TensorInfo.constructFromSparseTensor(tensor); + OnnxJavaType indicesType = tensor.getIndicesType(); + OrtUtil.BufferTuple indicesTuple = OrtUtil.prepareBuffer(tensor.getIndices(), indicesType); + OrtUtil.BufferTuple valuesTuple = OrtUtil.prepareBuffer(tensor.getValues(), info.type); + if (!((indicesTuple.data instanceof LongBuffer) + || (indicesTuple.data instanceof IntBuffer))) { + throw new IllegalStateException( + "Unexpected type of indices buffer, found " + + indicesTuple.data.getClass() + + ", expected IntBuffer or LongBuffer"); + } + // Replace with a type switch when using JDK 17+. + switch (tensor.getSparsityType()) { + case COO: + case BLOCK_SPARSE: + return new OnnxSparseTensor( + createSparseTensorFromBuffer( + OnnxRuntime.ortApiHandle, + allocator.handle, + indicesTuple.data, + indicesTuple.pos, + indicesTuple.size, + valuesTuple.data, + valuesTuple.pos, + info.shape, + tensor.getIndicesShape(), + tensor.getValuesShape(), + info.onnxType.value, + tensor.getSparsityType().value), + allocator.handle, + tensor.getSparsityType(), + info, + indicesTuple.data, + valuesTuple.data); + case CSRC: + OrtUtil.BufferTuple innerIndicesTuple = + OrtUtil.prepareBuffer(((CSRCTensor) tensor).getInnerIndices(), indicesType); + return new OnnxSparseTensor( + createCSRCSparseTensorFromBuffer( + OnnxRuntime.ortApiHandle, + allocator.handle, + indicesTuple.data, + indicesTuple.pos, + indicesTuple.size, + innerIndicesTuple.data, + innerIndicesTuple.pos, + innerIndicesTuple.size, + valuesTuple.data, + valuesTuple.pos, + info.shape, + tensor.getValuesShape(), + info.onnxType.value), + allocator.handle, + tensor.getSparsityType(), + info, + indicesTuple.data, + (LongBuffer) innerIndicesTuple.data, + valuesTuple.data); + case UNDEFINED: + default: + throw new IllegalArgumentException("Cannot create an UNDEFINED sparse tensor."); + } + } else { + throw new IllegalStateException( + "Trying to create an OnnxSparseTensor on a closed OrtAllocator."); + } + } + + @Override + public OnnxValueType getType() { + return OnnxValueType.ONNX_TYPE_SPARSETENSOR; + } + + @Override + public SparseTensor getValue() throws OrtException { + Buffer buffer = getValuesBuffer(); + long[] indicesShape = getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); + switch (sparseTensorType) { + case COO: + return new COOTensor( + (LongBuffer) getIndicesBuffer(), + indicesShape, + buffer, + info.shape, + info.type, + buffer.remaining()); + case CSRC: + return new CSRCTensor( + (LongBuffer) getIndicesBuffer(), + getInnerIndicesBuffer(), + buffer, + info.shape, + info.type, + buffer.remaining()); + case BLOCK_SPARSE: + long[] valuesShape = getValuesShape(OnnxRuntime.ortApiHandle, nativeHandle); + return new BlockSparseTensor( + (IntBuffer) getIndicesBuffer(), + indicesShape, + buffer, + valuesShape, + info.shape, + info.type, + buffer.remaining()); + case UNDEFINED: + default: + throw new IllegalStateException("Undefined sparsity type in this sparse tensor."); + } + } + + @Override + public void close() { + close(OnnxRuntime.ortApiHandle, nativeHandle); + } + + /** + * Returns the type of this OnnxSparseTensor. + * + * @return The sparsity type. + */ + public SparseTensorType getSparseTensorType() { + return sparseTensorType; + } + + /** + * Gets a copy of the indices. + * + *

These are the outer indices if it's a CSRC sparse tensor. + * + *

It's a {@link LongBuffer} if COO or CSRC, and {@link IntBuffer} if Block Sparse. + * + * @return The indices. + */ + public Buffer getIndicesBuffer() { + switch (sparseTensorType) { + case COO: + case CSRC: + { + LongBuffer longBuf = + getIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle) + .order(ByteOrder.nativeOrder()) + .asLongBuffer(); + LongBuffer output = LongBuffer.allocate(longBuf.capacity()); + output.put(longBuf); + output.rewind(); + return output; + } + case BLOCK_SPARSE: + { + IntBuffer intBuf = + getIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle) + .order(ByteOrder.nativeOrder()) + .asIntBuffer(); + IntBuffer output = IntBuffer.allocate(intBuf.capacity()); + output.put(intBuf); + output.rewind(); + return output; + } + case UNDEFINED: + default: + throw new IllegalStateException("UNDEFINED sparse tensor type."); + } + } + + /** + * Gets a copy of the inner indices in a CSRC sparse tensor. + * + *

Throws {@link IllegalStateException} if called on a different sparse tensor type. + * + * @return The inner indices. + */ + public LongBuffer getInnerIndicesBuffer() { + if (sparseTensorType == SparseTensorType.CSRC) { + LongBuffer buf = + getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle) + .order(ByteOrder.nativeOrder()) + .asLongBuffer(); + LongBuffer output = LongBuffer.allocate(buf.capacity()); + output.put(buf); + output.rewind(); + return output; + } else { + throw new IllegalStateException( + "Inner indices are only available for CSRC sparse tensors, this sparse tensor is " + + sparseTensorType); + } + } + + /** + * Gets a copy of the data buffer. + * + *

As with {@link OnnxTensor} fp16 values are upcast into fp32 and returned as a {@link + * FloatBuffer}. + * + * @return The data buffer. + */ + public Buffer getValuesBuffer() { + ByteBuffer buffer = + getValuesBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder()); + switch (info.type) { + case FLOAT: + if (info.onnxType == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) { + ShortBuffer shortBuffer = buffer.asShortBuffer(); + int bufferCap = shortBuffer.capacity(); + FloatBuffer output = FloatBuffer.allocate(bufferCap); + for (int i = 0; i < bufferCap; i++) { + output.put(fp16ToFloat(shortBuffer.get(i))); + } + output.rewind(); + return output; + } else if (info.onnxType + == TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) { + throw new IllegalArgumentException("BFloat16 is not supported."); + } else { + // regular fp32 + FloatBuffer floatBuf = buffer.asFloatBuffer(); + FloatBuffer output = FloatBuffer.allocate(floatBuf.capacity()); + output.put(floatBuf); + output.rewind(); + return output; + } + case DOUBLE: + { + DoubleBuffer doubleBuf = buffer.asDoubleBuffer(); + DoubleBuffer output = DoubleBuffer.allocate(doubleBuf.capacity()); + output.put(doubleBuf); + output.rewind(); + return output; + } + case INT16: + { + ShortBuffer shortBuf = buffer.asShortBuffer(); + ShortBuffer output = ShortBuffer.allocate(shortBuf.capacity()); + output.put(shortBuf); + output.rewind(); + return output; + } + case INT32: + { + IntBuffer intBuf = buffer.asIntBuffer(); + IntBuffer output = IntBuffer.allocate(intBuf.capacity()); + output.put(intBuf); + output.rewind(); + return output; + } + case INT64: + { + LongBuffer longBuf = buffer.asLongBuffer(); + LongBuffer output = LongBuffer.allocate(longBuf.capacity()); + output.put(longBuf); + output.rewind(); + return output; + } + case BOOL: + case INT8: + case UINT8: + { + ByteBuffer output = ByteBuffer.allocate(buffer.capacity()); + output.put(buffer); + output.rewind(); + return output; + } + case STRING: + throw new IllegalStateException("Unsupported data type String"); + case UNKNOWN: + default: + throw new IllegalStateException("Unsupported data type"); + } + } + + /** + * Gets the shape of the (outer) indices. + * + * @return The indices shape. + */ + public long[] getIndicesShape() { + return getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); + } + + /** + * Gets the shape of the inner indices in a CSRC sparse tensor. + * + * @return The indices shape. + */ + public long[] getInnerIndicesShape() { + if (sparseTensorType == SparseTensorType.CSRC) { + return getInnerIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); + } else { + throw new IllegalStateException( + "Inner indices are only available for CSRC sparse tensors, this sparse tensor is " + + sparseTensorType); + } + } + + /** + * Gets the shape of the values. + * + * @return The values shape. + */ + public long[] getValuesShape() { + return getValuesShape(OnnxRuntime.ortApiHandle, nativeHandle); + } + + /** + * Gets the shape of the (outer) indices. + * + * @param apiHandle The OrtApi pointer. + * @param nativeHandle The OrtSparseTensor pointer. + * @return The indices shape. + */ + private native long[] getIndicesShape(long apiHandle, long nativeHandle); + + /** + * Gets the shape of the inner indices. + * + * @param apiHandle The OrtApi pointer. + * @param nativeHandle The OrtSparseTensor pointer. + * @return The inner indices shape. + */ + private native long[] getInnerIndicesShape(long apiHandle, long nativeHandle); + + /** + * Gets the shape of the values. + * + * @param apiHandle The OrtApi pointer. + * @param nativeHandle The OrtSparseTensor pointer. + * @return The values shape. + */ + private native long[] getValuesShape(long apiHandle, long nativeHandle); + + /** + * Wraps the indices in a direct byte buffer. + * + * @param apiHandle The OrtApi pointer. + * @param nativeHandle The OrtSparseTensor pointer. + * @return A ByteBuffer wrapping the indices. + */ + private native ByteBuffer getIndicesBuffer(long apiHandle, long nativeHandle); + + /** + * Wraps the inner indices in a direct byte buffer. + * + * @param apiHandle The OrtApi pointer. + * @param nativeHandle The OrtSparseTensor pointer. + * @return A ByteBuffer wrapping the inner indices. + */ + private native ByteBuffer getInnerIndicesBuffer(long apiHandle, long nativeHandle); + + /** + * Wraps the data in a direct byte buffer. + * + * @param apiHandle The OrtApi pointer. + * @param nativeHandle The OrtSparseTensor pointer. + * @return A ByteBuffer wrapping the indices. + */ + private native ByteBuffer getValuesBuffer(long apiHandle, long nativeHandle); + + /** + * Closes the sparse tensor. + * + * @param apiHandle The OrtApi pointer. + * @param nativeHandle The OrtSparseTensor pointer. + */ + private native void close(long apiHandle, long nativeHandle); + + /** + * Creates a sparse CSRC sparse tensor. + * + *

The buffers must be kept alive for the lifetime of the ORT sparse tensor object. + * + * @param apiHandle The ORT API pointer. + * @param allocatorHandle The allocator pointer. + * @param indicesData The outer indices. + * @param indicesBufferPos The outer indices position in bytes. + * @param indicesBufferSize The outer indices buffer size in longs. + * @param innerIndicesData The inner indices. + * @param innerIndicesBufferPos The inner indices position in bytes. + * @param innerIndicesBufferSize The inner indices buffer size in longs. + * @param values The data. + * @param bufferPos The data position in bytes. + * @param denseShape The dense shape of the tensor. + * @param valuesShape The shape of the values (should be a vector). + * @param onnxType The type of the values. + * @return A pointer to an ORT sparse tensor value. + * @throws OrtException If the tensor could not be created. + */ + private static native long createCSRCSparseTensorFromBuffer( + long apiHandle, + long allocatorHandle, + Buffer indicesData, + int indicesBufferPos, + long indicesBufferSize, + Buffer innerIndicesData, + int innerIndicesBufferPos, + long innerIndicesBufferSize, + Buffer values, + int bufferPos, + long[] denseShape, + long[] valuesShape, + int onnxType) + throws OrtException; + + /** + * Creates a sparse COO or block sparse tensor. + * + *

The buffers must be kept alive for the lifetime of the ORT sparse tensor object. + * + * @param apiHandle The ORT API pointer. + * @param allocatorHandle The allocator pointer. + * @param indicesData The indices. + * @param indicesBufferPos The indices position in bytes. + * @param indicesBufferSize The indices buffer size in longs. + * @param values The data. + * @param bufferPos The data position in bytes. + * @param denseShape The dense shape of the tensor. + * @param indicesShape The shape of the indices (a vector or matrix for COO, and a matrix for + * block sparse). + * @param valuesShape The shape of the values (a vector for COO, and a block shape for block + * sparse). + * @param onnxType The type of the values. + * @param sparsityType The sparsity type. + * @return A pointer to an ORT sparse tensor value. + * @throws OrtException If the tensor could not be created. + */ + private static native long createSparseTensorFromBuffer( + long apiHandle, + long allocatorHandle, + Buffer indicesData, + int indicesBufferPos, + long indicesBufferSize, + Buffer values, + int bufferPos, + long[] denseShape, + long[] indicesShape, + long[] valuesShape, + int onnxType, + int sparsityType) + throws OrtException; + + /** + * The type of the sparse tensor. + * + *

Should be synchronized with OrtSparseFormat in the C API. + */ + public enum SparseTensorType { + /** Undefined sparse tensor. */ + UNDEFINED(0), + /** COO sparse tensor. */ + COO(1), + /** CSR or CSC sparse tensor. */ + CSRC(2), + /** Block sparse tensor. */ + BLOCK_SPARSE(4); + + /** The int value mirroring OrtSparseFormat. */ + public final int value; + + private static final SparseTensorType[] values = new SparseTensorType[5]; + + static { + values[0] = UNDEFINED; + values[1] = COO; + values[2] = CSRC; + values[3] = UNDEFINED; + values[4] = BLOCK_SPARSE; + } + + SparseTensorType(int value) { + this.value = value; + } + + /** + * Maps from an int in native land into a SparseTensorType instance. + * + * @param value The value to lookup. + * @return The enum instance. + */ + public static SparseTensorType mapFromInt(int value) { + if ((value > 0) && (value < values.length)) { + return values[value]; + } else { + return UNDEFINED; + } + } + } + + /** + * Abstract base class for Java sparse tensors + * + *

Will be sealed to {@link COOTensor}, {@link CSRCTensor} and {@link BlockSparseTensor} one + * day. + */ + public abstract static class SparseTensor { + private final long[] indicesShape; + private final long[] valuesShape; + private final long[] denseShape; + private final OnnxJavaType type; + private final long numNonZero; + + final T indices; + final Buffer values; + + SparseTensor( + T indices, + long[] indicesShape, + Buffer values, + long[] valuesShape, + long[] denseShape, + OnnxJavaType type, + long numNonZero) { + this.indices = indices; + this.indicesShape = indicesShape; + this.values = values; + this.valuesShape = valuesShape; + this.denseShape = denseShape; + this.type = type; + this.numNonZero = numNonZero; + if (values.remaining() != numNonZero) { + throw new IllegalArgumentException( + "Expected numNonZero and data.remaining to be equal, found " + + numNonZero + + " and " + + values.remaining() + + " respectively"); + } + if (type == OnnxJavaType.STRING) { + throw new IllegalArgumentException("String SparseTensors are not supported."); + } + } + + /** + * Gets the dense shape of the sparse tensor. + * + * @return The sparse tensor shape. + */ + public long[] getDenseShape() { + return denseShape; + } + + /** + * The data type of the sparse tensor. + * + * @return The sparse tensor data type. + */ + public OnnxJavaType getType() { + return type; + } + + /** + * The number of non-zero elements. + * + * @return The number of non-zero elements. + */ + public long getNumNonZeroElements() { + return numNonZero; + } + + /** + * Get the indices buffer. + * + * @return The indices buffer. + */ + public T getIndices() { + return indices; + } + + /** + * Get the value buffer. + * + * @return The value buffer. + */ + public Buffer getValues() { + return values; + } + + /** + * Gets the shape of the values of the sparse tensor. + * + * @return The sparse tensor value shape. + */ + public long[] getValuesShape() { + return valuesShape; + } + + /** + * Gets the shape of the indices of the sparse tensor. + * + * @return The sparse tensor indices shape. + */ + public long[] getIndicesShape() { + return indicesShape; + } + + /** + * The sparsity type of the sparse tensor. + * + * @return The sparse tensor sparsity type. + */ + public abstract SparseTensorType getSparsityType(); + + /** + * The indices type of the sparse tensor. + * + *

Only {@link OnnxJavaType#INT32} and {@link OnnxJavaType#INT64} are supported. + * + * @return The sparse tensor indices type. + */ + public abstract OnnxJavaType getIndicesType(); + } + + /** The Java side representation of a COO sparse tensor. */ + public static final class COOTensor extends SparseTensor { + /** + * Creates a COO sparse tensor suitable for constructing an ORT Sparse Tensor. + * + * @param indices The indices. Should be a 1d vector, or a 2d vector. + * @param indicesShape The shape of the indices. + * @param values The data. + * @param denseShape The dense shape. + * @param type The data type. + * @param numNonZero The number of non-zero elements. + */ + public COOTensor( + LongBuffer indices, + long[] indicesShape, + Buffer values, + long[] denseShape, + OnnxJavaType type, + long numNonZero) { + super(indices, indicesShape, values, new long[] {numNonZero}, denseShape, type, numNonZero); + if ((indicesShape.length > 2) + || (indicesShape.length == 0) + || (indicesShape[0] != numNonZero)) { + throw new IllegalArgumentException( + "Invalid indices shape, expected [numNonZero, dimension] or [numNonZero] found " + + Arrays.toString(indicesShape)); + } + long elementCount = OrtUtil.elementCount(indicesShape); + if (elementCount != indices.remaining()) { + throw new IllegalArgumentException( + "Unexpected number of indices found in buffer, expected " + + elementCount + + " found " + + indices.remaining()); + } + if (values.remaining() != numNonZero) { + throw new IllegalArgumentException( + "Expected data.remaining() - " + + values.remaining() + + " to equal numNonZero - " + + numNonZero); + } + } + + @Override + public OnnxJavaType getIndicesType() { + return OnnxJavaType.INT64; + } + + @Override + public SparseTensorType getSparsityType() { + return SparseTensorType.COO; + } + } + + /** The Java side representation of a CSRC sparse tensor. */ + public static final class CSRCTensor extends SparseTensor { + private final LongBuffer innerIndices; + + /** + * Creates a CSRC sparse tensor suitable for constructing an ORT Sparse Tensor. + * + * @param outerIndices The outer indices. + * @param innerIndices The inner indices. + * @param values The data. + * @param denseShape The dense shape. + * @param type The data type. + * @param numNonZero The number of non-zero elements. + */ + public CSRCTensor( + LongBuffer outerIndices, + LongBuffer innerIndices, + Buffer values, + long[] denseShape, + OnnxJavaType type, + long numNonZero) { + super( + outerIndices, + new long[] {outerIndices.remaining()}, + values, + new long[] {numNonZero}, + denseShape, + type, + numNonZero); + this.innerIndices = innerIndices; + long expectedRows = denseShape[0] + 1; + if (outerIndices.remaining() != expectedRows) { + throw new IllegalArgumentException( + "Outer indices should be equal to the number of rows + 1 in the dense shape, found " + + outerIndices.remaining() + + ", expected " + + expectedRows); + } + if (innerIndices.remaining() != numNonZero) { + throw new IllegalArgumentException( + "Inner indices should be equal to the number of non-zero elements, found " + + innerIndices.remaining() + + ", expected " + + numNonZero); + } + } + + /** + * Gets the shape of the inner indices. + * + * @return The inner indices shape. + */ + public long[] getInnerIndicesShape() { + return new long[] {innerIndices.remaining()}; + } + + /** + * Gets the inner indices buffer. + * + * @return The inner indices buffer. + */ + public LongBuffer getInnerIndices() { + return innerIndices; + } + + @Override + public OnnxJavaType getIndicesType() { + return OnnxJavaType.INT64; + } + + @Override + public SparseTensorType getSparsityType() { + return SparseTensorType.CSRC; + } + } + + /** The Java side representation of a block sparse tensor. */ + public static final class BlockSparseTensor extends SparseTensor { + /** + * Construct a block sparse tensor. + * + * @param indices The indices. + * @param indicesShape The shape of the indices. + * @param values The data. + * @param valuesShape The shape of the data. + * @param denseShape The dense shape. + * @param type The data type. + * @param numNonZero The number of non-zero elements. + */ + public BlockSparseTensor( + IntBuffer indices, + long[] indicesShape, + Buffer values, + long[] valuesShape, + long[] denseShape, + OnnxJavaType type, + long numNonZero) { + super(indices, indicesShape, values, valuesShape, denseShape, type, numNonZero); + if (OrtUtil.elementCount(valuesShape) != numNonZero) { + throw new IllegalArgumentException( + "Expected " + + numNonZero + + " entries in the data shape, found " + + Arrays.toString(valuesShape)); + } + if (numNonZero != values.remaining()) { + throw new IllegalArgumentException( + "Expected " + numNonZero + " elements in the data buffer, found " + values.remaining()); + } + if (OrtUtil.elementCount(indicesShape) != indices.remaining()) { + throw new IllegalArgumentException( + "Expected " + + OrtUtil.elementCount(indicesShape) + + " elements in the indices buffer, found " + + indices.remaining()); + } + if (valuesShape.length < 3) { + throw new IllegalArgumentException( + "Expected [numBlocks, blockSize, blockSize] or larger, but data shape was " + + Arrays.toString(valuesShape)); + } + if (indicesShape.length < 2) { + throw new IllegalArgumentException( + "Expected [numBlocks, co-ordinates] or larger, but indices shape was " + + Arrays.toString(indicesShape)); + } + } + + @Override + public OnnxJavaType getIndicesType() { + return OnnxJavaType.INT32; + } + + @Override + public SparseTensorType getSparsityType() { + return SparseTensorType.BLOCK_SPARSE; + } + } +} diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 653a3c9be91bd..c5e60a3dbaf51 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -1,10 +1,9 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; -import java.io.IOException; import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -18,21 +17,7 @@ * A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be * returned as outputs. */ -public class OnnxTensor implements OnnxValue { - static { - try { - OnnxRuntime.init(); - } catch (IOException e) { - throw new RuntimeException("Failed to load onnx-runtime library", e); - } - } - - private final long nativeHandle; - - private final long allocatorHandle; - - private final TensorInfo info; - +public class OnnxTensor extends OnnxTensorLike { /** * This reference is held for OnnxTensors backed by a Java nio buffer to ensure the buffer does * not go out of scope while the OnnxTensor exists. @@ -44,9 +29,7 @@ public class OnnxTensor implements OnnxValue { } OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer) { - this.nativeHandle = nativeHandle; - this.allocatorHandle = allocatorHandle; - this.info = info; + super(nativeHandle, allocatorHandle, info); this.buffer = buffer; } @@ -55,10 +38,6 @@ public OnnxValueType getType() { return OnnxValueType.ONNX_TYPE_TENSOR; } - long getNativeHandle() { - return nativeHandle; - } - /** * Either returns a boxed primitive if the Tensor is a scalar, or a multidimensional array of * primitives if it has multiple dimensions. @@ -108,11 +87,6 @@ public Object getValue() throws OrtException { } } - @Override - public TensorInfo getInfo() { - return info; - } - @Override public String toString() { return "OnnxTensor(info=" + info.toString() + ")"; @@ -300,7 +274,7 @@ private native void getArray(long apiHandle, long nativeHandle, Object carrier) * @param input A uint16_t representing an IEEE half precision float. * @return A float. */ - private static float fp16ToFloat(short input) { + static float fp16ToFloat(short input) { int output = ((input & 0x8000) << 16) | (((input & 0x7c00) + 0x1C000) << 13) | ((input & 0x03FF) << 13); return Float.intBitsToFloat(output); @@ -715,73 +689,20 @@ static OnnxTensor createTensor( */ private static OnnxTensor createTensor( OnnxJavaType type, OrtAllocator allocator, Buffer data, long[] shape) throws OrtException { - int bufferPos; - long bufferSizeLong = data.remaining() * (long) type.size; - if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) { - // The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending - // on the JVM, so we check for something 8 elements below the maximum size which - // should be allocatable (assuming there is enough memory) on all 64-bit JVMs. - throw new IllegalStateException( - "Cannot allocate a direct buffer of the requested size and type, size " - + data.remaining() - + ", type = " - + type); - } - // Now we know we're in range - int bufferSize = data.remaining() * type.size; - Buffer tmp; - if (data.isDirect()) { - tmp = data; - bufferPos = data.position() * type.size; - } else { - // Copy the data to a new direct buffer, then restore the state of the input. - int origPosition = data.position(); - ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); - switch (type) { - case FLOAT: - tmp = buffer.asFloatBuffer().put((FloatBuffer) data); - break; - case DOUBLE: - tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data); - break; - case UINT8: - case INT8: - // buffer is already a ByteBuffer, no cast needed. - tmp = buffer.put((ByteBuffer) data); - break; - case INT16: - tmp = buffer.asShortBuffer().put((ShortBuffer) data); - break; - case INT32: - tmp = buffer.asIntBuffer().put((IntBuffer) data); - break; - case INT64: - tmp = buffer.asLongBuffer().put((LongBuffer) data); - break; - case BOOL: - case STRING: - case UNKNOWN: - default: - throw new IllegalStateException( - "Impossible to reach here, managed to cast a buffer as an incorrect type"); - } - data.position(origPosition); - tmp.rewind(); - bufferPos = 0; - } - TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type); + OrtUtil.BufferTuple tuple = OrtUtil.prepareBuffer(data, type); + TensorInfo info = TensorInfo.constructFromBuffer(tuple.data, shape, type); return new OnnxTensor( createTensorFromBuffer( OnnxRuntime.ortApiHandle, allocator.handle, - tmp, - bufferPos, - bufferSize, + tuple.data, + tuple.pos, + tuple.byteSize, shape, info.onnxType.value), allocator.handle, info, - tmp); + tuple.data); } private static native long createTensor( diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java b/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java new file mode 100644 index 0000000000000..d0c874c4a1e15 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.io.IOException; + +/** + * Currently implemented by {@link OnnxTensor}, {@link OnnxSparseTensor}. Will be sealed to these + * types one day. + */ +public abstract class OnnxTensorLike implements OnnxValue { + static { + try { + OnnxRuntime.init(); + } catch (IOException e) { + throw new RuntimeException("Failed to load onnx-runtime library", e); + } + } + + protected final long nativeHandle; + + protected final long allocatorHandle; + + protected final TensorInfo info; + + /** + * Constructs a tensor-like (the base class of OnnxTensor and OnnxSparseTensor). + * + * @param nativeHandle The pointer to the tensor. + * @param allocatorHandle The pointer to the memory allocator. + * @param info The tensor info. + */ + OnnxTensorLike(long nativeHandle, long allocatorHandle, TensorInfo info) { + this.nativeHandle = nativeHandle; + this.allocatorHandle = allocatorHandle; + this.info = info; + } + + /** + * Returns the native pointer. + * + * @return The native pointer. + */ + long getNativeHandle() { + return nativeHandle; + } + + /** + * Returns a {@link TensorInfo} for this tensor. + * + * @return The tensor info. + */ + @Override + public TensorInfo getInfo() { + return info; + } +} diff --git a/java/src/main/java/ai/onnxruntime/OnnxValue.java b/java/src/main/java/ai/onnxruntime/OnnxValue.java index db25550d24ed5..8de6b9c58b1f5 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxValue.java +++ b/java/src/main/java/ai/onnxruntime/OnnxValue.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -8,9 +8,8 @@ /** * Top interface for input and output values from ONNX models. Currently implemented by {@link - * OnnxTensor}, {@link OnnxSequence} and {@link OnnxMap}. Will be sealed to these types one day. - * - *

Does not support sparse tensors. + * OnnxTensor}, {@link OnnxSparseTensor}, {@link OnnxSequence} and {@link OnnxMap}. Will be sealed + * to these types one day. */ public interface OnnxValue extends AutoCloseable { @@ -21,7 +20,8 @@ public enum OnnxValueType { ONNX_TYPE_SEQUENCE(2), ONNX_TYPE_MAP(3), ONNX_TYPE_OPAQUE(4), - ONNX_TYPE_SPARSETENSOR(5); + ONNX_TYPE_SPARSETENSOR(5), + ONNX_TYPE_OPTIONAL(6); /** The id number of this type in the C API. */ public final int value; diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 3ae8ca77cbffa..865c42ab14b91 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -203,7 +203,7 @@ public Map getOutputInfo() throws OrtException { * @throws OrtException If there was an error in native code, the input names are invalid, or if * there are zero or too many inputs. */ - public Result run(Map inputs) throws OrtException { + public Result run(Map inputs) throws OrtException { return run(inputs, outputNames); } @@ -218,7 +218,8 @@ public Result run(Map inputs) throws OrtException { * @throws OrtException If there was an error in native code, the input names are invalid, or if * there are zero or too many inputs. */ - public Result run(Map inputs, RunOptions runOptions) throws OrtException { + public Result run(Map inputs, RunOptions runOptions) + throws OrtException { return run(inputs, outputNames, runOptions); } @@ -233,7 +234,7 @@ public Result run(Map inputs, RunOptions runOptions) throws * @throws OrtException If there was an error in native code, the input or output names are * invalid, or if there are zero or too many inputs or outputs. */ - public Result run(Map inputs, Set requestedOutputs) + public Result run(Map inputs, Set requestedOutputs) throws OrtException { return run(inputs, requestedOutputs, null); } @@ -241,7 +242,7 @@ public Result run(Map inputs, Set requestedOutputs) /** * Scores an input feed dict, returning the map of requested inferred outputs. * - *

The outputs are sorted based on the supplied set traveral order. + *

The outputs are sorted based on the supplied set traversal order. * * @param inputs The inputs to score. * @param requestedOutputs The requested outputs. @@ -251,10 +252,12 @@ public Result run(Map inputs, Set requestedOutputs) * invalid, or if there are zero or too many inputs or outputs. */ public Result run( - Map inputs, Set requestedOutputs, RunOptions runOptions) + Map inputs, + Set requestedOutputs, + RunOptions runOptions) throws OrtException { if (!closed) { - if (inputs.isEmpty() || (inputs.size() > numInputs)) { + if ((inputs.isEmpty() && (numInputs != 0)) || (inputs.size() > numInputs)) { throw new OrtException( "Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size()); } @@ -268,7 +271,7 @@ public Result run( String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; int i = 0; - for (Map.Entry t : inputs.entrySet()) { + for (Map.Entry t : inputs.entrySet()) { if (inputNames.contains(t.getKey())) { inputNamesArray[i] = t.getKey(); inputHandles[i] = t.getValue().getNativeHandle(); diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index c13c84df57dba..ca340676e247d 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -1,10 +1,18 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; import java.lang.reflect.Array; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import java.nio.ShortBuffer; import java.util.ArrayList; import java.util.Arrays; @@ -472,4 +480,87 @@ static int capacityFromSize(int size) { // 0.75 is the default JDK load factor return (int) (size / 0.75 + 1); } + + /** + * Prepares a buffer, either copying it if it's not direct, or computing it's size and position if + * it is. + * + * @param data The buffer to prepare. + * @param type The Java-side type. + * @return The prepared buffer tuple. + */ + static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) { + int bufferPos; + long bufferSizeLong = data.remaining() * (long) type.size; + if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) { + // The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending + // on the JVM, so we check for something 8 elements below the maximum size which + // should be allocatable (assuming there is enough memory) on all 64-bit JVMs. + throw new IllegalStateException( + "Cannot allocate a direct buffer of the requested size and type, size " + + data.remaining() + + ", type = " + + type); + } + // Now we know we're in range + int bufferSize = data.remaining() * type.size; + Buffer tmp; + if (data.isDirect()) { + tmp = data; + bufferPos = data.position() * type.size; + } else { + // Copy the data to a new direct buffer, then restore the state of the input. + int origPosition = data.position(); + ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder()); + switch (type) { + case FLOAT: + tmp = buffer.asFloatBuffer().put((FloatBuffer) data); + break; + case DOUBLE: + tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data); + break; + case UINT8: + case INT8: + // buffer is already a ByteBuffer, no cast needed. + tmp = buffer.put((ByteBuffer) data); + break; + case INT16: + tmp = buffer.asShortBuffer().put((ShortBuffer) data); + break; + case INT32: + tmp = buffer.asIntBuffer().put((IntBuffer) data); + break; + case INT64: + tmp = buffer.asLongBuffer().put((LongBuffer) data); + break; + case BOOL: + case STRING: + case UNKNOWN: + default: + throw new IllegalStateException( + "Impossible to reach here, managed to cast a buffer as an incorrect type"); + } + data.position(origPosition); + tmp.rewind(); + bufferPos = 0; + } + + return new BufferTuple(tmp, bufferPos, bufferSize, data.remaining(), tmp != data); + } + + static final class BufferTuple { + final Buffer data; + final int pos; + final long byteSize; + final long size; + final boolean isCopy; + + BufferTuple(Buffer data, int pos, long byteSize, long size, boolean isCopy) { + this.data = data; + this.pos = pos; + this.byteSize = byteSize; + this.size = size; + this.isCopy = isCopy; + } + } } diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 4a7a3b833bc2b..b9b7835da2ee5 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -297,6 +297,39 @@ public static TensorInfo constructFromBuffer(Buffer buffer, long[] shape, OnnxJa Arrays.copyOf(shape, shape.length), type, OnnxTensorType.mapFromJavaType(type)); } + /** + * Constructs a TensorInfo from the supplied {@link OnnxSparseTensor.SparseTensor}. + * + * @param tensor The sparse tensor. + * @param The buffer type. + * @return A TensorInfo for a sparse tensor. + * @throws OrtException If the supplied tensor has too many elements for it's shape. + */ + public static TensorInfo constructFromSparseTensor( + OnnxSparseTensor.SparseTensor tensor) throws OrtException { + long[] shape = tensor.getDenseShape(); + + long elementCount = OrtUtil.elementCount(shape); + + long bufferRemaining = tensor.getValues().remaining(); + + if (elementCount < bufferRemaining) { + throw new OrtException( + "Shape " + + Arrays.toString(shape) + + ", has at most " + + elementCount + + " elements but the buffer has " + + bufferRemaining + + " elements."); + } + + return new TensorInfo( + Arrays.copyOf(shape, shape.length), + tensor.getType(), + OnnxTensorType.mapFromJavaType(tensor.getType())); + } + /** * Extracts the shape from a multidimensional array. Checks to see if the array is ragged or not. * diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index a670179a0eb25..165e0e96c696d 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -68,6 +68,45 @@ ExecutionMode convertExecutionMode(jint mode) { } } +/** + * Must be kept in sync with OrtSparseFormat and OnnxSparseTensor.SparseTensorType + * @param format The Java int. + * @return The enum. + */ +OrtSparseFormat convertToOrtSparseFormat(jint format) { + switch (format) { + case 0: + return ORT_SPARSE_UNDEFINED; + case 1: + return ORT_SPARSE_COO; + case 2: + return ORT_SPARSE_CSRC; + case 4: + return ORT_SPARSE_BLOCK_SPARSE; + default: + return ORT_SPARSE_UNDEFINED; + } +} + +/** + * Must be kept in sync with OrtSparseFormat and OnnxSparseTensor.SparseTensorType + * @param format The enum. + * @return The Java int. + */ +jint convertFromOrtSparseFormat(OrtSparseFormat format) { + switch (format) { + case ORT_SPARSE_COO: + return 1; + case ORT_SPARSE_CSRC: + return 2; + case ORT_SPARSE_BLOCK_SPARSE: + return 4; + case ORT_SPARSE_UNDEFINED: + default: + return 0; + } +} + /** * Must be kept in sync with convertToONNXDataFormat */ @@ -228,7 +267,8 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo } switch (type) { - case ONNX_TYPE_TENSOR: { + case ONNX_TYPE_TENSOR: + case ONNX_TYPE_SPARSETENSOR: { const OrtTensorTypeAndShapeInfo* tensorInfo = NULL; code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToTensorInfo(info, &tensorInfo)); if (code == ORT_OK) { @@ -257,7 +297,6 @@ jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo } case ONNX_TYPE_UNKNOWN: case ONNX_TYPE_OPAQUE: - case ONNX_TYPE_SPARSETENSOR: default: { throwOrtException(jniEnv,convertErrorCode(ORT_NOT_IMPLEMENTED),"Invalid ONNXType found."); return NULL; @@ -869,6 +908,40 @@ jobject createJavaTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocato return javaTensor; } +jobject createJavaSparseTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) { + // Extract the type information + OrtTensorTypeAndShapeInfo* info; + OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor, &info)); + if (code != ORT_OK) { + return NULL; + } + + // Construct the TensorInfo object + jobject tensorInfo = convertToTensorInfo(jniEnv, api, info); + + // Release the info object + api->ReleaseTensorTypeAndShapeInfo(info); + if (tensorInfo == NULL) { + return NULL; + } + + // Lookup the sparse tensor type enum + OrtSparseFormat format; + code = checkOrtStatus(jniEnv,api,api->GetSparseTensorFormat(tensor, &format)); + if (code != ORT_OK) { + return NULL; + } + jint sparseTensorInt = convertFromOrtSparseFormat(format); + + // Construct the ONNXTensor object + char *tensorClassName = "ai/onnxruntime/OnnxSparseTensor"; + jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorClassName); + jmethodID tensorConstructor = (*jniEnv)->GetMethodID(jniEnv, clazz, "", "(JJILai/onnxruntime/TensorInfo;)V"); + jobject javaSparseTensor = (*jniEnv)->NewObject(jniEnv, clazz, tensorConstructor, (jlong) tensor, (jlong) allocator, sparseTensorInt, tensorInfo); + + return javaSparseTensor; +} + jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* sequence) { // Get the sequence info class static const char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo"; @@ -1026,12 +1099,14 @@ jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca case ONNX_TYPE_MAP: { return createJavaMapFromONNX(jniEnv, api, allocator, onnxValue); } + case ONNX_TYPE_SPARSETENSOR: { + return createJavaSparseTensorFromONNX(jniEnv, api, allocator, onnxValue); + } case ONNX_TYPE_UNKNOWN: case ONNX_TYPE_OPAQUE: case ONNX_TYPE_OPTIONAL: - case ONNX_TYPE_SPARSETENSOR: default: { - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_SPARSETENSOR."); + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_OPTIONAL."); return NULL; } } diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 616a20503ad42..10f19c4d250c7 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -28,6 +28,10 @@ GraphOptimizationLevel convertOptimizationLevel(jint level); ExecutionMode convertExecutionMode(jint mode); +OrtSparseFormat convertToOrtSparseFormat(jint format); + +jint convertFromOrtSparseFormat(OrtSparseFormat format); + jint convertFromONNXDataFormat(ONNXTensorElementDataType type); ONNXTensorElementDataType convertToONNXDataFormat(jint type); @@ -68,6 +72,8 @@ jdoubleArray createDoubleArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, Ort jobject createJavaTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor); +jobject createJavaSparseTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor); + jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* sequence); jobject createJavaMapFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* map); diff --git a/java/src/main/native/ai_onnxruntime_OnnxSparseTensor.c b/java/src/main/native/ai_onnxruntime_OnnxSparseTensor.c new file mode 100644 index 0000000000000..30c6b39affe5e --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OnnxSparseTensor.c @@ -0,0 +1,534 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OnnxSparseTensor.h" + +/* + * Class: ai_onnxruntime_OnnxSparseTensor + * Method: getIndicesBuffer + * Signature: (JJ)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getIndicesBuffer + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtValue* ortValue = (const OrtValue*) handle; + OrtSparseFormat format; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format)); + if (code != ORT_OK) { + return NULL; + } + enum OrtSparseIndicesFormat indicesFormat; + switch (format) { + case ORT_SPARSE_COO: + indicesFormat = ORT_SPARSE_COO_INDICES; + break; + case ORT_SPARSE_CSRC: + indicesFormat = ORT_SPARSE_CSR_OUTER_INDICES; + break; + case ORT_SPARSE_BLOCK_SPARSE: + indicesFormat = ORT_SPARSE_BLOCK_SPARSE_INDICES; + break; + case ORT_SPARSE_UNDEFINED: + default: { + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Sparse format is ORT_SPARSE_UNDEFINED, cannot get indices"); + return NULL; + } + } + + OrtTensorTypeAndShapeInfo* info = NULL; + code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(ortValue, indicesFormat, &info)); + if (code != ORT_OK) { + return NULL; + } + size_t arrSize = 0; + code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize)); + if (code != ORT_OK) { + api->ReleaseTensorTypeAndShapeInfo(info); + return NULL; + } + ONNXTensorElementDataType onnxTypeEnum; + code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum)); + api->ReleaseTensorTypeAndShapeInfo(info); + if (code != ORT_OK) { + return NULL; + } + + size_t typeSize = onnxTypeSize(onnxTypeEnum); + size_t sizeBytes = arrSize * typeSize; + + uint8_t* arr = NULL; + size_t indices_size = 0; + code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndices(ortValue, indicesFormat, &indices_size, (const void**)&arr)); + if (code != ORT_OK) { + return NULL; + } + + if (indices_size != arrSize) { + throwOrtException(jniEnv, convertErrorCode(ORT_RUNTIME_EXCEPTION), "Unexpected size"); + return NULL; + } else { + return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes); + } +} + +/* + * Class: ai_onnxruntime_OnnxSparseTensor + * Method: getInnerIndicesBuffer + * Signature: (JJ)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getInnerIndicesBuffer + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtValue* ortValue = (const OrtValue*) handle; + OrtSparseFormat format; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format)); + if (code != ORT_OK) { + return NULL; + } + enum OrtSparseIndicesFormat indicesFormat; + switch (format) { + case ORT_SPARSE_CSRC: + indicesFormat = ORT_SPARSE_CSR_INNER_INDICES; + break; + case ORT_SPARSE_COO: + case ORT_SPARSE_BLOCK_SPARSE: + case ORT_SPARSE_UNDEFINED: + default: { + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), + "Sparse format is ORT_SPARSE_COO, ORT_SPARSE_BLOCK_SPARSE, or ORT_SPARSE_UNDEFINED, inner indices are not defined."); + return NULL; + } + } + + OrtTensorTypeAndShapeInfo* info = NULL; + code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(ortValue, indicesFormat, &info)); + if (code != ORT_OK) { + return NULL; + } + size_t arrSize = 0; + code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize)); + if (code != ORT_OK) { + api->ReleaseTensorTypeAndShapeInfo(info); + return NULL; + } + ONNXTensorElementDataType onnxTypeEnum; + code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum)); + api->ReleaseTensorTypeAndShapeInfo(info); + if (code != ORT_OK) { + return NULL; + } + + size_t typeSize = onnxTypeSize(onnxTypeEnum); + size_t sizeBytes = arrSize * typeSize; + + uint8_t* arr; + size_t indices_size; + code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndices(ortValue, indicesFormat, &indices_size, (const void**)&arr)); + if (code != ORT_OK) { + return NULL; + } + + if (indices_size != arrSize) { + throwOrtException(jniEnv, convertErrorCode(ORT_RUNTIME_EXCEPTION), "Unexpected size"); + return NULL; + } else { + return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes); + } +} + +/* + * Class: ai_onnxruntime_OnnxSparseTensor + * Method: getValuesBuffer + * Signature: (JJ)Ljava/nio/ByteBuffer; + */ +JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getValuesBuffer + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtValue* ortValue = (const OrtValue*) handle; + OrtSparseFormat format; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(ortValue, &format)); + if (code != ORT_OK) { + return NULL; + } + switch (format) { + case ORT_SPARSE_COO: + case ORT_SPARSE_CSRC: + case ORT_SPARSE_BLOCK_SPARSE: { + OrtTensorTypeAndShapeInfo* info = NULL; + checkOrtStatus(jniEnv, api, api->GetSparseTensorValuesTypeAndShape(ortValue, &info)); + if (code != ORT_OK) { + return NULL; + } + size_t arrSize = 0; + code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(info, &arrSize)); + if (code != ORT_OK) { + api->ReleaseTensorTypeAndShapeInfo(info); + return NULL; + } + ONNXTensorElementDataType onnxTypeEnum; + code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxTypeEnum)); + api->ReleaseTensorTypeAndShapeInfo(info); + if (code != ORT_OK) { + return NULL; + } + + size_t typeSize = onnxTypeSize(onnxTypeEnum); + size_t sizeBytes = arrSize * typeSize; + + uint8_t* arr = NULL; + checkOrtStatus(jniEnv, api, api->GetSparseTensorValues(ortValue, (const void**)&arr)); + + return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, sizeBytes); + } + case ORT_SPARSE_UNDEFINED: + default: { + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), + "Sparse format is ORT_SPARSE_UNDEFINED, cannot get data"); + return NULL; + } + } +} + +/* + * Class: ai_onnxruntime_OnnxSparseTensor + * Method: getInnerIndicesShape + * Signature: (JJ)[J; + */ +JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getInnerIndicesShape + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtValue* value = (const OrtValue*) handle; + + // Extract the info + OrtTensorTypeAndShapeInfo* info; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(value, ORT_SPARSE_CSR_INNER_INDICES, &info)); + if (code != ORT_OK) { + return NULL; + } + + // Extract the shape + size_t numDim = 0; + code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &numDim)); + if (code != ORT_OK) { + api->ReleaseTensorTypeAndShapeInfo(info); + return NULL; + } + int64_t* dimensions = malloc(sizeof(int64_t) * numDim); + if (dimensions == NULL) { + throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array"); + api->ReleaseTensorTypeAndShapeInfo(info); + return NULL; + } + code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim)); + // Free the info + api->ReleaseTensorTypeAndShapeInfo(info); + if (code != ORT_OK) { + free((void*)dimensions); + return NULL; + } + + // Create the long array for the shape. + jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim)); + (*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions); + + // Free the dimensions array + free((void*)dimensions); + + return shape; +} + +/* + * Class: ai_onnxruntime_OnnxSparseTensor + * Method: getIndicesShape + * Signature: (JJ)[J; + */ +JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getIndicesShape + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtValue* value = (const OrtValue*) handle; + + // Get the indices format + OrtSparseFormat format; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSparseTensorFormat(value, &format)); + if (code != ORT_OK) { + return NULL; + } + enum OrtSparseIndicesFormat indicesFormat; + switch (format) { + case ORT_SPARSE_CSRC: + indicesFormat = ORT_SPARSE_CSR_OUTER_INDICES; + break; + case ORT_SPARSE_COO: + indicesFormat = ORT_SPARSE_COO_INDICES; + break; + case ORT_SPARSE_BLOCK_SPARSE: + indicesFormat = ORT_SPARSE_BLOCK_SPARSE_INDICES; + break; + case ORT_SPARSE_UNDEFINED: + default: { + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), + "Sparse format is ORT_SPARSE_UNDEFINED, indices are not defined."); + return NULL; + } + } + + // Extract the info + OrtTensorTypeAndShapeInfo* info; + code = checkOrtStatus(jniEnv, api, api->GetSparseTensorIndicesTypeShape(value, indicesFormat, &info)); + if (code != ORT_OK) { + return NULL; + } + + // Extract the shape + size_t numDim = 0; + code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &numDim)); + if (code != ORT_OK) { + api->ReleaseTensorTypeAndShapeInfo(info); + return NULL; + } + int64_t* dimensions = malloc(sizeof(int64_t) * numDim); + if (dimensions == NULL) { + throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array"); + api->ReleaseTensorTypeAndShapeInfo(info); + return NULL; + } + code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim)); + // Free the info + api->ReleaseTensorTypeAndShapeInfo(info); + if (code != ORT_OK) { + free((void*)dimensions); + return NULL; + } + + // Create the long array for the shape. + jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim)); + (*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions); + // Free the dimensions array + free((void*)dimensions); + + return shape; +} + +/* + * Class: ai_onnxruntime_OnnxSparseTensor + * Method: getValuesShape + * Signature: (JJ)[J; + */ +JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxSparseTensor_getValuesShape + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtValue* value = (const OrtValue*) handle; + + // Extract the info + OrtTensorTypeAndShapeInfo* info; + OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetSparseTensorValuesTypeAndShape(value,&info)); + if (code != ORT_OK) { + return NULL; + } + + // Extract the shape + size_t numDim = 0; + code = checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&numDim)); + if (code != ORT_OK) { + api->ReleaseTensorTypeAndShapeInfo(info); + return NULL; + } + int64_t* dimensions = malloc(sizeof(int64_t)*numDim); + if (dimensions == NULL) { + throwOrtException(jniEnv, convertErrorCode(ORT_FAIL), "Out of memory when trying to allocate dimensions array"); + api->ReleaseTensorTypeAndShapeInfo(info); + return NULL; + } + code = checkOrtStatus(jniEnv,api,api->GetDimensions(info, dimensions, numDim)); + // Free the info + api->ReleaseTensorTypeAndShapeInfo(info); + if (code != ORT_OK) { + free((void*)dimensions); + return NULL; + } + + // Create the long array for the shape. + jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim)); + (*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions); + + // Free the dimensions array + free((void*)dimensions); + + return shape; +} + +/* + * Class: ai_onnxruntime_OnnxSparseTensor + * Method: close + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxSparseTensor_close(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void) jniEnv; (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + api->ReleaseValue((OrtValue*)handle); +} + +/* + * Class: ai_onnxruntime_OnnxSparseTensor + * Method: createCSRCSparseTensorFromBuffer + * Signature: (JJLjava/nio/Buffer;IJLjava/nio/Buffer;IJLjava/nio/Buffer;IJ[J[JI)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxSparseTensor_createCSRCSparseTensorFromBuffer + (JNIEnv * jniEnv, jclass cls, jlong apiHandle, jlong allocatorHandle, + jobject indicesBuffer, jint indicesBufferPos, jlong indicesBufferSize, + jobject innerIndicesBuffer, jint innerIndicesBufferPos, jlong innerIndicesBufferSize, + jobject dataBuffer, jint dataBufferPos, + jlongArray denseShape, jlongArray valuesShape, + jint onnxTypeJava) { + (void) cls; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; + const OrtMemoryInfo* allocatorInfo; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo)); + if (code != ORT_OK) { + return 0; + } + + // Convert types to ONNX C enums + ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava); + + // Extract the buffers + char* indicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, indicesBuffer); + char* innerIndicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, innerIndicesBuffer); + char* dataBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, dataBuffer); + // Increment by bufferPos bytes + indicesBufferArr = indicesBufferArr + indicesBufferPos; + innerIndicesBufferArr = innerIndicesBufferArr + innerIndicesBufferPos; + dataBufferArr = dataBufferArr + dataBufferPos; + + // Extract the dense shape information + jboolean mkCopy; + jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, denseShape, &mkCopy); + jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, denseShape); + + // Extract the value shape + jlong* valuesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, valuesShape, &mkCopy); + jsize valuesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, valuesShape); + + // Create the OrtValue + OrtValue* ortValue = NULL; + code = checkOrtStatus(jniEnv, api, api->CreateSparseTensorWithValuesAsOrtValue(allocatorInfo, dataBufferArr, + (int64_t*) shapeArr, shapeLen, (int64_t*) valuesShapeArr, valuesShapeLen, onnxType, &ortValue)); + // Release shapes + (*jniEnv)->ReleaseLongArrayElements(jniEnv, denseShape, shapeArr, JNI_ABORT); + (*jniEnv)->ReleaseLongArrayElements(jniEnv, valuesShape, valuesShapeArr, JNI_ABORT); + if (code != ORT_OK) { + return 0; + } + + // Fill it with indices + code = checkOrtStatus(jniEnv, api, api->UseCsrIndices(ortValue, + (int64_t *) innerIndicesBufferArr, innerIndicesBufferSize, + (int64_t *) indicesBufferArr, indicesBufferSize)); + if (code != ORT_OK) { + return 0; + } else { + // Return the pointer to the OrtValue + return (jlong) ortValue; + } +} + +/* + * Class: ai_onnxruntime_OnnxSparseTensor + * Method: createSparseTensorFromBuffer + * Signature: (JJLjava/nio/Buffer;IJLjava/nio/Buffer;IJ[J[J[JII)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxSparseTensor_createSparseTensorFromBuffer + (JNIEnv * jniEnv, jclass cls, jlong apiHandle, jlong allocatorHandle, + jobject indicesBuffer, jint indicesBufferPos, jlong indicesBufferSize, + jobject dataBuffer, jint dataBufferPos, + jlongArray denseShape, jlongArray indicesShape, jlongArray valuesShape, + jint onnxTypeJava, jint sparsityTypeJava) { + (void) cls; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; + const OrtMemoryInfo* allocatorInfo; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo)); + if (code != ORT_OK) { + return 0; + } + + // Convert types to ONNX C enums + ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava); + OrtSparseFormat sparsityType = convertToOrtSparseFormat(sparsityTypeJava); + + // Extract the buffers + char* indicesBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, indicesBuffer); + char* dataBufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, dataBuffer); + // Increment by bufferPos bytes + indicesBufferArr = indicesBufferArr + indicesBufferPos; + dataBufferArr = dataBufferArr + dataBufferPos; + + // Extract the dense shape information + jboolean mkCopy; + jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, denseShape, &mkCopy); + jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, denseShape); + + // Extract the value shape + jlong* valuesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, valuesShape, &mkCopy); + jsize valuesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, valuesShape); + + // Create the OrtValue + OrtValue* ortValue = NULL; + code = checkOrtStatus(jniEnv, api, api->CreateSparseTensorWithValuesAsOrtValue(allocatorInfo, dataBufferArr, + (int64_t*) shapeArr, shapeLen, (int64_t*) valuesShapeArr, valuesShapeLen, onnxType, &ortValue)); + + // Release shapes + (*jniEnv)->ReleaseLongArrayElements(jniEnv, denseShape, shapeArr, JNI_ABORT); + (*jniEnv)->ReleaseLongArrayElements(jniEnv, valuesShape, valuesShapeArr, JNI_ABORT); + if (code != ORT_OK) { + return 0; + } + + // Fill it with indices + switch (sparsityType) { + case ORT_SPARSE_COO: { + // The cast is because we compute the offset in bytes in Java. + code = checkOrtStatus(jniEnv, api, api->UseCooIndices(ortValue, (int64_t *) indicesBufferArr, + indicesBufferSize)); + break; + } + case ORT_SPARSE_BLOCK_SPARSE: { + // Extract the indices shape + jlong* indicesShapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, indicesShape, &mkCopy); + jsize indicesShapeLen = (*jniEnv)->GetArrayLength(jniEnv, indicesShape); + + // The cast is because we compute the offset in bytes in Java. + code = checkOrtStatus(jniEnv, api, api->UseBlockSparseIndices(ortValue, (int64_t *) indicesShapeArr, + indicesShapeLen, (int32_t *) indicesBufferArr)); + + // Release the indices shape + (*jniEnv)->ReleaseLongArrayElements(jniEnv, indicesShape, indicesShapeArr, JNI_ABORT); + break; + } + case ORT_SPARSE_CSRC: + case ORT_SPARSE_UNDEFINED: { + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), + "These types are unsupported by this method - ORT_SPARSE_CSRC, ORT_SPARSE_UNDEFINED"); + code = ORT_NOT_IMPLEMENTED; + } + } + if (code != ORT_OK) { + return 0; + } else { + // Return the pointer to the OrtValue + return (jlong) ortValue; + } +} diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index d6576548655c0..2c9fca894ce6c 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -59,6 +59,10 @@ public class InferenceTest { private static final OrtEnvironment env = OrtEnvironment.getEnvironment(); + public static Path getResourcePath(String path) { + return new File(InferenceTest.class.getResource(path).getFile()).toPath(); + } + @Test public void environmentTest() { // Checks that the environment instance is the same. diff --git a/java/src/test/java/ai/onnxruntime/SparseTensorTest.java b/java/src/test/java/ai/onnxruntime/SparseTensorTest.java new file mode 100644 index 0000000000000..7a3abb7bf78ff --- /dev/null +++ b/java/src/test/java/ai/onnxruntime/SparseTensorTest.java @@ -0,0 +1,450 @@ +/* + * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import static ai.onnxruntime.InferenceTest.getResourcePath; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.nio.LongBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; + +public class SparseTensorTest { + + @Test + public void testCSRC() throws OrtException { + String modelPath = getResourcePath("/generic_sparse_to_dense_matmul.onnx").toString(); + try (OrtEnvironment env = OrtEnvironment.getEnvironment(); + OrtSession.SessionOptions options = new OrtSession.SessionOptions()) { + try (OrtSession session = env.createSession(modelPath, options)) { + Map inputMap = new HashMap<>(); + + OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3); + long[] shape = new long[] {3, 3}; + /* + * Sparse matrix: + * [ + * 0 1 0 + * 1 0 1 + * 4 0 6 + * ] + */ + LongBuffer outerIndices = + ByteBuffer.allocateDirect(4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer(); + outerIndices.put(0); + outerIndices.put(1); + outerIndices.put(3); + outerIndices.put(5); + outerIndices.rewind(); + LongBuffer innerIndices = + ByteBuffer.allocateDirect(5 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer(); + innerIndices.put(1); + innerIndices.put(0); + innerIndices.put(2); + innerIndices.put(0); + innerIndices.put(2); + innerIndices.rewind(); + + FloatBuffer data = + ByteBuffer.allocateDirect(5 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); + data.put(1); + data.put(1); + data.put(1); + data.put(4); + data.put(6); + data.rewind(); + + OnnxSparseTensor.CSRCTensor csrcTensor = + new OnnxSparseTensor.CSRCTensor( + outerIndices, innerIndices, data, shape, OnnxJavaType.FLOAT, 5); + OnnxSparseTensor tensor = OnnxSparseTensor.createSparseTensor(env, csrcTensor); + + inputMap.put("sparse_A", tensor); + inputMap.put("dense_B", denseIdMatrix); + + OrtSession.Result result = session.run(inputMap); + + OnnxTensor outputTensor = (OnnxTensor) result.get(0); + assertArrayEquals(shape, outputTensor.getInfo().getShape()); + float[] output = outputTensor.getFloatBuffer().array(); + float[] expected = new float[] {0, 1, 0, 1, 0, 1, 4, 0, 6}; + assertArrayEquals(expected, output, 1e-6f); + result.close(); + inputMap.clear(); + + // check that the get methods return new buffers which exist past the tensor lifetime. + Buffer valuesOne = tensor.getValuesBuffer(); + Buffer valuesTwo = tensor.getValuesBuffer(); + Buffer indicesOne = tensor.getIndicesBuffer(); + Buffer indicesTwo = tensor.getIndicesBuffer(); + Buffer innerIndicesOne = tensor.getInnerIndicesBuffer(); + Buffer innerIndicesTwo = tensor.getInnerIndicesBuffer(); + tensor.close(); + assertEquals(valuesOne, valuesTwo); + assertFalse(valuesOne == valuesTwo); + assertEquals(indicesOne, indicesTwo); + assertFalse(indicesOne == indicesTwo); + assertEquals(innerIndicesOne, innerIndicesTwo); + assertFalse(innerIndicesOne == innerIndicesTwo); + + long[] rectangularShape = new long[] {2, 3}; + /* + * Sparse matrix: + * [ + * 1 0 3 + * 0 5 6 + * ] + */ + outerIndices = + ByteBuffer.allocateDirect(3 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer(); + outerIndices.put(0); + outerIndices.put(2); + outerIndices.put(4); + outerIndices.rewind(); + innerIndices = + ByteBuffer.allocateDirect(4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer(); + innerIndices.put(0); + innerIndices.put(2); + innerIndices.put(1); + innerIndices.put(2); + innerIndices.rewind(); + + data = ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); + data.put(1); + data.put(3); + data.put(5); + data.put(6); + data.rewind(); + + csrcTensor = + new OnnxSparseTensor.CSRCTensor( + outerIndices, innerIndices, data, rectangularShape, OnnxJavaType.FLOAT, 4); + tensor = OnnxSparseTensor.createSparseTensor(env, csrcTensor); + + assertArrayEquals(new long[] {3}, tensor.getIndicesShape()); + assertArrayEquals(new long[] {4}, tensor.getInnerIndicesShape()); + assertArrayEquals(new long[] {4}, tensor.getValuesShape()); + + inputMap.put("sparse_A", tensor); + inputMap.put("dense_B", denseIdMatrix); + + result = session.run(inputMap); + + outputTensor = (OnnxTensor) result.get(0); + assertArrayEquals(rectangularShape, outputTensor.getInfo().getShape()); + output = outputTensor.getFloatBuffer().array(); + expected = new float[] {1, 0, 3, 0, 5, 6}; + assertArrayEquals(expected, output, 1e-6f); + result.close(); + tensor.close(); + inputMap.clear(); + denseIdMatrix.close(); + + denseIdMatrix = makeIdentityMatrix(env, 4); + long[] vectorShape = new long[] {1, 4}; + /* + * Sparse matrix: + * [ + * 1 0 0 4 + * ] + */ + outerIndices = + ByteBuffer.allocateDirect(2 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer(); + outerIndices.put(0); + outerIndices.put(2); + outerIndices.rewind(); + innerIndices = + ByteBuffer.allocateDirect(2 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer(); + innerIndices.put(0); + innerIndices.put(3); + innerIndices.rewind(); + + data = ByteBuffer.allocateDirect(2 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); + data.put(1); + data.put(4); + data.rewind(); + + csrcTensor = + new OnnxSparseTensor.CSRCTensor( + outerIndices, innerIndices, data, vectorShape, OnnxJavaType.FLOAT, 2); + tensor = OnnxSparseTensor.createSparseTensor(env, csrcTensor); + + assertArrayEquals(new long[] {2}, tensor.getIndicesShape()); + assertArrayEquals(new long[] {2}, tensor.getInnerIndicesShape()); + assertArrayEquals(new long[] {2}, tensor.getValuesShape()); + + inputMap.put("sparse_A", tensor); + inputMap.put("dense_B", denseIdMatrix); + + result = session.run(inputMap); + + outputTensor = (OnnxTensor) result.get(0); + assertArrayEquals(vectorShape, outputTensor.getInfo().getShape()); + output = outputTensor.getFloatBuffer().array(); + expected = new float[] {1, 0, 0, 4}; + assertArrayEquals(expected, output, 1e-6f); + result.close(); + tensor.close(); + inputMap.clear(); + denseIdMatrix.close(); + } + } + } + + @Test + public void testCOO() throws OrtException { + String modelPath = getResourcePath("/generic_sparse_to_dense_matmul.onnx").toString(); + try (OrtEnvironment env = OrtEnvironment.getEnvironment(); + OrtSession.SessionOptions options = new OrtSession.SessionOptions()) { + try (OrtSession session = env.createSession(modelPath, options)) { + Map inputMap = new HashMap<>(); + + OnnxTensor denseIdMatrix = makeIdentityMatrix(env, 3); + long[] shape = new long[] {3, 3}; + /* + * Sparse matrix: + * [ + * 0 1 0 + * 1 0 1 + * 4 0 6 + * ] + */ + LongBuffer indices = + ByteBuffer.allocateDirect(2 * 5 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer(); + indices.put(0); + indices.put(1); + indices.put(1); + indices.put(0); + indices.put(1); + indices.put(2); + indices.put(2); + indices.put(0); + indices.put(2); + indices.put(2); + indices.rewind(); + + FloatBuffer data = + ByteBuffer.allocateDirect(5 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); + data.put(1); + data.put(1); + data.put(1); + data.put(4); + data.put(6); + data.rewind(); + + OnnxSparseTensor.COOTensor cooTensor = + new OnnxSparseTensor.COOTensor( + indices, new long[] {5, 2}, data, shape, OnnxJavaType.FLOAT, 5); + OnnxSparseTensor tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor); + + inputMap.put("sparse_A", tensor); + inputMap.put("dense_B", denseIdMatrix); + + OrtSession.Result result = session.run(inputMap); + + OnnxTensor outputTensor = (OnnxTensor) result.get(0); + assertArrayEquals(shape, outputTensor.getInfo().getShape()); + float[] output = outputTensor.getFloatBuffer().array(); + float[] expected = new float[] {0, 1, 0, 1, 0, 1, 4, 0, 6}; + assertArrayEquals(expected, output, 1e-6f); + result.close(); + tensor.close(); + inputMap.clear(); + + /* disabled as sparse_dense_matmul doesn't support COO tensors with 1d indices + // Run the same tensor through, but using 1d indexing rather than 2d indexing + indices = ByteBuffer.allocateDirect(5 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer(); + indices.put(1); + indices.put(3); + indices.put(5); + indices.put(6); + indices.put(8); + indices.rewind(); + + cooTensor = new OnnxSparseTensor.COOTensor(indices, new long[]{5}, data, shape, OnnxJavaType.FLOAT, 5); + tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor); + + inputMap.put("sparse_A", tensor); + inputMap.put("dense_B", denseIdMatrix); + + result = session.run(inputMap); + + outputTensor = (OnnxTensor) result.get(0); + assertArrayEquals(shape, outputTensor.getInfo().getShape()); + output = outputTensor.getFloatBuffer().array(); + assertArrayEquals(expected, output, 1e-6f); + result.close(); + tensor.close(); + inputMap.clear(); + */ + + long[] rectangularShape = new long[] {2, 3}; + /* + * Sparse matrix: + * [ + * 1 0 3 + * 0 5 6 + * ] + */ + indices = + ByteBuffer.allocateDirect(2 * 4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer(); + indices.put(0); + indices.put(0); + indices.put(0); + indices.put(2); + indices.put(1); + indices.put(1); + indices.put(1); + indices.put(2); + indices.rewind(); + + data = ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); + data.put(1); + data.put(3); + data.put(5); + data.put(6); + data.rewind(); + + cooTensor = + new OnnxSparseTensor.COOTensor( + indices, new long[] {4, 2}, data, rectangularShape, OnnxJavaType.FLOAT, 4); + tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor); + + assertArrayEquals(new long[] {4, 2}, tensor.getIndicesShape()); + assertArrayEquals(new long[] {4}, tensor.getValuesShape()); + + inputMap.put("sparse_A", tensor); + inputMap.put("dense_B", denseIdMatrix); + + result = session.run(inputMap); + + outputTensor = (OnnxTensor) result.get(0); + assertArrayEquals(rectangularShape, outputTensor.getInfo().getShape()); + output = outputTensor.getFloatBuffer().array(); + expected = new float[] {1, 0, 3, 0, 5, 6}; + assertArrayEquals(expected, output, 1e-6f); + result.close(); + tensor.close(); + inputMap.clear(); + denseIdMatrix.close(); + + denseIdMatrix = makeIdentityMatrix(env, 4); + long[] vectorShape = new long[] {1, 4}; + /* + * Sparse matrix: + * [ + * 1 + * 0 + * 0 + * 4 + * ] + */ + indices = ByteBuffer.allocateDirect(4 * 8).order(ByteOrder.LITTLE_ENDIAN).asLongBuffer(); + indices.put(0); + indices.put(0); + indices.put(0); + indices.put(3); + indices.rewind(); + + data = ByteBuffer.allocateDirect(2 * 4).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); + data.put(1); + data.put(4); + data.rewind(); + + cooTensor = + new OnnxSparseTensor.COOTensor( + indices, new long[] {2, 2}, data, vectorShape, OnnxJavaType.FLOAT, 2); + tensor = OnnxSparseTensor.createSparseTensor(env, cooTensor); + + assertArrayEquals(new long[] {2, 2}, tensor.getIndicesShape()); + assertArrayEquals(new long[] {2}, tensor.getValuesShape()); + + inputMap.put("sparse_A", tensor); + inputMap.put("dense_B", denseIdMatrix); + + result = session.run(inputMap); + + outputTensor = (OnnxTensor) result.get(0); + assertArrayEquals(vectorShape, outputTensor.getInfo().getShape()); + output = outputTensor.getFloatBuffer().array(); + expected = new float[] {1, 0, 0, 4}; + assertArrayEquals(expected, output, 1e-6f); + result.close(); + tensor.close(); + inputMap.clear(); + denseIdMatrix.close(); + } + } + } + + @Test + public void testCOOOutput() throws OrtException { + String modelPath = getResourcePath("/sparse_initializer_as_output.onnx").toString(); + try (OrtEnvironment env = OrtEnvironment.getEnvironment(); + OrtSession.SessionOptions options = new OrtSession.SessionOptions()) { + try (OrtSession session = env.createSession(modelPath, options)) { + Map outputs = session.getOutputInfo(); + assertEquals(1, outputs.size()); + + NodeInfo info = outputs.get("values"); + assertNotNull(info); + assertTrue(info.getInfo() instanceof TensorInfo); + + TensorInfo outputInfo = (TensorInfo) info.getInfo(); + assertArrayEquals(new long[] {3, 3}, outputInfo.getShape()); + assertEquals( + TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, outputInfo.onnxType); + assertEquals(OnnxJavaType.FLOAT, outputInfo.type); + + OrtSession.Result result = session.run(Collections.emptyMap()); + OnnxValue output = result.get("values").get(); + + assertTrue(output instanceof OnnxSparseTensor); + + OnnxSparseTensor sparseTensor = (OnnxSparseTensor) output; + + assertEquals(OnnxSparseTensor.SparseTensorType.COO, sparseTensor.getSparseTensorType()); + + assertArrayEquals(new long[] {3}, sparseTensor.getIndicesShape()); + assertArrayEquals(new long[] {3}, sparseTensor.getValuesShape()); + assertArrayEquals(new long[] {3, 3}, sparseTensor.getInfo().getShape()); + + OnnxSparseTensor.SparseTensor javaTensor = sparseTensor.getValue(); + + assertTrue(javaTensor instanceof OnnxSparseTensor.COOTensor); + + OnnxSparseTensor.COOTensor cooTensor = (OnnxSparseTensor.COOTensor) javaTensor; + + long[] indices = new long[3]; + cooTensor.getIndices().get(indices); + float[] data = new float[3]; + ((FloatBuffer) cooTensor.getValues()).get(data); + + assertArrayEquals(new long[] {2, 3, 5}, indices); + assertArrayEquals( + new float[] {1.764052391052246f, 0.40015721321105957f, 0.978738009929657f}, data); + } + } + } + + private static OnnxTensor makeIdentityMatrix(OrtEnvironment env, int size) throws OrtException { + float[][] values = new float[size][size]; + for (int i = 0; i < values.length; i++) { + values[i][i] = 1.0f; + } + + return OnnxTensor.createTensor(env, values); + } +} diff --git a/java/testdata/generic_sparse_to_dense_matmul.onnx b/java/testdata/generic_sparse_to_dense_matmul.onnx new file mode 100644 index 0000000000000..3de7973016f79 --- /dev/null +++ b/java/testdata/generic_sparse_to_dense_matmul.onnx @@ -0,0 +1,16 @@ +dmitrism:Î +F +sparse_A +dense_Bdense_YSpMM"SparseToDenseMatMul: com.microsoftSpMMZ* +sparse_AB + A_dim_1 +  inner_dimZ) +dense_B + +  inner_dim + B_dim_2b' +dense_Y + + A_dim_1 + B_dim_2B + com.microsoft \ No newline at end of file