Skip to content

Commit

Permalink
[java] Fp16 fix for android/react native (microsoft#16832)
Browse files Browse the repository at this point in the history
### Description
This PR splits out the FP16 conversions into a separate package we can
override in the android build with a version which works on old versions
of Android.

I'm not sure the android build system changes are correct as I haven't
got an android build environment configured on my workstation.
@YUNQIUGUO if the CI build fails we should follow up offline to get my
environment configured so I can iterate on it.

### Motivation and Context
Fixes the CI failure after microsoft#16703.
  • Loading branch information
Craigacp authored Jul 25, 2023
1 parent e01365f commit a1bb670
Show file tree
Hide file tree
Showing 10 changed files with 577 additions and 284 deletions.
3 changes: 3 additions & 0 deletions java/build-android.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ android {
sourceSets {
main {
jniLibs.srcDirs = [jniLibsDir]
java {
srcDirs = ['src/main/java', 'src/main/android']
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ compileTestJava {
}
}

sourceSets.main.java {
srcDirs = ['src/main/java', 'src/main/jvm']
}

sourceSets.test {
// add test resource files
resources.srcDirs += [
Expand Down
237 changes: 237 additions & 0 deletions java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.platform;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.ShortBuffer;
import java.util.logging.Level;
import java.util.logging.Logger;

/** * Conversions between fp16, bfloat16 and fp32. */
public final class Fp16Conversions {
private static final Logger logger = Logger.getLogger(Fp16Conversions.class.getName());

/**
* Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java).
*
* <p>Respects the position and limit of the input buffer.
*
* @param buf The buffer of floats.
* @return A buffer of fp16 values stored as shorts.
*/
public static ShortBuffer convertFloatBufferToFp16Buffer(FloatBuffer buf) {
int pos = buf.position();
int remaining = buf.remaining();
ShortBuffer output =
ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer();
for (int i = 0; i < remaining; i++) {
output.put(i, floatToFp16(buf.get(i + pos)));
}
return output;
}

/**
* Casts a buffer of fp16 values stored as shorts into a buffer of floats.
*
* <p>Respects the position and limit of the input buffer.
*
* @param buf The buffer of fp16 values stored as shorts.
* @return A buffer of float values.
*/
public static FloatBuffer convertFp16BufferToFloatBuffer(ShortBuffer buf) {
int pos = buf.position();
int remaining = buf.remaining();
FloatBuffer output =
ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer();
for (int i = 0; i < remaining; i++) {
output.put(i, fp16ToFloat(buf.get(i + pos)));
}
return output;
}

/**
* Rounds a buffer of floats into a buffer containing bf16 values (stored as shorts in Java).
*
* <p>Respects the position and limit of the input buffer.
*
* @param buf The buffer of floats.
* @return A buffer of bf16 values stored as shorts.
*/
public static ShortBuffer convertFloatBufferToBf16Buffer(FloatBuffer buf) {
int pos = buf.position();
int remaining = buf.remaining();
ShortBuffer output =
ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer();
for (int i = 0; i < remaining; i++) {
output.put(i, floatToBf16(buf.get(i + pos)));
}
return output;
}

/**
* Casts a buffer of bf16 values stored as shorts into a buffer of floats.
*
* <p>Respects the position and limit of the input buffer.
*
* @param buf The buffer of bf16 values stored as shorts.
* @return A buffer of float values.
*/
public static FloatBuffer convertBf16BufferToFloatBuffer(ShortBuffer buf) {
int pos = buf.position();
int remaining = buf.remaining();
FloatBuffer output =
ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer();
for (int i = 0; i < remaining; i++) {
output.put(i, bf16ToFloat(buf.get(i + pos)));
}
return output;
}

/**
* Converts a fp16 value stored in a short into a float value.
*
* <p>On Android this is an alias for {@link #mlasFp16ToFloat(short)}.
*
* @param input The fp16 value.
* @return The float value.
*/
public static float fp16ToFloat(short input) {
return mlasFp16ToFloat(input);
}

/**
* Converts a float value into a fp16 value stored in a short.
*
* <p>On Android this is an alias for {@link #mlasFloatToFp16(float)}.
*
* @param input The float value.
* @return The fp16 value.
*/
public static short floatToFp16(float input) {
return mlasFloatToFp16(input);
}

/**
* Upcasts a fp16 value to a float. Mirrors the conversion in MLAS.
*
* @param input A uint16_t representing an IEEE half precision float.
* @return A float.
*/
public static float mlasFp16ToFloat(short input) {
// Port of MLAS_Half2Float from onnxruntime/core/mlas/inc/mlas_float16.h
final int MAGIC = 113 << 23;
// exponent mask after shift
final int SHIFTED_EXP = 0x7c00 << 13;

// exponent/mantissa bits
int bits = (input & 0x7fff) << 13;
// just the exponent
final int exp = SHIFTED_EXP & bits;
// exponent adjust
bits += (127 - 15) << 23;

// handle exponent special cases
if (exp == SHIFTED_EXP) {
// Inf/NaN?
// extra exp adjust
bits += (128 - 16) << 23;
} else if (exp == 0) {
// Zero/Denormal?
// extra exp adjust
bits += (1 << 23);
// renormalize
float tmp = Float.intBitsToFloat(bits) - Float.intBitsToFloat(MAGIC);
bits = Float.floatToIntBits(tmp);
}

// sign bit
bits |= (input & 0x8000) << 16;

return Float.intBitsToFloat(bits);
}

/**
* Rounds a float value to fp16. Mirrors the conversion in MLAS.
*
* @param input A float value.
* @return The value rounded to an IEEE half precision value.
*/
public static short mlasFloatToFp16(float input) {
// Port of MLAS_Float2Half from onnxruntime/core/mlas/inc/mlas_float16.h
int bits = Float.floatToIntBits(input);
final int F32_INFINITY = Float.floatToIntBits(Float.POSITIVE_INFINITY);
final int F16_MAX = (127 + 16) << 23;
final int DENORM_MAGIC = ((127 - 15) + (23 - 10) + 1) << 23;
final int SIGN_MASK = 0x80000000;
final int ROUNDING_CONST = ((15 - 127) << 23) + 0xfff;

int sign = bits & SIGN_MASK;
// mask out sign bit
bits ^= sign;

short output;
if (bits >= F16_MAX) {
// Inf or NaN (all exponent bits set)
output = (bits > F32_INFINITY) ? (short) 0x7e00 : (short) 0x7c00;
} else {
if (bits < (113 << 23)) {
// Subnormal or zero
// use a magic value to align our 10 mantissa bits at the bottom of
// the float. as long as FP addition is round-to-nearest-even this
// just works.
float tmp = Float.intBitsToFloat(bits) + Float.intBitsToFloat(DENORM_MAGIC);

// and one integer subtract of the bias later, we have our final float!
output = (short) (Float.floatToIntBits(tmp) - DENORM_MAGIC);
} else {
int mant_odd = (bits >> 13) & 1; // resulting mantissa is odd

// update exponent, rounding bias part 1
bits += ROUNDING_CONST;
// rounding bias part 2
bits += mant_odd;
// take the bits!
output = (short) (bits >> 13);
}
}

// Add the sign back in
output = (short) (output | ((short) (sign >> 16)));

return output;
}

/**
* Converts a bf16 value stored in a short into a float value.
*
* @param input A uint16_t representing a bfloat16 value.
* @return A float.
*/
public static float bf16ToFloat(short input) {
int bits = input << 16;
return Float.intBitsToFloat(bits);
}

/**
* Converts a float into bf16. May not produce correct values for subnormal floats.
*
* <p>Rounds to nearest even.
*
* @param input The float input.
* @return A bfloat16 value which is closest to the float.
*/
public static short floatToBf16(float input) {
int bits = Float.floatToIntBits(input);
int lsb = (bits >> 16) & 1;
int roundingBias = 0x7fff + lsb;
bits += roundingBias;
return (short) (bits >> 16);
}
}
13 changes: 13 additions & 0 deletions java/src/main/android/ai/onnxruntime/platform/package-info.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
/*
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/

/**
* A package of platform specific code, used to swap out Java implementations which don't run on Android.
*
* <p>Classes in this package should always have the same public methods.
*
* <p>This is the Android version of the package.
*/
package ai.onnxruntime.platform;
5 changes: 3 additions & 2 deletions java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package ai.onnxruntime;

import ai.onnxruntime.platform.Fp16Conversions;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
Expand Down Expand Up @@ -323,12 +324,12 @@ public Buffer getValuesBuffer() {
case FLOAT16:
{
ShortBuffer shortBuffer = buffer.asShortBuffer();
return OrtUtil.convertFp16BufferToFloatBuffer(shortBuffer);
return Fp16Conversions.convertFp16BufferToFloatBuffer(shortBuffer);
}
case BFLOAT16:
{
ShortBuffer shortBuffer = buffer.asShortBuffer();
return OrtUtil.convertBf16BufferToFloatBuffer(shortBuffer);
return Fp16Conversions.convertBf16BufferToFloatBuffer(shortBuffer);
}
case DOUBLE:
{
Expand Down
9 changes: 5 additions & 4 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package ai.onnxruntime;

import ai.onnxruntime.platform.Fp16Conversions;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
Expand Down Expand Up @@ -72,10 +73,10 @@ public Object getValue() throws OrtException {
case STRING:
return getString(OnnxRuntime.ortApiHandle, nativeHandle);
case FLOAT16:
return OrtUtil.fp16ToFloat(
return Fp16Conversions.fp16ToFloat(
getShort(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value));
case BFLOAT16:
return OrtUtil.bf16ToFloat(
return Fp16Conversions.bf16ToFloat(
getShort(OnnxRuntime.ortApiHandle, nativeHandle, info.onnxType.value));
case UNKNOWN:
default:
Expand Down Expand Up @@ -149,12 +150,12 @@ public FloatBuffer getFloatBuffer() {
// if it's fp16 we need to copy it out by hand.
ByteBuffer buf = getBuffer();
ShortBuffer buffer = buf.asShortBuffer();
return OrtUtil.convertFp16BufferToFloatBuffer(buffer);
return Fp16Conversions.convertFp16BufferToFloatBuffer(buffer);
} else if (info.type == OnnxJavaType.BFLOAT16) {
// if it's bf16 we need to copy it out by hand.
ByteBuffer buf = getBuffer();
ShortBuffer buffer = buf.asShortBuffer();
return OrtUtil.convertBf16BufferToFloatBuffer(buffer);
return Fp16Conversions.convertBf16BufferToFloatBuffer(buffer);
} else {
return null;
}
Expand Down
Loading

0 comments on commit a1bb670

Please sign in to comment.