Skip to content
77 changes: 36 additions & 41 deletions ggml/src/ggml-hexagon/htp/rope-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
return (1 - MIN(1, MAX(0, y)));
}

static void rope_cache_init(const float theta_base,
float freq_scale,
const float * freq_factors,
float * corr_dims,
uint32_t ne0,
float ext_factor,
float mscale,
float * cache,
float theta_scale) {
static void rope_cache_init(const float theta_base,
const float freq_scale,
const float * freq_factors,
float * corr_dims,
const uint32_t ne0,
const float ext_factor,
const float mscale,
float * cache,
const float theta_scale) {
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
float theta = theta_base;

Expand All @@ -92,18 +92,19 @@ static void rope_cache_init(const float theta_base,

// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta2 = theta_interp;
float theta_final = theta_interp;
float mscale_final = mscale;

if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta2 = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;

// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}

cache[i0 + 0] = cosf(theta2) * mscale;
cache[i0 + 1] = sinf(theta2) * mscale;
cache[i0 + 0] = cosf(theta_final) * mscale_final;
cache[i0 + 1] = sinf(theta_final) * mscale_final;

theta *= theta_scale;
}
Expand Down Expand Up @@ -151,9 +152,9 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
}

static void hvx_calc_rope_neox_f32(const float * restrict src0,
float * restrict dst,
const int num_elems,
const float * restrict theta_cache) {
float * restrict dst,
const int num_elems,
const float * restrict theta_cache) {
// for (int i = 0; i < num_elems; i += 2) {
//const float cos_theta = theta_cache[i + 0];
//const float sin_theta = theta_cache[i + 1];
Expand Down Expand Up @@ -192,7 +193,7 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);

*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);

src0_curr += VLEN;
Expand Down Expand Up @@ -259,16 +260,16 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
const uint32_t ir1,
int nth,
int ith,
int opt_path) {
const int opt_path) {
struct htp_ops_context * octx = rope_ctx->octx;

const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;

const int32_t mode = rope_ctx->mode;
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
const int32_t mode = rope_ctx->mode;
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;

htp_rope_preamble;

Expand All @@ -281,23 +282,16 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
freq_factors = (const float *) src2->data;
}

int ir = 0;

const uint32_t i1_end = MIN(ir1, ne1);
const int32_t half_dims = rope_ctx->n_dims / 2;
for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch
for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len
const int32_t p = pos[i2];

rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor,
rope_ctx->attn_factor, wp0, rope_ctx->theta_scale);

for (uint32_t i1 = 0; i1 < ne1; i1++) { // attn-heads
if (ir++ < ir0) {
continue;
}
if (ir > ir1) {
break;
}
Copy link
Contributor Author

@chraac chraac Nov 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those two inner if statements can be merged into the for loop's condition.


for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads
const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01);
float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1);

Expand All @@ -310,17 +304,20 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
} else {
hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0);
}

src_loc += rope_ctx->n_dims;
dst_data_loc += rope_ctx->n_dims;
} else {
for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) {
const float cos_theta = wp0[i0 + 0];
const float sin_theta = wp0[i0 + 1];

if (is_neox) {
const float x0 = src_loc[0];
const float x1 = src_loc[rope_ctx->n_dims/2];
const float x1 = src_loc[half_dims];

dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta;

src_loc += 1;
dst_data_loc += 1;
Expand All @@ -335,15 +332,13 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
dst_data_loc += 2;
}
}
}

for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
dst_data_loc[0] = src_loc[0];
dst_data_loc[1] = src_loc[1];

src_loc += 2;
dst_data_loc += 2;
src_loc += (is_neox ? half_dims : 0);
dst_data_loc += (is_neox ? half_dims : 0);
}

// TODO: use simd to speed up the remaining elements copy
memcpy(dst_data_loc, src_loc, (ne0 - rope_ctx->n_dims) * sizeof(float));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: did we have simd acceleration in the memcpy?

}
}
}
Expand Down
Loading