File tree 3 files changed +10
-5
lines changed 3 files changed +10
-5
lines changed Original file line number Diff line number Diff line change @@ -1738,7 +1738,8 @@ struct server_context {
1738
1738
}
1739
1739
1740
1740
// process in chunks of params.n_batch
1741
- int32_t n_batch = params.n_batch ;
1741
+ int32_t n_batch = llama_n_batch (ctx);
1742
+ int32_t n_ubatch = llama_n_ubatch (ctx);
1742
1743
1743
1744
// next, batch any pending prompts without exceeding n_batch
1744
1745
if (params.cont_batching || batch.n_tokens == 0 ) {
@@ -1811,7 +1812,7 @@ struct server_context {
1811
1812
1812
1813
if (slot.embedding ) {
1813
1814
// this prompt is too large to process - discard it
1814
- if (slot.n_prompt_tokens > n_batch ) {
1815
+ if (slot.n_prompt_tokens > n_ubatch ) {
1815
1816
slot.state = SLOT_STATE_PROCESSING;
1816
1817
slot.command = SLOT_COMMAND_NONE;
1817
1818
slot.release ();
Original file line number Diff line number Diff line change @@ -8774,6 +8774,8 @@ static int llama_decode_internal(
8774
8774
8775
8775
GGML_ASSERT(n_tokens_all <= cparams.n_batch);
8776
8776
8777
+ GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens");
8778
+
8777
8779
if (lctx.t_compute_start_us == 0) {
8778
8780
lctx.t_compute_start_us = ggml_time_us();
8779
8781
}
@@ -9011,9 +9013,6 @@ static int llama_decode_internal(
9011
9013
case LLAMA_POOLING_TYPE_CLS:
9012
9014
case LLAMA_POOLING_TYPE_MEAN:
9013
9015
{
9014
- // FIXME: this may not work if the sequences are split into different batches
9015
- GGML_ASSERT(n_tokens_all == n_tokens);
9016
-
9017
9016
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);
9018
9017
9019
9018
// extract sequence embeddings
@@ -13076,6 +13075,10 @@ uint32_t llama_n_batch(const struct llama_context * ctx) {
13076
13075
return ctx->cparams.n_batch;
13077
13076
}
13078
13077
13078
+ uint32_t llama_n_ubatch(const struct llama_context *ctx) {
13079
+ return ctx->cparams.n_ubatch;
13080
+ }
13081
+
13079
13082
uint32_t llama_n_seq_max(const struct llama_context * ctx) {
13080
13083
return ctx->kv_self.size;
13081
13084
}
Original file line number Diff line number Diff line change @@ -378,6 +378,7 @@ extern "C" {
378
378
379
379
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
380
380
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
381
+ LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
381
382
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
382
383
383
384
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
You can’t perform that action at this time.
0 commit comments