Skip to content

Commit

Permalink
imatrix : fix wname for mul_mat_id ops (ggerganov#6271)
Browse files Browse the repository at this point in the history
* imatrix : fix wname for mul_mat_id ops

* also filter tensor names in mul_mat_id ops

---------

Co-authored-by: slaren <slarengh@gmail.com>
  • Loading branch information
2 people authored and hodlen committed Apr 1, 2024
1 parent 3370457 commit adc1fc5
Showing 1 changed file with 21 additions and 18 deletions.
39 changes: 21 additions & 18 deletions examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,31 @@ class IMatrixCollector {
void keep_imatrix(int ncall) const;
};

// remove any prefix and suffixes from the name
// CUDA0#blk.0.attn_k.weight#0 => blk.0.attn_k.weight
static std::string filter_tensor_name(const char * name) {
std::string wname;
const char * p = strchr(name, '#');
if (p != NULL) {
p = p + 1;
const char * q = strchr(p, '#');
if (q != NULL) {
wname = std::string(p, q - p);
} else {
wname = p;
}
} else {
wname = name;
}
return wname;
}

bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data) {
GGML_UNUSED(user_data);

const struct ggml_tensor * src0 = t->src[0];
const struct ggml_tensor * src1 = t->src[1];

std::string wname;
{
// remove any prefix and suffixes from the name
// CUDA0#blk.0.attn_k.weight#0 => blk.0.attn_k.weight
const char * p = strchr(src0->name, '#');
if (p != NULL) {
p = p + 1;
const char * q = strchr(p, '#');
if (q != NULL) {
wname = std::string(p, q - p);
} else {
wname = p;
}
} else {
wname = src0->name;
}
}
std::string wname = filter_tensor_name(src0->name);

// when ask is true, the scheduler wants to know if we are interested in data from this tensor
// if we return true, a follow-up call will be made with ask=false in which we can do the actual collection
Expand Down Expand Up @@ -112,6 +114,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
// this is necessary to guarantee equal number of "ncall" for each tensor
for (int ex = 0; ex < n_as; ++ex) {
src0 = t->src[2 + ex];
wname = filter_tensor_name(src0->name);
auto& e = m_stats[wname];
if (e.values.empty()) {
e.values.resize(src1->ne[0], 0);
Expand Down

0 comments on commit adc1fc5

Please sign in to comment.