Skip to content

Commit 734f9e2

Browse files
committed
use common_batch_add, reuse llama_batch in loop
1 parent b4c9911 commit 734f9e2

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

examples/imatrix/imatrix.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
496496
// clear the KV cache
497497
llama_kv_cache_clear(ctx);
498498

499+
llama_batch batch = llama_batch_init(n_batch, 0, 1);
500+
499501
for (int j = 0; j < num_batches; ++j) {
500502
const int batch_start = start + j * n_batch;
501503
const int batch_size = std::min(end - batch_start, n_batch);
@@ -508,12 +510,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
508510
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
509511
}
510512

511-
llama_batch batch = llama_batch_init(batch_size, 0, 1);
513+
common_batch_clear(batch);
512514
for (int i = 0; i < batch_size; i++) {
513-
batch. token[i] = tokens[batch_start + i];
514-
batch. pos[i] = j*n_batch + i;
515-
batch.logits[i] = true;
516-
batch.seq_id[i][0] = 0;
515+
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
517516
}
518517

519518
if (llama_decode(ctx, batch)) {
@@ -522,8 +521,6 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
522521
return false;
523522
}
524523

525-
llama_batch_free(batch);
526-
527524
// restore the original token in case it was set to BOS
528525
tokens[batch_start] = token_org;
529526

@@ -533,6 +530,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
533530
}
534531
}
535532

533+
llama_batch_free(batch);
534+
536535
const auto t_end = std::chrono::high_resolution_clock::now();
537536

538537
if (i == 0) {

examples/perplexity/perplexity.cpp

+6-7
Original file line numberDiff line numberDiff line change
@@ -1800,6 +1800,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
18001800
// clear the KV cache
18011801
llama_kv_cache_clear(ctx);
18021802

1803+
llama_batch batch = llama_batch_init(n_batch, 0, 1);
1804+
18031805
for (int j = 0; j < num_batches; ++j) {
18041806
const int batch_start = start + j * n_batch;
18051807
const int batch_size = std::min(end - batch_start, n_batch);
@@ -1812,12 +1814,9 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
18121814
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
18131815
}
18141816

1815-
llama_batch batch = llama_batch_init(batch_size, 0, 1);
1817+
common_batch_clear(batch);
18161818
for (int i = 0; i < batch_size; i++) {
1817-
batch. token[i] = tokens[batch_start + i];
1818-
batch. pos[i] = j*n_batch + i;
1819-
batch.logits[i] = true;
1820-
batch.seq_id[i][0] = 0;
1819+
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
18211820
}
18221821

18231822
if (llama_decode(ctx, batch)) {
@@ -1826,8 +1825,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
18261825
return;
18271826
}
18281827

1829-
llama_batch_free(batch);
1830-
18311828
// restore the original token in case it was set to BOS
18321829
tokens[batch_start] = token_org;
18331830

@@ -1837,6 +1834,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
18371834
}
18381835
}
18391836

1837+
llama_batch_free(batch);
1838+
18401839
const auto t_end = std::chrono::high_resolution_clock::now();
18411840

18421841
if (i == 0) {

0 commit comments

Comments
 (0)