Skip to content

Commit 202adca

Browse files
committed
check n_ubatch >= n_tokens with non-casual attention
1 parent 54cdd47 commit 202adca

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

examples/server/server.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,7 +1738,8 @@ struct server_context {
17381738
}
17391739

17401740
// process in chunks of params.n_batch
1741-
int32_t n_batch = params.n_batch;
1741+
int32_t n_batch = llama_n_batch(ctx);
1742+
int32_t n_ubatch = llama_n_ubatch(ctx);
17421743

17431744
// next, batch any pending prompts without exceeding n_batch
17441745
if (params.cont_batching || batch.n_tokens == 0) {
@@ -1811,7 +1812,7 @@ struct server_context {
18111812

18121813
if (slot.embedding) {
18131814
// this prompt is too large to process - discard it
1814-
if (slot.n_prompt_tokens > n_batch) {
1815+
if (slot.n_prompt_tokens > n_ubatch) {
18151816
slot.state = SLOT_STATE_PROCESSING;
18161817
slot.command = SLOT_COMMAND_NONE;
18171818
slot.release();

llama.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8774,6 +8774,8 @@ static int llama_decode_internal(
87748774

87758775
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
87768776

8777+
GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
8778+
87778779
if (lctx.t_compute_start_us == 0) {
87788780
lctx.t_compute_start_us = ggml_time_us();
87798781
}
@@ -9011,9 +9013,6 @@ static int llama_decode_internal(
90119013
case LLAMA_POOLING_TYPE_CLS:
90129014
case LLAMA_POOLING_TYPE_MEAN:
90139015
{
9014-
// FIXME: this may not work if the sequences are split into different batches
9015-
GGML_ASSERT(n_tokens_all == n_tokens);
9016-
90179016
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
90189017

90199018
// extract sequence embeddings
@@ -13076,6 +13075,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
1307613075
return ctx->cparams.n_batch;
1307713076
}
1307813077

13078+
uint32_t llama_n_ubatch(const struct llama_context *ctx) {
13079+
return ctx->cparams.n_ubatch;
13080+
}
13081+
1307913082
uint32_t llama_n_seq_max(const struct llama_context * ctx) {
1308013083
return ctx->kv_self.size;
1308113084
}

llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ extern "C" {
378378

379379
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
380380
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
381+
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
381382
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
382383

383384
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);

0 commit comments

Comments
 (0)