@@ -758,7 +758,8 @@ float * llama_context::get_embeddings_ith(int32_t i) {
758758 throw std::runtime_error (format (" corrupt output buffer (j=%" PRId64 " , n_outputs=%d)" , j, n_outputs));
759759 }
760760
761- return embd + j*model.hparams .n_embd ;
761+ const uint32_t n_embd_out = model.hparams .get_n_embd_out ();
762+ return embd + j*n_embd_out;
762763 } catch (const std::exception & err) {
763764 LLAMA_LOG_ERROR (" %s: invalid embeddings id %d, reason: %s\n " , __func__, i, err.what ());
764765#ifndef NDEBUG
@@ -1194,9 +1195,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
11941195 {
11951196 // extract token embeddings
11961197 GGML_ASSERT (embd != nullptr );
1198+ const uint32_t n_embd_out = hparams.get_n_embd_out ();
11971199
1198- GGML_ASSERT (n_tokens*n_embd <= (int64_t ) embd_size);
1199- ggml_backend_tensor_get_async (backend_embd, t_embd, embd, 0 , n_tokens*n_embd *sizeof (float ));
1200+ GGML_ASSERT (n_tokens*n_embd_out <= (int64_t ) embd_size);
1201+ ggml_backend_tensor_get_async (backend_embd, t_embd, embd, 0 , n_tokens*n_embd_out *sizeof (float ));
12001202 } break ;
12011203 case LLAMA_POOLING_TYPE_MEAN :
12021204 case LLAMA_POOLING_TYPE_CLS :
@@ -1215,17 +1217,17 @@ int llama_context::encode(const llama_batch & batch_inp) {
12151217 } break ;
12161218 case LLAMA_POOLING_TYPE_RANK :
12171219 {
1218- // extract the rerank score - n_cls_out floats per sequence
1220+ // extract the rerank score - n_embd_out floats per sequence
12191221 auto & embd_seq_out = embd_seq;
12201222
1221- const uint32_t n_cls_out = hparams.n_cls_out ;
1223+ const uint32_t n_embd_out = hparams.get_n_embd_out () ;
12221224
12231225 for (uint32_t s = 0 ; s < ubatch.n_seqs_unq ; ++s) {
12241226 const llama_seq_id seq_id = ubatch.seq_id_unq [s];
12251227 const int32_t seq_idx = ubatch.seq_idx [seq_id];
12261228
1227- embd_seq_out[seq_id].resize (n_cls_out );
1228- ggml_backend_tensor_get_async (backend_embd, t_embd, embd_seq_out[seq_id].data (), (n_cls_out *seq_idx)*sizeof (float ), n_cls_out *sizeof (float ));
1229+ embd_seq_out[seq_id].resize (n_embd_out );
1230+ ggml_backend_tensor_get_async (backend_embd, t_embd, embd_seq_out[seq_id].data (), (n_embd_out *seq_idx)*sizeof (float ), n_embd_out *sizeof (float ));
12291231 }
12301232 } break ;
12311233 case LLAMA_POOLING_TYPE_UNSPECIFIED :
@@ -1600,12 +1602,13 @@ int llama_context::decode(const llama_batch & batch_inp) {
16001602 {
16011603 // extract token embeddings
16021604 GGML_ASSERT (embd != nullptr );
1603- float * embd_out = embd + n_outputs_prev*n_embd;
1605+ const uint32_t n_embd_out = hparams.get_n_embd_out ();
1606+ float * embd_out = embd + n_outputs_prev*n_embd_out;
16041607
16051608 if (n_outputs) {
16061609 GGML_ASSERT ( n_outputs_prev + n_outputs <= n_outputs_all);
1607- GGML_ASSERT ((n_outputs_prev + n_outputs)*n_embd <= (int64_t ) embd_size);
1608- ggml_backend_tensor_get_async (backend_embd, t_embd, embd_out, 0 , n_outputs*n_embd *sizeof (float ));
1610+ GGML_ASSERT ((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t ) embd_size);
1611+ ggml_backend_tensor_get_async (backend_embd, t_embd, embd_out, 0 , n_outputs*n_embd_out *sizeof (float ));
16091612 }
16101613 } break ;
16111614 case LLAMA_POOLING_TYPE_MEAN :
@@ -1625,17 +1628,17 @@ int llama_context::decode(const llama_batch & batch_inp) {
16251628 } break ;
16261629 case LLAMA_POOLING_TYPE_RANK :
16271630 {
1628- // extract the rerank score - n_cls_out floats per sequence
1631+ // extract the rerank score - n_embd_out floats per sequence
16291632 auto & embd_seq_out = embd_seq;
16301633
1631- const uint32_t n_cls_out = hparams.n_cls_out ;
1634+ const uint32_t n_embd_out = hparams.get_n_embd_out () ;
16321635
16331636 for (uint32_t s = 0 ; s < ubatch.n_seqs_unq ; ++s) {
16341637 const llama_seq_id seq_id = ubatch.seq_id_unq [s];
16351638 const int32_t seq_idx = ubatch.seq_idx [seq_id];
16361639
1637- embd_seq_out[seq_id].resize (n_cls_out );
1638- ggml_backend_tensor_get_async (backend_embd, t_embd, embd_seq_out[seq_id].data (), (n_cls_out *seq_idx)*sizeof (float ), n_cls_out *sizeof (float ));
1640+ embd_seq_out[seq_id].resize (n_embd_out );
1641+ ggml_backend_tensor_get_async (backend_embd, t_embd, embd_seq_out[seq_id].data (), (n_embd_out *seq_idx)*sizeof (float ), n_embd_out *sizeof (float ));
16391642 }
16401643 } break ;
16411644 case LLAMA_POOLING_TYPE_UNSPECIFIED :
@@ -1730,9 +1733,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
17301733
17311734 const int64_t n_outputs_max = std::max<int64_t >(n_outputs, n_seq_max ());
17321735
1733- const auto n_batch = cparams.n_batch ;
1734- const auto n_vocab = vocab.n_tokens ();
1735- const auto n_embd = hparams.n_embd ;
1736+ const auto n_batch = cparams.n_batch ;
1737+ const auto n_vocab = vocab.n_tokens ();
1738+ const auto n_embd_out = hparams.get_n_embd_out () ;
17361739
17371740 bool has_logits = true ;
17381741 bool has_embd = cparams.embeddings ;
@@ -1773,7 +1776,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
17731776
17741777 // Allocate CPU logits buffer only if needed by sequences in this batch
17751778 logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0 ;
1776- embd_size = has_embd ? n_embd *n_outputs_max : 0 ;
1779+ embd_size = has_embd ? n_embd_out *n_outputs_max : 0 ;
17771780
17781781 // TODO: avoid this branching by working with the worst-case
17791782 if (!has_sampling) {
0 commit comments