Skip to content

Commit

Permalink
[java] Sparse tensor support (microsoft#10653)
Browse files Browse the repository at this point in the history
**Description**:

Adds support for creating and receiving sparse tensors in the ORT Java
API.

CSRC and COO tensors as inputs are tested, but there is no op which
accepts a block sparse tensor to test. COO tensors are tested as
outputs, but there is no op which emits a CSRC or block sparse tensor to
test.

**Motivation and Context**
- Why is this change required? What problem does it solve? Request to
expose ORT sparse tensor support in Java.

cc @yuslepukhin
  • Loading branch information
Craigacp authored Nov 22, 2022
1 parent 8b0e0f4 commit dd2c031
Show file tree
Hide file tree
Showing 13 changed files with 2,218 additions and 106 deletions.
920 changes: 920 additions & 0 deletions java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java

Large diffs are not rendered by default.

99 changes: 10 additions & 89 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
Expand All @@ -18,21 +17,7 @@
* A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be
* returned as outputs.
*/
public class OnnxTensor implements OnnxValue {
static {
try {
OnnxRuntime.init();
} catch (IOException e) {
throw new RuntimeException("Failed to load onnx-runtime library", e);
}
}

private final long nativeHandle;

private final long allocatorHandle;

private final TensorInfo info;

public class OnnxTensor extends OnnxTensorLike {
/**
* This reference is held for OnnxTensors backed by a Java nio buffer to ensure the buffer does
* not go out of scope while the OnnxTensor exists.
Expand All @@ -44,9 +29,7 @@ public class OnnxTensor implements OnnxValue {
}

OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer) {
this.nativeHandle = nativeHandle;
this.allocatorHandle = allocatorHandle;
this.info = info;
super(nativeHandle, allocatorHandle, info);
this.buffer = buffer;
}

Expand All @@ -55,10 +38,6 @@ public OnnxValueType getType() {
return OnnxValueType.ONNX_TYPE_TENSOR;
}

long getNativeHandle() {
return nativeHandle;
}

/**
* Either returns a boxed primitive if the Tensor is a scalar, or a multidimensional array of
* primitives if it has multiple dimensions.
Expand Down Expand Up @@ -108,11 +87,6 @@ public Object getValue() throws OrtException {
}
}

@Override
public TensorInfo getInfo() {
return info;
}

@Override
public String toString() {
return "OnnxTensor(info=" + info.toString() + ")";
Expand Down Expand Up @@ -300,7 +274,7 @@ private native void getArray(long apiHandle, long nativeHandle, Object carrier)
* @param input A uint16_t representing an IEEE half precision float.
* @return A float.
*/
private static float fp16ToFloat(short input) {
static float fp16ToFloat(short input) {
int output =
((input & 0x8000) << 16) | (((input & 0x7c00) + 0x1C000) << 13) | ((input & 0x03FF) << 13);
return Float.intBitsToFloat(output);
Expand Down Expand Up @@ -715,73 +689,20 @@ static OnnxTensor createTensor(
*/
private static OnnxTensor createTensor(
OnnxJavaType type, OrtAllocator allocator, Buffer data, long[] shape) throws OrtException {
int bufferPos;
long bufferSizeLong = data.remaining() * (long) type.size;
if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) {
// The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending
// on the JVM, so we check for something 8 elements below the maximum size which
// should be allocatable (assuming there is enough memory) on all 64-bit JVMs.
throw new IllegalStateException(
"Cannot allocate a direct buffer of the requested size and type, size "
+ data.remaining()
+ ", type = "
+ type);
}
// Now we know we're in range
int bufferSize = data.remaining() * type.size;
Buffer tmp;
if (data.isDirect()) {
tmp = data;
bufferPos = data.position() * type.size;
} else {
// Copy the data to a new direct buffer, then restore the state of the input.
int origPosition = data.position();
ByteBuffer buffer = ByteBuffer.allocateDirect(bufferSize).order(ByteOrder.nativeOrder());
switch (type) {
case FLOAT:
tmp = buffer.asFloatBuffer().put((FloatBuffer) data);
break;
case DOUBLE:
tmp = buffer.asDoubleBuffer().put((DoubleBuffer) data);
break;
case UINT8:
case INT8:
// buffer is already a ByteBuffer, no cast needed.
tmp = buffer.put((ByteBuffer) data);
break;
case INT16:
tmp = buffer.asShortBuffer().put((ShortBuffer) data);
break;
case INT32:
tmp = buffer.asIntBuffer().put((IntBuffer) data);
break;
case INT64:
tmp = buffer.asLongBuffer().put((LongBuffer) data);
break;
case BOOL:
case STRING:
case UNKNOWN:
default:
throw new IllegalStateException(
"Impossible to reach here, managed to cast a buffer as an incorrect type");
}
data.position(origPosition);
tmp.rewind();
bufferPos = 0;
}
TensorInfo info = TensorInfo.constructFromBuffer(tmp, shape, type);
OrtUtil.BufferTuple tuple = OrtUtil.prepareBuffer(data, type);
TensorInfo info = TensorInfo.constructFromBuffer(tuple.data, shape, type);
return new OnnxTensor(
createTensorFromBuffer(
OnnxRuntime.ortApiHandle,
allocator.handle,
tmp,
bufferPos,
bufferSize,
tuple.data,
tuple.pos,
tuple.byteSize,
shape,
info.onnxType.value),
allocator.handle,
info,
tmp);
tuple.data);
}

private static native long createTensor(
Expand Down
59 changes: 59 additions & 0 deletions java/src/main/java/ai/onnxruntime/OnnxTensorLike.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import java.io.IOException;

/**
* Currently implemented by {@link OnnxTensor}, {@link OnnxSparseTensor}. Will be sealed to these
* types one day.
*/
public abstract class OnnxTensorLike implements OnnxValue {
static {
try {
OnnxRuntime.init();
} catch (IOException e) {
throw new RuntimeException("Failed to load onnx-runtime library", e);
}
}

protected final long nativeHandle;

protected final long allocatorHandle;

protected final TensorInfo info;

/**
* Constructs a tensor-like (the base class of OnnxTensor and OnnxSparseTensor).
*
* @param nativeHandle The pointer to the tensor.
* @param allocatorHandle The pointer to the memory allocator.
* @param info The tensor info.
*/
OnnxTensorLike(long nativeHandle, long allocatorHandle, TensorInfo info) {
this.nativeHandle = nativeHandle;
this.allocatorHandle = allocatorHandle;
this.info = info;
}

/**
* Returns the native pointer.
*
* @return The native pointer.
*/
long getNativeHandle() {
return nativeHandle;
}

/**
* Returns a {@link TensorInfo} for this tensor.
*
* @return The tensor info.
*/
@Override
public TensorInfo getInfo() {
return info;
}
}
10 changes: 5 additions & 5 deletions java/src/main/java/ai/onnxruntime/OnnxValue.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2022 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand All @@ -8,9 +8,8 @@

/**
* Top interface for input and output values from ONNX models. Currently implemented by {@link
* OnnxTensor}, {@link OnnxSequence} and {@link OnnxMap}. Will be sealed to these types one day.
*
* <p>Does not support sparse tensors.
* OnnxTensor}, {@link OnnxSparseTensor}, {@link OnnxSequence} and {@link OnnxMap}. Will be sealed
* to these types one day.
*/
public interface OnnxValue extends AutoCloseable {

Expand All @@ -21,7 +20,8 @@ public enum OnnxValueType {
ONNX_TYPE_SEQUENCE(2),
ONNX_TYPE_MAP(3),
ONNX_TYPE_OPAQUE(4),
ONNX_TYPE_SPARSETENSOR(5);
ONNX_TYPE_SPARSETENSOR(5),
ONNX_TYPE_OPTIONAL(6);

/** The id number of this type in the C API. */
public final int value;
Expand Down
17 changes: 10 additions & 7 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ public Map<String, NodeInfo> getOutputInfo() throws OrtException {
* @throws OrtException If there was an error in native code, the input names are invalid, or if
* there are zero or too many inputs.
*/
public Result run(Map<String, OnnxTensor> inputs) throws OrtException {
public Result run(Map<String, ? extends OnnxTensorLike> inputs) throws OrtException {
return run(inputs, outputNames);
}

Expand All @@ -218,7 +218,8 @@ public Result run(Map<String, OnnxTensor> inputs) throws OrtException {
* @throws OrtException If there was an error in native code, the input names are invalid, or if
* there are zero or too many inputs.
*/
public Result run(Map<String, OnnxTensor> inputs, RunOptions runOptions) throws OrtException {
public Result run(Map<String, ? extends OnnxTensorLike> inputs, RunOptions runOptions)
throws OrtException {
return run(inputs, outputNames, runOptions);
}

Expand All @@ -233,15 +234,15 @@ public Result run(Map<String, OnnxTensor> inputs, RunOptions runOptions) throws
* @throws OrtException If there was an error in native code, the input or output names are
* invalid, or if there are zero or too many inputs or outputs.
*/
public Result run(Map<String, OnnxTensor> inputs, Set<String> requestedOutputs)
public Result run(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs)
throws OrtException {
return run(inputs, requestedOutputs, null);
}

/**
* Scores an input feed dict, returning the map of requested inferred outputs.
*
* <p>The outputs are sorted based on the supplied set traveral order.
* <p>The outputs are sorted based on the supplied set traversal order.
*
* @param inputs The inputs to score.
* @param requestedOutputs The requested outputs.
Expand All @@ -251,10 +252,12 @@ public Result run(Map<String, OnnxTensor> inputs, Set<String> requestedOutputs)
* invalid, or if there are zero or too many inputs or outputs.
*/
public Result run(
Map<String, OnnxTensor> inputs, Set<String> requestedOutputs, RunOptions runOptions)
Map<String, ? extends OnnxTensorLike> inputs,
Set<String> requestedOutputs,
RunOptions runOptions)
throws OrtException {
if (!closed) {
if (inputs.isEmpty() || (inputs.size() > numInputs)) {
if ((inputs.isEmpty() && (numInputs != 0)) || (inputs.size() > numInputs)) {
throw new OrtException(
"Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size());
}
Expand All @@ -268,7 +271,7 @@ public Result run(
String[] inputNamesArray = new String[inputs.size()];
long[] inputHandles = new long[inputs.size()];
int i = 0;
for (Map.Entry<String, OnnxTensor> t : inputs.entrySet()) {
for (Map.Entry<String, ? extends OnnxTensorLike> t : inputs.entrySet()) {
if (inputNames.contains(t.getKey())) {
inputNamesArray[i] = t.getKey();
inputHandles[i] = t.getValue().getNativeHandle();
Expand Down
Loading

0 comments on commit dd2c031

Please sign in to comment.