forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[java] Fp16 fix for android/react native (microsoft#16832)
### Description This PR splits out the FP16 conversions into a separate package we can override in the android build with a version which works on old versions of Android. I'm not sure the android build system changes are correct as I haven't got an android build environment configured on my workstation. @YUNQIUGUO if the CI build fails we should follow up offline to get my environment configured so I can iterate on it. ### Motivation and Context Fixes the CI failure after microsoft#16703.
- Loading branch information
Showing
10 changed files
with
577 additions
and
284 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
237 changes: 237 additions & 0 deletions
237
java/src/main/android/ai/onnxruntime/platform/Fp16Conversions.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
/* | ||
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. | ||
* Licensed under the MIT License. | ||
*/ | ||
package ai.onnxruntime.platform; | ||
|
||
import java.lang.invoke.MethodHandle; | ||
import java.lang.invoke.MethodHandles; | ||
import java.lang.invoke.MethodType; | ||
import java.nio.ByteBuffer; | ||
import java.nio.ByteOrder; | ||
import java.nio.FloatBuffer; | ||
import java.nio.ShortBuffer; | ||
import java.util.logging.Level; | ||
import java.util.logging.Logger; | ||
|
||
/** * Conversions between fp16, bfloat16 and fp32. */ | ||
public final class Fp16Conversions { | ||
private static final Logger logger = Logger.getLogger(Fp16Conversions.class.getName()); | ||
|
||
/** | ||
* Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java). | ||
* | ||
* <p>Respects the position and limit of the input buffer. | ||
* | ||
* @param buf The buffer of floats. | ||
* @return A buffer of fp16 values stored as shorts. | ||
*/ | ||
public static ShortBuffer convertFloatBufferToFp16Buffer(FloatBuffer buf) { | ||
int pos = buf.position(); | ||
int remaining = buf.remaining(); | ||
ShortBuffer output = | ||
ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer(); | ||
for (int i = 0; i < remaining; i++) { | ||
output.put(i, floatToFp16(buf.get(i + pos))); | ||
} | ||
return output; | ||
} | ||
|
||
/** | ||
* Casts a buffer of fp16 values stored as shorts into a buffer of floats. | ||
* | ||
* <p>Respects the position and limit of the input buffer. | ||
* | ||
* @param buf The buffer of fp16 values stored as shorts. | ||
* @return A buffer of float values. | ||
*/ | ||
public static FloatBuffer convertFp16BufferToFloatBuffer(ShortBuffer buf) { | ||
int pos = buf.position(); | ||
int remaining = buf.remaining(); | ||
FloatBuffer output = | ||
ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); | ||
for (int i = 0; i < remaining; i++) { | ||
output.put(i, fp16ToFloat(buf.get(i + pos))); | ||
} | ||
return output; | ||
} | ||
|
||
/** | ||
* Rounds a buffer of floats into a buffer containing bf16 values (stored as shorts in Java). | ||
* | ||
* <p>Respects the position and limit of the input buffer. | ||
* | ||
* @param buf The buffer of floats. | ||
* @return A buffer of bf16 values stored as shorts. | ||
*/ | ||
public static ShortBuffer convertFloatBufferToBf16Buffer(FloatBuffer buf) { | ||
int pos = buf.position(); | ||
int remaining = buf.remaining(); | ||
ShortBuffer output = | ||
ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer(); | ||
for (int i = 0; i < remaining; i++) { | ||
output.put(i, floatToBf16(buf.get(i + pos))); | ||
} | ||
return output; | ||
} | ||
|
||
/** | ||
* Casts a buffer of bf16 values stored as shorts into a buffer of floats. | ||
* | ||
* <p>Respects the position and limit of the input buffer. | ||
* | ||
* @param buf The buffer of bf16 values stored as shorts. | ||
* @return A buffer of float values. | ||
*/ | ||
public static FloatBuffer convertBf16BufferToFloatBuffer(ShortBuffer buf) { | ||
int pos = buf.position(); | ||
int remaining = buf.remaining(); | ||
FloatBuffer output = | ||
ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); | ||
for (int i = 0; i < remaining; i++) { | ||
output.put(i, bf16ToFloat(buf.get(i + pos))); | ||
} | ||
return output; | ||
} | ||
|
||
/** | ||
* Converts a fp16 value stored in a short into a float value. | ||
* | ||
* <p>On Android this is an alias for {@link #mlasFp16ToFloat(short)}. | ||
* | ||
* @param input The fp16 value. | ||
* @return The float value. | ||
*/ | ||
public static float fp16ToFloat(short input) { | ||
return mlasFp16ToFloat(input); | ||
} | ||
|
||
/** | ||
* Converts a float value into a fp16 value stored in a short. | ||
* | ||
* <p>On Android this is an alias for {@link #mlasFloatToFp16(float)}. | ||
* | ||
* @param input The float value. | ||
* @return The fp16 value. | ||
*/ | ||
public static short floatToFp16(float input) { | ||
return mlasFloatToFp16(input); | ||
} | ||
|
||
/** | ||
* Upcasts a fp16 value to a float. Mirrors the conversion in MLAS. | ||
* | ||
* @param input A uint16_t representing an IEEE half precision float. | ||
* @return A float. | ||
*/ | ||
public static float mlasFp16ToFloat(short input) { | ||
// Port of MLAS_Half2Float from onnxruntime/core/mlas/inc/mlas_float16.h | ||
final int MAGIC = 113 << 23; | ||
// exponent mask after shift | ||
final int SHIFTED_EXP = 0x7c00 << 13; | ||
|
||
// exponent/mantissa bits | ||
int bits = (input & 0x7fff) << 13; | ||
// just the exponent | ||
final int exp = SHIFTED_EXP & bits; | ||
// exponent adjust | ||
bits += (127 - 15) << 23; | ||
|
||
// handle exponent special cases | ||
if (exp == SHIFTED_EXP) { | ||
// Inf/NaN? | ||
// extra exp adjust | ||
bits += (128 - 16) << 23; | ||
} else if (exp == 0) { | ||
// Zero/Denormal? | ||
// extra exp adjust | ||
bits += (1 << 23); | ||
// renormalize | ||
float tmp = Float.intBitsToFloat(bits) - Float.intBitsToFloat(MAGIC); | ||
bits = Float.floatToIntBits(tmp); | ||
} | ||
|
||
// sign bit | ||
bits |= (input & 0x8000) << 16; | ||
|
||
return Float.intBitsToFloat(bits); | ||
} | ||
|
||
/** | ||
* Rounds a float value to fp16. Mirrors the conversion in MLAS. | ||
* | ||
* @param input A float value. | ||
* @return The value rounded to an IEEE half precision value. | ||
*/ | ||
public static short mlasFloatToFp16(float input) { | ||
// Port of MLAS_Float2Half from onnxruntime/core/mlas/inc/mlas_float16.h | ||
int bits = Float.floatToIntBits(input); | ||
final int F32_INFINITY = Float.floatToIntBits(Float.POSITIVE_INFINITY); | ||
final int F16_MAX = (127 + 16) << 23; | ||
final int DENORM_MAGIC = ((127 - 15) + (23 - 10) + 1) << 23; | ||
final int SIGN_MASK = 0x80000000; | ||
final int ROUNDING_CONST = ((15 - 127) << 23) + 0xfff; | ||
|
||
int sign = bits & SIGN_MASK; | ||
// mask out sign bit | ||
bits ^= sign; | ||
|
||
short output; | ||
if (bits >= F16_MAX) { | ||
// Inf or NaN (all exponent bits set) | ||
output = (bits > F32_INFINITY) ? (short) 0x7e00 : (short) 0x7c00; | ||
} else { | ||
if (bits < (113 << 23)) { | ||
// Subnormal or zero | ||
// use a magic value to align our 10 mantissa bits at the bottom of | ||
// the float. as long as FP addition is round-to-nearest-even this | ||
// just works. | ||
float tmp = Float.intBitsToFloat(bits) + Float.intBitsToFloat(DENORM_MAGIC); | ||
|
||
// and one integer subtract of the bias later, we have our final float! | ||
output = (short) (Float.floatToIntBits(tmp) - DENORM_MAGIC); | ||
} else { | ||
int mant_odd = (bits >> 13) & 1; // resulting mantissa is odd | ||
|
||
// update exponent, rounding bias part 1 | ||
bits += ROUNDING_CONST; | ||
// rounding bias part 2 | ||
bits += mant_odd; | ||
// take the bits! | ||
output = (short) (bits >> 13); | ||
} | ||
} | ||
|
||
// Add the sign back in | ||
output = (short) (output | ((short) (sign >> 16))); | ||
|
||
return output; | ||
} | ||
|
||
/** | ||
* Converts a bf16 value stored in a short into a float value. | ||
* | ||
* @param input A uint16_t representing a bfloat16 value. | ||
* @return A float. | ||
*/ | ||
public static float bf16ToFloat(short input) { | ||
int bits = input << 16; | ||
return Float.intBitsToFloat(bits); | ||
} | ||
|
||
/** | ||
* Converts a float into bf16. May not produce correct values for subnormal floats. | ||
* | ||
* <p>Rounds to nearest even. | ||
* | ||
* @param input The float input. | ||
* @return A bfloat16 value which is closest to the float. | ||
*/ | ||
public static short floatToBf16(float input) { | ||
int bits = Float.floatToIntBits(input); | ||
int lsb = (bits >> 16) & 1; | ||
int roundingBias = 0x7fff + lsb; | ||
bits += roundingBias; | ||
return (short) (bits >> 16); | ||
} | ||
} |
13 changes: 13 additions & 0 deletions
13
java/src/main/android/ai/onnxruntime/platform/package-info.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
/* | ||
* Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. | ||
* Licensed under the MIT License. | ||
*/ | ||
|
||
/** | ||
* A package of platform specific code, used to swap out Java implementations which don't run on Android. | ||
* | ||
* <p>Classes in this package should always have the same public methods. | ||
* | ||
* <p>This is the Android version of the package. | ||
*/ | ||
package ai.onnxruntime.platform; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.