Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] master from ggerganov:master #197

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 74 additions & 35 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ class HttpClient {
public:
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
const bool progress, std::string * response_str = nullptr) {
if (std::filesystem::exists(output_file)) {
return 0;
}

std::string output_file_partial;
curl = curl_easy_init();
if (!curl) {
Expand Down Expand Up @@ -558,13 +562,14 @@ class LlamaData {
}

sampler = initialize_sampler(opt);

return 0;
}

private:
#ifdef LLAMA_USE_CURL
int download(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
const bool progress, std::string * response_str = nullptr) {
int download(const std::string & url, const std::string & output_file, const bool progress,
const std::vector<std::string> & headers = {}, std::string * response_str = nullptr) {
HttpClient http;
if (http.init(url, headers, output_file, progress, response_str)) {
return 1;
Expand All @@ -573,57 +578,95 @@ class LlamaData {
return 0;
}
#else
int download(const std::string &, const std::vector<std::string> &, const std::string &, const bool,
int download(const std::string &, const std::string &, const bool, const std::vector<std::string> & = {},
std::string * = nullptr) {
printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__);

return 1;
}
#endif

int huggingface_dl(const std::string & model, const std::vector<std::string> headers, const std::string & bn) {
// Helper function to handle model tag extraction and URL construction
std::pair<std::string, std::string> extract_model_and_tag(std::string & model, const std::string & base_url) {
std::string model_tag = "latest";
const size_t colon_pos = model.find(':');
if (colon_pos != std::string::npos) {
model_tag = model.substr(colon_pos + 1);
model = model.substr(0, colon_pos);
}

std::string url = base_url + model + "/manifests/" + model_tag;

return { model, url };
}

// Helper function to download and parse the manifest
int download_and_parse_manifest(const std::string & url, const std::vector<std::string> & headers,
nlohmann::json & manifest) {
std::string manifest_str;
int ret = download(url, "", false, headers, &manifest_str);
if (ret) {
return ret;
}

manifest = nlohmann::json::parse(manifest_str);

return 0;
}

int huggingface_dl(std::string & model, const std::string & bn) {
// Find the second occurrence of '/' after protocol string
size_t pos = model.find('/');
pos = model.find('/', pos + 1);
std::string hfr, hff;
std::vector<std::string> headers = { "User-Agent: llama-cpp", "Accept: application/json" };
std::string url;

if (pos == std::string::npos) {
return 1;
auto [model_name, manifest_url] = extract_model_and_tag(model, "https://huggingface.co/v2/");
hfr = model_name;

nlohmann::json manifest;
int ret = download_and_parse_manifest(manifest_url, headers, manifest);
if (ret) {
return ret;
}

hff = manifest["ggufFile"]["rfilename"];
} else {
hfr = model.substr(0, pos);
hff = model.substr(pos + 1);
}

const std::string hfr = model.substr(0, pos);
const std::string hff = model.substr(pos + 1);
const std::string url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;
return download(url, headers, bn, true);
url = "https://huggingface.co/" + hfr + "/resolve/main/" + hff;

return download(url, bn, true, headers);
}

int ollama_dl(std::string & model, const std::vector<std::string> headers, const std::string & bn) {
int ollama_dl(std::string & model, const std::string & bn) {
const std::vector<std::string> headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" };
if (model.find('/') == std::string::npos) {
model = "library/" + model;
}

std::string model_tag = "latest";
size_t colon_pos = model.find(':');
if (colon_pos != std::string::npos) {
model_tag = model.substr(colon_pos + 1);
model = model.substr(0, colon_pos);
}

std::string manifest_url = "https://registry.ollama.ai/v2/" + model + "/manifests/" + model_tag;
std::string manifest_str;
const int ret = download(manifest_url, headers, "", false, &manifest_str);
auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/");
nlohmann::json manifest;
int ret = download_and_parse_manifest(manifest_url, {}, manifest);
if (ret) {
return ret;
}

nlohmann::json manifest = nlohmann::json::parse(manifest_str);
std::string layer;
std::string layer;
for (const auto & l : manifest["layers"]) {
if (l["mediaType"] == "application/vnd.ollama.image.model") {
layer = l["digest"];
break;
}
}

std::string blob_url = "https://registry.ollama.ai/v2/" + model + "/blobs/" + layer;
return download(blob_url, headers, bn, true);
std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer;

return download(blob_url, bn, true, headers);
}

std::string basename(const std::string & path) {
Expand Down Expand Up @@ -653,22 +696,18 @@ class LlamaData {
return ret;
}

const std::string bn = basename(model_);
const std::vector<std::string> headers = { "--header",
"Accept: application/vnd.docker.distribution.manifest.v2+json" };
const std::string bn = basename(model_);
if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://")) {
rm_until_substring(model_, "://");
ret = huggingface_dl(model_, headers, bn);
ret = huggingface_dl(model_, bn);
} else if (string_starts_with(model_, "hf.co/")) {
rm_until_substring(model_, "hf.co/");
ret = huggingface_dl(model_, headers, bn);
} else if (string_starts_with(model_, "ollama://")) {
rm_until_substring(model_, "://");
ret = ollama_dl(model_, headers, bn);
ret = huggingface_dl(model_, bn);
} else if (string_starts_with(model_, "https://")) {
ret = download(model_, headers, bn, true);
} else {
ret = ollama_dl(model_, headers, bn);
ret = download(model_, bn, true);
} else { // ollama:// or nothing
rm_until_substring(model_, "://");
ret = ollama_dl(model_, bn);
}

model_ = bn;
Expand Down
Loading