Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support resumable downloads #4

Merged
merged 5 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 158 additions & 5 deletions Sources/Hub/Downloader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import Combine
class Downloader: NSObject, ObservableObject {
private(set) var destination: URL

private let chunkSize = 10 * 1024 * 1024 // 10MB

enum DownloadState {
case notStarted
case downloading(Double)
Expand All @@ -29,7 +31,17 @@ class Downloader: NSObject, ObservableObject {

private var urlSession: URLSession? = nil

init(from url: URL, to destination: URL, using authToken: String? = nil, inBackground: Bool = false) {
init(
from url: URL,
to destination: URL,
using authToken: String? = nil,
inBackground: Bool = false,
resumeSize: Int = 0,
headers: [String: String]? = nil,
expectedSize: Int? = nil,
timeout: TimeInterval = 10,
numRetries: Int = 5
) {
self.destination = destination
super.init()
let sessionIdentifier = "swift-transformers.hub.downloader"
Expand All @@ -43,10 +55,28 @@ class Downloader: NSObject, ObservableObject {

self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)

setupDownload(from: url, with: authToken)
setupDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries)
}

private func setupDownload(from url: URL, with authToken: String?) {
/// Sets up and initiates a file download operation
///
/// - Parameters:
/// - url: Source URL to download from
/// - authToken: Bearer token for authentication with Hugging Face
/// - resumeSize: Number of bytes already downloaded for resuming interrupted downloads
/// - headers: Additional HTTP headers to include in the request
/// - expectedSize: Expected file size in bytes for validation
/// - timeout: Time interval before the request times out
/// - numRetries: Number of retry attempts for failed downloads
private func setupDownload(
from url: URL,
with authToken: String?,
resumeSize: Int,
headers: [String: String]?,
expectedSize: Int?,
timeout: TimeInterval,
numRetries: Int
) {
downloadState.value = .downloading(0)
urlSession?.getAllTasks { tasks in
// If there's an existing pending background task with the same URL, let it proceed.
Expand All @@ -71,14 +101,137 @@ class Downloader: NSObject, ObservableObject {
}
}
var request = URLRequest(url: url)

// Use headers from argument else create an empty header dictionary
var requestHeaders = headers ?? [:]

// Populate header auth and range fields
if let authToken = authToken {
request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization")
requestHeaders["Authorization"] = "Bearer \(authToken)"
}
if resumeSize > 0 {
requestHeaders["Range"] = "bytes=\(resumeSize)-"
}


request.timeoutInterval = timeout
request.allHTTPHeaderFields = requestHeaders

self.urlSession?.downloadTask(with: request).resume()
Task {
do {
// Create a temp file to write
let tempURL = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
FileManager.default.createFile(atPath: tempURL.path, contents: nil)
let tempFile = try FileHandle(forWritingTo: tempURL)

defer { tempFile.closeFile() }
try await self.httpGet(request: request, tempFile: tempFile, resumeSize: resumeSize, numRetries: numRetries, expectedSize: expectedSize)

// Clean up and move the completed download to its final destination
tempFile.closeFile()
try FileManager.default.moveDownloadedFile(from: tempURL, to: self.destination)

self.downloadState.value = .completed(self.destination)
} catch {
self.downloadState.value = .failed(error)
}
}
}
}

/// Downloads a file from given URL using chunked transfer and handles retries.
///
/// Reference: https://github.com/huggingface/huggingface_hub/blob/418a6ffce7881f5c571b2362ed1c23ef8e4d7d20/src/huggingface_hub/file_download.py#L306
///
/// - Parameters:
/// - request: The URLRequest for the file to download
/// - resumeSize: The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a positive number, the download will resume at the given position
/// - numRetries: The number of retry attempts remaining for failed downloads
/// - expectedSize: The expected size of the file to download. If set, the download will raise an error if the size of the received content is different from the expected one.
/// - Throws: `DownloadError.unexpectedError` if the response is invalid or file size mismatch occurs
/// `URLError` if the download fails after all retries are exhausted
private func httpGet(
request: URLRequest,
tempFile: FileHandle,
resumeSize: Int,
numRetries: Int,
expectedSize: Int?
) async throws {
guard let session = self.urlSession else {
throw DownloadError.unexpectedError
}

// Create a new request with Range header for resuming
var newRequest = request
if resumeSize > 0 {
newRequest.setValue("bytes=\(resumeSize)-", forHTTPHeaderField: "Range")
}

// Start the download and get the byte stream
let (asyncBytes, response) = try await session.bytes(for: newRequest)

guard let response = response as? HTTPURLResponse else {
throw DownloadError.unexpectedError
}

guard (200..<300).contains(response.statusCode) else {
throw DownloadError.unexpectedError
}

var downloadedSize = resumeSize

// Create a buffer to collect bytes before writing to disk
var buffer = Data(capacity: chunkSize)

var newNumRetries = numRetries
do {
for try await byte in asyncBytes {
buffer.append(byte)
// When buffer is full, write to disk
if buffer.count == chunkSize {
if !buffer.isEmpty { // Filter out keep-alive chunks
try tempFile.write(contentsOf: buffer)
buffer.removeAll(keepingCapacity: true)
downloadedSize += chunkSize
newNumRetries = 5
guard let expectedSize = expectedSize else { continue }
let progress = expectedSize != 0 ? Double(downloadedSize) / Double(expectedSize) : 0
downloadState.value = .downloading(progress)
}
}
}

if !buffer.isEmpty {
try tempFile.write(contentsOf: buffer)
downloadedSize += buffer.count
buffer.removeAll(keepingCapacity: true)
newNumRetries = 5
}
} catch let error as URLError {
if newNumRetries <= 0 {
throw error
}
try await Task.sleep(nanoseconds: 1_000_000_000)

let config = URLSessionConfiguration.default
self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)

try await httpGet(
request: request,
tempFile: tempFile,
resumeSize: downloadedSize,
numRetries: newNumRetries - 1,
expectedSize: expectedSize
)
}

// Verify the downloaded file size matches the expected size
let actualSize = try tempFile.seekToEnd()
if let expectedSize = expectedSize, expectedSize != actualSize {
throw DownloadError.unexpectedError
}
}

@discardableResult
func waitUntilDone() throws -> URL {
// It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
Expand Down
3 changes: 2 additions & 1 deletion Sources/Hub/HubApi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ public extension HubApi {
// From now on, etag, commit_hash, url and size are not empty
guard let remoteCommitHash = remoteMetadata.commitHash,
let remoteEtag = remoteMetadata.etag,
let remoteSize = remoteMetadata.size,
remoteMetadata.location != "" else {
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server")
}
Expand Down Expand Up @@ -396,7 +397,7 @@ public extension HubApi {
try prepareDestination()
try prepareMetadataDestination()

let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession)
let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession, expectedSize: remoteSize)
let downloadSubscriber = downloader.downloadState.sink { state in
if case .downloading(let progress) = state {
progressHandler(progress)
Expand Down
162 changes: 162 additions & 0 deletions Tests/HubTests/DownloaderTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
//
// DownloaderTests.swift
// swift-transformers
//
// Created by Arda Atahan Ibis on 1/28/25.
//

import XCTest
import Combine
@testable import Hub

/// Errors that can occur during the download process
enum DownloadError: Error {
case invalidDownloadLocation
case unexpectedError
}

final class DownloaderTests: XCTestCase {
var tempDir: URL!

override func setUp() {
super.setUp()
tempDir = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
try? FileManager.default.createDirectory(at: tempDir, withIntermediateDirectories: true)
}

override func tearDown() {
try? FileManager.default.removeItem(at: tempDir)
super.tearDown()
}

/// This test downloads a known config file, verifies the download completes, checks the content matches expected value
func testSuccessfulDownload() async throws {
// Create a test file
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")!
let destination = tempDir.appendingPathComponent("config.json")
let fileContent = """
{
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 1,
"eos_token_id": 2,
"model_type": "llama",
"pad_token_id": 0,
"vocab_size": 32000
}
"""

let downloader = Downloader(
from: url,
to: destination
)

// Store subscriber outside the continuation to maintain its lifecycle
var subscriber: AnyCancellable?

try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
subscriber = downloader.downloadState.sink { state in
switch state {
case .completed:
continuation.resume()
case .failed(let error):
continuation.resume(throwing: error)
case .downloading:
break
case .notStarted:
break
}
}
}

// Cancel subscription after continuation completes
subscriber?.cancel()

// Verify download completed successfully
XCTAssertTrue(FileManager.default.fileExists(atPath: destination.path))
XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), fileContent)
}

/// This test attempts to download with incorrect expected file, verifies the download fails, ensures no partial file is left behind
func testDownloadFailsWithIncorrectSize() async throws {
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")!
let destination = tempDir.appendingPathComponent("config.json")

// Create downloader with incorrect expected size
let downloader = Downloader(
from: url,
to: destination,
expectedSize: 999999 // Incorrect size
)

do {
try downloader.waitUntilDone()
XCTFail("Download should have failed due to size mismatch")
} catch {

}

// Verify no file was created at destination
XCTAssertFalse(FileManager.default.fileExists(atPath: destination.path))
}

/// This test downloads an LFS file, interrupts the download at 50% and 75% progress,
/// verifies the download can resume and complete successfully, checks the final file exists and has content
func testSuccessfulInterruptedDownload() async throws {
let url = URL(string: "https://huggingface.co/coreml-projects/sam-2-studio/resolve/main/SAM%202%20Studio%201.1.zip")!
let destination = tempDir.appendingPathComponent("SAM%202%20Studio%201.1.zip")

// Create parent directory if it doesn't exist
try FileManager.default.createDirectory(at: destination.deletingLastPathComponent(),
withIntermediateDirectories: true)

let downloader = Downloader(
from: url,
to: destination,
expectedSize: 73194001 // Correct size for verification
)

// First interruption point at 50%
var threshold = 0.5

var subscriber: AnyCancellable?

do {
// Monitor download progress and interrupt at thresholds to test if
// download continues from where it left off
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
subscriber = downloader.downloadState.sink { state in
switch state {
case .downloading(let progress):
if threshold != 1.0 && progress >= threshold {
// Move to next threshold and interrupt
threshold = threshold == 0.5 ? 0.75 : 1.0
downloader.cancel()
}
case .completed:
continuation.resume()
case .failed(let error):
continuation.resume(throwing: error)
case .notStarted:
break
}
}
}

subscriber?.cancel()

// Verify the file exists and is complete
if FileManager.default.fileExists(atPath: destination.path) {
let attributes = try FileManager.default.attributesOfItem(atPath: destination.path)
let finalSize = attributes[.size] as! Int64
XCTAssertGreaterThan(finalSize, 0, "File should not be empty")
} else {
XCTFail("File was not created at destination")
}
} catch {
throw error
}
}
}
4 changes: 2 additions & 2 deletions Tests/TokenizersTests/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class ChatTemplateTests: XCTestCase {
func testDeepSeekQwenChatTemplate() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
let encoded = try tokenizer.applyChatTemplate(messages: messages)
let encodedTarget = [151646, 151644, 74785, 279, 23670, 15473, 4128, 13, 151645]
let encodedTarget = [151646, 151644, 74785, 279, 23670, 15473, 4128, 13, 151645, 151648, 198]
XCTAssertEqual(encoded, encodedTarget)

let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<|begin▁of▁sentence|><|User|>Describe the Swift programming language.<|Assistant|>"
let decodedTarget = "<|begin▁of▁sentence|><|User|>Describe the Swift programming language.<|Assistant|><think>\n"
XCTAssertEqual(decoded, decodedTarget)
}

Expand Down