Skip to content

Commit

Permalink
ggml : fix YARN + add tests + add asserts (#7617)
Browse files Browse the repository at this point in the history
* tests : add rope tests

ggml-ci

* ggml : fixes (hopefully)

ggml-ci

* tests : add non-cont tests

ggml-ci

* cuda : add asserts for rope/norm + fix DS2

ggml-ci

* ggml : assert contiguousness

* tests : reduce RoPE tests

ggml-ci
  • Loading branch information
ggerganov authored May 29, 2024
1 parent cce3dcf commit fb76ec3
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 105 deletions.
4 changes: 3 additions & 1 deletion ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1870,7 +1870,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
}
}
#else
if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
CUBLAS_CHECK(
Expand Down Expand Up @@ -2886,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
return true;
case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
Expand Down
6 changes: 6 additions & 0 deletions ggml-cuda/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

Expand All @@ -188,6 +190,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

Expand All @@ -202,6 +206,8 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);

Expand Down
18 changes: 8 additions & 10 deletions ggml-cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ static __global__ void rope(
template<typename T, bool has_pos, bool has_freq_facs>
static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);

Expand All @@ -85,15 +85,13 @@ static __global__ void rope_neox(
const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;

float cur_rot = inv_ndims * ic - ib;

const int p = has_pos ? pos[i2] : 0;
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;

const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;

float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);

const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
Expand Down Expand Up @@ -174,30 +172,29 @@ static void rope_neox_cuda(
const dim3 block_nums(nrows, num_blocks_x, 1);

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.0f / n_dims;

if (pos == nullptr) {
if (freq_factors == nullptr) {
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
theta_scale, freq_factors
);
} else {
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
theta_scale, freq_factors
);
}
} else {
if (freq_factors == nullptr) {
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
theta_scale, freq_factors
);
} else {
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims, freq_factors
theta_scale, freq_factors
);
}
}
Expand Down Expand Up @@ -254,6 +251,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
Expand Down
4 changes: 3 additions & 1 deletion ggml-kompute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,9 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
{
GGML_ASSERT(ne00 == ne10);

// TODO: assert that dim2 and dim3 are contiguous
ggml_is_contiguous_2(src0);
ggml_is_contiguous_2(src1);

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);

Expand Down
8 changes: 7 additions & 1 deletion ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -1519,7 +1519,9 @@ static enum ggml_status ggml_metal_graph_compute(
{
GGML_ASSERT(ne00 == ne10);

// TODO: assert that dim2 and dim3 are contiguous
ggml_is_contiguous_2(src0);
ggml_is_contiguous_2(src1);

GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);

Expand Down Expand Up @@ -2187,6 +2189,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_RMS_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(src0));

float eps;
memcpy(&eps, dst->op_params, sizeof(float));
Expand Down Expand Up @@ -2214,6 +2217,7 @@ static enum ggml_status ggml_metal_graph_compute(
case GGML_OP_GROUP_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
GGML_ASSERT(ggml_is_contiguous(src0));

//float eps;
//memcpy(&eps, dst->op_params, sizeof(float));
Expand Down Expand Up @@ -2247,6 +2251,8 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_NORM:
{
GGML_ASSERT(ggml_is_contiguous_1(src0));

float eps;
memcpy(&eps, dst->op_params, sizeof(float));

Expand Down
16 changes: 6 additions & 10 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -1767,13 +1767,13 @@ kernel void kernel_rope(

const int64_t p = pos[i2];

const float theta_0 = (float)p;
const float theta_base = (float)p;
const float inv_ndims = -1.f/n_dims;

if (!is_neox) {
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
const float theta = theta_base * pow(freq_base, inv_ndims*i0);

const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);

Expand All @@ -1789,18 +1789,14 @@ kernel void kernel_rope(
} else {
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
if (ic < n_dims) {
const int64_t ib = 0;
const int64_t i0 = ic/2;

// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;

const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
const float theta = theta_base * pow(freq_base, inv_ndims*ic);

float cos_theta, sin_theta;
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);

const int64_t i0 = ib*n_dims + ic/2;
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);

device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
Expand Down
2 changes: 1 addition & 1 deletion ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15183,7 +15183,7 @@ static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;

if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
*g_sycl_handles[g_main_device], oneapi::mkl::transpose::trans,
Expand Down
74 changes: 36 additions & 38 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3221,7 +3221,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}

static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
return ggml_is_contiguous(tensor);
}

GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");

return
Expand All @@ -3230,6 +3234,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}

GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");

return
tensor->nb[0] == ggml_type_size(tensor->type) &&
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}

GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");

Expand Down Expand Up @@ -11420,8 +11432,8 @@ static void ggml_compute_forward_gelu_f32(

const struct ggml_tensor * src0 = dst->src[0];

GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
Expand Down Expand Up @@ -11483,8 +11495,8 @@ static void ggml_compute_forward_gelu_quick_f32(

const struct ggml_tensor * src0 = dst->src[0];

GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
Expand Down Expand Up @@ -11546,8 +11558,8 @@ static void ggml_compute_forward_silu_f32(

const struct ggml_tensor * src0 = dst->src[0];

GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));

if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
Expand Down Expand Up @@ -11658,9 +11670,9 @@ static void ggml_compute_forward_silu_back_f32(
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * grad = dst->src[1];

GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_is_contiguous_1(grad));
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src0, grad));

Expand Down Expand Up @@ -14358,7 +14370,7 @@ static void ggml_compute_forward_rope_f32(
int ir = 0;

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;

float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);

Expand Down Expand Up @@ -14407,7 +14419,7 @@ static void ggml_compute_forward_rope_f32(
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta) * sin_sign;

theta_base *= theta_scale;
theta_base *= theta_scale;
block_theta *= theta_scale;

const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
Expand Down Expand Up @@ -14442,29 +14454,22 @@ static void ggml_compute_forward_rope_f32(
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;
const int64_t i0 = ic/2;

// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;

sin_theta *= sin_sign;
theta_base *= theta_scale;

const int64_t i0 = ib*n_dims + ic/2;

const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

Expand Down Expand Up @@ -14543,7 +14548,7 @@ static void ggml_compute_forward_rope_f16(
int ir = 0;

const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;

float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);

Expand Down Expand Up @@ -14592,7 +14597,7 @@ static void ggml_compute_forward_rope_f16(
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta) * sin_sign;

theta_base *= theta_scale;
theta_base *= theta_scale;
block_theta *= theta_scale;

const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
Expand Down Expand Up @@ -14623,29 +14628,22 @@ static void ggml_compute_forward_rope_f16(
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;
const int64_t i0 = ic/2;

// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;

float cos_theta, sin_theta;
rope_yarn(
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;

sin_theta *= sin_sign;
theta_base *= theta_scale;

const int64_t i0 = ib*n_dims + ic/2;

const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

Expand Down
Loading

0 comments on commit fb76ec3

Please sign in to comment.