Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,14 +291,16 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
hf_tag = "default";
}

const bool offline = params.offline;
std::string model_endpoint = get_model_endpoint();
auto preset_url = model_endpoint + hf_repo + "/resolve/main/preset.ini";

// prepare local path for caching
auto preset_fname = clean_file_name(hf_repo + "_preset.ini");
auto preset_path = fs_get_cache_file(preset_fname);
const int status = common_download_file_single(preset_url, preset_path, params.hf_token, offline);
common_download_opts opts;
opts.bearer_token = params.hf_token;
opts.offline = params.offline;
const int status = common_download_file_single(preset_url, preset_path, opts);
const bool has_preset = status >= 200 && status < 400;

// remote preset is optional, so we don't error out if not found
Expand Down Expand Up @@ -341,10 +343,10 @@ static handle_model_result common_params_handle_model(struct common_params_model
model.hf_file = model.path;
model.path = "";
}
common_download_model_opts opts;
opts.download_mmproj = true;
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
auto download_result = common_download_model(model, bearer_token, opts);
auto download_result = common_download_model(model, opts, true);

if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from Hugging Face\n");
Expand All @@ -365,9 +367,10 @@ static handle_model_result common_params_handle_model(struct common_params_model
model.path = fs_get_cache_file(string_split<std::string>(f, '/').back());
}

common_download_model_opts opts;
common_download_opts opts;
opts.bearer_token = bearer_token;
opts.offline = offline;
auto download_result = common_download_model(model, bearer_token, opts);
auto download_result = common_download_model(model, opts);
if (download_result.model_path.empty()) {
LOG_ERR("error: failed to download model from %s\n", model.url.c_str());
exit(1);
Expand Down
146 changes: 89 additions & 57 deletions common/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ std::pair<std::string, std::string> common_download_split_repo_tag(const std::st
return {hf_repo, tag};
}

class ProgressBar {
class ProgressBar : public common_download_callback {
static inline std::mutex mutex;
static inline std::map<const ProgressBar *, int> lines;
static inline int max_line = 0;
Expand All @@ -138,7 +138,11 @@ class ProgressBar {
}

public:
ProgressBar(const std::string & url = "") : filename(url) {
ProgressBar() = default;

void on_start(const common_download_progress & p) override {
filename = p.url;

if (auto pos = filename.rfind('/'); pos != std::string::npos) {
filename = filename.substr(pos + 1);
}
Expand All @@ -156,13 +160,13 @@ class ProgressBar {
}
}

~ProgressBar() {
void on_done(const common_download_progress &, bool) override {
std::lock_guard<std::mutex> lock(mutex);
cleanup(this);
}

void update(size_t current, size_t total) {
if (!total || !is_output_a_tty()) {
void on_update(const common_download_progress & p) override {
if (!p.total || !is_output_a_tty()) {
return;
}

Expand All @@ -175,8 +179,8 @@ class ProgressBar {
int lines_up = max_line - lines[this];

size_t bar = (55 - len) * 2;
size_t pct = (100 * current) / total;
size_t pos = (bar * current) / total;
size_t pct = (100 * p.downloaded) / p.total;
size_t pos = (bar * p.downloaded) / p.total;

if (lines_up > 0) {
std::cout << "\033[" << lines_up << "A";
Expand All @@ -193,7 +197,7 @@ class ProgressBar {
}
std::cout << '\r' << std::flush;

if (current == total) {
if (p.downloaded == p.total) {
cleanup(this);
}
}
Expand All @@ -206,38 +210,36 @@ static bool common_pull_file(httplib::Client & cli,
const std::string & resolve_path,
const std::string & path_tmp,
bool supports_ranges,
size_t existing_size,
size_t & total_size) {
common_download_progress & p,
common_download_callback * callback) {
std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
if (!ofs.is_open()) {
LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
return false;
}

httplib::Headers headers;
if (supports_ranges && existing_size > 0) {
headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
if (supports_ranges && p.downloaded > 0) {
headers.emplace("Range", "bytes=" + std::to_string(p.downloaded) + "-");
}

const char * func = __func__; // avoid __func__ inside a lambda
size_t downloaded = existing_size;
size_t progress_step = 0;
ProgressBar bar(resolve_path);

auto res = cli.Get(resolve_path, headers,
[&](const httplib::Response &response) {
if (existing_size > 0 && response.status != 206) {
if (p.downloaded > 0 && response.status != 206) {
LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
return false;
}
if (existing_size == 0 && response.status != 200) {
if (p.downloaded == 0 && response.status != 200) {
LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
return false;
}
if (total_size == 0 && response.has_header("Content-Length")) {
if (p.total == 0 && response.has_header("Content-Length")) {
try {
size_t content_length = std::stoull(response.get_header_value("Content-Length"));
total_size = existing_size + content_length;
p.total = p.downloaded + content_length;
} catch (const std::exception &e) {
LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
}
Expand All @@ -250,11 +252,13 @@ static bool common_pull_file(httplib::Client & cli,
LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
return false;
}
downloaded += len;
p.downloaded += len;
progress_step += len;

if (progress_step >= total_size / 1000 || downloaded == total_size) {
bar.update(downloaded, total_size);
if (progress_step >= p.total / 1000 || p.downloaded == p.total) {
if (callback) {
callback->on_update(p);
}
progress_step = 0;
}
return true;
Expand All @@ -275,11 +279,10 @@ static bool common_pull_file(httplib::Client & cli,

// download one single file from remote URL to local path
// returns status code or -1 on error
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token,
const common_header_list & custom_headers,
bool skip_etag = false) {
static int common_download_file_single_online(const std::string & url,
const std::string & path,
const common_download_opts & opts,
bool skip_etag) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;

Expand All @@ -293,14 +296,14 @@ static int common_download_file_single_online(const std::string & url,
auto [cli, parts] = common_http_client(url);

httplib::Headers headers;
for (const auto & h : custom_headers) {
for (const auto & h : opts.headers) {
headers.emplace(h.first, h.second);
}
if (headers.find("User-Agent") == headers.end()) {
headers.emplace("User-Agent", "llama-cpp/" + build_info);
}
if (!bearer_token.empty()) {
headers.emplace("Authorization", "Bearer " + bearer_token);
if (!opts.bearer_token.empty()) {
headers.emplace("Authorization", "Bearer " + opts.bearer_token);
}
cli.set_default_headers(headers);

Expand All @@ -326,10 +329,11 @@ static int common_download_file_single_online(const std::string & url,
etag = head->get_header_value("ETag");
}

size_t total_size = 0;
common_download_progress p;
p.url = url;
if (head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
p.total = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
Expand Down Expand Up @@ -357,13 +361,17 @@ static int common_download_file_single_online(const std::string & url,

{ // silent
std::error_code ec;
std::filesystem::path p(path);
std::filesystem::create_directories(p.parent_path(), ec);
std::filesystem::create_directories(std::filesystem::path(path).parent_path(), ec);
}

bool success = false;
const std::string path_temporary = path + ".downloadInProgress";
int delay = retry_delay_seconds;

if (opts.callback) {
opts.callback->on_start(p);
}

for (int i = 0; i < max_attempts; ++i) {
if (i) {
LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
Expand All @@ -378,28 +386,38 @@ static int common_download_file_single_online(const std::string & url,
existing_size = std::filesystem::file_size(path_temporary);
} else if (remove(path_temporary.c_str()) != 0) {
LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
return -1;
break;
}
}

p.downloaded = existing_size;

LOG_DBG("%s: downloading from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(),
path_temporary.c_str(), etag.c_str());

if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, p, opts.callback)) {
if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
return -1;
break;
}
if (!etag.empty() && !skip_etag) {
write_etag(path, etag);
}
return head->status;
success = true;
break;
}
}

LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached
if (opts.callback) {
opts.callback->on_done(p, success);
}
if (!success) {
LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
return -1; // max attempts reached
}

return head->status;
}

std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
Expand Down Expand Up @@ -438,12 +456,15 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string

int common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline,
const common_header_list & headers,
const common_download_opts & opts,
bool skip_etag) {
if (!offline) {
return common_download_file_single_online(url, path, bearer_token, headers, skip_etag);
if (!opts.offline) {
ProgressBar tty_cb;
common_download_opts online_opts = opts;
if (!online_opts.callback) {
online_opts.callback = &tty_cb;
}
return common_download_file_single_online(url, path, online_opts, skip_etag);
}

if (!std::filesystem::exists(path)) {
Expand All @@ -452,6 +473,16 @@ int common_download_file_single(const std::string & url,
}

LOG_DBG("%s: using cached file (offline mode): %s\n", __func__, path.c_str());

// notify the callback that the file was cached
if (opts.callback) {
common_download_progress p;
p.url = url;
p.cached = true;
opts.callback->on_start(p);
opts.callback->on_done(p, true);
}

return 304; // Not Modified - fake cached response
}

Expand Down Expand Up @@ -631,16 +662,16 @@ struct hf_plan {
hf_cache::hf_file mmproj;
};

static hf_plan get_hf_plan(const common_params_model & model,
const std::string & token,
const common_download_model_opts & opts) {
static hf_plan get_hf_plan(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
hf_plan plan;
hf_cache::hf_files all;

auto [repo, tag] = common_download_split_repo_tag(model.hf_repo);

if (!opts.offline) {
all = hf_cache::get_repo_files(repo, token);
all = hf_cache::get_repo_files(repo, opts.bearer_token);
}
if (all.empty()) {
all = hf_cache::get_cached_files(repo);
Expand Down Expand Up @@ -675,7 +706,7 @@ static hf_plan get_hf_plan(const common_params_model & model,
plan.primary = primary;
plan.model_files = get_split_files(all, primary);

if (opts.download_mmproj) {
if (download_mmproj) {
plan.mmproj = find_best_mmproj(all, primary.path);
}

Expand Down Expand Up @@ -710,18 +741,17 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode
return tasks;
}

common_download_model_result common_download_model(const common_params_model & model,
const std::string & bearer_token,
const common_download_model_opts & opts,
const common_header_list & headers) {
common_download_model_result common_download_model(const common_params_model & model,
const common_download_opts & opts,
bool download_mmproj) {
common_download_model_result result;
std::vector<download_task> tasks;
hf_plan hf;

bool is_hf = !model.hf_repo.empty();

if (is_hf) {
hf = get_hf_plan(model, bearer_token, opts);
hf = get_hf_plan(model, opts, download_mmproj);
for (const auto & f : hf.model_files) {
tasks.push_back({f.url, f.local_path});
}
Expand All @@ -742,8 +772,8 @@ common_download_model_result common_download_model(const common_params_model
std::vector<std::future<bool>> futures;
for (const auto & task : tasks) {
futures.push_back(std::async(std::launch::async,
[&task, &bearer_token, offline = opts.offline, &headers, is_hf]() {
int status = common_download_file_single(task.url, task.path, bearer_token, offline, headers, is_hf);
[&task, &opts, is_hf]() {
int status = common_download_file_single(task.url, task.path, opts, is_hf);
return is_http_status_ok(status);
}
));
Expand Down Expand Up @@ -879,7 +909,9 @@ std::string common_docker_resolve_model(const std::string & docker) {
std::string local_path = fs_get_cache_file(model_filename);

const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
common_download_opts opts;
opts.bearer_token = token;
const int http_status = common_download_file_single(blob_url, local_path, opts);
if (!is_http_status_ok(http_status)) {
throw std::runtime_error("Failed to download Docker Model");
}
Expand Down
Loading
Loading