Skip to content

Commit

Permalink
Merge pull request huggingface#175 from argmaxinc/download
Browse files Browse the repository at this point in the history
Add resumable download support with tests
  • Loading branch information
FL33TW00D authored Feb 18, 2025
2 parents abf5b16 + a808140 commit 4f97f98
Show file tree
Hide file tree
Showing 4 changed files with 324 additions and 8 deletions.
163 changes: 158 additions & 5 deletions Sources/Hub/Downloader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import Combine
class Downloader: NSObject, ObservableObject {
private(set) var destination: URL

private let chunkSize = 10 * 1024 * 1024 // 10MB

enum DownloadState {
case notStarted
case downloading(Double)
Expand All @@ -29,7 +31,17 @@ class Downloader: NSObject, ObservableObject {

private var urlSession: URLSession? = nil

init(from url: URL, to destination: URL, using authToken: String? = nil, inBackground: Bool = false) {
init(
from url: URL,
to destination: URL,
using authToken: String? = nil,
inBackground: Bool = false,
resumeSize: Int = 0,
headers: [String: String]? = nil,
expectedSize: Int? = nil,
timeout: TimeInterval = 10,
numRetries: Int = 5
) {
self.destination = destination
super.init()
let sessionIdentifier = "swift-transformers.hub.downloader"
Expand All @@ -43,10 +55,28 @@ class Downloader: NSObject, ObservableObject {

self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)

setupDownload(from: url, with: authToken)
setupDownload(from: url, with: authToken, resumeSize: resumeSize, headers: headers, expectedSize: expectedSize, timeout: timeout, numRetries: numRetries)
}

private func setupDownload(from url: URL, with authToken: String?) {
/// Sets up and initiates a file download operation
///
/// - Parameters:
/// - url: Source URL to download from
/// - authToken: Bearer token for authentication with Hugging Face
/// - resumeSize: Number of bytes already downloaded for resuming interrupted downloads
/// - headers: Additional HTTP headers to include in the request
/// - expectedSize: Expected file size in bytes for validation
/// - timeout: Time interval before the request times out
/// - numRetries: Number of retry attempts for failed downloads
private func setupDownload(
from url: URL,
with authToken: String?,
resumeSize: Int,
headers: [String: String]?,
expectedSize: Int?,
timeout: TimeInterval,
numRetries: Int
) {
downloadState.value = .downloading(0)
urlSession?.getAllTasks { tasks in
// If there's an existing pending background task with the same URL, let it proceed.
Expand All @@ -71,14 +101,137 @@ class Downloader: NSObject, ObservableObject {
}
}
var request = URLRequest(url: url)

// Use headers from argument else create an empty header dictionary
var requestHeaders = headers ?? [:]

// Populate header auth and range fields
if let authToken = authToken {
request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization")
requestHeaders["Authorization"] = "Bearer \(authToken)"
}
if resumeSize > 0 {
requestHeaders["Range"] = "bytes=\(resumeSize)-"
}


request.timeoutInterval = timeout
request.allHTTPHeaderFields = requestHeaders

self.urlSession?.downloadTask(with: request).resume()
Task {
do {
// Create a temp file to write
let tempURL = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
FileManager.default.createFile(atPath: tempURL.path, contents: nil)
let tempFile = try FileHandle(forWritingTo: tempURL)

defer { tempFile.closeFile() }
try await self.httpGet(request: request, tempFile: tempFile, resumeSize: resumeSize, numRetries: numRetries, expectedSize: expectedSize)

// Clean up and move the completed download to its final destination
tempFile.closeFile()
try FileManager.default.moveDownloadedFile(from: tempURL, to: self.destination)

self.downloadState.value = .completed(self.destination)
} catch {
self.downloadState.value = .failed(error)
}
}
}
}

/// Downloads a file from given URL using chunked transfer and handles retries.
///
/// Reference: https://github.com/huggingface/huggingface_hub/blob/418a6ffce7881f5c571b2362ed1c23ef8e4d7d20/src/huggingface_hub/file_download.py#L306
///
/// - Parameters:
/// - request: The URLRequest for the file to download
/// - resumeSize: The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a positive number, the download will resume at the given position
/// - numRetries: The number of retry attempts remaining for failed downloads
/// - expectedSize: The expected size of the file to download. If set, the download will raise an error if the size of the received content is different from the expected one.
/// - Throws: `DownloadError.unexpectedError` if the response is invalid or file size mismatch occurs
/// `URLError` if the download fails after all retries are exhausted
private func httpGet(
request: URLRequest,
tempFile: FileHandle,
resumeSize: Int,
numRetries: Int,
expectedSize: Int?
) async throws {
guard let session = self.urlSession else {
throw DownloadError.unexpectedError
}

// Create a new request with Range header for resuming
var newRequest = request
if resumeSize > 0 {
newRequest.setValue("bytes=\(resumeSize)-", forHTTPHeaderField: "Range")
}

// Start the download and get the byte stream
let (asyncBytes, response) = try await session.bytes(for: newRequest)

guard let response = response as? HTTPURLResponse else {
throw DownloadError.unexpectedError
}

guard (200..<300).contains(response.statusCode) else {
throw DownloadError.unexpectedError
}

var downloadedSize = resumeSize

// Create a buffer to collect bytes before writing to disk
var buffer = Data(capacity: chunkSize)

var newNumRetries = numRetries
do {
for try await byte in asyncBytes {
buffer.append(byte)
// When buffer is full, write to disk
if buffer.count == chunkSize {
if !buffer.isEmpty { // Filter out keep-alive chunks
try tempFile.write(contentsOf: buffer)
buffer.removeAll(keepingCapacity: true)
downloadedSize += chunkSize
newNumRetries = 5
guard let expectedSize = expectedSize else { continue }
let progress = expectedSize != 0 ? Double(downloadedSize) / Double(expectedSize) : 0
downloadState.value = .downloading(progress)
}
}
}

if !buffer.isEmpty {
try tempFile.write(contentsOf: buffer)
downloadedSize += buffer.count
buffer.removeAll(keepingCapacity: true)
newNumRetries = 5
}
} catch let error as URLError {
if newNumRetries <= 0 {
throw error
}
try await Task.sleep(nanoseconds: 1_000_000_000)

let config = URLSessionConfiguration.default
self.urlSession = URLSession(configuration: config, delegate: self, delegateQueue: nil)

try await httpGet(
request: request,
tempFile: tempFile,
resumeSize: downloadedSize,
numRetries: newNumRetries - 1,
expectedSize: expectedSize
)
}

// Verify the downloaded file size matches the expected size
let actualSize = try tempFile.seekToEnd()
if let expectedSize = expectedSize, expectedSize != actualSize {
throw DownloadError.unexpectedError
}
}

@discardableResult
func waitUntilDone() throws -> URL {
// It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
Expand Down
3 changes: 2 additions & 1 deletion Sources/Hub/HubApi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ public extension HubApi {
// From now on, etag, commit_hash, url and size are not empty
guard let remoteCommitHash = remoteMetadata.commitHash,
let remoteEtag = remoteMetadata.etag,
let remoteSize = remoteMetadata.size,
remoteMetadata.location != "" else {
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server")
}
Expand Down Expand Up @@ -396,7 +397,7 @@ public extension HubApi {
try prepareDestination()
try prepareMetadataDestination()

let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession)
let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession, expectedSize: remoteSize)
let downloadSubscriber = downloader.downloadState.sink { state in
if case .downloading(let progress) = state {
progressHandler(progress)
Expand Down
162 changes: 162 additions & 0 deletions Tests/HubTests/DownloaderTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
//
// DownloaderTests.swift
// swift-transformers
//
// Created by Arda Atahan Ibis on 1/28/25.
//

import XCTest
import Combine
@testable import Hub

/// Errors that can occur during the download process
enum DownloadError: Error {
case invalidDownloadLocation
case unexpectedError
}

final class DownloaderTests: XCTestCase {
var tempDir: URL!

override func setUp() {
super.setUp()
tempDir = FileManager.default.temporaryDirectory.appendingPathComponent(UUID().uuidString)
try? FileManager.default.createDirectory(at: tempDir, withIntermediateDirectories: true)
}

override func tearDown() {
try? FileManager.default.removeItem(at: tempDir)
super.tearDown()
}

/// This test downloads a known config file, verifies the download completes, checks the content matches expected value
func testSuccessfulDownload() async throws {
// Create a test file
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")!
let destination = tempDir.appendingPathComponent("config.json")
let fileContent = """
{
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 1,
"eos_token_id": 2,
"model_type": "llama",
"pad_token_id": 0,
"vocab_size": 32000
}
"""

let downloader = Downloader(
from: url,
to: destination
)

// Store subscriber outside the continuation to maintain its lifecycle
var subscriber: AnyCancellable?

try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
subscriber = downloader.downloadState.sink { state in
switch state {
case .completed:
continuation.resume()
case .failed(let error):
continuation.resume(throwing: error)
case .downloading:
break
case .notStarted:
break
}
}
}

// Cancel subscription after continuation completes
subscriber?.cancel()

// Verify download completed successfully
XCTAssertTrue(FileManager.default.fileExists(atPath: destination.path))
XCTAssertEqual(try String(contentsOf: destination, encoding: .utf8), fileContent)
}

/// This test attempts to download with incorrect expected file, verifies the download fails, ensures no partial file is left behind
func testDownloadFailsWithIncorrectSize() async throws {
let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/config.json")!
let destination = tempDir.appendingPathComponent("config.json")

// Create downloader with incorrect expected size
let downloader = Downloader(
from: url,
to: destination,
expectedSize: 999999 // Incorrect size
)

do {
try downloader.waitUntilDone()
XCTFail("Download should have failed due to size mismatch")
} catch {

}

// Verify no file was created at destination
XCTAssertFalse(FileManager.default.fileExists(atPath: destination.path))
}

/// This test downloads an LFS file, interrupts the download at 50% and 75% progress,
/// verifies the download can resume and complete successfully, checks the final file exists and has content
func testSuccessfulInterruptedDownload() async throws {
let url = URL(string: "https://huggingface.co/coreml-projects/sam-2-studio/resolve/main/SAM%202%20Studio%201.1.zip")!
let destination = tempDir.appendingPathComponent("SAM%202%20Studio%201.1.zip")

// Create parent directory if it doesn't exist
try FileManager.default.createDirectory(at: destination.deletingLastPathComponent(),
withIntermediateDirectories: true)

let downloader = Downloader(
from: url,
to: destination,
expectedSize: 73194001 // Correct size for verification
)

// First interruption point at 50%
var threshold = 0.5

var subscriber: AnyCancellable?

do {
// Monitor download progress and interrupt at thresholds to test if
// download continues from where it left off
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, Error>) in
subscriber = downloader.downloadState.sink { state in
switch state {
case .downloading(let progress):
if threshold != 1.0 && progress >= threshold {
// Move to next threshold and interrupt
threshold = threshold == 0.5 ? 0.75 : 1.0
downloader.cancel()
}
case .completed:
continuation.resume()
case .failed(let error):
continuation.resume(throwing: error)
case .notStarted:
break
}
}
}

subscriber?.cancel()

// Verify the file exists and is complete
if FileManager.default.fileExists(atPath: destination.path) {
let attributes = try FileManager.default.attributesOfItem(atPath: destination.path)
let finalSize = attributes[.size] as! Int64
XCTAssertGreaterThan(finalSize, 0, "File should not be empty")
} else {
XCTFail("File was not created at destination")
}
} catch {
throw error
}
}
}
4 changes: 2 additions & 2 deletions Tests/TokenizersTests/ChatTemplateTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ class ChatTemplateTests: XCTestCase {
func testDeepSeekQwenChatTemplate() async throws {
let tokenizer = try await AutoTokenizer.from(pretrained: "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
let encoded = try tokenizer.applyChatTemplate(messages: messages)
let encodedTarget = [151646, 151644, 74785, 279, 23670, 15473, 4128, 13, 151645]
let encodedTarget = [151646, 151644, 74785, 279, 23670, 15473, 4128, 13, 151645, 151648, 198]
XCTAssertEqual(encoded, encodedTarget)

let decoded = tokenizer.decode(tokens: encoded)
let decodedTarget = "<|begin▁of▁sentence|><|User|>Describe the Swift programming language.<|Assistant|>"
let decodedTarget = "<|begin▁of▁sentence|><|User|>Describe the Swift programming language.<|Assistant|><think>\n"
XCTAssertEqual(decoded, decodedTarget)
}

Expand Down

0 comments on commit 4f97f98

Please sign in to comment.