Skip to content

Commit 06254d1

Browse files
committed
Add PagedAttention support (experimental, CUDA only)
Implement PagedAttention algorithm from for memory-efficient KV cache management. This feature reduces memory fragmentation by storing KV cache in fixed-size blocks (similar to virtual memory paging) and enables efficient memory sharing between sequences through copy-on-write semantics. The implementation is experimental and disabled by default. Enable with the --pagedattention flag Signed-off-by: Eric Curtin <[email protected]>
1 parent 03914c7 commit 06254d1

19 files changed

+1939
-3
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
10171017
string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
10181018
}
10191019
}).set_env("LLAMA_ARG_FLASH_ATTN"));
1020+
add_opt(common_arg(
1021+
{"--pagedattention"},
1022+
"enable PagedAttention for KV cache (experimental, requires CUDA)",
1023+
[](common_params & params) {
1024+
params.use_paged_attention = true;
1025+
}
1026+
));
10201027
add_opt(common_arg(
10211028
{"-p", "--prompt"}, "PROMPT",
10221029
"prompt to start generation with; for system message, use -sys",

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
12751275
cparams.op_offload = !params.no_op_offload;
12761276
cparams.swa_full = params.swa_full;
12771277
cparams.kv_unified = params.kv_unified;
1278+
cparams.use_paged_attention = params.use_paged_attention;
12781279

12791280
cparams.type_k = params.cache_type_k;
12801281
cparams.type_v = params.cache_type_v;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ struct common_params {
406406
bool ctx_shift = false; // context shift on infinite text generation
407407
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
408408
bool kv_unified = false; // enable unified KV cache
409+
bool use_paged_attention = false; // enable PagedAttention (experimental, requires CUDA)
409410

410411
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
411412
bool use_mmap = true; // use mmap for faster loads

ggml/include/ggml.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ extern "C" {
537537

538538
GGML_OP_FLASH_ATTN_EXT,
539539
GGML_OP_FLASH_ATTN_BACK,
540+
GGML_OP_PAGED_ATTENTION,
540541
GGML_OP_SSM_CONV,
541542
GGML_OP_SSM_SCAN,
542543
GGML_OP_WIN_PART,
@@ -2312,6 +2313,22 @@ extern "C" {
23122313
struct ggml_tensor * a,
23132314
struct ggml_tensor * sinks);
23142315

2316+
// PagedAttention (paged KV cache attention)
2317+
// q: [n_tokens, n_heads, head_size]
2318+
// k_cache: [num_blocks, block_size, n_kv_heads, head_size] (paged)
2319+
// v_cache: [num_blocks, block_size, n_kv_heads, head_size] (paged)
2320+
// block_tables: [n_seqs, max_blocks_per_seq] (int32)
2321+
// seq_lens: [n_seqs] (int32)
2322+
GGML_API struct ggml_tensor * ggml_paged_attention(
2323+
struct ggml_context * ctx,
2324+
struct ggml_tensor * q,
2325+
struct ggml_tensor * k_cache,
2326+
struct ggml_tensor * v_cache,
2327+
struct ggml_tensor * block_tables,
2328+
struct ggml_tensor * seq_lens,
2329+
int32_t block_size,
2330+
float scale);
2331+
23152332
// TODO: needs to be adapted to ggml_flash_attn_ext
23162333
GGML_API struct ggml_tensor * ggml_flash_attn_back(
23172334
struct ggml_context * ctx,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,6 +2062,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20622062
{
20632063
// nop
20642064
} break;
2065+
case GGML_OP_PAGED_ATTENTION:
2066+
{
2067+
// nop (CUDA-only operation)
2068+
} break;
20652069
case GGML_OP_COUNT:
20662070
{
20672071
GGML_ABORT("fatal error");
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
/**
2+
* GGML CUDA Backend for PagedAttention
3+
*
4+
* This file provides the CUDA backend implementation for the GGML_OP_PAGED_ATTENTION operation.
5+
* It bridges GGML's operation framework with the PagedAttention CUDA kernels.
6+
*/
7+
8+
#include "common.cuh"
9+
#include "paged-attention.cuh"
10+
11+
// Extract parameters from GGML tensor
12+
static void ggml_cuda_op_paged_attention_get_params(
13+
const ggml_tensor * dst,
14+
float * scale,
15+
int32_t * block_size) {
16+
17+
const float * params = (const float *)dst->op_params;
18+
*scale = params[0];
19+
*block_size = (int32_t)params[1];
20+
}
21+
22+
// Main CUDA backend function for PagedAttention
23+
void ggml_cuda_op_paged_attention(
24+
ggml_backend_cuda_context & ctx,
25+
ggml_tensor * dst) {
26+
27+
const ggml_tensor * q = dst->src[0]; // query
28+
const ggml_tensor * k_cache = dst->src[1]; // key cache (paged)
29+
const ggml_tensor * v_cache = dst->src[2]; // value cache (paged)
30+
const ggml_tensor * block_tables = dst->src[3]; // block tables
31+
const ggml_tensor * seq_lens = dst->src[4]; // sequence lengths
32+
33+
// Extract parameters
34+
float scale;
35+
int32_t block_size;
36+
ggml_cuda_op_paged_attention_get_params(dst, &scale, &block_size);
37+
38+
// Get tensor dimensions
39+
const int64_t head_size = q->ne[0];
40+
const int64_t n_heads = q->ne[1];
41+
const int64_t n_tokens = q->ne[2];
42+
const int64_t n_seqs = q->ne[3];
43+
44+
const int64_t n_kv_heads = k_cache->ne[2];
45+
const int64_t num_blocks = k_cache->ne[0];
46+
47+
const int64_t max_blocks_per_seq = block_tables->ne[0];
48+
49+
// Get pointers
50+
void * out_ptr = dst->data;
51+
const void * q_ptr = q->data;
52+
const void * k_cache_ptr = k_cache->data;
53+
const void * v_cache_ptr = v_cache->data;
54+
const int32_t * block_tables_ptr = (const int32_t *)block_tables->data;
55+
const int32_t * seq_lens_ptr = (const int32_t *)seq_lens->data;
56+
57+
// Calculate max sequence length (needed to decide V1 vs V2)
58+
int max_seq_len = 0;
59+
for (int i = 0; i < n_seqs; i++) {
60+
if (seq_lens_ptr[i] > max_seq_len) {
61+
max_seq_len = seq_lens_ptr[i];
62+
}
63+
}
64+
65+
// Get CUDA stream
66+
cudaStream_t stream = ctx.stream();
67+
68+
// Decide whether to use V1 or V2
69+
const bool use_v1 = ggml_cuda_paged_attention::should_use_v1(
70+
max_seq_len, n_seqs, n_heads);
71+
72+
// Launch appropriate kernel
73+
if (use_v1) {
74+
ggml_cuda_paged_attention::paged_attention_v1_launcher(
75+
out_ptr,
76+
q_ptr,
77+
k_cache_ptr,
78+
v_cache_ptr,
79+
n_seqs,
80+
n_heads,
81+
n_kv_heads,
82+
head_size,
83+
block_size,
84+
max_blocks_per_seq,
85+
block_tables_ptr,
86+
seq_lens_ptr,
87+
max_seq_len,
88+
scale,
89+
nullptr, // alibi_slopes (TODO: add support if needed)
90+
q->type,
91+
k_cache->type,
92+
stream);
93+
} else {
94+
ggml_cuda_paged_attention::paged_attention_v2_launcher(
95+
out_ptr,
96+
q_ptr,
97+
k_cache_ptr,
98+
v_cache_ptr,
99+
n_seqs,
100+
n_heads,
101+
n_kv_heads,
102+
head_size,
103+
block_size,
104+
max_blocks_per_seq,
105+
block_tables_ptr,
106+
seq_lens_ptr,
107+
max_seq_len,
108+
scale,
109+
nullptr, // alibi_slopes
110+
q->type,
111+
k_cache->type,
112+
stream);
113+
}
114+
115+
// Check for errors
116+
CUDA_CHECK(cudaGetLastError());
117+
}
118+
119+
// Check if PagedAttention is supported for given configuration
120+
bool ggml_cuda_can_paged_attention(const ggml_tensor * dst) {
121+
const ggml_tensor * q = dst->src[0];
122+
const ggml_tensor * k_cache = dst->src[1];
123+
124+
// Check data types
125+
if (q->type != GGML_TYPE_F16 && q->type != GGML_TYPE_F32) {
126+
return false;
127+
}
128+
129+
if (k_cache->type != GGML_TYPE_F16 && k_cache->type != GGML_TYPE_F32) {
130+
return false;
131+
}
132+
133+
// Check head size is supported
134+
const int64_t head_size = q->ne[0];
135+
const int supported_head_sizes[] = {32, 64, 80, 96, 112, 120, 128, 192, 256};
136+
bool head_size_supported = false;
137+
138+
for (int hs : supported_head_sizes) {
139+
if (head_size == hs) {
140+
head_size_supported = true;
141+
break;
142+
}
143+
}
144+
145+
if (!head_size_supported) {
146+
return false;
147+
}
148+
149+
// Extract block size and check it's supported
150+
float scale;
151+
int32_t block_size;
152+
ggml_cuda_op_paged_attention_get_params(dst, &scale, &block_size);
153+
154+
if (block_size != 8 && block_size != 16 && block_size != 32) {
155+
return false;
156+
}
157+
158+
return true;
159+
}

0 commit comments

Comments
 (0)