Skip to content

Commit e8a0336

Browse files
committed
Use n_cls_out for pooling rank
1 parent 7ba071a commit e8a0336

2 files changed

Lines changed: 8 additions & 9 deletions

File tree

src/llama-context.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,17 +1217,17 @@ int llama_context::encode(const llama_batch & batch_inp) {
12171217
} break;
12181218
case LLAMA_POOLING_TYPE_RANK:
12191219
{
1220-
// extract the rerank score - n_embd_out floats per sequence
1220+
// extract the rerank score - n_cls_out floats per sequence
12211221
auto & embd_seq_out = embd_seq;
12221222

1223-
const uint32_t n_embd_out = hparams.get_n_embd_out();
1223+
const uint32_t n_cls_out = hparams.n_cls_out;
12241224

12251225
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
12261226
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
12271227
const int32_t seq_idx = ubatch.seq_idx[seq_id];
12281228

1229-
embd_seq_out[seq_id].resize(n_embd_out);
1230-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float));
1229+
embd_seq_out[seq_id].resize(n_cls_out);
1230+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
12311231
}
12321232
} break;
12331233
case LLAMA_POOLING_TYPE_UNSPECIFIED:
@@ -1628,17 +1628,17 @@ int llama_context::decode(const llama_batch & batch_inp) {
16281628
} break;
16291629
case LLAMA_POOLING_TYPE_RANK:
16301630
{
1631-
// extract the rerank score - n_embd_out floats per sequence
1631+
// extract the rerank score - n_cls_out floats per sequence
16321632
auto & embd_seq_out = embd_seq;
16331633

1634-
const uint32_t n_embd_out = hparams.get_n_embd_out();
1634+
const uint32_t n_cls_out = hparams.n_cls_out;
16351635

16361636
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
16371637
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
16381638
const int32_t seq_idx = ubatch.seq_idx[seq_id];
16391639

1640-
embd_seq_out[seq_id].resize(n_embd_out);
1641-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float));
1640+
embd_seq_out[seq_id].resize(n_cls_out);
1641+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
16421642
}
16431643
} break;
16441644
case LLAMA_POOLING_TYPE_UNSPECIFIED:

src/llama-model.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
628628
ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false);
629629
if (!classifier_labels.empty()) {
630630
hparams.n_cls_out = classifier_labels.size();
631-
hparams.n_embd_out = classifier_labels.size();
632631
}
633632

634633
// arch-specific KVs

0 commit comments

Comments
 (0)