@@ -48,9 +48,16 @@ int main(int argc, char ** argv) {
48
48
// tokenize prompt
49
49
auto tokens = common_tokenize (ctx, params.prompt , true );
50
50
51
+ // prepare the batch
52
+ llama_batch batch = llama_batch_init (tokens.size (), 0 , 1 );
53
+ for (size_t i = 0 ; i < tokens.size (); i++) {
54
+ common_batch_add (batch, tokens[i], i, {0 }, false );
55
+ }
56
+ batch.logits [batch.n_tokens - 1 ] = true ; // generate next token
57
+
51
58
// evaluate prompt
52
- llama_decode (ctx, llama_batch_get_one (tokens. data (), tokens. size ()) );
53
- n_past += tokens. size () ;
59
+ llama_decode (ctx, batch );
60
+ n_past += batch. n_tokens ;
54
61
55
62
// save state (rng, logits, embedding and kv_cache) to file
56
63
{
@@ -77,8 +84,12 @@ int main(int argc, char ** argv) {
77
84
printf (" %s" , next_token_str.c_str ());
78
85
result0 += next_token_str;
79
86
80
- if (llama_decode (ctx, llama_batch_get_one (&next_token, 1 ))) {
87
+ common_batch_clear (batch);
88
+ common_batch_add (batch, next_token, n_past, {0 }, true );
89
+
90
+ if (llama_decode (ctx, batch)) {
81
91
fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
92
+ llama_batch_free (batch);
82
93
llama_free (ctx);
83
94
llama_free_model (model);
84
95
return 1 ;
@@ -133,8 +144,12 @@ int main(int argc, char ** argv) {
133
144
printf (" %s" , next_token_str.c_str ());
134
145
result1 += next_token_str;
135
146
136
- if (llama_decode (ctx2, llama_batch_get_one (&next_token, 1 ))) {
147
+ common_batch_clear (batch);
148
+ common_batch_add (batch, next_token, n_past, {0 }, true );
149
+
150
+ if (llama_decode (ctx2, batch)) {
137
151
fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
152
+ llama_batch_free (batch);
138
153
llama_free (ctx2);
139
154
llama_free_model (model);
140
155
return 1 ;
@@ -221,8 +236,12 @@ int main(int argc, char ** argv) {
221
236
printf (" %s" , next_token_str.c_str ());
222
237
result2 += next_token_str;
223
238
224
- if (llama_decode (ctx3, llama_batch_get_one (&next_token, 1 ))) {
239
+ common_batch_clear (batch);
240
+ common_batch_add (batch, next_token, n_past, {1 }, true );
241
+
242
+ if (llama_decode (ctx3, batch)) {
225
243
fprintf (stderr, " \n %s : failed to evaluate\n " , __func__);
244
+ llama_batch_free (batch);
226
245
llama_free (ctx3);
227
246
llama_free_model (model);
228
247
return 1 ;
@@ -236,6 +255,7 @@ int main(int argc, char ** argv) {
236
255
llama_sampler_free (smpl2);
237
256
llama_sampler_free (smpl3);
238
257
258
+ llama_batch_free (batch);
239
259
llama_free (ctx3);
240
260
llama_free_model (model);
241
261
0 commit comments