Skip to content

Commit

Permalink
imatrix : keep intermediate imatrix results (ggerganov#5077)
Browse files Browse the repository at this point in the history
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
  • Loading branch information
ikawrakow and Kawrakow authored Jan 22, 2024
1 parent d6bd4d4 commit 15bceec
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ struct StatParams {
std::string ofile = "imatrix.dat";
int n_output_frequency = 10;
int verbosity = 1;
int keep_every = 0;
bool collect_output_weight = false;
};

Expand All @@ -42,6 +43,9 @@ class IMatrixCollector {
int m_last_call = 0;
std::vector<float> m_src1_data;
std::vector<int> m_ids; // the expert ids from ggml_mul_mat_id
//
void save_imatrix(const char * file_name) const;
void keep_imatrix(int ncall) const;
};

bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) {
Expand Down Expand Up @@ -117,6 +121,9 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
if (m_last_call % m_params.n_output_frequency == 0) {
save_imatrix();
}
if (m_params.keep_every > 0 && m_last_call%m_params.keep_every == 0) {
keep_imatrix(m_last_call);
}
}
}
} else {
Expand All @@ -143,14 +150,28 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
if (m_last_call % m_params.n_output_frequency == 0) {
save_imatrix();
}
if (m_params.keep_every > 0 && m_last_call%m_params.keep_every == 0) {
keep_imatrix(m_last_call);
}
}
}

return true;
}

void IMatrixCollector::save_imatrix() const {
const char * fname = m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str();
save_imatrix(m_params.ofile.empty() ? "imatrix.dat" : m_params.ofile.c_str());
}

void IMatrixCollector::keep_imatrix(int ncall) const {
auto file_name = m_params.ofile;
if (file_name.empty()) file_name = "imatrix.dat";
file_name += ".at_";
file_name += std::to_string(ncall);
save_imatrix(file_name.c_str());
}

void IMatrixCollector::save_imatrix(const char * fname) const {
std::ofstream out(fname, std::ios::binary);
int n_entries = m_stats.size();
out.write((const char*)&n_entries, sizeof(n_entries));
Expand Down Expand Up @@ -400,6 +421,8 @@ int main(int argc, char ** argv) {
sparams.verbosity = std::stoi(argv[++iarg]);
} else if (arg == "--no-ppl") {
compute_ppl = false;
} else if (arg == "--keep-imatrix") {
sparams.keep_every = std::stoi(argv[++iarg]);
} else {
args.push_back(argv[iarg]);
}
Expand Down

0 comments on commit 15bceec

Please sign in to comment.