From 0c158046942230cbc26ae59cea6e498a7f89c609 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Mon, 22 Apr 2024 21:42:37 -0700 Subject: [PATCH 01/29] Initial mlx integration --- Package.resolved | 18 ++++++++++ Package.swift | 23 +++++++++++-- Sources/WhisperKit/MLX/MLXAudioEncoder.swift | 36 ++++++++++++++++++++ Tests/WhisperKitTests/MLXTests.swift | 22 ++++++++++++ 4 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 Sources/WhisperKit/MLX/MLXAudioEncoder.swift create mode 100644 Tests/WhisperKitTests/MLXTests.swift diff --git a/Package.resolved b/Package.resolved index 6cccf25..51c8fee 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,5 +1,14 @@ { "pins" : [ + { + "identity" : "mlx-swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ml-explore/mlx-swift", + "state" : { + "revision" : "7838f8cd93499f3d9e9a35b87cb74fe2664b325f", + "version" : "0.10.0" + } + }, { "identity" : "swift-argument-parser", "kind" : "remoteSourceControl", @@ -9,6 +18,15 @@ "version" : "1.3.0" } }, + { + "identity" : "swift-numerics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-numerics", + "state" : { + "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", + "version" : "1.0.2" + } + }, { "identity" : "swift-transformers", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index f3f111e..5fdfd12 100644 --- a/Package.swift +++ b/Package.swift @@ -7,13 +7,17 @@ let package = Package( name: "whisperkit", platforms: [ .iOS(.v16), - .macOS(.v13), + .macOS("13.3"), ], products: [ .library( name: "WhisperKit", targets: ["WhisperKit"] ), + .library( + name: "WhisperKitMLX", + targets: ["WhisperKitMLX"] + ), .executable( name: "whisperkit-cli", targets: ["WhisperKitCLI"] @@ -21,6 +25,7 @@ let package = Package( ], dependencies: [ .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"), + .package(url: "https://github.com/ml-explore/mlx-swift", exact: "0.10.0"), .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), ], targets: [ @@ -28,7 +33,20 @@ let package = Package( name: "WhisperKit", dependencies: [ .product(name: "Transformers", package: "swift-transformers"), - ] + ], + path: "Sources/WhisperKit/Core" + ), + .target( + name: "WhisperKitMLX", + dependencies: [ + "WhisperKit", + .product(name: "MLX", package: "mlx-swift"), + .product(name: "MLXRandom", package: "mlx-swift"), + .product(name: "MLXNN", package: "mlx-swift"), + .product(name: "MLXOptimizers", package: "mlx-swift"), + .product(name: "MLXFFT", package: "mlx-swift") + ], + path: "Sources/WhisperKit/MLX" ), .executableTarget( name: "WhisperKitCLI", @@ -41,6 +59,7 @@ let package = Package( name: "WhisperKitTests", dependencies: [ "WhisperKit", + "WhisperKitMLX", .product(name: "Transformers", package: "swift-transformers"), ], path: ".", diff --git a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift new file mode 100644 index 0000000..5e57a01 --- /dev/null +++ b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift @@ -0,0 +1,36 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import CoreML +import MLX +import WhisperKit + +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) +public class MLXAudioEncoder: AudioEncoding { + + public var embedSize: Int? { + return 1234 + } + + public var sequenceLength: Int? { + return 1234 + } + + public init() {} + + public func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray? { + // Make sure features is shape MultiArray (Float32 1 × {80,128} × 3000) +// guard let model else { +// throw WhisperError.modelsUnavailable() +// } + try Task.checkCancellation() + +// let interval = Logging.beginSignpost("EncodeAudio", signposter: Logging.AudioEncoding.signposter) +// defer { Logging.endSignpost("EncodeAudio", interval: interval, signposter: Logging.AudioEncoding.signposter) } + + let modelInputs = MLXArray() + let outputFeatures = MLXArray() + let output = MLMultiArray() + return output + } +} diff --git a/Tests/WhisperKitTests/MLXTests.swift b/Tests/WhisperKitTests/MLXTests.swift new file mode 100644 index 0000000..15e6bf1 --- /dev/null +++ b/Tests/WhisperKitTests/MLXTests.swift @@ -0,0 +1,22 @@ +// +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import XCTest +@testable import WhisperKit +@testable import WhisperKitMLX + +final class MLXTests: XCTestCase { + + override func setUpWithError() throws { + // Put setup code here. This method is called before the invocation of each test method in the class. + } + + override func tearDownWithError() throws { + // Put teardown code here. This method is called after the invocation of each test method in the class. + } + + func testExample() throws { + XCTAssertNotNil(MLXAudioEncoder()) + } +} From 211c8346a44721ab51551386b18c6847bb5e228b Mon Sep 17 00:00:00 2001 From: Jan Krukowski Date: Fri, 3 May 2024 06:47:08 +0200 Subject: [PATCH 02/29] Added MLX feature extractor implementation (#129) * Added MLX feature extractor implementation * CI fix * added better multiarray conversion * CI fix * CI fix * fixed `asMLMultiArray` implementation, fixed CI * update xcode, trigger pr when targeting not main branch * check if vision os builds * update watch os version * conditional watchos compilation * conditional package.swift * conditional package.swift * ci fix * ci fix * ci fix * ci fix * ci fix * ci fix * ci fix * ci fix * ci fix * ci fix * ci fix * add other tests targest back * package.swift cleanup * general cleanup * revert to xcode 15.2 --- .github/workflows/development-tests.yml | 1 - .github/workflows/unit-tests.yml | 22 ++- Package.resolved | 4 +- Package.swift | 132 ++++++++++++----- Sources/WhisperKit/MLX/MLXAudioEncoder.swift | 16 +- .../WhisperKit/MLX/MLXFeatureExtractor.swift | 140 ++++++++++++++++++ Sources/WhisperKit/MLX/MLXModels.swift | 9 ++ Sources/WhisperKit/MLX/MLXUtils.swift | 45 ++++++ .../MLX/Resources/mel_filters_128.npy | Bin 0 -> 103040 bytes .../MLX/Resources/mel_filters_80.npy | Bin 0 -> 64448 bytes Tests/WhisperKitMLXTests/MLXUnitTests.swift | 69 +++++++++ Tests/WhisperKitTests/MLXTests.swift | 22 --- 12 files changed, 380 insertions(+), 80 deletions(-) create mode 100644 Sources/WhisperKit/MLX/MLXFeatureExtractor.swift create mode 100644 Sources/WhisperKit/MLX/MLXModels.swift create mode 100644 Sources/WhisperKit/MLX/MLXUtils.swift create mode 100644 Sources/WhisperKit/MLX/Resources/mel_filters_128.npy create mode 100644 Sources/WhisperKit/MLX/Resources/mel_filters_80.npy create mode 100644 Tests/WhisperKitMLXTests/MLXUnitTests.swift delete mode 100644 Tests/WhisperKitTests/MLXTests.swift diff --git a/.github/workflows/development-tests.yml b/.github/workflows/development-tests.yml index 160a13d..2ea2345 100644 --- a/.github/workflows/development-tests.yml +++ b/.github/workflows/development-tests.yml @@ -2,7 +2,6 @@ name: Development Tests on: pull_request: - branches: ["main"] pull_request_review: types: [submitted] workflow_dispatch: diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 9287f61..c5faf3b 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -22,24 +22,32 @@ jobs: condition: true, clean-destination: "generic/platform=macOS", test-destination: "platform=macOS,arch=arm64", + mlx-disabled: "0", + scheme: "whisperkit-Package", } - { name: "iOS", condition: true, clean-destination: "generic/platform=iOS", test-destination: "platform=iOS Simulator,OS=${{ inputs.ios-version }},name=iPhone 15", + mlx-disabled: "0", + scheme: "whisperkit-Package", } - { name: "watchOS", condition: "${{ inputs.macos-runner == 'macos-14' }}", clean-destination: "generic/platform=watchOS", - test-destination: "platform=watchOS Simulator,OS=10.2,name=Apple Watch Ultra 2 (49mm)", + test-destination: "platform=watchOS Simulator,OS=10.5,name=Apple Watch Ultra 2 (49mm)", + mlx-disabled: "1", + scheme: "whisperkit", } - { name: "visionOS", condition: "${{ inputs.macos-runner == 'macos-14' }}", clean-destination: "generic/platform=visionOS", test-destination: "platform=visionOS Simulator,name=Apple Vision Pro", + mlx-disabled: "0", + scheme: "whisperkit-Package", } timeout-minutes: 20 steps: @@ -59,13 +67,19 @@ jobs: if: steps.model-cache.outputs.cache-hit != 'true' run: make download-model MODEL=tiny - name: Install and discover destinations + env: + MLX_DISABLED: ${{ matrix.run-config['mlx-disabled'] }} run: | + echo "Available schemes:" + xcodebuild -list xcodebuild -downloadAllPlatforms echo "Destinations for testing:" - xcodebuild test-without-building -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -showdestinations + export ${{ matrix.run-config['compiler-flags'] }} && xcodebuild test-without-building -only-testing WhisperKitTests/UnitTests -scheme ${{ matrix.run-config['scheme'] }} -showdestinations -skipPackagePluginValidation - name: Build and Test - ${{ matrix.run-config['name'] }} + env: + MLX_DISABLED: ${{ matrix.run-config['mlx-disabled'] }} if: ${{ matrix.run-config['condition'] == true }} run: | set -o pipefail - xcodebuild clean build-for-testing -scheme whisperkit-Package -destination '${{ matrix.run-config['clean-destination'] }}' | xcpretty - xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}' | xcpretty + xcodebuild clean build-for-testing -scheme ${{ matrix.run-config['scheme'] }} -destination "${{ matrix.run-config['clean-destination'] }}" -skipPackagePluginValidation | xcpretty + xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme ${{ matrix.run-config['scheme'] }} -destination "${{ matrix.run-config['test-destination'] }}" -skipPackagePluginValidation | xcpretty diff --git a/Package.resolved b/Package.resolved index 51c8fee..5c38ee0 100644 --- a/Package.resolved +++ b/Package.resolved @@ -5,8 +5,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "revision" : "7838f8cd93499f3d9e9a35b87cb74fe2664b325f", - "version" : "0.10.0" + "branch" : "main", + "revision" : "b43bdff8b6a413eb75e88eafd4a3995971a406fd" } }, { diff --git a/Package.swift b/Package.swift index 5fdfd12..460ce32 100644 --- a/Package.swift +++ b/Package.swift @@ -2,33 +2,69 @@ // The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription +import Foundation +// NOTE: `MLX` doesn't support `watchOS` yet, that's why we control the build using the `MLX_DISABLED` environment variable. +// To manualy build for `watchOS` use: +// `export MLX_DISABLED=1 && xcodebuild clean build-for-testing -scheme whisperkit -sdk watchos10.4 -destination 'platform=watchOS Simulator' -skipPackagePluginValidation` let package = Package( name: "whisperkit", platforms: [ .iOS(.v16), - .macOS("13.3"), + .macOS("13.3") ], - products: [ + products: products() + mlxProducts(), + dependencies: dependencies() + mlxDependencies(), + targets: targets() + mlxTargets() +) + +func products() -> [PackageDescription.Product] { + return [ .library( name: "WhisperKit", targets: ["WhisperKit"] - ), - .library( - name: "WhisperKitMLX", - targets: ["WhisperKitMLX"] - ), - .executable( - name: "whisperkit-cli", - targets: ["WhisperKitCLI"] - ), - ], - dependencies: [ + ) + ] +} + +func mlxProducts() -> [PackageDescription.Product] { + let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" + if isMLXDisabled { + return [] + } else { + return [ + .library( + name: "WhisperKitMLX", + targets: ["WhisperKitMLX"] + ), + .executable( + name: "whisperkit-cli", + targets: ["WhisperKitCLI"] + ), + ] + } +} + +func dependencies() -> [PackageDescription.Package.Dependency] { + return [ .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"), - .package(url: "https://github.com/ml-explore/mlx-swift", exact: "0.10.0"), .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), - ], - targets: [ + ] +} + +func mlxDependencies() -> [PackageDescription.Package.Dependency] { + let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" + if isMLXDisabled { + return [] + } else { + return [ + .package(url: "https://github.com/ml-explore/mlx-swift", branch: "main"), + ] + } +} + +func targets() -> [PackageDescription.Target] { + return [ .target( name: "WhisperKit", dependencies: [ @@ -36,30 +72,10 @@ let package = Package( ], path: "Sources/WhisperKit/Core" ), - .target( - name: "WhisperKitMLX", - dependencies: [ - "WhisperKit", - .product(name: "MLX", package: "mlx-swift"), - .product(name: "MLXRandom", package: "mlx-swift"), - .product(name: "MLXNN", package: "mlx-swift"), - .product(name: "MLXOptimizers", package: "mlx-swift"), - .product(name: "MLXFFT", package: "mlx-swift") - ], - path: "Sources/WhisperKit/MLX" - ), - .executableTarget( - name: "WhisperKitCLI", - dependencies: [ - "WhisperKit", - .product(name: "ArgumentParser", package: "swift-argument-parser"), - ] - ), .testTarget( name: "WhisperKitTests", dependencies: [ "WhisperKit", - "WhisperKitMLX", .product(name: "Transformers", package: "swift-transformers"), ], path: ".", @@ -70,11 +86,51 @@ let package = Package( "README.md", "LICENSE", "CONTRIBUTING.md", + "Tests/WhisperKitMLXTests" ], resources: [ .process("Tests/WhisperKitTests/Resources"), .copy("Models/whisperkit-coreml"), ] - ), + ) ] -) +} + +func mlxTargets() -> [PackageDescription.Target] { + let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" + if isMLXDisabled { + return [] + } else { + return [ + .executableTarget( + name: "WhisperKitCLI", + dependencies: [ + "WhisperKit", + "WhisperKitMLX", + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ] + ), + .target( + name: "WhisperKitMLX", + dependencies: [ + "WhisperKit", + .product(name: "MLX", package: "mlx-swift"), + .product(name: "MLXFFT", package: "mlx-swift") + ], + path: "Sources/WhisperKit/MLX", + resources: [ + .copy("Resources/mel_filters_80.npy"), + .copy("Resources/mel_filters_128.npy") + ] + ), + .testTarget( + name: "WhisperKitMLXTests", + dependencies: [ + "WhisperKit", + "WhisperKitMLX", + .product(name: "Transformers", package: "swift-transformers"), + ] + ) + ] + } +} diff --git a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift index 5e57a01..eee2d0b 100644 --- a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift +++ b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift @@ -9,28 +9,18 @@ import WhisperKit public class MLXAudioEncoder: AudioEncoding { public var embedSize: Int? { - return 1234 + fatalError("Not implemented") } public var sequenceLength: Int? { - return 1234 + fatalError("Not implemented") } public init() {} public func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray? { // Make sure features is shape MultiArray (Float32 1 × {80,128} × 3000) -// guard let model else { -// throw WhisperError.modelsUnavailable() -// } try Task.checkCancellation() - -// let interval = Logging.beginSignpost("EncodeAudio", signposter: Logging.AudioEncoding.signposter) -// defer { Logging.endSignpost("EncodeAudio", interval: interval, signposter: Logging.AudioEncoding.signposter) } - - let modelInputs = MLXArray() - let outputFeatures = MLXArray() - let output = MLMultiArray() - return output + fatalError("Not implemented") } } diff --git a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift new file mode 100644 index 0000000..2f47cd8 --- /dev/null +++ b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift @@ -0,0 +1,140 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import Foundation +import MLX +import MLXFFT +import CoreML +import WhisperKit + +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) +open class MLXFeatureExtractor: FeatureExtracting { + public let melCount: Int? + private let nFFT: Int + private let hopLength: Int + private let filters: MLXArray + + public init( + melCount: Int = 80, + nFFT: Int = 400, + hopLength: Int = 160 + ) { + self.melCount = melCount + self.nFFT = nFFT + self.hopLength = hopLength + self.filters = MLXFeatureExtractor.loadMelFilters(nMels: melCount) + } + + public func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray? { + try Task.checkCancellation() + let input = inputAudio.withUnsafeBytes { ptr in + MLXArray(ptr, inputAudio.shape.map { $0.intValue }, type: Float.self) + } + let logMelSpectrogram = MLXFeatureExtractor.logMelSpectrogram( + audio: input, + filters: filters, + nMels: melCount ?? 80, + nFFT: nFFT, + hopLength: hopLength + ) + return try logMelSpectrogram.asMLMultiArray() + } +} + +extension MLXFeatureExtractor { + /// Return the Hanning window. + /// Taken from [numpy](https://numpy.org/doc/stable/reference/generated/numpy.hanning.html) implementation + public static func hanningNumpy(_ size: Int) -> MLXArray { + if size < 1 { + return MLXArray([Float]()) + } + if size == 1 { + return MLXArray([1.0] as [Float]) + } + let n = MLXArray(Array(stride(from: 1 - size, to: size, by: 2))) + return 0.5 + 0.5 * MLX.cos(MLXArray(.pi) * n / Float(size - 1)) + } + + public static func hanning(_ size: Int) -> MLXArray { + hanningNumpy(size + 1)[..<(-1)] + } + + public static func pad( + _ x: MLXArray, + padding: Int, + padMode: PadMode = .constant + ) -> MLXArray { + switch padMode { + case .constant: + return MLX.padded(x, widths: [IntOrPair((padding, padding))]) + case .reflect: + let prefix = x[1 ..< padding + 1][.stride(by: -1)] + let suffix = x[-(padding + 1) ..< -1][.stride(by: -1)] + return MLX.concatenated([prefix, x, suffix]) + } + } + + public static func stft( + _ x: MLXArray, + window: MLXArray, + nPerSeg: Int = 256, + nOverlap: Int? = nil, + nFFT: Int? = nil, + axis: Int = -1, + padMode: PadMode = .reflect + ) -> MLXArray { + let nFFT = nFFT ?? nPerSeg + let nOverlap = nOverlap ?? nFFT / 4 + + let padding = nPerSeg / 2 + let x = pad(x, padding: padding, padMode: padMode) + + let strides = [nOverlap, 1] + let t = (x.count - nPerSeg + nOverlap) / nOverlap + let shape = [t, nFFT] + return MLXFFT.rfft(MLX.asStrided(x, shape, strides: strides) * window) + } + + /// Compute the log mel spectrogram of audio + /// Taken from [MLX](https://github.com/ml-explore/mlx-examples/blob/c012eb173f0f632e369ec71f08be777df3aede08/whisper/whisper/audio.py#L130) implementation + public static func logMelSpectrogram( + audio: MLXArray, + filters: MLXArray, + nMels: Int = 80, + padding: Int = 0, + nFFT: Int = 400, + hopLength: Int = 160 + ) -> MLXArray { + let device = MLX.Device.defaultDevice() + MLX.Device.setDefault(device: .cpu) + defer { MLX.Device.setDefault(device: device) } + let window = hanning(nFFT) + let freqs = stft(audio, window: window, nPerSeg: nFFT, nOverlap: hopLength) + let magnitudes = freqs[..<(-1)].abs().square() + + let melSpec = magnitudes.matmul(filters.T) + + var logSpec = MLX.maximum(melSpec, 1e-10).log10() + logSpec = MLX.maximum(logSpec, logSpec.max() - 8.0) + logSpec = (logSpec + 4.0) / 4.0 + return logSpec + } + + /// Load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + /// Allows decoupling librosa dependency. + /// + /// Saved using: + /// ```python + /// n80 = librosa.filters.mel(sr=16000, n_fft=400, n_mels=80) + /// n128 = librosa.filters.mel(sr=16000, n_fft=400, n_mels=128) + /// with open('mel_filters_80.npy', 'wb') as f: + /// np.save(f, n80) + /// with open('mel_filters_128.npy', 'wb') as f: + /// np.save(f, n128) + /// ``` + public static func loadMelFilters(nMels: Int) -> MLXArray { + precondition(nMels == 80 || nMels == 128, "Unsupported nMels: \(nMels)") + let fileUrl = Bundle.module.url(forResource: "mel_filters_\(nMels)", withExtension: "npy")! + return try! MLX.loadArray(url: fileUrl) + } +} diff --git a/Sources/WhisperKit/MLX/MLXModels.swift b/Sources/WhisperKit/MLX/MLXModels.swift new file mode 100644 index 0000000..e0534cc --- /dev/null +++ b/Sources/WhisperKit/MLX/MLXModels.swift @@ -0,0 +1,9 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import Foundation + +public enum PadMode { + case constant + case reflect +} diff --git a/Sources/WhisperKit/MLX/MLXUtils.swift b/Sources/WhisperKit/MLX/MLXUtils.swift new file mode 100644 index 0000000..2567c2c --- /dev/null +++ b/Sources/WhisperKit/MLX/MLXUtils.swift @@ -0,0 +1,45 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import Foundation +import MLX +import CoreML + +// MARK: - Extensions + +extension MLXArray { + func asMLMultiArray() throws -> MLMultiArray { + let dataType = multiArrayDataType() + // a buffer to be passed to CoreML + let buffer = UnsafeMutableRawPointer.allocate(byteCount: nbytes, alignment: 8) + // copy the data from the MLXArray backing into buffer + asData(noCopy: true).withUnsafeBytes { ptr in + let destination = UnsafeMutableRawBufferPointer(start: buffer, count: nbytes) + ptr.copyBytes(to: destination) + } + return try MLMultiArray( + dataPointer: buffer, + shape: shape.map { NSNumber(value: $0) }, + dataType: dataType, + strides: strides.map { NSNumber(value: $0) }, + deallocator: { $0.deallocate() } + ) + } +} + +extension MLXArray { + func multiArrayDataType() -> MLMultiArrayDataType { + switch dtype { + case .bool, .bfloat16, .complex64, + .uint8, .uint16, .uint32, .uint64, + .int8, .int16, .int64: + fatalError("Unsupported type: \(dtype)") + case .int32: + return .int32 + case .float16: + return .float16 + case .float32: + return .float32 + } + } +} diff --git a/Sources/WhisperKit/MLX/Resources/mel_filters_128.npy b/Sources/WhisperKit/MLX/Resources/mel_filters_128.npy new file mode 100644 index 0000000000000000000000000000000000000000..a17544243c1adae357bdf6e8d3722df5e8ed73b2 GIT binary patch literal 103040 zcmeI*dsNNo8VB%grgZFyD0M_xWkoTiB-Q>sFN-cj+HPYIF^x*OSB7TDP)v!^Q906W za=&yHO)f1OHE0Y;WejRu!WpK+8b!mHQ~t14tG(Y}>-_tE`$zkIwm+WF`abJ@-~8;7 z)H8ji&grIDtypUv5E#BN%v!RrW9R!Nf zjIft<5Wx!wbQUOFpN_jpXBMnL;JtwR%`(i$dkuU80nGv-O$wa3G}B-e0`CMoTU^<1 z_6ZU=AOHdF0_#oT9ap(5snW^5NH<|mvtR;f+#?M zBXGPV4l}~>K?wrw0)^|%aDPV>AixpuUYtVadJ60(93zw<;55ttdkJCy0U3b;|NYoY zWQ#)%W*dmI|ArlMxbuwBg}}55y}u~`-}+831R>NjzDqkM$8Du2PFu!3phNp z!2KOjfB;7z`^-Wb@KYb`CmbV`AaJ}yiM<3dfPjoZzDFMR5*Z_u#TJ-gRDc-~I|Asw z3j)E{df?2}1t+{kTmkRJ#;ktAOgzgaZUoSI7X%`2SYj{ff)m~%wt%_b2HZ`=jsUus z5%_s(J!VA42xSP232eliAO;ZN2rO?5#f)%#P=Y|aK!N{Y+}{xe2yg^^bj|3cjWzZY zjuA=_&|TXNdkJCy0U3cGO)g=>YX83CtI6P&qZj8KL^ z&YpcZb0G#0;0XNpSRDIG8iQvZIYuZ!AR@Lq&RmEA1Y`sT=AOfuOU4Lgu>}flJF$S% z_Lvi~V}R~u1d0co!i>lmp$vf{XC2&45CaHs1WaoGaFtw6V?W^-p#*^(AANRd*%z1- zL;(UE0l(qv*?_2zFee-%lptXDIES6<=8Jz{L=+&v5vV>`g}sF1gAxSV1%_Dsia9|P zAixpG$hMViz0PA!I7TQzpfT5<%o}nsCx`+BI08i}`7~!x7UqOwgc1bQ6*Xiwe<$Vy zQGftPV6Vk}+Wd3}=7eK}5(GkR@6ZvaWXuVo00E9bvs6Kahf^^p93zwuEnwaJZ)wibBdjWJ6Q1u9I|}GtMqtn7?QBH< z$CwitCzQn&2u}}Ug$_e;M-e*)=w3!3Wa}`d=;Mz$k#RyAiC_sQCQ1zKUs~Z){-uT7h?8R|H zNo;`+Rz}d`HQU(6kc&8bi5&%WFC(BlGL{Wo;*7JGj1$UY3l#rZO;*Lxtn6_d_7t(B zfbKa0Wu6(*)^5{SlJ!9BDI6!1AfR{P5jkF3#eCIUu%{pn5YQqJ5dVV|Smwdj`C4L6 z(PD&U2(%0MG+&@S>a}b}X)*Q`L;?a@1WxoVQ<)AO#Iiq>Ski6J%!_hlGp<3L^mmBoKf5FMv<85D14s1*l9rb zJON)F6|ujDv!sOnOlkfI&wlZ+P!n1p{?>L9KS1=72(qW0Br(MJ}g^Yr0& zcI-yuWx8hd0CyHq!+^eJ1!~IU=}J!nHje#G#YY2h4wDtb8$=a|Z?&Ncy?S~vv_Gpg z$fKV2S=d`d4Fme-3hc`Fl#DN=k=arM_U+LpR2M!Adka?#MUe$e%IvA3?f2AAn`zAY zoh=||^)2izA_oG!YZW-W%0!B+-a;dnD46@w7V;buMcZ`lsGQhs$l zbSxn$Opmn}L?w_OSgu#!LI8UE7sq$&{7aw6>_++RGDo~pFn@a7y zgcfGsrZpLFsor=u8LH}VcM&uXXqzukXS`W@uqc|Q$2HRBUcb}tZt>)N`vm5N&xWd? z0^t|)q=Hdf$o=F)x@C2r4&92O>l2S*UIfhp+U5&y6RO+H7)mOrC}mDmb}2c2$~19%@^K zLL;PfYWT80em=%$Lsd|L^+}deRhk=p7LrE_t5;N4pF^`px|88PN9-|zW&&+%7x+7R zs;h&u3E8A1ldC~Jy$&v;Ve`Z3RLup`5IM_hv6x@cq@|h zlJv-BdK}JWoyo8wxPZ5Dob-I7D`mb(A&-))WU4$wANwt!$&>G3j}bf@XkW9yln1A! zf>v)zS7gw*?$xB6o=%6!hawJ~$INIZ!>YgnO|P3>w#L;;J2Gd`Gs-0Y#g}N^=$&L| zIiAWB&SGW+&IVf7EO0Tam#W*ro04wnRPvayn>?OZkh5J1sjI!{^W-biCbtrN9+_r3 ztO_b{BkR2Ds(qEx>KG52Q~xh&i72B@3A-rD_7i%wrb#N8TaBLq2$~JFtzF>$nIlra zeh${H8dg&c}!`#5? ze>=K6JK4HAyL5N5U1t0Ly&9WTzf$&GkZ=FyivkpQw}5f6DH0|hVNc%8!>bhdTfn$D z604eJuqVs|1r!2%vQOjRC$_LB3L}mvQ2pE`XU>gfPZ$LXCMqFE&IMK+S zXrBZEs03C%%tA%k2H7y$o$papP@HKkV9Y!(mY$84&AL^xCt4?g%hx856m6M2ZrSS7RX#_>$iItg4}B{1T}XJREr%lnOOc&4bJIMZIh^-(E|Q;pKS zG@fUQ_GutMO#P zAM)kAJNVwJAI}zD!@#|36-dk-i@dHC_-m?#TskKSbIRuOY^ha^H|QwPVE=Zp>E(8e zJ8L12&%Onhdp4n)7?PNV(v$(H8s^9|Mz=6<-&zGGzi{<<<I6i7-2EJh<&M5k2mNjP-S*hwED~d->Cg)_o5i}v%W$}u`5p6k7j3d z3VQ=x&Ys6<@M5N<9vB6c}*DQ@pruimoNTaGbFLx#5}U zlduXy$9dvY^IIY^<`tjMbPNV}u2Eo@+XRoB=d#6)z78lhn-8;)ZE&5w8EvixqoBMK zp1s~JDqUUJn;Hc<&`}^}Ll05fqEzJe@IZ{iayTDO#@fo&sGK(zr>mO5v3Y_>7?$g< zqYZbaKurP-8%-0wd6lBo8Q^8gGW?Od6~z~#(bIJpGA=g7nD4fTpbpb`-qZxh>sks( z|0&{N-(N*pk_-O)W)7BLPJm}-G;XGi!hlLEh$RQbUpA)vTtLfoaOs)^_BMAG31$T% z-P9V5B7KoPDJ*g9p7h_>IwXSJFTd0jh!yd60ng~K+9$G^W3 zF+=;{tE6D~EQ&*5*nB)~=Y`KtG)LcyIU>u|iD!;>>EOaDfpF*E;%3Bu#j&iWh;tZ# z$%(UJZ5M+RsRj&s&kcD24X{2YL-dR9$nK~hInzoY*6Ff)%JSi&WZ6lv-`WC;syrY| zg5Z525^=#3u}1X9x~2_ae{7!^liN>3rrPqEO{;`(*_s5#tt@hnzBf>eipvuJ@u`o# zx1De_)DK>-BJlLsB=kSk2aC@%MOb`+xZA@hF3sD?&)90h zwq~})zV|=C>C^xe-w(orTOZZ#mw?l$7>}*x zV}#F(_BcEJLjEUycjxFj*CwGf?L zLU{gYmlQ7iE&=I%(qsDciDLbzbnztgzR2*l#H3rj6H*JNP-sNJ= Qg1us<_pIta<$w3*e>A Date: Wed, 22 May 2024 08:50:31 +0200 Subject: [PATCH 03/29] Added MLX Audio Encoder (#139) * added mlx audio encoder * fixed model protocols * removed not needed --- .github/workflows/unit-tests.yml | 4 +- .gitignore | 3 +- Makefile | 26 ++++- Package.resolved | 2 +- Package.swift | 29 +++--- Sources/WhisperKit/Core/Models.swift | 11 +- Sources/WhisperKit/MLX/Attention.swift | 95 ++++++++++++++++++ Sources/WhisperKit/MLX/MLXAudioEncoder.swift | 90 +++++++++++++++-- Sources/WhisperKit/MLX/MLXModels.swift | 21 ++++ Sources/WhisperKit/MLX/MLXUtils.swift | 29 ++++++ .../Resources/es_test_clip.wav | Bin .../Resources/ja_test_clip.wav | Bin .../WhisperKitTestsUtils}/Resources/jfk.wav | Bin .../WhisperKitTestsUtils}/TestUtils.swift | 64 +++++++++--- Tests/WhisperKitMLXTests/MLXUnitTests.swift | 46 ++++++++- Tests/WhisperKitTests/FunctionalTests.swift | 27 ++--- Tests/WhisperKitTests/RegressionTests.swift | 1 + Tests/WhisperKitTests/UnitTests.swift | 23 ++--- 18 files changed, 400 insertions(+), 71 deletions(-) create mode 100644 Sources/WhisperKit/MLX/Attention.swift rename {Tests/WhisperKitTests => Sources/WhisperKitTestsUtils}/Resources/es_test_clip.wav (100%) rename {Tests/WhisperKitTests => Sources/WhisperKitTestsUtils}/Resources/ja_test_clip.wav (100%) rename {Tests/WhisperKitTests => Sources/WhisperKitTestsUtils}/Resources/jfk.wav (100%) rename {Tests/WhisperKitTests => Sources/WhisperKitTestsUtils}/TestUtils.swift (82%) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index c5faf3b..d92f219 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -61,11 +61,11 @@ jobs: id: model-cache uses: actions/cache@v4 with: - path: Models + path: Sources/WhisperKitTestsUtils/Models key: ${{ runner.os }}-models - name: Download Models if: steps.model-cache.outputs.cache-hit != 'true' - run: make download-model MODEL=tiny + run: make download-model MODEL=tiny && make download-mlx-model - name: Install and discover destinations env: MLX_DISABLED: ${{ matrix.run-config['mlx-disabled'] }} diff --git a/.gitignore b/.gitignore index 9ce9f61..fd725dc 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ Models **/*.mp3 **/*.m4a **/*.flac +!Sources/WhisperKitTestsUtils/Resources/*.* ## Xcode # Build generated @@ -62,4 +63,4 @@ fastlane/test_output !*.xcodeproj/project.pbxproj !*.xcodeproj/xcshareddata/ !*.xcworkspace/contents.xcworkspacedata -/*.gcno \ No newline at end of file +/*.gcno diff --git a/Makefile b/Makefile index 8b5af9e..6f33117 100644 --- a/Makefile +++ b/Makefile @@ -5,8 +5,10 @@ PYTHON_COMMAND := python3 # Define model repository and directories MODEL_REPO := argmaxinc/whisperkit-coreml -MODEL_REPO_DIR := ./Models/whisperkit-coreml -BASE_COMPILED_DIR := ./Models +MLX_MODEL_REPO := jkrukowski/whisper-tiny-mlx-safetensors +MLX_MODEL_REPO_DIR := ./Sources/WhisperKitTestsUtils/Models/mlx/whisper-tiny-mlx +MODEL_REPO_DIR := ./Sources/WhisperKitTestsUtils/Models/whisperkit-coreml +BASE_COMPILED_DIR := ./Sources/WhisperKitTestsUtils/Models setup: @@ -56,6 +58,19 @@ setup-model-repo: git clone https://huggingface.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \ fi +setup-mlx-model-repo: + @echo "Setting up mlx repository..." + @mkdir -p $(BASE_COMPILED_DIR) + @if [ -d "$(MLX_MODEL_REPO_DIR)/.git" ]; then \ + echo "Repository exists, resetting..."; \ + export GIT_LFS_SKIP_SMUDGE=1; \ + cd $(MLX_MODEL_REPO_DIR) && git fetch --all && git reset --hard origin/main && git clean -fdx; \ + else \ + echo "Repository not found, initializing..."; \ + export GIT_LFS_SKIP_SMUDGE=1; \ + git clone https://huggingface.co/$(MLX_MODEL_REPO) $(MLX_MODEL_REPO_DIR); \ + fi + # Download all models download-models: setup-model-repo @echo "Downloading all models..." @@ -74,6 +89,13 @@ download-model: @cd $(MODEL_REPO_DIR) && \ git lfs pull --include="openai_whisper-$(MODEL)/*" +download-mlx-model: + @echo "Downloading mlx model $(MODEL)..." + @$(MAKE) setup-mlx-model-repo + @echo "Fetching mlx model $(MODEL)..." + @cd $(MLX_MODEL_REPO_DIR) && \ + git lfs pull + build: @echo "Building WhisperKit..." @swift build -v diff --git a/Package.resolved b/Package.resolved index 5c38ee0..d6aa442 100644 --- a/Package.resolved +++ b/Package.resolved @@ -6,7 +6,7 @@ "location" : "https://github.com/ml-explore/mlx-swift", "state" : { "branch" : "main", - "revision" : "b43bdff8b6a413eb75e88eafd4a3995971a406fd" + "revision" : "3c802c808d281c191d5f26f37a4f93135d8ca119" } }, { diff --git a/Package.swift b/Package.swift index 460ce32..dd76229 100644 --- a/Package.swift +++ b/Package.swift @@ -72,25 +72,24 @@ func targets() -> [PackageDescription.Target] { ], path: "Sources/WhisperKit/Core" ), - .testTarget( - name: "WhisperKitTests", + .target( + name: "WhisperKitTestsUtils", dependencies: [ "WhisperKit", .product(name: "Transformers", package: "swift-transformers"), ], - path: ".", - exclude: [ - "Examples", - "Sources", - "Makefile", - "README.md", - "LICENSE", - "CONTRIBUTING.md", - "Tests/WhisperKitMLXTests" - ], resources: [ - .process("Tests/WhisperKitTests/Resources"), .copy("Models/whisperkit-coreml"), + .copy("Models/mlx"), + .process("Resources"), + ] + ), + .testTarget( + name: "WhisperKitTests", + dependencies: [ + "WhisperKit", + "WhisperKitTestsUtils", + .product(name: "Transformers", package: "swift-transformers"), ] ) ] @@ -115,7 +114,8 @@ func mlxTargets() -> [PackageDescription.Target] { dependencies: [ "WhisperKit", .product(name: "MLX", package: "mlx-swift"), - .product(name: "MLXFFT", package: "mlx-swift") + .product(name: "MLXFFT", package: "mlx-swift"), + .product(name: "MLXNN", package: "mlx-swift") ], path: "Sources/WhisperKit/MLX", resources: [ @@ -128,6 +128,7 @@ func mlxTargets() -> [PackageDescription.Target] { dependencies: [ "WhisperKit", "WhisperKitMLX", + "WhisperKitTestsUtils", .product(name: "Transformers", package: "swift-transformers"), ] ) diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index d11a20b..58a7f29 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -20,10 +20,17 @@ extension Float16: MLShapedArrayScalar {} // MARK: - CoreML -public protocol WhisperMLModel { +public protocol WhisperModel { + mutating func unloadModel() +} + +public protocol WhisperMLModel: WhisperModel { var model: MLModel? { get set } mutating func loadModel(at modelPath: URL, computeUnits: MLComputeUnits, prewarmMode: Bool) async throws - mutating func unloadModel() +} + +public protocol WhisperMLXModel: WhisperModel { + mutating func loadModel(at modelPath: URL) async throws } public extension WhisperMLModel { diff --git a/Sources/WhisperKit/MLX/Attention.swift b/Sources/WhisperKit/MLX/Attention.swift new file mode 100644 index 0000000..c38c3fa --- /dev/null +++ b/Sources/WhisperKit/MLX/Attention.swift @@ -0,0 +1,95 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import Foundation +import MLX +import MLXNN + +final class MultiHeadAttention: Module { + let nHead: Int + let query: Linear + let key: Linear + let value: Linear + let out: Linear + + init(nState: Int, nHead: Int) { + self.nHead = nHead + self.query = Linear(nState, nState) + self.key = Linear(nState, nState, bias: false) + self.value = Linear(nState, nState) + self.out = Linear(nState, nState) + } + + func callAsFunction( + _ x: MLXArray, + xa: MLXArray? = nil, + mask: MLXArray? = nil, + kvCache: MLXArray? = nil + ) -> (MLXArray, (MLXArray, MLXArray), MLXArray) { + let q = query(x) + + var k: MLXArray + var v: MLXArray + if let xa { + k = key(xa) + v = value(xa) + } else { + k = key(x) + v = value(x) + if let kvCache { + k = MLX.concatenated([kvCache[0], k], axis: 1) + v = MLX.concatenated([kvCache[1], v], axis: 1) + } + } + + let (wv, qk) = qkvAttention(q, k, v, mask) + return (out(wv), (k, v), qk) + } + + private func qkvAttention(_ q: MLXArray, _ k: MLXArray, _ v: MLXArray, _ mask: MLXArray?) -> (MLXArray, MLXArray) { + let (nBatch, nCtx, nState) = (q.shape[0], q.shape[1], q.shape[2]) + let scale = pow(Float(nState / nHead), -0.25) + let q = q.reshaped([q.shape[0], q.shape[1], nHead, -1]).transposed(0, 2, 1, 3) * scale + let k = k.reshaped([k.shape[0], k.shape[1], nHead, -1]).transposed(0, 2, 3, 1) * scale + let v = v.reshaped([v.shape[0], v.shape[1], nHead, -1]).transposed(0, 2, 1, 3) + var qk = q.matmul(k) + if let mask { + qk = qk + mask[0.. ResidualAttentionBlockResult { + let (kvCache, crossKv) = kvCache ?? (nil, nil) + let (y, kv, _) = attn(attn_ln(x), mask: mask, kvCache: kvCache) + var x = x + y + x = x + mlp2(gelu(mlp1(mlp_ln(x)))) + return ResidualAttentionBlockResult(x: x, kv: kv, crossKv: crossKv, crossQk: nil) + } +} diff --git a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift index eee2d0b..f95609e 100644 --- a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift +++ b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift @@ -4,23 +4,97 @@ import CoreML import MLX import WhisperKit +import MLXNN @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public class MLXAudioEncoder: AudioEncoding { - public var embedSize: Int? { - fatalError("Not implemented") - } - - public var sequenceLength: Int? { - fatalError("Not implemented") + encoder?.nState } + private var encoder: AudioEncoder? public init() {} public func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray? { - // Make sure features is shape MultiArray (Float32 1 × {80,128} × 3000) + guard let encoder else { + throw WhisperError.modelsUnavailable() + } try Task.checkCancellation() - fatalError("Not implemented") + let input = features.withUnsafeBytes { ptr in + MLXArray(ptr, features.shape.map { $0.intValue }, type: FloatType.self) + } + let ouput = encoder(input) + return try ouput.asMLMultiArray() + } +} + +extension MLXAudioEncoder: WhisperMLXModel { + public func loadModel(at modelPath: URL) async throws { + let parameters = try loadParameters(at: modelPath.appending(path: "weights.safetensors"), forKey: "encoder") + let config = try loadConfig(at: modelPath.appending(path: "config.json")) + let encoder = AudioEncoder( + nMels: config.nMels, + nCtx: config.nAudioCtx, + nState: config.nAudioState, + nHead: config.nAudioHead, + nLayer: config.nAudioLayer, + dType: .float16 + ) + let loadedEncoder = try encoder.update(parameters: parameters, verify: [.noUnusedKeys]) + MLX.eval(loadedEncoder) + self.encoder = encoder + } + + public func unloadModel() { + encoder = nil + } +} + +final class AudioEncoder: Module { + let nMels: Int + let nCtx: Int + let nState: Int + let nHead: Int + let nLayer: Int + let dType: MLX.DType + + private let conv1: Conv1d + private let conv2: Conv1d + private let positionalEmbedding: MLXArray + private let blocks: [ResidualAttentionBlock] + private let ln_post: LayerNorm + + init( + nMels: Int, + nCtx: Int, + nState: Int, + nHead: Int, + nLayer: Int, + dType: MLX.DType = .float16 + ) { + self.nMels = nMels + self.nCtx = nCtx + self.nState = nState + self.nHead = nHead + self.nLayer = nLayer + self.dType = dType + + self.conv1 = Conv1d(inputChannels: nMels, outputChannels: nState, kernelSize: 3, padding: 1) + self.conv2 = Conv1d(inputChannels: nState, outputChannels: nState, kernelSize: 3, stride: 2, padding: 1) + self.positionalEmbedding = sinusoids(length: nCtx, channels: nState).asType(dType) + self.blocks = (0.. MLXArray { + var x = MLXNN.gelu(conv1(x)) + x = MLXNN.gelu(conv2(x)) + assert(Array(x.shape[1...]) == positionalEmbedding.shape, "incorrect audio shape") + x = x + positionalEmbedding + for block in blocks { + x = block(x).x + } + x = ln_post(x) + return x } } diff --git a/Sources/WhisperKit/MLX/MLXModels.swift b/Sources/WhisperKit/MLX/MLXModels.swift index e0534cc..cdd8a64 100644 --- a/Sources/WhisperKit/MLX/MLXModels.swift +++ b/Sources/WhisperKit/MLX/MLXModels.swift @@ -2,8 +2,29 @@ // Copyright © 2024 Argmax, Inc. All rights reserved. import Foundation +import MLX public enum PadMode { case constant case reflect } + +struct ModelConfig: Codable { + let nMels: Int + let nAudioCtx: Int + let nAudioState: Int + let nAudioHead: Int + let nAudioLayer: Int + let nVocab: Int + let nTextCtx: Int + let nTextState: Int + let nTextHead: Int + let nTextLayer: Int +} + +struct ResidualAttentionBlockResult { + var x: MLXArray + var kv: (MLXArray, MLXArray) + var crossKv: MLXArray? + var crossQk: MLXArray? +} diff --git a/Sources/WhisperKit/MLX/MLXUtils.swift b/Sources/WhisperKit/MLX/MLXUtils.swift index 2567c2c..6849643 100644 --- a/Sources/WhisperKit/MLX/MLXUtils.swift +++ b/Sources/WhisperKit/MLX/MLXUtils.swift @@ -3,6 +3,7 @@ import Foundation import MLX +import MLXNN import CoreML // MARK: - Extensions @@ -43,3 +44,31 @@ extension MLXArray { } } } + +// MARK: - Functions + +func sinusoids(length: Int, channels: Int, maxTimescale: Int = 10000) -> MLXArray { + assert(channels % 2 == 0) + let logTimescaleIncrement = log(Float(maxTimescale)) / Float(channels / 2 - 1) + let invTimescales = MLX.exp(-logTimescaleIncrement * MLXArray(Array(0..<(channels / 2)))) + let scaledTime = MLXArray(Array(0.. NestedDictionary { + let arrays = try MLX.loadArrays(url: url) + let params = ModuleParameters.unflattened(arrays) + guard let key else { + return params + } + guard let keyParams = params[key] else { + throw CocoaError.error(.coderValueNotFound) + } + return NestedDictionary(item: keyParams) +} + +func loadConfig(at url: URL) throws -> ModelConfig { + let configDecoder = JSONDecoder() + configDecoder.keyDecodingStrategy = .convertFromSnakeCase + return try configDecoder.decode(ModelConfig.self, from: Data(contentsOf: url)) +} diff --git a/Tests/WhisperKitTests/Resources/es_test_clip.wav b/Sources/WhisperKitTestsUtils/Resources/es_test_clip.wav similarity index 100% rename from Tests/WhisperKitTests/Resources/es_test_clip.wav rename to Sources/WhisperKitTestsUtils/Resources/es_test_clip.wav diff --git a/Tests/WhisperKitTests/Resources/ja_test_clip.wav b/Sources/WhisperKitTestsUtils/Resources/ja_test_clip.wav similarity index 100% rename from Tests/WhisperKitTests/Resources/ja_test_clip.wav rename to Sources/WhisperKitTestsUtils/Resources/ja_test_clip.wav diff --git a/Tests/WhisperKitTests/Resources/jfk.wav b/Sources/WhisperKitTestsUtils/Resources/jfk.wav similarity index 100% rename from Tests/WhisperKitTests/Resources/jfk.wav rename to Sources/WhisperKitTestsUtils/Resources/jfk.wav diff --git a/Tests/WhisperKitTests/TestUtils.swift b/Sources/WhisperKitTestsUtils/TestUtils.swift similarity index 82% rename from Tests/WhisperKitTests/TestUtils.swift rename to Sources/WhisperKitTestsUtils/TestUtils.swift index 62aeb4c..67381a1 100644 --- a/Tests/WhisperKitTests/TestUtils.swift +++ b/Sources/WhisperKitTestsUtils/TestUtils.swift @@ -3,13 +3,43 @@ import Foundation @testable import WhisperKit import XCTest -enum TestError: Error { +public enum TestError: Error { case missingFile(String) case missingDirectory(String) } +public enum TestResource { + public static func path(forResource resource: String?, ofType type: String?) -> String? { + Bundle.module.path(forResource: resource, ofType: type) + } + + public static func url(forResource resource: String?, withExtension ext: String?) -> URL? { + Bundle.module.url(forResource: resource, withExtension: ext) + } +} + +public func XCTAssertEqual( + _ expression1: @autoclosure () throws -> [T], + _ expression2: @autoclosure () throws -> [T], + accuracy: T, + _ message: @autoclosure () -> String = "", + file: StaticString = #filePath, + line: UInt = #line +) { + do { + let lhsEvaluated = try expression1() + let rhsEvaluated = try expression2() + XCTAssertEqual(lhsEvaluated.count, rhsEvaluated.count, file: file, line: line) + for (lhs, rhs) in zip(lhsEvaluated, rhsEvaluated) { + XCTAssertEqual(lhs, rhs, accuracy: accuracy, file: file, line: line) + } + } catch { + XCTFail("Unexpected error: \(error)", file: file, line: line) + } +} + @discardableResult -func XCTUnwrapAsync( +public func XCTUnwrapAsync( _ expression: @autoclosure () async throws -> T, _ message: @autoclosure () -> String = "", file: StaticString = #filePath, @@ -20,7 +50,7 @@ func XCTUnwrapAsync( } @discardableResult -func XCTUnwrapAsync( +public func XCTUnwrapAsync( _ expression: @autoclosure () async throws -> T?, _ message: @autoclosure () -> String = "", file: StaticString = #filePath, @@ -30,7 +60,7 @@ func XCTUnwrapAsync( return try XCTUnwrap(evaluated, message(), file: file, line: line) } -func XCTAssertNoThrowAsync( +public func XCTAssertNoThrowAsync( _ expression: @autoclosure () async throws -> T, _ message: @autoclosure () -> String = "", file: StaticString = #filePath, @@ -43,7 +73,7 @@ func XCTAssertNoThrowAsync( } } -func XCTAssertNoThrowAsync( +public func XCTAssertNoThrowAsync( _ expression: @autoclosure () async throws -> T?, _ message: @autoclosure () -> String = "", file: StaticString = #filePath, @@ -56,7 +86,7 @@ func XCTAssertNoThrowAsync( } } -func XCTAssertNoThrowAsync( +public func XCTAssertNoThrowAsync( _ expression: @autoclosure () async throws -> Void, _ message: @autoclosure () -> String = "", file: StaticString = #filePath, @@ -72,7 +102,7 @@ func XCTAssertNoThrowAsync( // MARK: Helpers @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -extension MLMultiArray { +public extension MLMultiArray { /// Create `MLMultiArray` of shape [1, 1, arr.count] and fill up the last /// dimension with with values from arr. static func logits(_ arr: [FloatType]) throws -> MLMultiArray { @@ -100,7 +130,7 @@ extension MLMultiArray { } @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -extension XCTestCase { +public extension XCTestCase { func transcribe( with variant: ModelVariant, options: DecodingOptions, @@ -139,6 +169,14 @@ extension XCTestCase { return modelPath } + func tinyMLXModelPath() throws -> String { + let modelDir = "mlx/whisper-tiny-mlx" + guard let modelPath = Bundle.module.urls(forResourcesWithExtension: "safetensors", subdirectory: modelDir)?.first?.deletingLastPathComponent().path else { + throw TestError.missingFile("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-models`") + } + return modelPath + } + func largev3ModelPath() throws -> String { let modelDir = "whisperkit-coreml/openai_whisper-large-v3" // use faster to compile model for tests guard let modelPath = Bundle.module.urls(forResourcesWithExtension: "mlmodelc", subdirectory: modelDir)?.first?.deletingLastPathComponent().path else { @@ -206,7 +244,7 @@ extension XCTestCase { } } -extension SpecialTokens { +public extension SpecialTokens { static func `default`( endToken: Int = 0, englishToken: Int = 0, @@ -236,7 +274,7 @@ extension SpecialTokens { } } -extension Result { +public extension Result { var isSuccess: Bool { switch self { case .success: @@ -256,19 +294,19 @@ extension Result { } } -extension Result where Success == [TranscriptionResult] { +public extension Result where Success == [TranscriptionResult] { func normalizedText(prefix: Int) throws -> String { try get().text.normalized.split(separator: " ").prefix(prefix).joined(separator: " ") } } -extension Collection where Element == TranscriptionResult { +public extension Collection where Element == TranscriptionResult { var text: String { map(\.text).joined(separator: " ") } } -extension Collection where Element == TranscriptionResult { +public extension Collection where Element == TranscriptionResult { var segments: [TranscriptionSegment] { flatMap(\.segments) } diff --git a/Tests/WhisperKitMLXTests/MLXUnitTests.swift b/Tests/WhisperKitMLXTests/MLXUnitTests.swift index 2d9c356..80e9a68 100644 --- a/Tests/WhisperKitMLXTests/MLXUnitTests.swift +++ b/Tests/WhisperKitMLXTests/MLXUnitTests.swift @@ -3,11 +3,15 @@ import XCTest import MLX +import WhisperKitTestsUtils +import CoreML @testable import WhisperKit @testable import WhisperKitMLX final class MLXUnitTests: XCTestCase { + private let accuracy: Float = 0.00001 + // MARK: - Feature Extractor Tests func testLogmelOutput() async throws { @@ -25,6 +29,21 @@ final class MLXUnitTests: XCTestCase { XCTAssertEqual(melSpectrogram.shape, expectedShape, "Mel spectrogram shape is not as expected") } + // MARK: - Encoder Tests + + func testEncoderOutput() async throws { + let audioEncoder = MLXAudioEncoder() + let modelPath = try URL(filePath: tinyMLXModelPath()) + try await audioEncoder.loadModel(at: modelPath) + + let encoderInput = try MLMultiArray(shape: [1, 3000, 80], dataType: .float16) + let expectedShape: [NSNumber] = [1, 1500, 384] + + let encoderOutput = try await audioEncoder.encodeFeatures(encoderInput) + XCTAssertNotNil(encoderOutput, "Failed to encode features") + XCTAssertEqual(encoderOutput?.shape, expectedShape, "Encoder output shape is not as expected") + } + // MARK: - Utils Tests func testAsMLMultiArray() throws { @@ -37,7 +56,7 @@ final class MLXUnitTests: XCTestCase { for col in 0.. Date: Tue, 28 May 2024 16:15:25 -0400 Subject: [PATCH 04/29] Updates for merge --- Tests/WhisperKitTests/FunctionalTests.swift | 29 --------------------- Tests/WhisperKitTests/RegressionTests.swift | 29 +++++++++++++++++++++ Tests/WhisperKitTests/UnitTests.swift | 8 +++--- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/Tests/WhisperKitTests/FunctionalTests.swift b/Tests/WhisperKitTests/FunctionalTests.swift index 4482e94..45eadae 100644 --- a/Tests/WhisperKitTests/FunctionalTests.swift +++ b/Tests/WhisperKitTests/FunctionalTests.swift @@ -14,35 +14,6 @@ final class FunctionalTests: XCTestCase { ) } - func testOutputAll() async throws { - let modelPaths = try allModelPaths() - - for modelPath in modelPaths { - let modelName = modelPath.split(separator: "/").last! - print("[Integration] Testing model \(modelName)") - let audioFilePath = try XCTUnwrap( - TestResource.path(forResource: "jfk", ofType: "wav"), - "Audio file not found" - ) - - let whisperKit = try await WhisperKit( - modelFolder: modelPath, - verbose: true, - logLevel: .debug - ) - - let transcriptionResult: [TranscriptionResult] = try await whisperKit.transcribe(audioPath: audioFilePath) - let transcriptionResultText = transcriptionResult.text - - print("[Integration] \(transcriptionResultText)") - XCTAssertEqual( - transcriptionResultText.normalized, - " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.".normalized, - "Transcription result does not match expected result for model \(modelName)" - ) - } - } - func testRealTimeFactorTiny() async throws { let modelPath = try tinyModelPath() diff --git a/Tests/WhisperKitTests/RegressionTests.swift b/Tests/WhisperKitTests/RegressionTests.swift index 1c8bdb6..1cc1ad2 100644 --- a/Tests/WhisperKitTests/RegressionTests.swift +++ b/Tests/WhisperKitTests/RegressionTests.swift @@ -24,6 +24,35 @@ final class RegressionTests: XCTestCase { wait(for: [expectation], timeout: 30) } } + + func testOutputAll() async throws { + let modelPaths = try allModelPaths() + + for modelPath in modelPaths { + let modelName = modelPath.split(separator: "/").last! + print("[Integration] Testing model \(modelName)") + let audioFilePath = try XCTUnwrap( + TestResource.path(forResource: "jfk", ofType: "wav"), + "Audio file not found" + ) + + let whisperKit = try await WhisperKit( + modelFolder: modelPath, + verbose: true, + logLevel: .debug + ) + + let transcriptionResult: [TranscriptionResult] = try await whisperKit.transcribe(audioPath: audioFilePath) + let transcriptionResultText = transcriptionResult.text + + print("[Integration] \(transcriptionResultText)") + XCTAssertEqual( + transcriptionResultText.normalized, + " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.".normalized, + "Transcription result does not match expected result for model \(modelName)" + ) + } + } func downloadTestAudio(completion: @escaping (Bool) -> Void) { Task { diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index ce3c280..2b0fa16 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -657,7 +657,7 @@ final class UnitTests: XCTestCase { for language in targetLanguages { let audioFilePath = try XCTUnwrap( - Bundle.module.path(forResource: "\(language)_test_clip", ofType: "wav"), + TestResource.path(forResource: "\(language)_test_clip", ofType: "wav"), "Audio file not found" ) @@ -1016,7 +1016,7 @@ final class UnitTests: XCTestCase { XCTAssertTrue(vad.voiceActivity(in: []).isEmpty) let audioFilePath = try XCTUnwrap( - Bundle.module.path(forResource: "jfk", ofType: "wav"), + TestResource.path(forResource: "jfk", ofType: "wav"), "Audio file not found" ) let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioFilePath) @@ -1121,7 +1121,7 @@ final class UnitTests: XCTestCase { Logging.shared.logLevel = .debug let singleChunkPath = try XCTUnwrap( - Bundle.module.path(forResource: "jfk", ofType: "wav"), + TestResource.path(forResource: "jfk", ofType: "wav"), "Audio file not found" ) var audioBuffer = try AudioProcessor.loadAudio(fromPath: singleChunkPath) @@ -1136,7 +1136,7 @@ final class UnitTests: XCTestCase { XCTAssertEqual(audioChunks.count, 1) let multiChunkPath = try XCTUnwrap( - Bundle.module.path(forResource: "ted_60", ofType: "m4a"), + TestResource.path(forResource: "ted_60", ofType: "m4a"), "Audio file not found" ) audioBuffer = try AudioProcessor.loadAudio(fromPath: multiChunkPath) From 470e227dc06e8fa18069be3c8806bbb983eaba88 Mon Sep 17 00:00:00 2001 From: Jan Krukowski Date: Tue, 4 Jun 2024 20:33:31 +0200 Subject: [PATCH 05/29] Allow MLX and CoreML to coexist (#156) * fixes for mlx models * fixed asMLXArray * fixed tests, mlx doesn't run on simulators * fix --- .github/workflows/unit-tests.yml | 14 ++++++--- Sources/WhisperKit/Core/AudioEncoder.swift | 2 +- Sources/WhisperKit/Core/WhisperKit.swift | 12 +++++++ Sources/WhisperKit/MLX/MLXAudioEncoder.swift | 9 +++--- .../WhisperKit/MLX/MLXFeatureExtractor.swift | 8 ++--- Sources/WhisperKit/MLX/MLXModels.swift | 2 +- Sources/WhisperKit/MLX/MLXUtils.swift | 31 +++++++++++++++++-- Sources/WhisperKitCLI/TranscribeCLI.swift | 1 + Tests/WhisperKitMLXTests/MLXUnitTests.swift | 22 +++++++++++-- 9 files changed, 79 insertions(+), 22 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index d92f219..970fedc 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -22,6 +22,7 @@ jobs: condition: true, clean-destination: "generic/platform=macOS", test-destination: "platform=macOS,arch=arm64", + test-cases: "-only-testing WhisperKitTests/UnitTests -only-testing WhisperKitMLXTests/MLXUnitTests", mlx-disabled: "0", scheme: "whisperkit-Package", } @@ -30,14 +31,16 @@ jobs: condition: true, clean-destination: "generic/platform=iOS", test-destination: "platform=iOS Simulator,OS=${{ inputs.ios-version }},name=iPhone 15", - mlx-disabled: "0", - scheme: "whisperkit-Package", + test-cases: "-only-testing WhisperKitTests/UnitTests", + mlx-disabled: "1", + scheme: "whisperkit", } - { name: "watchOS", condition: "${{ inputs.macos-runner == 'macos-14' }}", clean-destination: "generic/platform=watchOS", test-destination: "platform=watchOS Simulator,OS=10.5,name=Apple Watch Ultra 2 (49mm)", + test-cases: "-only-testing WhisperKitTests/UnitTests", mlx-disabled: "1", scheme: "whisperkit", } @@ -46,8 +49,9 @@ jobs: condition: "${{ inputs.macos-runner == 'macos-14' }}", clean-destination: "generic/platform=visionOS", test-destination: "platform=visionOS Simulator,name=Apple Vision Pro", - mlx-disabled: "0", - scheme: "whisperkit-Package", + test-cases: "-only-testing WhisperKitTests/UnitTests", + mlx-disabled: "1", + scheme: "whisperkit", } timeout-minutes: 20 steps: @@ -82,4 +86,4 @@ jobs: run: | set -o pipefail xcodebuild clean build-for-testing -scheme ${{ matrix.run-config['scheme'] }} -destination "${{ matrix.run-config['clean-destination'] }}" -skipPackagePluginValidation | xcpretty - xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme ${{ matrix.run-config['scheme'] }} -destination "${{ matrix.run-config['test-destination'] }}" -skipPackagePluginValidation | xcpretty + xcodebuild test ${{ matrix.run-config['test-cases'] }} -scheme ${{ matrix.run-config['scheme'] }} -destination "${{ matrix.run-config['test-destination'] }}" -skipPackagePluginValidation | xcpretty diff --git a/Sources/WhisperKit/Core/AudioEncoder.swift b/Sources/WhisperKit/Core/AudioEncoder.swift index 06337cd..b5e6d4c 100644 --- a/Sources/WhisperKit/Core/AudioEncoder.swift +++ b/Sources/WhisperKit/Core/AudioEncoder.swift @@ -37,7 +37,7 @@ public class AudioEncoder: AudioEncoding, WhisperMLModel { public init() {} public func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray? { - // Make sure features is shape MultiArray (Float32 1 × {80,128} × 3000) + // Make sure features is shape MultiArray (Float16 1 × {80,128} x 1 × 3000) guard let model else { throw WhisperError.modelsUnavailable() } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index d3d6c5b..c6fac5b 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -280,6 +280,12 @@ open class WhisperKit { prewarmMode: prewarmMode ) Logging.debug("Loaded feature extractor") + } else if let featureExtractor = featureExtractor as? WhisperMLXModel { + Logging.debug("Loading MLX feature extractor") + try await featureExtractor.loadModel( + at: path + ) + Logging.debug("Loaded MLX feature extractor") } if let audioEncoder = audioEncoder as? WhisperMLModel { @@ -290,6 +296,12 @@ open class WhisperKit { prewarmMode: prewarmMode ) Logging.debug("Loaded audio encoder") + } else if let audioEncoder = audioEncoder as? WhisperMLXModel { + Logging.debug("Loading MLX audio encoder") + try await audioEncoder.loadModel( + at: path + ) + Logging.debug("Loaded MLX audio encoder") } if let textDecoder = textDecoder as? WhisperMLModel { diff --git a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift index f95609e..9b074a0 100644 --- a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift +++ b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift @@ -20,11 +20,10 @@ public class MLXAudioEncoder: AudioEncoding { throw WhisperError.modelsUnavailable() } try Task.checkCancellation() - let input = features.withUnsafeBytes { ptr in - MLXArray(ptr, features.shape.map { $0.intValue }, type: FloatType.self) - } - let ouput = encoder(input) - return try ouput.asMLMultiArray() + let inputArray = features.asMLXArray(FloatType.self) + let input = inputArray.asMLXInput() + let output = encoder(input[.newAxis]) + return try output.asMLXOutput().asMLMultiArray() } } diff --git a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift index 2f47cd8..0294217 100644 --- a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift +++ b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift @@ -27,17 +27,15 @@ open class MLXFeatureExtractor: FeatureExtracting { public func logMelSpectrogram(fromAudio inputAudio: MLMultiArray) async throws -> MLMultiArray? { try Task.checkCancellation() - let input = inputAudio.withUnsafeBytes { ptr in - MLXArray(ptr, inputAudio.shape.map { $0.intValue }, type: Float.self) - } - let logMelSpectrogram = MLXFeatureExtractor.logMelSpectrogram( + let input = inputAudio.asMLXArray(Float.self) + let output = MLXFeatureExtractor.logMelSpectrogram( audio: input, filters: filters, nMels: melCount ?? 80, nFFT: nFFT, hopLength: hopLength ) - return try logMelSpectrogram.asMLMultiArray() + return try output.asType(FloatType.self).asMLXOutput().asMLMultiArray() } } diff --git a/Sources/WhisperKit/MLX/MLXModels.swift b/Sources/WhisperKit/MLX/MLXModels.swift index cdd8a64..5ab5f2b 100644 --- a/Sources/WhisperKit/MLX/MLXModels.swift +++ b/Sources/WhisperKit/MLX/MLXModels.swift @@ -9,7 +9,7 @@ public enum PadMode { case reflect } -struct ModelConfig: Codable { +struct MLXModelConfig: Codable { let nMels: Int let nAudioCtx: Int let nAudioState: Int diff --git a/Sources/WhisperKit/MLX/MLXUtils.swift b/Sources/WhisperKit/MLX/MLXUtils.swift index 6849643..ca3e030 100644 --- a/Sources/WhisperKit/MLX/MLXUtils.swift +++ b/Sources/WhisperKit/MLX/MLXUtils.swift @@ -8,6 +8,33 @@ import CoreML // MARK: - Extensions +extension MLMultiArray { + func asMLXArray(_ type: T.Type) -> MLXArray { + let shape = shape.map(\.intValue) + let strides = strides.map(\.intValue) + return withUnsafeBufferPointer(ofType: T.self) { ptr in + let buffer = UnsafeBufferPointer(start: ptr.baseAddress, count: shape.reduce(1, *)) + return asStrided(MLXArray(buffer, shape), shape, strides: strides) + } + } +} + +extension MLXArray { + /// Adapts the shape of the output array so MLX is compatible with CoreML + /// + /// Remove empty dimensions, swap axes and add empty dimensions, result: [1, n, 1, m] + func asMLXOutput() -> MLXArray { + squeezed().swappedAxes(0, 1).expandedDimensions(axes: [0, 2]) + } + + /// Adapts the shape of the input array so MLX is compatible with CoreML + /// + /// Remove empty dimensions, swap axes, result: [n, m] + func asMLXInput() -> MLXArray { + squeezed().swappedAxes(0, 1) + } +} + extension MLXArray { func asMLMultiArray() throws -> MLMultiArray { let dataType = multiArrayDataType() @@ -67,8 +94,8 @@ func loadParameters(at url: URL, forKey key: String? = nil) throws -> NestedDict return NestedDictionary(item: keyParams) } -func loadConfig(at url: URL) throws -> ModelConfig { +func loadConfig(at url: URL) throws -> MLXModelConfig { let configDecoder = JSONDecoder() configDecoder.keyDecodingStrategy = .convertFromSnakeCase - return try configDecoder.decode(ModelConfig.self, from: Data(contentsOf: url)) + return try configDecoder.decode(MLXModelConfig.self, from: Data(contentsOf: url)) } diff --git a/Sources/WhisperKitCLI/TranscribeCLI.swift b/Sources/WhisperKitCLI/TranscribeCLI.swift index 6e851f1..fb106da 100644 --- a/Sources/WhisperKitCLI/TranscribeCLI.swift +++ b/Sources/WhisperKitCLI/TranscribeCLI.swift @@ -5,6 +5,7 @@ import ArgumentParser import CoreML import Foundation import WhisperKit +import WhisperKitMLX @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) struct TranscribeCLI: AsyncParsableCommand { diff --git a/Tests/WhisperKitMLXTests/MLXUnitTests.swift b/Tests/WhisperKitMLXTests/MLXUnitTests.swift index 80e9a68..1fb1f2d 100644 --- a/Tests/WhisperKitMLXTests/MLXUnitTests.swift +++ b/Tests/WhisperKitMLXTests/MLXUnitTests.swift @@ -24,7 +24,7 @@ final class MLXUnitTests: XCTestCase { let extractedFeature = try await featureExtractor.logMelSpectrogram(fromAudio: paddedSamples) let melSpectrogram = try XCTUnwrap(extractedFeature, "Failed to produce Mel spectrogram from audio samples") - let expectedShape: [NSNumber] = [3000, 80] + let expectedShape: [NSNumber] = [1, 80, 1, 3000] XCTAssertNotNil(melSpectrogram, "Failed to produce Mel spectrogram from audio samples") XCTAssertEqual(melSpectrogram.shape, expectedShape, "Mel spectrogram shape is not as expected") } @@ -36,8 +36,8 @@ final class MLXUnitTests: XCTestCase { let modelPath = try URL(filePath: tinyMLXModelPath()) try await audioEncoder.loadModel(at: modelPath) - let encoderInput = try MLMultiArray(shape: [1, 3000, 80], dataType: .float16) - let expectedShape: [NSNumber] = [1, 1500, 384] + let encoderInput = try MLMultiArray(shape: [1, 80, 1, 3000], dataType: .float16) + let expectedShape: [NSNumber] = [1, 384, 1, 1500] let encoderOutput = try await audioEncoder.encodeFeatures(encoderInput) XCTAssertNotNil(encoderOutput, "Failed to encode features") @@ -46,6 +46,22 @@ final class MLXUnitTests: XCTestCase { // MARK: - Utils Tests + func testArrayConversion() throws { + let count = 16 + let arr1 = MLXArray(0.. Date: Wed, 12 Jun 2024 13:55:05 +0200 Subject: [PATCH 06/29] added MLX text decoder (#161) --- Sources/WhisperKit/Core/Models.swift | 109 +++- Sources/WhisperKit/Core/TextDecoder.swift | 295 +++++----- Sources/WhisperKit/Core/TokenSampler.swift | 6 + Sources/WhisperKit/Core/Utils.swift | 6 +- Sources/WhisperKit/Core/WhisperKit.swift | 30 +- Sources/WhisperKit/MLX/Attention.swift | 49 +- Sources/WhisperKit/MLX/MLXAudioEncoder.swift | 2 +- Sources/WhisperKit/MLX/MLXModels.swift | 20 +- Sources/WhisperKit/MLX/MLXTextDecoder.swift | 543 +++++++++++++++++++ Sources/WhisperKit/MLX/MLXUtils.swift | 16 +- Tests/WhisperKitMLXTests/MLXUnitTests.swift | 71 +++ 11 files changed, 963 insertions(+), 184 deletions(-) create mode 100644 Sources/WhisperKit/MLX/MLXTextDecoder.swift diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index c0d0463..c291380 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -195,18 +195,42 @@ public enum DecodingTask: CustomStringConvertible, CaseIterable { } public struct DecodingInputs { - var initialPrompt: [Int] - var inputIds: MLMultiArray - var cacheLength: MLMultiArray - var keyCache: MLMultiArray - var valueCache: MLMultiArray - var alignmentWeights: MLMultiArray - var kvCacheUpdateMask: MLMultiArray - var decoderKeyPaddingMask: MLMultiArray - var prefillKeyCache: MLMultiArray - var prefillValueCache: MLMultiArray - - func reset(prefilledCacheSize: Int, maxTokenContext: Int) { + public var initialPrompt: [Int] + public var inputIds: MLMultiArray + public var cacheLength: MLMultiArray + public var keyCache: MLMultiArray? + public var valueCache: MLMultiArray? + public var alignmentWeights: MLMultiArray + public var kvCacheUpdateMask: MLMultiArray + public var decoderKeyPaddingMask: MLMultiArray + public var prefillKeyCache: MLMultiArray + public var prefillValueCache: MLMultiArray + + public init( + initialPrompt: [Int], + inputIds: MLMultiArray, + cacheLength: MLMultiArray, + keyCache: MLMultiArray?, + valueCache: MLMultiArray?, + alignmentWeights: MLMultiArray, + kvCacheUpdateMask: MLMultiArray, + decoderKeyPaddingMask: MLMultiArray, + prefillKeyCache: MLMultiArray, + prefillValueCache: MLMultiArray + ) { + self.initialPrompt = initialPrompt + self.inputIds = inputIds + self.cacheLength = cacheLength + self.keyCache = keyCache + self.valueCache = valueCache + self.alignmentWeights = alignmentWeights + self.kvCacheUpdateMask = kvCacheUpdateMask + self.decoderKeyPaddingMask = decoderKeyPaddingMask + self.prefillKeyCache = prefillKeyCache + self.prefillValueCache = prefillValueCache + } + + public func reset(prefilledCacheSize: Int, maxTokenContext: Int) { // NOTE: Because we have a mask on the kvcache, // we can simply shift the masks without touching the data, // it will be overwritten by the new data without impact on the output @@ -230,9 +254,19 @@ public struct DecodingInputs { } public struct DecodingCache { - var keyCache: MLMultiArray? - var valueCache: MLMultiArray? - var alignmentWeights: MLMultiArray? + public var keyCache: MLMultiArray? + public var valueCache: MLMultiArray? + public var alignmentWeights: MLMultiArray? + + public init( + keyCache: MLMultiArray?, + valueCache: MLMultiArray?, + alignmentWeights: MLMultiArray? + ) { + self.keyCache = keyCache + self.valueCache = valueCache + self.alignmentWeights = alignmentWeights + } } public enum ChunkingStrategy: String, CaseIterable { @@ -410,6 +444,34 @@ public struct DecodingResult { public var timings: TranscriptionTimings? public var fallback: DecodingFallback? + public init( + language: String, + languageProbs: [String : Float], + tokens: [Int], + tokenLogProbs: [[Int : Float]], + text: String, + avgLogProb: Float, + noSpeechProb: Float, + temperature: Float, + compressionRatio: Float, + cache: DecodingCache?, + timings: TranscriptionTimings?, + fallback: DecodingFallback? + ) { + self.language = language + self.languageProbs = languageProbs + self.tokens = tokens + self.tokenLogProbs = tokenLogProbs + self.text = text + self.avgLogProb = avgLogProb + self.noSpeechProb = noSpeechProb + self.temperature = temperature + self.compressionRatio = compressionRatio + self.cache = cache + self.timings = timings + self.fallback = fallback + } + public static var emptyResults: DecodingResult { return DecodingResult(language: "", languageProbs: [:], @@ -596,6 +658,23 @@ public struct TranscriptionProgress { public var avgLogprob: Float? public var compressionRatio: Float? public var windowId: Int = 0 + + public init( + timings: TranscriptionTimings, + text: String, + tokens: [Int], + temperature: Float?, + avgLogprob: Float?, + compressionRatio: Float?, + windowId: Int = 0 + ) { + self.timings = timings + self.text = text + self.tokens = tokens + self.temperature = temperature + self.avgLogprob = avgLogprob + self.compressionRatio = compressionRatio + } } /// Callback to receive progress updates during transcription. diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 96e83ad..86bd3e8 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -17,11 +17,15 @@ public protocol TextDecoding { var windowSize: Int? { get } var embedSize: Int? { get } + func prepareDecoderInputs( + withPrompt initialPrompt: [Int] + ) throws -> DecodingInputs + func predictLogits( inputIds: MLMultiArray, cacheLength: MLMultiArray, - keyCache: MLMultiArray, - valueCache: MLMultiArray, + keyCache: MLMultiArray?, + valueCache: MLMultiArray?, kvCacheUpdateMask: MLMultiArray, encoderOutputEmbeds: MLMultiArray, decoderKeyPaddingMask: MLMultiArray @@ -152,7 +156,6 @@ public extension TextDecoding { let decoderKeyPaddingMask = initMLMultiArray(shape: [1, kvCacheMaxSequenceLengthValue], dataType: .float16, initialValue: FloatType(-10000)) let prefillKeyCache = try! MLMultiArray(shape: [1, kvCacheEmbedDimValue, 1, kvCacheMaxSequenceLengthValue], dataType: .float16) let prefillValueCache = try! MLMultiArray(shape: [1, kvCacheEmbedDimValue, 1, kvCacheMaxSequenceLengthValue], dataType: .float16) - let decoderInputs = DecodingInputs( initialPrompt: initialPrompt, inputIds: inputIds, @@ -165,7 +168,6 @@ public extension TextDecoding { prefillKeyCache: prefillKeyCache, prefillValueCache: prefillValueCache ) - return decoderInputs } @@ -231,15 +233,18 @@ public extension TextDecoding { } // Prefill kv cache - prefilledDecoderInputs.prefillKeyCache = prefillOutput.keyCache! - prefilledDecoderInputs.prefillValueCache = prefillOutput.valueCache! - - TextDecoder.updateKVCache(keyTensor: prefilledDecoderInputs.keyCache, - keySlice: prefilledDecoderInputs.prefillKeyCache, - valueTensor: prefilledDecoderInputs.valueCache, - valueSlice: prefilledDecoderInputs.prefillValueCache, - insertAtIndex: prefillTokens.firstIndex(of: tokenizer.specialTokens.startOfTranscriptToken) ?? 0) - prefilledDecoderInputs.cacheLength[0] = prefilledDecoderInputs.prefillKeyCache.shape[3] + if let keyCache = prefillOutput.keyCache, let valueCache = prefillOutput.valueCache { + prefilledDecoderInputs.prefillKeyCache = keyCache + prefilledDecoderInputs.prefillValueCache = valueCache + TextDecoder.updateKVCache( + keyTensor: keyCache, + keySlice: prefilledDecoderInputs.prefillKeyCache, + valueTensor: valueCache, + valueSlice: prefilledDecoderInputs.prefillValueCache, + insertAtIndex: prefillTokens.firstIndex(of: tokenizer.specialTokens.startOfTranscriptToken) ?? 0 + ) + prefilledDecoderInputs.cacheLength[0] = prefilledDecoderInputs.prefillKeyCache.shape[3] + } } return prefilledDecoderInputs @@ -322,102 +327,22 @@ public extension TextDecoding { Logging.debug("Key Cache | Val Cache | Align Cache | Update Mask | Decoder Mask | Position") for i in 0.. (logits: MLMultiArray?, cache: DecodingCache?)? { - let modelInputs = TextDecoderInput( - input_ids: inputIds, - cache_length: cacheLength, - key_cache: keyCache, - value_cache: valueCache, - kv_cache_update_mask: kvCacheUpdateMask, - encoder_output_embeds: encoderOutputEmbeds, - decoder_key_padding_mask: decoderKeyPaddingMask - ) - - guard let model = model else { - return nil - } - - try Task.checkCancellation() - - let outputFeatures = try await model.asyncPrediction(from: modelInputs, options: MLPredictionOptions()) - - let output = TextDecoderOutput(features: outputFeatures) - - let logits = output.logits - let cache = DecodingCache( - keyCache: output.key_cache_updates, - valueCache: output.value_cache_updates, - alignmentWeights: output.alignment_heads_weights - ) - - return (logits, cache) - } - - public func detectLanguage( + static func detectLanguage( + textDecoder: any TextDecoding, + languageLogitsFilter: any LogitsFiltering, from encoderOutput: MLMultiArray, using decoderInputs: DecodingInputs, sampler tokenSampler: TokenSampling, @@ -427,27 +352,16 @@ open class TextDecoder: TextDecoding, WhisperMLModel { // Predict logits for 1 iteration with sot // 1. LanguageLogitsFilter for only language tokens // 2. GreedyTokenSampler for most likely language - guard let tokenizer = tokenizer else { + guard let tokenizer = textDecoder.tokenizer else { // Tokenizer required for decoding throw WhisperError.tokenizerUnavailable() } - guard let logitsSize = logitsSize else { - throw WhisperError.modelsUnavailable("Failed to read logits size from model") - } var timings = TranscriptionTimings() let prefilledIndex = 0 let currentTokens: [Int] = [tokenizer.specialTokens.startOfTranscriptToken] var logProbs: [Float] = Array(repeating: 0, count: prefilledIndex + 1) - // Logits filters - let languageLogitsFilter = self.languageLogitsFilter ?? LanguageLogitsFilter( - allLanguageTokens: tokenizer.allLanguageTokens, - logitsDim: logitsSize, - sampleBegin: prefilledIndex - ) - self.languageLogitsFilter = languageLogitsFilter - let tokenIndex = 0 let prefillToken = currentTokens[tokenIndex] var nextToken = prefillToken @@ -462,7 +376,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { let inferenceTime = Date() Logging.debug("Detecting language...") - let predictedLogits = try await self.predictLogits( + let predictedLogits = try await textDecoder.predictLogits( inputIds: decoderInputs.inputIds, cacheLength: decoderInputs.cacheLength, keyCache: decoderInputs.keyCache, @@ -529,6 +443,119 @@ open class TextDecoder: TextDecoding, WhisperMLModel { fallback: nil ) } +} + +public class TextDecoderContextPrefill: WhisperMLModel { + public var model: MLModel? +} + +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) +open class TextDecoder: TextDecoding, WhisperMLModel { + public var model: MLModel? + public var tokenizer: WhisperTokenizer? + public var prefillData: WhisperMLModel? + public var isModelMultilingual: Bool = false + public var shouldEarlyStop: Bool = false + private var languageLogitsFilter: LanguageLogitsFilter? + + public var supportsWordTimestamps: Bool { + return getModelOutputDimention(model, named: "alignment_heads_weights", position: 0) != nil + } + + public var logitsSize: Int? { + return getModelOutputDimention(model, named: "logits", position: 2) + } + + public var kvCacheEmbedDim: Int? { + return getModelInputDimention(model, named: "key_cache", position: 1) + } + + public var kvCacheMaxSequenceLength: Int? { + return getModelInputDimention(model, named: "key_cache", position: 3) + } + + public var windowSize: Int? { + return getModelInputDimention(model, named: "encoder_output_embeds", position: 3) + } + + public var embedSize: Int? { + return getModelInputDimention(model, named: "encoder_output_embeds", position: 1) + } + + /// Override default so we an unload the prefill data as well + public func unloadModel() { + model = nil + prefillData = nil + languageLogitsFilter = nil + } + + public func predictLogits( + inputIds: MLMultiArray, + cacheLength: MLMultiArray, + keyCache: MLMultiArray?, + valueCache: MLMultiArray?, + kvCacheUpdateMask: MLMultiArray, + encoderOutputEmbeds: MLMultiArray, + decoderKeyPaddingMask: MLMultiArray + ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? { + guard let model, let keyCache, let valueCache else { + return nil + } + let modelInputs = TextDecoderInput( + input_ids: inputIds, + cache_length: cacheLength, + key_cache: keyCache, + value_cache: valueCache, + kv_cache_update_mask: kvCacheUpdateMask, + encoder_output_embeds: encoderOutputEmbeds, + decoder_key_padding_mask: decoderKeyPaddingMask + ) + + try Task.checkCancellation() + + let outputFeatures = try await model.asyncPrediction(from: modelInputs, options: MLPredictionOptions()) + + let output = TextDecoderOutput(features: outputFeatures) + + let logits = output.logits + let cache = DecodingCache( + keyCache: output.key_cache_updates, + valueCache: output.value_cache_updates, + alignmentWeights: output.alignment_heads_weights + ) + + return (logits, cache) + } + + public func detectLanguage( + from encoderOutput: MLMultiArray, + using decoderInputs: DecodingInputs, + sampler tokenSampler: TokenSampling, + options: DecodingOptions, + temperature: FloatType + ) async throws -> DecodingResult { + guard let tokenizer else { + throw WhisperError.tokenizerUnavailable() + } + guard let logitsSize else { + throw WhisperError.modelsUnavailable("Failed to read logits size from model") + } + let languageLogitsFilter = self.languageLogitsFilter ?? LanguageLogitsFilter( + allLanguageTokens: tokenizer.allLanguageTokens, + logitsDim: logitsSize, + sampleBegin: 0 + ) + self.languageLogitsFilter = languageLogitsFilter + return try await TextDecoder.detectLanguage( + textDecoder: self, + languageLogitsFilter: languageLogitsFilter, + from: encoderOutput, + using: decoderInputs, + sampler: tokenSampler, + options: options, + temperature: temperature + ) + } public func decodeText( from encoderOutput: MLMultiArray, @@ -693,18 +720,21 @@ open class TextDecoder: TextDecoding, WhisperMLModel { } // tensor: [1, kvCacheEmbedDim, 1, kvCacheMaxSequenceLength], slice: [1, kvCacheEmbedDim, 1, 1] - let kvStartTime = Date() - TextDecoder.updateKVCache(keyTensor: decoderInputs.keyCache, - keySlice: newKeyCache, - valueTensor: decoderInputs.valueCache, - valueSlice: newValueCache, - insertAtIndex: tokenIndex) - let kvTime = Date().timeIntervalSince(kvStartTime) - timings.decodingKvCaching += kvTime - timings.totalKVUpdateRuns += 1 + if let keyCache = decoderInputs.keyCache, let valueCache = decoderInputs.valueCache { + let kvStartTime = Date() + TextDecoder.updateKVCache( + keyTensor: keyCache, + keySlice: newKeyCache, + valueTensor: valueCache, + valueSlice: newValueCache, + insertAtIndex: tokenIndex + ) + let kvTime = Date().timeIntervalSince(kvStartTime) + timings.decodingKvCaching += kvTime + timings.totalKVUpdateRuns += 1 + } decoderInputs.decoderKeyPaddingMask[tokenIndex + 1] = 0 - decoderInputs.kvCacheUpdateMask[tokenIndex] = 0 decoderInputs.kvCacheUpdateMask[tokenIndex + 1] = 1 @@ -725,7 +755,14 @@ open class TextDecoder: TextDecoding, WhisperMLModel { let averageLogProb = logProbs.reduce(0, +) / Float(logProbs.count) let compressionRatio = compressionRatio(of: currentTokens) - let result = TranscriptionProgress(timings: timings, text: currentTranscript, tokens: currentTokens, avgLogprob: averageLogProb, compressionRatio: compressionRatio) + let result = TranscriptionProgress( + timings: timings, + text: currentTranscript, + tokens: currentTokens, + temperature: nil, + avgLogprob: averageLogProb, + compressionRatio: compressionRatio + ) // Call the callback if it is provided on a background thread to avoid blocking the decoding loop if let callback = callback { diff --git a/Sources/WhisperKit/Core/TokenSampler.swift b/Sources/WhisperKit/Core/TokenSampler.swift index ce15cd5..9a8c611 100644 --- a/Sources/WhisperKit/Core/TokenSampler.swift +++ b/Sources/WhisperKit/Core/TokenSampler.swift @@ -14,6 +14,12 @@ public struct SamplingResult { public var tokens: [Int] public var logProbs: [Float] public var completed: Bool + + public init(tokens: [Int], logProbs: [Float], completed: Bool) { + self.tokens = tokens + self.logProbs = logProbs + self.completed = completed + } } @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index a46a65c..cc5ec21 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -149,14 +149,14 @@ public extension WhisperKit { } } -extension Float { +public extension Float { func rounded(_ decimalPlaces: Int) -> Float { let divisor = pow(10.0, Float(decimalPlaces)) return (self * divisor).rounded() / divisor } } -extension String { +public extension String { var normalized: String { // Trim whitespace and newlines let trimmedString = self.trimmingCharacters(in: .whitespacesAndNewlines) @@ -206,7 +206,7 @@ func prepareSeekClips(contentFrames: Int, decodeOptions: DecodingOptions?) -> [( } @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -func initMLMultiArray(shape: [NSNumber], dataType: MLMultiArrayDataType, initialValue: Any) -> MLMultiArray { +public func initMLMultiArray(shape: [NSNumber], dataType: MLMultiArrayDataType, initialValue: Any) -> MLMultiArray { var multiArray: MLMultiArray switch dataType { case .float16: diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index c6fac5b..bbecf4f 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -261,59 +261,49 @@ open class WhisperKit { Logging.debug("Loading models from \(path.path) with prewarmMode: \(prewarmMode)") - let logmelUrl = path.appending(path: "MelSpectrogram.mlmodelc") - let encoderUrl = path.appending(path: "AudioEncoder.mlmodelc") - let decoderUrl = path.appending(path: "TextDecoder.mlmodelc") - let decoderPrefillUrl = path.appending(path: "TextDecoderContextPrefill.mlmodelc") - - for item in [logmelUrl, encoderUrl, decoderUrl] { - if !FileManager.default.fileExists(atPath: item.path) { - throw WhisperError.modelsUnavailable("Model file not found at \(item.path)") - } - } - if let featureExtractor = featureExtractor as? WhisperMLModel { Logging.debug("Loading feature extractor") try await featureExtractor.loadModel( - at: logmelUrl, + at: path.appending(path: "MelSpectrogram.mlmodelc"), computeUnits: modelCompute.melCompute, // hardcoded to use GPU prewarmMode: prewarmMode ) Logging.debug("Loaded feature extractor") } else if let featureExtractor = featureExtractor as? WhisperMLXModel { Logging.debug("Loading MLX feature extractor") - try await featureExtractor.loadModel( - at: path - ) + try await featureExtractor.loadModel(at: path) Logging.debug("Loaded MLX feature extractor") } if let audioEncoder = audioEncoder as? WhisperMLModel { Logging.debug("Loading audio encoder") try await audioEncoder.loadModel( - at: encoderUrl, + at: path.appending(path: "AudioEncoder.mlmodelc"), computeUnits: modelCompute.audioEncoderCompute, prewarmMode: prewarmMode ) Logging.debug("Loaded audio encoder") } else if let audioEncoder = audioEncoder as? WhisperMLXModel { Logging.debug("Loading MLX audio encoder") - try await audioEncoder.loadModel( - at: path - ) + try await audioEncoder.loadModel(at: path) Logging.debug("Loaded MLX audio encoder") } if let textDecoder = textDecoder as? WhisperMLModel { Logging.debug("Loading text decoder") try await textDecoder.loadModel( - at: decoderUrl, + at: path.appending(path: "TextDecoder.mlmodelc"), computeUnits: modelCompute.textDecoderCompute, prewarmMode: prewarmMode ) Logging.debug("Loaded text decoder") + } else if let textDecoder = textDecoder as? WhisperMLXModel { + Logging.debug("Loading MLX text decoder") + try await textDecoder.loadModel(at: path) + Logging.debug("Loaded MLX text decoder") } + let decoderPrefillUrl = path.appending(path: "TextDecoderContextPrefill.mlmodelc") if FileManager.default.fileExists(atPath: decoderPrefillUrl.path) { Logging.debug("Loading text decoder prefill data") textDecoder.prefillData = TextDecoderContextPrefill() diff --git a/Sources/WhisperKit/MLX/Attention.swift b/Sources/WhisperKit/MLX/Attention.swift index c38c3fa..5a43cd2 100644 --- a/Sources/WhisperKit/MLX/Attention.swift +++ b/Sources/WhisperKit/MLX/Attention.swift @@ -24,8 +24,8 @@ final class MultiHeadAttention: Module { _ x: MLXArray, xa: MLXArray? = nil, mask: MLXArray? = nil, - kvCache: MLXArray? = nil - ) -> (MLXArray, (MLXArray, MLXArray), MLXArray) { + kvCache: KV? = nil + ) -> MultiHeadAttentionResult { let q = query(x) var k: MLXArray @@ -37,13 +37,17 @@ final class MultiHeadAttention: Module { k = key(x) v = value(x) if let kvCache { - k = MLX.concatenated([kvCache[0], k], axis: 1) - v = MLX.concatenated([kvCache[1], v], axis: 1) + k = MLX.concatenated([kvCache.k, k], axis: 1) + v = MLX.concatenated([kvCache.v, v], axis: 1) } } let (wv, qk) = qkvAttention(q, k, v, mask) - return (out(wv), (k, v), qk) + return MultiHeadAttentionResult( + x: out(wv), + kv: KV(k: k, v: v), + qk: qk + ) } private func qkvAttention(_ q: MLXArray, _ k: MLXArray, _ v: MLXArray, _ mask: MLXArray?) -> (MLXArray, MLXArray) { @@ -70,10 +74,14 @@ final class ResidualAttentionBlock: Module { let mlp1: Linear let mlp2: Linear let mlp_ln: LayerNorm + let cross_attn: MultiHeadAttention? + let cross_attn_ln: LayerNorm? - init(nState: Int, nHead: Int) { + init(nState: Int, nHead: Int, crossAttention: Bool = false) { self.attn = MultiHeadAttention(nState: nState, nHead: nHead) self.attn_ln = LayerNorm(dimensions: nState) + self.cross_attn = crossAttention ? MultiHeadAttention(nState: nState, nHead: nHead) : nil + self.cross_attn_ln = crossAttention ? LayerNorm(dimensions: nState) : nil let nMlp = nState * 4 self.mlp1 = Linear(nState, nMlp) self.mlp2 = Linear(nMlp, nState) @@ -84,12 +92,29 @@ final class ResidualAttentionBlock: Module { _ x: MLXArray, xa: MLXArray? = nil, mask: MLXArray? = nil, - kvCache: (MLXArray, MLXArray)? = nil + kvCache: KV? = nil, + crossKvCache: KV? = nil ) -> ResidualAttentionBlockResult { - let (kvCache, crossKv) = kvCache ?? (nil, nil) - let (y, kv, _) = attn(attn_ln(x), mask: mask, kvCache: kvCache) - var x = x + y - x = x + mlp2(gelu(mlp1(mlp_ln(x)))) - return ResidualAttentionBlockResult(x: x, kv: kv, crossKv: crossKv, crossQk: nil) + let attnResult = attn(attn_ln(x), mask: mask, kvCache: kvCache) + var x = x + attnResult.x + if let cross_attn, let cross_attn_ln { + let crossAttnResult = cross_attn(cross_attn_ln(x), xa: xa, kvCache: crossKvCache) + x = x + crossAttnResult.x + x = x + mlp2(gelu(mlp1(mlp_ln(x)))) + return ResidualAttentionBlockResult( + x: x, + kv: attnResult.kv, + crossKv: crossAttnResult.kv, + crossQk: crossAttnResult.qk + ) + } else { + x = x + mlp2(gelu(mlp1(mlp_ln(x)))) + return ResidualAttentionBlockResult( + x: x, + kv: attnResult.kv, + crossKv: crossKvCache, + crossQk: nil + ) + } } } diff --git a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift index 9b074a0..4000dd8 100644 --- a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift +++ b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift @@ -22,7 +22,7 @@ public class MLXAudioEncoder: AudioEncoding { try Task.checkCancellation() let inputArray = features.asMLXArray(FloatType.self) let input = inputArray.asMLXInput() - let output = encoder(input[.newAxis]) + let output = encoder(input) return try output.asMLXOutput().asMLMultiArray() } } diff --git a/Sources/WhisperKit/MLX/MLXModels.swift b/Sources/WhisperKit/MLX/MLXModels.swift index 5ab5f2b..207e884 100644 --- a/Sources/WhisperKit/MLX/MLXModels.swift +++ b/Sources/WhisperKit/MLX/MLXModels.swift @@ -22,9 +22,25 @@ struct MLXModelConfig: Codable { let nTextLayer: Int } +struct KV { + var k: MLXArray + var v: MLXArray +} + +struct TextDecoderResult { + var logits: MLXArray + var kvCache: [KV] +} + struct ResidualAttentionBlockResult { var x: MLXArray - var kv: (MLXArray, MLXArray) - var crossKv: MLXArray? + var kv: KV + var crossKv: KV? var crossQk: MLXArray? } + +struct MultiHeadAttentionResult { + var x: MLXArray + var kv: KV + var qk: MLXArray +} diff --git a/Sources/WhisperKit/MLX/MLXTextDecoder.swift b/Sources/WhisperKit/MLX/MLXTextDecoder.swift new file mode 100644 index 0000000..eb30fd2 --- /dev/null +++ b/Sources/WhisperKit/MLX/MLXTextDecoder.swift @@ -0,0 +1,543 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import CoreML +import MLX +import MLXNN +import WhisperKit + +@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) +public final class MLXTextDecoder: TextDecoding { + public var tokenizer: (any WhisperTokenizer)? + public var prefillData: (any WhisperMLModel)? + public var isModelMultilingual: Bool = false + public let supportsWordTimestamps: Bool = false + public var logitsSize: Int? { + decoder?.nState + } + public var kvCacheEmbedDim: Int? { + guard let config else { return nil } + return config.nTextState * config.nTextLayer + } + public var kvCacheMaxSequenceLength: Int? { + guard let config else { return nil } + return config.nTextCtx / 2 + } + public var windowSize: Int? { + guard let config else { return nil } + return config.nAudioCtx + } + public var embedSize: Int? { + guard let config else { return nil } + return config.nTextState + } + private var decoder: TextDecoder? + private var config: MLXModelConfig? + private var languageLogitsFilter: LanguageLogitsFilter? + + public init() {} + + private static func toKvCache(keyCache: MLMultiArray?, valueCache: MLMultiArray?) -> [KV]? { + guard let keyCache, let valueCache else { + return nil + } + let keyCacheMlx = keyCache.asMLXArray(FloatType.self) + let valueCacheMlx = valueCache.asMLXArray(FloatType.self) + assert(keyCacheMlx.shape == valueCacheMlx.shape) + var result = [KV]() + for index in 0.. DecodingInputs { + let tokenShape = [NSNumber(value: 1), NSNumber(value: initialPrompt.count)] + + // Initialize MLMultiArray for tokens + let tokenMultiArray = try MLMultiArray(shape: tokenShape, dataType: .int32) + + // Assign token values to the MLMultiArray + for (index, token) in initialPrompt.enumerated() { + tokenMultiArray[index] = NSNumber(value: token) + } + + guard let kvCacheEmbedDim = self.kvCacheEmbedDim else { + throw WhisperError.prepareDecoderInputsFailed("Unable to determine kvCacheEmbedDim") + } + + guard let kvCacheMaxSequenceLength = self.kvCacheMaxSequenceLength else { + throw WhisperError.prepareDecoderInputsFailed("Unable to determine kvCacheMaxSequenceLength") + } + + guard let encoderOutputDim = self.windowSize else { + throw WhisperError.prepareDecoderInputsFailed("Unable to determine encoderOutputDim") + } + + // Initialize each MLMultiArray + let kvCacheEmbedDimValue = NSNumber(value: kvCacheEmbedDim) + let kvCacheMaxSequenceLengthValue = NSNumber(value: kvCacheMaxSequenceLength) + let encoderOutputDimValue = NSNumber(value: encoderOutputDim) + + let inputIds = initMLMultiArray(shape: [1], dataType: .int32, initialValue: Int32(0)) + let cacheLength = initMLMultiArray(shape: [1], dataType: .int32, initialValue: Int32(0)) + let alignmentWeights = initMLMultiArray(shape: [kvCacheMaxSequenceLengthValue, encoderOutputDimValue], dataType: .float16, initialValue: FloatType(0)) + let kvCacheUpdateMask = initMLMultiArray(shape: [1, kvCacheMaxSequenceLengthValue], dataType: .int32, initialValue: Int32(0)) + let decoderKeyPaddingMask = initMLMultiArray(shape: [1, kvCacheMaxSequenceLengthValue], dataType: .float16, initialValue: FloatType(-10000)) + let prefillKeyCache = try! MLMultiArray(shape: [1, kvCacheEmbedDimValue, 1, kvCacheMaxSequenceLengthValue], dataType: .float16) + let prefillValueCache = try! MLMultiArray(shape: [1, kvCacheEmbedDimValue, 1, kvCacheMaxSequenceLengthValue], dataType: .float16) + let decoderInputs = DecodingInputs( + initialPrompt: initialPrompt, + inputIds: inputIds, + cacheLength: cacheLength, + keyCache: nil, + valueCache: nil, + alignmentWeights: alignmentWeights, + kvCacheUpdateMask: kvCacheUpdateMask, + decoderKeyPaddingMask: decoderKeyPaddingMask, + prefillKeyCache: prefillKeyCache, + prefillValueCache: prefillValueCache + ) + return decoderInputs + } + + public func predictLogits( + inputIds: MLMultiArray, + cacheLength: MLMultiArray, + keyCache: MLMultiArray?, + valueCache: MLMultiArray?, + kvCacheUpdateMask: MLMultiArray, + encoderOutputEmbeds: MLMultiArray, + decoderKeyPaddingMask: MLMultiArray + ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? { + guard let decoder else { + return nil + } + let tokens = inputIds.asMLXArray(Int32.self) + let audioFeatures = encoderOutputEmbeds.asMLXArray(FloatType.self).asMLXInput() + let result = decoder( + tokens, + xa: audioFeatures, + kvCache: Self.toKvCache(keyCache: keyCache, valueCache: valueCache) + ) + let keyCache = try MLX.stacked(result.kvCache.map(\.k)).asMLMultiArray() + let valueCache = try MLX.stacked(result.kvCache.map(\.v)).asMLMultiArray() + let decodingCache = DecodingCache( + keyCache: keyCache, + valueCache: valueCache, + alignmentWeights: nil + ) + return try (result.logits.asMLMultiArray(), decodingCache) + } + + public func decodeText( + from encoderOutput: MLMultiArray, + using decoderInputs: DecodingInputs, + sampler tokenSampler: TokenSampling, + options: DecodingOptions, + callback: TranscriptionCallback = nil + ) async throws -> DecodingResult { + guard let tokenizer else { + // Tokenizer required for decoding + throw WhisperError.tokenizerUnavailable() + } + + // Single loop variables + var timings = TranscriptionTimings() + let prefilledIndex = decoderInputs.cacheLength[0].intValue + let intialPromptIndex = decoderInputs.initialPrompt.count + var currentTokens: [Int] = decoderInputs.initialPrompt + var nextToken: Int = decoderInputs.initialPrompt.last! + var logProbs: [Float] = Array(repeating: 0, count: currentTokens.count) + + // Logits filters + var logitsFilters: [any LogitsFiltering] = [] + if options.suppressBlank { + logitsFilters.append( + SuppressBlankFilter( + specialTokens: tokenizer.specialTokens, + sampleBegin: prefilledIndex + ) + ) + } + + if !options.supressTokens.isEmpty { + logitsFilters.append(SuppressTokensFilter(suppressTokens: options.supressTokens)) + } + + if !options.withoutTimestamps { + let maxInitialTimestampIndex: Int? = + if let maxInitialTimestamp = options.maxInitialTimestamp { + Int(maxInitialTimestamp / WhisperKit.secondsPerTimeToken) + } else { + nil + } + logitsFilters.append( + TimestampRulesFilter( + specialTokens: tokenizer.specialTokens, + sampleBegin: intialPromptIndex, + maxInitialTimestampIndex: maxInitialTimestampIndex, + isModelMultilingual: isModelMultilingual + ) + ) + } + + // MARK: Main loop + + let loopCount = min(options.sampleLength, Constants.maxTokenContext - 1) + Logging.debug("Running main loop for a maximum of \(loopCount) iterations, starting at index \(prefilledIndex)") + var hasAlignment = false + var isFirstTokenLogProbTooLow = false + var keyCache = decoderInputs.keyCache + var valueCache = decoderInputs.valueCache + for tokenIndex in prefilledIndex..= Constants.maxTokenContext - 1 || + isFirstTokenLogProbTooLow + + if isSegmentCompleted { + // Completed segment, stop the loop + timings.decodingNonPrediction += Date().timeIntervalSince(nonInferenceStartTime) + timings.decodingLoop += Date().timeIntervalSince(loopStart) + timings.totalDecodingLoops += 1 + break + } else { + // MARK: KV Caching + + if !isPrefill { + // Found the next token, store it + currentTokens.append(nextToken) + logProbs.append(nextTokenLogProb) + } + + decoderInputs.decoderKeyPaddingMask[tokenIndex + 1] = 0 + decoderInputs.kvCacheUpdateMask[tokenIndex] = 0 + decoderInputs.kvCacheUpdateMask[tokenIndex + 1] = 1 + + // Update alignment weights for token if present + if let newAlignmentWeights = decoderOutput.cache?.alignmentWeights { + hasAlignment = true + for column in 0.. DecodingResult { + guard let tokenizer else { + throw WhisperError.tokenizerUnavailable() + } + guard let logitsSize else { + throw WhisperError.modelsUnavailable("Failed to read logits size from model") + } + let languageLogitsFilter = self.languageLogitsFilter ?? LanguageLogitsFilter( + allLanguageTokens: tokenizer.allLanguageTokens, + logitsDim: logitsSize, + sampleBegin: 0 + ) + self.languageLogitsFilter = languageLogitsFilter + return try await MLXTextDecoder.detectLanguage( + textDecoder: self, + languageLogitsFilter: languageLogitsFilter, + from: encoderOutput, + using: decoderInputs, + sampler: tokenSampler, + options: options, + temperature: temperature + ) + } +} + +extension MLXTextDecoder: WhisperMLXModel { + public func loadModel(at modelPath: URL) async throws { + let parameters = try loadParameters(at: modelPath.appending(path: "weights.safetensors"), forKey: "decoder") + let config = try loadConfig(at: modelPath.appending(path: "config.json")) + let decoder = TextDecoder( + nVocab: config.nVocab, + nCtx: config.nTextCtx, + nState: config.nTextState, + nHead: config.nTextHead, + nLayer: config.nTextLayer, + dtype: .float16 + ) + let loadedDecoder = try decoder.update(parameters: parameters, verify: [.noUnusedKeys]) + MLX.eval(loadedDecoder) + self.decoder = loadedDecoder + self.config = config + } + + public func unloadModel() { + decoder = nil + config = nil + prefillData = nil + languageLogitsFilter = nil + } +} + +final class TextDecoder: Module { + let nVocab: Int + let nCtx: Int + let nState: Int + let nHead: Int + let nLayer: Int + let dtype: MLX.DType + + private let token_embedding: Embedding + private let positional_embedding: MLXArray + private let blocks: [ResidualAttentionBlock] + private let ln: LayerNorm + private let _mask: MLXArray + + init( + nVocab: Int, + nCtx: Int, + nState: Int, + nHead: Int, + nLayer: Int, + dtype: MLX.DType = .float16 + ) { + self.nVocab = nVocab + self.nCtx = nCtx + self.nState = nState + self.nHead = nHead + self.nLayer = nLayer + self.dtype = dtype + + self.token_embedding = Embedding(embeddingCount: nVocab, dimensions: nState) + self.positional_embedding = MLX.zeros([nCtx, nState]) + self.blocks = (0.. TextDecoderResult { + let offset = kvCache?.first??.k.shape[1] ?? 0 + var x = x[.newAxis, .ellipsis] + x = token_embedding(x) + positional_embedding[offset.. MLXArray { - squeezed().swappedAxes(0, 1) + squeezed().swappedAxes(0, 1).expandedDimensions(axes: [0]) } } @@ -72,8 +72,20 @@ extension MLXArray { } } +extension Embedding { + func asLinear(_ x: MLXArray) -> MLXArray { + x.matmul(weight.T) + } +} + // MARK: - Functions +func additiveCausalMask(_ n: Int, dType: MLX.DType = .float32) -> MLXArray { + let indices = MLXArray(Array(0.. MLXArray { assert(channels % 2 == 0) let logTimescaleIncrement = log(Float(maxTimescale)) / Float(channels / 2 - 1) diff --git a/Tests/WhisperKitMLXTests/MLXUnitTests.swift b/Tests/WhisperKitMLXTests/MLXUnitTests.swift index 1fb1f2d..c9f6dbf 100644 --- a/Tests/WhisperKitMLXTests/MLXUnitTests.swift +++ b/Tests/WhisperKitMLXTests/MLXUnitTests.swift @@ -44,6 +44,40 @@ final class MLXUnitTests: XCTestCase { XCTAssertEqual(encoderOutput?.shape, expectedShape, "Encoder output shape is not as expected") } + // MARK: - Decoder Tests + + func testDecoderOutput() async throws { + let textDecoder = MLXTextDecoder() + let decodingOptions = DecodingOptions() + let modelPath = try URL(filePath: tinyMLXModelPath()) + await XCTAssertNoThrowAsync( + try await textDecoder.loadModel(at: modelPath), + "Failed to load the model" + ) + textDecoder.tokenizer = try await XCTUnwrapAsync( + await loadTokenizer(for: .tiny), + "Failed to load the tokenizer" + ) + + let tokenSampler = GreedyTokenSampler( + temperature: 0, + eotToken: textDecoder.tokenizer!.specialTokens.endToken, + decodingOptions: decodingOptions + ) + + let encoderInput = try MLMultiArray(shape: [1, 384, 1, 1500], dataType: .float16) + let inputs = try textDecoder.prepareDecoderInputs(withPrompt: [textDecoder.tokenizer!.specialTokens.startOfTranscriptToken]) + + await XCTAssertNoThrowAsync( + try await textDecoder.decodeText( + from: encoderInput, + using: inputs, + sampler: tokenSampler, + options: decodingOptions + ) + ) + } + // MARK: - Utils Tests func testArrayConversion() throws { @@ -100,6 +134,22 @@ final class MLXUnitTests: XCTestCase { } } } + + let arr4 = MLXArray(input, [2, 3, 2, 2]) + let multiArray4 = try arr4.asMLMultiArray() + + XCTAssertEqual(arr4.shape, multiArray4.shape.map { $0.intValue }) + for dim1 in 0..<2 { + for dim2 in 0..<3 { + for dim3 in 0..<2 { + for dim4 in 0..<2 { + let v1 = multiArray4[[dim1, dim2, dim3, dim4] as [NSNumber]].floatValue + let v2 = arr4[dim1, dim2, dim3, dim4] + XCTAssertEqual(v1, v2.item(Float.self), accuracy: accuracy) + } + } + } + } } func testSinusoids() { @@ -122,4 +172,25 @@ final class MLXUnitTests: XCTestCase { XCTAssertEqual(result3[2].asArray(Float.self), [0.909297, 0.0926985, 0.00430886, 0.0002, -0.416147, 0.995694, 0.999991, 1.0], accuracy: accuracy) XCTAssertEqual(result3[3].asArray(Float.self), [0.14112, 0.138798, 0.00646326, 0.0003, -0.989992, 0.990321, 0.999979, 1.0], accuracy: accuracy) } + + func testAdditiveCausalMask() { + let result1 = additiveCausalMask(0) + XCTAssertEqual(result1.shape, [0 ,0]) + XCTAssertEqual(result1.dtype, .float32) + + let result2 = additiveCausalMask(3) + XCTAssertEqual(result2.shape, [3 ,3]) + XCTAssertEqual(result2.dtype, .float32) + XCTAssertEqual(result2[0].asArray(Float.self), [0.0, -1e9, -1e9], accuracy: accuracy) + XCTAssertEqual(result2[1].asArray(Float.self), [0.0, 0.0, -1e9], accuracy: accuracy) + XCTAssertEqual(result2[2].asArray(Float.self), [0.0, 0.0, 0.0], accuracy: accuracy) + + let result3 = additiveCausalMask(4) + XCTAssertEqual(result3.shape, [4 ,4]) + XCTAssertEqual(result3.dtype, .float32) + XCTAssertEqual(result3[0].asArray(Float.self), [0.0, -1e9, -1e9, -1e9], accuracy: accuracy) + XCTAssertEqual(result3[1].asArray(Float.self), [0.0, 0.0, -1e9, -1e9], accuracy: accuracy) + XCTAssertEqual(result3[2].asArray(Float.self), [0.0, 0.0, 0.0, -1e9], accuracy: accuracy) + XCTAssertEqual(result3[3].asArray(Float.self), [0.0, 0.0, 0.0, 0.0], accuracy: accuracy) + } } From ce604928d98084c0675667e53d11138e782c9aa0 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 15 Jun 2024 10:58:55 -0700 Subject: [PATCH 07/29] Fix merge --- Sources/WhisperKit/Core/TextDecoder.swift | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 6f6775c..582fd81 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -451,7 +451,7 @@ public class TextDecoderContextPrefill: WhisperMLModel { @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) open class TextDecoder: TextDecoding, WhisperMLModel { - public var model: MLModel? + public var model: MLModel? public var tokenizer: WhisperTokenizer? public var prefillData: WhisperMLModel? public var isModelMultilingual: Bool = false @@ -492,12 +492,16 @@ open class TextDecoder: TextDecoding, WhisperMLModel { public func predictLogits( inputIds: MLMultiArray, cacheLength: MLMultiArray, - keyCache: MLMultiArray, - valueCache: MLMultiArray, + keyCache: MLMultiArray?, + valueCache: MLMultiArray?, kvCacheUpdateMask: MLMultiArray, encoderOutputEmbeds: MLMultiArray, decoderKeyPaddingMask: MLMultiArray ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? { + guard let model, let keyCache, let valueCache else { + return nil + } + let modelInputs = TextDecoderInput( input_ids: inputIds, cache_length: cacheLength, @@ -508,10 +512,6 @@ open class TextDecoder: TextDecoding, WhisperMLModel { decoder_key_padding_mask: decoderKeyPaddingMask ) - guard let model = model else { - return nil - } - try Task.checkCancellation() let outputFeatures = try await model.asyncPrediction(from: modelInputs, options: MLPredictionOptions()) From 1e12fe289d30da4e7c2e71ea4523201524d37be4 Mon Sep 17 00:00:00 2001 From: Jan Krukowski Date: Wed, 19 Jun 2024 22:01:05 +0200 Subject: [PATCH 08/29] Cleanup and more tests for MLX (#169) * Added more tests for MLX, cleanup * bumped timeout * fixed tests * reverted cache id --- .github/workflows/unit-tests.yml | 2 +- Sources/WhisperKit/Core/Models.swift | 2 +- Sources/WhisperKit/Core/WhisperKit.swift | 12 +- Sources/WhisperKit/MLX/Attention.swift | 54 ++--- Sources/WhisperKit/MLX/MLXAudioEncoder.swift | 32 +-- Sources/WhisperKit/MLX/MLXTextDecoder.swift | 26 +- Sources/WhisperKit/MLX/MLXUtils.swift | 11 +- Sources/WhisperKitTestsUtils/TestUtils.swift | 22 +- Tests/WhisperKitMLXTests/MLXUnitTests.swift | 242 ++++++++++++++++++- Tests/WhisperKitTests/UnitTests.swift | 93 +++---- 10 files changed, 368 insertions(+), 128 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 970fedc..16200de 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -53,7 +53,7 @@ jobs: mlx-disabled: "1", scheme: "whisperkit", } - timeout-minutes: 20 + timeout-minutes: 30 steps: - uses: actions/checkout@v4 - uses: maxim-lobanov/setup-xcode@v1 diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index c291380..26241de 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -30,7 +30,7 @@ public protocol WhisperMLModel: WhisperModel { } public protocol WhisperMLXModel: WhisperModel { - func loadModel(at modelPath: URL) async throws + func loadModel(at modelPath: URL, configPath: URL) async throws } public extension WhisperMLModel { diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index bbecf4f..c527337 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -271,7 +271,7 @@ open class WhisperKit { Logging.debug("Loaded feature extractor") } else if let featureExtractor = featureExtractor as? WhisperMLXModel { Logging.debug("Loading MLX feature extractor") - try await featureExtractor.loadModel(at: path) + try await featureExtractor.loadModel(at: path, configPath: path) Logging.debug("Loaded MLX feature extractor") } @@ -285,7 +285,10 @@ open class WhisperKit { Logging.debug("Loaded audio encoder") } else if let audioEncoder = audioEncoder as? WhisperMLXModel { Logging.debug("Loading MLX audio encoder") - try await audioEncoder.loadModel(at: path) + try await audioEncoder.loadModel( + at: path.appending(path: "encoder.safetensors"), + configPath: path.appending(path: "config.json") + ) Logging.debug("Loaded MLX audio encoder") } @@ -299,7 +302,10 @@ open class WhisperKit { Logging.debug("Loaded text decoder") } else if let textDecoder = textDecoder as? WhisperMLXModel { Logging.debug("Loading MLX text decoder") - try await textDecoder.loadModel(at: path) + try await textDecoder.loadModel( + at: path.appending(path: "decoder.safetensors"), + configPath: path.appending(path: "config.json") + ) Logging.debug("Loaded MLX text decoder") } diff --git a/Sources/WhisperKit/MLX/Attention.swift b/Sources/WhisperKit/MLX/Attention.swift index 5a43cd2..53d46d9 100644 --- a/Sources/WhisperKit/MLX/Attention.swift +++ b/Sources/WhisperKit/MLX/Attention.swift @@ -7,17 +7,17 @@ import MLXNN final class MultiHeadAttention: Module { let nHead: Int - let query: Linear - let key: Linear - let value: Linear - let out: Linear + @ModuleInfo(key: "query") private var query: Linear + @ModuleInfo(key: "key") private var key: Linear + @ModuleInfo(key: "value") private var value: Linear + @ModuleInfo(key: "out") private var out: Linear init(nState: Int, nHead: Int) { self.nHead = nHead - self.query = Linear(nState, nState) - self.key = Linear(nState, nState, bias: false) - self.value = Linear(nState, nState) - self.out = Linear(nState, nState) + self._query.wrappedValue = Linear(nState, nState) + self._key.wrappedValue = Linear(nState, nState, bias: false) + self._value.wrappedValue = Linear(nState, nState) + self._out.wrappedValue = Linear(nState, nState) } func callAsFunction( @@ -69,23 +69,23 @@ final class MultiHeadAttention: Module { } final class ResidualAttentionBlock: Module { - let attn: MultiHeadAttention - let attn_ln: LayerNorm - let mlp1: Linear - let mlp2: Linear - let mlp_ln: LayerNorm - let cross_attn: MultiHeadAttention? - let cross_attn_ln: LayerNorm? + @ModuleInfo(key: "attn") private var attn: MultiHeadAttention + @ModuleInfo(key: "attn_ln") private var attnLn: LayerNorm + @ModuleInfo(key: "mlp1") private var mlp1: Linear + @ModuleInfo(key: "mlp2") private var mlp2: Linear + @ModuleInfo(key: "mlp_ln") private var mlpLn: LayerNorm + @ModuleInfo(key: "cross_attn") private var crossAttn: MultiHeadAttention? + @ModuleInfo(key: "cross_attn_ln") private var crossAttnLn: LayerNorm? init(nState: Int, nHead: Int, crossAttention: Bool = false) { - self.attn = MultiHeadAttention(nState: nState, nHead: nHead) - self.attn_ln = LayerNorm(dimensions: nState) - self.cross_attn = crossAttention ? MultiHeadAttention(nState: nState, nHead: nHead) : nil - self.cross_attn_ln = crossAttention ? LayerNorm(dimensions: nState) : nil let nMlp = nState * 4 - self.mlp1 = Linear(nState, nMlp) - self.mlp2 = Linear(nMlp, nState) - self.mlp_ln = LayerNorm(dimensions: nState) + self._attn.wrappedValue = MultiHeadAttention(nState: nState, nHead: nHead) + self._attnLn.wrappedValue = LayerNorm(dimensions: nState) + self._crossAttn.wrappedValue = crossAttention ? MultiHeadAttention(nState: nState, nHead: nHead) : nil + self._crossAttnLn.wrappedValue = crossAttention ? LayerNorm(dimensions: nState) : nil + self._mlp1.wrappedValue = Linear(nState, nMlp) + self._mlp2.wrappedValue = Linear(nMlp, nState) + self._mlpLn.wrappedValue = LayerNorm(dimensions: nState) } func callAsFunction( @@ -95,12 +95,12 @@ final class ResidualAttentionBlock: Module { kvCache: KV? = nil, crossKvCache: KV? = nil ) -> ResidualAttentionBlockResult { - let attnResult = attn(attn_ln(x), mask: mask, kvCache: kvCache) + let attnResult = attn(attnLn(x), mask: mask, kvCache: kvCache) var x = x + attnResult.x - if let cross_attn, let cross_attn_ln { - let crossAttnResult = cross_attn(cross_attn_ln(x), xa: xa, kvCache: crossKvCache) + if let crossAttn, let crossAttnLn { + let crossAttnResult = crossAttn(crossAttnLn(x), xa: xa, kvCache: crossKvCache) x = x + crossAttnResult.x - x = x + mlp2(gelu(mlp1(mlp_ln(x)))) + x = x + mlp2(gelu(mlp1(mlpLn(x)))) return ResidualAttentionBlockResult( x: x, kv: attnResult.kv, @@ -108,7 +108,7 @@ final class ResidualAttentionBlock: Module { crossQk: crossAttnResult.qk ) } else { - x = x + mlp2(gelu(mlp1(mlp_ln(x)))) + x = x + mlp2(gelu(mlp1(mlpLn(x)))) return ResidualAttentionBlockResult( x: x, kv: attnResult.kv, diff --git a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift index 4000dd8..ab472ad 100644 --- a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift +++ b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift @@ -28,9 +28,9 @@ public class MLXAudioEncoder: AudioEncoding { } extension MLXAudioEncoder: WhisperMLXModel { - public func loadModel(at modelPath: URL) async throws { - let parameters = try loadParameters(at: modelPath.appending(path: "weights.safetensors"), forKey: "encoder") - let config = try loadConfig(at: modelPath.appending(path: "config.json")) + public func loadModel(at modelPath: URL, configPath: URL) async throws { + let parameters = try loadParameters(at: modelPath) + let config = try loadConfig(at: configPath) let encoder = AudioEncoder( nMels: config.nMels, nCtx: config.nAudioCtx, @@ -57,11 +57,11 @@ final class AudioEncoder: Module { let nLayer: Int let dType: MLX.DType - private let conv1: Conv1d - private let conv2: Conv1d - private let positionalEmbedding: MLXArray - private let blocks: [ResidualAttentionBlock] - private let ln_post: LayerNorm + @ModuleInfo(key: "conv1") private var conv1: Conv1d + @ModuleInfo(key: "conv2") private var conv2: Conv1d + @ModuleInfo(key: "blocks") private var blocks: [ResidualAttentionBlock] + @ModuleInfo(key: "ln_post") private var lnPost: LayerNorm + private let _positionalEmbedding: MLXArray init( nMels: Int, @@ -78,22 +78,22 @@ final class AudioEncoder: Module { self.nLayer = nLayer self.dType = dType - self.conv1 = Conv1d(inputChannels: nMels, outputChannels: nState, kernelSize: 3, padding: 1) - self.conv2 = Conv1d(inputChannels: nState, outputChannels: nState, kernelSize: 3, stride: 2, padding: 1) - self.positionalEmbedding = sinusoids(length: nCtx, channels: nState).asType(dType) - self.blocks = (0.. MLXArray { var x = MLXNN.gelu(conv1(x)) x = MLXNN.gelu(conv2(x)) - assert(Array(x.shape[1...]) == positionalEmbedding.shape, "incorrect audio shape") - x = x + positionalEmbedding + assert(Array(x.shape[1...]) == _positionalEmbedding.shape, "incorrect audio shape") + x = x + _positionalEmbedding for block in blocks { x = block(x).x } - x = ln_post(x) + x = lnPost(x) return x } } diff --git a/Sources/WhisperKit/MLX/MLXTextDecoder.swift b/Sources/WhisperKit/MLX/MLXTextDecoder.swift index eb30fd2..f46b533 100644 --- a/Sources/WhisperKit/MLX/MLXTextDecoder.swift +++ b/Sources/WhisperKit/MLX/MLXTextDecoder.swift @@ -457,9 +457,9 @@ public final class MLXTextDecoder: TextDecoding { } extension MLXTextDecoder: WhisperMLXModel { - public func loadModel(at modelPath: URL) async throws { - let parameters = try loadParameters(at: modelPath.appending(path: "weights.safetensors"), forKey: "decoder") - let config = try loadConfig(at: modelPath.appending(path: "config.json")) + public func loadModel(at modelPath: URL, configPath: URL) async throws { + let parameters = try loadParameters(at: modelPath) + let config = try loadConfig(at: configPath) let decoder = TextDecoder( nVocab: config.nVocab, nCtx: config.nTextCtx, @@ -490,10 +490,10 @@ final class TextDecoder: Module { let nLayer: Int let dtype: MLX.DType - private let token_embedding: Embedding - private let positional_embedding: MLXArray - private let blocks: [ResidualAttentionBlock] - private let ln: LayerNorm + @ModuleInfo(key: "token_embedding") private var tokenEmbedding: Embedding + @ModuleInfo(key: "positional_embedding") private var positionalEmbedding: MLXArray + @ModuleInfo(key: "blocks") private var blocks: [ResidualAttentionBlock] + @ModuleInfo(key: "ln") private var ln: LayerNorm private let _mask: MLXArray init( @@ -511,12 +511,12 @@ final class TextDecoder: Module { self.nLayer = nLayer self.dtype = dtype - self.token_embedding = Embedding(embeddingCount: nVocab, dimensions: nState) - self.positional_embedding = MLX.zeros([nCtx, nState]) - self.blocks = (0.. TextDecoderResult { let offset = kvCache?.first??.k.shape[1] ?? 0 var x = x[.newAxis, .ellipsis] - x = token_embedding(x) + positional_embedding[offset.. MLXArra return MLX.concatenated([MLX.sin(scaledTime), MLX.cos(scaledTime)], axis: 1) } -func loadParameters(at url: URL, forKey key: String? = nil) throws -> NestedDictionary { +func loadParameters(at url: URL) throws -> NestedDictionary { let arrays = try MLX.loadArrays(url: url) - let params = ModuleParameters.unflattened(arrays) - guard let key else { - return params - } - guard let keyParams = params[key] else { - throw CocoaError.error(.coderValueNotFound) - } - return NestedDictionary(item: keyParams) + return ModuleParameters.unflattened(arrays) } func loadConfig(at url: URL) throws -> MLXModelConfig { diff --git a/Sources/WhisperKitTestsUtils/TestUtils.swift b/Sources/WhisperKitTestsUtils/TestUtils.swift index 0b69450..df57ac7 100644 --- a/Sources/WhisperKitTestsUtils/TestUtils.swift +++ b/Sources/WhisperKitTestsUtils/TestUtils.swift @@ -132,27 +132,31 @@ public extension MLMultiArray { @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public extension XCTestCase { func transcribe( - with variant: ModelVariant, + modelPath: String, options: DecodingOptions, callback: TranscriptionCallback = nil, audioFile: String = "jfk.wav", + featureExtractor: (any FeatureExtracting)? = nil, + audioEncoder: (any AudioEncoding)? = nil, + textDecoder: (any TextDecoding)? = nil, file: StaticString = #file, line: UInt = #line ) async throws -> [TranscriptionResult] { - let modelPath: String - switch variant { - case .largev3: - modelPath = try largev3ModelPath() - default: - modelPath = try tinyModelPath() - } let computeOptions = ModelComputeOptions( melCompute: .cpuOnly, audioEncoderCompute: .cpuOnly, textDecoderCompute: .cpuOnly, prefillCompute: .cpuOnly ) - let whisperKit = try await WhisperKit(modelFolder: modelPath, computeOptions: computeOptions, verbose: true, logLevel: .debug) + let whisperKit = try await WhisperKit( + modelFolder: modelPath, + computeOptions: computeOptions, + featureExtractor: featureExtractor, + audioEncoder: audioEncoder, + textDecoder: textDecoder, + verbose: true, + logLevel: .debug + ) trackForMemoryLeaks(on: whisperKit, file: file, line: line) let audioComponents = audioFile.components(separatedBy: ".") diff --git a/Tests/WhisperKitMLXTests/MLXUnitTests.swift b/Tests/WhisperKitMLXTests/MLXUnitTests.swift index c9f6dbf..2368b3b 100644 --- a/Tests/WhisperKitMLXTests/MLXUnitTests.swift +++ b/Tests/WhisperKitMLXTests/MLXUnitTests.swift @@ -5,13 +5,19 @@ import XCTest import MLX import WhisperKitTestsUtils import CoreML +import NaturalLanguage @testable import WhisperKit @testable import WhisperKitMLX final class MLXUnitTests: XCTestCase { - + private var tinyModelPath: String! private let accuracy: Float = 0.00001 + override func setUp() async throws { + try await super.setUp() + self.tinyModelPath = try tinyMLXModelPath() + } + // MARK: - Feature Extractor Tests func testLogmelOutput() async throws { @@ -33,8 +39,11 @@ final class MLXUnitTests: XCTestCase { func testEncoderOutput() async throws { let audioEncoder = MLXAudioEncoder() - let modelPath = try URL(filePath: tinyMLXModelPath()) - try await audioEncoder.loadModel(at: modelPath) + let modelPath = URL(filePath: tinyModelPath) + try await audioEncoder.loadModel( + at: modelPath.appending(path: "encoder.safetensors"), + configPath: modelPath.appending(path: "config.json") + ) let encoderInput = try MLMultiArray(shape: [1, 80, 1, 3000], dataType: .float16) let expectedShape: [NSNumber] = [1, 384, 1, 1500] @@ -49,9 +58,12 @@ final class MLXUnitTests: XCTestCase { func testDecoderOutput() async throws { let textDecoder = MLXTextDecoder() let decodingOptions = DecodingOptions() - let modelPath = try URL(filePath: tinyMLXModelPath()) + let modelPath = URL(filePath: tinyModelPath) await XCTAssertNoThrowAsync( - try await textDecoder.loadModel(at: modelPath), + try await textDecoder.loadModel( + at: modelPath.appending(path: "decoder.safetensors"), + configPath: modelPath.appending(path: "config.json") + ), "Failed to load the model" ) textDecoder.tokenizer = try await XCTUnwrapAsync( @@ -78,6 +90,226 @@ final class MLXUnitTests: XCTestCase { ) } + func testDecoderLogProbThresholdDecodingFallback() async throws { + let decodingOptions = DecodingOptions( + withoutTimestamps: true, + compressionRatioThreshold: nil, + logProbThreshold: 1000.0, + firstTokenLogProbThreshold: nil, + noSpeechThreshold: nil + ) + let textDecoder = MLXTextDecoder() + let modelPath = URL(filePath: tinyModelPath) + try await textDecoder.loadModel( + at: modelPath.appending(path: "decoder.safetensors"), + configPath: modelPath.appending(path: "config.json") + ) + textDecoder.tokenizer = try await loadTokenizer(for: .tiny) + + let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions) + + let encoderInput = initMLMultiArray(shape: [1, 384, 1, 1500], dataType: .float16, initialValue: FloatType(0)) + let inputs = try textDecoder.prepareDecoderInputs(withPrompt: [textDecoder.tokenizer!.specialTokens.startOfTranscriptToken]) + let decoderOutput = try await textDecoder.decodeText(from: encoderInput, using: inputs, sampler: tokenSampler, options: decodingOptions) + + let fallback = try XCTUnwrap(decoderOutput.fallback, "Fallback should not be `nil`") + XCTAssertEqual(fallback.fallbackReason, "logProbThreshold") + XCTAssertTrue(fallback.needsFallback) + } + + func testDecoderFirstTokenLogProbThresholdDecodingFallback() async throws { + let decodingOptions = DecodingOptions( + withoutTimestamps: true, + compressionRatioThreshold: nil, + logProbThreshold: nil, + firstTokenLogProbThreshold: 1000.0, + noSpeechThreshold: nil + ) + let textDecoder = MLXTextDecoder() + let modelPath = URL(filePath: tinyModelPath) + try await textDecoder.loadModel( + at: modelPath.appending(path: "decoder.safetensors"), + configPath: modelPath.appending(path: "config.json") + ) + textDecoder.tokenizer = try await loadTokenizer(for: .tiny) + + let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions) + + let encoderInput = initMLMultiArray(shape: [1, 384, 1, 1500], dataType: .float16, initialValue: FloatType(0)) + let inputs = try textDecoder.prepareDecoderInputs(withPrompt: [textDecoder.tokenizer!.specialTokens.startOfTranscriptToken]) + let decoderOutput = try await textDecoder.decodeText(from: encoderInput, using: inputs, sampler: tokenSampler, options: decodingOptions) + + let fallback = try XCTUnwrap(decoderOutput.fallback, "Fallback should not be `nil`") + XCTAssertEqual(fallback.fallbackReason, "firstTokenLogProbThreshold") + XCTAssertTrue(fallback.needsFallback) + } + + // MARK: - Options Tests + + /// Multilingual Tests + /// NOTE: These are purely for consistency checks and do not reflect the ground truth translations + func testTranslateSpanish() async throws { + let targetLanguage = "es" + let options = DecodingOptions(task: .translate, language: targetLanguage, temperatureFallbackCount: 0) + + let result = try await XCTUnwrapAsync( + try await transcribe( + modelPath: tinyModelPath, + options: options, + audioFile: "es_test_clip.wav", + featureExtractor: MLXFeatureExtractor(), + audioEncoder: MLXAudioEncoder(), + textDecoder: MLXTextDecoder() + ), + "Failed to transcribe" + ) + + XCTAssertEqual(result.text.split(separator: " ").prefix(2).joined(separator: " "), "This is") + } + + func testTranscribeSpanish() async throws { + let sourceLanguage = "es" + let options = DecodingOptions(task: .transcribe, language: sourceLanguage, temperatureFallbackCount: 0) + + let result = try await XCTUnwrapAsync( + try await transcribe( + modelPath: tinyModelPath, + options: options, + audioFile: "es_test_clip.wav", + featureExtractor: MLXFeatureExtractor(), + audioEncoder: MLXAudioEncoder(), + textDecoder: MLXTextDecoder() + ), + "Failed to transcribe" + ) + + XCTAssertEqual(result.text.split(separator: " ").prefix(4).joined(separator: " "), "Esta es una grabación") + } + + func testDetectSpanish() async throws { + let targetLanguage = "es" + let whisperKit = try await WhisperKit( + modelFolder: tinyModelPath, + featureExtractor: MLXFeatureExtractor(), + audioEncoder: MLXAudioEncoder(), + textDecoder: MLXTextDecoder(), + verbose: true, + logLevel: .debug + ) + + let audioFilePath = try XCTUnwrap( + TestResource.path(forResource: "es_test_clip", ofType: "wav"), + "Audio file not found" + ) + + // To detect language only, set `sampleLength` to 1 and no prefill prompt + let optionsDetectOnly = DecodingOptions(task: .transcribe, temperatureFallbackCount: 0, sampleLength: 1, detectLanguage: true) + let resultNoPrefill: [TranscriptionResult] = try await whisperKit.transcribe(audioPath: audioFilePath, decodeOptions: optionsDetectOnly) + + XCTAssertEqual(resultNoPrefill.first?.language, targetLanguage) + } + + func testTranslateJapaneseOptions() async throws { + let targetLanguage = "ja" + let options = DecodingOptions(task: .translate, language: targetLanguage, temperatureFallbackCount: 0) + + let result = try await XCTUnwrapAsync( + try await transcribe( + modelPath: tinyModelPath, + options: options, + audioFile: "ja_test_clip.wav", + featureExtractor: MLXFeatureExtractor(), + audioEncoder: MLXAudioEncoder(), + textDecoder: MLXTextDecoder() + ), + "Failed to transcribe" + ) + + XCTAssertEqual(result.text.split(separator: " ").first, "Tokyo") + } + + func testTranscribeJapanese() async throws { + let sourceLanguage = "ja" + let options = DecodingOptions(task: .transcribe, language: sourceLanguage, temperatureFallbackCount: 0) + + let result = try await XCTUnwrapAsync( + try await transcribe( + modelPath: tinyModelPath, + options: options, + audioFile: "ja_test_clip.wav", + featureExtractor: MLXFeatureExtractor(), + audioEncoder: MLXAudioEncoder(), + textDecoder: MLXTextDecoder() + ), + "Failed to transcribe" + ) + + XCTAssertEqual(result.text.prefix(3), "東京は") + } + + func testDetectJapanese() async throws { + let targetLanguage = "ja" + let whisperKit = try await WhisperKit( + modelFolder: tinyModelPath, + featureExtractor: MLXFeatureExtractor(), + audioEncoder: MLXAudioEncoder(), + textDecoder: MLXTextDecoder(), + verbose: true, + logLevel: .debug + ) + + let audioFilePath = try XCTUnwrap( + TestResource.path(forResource: "ja_test_clip", ofType: "wav"), + "Audio file not found" + ) + + // To detect language only, set `sampleLength` to 1 and no prefill prompt + let optionsDetectOnly = DecodingOptions(task: .transcribe, temperatureFallbackCount: 0, sampleLength: 1, detectLanguage: true) + let result: [TranscriptionResult] = try await whisperKit.transcribe(audioPath: audioFilePath, decodeOptions: optionsDetectOnly) + + XCTAssertEqual(result.first?.language, targetLanguage) + } + + func testDetectJapaneseOptions() async throws { + let optionsPairs: [(options: DecodingOptions, language: String)] = [ + (DecodingOptions(task: .transcribe, temperatureFallbackCount: 0, usePrefillPrompt: true, detectLanguage: true), "ja"), // recommended usage for transcribing unknown language + (DecodingOptions(task: .transcribe, temperatureFallbackCount: 0, usePrefillPrompt: true, detectLanguage: false), "en"), // en is the default prompt language + (DecodingOptions(task: .transcribe, temperatureFallbackCount: 0, usePrefillPrompt: true, detectLanguage: nil), "en"), // en is the default prompt language + (DecodingOptions(task: .transcribe, temperatureFallbackCount: 0, usePrefillPrompt: false, detectLanguage: true), "ja"), // Unecessary combination, but can be useful if used with low `sampleLength` values to purely detect language and not decode (see above) + (DecodingOptions(task: .transcribe, temperatureFallbackCount: 0, usePrefillPrompt: false, detectLanguage: false), "ja"), // no prefill, model will detect language naturally + (DecodingOptions(task: .transcribe, temperatureFallbackCount: 0, usePrefillPrompt: false, detectLanguage: nil), "ja"), // no prefill, model will detect language naturally + ] + + for (i, option) in optionsPairs.enumerated() { + let result = try await XCTUnwrapAsync( + try await transcribe( + modelPath: tinyModelPath, + options: option.options, + audioFile: "ja_test_clip.wav", + featureExtractor: MLXFeatureExtractor(), + audioEncoder: MLXAudioEncoder(), + textDecoder: MLXTextDecoder() + ), + "Failed to transcribe" + ) + + let recognizer = NLLanguageRecognizer() + recognizer.processString(result.text) + let languageCode = recognizer.dominantLanguage!.rawValue + + XCTAssertEqual( + languageCode, + option.language, + "Text language \"\(languageCode)\" at index \(i) did not match expected language \"\(option.language)\"" + ) + XCTAssertEqual( + result.first?.language, + option.language, + "Result language \"\(String(describing: result.first?.language))\" at index \(i) did not match expected language \"\(option.language)\"" + ) + } + } + // MARK: - Utils Tests func testArrayConversion() throws { diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index 2b0fa16..aaf39ad 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -12,6 +12,13 @@ import XCTest @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) final class UnitTests: XCTestCase { + private var tinyModelPath: String! + + override func setUp() async throws { + try await super.setUp() + self.tinyModelPath = try tinyModelPath() + } + // MARK: - Model Loading Test func testInit() async throws { @@ -23,7 +30,7 @@ final class UnitTests: XCTestCase { func testInitTiny() async throws { try await XCTUnwrapAsync( - await WhisperKit(modelFolder: tinyModelPath(), logLevel: .error), + await WhisperKit(modelFolder: tinyModelPath, logLevel: .error), "Failed to init WhisperKit" ) } @@ -98,7 +105,7 @@ final class UnitTests: XCTestCase { "Failed to pad audio samples" ) let featureExtractor = FeatureExtractor() - let modelPath = try URL(filePath: tinyModelPath()).appending(path: "MelSpectrogram.mlmodelc") + let modelPath = URL(filePath: tinyModelPath).appending(path: "MelSpectrogram.mlmodelc") try await featureExtractor.loadModel(at: modelPath, computeUnits: ModelComputeOptions().melCompute) let melSpectrogram = try await XCTUnwrapAsync( await featureExtractor.logMelSpectrogram(fromAudio: paddedSamples), @@ -137,7 +144,7 @@ final class UnitTests: XCTestCase { func testEncoderOutput() async throws { let audioEncoder = AudioEncoder() - let modelPath = try URL(filePath: tinyModelPath()).appending(path: "AudioEncoder.mlmodelc") + let modelPath = URL(filePath: tinyModelPath).appending(path: "AudioEncoder.mlmodelc") try? await audioEncoder.loadModel(at: modelPath, computeUnits: ModelComputeOptions().audioEncoderCompute) let encoderInput = try MLMultiArray(shape: [1, 80, 1, 3000], dataType: .float16) @@ -153,7 +160,7 @@ final class UnitTests: XCTestCase { func testDecoderOutput() async throws { let textDecoder = TextDecoder() let decodingOptions = DecodingOptions() - let modelPath = try URL(filePath: tinyModelPath()).appending(path: "TextDecoder.mlmodelc") + let modelPath = URL(filePath: tinyModelPath).appending(path: "TextDecoder.mlmodelc") await XCTAssertNoThrowAsync( try await textDecoder.loadModel(at: modelPath, computeUnits: ModelComputeOptions().textDecoderCompute), "Failed to load the model" @@ -191,7 +198,7 @@ final class UnitTests: XCTestCase { noSpeechThreshold: nil ) let textDecoder = TextDecoder() - let modelPath = try URL(filePath: tinyModelPath()).appending(path: "TextDecoder.mlmodelc") + let modelPath = URL(filePath: tinyModelPath).appending(path: "TextDecoder.mlmodelc") try await textDecoder.loadModel(at: modelPath, computeUnits: ModelComputeOptions().textDecoderCompute) textDecoder.tokenizer = try await loadTokenizer(for: .tiny) @@ -215,7 +222,7 @@ final class UnitTests: XCTestCase { noSpeechThreshold: nil ) let textDecoder = TextDecoder() - let modelPath = try URL(filePath: tinyModelPath()).appending(path: "TextDecoder.mlmodelc") + let modelPath = URL(filePath: tinyModelPath).appending(path: "TextDecoder.mlmodelc") try await textDecoder.loadModel(at: modelPath, computeUnits: ModelComputeOptions().textDecoderCompute) textDecoder.tokenizer = try await loadTokenizer(for: .tiny) @@ -301,7 +308,7 @@ final class UnitTests: XCTestCase { } let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options, callback: continuationCallback).first!, + try await transcribe(modelPath: tinyModelPath, options: options, callback: continuationCallback).first!, "Failed to transcribe" ) @@ -316,7 +323,7 @@ final class UnitTests: XCTestCase { } let resultWithWait = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options, callback: continuationCallbackWithWait).first!, + try await transcribe(modelPath: tinyModelPath, options: options, callback: continuationCallbackWithWait).first!, "Failed to transcribe" ) @@ -388,7 +395,7 @@ final class UnitTests: XCTestCase { melCompute: .cpuOnly ) let whisperKit = try await WhisperKit( - modelFolder: tinyModelPath(), + modelFolder: tinyModelPath, computeOptions: computeOptions, verbose: true, logLevel: .debug @@ -483,7 +490,7 @@ final class UnitTests: XCTestCase { for option in options { let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: option), + try await transcribe(modelPath: tinyModelPath, options: option), "Failed to transcribe" ) XCTAssertEqual(result.segments.first?.tokens.count, targetTokenCount) @@ -497,7 +504,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(task: .translate, language: targetLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options, audioFile: "es_test_clip.wav"), + try await transcribe(modelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav"), "Failed to transcribe" ) @@ -509,7 +516,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(task: .transcribe, language: sourceLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options, audioFile: "es_test_clip.wav"), + try await transcribe(modelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav"), "Failed to transcribe" ) @@ -519,7 +526,7 @@ final class UnitTests: XCTestCase { func testDetectSpanish() async throws { let targetLanguage = "es" let whisperKit = try await WhisperKit( - modelFolder: tinyModelPath(), + modelFolder: tinyModelPath, verbose: true, logLevel: .debug ) @@ -549,7 +556,7 @@ final class UnitTests: XCTestCase { for (i, option) in optionsPairs.enumerated() { let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: option.options, audioFile: "es_test_clip.wav"), + try await transcribe(modelPath: tinyModelPath, options: option.options, audioFile: "es_test_clip.wav"), "Failed to transcribe" ) @@ -575,7 +582,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(task: .translate, language: targetLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options, audioFile: "ja_test_clip.wav"), + try await transcribe(modelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav"), "Failed to transcribe" ) @@ -587,7 +594,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(task: .transcribe, language: sourceLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options, audioFile: "ja_test_clip.wav"), + try await transcribe(modelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav"), "Failed to transcribe" ) @@ -597,7 +604,7 @@ final class UnitTests: XCTestCase { func testDetectJapanese() async throws { let targetLanguage = "ja" let whisperKit = try await WhisperKit( - modelFolder: tinyModelPath(), + modelFolder: tinyModelPath, verbose: true, logLevel: .debug ) @@ -626,7 +633,7 @@ final class UnitTests: XCTestCase { for (i, option) in optionsPairs.enumerated() { let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: option.options, audioFile: "ja_test_clip.wav"), + try await transcribe(modelPath: tinyModelPath, options: option.options, audioFile: "ja_test_clip.wav"), "Failed to transcribe" ) @@ -650,7 +657,7 @@ final class UnitTests: XCTestCase { func testDetectLanguageHelperMethod() async throws { let targetLanguages = ["es", "ja"] let whisperKit = try await WhisperKit( - modelFolder: tinyModelPath(), + modelFolder: tinyModelPath, verbose: true, logLevel: .debug ) @@ -672,7 +679,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(withoutTimestamps: true) let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options), + try await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -683,7 +690,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(skipSpecialTokens: true, withoutTimestamps: true) let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options), + try await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -694,7 +701,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(usePrefillPrompt: true) try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options), + try await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) } @@ -703,7 +710,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(usePrefillPrompt: false) let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options), + try await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -711,7 +718,7 @@ final class UnitTests: XCTestCase { } func testSilence() async throws { - let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) + let whisperKit = try await WhisperKit(modelFolder: tinyModelPath, verbose: true, logLevel: .debug) let audioSamples = [Float](repeating: 0.0, count: 30 * 16000) let options = DecodingOptions(usePrefillPrompt: false, skipSpecialTokens: false) @@ -723,7 +730,7 @@ final class UnitTests: XCTestCase { } func testTemperatureIncrement() async throws { - let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) + let whisperKit = try await WhisperKit(modelFolder: tinyModelPath, verbose: true, logLevel: .debug) // Generate random audio samples let audioSamples = (0..<(30 * 16000)).map { _ in Float.random(in: -0.7...0.7) } @@ -750,11 +757,11 @@ final class UnitTests: XCTestCase { func testTopK() async throws { let result10000 = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: DecodingOptions(temperature: 0.5, topK: 10000)).first, + try await transcribe(modelPath: tinyModelPath, options: DecodingOptions(temperature: 0.5, topK: 10000)).first, "Failed to transcribe" ) let result5 = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: DecodingOptions(temperature: 0.5)).first, + try await transcribe(modelPath: tinyModelPath, options: DecodingOptions(temperature: 0.5)).first, "Failed to transcribe" ) @@ -765,7 +772,7 @@ final class UnitTests: XCTestCase { var options = DecodingOptions(withoutTimestamps: true, clipTimestamps: [0]) let resultFull = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options), + try await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -773,7 +780,7 @@ final class UnitTests: XCTestCase { options = DecodingOptions(withoutTimestamps: true, clipTimestamps: [seekTime]) let resultSeek = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options), + try await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -786,14 +793,14 @@ final class UnitTests: XCTestCase { } func testPromptTokens() async throws { - let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) + let whisperKit = try await WhisperKit(modelFolder: tinyModelPath, verbose: true, logLevel: .debug) let promptText = " prompt to encourage output without any punctuation and without capitalizing americans as if it was already normalized" let tokenizer = try XCTUnwrap(whisperKit.tokenizer) let promptTokens = tokenizer.encode(text: promptText).filter { $0 < tokenizer.specialTokens.specialTokenBegin } let options = DecodingOptions(skipSpecialTokens: true, promptTokens: promptTokens) let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options), + try await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -801,7 +808,7 @@ final class UnitTests: XCTestCase { } func testPrefixTokens() async throws { - let whisperKit = try await WhisperKit(modelFolder: tinyModelPath(), verbose: true, logLevel: .debug) + let whisperKit = try await WhisperKit(modelFolder: tinyModelPath, verbose: true, logLevel: .debug) // Prefix to encourage output without any punctuation and without capitalizing americans as if it was already normalized let prefixText = " and so my fellow americans" let tokenizer = try XCTUnwrap(whisperKit.tokenizer) @@ -809,7 +816,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(skipSpecialTokens: true, prefixTokens: prefixTokens) let result = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options), + try await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -1153,14 +1160,14 @@ final class UnitTests: XCTestCase { func testVADAudioChunkerAccuracy() async throws { let testResult = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: DecodingOptions(), audioFile: "ted_60.m4a"), + try await transcribe(modelPath: tinyModelPath, options: DecodingOptions(), audioFile: "ted_60.m4a"), "Failed to transcribe" ) let options = DecodingOptions(chunkingStrategy: .vad) let chunkedResult = try await XCTUnwrapAsync( - await transcribe(with: .tiny, options: options, audioFile: "ted_60.m4a"), + try await transcribe(modelPath: tinyModelPath, options: options, audioFile: "ted_60.m4a"), "Failed to transcribe" ) @@ -1444,14 +1451,10 @@ final class UnitTests: XCTestCase { } } - func testWordTimestampCorrectness() async { + func testWordTimestampCorrectness() async throws { let options = DecodingOptions(wordTimestamps: true) - guard let result = try? await transcribe(with: .tiny, options: options) else { - XCTFail("Failed to transcribe") - return - } - + let result = try await transcribe(modelPath: tinyModelPath, options: options) let wordTimings = result.segments.compactMap { $0.words }.flatMap { $0 } let expectedWordTimings = [ @@ -1496,9 +1499,11 @@ final class UnitTests: XCTestCase { func testStreamingTimestamps() async throws { let options = DecodingOptions(usePrefillPrompt: true, wordTimestamps: true) - let modelPath = try tinyModelPath() - - let whisperKit = try await WhisperKit(modelFolder: modelPath, /* computeOptions: computeOptions,*/ verbose: true, logLevel: .debug) + let whisperKit = try await WhisperKit( + modelFolder: tinyModelPath, + verbose: true, + logLevel: .debug + ) let audioFilePath = try XCTUnwrap( TestResource.path(forResource: "jfk", ofType: "wav"), From 4d24e43734974788297e78e9aff898faa9bcc82d Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Fri, 28 Jun 2024 08:42:15 -0700 Subject: [PATCH 09/29] Formatting --- Makefile | 2 +- Sources/WhisperKit/MLX/MLXAudioEncoder.swift | 3 +- .../WhisperKit/MLX/MLXFeatureExtractor.swift | 28 +++++++++---------- Sources/WhisperKit/MLX/MLXTextDecoder.swift | 5 ++++ Sources/WhisperKit/MLX/MLXUtils.swift | 22 +++++++-------- 5 files changed, 33 insertions(+), 27 deletions(-) diff --git a/Makefile b/Makefile index 6f33117..3c51901 100644 --- a/Makefile +++ b/Makefile @@ -89,7 +89,7 @@ download-model: @cd $(MODEL_REPO_DIR) && \ git lfs pull --include="openai_whisper-$(MODEL)/*" -download-mlx-model: +download-mlx-models: @echo "Downloading mlx model $(MODEL)..." @$(MAKE) setup-mlx-model-repo @echo "Fetching mlx model $(MODEL)..." diff --git a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift index ab472ad..f6dcedd 100644 --- a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift +++ b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift @@ -3,14 +3,15 @@ import CoreML import MLX -import WhisperKit import MLXNN +import WhisperKit @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public class MLXAudioEncoder: AudioEncoding { public var embedSize: Int? { encoder?.nState } + private var encoder: AudioEncoder? public init() {} diff --git a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift index 0294217..29aecc9 100644 --- a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift +++ b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift @@ -1,10 +1,10 @@ // For licensing see accompanying LICENSE.md file. // Copyright © 2024 Argmax, Inc. All rights reserved. +import CoreML import Foundation import MLX import MLXFFT -import CoreML import WhisperKit @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) @@ -39,10 +39,10 @@ open class MLXFeatureExtractor: FeatureExtracting { } } -extension MLXFeatureExtractor { +public extension MLXFeatureExtractor { /// Return the Hanning window. /// Taken from [numpy](https://numpy.org/doc/stable/reference/generated/numpy.hanning.html) implementation - public static func hanningNumpy(_ size: Int) -> MLXArray { + static func hanningNumpy(_ size: Int) -> MLXArray { if size < 1 { return MLXArray([Float]()) } @@ -53,26 +53,26 @@ extension MLXFeatureExtractor { return 0.5 + 0.5 * MLX.cos(MLXArray(.pi) * n / Float(size - 1)) } - public static func hanning(_ size: Int) -> MLXArray { + static func hanning(_ size: Int) -> MLXArray { hanningNumpy(size + 1)[..<(-1)] } - public static func pad( + static func pad( _ x: MLXArray, padding: Int, padMode: PadMode = .constant ) -> MLXArray { switch padMode { - case .constant: - return MLX.padded(x, widths: [IntOrPair((padding, padding))]) - case .reflect: - let prefix = x[1 ..< padding + 1][.stride(by: -1)] - let suffix = x[-(padding + 1) ..< -1][.stride(by: -1)] - return MLX.concatenated([prefix, x, suffix]) + case .constant: + return MLX.padded(x, widths: [IntOrPair((padding, padding))]) + case .reflect: + let prefix = x[1.. MLXArray { + static func loadMelFilters(nMels: Int) -> MLXArray { precondition(nMels == 80 || nMels == 128, "Unsupported nMels: \(nMels)") let fileUrl = Bundle.module.url(forResource: "mel_filters_\(nMels)", withExtension: "npy")! return try! MLX.loadArray(url: fileUrl) diff --git a/Sources/WhisperKit/MLX/MLXTextDecoder.swift b/Sources/WhisperKit/MLX/MLXTextDecoder.swift index f46b533..2661c40 100644 --- a/Sources/WhisperKit/MLX/MLXTextDecoder.swift +++ b/Sources/WhisperKit/MLX/MLXTextDecoder.swift @@ -15,22 +15,27 @@ public final class MLXTextDecoder: TextDecoding { public var logitsSize: Int? { decoder?.nState } + public var kvCacheEmbedDim: Int? { guard let config else { return nil } return config.nTextState * config.nTextLayer } + public var kvCacheMaxSequenceLength: Int? { guard let config else { return nil } return config.nTextCtx / 2 } + public var windowSize: Int? { guard let config else { return nil } return config.nAudioCtx } + public var embedSize: Int? { guard let config else { return nil } return config.nTextState } + private var decoder: TextDecoder? private var config: MLXModelConfig? private var languageLogitsFilter: LanguageLogitsFilter? diff --git a/Sources/WhisperKit/MLX/MLXUtils.swift b/Sources/WhisperKit/MLX/MLXUtils.swift index 95d6863..5059ac1 100644 --- a/Sources/WhisperKit/MLX/MLXUtils.swift +++ b/Sources/WhisperKit/MLX/MLXUtils.swift @@ -1,10 +1,10 @@ // For licensing see accompanying LICENSE.md file. // Copyright © 2024 Argmax, Inc. All rights reserved. +import CoreML import Foundation import MLX import MLXNN -import CoreML // MARK: - Extensions @@ -58,16 +58,16 @@ extension MLXArray { extension MLXArray { func multiArrayDataType() -> MLMultiArrayDataType { switch dtype { - case .bool, .bfloat16, .complex64, - .uint8, .uint16, .uint32, .uint64, - .int8, .int16, .int64: - fatalError("Unsupported type: \(dtype)") - case .int32: - return .int32 - case .float16: - return .float16 - case .float32: - return .float32 + case .bool, .bfloat16, .complex64, + .uint8, .uint16, .uint32, .uint64, + .int8, .int16, .int64: + fatalError("Unsupported type: \(dtype)") + case .int32: + return .int32 + case .float16: + return .float16 + case .float32: + return .float32 } } } From 20549e1c030bcb09fd93d3b8458601aba33f16c5 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Fri, 12 Jul 2024 15:39:24 -0700 Subject: [PATCH 10/29] Fix merge for makefile function --- .github/workflows/unit-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 4142aab..5dff176 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -69,7 +69,7 @@ jobs: key: ${{ runner.os }}-models - name: Download Models if: steps.model-cache.outputs.cache-hit != 'true' - run: make download-model MODEL=tiny && make download-mlx-model + run: make download-model MODEL=tiny && make download-mlx-models - name: Install and discover destinations env: MLX_DISABLED: ${{ matrix.run-config['mlx-disabled'] }} From 9e9e13a8da846a3e8eda677ac5fd832daf755c69 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 13 Jul 2024 15:39:41 -0700 Subject: [PATCH 11/29] Skip plugin validation in CI --- .github/workflows/unit-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 5dff176..2dd3607 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -94,5 +94,5 @@ jobs: if: ${{ matrix.run-config['condition'] == true }} run: | set -o pipefail - xcodebuild clean build-for-testing -scheme whisperkit-Package -destination '${{ matrix.run-config['clean-destination'] }}' | xcpretty - xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}' + xcodebuild clean build-for-testing -scheme whisperkit-Package -destination '${{ matrix.run-config['clean-destination'] }}' -skipPackagePluginValidation | xcpretty + xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}' -skipPackagePluginValidation From ca462148d31f8a4c33fbb78a9bf0d51871082977 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 13 Jul 2024 15:58:41 -0700 Subject: [PATCH 12/29] Fix tests from merge --- Tests/WhisperKitTests/RegressionTests.swift | 31 +-------------------- Tests/WhisperKitTests/UnitTests.swift | 4 +-- 2 files changed, 3 insertions(+), 32 deletions(-) diff --git a/Tests/WhisperKitTests/RegressionTests.swift b/Tests/WhisperKitTests/RegressionTests.swift index 7848bdb..57e4cec 100644 --- a/Tests/WhisperKitTests/RegressionTests.swift +++ b/Tests/WhisperKitTests/RegressionTests.swift @@ -25,35 +25,6 @@ final class RegressionTests: XCTestCase { } } - func testOutputAll() async throws { - let modelPaths = try allModelPaths() - - for modelPath in modelPaths { - let modelName = modelPath.split(separator: "/").last! - print("[Integration] Testing model \(modelName)") - let audioFilePath = try XCTUnwrap( - TestResource.path(forResource: "jfk", ofType: "wav"), - "Audio file not found" - ) - - let whisperKit = try await WhisperKit( - modelFolder: modelPath, - verbose: true, - logLevel: .debug - ) - - let transcriptionResult: [TranscriptionResult] = try await whisperKit.transcribe(audioPath: audioFilePath) - let transcriptionResultText = transcriptionResult.text - - print("[Integration] \(transcriptionResultText)") - XCTAssertEqual( - transcriptionResultText.normalized, - " And so my fellow Americans ask not what your country can do for you, ask what you can do for your country.".normalized, - "Transcription result does not match expected result for model \(modelName)" - ) - } - } - func downloadTestAudio(completion: @escaping (Bool) -> Void) { Task { do { @@ -153,7 +124,7 @@ final class RegressionTests: XCTestCase { let modelName = modelPath.split(separator: "/").last! print("[Integration] Testing model \(modelName)") let audioFilePath = try XCTUnwrap( - Bundle.module.path(forResource: "jfk", ofType: "wav"), + TestResource.path(forResource: "jfk", ofType: "wav"), "Audio file not found" ) diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index c5735fe..788342e 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -96,7 +96,7 @@ final class UnitTests: XCTestCase { Logging.shared.logLevel = .debug let audioFileURL = try XCTUnwrap( - Bundle.module.url(forResource: "jfk", withExtension: "wav"), + TestResource.url(forResource: "jfk", withExtension: "wav"), "Audio file not found" ) let audioFile = try AVAudioFile(forReading: audioFileURL) @@ -1269,7 +1269,7 @@ final class UnitTests: XCTestCase { } } _ = try await pipe.transcribe( - audioPath: Bundle.module.path(forResource: "ted_60", ofType: "m4a")!, + audioPath: TestResource.path(forResource: "ted_60", ofType: "m4a")!, decodeOptions: .init(chunkingStrategy: .vad) ) cancellable?.cancel() From 1615d69a5c18bbd8d5684035dc7fdb23e5eab014 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 13 Jul 2024 17:31:52 -0700 Subject: [PATCH 13/29] Update model paths --- Makefile | 11 +++++++---- Package.swift | 15 +++++++++++++-- Sources/WhisperKitTestsUtils/TestUtils.swift | 4 ++-- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index 3c51901..58af4a3 100644 --- a/Makefile +++ b/Makefile @@ -5,10 +5,10 @@ PYTHON_COMMAND := python3 # Define model repository and directories MODEL_REPO := argmaxinc/whisperkit-coreml -MLX_MODEL_REPO := jkrukowski/whisper-tiny-mlx-safetensors -MLX_MODEL_REPO_DIR := ./Sources/WhisperKitTestsUtils/Models/mlx/whisper-tiny-mlx -MODEL_REPO_DIR := ./Sources/WhisperKitTestsUtils/Models/whisperkit-coreml -BASE_COMPILED_DIR := ./Sources/WhisperKitTestsUtils/Models +MODEL_REPO_DIR := ./Models/whisperkit-coreml +MLX_MODEL_REPO := argmaxinc/whisperkit-mlx +MLX_MODEL_REPO_DIR := ./Models/whisperkit-mlx +BASE_COMPILED_DIR := ./Models setup: @@ -76,6 +76,7 @@ download-models: setup-model-repo @echo "Downloading all models..." @cd $(MODEL_REPO_DIR) && \ git lfs pull + @echo "CoreML models downloaded to $(MODEL_REPO_DIR)" # Download a specific model download-model: @@ -88,6 +89,7 @@ download-model: @echo "Fetching model $(MODEL)..." @cd $(MODEL_REPO_DIR) && \ git lfs pull --include="openai_whisper-$(MODEL)/*" + @echo "CoreML model $(MODEL) downloaded to $(MODEL_REPO_DIR)/openai_whisper-$(MODEL)" download-mlx-models: @echo "Downloading mlx model $(MODEL)..." @@ -95,6 +97,7 @@ download-mlx-models: @echo "Fetching mlx model $(MODEL)..." @cd $(MLX_MODEL_REPO_DIR) && \ git lfs pull + @echo "MLX models downloaded to $(MLX_MODEL_REPO_DIR)" build: @echo "Building WhisperKit..." diff --git a/Package.swift b/Package.swift index dd76229..09b33f8 100644 --- a/Package.swift +++ b/Package.swift @@ -78,10 +78,21 @@ func targets() -> [PackageDescription.Target] { "WhisperKit", .product(name: "Transformers", package: "swift-transformers"), ], + path: ".", + exclude: [ + "Examples", + "Sources/WhisperKit", + "Sources/WhisperKitCLI", + "Tests", + "Makefile", + "README.md", + "LICENSE", + "CONTRIBUTING.md", + ], resources: [ .copy("Models/whisperkit-coreml"), - .copy("Models/mlx"), - .process("Resources"), + .copy("Models/whisperkit-mlx"), + .process("Sources/WhisperKitTestsUtils/Resources") ] ), .testTarget( diff --git a/Sources/WhisperKitTestsUtils/TestUtils.swift b/Sources/WhisperKitTestsUtils/TestUtils.swift index aa83581..7bfae95 100644 --- a/Sources/WhisperKitTestsUtils/TestUtils.swift +++ b/Sources/WhisperKitTestsUtils/TestUtils.swift @@ -176,9 +176,9 @@ public extension XCTestCase { } func tinyMLXModelPath() throws -> String { - let modelDir = "mlx/whisper-tiny-mlx" + let modelDir = "whisperkit-mlx/openai_whisper-tiny" guard let modelPath = Bundle.module.urls(forResourcesWithExtension: "safetensors", subdirectory: modelDir)?.first?.deletingLastPathComponent().path else { - throw TestError.missingFile("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-models`") + throw TestError.missingFile("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-mlx-models`") } return modelPath } From 08eb93e835982ca59cbfb54d9be9d4574bd86aad Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 13 Jul 2024 18:10:42 -0700 Subject: [PATCH 14/29] Fix model downloads --- .github/workflows/unit-tests.yml | 5 ++++- Makefile | 35 +++++++++++++++++++++----------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 2dd3607..c7f6f87 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -69,7 +69,10 @@ jobs: key: ${{ runner.os }}-models - name: Download Models if: steps.model-cache.outputs.cache-hit != 'true' - run: make download-model MODEL=tiny && make download-mlx-models + run: | + make setup-huggingface-cli + make download-model MODEL=tiny + make download-mlx-model MODEL=tiny - name: Install and discover destinations env: MLX_DISABLED: ${{ matrix.run-config['mlx-disabled'] }} diff --git a/Makefile b/Makefile index 58af4a3..2190c7b 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,15 @@ -.PHONY: setup setup-huggingface-cli setup-model-repo download-models download-model build build-cli test clean-package-caches +.PHONY: setup setup-huggingface-cli setup-model-repo download-models download-model download-mlx-models download-mlx-model build build-cli test clean-package-caches PIP_COMMAND := pip3 PYTHON_COMMAND := python3 # Define model repository and directories MODEL_REPO := argmaxinc/whisperkit-coreml -MODEL_REPO_DIR := ./Models/whisperkit-coreml MLX_MODEL_REPO := argmaxinc/whisperkit-mlx + +MODEL_REPO_DIR := ./Models/whisperkit-coreml MLX_MODEL_REPO_DIR := ./Models/whisperkit-mlx -BASE_COMPILED_DIR := ./Models +BASE_MODEL_DIR := ./Models setup: @@ -47,7 +48,7 @@ setup-huggingface-cli: setup-model-repo: @echo "Setting up repository..." - @mkdir -p $(BASE_COMPILED_DIR) + @mkdir -p $(BASE_MODEL_DIR) @if [ -d "$(MODEL_REPO_DIR)/.git" ]; then \ echo "Repository exists, resetting..."; \ export GIT_LFS_SKIP_SMUDGE=1; \ @@ -60,7 +61,7 @@ setup-model-repo: setup-mlx-model-repo: @echo "Setting up mlx repository..." - @mkdir -p $(BASE_COMPILED_DIR) + @mkdir -p $(BASE_MODEL_DIR) @if [ -d "$(MLX_MODEL_REPO_DIR)/.git" ]; then \ echo "Repository exists, resetting..."; \ export GIT_LFS_SKIP_SMUDGE=1; \ @@ -71,6 +72,7 @@ setup-mlx-model-repo: git clone https://huggingface.co/$(MLX_MODEL_REPO) $(MLX_MODEL_REPO_DIR); \ fi + # Download all models download-models: setup-model-repo @echo "Downloading all models..." @@ -78,27 +80,36 @@ download-models: setup-model-repo git lfs pull @echo "CoreML models downloaded to $(MODEL_REPO_DIR)" + # Download a specific model -download-model: +download-model: setup-model-repo @if [ -z "$(MODEL)" ]; then \ echo "Error: MODEL is not set. Usage: make download-model MODEL=base"; \ exit 1; \ fi @echo "Downloading model $(MODEL)..." - @$(MAKE) setup-model-repo - @echo "Fetching model $(MODEL)..." @cd $(MODEL_REPO_DIR) && \ git lfs pull --include="openai_whisper-$(MODEL)/*" @echo "CoreML model $(MODEL) downloaded to $(MODEL_REPO_DIR)/openai_whisper-$(MODEL)" -download-mlx-models: - @echo "Downloading mlx model $(MODEL)..." - @$(MAKE) setup-mlx-model-repo - @echo "Fetching mlx model $(MODEL)..." + +download-mlx-models: setup-mlx-model-repo + @echo "Downloading all mlx models..." @cd $(MLX_MODEL_REPO_DIR) && \ git lfs pull @echo "MLX models downloaded to $(MLX_MODEL_REPO_DIR)" + +download-mlx-model: setup-mlx-model-repo + @if [ -z "$(MODEL)" ]; then \ + echo "Error: MODEL is not set. Usage: make download-mlx-model MODEL=base"; \ + exit 1; \ + fi + @echo "Downloading mlx model $(MODEL)..." + @cd $(MLX_MODEL_REPO_DIR) && \ + git lfs pull --include="openai_whisper-$(MODEL)/*" + @echo "MLX model $(MODEL) downloaded to $(MLX_MODEL_REPO_DIR)/openai_whisper-mlx-$(MODEL)" + build: @echo "Building WhisperKit..." @swift build -v From 5432a8fa1c56ac66430820d8e068e9685f42e929 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 13 Jul 2024 18:31:44 -0700 Subject: [PATCH 15/29] Fix HF auth --- .github/workflows/unit-tests.yml | 2 ++ Makefile | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index c7f6f87..a8a9585 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -14,6 +14,8 @@ jobs: unit-tests: name: "${{ matrix.run-config['name'] }} on ${{ inputs.macos-runner }}" runs-on: ${{ inputs.macos-runner }} + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} strategy: matrix: run-config: diff --git a/Makefile b/Makefile index 2190c7b..c352607 100644 --- a/Makefile +++ b/Makefile @@ -34,6 +34,7 @@ setup: setup-huggingface-cli: @if huggingface-cli whoami; then \ echo "Already logged in to Hugging Face."; \ + huggingface-cli whoami \ else \ echo "Not logged in to Hugging Face."; \ if [ -z "$$HF_TOKEN" ]; then \ @@ -84,7 +85,7 @@ download-models: setup-model-repo # Download a specific model download-model: setup-model-repo @if [ -z "$(MODEL)" ]; then \ - echo "Error: MODEL is not set. Usage: make download-model MODEL=base"; \ + echo "Error: MODEL is not set. Usage: make download-model MODEL=tiny"; \ exit 1; \ fi @echo "Downloading model $(MODEL)..." @@ -102,7 +103,7 @@ download-mlx-models: setup-mlx-model-repo download-mlx-model: setup-mlx-model-repo @if [ -z "$(MODEL)" ]; then \ - echo "Error: MODEL is not set. Usage: make download-mlx-model MODEL=base"; \ + echo "Error: MODEL is not set. Usage: make download-mlx-model MODEL=tiny"; \ exit 1; \ fi @echo "Downloading mlx model $(MODEL)..." From 48cf8ffa626f9eabec8a2e2fd3df88bd39b7007d Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 13 Jul 2024 18:38:14 -0700 Subject: [PATCH 16/29] Fix HF login script --- Makefile | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index c352607..309b74a 100644 --- a/Makefile +++ b/Makefile @@ -32,10 +32,7 @@ setup: setup-huggingface-cli: - @if huggingface-cli whoami; then \ - echo "Already logged in to Hugging Face."; \ - huggingface-cli whoami \ - else \ + @if huggingface-cli whoami 2>&1 | grep -q "Not logged in"; then \ echo "Not logged in to Hugging Face."; \ if [ -z "$$HF_TOKEN" ]; then \ echo "Environment variable HF_TOKEN is not set. Running normal login."; \ @@ -44,6 +41,9 @@ setup-huggingface-cli: echo "Using HF_TOKEN from environment variable."; \ huggingface-cli login --token $$HF_TOKEN; \ fi; \ + else \ + echo "Already logged in to Hugging Face."; \ + huggingface-cli whoami; \ fi From b274e2fb0f036aa663a5e2d8973683cd86679f9e Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 13 Jul 2024 18:46:53 -0700 Subject: [PATCH 17/29] Include hf token in download step --- .github/workflows/unit-tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index a8a9585..adc75cc 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -71,6 +71,8 @@ jobs: key: ${{ runner.os }}-models - name: Download Models if: steps.model-cache.outputs.cache-hit != 'true' + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | make setup-huggingface-cli make download-model MODEL=tiny From 4b3595280b39c52e532c063f70d185eb0b7f3bcb Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 13 Jul 2024 18:55:08 -0700 Subject: [PATCH 18/29] Remove hf login in favor up update model repo permissions --- .github/workflows/unit-tests.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index adc75cc..73e4132 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -14,8 +14,6 @@ jobs: unit-tests: name: "${{ matrix.run-config['name'] }} on ${{ inputs.macos-runner }}" runs-on: ${{ inputs.macos-runner }} - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} strategy: matrix: run-config: @@ -71,10 +69,7 @@ jobs: key: ${{ runner.os }}-models - name: Download Models if: steps.model-cache.outputs.cache-hit != 'true' - env: - HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | - make setup-huggingface-cli make download-model MODEL=tiny make download-mlx-model MODEL=tiny - name: Install and discover destinations From a139839d9b3427256f3428ccd750ceaf80dbcf39 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 13 Jul 2024 19:01:03 -0700 Subject: [PATCH 19/29] Use scheme from run config --- .github/workflows/unit-tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 73e4132..255f357 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -96,5 +96,5 @@ jobs: if: ${{ matrix.run-config['condition'] == true }} run: | set -o pipefail - xcodebuild clean build-for-testing -scheme whisperkit-Package -destination '${{ matrix.run-config['clean-destination'] }}' -skipPackagePluginValidation | xcpretty - xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme whisperkit-Package -destination '${{ matrix.run-config['test-destination'] }}' -skipPackagePluginValidation + xcodebuild clean build-for-testing -scheme ${{ matrix.run-config['scheme'] }} -destination '${{ matrix.run-config['clean-destination'] }}' -skipPackagePluginValidation | xcpretty + xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme ${{ matrix.run-config['scheme'] }} -destination '${{ matrix.run-config['test-destination'] }}' -skipPackagePluginValidation From d7cf7a65c160ae0639c0bc58f8e1d69d22f2ee50 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Mon, 15 Jul 2024 08:24:55 -0700 Subject: [PATCH 20/29] Use fixed mlx-swift version --- Package.resolved | 4 ++-- Package.swift | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Package.resolved b/Package.resolved index d6aa442..7e2e0e7 100644 --- a/Package.resolved +++ b/Package.resolved @@ -5,8 +5,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "branch" : "main", - "revision" : "3c802c808d281c191d5f26f37a4f93135d8ca119" + "revision" : "084597e7ec38d97a92e50855b04de005c7dad1df", + "version" : "0.15.2" } }, { diff --git a/Package.swift b/Package.swift index 09b33f8..1b90f72 100644 --- a/Package.swift +++ b/Package.swift @@ -58,7 +58,7 @@ func mlxDependencies() -> [PackageDescription.Package.Dependency] { return [] } else { return [ - .package(url: "https://github.com/ml-explore/mlx-swift", branch: "main"), + .package(url: "https://github.com/ml-explore/mlx-swift", exact: "0.15.2"), ] } } From 2ea84262df68ab3d8cfa509e0ef3b3a9db311f28 Mon Sep 17 00:00:00 2001 From: Jan Krukowski Date: Sun, 11 Aug 2024 00:41:48 +0200 Subject: [PATCH 21/29] MLX Cleanup (#187) * update mlx-swift * - reverted mlx version - updated readme - updated makefile * reversed * fixed tests * updated mlx-swift * updated makefile * remove device change, fft can run on gpu now * updated readme, added tests * updated readme * review changes * review changes * CI model cache path fix * tests failed * Update package.swift * Keep setupModels with adjustments * Test CI skip cache * Test CI package name change * Test CI optional CLI settings * Use correct logits size for MLX --------- Co-authored-by: ZachNagengast --- .github/workflows/unit-tests.yml | 13 +- .gitignore | 1 + .swiftpm/configuration/Package.resolved | 41 ++++ .../xcshareddata/swiftpm/Package.resolved | 6 +- Makefile | 28 ++- Package.resolved | 4 +- Package.swift | 182 +++++++++--------- README.md | 84 ++++++-- Sources/WhisperKit/Core/TextDecoder.swift | 2 + Sources/WhisperKit/Core/Utils.swift | 2 +- Sources/WhisperKit/Core/WhisperKit.swift | 110 +++++++---- .../WhisperKit/MLX/MLXFeatureExtractor.swift | 8 +- Sources/WhisperKit/MLX/MLXTextDecoder.swift | 2 +- Sources/WhisperKit/MLX/MLXUtils.swift | 16 +- Sources/WhisperKitCLI/CLIArguments.swift | 23 +++ Sources/WhisperKitCLI/TranscribeCLI.swift | 46 +++++ Sources/WhisperKitTestsUtils/TestUtils.swift | 10 +- Tests/WhisperKitMLXTests/MLXUnitTests.swift | 42 ++-- 18 files changed, 441 insertions(+), 179 deletions(-) create mode 100644 .swiftpm/configuration/Package.resolved diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 255f357..b7951fd 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -20,7 +20,6 @@ jobs: - { name: "macOS", condition: true, - clean-destination: "generic/platform=macOS", test-destination: "platform=macOS,arch=arm64", test-cases: "-only-testing WhisperKitTests/UnitTests -only-testing WhisperKitMLXTests/MLXUnitTests", mlx-disabled: "0", @@ -29,7 +28,6 @@ jobs: - { name: "iOS", condition: true, - clean-destination: "generic/platform=iOS", test-destination: "platform=iOS Simulator,OS=${{ inputs.ios-version }},name=iPhone 15", test-cases: "-only-testing WhisperKitTests/UnitTests", mlx-disabled: "1", @@ -38,7 +36,6 @@ jobs: - { name: "watchOS", condition: "${{ inputs.macos-runner == 'macos-14' }}", - clean-destination: "generic/platform=watchOS", test-destination: "platform=watchOS Simulator,OS=10.5,name=Apple Watch Ultra 2 (49mm)", test-cases: "-only-testing WhisperKitTests/UnitTests", mlx-disabled: "1", @@ -47,7 +44,6 @@ jobs: - { name: "visionOS", condition: "${{ inputs.macos-runner == 'macos-14' }}", - clean-destination: "generic/platform=visionOS", test-destination: "platform=visionOS Simulator,name=Apple Vision Pro", test-cases: "-only-testing WhisperKitTests/UnitTests", mlx-disabled: "1", @@ -65,7 +61,7 @@ jobs: id: model-cache uses: actions/cache@v4 with: - path: Sources/WhisperKitTestsUtils/Models + path: Models key: ${{ runner.os }}-models - name: Download Models if: steps.model-cache.outputs.cache-hit != 'true' @@ -96,5 +92,8 @@ jobs: if: ${{ matrix.run-config['condition'] == true }} run: | set -o pipefail - xcodebuild clean build-for-testing -scheme ${{ matrix.run-config['scheme'] }} -destination '${{ matrix.run-config['clean-destination'] }}' -skipPackagePluginValidation | xcpretty - xcodebuild test -only-testing WhisperKitTests/UnitTests -scheme ${{ matrix.run-config['scheme'] }} -destination '${{ matrix.run-config['test-destination'] }}' -skipPackagePluginValidation + xcodebuild clean build-for-testing test \ + ${{ matrix.run-config['test-cases'] }} \ + -scheme ${{ matrix.run-config['scheme'] }} \ + -destination '${{ matrix.run-config['test-destination'] }}' \ + -skipPackagePluginValidation diff --git a/.gitignore b/.gitignore index fd725dc..bb8893b 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ DerivedData/ **/*.xcscheme .netrc .env +/.vscode # Core ML Model Files Models diff --git a/.swiftpm/configuration/Package.resolved b/.swiftpm/configuration/Package.resolved new file mode 100644 index 0000000..bb2ef99 --- /dev/null +++ b/.swiftpm/configuration/Package.resolved @@ -0,0 +1,41 @@ +{ + "pins" : [ + { + "identity" : "mlx-swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/ml-explore/mlx-swift", + "state" : { + "revision" : "597aaa5f465b4b9a17c8646b751053f84e37925b", + "version" : "0.16.0" + } + }, + { + "identity" : "swift-argument-parser", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-argument-parser.git", + "state" : { + "revision" : "c8ed701b513cf5177118a175d85fbbbcd707ab41", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-numerics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-numerics", + "state" : { + "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", + "version" : "1.0.2" + } + }, + { + "identity" : "swift-transformers", + "kind" : "remoteSourceControl", + "location" : "https://github.com/huggingface/swift-transformers.git", + "state" : { + "revision" : "74b94211bdc741694ed7e700a1104c72e5ba68fe", + "version" : "0.1.7" + } + } + ], + "version" : 2 +} diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 6d2ac24..7ae7da8 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,13 +1,13 @@ { - "originHash" : "cd17206b47bb810af9459722192530e3838d8e6629a970988e32a432aaa05f6e", + "originHash" : "829222b514832cb61fe0002e0eebda98f23a75169c63f7d6ed7a320d57d5318f", "pins" : [ { "identity" : "mlx-swift", "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "branch" : "main", - "revision" : "c11212bff42a1b88aea83811210d42a5f99440ad" + "revision" : "597aaa5f465b4b9a17c8646b751053f84e37925b", + "version" : "0.16.0" } }, { diff --git a/Makefile b/Makefile index 309b74a..a1314dd 100644 --- a/Makefile +++ b/Makefile @@ -60,6 +60,7 @@ setup-model-repo: git clone https://huggingface.co/$(MODEL_REPO) $(MODEL_REPO_DIR); \ fi + setup-mlx-model-repo: @echo "Setting up mlx repository..." @mkdir -p $(BASE_MODEL_DIR) @@ -109,21 +110,40 @@ download-mlx-model: setup-mlx-model-repo @echo "Downloading mlx model $(MODEL)..." @cd $(MLX_MODEL_REPO_DIR) && \ git lfs pull --include="openai_whisper-$(MODEL)/*" - @echo "MLX model $(MODEL) downloaded to $(MLX_MODEL_REPO_DIR)/openai_whisper-mlx-$(MODEL)" + @echo "MLX model $(MODEL) downloaded to $(MLX_MODEL_REPO_DIR)/openai_whisper-$(MODEL)" + build: @echo "Building WhisperKit..." - @swift build -v + @xcodebuild CLANG_ENABLE_CODE_COVERAGE=NO VALID_ARCHS=arm64 clean build \ + -configuration Release \ + -scheme whisperkit-Package \ + -destination generic/platform=macOS \ + -derivedDataPath .build/.xcodebuild/ \ + -clonedSourcePackagesDirPath .build/ \ + -skipPackagePluginValidation build-cli: @echo "Building WhisperKit CLI..." - @swift build -c release --product whisperkit-cli + @xcodebuild CLANG_ENABLE_CODE_COVERAGE=NO VALID_ARCHS=arm64 clean build \ + -configuration Release \ + -scheme whisperkit-cli \ + -destination generic/platform=macOS \ + -derivedDataPath .build/.xcodebuild/ \ + -clonedSourcePackagesDirPath .build/ \ + -skipPackagePluginValidation test: @echo "Running tests..." - @swift test -v + @xcodebuild clean build-for-testing test \ + -scheme whisperkit-Package \ + -only-testing WhisperKitMLXTests/MLXUnitTests \ + -only-testing WhisperKitTests/UnitTests \ + -destination 'platform=macOS,arch=arm64' \ + -skipPackagePluginValidation + clean-package-caches: @trash ~/Library/Caches/org.swift.swiftpm/repositories diff --git a/Package.resolved b/Package.resolved index d6aa442..bb2ef99 100644 --- a/Package.resolved +++ b/Package.resolved @@ -5,8 +5,8 @@ "kind" : "remoteSourceControl", "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "branch" : "main", - "revision" : "3c802c808d281c191d5f26f37a4f93135d8ca119" + "revision" : "597aaa5f465b4b9a17c8646b751053f84e37925b", + "version" : "0.16.0" } }, { diff --git a/Package.swift b/Package.swift index 09b33f8..824bb0f 100644 --- a/Package.swift +++ b/Package.swift @@ -4,67 +4,26 @@ import PackageDescription import Foundation -// NOTE: `MLX` doesn't support `watchOS` yet, that's why we control the build using the `MLX_DISABLED` environment variable. -// To manualy build for `watchOS` use: -// `export MLX_DISABLED=1 && xcodebuild clean build-for-testing -scheme whisperkit -sdk watchos10.4 -destination 'platform=watchOS Simulator' -skipPackagePluginValidation` let package = Package( name: "whisperkit", platforms: [ .iOS(.v16), - .macOS("13.3") + .macOS("13.3"), + .watchOS(.v10) ], - products: products() + mlxProducts(), - dependencies: dependencies() + mlxDependencies(), - targets: targets() + mlxTargets() -) - -func products() -> [PackageDescription.Product] { - return [ + products: [ .library( name: "WhisperKit", targets: ["WhisperKit"] - ) - ] -} - -func mlxProducts() -> [PackageDescription.Product] { - let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" - if isMLXDisabled { - return [] - } else { - return [ - .library( - name: "WhisperKitMLX", - targets: ["WhisperKitMLX"] - ), - .executable( - name: "whisperkit-cli", - targets: ["WhisperKitCLI"] - ), - ] - } -} - -func dependencies() -> [PackageDescription.Package.Dependency] { - return [ + ), + ] + + cliProducts() + + mlxProducts(), + dependencies: [ .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"), .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), - ] -} - -func mlxDependencies() -> [PackageDescription.Package.Dependency] { - let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" - if isMLXDisabled { - return [] - } else { - return [ - .package(url: "https://github.com/ml-explore/mlx-swift", branch: "main"), - ] - } -} - -func targets() -> [PackageDescription.Target] { - return [ + ] + mlxDependencies(), + targets: [ .target( name: "WhisperKit", dependencies: [ @@ -103,46 +62,89 @@ func targets() -> [PackageDescription.Target] { .product(name: "Transformers", package: "swift-transformers"), ] ) + ] + + cliTargets() + + mlxTargets() +) + +// MARK: - MLX Helper Functions + +// CLI +func cliProducts() -> [Product] { + guard !isMLXDisabled() else { return [] } + return [ + .executable( + name: "whisperkit-cli", + targets: ["WhisperKitCLI"] + ), + ] +} + +func cliTargets() -> [Target] { + guard !isMLXDisabled() else { return [] } + return [ + .executableTarget( + name: "WhisperKitCLI", + dependencies: [ + "WhisperKit", + "WhisperKitMLX", + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ] + ), ] } -func mlxTargets() -> [PackageDescription.Target] { - let isMLXDisabled = ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" - if isMLXDisabled { - return [] - } else { - return [ - .executableTarget( - name: "WhisperKitCLI", - dependencies: [ - "WhisperKit", - "WhisperKitMLX", - .product(name: "ArgumentParser", package: "swift-argument-parser"), - ] - ), - .target( - name: "WhisperKitMLX", - dependencies: [ - "WhisperKit", - .product(name: "MLX", package: "mlx-swift"), - .product(name: "MLXFFT", package: "mlx-swift"), - .product(name: "MLXNN", package: "mlx-swift") - ], - path: "Sources/WhisperKit/MLX", - resources: [ - .copy("Resources/mel_filters_80.npy"), - .copy("Resources/mel_filters_128.npy") - ] - ), - .testTarget( - name: "WhisperKitMLXTests", - dependencies: [ - "WhisperKit", - "WhisperKitMLX", - "WhisperKitTestsUtils", - .product(name: "Transformers", package: "swift-transformers"), - ] - ) - ] - } +// MLX +func mlxProducts() -> [Product] { + guard !isMLXDisabled() else { return [] } + return [ + .library( + name: "WhisperKitMLX", + targets: ["WhisperKitMLX"] + ), + ] +} + +func mlxDependencies() -> [Package.Dependency] { + guard !isMLXDisabled() else { return [] } + return [ + .package(url: "https://github.com/ml-explore/mlx-swift", exact: "0.16.0"), + ] +} + +func mlxTargets() -> [Target] { + guard !isMLXDisabled() else { return [] } + return [ + .target( + name: "WhisperKitMLX", + dependencies: [ + "WhisperKit", + .product(name: "MLX", package: "mlx-swift"), + .product(name: "MLXFFT", package: "mlx-swift"), + .product(name: "MLXNN", package: "mlx-swift") + ], + path: "Sources/WhisperKit/MLX", + resources: [ + .copy("Resources/mel_filters_80.npy"), + .copy("Resources/mel_filters_128.npy") + ] + ), + .testTarget( + name: "WhisperKitMLXTests", + dependencies: [ + "WhisperKit", + "WhisperKitMLX", + "WhisperKitTestsUtils", + .product(name: "Transformers", package: "swift-transformers"), + ] + ) + ] +} + +// NOTE: `MLX` doesn't support `watchOS` yet, that's why we control the build using the `MLX_DISABLED` environment variable. +// To manualy build for `watchOS` use: +// `export MLX_DISABLED=1 && xcodebuild clean build-for-testing -scheme whisperkit -sdk watchos10.4 -destination 'platform=watchOS Simulator,OS=10.5,name=Apple Watch Ultra 2 (49mm)' -skipPackagePluginValidation` + +func isMLXDisabled() -> Bool { + ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" } diff --git a/README.md b/README.md index a4d5f5c..242087c 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,8 @@ Check out the demo app on [TestFlight](https://testflight.apple.com/join/LPVOyJZ - [Model Selection](#model-selection) - [Generating Models](#generating-models) - [Swift CLI](#swift-cli) + - [Backend Selection](#backend-selection) + - [Testing](#testing) - [Contributing \& Roadmap](#contributing--roadmap) - [License](#license) - [Citation](#citation) @@ -66,7 +68,7 @@ You can install `WhisperKit` command line app using [Homebrew](https://brew.sh) ```bash brew install whisperkit-cli -``` +``` ## Getting Started @@ -79,38 +81,51 @@ This example demonstrates how to transcribe a local audio file: ```swift import WhisperKit -// Initialize WhisperKit with default settings -Task { - let pipe = try? await WhisperKit() - let transcription = try? await pipe!.transcribe(audioPath: "path/to/your/audio.{wav,mp3,m4a,flac}")?.text - print(transcription) -} +// Initialize WhisperKit by passing the model name (WhisperKit will automatically download it): +let pipe = try await WhisperKit(model: "tiny") +// Transcribe the audio file +let transcription = try await pipe.transcribe(audioPath: "path/to/your/audio.{wav,mp3,m4a,flac}")?.text +// Print the transcription +print(transcription) ``` ### Model Selection -WhisperKit automatically downloads the recommended model for the device if not specified. You can also select a specific model by passing in the model name: +You have to specify the model by passing the model name: ```swift -let pipe = try? await WhisperKit(model: "large-v3") +let pipe = try await WhisperKit(model: "large-v3") ``` This method also supports glob search, so you can use wildcards to select a model: ```swift -let pipe = try? await WhisperKit(model: "distil*large-v3") +let pipe = try await WhisperKit(model: "distil*large-v3") ``` Note that the model search must return a single model from the source repo, otherwise an error will be thrown. For a list of available models, see our [HuggingFace repo](https://huggingface.co/argmaxinc/whisperkit-coreml). +For MLX models, see [here](https://huggingface.co/argmaxinc/whisperkit-mlx). + +If you want to get the recommended model for your device, you can use the following method: + +```swift +print(WhisperKit.recommendedModels()) +``` + +it should print the default and a list of disabled models, e.g.: + +```bash +(default: "openai_whisper-base", disabled: ["openai_whisper-large-v2_turbo", "openai_whisper-large-v2_turbo_955MB", "openai_whisper-large-v3_turbo", "openai_whisper-large-v3_turbo_954MB", "distil-whisper_distil-large-v3_turbo_600MB", "distil-whisper_distil-large-v3_turbo"]) +``` ### Generating Models WhisperKit also comes with the supporting repo [`whisperkittools`](https://github.com/argmaxinc/whisperkittools) which lets you create and deploy your own fine tuned versions of Whisper in CoreML format to HuggingFace. Once generated, they can be loaded by simply changing the repo name to the one used to upload the model: ```swift -let pipe = try? await WhisperKit(model: "large-v3", modelRepo: "username/your-model-repo") +let pipe = try await WhisperKit(model: "large-v3", modelRepo: "username/your-model-repo") ``` ### Swift CLI @@ -152,6 +167,53 @@ Which should print a transcription of the audio file. If you would like to strea swift run whisperkit-cli transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --stream ``` +### Backend Selection + +WhisperKit supports both CoreML and MLX backends. By default, it uses CoreML, but you can switch some or all pipeline components to MLX. +Available pipeline components are: +- `featureExtractor`, `FeatureExtractor` is used by default, use `MLXFeatureExtractor` to switch to MLX +- `audioEncoder`, `AudioEncoder` is used by default, use `MLXAudioEncoder` to switch to MLX +- `textDecoder`, `TextDecoder` is used by default, use `MLXTextDecoder` to switch to MLX + +Here is an example of how to switch the `featureExtractor` and `audioEncoder` to MLX and keep the `textDecoder` as CoreML: + +```swift +let pipe = try await WhisperKit( + model: "tiny", + mlxModel: "tiny", + featureExtractor: MLXFeatureExtractor(), + audioEncoder: MLXAudioEncoder() +) +``` + +**Note**: + +`swift run` and `swift test` commands won't work when the `mlx` backend is selected. +SwiftPM (command line) cannot build the Metal shaders so the ultimate build has to be done via Xcode. + +### Testing + +If you want to run the unit tests locally, first clone the repo: + +```bash +git clone https://github.com/argmaxinc/whisperkit.git +cd whisperkit +``` + +download the required models: + +```bash +make setup +make download-model MODEL=tiny +make download-mlx-model MODEL=tiny +``` + +and then run the tests: + +```bash +make test +``` + ## Contributing & Roadmap Our goal is to make WhisperKit better and better over time and we'd love your help! Just search the code for "TODO" for a variety of features that are yet to be built. Please refer to our [contribution guidelines](CONTRIBUTING.md) for submitting issues, pull requests, and coding standards, where we also have a public roadmap of features we are looking forward to building in the future. diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index f2bd27d..250f895 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -482,6 +482,8 @@ open class TextDecoder: TextDecoding, WhisperMLModel { return getModelInputDimention(model, named: "encoder_output_embeds", position: 1) } + public init() {} + /// Override default so we an unload the prefill data as well public func unloadModel() { model = nil diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index f4c7633..d8d759d 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -44,7 +44,7 @@ extension MLMultiArray { /// - index: The index of the element /// - strides: The precomputed strides of the multi-array, if not provided, it will be computed. It's a performance optimization to avoid recomputing the strides every time when accessing the multi-array with multiple indexes. @inline(__always) - func linearOffset(for index: [NSNumber], strides strideInts: [Int]? = nil) -> Int { + public func linearOffset(for index: [NSNumber], strides strideInts: [Int]? = nil) -> Int { var linearOffset = 0 let strideInts = strideInts ?? strides.map { $0.intValue } for (dimension, stride) in zip(index, strideInts) { diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 0841c02..a9bb407 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -38,14 +38,18 @@ open class WhisperKit { /// Configuration public var modelFolder: URL? + public var mlxModelFolder: URL? public var tokenizerFolder: URL? public let useBackgroundDownloadSession: Bool public init( model: String? = nil, + mlxModel: String? = nil, downloadBase: URL? = nil, - modelRepo: String? = nil, + modelRepo: String = "argmaxinc/whisperkit-coreml", modelFolder: String? = nil, + mlxModelRepo: String = "argmaxinc/whisperkit-mlx", + mlxModelFolder: String? = nil, tokenizerFolder: URL? = nil, computeOptions: ModelComputeOptions? = nil, audioProcessor: (any AudioProcessing)? = nil, @@ -75,9 +79,12 @@ open class WhisperKit { try await setupModels( model: model, + mlxModel: mlxModel, downloadBase: downloadBase, modelRepo: modelRepo, modelFolder: modelFolder, + mlxModelRepo: mlxModelRepo, + mlxModelFolder: mlxModelFolder, download: download ) @@ -214,30 +221,56 @@ open class WhisperKit { /// Sets up the model folder either from a local path or by downloading from a repository. public func setupModels( model: String?, + mlxModel: String? = nil, downloadBase: URL? = nil, - modelRepo: String?, - modelFolder: String?, + modelRepo: String = "argmaxinc/whisperkit-coreml", + modelFolder: String? = nil, + mlxModelRepo: String = "argmaxinc/whisperkit-mlx", + mlxModelFolder: String? = nil, download: Bool ) async throws { - // Determine the model variant to use - let modelVariant = model ?? WhisperKit.recommendedModels().default + // If no model is provided, use the recommended model + var modelVariant = model + if model == nil, mlxModel == nil, mlxModelFolder == nil { + // Determine the model variant to use by default + modelVariant = WhisperKit.recommendedModels().default + } // If a local model folder is provided, use it; otherwise, download the model - if let folder = modelFolder { - self.modelFolder = URL(fileURLWithPath: folder) - } else if download { - let repo = modelRepo ?? "argmaxinc/whisperkit-coreml" + if let modelFolder { + self.modelFolder = URL(fileURLWithPath: modelFolder) + } else if download, let modelVariant { do { self.modelFolder = try await Self.download( variant: modelVariant, downloadBase: downloadBase, useBackgroundSession: useBackgroundDownloadSession, - from: repo + from: modelRepo + ) + } catch { + // Handle errors related to model downloading + throw WhisperError.modelsUnavailable(""" + CoreML Model not found. Please check the model or repo name and try again. + Error: \(error) + """) + } + } + + // Same for MLX + if let mlxModelFolder { + self.mlxModelFolder = URL(fileURLWithPath: mlxModelFolder) + } else if download, let mlxModel { + do { + self.mlxModelFolder = try await Self.download( + variant: mlxModel, + downloadBase: downloadBase, + useBackgroundSession: useBackgroundDownloadSession, + from: mlxModelRepo ) } catch { // Handle errors related to model downloading throw WhisperError.modelsUnavailable(""" - Model not found. Please check the model or repo name and try again. + MLX Model not found. Please check the model or repo name and try again. Error: \(error) """) } @@ -251,40 +284,37 @@ open class WhisperKit { public func loadModels( prewarmMode: Bool = false ) async throws { - modelState = prewarmMode ? .prewarming : .loading + assert(modelFolder != nil || mlxModelFolder != nil, "Please specify `modelFolder` or `mlxModelFolder`") + modelState = prewarmMode ? .prewarming : .loading let modelLoadStart = CFAbsoluteTimeGetCurrent() - guard let path = modelFolder else { - throw WhisperError.modelsUnavailable("Model folder is not set.") - } + Logging.debug("Loading models with prewarmMode: \(prewarmMode)") - Logging.debug("Loading models from \(path.path) with prewarmMode: \(prewarmMode)") - - if let featureExtractor = featureExtractor as? WhisperMLModel { - Logging.debug("Loading feature extractor") + if let path = modelFolder, let featureExtractor = featureExtractor as? WhisperMLModel { + Logging.debug("Loading feature extractor from \(path.path)") try await featureExtractor.loadModel( at: path.appending(path: "MelSpectrogram.mlmodelc"), computeUnits: modelCompute.melCompute, // hardcoded to use GPU prewarmMode: prewarmMode ) Logging.debug("Loaded feature extractor") - } else if let featureExtractor = featureExtractor as? WhisperMLXModel { - Logging.debug("Loading MLX feature extractor") + } else if let path = mlxModelFolder, let featureExtractor = featureExtractor as? WhisperMLXModel { + Logging.debug("Loading MLX feature extractor from \(path.path)") try await featureExtractor.loadModel(at: path, configPath: path) Logging.debug("Loaded MLX feature extractor") } - if let audioEncoder = audioEncoder as? WhisperMLModel { - Logging.debug("Loading audio encoder") + if let path = modelFolder, let audioEncoder = audioEncoder as? WhisperMLModel { + Logging.debug("Loading audio encoder from \(path.path)") try await audioEncoder.loadModel( at: path.appending(path: "AudioEncoder.mlmodelc"), computeUnits: modelCompute.audioEncoderCompute, prewarmMode: prewarmMode ) Logging.debug("Loaded audio encoder") - } else if let audioEncoder = audioEncoder as? WhisperMLXModel { - Logging.debug("Loading MLX audio encoder") + } else if let path = mlxModelFolder, let audioEncoder = audioEncoder as? WhisperMLXModel { + Logging.debug("Loading MLX audio encoder from \(path.path)") try await audioEncoder.loadModel( at: path.appending(path: "encoder.safetensors"), configPath: path.appending(path: "config.json") @@ -292,16 +322,16 @@ open class WhisperKit { Logging.debug("Loaded MLX audio encoder") } - if let textDecoder = textDecoder as? WhisperMLModel { - Logging.debug("Loading text decoder") + if let path = modelFolder, let textDecoder = textDecoder as? WhisperMLModel { + Logging.debug("Loading text decoder from \(path.path)") try await textDecoder.loadModel( at: path.appending(path: "TextDecoder.mlmodelc"), computeUnits: modelCompute.textDecoderCompute, prewarmMode: prewarmMode ) Logging.debug("Loaded text decoder") - } else if let textDecoder = textDecoder as? WhisperMLXModel { - Logging.debug("Loading MLX text decoder") + } else if let path = mlxModelFolder, let textDecoder = textDecoder as? WhisperMLXModel { + Logging.debug("Loading MLX text decoder from \(path.path)") try await textDecoder.loadModel( at: path.appending(path: "decoder.safetensors"), configPath: path.appending(path: "config.json") @@ -309,16 +339,18 @@ open class WhisperKit { Logging.debug("Loaded MLX text decoder") } - let decoderPrefillUrl = path.appending(path: "TextDecoderContextPrefill.mlmodelc") - if FileManager.default.fileExists(atPath: decoderPrefillUrl.path) { - Logging.debug("Loading text decoder prefill data") - textDecoder.prefillData = TextDecoderContextPrefill() - try await textDecoder.prefillData?.loadModel( - at: decoderPrefillUrl, - computeUnits: modelCompute.prefillCompute, - prewarmMode: prewarmMode - ) - Logging.debug("Loaded text decoder prefill data") + if let path = modelFolder { + let decoderPrefillUrl = path.appending(path: "TextDecoderContextPrefill.mlmodelc") + if FileManager.default.fileExists(atPath: decoderPrefillUrl.path) { + Logging.debug("Loading text decoder prefill data") + textDecoder.prefillData = TextDecoderContextPrefill() + try await textDecoder.prefillData?.loadModel( + at: decoderPrefillUrl, + computeUnits: modelCompute.prefillCompute, + prewarmMode: prewarmMode + ) + Logging.debug("Loaded text decoder prefill data") + } } if prewarmMode { diff --git a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift index 29aecc9..935bbf0 100644 --- a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift +++ b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift @@ -39,6 +39,11 @@ open class MLXFeatureExtractor: FeatureExtracting { } } +extension MLXFeatureExtractor: WhisperMLXModel { + public func loadModel(at modelPath: URL, configPath: URL) async throws {} + public func unloadModel() {} +} + public extension MLXFeatureExtractor { /// Return the Hanning window. /// Taken from [numpy](https://numpy.org/doc/stable/reference/generated/numpy.hanning.html) implementation @@ -103,9 +108,6 @@ public extension MLXFeatureExtractor { nFFT: Int = 400, hopLength: Int = 160 ) -> MLXArray { - let device = MLX.Device.defaultDevice() - MLX.Device.setDefault(device: .cpu) - defer { MLX.Device.setDefault(device: device) } let window = hanning(nFFT) let freqs = stft(audio, window: window, nPerSeg: nFFT, nOverlap: hopLength) let magnitudes = freqs[..<(-1)].abs().square() diff --git a/Sources/WhisperKit/MLX/MLXTextDecoder.swift b/Sources/WhisperKit/MLX/MLXTextDecoder.swift index 2661c40..d98b9a1 100644 --- a/Sources/WhisperKit/MLX/MLXTextDecoder.swift +++ b/Sources/WhisperKit/MLX/MLXTextDecoder.swift @@ -13,7 +13,7 @@ public final class MLXTextDecoder: TextDecoding { public var isModelMultilingual: Bool = false public let supportsWordTimestamps: Bool = false public var logitsSize: Int? { - decoder?.nState + decoder?.nVocab } public var kvCacheEmbedDim: Int? { diff --git a/Sources/WhisperKit/MLX/MLXUtils.swift b/Sources/WhisperKit/MLX/MLXUtils.swift index 5059ac1..52a33db 100644 --- a/Sources/WhisperKit/MLX/MLXUtils.swift +++ b/Sources/WhisperKit/MLX/MLXUtils.swift @@ -35,6 +35,19 @@ extension MLXArray { } } +extension MLXArray { + var contiguousStrides: [Int] { + var contiguousStrides = [1] + var stride = 1 + for dimension in shape.dropFirst().reversed() { + stride = stride * dimension + contiguousStrides.append(stride) + } + contiguousStrides.reverse() + return contiguousStrides + } +} + extension MLXArray { func asMLMultiArray() throws -> MLMultiArray { let dataType = multiArrayDataType() @@ -45,11 +58,12 @@ extension MLXArray { let destination = UnsafeMutableRawBufferPointer(start: buffer, count: nbytes) ptr.copyBytes(to: destination) } + // `contiguousStrides` has to used, see the [discussion](https://github.com/ml-explore/mlx-swift/issues/117) return try MLMultiArray( dataPointer: buffer, shape: shape.map { NSNumber(value: $0) }, dataType: dataType, - strides: strides.map { NSNumber(value: $0) }, + strides: contiguousStrides.map { NSNumber(value: $0) }, deallocator: { $0.deallocate() } ) } diff --git a/Sources/WhisperKitCLI/CLIArguments.swift b/Sources/WhisperKitCLI/CLIArguments.swift index b76439b..1ca5e0e 100644 --- a/Sources/WhisperKitCLI/CLIArguments.swift +++ b/Sources/WhisperKitCLI/CLIArguments.swift @@ -3,6 +3,11 @@ import ArgumentParser +enum ModelType: String, Decodable, ExpressibleByArgument { + case coreML = "coreml" + case mlx = "mlx" +} + struct CLIArguments: ParsableArguments { @Option(help: "Paths to audio files") var audioPath = [String]() @@ -16,15 +21,33 @@ struct CLIArguments: ParsableArguments { @Option(help: "Model to download if no modelPath is provided") var model: String? + @Option(help: "Path of MLX model files") + var mlxModelPath: String? + + @Option(help: "MLX Model to download if no mlxModelPath is provided") + var mlxModel: String? + @Option(help: "Text to add in front of the model name to specify between different types of the same variant (values: \"openai\", \"distil\")") var modelPrefix: String = "openai" + @Option(help: "Text to add in front of the mlx model name to specify between different types of the same variant (values: \"openai\")") + var mlxModelPrefix: String = "openai" + @Option(help: "Path to save the downloaded model") var downloadModelPath: String? @Option(help: "Path to save the downloaded tokenizer files") var downloadTokenizerPath: String? + @Option(help: "Which feature extractor to use (supported: `coreml` and `mlx`)") + var featureExtractorType: ModelType = .coreML + + @Option(help: "Which audio encoder to use (supported: `coreml` and `mlx`)") + var audioEncoderType: ModelType = .coreML + + @Option(help: "Which text decoder to use (supported: `coreml` and `mlx`)") + var textDecoderType: ModelType = .coreML + @Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}") var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine diff --git a/Sources/WhisperKitCLI/TranscribeCLI.swift b/Sources/WhisperKitCLI/TranscribeCLI.swift index 172e423..cae3c33 100644 --- a/Sources/WhisperKitCLI/TranscribeCLI.swift +++ b/Sources/WhisperKitCLI/TranscribeCLI.swift @@ -305,12 +305,58 @@ struct TranscribeCLI: AsyncParsableCommand { nil } + let mlxModelName: String? = + if let modelVariant = cliArguments.mlxModel { + cliArguments.mlxModelPrefix + "*" + modelVariant + } else { + nil + } + + var featureExtractorType = cliArguments.featureExtractorType + var audioEncoderType = cliArguments.featureExtractorType + var textDecoderType = cliArguments.featureExtractorType + + if modelName == nil, mlxModelName != nil { + // CoreML model not provided, default to MLX + featureExtractorType = .mlx + audioEncoderType = .mlx + textDecoderType = .mlx + } + + let featureExtractor: FeatureExtracting = + switch featureExtractorType { + case .coreML: + FeatureExtractor() + case .mlx: + MLXFeatureExtractor() + } + + let audioEncoder: AudioEncoding = + switch audioEncoderType { + case .coreML: + AudioEncoder() + case .mlx: + MLXAudioEncoder() + } + + let textDecoder: TextDecoding = + switch textDecoderType { + case .coreML: + TextDecoder() + case .mlx: + MLXTextDecoder() + } + return try await WhisperKit( model: modelName, + mlxModel: mlxModelName, downloadBase: downloadModelFolder, modelFolder: cliArguments.modelPath, tokenizerFolder: downloadTokenizerFolder, computeOptions: computeOptions, + featureExtractor: featureExtractor, + audioEncoder: audioEncoder, + textDecoder: textDecoder, verbose: cliArguments.verbose, logLevel: .debug, load: true, diff --git a/Sources/WhisperKitTestsUtils/TestUtils.swift b/Sources/WhisperKitTestsUtils/TestUtils.swift index 7bfae95..4271313 100644 --- a/Sources/WhisperKitTestsUtils/TestUtils.swift +++ b/Sources/WhisperKitTestsUtils/TestUtils.swift @@ -1,7 +1,7 @@ import CoreML import Combine import Foundation -@testable import WhisperKit +import WhisperKit import XCTest public enum TestError: Error { @@ -133,7 +133,8 @@ public extension MLMultiArray { @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public extension XCTestCase { func transcribe( - modelPath: String, + modelPath: String? = nil, + mlxModelPath: String? = nil, options: DecodingOptions, callback: TranscriptionCallback = nil, audioFile: String = "jfk.wav", @@ -151,6 +152,7 @@ public extension XCTestCase { ) let whisperKit = try await WhisperKit( modelFolder: modelPath, + mlxModelFolder: mlxModelPath, computeOptions: computeOptions, featureExtractor: featureExtractor, audioEncoder: audioEncoder, @@ -170,7 +172,7 @@ public extension XCTestCase { func tinyModelPath() throws -> String { let modelDir = "whisperkit-coreml/openai_whisper-tiny" guard let modelPath = Bundle.module.urls(forResourcesWithExtension: "mlmodelc", subdirectory: modelDir)?.first?.deletingLastPathComponent().path else { - throw TestError.missingFile("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-models`") + throw TestError.missingFile("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-model MODEL=tiny`") } return modelPath } @@ -178,7 +180,7 @@ public extension XCTestCase { func tinyMLXModelPath() throws -> String { let modelDir = "whisperkit-mlx/openai_whisper-tiny" guard let modelPath = Bundle.module.urls(forResourcesWithExtension: "safetensors", subdirectory: modelDir)?.first?.deletingLastPathComponent().path else { - throw TestError.missingFile("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-mlx-models`") + throw TestError.missingFile("Failed to load model, ensure \"Models/\(modelDir)\" exists via Makefile command: `make download-mlx-model MODEL=tiny`") } return modelPath } diff --git a/Tests/WhisperKitMLXTests/MLXUnitTests.swift b/Tests/WhisperKitMLXTests/MLXUnitTests.swift index 2368b3b..c51a03e 100644 --- a/Tests/WhisperKitMLXTests/MLXUnitTests.swift +++ b/Tests/WhisperKitMLXTests/MLXUnitTests.swift @@ -154,7 +154,7 @@ final class MLXUnitTests: XCTestCase { let result = try await XCTUnwrapAsync( try await transcribe( - modelPath: tinyModelPath, + mlxModelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav", featureExtractor: MLXFeatureExtractor(), @@ -173,7 +173,7 @@ final class MLXUnitTests: XCTestCase { let result = try await XCTUnwrapAsync( try await transcribe( - modelPath: tinyModelPath, + mlxModelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav", featureExtractor: MLXFeatureExtractor(), @@ -189,7 +189,7 @@ final class MLXUnitTests: XCTestCase { func testDetectSpanish() async throws { let targetLanguage = "es" let whisperKit = try await WhisperKit( - modelFolder: tinyModelPath, + mlxModelFolder: tinyModelPath, featureExtractor: MLXFeatureExtractor(), audioEncoder: MLXAudioEncoder(), textDecoder: MLXTextDecoder(), @@ -215,7 +215,7 @@ final class MLXUnitTests: XCTestCase { let result = try await XCTUnwrapAsync( try await transcribe( - modelPath: tinyModelPath, + mlxModelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav", featureExtractor: MLXFeatureExtractor(), @@ -234,7 +234,7 @@ final class MLXUnitTests: XCTestCase { let result = try await XCTUnwrapAsync( try await transcribe( - modelPath: tinyModelPath, + mlxModelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav", featureExtractor: MLXFeatureExtractor(), @@ -250,7 +250,7 @@ final class MLXUnitTests: XCTestCase { func testDetectJapanese() async throws { let targetLanguage = "ja" let whisperKit = try await WhisperKit( - modelFolder: tinyModelPath, + mlxModelFolder: tinyModelPath, featureExtractor: MLXFeatureExtractor(), audioEncoder: MLXAudioEncoder(), textDecoder: MLXTextDecoder(), @@ -283,7 +283,7 @@ final class MLXUnitTests: XCTestCase { for (i, option) in optionsPairs.enumerated() { let result = try await XCTUnwrapAsync( try await transcribe( - modelPath: tinyModelPath, + mlxModelPath: tinyModelPath, options: option.options, audioFile: "ja_test_clip.wav", featureExtractor: MLXFeatureExtractor(), @@ -312,19 +312,35 @@ final class MLXUnitTests: XCTestCase { // MARK: - Utils Tests + func testContiguousStrides() { + let count = 24 + let arr1 = MLXArray(0.. Date: Tue, 3 Sep 2024 22:13:49 -0700 Subject: [PATCH 22/29] Refactor protocols for app support --- .../WhisperAX.xcodeproj/project.pbxproj | 21 +- .../xcshareddata/swiftpm/Package.resolved | 2 +- .../WhisperAX/Views/ContentView.swift | 344 ++++++++++++------ Sources/WhisperKit/Core/Models.swift | 103 ++++-- Sources/WhisperKit/Core/Utils.swift | 15 + Sources/WhisperKit/Core/WhisperKit.swift | 7 +- Sources/WhisperKit/MLX/MLXAudioEncoder.swift | 28 +- .../WhisperKit/MLX/MLXFeatureExtractor.swift | 10 +- Sources/WhisperKit/MLX/MLXModels.swift | 2 + Sources/WhisperKit/MLX/MLXTextDecoder.swift | 18 +- Sources/WhisperKit/MLX/MLXUtils.swift | 6 +- Sources/WhisperKitCLI/CLIArguments.swift | 8 +- 12 files changed, 379 insertions(+), 185 deletions(-) diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj index bfb9069..9010f74 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj @@ -8,6 +8,8 @@ /* Begin PBXBuildFile section */ 161136102B3F6C68003C20F6 /* WhisperKit in Frameworks */ = {isa = PBXBuildFile; productRef = 1611360F2B3F6C68003C20F6 /* WhisperKit */; }; + 1612D45D2C87D9ED009BB384 /* WhisperKit in Frameworks */ = {isa = PBXBuildFile; productRef = 1612D45C2C87D9ED009BB384 /* WhisperKit */; }; + 1612D45F2C87D9ED009BB384 /* WhisperKitMLX in Frameworks */ = {isa = PBXBuildFile; productRef = 1612D45E2C87D9ED009BB384 /* WhisperKitMLX */; }; 1677AFC22B57618A008C61C0 /* WhisperAXApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1677AFAB2B57618A008C61C0 /* WhisperAXApp.swift */; }; 1677AFC42B57618A008C61C0 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 1677AFAE2B57618A008C61C0 /* Preview Assets.xcassets */; }; 1677AFC92B57618A008C61C0 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 1677AFB42B57618A008C61C0 /* Assets.xcassets */; }; @@ -25,7 +27,6 @@ 1683EFEE2B9FACFE002448CD /* WhisperAX Watch App.app in Embed Watch Content */ = {isa = PBXBuildFile; fileRef = 161135DE2B3F66DA003C20F6 /* WhisperAX Watch App.app */; platformFilter = ios; settings = {ATTRIBUTES = (RemoveHeadersOnCopy, ); }; }; 1683EFEF2B9FADFE002448CD /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 1677AFB42B57618A008C61C0 /* Assets.xcassets */; }; 1683EFF02B9FADFE002448CD /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 1677AFB62B57618A008C61C0 /* Preview Assets.xcassets */; }; - 16EA36CF2B59E550006CA7CF /* WhisperKit in Frameworks */ = {isa = PBXBuildFile; productRef = 16EA36CE2B59E550006CA7CF /* WhisperKit */; }; /* End PBXBuildFile section */ /* Begin PBXContainerItemProxy section */ @@ -132,7 +133,8 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - 16EA36CF2B59E550006CA7CF /* WhisperKit in Frameworks */, + 1612D45D2C87D9ED009BB384 /* WhisperKit in Frameworks */, + 1612D45F2C87D9ED009BB384 /* WhisperKitMLX in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -352,7 +354,8 @@ ); name = WhisperAX; packageProductDependencies = ( - 16EA36CE2B59E550006CA7CF /* WhisperKit */, + 1612D45C2C87D9ED009BB384 /* WhisperKit */, + 1612D45E2C87D9ED009BB384 /* WhisperKitMLX */, ); productName = BasicExample; productReference = 167B345E2B05431E0076F261 /* WhisperAX.app */; @@ -438,8 +441,8 @@ ); mainGroup = 167B34552B05431E0076F261; packageReferences = ( - 161135D62B3F66A6003C20F6 /* XCLocalSwiftPackageReference "../.." */, 16D581062B4F7DCE000C0AB0 /* XCRemoteSwiftPackageReference "swift-markdown-ui" */, + 1612D45B2C87D9ED009BB384 /* XCLocalSwiftPackageReference "../../../WhisperKit" */, ); productRefGroup = 167B345F2B05431E0076F261 /* Products */; projectDirPath = ""; @@ -1112,9 +1115,9 @@ /* End XCConfigurationList section */ /* Begin XCLocalSwiftPackageReference section */ - 161135D62B3F66A6003C20F6 /* XCLocalSwiftPackageReference "../.." */ = { + 1612D45B2C87D9ED009BB384 /* XCLocalSwiftPackageReference "../../../WhisperKit" */ = { isa = XCLocalSwiftPackageReference; - relativePath = ../..; + relativePath = ../../../WhisperKit; }; /* End XCLocalSwiftPackageReference section */ @@ -1130,13 +1133,13 @@ /* End XCRemoteSwiftPackageReference section */ /* Begin XCSwiftPackageProductDependency section */ - 1611360F2B3F6C68003C20F6 /* WhisperKit */ = { + 1612D45C2C87D9ED009BB384 /* WhisperKit */ = { isa = XCSwiftPackageProductDependency; productName = WhisperKit; }; - 16EA36CE2B59E550006CA7CF /* WhisperKit */ = { + 1612D45E2C87D9ED009BB384 /* WhisperKitMLX */ = { isa = XCSwiftPackageProductDependency; - productName = WhisperKit; + productName = WhisperKitMLX; }; /* End XCSwiftPackageProductDependency section */ }; diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 7ae7da8..6e7f8f9 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "829222b514832cb61fe0002e0eebda98f23a75169c63f7d6ed7a320d57d5318f", + "originHash" : "cd17206b47bb810af9459722192530e3838d8e6629a970988e32a432aaa05f6e", "pins" : [ { "identity" : "mlx-swift", diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 1fa7096..8291d47 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -3,6 +3,7 @@ import SwiftUI import WhisperKit +import WhisperKitMLX #if canImport(UIKit) import UIKit #elseif canImport(AppKit) @@ -21,14 +22,15 @@ struct ContentView: View { @State var currentText: String = "" @State var currentChunks: [Int: (chunkText: [String], fallbacks: Int)] = [:] // TODO: Make this configurable in the UI - @State var modelStorage: String = "huggingface/models/argmaxinc/whisperkit-coreml" + @State var modelStorage: String = "huggingface/models" // MARK: Model management @State private var modelState: ModelState = .unloaded - @State private var localModels: [String] = [] - @State private var localModelPath: String = "" + @State private var localModels: [ModelInfo] = [] @State private var availableModels: [String] = [] + @State private var localModelPath: String = "" + @State private var localMLXModelPath: String = "" @State private var availableLanguages: [String] = [] @State private var disabledModels: [String] = WhisperKit.recommendedModels().disabled @@ -38,6 +40,7 @@ struct ContentView: View { @AppStorage("selectedTask") private var selectedTask: String = "transcribe" @AppStorage("selectedLanguage") private var selectedLanguage: String = "english" @AppStorage("repoName") private var repoName: String = "argmaxinc/whisperkit-coreml" + @AppStorage("mlxRepoName") private var mlxRepoName: String = "argmaxinc/whisperkit-mlx" @AppStorage("enableTimestamps") private var enableTimestamps: Bool = true @AppStorage("enablePromptPrefill") private var enablePromptPrefill: Bool = true @AppStorage("enableCachePrefill") private var enableCachePrefill: Bool = true @@ -353,13 +356,15 @@ struct ContentView: View { Picker("", selection: $selectedModel) { ForEach(availableModels, id: \.self) { model in HStack { - let modelIcon = localModels.contains { $0 == model.description } ? "checkmark.circle" : "arrow.down.circle.dotted" - Text("\(Image(systemName: modelIcon)) \(model.description.components(separatedBy: "_").dropFirst().joined(separator: " "))").tag(model.description) + let modelIcon = localModels.contains(where: { $0.name == model.description }) ? "checkmark.circle" : "arrow.down.circle.dotted" + let modelName: String = model.description.components(separatedBy: "_").dropFirst().joined(separator: " ") + Text("\(Image(systemName: modelIcon)) \(modelName)").tag(model.description) } } } .pickerStyle(MenuPickerStyle()) - .onChange(of: selectedModel, initial: false) { _, _ in + .onChange(of: selectedModel, initial: false) { _, newValue in + print("Selected model: \(newValue)") modelState = .unloaded } } else { @@ -376,13 +381,15 @@ struct ContentView: View { .help("Delete model") .buttonStyle(BorderlessButtonStyle()) .disabled(localModels.count == 0) - .disabled(!localModels.contains(selectedModel)) + .disabled(!localModels.contains(where: { $0.name == selectedModel })) #if os(macOS) Button(action: { - let folderURL = whisperKit?.modelFolder ?? (localModels.contains(selectedModel) ? URL(fileURLWithPath: localModelPath) : nil) - if let folder = folderURL { - NSWorkspace.shared.open(folder) + if let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first { + let modelPath = documents.appendingPathComponent(modelStorage) + if modelPath.hasDirectoryPath { + NSWorkspace.shared.open(modelPath) + } } }, label: { Image(systemName: "folder") @@ -442,7 +449,7 @@ struct ContentView: View { VStack(alignment: .leading) { HStack { Image(systemName: "circle.fill") - .foregroundStyle((whisperKit?.audioEncoder as? WhisperMLModel)?.modelState == .loaded ? .green : (modelState == .unloaded ? .red : .yellow)) + .foregroundStyle((whisperKit?.audioEncoder as? WhisperModel)?.modelState == .loaded ? .green : (modelState == .unloaded ? .red : .yellow)) .symbolEffect(.variableColor, isActive: modelState != .loaded && modelState != .unloaded) Text("Audio Encoder") Spacer() @@ -450,6 +457,7 @@ struct ContentView: View { Text("CPU").tag(MLComputeUnits.cpuOnly) Text("GPU").tag(MLComputeUnits.cpuAndGPU) Text("Neural Engine").tag(MLComputeUnits.cpuAndNeuralEngine) + Text("MLX").tag(MLComputeUnits.mlx) } .onChange(of: encoderComputeUnits, initial: false) { _, _ in loadModel(selectedModel) @@ -459,7 +467,7 @@ struct ContentView: View { } HStack { Image(systemName: "circle.fill") - .foregroundStyle((whisperKit?.textDecoder as? WhisperMLModel)?.modelState == .loaded ? .green : (modelState == .unloaded ? .red : .yellow)) + .foregroundStyle((whisperKit?.textDecoder as? WhisperModel)?.modelState == .loaded ? .green : (modelState == .unloaded ? .red : .yellow)) .symbolEffect(.variableColor, isActive: modelState != .loaded && modelState != .unloaded) Text("Text Decoder") Spacer() @@ -467,6 +475,7 @@ struct ContentView: View { Text("CPU").tag(MLComputeUnits.cpuOnly) Text("GPU").tag(MLComputeUnits.cpuAndGPU) Text("Neural Engine").tag(MLComputeUnits.cpuAndNeuralEngine) + Text("MLX").tag(MLComputeUnits.mlx) } .onChange(of: decoderComputeUnits, initial: false) { _, _ in loadModel(selectedModel) @@ -520,7 +529,6 @@ struct ContentView: View { var controlsView: some View { VStack { basicSettingsView - if let selectedCategoryId, let item = menu.first(where: { $0.id == selectedCategoryId }) { switch item.name { case "Transcribe": @@ -934,28 +942,33 @@ struct ContentView: View { // MARK: - Logic func fetchModels() { - availableModels = [selectedModel] - // First check what's already downloaded + availableModels.removeAll() + localModels.removeAll() + if let documents = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first { - let modelPath = documents.appendingPathComponent(modelStorage).path + let modelPath = documents.appendingPathComponent(modelStorage) + let subdirectories = [repoName, mlxRepoName] + + for subdirectory in subdirectories { + let localRepoPath = modelPath.appendingPathComponent(subdirectory) + let modelType: ModelEngine = subdirectory == repoName ? .coreML : .mlx - // Check if the directory exists - if FileManager.default.fileExists(atPath: modelPath) { - localModelPath = modelPath do { - let downloadedModels = try FileManager.default.contentsOfDirectory(atPath: modelPath) - for model in downloadedModels where !localModels.contains(model) { - localModels.append(model) + let downloadedModels = try FileManager.default.contentsOfDirectory(at: localRepoPath, includingPropertiesForKeys: nil) + for modelURL in downloadedModels where modelURL.hasDirectoryPath { + let modelName = modelURL.lastPathComponent + let modelInfo = ModelInfo(name: modelName, engine: modelType, url: modelURL) + localModels.append(modelInfo) } } catch { - print("Error enumerating files at \(modelPath): \(error.localizedDescription)") + print("Error enumerating files at \(localRepoPath): \(error.localizedDescription)") } } } - localModels = WhisperKit.formatModelFiles(localModels) - for model in localModels { + let formattedLocalModels = WhisperKit.formatModelFiles(localModels.map { $0.name }) + for model in formattedLocalModels { if !availableModels.contains(model), !disabledModels.contains(model) { @@ -963,22 +976,53 @@ struct ContentView: View { } } - print("Found locally: \(localModels)") + print("Found locally: \(localModels.map { $0.name })") print("Previously selected model: \(selectedModel)") + // Fetch remote models and add them to availableModels if they're not already local Task { let remoteModels = try await WhisperKit.fetchAvailableModels(from: repoName) for model in remoteModels { if !availableModels.contains(model), - !disabledModels.contains(model) - { + !disabledModels.contains(model) { + availableModels.append(model) + } + } + + let remoteMLXModels = try await WhisperKit.fetchAvailableModels(from: mlxRepoName) + for model in remoteMLXModels { + if !availableModels.contains(model), + !disabledModels.contains(model) { availableModels.append(model) } } + + print("Available models: \(availableModels)") + print("Selected model: \(selectedModel)") } } func loadModel(_ model: String, redownload: Bool = false) { + // Print selected model and compute options for debugging + printModelInfo() + + // Reset the WhisperKit instance + resetWhisperKit() + + Task { + do { + try await initializeWhisperKit() + let (folder, mlxFolder) = try await downloadModelFolders(model: model, redownload: redownload) + setupPipelineComponents(folder: folder, mlxFolder: mlxFolder) + try await prewarmAndLoadModels() + await updateLocalModels(model: model, folder: folder, mlxFolder: mlxFolder) + } catch { + await handleModelLoadError(error: error, model: model, redownload: redownload) + } + } + } + + private func printModelInfo() { print("Selected Model: \(UserDefaults.standard.string(forKey: "selectedModel") ?? "nil")") print(""" Computing Options: @@ -987,110 +1031,194 @@ struct ContentView: View { - Text Decoder: \(getComputeOptions().textDecoderCompute.description) - Prefill Data: \(getComputeOptions().prefillCompute.description) """) + } + private func resetWhisperKit() { + // Reset the current WhisperKit instance whisperKit = nil - Task { - whisperKit = try await WhisperKit( - computeOptions: getComputeOptions(), - verbose: true, - logLevel: .debug, - prewarm: false, - load: false, - download: false - ) - guard let whisperKit = whisperKit else { - return - } + } - var folder: URL? + private func initializeWhisperKit() async throws { + // Initialize a new WhisperKit instance with the current compute options + whisperKit = try await WhisperKit( + computeOptions: getComputeOptions(), + verbose: true, + logLevel: .debug, + prewarm: false, + load: false, + download: false + ) + } - // Check if the model is available locally - if localModels.contains(model) && !redownload { - // Get local model folder URL from localModels - // TODO: Make this configurable in the UI - folder = URL(fileURLWithPath: localModelPath).appendingPathComponent(model) - } else { - // Download the model - folder = try await WhisperKit.download(variant: model, from: repoName, progressCallback: { progress in - DispatchQueue.main.async { - loadingProgressValue = Float(progress.fractionCompleted) * specializationProgressRatio - modelState = .downloading - } - }) - } + private func downloadModelFolders(model: String, redownload: Bool) async throws -> (URL?, URL?) { + guard whisperKit != nil else { return (nil, nil) } - await MainActor.run { - loadingProgressValue = specializationProgressRatio - modelState = .downloaded - } + var folder: URL? + var mlxFolder: URL? - if let modelFolder = folder { - whisperKit.modelFolder = modelFolder + // Check if the model is available locally + let needsCoreMLModel = encoderComputeUnits != .mlx || decoderComputeUnits != .mlx + if needsCoreMLModel { + folder = try await downloadCoreMLModelIfNeeded(model: model, redownload: redownload) + } - await MainActor.run { - // Set the loading progress to 90% of the way after prewarm - loadingProgressValue = specializationProgressRatio - modelState = .prewarming - } + // Check if MLX model is needed based on compute units + let needsMLXModel = encoderComputeUnits == .mlx || decoderComputeUnits == .mlx + if needsMLXModel { + mlxFolder = try await downloadMLXModelIfNeeded(model: model, redownload: redownload) + } - let progressBarTask = Task { - await updateProgressBar(targetProgress: 0.9, maxTime: 240) - } + await MainActor.run { + loadingProgressValue = specializationProgressRatio + modelState = .downloaded + } - // Prewarm models - do { - try await whisperKit.prewarmModels() - progressBarTask.cancel() - } catch { - print("Error prewarming models, retrying: \(error.localizedDescription)") - progressBarTask.cancel() - if !redownload { - loadModel(model, redownload: true) - return - } else { - // Redownloading failed, error out - modelState = .unloaded - return - } - } + return (folder, mlxFolder) + } - await MainActor.run { - // Set the loading progress to 90% of the way after prewarm - loadingProgressValue = specializationProgressRatio + 0.9 * (1 - specializationProgressRatio) - modelState = .loading - } + private func downloadCoreMLModelIfNeeded(model: String, redownload: Bool) async throws -> URL? { + if localModels.contains(where: { $0.name == model && $0.engine == .coreML }) && !redownload { + // Get local model folder URL from localModels + // TODO: Make this configurable in the UI + return localModels.first(where: { $0.name == model && $0.engine == .coreML })?.url + } else { + // Download the model + return try await WhisperKit.download(variant: model, from: repoName, progressCallback: updateDownloadProgress) + } + } - try await whisperKit.loadModels() + private func downloadMLXModelIfNeeded(model: String, redownload: Bool) async throws -> URL? { + if localModels.contains(where: { $0.name == model && $0.engine == .mlx }) && !redownload { + return localModels.first(where: { $0.name == model && $0.engine == .mlx })?.url + } else { + return try await WhisperKit.download(variant: model, from: mlxRepoName, progressCallback: updateDownloadProgress) + } + } - await MainActor.run { - if !localModels.contains(model) { - localModels.append(model) - } + private func updateDownloadProgress(_ progress: Progress) { + DispatchQueue.main.async { + loadingProgressValue = Float(progress.fractionCompleted) * specializationProgressRatio + modelState = .downloading + } + } - availableLanguages = Constants.languages.map { $0.key }.sorted() - loadingProgressValue = 1.0 - modelState = whisperKit.modelState - } + private func setupPipelineComponents(folder: URL?, mlxFolder: URL?) { + guard let whisperKit = whisperKit else { return } + + // Set up CoreML components if needed + if let modelFolder = folder { + whisperKit.modelFolder = modelFolder + if encoderComputeUnits != .mlx || decoderComputeUnits != .mlx { + whisperKit.featureExtractor = FeatureExtractor() + } + if encoderComputeUnits != .mlx { + whisperKit.audioEncoder = AudioEncoder() + } + if decoderComputeUnits != .mlx { + whisperKit.textDecoder = TextDecoder() } } - } - func deleteModel() { - if localModels.contains(selectedModel) { - let modelFolder = URL(fileURLWithPath: localModelPath).appendingPathComponent(selectedModel) + // Set up MLX components if needed + if let mlxModelFolder = mlxFolder { + whisperKit.mlxModelFolder = mlxModelFolder + if encoderComputeUnits == .mlx || decoderComputeUnits == .mlx { + whisperKit.featureExtractor = MLXFeatureExtractor() + } + if encoderComputeUnits == .mlx { + whisperKit.audioEncoder = MLXAudioEncoder() + } + if decoderComputeUnits == .mlx { + whisperKit.textDecoder = MLXTextDecoder() + } + } + } - do { - try FileManager.default.removeItem(at: modelFolder) + private func prewarmAndLoadModels() async throws { + guard let whisperKit = whisperKit else { return } + + await MainActor.run { + loadingProgressValue = specializationProgressRatio + modelState = .prewarming + } + + let progressBarTask = Task { + await updateProgressBar(targetProgress: 0.9, maxTime: 240) + } + + // Prewarm models + do { + try await whisperKit.prewarmModels() + progressBarTask.cancel() + } catch { + print("Error prewarming models, retrying: \(error.localizedDescription)") + progressBarTask.cancel() + throw error + } + + await MainActor.run { + loadingProgressValue = specializationProgressRatio + 0.9 * (1 - specializationProgressRatio) + modelState = .loading + } + + try await whisperKit.loadModels() + } - if let index = localModels.firstIndex(of: selectedModel) { - localModels.remove(at: index) + private func updateLocalModels(model: String, folder: URL?, mlxFolder: URL?) async { + await MainActor.run { + // Add newly downloaded models to localModels if not already present + if !localModels.contains(where: { $0.name == model }) { + if let folder = folder { + localModels.append(ModelInfo(name: model, engine: .coreML, url: folder)) + } + if let mlxFolder = mlxFolder { + localModels.append(ModelInfo(name: model, engine: .mlx, url: mlxFolder)) } + } + + availableLanguages = Constants.languages.map { $0.key }.sorted() + loadingProgressValue = 1.0 + modelState = whisperKit?.modelState ?? .unloaded + } + } + private func handleModelLoadError(error: Error, model: String, redownload: Bool) async { + print("Error loading model: \(error.localizedDescription)") + if !redownload { + // Attempt to redownload and load the model if prewarming fails + loadModel(model, redownload: true) + } else { + await MainActor.run { modelState = .unloaded + } + } + } + + func deleteModel() { + let modelsToDelete = localModels.filter { $0.name == selectedModel } + + guard !modelsToDelete.isEmpty else { + print("Model not found locally") + return + } + + for modelToDelete in modelsToDelete { + do { + try FileManager.default.removeItem(at: modelToDelete.url) + print("Deleted model at: \(modelToDelete.url)") } catch { - print("Error deleting model: \(error)") + print("Error deleting model at \(modelToDelete.url): \(error)") } } + + localModels.removeAll { $0.name == selectedModel } + + modelState = .unloaded + + // Reset selected model if it was deleted + if selectedModel == modelsToDelete.first?.name { + selectedModel = availableModels.first ?? "" + } } func updateProgressBar(targetProgress: Float, maxTime: TimeInterval) async { diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 4e86e2b..1939cda 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -18,41 +18,6 @@ extension Float16: BNNSScalar {} extension Float16: MLShapedArrayScalar {} #endif -// MARK: - CoreML - -public protocol WhisperModel: AnyObject { - func unloadModel() -} - -public protocol WhisperMLModel: WhisperModel { - var model: MLModel? { get set } - func loadModel(at modelPath: URL, computeUnits: MLComputeUnits, prewarmMode: Bool) async throws -} - -public protocol WhisperMLXModel: WhisperModel { - func loadModel(at modelPath: URL, configPath: URL) async throws -} - -public extension WhisperMLModel { - func loadModel(at modelPath: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws { - let loadedModel = try await Task { - let modelConfig = MLModelConfiguration() - modelConfig.computeUnits = computeUnits - return try await MLModel.load(contentsOf: modelPath, configuration: modelConfig) - }.value - - model = prewarmMode ? nil : loadedModel - } - - func unloadModel() { - model = nil - } - - var modelState: ModelState { - return model == nil ? .unloaded : .loaded - } -} - // MARK: - Whisper Models public enum ModelVariant: CustomStringConvertible, CaseIterable { @@ -171,6 +136,74 @@ public struct ModelComputeOptions { } } +public struct ModelInfo: Identifiable, Hashable { + public let id = UUID() + public let name: String + public let engine: ModelEngine + public let url: URL + + public init(name: String, engine: ModelEngine, url: URL) { + self.name = name + self.engine = engine + self.url = url + } +} + +public enum ModelEngine: String, Codable { + case coreML = "coreml" + case mlx = "mlx" +} + +public protocol WhisperModel: AnyObject { + func unloadModel() + var modelState: ModelState { get } +} + +// MARK: - CoreML + +public protocol WhisperMLModel: WhisperModel { + var model: MLModel? { get set } + func loadModel(at modelPath: URL, computeUnits: MLComputeUnits, prewarmMode: Bool) async throws +} + +public extension WhisperMLModel { + func loadModel(at modelPath: URL, computeUnits: MLComputeUnits, prewarmMode: Bool = false) async throws { + let loadedModel = try await Task { + let modelConfig = MLModelConfiguration() + modelConfig.computeUnits = computeUnits + return try await MLModel.load(contentsOf: modelPath, configuration: modelConfig) + }.value + + model = prewarmMode ? nil : loadedModel + } + + func unloadModel() { + model = nil + } + + var modelState: ModelState { + return model == nil ? .unloaded : .loaded + } +} + +// MARK: MLX + +public protocol WhisperMLXModel: WhisperModel { + associatedtype MLXModuleType + var model: MLXModuleType? { get set } + func loadModel(at modelPath: URL, configPath: URL?) async throws +} + +public extension WhisperMLXModel { + func unloadModel() { + model = nil + } + + var modelState: ModelState { + return model == nil ? .unloaded : .loaded + } +} + // MARK: - Chunking public struct AudioChunk { diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index d8d759d..45b930c 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -107,6 +107,19 @@ extension MLModel { } public extension MLComputeUnits { + /// Compute unit for MLX-based models. + /// + /// This compute unit is specifically designed for use with MLX-based models in WhisperKit. + /// + /// - Important: This is a custom compute unit and will not be recognized by MLModel instances. + static let mlx: MLComputeUnits = { + guard let unit = MLComputeUnits(rawValue: 99) else { // Prevent overlap with future values + Logging.error("Failed to create MLComputeUnits for MLX. Defaulting to .cpuAndGPU.") + return .cpuAndGPU + } + return unit + }() + var description: String { switch self { case .cpuOnly: @@ -117,6 +130,8 @@ public extension MLComputeUnits { return "all" case .cpuAndNeuralEngine: return "cpuAndNeuralEngine" + case .mlx: + return "mlx" @unknown default: return "unknown" } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index a9bb407..9ade93b 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -165,6 +165,7 @@ open class WhisperKit { return sortedModels } + @MainActor public static func download( variant: String, downloadBase: URL? = nil, @@ -299,7 +300,7 @@ open class WhisperKit { prewarmMode: prewarmMode ) Logging.debug("Loaded feature extractor") - } else if let path = mlxModelFolder, let featureExtractor = featureExtractor as? WhisperMLXModel { + } else if let path = mlxModelFolder, let featureExtractor = featureExtractor as? (any WhisperMLXModel) { Logging.debug("Loading MLX feature extractor from \(path.path)") try await featureExtractor.loadModel(at: path, configPath: path) Logging.debug("Loaded MLX feature extractor") @@ -313,7 +314,7 @@ open class WhisperKit { prewarmMode: prewarmMode ) Logging.debug("Loaded audio encoder") - } else if let path = mlxModelFolder, let audioEncoder = audioEncoder as? WhisperMLXModel { + } else if let path = mlxModelFolder, let audioEncoder = audioEncoder as? (any WhisperMLXModel) { Logging.debug("Loading MLX audio encoder from \(path.path)") try await audioEncoder.loadModel( at: path.appending(path: "encoder.safetensors"), @@ -330,7 +331,7 @@ open class WhisperKit { prewarmMode: prewarmMode ) Logging.debug("Loaded text decoder") - } else if let path = mlxModelFolder, let textDecoder = textDecoder as? WhisperMLXModel { + } else if let path = mlxModelFolder, let textDecoder = textDecoder as? (any WhisperMLXModel) { Logging.debug("Loading MLX text decoder from \(path.path)") try await textDecoder.loadModel( at: path.appending(path: "decoder.safetensors"), diff --git a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift index f6dcedd..4de2b7a 100644 --- a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift +++ b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift @@ -7,32 +7,32 @@ import MLXNN import WhisperKit @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -public class MLXAudioEncoder: AudioEncoding { +public class MLXAudioEncoder: AudioEncoding, WhisperMLXModel { + public var model: AudioEncoderModule? + public var embedSize: Int? { - encoder?.nState + model?.nState } - private var encoder: AudioEncoder? - public init() {} public func encodeFeatures(_ features: MLMultiArray) async throws -> MLMultiArray? { - guard let encoder else { + guard let model else { throw WhisperError.modelsUnavailable() } try Task.checkCancellation() let inputArray = features.asMLXArray(FloatType.self) let input = inputArray.asMLXInput() - let output = encoder(input) + let output = model(input) return try output.asMLXOutput().asMLMultiArray() } } -extension MLXAudioEncoder: WhisperMLXModel { - public func loadModel(at modelPath: URL, configPath: URL) async throws { +extension MLXAudioEncoder { + public func loadModel(at modelPath: URL, configPath: URL?) async throws { let parameters = try loadParameters(at: modelPath) let config = try loadConfig(at: configPath) - let encoder = AudioEncoder( + let encoder = AudioEncoderModule( nMels: config.nMels, nCtx: config.nAudioCtx, nState: config.nAudioState, @@ -42,15 +42,19 @@ extension MLXAudioEncoder: WhisperMLXModel { ) let loadedEncoder = try encoder.update(parameters: parameters, verify: [.noUnusedKeys]) MLX.eval(loadedEncoder) - self.encoder = encoder + self.model = encoder } public func unloadModel() { - encoder = nil + model = nil + } + + public var modelState: ModelState { + return model == nil ? .unloaded : .loaded } } -final class AudioEncoder: Module { +public class AudioEncoderModule: MLXNN.Module { let nMels: Int let nCtx: Int let nState: Int diff --git a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift index 935bbf0..fe95078 100644 --- a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift +++ b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift @@ -8,7 +8,7 @@ import MLXFFT import WhisperKit @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -open class MLXFeatureExtractor: FeatureExtracting { +open class MLXFeatureExtractor: FeatureExtracting, WhisperMLXModel { public let melCount: Int? private let nFFT: Int private let hopLength: Int @@ -37,10 +37,12 @@ open class MLXFeatureExtractor: FeatureExtracting { ) return try output.asType(FloatType.self).asMLXOutput().asMLMultiArray() } -} -extension MLXFeatureExtractor: WhisperMLXModel { - public func loadModel(at modelPath: URL, configPath: URL) async throws {} + // Stubs for WhisperMLXModel protocol, not needed + public typealias MLXModuleType = NSObject + public var model: NSObject? + + public func loadModel(at modelPath: URL, configPath: URL?) async throws {} public func unloadModel() {} } diff --git a/Sources/WhisperKit/MLX/MLXModels.swift b/Sources/WhisperKit/MLX/MLXModels.swift index 207e884..4322ff6 100644 --- a/Sources/WhisperKit/MLX/MLXModels.swift +++ b/Sources/WhisperKit/MLX/MLXModels.swift @@ -3,6 +3,8 @@ import Foundation import MLX +import MLXNN +import WhisperKit public enum PadMode { case constant diff --git a/Sources/WhisperKit/MLX/MLXTextDecoder.swift b/Sources/WhisperKit/MLX/MLXTextDecoder.swift index d98b9a1..d6e2b5f 100644 --- a/Sources/WhisperKit/MLX/MLXTextDecoder.swift +++ b/Sources/WhisperKit/MLX/MLXTextDecoder.swift @@ -8,12 +8,13 @@ import WhisperKit @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public final class MLXTextDecoder: TextDecoding { + public var model: TextDecoderModule? public var tokenizer: (any WhisperTokenizer)? public var prefillData: (any WhisperMLModel)? public var isModelMultilingual: Bool = false public let supportsWordTimestamps: Bool = false public var logitsSize: Int? { - decoder?.nVocab + model?.nVocab } public var kvCacheEmbedDim: Int? { @@ -36,7 +37,6 @@ public final class MLXTextDecoder: TextDecoding { return config.nTextState } - private var decoder: TextDecoder? private var config: MLXModelConfig? private var languageLogitsFilter: LanguageLogitsFilter? @@ -117,12 +117,12 @@ public final class MLXTextDecoder: TextDecoding { encoderOutputEmbeds: MLMultiArray, decoderKeyPaddingMask: MLMultiArray ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? { - guard let decoder else { + guard let model else { return nil } let tokens = inputIds.asMLXArray(Int32.self) let audioFeatures = encoderOutputEmbeds.asMLXArray(FloatType.self).asMLXInput() - let result = decoder( + let result = model( tokens, xa: audioFeatures, kvCache: Self.toKvCache(keyCache: keyCache, valueCache: valueCache) @@ -462,10 +462,10 @@ public final class MLXTextDecoder: TextDecoding { } extension MLXTextDecoder: WhisperMLXModel { - public func loadModel(at modelPath: URL, configPath: URL) async throws { + public func loadModel(at modelPath: URL, configPath: URL?) async throws { let parameters = try loadParameters(at: modelPath) let config = try loadConfig(at: configPath) - let decoder = TextDecoder( + let decoder = TextDecoderModule( nVocab: config.nVocab, nCtx: config.nTextCtx, nState: config.nTextState, @@ -475,19 +475,19 @@ extension MLXTextDecoder: WhisperMLXModel { ) let loadedDecoder = try decoder.update(parameters: parameters, verify: [.noUnusedKeys]) MLX.eval(loadedDecoder) - self.decoder = loadedDecoder + self.model = loadedDecoder self.config = config } public func unloadModel() { - decoder = nil + model = nil config = nil prefillData = nil languageLogitsFilter = nil } } -final class TextDecoder: Module { +public class TextDecoderModule: Module { let nVocab: Int let nCtx: Int let nState: Int diff --git a/Sources/WhisperKit/MLX/MLXUtils.swift b/Sources/WhisperKit/MLX/MLXUtils.swift index 52a33db..d6af28b 100644 --- a/Sources/WhisperKit/MLX/MLXUtils.swift +++ b/Sources/WhisperKit/MLX/MLXUtils.swift @@ -5,6 +5,7 @@ import CoreML import Foundation import MLX import MLXNN +import WhisperKit // MARK: - Extensions @@ -113,7 +114,10 @@ func loadParameters(at url: URL) throws -> NestedDictionary { return ModuleParameters.unflattened(arrays) } -func loadConfig(at url: URL) throws -> MLXModelConfig { +func loadConfig(at configPath: URL?) throws -> MLXModelConfig { + guard let url = configPath else { + throw WhisperError.modelsUnavailable("Config path must be specified for MLX models") + } let configDecoder = JSONDecoder() configDecoder.keyDecodingStrategy = .convertFromSnakeCase return try configDecoder.decode(MLXModelConfig.self, from: Data(contentsOf: url)) diff --git a/Sources/WhisperKitCLI/CLIArguments.swift b/Sources/WhisperKitCLI/CLIArguments.swift index 1ca5e0e..a948ed2 100644 --- a/Sources/WhisperKitCLI/CLIArguments.swift +++ b/Sources/WhisperKitCLI/CLIArguments.swift @@ -2,10 +2,12 @@ // Copyright © 2024 Argmax, Inc. All rights reserved. import ArgumentParser +import WhisperKitMLX -enum ModelType: String, Decodable, ExpressibleByArgument { - case coreML = "coreml" - case mlx = "mlx" +extension ModelType: ExpressibleByArgument { + public init?(argument: String) { + self.init(rawValue: argument.lowercased()) + } } struct CLIArguments: ParsableArguments { From cc10a2333c37916798e4821a754c19cfc58244f1 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Wed, 4 Sep 2024 09:18:11 -0700 Subject: [PATCH 23/29] WIP perf improvements --- .../WhisperAX.xcodeproj/project.pbxproj | 23 +-- .../xcshareddata/swiftpm/Package.resolved | 25 +-- Package.swift | 2 +- Sources/WhisperKit/Core/TextDecoder.swift | 128 ++++++------ Sources/WhisperKit/MLX/MLXModels.swift | 15 +- Sources/WhisperKit/MLX/MLXTextDecoder.swift | 100 ++++++---- Sources/WhisperKit/MLX/MLXTokenSampler.swift | 186 ++++++++++++++++++ Sources/WhisperKit/MLX/MLXUtils.swift | 36 ++-- Tests/WhisperKitMLXTests/MLXUnitTests.swift | 6 +- Tests/WhisperKitTests/UnitTests.swift | 6 +- 10 files changed, 368 insertions(+), 159 deletions(-) create mode 100644 Sources/WhisperKit/MLX/MLXTokenSampler.swift diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj index 9010f74..d25d5e0 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.pbxproj @@ -7,9 +7,9 @@ objects = { /* Begin PBXBuildFile section */ - 161136102B3F6C68003C20F6 /* WhisperKit in Frameworks */ = {isa = PBXBuildFile; productRef = 1611360F2B3F6C68003C20F6 /* WhisperKit */; }; 1612D45D2C87D9ED009BB384 /* WhisperKit in Frameworks */ = {isa = PBXBuildFile; productRef = 1612D45C2C87D9ED009BB384 /* WhisperKit */; }; 1612D45F2C87D9ED009BB384 /* WhisperKitMLX in Frameworks */ = {isa = PBXBuildFile; productRef = 1612D45E2C87D9ED009BB384 /* WhisperKitMLX */; }; + 167209782C88BF3F0010BE5F /* WhisperKit in Frameworks */ = {isa = PBXBuildFile; productRef = 167209772C88BF3F0010BE5F /* WhisperKit */; }; 1677AFC22B57618A008C61C0 /* WhisperAXApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1677AFAB2B57618A008C61C0 /* WhisperAXApp.swift */; }; 1677AFC42B57618A008C61C0 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 1677AFAE2B57618A008C61C0 /* Preview Assets.xcassets */; }; 1677AFC92B57618A008C61C0 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 1677AFB42B57618A008C61C0 /* Assets.xcassets */; }; @@ -111,7 +111,7 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - 161136102B3F6C68003C20F6 /* WhisperKit in Frameworks */, + 167209782C88BF3F0010BE5F /* WhisperKit in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -296,7 +296,7 @@ ); name = "WhisperAX Watch App"; packageProductDependencies = ( - 1611360F2B3F6C68003C20F6 /* WhisperKit */, + 167209772C88BF3F0010BE5F /* WhisperKit */, ); productName = "Basic Watch App"; productReference = 161135DE2B3F66DA003C20F6 /* WhisperAX Watch App.app */; @@ -441,7 +441,6 @@ ); mainGroup = 167B34552B05431E0076F261; packageReferences = ( - 16D581062B4F7DCE000C0AB0 /* XCRemoteSwiftPackageReference "swift-markdown-ui" */, 1612D45B2C87D9ED009BB384 /* XCLocalSwiftPackageReference "../../../WhisperKit" */, ); productRefGroup = 167B345F2B05431E0076F261 /* Products */; @@ -1121,17 +1120,6 @@ }; /* End XCLocalSwiftPackageReference section */ -/* Begin XCRemoteSwiftPackageReference section */ - 16D581062B4F7DCE000C0AB0 /* XCRemoteSwiftPackageReference "swift-markdown-ui" */ = { - isa = XCRemoteSwiftPackageReference; - repositoryURL = "https://github.com/gonzalezreal/swift-markdown-ui.git"; - requirement = { - kind = upToNextMajorVersion; - minimumVersion = 2.3.0; - }; - }; -/* End XCRemoteSwiftPackageReference section */ - /* Begin XCSwiftPackageProductDependency section */ 1612D45C2C87D9ED009BB384 /* WhisperKit */ = { isa = XCSwiftPackageProductDependency; @@ -1141,6 +1129,11 @@ isa = XCSwiftPackageProductDependency; productName = WhisperKitMLX; }; + 167209772C88BF3F0010BE5F /* WhisperKit */ = { + isa = XCSwiftPackageProductDependency; + package = 1612D45B2C87D9ED009BB384 /* XCLocalSwiftPackageReference "../../../WhisperKit" */; + productName = WhisperKit; + }; /* End XCSwiftPackageProductDependency section */ }; rootObject = 167B34562B05431E0076F261 /* Project object */; diff --git a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 6e7f8f9..2c69b2d 100644 --- a/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/Examples/WhisperAX/WhisperAX.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,22 +1,12 @@ { - "originHash" : "cd17206b47bb810af9459722192530e3838d8e6629a970988e32a432aaa05f6e", + "originHash" : "831ad63194a5262b2549d58e383a520f9cbbc80b4a75660fbbcc56d65edfdab4", "pins" : [ { "identity" : "mlx-swift", "kind" : "remoteSourceControl", - "location" : "https://github.com/ml-explore/mlx-swift", + "location" : "https://github.com/davidkoski/mlx-swift", "state" : { - "revision" : "597aaa5f465b4b9a17c8646b751053f84e37925b", - "version" : "0.16.0" - } - }, - { - "identity" : "networkimage", - "kind" : "remoteSourceControl", - "location" : "https://github.com/gonzalezreal/NetworkImage", - "state" : { - "revision" : "7aff8d1b31148d32c5933d75557d42f6323ee3d1", - "version" : "6.0.0" + "revision" : "3314bc684f0ccab1793be54acddaea16c0501d3c" } }, { @@ -28,15 +18,6 @@ "version" : "1.3.0" } }, - { - "identity" : "swift-markdown-ui", - "kind" : "remoteSourceControl", - "location" : "https://github.com/gonzalezreal/swift-markdown-ui.git", - "state" : { - "revision" : "9a8119b37e09a770367eeb26e05267c75d854053", - "version" : "2.3.1" - } - }, { "identity" : "swift-numerics", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index 824bb0f..47a24d6 100644 --- a/Package.swift +++ b/Package.swift @@ -108,7 +108,7 @@ func mlxProducts() -> [Product] { func mlxDependencies() -> [Package.Dependency] { guard !isMLXDisabled() else { return [] } return [ - .package(url: "https://github.com/ml-explore/mlx-swift", exact: "0.16.0"), + .package(url: "https://github.com/davidkoski/mlx-swift", revision: "3314bc684f0ccab1793be54acddaea16c0501d3c"), ] } diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 250f895..032e04e 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -21,15 +21,15 @@ public protocol TextDecoding { withPrompt initialPrompt: [Int] ) throws -> DecodingInputs - func predictLogits( - inputIds: MLMultiArray, - cacheLength: MLMultiArray, - keyCache: MLMultiArray?, - valueCache: MLMultiArray?, - kvCacheUpdateMask: MLMultiArray, - encoderOutputEmbeds: MLMultiArray, - decoderKeyPaddingMask: MLMultiArray - ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? +// func predictLogits( +// inputIds: MLMultiArray, +// cacheLength: MLMultiArray, +// keyCache: MLMultiArray?, +// valueCache: MLMultiArray?, +// kvCacheUpdateMask: MLMultiArray, +// encoderOutputEmbeds: MLMultiArray, +// decoderKeyPaddingMask: MLMultiArray +// ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? func prefillKVCache( withTask task: MLMultiArray, @@ -376,61 +376,63 @@ public extension TextDecoding { let inferenceTime = Date() Logging.debug("Detecting language...") - let predictedLogits = try await textDecoder.predictLogits( - inputIds: decoderInputs.inputIds, - cacheLength: decoderInputs.cacheLength, - keyCache: decoderInputs.keyCache, - valueCache: decoderInputs.valueCache, - kvCacheUpdateMask: decoderInputs.kvCacheUpdateMask, - encoderOutputEmbeds: encoderOutput, - decoderKeyPaddingMask: decoderInputs.decoderKeyPaddingMask - ) - - guard let decoderOutput = predictedLogits else { - Logging.error("Unable to decode logits") - throw WhisperError.decodingLogitsFailed() - } - - let decodingInferenceTime = Date().timeIntervalSince(inferenceTime) - timings.decodingPredictions += decodingInferenceTime - - // MARK: Non-inference - - // Update predicted token as current - let logits = languageLogitsFilter.filterLogits(decoderOutput.logits!, withTokens: currentTokens) - - // MARK: Sampling - - let samplingStartTime = Date() - - let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) - - nextToken = sampleResult.tokens.last! - logProbs = sampleResult.logProbs - - let samplingTime = Date().timeIntervalSince(samplingStartTime) - timings.decodingSampling += samplingTime - - var languageProbs = [String: Float]() - for (tokenIndex, token) in sampleResult.tokens.enumerated() { - if tokenizer.allLanguageTokens.contains(token) { - let language = tokenizer.decode(tokens: [token]).trimmingSpecialTokenCharacters() - languageProbs[language] = sampleResult.logProbs[tokenIndex] - } - } - - let sampledLanguage = tokenizer.decode(tokens: [nextToken]).trimmingSpecialTokenCharacters() - let detectedLanguage: String - if Constants.languageCodes.contains(sampledLanguage) { - detectedLanguage = sampledLanguage - Logging.debug("Detected language: \(sampledLanguage)") - } else { - detectedLanguage = Constants.defaultLanguageCode - Logging.error("Detected language \(sampledLanguage) is not supported, defaulting to \(Constants.defaultLanguageCode)") - } +// let predictedLogits = try await textDecoder.predictLogits( +// inputIds: decoderInputs.inputIds, +// cacheLength: decoderInputs.cacheLength, +// keyCache: decoderInputs.keyCache, +// valueCache: decoderInputs.valueCache, +// kvCacheUpdateMask: decoderInputs.kvCacheUpdateMask, +// encoderOutputEmbeds: encoderOutput, +// decoderKeyPaddingMask: decoderInputs.decoderKeyPaddingMask +// ) +// +// guard let decoderOutput = predictedLogits else { +// Logging.error("Unable to decode logits") +// throw WhisperError.decodingLogitsFailed() +// } +// +// let decodingInferenceTime = Date().timeIntervalSince(inferenceTime) +// timings.decodingPredictions += decodingInferenceTime +// +// // MARK: Non-inference +// +// // Update predicted token as current +// let logits = languageLogitsFilter.filterLogits(decoderOutput.logits!, withTokens: currentTokens) +// +// // MARK: Sampling +// +// let samplingStartTime = Date() +// +// let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) +// +// nextToken = sampleResult.tokens.last! +// logProbs = sampleResult.logProbs +// +// let samplingTime = Date().timeIntervalSince(samplingStartTime) +// timings.decodingSampling += samplingTime +// +// var languageProbs = [String: Float]() +// for (tokenIndex, token) in sampleResult.tokens.enumerated() { +// if tokenizer.allLanguageTokens.contains(token) { +// let language = tokenizer.decode(tokens: [token]).trimmingSpecialTokenCharacters() +// languageProbs[language] = sampleResult.logProbs[tokenIndex] +// } +// } +// +// let sampledLanguage = tokenizer.decode(tokens: [nextToken]).trimmingSpecialTokenCharacters() +// let detectedLanguage: String +// if Constants.languageCodes.contains(sampledLanguage) { +// detectedLanguage = sampledLanguage +// Logging.debug("Detected language: \(sampledLanguage)") +// } else { +// detectedLanguage = Constants.defaultLanguageCode +// Logging.error("Detected language \(sampledLanguage) is not supported, defaulting to \(Constants.defaultLanguageCode)") +// } return DecodingResult( - language: detectedLanguage, - languageProbs: languageProbs, +// language: detectedLanguage, +// languageProbs: languageProbs, + language: Constants.defaultLanguageCode, + languageProbs: [:], tokens: [], tokenLogProbs: [], text: "", diff --git a/Sources/WhisperKit/MLX/MLXModels.swift b/Sources/WhisperKit/MLX/MLXModels.swift index 4322ff6..e29b103 100644 --- a/Sources/WhisperKit/MLX/MLXModels.swift +++ b/Sources/WhisperKit/MLX/MLXModels.swift @@ -24,11 +24,24 @@ struct MLXModelConfig: Codable { let nTextLayer: Int } -struct KV { +public struct KV { var k: MLXArray var v: MLXArray } +public struct MLXDecodingCache { + public var kvCache: [KV] + public var alignmentWeights: MLXArray? + + public init( + kvCache: [KV], + alignmentWeights: MLXArray? + ) { + self.kvCache = kvCache + self.alignmentWeights = alignmentWeights + } +} + struct TextDecoderResult { var logits: MLXArray var kvCache: [KV] diff --git a/Sources/WhisperKit/MLX/MLXTextDecoder.swift b/Sources/WhisperKit/MLX/MLXTextDecoder.swift index d6e2b5f..63201d9 100644 --- a/Sources/WhisperKit/MLX/MLXTextDecoder.swift +++ b/Sources/WhisperKit/MLX/MLXTextDecoder.swift @@ -5,9 +5,10 @@ import CoreML import MLX import MLXNN import WhisperKit +import MLXFast @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -public final class MLXTextDecoder: TextDecoding { +public final class MLXTextDecoder: TextDecoding { public var model: TextDecoderModule? public var tokenizer: (any WhisperTokenizer)? public var prefillData: (any WhisperMLModel)? @@ -42,17 +43,16 @@ public final class MLXTextDecoder: TextDecoding { public init() {} - private static func toKvCache(keyCache: MLMultiArray?, valueCache: MLMultiArray?) -> [KV]? { + private static func toKvCache(keyCache: MLXArray?, valueCache: MLXArray?) -> [KV]? { guard let keyCache, let valueCache else { return nil } - let keyCacheMlx = keyCache.asMLXArray(FloatType.self) - let valueCacheMlx = valueCache.asMLXArray(FloatType.self) - assert(keyCacheMlx.shape == valueCacheMlx.shape) + assert(keyCache.shape == valueCache.shape) + var result = [KV]() - for index in 0.. (logits: MLMultiArray?, cache: DecodingCache?)? { + ) async throws -> (logits: MLXArray?, cache: MLXDecodingCache?)? { + let time3 = Date() + guard let model else { return nil } let tokens = inputIds.asMLXArray(Int32.self) let audioFeatures = encoderOutputEmbeds.asMLXArray(FloatType.self).asMLXInput() + Logging.debug("Time to prepare input time: \(Date().timeIntervalSince(time3))") + + let time = Date() let result = model( tokens, xa: audioFeatures, kvCache: Self.toKvCache(keyCache: keyCache, valueCache: valueCache) ) - let keyCache = try MLX.stacked(result.kvCache.map(\.k)).asMLMultiArray() - let valueCache = try MLX.stacked(result.kvCache.map(\.v)).asMLMultiArray() - let decodingCache = DecodingCache( - keyCache: keyCache, - valueCache: valueCache, + MLX.eval(result.logits) + + Logging.debug("Time to Inference time: \(Date().timeIntervalSince(time))") + + let time2 = Date() + + let decodingCache = MLXDecodingCache( + kvCache: result.kvCache, alignmentWeights: nil ) - return try (result.logits.asMLMultiArray(), decodingCache) + Logging.debug("Time to cache time: \(Date().timeIntervalSince(time2))") + + return try (result.logits, decodingCache) } public func decodeText( @@ -149,6 +159,9 @@ public final class MLXTextDecoder: TextDecoding { throw WhisperError.tokenizerUnavailable() } + let tokenSampler = MLXGreedyTokenSampler(temperature: Float(options.temperature), eotToken: tokenizer.specialTokens.endToken, decodingOptions: options) + + // Single loop variables var timings = TranscriptionTimings() let prefilledIndex = decoderInputs.cacheLength[0].intValue @@ -195,8 +208,8 @@ public final class MLXTextDecoder: TextDecoding { Logging.debug("Running main loop for a maximum of \(loopCount) iterations, starting at index \(prefilledIndex)") var hasAlignment = false var isFirstTokenLogProbTooLow = false - var keyCache = decoderInputs.keyCache - var valueCache = decoderInputs.valueCache + var keyCache = decoderInputs.keyCache?.asMLXArray(FloatType.self) + var valueCache = decoderInputs.valueCache?.asMLXArray(FloatType.self) for tokenIndex in prefilledIndex.. MLXArray { +// MLXFast.layerNorm(x, weight: weight, bias: bias, eps: 1e-5) +// } +// } diff --git a/Sources/WhisperKit/MLX/MLXTokenSampler.swift b/Sources/WhisperKit/MLX/MLXTokenSampler.swift new file mode 100644 index 0000000..9ac1f03 --- /dev/null +++ b/Sources/WhisperKit/MLX/MLXTokenSampler.swift @@ -0,0 +1,186 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import MLX +import MLXNN +import MLXRandom +import WhisperKit +import Foundation + +public protocol MLXTokenSampling { + func update(tokens: [Int], logits: MLXArray, logProbs: [Float]) -> SamplingResult + func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult +} + +public struct SamplingResult { + public var tokens: [Int] + public var logProbs: [Float] + public var completed: Bool + + public init(tokens: [Int], logProbs: [Float], completed: Bool) { + self.tokens = tokens + self.logProbs = logProbs + self.completed = completed + } +} + +open class MLXGreedyTokenSampler: MLXTokenSampling { + public var temperature: Float + public var eotToken: Int + public var decodingOptions: DecodingOptions + + public init(temperature: Float, eotToken: Int, decodingOptions: DecodingOptions) { + self.temperature = temperature + self.eotToken = eotToken + self.decodingOptions = decodingOptions + } + + public func update(tokens: [Int], logits: MLXArray, logProbs: [Float]) -> SamplingResult { + let startTime = CFAbsoluteTimeGetCurrent() + + print("Input shapes:") + print("logits shape:", logits.shape) + print("logits strides:", logits.strides) + +// let flattenStartTime = CFAbsoluteTimeGetCurrent() +// let logitArray = logits.flattened() +// let flattenEndTime = CFAbsoluteTimeGetCurrent() +// print("Flattening time: \(flattenEndTime - flattenStartTime) seconds") +// +// print("Flattened logits shape:", logitArray.shape) +// print("Flattened logits strides:", logitArray.strides) + + let scaleStartTime = CFAbsoluteTimeGetCurrent() + // Scale logits by temperature if > 0 + +// let scaledLogits = temperature != 0.0 ? logitArray / MLXArray(temperature) : logitArray + let scaleEndTime = CFAbsoluteTimeGetCurrent() + print("Scaling time: \(scaleEndTime - scaleStartTime) seconds") + + let softmaxStartTime = CFAbsoluteTimeGetCurrent() + // Apply softmax +// let probs = MLX.softmax(scaledLogits) + let probs: MLXArray + if temperature != 0.0 { + probs = softmax(logits / temperature, axis: -1) + } else { + probs = logits + } + +// let sortedIndices = argSort(probs, axis: -1) +// +// let sortedProbs = take(probs, sortedIndices, axis: -1).squeezed(axis: 0) +// ---- Transcription Timings ---- +// Audio Load: 0.00 ms / 1 runs ( 0.00 ms/run) 0.00% +// Audio Processing: 1.51 ms / 3 runs ( 0.50 ms/run) 0.03% +// Mels: 100.73 ms / 3 runs ( 33.58 ms/run) 2.19% +// Encoding: 420.60 ms / 3 runs ( 140.20 ms/run) 9.13% +// Matrices Init: 1.01 ms / 1 runs ( 1.01 ms/run) 0.02% +// Prefill: 0.05 ms / 1 runs ( 0.05 ms/run) 0.00% +// Decoding: 3790.05 ms / 248 runs ( 15.28 ms/run) 82.29% +// Non-inference: 231.55 ms / 248 runs ( 0.93 ms/run) 5.03% +// - Logit Filtering: 0.01 ms / 248 runs ( 0.00 ms/run) 0.00% +// - Sampling: 156.38 ms / 248 runs ( 0.63 ms/run) 3.40% +// - Kv Caching: 16.10 ms / 248 runs ( 0.06 ms/run) 0.35% +// - Word Timestamps: 0.00 ms / 0 runs ( 0.00 ms/run) 0.00% +// - Windowing: 1.31 ms / 3 runs ( 0.44 ms/run) 0.03% +// Fallbacks: 0.00 ms / 0 runs ( 0.00 ms/run) 0.00% +// Decoding Full Loop: 4604.25 ms / 248 runs ( 18.57 ms/run) 99.97% +// ------------------------------- +// Model Load Time: 0.54 seconds +// Inference Duration (Global): 4.61 seconds +// - Decoding Loop (Avg/window): 1.53 seconds +// - Audio Windows: 3.00 +// Time to first token: 0.28 seconds +// Total Tokens: 247 +// Tokens per Second: 53.85 tok/s +// Real Time Factor: 0.077 +// Fallbacks: 0.0 + let softmaxEndTime = CFAbsoluteTimeGetCurrent() + print("Softmax time: \(softmaxEndTime - softmaxStartTime) seconds") + // if temperature != 0.0 { + // // Top-k multinomial sampling + // let k = decodingOptions.topK + // let test = MLX.top([1, 2, 3], k: 2) + // let topKValues = MLX.argSort().top(probs, k: k) + // + // // Multinomial sample from top-k + // let sumOfTopKValues = topKValues.sum().item() + // let rnd = MLXRandom.uniform(Float.self, low: 0, high: sumOfTopKValues) + // let cumulativeProbs = MLX.cumsum(topKValues) + // let chosenIndex = MLX.argMax(cumulativeProbs .>= rnd).item() + // + // nextToken = topKIndices[chosenIndex].item() + // nextLogprob = MLX.log(topKValues[chosenIndex]).item() + // } else { + // Argmax sampling +// nextLogprob = probs.take(nextToken) + // } + var nextToken: MLXArray + var nextLogprob: MLXArray + +// nextToken = MLX.argMax(probs, axis: -1) + + + let samplingStartTime = CFAbsoluteTimeGetCurrent() + // Argmax sampling +// nextToken = compiledArgmax(probs) +// measure(noncompiledArgmax, probs) +// measure(compiledArgmax, probs) +// nextLogprob = probs.take(nextToken) + let token: Int = compiledArgmax(probs).item() + let logprob: Float = 0.05//nextLogprob.item() + let samplingEndTime = CFAbsoluteTimeGetCurrent() + print("Sampling time: \(samplingEndTime - samplingStartTime) seconds") + + let postProcessStartTime = CFAbsoluteTimeGetCurrent() + let nextTokens = tokens + [token] + let nextLogprobs: [Float] = logProbs + [logprob] + let completed = token == eotToken + let postProcessEndTime = CFAbsoluteTimeGetCurrent() + print("Post-processing time: \(postProcessEndTime - postProcessStartTime) seconds") + + let endTime = CFAbsoluteTimeGetCurrent() + print("Total update time: \(endTime - startTime) seconds") + + return SamplingResult(tokens: nextTokens, logProbs: nextLogprobs, completed: completed) + } + + + private let compiledArgmax: (MLXArray) -> MLXArray = compile { logits in + MLX.argMax(logits, axis: -1) + } + + private func noncompiledArgmax(_ logits: MLXArray) -> MLXArray { + return MLX.argMax(logits, axis: -1) + } + + func measure(_ f: (MLXArray) -> MLXArray, _ x: MLXArray) { + // warm up + for _ in 0..<10 { + eval(f(x)) + } + + let start = Date.timeIntervalSinceReferenceDate + let iterations = 100 + for _ in 0.. SamplingResult { + var finalTokens = tokens + var finalLogProbs = logProbs + if tokens.last != eotToken { + finalTokens.append(eotToken) + finalLogProbs.append(0) + } + + return SamplingResult(tokens: finalTokens, logProbs: finalLogProbs, completed: true) + } +} diff --git a/Sources/WhisperKit/MLX/MLXUtils.swift b/Sources/WhisperKit/MLX/MLXUtils.swift index d6af28b..ea3226a 100644 --- a/Sources/WhisperKit/MLX/MLXUtils.swift +++ b/Sources/WhisperKit/MLX/MLXUtils.swift @@ -36,37 +36,43 @@ extension MLXArray { } } -extension MLXArray { - var contiguousStrides: [Int] { - var contiguousStrides = [1] - var stride = 1 - for dimension in shape.dropFirst().reversed() { - stride = stride * dimension - contiguousStrides.append(stride) - } - contiguousStrides.reverse() - return contiguousStrides - } -} +//extension MLXArray { +// var contiguousStrides: [Int] { +// var contiguousStrides = [1] +// var stride = 1 +// for dimension in shape.dropFirst().reversed() { +// stride = stride * dimension +// contiguousStrides.append(stride) +// } +// contiguousStrides.reverse() +// return contiguousStrides +// } +//} extension MLXArray { func asMLMultiArray() throws -> MLMultiArray { + let dataType = multiArrayDataType() // a buffer to be passed to CoreML let buffer = UnsafeMutableRawPointer.allocate(byteCount: nbytes, alignment: 8) // copy the data from the MLXArray backing into buffer - asData(noCopy: true).withUnsafeBytes { ptr in + let dataStartTime = CFAbsoluteTimeGetCurrent() + asData(access: .noCopy).data.withUnsafeBytes { ptr in let destination = UnsafeMutableRawBufferPointer(start: buffer, count: nbytes) ptr.copyBytes(to: destination) } // `contiguousStrides` has to used, see the [discussion](https://github.com/ml-explore/mlx-swift/issues/117) - return try MLMultiArray( + let time = Date() + let outputArray = try MLMultiArray( dataPointer: buffer, shape: shape.map { NSNumber(value: $0) }, dataType: dataType, - strides: contiguousStrides.map { NSNumber(value: $0) }, + strides: strides.map { NSNumber(value: $0) }, deallocator: { $0.deallocate() } ) + Logging.debug("Time to convert to multi array: \(Date().timeIntervalSince(time))") + + return outputArray } } diff --git a/Tests/WhisperKitMLXTests/MLXUnitTests.swift b/Tests/WhisperKitMLXTests/MLXUnitTests.swift index c51a03e..94bad93 100644 --- a/Tests/WhisperKitMLXTests/MLXUnitTests.swift +++ b/Tests/WhisperKitMLXTests/MLXUnitTests.swift @@ -71,7 +71,7 @@ final class MLXUnitTests: XCTestCase { "Failed to load the tokenizer" ) - let tokenSampler = GreedyTokenSampler( + let tokenSampler = MLXGreedyTokenSampler( temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions @@ -106,7 +106,7 @@ final class MLXUnitTests: XCTestCase { ) textDecoder.tokenizer = try await loadTokenizer(for: .tiny) - let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions) + let tokenSampler = MLXGreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions) let encoderInput = initMLMultiArray(shape: [1, 384, 1, 1500], dataType: .float16, initialValue: FloatType(0)) let inputs = try textDecoder.prepareDecoderInputs(withPrompt: [textDecoder.tokenizer!.specialTokens.startOfTranscriptToken]) @@ -133,7 +133,7 @@ final class MLXUnitTests: XCTestCase { ) textDecoder.tokenizer = try await loadTokenizer(for: .tiny) - let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions) + let tokenSampler = MLXGreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions) let encoderInput = initMLMultiArray(shape: [1, 384, 1, 1500], dataType: .float16, initialValue: FloatType(0)) let inputs = try textDecoder.prepareDecoderInputs(withPrompt: [textDecoder.tokenizer!.specialTokens.startOfTranscriptToken]) diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index 788342e..3a9bacd 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -237,7 +237,7 @@ final class UnitTests: XCTestCase { "Failed to load the tokenizer" ) - let tokenSampler = GreedyTokenSampler( + let tokenSampler = MLXGreedyTokenSampler( temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions @@ -269,7 +269,7 @@ final class UnitTests: XCTestCase { try await textDecoder.loadModel(at: modelPath, computeUnits: ModelComputeOptions().textDecoderCompute) textDecoder.tokenizer = try await loadTokenizer(for: .tiny) - let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions) + let tokenSampler = MLXGreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions) let encoderInput = initMLMultiArray(shape: [1, 384, 1, 1500], dataType: .float16, initialValue: FloatType(0)) let inputs = try textDecoder.prepareDecoderInputs(withPrompt: [textDecoder.tokenizer!.specialTokens.startOfTranscriptToken]) @@ -293,7 +293,7 @@ final class UnitTests: XCTestCase { try await textDecoder.loadModel(at: modelPath, computeUnits: ModelComputeOptions().textDecoderCompute) textDecoder.tokenizer = try await loadTokenizer(for: .tiny) - let tokenSampler = GreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions) + let tokenSampler = MLXGreedyTokenSampler(temperature: 0, eotToken: textDecoder.tokenizer!.specialTokens.endToken, decodingOptions: decodingOptions) let encoderInput = initMLMultiArray(shape: [1, 384, 1, 1500], dataType: .float16, initialValue: FloatType(0)) let inputs = try textDecoder.prepareDecoderInputs(withPrompt: [textDecoder.tokenizer!.specialTokens.startOfTranscriptToken]) From 76d14db03e2cd435d39c3abcc307cd20f2176cc9 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 7 Sep 2024 00:16:57 -0700 Subject: [PATCH 24/29] Restructure package.swift --- Package.resolved | 18 -------- Package.swift | 109 ++++++++++++++++------------------------------- 2 files changed, 37 insertions(+), 90 deletions(-) diff --git a/Package.resolved b/Package.resolved index bb2ef99..6cccf25 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,14 +1,5 @@ { "pins" : [ - { - "identity" : "mlx-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/ml-explore/mlx-swift", - "state" : { - "revision" : "597aaa5f465b4b9a17c8646b751053f84e37925b", - "version" : "0.16.0" - } - }, { "identity" : "swift-argument-parser", "kind" : "remoteSourceControl", @@ -18,15 +9,6 @@ "version" : "1.3.0" } }, - { - "identity" : "swift-numerics", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-numerics", - "state" : { - "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", - "version" : "1.0.2" - } - }, { "identity" : "swift-transformers", "kind" : "remoteSourceControl", diff --git a/Package.swift b/Package.swift index 47a24d6..8084927 100644 --- a/Package.swift +++ b/Package.swift @@ -1,28 +1,37 @@ // swift-tools-version: 5.9 // The swift-tools-version declares the minimum version of Swift required to build this package. -import PackageDescription import Foundation +import PackageDescription let package = Package( name: "whisperkit", platforms: [ .iOS(.v16), .macOS("13.3"), - .watchOS(.v10) + .watchOS(.v10), ], products: [ .library( name: "WhisperKit", targets: ["WhisperKit"] ), - ] - + cliProducts() - + mlxProducts(), + ] + (!isMLXDisabled() ? [ + .executable( + name: "whisperkit-cli", + targets: ["WhisperKitCLI"] + ), + .library( + name: "WhisperKitMLX", + targets: ["WhisperKitMLX"] + ), + ] : []), dependencies: [ .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"), .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), - ] + mlxDependencies(), + ] + (!isMLXDisabled() ? [ + .package(url: "https://github.com/davidkoski/mlx-swift.git", revision: "3314bc684f0ccab1793be54acddaea16c0501d3c"), + ] : []), targets: [ .target( name: "WhisperKit", @@ -31,6 +40,14 @@ let package = Package( ], path: "Sources/WhisperKit/Core" ), + .testTarget( + name: "WhisperKitTests", + dependencies: [ + "WhisperKit", + "WhisperKitTestsUtils", + .product(name: "Transformers", package: "swift-transformers"), + ] + ), .target( name: "WhisperKitTestsUtils", dependencies: [ @@ -54,79 +71,19 @@ let package = Package( .process("Sources/WhisperKitTestsUtils/Resources") ] ), - .testTarget( - name: "WhisperKitTests", - dependencies: [ - "WhisperKit", - "WhisperKitTestsUtils", - .product(name: "Transformers", package: "swift-transformers"), - ] - ) - ] - + cliTargets() - + mlxTargets() -) - -// MARK: - MLX Helper Functions - -// CLI -func cliProducts() -> [Product] { - guard !isMLXDisabled() else { return [] } - return [ - .executable( - name: "whisperkit-cli", - targets: ["WhisperKitCLI"] - ), - ] -} - -func cliTargets() -> [Target] { - guard !isMLXDisabled() else { return [] } - return [ - .executableTarget( - name: "WhisperKitCLI", - dependencies: [ - "WhisperKit", - "WhisperKitMLX", - .product(name: "ArgumentParser", package: "swift-argument-parser"), - ] - ), - ] -} - -// MLX -func mlxProducts() -> [Product] { - guard !isMLXDisabled() else { return [] } - return [ - .library( - name: "WhisperKitMLX", - targets: ["WhisperKitMLX"] - ), - ] -} - -func mlxDependencies() -> [Package.Dependency] { - guard !isMLXDisabled() else { return [] } - return [ - .package(url: "https://github.com/davidkoski/mlx-swift", revision: "3314bc684f0ccab1793be54acddaea16c0501d3c"), - ] -} - -func mlxTargets() -> [Target] { - guard !isMLXDisabled() else { return [] } - return [ + ] + (!isMLXDisabled() ? [ .target( name: "WhisperKitMLX", dependencies: [ "WhisperKit", .product(name: "MLX", package: "mlx-swift"), .product(name: "MLXFFT", package: "mlx-swift"), - .product(name: "MLXNN", package: "mlx-swift") + .product(name: "MLXNN", package: "mlx-swift"), ], path: "Sources/WhisperKit/MLX", resources: [ .copy("Resources/mel_filters_80.npy"), - .copy("Resources/mel_filters_128.npy") + .copy("Resources/mel_filters_128.npy"), ] ), .testTarget( @@ -137,9 +94,17 @@ func mlxTargets() -> [Target] { "WhisperKitTestsUtils", .product(name: "Transformers", package: "swift-transformers"), ] - ) - ] -} + ), + .executableTarget( + name: "WhisperKitCLI", + dependencies: [ + "WhisperKit", + "WhisperKitMLX", + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ] + ), + ] : []) +) // NOTE: `MLX` doesn't support `watchOS` yet, that's why we control the build using the `MLX_DISABLED` environment variable. // To manualy build for `watchOS` use: From 292cb1660985b96028623edfff4748eeb025f958 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 7 Sep 2024 00:18:06 -0700 Subject: [PATCH 25/29] Complete MLXTokenSampling impl --- .../WhisperKit/MLX/MLXFeatureExtractor.swift | 2 +- Sources/WhisperKit/MLX/MLXModels.swift | 5 +- Sources/WhisperKit/MLX/MLXTextDecoder.swift | 21 ++- Sources/WhisperKit/MLX/MLXTokenSampler.swift | 153 ++++-------------- 4 files changed, 47 insertions(+), 134 deletions(-) diff --git a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift index fe95078..453798e 100644 --- a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift +++ b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift @@ -78,7 +78,7 @@ public extension MLXFeatureExtractor { return MLX.concatenated([prefix, x, suffix]) } } - + static func stft( _ x: MLXArray, window: MLXArray, diff --git a/Sources/WhisperKit/MLX/MLXModels.swift b/Sources/WhisperKit/MLX/MLXModels.swift index e29b103..50196ab 100644 --- a/Sources/WhisperKit/MLX/MLXModels.swift +++ b/Sources/WhisperKit/MLX/MLXModels.swift @@ -31,11 +31,11 @@ public struct KV { public struct MLXDecodingCache { public var kvCache: [KV] - public var alignmentWeights: MLXArray? + public var alignmentWeights: [MLXArray?] public init( kvCache: [KV], - alignmentWeights: MLXArray? + alignmentWeights: [MLXArray?] ) { self.kvCache = kvCache self.alignmentWeights = alignmentWeights @@ -45,6 +45,7 @@ public struct MLXDecodingCache { struct TextDecoderResult { var logits: MLXArray var kvCache: [KV] + var alignmentWeights: [MLXArray?] } struct ResidualAttentionBlockResult { diff --git a/Sources/WhisperKit/MLX/MLXTextDecoder.swift b/Sources/WhisperKit/MLX/MLXTextDecoder.swift index 63201d9..736f3dc 100644 --- a/Sources/WhisperKit/MLX/MLXTextDecoder.swift +++ b/Sources/WhisperKit/MLX/MLXTextDecoder.swift @@ -140,7 +140,7 @@ public final class MLXTextDecoder: TextDecoding { let decodingCache = MLXDecodingCache( kvCache: result.kvCache, - alignmentWeights: nil + alignmentWeights: result.alignmentWeights ) Logging.debug("Time to cache time: \(Date().timeIntervalSince(time2))") @@ -320,7 +320,8 @@ public final class MLXTextDecoder: TextDecoding { decoderInputs.kvCacheUpdateMask[tokenIndex + 1] = 1 // Update alignment weights for token if present -// if let newAlignmentWeights = decoderOutput.cache?.alignmentWeights { + // TODO: use correct alignment heads +// if let newAlignmentWeights = try decoderOutput.cache?.alignmentWeights { // hasAlignment = true // for column in 0.. MLXArray { -// MLXFast.layerNorm(x, weight: weight, bias: bias, eps: 1e-5) -// } -// } +class FastLayerNorm: LayerNorm { + override func callAsFunction(_ x: MLXArray) -> MLXArray { + MLXFast.layerNorm(x, weight: weight, bias: bias, eps: 1e-5) + } +} diff --git a/Sources/WhisperKit/MLX/MLXTokenSampler.swift b/Sources/WhisperKit/MLX/MLXTokenSampler.swift index 9ac1f03..572e569 100644 --- a/Sources/WhisperKit/MLX/MLXTokenSampler.swift +++ b/Sources/WhisperKit/MLX/MLXTokenSampler.swift @@ -36,141 +36,48 @@ open class MLXGreedyTokenSampler: MLXTokenSampling { } public func update(tokens: [Int], logits: MLXArray, logProbs: [Float]) -> SamplingResult { - let startTime = CFAbsoluteTimeGetCurrent() - - print("Input shapes:") - print("logits shape:", logits.shape) - print("logits strides:", logits.strides) - -// let flattenStartTime = CFAbsoluteTimeGetCurrent() -// let logitArray = logits.flattened() -// let flattenEndTime = CFAbsoluteTimeGetCurrent() -// print("Flattening time: \(flattenEndTime - flattenStartTime) seconds") -// -// print("Flattened logits shape:", logitArray.shape) -// print("Flattened logits strides:", logitArray.strides) - - let scaleStartTime = CFAbsoluteTimeGetCurrent() + let softmaxOutput: MLXArray + var sampledToken: MLXArray + // Scale logits by temperature if > 0 + let scaledLogits = temperature != 0.0 ? logits / MLXArray(temperature) : logits -// let scaledLogits = temperature != 0.0 ? logitArray / MLXArray(temperature) : logitArray - let scaleEndTime = CFAbsoluteTimeGetCurrent() - print("Scaling time: \(scaleEndTime - scaleStartTime) seconds") + // Always apply softmax + softmaxOutput = softmax(scaledLogits, axis: -1) - let softmaxStartTime = CFAbsoluteTimeGetCurrent() - // Apply softmax -// let probs = MLX.softmax(scaledLogits) - let probs: MLXArray if temperature != 0.0 { - probs = softmax(logits / temperature, axis: -1) - } else { - probs = logits - } - -// let sortedIndices = argSort(probs, axis: -1) -// -// let sortedProbs = take(probs, sortedIndices, axis: -1).squeezed(axis: 0) -// ---- Transcription Timings ---- -// Audio Load: 0.00 ms / 1 runs ( 0.00 ms/run) 0.00% -// Audio Processing: 1.51 ms / 3 runs ( 0.50 ms/run) 0.03% -// Mels: 100.73 ms / 3 runs ( 33.58 ms/run) 2.19% -// Encoding: 420.60 ms / 3 runs ( 140.20 ms/run) 9.13% -// Matrices Init: 1.01 ms / 1 runs ( 1.01 ms/run) 0.02% -// Prefill: 0.05 ms / 1 runs ( 0.05 ms/run) 0.00% -// Decoding: 3790.05 ms / 248 runs ( 15.28 ms/run) 82.29% -// Non-inference: 231.55 ms / 248 runs ( 0.93 ms/run) 5.03% -// - Logit Filtering: 0.01 ms / 248 runs ( 0.00 ms/run) 0.00% -// - Sampling: 156.38 ms / 248 runs ( 0.63 ms/run) 3.40% -// - Kv Caching: 16.10 ms / 248 runs ( 0.06 ms/run) 0.35% -// - Word Timestamps: 0.00 ms / 0 runs ( 0.00 ms/run) 0.00% -// - Windowing: 1.31 ms / 3 runs ( 0.44 ms/run) 0.03% -// Fallbacks: 0.00 ms / 0 runs ( 0.00 ms/run) 0.00% -// Decoding Full Loop: 4604.25 ms / 248 runs ( 18.57 ms/run) 99.97% -// ------------------------------- -// Model Load Time: 0.54 seconds -// Inference Duration (Global): 4.61 seconds -// - Decoding Loop (Avg/window): 1.53 seconds -// - Audio Windows: 3.00 -// Time to first token: 0.28 seconds -// Total Tokens: 247 -// Tokens per Second: 53.85 tok/s -// Real Time Factor: 0.077 -// Fallbacks: 0.0 - let softmaxEndTime = CFAbsoluteTimeGetCurrent() - print("Softmax time: \(softmaxEndTime - softmaxStartTime) seconds") - // if temperature != 0.0 { - // // Top-k multinomial sampling - // let k = decodingOptions.topK - // let test = MLX.top([1, 2, 3], k: 2) - // let topKValues = MLX.argSort().top(probs, k: k) - // - // // Multinomial sample from top-k - // let sumOfTopKValues = topKValues.sum().item() - // let rnd = MLXRandom.uniform(Float.self, low: 0, high: sumOfTopKValues) - // let cumulativeProbs = MLX.cumsum(topKValues) - // let chosenIndex = MLX.argMax(cumulativeProbs .>= rnd).item() - // - // nextToken = topKIndices[chosenIndex].item() - // nextLogprob = MLX.log(topKValues[chosenIndex]).item() - // } else { - // Argmax sampling -// nextLogprob = probs.take(nextToken) - // } - var nextToken: MLXArray - var nextLogprob: MLXArray - -// nextToken = MLX.argMax(probs, axis: -1) - - - let samplingStartTime = CFAbsoluteTimeGetCurrent() - // Argmax sampling -// nextToken = compiledArgmax(probs) -// measure(noncompiledArgmax, probs) -// measure(compiledArgmax, probs) -// nextLogprob = probs.take(nextToken) - let token: Int = compiledArgmax(probs).item() - let logprob: Float = 0.05//nextLogprob.item() - let samplingEndTime = CFAbsoluteTimeGetCurrent() - print("Sampling time: \(samplingEndTime - samplingStartTime) seconds") - - let postProcessStartTime = CFAbsoluteTimeGetCurrent() - let nextTokens = tokens + [token] - let nextLogprobs: [Float] = logProbs + [logprob] - let completed = token == eotToken - let postProcessEndTime = CFAbsoluteTimeGetCurrent() - print("Post-processing time: \(postProcessEndTime - postProcessStartTime) seconds") - - let endTime = CFAbsoluteTimeGetCurrent() - print("Total update time: \(endTime - startTime) seconds") - - return SamplingResult(tokens: nextTokens, logProbs: nextLogprobs, completed: completed) - } + // Top-k multinomial sampling + let sortedIndices = MLX.argSort(softmaxOutput, axis: -1) + // Implement top-k selection (argSort is ascending) + let topKIndices = MLXArray(-decodingOptions.topK ..< 0) + let sortedProbs = take(softmaxOutput, sortedIndices, axis: -1) + let bestValues = sortedProbs.take(topKIndices, axis: -1) + let bestIndices = sortedIndices.take(topKIndices, axis: -1) - private let compiledArgmax: (MLXArray) -> MLXArray = compile { logits in - MLX.argMax(logits, axis: -1) - } + // multinomial sample from top-k + let sumOfbestIndicesResult = bestValues.sum() + let rnd = MLXRandom.uniform(low: 0.0, high: sumOfbestIndicesResult) + let cumulativeProbs = cumsum(bestValues, axis: -1) - private func noncompiledArgmax(_ logits: MLXArray) -> MLXArray { - return MLX.argMax(logits, axis: -1) - } + let chosenIndex = MLX.argMax(cumulativeProbs .>= rnd) - func measure(_ f: (MLXArray) -> MLXArray, _ x: MLXArray) { - // warm up - for _ in 0..<10 { - eval(f(x)) + sampledToken = bestIndices.take(chosenIndex) + } else { + // Argmax sampling + sampledToken = MLX.argMax(softmaxOutput, axis: -1) } + let nextToken = sampledToken.item(Int.self) - let start = Date.timeIntervalSinceReferenceDate - let iterations = 100 - for _ in 0.. SamplingResult { From 649d139e873140356a05bb2d92bbd04877d9e64e Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Sat, 7 Sep 2024 02:15:50 -0700 Subject: [PATCH 26/29] Fix tests --- Package.resolved | 17 +++ Sources/WhisperKit/Core/TextDecoder.swift | 128 +++++++++---------- Sources/WhisperKit/MLX/MLXTextDecoder.swift | 92 ++++++++----- Sources/WhisperKit/MLX/MLXTokenSampler.swift | 24 +--- Sources/WhisperKit/MLX/MLXUtils.swift | 5 +- Sources/WhisperKitCLI/CLIArguments.swift | 9 +- Sources/WhisperKitCLI/TranscribeCLI.swift | 4 +- Tests/WhisperKitMLXTests/MLXUnitTests.swift | 16 +-- Tests/WhisperKitTests/UnitTests.swift | 6 +- 9 files changed, 167 insertions(+), 134 deletions(-) diff --git a/Package.resolved b/Package.resolved index 6cccf25..241c17f 100644 --- a/Package.resolved +++ b/Package.resolved @@ -1,5 +1,13 @@ { "pins" : [ + { + "identity" : "mlx-swift", + "kind" : "remoteSourceControl", + "location" : "https://github.com/davidkoski/mlx-swift.git", + "state" : { + "revision" : "3314bc684f0ccab1793be54acddaea16c0501d3c" + } + }, { "identity" : "swift-argument-parser", "kind" : "remoteSourceControl", @@ -9,6 +17,15 @@ "version" : "1.3.0" } }, + { + "identity" : "swift-numerics", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-numerics", + "state" : { + "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", + "version" : "1.0.2" + } + }, { "identity" : "swift-transformers", "kind" : "remoteSourceControl", diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 032e04e..14575fc 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -21,15 +21,15 @@ public protocol TextDecoding { withPrompt initialPrompt: [Int] ) throws -> DecodingInputs -// func predictLogits( -// inputIds: MLMultiArray, -// cacheLength: MLMultiArray, -// keyCache: MLMultiArray?, -// valueCache: MLMultiArray?, -// kvCacheUpdateMask: MLMultiArray, -// encoderOutputEmbeds: MLMultiArray, -// decoderKeyPaddingMask: MLMultiArray -// ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? + func predictLogits( + inputIds: MLMultiArray, + cacheLength: MLMultiArray, + keyCache: MLMultiArray?, + valueCache: MLMultiArray?, + kvCacheUpdateMask: MLMultiArray, + encoderOutputEmbeds: MLMultiArray, + decoderKeyPaddingMask: MLMultiArray + ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? func prefillKVCache( withTask task: MLMultiArray, @@ -376,63 +376,61 @@ public extension TextDecoding { let inferenceTime = Date() Logging.debug("Detecting language...") -// let predictedLogits = try await textDecoder.predictLogits( -// inputIds: decoderInputs.inputIds, -// cacheLength: decoderInputs.cacheLength, -// keyCache: decoderInputs.keyCache, -// valueCache: decoderInputs.valueCache, -// kvCacheUpdateMask: decoderInputs.kvCacheUpdateMask, -// encoderOutputEmbeds: encoderOutput, -// decoderKeyPaddingMask: decoderInputs.decoderKeyPaddingMask -// ) -// -// guard let decoderOutput = predictedLogits else { -// Logging.error("Unable to decode logits") -// throw WhisperError.decodingLogitsFailed() -// } -// -// let decodingInferenceTime = Date().timeIntervalSince(inferenceTime) -// timings.decodingPredictions += decodingInferenceTime -// -// // MARK: Non-inference -// -// // Update predicted token as current -// let logits = languageLogitsFilter.filterLogits(decoderOutput.logits!, withTokens: currentTokens) -// -// // MARK: Sampling -// -// let samplingStartTime = Date() -// -// let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) -// -// nextToken = sampleResult.tokens.last! -// logProbs = sampleResult.logProbs -// -// let samplingTime = Date().timeIntervalSince(samplingStartTime) -// timings.decodingSampling += samplingTime -// -// var languageProbs = [String: Float]() -// for (tokenIndex, token) in sampleResult.tokens.enumerated() { -// if tokenizer.allLanguageTokens.contains(token) { -// let language = tokenizer.decode(tokens: [token]).trimmingSpecialTokenCharacters() -// languageProbs[language] = sampleResult.logProbs[tokenIndex] -// } -// } -// -// let sampledLanguage = tokenizer.decode(tokens: [nextToken]).trimmingSpecialTokenCharacters() -// let detectedLanguage: String -// if Constants.languageCodes.contains(sampledLanguage) { -// detectedLanguage = sampledLanguage -// Logging.debug("Detected language: \(sampledLanguage)") -// } else { -// detectedLanguage = Constants.defaultLanguageCode -// Logging.error("Detected language \(sampledLanguage) is not supported, defaulting to \(Constants.defaultLanguageCode)") -// } + let predictedLogits = try await textDecoder.predictLogits( + inputIds: decoderInputs.inputIds, + cacheLength: decoderInputs.cacheLength, + keyCache: decoderInputs.keyCache, + valueCache: decoderInputs.valueCache, + kvCacheUpdateMask: decoderInputs.kvCacheUpdateMask, + encoderOutputEmbeds: encoderOutput, + decoderKeyPaddingMask: decoderInputs.decoderKeyPaddingMask + ) + + guard let decoderOutput = predictedLogits else { + Logging.error("Unable to decode logits") + throw WhisperError.decodingLogitsFailed() + } + + let decodingInferenceTime = Date().timeIntervalSince(inferenceTime) + timings.decodingPredictions += decodingInferenceTime + + // MARK: Non-inference + + // Update predicted token as current + let logits = languageLogitsFilter.filterLogits(decoderOutput.logits!, withTokens: currentTokens) + + // MARK: Sampling + + let samplingStartTime = Date() + + let sampleResult = tokenSampler.update(tokens: currentTokens, logits: logits, logProbs: logProbs) + + nextToken = sampleResult.tokens.last! + logProbs = sampleResult.logProbs + + let samplingTime = Date().timeIntervalSince(samplingStartTime) + timings.decodingSampling += samplingTime + + var languageProbs = [String: Float]() + for (tokenIndex, token) in sampleResult.tokens.enumerated() { + if tokenizer.allLanguageTokens.contains(token) { + let language = tokenizer.decode(tokens: [token]).trimmingSpecialTokenCharacters() + languageProbs[language] = sampleResult.logProbs[tokenIndex] + } + } + + let sampledLanguage = tokenizer.decode(tokens: [nextToken]).trimmingSpecialTokenCharacters() + let detectedLanguage: String + if Constants.languageCodes.contains(sampledLanguage) { + detectedLanguage = sampledLanguage + Logging.debug("Detected language: \(sampledLanguage)") + } else { + detectedLanguage = Constants.defaultLanguageCode + Logging.error("Detected language \(sampledLanguage) is not supported, defaulting to \(Constants.defaultLanguageCode)") + } return DecodingResult( -// language: detectedLanguage, -// languageProbs: languageProbs, - language: Constants.defaultLanguageCode, - languageProbs: [:], + language: detectedLanguage, + languageProbs: languageProbs, tokens: [], tokenLogProbs: [], text: "", diff --git a/Sources/WhisperKit/MLX/MLXTextDecoder.swift b/Sources/WhisperKit/MLX/MLXTextDecoder.swift index 736f3dc..60f5478 100644 --- a/Sources/WhisperKit/MLX/MLXTextDecoder.swift +++ b/Sources/WhisperKit/MLX/MLXTextDecoder.swift @@ -8,7 +8,7 @@ import WhisperKit import MLXFast @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) -public final class MLXTextDecoder: TextDecoding { +public final class MLXTextDecoder: TextDecoding { public var model: TextDecoderModule? public var tokenizer: (any WhisperTokenizer)? public var prefillData: (any WhisperMLModel)? @@ -43,6 +43,16 @@ public final class MLXTextDecoder: TextDecoding { public init() {} + private static func toKvCache(keyCache: MLMultiArray?, valueCache: MLMultiArray?) -> [KV]? { + guard let keyCache, let valueCache else { + return nil + } + let keyCacheMlx = keyCache.asMLXArray(FloatType.self) + let valueCacheMlx = valueCache.asMLXArray(FloatType.self) + + return toKvCache(keyCache: keyCacheMlx, valueCache: valueCacheMlx) + } + private static func toKvCache(keyCache: MLXArray?, valueCache: MLXArray?) -> [KV]? { guard let keyCache, let valueCache else { return nil @@ -108,43 +118,68 @@ public final class MLXTextDecoder: TextDecoding { return decoderInputs } + public func predictLogits( + inputIds: MLMultiArray, + cacheLength: MLMultiArray, + keyCache: MLMultiArray?, + valueCache: MLMultiArray?, + kvCacheUpdateMask: MLMultiArray, + encoderOutputEmbeds: MLMultiArray, + decoderKeyPaddingMask: MLMultiArray + ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? { + let result = try await predictLogits( + inputIds: inputIds, + cacheLength: cacheLength, + keyCache: keyCache?.asMLXArray(FloatType.self), + valueCache: valueCache?.asMLXArray(FloatType.self), + kvCacheUpdateMask: kvCacheUpdateMask, + encoderOutputEmbeds: encoderOutputEmbeds.asMLXArray(FloatType.self).asMLXInput(), + decoderKeyPaddingMask: decoderKeyPaddingMask + ) + + guard let result = result, + let keyCacheResult = result.cache?.kvCache.map(\.k), + let valueCacheResult = result.cache?.kvCache.map(\.v) + else { return nil } + + let keyCache = try? MLX.stacked(keyCacheResult).asMLMultiArray() + let valueCache = try? MLX.stacked(valueCacheResult).asMLMultiArray() + let decodingCache = DecodingCache( + keyCache: keyCache, + valueCache: valueCache, + alignmentWeights: nil + ) + + let logits = try? result.logits?.asMLMultiArray() + + return (logits, decodingCache) + } + public func predictLogits( inputIds: MLMultiArray, cacheLength: MLMultiArray, keyCache: MLXArray?, valueCache: MLXArray?, kvCacheUpdateMask: MLMultiArray, - encoderOutputEmbeds: MLMultiArray, + encoderOutputEmbeds: MLXArray, decoderKeyPaddingMask: MLMultiArray ) async throws -> (logits: MLXArray?, cache: MLXDecodingCache?)? { - let time3 = Date() - guard let model else { return nil } let tokens = inputIds.asMLXArray(Int32.self) - let audioFeatures = encoderOutputEmbeds.asMLXArray(FloatType.self).asMLXInput() - Logging.debug("Time to prepare input time: \(Date().timeIntervalSince(time3))") - - let time = Date() let result = model( tokens, - xa: audioFeatures, + xa: encoderOutputEmbeds, kvCache: Self.toKvCache(keyCache: keyCache, valueCache: valueCache) ) - MLX.eval(result.logits) - - Logging.debug("Time to Inference time: \(Date().timeIntervalSince(time))") - - let time2 = Date() let decodingCache = MLXDecodingCache( kvCache: result.kvCache, alignmentWeights: result.alignmentWeights ) - Logging.debug("Time to cache time: \(Date().timeIntervalSince(time2))") - return try (result.logits, decodingCache) + return (result.logits, decodingCache) } public func decodeText( @@ -161,7 +196,6 @@ public final class MLXTextDecoder: TextDecoding { let tokenSampler = MLXGreedyTokenSampler(temperature: Float(options.temperature), eotToken: tokenizer.specialTokens.endToken, decodingOptions: options) - // Single loop variables var timings = TranscriptionTimings() let prefilledIndex = decoderInputs.cacheLength[0].intValue @@ -210,6 +244,7 @@ public final class MLXTextDecoder: TextDecoding { var isFirstTokenLogProbTooLow = false var keyCache = decoderInputs.keyCache?.asMLXArray(FloatType.self) var valueCache = decoderInputs.valueCache?.asMLXArray(FloatType.self) + let encoderOutput = encoderOutput.asMLXArray(FloatType.self).asMLXInput() for tokenIndex in prefilledIndex.. SamplingResult - func finalize(tokens: [Int], logProbs: [Float]) -> SamplingResult -} - -public struct SamplingResult { - public var tokens: [Int] - public var logProbs: [Float] - public var completed: Bool - - public init(tokens: [Int], logProbs: [Float], completed: Bool) { - self.tokens = tokens - self.logProbs = logProbs - self.completed = completed - } -} - -open class MLXGreedyTokenSampler: MLXTokenSampling { +open class MLXGreedyTokenSampler: TokenSampling { public var temperature: Float public var eotToken: Int public var decodingOptions: DecodingOptions @@ -35,6 +19,10 @@ open class MLXGreedyTokenSampler: MLXTokenSampling { self.decodingOptions = decodingOptions } + public func update(tokens: [Int], logits: MLMultiArray, logProbs: [Float]) -> SamplingResult { + return update(tokens: tokens, logits: logits.asMLXArray(FloatType.self), logProbs: logProbs) + } + public func update(tokens: [Int], logits: MLXArray, logProbs: [Float]) -> SamplingResult { let softmaxOutput: MLXArray var sampledToken: MLXArray diff --git a/Sources/WhisperKit/MLX/MLXUtils.swift b/Sources/WhisperKit/MLX/MLXUtils.swift index ea3226a..1d163cb 100644 --- a/Sources/WhisperKit/MLX/MLXUtils.swift +++ b/Sources/WhisperKit/MLX/MLXUtils.swift @@ -51,7 +51,6 @@ extension MLXArray { extension MLXArray { func asMLMultiArray() throws -> MLMultiArray { - let dataType = multiArrayDataType() // a buffer to be passed to CoreML let buffer = UnsafeMutableRawPointer.allocate(byteCount: nbytes, alignment: 8) @@ -61,7 +60,6 @@ extension MLXArray { let destination = UnsafeMutableRawBufferPointer(start: buffer, count: nbytes) ptr.copyBytes(to: destination) } - // `contiguousStrides` has to used, see the [discussion](https://github.com/ml-explore/mlx-swift/issues/117) let time = Date() let outputArray = try MLMultiArray( dataPointer: buffer, @@ -70,13 +68,12 @@ extension MLXArray { strides: strides.map { NSNumber(value: $0) }, deallocator: { $0.deallocate() } ) - Logging.debug("Time to convert to multi array: \(Date().timeIntervalSince(time))") return outputArray } } -extension MLXArray { +public extension MLXArray { func multiArrayDataType() -> MLMultiArrayDataType { switch dtype { case .bool, .bfloat16, .complex64, diff --git a/Sources/WhisperKitCLI/CLIArguments.swift b/Sources/WhisperKitCLI/CLIArguments.swift index a948ed2..28ce0e4 100644 --- a/Sources/WhisperKitCLI/CLIArguments.swift +++ b/Sources/WhisperKitCLI/CLIArguments.swift @@ -2,9 +2,10 @@ // Copyright © 2024 Argmax, Inc. All rights reserved. import ArgumentParser +import WhisperKit import WhisperKitMLX -extension ModelType: ExpressibleByArgument { +extension ModelEngine: ExpressibleByArgument { public init?(argument: String) { self.init(rawValue: argument.lowercased()) } @@ -42,13 +43,13 @@ struct CLIArguments: ParsableArguments { var downloadTokenizerPath: String? @Option(help: "Which feature extractor to use (supported: `coreml` and `mlx`)") - var featureExtractorType: ModelType = .coreML + var featureExtractorType: ModelEngine = .coreML @Option(help: "Which audio encoder to use (supported: `coreml` and `mlx`)") - var audioEncoderType: ModelType = .coreML + var audioEncoderType: ModelEngine = .coreML @Option(help: "Which text decoder to use (supported: `coreml` and `mlx`)") - var textDecoderType: ModelType = .coreML + var textDecoderType: ModelEngine = .coreML @Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}") var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine diff --git a/Sources/WhisperKitCLI/TranscribeCLI.swift b/Sources/WhisperKitCLI/TranscribeCLI.swift index cae3c33..99ae4ff 100644 --- a/Sources/WhisperKitCLI/TranscribeCLI.swift +++ b/Sources/WhisperKitCLI/TranscribeCLI.swift @@ -313,8 +313,8 @@ struct TranscribeCLI: AsyncParsableCommand { } var featureExtractorType = cliArguments.featureExtractorType - var audioEncoderType = cliArguments.featureExtractorType - var textDecoderType = cliArguments.featureExtractorType + var audioEncoderType = cliArguments.audioEncoderType + var textDecoderType = cliArguments.textDecoderType if modelName == nil, mlxModelName != nil { // CoreML model not provided, default to MLX diff --git a/Tests/WhisperKitMLXTests/MLXUnitTests.swift b/Tests/WhisperKitMLXTests/MLXUnitTests.swift index 94bad93..440968d 100644 --- a/Tests/WhisperKitMLXTests/MLXUnitTests.swift +++ b/Tests/WhisperKitMLXTests/MLXUnitTests.swift @@ -312,35 +312,35 @@ final class MLXUnitTests: XCTestCase { // MARK: - Utils Tests - func testContiguousStrides() { + func testStrides() { let count = 24 let arr1 = MLXArray(0.. Date: Sat, 7 Sep 2024 02:19:50 -0700 Subject: [PATCH 27/29] Formatting --- Sources/WhisperKit/Core/AudioProcessor.swift | 5 +- Sources/WhisperKit/Core/Models.swift | 6 +-- Sources/WhisperKit/Core/TextDecoder.swift | 22 ++++---- Sources/WhisperKit/Core/Utils.swift | 8 +-- Sources/WhisperKit/Core/WhisperKit.swift | 2 +- Sources/WhisperKit/MLX/MLXAudioEncoder.swift | 8 +-- .../WhisperKit/MLX/MLXFeatureExtractor.swift | 4 +- Sources/WhisperKit/MLX/MLXTextDecoder.swift | 7 --- Sources/WhisperKit/MLX/MLXTokenSampler.swift | 6 +-- Sources/WhisperKit/MLX/MLXUtils.swift | 12 ----- Tests/WhisperKitMLXTests/MLXUnitTests.swift | 22 ++++---- Tests/WhisperKitTests/UnitTests.swift | 50 +++++++++---------- 12 files changed, 67 insertions(+), 85 deletions(-) diff --git a/Sources/WhisperKit/Core/AudioProcessor.swift b/Sources/WhisperKit/Core/AudioProcessor.swift index 41fe096..4b5fcdf 100644 --- a/Sources/WhisperKit/Core/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/AudioProcessor.swift @@ -80,7 +80,7 @@ public extension AudioProcessing { @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) static func loadAudioAsync(fromPath audioFilePath: String) async throws -> AVAudioPCMBuffer { return try await Task { - return try AudioProcessor.loadAudio(fromPath: audioFilePath) + try AudioProcessor.loadAudio(fromPath: audioFilePath) }.value } @@ -305,7 +305,8 @@ public class AudioProcessor: NSObject, AudioProcessing { try audioFile.read(into: inputBuffer, frameCount: framesToRead) guard let resampledChunk = resampleAudio(fromBuffer: inputBuffer, toSampleRate: outputFormat.sampleRate, - channelCount: outputFormat.channelCount) else { + channelCount: outputFormat.channelCount) + else { Logging.error("Failed to resample audio chunk") return nil } diff --git a/Sources/WhisperKit/Core/Models.swift b/Sources/WhisperKit/Core/Models.swift index 1939cda..c43cfab 100644 --- a/Sources/WhisperKit/Core/Models.swift +++ b/Sources/WhisperKit/Core/Models.swift @@ -151,7 +151,7 @@ public struct ModelInfo: Identifiable, Hashable { public enum ModelEngine: String, Codable { case coreML = "coreml" - case mlx = "mlx" + case mlx } public protocol WhisperModel: AnyObject { @@ -479,9 +479,9 @@ public struct DecodingResult { public init( language: String, - languageProbs: [String : Float], + languageProbs: [String: Float], tokens: [Int], - tokenLogProbs: [[Int : Float]], + tokenLogProbs: [[Int: Float]], text: String, avgLogProb: Float, noSpeechProb: Float, diff --git a/Sources/WhisperKit/Core/TextDecoder.swift b/Sources/WhisperKit/Core/TextDecoder.swift index 14575fc..3c37b47 100644 --- a/Sources/WhisperKit/Core/TextDecoder.swift +++ b/Sources/WhisperKit/Core/TextDecoder.swift @@ -21,15 +21,15 @@ public protocol TextDecoding { withPrompt initialPrompt: [Int] ) throws -> DecodingInputs - func predictLogits( - inputIds: MLMultiArray, - cacheLength: MLMultiArray, - keyCache: MLMultiArray?, - valueCache: MLMultiArray?, - kvCacheUpdateMask: MLMultiArray, - encoderOutputEmbeds: MLMultiArray, - decoderKeyPaddingMask: MLMultiArray - ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? + func predictLogits( + inputIds: MLMultiArray, + cacheLength: MLMultiArray, + keyCache: MLMultiArray?, + valueCache: MLMultiArray?, + kvCacheUpdateMask: MLMultiArray, + encoderOutputEmbeds: MLMultiArray, + decoderKeyPaddingMask: MLMultiArray + ) async throws -> (logits: MLMultiArray?, cache: DecodingCache?)? func prefillKVCache( withTask task: MLMultiArray, @@ -503,7 +503,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { guard let model, let keyCache, let valueCache else { return nil } - + let modelInputs = TextDecoderInput( input_ids: inputIds, cache_length: cacheLength, @@ -800,7 +800,7 @@ open class TextDecoder: TextDecoding, WhisperMLModel { break } } - + // Cleanup the early stop flag after loop completion if shouldEarlyStop.removeValue(forKey: windowUUID) == nil { Logging.error("Early stop flag not found for window: \(windowUUID)") diff --git a/Sources/WhisperKit/Core/Utils.swift b/Sources/WhisperKit/Core/Utils.swift index 45b930c..f70b24f 100644 --- a/Sources/WhisperKit/Core/Utils.swift +++ b/Sources/WhisperKit/Core/Utils.swift @@ -203,12 +203,12 @@ public extension String { } extension AVAudioPCMBuffer { - // Appends the contents of another buffer to the current buffer + /// Appends the contents of another buffer to the current buffer func appendContents(of buffer: AVAudioPCMBuffer) -> Bool { return appendContents(of: buffer, startingFrame: 0, frameCount: buffer.frameLength) } - // Appends a specific range of frames from another buffer to the current buffer + /// Appends a specific range of frames from another buffer to the current buffer func appendContents(of buffer: AVAudioPCMBuffer, startingFrame: AVAudioFramePosition, frameCount: AVAudioFrameCount) -> Bool { guard format == buffer.format else { Logging.debug("Format mismatch") @@ -240,7 +240,7 @@ extension AVAudioPCMBuffer { return true } - // Convenience initializer to concatenate multiple buffers into one + /// Convenience initializer to concatenate multiple buffers into one convenience init?(concatenating buffers: [AVAudioPCMBuffer]) { guard !buffers.isEmpty else { Logging.debug("Buffers array should not be empty") @@ -264,7 +264,7 @@ extension AVAudioPCMBuffer { } } - // Computed property to determine the stride for float channel data + /// Computed property to determine the stride for float channel data private var stride: Int { return Int(format.streamDescription.pointee.mBytesPerFrame) / MemoryLayout.size } diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 9ade93b..24fd318 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -440,7 +440,7 @@ open class WhisperKit { guard textDecoder.isModelMultilingual else { throw WhisperError.decodingFailed("Language detection not supported for this model") } - + // Tokenizer required for decoding guard let tokenizer else { throw WhisperError.tokenizerUnavailable() diff --git a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift index 4de2b7a..992c9be 100644 --- a/Sources/WhisperKit/MLX/MLXAudioEncoder.swift +++ b/Sources/WhisperKit/MLX/MLXAudioEncoder.swift @@ -28,8 +28,8 @@ public class MLXAudioEncoder: AudioEncoding, WhisperMLXModel { } } -extension MLXAudioEncoder { - public func loadModel(at modelPath: URL, configPath: URL?) async throws { +public extension MLXAudioEncoder { + func loadModel(at modelPath: URL, configPath: URL?) async throws { let parameters = try loadParameters(at: modelPath) let config = try loadConfig(at: configPath) let encoder = AudioEncoderModule( @@ -45,11 +45,11 @@ extension MLXAudioEncoder { self.model = encoder } - public func unloadModel() { + func unloadModel() { model = nil } - public var modelState: ModelState { + var modelState: ModelState { return model == nil ? .unloaded : .loaded } } diff --git a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift index 453798e..12de9ea 100644 --- a/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift +++ b/Sources/WhisperKit/MLX/MLXFeatureExtractor.swift @@ -38,7 +38,7 @@ open class MLXFeatureExtractor: FeatureExtracting, WhisperMLXModel { return try output.asType(FloatType.self).asMLXOutput().asMLMultiArray() } - // Stubs for WhisperMLXModel protocol, not needed + /// Stubs for WhisperMLXModel protocol, not needed public typealias MLXModuleType = NSObject public var model: NSObject? @@ -78,7 +78,7 @@ public extension MLXFeatureExtractor { return MLX.concatenated([prefix, x, suffix]) } } - + static func stft( _ x: MLXArray, window: MLXArray, diff --git a/Sources/WhisperKit/MLX/MLXTextDecoder.swift b/Sources/WhisperKit/MLX/MLXTextDecoder.swift index 60f5478..8853094 100644 --- a/Sources/WhisperKit/MLX/MLXTextDecoder.swift +++ b/Sources/WhisperKit/MLX/MLXTextDecoder.swift @@ -5,7 +5,6 @@ import CoreML import MLX import MLXNN import WhisperKit -import MLXFast @available(macOS 13, iOS 16, watchOS 10, visionOS 1, *) public final class MLXTextDecoder: TextDecoding { @@ -605,9 +604,3 @@ public class TextDecoderModule: Module { ) } } - -class FastLayerNorm: LayerNorm { - override func callAsFunction(_ x: MLXArray) -> MLXArray { - MLXFast.layerNorm(x, weight: weight, bias: bias, eps: 1e-5) - } -} diff --git a/Sources/WhisperKit/MLX/MLXTokenSampler.swift b/Sources/WhisperKit/MLX/MLXTokenSampler.swift index 9d38235..32da8cd 100644 --- a/Sources/WhisperKit/MLX/MLXTokenSampler.swift +++ b/Sources/WhisperKit/MLX/MLXTokenSampler.swift @@ -1,12 +1,12 @@ // For licensing see accompanying LICENSE.md file. // Copyright © 2024 Argmax, Inc. All rights reserved. +import CoreML +import Foundation import MLX import MLXNN import MLXRandom import WhisperKit -import Foundation -import CoreML open class MLXGreedyTokenSampler: TokenSampling { public var temperature: Float @@ -38,7 +38,7 @@ open class MLXGreedyTokenSampler: TokenSampling { let sortedIndices = MLX.argSort(softmaxOutput, axis: -1) // Implement top-k selection (argSort is ascending) - let topKIndices = MLXArray(-decodingOptions.topK ..< 0) + let topKIndices = MLXArray(-decodingOptions.topK..<0) let sortedProbs = take(softmaxOutput, sortedIndices, axis: -1) let bestValues = sortedProbs.take(topKIndices, axis: -1) let bestIndices = sortedIndices.take(topKIndices, axis: -1) diff --git a/Sources/WhisperKit/MLX/MLXUtils.swift b/Sources/WhisperKit/MLX/MLXUtils.swift index 1d163cb..bb93418 100644 --- a/Sources/WhisperKit/MLX/MLXUtils.swift +++ b/Sources/WhisperKit/MLX/MLXUtils.swift @@ -36,18 +36,6 @@ extension MLXArray { } } -//extension MLXArray { -// var contiguousStrides: [Int] { -// var contiguousStrides = [1] -// var stride = 1 -// for dimension in shape.dropFirst().reversed() { -// stride = stride * dimension -// contiguousStrides.append(stride) -// } -// contiguousStrides.reverse() -// return contiguousStrides -// } -//} extension MLXArray { func asMLMultiArray() throws -> MLMultiArray { diff --git a/Tests/WhisperKitMLXTests/MLXUnitTests.swift b/Tests/WhisperKitMLXTests/MLXUnitTests.swift index 440968d..e65a017 100644 --- a/Tests/WhisperKitMLXTests/MLXUnitTests.swift +++ b/Tests/WhisperKitMLXTests/MLXUnitTests.swift @@ -1,13 +1,13 @@ // For licensing see accompanying LICENSE.md file. // Copyright © 2024 Argmax, Inc. All rights reserved. -import XCTest -import MLX -import WhisperKitTestsUtils import CoreML +import MLX import NaturalLanguage @testable import WhisperKit @testable import WhisperKitMLX +import WhisperKitTestsUtils +import XCTest final class MLXUnitTests: XCTestCase { private var tinyModelPath: String! @@ -153,7 +153,7 @@ final class MLXUnitTests: XCTestCase { let options = DecodingOptions(task: .translate, language: targetLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - try await transcribe( + await transcribe( mlxModelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav", @@ -172,7 +172,7 @@ final class MLXUnitTests: XCTestCase { let options = DecodingOptions(task: .transcribe, language: sourceLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - try await transcribe( + await transcribe( mlxModelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav", @@ -214,7 +214,7 @@ final class MLXUnitTests: XCTestCase { let options = DecodingOptions(task: .translate, language: targetLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - try await transcribe( + await transcribe( mlxModelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav", @@ -233,7 +233,7 @@ final class MLXUnitTests: XCTestCase { let options = DecodingOptions(task: .transcribe, language: sourceLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - try await transcribe( + await transcribe( mlxModelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav", @@ -282,7 +282,7 @@ final class MLXUnitTests: XCTestCase { for (i, option) in optionsPairs.enumerated() { let result = try await XCTUnwrapAsync( - try await transcribe( + await transcribe( mlxModelPath: tinyModelPath, options: option.options, audioFile: "ja_test_clip.wav", @@ -423,18 +423,18 @@ final class MLXUnitTests: XCTestCase { func testAdditiveCausalMask() { let result1 = additiveCausalMask(0) - XCTAssertEqual(result1.shape, [0 ,0], "Array shape should be [0, 0]") + XCTAssertEqual(result1.shape, [0, 0], "Array shape should be [0, 0]") XCTAssertEqual(result1.dtype, .float32, "Array type should be .float32") let result2 = additiveCausalMask(3) - XCTAssertEqual(result2.shape, [3 ,3], "Array shape should be [3, 3]") + XCTAssertEqual(result2.shape, [3, 3], "Array shape should be [3, 3]") XCTAssertEqual(result2.dtype, .float32, "Array type should be .float32") XCTAssertEqual(result2[0].asArray(Float.self), [0.0, -1e9, -1e9], accuracy: accuracy) XCTAssertEqual(result2[1].asArray(Float.self), [0.0, 0.0, -1e9], accuracy: accuracy) XCTAssertEqual(result2[2].asArray(Float.self), [0.0, 0.0, 0.0], accuracy: accuracy) let result3 = additiveCausalMask(4) - XCTAssertEqual(result3.shape, [4 ,4], "Array shape should be [4, 4]") + XCTAssertEqual(result3.shape, [4, 4], "Array shape should be [4, 4]") XCTAssertEqual(result3.dtype, .float32, "Array type should be .float32") XCTAssertEqual(result3[0].asArray(Float.self), [0.0, -1e9, -1e9, -1e9], accuracy: accuracy) XCTAssertEqual(result3[1].asArray(Float.self), [0.0, 0.0, -1e9, -1e9], accuracy: accuracy) diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index 788342e..1f631db 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -1,8 +1,8 @@ // For licensing see accompanying LICENSE.md file. // Copyright © 2024 Argmax, Inc. All rights reserved. -import Combine import AVFoundation +import Combine import CoreML import Hub import NaturalLanguage @@ -47,11 +47,11 @@ final class UnitTests: XCTestCase { XCTAssertNotNil(audioBuffer, "Failed to load audio file at path: \(audioFilePath)") XCTAssertEqual(audioBuffer.format.sampleRate, 16000) XCTAssertEqual(audioBuffer.format.channelCount, 1) - XCTAssertEqual(audioBuffer.frameLength, 176000) + XCTAssertEqual(audioBuffer.frameLength, 176_000) XCTAssertEqual(audioBuffer.frameLength, 11 * 16000) let audioBufferWithStartTime = try AudioProcessor.loadAudio(fromPath: audioFilePath, startTime: 1.2) - XCTAssertEqual(audioBufferWithStartTime.frameLength, AVAudioFrameCount(156800)) + XCTAssertEqual(audioBufferWithStartTime.frameLength, AVAudioFrameCount(156_800)) XCTAssertEqual(audioBufferWithStartTime.frameLength, AVAudioFrameCount(16000 * (11 - 1.2))) let audioBufferWithStartTimeAndEndTime = try AudioProcessor.loadAudio(fromPath: audioFilePath, startTime: 1.2, endTime: 3.4) @@ -103,7 +103,7 @@ final class UnitTests: XCTestCase { let targetSampleRate = 16000.0 let targetChannelCount: AVAudioChannelCount = 1 - let smallMaxReadFrameSize: AVAudioFrameCount = 10_000 // Small chunk size to test chunking logic + let smallMaxReadFrameSize: AVAudioFrameCount = 10000 // Small chunk size to test chunking logic let resampledAudio = AudioProcessor.resampleAudio( fromFile: audioFile, @@ -375,7 +375,7 @@ final class UnitTests: XCTestCase { } let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options, callback: continuationCallback).first!, + await transcribe(modelPath: tinyModelPath, options: options, callback: continuationCallback).first!, "Failed to transcribe" ) @@ -390,7 +390,7 @@ final class UnitTests: XCTestCase { } let resultWithWait = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options, callback: continuationCallbackWithWait).first!, + await transcribe(modelPath: tinyModelPath, options: options, callback: continuationCallbackWithWait).first!, "Failed to transcribe" ) @@ -557,7 +557,7 @@ final class UnitTests: XCTestCase { for option in options { let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: option), + await transcribe(modelPath: tinyModelPath, options: option), "Failed to transcribe" ) XCTAssertEqual(result.segments.first?.tokens.count, targetTokenCount) @@ -571,7 +571,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(task: .translate, language: targetLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav"), + await transcribe(modelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav"), "Failed to transcribe" ) @@ -583,7 +583,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(task: .transcribe, language: sourceLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav"), + await transcribe(modelPath: tinyModelPath, options: options, audioFile: "es_test_clip.wav"), "Failed to transcribe" ) @@ -623,7 +623,7 @@ final class UnitTests: XCTestCase { for (i, option) in optionsPairs.enumerated() { let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: option.options, audioFile: "es_test_clip.wav"), + await transcribe(modelPath: tinyModelPath, options: option.options, audioFile: "es_test_clip.wav"), "Failed to transcribe" ) @@ -649,7 +649,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(task: .translate, language: targetLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav"), + await transcribe(modelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav"), "Failed to transcribe" ) @@ -661,7 +661,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(task: .transcribe, language: sourceLanguage, temperatureFallbackCount: 0) let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav"), + await transcribe(modelPath: tinyModelPath, options: options, audioFile: "ja_test_clip.wav"), "Failed to transcribe" ) @@ -700,7 +700,7 @@ final class UnitTests: XCTestCase { for (i, option) in optionsPairs.enumerated() { let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: option.options, audioFile: "ja_test_clip.wav"), + await transcribe(modelPath: tinyModelPath, options: option.options, audioFile: "ja_test_clip.wav"), "Failed to transcribe" ) @@ -746,7 +746,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(withoutTimestamps: true) let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options), + await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -757,7 +757,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(skipSpecialTokens: true, withoutTimestamps: true) let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options), + await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -768,7 +768,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(usePrefillPrompt: true) try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options), + await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) } @@ -777,7 +777,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(usePrefillPrompt: false) let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options), + await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -824,11 +824,11 @@ final class UnitTests: XCTestCase { func testTopK() async throws { let result10000 = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: DecodingOptions(temperature: 0.5, topK: 10000)).first, + await transcribe(modelPath: tinyModelPath, options: DecodingOptions(temperature: 0.5, topK: 10000)).first, "Failed to transcribe" ) let result5 = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: DecodingOptions(temperature: 0.5)).first, + await transcribe(modelPath: tinyModelPath, options: DecodingOptions(temperature: 0.5)).first, "Failed to transcribe" ) @@ -839,7 +839,7 @@ final class UnitTests: XCTestCase { var options = DecodingOptions(withoutTimestamps: true, clipTimestamps: [0]) let resultFull = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options), + await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -847,7 +847,7 @@ final class UnitTests: XCTestCase { options = DecodingOptions(withoutTimestamps: true, clipTimestamps: [seekTime]) let resultSeek = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options), + await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -867,7 +867,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(skipSpecialTokens: true, promptTokens: promptTokens) let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options), + await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -884,7 +884,7 @@ final class UnitTests: XCTestCase { let options = DecodingOptions(skipSpecialTokens: true, prefixTokens: prefixTokens) let result = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options), + await transcribe(modelPath: tinyModelPath, options: options), "Failed to transcribe" ) @@ -1230,14 +1230,14 @@ final class UnitTests: XCTestCase { func testVADAudioChunkerAccuracy() async throws { let testResult = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: DecodingOptions(), audioFile: "ted_60.m4a"), + await transcribe(modelPath: tinyModelPath, options: DecodingOptions(), audioFile: "ted_60.m4a"), "Failed to transcribe" ) let options = DecodingOptions(chunkingStrategy: .vad) let chunkedResult = try await XCTUnwrapAsync( - try await transcribe(modelPath: tinyModelPath, options: options, audioFile: "ted_60.m4a"), + await transcribe(modelPath: tinyModelPath, options: options, audioFile: "ted_60.m4a"), "Failed to transcribe" ) From 9fff900005c0a6802a18418689fe0a37f9e8a216 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Tue, 10 Sep 2024 09:53:05 -0700 Subject: [PATCH 28/29] Code review --- Sources/WhisperKit/Core/WhisperKit.swift | 3 ++- Sources/WhisperKit/MLX/MLXTextDecoder.swift | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index 86067e8..e0ebbb9 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -170,10 +170,11 @@ open class WhisperKit { variant: String, downloadBase: URL? = nil, useBackgroundSession: Bool = false, + hfToken: String? = nil, from repo: String = "argmaxinc/whisperkit-coreml", progressCallback: ((Progress) -> Void)? = nil ) async throws -> URL { - let hubApi = HubApi(downloadBase: downloadBase, useBackgroundSession: useBackgroundSession) + let hubApi = HubApi(downloadBase: downloadBase, hfToken: hfToken, useBackgroundSession: useBackgroundSession) let repo = Hub.Repo(id: repo, type: .models) let modelSearchPath = "*\(variant.description)/*" do { diff --git a/Sources/WhisperKit/MLX/MLXTextDecoder.swift b/Sources/WhisperKit/MLX/MLXTextDecoder.swift index 8853094..06327e2 100644 --- a/Sources/WhisperKit/MLX/MLXTextDecoder.swift +++ b/Sources/WhisperKit/MLX/MLXTextDecoder.swift @@ -141,15 +141,15 @@ public final class MLXTextDecoder: TextDecoding { let valueCacheResult = result.cache?.kvCache.map(\.v) else { return nil } - let keyCache = try? MLX.stacked(keyCacheResult).asMLMultiArray() - let valueCache = try? MLX.stacked(valueCacheResult).asMLMultiArray() + let keyCache = try MLX.stacked(keyCacheResult).asMLMultiArray() + let valueCache = try MLX.stacked(valueCacheResult).asMLMultiArray() let decodingCache = DecodingCache( keyCache: keyCache, valueCache: valueCache, alignmentWeights: nil ) - let logits = try? result.logits?.asMLMultiArray() + let logits = try result.logits?.asMLMultiArray() return (logits, decodingCache) } From ce11d9b03296f2c2a912d3d3c64ddb6d6bdbdc29 Mon Sep 17 00:00:00 2001 From: ZachNagengast Date: Thu, 19 Sep 2024 23:39:06 -0700 Subject: [PATCH 29/29] Update to latest mlx version --- .swiftpm/configuration/Package.resolved | 24 ------------------------ Makefile | 3 ++- Package.resolved | 5 +++-- Package.swift | 21 ++++++++++++++------- Tests/WhisperKitTests/UnitTests.swift | 2 +- 5 files changed, 20 insertions(+), 35 deletions(-) diff --git a/.swiftpm/configuration/Package.resolved b/.swiftpm/configuration/Package.resolved index 0e61882..6cccf25 100644 --- a/.swiftpm/configuration/Package.resolved +++ b/.swiftpm/configuration/Package.resolved @@ -1,18 +1,6 @@ { "pins" : [ { -<<<<<<< HEAD - "identity" : "mlx-swift", - "kind" : "remoteSourceControl", - "location" : "https://github.com/ml-explore/mlx-swift", - "state" : { - "revision" : "597aaa5f465b4b9a17c8646b751053f84e37925b", - "version" : "0.16.0" - } - }, - { -======= ->>>>>>> main "identity" : "swift-argument-parser", "kind" : "remoteSourceControl", "location" : "https://github.com/apple/swift-argument-parser.git", @@ -22,18 +10,6 @@ } }, { -<<<<<<< HEAD - "identity" : "swift-numerics", - "kind" : "remoteSourceControl", - "location" : "https://github.com/apple/swift-numerics", - "state" : { - "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", - "version" : "1.0.2" - } - }, - { -======= ->>>>>>> main "identity" : "swift-transformers", "kind" : "remoteSourceControl", "location" : "https://github.com/huggingface/swift-transformers.git", diff --git a/Makefile b/Makefile index a1314dd..3d49129 100644 --- a/Makefile +++ b/Makefile @@ -146,5 +146,6 @@ test: clean-package-caches: - @trash ~/Library/Caches/org.swift.swiftpm/repositories @trash ~/Library/Developer/Xcode/DerivedData + @swift package purge-cache + @swift package reset diff --git a/Package.resolved b/Package.resolved index 241c17f..f12ed9c 100644 --- a/Package.resolved +++ b/Package.resolved @@ -3,9 +3,10 @@ { "identity" : "mlx-swift", "kind" : "remoteSourceControl", - "location" : "https://github.com/davidkoski/mlx-swift.git", + "location" : "https://github.com/ml-explore/mlx-swift", "state" : { - "revision" : "3314bc684f0ccab1793be54acddaea16c0501d3c" + "revision" : "f27763bef455d76f9455e9dfc6704a6b2859fa26", + "version" : "0.16.2" } }, { diff --git a/Package.swift b/Package.swift index 8084927..bfcdb67 100644 --- a/Package.swift +++ b/Package.swift @@ -16,7 +16,7 @@ let package = Package( name: "WhisperKit", targets: ["WhisperKit"] ), - ] + (!isMLXDisabled() ? [ + ] + (isMLXEnabled() ? [ .executable( name: "whisperkit-cli", targets: ["WhisperKitCLI"] @@ -29,8 +29,8 @@ let package = Package( dependencies: [ .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.7"), .package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"), - ] + (!isMLXDisabled() ? [ - .package(url: "https://github.com/davidkoski/mlx-swift.git", revision: "3314bc684f0ccab1793be54acddaea16c0501d3c"), + ] + (isMLXEnabled() ? [ + .package(url: "https://github.com/ml-explore/mlx-swift", exact: "0.16.2"), ] : []), targets: [ .target( @@ -71,7 +71,7 @@ let package = Package( .process("Sources/WhisperKitTestsUtils/Resources") ] ), - ] + (!isMLXDisabled() ? [ + ] + (isMLXEnabled() ? [ .target( name: "WhisperKitMLX", dependencies: [ @@ -108,8 +108,15 @@ let package = Package( // NOTE: `MLX` doesn't support `watchOS` yet, that's why we control the build using the `MLX_DISABLED` environment variable. // To manualy build for `watchOS` use: -// `export MLX_DISABLED=1 && xcodebuild clean build-for-testing -scheme whisperkit -sdk watchos10.4 -destination 'platform=watchOS Simulator,OS=10.5,name=Apple Watch Ultra 2 (49mm)' -skipPackagePluginValidation` +// MLX_DISABLED=1 xcodebuild clean build-for-testing -scheme whisperkit -sdk watchos -destination 'platform=watchOS Simulator,name=Apple Watch Ultra 2 (49mm)' -skipPackagePluginValidation +// or with swift build: +// MLX_DISABLED=1 swift build -c release -func isMLXDisabled() -> Bool { - ProcessInfo.processInfo.environment["MLX_DISABLED"] == "1" +func isMLXEnabled() -> Bool { + if let disabledValue = ProcessInfo.processInfo.environment["MLX_DISABLED"] { + return disabledValue.lowercased() == "true" || disabledValue == "1" + } + + // Default enabled + return true } diff --git a/Tests/WhisperKitTests/UnitTests.swift b/Tests/WhisperKitTests/UnitTests.swift index d768a9b..f44af55 100644 --- a/Tests/WhisperKitTests/UnitTests.swift +++ b/Tests/WhisperKitTests/UnitTests.swift @@ -62,7 +62,7 @@ final class UnitTests: XCTestCase { func testAudioFileLoadingWithResampling() throws { let audioFilePath = try XCTUnwrap( - Bundle.module.path(forResource: "jfk_441khz", ofType: "m4a"), + TestResource.path(forResource: "jfk_441khz", ofType: "m4a"), "Audio file not found" ) let audioBuffer = try AudioProcessor.loadAudio(fromPath: audioFilePath)