Skip to content

Commit 684b361

Browse files
authored
ggml-cpu: FA add GEMM microkernel (#19422)
* ggml-cpu: FA add GEMM microkernel * add guard for sizeless vector types * fix case where DV % GGML_F32_EPR !=0 * move memset out of the loop * move another memset out of the loop * use RM=4 for arm * simd_gemm: convert everything to int * convert everything to size_t to avoid warnings * fixup * add pragma for ignoring aggressive loop optimizations
1 parent 3a00c98 commit 684b361

4 files changed

Lines changed: 211 additions & 57 deletions

File tree

ggml/src/ggml-cpu/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
#include "ggml-impl.h"
77
#include "simd-mappings.h"
88

9-
#define GGML_FA_TILE_Q 32
10-
#define GGML_FA_TILE_KV 16
9+
#define GGML_FA_TILE_Q 64
10+
#define GGML_FA_TILE_KV 64
1111

1212
#ifdef __cplusplus
1313

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan(
28742874
const int64_t DV = node->src[2]->ne[0];
28752875

28762876
// Tiled flash attention scratch (tile sizes defined in common.h)
2877-
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
2878-
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
2877+
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding
2878+
size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks;
28792879

28802880
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
28812881
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ

ggml/src/ggml-cpu/ops.cpp

Lines changed: 68 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "ggml-cpu.h"
44
#include "ggml-impl.h"
55
#include "binary-ops.h"
6+
#include "simd-gemm.h"
67
#include "ggml.h"
78
#include "unary-ops.h"
89
#include "vec.h"
@@ -8389,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
83898390
GGML_ASSERT(k->type == v->type);
83908391
const ggml_type kv_type = k->type;
83918392

8392-
const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
8393-
const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
8394-
const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
8395-
const size_t kv_type_size = ggml_type_size(kv_type);
83968393

83978394
// broadcast factors
83988395
const int64_t rk2 = neq2/nek2;
@@ -8424,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
84248421
static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
84258422
static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
84268423

8427-
GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
8428-
84298424
int ir = ir0;
84308425
while (ir < ir1) {
84318426
// q indices for the start of this tile
@@ -8452,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
84528447
}
84538448

84548449
// Per-thread scratch layout:
8455-
// Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
8450+
// Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar)
84568451
// KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
84578452
// mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
84588453
// VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
8459-
// V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
8460-
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
8454+
// V32: KV_TILE_SZ * DV (F32 buffer for V tile)
8455+
// K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path)
8456+
float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32);
84618457

84628458
void * Q_q = base;
84638459
float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
84648460
float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
84658461
float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
8466-
float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
8462+
float * V32 = VKQ32 + Q_TILE_SZ * DV;
8463+
float * K_f32 = V32 + KV_TILE_SZ * DV;
84678464

84688465
memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
84698466
memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
@@ -8476,42 +8473,71 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
84768473
const int iv3 = iq3 / rv3;
84778474
const int iv2 = iq2 / rv2;
84788475

8479-
for (int tq = 0; tq < tile_rows; tq++) {
8480-
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
8481-
kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
8482-
}
8483-
// Zero-pad remaining rows
8484-
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
8485-
memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
8476+
{
8477+
float * Q_f32 = (float *)Q_q;
8478+
for (int tq = 0; tq < tile_rows; tq++) {
8479+
const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
8480+
memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float));
8481+
}
8482+
for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
8483+
memset(Q_f32 + tq * DK, 0, DK * sizeof(float));
8484+
}
84868485
}
84878486

8487+
memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float));
8488+
memset(V32, 0, KV_TILE_SZ * DV * sizeof(float));
8489+
84888490
for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
8491+
const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic);
84898492

84908493
// skip the tile entirely if all the masks are -inf
84918494
if (mask) {
84928495
bool can_skip = true;
84938496
for (int tq = 0; tq < tile_rows; tq++) {
84948497
const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
8495-
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
8498+
for (int tk = 0; tk < kv_tile; tk++) {
84968499
mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
84978500
if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
84988501
can_skip = false;
84998502
}
85008503
}
8504+
// Pad remaining mask entries with -inf
8505+
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8506+
mask32[tq * KV_TILE_SZ + tk] = -INFINITY;
8507+
}
85018508
}
85028509

85038510
if (can_skip) {
85048511
continue;
85058512
}
85068513
}
85078514

8508-
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8509-
const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
8510-
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
8511-
const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
8512-
float s;
8513-
kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
8514-
KQ[tq * KV_TILE_SZ + tk] = s * scale;
8515+
// Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim)
8516+
// Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns
8517+
for (int tk = 0; tk < kv_tile; tk++) {
8518+
const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3;
8519+
if (kv_type == GGML_TYPE_F16) {
8520+
const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data;
8521+
for (int64_t dk = 0; dk < DK; dk++) {
8522+
K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]);
8523+
}
8524+
} else {
8525+
const float * k_f32_src = (const float *)k_data;
8526+
for (int64_t dk = 0; dk < DK; dk++) {
8527+
K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk];
8528+
}
8529+
}
8530+
}
8531+
memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
8532+
simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ);
8533+
ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale);
8534+
8535+
// Set padded KQ entries to -inf so softmax gives them zero weight
8536+
if (kv_tile < KV_TILE_SZ) {
8537+
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8538+
for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) {
8539+
KQ[tq * KV_TILE_SZ + tk] = -INFINITY;
8540+
}
85158541
}
85168542
}
85178543

@@ -8551,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
85518577
S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
85528578
}
85538579

8554-
// Convert V tile to F32 first (if F16), then do MAD
8555-
// On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
8556-
// TODO: on ARM, native f16 should be faster
8557-
if (kv_type == GGML_TYPE_F16) {
8558-
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
8559-
const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
8560-
ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
8561-
}
8562-
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8563-
if (skip[tq]) continue;
8564-
float * vkq_row = VKQ32 + tq * DV;
8565-
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
8566-
const float p = KQ[tq * KV_TILE_SZ + tk];
8567-
ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
8568-
}
8580+
// V accumulation: VKQ32 += softmax(KQ) * V
8581+
// Pack V tile to contiguous F32, zero-padded
8582+
for (int tk = 0; tk < kv_tile; tk++) {
8583+
const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3;
8584+
if (kv_type == GGML_TYPE_F16) {
8585+
ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV);
8586+
} else {
8587+
memcpy(V32 + tk * DV, v_data, DV * sizeof(float));
85698588
}
8570-
} else {
8571-
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8572-
if (skip[tq]) continue;
8573-
float * vkq_row = VKQ32 + tq * DV;
8574-
for (int tk = 0; tk < KV_TILE_SZ; tk++) {
8575-
const float p = KQ[tq * KV_TILE_SZ + tk];
8576-
const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
8577-
ggml_vec_mad_f32(DV, vkq_row, v_row, p);
8578-
}
8589+
}
8590+
for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8591+
if (skip[tq]) {
8592+
memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float));
85798593
}
85808594
}
8595+
simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV);
85818596
}
85828597

85838598
// sinks (apply only to valid rows in the tile)
@@ -8794,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16(
87948809

87958810
const int64_t dr = (nr + nchunk - 1) / nchunk;
87968811

8797-
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
87988812
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
8799-
const bool use_tiled = !use_ref &&
8813+
bool use_tiled = !use_ref &&
88008814
(q->type == GGML_TYPE_F32 &&
88018815
kv_is_f32_or_f16 &&
88028816
k->type == v->type &&
8803-
nek1 % KV_TILE_SZ == 0 &&
88048817
neq1 >= Q_TILE_SZ);
8805-
8818+
#ifdef GGML_SIMD
8819+
use_tiled &= (DV % GGML_F32_EPR == 0);
8820+
#endif
88068821
int current_chunk = ith;
88078822

88088823
while (current_chunk < nchunk) {

ggml/src/ggml-cpu/simd-gemm.h

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
#pragma once
2+
3+
// Computes C[M x N] += A[M x K] * B[K x N]
4+
5+
#include "ggml-cpu-impl.h"
6+
#include "vec.h"
7+
#include "common.h"
8+
#include "simd-mappings.h"
9+
10+
// TODO: add support for sizeless vector types
11+
#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic)
12+
13+
// TODO: untested on avx512
14+
// These are in units of GGML_F32_EPR
15+
#if defined(__AVX512F__) || defined (__ARM_NEON__)
16+
static constexpr int GEMM_RM = 4;
17+
static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32
18+
#elif defined(__AVX2__) || defined(__AVX__)
19+
static constexpr int GEMM_RM = 6;
20+
static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16
21+
#else
22+
static constexpr int GEMM_RM = 2;
23+
static constexpr int GEMM_RN = 2;
24+
#endif
25+
26+
#if defined(__GNUC__) && !defined(__clang__)
27+
#pragma GCC diagnostic push
28+
#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations"
29+
#endif
30+
31+
template <int RM, int RN>
32+
static inline void simd_gemm_ukernel(
33+
float * GGML_RESTRICT C,
34+
const float * GGML_RESTRICT A,
35+
const float * GGML_RESTRICT B,
36+
int64_t K, int64_t N,
37+
int64_t ii, int64_t jj)
38+
{
39+
static constexpr int KN = GGML_F32_EPR;
40+
41+
GGML_F32_VEC acc[RM][RN];
42+
for (int i = 0; i < RM; i++) {
43+
for (int r = 0; r < RN; r++) {
44+
acc[i][r] = GGML_F32_VEC_LOAD(C + (ii + i) * N + jj + r * KN);
45+
}
46+
}
47+
48+
for (int64_t kk = 0; kk < K; kk++) {
49+
GGML_F32_VEC Bv[RN];
50+
for (int r = 0; r < RN; r++) {
51+
Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + jj + r * KN);
52+
}
53+
for (int i = 0; i < RM; i++) {
54+
GGML_F32_VEC p = GGML_F32_VEC_SET1(A[(ii + i) * K + kk]);
55+
for (int r = 0; r < RN; r++) {
56+
acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p);
57+
}
58+
}
59+
}
60+
61+
for (int i = 0; i < RM; i++) {
62+
for (int r = 0; r < RN; r++) {
63+
GGML_F32_VEC_STORE(C + (ii + i) * N + jj + r * KN, acc[i][r]);
64+
}
65+
}
66+
}
67+
68+
// C[M x N] += A[M x K] * B[K x N]
69+
static void simd_gemm(
70+
float * GGML_RESTRICT C,
71+
const float * GGML_RESTRICT A,
72+
const float * GGML_RESTRICT B,
73+
int64_t M, int64_t K, int64_t N)
74+
{
75+
static constexpr int KN = GGML_F32_EPR;
76+
77+
int64_t ii = 0;
78+
for (; ii + GEMM_RM <= M; ii += GEMM_RM) {
79+
int64_t jj = 0;
80+
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
81+
simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C, A, B, K, N, ii, jj);
82+
}
83+
for (; jj + KN <= N; jj += KN) {
84+
simd_gemm_ukernel<GEMM_RM, 1>(C, A, B, K, N, ii, jj);
85+
}
86+
for (; jj < N; jj++) {
87+
for (int i = 0; i < GEMM_RM; i++) {
88+
float a = C[(ii + i) * N + jj];
89+
for (int64_t kk = 0; kk < K; kk++) {
90+
a += A[(ii + i) * K + kk] * B[kk * N + jj];
91+
}
92+
C[(ii + i) * N + jj] = a;
93+
}
94+
}
95+
}
96+
97+
// Tail rows: one at a time
98+
for (; ii < M; ii++) {
99+
int64_t jj = 0;
100+
for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) {
101+
simd_gemm_ukernel<1, GEMM_RN>(C, A, B, K, N, ii, jj);
102+
}
103+
for (; jj + KN <= N; jj += KN) {
104+
simd_gemm_ukernel<1, 1>(C, A, B, K, N, ii, jj);
105+
}
106+
for (; jj < N; jj++) {
107+
float a = C[ii * N + jj];
108+
for (int64_t kk = 0; kk < K; kk++) {
109+
a += A[ii * K + kk] * B[kk * N + jj];
110+
}
111+
C[ii * N + jj] = a;
112+
}
113+
}
114+
}
115+
116+
#if defined(__GNUC__) && !defined(__clang__)
117+
#pragma GCC diagnostic pop
118+
#endif
119+
120+
#else // scalar path
121+
122+
static void simd_gemm(
123+
float * GGML_RESTRICT C,
124+
const float * GGML_RESTRICT A,
125+
const float * GGML_RESTRICT B,
126+
int64_t M, int64_t K, int64_t N)
127+
{
128+
for (int64_t i = 0; i < M; i++) {
129+
for (int64_t j = 0; j < N; j++) {
130+
float sum = C[i * N + j];
131+
for (int64_t kk = 0; kk < K; kk++) {
132+
sum += A[i * K + kk] * B[kk * N + j];
133+
}
134+
C[i * N + j] = sum;
135+
}
136+
}
137+
}
138+
139+
#endif // GGML_SIMD

0 commit comments

Comments
 (0)