33#include " ggml-cpu.h"
44#include " ggml-impl.h"
55#include " binary-ops.h"
6+ #include " simd-gemm.h"
67#include " ggml.h"
78#include " unary-ops.h"
89#include " vec.h"
@@ -8389,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
83898390 GGML_ASSERT (k->type == v->type );
83908391 const ggml_type kv_type = k->type ;
83918392
8392- const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu (kv_type);
8393- const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float ;
8394- const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot ;
8395- const size_t kv_type_size = ggml_type_size (kv_type);
83968393
83978394 // broadcast factors
83988395 const int64_t rk2 = neq2/nek2;
@@ -8424,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
84248421 static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
84258422 static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
84268423
8427- GGML_ASSERT (nek1 % KV_TILE_SZ == 0 && " KV sequence length must be divisible by KV_TILE_SZ" );
8428-
84298424 int ir = ir0;
84308425 while (ir < ir1) {
84318426 // q indices for the start of this tile
@@ -8452,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
84528447 }
84538448
84548449 // Per-thread scratch layout:
8455- // Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
8450+ // Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar )
84568451 // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
84578452 // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
84588453 // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
8459- // V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
8460- float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2 *Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
8454+ // V32: KV_TILE_SZ * DV (F32 buffer for V tile)
8455+ // K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
8456+ float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2 *Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
84618457
84628458 void * Q_q = base;
84638459 float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof (float ));
84648460 float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
84658461 float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
8466- float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
8462+ float * V32 = VKQ32 + Q_TILE_SZ * DV;
8463+ float * K_f32 = V32 + KV_TILE_SZ * DV;
84678464
84688465 memset (VKQ32, 0 , Q_TILE_SZ * DV * sizeof (float ));
84698466 memset (mask32, 0 , Q_TILE_SZ * KV_TILE_SZ * sizeof (float ));
@@ -8476,42 +8473,71 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
84768473 const int iv3 = iq3 / rv3;
84778474 const int iv2 = iq2 / rv2;
84788475
8479- for (int tq = 0 ; tq < tile_rows; tq++) {
8480- const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
8481- kv_from_float (pq, (char *)Q_q + tq * DK * kv_type_size, DK);
8482- }
8483- // Zero-pad remaining rows
8484- for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
8485- memset ((char *)Q_q + tq * DK * kv_type_size, 0 , DK * kv_type_size);
8476+ {
8477+ float * Q_f32 = (float *)Q_q;
8478+ for (int tq = 0 ; tq < tile_rows; tq++) {
8479+ const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
8480+ memcpy (Q_f32 + tq * DK, pq, DK * sizeof (float ));
8481+ }
8482+ for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
8483+ memset (Q_f32 + tq * DK, 0 , DK * sizeof (float ));
8484+ }
84868485 }
84878486
8487+ memset (K_f32, 0 , DK * KV_TILE_SZ * sizeof (float ));
8488+ memset (V32, 0 , KV_TILE_SZ * DV * sizeof (float ));
8489+
84888490 for (int64_t ic = 0 ; ic < nek1; ic += KV_TILE_SZ) {
8491+ const int kv_tile = (int )std::min ((int64_t )KV_TILE_SZ, nek1 - ic);
84898492
84908493 // skip the tile entirely if all the masks are -inf
84918494 if (mask) {
84928495 bool can_skip = true ;
84938496 for (int tq = 0 ; tq < tile_rows; tq++) {
84948497 const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb [1 ] + (iq2%mask->ne [2 ])*mask->nb [2 ] + (iq3%mask->ne [3 ])*mask->nb [3 ]);
8495- for (int tk = 0 ; tk < KV_TILE_SZ ; tk++) {
8498+ for (int tk = 0 ; tk < kv_tile ; tk++) {
84968499 mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32 (mp_row[ic + tk]);
84978500 if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
84988501 can_skip = false ;
84998502 }
85008503 }
8504+ // Pad remaining mask entries with -inf
8505+ for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8506+ mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
8507+ }
85018508 }
85028509
85038510 if (can_skip) {
85048511 continue ;
85058512 }
85068513 }
85078514
8508- for (int tq = 0 ; tq < Q_TILE_SZ; tq++) {
8509- const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
8510- for (int tk = 0 ; tk < KV_TILE_SZ; tk++) {
8511- const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
8512- float s;
8513- kv_vec_dot (DK, &s, 0 , k_row, 0 , q_row, 0 , 1 );
8514- KQ[tq * KV_TILE_SZ + tk] = s * scale;
8515+ // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
8516+ // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
8517+ for (int tk = 0 ; tk < kv_tile; tk++) {
8518+ const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
8519+ if (kv_type == GGML_TYPE_F16) {
8520+ const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
8521+ for (int64_t dk = 0 ; dk < DK; dk++) {
8522+ K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32 (k_f16[dk]);
8523+ }
8524+ } else {
8525+ const float * k_f32_src = (const float *)k_data;
8526+ for (int64_t dk = 0 ; dk < DK; dk++) {
8527+ K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
8528+ }
8529+ }
8530+ }
8531+ memset (KQ, 0 , Q_TILE_SZ * KV_TILE_SZ * sizeof (float ));
8532+ simd_gemm (KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
8533+ ggml_vec_scale_f32 (Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
8534+
8535+ // Set padded KQ entries to -inf so softmax gives them zero weight
8536+ if (kv_tile < KV_TILE_SZ) {
8537+ for (int tq = 0 ; tq < Q_TILE_SZ; tq++) {
8538+ for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8539+ KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
8540+ }
85158541 }
85168542 }
85178543
@@ -8551,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
85518577 S[tq] += ggml_vec_soft_max_f32 (KV_TILE_SZ, kq_row, kq_row, Mnew);
85528578 }
85538579
8554- // Convert V tile to F32 first (if F16), then do MAD
8555- // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
8556- // TODO: on ARM, native f16 should be faster
8557- if (kv_type == GGML_TYPE_F16) {
8558- for (int tk = 0 ; tk < KV_TILE_SZ; tk++) {
8559- const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
8560- ggml_fp16_to_fp32_row (v_row, V32 + tk * DV, DV);
8561- }
8562- for (int tq = 0 ; tq < Q_TILE_SZ; tq++) {
8563- if (skip[tq]) continue ;
8564- float * vkq_row = VKQ32 + tq * DV;
8565- for (int tk = 0 ; tk < KV_TILE_SZ; tk++) {
8566- const float p = KQ[tq * KV_TILE_SZ + tk];
8567- ggml_vec_mad_f32 (DV, vkq_row, V32 + tk * DV, p);
8568- }
8580+ // V accumulation: VKQ32 += softmax(KQ) * V
8581+ // Pack V tile to contiguous F32, zero-padded
8582+ for (int tk = 0 ; tk < kv_tile; tk++) {
8583+ const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
8584+ if (kv_type == GGML_TYPE_F16) {
8585+ ggml_fp16_to_fp32_row ((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
8586+ } else {
8587+ memcpy (V32 + tk * DV, v_data, DV * sizeof (float ));
85698588 }
8570- } else {
8571- for (int tq = 0 ; tq < Q_TILE_SZ; tq++) {
8572- if (skip[tq]) continue ;
8573- float * vkq_row = VKQ32 + tq * DV;
8574- for (int tk = 0 ; tk < KV_TILE_SZ; tk++) {
8575- const float p = KQ[tq * KV_TILE_SZ + tk];
8576- const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
8577- ggml_vec_mad_f32 (DV, vkq_row, v_row, p);
8578- }
8589+ }
8590+ for (int tq = 0 ; tq < Q_TILE_SZ; tq++) {
8591+ if (skip[tq]) {
8592+ memset (KQ + tq * KV_TILE_SZ, 0 , KV_TILE_SZ * sizeof (float ));
85798593 }
85808594 }
8595+ simd_gemm (VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
85818596 }
85828597
85838598 // sinks (apply only to valid rows in the tile)
@@ -8794,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
87948809
87958810 const int64_t dr = (nr + nchunk - 1 ) / nchunk;
87968811
8797- static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
87988812 static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
8799- const bool use_tiled = !use_ref &&
8813+ bool use_tiled = !use_ref &&
88008814 (q->type == GGML_TYPE_F32 &&
88018815 kv_is_f32_or_f16 &&
88028816 k->type == v->type &&
8803- nek1 % KV_TILE_SZ == 0 &&
88048817 neq1 >= Q_TILE_SZ);
8805-
8818+ #ifdef GGML_SIMD
8819+ use_tiled &= (DV % GGML_F32_EPR == 0 );
8820+ #endif
88068821 int current_chunk = ith;
88078822
88088823 while (current_chunk < nchunk) {
0 commit comments