Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ __STATIC_INLINE__ ggml_fp16_t ggml_ext_tensor_get_f16(const ggml_tensor* tensor,
return *(ggml_fp16_t*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]);
}

__STATIC_INLINE__ ggml_bf16_t ggml_ext_tensor_get_bf16(const ggml_tensor* tensor, int i0, int i1 = 0, int i2 = 0, int i3 = 0) {
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
return *(ggml_bf16_t*)((char*)(tensor->data) + i3 * tensor->nb[3] + i2 * tensor->nb[2] + i1 * tensor->nb[1] + i0 * tensor->nb[0]);
}

__STATIC_INLINE__ float sd_image_get_f32(sd_image_t image, int iw, int ih, int ic, bool scale = true) {
float value = *(image.data + ih * image.width * image.channel + iw * image.channel + ic);
if (scale) {
Expand Down Expand Up @@ -231,6 +236,8 @@ __STATIC_INLINE__ void print_ggml_tensor(struct ggml_tensor* tensor, bool shape_
printf(" [%d, %d, %d, %d] = %f\n", i3, i2, i1, i0, ggml_ext_tensor_get_f32(tensor, i0, i1, i2, i3));
} else if (tensor->type == GGML_TYPE_F16) {
printf(" [%d, %d, %d, %d] = %f\n", i3, i2, i1, i0, ggml_fp16_to_fp32(ggml_ext_tensor_get_f16(tensor, i0, i1, i2, i3)));
} else if (tensor->type == GGML_TYPE_BF16) {
printf(" [%d, %d, %d, %d] = %f\n", i3, i2, i1, i0, ggml_bf16_to_fp32(ggml_ext_tensor_get_bf16(tensor, i0, i1, i2, i3)));
} else if (tensor->type == GGML_TYPE_I32) {
printf(" [%d, %d, %d, %d] = %i3\n", i3, i2, i1, i0, ggml_ext_tensor_get_i32(tensor, i0, i1, i2, i3));
}
Expand Down Expand Up @@ -1344,14 +1351,18 @@ __STATIC_INLINE__ void ggml_ext_backend_tensor_get_and_sync(ggml_backend_t backe
}

__STATIC_INLINE__ float ggml_ext_backend_tensor_get_f32(ggml_tensor* tensor) {
GGML_ASSERT(tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_I32);
GGML_ASSERT(tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16 || tensor->type == GGML_TYPE_I32);
float value;
if (tensor->type == GGML_TYPE_F32) {
ggml_backend_tensor_get(tensor, &value, 0, sizeof(value));
} else if (tensor->type == GGML_TYPE_F16) {
ggml_fp16_t f16_value;
ggml_backend_tensor_get(tensor, &f16_value, 0, sizeof(f16_value));
value = ggml_fp16_to_fp32(f16_value);
} else if (tensor->type == GGML_TYPE_BF16) {
ggml_bf16_t bf16_value;
ggml_backend_tensor_get(tensor, &bf16_value, 0, sizeof(bf16_value));
value = ggml_bf16_to_fp32(bf16_value);
} else { // GGML_TYPE_I32
int int32_value;
ggml_backend_tensor_get(tensor, &int32_value, 0, sizeof(int32_value));
Expand Down Expand Up @@ -2011,11 +2022,19 @@ class Linear : public UnaryBlock {
};

__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
if (allow_types.find(wtype) != allow_types.end()) {
return true;
switch (wtype) {
case GGML_TYPE_F32:
case GGML_TYPE_BF16:
case GGML_TYPE_F16:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_0:
return true;
default:
return false;
}
return false;
}

class Embedding : public UnaryBlock {
Expand Down
Loading