diff --git a/android/src/main/java/com/rnwhisper/RNWhisperModule.java b/android/src/main/java/com/rnwhisper/RNWhisperModule.java index 2ab687c..2da4808 100644 --- a/android/src/main/java/com/rnwhisper/RNWhisperModule.java +++ b/android/src/main/java/com/rnwhisper/RNWhisperModule.java @@ -12,6 +12,7 @@ import com.facebook.react.bridge.ReactMethod; import com.facebook.react.bridge.LifecycleEventListener; import com.facebook.react.bridge.ReadableMap; +import com.facebook.react.bridge.WritableMap; import com.facebook.react.module.annotations.ReactModule; import java.util.HashMap; @@ -72,11 +73,11 @@ protected void onPostExecute(Integer id) { @ReactMethod public void transcribe(int id, String filePath, ReadableMap options, Promise promise) { - new AsyncTask() { + new AsyncTask() { private Exception exception; @Override - protected String doInBackground(Void... voids) { + protected WritableMap doInBackground(Void... voids) { try { WhisperContext context = contexts.get(id); if (context == null) { @@ -90,12 +91,12 @@ protected String doInBackground(Void... voids) { } @Override - protected void onPostExecute(String result) { + protected void onPostExecute(WritableMap data) { if (exception != null) { promise.reject(exception); return; } - promise.resolve(result); + promise.resolve(data); } }.execute(); } diff --git a/android/src/main/java/com/rnwhisper/WhisperContext.java b/android/src/main/java/com/rnwhisper/WhisperContext.java index 2930ae4..4c26891 100644 --- a/android/src/main/java/com/rnwhisper/WhisperContext.java +++ b/android/src/main/java/com/rnwhisper/WhisperContext.java @@ -1,5 +1,8 @@ package com.rnwhisper; +import com.facebook.react.bridge.Arguments; +import com.facebook.react.bridge.WritableArray; +import com.facebook.react.bridge.WritableMap; import com.facebook.react.bridge.ReadableMap; import android.util.Log; @@ -29,7 +32,7 @@ public WhisperContext(long context) { this.context = context; } - public String transcribe(final String filePath, final ReadableMap options) throws IOException, Exception { + public WritableMap transcribe(final String filePath, final ReadableMap options) throws IOException, Exception { int code = fullTranscribe( context, decodeWaveFile(new File(filePath)), @@ -37,14 +40,18 @@ public String transcribe(final String filePath, final ReadableMap options) throw options.hasKey("maxThreads") ? options.getInt("maxThreads") : -1, // jint max_context, options.hasKey("maxContext") ? options.getInt("maxContext") : -1, + + // jint word_thold, + options.hasKey("wordThold") ? options.getInt("wordThold") : -1, // jint max_len, options.hasKey("maxLen") ? options.getInt("maxLen") : -1, + // jboolean token_timestamps, + options.hasKey("tokenTimestamps") ? options.getBoolean("tokenTimestamps") : false, + // jint offset, options.hasKey("offset") ? options.getInt("offset") : -1, // jint duration, options.hasKey("duration") ? options.getInt("duration") : -1, - // jint word_thold, - options.hasKey("wordThold") ? options.getInt("wordThold") : -1, // jfloat temperature, options.hasKey("temperature") ? (float) options.getDouble("temperature") : -1.0f, // jfloat temperature_inc, @@ -67,10 +74,22 @@ public String transcribe(final String filePath, final ReadableMap options) throw } Integer count = getTextSegmentCount(context); StringBuilder builder = new StringBuilder(); + + WritableMap data = Arguments.createMap(); + WritableArray segments = Arguments.createArray(); for (int i = 0; i < count; i++) { - builder.append(getTextSegment(context, i)); + String text = getTextSegment(context, i); + builder.append(text); + + WritableMap segment = Arguments.createMap(); + segment.putString("text", text); + segment.putInt("t0", getTextSegmentT0(context, i)); + segment.putInt("t1", getTextSegmentT1(context, i)); + segments.pushMap(segment); } - return builder.toString(); + data.putString("result", builder.toString()); + data.putArray("segments", segments); + return data; } public void release() { @@ -170,10 +189,11 @@ protected static native int fullTranscribe( float[] audio_data, int n_threads, int max_context, + int word_thold, int max_len, + boolean token_timestamps, int offset, int duration, - int word_thold, float temperature, float temperature_inc, int beam_size, @@ -185,5 +205,7 @@ protected static native int fullTranscribe( ); protected static native int getTextSegmentCount(long context); protected static native String getTextSegment(long context, int index); + protected static native int getTextSegmentT0(long context, int index); + protected static native int getTextSegmentT1(long context, int index); protected static native void freeContext(long contextPtr); } \ No newline at end of file diff --git a/android/src/main/jni/whisper/jni.cpp b/android/src/main/jni/whisper/jni.cpp index 5ece6e6..67c672d 100644 --- a/android/src/main/jni/whisper/jni.cpp +++ b/android/src/main/jni/whisper/jni.cpp @@ -44,10 +44,11 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( jfloatArray audio_data, jint n_threads, jint max_context, - jint max_len, + int word_thold, + int max_len, + jboolean token_timestamps, jint offset, jint duration, - jint word_thold, jfloat temperature, jfloat temperature_inc, jint beam_size, @@ -86,15 +87,17 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe( params.no_context = true; params.single_segment = false; + if (max_len > -1) { + params.max_len = max_len; + } + params.token_timestamps = token_timestamps; + if (best_of > -1) { params.greedy.best_of = best_of; } if (max_context > -1) { params.n_max_text_ctx = max_context; } - if (max_len > -1) { - params.max_len = max_len; - } if (offset > -1) { params.offset_ms = offset; } @@ -150,6 +153,24 @@ Java_com_rnwhisper_WhisperContext_getTextSegment( return string; } +JNIEXPORT jint JNICALL +Java_com_rnwhisper_WhisperContext_getTextSegmentT0( + JNIEnv *env, jobject thiz, jlong context_ptr, jint index) { + UNUSED(env); + UNUSED(thiz); + struct whisper_context *context = reinterpret_cast(context_ptr); + return whisper_full_get_segment_t0(context, index); +} + +JNIEXPORT jint JNICALL +Java_com_rnwhisper_WhisperContext_getTextSegmentT1( + JNIEnv *env, jobject thiz, jlong context_ptr, jint index) { + UNUSED(env); + UNUSED(thiz); + struct whisper_context *context = reinterpret_cast(context_ptr); + return whisper_full_get_segment_t1(context, index); +} + JNIEXPORT void JNICALL Java_com_rnwhisper_WhisperContext_freeContext( JNIEnv *env, jobject thiz, jlong context_ptr) { diff --git a/example/src/App.js b/example/src/App.js index f8493f7..aa3832f 100644 --- a/example/src/App.js +++ b/example/src/App.js @@ -46,6 +46,27 @@ console.log('[App] fileDir', fileDir) const modelFilePath = `${fileDir}/base.en` const sampleFilePath = `${fileDir}/jfk.wav` +function toTimestamp(t, comma = false) { + let msec = t * 10 + const hr = Math.floor(msec / (1000 * 60 * 60)) + msec -= hr * (1000 * 60 * 60) + const min = Math.floor(msec / (1000 * 60)) + msec -= min * (1000 * 60) + const sec = Math.floor(msec / 1000) + msec -= sec * 1000 + + const separator = comma ? ',' : '.' + const timestamp = `${String(hr).padStart(2, '0')}:${String(min).padStart( + 2, + '0', + )}:${String(sec).padStart(2, '0')}${separator}${String(msec).padStart( + 3, + '0', + )}` + + return timestamp +} + const filterPath = (path) => path.replace(RNFS.DocumentDirectoryPath, '') @@ -74,9 +95,7 @@ export default function App() { return ( - + { + log( + `[${toTimestamp(segment.t0)} --> ${toTimestamp( + segment.t1, + )}] ${segment.text}`, + ) + }) }} > Transcribe diff --git a/ios/RNWhisper.mm b/ios/RNWhisper.mm index 00ad6a4..44b3310 100644 --- a/ios/RNWhisper.mm +++ b/ios/RNWhisper.mm @@ -94,15 +94,18 @@ @implementation RNWhisper params.no_context = true; params.single_segment = false; + if (options[@"maxLen"] != nil) { + params.max_len = [options[@"maxLen"] intValue]; + } + params.token_timestamps = options[@"tokenTimestamps"] != nil ? [options[@"tokenTimestamps"] boolValue] : false; + if (options[@"bestOf"] != nil) { params.greedy.best_of = [options[@"bestOf"] intValue]; } if (options[@"maxContext"] != nil) { params.n_max_text_ctx = [options[@"maxContext"] intValue]; } - if (options[@"maxLen"] != nil) { - params.max_len = [options[@"maxLen"] intValue]; - } + if (options[@"offset"] != nil) { params.offset_ms = [options[@"offset"] intValue]; } @@ -118,7 +121,7 @@ @implementation RNWhisper if (options[@"temperatureInc"] != nil) { params.temperature_inc = [options[@"temperature_inc"] floatValue]; } - + if (options[@"prompt"] != nil) { std::string *prompt = new std::string([options[@"prompt"] UTF8String]); rn_whisper_convert_prompt( @@ -142,11 +145,25 @@ @implementation RNWhisper NSString *result = @""; int n_segments = whisper_full_n_segments(context.ctx); + + NSMutableArray *segments = [[NSMutableArray alloc] init]; for (int i = 0; i < n_segments; i++) { const char * text_cur = whisper_full_get_segment_text(context.ctx, i); result = [result stringByAppendingString:[NSString stringWithUTF8String:text_cur]]; - } - resolve(result); + + const int64_t t0 = whisper_full_get_segment_t0(context.ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(context.ctx, i); + NSDictionary *segment = @{ + @"text": [NSString stringWithUTF8String:text_cur], + @"t0": [NSNumber numberWithLongLong:t0], + @"t1": [NSNumber numberWithLongLong:t1] + }; + [segments addObject:segment]; + } + resolve(@{ + @"result": result, + @"segments": segments + }); } RCT_REMAP_METHOD(releaseContext, diff --git a/jest/mock.js b/jest/mock.js index 664f1fb..9edea9a 100644 --- a/jest/mock.js +++ b/jest/mock.js @@ -3,7 +3,10 @@ const { NativeModules } = require('react-native') if (!NativeModules.RNWhisper) { NativeModules.RNWhisper = { initContext: jest.fn(() => Promise.resolve(1)), - transcribe: jest.fn(() => Promise.resolve('TEST')), + transcribe: jest.fn(() => Promise.resolve({ + result: ' Test', + segments: [{ text: ' Test', t0: 0, t1: 33 }], + })), releaseContext: jest.fn(() => Promise.resolve()), releaseAllContexts: jest.fn(() => Promise.resolve()), } diff --git a/src/__tests__/index.test.tsx b/src/__tests__/index.test.tsx index 255b6cf..ca87c32 100644 --- a/src/__tests__/index.test.tsx +++ b/src/__tests__/index.test.tsx @@ -5,7 +5,10 @@ jest.mock('..', () => require('../../jest/mock')) test('Mock', async () => { const context = await initWhisper() expect(context.id).toBe(1) - expect(await context.transcribe('test.wav')).toEqual({ result: 'TEST' }) + expect(await context.transcribe('test.wav')).toEqual({ + result: ' Test', + segments: [{ text: ' Test', t0: 0, t1: 33 }], + }) await context.release() await releaseAllWhisper() }) diff --git a/src/index.tsx b/src/index.tsx index 5ad2ae4..b63f381 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -21,6 +21,7 @@ export type TranscribeOptions = { maxThreads?: number, maxContext?: number, maxLen?: number, + tokenTimestamps?: boolean, offset?: number, duration?: number, wordThold?: number, @@ -34,6 +35,11 @@ export type TranscribeOptions = { export type TranscribeResult = { result: string, + segments: Array<{ + text: string, + t0: number, + t1: number, + }>, } class WhisperContext { @@ -44,9 +50,7 @@ class WhisperContext { } async transcribe(path: string, options: TranscribeOptions = {}): Promise { - return RNWhisper.transcribe(this.id, path, options).then((result: string) => ({ - result - })) + return RNWhisper.transcribe(this.id, path, options) } async release() {