Skip to content

Commit 4a14540

Browse files
committed
model-conversion : add detect_pooling script
This commit adds a Python script to automatically detect the pooling configuration from a sentence-transformers model directory. The motivation for this change is that I make a mistake when adding the sentence-transformers support and I incorrectly assumed that if an embedding model uses sentence-transformers, it always used pooling. With the recent addition of support for late interaction models, which can have a down-projection but do not use pooling (like LFM2-ColBert-350M). This commit builds upon #18464 which needs to be merged first. Refs: #18607 (comment)
1 parent 193ee38 commit 4a14540

3 files changed

Lines changed: 56 additions & 18 deletions

File tree

examples/model-conversion/Makefile

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,17 @@ embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1
136136
embedding-run-original-model-st: embedding-run-original-model
137137

138138
embedding-run-converted-model:
139-
@./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
140-
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \
141-
$(if $(USE_POOLING),--pooling)
142-
143-
embedding-run-converted-model-st: USE_POOLING=1
144-
embedding-run-converted-model-st: embedding-run-converted-model
139+
@POOLING_FLAG=$$(./scripts/utils/detect_pooling.py $(EMBEDDING_MODEL_PATH)); \
140+
echo "pooling: $$POOLING_FLAG"; \
141+
./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \
142+
--pooling "$$POOLING_FLAG" \
143+
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
145144

146145
embedding-verify-logits: embedding-run-original-model embedding-run-converted-model
147146
@./scripts/embedding/compare-embeddings-logits.sh \
148147
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
149148

150-
embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model-st
149+
embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model
151150
@./scripts/embedding/compare-embeddings-logits.sh \
152151
$(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)")
153152

examples/model-conversion/scripts/embedding/run-converted-model.sh

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#!/usr/bin/env bash
22

3-
set -e
3+
set -ex
44

55
# Parse command line arguments
66
CONVERTED_MODEL=""
77
PROMPTS_FILE=""
8-
USE_POOLING=""
8+
POOLING=""
99

1010
while [[ $# -gt 0 ]]; do
1111
case $1 in
@@ -14,8 +14,8 @@ while [[ $# -gt 0 ]]; do
1414
shift 2
1515
;;
1616
--pooling)
17-
USE_POOLING="1"
18-
shift
17+
POOLING="$2"
18+
shift 2
1919
;;
2020
*)
2121
if [ -z "$CONVERTED_MODEL" ]; then
@@ -50,10 +50,5 @@ fi
5050

5151
echo $CONVERTED_MODEL
5252

53-
cmake --build ../../build --target llama-logits -j8
54-
# TODO: update logits.cpp to accept a --file/-f option for the prompt
55-
if [ -n "$USE_POOLING" ]; then
56-
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
57-
else
58-
../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
59-
fi
53+
cmake --build ../../build --target llama-debug -j8
54+
../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding $POOLING -p "$PROMPT" --save-logits
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Detect pooling configuration from sentence-transformers model.
4+
Usage: detect_pooling.py <model_dir>
5+
Outputs: pooling flag for llama-cli (e.g., "--pooling mean") or "--pooling none"
6+
"""
7+
8+
import sys
9+
import json
10+
from pathlib import Path
11+
12+
def detect_pooling(model_dir: str) -> str:
13+
model_path = Path(model_dir)
14+
15+
pooling_configs = list(model_path.glob("*_Pooling/config.json"))
16+
17+
if not pooling_configs:
18+
return "--pooling none"
19+
20+
config_path = pooling_configs[0]
21+
try:
22+
with open(config_path, 'r') as f:
23+
config = json.load(f)
24+
25+
if config.get("pooling_mode_mean_tokens", False):
26+
return "--pooling mean"
27+
elif config.get("pooling_mode_cls_token", False):
28+
return "--pooling cls"
29+
elif config.get("pooling_mode_lasttoken", False):
30+
return "--pooling last"
31+
else:
32+
print(f"Warning: Unsupported pooling mode in {config_path}", file=sys.stderr)
33+
return "--pooling none"
34+
35+
except Exception as e:
36+
print(f"Error reading pooling config: {e}", file=sys.stderr)
37+
return ""
38+
39+
if __name__ == "__main__":
40+
if len(sys.argv) != 2:
41+
print("Usage: detect_pooling.py <model_dir>", file=sys.stderr)
42+
sys.exit(1)
43+
44+
print(detect_pooling(sys.argv[1]))

0 commit comments

Comments
 (0)