Skip to content

Commit

Permalink
feat: support word-level timestamp on transcribe (#9)
Browse files Browse the repository at this point in the history
* feat(ios): add token_timestamps option & put segments in result

* feat(android): add token_timestamps option & put segments in result

* feat(jest): update mock

* feat(example): format log segments
  • Loading branch information
jhen0409 authored Mar 21, 2023
1 parent dc2d064 commit 6ac4b4e
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 30 deletions.
9 changes: 5 additions & 4 deletions android/src/main/java/com/rnwhisper/RNWhisperModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.facebook.react.bridge.ReactMethod;
import com.facebook.react.bridge.LifecycleEventListener;
import com.facebook.react.bridge.ReadableMap;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.module.annotations.ReactModule;

import java.util.HashMap;
Expand Down Expand Up @@ -72,11 +73,11 @@ protected void onPostExecute(Integer id) {

@ReactMethod
public void transcribe(int id, String filePath, ReadableMap options, Promise promise) {
new AsyncTask<Void, Void, String>() {
new AsyncTask<Void, Void, WritableMap>() {
private Exception exception;

@Override
protected String doInBackground(Void... voids) {
protected WritableMap doInBackground(Void... voids) {
try {
WhisperContext context = contexts.get(id);
if (context == null) {
Expand All @@ -90,12 +91,12 @@ protected String doInBackground(Void... voids) {
}

@Override
protected void onPostExecute(String result) {
protected void onPostExecute(WritableMap data) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
promise.resolve(data);
}
}.execute();
}
Expand Down
34 changes: 28 additions & 6 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package com.rnwhisper;

import com.facebook.react.bridge.Arguments;
import com.facebook.react.bridge.WritableArray;
import com.facebook.react.bridge.WritableMap;
import com.facebook.react.bridge.ReadableMap;

import android.util.Log;
Expand Down Expand Up @@ -29,22 +32,26 @@ public WhisperContext(long context) {
this.context = context;
}

public String transcribe(final String filePath, final ReadableMap options) throws IOException, Exception {
public WritableMap transcribe(final String filePath, final ReadableMap options) throws IOException, Exception {
int code = fullTranscribe(
context,
decodeWaveFile(new File(filePath)),
// jint n_threads,
options.hasKey("maxThreads") ? options.getInt("maxThreads") : -1,
// jint max_context,
options.hasKey("maxContext") ? options.getInt("maxContext") : -1,

// jint word_thold,
options.hasKey("wordThold") ? options.getInt("wordThold") : -1,
// jint max_len,
options.hasKey("maxLen") ? options.getInt("maxLen") : -1,
// jboolean token_timestamps,
options.hasKey("tokenTimestamps") ? options.getBoolean("tokenTimestamps") : false,

// jint offset,
options.hasKey("offset") ? options.getInt("offset") : -1,
// jint duration,
options.hasKey("duration") ? options.getInt("duration") : -1,
// jint word_thold,
options.hasKey("wordThold") ? options.getInt("wordThold") : -1,
// jfloat temperature,
options.hasKey("temperature") ? (float) options.getDouble("temperature") : -1.0f,
// jfloat temperature_inc,
Expand All @@ -67,10 +74,22 @@ public String transcribe(final String filePath, final ReadableMap options) throw
}
Integer count = getTextSegmentCount(context);
StringBuilder builder = new StringBuilder();

WritableMap data = Arguments.createMap();
WritableArray segments = Arguments.createArray();
for (int i = 0; i < count; i++) {
builder.append(getTextSegment(context, i));
String text = getTextSegment(context, i);
builder.append(text);

WritableMap segment = Arguments.createMap();
segment.putString("text", text);
segment.putInt("t0", getTextSegmentT0(context, i));
segment.putInt("t1", getTextSegmentT1(context, i));
segments.pushMap(segment);
}
return builder.toString();
data.putString("result", builder.toString());
data.putArray("segments", segments);
return data;
}

public void release() {
Expand Down Expand Up @@ -170,10 +189,11 @@ protected static native int fullTranscribe(
float[] audio_data,
int n_threads,
int max_context,
int word_thold,
int max_len,
boolean token_timestamps,
int offset,
int duration,
int word_thold,
float temperature,
float temperature_inc,
int beam_size,
Expand All @@ -185,5 +205,7 @@ protected static native int fullTranscribe(
);
protected static native int getTextSegmentCount(long context);
protected static native String getTextSegment(long context, int index);
protected static native int getTextSegmentT0(long context, int index);
protected static native int getTextSegmentT1(long context, int index);
protected static native void freeContext(long contextPtr);
}
31 changes: 26 additions & 5 deletions android/src/main/jni/whisper/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
jfloatArray audio_data,
jint n_threads,
jint max_context,
jint max_len,
int word_thold,
int max_len,
jboolean token_timestamps,
jint offset,
jint duration,
jint word_thold,
jfloat temperature,
jfloat temperature_inc,
jint beam_size,
Expand Down Expand Up @@ -86,15 +87,17 @@ Java_com_rnwhisper_WhisperContext_fullTranscribe(
params.no_context = true;
params.single_segment = false;

if (max_len > -1) {
params.max_len = max_len;
}
params.token_timestamps = token_timestamps;

if (best_of > -1) {
params.greedy.best_of = best_of;
}
if (max_context > -1) {
params.n_max_text_ctx = max_context;
}
if (max_len > -1) {
params.max_len = max_len;
}
if (offset > -1) {
params.offset_ms = offset;
}
Expand Down Expand Up @@ -150,6 +153,24 @@ Java_com_rnwhisper_WhisperContext_getTextSegment(
return string;
}

JNIEXPORT jint JNICALL
Java_com_rnwhisper_WhisperContext_getTextSegmentT0(
JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
UNUSED(env);
UNUSED(thiz);
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
return whisper_full_get_segment_t0(context, index);
}

JNIEXPORT jint JNICALL
Java_com_rnwhisper_WhisperContext_getTextSegmentT1(
JNIEnv *env, jobject thiz, jlong context_ptr, jint index) {
UNUSED(env);
UNUSED(thiz);
struct whisper_context *context = reinterpret_cast<struct whisper_context *>(context_ptr);
return whisper_full_get_segment_t1(context, index);
}

JNIEXPORT void JNICALL
Java_com_rnwhisper_WhisperContext_freeContext(
JNIEnv *env, jobject thiz, jlong context_ptr) {
Expand Down
37 changes: 33 additions & 4 deletions example/src/App.js
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,27 @@ console.log('[App] fileDir', fileDir)
const modelFilePath = `${fileDir}/base.en`
const sampleFilePath = `${fileDir}/jfk.wav`

function toTimestamp(t, comma = false) {
let msec = t * 10
const hr = Math.floor(msec / (1000 * 60 * 60))
msec -= hr * (1000 * 60 * 60)
const min = Math.floor(msec / (1000 * 60))
msec -= min * (1000 * 60)
const sec = Math.floor(msec / 1000)
msec -= sec * 1000

const separator = comma ? ',' : '.'
const timestamp = `${String(hr).padStart(2, '0')}:${String(min).padStart(
2,
'0',
)}:${String(sec).padStart(2, '0')}${separator}${String(msec).padStart(
3,
'0',
)}`

return timestamp
}

const filterPath = (path) =>
path.replace(RNFS.DocumentDirectoryPath, '<DocumentDir>')

Expand Down Expand Up @@ -74,9 +95,7 @@ export default function App() {

return (
<SafeAreaView style={styles.container}>
<ScrollView
contentContainerStyle={styles.content}
>
<ScrollView contentContainerStyle={styles.content}>
<View style={styles.buttons}>
<TouchableOpacity
style={styles.button}
Expand Down Expand Up @@ -142,15 +161,25 @@ export default function App() {
}
log('Start transcribing...')
const startTime = Date.now()
const { result } = await whisperContext.transcribe(
const { result, segments } = await whisperContext.transcribe(
sampleFilePath,
{
language: 'en',
maxLen: 1,
tokenTimestamps: true,
},
)
const endTime = Date.now()
log('Transcribed result:', result)
log('Transcribed in', endTime - startTime, `ms in ${mode} mode`)
log('Segments:')
segments.forEach((segment) => {
log(
`[${toTimestamp(segment.t0)} --> ${toTimestamp(
segment.t1,
)}] ${segment.text}`,
)
})
}}
>
<Text style={styles.buttonText}>Transcribe</Text>
Expand Down
29 changes: 23 additions & 6 deletions ios/RNWhisper.mm
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,18 @@ @implementation RNWhisper
params.no_context = true;
params.single_segment = false;

if (options[@"maxLen"] != nil) {
params.max_len = [options[@"maxLen"] intValue];
}
params.token_timestamps = options[@"tokenTimestamps"] != nil ? [options[@"tokenTimestamps"] boolValue] : false;

if (options[@"bestOf"] != nil) {
params.greedy.best_of = [options[@"bestOf"] intValue];
}
if (options[@"maxContext"] != nil) {
params.n_max_text_ctx = [options[@"maxContext"] intValue];
}
if (options[@"maxLen"] != nil) {
params.max_len = [options[@"maxLen"] intValue];
}

if (options[@"offset"] != nil) {
params.offset_ms = [options[@"offset"] intValue];
}
Expand All @@ -118,7 +121,7 @@ @implementation RNWhisper
if (options[@"temperatureInc"] != nil) {
params.temperature_inc = [options[@"temperature_inc"] floatValue];
}

if (options[@"prompt"] != nil) {
std::string *prompt = new std::string([options[@"prompt"] UTF8String]);
rn_whisper_convert_prompt(
Expand All @@ -142,11 +145,25 @@ @implementation RNWhisper

NSString *result = @"";
int n_segments = whisper_full_n_segments(context.ctx);

NSMutableArray *segments = [[NSMutableArray alloc] init];
for (int i = 0; i < n_segments; i++) {
const char * text_cur = whisper_full_get_segment_text(context.ctx, i);
result = [result stringByAppendingString:[NSString stringWithUTF8String:text_cur]];
}
resolve(result);

const int64_t t0 = whisper_full_get_segment_t0(context.ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(context.ctx, i);
NSDictionary *segment = @{
@"text": [NSString stringWithUTF8String:text_cur],
@"t0": [NSNumber numberWithLongLong:t0],
@"t1": [NSNumber numberWithLongLong:t1]
};
[segments addObject:segment];
}
resolve(@{
@"result": result,
@"segments": segments
});
}

RCT_REMAP_METHOD(releaseContext,
Expand Down
5 changes: 4 additions & 1 deletion jest/mock.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ const { NativeModules } = require('react-native')
if (!NativeModules.RNWhisper) {
NativeModules.RNWhisper = {
initContext: jest.fn(() => Promise.resolve(1)),
transcribe: jest.fn(() => Promise.resolve('TEST')),
transcribe: jest.fn(() => Promise.resolve({
result: ' Test',
segments: [{ text: ' Test', t0: 0, t1: 33 }],
})),
releaseContext: jest.fn(() => Promise.resolve()),
releaseAllContexts: jest.fn(() => Promise.resolve()),
}
Expand Down
5 changes: 4 additions & 1 deletion src/__tests__/index.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ jest.mock('..', () => require('../../jest/mock'))
test('Mock', async () => {
const context = await initWhisper()
expect(context.id).toBe(1)
expect(await context.transcribe('test.wav')).toEqual({ result: 'TEST' })
expect(await context.transcribe('test.wav')).toEqual({
result: ' Test',
segments: [{ text: ' Test', t0: 0, t1: 33 }],
})
await context.release()
await releaseAllWhisper()
})
10 changes: 7 additions & 3 deletions src/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export type TranscribeOptions = {
maxThreads?: number,
maxContext?: number,
maxLen?: number,
tokenTimestamps?: boolean,
offset?: number,
duration?: number,
wordThold?: number,
Expand All @@ -34,6 +35,11 @@ export type TranscribeOptions = {

export type TranscribeResult = {
result: string,
segments: Array<{
text: string,
t0: number,
t1: number,
}>,
}

class WhisperContext {
Expand All @@ -44,9 +50,7 @@ class WhisperContext {
}

async transcribe(path: string, options: TranscribeOptions = {}): Promise<TranscribeResult> {
return RNWhisper.transcribe(this.id, path, options).then((result: string) => ({
result
}))
return RNWhisper.transcribe(this.id, path, options)
}

async release() {
Expand Down

0 comments on commit 6ac4b4e

Please sign in to comment.