Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates the kvcache func in TextDecoder to work with Swift6 #289

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 43 additions & 39 deletions Sources/WhisperKit/Core/TextDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -308,45 +308,49 @@ public extension TextDecoding {
return kvCache
}

static func updateKVCache(keyTensor: MLMultiArray, keySlice: MLMultiArray,
valueTensor: MLMultiArray, valueSlice: MLMultiArray,
insertAtIndex index: Int)
{
let tensorShape = keyTensor.shape.map { $0.intValue }
let sliceShape = keySlice.shape.map { $0.intValue }
let sliceStrides = keySlice.strides.map { $0.intValue } // same for val
let bytesPerSample = MemoryLayout<FloatType>.size

keyTensor.withUnsafeMutableBytes { keyTensorPointer, keyTargetStrides in
keySlice.withUnsafeBytes { keySlicePointer in
valueTensor.withUnsafeMutableBytes { valueTensorPointer, valueTargetStrides in
valueSlice.withUnsafeBytes { valueSlicePointer in
// Assuming batch size is always 1
DispatchQueue.concurrentPerform(iterations: tensorShape[1]) { j in
// Slice size is 3 for prefill and 1 for decode loops
for k in 0..<sliceShape[3] {
// Equivalent to:
// `tensor[0, j, 0, k + index] = slice[0, j, 0, k + index]`
let keyDestIndex = j * keyTargetStrides[1] + (index + k) * keyTargetStrides[3]
let keyDest = keyTensorPointer.baseAddress! + keyDestIndex * bytesPerSample

let keySliceIndex = j * sliceStrides[1] + k * sliceStrides[3]
let keySlice = keySlicePointer.baseAddress! + keySliceIndex * bytesPerSample
memcpy(keyDest, keySlice, bytesPerSample)

let valDestIndex = j * valueTargetStrides[1] + (index + k) * valueTargetStrides[3]
let valDest = valueTensorPointer.baseAddress! + valDestIndex * bytesPerSample

let valSliceIndex = j * sliceStrides[1] + k * sliceStrides[3]
let valSlice = valueSlicePointer.baseAddress! + valSliceIndex * bytesPerSample
memcpy(valDest, valSlice, bytesPerSample)
}
}
}
}
}
}
}
static func updateKVCache(keyTensor: MLMultiArray, keySlice: MLMultiArray,
valueTensor: MLMultiArray, valueSlice: MLMultiArray,
insertAtIndex index: Int)
{
let tensorShape = keyTensor.shape.map { $0.intValue }
let sliceShape = keySlice.shape.map { $0.intValue }

// Create flat arrays for safe concurrent access
var keyData = [FloatType](repeating: 0, count: keyTensor.count)
var valueData = [FloatType](repeating: 0, count: valueTensor.count)

// Get current tensor data
memcpy(&keyData, keyTensor.dataPointer, keyTensor.count * MemoryLayout<FloatType>.size)
memcpy(&valueData, valueTensor.dataPointer, valueTensor.count * MemoryLayout<FloatType>.size)

// Calculate dimensions for index mapping
let seqLength = tensorShape[3]
let hiddenDim = tensorShape[1]

// Concurrent processing across hidden dimension
DispatchQueue.concurrentPerform(iterations: hiddenDim) { j in
for k in 0..<sliceShape[3] {
// Calculate linear indices
let targetSeqPos = index + k
guard targetSeqPos < seqLength else { continue }

// Map 4D indices [0, j, 0, index+k] to linear index
let flatKeyIndex = j * seqLength + targetSeqPos
let flatSliceIndex = j * sliceShape[3] + k

// Copy from slice to tensor
let sliceKeyPtr = keySlice.dataPointer.assumingMemoryBound(to: FloatType.self)
let sliceValuePtr = valueSlice.dataPointer.assumingMemoryBound(to: FloatType.self)

keyData[flatKeyIndex] = sliceKeyPtr[flatSliceIndex]
valueData[flatKeyIndex] = sliceValuePtr[flatSliceIndex]
}
}

// Copy data back to tensors
memcpy(keyTensor.dataPointer, &keyData, keyTensor.count * MemoryLayout<FloatType>.size)
memcpy(valueTensor.dataPointer, &valueData, valueTensor.count * MemoryLayout<FloatType>.size)
}

static func updateAlignmentWeights(
alignmentTensor: MLMultiArray,
Expand Down
Loading