Skip to content

Commit

Permalink
Feat: refactor STT service (#294)
Browse files Browse the repository at this point in the history
* add stt hook interface

* fix crypto exported to browser

* refactor use-transcribe

* may use openai stt

* refactor: remove decprecated codes

* fix undefined method
  • Loading branch information
an-lee authored Feb 10, 2024
1 parent a716719 commit bc22a5e
Show file tree
Hide file tree
Showing 21 changed files with 484 additions and 629 deletions.
1 change: 1 addition & 0 deletions enjoy/src/i18n/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@
"azureSpeechToTextDescription": "Use Azure AI Speech to transcribe. It is a paid service.",
"cloudflareAi": "Cloudflare AI",
"cloudflareSpeechToTextDescription": "Use Cloudflare AI Worker to transcribe. It is in beta and free for now.",
"openaiSpeechToTextDescription": "Use openAI to transcribe using your own key.",
"checkingWhisper": "Checking whisper status",
"pleaseDownloadWhisperModelFirst": "Please download whisper model first",
"whisperIsWorkingGood": "Whisper is working good",
Expand Down
1 change: 1 addition & 0 deletions enjoy/src/i18n/zh-CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@
"azureSpeechToTextDescription": "使用 Azure AI Speech 进行语音转文本,收费服务",
"cloudflareAi": "Cloudflare AI",
"cloudflareSpeechToTextDescription": "使用 Cloudflare AI 进行语音转文本,目前免费",
"openaiSpeechToTextDescription": "使用 OpenAI 进行语音转文本(需要 API 密钥)",
"checkingWhisper": "正在检查 Whisper",
"pleaseDownloadWhisperModelFirst": "请先下载 Whisper 模型",
"whisperIsWorkingGood": "Whisper 正常工作",
Expand Down
2 changes: 1 addition & 1 deletion enjoy/src/main/db/handlers/speeches-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { Speech } from "@main/db/models";
import fs from "fs-extra";
import path from "path";
import settings from "@main/settings";
import { hashFile } from "@/utils";
import { hashFile } from "@main/utils";

class SpeechesHandler {
private async create(
Expand Down
64 changes: 3 additions & 61 deletions enjoy/src/main/db/handlers/transcriptions-handler.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import { ipcMain, IpcMainEvent } from "electron";
import { Transcription, Audio, Video } from "@main/db/models";
import { WhereOptions, Attributes } from "sequelize";
import { t } from "i18next";
import { Attributes } from "sequelize";
import log from "electron-log/main";

const logger = log.scope("db/handlers/transcriptions-handler");
Expand Down Expand Up @@ -44,7 +43,7 @@ class TranscriptionsHandler {
id: string,
params: Attributes<Transcription>
) {
const { result } = params;
const { result, engine, model, state } = params;

return Transcription.findOne({
where: { id },
Expand All @@ -53,63 +52,7 @@ class TranscriptionsHandler {
if (!transcription) {
throw new Error("models.transcription.notFound");
}
transcription.update({ result });
})
.catch((err) => {
logger.error(err);
event.sender.send("on-notification", {
type: "error",
message: err.message,
});
});
}

private async process(
event: IpcMainEvent,
where: WhereOptions<Attributes<Transcription>>,
options?: {
force?: boolean;
blob: {
type: string;
arrayBuffer: ArrayBuffer;
};
}
) {
const { force = true, blob } = options || {};
return Transcription.findOne({
where: {
...where,
},
})
.then((transcription) => {
if (!transcription) {
throw new Error("models.transcription.notFound");
}

const interval = setInterval(() => {
event.sender.send("on-notification", {
type: "warning",
message: t("stillTranscribing"),
});
}, 1000 * 10);

transcription
.process({
force,
wavFileBlob: blob,
onProgress: (progress: number) => {
event.sender.send("transcription-on-progress", progress);
},
})
.catch((err) => {
event.sender.send("on-notification", {
type: "error",
message: err.message,
});
})
.finally(() => {
clearInterval(interval);
});
transcription.update({ result, engine, model, state });
})
.catch((err) => {
logger.error(err);
Expand All @@ -122,7 +65,6 @@ class TranscriptionsHandler {

register() {
ipcMain.handle("transcriptions-find-or-create", this.findOrCreate);
ipcMain.handle("transcriptions-process", this.process);
ipcMain.handle("transcriptions-update", this.update);
}
}
Expand Down
43 changes: 1 addition & 42 deletions enjoy/src/main/db/models/audio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import {
import { Recording, Speech, Transcription, Video } from "@main/db/models";
import settings from "@main/settings";
import { AudioFormats, VideoFormats, WEB_API_URL } from "@/constants";
import { hashFile } from "@/utils";
import { hashFile } from "@main/utils";
import path from "path";
import fs from "fs-extra";
import { t } from "i18next";
Expand Down Expand Up @@ -191,15 +191,6 @@ export class Audio extends Model<Audio> {
}
}

@AfterCreate
static transcribeAsync(audio: Audio) {
if (settings.ffmpegConfig().ready) {
setTimeout(() => {
audio.transcribe();
}, 500);
}
}

@AfterCreate
static autoSync(audio: Audio) {
// auto sync should not block the main thread
Expand Down Expand Up @@ -332,38 +323,6 @@ export class Audio extends Model<Audio> {
});
}

// STT using whisper
async transcribe() {
Transcription.findOrCreate({
where: {
targetId: this.id,
targetType: "Audio",
},
defaults: {
targetId: this.id,
targetType: "Audio",
targetMd5: this.md5,
},
})
.then(([transcription, _created]) => {
if (transcription.state === "pending") {
transcription.process();
} else if (transcription.state === "finished") {
transcription.process({ force: true });
} else if (transcription.state === "processing") {
logger.warn(
`[${transcription.getDataValue("id")}]`,
"Transcription is processing."
);
}
})
.catch((err) => {
logger.error(err);

throw err;
});
}

static notify(audio: Audio, action: "create" | "update" | "destroy") {
if (!mainWindow.win) return;

Expand Down
2 changes: 1 addition & 1 deletion enjoy/src/main/db/models/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import fs from "fs-extra";
import path from "path";
import Ffmpeg from "@main/ffmpeg";
import whisper from "@main/whisper";
import { hashFile } from "@/utils";
import { hashFile } from "@main/utils";
import { WEB_API_URL } from "@/constants";
import proxyAgent from "@main/proxy-agent";

Expand Down
2 changes: 1 addition & 1 deletion enjoy/src/main/db/models/recording.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { Audio, PronunciationAssessment, Video } from "@main/db/models";
import fs from "fs-extra";
import path from "path";
import settings from "@main/settings";
import { hashFile } from "@/utils";
import { hashFile } from "@main/utils";
import log from "electron-log/main";
import storage from "@main/storage";
import { Client } from "@/api";
Expand Down
2 changes: 1 addition & 1 deletion enjoy/src/main/db/models/speech.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import path from "path";
import settings from "@main/settings";
import OpenAI, { type ClientOptions } from "openai";
import { t } from "i18next";
import { hashFile } from "@/utils";
import { hashFile } from "@main/utils";
import { Audio, Message } from "@main/db/models";
import log from "electron-log/main";
import { WEB_API_URL } from "@/constants";
Expand Down
115 changes: 1 addition & 114 deletions enjoy/src/main/db/models/transcription.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import {
AfterCreate,
AfterUpdate,
AfterDestroy,
AfterFind,
Expand All @@ -13,18 +12,13 @@ import {
Unique,
} from "sequelize-typescript";
import { Audio, Video } from "@main/db/models";
import whisper from "@main/whisper";
import mainWindow from "@main/window";
import log from "electron-log/main";
import { Client } from "@/api";
import { WEB_API_URL, PROCESS_TIMEOUT } from "@/constants";
import settings from "@main/settings";
import Ffmpeg from "@main/ffmpeg";
import path from "path";
import fs from "fs-extra";

const logger = log.scope("db/models/transcription");

@Table({
modelName: "Transcription",
tableName: "transcriptions",
Expand Down Expand Up @@ -80,120 +74,13 @@ export class Transcription extends Model<Transcription> {
const webApi = new Client({
baseUrl: process.env.WEB_API_URL || WEB_API_URL,
accessToken: settings.getSync("user.accessToken") as string,
logger: log.scope("api/client"),
logger,
});
return webApi.syncTranscription(this.toJSON()).then(() => {
this.update({ syncedAt: new Date() });
});
}

// STT using whisper
async process(
options: {
force?: boolean;
wavFileBlob?: { type: string; arrayBuffer: ArrayBuffer };
onProgress?: (progress: number) => void;
} = {}
) {
if (this.getDataValue("state") === "processing") return;

const { force = false, wavFileBlob, onProgress } = options;

logger.info(`[${this.getDataValue("id")}]`, "Start to transcribe.");

let filePath = "";
if (this.targetType === "Audio") {
filePath = (await Audio.findByPk(this.targetId)).filePath;
} else if (this.targetType === "Video") {
filePath = (await Video.findByPk(this.targetId)).filePath;
}

if (!filePath) {
logger.error(`[${this.getDataValue("id")}]`, "No file path.");
throw new Error("No file path.");
}

let wavFile: string = filePath;

const tmpDir = settings.cachePath();
const outputFile = path.join(
tmpDir,
path.basename(filePath, path.extname(filePath)) + ".wav"
);

if (wavFileBlob) {
const format = wavFileBlob.type.split("/")[1];

if (format !== "wav") {
throw new Error("Only wav format is supported");
}

await fs.outputFile(outputFile, Buffer.from(wavFileBlob.arrayBuffer));
wavFile = outputFile;
} else if (settings.ffmpegConfig().ready) {
const ffmpeg = new Ffmpeg();
try {
wavFile = await ffmpeg.prepareForWhisper(
filePath,
path.join(
tmpDir,
path.basename(filePath, path.extname(filePath)) + ".wav"
)
);
} catch (err) {
logger.error("ffmpeg error", err);
}
}

try {
await this.update({
state: "processing",
});
const {
engine = "whisper",
model,
transcription,
} = await whisper.transcribe(wavFile, {
force,
extra: [
"--split-on-word",
"--max-len",
"1",
"--prompt",
`"Hello! Welcome to listen to this audio."`,
],
onProgress,
});
const result = whisper.groupTranscription(transcription);
this.update({
engine,
model: model?.type,
result,
state: "finished",
}).then(() => this.sync());

logger.info(`[${this.getDataValue("id")}]`, "Transcription finished.");
} catch (err) {
logger.error(
`[${this.getDataValue("id")}]`,
"Transcription not finished.",
err
);
this.update({
state: "pending",
});

throw err;
}
}

@AfterCreate
static startTranscribeAsync(transcription: Transcription) {
setTimeout(() => {
transcription.process();
}, 0);
}

@AfterUpdate
static notifyForUpdate(transcription: Transcription) {
this.notify(transcription, "update");
Expand Down
Loading

0 comments on commit bc22a5e

Please sign in to comment.