Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
server : add SWA checkpoints
ggml-ci
  • Loading branch information
ggerganov committed Aug 13, 2025
commit 96db966b1e1ca6c5bf39404890954fc54ec165e9
24 changes: 24 additions & 0 deletions src/llama-kv-cache-unified.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1957,6 +1957,10 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_
for (const auto & layer : layers) {
const uint32_t il = layer.il;

if (!hparams.is_swa(il)) {
continue;
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary hack to store just the SWA data

const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);

auto * k = layer.k_stream[cr.strm];
Expand All @@ -1981,6 +1985,10 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_
for (const auto & layer : layers) {
const uint32_t il = layer.il;

if (!hparams.is_swa(il)) {
continue;
}

const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);

auto * v = layer.v_stream[cr.strm];
Expand All @@ -2007,6 +2015,10 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_
for (const auto & layer : layers) {
const uint32_t il = layer.il;

if (!hparams.is_swa(il)) {
continue;
}

const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);

auto * v = layer.v_stream[cr.strm];
Expand Down Expand Up @@ -2162,6 +2174,10 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm
for (const auto & layer : layers) {
const uint32_t il = layer.il;

if (!hparams.is_swa(il)) {
continue;
}

const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);

auto * k = layer.k_stream[strm];
Expand Down Expand Up @@ -2194,6 +2210,10 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm
for (const auto & layer : layers) {
const uint32_t il = layer.il;

if (!hparams.is_swa(il)) {
continue;
}

const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);

auto * v = layer.v_stream[strm];
Expand Down Expand Up @@ -2226,6 +2246,10 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm
for (const auto & layer : layers) {
const uint32_t il = layer.il;

if (!hparams.is_swa(il)) {
continue;
}

const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);

auto * v = layer.v_stream[strm];
Expand Down
68 changes: 64 additions & 4 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,13 @@ struct completion_token_output {
}
};

struct swa_checkpoint {
std::vector<uint8_t> data;

llama_pos pos_min;
llama_pos pos_max;
};

struct server_task_result_cmpl_final : server_task_result {
int index = 0;

Expand Down Expand Up @@ -1336,6 +1343,8 @@ struct server_slot {

std::vector<completion_token_output> generated_token_probs;

std::vector<swa_checkpoint> swa_checkpoints;

bool has_next_token = true;
bool has_new_line = false;
bool truncated = false;
Expand Down Expand Up @@ -3300,10 +3309,42 @@ struct server_context {

const auto n_swa = llama_model_n_swa(model);
if (pos_min > std::max(0, slot.n_past - n_swa)) {
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
slot.n_past = 0;
// search for a SWA checkpoint
int ic = -1;
int np = std::numeric_limits<int>::max();
for (int i = 0; i < (int) slot.swa_checkpoints.size(); i++) {
const auto & cur = slot.swa_checkpoints[i];
if (cur.pos_min <= std::max(0, slot.n_past - n_swa)) {
const int p = std::max(0, slot.n_past - cur.pos_max);

if (p < np) {
ic = i;
np = p;
}
}
}

if (ic == -1) {
SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
slot.n_past = 0;

slot.swa_checkpoints.clear();
} else {
// erase all checkpoints after the one we are using
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin() + ic + 1, slot.swa_checkpoints.end());

// restore the checkpoint
const auto & cur = slot.swa_checkpoints[ic];

const size_t swa_size = cur.data.size();
llama_state_seq_set_data(ctx, cur.data.data(), swa_size, slot.id);

slot.n_past = std::min(slot.n_past, cur.pos_max);

SLT_WRN(slot, "prompt swa checkpoint restored, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024);
}
}
}
}
Expand Down Expand Up @@ -3517,6 +3558,25 @@ struct server_context {

// prompt evaluated for next-token prediction
slot.state = SLOT_STATE_GENERATING;

// make a checkpoint
if (llama_model_n_swa(model) > 0) {
if (slot.swa_checkpoints.size() > 8) {
slot.swa_checkpoints.erase(slot.swa_checkpoints.begin());
}

auto & cur = slot.swa_checkpoints.emplace_back();

cur.pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
cur.pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);

const size_t swa_size = llama_state_seq_get_size(ctx, slot.id);
cur.data.resize(swa_size);

llama_state_seq_get_data(ctx, cur.data.data(), swa_size, slot.id);

SLT_WRN(slot, "prompt swa checkpoint, pos_min = %d, pos_max = %d, size = %f MB\n", cur.pos_min, cur.pos_max, (float) swa_size / 1024 / 1024);
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots
}
Expand Down
Loading