Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
server : improve prompt caching logic
  • Loading branch information
ggerganov committed Oct 8, 2025
commit 677b10dda1523b7b159ad8df271b6a9c713d34bf
80 changes: 45 additions & 35 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "sampling.h"
#include "speculative.h"
#include "mtmd.h"
#include "mtmd-helper.h"

// mime type for sending response
#define MIMETYPE_JSON "application/json; charset=utf-8"
Expand Down Expand Up @@ -1439,6 +1438,9 @@ struct server_prompt_cache {
// in bytes, 0 = no limit
size_t limit_size = 2ull*1024*1024*1024;

// in tokens, 0 = no limit
size_t limit_tokens = 0;

size_t size() const {
size_t res = 0;

Expand All @@ -1449,15 +1451,51 @@ struct server_prompt_cache {
return res;
}

int n_tokens() const {
int res = 0;
size_t n_tokens() const {
size_t res = 0;

for (const auto & state : states) {
res += state.n_tokens();
}

return res;
}

void update() {
// always keep at least one state, regardless of the limits
if (states.size() > 1) {
if (limit_size > 0) {
while (size() > limit_size) {
if (states.empty()) {
break;
}

SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));

states.pop_front();
}
}

if (limit_tokens > 0) {
while (n_tokens() > limit_tokens) {
if (states.empty()) {
break;
}

SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));

states.pop_front();
}
}
}

SRV_WRN(" - cache state: %zu prompts, %.3f MiB, limits: %.3f MiB, %zu tokens\n",
states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens);

for (const auto & state : states) {
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
}
}
};

struct server_slot {
Expand Down Expand Up @@ -1805,7 +1843,7 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
const int len = cached_prompt.get_common_prefix(prompt.tokens);

if (len == (int) cached_prompt.size()) {
SRV_WRN(" - removing cached prompt with length %d\n", len);
SRV_WRN(" - removing obsolete cached prompt with length %d\n", len);

it = states.erase(it);
} else {
Expand All @@ -1815,33 +1853,9 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {

const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0);

SRV_WRN(" - saving prompt with length %d, total cache size = %.3f MiB\n",
SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n",
(int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0));

// if there is a limit, remove the oldest entries to make room
if (prompt_cache.limit_size > 0) {
while (prompt_cache.size() + cur_size > prompt_cache.limit_size) {
if (states.empty()) {
break;
}

SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));

states.pop_front();
}
} else {
// else, make sure the number of cached tokens doesn't exceed the context size of the slot
while (prompt_cache.n_tokens() + (int) prompt.tokens.size() > n_ctx) {
if (states.empty()) {
break;
}

SRV_WRN(" - cache token limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0));

states.pop_front();
}
}

// TODO: for some reason we can't copy server_tokens, so we have to do this workaround
auto & cur = states.emplace_back();
cur = {
Expand All @@ -1851,12 +1865,6 @@ void server_slot::prompt_save(server_prompt_cache & prompt_cache) {
};

llama_state_seq_get_data_ext(ctx, cur.data.data(), cur_size, id, 0);

SRV_WRN(" - cache state: %zu prompts, %.3f MiB\n", states.size(), prompt_cache.size() / (1024.0 * 1024.0));

for (const auto & state : states) {
SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %.3f MiB\n", (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0));
}
}

void server_slot::prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) {
Expand Down Expand Up @@ -2611,6 +2619,8 @@ struct server_context {
ret->prompt_save(prompt_cache);
ret->prompt_load(prompt_cache, task.tokens);

prompt_cache.update();

SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0);
}
}
Expand Down