Skip to content

Apply LoRA during model conversion #709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

/*================================================== CLIPTokenizer ===================================================*/

std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
static inline std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
std::regex re("<lora:([^:]+):([^>]+)>");
std::smatch matches;
std::unordered_map<std::string, float> filename2multiplier;
Expand All @@ -31,7 +31,7 @@ std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remov
return std::make_pair(filename2multiplier, text);
}

std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
static inline std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
std::set<int> byte_set;
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
Expand Down
4 changes: 2 additions & 2 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ int main(int argc, const char* argv[]) {
}

if (params.mode == CONVERT) {
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype);
bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype, params.prompt.c_str(), params.lora_model_dir.c_str());
if (!success) {
fprintf(stderr,
"convert '%s'/'%s' to '%s' failed\n",
Expand Down Expand Up @@ -1218,4 +1218,4 @@ int main(int argc, const char* argv[]) {
free(input_image_buffer);

return 0;
}
}
2 changes: 1 addition & 1 deletion lora.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ struct LoraModel : public GGMLRunner {
std::set<std::string> applied_lora_tensors;
for (auto it : model_tensors) {
std::string k_tensor = it.first;
struct ggml_tensor* weight = model_tensors[it.first];
struct ggml_tensor* weight = it.second;

std::vector<std::string> keys = to_lora_keys(k_tensor, version);
if (keys.size() == 0)
Expand Down
80 changes: 71 additions & 9 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "stable-diffusion.h"
#include "util.h"
#include "vocab.hpp"
#include "clip.hpp"
#include "lora.hpp"

#include "ggml-alloc.h"
#include "ggml-backend.h"
Expand Down Expand Up @@ -1977,7 +1979,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
return false;
}

bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type) {
bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::unordered_map<std::string, float>& loras) {
auto backend = ggml_backend_cpu_init();
size_t mem_size = 1 * 1024 * 1024; // for padding
mem_size += tensor_storages.size() * ggml_tensor_overhead();
Expand All @@ -1987,6 +1989,9 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type

gguf_context* gguf_ctx = gguf_init_empty();

// lora lookup table
std::map<std::string, struct ggml_tensor*> tensors;

auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
const std::string& name = tensor_storage.name;

Expand All @@ -2012,19 +2017,44 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type

gguf_add_tensor(gguf_ctx, tensor);

tensors[name] = tensor;

return true;
};

bool success = load_tensors(on_new_tensor_cb, backend);
ggml_backend_free(backend);
if (!load_tensors(on_new_tensor_cb, backend)) {
ggml_backend_free(backend);
ggml_free(ggml_ctx);
gguf_free(gguf_ctx);
return false;
}

LOG_INFO("load tensors done");
LOG_INFO("trying to save tensors to %s", file_path.c_str());
if (success) {
gguf_write_to_file(gguf_ctx, file_path.c_str(), false);

for (const auto& [lora_path, lora_scale] : loras) {
LoraModel lora(backend, lora_path);
if (!lora.load_from_file()) {
LOG_WARN("load lora tensors from '%s' failed", lora_path.c_str());
ggml_backend_free(backend);
ggml_free(ggml_ctx);
gguf_free(gguf_ctx);
return false;
}

lora.multiplier = lora_scale;
lora.apply(tensors, get_sd_version(), 4);
lora.free_params_buffer();
LOG_INFO("applied '%s':%f", lora_path.c_str(), lora_scale);
}

ggml_backend_free(backend);

LOG_INFO("trying to save tensors to %s", file_path.c_str());
gguf_write_to_file(gguf_ctx, file_path.c_str(), false);

ggml_free(ggml_ctx);
gguf_free(gguf_ctx);
return success;
return true;
}

int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) {
Expand All @@ -2051,7 +2081,7 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
return mem_size;
}

bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type) {
bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type, const char* prompt, const char* lora_model_dir) {
ModelLoader model_loader;

if (!model_loader.init_from_file(input_path)) {
Expand All @@ -2065,6 +2095,38 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
return false;
}
}
bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type);

// process prompt for loras
std::unordered_map<std::string, float> loras;
if (prompt != nullptr && lora_model_dir != nullptr) {
auto result_pair = extract_and_remove_lora(prompt);
std::unordered_map<std::string, float> extracted_loras = result_pair.first;

for (auto& kv : extracted_loras) {
LOG_INFO("lora %s:%.2f", kv.first.c_str(), kv.second);

// save_to_gguf_file expects file paths
std::string st_file_path = path_join(lora_model_dir, kv.first + ".safetensors");
std::string ckpt_file_path = path_join(lora_model_dir, kv.first + ".ckpt");
std::string file_path;
if (file_exists(st_file_path)) {
file_path = st_file_path;
} else if (file_exists(ckpt_file_path)) {
file_path = ckpt_file_path;
} else {
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), kv.first.c_str());
continue;
}

LOG_INFO("found at '%s'", file_path.c_str());
loras[file_path] = kv.second;
}

if (result_pair.second != "") {
LOG_WARN("unused prompt after lora extraction: '%s'", result_pair.second.c_str());
}
}

bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, loras);
return success;
}
2 changes: 1 addition & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class ModelLoader {
ggml_backend_t backend,
std::set<std::string> ignore_tensors = {});

bool save_to_gguf_file(const std::string& file_path, ggml_type type);
bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::unordered_map<std::string, float>& loras = {});
bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type);
int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT);
~ModelLoader() = default;
Expand Down
2 changes: 1 addition & 1 deletion stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);

SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor);

SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type);
SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type, const char* prompt, const char* lora_model_dir);

SD_API uint8_t* preprocess_canny(uint8_t* img,
int width,
Expand Down
Loading