Skip to content

Commit f0b133d

Browse files
committed
Add PagedAttention support (experimental, CUDA only)
Implement PagedAttention algorithm from for memory-efficient KV cache management. This feature reduces memory fragmentation by storing KV cache in fixed-size blocks (similar to virtual memory paging) and enables efficient memory sharing between sequences through copy-on-write semantics. The implementation is experimental and disabled by default. Enable with the --pagedattention flag Signed-off-by: Eric Curtin <[email protected]>
1 parent 03914c7 commit f0b133d

19 files changed

+1979
-3
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
10171017
string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
10181018
}
10191019
}).set_env("LLAMA_ARG_FLASH_ATTN"));
1020+
add_opt(common_arg(
1021+
{"--pagedattention"},
1022+
"enable PagedAttention for KV cache (experimental, requires CUDA)",
1023+
[](common_params & params) {
1024+
params.use_paged_attention = true;
1025+
}
1026+
));
10201027
add_opt(common_arg(
10211028
{"-p", "--prompt"}, "PROMPT",
10221029
"prompt to start generation with; for system message, use -sys",

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
12751275
cparams.op_offload = !params.no_op_offload;
12761276
cparams.swa_full = params.swa_full;
12771277
cparams.kv_unified = params.kv_unified;
1278+
cparams.use_paged_attention = params.use_paged_attention;
12781279

12791280
cparams.type_k = params.cache_type_k;
12801281
cparams.type_v = params.cache_type_v;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ struct common_params {
406406
bool ctx_shift = false; // context shift on infinite text generation
407407
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
408408
bool kv_unified = false; // enable unified KV cache
409+
bool use_paged_attention = false; // enable PagedAttention (experimental, requires CUDA)
409410

410411
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
411412
bool use_mmap = true; // use mmap for faster loads

ggml/include/ggml.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ extern "C" {
537537

538538
GGML_OP_FLASH_ATTN_EXT,
539539
GGML_OP_FLASH_ATTN_BACK,
540+
GGML_OP_PAGED_ATTENTION,
540541
GGML_OP_SSM_CONV,
541542
GGML_OP_SSM_SCAN,
542543
GGML_OP_WIN_PART,
@@ -2312,6 +2313,22 @@ extern "C" {
23122313
struct ggml_tensor * a,
23132314
struct ggml_tensor * sinks);
23142315

2316+
// PagedAttention (paged KV cache attention)
2317+
// q: [n_tokens, n_heads, head_size]
2318+
// k_cache: [num_blocks, block_size, n_kv_heads, head_size] (paged)
2319+
// v_cache: [num_blocks, block_size, n_kv_heads, head_size] (paged)
2320+
// block_tables: [n_seqs, max_blocks_per_seq] (int32)
2321+
// seq_lens: [n_seqs] (int32)
2322+
GGML_API struct ggml_tensor * ggml_paged_attention(
2323+
struct ggml_context * ctx,
2324+
struct ggml_tensor * q,
2325+
struct ggml_tensor * k_cache,
2326+
struct ggml_tensor * v_cache,
2327+
struct ggml_tensor * block_tables,
2328+
struct ggml_tensor * seq_lens,
2329+
int32_t block_size,
2330+
float scale);
2331+
23152332
// TODO: needs to be adapted to ggml_flash_attn_ext
23162333
GGML_API struct ggml_tensor * ggml_flash_attn_back(
23172334
struct ggml_context * ctx,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,6 +2062,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20622062
{
20632063
// nop
20642064
} break;
2065+
case GGML_OP_PAGED_ATTENTION:
2066+
{
2067+
// nop (CUDA-only operation)
2068+
} break;
20652069
case GGML_OP_COUNT:
20662070
{
20672071
GGML_ABORT("fatal error");
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/**
2+
* GGML CUDA Backend for PagedAttention
3+
*
4+
* This file provides the CUDA backend implementation for the GGML_OP_PAGED_ATTENTION operation.
5+
* It bridges GGML's operation framework with the PagedAttention CUDA kernels.
6+
*
7+
* NOTE: PagedAttention is currently experimental and only supported on CUDA.
8+
* MUSA support is disabled due to compiler compatibility issues.
9+
*/
10+
11+
// PagedAttention is not yet supported on MUSA
12+
#ifndef GGML_USE_MUSA
13+
14+
#include "common.cuh"
15+
#include "paged-attention.cuh"
16+
17+
// Extract parameters from GGML tensor
18+
static void ggml_cuda_op_paged_attention_get_params(
19+
const ggml_tensor * dst,
20+
float * scale,
21+
int32_t * block_size) {
22+
23+
const float * params = (const float *)dst->op_params;
24+
*scale = params[0];
25+
*block_size = (int32_t)params[1];
26+
}
27+
28+
// Main CUDA backend function for PagedAttention
29+
void ggml_cuda_op_paged_attention(
30+
ggml_backend_cuda_context & ctx,
31+
ggml_tensor * dst) {
32+
33+
const ggml_tensor * q = dst->src[0]; // query
34+
const ggml_tensor * k_cache = dst->src[1]; // key cache (paged)
35+
const ggml_tensor * v_cache = dst->src[2]; // value cache (paged)
36+
const ggml_tensor * block_tables = dst->src[3]; // block tables
37+
const ggml_tensor * seq_lens = dst->src[4]; // sequence lengths
38+
39+
// Extract parameters
40+
float scale;
41+
int32_t block_size;
42+
ggml_cuda_op_paged_attention_get_params(dst, &scale, &block_size);
43+
44+
// Get tensor dimensions
45+
const int64_t head_size = q->ne[0];
46+
const int64_t n_heads = q->ne[1];
47+
const int64_t n_tokens = q->ne[2]; // TODO: use for validation
48+
const int64_t n_seqs = q->ne[3];
49+
50+
const int64_t n_kv_heads = k_cache->ne[2];
51+
const int64_t num_blocks = k_cache->ne[0]; // TODO: use for validation
52+
53+
const int64_t max_blocks_per_seq = block_tables->ne[0];
54+
55+
// Suppress unused variable warnings
56+
GGML_UNUSED(n_tokens);
57+
GGML_UNUSED(num_blocks);
58+
59+
// Get pointers
60+
void * out_ptr = dst->data;
61+
const void * q_ptr = q->data;
62+
const void * k_cache_ptr = k_cache->data;
63+
const void * v_cache_ptr = v_cache->data;
64+
const int32_t * block_tables_ptr = (const int32_t *)block_tables->data;
65+
const int32_t * seq_lens_ptr = (const int32_t *)seq_lens->data;
66+
67+
// Calculate max sequence length (needed to decide V1 vs V2)
68+
int max_seq_len = 0;
69+
for (int i = 0; i < n_seqs; i++) {
70+
if (seq_lens_ptr[i] > max_seq_len) {
71+
max_seq_len = seq_lens_ptr[i];
72+
}
73+
}
74+
75+
// Get CUDA stream
76+
cudaStream_t stream = ctx.stream();
77+
78+
// Decide whether to use V1 or V2
79+
const bool use_v1 = ggml_cuda_paged_attention::should_use_v1(
80+
max_seq_len, n_seqs, n_heads);
81+
82+
// Launch appropriate kernel
83+
if (use_v1) {
84+
ggml_cuda_paged_attention::paged_attention_v1_launcher(
85+
out_ptr,
86+
q_ptr,
87+
k_cache_ptr,
88+
v_cache_ptr,
89+
n_seqs,
90+
n_heads,
91+
n_kv_heads,
92+
head_size,
93+
block_size,
94+
max_blocks_per_seq,
95+
block_tables_ptr,
96+
seq_lens_ptr,
97+
max_seq_len,
98+
scale,
99+
nullptr, // alibi_slopes (TODO: add support if needed)
100+
q->type,
101+
k_cache->type,
102+
stream);
103+
} else {
104+
ggml_cuda_paged_attention::paged_attention_v2_launcher(
105+
out_ptr,
106+
q_ptr,
107+
k_cache_ptr,
108+
v_cache_ptr,
109+
n_seqs,
110+
n_heads,
111+
n_kv_heads,
112+
head_size,
113+
block_size,
114+
max_blocks_per_seq,
115+
block_tables_ptr,
116+
seq_lens_ptr,
117+
max_seq_len,
118+
scale,
119+
nullptr, // alibi_slopes
120+
q->type,
121+
k_cache->type,
122+
stream);
123+
}
124+
125+
// Check for errors
126+
CUDA_CHECK(cudaGetLastError());
127+
}
128+
129+
// Check if PagedAttention is supported for given configuration
130+
bool ggml_cuda_can_paged_attention(const ggml_tensor * dst) {
131+
const ggml_tensor * q = dst->src[0];
132+
const ggml_tensor * k_cache = dst->src[1];
133+
134+
// Check data types
135+
if (q->type != GGML_TYPE_F16 && q->type != GGML_TYPE_F32) {
136+
return false;
137+
}
138+
139+
if (k_cache->type != GGML_TYPE_F16 && k_cache->type != GGML_TYPE_F32) {
140+
return false;
141+
}
142+
143+
// Check head size is supported
144+
const int64_t head_size = q->ne[0];
145+
const int supported_head_sizes[] = {32, 64, 80, 96, 112, 120, 128, 192, 256};
146+
bool head_size_supported = false;
147+
148+
for (int hs : supported_head_sizes) {
149+
if (head_size == hs) {
150+
head_size_supported = true;
151+
break;
152+
}
153+
}
154+
155+
if (!head_size_supported) {
156+
return false;
157+
}
158+
159+
// Extract block size and check it's supported
160+
float scale;
161+
int32_t block_size;
162+
ggml_cuda_op_paged_attention_get_params(dst, &scale, &block_size);
163+
164+
if (block_size != 8 && block_size != 16 && block_size != 32) {
165+
return false;
166+
}
167+
168+
return true;
169+
}
170+
171+
#else // GGML_USE_MUSA
172+
173+
// Stub implementations for MUSA (PagedAttention not yet supported)
174+
#include "common.cuh"
175+
176+
void ggml_cuda_op_paged_attention(
177+
ggml_backend_cuda_context & ctx,
178+
ggml_tensor * dst) {
179+
GGML_UNUSED(ctx);
180+
GGML_UNUSED(dst);
181+
GGML_ABORT("PagedAttention is not yet supported on MUSA");
182+
}
183+
184+
bool ggml_cuda_supports_paged_attention(const ggml_tensor * dst) {
185+
GGML_UNUSED(dst);
186+
return false;
187+
}
188+
189+
#endif // GGML_USE_MUSA

0 commit comments

Comments
 (0)