Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix test failure
  • Loading branch information
chraac committed Nov 27, 2025
commit 407b408981abc4180f64b99f89b003168bf7b962
22 changes: 12 additions & 10 deletions ggml/src/ggml-hexagon/htp/rope-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context
}

static void hvx_calc_rope_neox_f32(const float * restrict src0,
float * restrict dst,
const int num_elems,
const float * restrict theta_cache) {
float * restrict dst,
const int num_elems,
const float * restrict theta_cache) {
// for (int i = 0; i < num_elems; i += 2) {
//const float cos_theta = theta_cache[i + 0];
//const float sin_theta = theta_cache[i + 1];
Expand Down Expand Up @@ -192,7 +192,7 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0,
HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s);
HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c);

*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
*(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4);
*(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5);

src0_curr += VLEN;
Expand Down Expand Up @@ -259,16 +259,16 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
const uint32_t ir1,
int nth,
int ith,
int opt_path) {
const int opt_path) {
struct htp_ops_context * octx = rope_ctx->octx;

const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * src1 = &octx->src1;
const struct htp_tensor * src2 = &octx->src2;
struct htp_tensor * dst = &octx->dst;

const int32_t mode = rope_ctx->mode;
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;
const int32_t mode = rope_ctx->mode;
const bool is_neox = mode & HTP_ROPE_TYPE_NEOX;

htp_rope_preamble;

Expand Down Expand Up @@ -317,10 +317,10 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,

if (is_neox) {
const float x0 = src_loc[0];
const float x1 = src_loc[rope_ctx->n_dims/2];
const float x1 = src_loc[rope_ctx->n_dims / 2];

dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[rope_ctx->n_dims/2] = x0 * sin_theta + x1 * cos_theta;
dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta;
dst_data_loc[rope_ctx->n_dims / 2] = x0 * sin_theta + x1 * cos_theta;

src_loc += 1;
dst_data_loc += 1;
Expand All @@ -337,6 +337,8 @@ static void rope_hex_f32(struct rope_th_ctx * rope_ctx,
}
}

src_loc += is_neox ? (rope_ctx->n_dims / 2) : 0;
dst_data_loc += is_neox ? (rope_ctx->n_dims / 2) : 0;
for (uint32_t i0 = rope_ctx->n_dims; i0 < ne0; i0 += 2) {
dst_data_loc[0] = src_loc[0];
dst_data_loc[1] = src_loc[1];
Expand Down