Skip to content

Commit 6395174

Browse files
committed
fix save-load-state example
1 parent 7264596 commit 6395174

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

examples/save-load-state/save-load-state.cpp

+25-5
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,16 @@ int main(int argc, char ** argv) {
4848
// tokenize prompt
4949
auto tokens = common_tokenize(ctx, params.prompt, true);
5050

51+
// prepare the batch
52+
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
53+
for (size_t i = 0; i < tokens.size(); i++) {
54+
common_batch_add(batch, tokens[i], i, {0}, false);
55+
}
56+
batch.logits[batch.n_tokens - 1] = true; // generate next token
57+
5158
// evaluate prompt
52-
llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()));
53-
n_past += tokens.size();
59+
llama_decode(ctx, batch);
60+
n_past += batch.n_tokens;
5461

5562
// save state (rng, logits, embedding and kv_cache) to file
5663
{
@@ -77,8 +84,12 @@ int main(int argc, char ** argv) {
7784
printf("%s", next_token_str.c_str());
7885
result0 += next_token_str;
7986

80-
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1))) {
87+
common_batch_clear(batch);
88+
common_batch_add(batch, next_token, n_past, {0}, true);
89+
90+
if (llama_decode(ctx, batch)) {
8191
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
92+
llama_batch_free(batch);
8293
llama_free(ctx);
8394
llama_free_model(model);
8495
return 1;
@@ -133,8 +144,12 @@ int main(int argc, char ** argv) {
133144
printf("%s", next_token_str.c_str());
134145
result1 += next_token_str;
135146

136-
if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1))) {
147+
common_batch_clear(batch);
148+
common_batch_add(batch, next_token, n_past, {0}, true);
149+
150+
if (llama_decode(ctx2, batch)) {
137151
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
152+
llama_batch_free(batch);
138153
llama_free(ctx2);
139154
llama_free_model(model);
140155
return 1;
@@ -221,8 +236,12 @@ int main(int argc, char ** argv) {
221236
printf("%s", next_token_str.c_str());
222237
result2 += next_token_str;
223238

224-
if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) {
239+
common_batch_clear(batch);
240+
common_batch_add(batch, next_token, n_past, {1}, true);
241+
242+
if (llama_decode(ctx3, batch)) {
225243
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
244+
llama_batch_free(batch);
226245
llama_free(ctx3);
227246
llama_free_model(model);
228247
return 1;
@@ -236,6 +255,7 @@ int main(int argc, char ** argv) {
236255
llama_sampler_free(smpl2);
237256
llama_sampler_free(smpl3);
238257

258+
llama_batch_free(batch);
239259
llama_free(ctx3);
240260
llama_free_model(model);
241261

0 commit comments

Comments
 (0)