From 519a8253eebfe815c716fc089c32907f5718c2d7 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 15:08:17 +0000 Subject: [PATCH 01/43] sampler: turn lazy grammar trigger words to regexes --- common/sampling.cpp | 32 +++++++++++++++++++++++------ examples/server/server.cpp | 3 ++- include/llama.h | 8 ++++---- src/llama-grammar.cpp | 41 +++++++++++++++++++------------------- src/llama-grammar.h | 10 +++++++--- src/llama-sampling.cpp | 20 ++++++++++--------- 6 files changed, 70 insertions(+), 44 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 37a0d9c85ae30..9eea0f749f3be 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -3,6 +3,7 @@ #include "common.h" #include +#include #include // the ring buffer works similarly to std::deque, but with a fixed capacity @@ -144,6 +145,11 @@ std::string common_params_sampling::print() const { return std::string(result); } +inline std::string regex_escape(const std::string & literal) { + static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); + return std::regex_replace(literal, special_chars, "\\$0"); +} + struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); @@ -159,15 +165,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE } else { - std::vector trigger_words; - trigger_words.reserve(params.grammar_trigger_words.size()); - for (const auto & str : params.grammar_trigger_words) { - trigger_words.push_back(str.word.c_str()); + std::vector escaped_triggers_at_start; + std::vector escaped_triggers_anywhere; + for (const auto & trigger : params.grammar_trigger_words) { + (trigger.at_start ? escaped_triggers_at_start : escaped_triggers_anywhere) + .push_back(regex_escape(trigger.word)); + } + + std::vector trigger_regexes; + if (!escaped_triggers_at_start.empty()) { + trigger_regexes.push_back("^(" + string_join(escaped_triggers_at_start, "|") + ")[\\s\\S]*"); + } + if (!escaped_triggers_anywhere.empty()) { + trigger_regexes.push_back("^[\\s\\S]*?(" + string_join(escaped_triggers_anywhere, "|") + ")[\\s\\S]*"); + } + + std::vector trigger_regexes_c; + trigger_regexes_c.reserve(trigger_regexes.size()); + for (const auto & regex : trigger_regexes) { + trigger_regexes_c.push_back(regex.c_str()); } grmr = params.grammar_lazy ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", - trigger_words.data(), trigger_words.size(), + trigger_regexes_c.data(), trigger_regexes_c.size(), params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()) : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); } @@ -202,7 +223,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co for (const auto & str : params.dry_sequence_breakers) { c_breakers.push_back(str.c_str()); } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 809bfe0e36cd7..bdea828aeba3f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -170,6 +170,7 @@ struct slot_params { {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, {"grammar", sampling.grammar}, + {"grammar_lazy", sampling.grammar_lazy}, {"grammar_trigger_words", grammar_trigger_words}, {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, {"preserved_tokens", sampling.preserved_tokens}, @@ -2045,7 +2046,7 @@ struct server_context { if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? - SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict); + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict); slot.params.n_predict = slot.n_predict; } diff --git a/include/llama.h b/include/llama.h index b0726cbe63ea6..b26f2b05e91f8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1204,14 +1204,14 @@ extern "C" { const char * grammar_root); /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 - /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. - /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. + /// @param trigger_regexes A list of (full-string) regexes that will trigger the grammar sampler. Grammar sampler will be fed content starting from the first match group. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included. LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_regexes, + size_t num_trigger_regexes, const llama_token * trigger_tokens, size_t num_trigger_tokens); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 46e27a96ed728..06c3d188d4c2e 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -969,7 +969,7 @@ struct llama_grammar * llama_grammar_init_impl( /* .awaiting_trigger = */ false, /* .trigger_buffer = */ "", /* .trigger_tokens = */ {}, - /* .trigger_words = */ {}, + /* .trigger_regexes = */ {}, }; } @@ -978,19 +978,15 @@ struct llama_grammar * llama_grammar_init_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_regexes, + size_t num_trigger_regexes, const llama_token * trigger_tokens, size_t num_trigger_tokens) { llama_grammar_parser parser; // if there is a grammar, parse it - if (!parser.parse(grammar_str)) { - return nullptr; - } - - // will be empty (default) if there are parse errors - if (parser.rules.empty()) { + // rules will be empty (default) if there are parse errors + if (!parser.parse(grammar_str) || parser.rules.empty()) { fprintf(stderr, "%s: failed to parse grammar\n", __func__); return nullptr; } @@ -1054,14 +1050,14 @@ struct llama_grammar * llama_grammar_init_impl( } while (true); std::vector vec_trigger_tokens; - std::vector vec_trigger_words; + std::vector> vec_trigger_regexes; for (size_t i = 0; i < num_trigger_tokens; i++) { GGML_ASSERT(trigger_tokens != nullptr); vec_trigger_tokens.push_back(trigger_tokens[i]); } - for (size_t i = 0; i < num_trigger_words; i++) { - GGML_ASSERT(trigger_words != nullptr); - vec_trigger_words.push_back(trigger_words[i]); + for (size_t i = 0; i < num_trigger_regexes; i++) { + GGML_ASSERT(trigger_regexes != nullptr); + vec_trigger_regexes.emplace_back(trigger_regexes[i], trigger_regexes[i]); } // Important: vec_rules has to be moved here, not copied, because stacks contains @@ -1076,7 +1072,7 @@ struct llama_grammar * llama_grammar_init_impl( /* .awaiting_trigger = */ lazy, /* .trigger_buffer = */ "", std::move(vec_trigger_tokens), - std::move(vec_trigger_words), + std::move(vec_trigger_regexes), }; } @@ -1098,7 +1094,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.awaiting_trigger, grammar.trigger_buffer, grammar.trigger_tokens, - grammar.trigger_words, + grammar.trigger_regexes, }; // redirect elements in stacks to point to new rules @@ -1173,19 +1169,22 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); return; } else { - // TODO: consider a smarter incremental substring search algorithm (store last position to search from). grammar.trigger_buffer += piece; - for (const auto & word : grammar.trigger_words) { - auto pos = grammar.trigger_buffer.find(word); - if (pos != std::string::npos) { + + std::smatch match; + for (const auto & [_, regex] : grammar.trigger_regexes) { + if (std::regex_match(grammar.trigger_buffer, match, regex)) { grammar.awaiting_trigger = false; - auto constrained_str = grammar.trigger_buffer.substr(pos); + // get from the first match to the end of the string + auto constrained_str = grammar.trigger_buffer.substr(match.position(1)); + // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); grammar.trigger_buffer.clear(); llama_grammar_accept_str(grammar, constrained_str); - LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str()); + LLAMA_LOG_DEBUG("Grammar triggered on regex: %s", constrained_str.c_str()); return; } } + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); return; } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index b143d834cfabe..8d9b1a81dfd1c 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include #include #include @@ -122,7 +123,10 @@ struct llama_grammar { bool awaiting_trigger = false; // Initialized to true for lazy grammars only std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). - std::vector trigger_words; + std::vector> + trigger_regexes; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated + // string, and the grammar will be given the string from the first match group onwards. + }; // @@ -141,8 +145,8 @@ struct llama_grammar * llama_grammar_init_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_regexes, + size_t num_trigger_regexes, const llama_token * trigger_tokens, size_t num_trigger_tokens); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index f40bf2db83a80..7d5f9e86584c1 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1446,8 +1446,8 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_regexes, + size_t num_trigger_regexes, const llama_token * trigger_tokens, size_t num_trigger_tokens); @@ -1457,12 +1457,14 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { return; } - std::vector trigger_words; - for (auto & word : ctx->grammar->trigger_words) { - trigger_words.push_back(word.c_str()); + std::vector trigger_regexes_c; + trigger_regexes_c.reserve(ctx->grammar->trigger_regexes.size()); + for (auto & [pattern, _] : ctx->grammar->trigger_regexes) { + trigger_regexes_c.push_back(pattern.c_str()); } + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), - ctx->grammar->lazy, trigger_words.data(), trigger_words.size(), + ctx->grammar->lazy, trigger_regexes_c.data(), trigger_regexes_c.size(), ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); llama_grammar_free_impl(ctx->grammar); @@ -1513,8 +1515,8 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_regexes, + size_t num_trigger_regexes, const llama_token * trigger_tokens, size_t num_trigger_tokens) { auto * ctx = new llama_sampler_grammar; @@ -1524,7 +1526,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( /* .vocab = */ vocab, /* .grammar_str = */ grammar_str, /* .grammar_root = */ grammar_root, - /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens), + /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_regexes, num_trigger_regexes, trigger_tokens, num_trigger_tokens), }; } else { *ctx = { From b57690af025c9233706823e9ad1249fea7c6d644 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 15:13:32 +0000 Subject: [PATCH 02/43] tool-call: add `script/tool_bench.py` --- examples/server/tests/unit/test_tool_call.py | 121 ++++++----- scripts/plot_tool_call_tests.py | 190 +++++++++++++++++ scripts/tool_bench.py | 203 +++++++++++++++++++ 3 files changed, 467 insertions(+), 47 deletions(-) create mode 100644 scripts/plot_tool_call_tests.py create mode 100755 scripts/tool_bench.py diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index a91a2f3333ca3..4467de2a2e755 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -1,4 +1,12 @@ +#!/usr/bin/env python import pytest + +# ensure grandparent path is in sys.path +from pathlib import Path +import sys +path = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(path)) + from utils import * server: ServerProcess @@ -66,15 +74,8 @@ def create_server(): } -def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None): - global server - n_predict = 512 - # server = ServerPreset.stories15m_moe() - server.jinja = True - server.n_predict = n_predict - server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ +def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs): + res = server.make_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ {"role": "system", "content": "You are a coding assistant."}, @@ -83,16 +84,14 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a "tool_choice": "required", "tools": [tool], "parallel_tool_calls": False, - "temperature": 0.0, - "top_k": 1, - "top_p": 1.0, + **kwargs, }) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] - assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] assert expected_function_name == tool_call["function"]["name"] actual_arguments = tool_call["function"]["arguments"] @@ -108,7 +107,14 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), ]) def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None): - do_test_completion_with_required_tool_tiny(template_name, tool, argument_key) + global server + n_predict = 512 + # server = ServerPreset.stories15m_moe() + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0) @pytest.mark.slow @@ -133,7 +139,14 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"), ]) def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None): - do_test_completion_with_required_tool_tiny(template_name, tool, argument_key) + global server + n_predict = 512 + # server = ServerPreset.stories15m_moe() + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict) @pytest.mark.slow @@ -197,7 +210,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ + res = server.make_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ {"role": "system", "content": "You are a coding assistant."}, @@ -215,7 +228,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] - assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] assert expected_function_name == tool_call["function"]["name"] actual_arguments = tool_call["function"]["arguments"] @@ -225,13 +238,8 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" -def do_test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): - global server - server.jinja = True - server.n_predict = n_predict - server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ +def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs): + res = server.make_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ {"role": "system", "content": "You are a coding assistant."}, @@ -239,9 +247,7 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too ], "tools": tools if tools else None, "tool_choice": tool_choice, - "temperature": 0.0, - "top_k": 1, - "top_p": 1.0, + **kwargs, }, timeout=TIMEOUT_HTTP_REQUEST) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] @@ -254,7 +260,12 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'), ]) def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): - do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice) + global server + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_without_tool_call(server, n_predict, tools, tool_choice) @pytest.mark.slow @@ -270,7 +281,12 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t ("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'), ]) def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): - do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice) + global server + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_without_tool_call(server, n_predict, tools, tool_choice) @pytest.mark.slow @@ -324,26 +340,30 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": n_predict, + do_test_weather(server, max_tokens=n_predict) + + +def do_test_weather(server: ServerProcess, **kwargs): + res = server.make_request("POST", "/v1/chat/completions", data={ "messages": [ {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."}, {"role": "user", "content": "What is the weather in Istanbul?"}, ], "tools": [WEATHER_TOOL], + **kwargs, }, timeout=TIMEOUT_HTTP_REQUEST) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] - assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"] actual_arguments = json.loads(tool_call["function"]["arguments"]) assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" location = actual_arguments["location"] assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}" - assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' + assert re.match('^Istanbul(, ?(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' @pytest.mark.slow @@ -379,10 +399,14 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ + do_test_calc_result(server, result_override, n_predict) + + +def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs): + res = server.make_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ - {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things, and provide very concise answers. Do not explain your reasoning to the user. Provide any numerical values back to the user with at most two decimals."}, + {"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."}, {"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"}, { "role": "assistant", @@ -423,7 +447,8 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, } } } - ] + ], + **kwargs, }, timeout=TIMEOUT_HTTP_REQUEST) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] @@ -434,7 +459,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, if result_override is not None: assert re.match(result_override, content), f'Expected {result_override}, got {content}' else: - assert re.match('^[\\s\\S]*?The (y[ -])?coordinate [\\s\\S]*?is (approximately )?0\\.56\\b|^0\\.56$', content), \ + assert re.match('^[\\s\\S]*?((That\'s|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)', content), \ f'Expected something like "The y coordinate is 0.56.", got {content}' @@ -464,7 +489,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ + res = server.make_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ {"role": "user", "content": "What's the sum of 102 and 7?"}, @@ -476,7 +501,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] content = choice["message"].get("content") if expect_content is None: - assert content is None, f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' else: assert re.match(expect_content, content), f'Expected {expect_content}, got {content}' @@ -528,6 +553,7 @@ def test_hello_world(expected_arguments_override: str | None, hf_repo: str, temp server.jinja = True server.n_ctx = 8192 server.n_predict = 512 # High because of DeepSeek R1 + server.n_predict = n_predict server.model_hf_repo = hf_repo server.model_hf_file = None if isinstance(template_override, tuple): @@ -537,24 +563,25 @@ def test_hello_world(expected_arguments_override: str | None, hf_repo: str, temp elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": 256, + + do_test_hello_world(server, expected_arguments_override, max_tokens=n_predict) + + +def do_test_hello_world(server: ServerProcess, expected_arguments_override, **kwargs): + res = server.make_request("POST", "/v1/chat/completions", data={ "messages": [ - {"role": "system", "content": "You are a coding assistant."}, + {"role": "system", "content": "You are a tool-calling agent."}, {"role": "user", "content": "say hello world with python"}, ], "tools": [PYTHON_TOOL], - # Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n print("Hello, World!")\nhello_world()` which is correct but a pain to test. - "temperature": 0.0, - "top_k": 1, - "top_p": 1.0, + **kwargs, }, timeout=TIMEOUT_HTTP_REQUEST) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] - assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] actual_arguments = tool_call["function"]["arguments"] if expected_arguments_override is not None: diff --git a/scripts/plot_tool_call_tests.py b/scripts/plot_tool_call_tests.py new file mode 100644 index 0000000000000..b54aecce526dd --- /dev/null +++ b/scripts/plot_tool_call_tests.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +""" +Model Performance Analysis and Visualization Tool + +This script analyzes JSON performance data for different model implementations and tests, +creating a heatmap visualization of success ratios. It handles multiple input files and +supports various model configurations. + +Usage: + python script.py input_file1.json [input_file2.json ...] +""" + +import json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +import sys +from typing import Dict, List, Tuple, Set, Any +from pathlib import Path +import logging + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +class ModelAnalyzer: + def __init__(self): + self.lines: List[Dict] = [] + self.data_dict: Dict[Tuple, float] = {} + self.models: List[str] = [] + self.temps: Set[float] = set() + self.tests: Set[str] = set() + self.impls: Set[str] = set() + + self.column_groups = [ + ("llama-server", []), # Tests will be populated dynamically + ("llama-server (no grammar)", []), + ("ollama", []) + ] + + def read_files(self, files: List[str]) -> None: + """Read and parse JSON data from input files.""" + for file in files: + path = Path(file) + if not path.exists(): + logger.error(f"File not found: {file}") + continue + + try: + with path.open() as f: + raw_data = f.read() + logger.info(f"Reading {file} ({len(raw_data)} bytes)") + + for line_num, line in enumerate(raw_data.split('\n'), 1): + line = line.strip() + if not line: + continue + try: + record = json.loads(line) + self.lines.append(record) + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON at {file}:{line_num} - {e}") + except Exception as e: + logger.error(f"Error processing {file}: {e}") + + def process_data(self) -> None: + """Process the loaded data and organize it for visualization.""" + for rec in self.lines: + try: + model = rec["model"] + temp = rec["temp"] + impl = rec["implementation"] + test = rec["test"] + success = rec["success_ratio"] + + self.data_dict[(model, temp, impl, test)] = success + + if model not in self.models: + self.models.append(model) + self.temps.add(temp) + self.tests.add(test) + self.impls.add(impl) + + except KeyError as e: + logger.warning(f"Missing required field in record: {e}") + + # Sort the collected values + self.temps = sorted(list(self.temps), key=lambda x: x if x is not None else -1) + self.tests = sorted(list(self.tests)) + + # Update column groups with actual tests + self.column_groups = [ + (impl, list(self.tests)) for impl, _ in self.column_groups + if impl in self.impls + ] + + def create_matrix(self) -> pd.DataFrame: + """Create a matrix for visualization.""" + all_cols = [ + (impl, test) + for impl, tests in self.column_groups + for test in tests + ] + + matrix = [] + index = [] + + for model in self.models: + for temp in self.temps: + index.append(f"{model} @ {temp}") + row_vals = [ + self.data_dict.get((model, temp, impl, test), np.nan) + for impl, test in all_cols + ] + matrix.append(row_vals) + + # Create column labels + col_labels = [f"{impl}\n({test})" for impl, test in all_cols] + + return pd.DataFrame(matrix, index=index, columns=col_labels) + + def plot_heatmap(self, df: pd.DataFrame, output_file: str = None) -> None: + """Create and display/save the heatmap visualization.""" + plt.figure(figsize=(12, 6)) + + sns.heatmap( + df, + annot=True, + cmap="RdYlGn", + vmin=0.0, + vmax=1.0, + cbar=True, + fmt=".2f", + center=0.5, + square=True, + linewidths=0.5, + cbar_kws={"label": "Success Ratio"} + ) + + plt.title("Model Performance Analysis\nSuccess Ratios by Implementation & Test", + pad=20) + plt.xlabel("Implementation and Test", labelpad=10) + plt.ylabel("Model @ Temperature", labelpad=10) + + plt.xticks(rotation=45, ha='right') + plt.yticks(rotation=0) + + plt.tight_layout() + + if output_file: + plt.savefig(output_file, dpi=300, bbox_inches='tight') + logger.info(f"Plot saved to {output_file}") + else: + plt.show() + +def main(): + if len(sys.argv) < 2: + logger.error("Please provide at least one input file") + sys.exit(1) + + analyzer = ModelAnalyzer() + + # Process input files + analyzer.read_files(sys.argv[1:]) + + if not analyzer.lines: + logger.error("No valid data was loaded") + sys.exit(1) + + # Process the data + analyzer.process_data() + + # Log summary statistics + logger.info(f"Processed {len(analyzer.lines)} lines") + logger.info(f"Found {len(analyzer.data_dict)} valid data points") + logger.info(f"Models: {analyzer.models}") + logger.info(f"Temperatures: {analyzer.temps}") + logger.info(f"Tests: {analyzer.tests}") + logger.info(f"Implementations: {analyzer.impls}") + + # Create and plot the visualization + df = analyzer.create_matrix() + # analyzer.plot_heatmap(df, "model_analysis.png") + analyzer.plot_heatmap(df) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py new file mode 100755 index 0000000000000..826112ff54588 --- /dev/null +++ b/scripts/tool_bench.py @@ -0,0 +1,203 @@ +#!/usr/bin/env uv run +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "pytest", +# "numpy", +# "pandas", +# "matplotlib", +# "seaborn", +# "requests", +# "wget", +# ] +# /// +''' + cmake --build build -j && ( \ + export LLAMA_CACHE=$HOME/Library/Caches/llama.cpp ; + export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server ; + export ARGS=( --n=10 --temps=0,0.5,0.75,1,1.5,2,5, --append=all.jsonl ) ; + ./scripts/tool_bench.py ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 1.5B Q4_K_M" --hf bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF --ollama deepseek-r1:1.5b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M" --hf bartowski/Qwen2.5-7B-Instruct-GGUF ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --hf bartowski/Llama-3.2-1B-Instruct-GGUF --ollama llama3.2:1b-instruct-q4_K_M ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --hf bartowski/Llama-3.2-3B-Instruct-GGUF --ollama llama3.1:3b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M" --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF --ollama llama3.1:8b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.3 Instruct 70B Q4_K_M" --hf bartowski/Llama-3.3-70B-Instruct-GGUF ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Mistral Nemo 2407 Q4_K_M" --hf bartowski/Mistral-Nemo-Instruct-2407-GGUF --ollama mistral-nemo:12b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Functionary Small v3.2 Q4_K_M" --hf bartowski/functionary-small-v3.2-GGUF ; + ) + +''' +import argparse +from contextlib import contextmanager +from statistics import mean, median +import pytest + +# ensure grandparent path is in sys.path +from pathlib import Path +import sys + +sys.path.insert(0, Path(__file__).parent.parent.as_posix()) +print(sys.path) +from examples.server.tests.unit.test_tool_call import * + + +@contextmanager +def scoped_server(sp: ServerProcess): + global server + server = sp + + import atexit + def stop(): + global server + nonlocal sp + if sp is not None: + sp.stop() + sp = None # type: ignore + server = None # type: ignore + atexit.register(stop) + + yield sp + + stop() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Run tests for the chat server.') + parser.add_argument('--model', type=str, help='Name of the model to test (implementation agnostic)', required=True) + parser.add_argument('--hf', type=str, help='GGUF huggingface model repo id (+ optional quant) to test w/ llama-server') + parser.add_argument('--chat-template', type=str, help='Chat template override for llama-server') + parser.add_argument('--ollama', type=str, help='Ollama model tag to test') + parser.add_argument('--n', type=int, help='Number of times to run each test', default=30) + parser.add_argument('--temps', type=str, help='Comma-separated list of temperatures') + parser.add_argument('--top-p', type=float, help='top_p') + parser.add_argument('--top-k', type=int, help='top_k') + parser.add_argument('--seed', type=int, help='Random seed') + parser.add_argument('--port', type=int, help='llama-server port') + parser.add_argument('--output', type=str, help='Output JSON file') + parser.add_argument('--append', type=str, help='Output JSON file') + + + args = parser.parse_args() + + # Check only one of output and append + assert (args.output is None) != (args.append is None), "Exactly one of --output and --append must be specified" + + # chat_template = args.chat_template + n = args.n + + n_predict = 512 + + with open(args.output or args.append, 'w' if args.output else 'a') as output_file: + + def run(server: ServerProcess, *, implementation: str, model_id: str, temp: float | None = None, output_kwargs={}, request_kwargs={}): + request_kwargs = {**request_kwargs} + if temp is not None: + request_kwargs['temperature'] = temp + if args.top_p is not None: + request_kwargs['top_p'] = args.top_p + if args.top_k is not None: + request_kwargs['top_k'] = args.top_k + if args.seed is not None: + request_kwargs['seed'] = args.seed + + request_kwargs['cache_prompt'] = False + + tests = { + "hello world": lambda server: do_test_hello_world(server, **request_kwargs), + "weather": lambda server: do_test_weather(server, **request_kwargs), + "calc result": lambda server: do_test_calc_result(server, None, 512, **request_kwargs), + } + for test_name, test in tests.items(): + success_count = 0 + failure_count = 0 + failures = [] + success_times = [] + failure_times = [] + print(f"Running {test_name} ({implementation}, {args.model}): ", file=sys.stderr, flush=True) + for i in range(n): + start_time = time.time() + def elapsed(): + return time.time() - start_time + try: + test(server) + success_times.append(elapsed()) + success_count += 1 + print('.', end='', file=sys.stderr, flush=True) + except Exception as e: + print('!', end='', file=sys.stderr, flush=True) + if failure_count == 0: + print(f" ({e}) ", end='', file=sys.stderr, flush=True) + failure_count += 1 + failure_times.append(elapsed()) + failures.append(str(e)) + print('\n', file=sys.stderr, flush=True) + output_file.write(json.dumps({**output_kwargs, **dict( + model=args.model, + implementation=implementation, + model_id=model_id, + test=test_name, + temp=temp, + top_p=args.top_p, + top_k=args.top_k, + success_ratio=float(success_count) / n, + avg_time=mean(success_times + failure_times), + median_time=median(success_times + failure_times), + success_count=success_count, + success_times=success_times, + failure_count=failure_count, + failure_times=failure_times, + failures=list(set(failures)), + )}) + '\n') + output_file.flush() + + temps = [float(temp) if temp != "" else None for temp in args.temps.split(',')] if args.temps is not None else [None] + for temp in temps: + if args.hf is not None: + server = ServerProcess() + server.n_slots = 1 + server.jinja = True + server.n_predict = 512 # High because of DeepSeek R1 + server.model_hf_repo = args.hf + server.model_hf_file = None + server.chat_template = args.chat_template + if args.port is not None: + server.server_port = args.port + # server.debug = True + + with scoped_server(server): + server.start(timeout_seconds=TIMEOUT_SERVER_START) + for ignore_chat_grammar in [False, True]: + run( + server, + implementation="llama-server" + (" (no grammar)" if ignore_chat_grammar else ""), + model_id=args.hf, + temp=temp, + output_kwargs=dict( + chat_template=args.chat_template, + ), + request_kwargs=dict( + ignore_chat_grammar=ignore_chat_grammar, + ), + ) + + if args.ollama is not None: + server = ServerProcess() + server.server_port = 11434 + server.server_host = "localhost" + subprocess.check_call(["ollama", "pull", args.ollama]) + + with scoped_server(server): + run( + server, + implementation="ollama", + model_id=args.ollama, + temp=temp, + output_kwargs=dict( + chat_template=None, + ), + request_kwargs=dict( + model=args.ollama, + ), + ) From 84d0ff508b1a6a92fdb78cdcc737c770e2a069e7 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 15:26:24 +0000 Subject: [PATCH 03/43] support RETRIES=N in server test utils --- examples/server/tests/utils.py | 45 ++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index a82504235ff54..2e850b216b1b3 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -181,7 +181,7 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--chat-template-file", self.chat_template_file]) args = [str(arg) for arg in [server_path, *server_args]] - print(f"bench: starting server with: {' '.join(args)}") + print(f"tests: starting server with: {' '.join(args)}") flags = 0 if "nt" == os.name: @@ -233,23 +233,32 @@ def make_request( timeout: float | None = None, ) -> ServerResponse: url = f"http://{self.server_host}:{self.server_port}{path}" - parse_body = False - if method == "GET": - response = requests.get(url, headers=headers, timeout=timeout) - parse_body = True - elif method == "POST": - response = requests.post(url, headers=headers, json=data, timeout=timeout) - parse_body = True - elif method == "OPTIONS": - response = requests.options(url, headers=headers, timeout=timeout) - else: - raise ValueError(f"Unimplemented method: {method}") - result = ServerResponse() - result.headers = dict(response.headers) - result.status_code = response.status_code - result.body = response.json() if parse_body else None - print("Response from server", json.dumps(result.body, indent=2)) - return result + retries = int(os.environ.get('RETRIES', '1')) + for remaining_attempts in range(retries, 0, -1): + # print(f"#\ncurl {url} -d '{json.dumps(data, indent=2)}'\n") + parse_body = False + if method == "GET": + response = requests.get(url, headers=headers, timeout=timeout) + parse_body = True + elif method == "POST": + response = requests.post(url, headers=headers, json=data, timeout=timeout) + parse_body = True + elif method == "OPTIONS": + response = requests.options(url, headers=headers, timeout=timeout) + else: + raise ValueError(f"Unimplemented method: {method}") + + if (response is None or response.status_code != 200) and remaining_attempts > 0: + continue + result = ServerResponse() + result.headers = dict(response.headers) + result.status_code = response.status_code + result.body = response.json() if parse_body else None + # print("Response from server", json.dumps(result.body, indent=2)) + return result + + raise RuntimeError(f"Failed to make request to {url} after {retries} attempts") + def make_stream_request( self, From 85283781039cfadd45a44536d8b7ef48cab12c7b Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 15:26:46 +0000 Subject: [PATCH 04/43] server: detect premature llama-server death in e2e tests --- examples/server/tests/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 2e850b216b1b3..bb97178dcb4cb 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -212,6 +212,10 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: return # server is ready except Exception as e: pass + # Check if process died + if self.process.poll() is not None: + raise RuntimeError(f"Server process died with return code {self.process.returncode}") + print(f"Waiting for server to start...") time.sleep(0.5) raise TimeoutError(f"Server did not start within {timeout_seconds} seconds") From d617bb535ca140213cbb48ac3c4ed71774a00447 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 15:30:02 +0000 Subject: [PATCH 05/43] tool-call: improve Qwen 2.5 & Llama 3.x w/ more triggers, constrained python code strings and revamped parsers --- common/chat.cpp | 441 ++++++++++++++----- examples/server/tests/unit/test_tool_call.py | 91 ++-- tests/test-chat.cpp | 80 ++++ 3 files changed, 456 insertions(+), 156 deletions(-) mode change 100644 => 100755 examples/server/tests/unit/test_tool_call.py diff --git a/common/chat.cpp b/common/chat.cpp index 9ebe4c5784cbc..bf1269e1ac61f 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -455,7 +455,7 @@ const common_grammar_options grammar_options { // /* .compact_spaces = */ true, }; -static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { +static std::optional parse_json(std::string::const_iterator & it, const std::string::const_iterator & end) { // // https://json.nlohmann.me/features/parsing/sax_interface/ struct json_error_locator : public nlohmann::json_sax { std::size_t position; @@ -492,14 +492,42 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons } std::string json_sub {it, temptative_end}; try { - out = json::parse(json_sub); + auto out = json::parse(json_sub); it = temptative_end; - return true; + return out; } catch (const std::exception &) { - return false; + return std::nullopt; } } +static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) { + auto expected_it = expected.begin(); + auto tmp_it = it; + while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) { + ++tmp_it; + ++expected_it; + } + if (expected_it == expected.end()) { + it = tmp_it; + return true; + } + return false; +} + +static std::optional parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) { + std::smatch match; + if (std::regex_match(it, end, match, expected)) { + it = match.suffix().first; + return match; + } + return std::nullopt; +} + +static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) { + while (it != end && std::isspace(*it)) { + ++it; + } +} /** * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. @@ -509,7 +537,8 @@ static common_chat_msg parse_json_tool_calls( const std::string& input, const std::optional & trigger_opt, const std::regex & function_regex, - const std::regex & close_regex) { + const std::regex & close_regex, + bool allow_raw_python = false) { std::smatch match; common_chat_msg result; @@ -539,15 +568,19 @@ static common_chat_msg parse_json_tool_calls( result.content += std::string(it, rit->prefix().second); it = rit->suffix().first; - json arguments; - if (!parse_json(it, end, arguments)) { + if (auto arguments = parse_json(it, end)) { + if (!std::regex_search(it, end, match, close_regex)) { + throw std::runtime_error("Malformed input, missing closing pattern: " + input); + } + it = match.suffix().first; + result.tool_calls.push_back({name, arguments->is_string() ? arguments->get() : arguments->dump(), /* id= */ ""}); + } else { + if (allow_raw_python && name == "python") { + result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""}); + break; + } throw std::runtime_error("Failed to parse json tool call arguments: " + input); } - if (!std::regex_search(it, end, match, close_regex)) { - throw std::runtime_error("Malformed input, missing closing pattern: " + input); - } - it = match.suffix().first; - result.tool_calls.push_back({name, arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ ""}); } if (!result.tool_calls.empty()) { @@ -559,29 +592,29 @@ static common_chat_msg parse_json_tool_calls( return result; } +static common_chat_tool_call process_tool_call(const json & tool_call) { + const auto & arguments = tool_call.at("arguments"); + return { + /* .name = */ tool_call.at("name"), + /* .arguments = */ arguments.is_string() ? arguments.get() : arguments.dump(), + /* .id = */ tool_call.contains("id") ? tool_call.at("id") : "", + }; +} static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { auto content_end = input.find(prefix); size_t tc_start = std::string::npos; common_chat_msg result; result.role = "assistant"; - const auto process_tool_calls = [&](const json & tool_calls) { - for (const auto & tool_call : tool_calls) { - const auto & arguments = tool_call.at("arguments"); - result.tool_calls.push_back({ - tool_call.at("name"), - arguments.is_string() ? arguments.get() : arguments.dump(), - tool_call.contains("id") ? tool_call.at("id") : "", - }); - } - }; if (content_end == std::string::npos) { result.content = input; } else { tc_start = content_end + prefix.size() - rstrip_prefix; result.content = input.substr(0, content_end); auto tool_calls = json::parse(input.substr(tc_start)); - process_tool_calls(tool_calls); + for (const auto & tool_call : tool_calls) { + result.tool_calls.emplace_back(process_tool_call(tool_call)); + } } return result; } @@ -840,9 +873,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ return data; } static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) { - static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)"); - static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>"); - static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>"); + static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)"); + static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>"); + static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>"); std::smatch match; @@ -898,6 +931,78 @@ static void expect_tool_parameters(const std::string & name, const json & parame } } +/* + Adds a GBNF rule that matches a Python code string when escaped inside a JSON string (without surrounding double quotes) + + If this sounds meta, well, it is: + - Most tool call style pass tool arguments as JSON objects, e.g. {"arg1": , ...} + - When the tool is `python` and the argument is `code`, the value is JSON-escaped Python code. + Some models (Llama 3.x series) tend to close the code string itself when the nested code + tries to open a double quoted string. So when the model wants to write the code `print("Hey")`, + it only goes so far as `{"code": "print("` and the general JSON constraints of the python tool arguments call it a day. + - This rule (when wrapped in double quotes) can be used instead of a JSON string + to match a structured soup of Python tokens that has the following properties: + - All open brackets / braces / parentheses are closed + - All strings (single or double quoted) are closed + - All double quotes are escaped + + This should prevent an entire class of invalid Python programs to be generated by the model, + but any bugs / omissions may also disallow some valid Python syntax. Current limitations: + + - No f strings + - No multiline strings + + Examples: + + - OK + {"code": "print('Hey')"} + {"code": "print(\"Hey\")"} + {"code": "# in \" comments...\nx = \"Hey\""} + - NOT OK + {"code": "print("} + {"code": "print(\""} + {"code": "print('"} +*/ +static std::string add_escaped_python_code_soup_rule(const common_grammar_builder & builder) { + return builder.add_rule("json-escaped-code-soup", + // Allow comments w/ (escaped) newline + R"( ( [#] ( ( [^\\\t\r\n\uff00-\uffef] | [\\] [^n\n] )* [\\] [n] )? | )" + // Allow (escaped) double quoted strings and their nested (double) escapes + R"( [\\] ["] ( [^"\\\t\r\n\uff00-\uffef] | [\\] [\\] ["] | [\\] [trnu] )* [\\] ["] | )" + // Allow single quoted strings and their nested (double) escapes + R"( ['] ( [^"'\\\t\r\n\uff00-\uffef] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] )* ['] | )" + // Soup wrapped in parentheses, curly braces or square brackets + R"( [(] json-escaped-code-soup [)] | )" + R"( [{] json-escaped-code-soup [}] | )" + R"( "[" json-escaped-code-soup "]" | )" + // Allow escapes + R"( [\\] [\\trnu] | )" + // Allow other characters, minus code blocks for halfwidth & fullwidth forms (U+FF00 - U+FFEF) + // (special tokens can use these to avoid prompt injections, as they will have to be unicode-escaped w/ \uXXXX + // and won't be able to interfere w/ parsing) + R"( [^#{}"'\[\]\\()\t\r\n\uff00-\uffef]+ )" + // After any repetition of the previous, allow trailing comment w/o newline + R"( )* ( [#] ( [^\\] | [\\] [^n] )* )? )" + ); +} + +static std::string add_python_code_arguments_rule(const std::string & name, const common_grammar_builder & builder) { + return builder.add_rule( + name, + "\"{\" space \"\\\"code\\\": \\\"\" " + + add_escaped_python_code_soup_rule(builder) + + " \"\\\"\" space \"}\" space "); +} + +static std::string add_json_tool_args_rule(const std::string & name, const json & parameters, const common_grammar_builder & builder) { + if (name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { + return add_python_code_arguments_rule(name + "-code-args", builder); + } else { + return builder.add_schema(name + "-args", parameters); + } +} + + static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { auto builtin_tools = json::array(); common_chat_params data; @@ -919,7 +1024,11 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com std::vector kvs; for (const auto & [key, value] : parameters.at("properties").items()) { - kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT + if (name == "python" && key == "code") { + kvs.push_back("\"" + key + "=\\\"\" " + add_escaped_python_code_soup_rule(builder) + " \"\\\"\""); // NOLINT + } else { + kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT + } } tool_rules.push_back( @@ -947,7 +1056,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com "\"{\" space " "( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? " "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + - builder.add_schema(name + "-args", parameters) + + add_json_tool_args_rule(name, parameters, builder) + " \"}\"")); data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); }); @@ -974,33 +1083,33 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com } static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) { // TODO: tighten & simplify the parser, don't accept leading text context. - static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": "); + static std::regex function_regex( + "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*|\\s*)\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\": "); static std::regex close_regex("\\}"); - static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)"); + static std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)"); if (with_builtin_tools) { std::smatch match; if (std::regex_match(input, match, builtin_call_regex)) { - auto name = match[1].str(); - auto raw_args = match[2].str(); - - // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing. - auto it_eq = raw_args.find('='); - auto arg_name = raw_args.substr(0, it_eq); - auto arg_value_str = raw_args.substr(it_eq + 1); - auto arg_value = json::parse(arg_value_str); - - common_chat_msg msg; - msg.role = "assistant"; - msg.content = match.prefix().str(); - msg.tool_calls.push_back({ - /* .name = */ name, - /* .arguments = */ (json { - {arg_name, arg_value}, - }).dump(), - /* .id = */ "", - }); - return msg; + try { + auto name = match[1].str(); + auto arg_name = match[2].str(); + auto arg_value_str = match[3].str(); + auto arg_value = json::parse(arg_value_str); + + common_chat_msg msg; + msg.role = "assistant"; + msg.tool_calls.push_back({ + /* .name = */ name, + /* .arguments = */ (json { + {arg_name, arg_value}, + }).dump(), + /* .id = */ "", + }); + return msg; + } catch (const std::exception & e) { + LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str()); + } } } return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); @@ -1017,10 +1126,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ std::string name = function.at("name"); auto parameters = function.at("parameters"); builder.resolve_refs(parameters); - auto args_rule = builder.add_schema(name + "-args", parameters); tool_rules.push_back(builder.add_rule(name + "-call", "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n" - "```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); + "```json\\n\" " + add_json_tool_args_rule(name, parameters, builder) + " " + "\"```<|tool▁call▁end|>\"")); }); // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, // so we accept common variants (then it's all constrained) @@ -1158,11 +1267,16 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ auto parameters = function.at("parameters"); builder.resolve_refs(parameters); auto args_rule = builder.add_schema(name + "-args", parameters); - first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); + first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); data.grammar_triggers.push_back({name, /* .at_start = */ true}); + data.grammar_triggers.push_back({"assistant<|end_header_id|>\n" + name, /* .at_start = */ true}); data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false}); + data.grammar_triggers.push_back({">>>assistant<|end_header_id|>\n" + name, /* .at_start = */ false}); }); + data.preserved_tokens = { + "<|end_header_id|>", + }; auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; if (inputs.parallel_tool_calls) { auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; @@ -1176,29 +1290,15 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ return data; } -static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) { - auto expected_it = expected.begin(); - auto tmp_it = it; - while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) { - ++tmp_it; - ++expected_it; - } - if (expected_it == expected.end()) { - it = tmp_it; - return true; - } - return false; -} - static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) { - static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); + static std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); std::string content; auto it = input.begin(); const auto end = input.end(); - if (consume(it, end, "all\n")) { + if (parse_literal(it, end, "all\n")) { std::smatch match; if (std::regex_search(it, end, match, function_regex)) { auto fun_it = match.prefix().second; @@ -1213,7 +1313,7 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in } // TODO: tighten & simplify. try { - auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex); + auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true); res.content = content + res.content; return res; } catch (const std::exception & e) { @@ -1306,70 +1406,185 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; + std::vector tool_call_alts; foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); std::string name = function.at("name"); auto parameters = function.at("parameters"); builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_schema(name + "-call", { - {"type", "object"}, - {"properties", json { - {"name", json {{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - })); + if (name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { + tool_rules.push_back(builder.add_rule(name + "-call", + "\"{\" space " + "\"\\\"name\\\":\" space \"\\\"" + name + "\\\"\" space \",\" space " + "\"\\\"arguments\\\":\" space " + add_python_code_arguments_rule(name + "-code-arguments", builder) + " " + "\"}\" space ")); + } else { + tool_rules.push_back(builder.add_schema(name + "-call", { + {"type", "object"}, + {"properties", json { + {"name", json {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + })); + } + tool_call_alts.push_back(builder.add_rule( + name + "-function-tag", + "\"\" space " + + builder.add_schema(name + "-args", parameters) + " " + "\"\" space")); }); - auto tool_call = "\"\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"\" space"; + auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space"); + std::vector alt_tags { + any_tool_call, + "\"\" space " + any_tool_call + " \"\"", + // The rest is just to accommodate common "good bad" outputs. + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + }; + auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space"); + tool_call_alts.push_back(wrappable_tool_call); + tool_call_alts.push_back( + "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space "); + auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | ")); builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); data.grammar_triggers.push_back({"", /* .at_start = */ false}); - data.preserved_tokens = { "" }; + data.grammar_triggers.push_back({"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```\n{\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```\n {\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```\n{\n \"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```json\n{\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```json\n {\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```json\n{\n \"name\": \"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```xml\n{\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```xml\n {\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```xml\n{\n \"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```xml\n\n {\"name\":", /* .at_start = */ true}); + data.preserved_tokens = { + "", + "", + "", + "", + "", + "", + "```", + "```json", + "```xml", + }; }, grammar_options); data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; return data; } -static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) { - try { - std::regex start_pattern(R"([\n\s]*)"); - std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); - std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); +static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input) { + const static std::regex open_regex( + "(?:" + "(```(?:xml|json)?\\n)?" // match 1 (block_start) + "(" // match 2 (open_tag) + "|" + "|" + "|" + "|" + "|" + "|" + "|" + ")?" + "(\\s*\\{\\s*\"name\":[\\s\\S]*)" // match 3 (named tool call + rest) + ")" + "|" + "(?:]+)>" // match 4 (function name) + "|)" // match 5 (function name again) + "([\\s\\S]*)" // match 6 (function arguments + rest)})" + ); + try { + common_chat_msg msg; msg.role = "assistant"; - auto end = input.end(); - std::sregex_iterator rend; - std::sregex_iterator rit(input.begin(), end, start_pattern); - if (rit == rend) { - msg.content = input; - return msg; - } - - msg.content = rit->prefix(); + std::string::const_iterator it = input.begin(); + const std::string::const_iterator end = input.end(); + std::smatch match; - auto it = rit->suffix().first; while (it != end) { - json call; - if (!parse_json(it, end, call)) { - throw std::runtime_error("Failed to parse json tool call"); - } - const auto & arguments = call.at("arguments"); - msg.tool_calls.push_back({ - call.at("name"), - arguments.dump(), - // arguments.is_string() ? arguments.get() : arguments.dump(), - /* id= */ "", - }); - rit = {it, end, middle_pattern}; - if (rit != rend) { - it = rit->suffix().first; - } else { - rit = {it, end, end_pattern}; - if (rit == rend) { - throw std::runtime_error("Malformed input, missing "); + if (std::regex_search(it, end, match, open_regex)) { + // Add content before the match + msg.content += std::string(it, match[0].first); + + auto block_start = match[1].str(); + std::string block_end = block_start.empty() ? "" : "```"; + + auto open_tag = match[2].str(); + std::string close_tag; + + if (match[3].matched) { + close_tag = open_tag.empty() ? "" : "contains("name") && tool_call->contains("arguments")) { + + msg.tool_calls.emplace_back(process_tool_call(*tool_call)); + it = json_it; // Move iterator past parsed JSON + + // Handle close tags + consume_spaces(it, end); + if (!close_tag.empty() && !parse_literal(it, end, close_tag)) { + throw std::runtime_error("Failed to parse closing tag"); + } + consume_spaces(it, end); + if (!block_end.empty() && !parse_literal(it, end, block_end)) { + throw std::runtime_error("Failed to parse block end"); + } + } else { + // Not a valid tool call, treat as content + msg.content += std::string(match[0].first, match[0].second); + it = match[0].second; + } + } else { + auto function_name = match[4].str(); + if (function_name.empty()) { + function_name = match[5].str(); + } + GGML_ASSERT(!function_name.empty()); + + close_tag = ""; + // Start parsing from after the opening tags + auto json_it = match[6].first; + if (auto arguments = parse_json(json_it, end)) { + msg.tool_calls.emplace_back(process_tool_call({ + {"name", function_name}, + {"arguments", *arguments}, + })); + it = json_it; // Move iterator past parsed JSON + + // Handle close tags + consume_spaces(it, end); + if (!close_tag.empty() && !parse_literal(it, end, close_tag)) { + throw std::runtime_error("Failed to parse closing tag"); + } + consume_spaces(it, end); + if (!block_end.empty() && !parse_literal(it, end, block_end)) { + throw std::runtime_error("Failed to parse block end"); + } + } else { + // Not a valid tool call, treat as content + msg.content += std::string(match[0].first, match[0].second); + it = match[0].second; + } } + } else { + // Add remaining content + msg.content += std::string(it, end); break; } } diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py old mode 100644 new mode 100755 index 4467de2a2e755..e02cb83876890 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -155,25 +155,29 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), - # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - # (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), - # (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), @@ -189,10 +193,10 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), - # (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), - # TODO: fix these - # (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - # (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server @@ -297,6 +301,9 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), @@ -513,46 +520,48 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] @pytest.mark.slow -@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ - (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"), +@pytest.mark.parametrize("hf_repo,template_override", [ + ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), - (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), - ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - (None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), - (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), - ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), - (None, "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), - (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), - (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), - # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. - (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"), ]) -def test_hello_world(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): +def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server + n_predict = 512 # High because of DeepSeek R1 server.n_slots = 1 server.jinja = True server.n_ctx = 8192 - server.n_predict = 512 # High because of DeepSeek R1 server.n_predict = n_predict server.model_hf_repo = hf_repo server.model_hf_file = None @@ -564,10 +573,10 @@ def test_hello_world(expected_arguments_override: str | None, hf_repo: str, temp server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - do_test_hello_world(server, expected_arguments_override, max_tokens=n_predict) + do_test_hello_world(server, max_tokens=n_predict) -def do_test_hello_world(server: ServerProcess, expected_arguments_override, **kwargs): +def do_test_hello_world(server: ServerProcess, **kwargs): res = server.make_request("POST", "/v1/chat/completions", data={ "messages": [ {"role": "system", "content": "You are a tool-calling agent."}, @@ -583,12 +592,8 @@ def do_test_hello_world(server: ServerProcess, expected_arguments_override, **kw tool_call = tool_calls[0] assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] - actual_arguments = tool_call["function"]["arguments"] - if expected_arguments_override is not None: - assert actual_arguments == expected_arguments_override - else: - actual_arguments = json.loads(actual_arguments) - assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" - code = actual_arguments["code"] - assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" - assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}' + actual_arguments = json.loads(tool_call["function"]["arguments"]) + assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" + code = actual_arguments["code"] + assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" + assert re.match(r'''((#.*)?\n)*print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}' diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 6435923054859..2b1226292177b 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -640,6 +640,86 @@ static void test_template_output_parsers() { inputs_tools) .format); + // Test parsing + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "{\"arg1\": 1}", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + "{\"arg1\": 1}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "```xml\n" + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "\n" + "```", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "```xml\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "```\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "```\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "```json\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\n" + " \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n" + " }\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "\n" From ae8747ed3d21231c2e442a9fd9f61ccccfa32ec3 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 15:08:17 +0000 Subject: [PATCH 06/43] sampler: turn lazy grammar trigger words to regexes --- common/sampling.cpp | 32 +++++++++++++++++++++++------ examples/server/server.cpp | 3 ++- include/llama.h | 8 ++++---- src/llama-grammar.cpp | 41 +++++++++++++++++++------------------- src/llama-grammar.h | 10 +++++++--- src/llama-sampling.cpp | 20 ++++++++++--------- 6 files changed, 70 insertions(+), 44 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 37a0d9c85ae30..9eea0f749f3be 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -3,6 +3,7 @@ #include "common.h" #include +#include #include // the ring buffer works similarly to std::deque, but with a fixed capacity @@ -144,6 +145,11 @@ std::string common_params_sampling::print() const { return std::string(result); } +inline std::string regex_escape(const std::string & literal) { + static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); + return std::regex_replace(literal, special_chars, "\\$0"); +} + struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); @@ -159,15 +165,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE } else { - std::vector trigger_words; - trigger_words.reserve(params.grammar_trigger_words.size()); - for (const auto & str : params.grammar_trigger_words) { - trigger_words.push_back(str.word.c_str()); + std::vector escaped_triggers_at_start; + std::vector escaped_triggers_anywhere; + for (const auto & trigger : params.grammar_trigger_words) { + (trigger.at_start ? escaped_triggers_at_start : escaped_triggers_anywhere) + .push_back(regex_escape(trigger.word)); + } + + std::vector trigger_regexes; + if (!escaped_triggers_at_start.empty()) { + trigger_regexes.push_back("^(" + string_join(escaped_triggers_at_start, "|") + ")[\\s\\S]*"); + } + if (!escaped_triggers_anywhere.empty()) { + trigger_regexes.push_back("^[\\s\\S]*?(" + string_join(escaped_triggers_anywhere, "|") + ")[\\s\\S]*"); + } + + std::vector trigger_regexes_c; + trigger_regexes_c.reserve(trigger_regexes.size()); + for (const auto & regex : trigger_regexes) { + trigger_regexes_c.push_back(regex.c_str()); } grmr = params.grammar_lazy ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", - trigger_words.data(), trigger_words.size(), + trigger_regexes_c.data(), trigger_regexes_c.size(), params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()) : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); } @@ -202,7 +223,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co for (const auto & str : params.dry_sequence_breakers) { c_breakers.push_back(str.c_str()); } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 809bfe0e36cd7..bdea828aeba3f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -170,6 +170,7 @@ struct slot_params { {"n_probs", sampling.n_probs}, {"min_keep", sampling.min_keep}, {"grammar", sampling.grammar}, + {"grammar_lazy", sampling.grammar_lazy}, {"grammar_trigger_words", grammar_trigger_words}, {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, {"preserved_tokens", sampling.preserved_tokens}, @@ -2045,7 +2046,7 @@ struct server_context { if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { // Might be better to reject the request with a 400 ? - SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.params.n_predict, slot.n_predict); + SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d\n", slot.params.n_predict, slot.n_predict); slot.params.n_predict = slot.n_predict; } diff --git a/include/llama.h b/include/llama.h index b0726cbe63ea6..b26f2b05e91f8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1204,14 +1204,14 @@ extern "C" { const char * grammar_root); /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 - /// @param trigger_words A list of words that will trigger the grammar sampler. This may be updated to a loose regex syntax (w/ ^) in a near future. - /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. + /// @param trigger_regexes A list of (full-string) regexes that will trigger the grammar sampler. Grammar sampler will be fed content starting from the first match group. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included. LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_regexes, + size_t num_trigger_regexes, const llama_token * trigger_tokens, size_t num_trigger_tokens); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 46e27a96ed728..06c3d188d4c2e 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -969,7 +969,7 @@ struct llama_grammar * llama_grammar_init_impl( /* .awaiting_trigger = */ false, /* .trigger_buffer = */ "", /* .trigger_tokens = */ {}, - /* .trigger_words = */ {}, + /* .trigger_regexes = */ {}, }; } @@ -978,19 +978,15 @@ struct llama_grammar * llama_grammar_init_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_regexes, + size_t num_trigger_regexes, const llama_token * trigger_tokens, size_t num_trigger_tokens) { llama_grammar_parser parser; // if there is a grammar, parse it - if (!parser.parse(grammar_str)) { - return nullptr; - } - - // will be empty (default) if there are parse errors - if (parser.rules.empty()) { + // rules will be empty (default) if there are parse errors + if (!parser.parse(grammar_str) || parser.rules.empty()) { fprintf(stderr, "%s: failed to parse grammar\n", __func__); return nullptr; } @@ -1054,14 +1050,14 @@ struct llama_grammar * llama_grammar_init_impl( } while (true); std::vector vec_trigger_tokens; - std::vector vec_trigger_words; + std::vector> vec_trigger_regexes; for (size_t i = 0; i < num_trigger_tokens; i++) { GGML_ASSERT(trigger_tokens != nullptr); vec_trigger_tokens.push_back(trigger_tokens[i]); } - for (size_t i = 0; i < num_trigger_words; i++) { - GGML_ASSERT(trigger_words != nullptr); - vec_trigger_words.push_back(trigger_words[i]); + for (size_t i = 0; i < num_trigger_regexes; i++) { + GGML_ASSERT(trigger_regexes != nullptr); + vec_trigger_regexes.emplace_back(trigger_regexes[i], trigger_regexes[i]); } // Important: vec_rules has to be moved here, not copied, because stacks contains @@ -1076,7 +1072,7 @@ struct llama_grammar * llama_grammar_init_impl( /* .awaiting_trigger = */ lazy, /* .trigger_buffer = */ "", std::move(vec_trigger_tokens), - std::move(vec_trigger_words), + std::move(vec_trigger_regexes), }; } @@ -1098,7 +1094,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.awaiting_trigger, grammar.trigger_buffer, grammar.trigger_tokens, - grammar.trigger_words, + grammar.trigger_regexes, }; // redirect elements in stacks to point to new rules @@ -1173,19 +1169,22 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str()); return; } else { - // TODO: consider a smarter incremental substring search algorithm (store last position to search from). grammar.trigger_buffer += piece; - for (const auto & word : grammar.trigger_words) { - auto pos = grammar.trigger_buffer.find(word); - if (pos != std::string::npos) { + + std::smatch match; + for (const auto & [_, regex] : grammar.trigger_regexes) { + if (std::regex_match(grammar.trigger_buffer, match, regex)) { grammar.awaiting_trigger = false; - auto constrained_str = grammar.trigger_buffer.substr(pos); + // get from the first match to the end of the string + auto constrained_str = grammar.trigger_buffer.substr(match.position(1)); + // std::string constrained_str(match[1].first, grammar.trigger_buffer.end()); grammar.trigger_buffer.clear(); llama_grammar_accept_str(grammar, constrained_str); - LLAMA_LOG_DEBUG("Grammar triggered on word `%s`", word.c_str()); + LLAMA_LOG_DEBUG("Grammar triggered on regex: %s", constrained_str.c_str()); return; } } + LLAMA_LOG_DEBUG("Grammar still awaiting trigger after token %d (`%s`)\n", token, piece.c_str()); return; } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index b143d834cfabe..8d9b1a81dfd1c 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include #include #include @@ -122,7 +123,10 @@ struct llama_grammar { bool awaiting_trigger = false; // Initialized to true for lazy grammars only std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). - std::vector trigger_words; + std::vector> + trigger_regexes; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated + // string, and the grammar will be given the string from the first match group onwards. + }; // @@ -141,8 +145,8 @@ struct llama_grammar * llama_grammar_init_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_regexes, + size_t num_trigger_regexes, const llama_token * trigger_tokens, size_t num_trigger_tokens); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index f40bf2db83a80..7d5f9e86584c1 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1446,8 +1446,8 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_regexes, + size_t num_trigger_regexes, const llama_token * trigger_tokens, size_t num_trigger_tokens); @@ -1457,12 +1457,14 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { return; } - std::vector trigger_words; - for (auto & word : ctx->grammar->trigger_words) { - trigger_words.push_back(word.c_str()); + std::vector trigger_regexes_c; + trigger_regexes_c.reserve(ctx->grammar->trigger_regexes.size()); + for (auto & [pattern, _] : ctx->grammar->trigger_regexes) { + trigger_regexes_c.push_back(pattern.c_str()); } + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), - ctx->grammar->lazy, trigger_words.data(), trigger_words.size(), + ctx->grammar->lazy, trigger_regexes_c.data(), trigger_regexes_c.size(), ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); llama_grammar_free_impl(ctx->grammar); @@ -1513,8 +1515,8 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_words, - size_t num_trigger_words, + const char ** trigger_regexes, + size_t num_trigger_regexes, const llama_token * trigger_tokens, size_t num_trigger_tokens) { auto * ctx = new llama_sampler_grammar; @@ -1524,7 +1526,7 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( /* .vocab = */ vocab, /* .grammar_str = */ grammar_str, /* .grammar_root = */ grammar_root, - /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens), + /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_regexes, num_trigger_regexes, trigger_tokens, num_trigger_tokens), }; } else { *ctx = { From 6703c3955b8bfe867f3c2bd96baecfa2f5ef812d Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 15:13:32 +0000 Subject: [PATCH 07/43] tool-call: add `script/tool_bench.py` --- examples/server/tests/unit/test_tool_call.py | 121 ++++++----- scripts/plot_tool_call_tests.py | 190 +++++++++++++++++ scripts/tool_bench.py | 203 +++++++++++++++++++ 3 files changed, 467 insertions(+), 47 deletions(-) create mode 100644 scripts/plot_tool_call_tests.py create mode 100755 scripts/tool_bench.py diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index a91a2f3333ca3..4467de2a2e755 100644 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -1,4 +1,12 @@ +#!/usr/bin/env python import pytest + +# ensure grandparent path is in sys.path +from pathlib import Path +import sys +path = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(path)) + from utils import * server: ServerProcess @@ -66,15 +74,8 @@ def create_server(): } -def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, argument_key: str | None): - global server - n_predict = 512 - # server = ServerPreset.stories15m_moe() - server.jinja = True - server.n_predict = n_predict - server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ +def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs): + res = server.make_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ {"role": "system", "content": "You are a coding assistant."}, @@ -83,16 +84,14 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a "tool_choice": "required", "tools": [tool], "parallel_tool_calls": False, - "temperature": 0.0, - "top_k": 1, - "top_p": 1.0, + **kwargs, }) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] - assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] assert expected_function_name == tool_call["function"]["name"] actual_arguments = tool_call["function"]["arguments"] @@ -108,7 +107,14 @@ def do_test_completion_with_required_tool_tiny(template_name: str, tool: dict, a ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), ]) def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None): - do_test_completion_with_required_tool_tiny(template_name, tool, argument_key) + global server + n_predict = 512 + # server = ServerPreset.stories15m_moe() + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0) @pytest.mark.slow @@ -133,7 +139,14 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"), ]) def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None): - do_test_completion_with_required_tool_tiny(template_name, tool, argument_key) + global server + n_predict = 512 + # server = ServerPreset.stories15m_moe() + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict) @pytest.mark.slow @@ -197,7 +210,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ + res = server.make_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ {"role": "system", "content": "You are a coding assistant."}, @@ -215,7 +228,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] - assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] assert expected_function_name == tool_call["function"]["name"] actual_arguments = tool_call["function"]["arguments"] @@ -225,13 +238,8 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" -def do_test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): - global server - server.jinja = True - server.n_predict = n_predict - server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ +def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs): + res = server.make_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ {"role": "system", "content": "You are a coding assistant."}, @@ -239,9 +247,7 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too ], "tools": tools if tools else None, "tool_choice": tool_choice, - "temperature": 0.0, - "top_k": 1, - "top_p": 1.0, + **kwargs, }, timeout=TIMEOUT_HTTP_REQUEST) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] @@ -254,7 +260,12 @@ def do_test_completion_without_tool_call(template_name: str, n_predict: int, too ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'), ]) def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): - do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice) + global server + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_without_tool_call(server, n_predict, tools, tool_choice) @pytest.mark.slow @@ -270,7 +281,12 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t ("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'), ]) def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): - do_test_completion_without_tool_call(template_name, n_predict, tools, tool_choice) + global server + server.jinja = True + server.n_predict = n_predict + server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.start(timeout_seconds=TIMEOUT_SERVER_START) + do_test_completion_without_tool_call(server, n_predict, tools, tool_choice) @pytest.mark.slow @@ -324,26 +340,30 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": n_predict, + do_test_weather(server, max_tokens=n_predict) + + +def do_test_weather(server: ServerProcess, **kwargs): + res = server.make_request("POST", "/v1/chat/completions", data={ "messages": [ {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."}, {"role": "user", "content": "What is the weather in Istanbul?"}, ], "tools": [WEATHER_TOOL], + **kwargs, }, timeout=TIMEOUT_HTTP_REQUEST) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] - assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"] actual_arguments = json.loads(tool_call["function"]["arguments"]) assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" location = actual_arguments["location"] assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}" - assert re.match('^Istanbul(, (TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' + assert re.match('^Istanbul(, ?(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' @pytest.mark.slow @@ -379,10 +399,14 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ + do_test_calc_result(server, result_override, n_predict) + + +def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs): + res = server.make_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ - {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things, and provide very concise answers. Do not explain your reasoning to the user. Provide any numerical values back to the user with at most two decimals."}, + {"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."}, {"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"}, { "role": "assistant", @@ -423,7 +447,8 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, } } } - ] + ], + **kwargs, }, timeout=TIMEOUT_HTTP_REQUEST) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] @@ -434,7 +459,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, if result_override is not None: assert re.match(result_override, content), f'Expected {result_override}, got {content}' else: - assert re.match('^[\\s\\S]*?The (y[ -])?coordinate [\\s\\S]*?is (approximately )?0\\.56\\b|^0\\.56$', content), \ + assert re.match('^[\\s\\S]*?((That\'s|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)', content), \ f'Expected something like "The y coordinate is 0.56.", got {content}' @@ -464,7 +489,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ + res = server.make_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ {"role": "user", "content": "What's the sum of 102 and 7?"}, @@ -476,7 +501,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] content = choice["message"].get("content") if expect_content is None: - assert content is None, f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' else: assert re.match(expect_content, content), f'Expected {expect_content}, got {content}' @@ -528,6 +553,7 @@ def test_hello_world(expected_arguments_override: str | None, hf_repo: str, temp server.jinja = True server.n_ctx = 8192 server.n_predict = 512 # High because of DeepSeek R1 + server.n_predict = n_predict server.model_hf_repo = hf_repo server.model_hf_file = None if isinstance(template_override, tuple): @@ -537,24 +563,25 @@ def test_hello_world(expected_arguments_override: str | None, hf_repo: str, temp elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": 256, + + do_test_hello_world(server, expected_arguments_override, max_tokens=n_predict) + + +def do_test_hello_world(server: ServerProcess, expected_arguments_override, **kwargs): + res = server.make_request("POST", "/v1/chat/completions", data={ "messages": [ - {"role": "system", "content": "You are a coding assistant."}, + {"role": "system", "content": "You are a tool-calling agent."}, {"role": "user", "content": "say hello world with python"}, ], "tools": [PYTHON_TOOL], - # Note: without these greedy params, Functionary v3.2 writes `def hello_world():\n print("Hello, World!")\nhello_world()` which is correct but a pain to test. - "temperature": 0.0, - "top_k": 1, - "top_p": 1.0, + **kwargs, }, timeout=TIMEOUT_HTTP_REQUEST) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] tool_calls = choice["message"].get("tool_calls") assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] - assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] actual_arguments = tool_call["function"]["arguments"] if expected_arguments_override is not None: diff --git a/scripts/plot_tool_call_tests.py b/scripts/plot_tool_call_tests.py new file mode 100644 index 0000000000000..b54aecce526dd --- /dev/null +++ b/scripts/plot_tool_call_tests.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python3 +""" +Model Performance Analysis and Visualization Tool + +This script analyzes JSON performance data for different model implementations and tests, +creating a heatmap visualization of success ratios. It handles multiple input files and +supports various model configurations. + +Usage: + python script.py input_file1.json [input_file2.json ...] +""" + +import json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +import sys +from typing import Dict, List, Tuple, Set, Any +from pathlib import Path +import logging + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +class ModelAnalyzer: + def __init__(self): + self.lines: List[Dict] = [] + self.data_dict: Dict[Tuple, float] = {} + self.models: List[str] = [] + self.temps: Set[float] = set() + self.tests: Set[str] = set() + self.impls: Set[str] = set() + + self.column_groups = [ + ("llama-server", []), # Tests will be populated dynamically + ("llama-server (no grammar)", []), + ("ollama", []) + ] + + def read_files(self, files: List[str]) -> None: + """Read and parse JSON data from input files.""" + for file in files: + path = Path(file) + if not path.exists(): + logger.error(f"File not found: {file}") + continue + + try: + with path.open() as f: + raw_data = f.read() + logger.info(f"Reading {file} ({len(raw_data)} bytes)") + + for line_num, line in enumerate(raw_data.split('\n'), 1): + line = line.strip() + if not line: + continue + try: + record = json.loads(line) + self.lines.append(record) + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON at {file}:{line_num} - {e}") + except Exception as e: + logger.error(f"Error processing {file}: {e}") + + def process_data(self) -> None: + """Process the loaded data and organize it for visualization.""" + for rec in self.lines: + try: + model = rec["model"] + temp = rec["temp"] + impl = rec["implementation"] + test = rec["test"] + success = rec["success_ratio"] + + self.data_dict[(model, temp, impl, test)] = success + + if model not in self.models: + self.models.append(model) + self.temps.add(temp) + self.tests.add(test) + self.impls.add(impl) + + except KeyError as e: + logger.warning(f"Missing required field in record: {e}") + + # Sort the collected values + self.temps = sorted(list(self.temps), key=lambda x: x if x is not None else -1) + self.tests = sorted(list(self.tests)) + + # Update column groups with actual tests + self.column_groups = [ + (impl, list(self.tests)) for impl, _ in self.column_groups + if impl in self.impls + ] + + def create_matrix(self) -> pd.DataFrame: + """Create a matrix for visualization.""" + all_cols = [ + (impl, test) + for impl, tests in self.column_groups + for test in tests + ] + + matrix = [] + index = [] + + for model in self.models: + for temp in self.temps: + index.append(f"{model} @ {temp}") + row_vals = [ + self.data_dict.get((model, temp, impl, test), np.nan) + for impl, test in all_cols + ] + matrix.append(row_vals) + + # Create column labels + col_labels = [f"{impl}\n({test})" for impl, test in all_cols] + + return pd.DataFrame(matrix, index=index, columns=col_labels) + + def plot_heatmap(self, df: pd.DataFrame, output_file: str = None) -> None: + """Create and display/save the heatmap visualization.""" + plt.figure(figsize=(12, 6)) + + sns.heatmap( + df, + annot=True, + cmap="RdYlGn", + vmin=0.0, + vmax=1.0, + cbar=True, + fmt=".2f", + center=0.5, + square=True, + linewidths=0.5, + cbar_kws={"label": "Success Ratio"} + ) + + plt.title("Model Performance Analysis\nSuccess Ratios by Implementation & Test", + pad=20) + plt.xlabel("Implementation and Test", labelpad=10) + plt.ylabel("Model @ Temperature", labelpad=10) + + plt.xticks(rotation=45, ha='right') + plt.yticks(rotation=0) + + plt.tight_layout() + + if output_file: + plt.savefig(output_file, dpi=300, bbox_inches='tight') + logger.info(f"Plot saved to {output_file}") + else: + plt.show() + +def main(): + if len(sys.argv) < 2: + logger.error("Please provide at least one input file") + sys.exit(1) + + analyzer = ModelAnalyzer() + + # Process input files + analyzer.read_files(sys.argv[1:]) + + if not analyzer.lines: + logger.error("No valid data was loaded") + sys.exit(1) + + # Process the data + analyzer.process_data() + + # Log summary statistics + logger.info(f"Processed {len(analyzer.lines)} lines") + logger.info(f"Found {len(analyzer.data_dict)} valid data points") + logger.info(f"Models: {analyzer.models}") + logger.info(f"Temperatures: {analyzer.temps}") + logger.info(f"Tests: {analyzer.tests}") + logger.info(f"Implementations: {analyzer.impls}") + + # Create and plot the visualization + df = analyzer.create_matrix() + # analyzer.plot_heatmap(df, "model_analysis.png") + analyzer.plot_heatmap(df) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py new file mode 100755 index 0000000000000..826112ff54588 --- /dev/null +++ b/scripts/tool_bench.py @@ -0,0 +1,203 @@ +#!/usr/bin/env uv run +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "pytest", +# "numpy", +# "pandas", +# "matplotlib", +# "seaborn", +# "requests", +# "wget", +# ] +# /// +''' + cmake --build build -j && ( \ + export LLAMA_CACHE=$HOME/Library/Caches/llama.cpp ; + export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server ; + export ARGS=( --n=10 --temps=0,0.5,0.75,1,1.5,2,5, --append=all.jsonl ) ; + ./scripts/tool_bench.py ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 1.5B Q4_K_M" --hf bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF --ollama deepseek-r1:1.5b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M" --hf bartowski/Qwen2.5-7B-Instruct-GGUF ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --hf bartowski/Llama-3.2-1B-Instruct-GGUF --ollama llama3.2:1b-instruct-q4_K_M ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --hf bartowski/Llama-3.2-3B-Instruct-GGUF --ollama llama3.1:3b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M" --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF --ollama llama3.1:8b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.3 Instruct 70B Q4_K_M" --hf bartowski/Llama-3.3-70B-Instruct-GGUF ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Mistral Nemo 2407 Q4_K_M" --hf bartowski/Mistral-Nemo-Instruct-2407-GGUF --ollama mistral-nemo:12b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Functionary Small v3.2 Q4_K_M" --hf bartowski/functionary-small-v3.2-GGUF ; + ) + +''' +import argparse +from contextlib import contextmanager +from statistics import mean, median +import pytest + +# ensure grandparent path is in sys.path +from pathlib import Path +import sys + +sys.path.insert(0, Path(__file__).parent.parent.as_posix()) +print(sys.path) +from examples.server.tests.unit.test_tool_call import * + + +@contextmanager +def scoped_server(sp: ServerProcess): + global server + server = sp + + import atexit + def stop(): + global server + nonlocal sp + if sp is not None: + sp.stop() + sp = None # type: ignore + server = None # type: ignore + atexit.register(stop) + + yield sp + + stop() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Run tests for the chat server.') + parser.add_argument('--model', type=str, help='Name of the model to test (implementation agnostic)', required=True) + parser.add_argument('--hf', type=str, help='GGUF huggingface model repo id (+ optional quant) to test w/ llama-server') + parser.add_argument('--chat-template', type=str, help='Chat template override for llama-server') + parser.add_argument('--ollama', type=str, help='Ollama model tag to test') + parser.add_argument('--n', type=int, help='Number of times to run each test', default=30) + parser.add_argument('--temps', type=str, help='Comma-separated list of temperatures') + parser.add_argument('--top-p', type=float, help='top_p') + parser.add_argument('--top-k', type=int, help='top_k') + parser.add_argument('--seed', type=int, help='Random seed') + parser.add_argument('--port', type=int, help='llama-server port') + parser.add_argument('--output', type=str, help='Output JSON file') + parser.add_argument('--append', type=str, help='Output JSON file') + + + args = parser.parse_args() + + # Check only one of output and append + assert (args.output is None) != (args.append is None), "Exactly one of --output and --append must be specified" + + # chat_template = args.chat_template + n = args.n + + n_predict = 512 + + with open(args.output or args.append, 'w' if args.output else 'a') as output_file: + + def run(server: ServerProcess, *, implementation: str, model_id: str, temp: float | None = None, output_kwargs={}, request_kwargs={}): + request_kwargs = {**request_kwargs} + if temp is not None: + request_kwargs['temperature'] = temp + if args.top_p is not None: + request_kwargs['top_p'] = args.top_p + if args.top_k is not None: + request_kwargs['top_k'] = args.top_k + if args.seed is not None: + request_kwargs['seed'] = args.seed + + request_kwargs['cache_prompt'] = False + + tests = { + "hello world": lambda server: do_test_hello_world(server, **request_kwargs), + "weather": lambda server: do_test_weather(server, **request_kwargs), + "calc result": lambda server: do_test_calc_result(server, None, 512, **request_kwargs), + } + for test_name, test in tests.items(): + success_count = 0 + failure_count = 0 + failures = [] + success_times = [] + failure_times = [] + print(f"Running {test_name} ({implementation}, {args.model}): ", file=sys.stderr, flush=True) + for i in range(n): + start_time = time.time() + def elapsed(): + return time.time() - start_time + try: + test(server) + success_times.append(elapsed()) + success_count += 1 + print('.', end='', file=sys.stderr, flush=True) + except Exception as e: + print('!', end='', file=sys.stderr, flush=True) + if failure_count == 0: + print(f" ({e}) ", end='', file=sys.stderr, flush=True) + failure_count += 1 + failure_times.append(elapsed()) + failures.append(str(e)) + print('\n', file=sys.stderr, flush=True) + output_file.write(json.dumps({**output_kwargs, **dict( + model=args.model, + implementation=implementation, + model_id=model_id, + test=test_name, + temp=temp, + top_p=args.top_p, + top_k=args.top_k, + success_ratio=float(success_count) / n, + avg_time=mean(success_times + failure_times), + median_time=median(success_times + failure_times), + success_count=success_count, + success_times=success_times, + failure_count=failure_count, + failure_times=failure_times, + failures=list(set(failures)), + )}) + '\n') + output_file.flush() + + temps = [float(temp) if temp != "" else None for temp in args.temps.split(',')] if args.temps is not None else [None] + for temp in temps: + if args.hf is not None: + server = ServerProcess() + server.n_slots = 1 + server.jinja = True + server.n_predict = 512 # High because of DeepSeek R1 + server.model_hf_repo = args.hf + server.model_hf_file = None + server.chat_template = args.chat_template + if args.port is not None: + server.server_port = args.port + # server.debug = True + + with scoped_server(server): + server.start(timeout_seconds=TIMEOUT_SERVER_START) + for ignore_chat_grammar in [False, True]: + run( + server, + implementation="llama-server" + (" (no grammar)" if ignore_chat_grammar else ""), + model_id=args.hf, + temp=temp, + output_kwargs=dict( + chat_template=args.chat_template, + ), + request_kwargs=dict( + ignore_chat_grammar=ignore_chat_grammar, + ), + ) + + if args.ollama is not None: + server = ServerProcess() + server.server_port = 11434 + server.server_host = "localhost" + subprocess.check_call(["ollama", "pull", args.ollama]) + + with scoped_server(server): + run( + server, + implementation="ollama", + model_id=args.ollama, + temp=temp, + output_kwargs=dict( + chat_template=None, + ), + request_kwargs=dict( + model=args.ollama, + ), + ) From d5aff5afdb450f702a5519344b66c02062916980 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 15:26:24 +0000 Subject: [PATCH 08/43] support RETRIES=N in server test utils --- examples/server/tests/utils.py | 45 ++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index a82504235ff54..2e850b216b1b3 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -181,7 +181,7 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--chat-template-file", self.chat_template_file]) args = [str(arg) for arg in [server_path, *server_args]] - print(f"bench: starting server with: {' '.join(args)}") + print(f"tests: starting server with: {' '.join(args)}") flags = 0 if "nt" == os.name: @@ -233,23 +233,32 @@ def make_request( timeout: float | None = None, ) -> ServerResponse: url = f"http://{self.server_host}:{self.server_port}{path}" - parse_body = False - if method == "GET": - response = requests.get(url, headers=headers, timeout=timeout) - parse_body = True - elif method == "POST": - response = requests.post(url, headers=headers, json=data, timeout=timeout) - parse_body = True - elif method == "OPTIONS": - response = requests.options(url, headers=headers, timeout=timeout) - else: - raise ValueError(f"Unimplemented method: {method}") - result = ServerResponse() - result.headers = dict(response.headers) - result.status_code = response.status_code - result.body = response.json() if parse_body else None - print("Response from server", json.dumps(result.body, indent=2)) - return result + retries = int(os.environ.get('RETRIES', '1')) + for remaining_attempts in range(retries, 0, -1): + # print(f"#\ncurl {url} -d '{json.dumps(data, indent=2)}'\n") + parse_body = False + if method == "GET": + response = requests.get(url, headers=headers, timeout=timeout) + parse_body = True + elif method == "POST": + response = requests.post(url, headers=headers, json=data, timeout=timeout) + parse_body = True + elif method == "OPTIONS": + response = requests.options(url, headers=headers, timeout=timeout) + else: + raise ValueError(f"Unimplemented method: {method}") + + if (response is None or response.status_code != 200) and remaining_attempts > 0: + continue + result = ServerResponse() + result.headers = dict(response.headers) + result.status_code = response.status_code + result.body = response.json() if parse_body else None + # print("Response from server", json.dumps(result.body, indent=2)) + return result + + raise RuntimeError(f"Failed to make request to {url} after {retries} attempts") + def make_stream_request( self, From a06821cac6e7ca193f5061d6496614ebfbdd47dc Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 15:26:46 +0000 Subject: [PATCH 09/43] server: detect premature llama-server death in e2e tests --- examples/server/tests/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 2e850b216b1b3..bb97178dcb4cb 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -212,6 +212,10 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: return # server is ready except Exception as e: pass + # Check if process died + if self.process.poll() is not None: + raise RuntimeError(f"Server process died with return code {self.process.returncode}") + print(f"Waiting for server to start...") time.sleep(0.5) raise TimeoutError(f"Server did not start within {timeout_seconds} seconds") From d73c89430fd72b05a95f9f3190ac8f7d65d19796 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 15:30:02 +0000 Subject: [PATCH 10/43] tool-call: improve Qwen 2.5 & Llama 3.x w/ more triggers, constrained python code strings and revamped parsers --- common/chat.cpp | 441 ++++++++++++++----- examples/server/tests/unit/test_tool_call.py | 86 ++-- tests/test-chat.cpp | 80 ++++ 3 files changed, 454 insertions(+), 153 deletions(-) mode change 100644 => 100755 examples/server/tests/unit/test_tool_call.py diff --git a/common/chat.cpp b/common/chat.cpp index 9ebe4c5784cbc..bf1269e1ac61f 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -455,7 +455,7 @@ const common_grammar_options grammar_options { // /* .compact_spaces = */ true, }; -static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { +static std::optional parse_json(std::string::const_iterator & it, const std::string::const_iterator & end) { // // https://json.nlohmann.me/features/parsing/sax_interface/ struct json_error_locator : public nlohmann::json_sax { std::size_t position; @@ -492,14 +492,42 @@ static bool parse_json(std::string::const_iterator & it, const std::string::cons } std::string json_sub {it, temptative_end}; try { - out = json::parse(json_sub); + auto out = json::parse(json_sub); it = temptative_end; - return true; + return out; } catch (const std::exception &) { - return false; + return std::nullopt; } } +static bool parse_literal(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) { + auto expected_it = expected.begin(); + auto tmp_it = it; + while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) { + ++tmp_it; + ++expected_it; + } + if (expected_it == expected.end()) { + it = tmp_it; + return true; + } + return false; +} + +static std::optional parse_pattern(std::string::const_iterator & it, const std::string::const_iterator & end, const std::regex & expected) { + std::smatch match; + if (std::regex_match(it, end, match, expected)) { + it = match.suffix().first; + return match; + } + return std::nullopt; +} + +static void consume_spaces(std::string::const_iterator & it, const std::string::const_iterator & end) { + while (it != end && std::isspace(*it)) { + ++it; + } +} /** * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. @@ -509,7 +537,8 @@ static common_chat_msg parse_json_tool_calls( const std::string& input, const std::optional & trigger_opt, const std::regex & function_regex, - const std::regex & close_regex) { + const std::regex & close_regex, + bool allow_raw_python = false) { std::smatch match; common_chat_msg result; @@ -539,15 +568,19 @@ static common_chat_msg parse_json_tool_calls( result.content += std::string(it, rit->prefix().second); it = rit->suffix().first; - json arguments; - if (!parse_json(it, end, arguments)) { + if (auto arguments = parse_json(it, end)) { + if (!std::regex_search(it, end, match, close_regex)) { + throw std::runtime_error("Malformed input, missing closing pattern: " + input); + } + it = match.suffix().first; + result.tool_calls.push_back({name, arguments->is_string() ? arguments->get() : arguments->dump(), /* id= */ ""}); + } else { + if (allow_raw_python && name == "python") { + result.tool_calls.push_back({name, json({{"code", std::string(it, end)}}).dump(), /* id= */ ""}); + break; + } throw std::runtime_error("Failed to parse json tool call arguments: " + input); } - if (!std::regex_search(it, end, match, close_regex)) { - throw std::runtime_error("Malformed input, missing closing pattern: " + input); - } - it = match.suffix().first; - result.tool_calls.push_back({name, arguments.is_string() ? arguments.get() : arguments.dump(), /* id= */ ""}); } if (!result.tool_calls.empty()) { @@ -559,29 +592,29 @@ static common_chat_msg parse_json_tool_calls( return result; } +static common_chat_tool_call process_tool_call(const json & tool_call) { + const auto & arguments = tool_call.at("arguments"); + return { + /* .name = */ tool_call.at("name"), + /* .arguments = */ arguments.is_string() ? arguments.get() : arguments.dump(), + /* .id = */ tool_call.contains("id") ? tool_call.at("id") : "", + }; +} static common_chat_msg parse_prefixed_json_tool_call_array(const std::string& input, const std::string & prefix, size_t rstrip_prefix = 0) { auto content_end = input.find(prefix); size_t tc_start = std::string::npos; common_chat_msg result; result.role = "assistant"; - const auto process_tool_calls = [&](const json & tool_calls) { - for (const auto & tool_call : tool_calls) { - const auto & arguments = tool_call.at("arguments"); - result.tool_calls.push_back({ - tool_call.at("name"), - arguments.is_string() ? arguments.get() : arguments.dump(), - tool_call.contains("id") ? tool_call.at("id") : "", - }); - } - }; if (content_end == std::string::npos) { result.content = input; } else { tc_start = content_end + prefix.size() - rstrip_prefix; result.content = input.substr(0, content_end); auto tool_calls = json::parse(input.substr(tc_start)); - process_tool_calls(tool_calls); + for (const auto & tool_call : tool_calls) { + result.tool_calls.emplace_back(process_tool_call(tool_call)); + } } return result; } @@ -840,9 +873,9 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ return data; } static common_chat_msg common_chat_parse_command_r7b(const std::string & input, bool extract_reasoning) { - static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S\\n\\r]*?)<\\|END_THINKING\\|>)([\\s\\S\\n\\r]*)"); - static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S\\n\\r]*?)<\\|END_ACTION\\|>"); - static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S\\n\\r]*?)<\\|END_RESPONSE\\|>"); + static std::regex thought_regex("(<\\|START_THINKING\\|>([\\s\\S]*?)<\\|END_THINKING\\|>)([\\s\\S]*)"); + static std::regex action_regex("<\\|START_ACTION\\|>([\\s\\S]*?)<\\|END_ACTION\\|>"); + static std::regex response_regex("(?:<\\|START_RESPONSE\\|>)?([\\s\\S]*?)<\\|END_RESPONSE\\|>"); std::smatch match; @@ -898,6 +931,78 @@ static void expect_tool_parameters(const std::string & name, const json & parame } } +/* + Adds a GBNF rule that matches a Python code string when escaped inside a JSON string (without surrounding double quotes) + + If this sounds meta, well, it is: + - Most tool call style pass tool arguments as JSON objects, e.g. {"arg1": , ...} + - When the tool is `python` and the argument is `code`, the value is JSON-escaped Python code. + Some models (Llama 3.x series) tend to close the code string itself when the nested code + tries to open a double quoted string. So when the model wants to write the code `print("Hey")`, + it only goes so far as `{"code": "print("` and the general JSON constraints of the python tool arguments call it a day. + - This rule (when wrapped in double quotes) can be used instead of a JSON string + to match a structured soup of Python tokens that has the following properties: + - All open brackets / braces / parentheses are closed + - All strings (single or double quoted) are closed + - All double quotes are escaped + + This should prevent an entire class of invalid Python programs to be generated by the model, + but any bugs / omissions may also disallow some valid Python syntax. Current limitations: + + - No f strings + - No multiline strings + + Examples: + + - OK + {"code": "print('Hey')"} + {"code": "print(\"Hey\")"} + {"code": "# in \" comments...\nx = \"Hey\""} + - NOT OK + {"code": "print("} + {"code": "print(\""} + {"code": "print('"} +*/ +static std::string add_escaped_python_code_soup_rule(const common_grammar_builder & builder) { + return builder.add_rule("json-escaped-code-soup", + // Allow comments w/ (escaped) newline + R"( ( [#] ( ( [^\\\t\r\n\uff00-\uffef] | [\\] [^n\n] )* [\\] [n] )? | )" + // Allow (escaped) double quoted strings and their nested (double) escapes + R"( [\\] ["] ( [^"\\\t\r\n\uff00-\uffef] | [\\] [\\] ["] | [\\] [trnu] )* [\\] ["] | )" + // Allow single quoted strings and their nested (double) escapes + R"( ['] ( [^"'\\\t\r\n\uff00-\uffef] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] )* ['] | )" + // Soup wrapped in parentheses, curly braces or square brackets + R"( [(] json-escaped-code-soup [)] | )" + R"( [{] json-escaped-code-soup [}] | )" + R"( "[" json-escaped-code-soup "]" | )" + // Allow escapes + R"( [\\] [\\trnu] | )" + // Allow other characters, minus code blocks for halfwidth & fullwidth forms (U+FF00 - U+FFEF) + // (special tokens can use these to avoid prompt injections, as they will have to be unicode-escaped w/ \uXXXX + // and won't be able to interfere w/ parsing) + R"( [^#{}"'\[\]\\()\t\r\n\uff00-\uffef]+ )" + // After any repetition of the previous, allow trailing comment w/o newline + R"( )* ( [#] ( [^\\] | [\\] [^n] )* )? )" + ); +} + +static std::string add_python_code_arguments_rule(const std::string & name, const common_grammar_builder & builder) { + return builder.add_rule( + name, + "\"{\" space \"\\\"code\\\": \\\"\" " + + add_escaped_python_code_soup_rule(builder) + + " \"\\\"\" space \"}\" space "); +} + +static std::string add_json_tool_args_rule(const std::string & name, const json & parameters, const common_grammar_builder & builder) { + if (name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { + return add_python_code_arguments_rule(name + "-code-args", builder); + } else { + return builder.add_schema(name + "-args", parameters); + } +} + + static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { auto builtin_tools = json::array(); common_chat_params data; @@ -919,7 +1024,11 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com std::vector kvs; for (const auto & [key, value] : parameters.at("properties").items()) { - kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT + if (name == "python" && key == "code") { + kvs.push_back("\"" + key + "=\\\"\" " + add_escaped_python_code_soup_rule(builder) + " \"\\\"\""); // NOLINT + } else { + kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT + } } tool_rules.push_back( @@ -947,7 +1056,7 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com "\"{\" space " "( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? " "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + - builder.add_schema(name + "-args", parameters) + + add_json_tool_args_rule(name, parameters, builder) + " \"}\"")); data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); }); @@ -974,33 +1083,33 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com } static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) { // TODO: tighten & simplify the parser, don't accept leading text context. - static std::regex function_regex("\\{[\\s\\n\\r]*(?:\"type\"[\\s\\n\\r]*:[\\s\\n\\r]*\"function\"[\\s\\n\\r]*,[\\s\\n\\r]*|[\\s\\n\\r]*)\"name\"[\\s\\n\\r]*:[\\s\\n\\r]*\"([^\"]+)\"[\\s\\n\\r]*,[\\s\\n\\r]*\"parameters\": "); + static std::regex function_regex( + "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*|\\s*)\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\": "); static std::regex close_regex("\\}"); - static std::regex builtin_call_regex("<\\|python_tag\\|>([^.(]+)\\.call\\((.*)\\)"); + static std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)"); if (with_builtin_tools) { std::smatch match; if (std::regex_match(input, match, builtin_call_regex)) { - auto name = match[1].str(); - auto raw_args = match[2].str(); - - // TODO: if/when builtin tools start accepting more than 1 argument, use parse_json for real parsing. - auto it_eq = raw_args.find('='); - auto arg_name = raw_args.substr(0, it_eq); - auto arg_value_str = raw_args.substr(it_eq + 1); - auto arg_value = json::parse(arg_value_str); - - common_chat_msg msg; - msg.role = "assistant"; - msg.content = match.prefix().str(); - msg.tool_calls.push_back({ - /* .name = */ name, - /* .arguments = */ (json { - {arg_name, arg_value}, - }).dump(), - /* .id = */ "", - }); - return msg; + try { + auto name = match[1].str(); + auto arg_name = match[2].str(); + auto arg_value_str = match[3].str(); + auto arg_value = json::parse(arg_value_str); + + common_chat_msg msg; + msg.role = "assistant"; + msg.tool_calls.push_back({ + /* .name = */ name, + /* .arguments = */ (json { + {arg_name, arg_value}, + }).dump(), + /* .id = */ "", + }); + return msg; + } catch (const std::exception & e) { + LOG_WRN("Failed to parse builtin tool call arguments (%s): %s", e.what(), input.c_str()); + } } } return parse_json_tool_calls(input, std::nullopt, function_regex, close_regex); @@ -1017,10 +1126,10 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ std::string name = function.at("name"); auto parameters = function.at("parameters"); builder.resolve_refs(parameters); - auto args_rule = builder.add_schema(name + "-args", parameters); tool_rules.push_back(builder.add_rule(name + "-call", "\"<|tool▁call▁begin|>function<|tool▁sep|>" + name + "\\n" - "```json\\n\" " + args_rule + " \"```<|tool▁call▁end|>\"")); + "```json\\n\" " + add_json_tool_args_rule(name, parameters, builder) + " " + "\"```<|tool▁call▁end|>\"")); }); // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, // so we accept common variants (then it's all constrained) @@ -1158,11 +1267,16 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ auto parameters = function.at("parameters"); builder.resolve_refs(parameters); auto args_rule = builder.add_schema(name + "-args", parameters); - first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); + first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); data.grammar_triggers.push_back({name, /* .at_start = */ true}); + data.grammar_triggers.push_back({"assistant<|end_header_id|>\n" + name, /* .at_start = */ true}); data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false}); + data.grammar_triggers.push_back({">>>assistant<|end_header_id|>\n" + name, /* .at_start = */ false}); }); + data.preserved_tokens = { + "<|end_header_id|>", + }; auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; if (inputs.parallel_tool_calls) { auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; @@ -1176,29 +1290,15 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ return data; } -static bool consume(std::string::const_iterator & it, const std::string::const_iterator & end, const std::string & expected) { - auto expected_it = expected.begin(); - auto tmp_it = it; - while (tmp_it != end && expected_it != expected.end() && *tmp_it == *expected_it) { - ++tmp_it; - ++expected_it; - } - if (expected_it == expected.end()) { - it = tmp_it; - return true; - } - return false; -} - static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & input) { - static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); + static std::regex function_regex(R"((?:>>>)?(?:assistant<|end_header_id|>\n)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); std::string content; auto it = input.begin(); const auto end = input.end(); - if (consume(it, end, "all\n")) { + if (parse_literal(it, end, "all\n")) { std::smatch match; if (std::regex_search(it, end, match, function_regex)) { auto fun_it = match.prefix().second; @@ -1213,7 +1313,7 @@ static common_chat_msg common_chat_parse_functionary_v3_2(const std::string & in } // TODO: tighten & simplify. try { - auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex); + auto res = parse_json_tool_calls(std::string(it, end), std::nullopt, function_regex, close_regex, /* allow_raw_python= */ true); res.content = content + res.content; return res; } catch (const std::exception & e) { @@ -1306,70 +1406,185 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { std::vector tool_rules; + std::vector tool_call_alts; foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); std::string name = function.at("name"); auto parameters = function.at("parameters"); builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_schema(name + "-call", { - {"type", "object"}, - {"properties", json { - {"name", json {{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - })); + if (name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { + tool_rules.push_back(builder.add_rule(name + "-call", + "\"{\" space " + "\"\\\"name\\\":\" space \"\\\"" + name + "\\\"\" space \",\" space " + "\"\\\"arguments\\\":\" space " + add_python_code_arguments_rule(name + "-code-arguments", builder) + " " + "\"}\" space ")); + } else { + tool_rules.push_back(builder.add_schema(name + "-call", { + {"type", "object"}, + {"properties", json { + {"name", json {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + })); + } + tool_call_alts.push_back(builder.add_rule( + name + "-function-tag", + "\"\" space " + + builder.add_schema(name + "-args", parameters) + " " + "\"\" space")); }); - auto tool_call = "\"\" space " + builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " \"\" space"; + auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space"); + std::vector alt_tags { + any_tool_call, + "\"\" space " + any_tool_call + " \"\"", + // The rest is just to accommodate common "good bad" outputs. + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + "\"\" space " + any_tool_call + " \"\"", + }; + auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space"); + tool_call_alts.push_back(wrappable_tool_call); + tool_call_alts.push_back( + "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space "); + auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | ")); builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); data.grammar_triggers.push_back({"", /* .at_start = */ false}); - data.preserved_tokens = { "" }; + data.grammar_triggers.push_back({"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```\n{\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```\n {\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```\n{\n \"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```json\n{\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```json\n {\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```json\n{\n \"name\": \"", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```xml\n{\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```xml\n {\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```xml\n{\n \"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({"```xml\n\n {\"name\":", /* .at_start = */ true}); + data.preserved_tokens = { + "", + "", + "", + "", + "", + "", + "```", + "```json", + "```xml", + }; }, grammar_options); data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_HERMES_2_PRO; return data; } -static common_chat_msg common_chat_parse_hermes_2_pro(const std::string & input) { - try { - std::regex start_pattern(R"([\n\s]*)"); - std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); - std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); +static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input) { + const static std::regex open_regex( + "(?:" + "(```(?:xml|json)?\\n)?" // match 1 (block_start) + "(" // match 2 (open_tag) + "|" + "|" + "|" + "|" + "|" + "|" + "|" + ")?" + "(\\s*\\{\\s*\"name\":[\\s\\S]*)" // match 3 (named tool call + rest) + ")" + "|" + "(?:]+)>" // match 4 (function name) + "|)" // match 5 (function name again) + "([\\s\\S]*)" // match 6 (function arguments + rest)})" + ); + try { + common_chat_msg msg; msg.role = "assistant"; - auto end = input.end(); - std::sregex_iterator rend; - std::sregex_iterator rit(input.begin(), end, start_pattern); - if (rit == rend) { - msg.content = input; - return msg; - } - - msg.content = rit->prefix(); + std::string::const_iterator it = input.begin(); + const std::string::const_iterator end = input.end(); + std::smatch match; - auto it = rit->suffix().first; while (it != end) { - json call; - if (!parse_json(it, end, call)) { - throw std::runtime_error("Failed to parse json tool call"); - } - const auto & arguments = call.at("arguments"); - msg.tool_calls.push_back({ - call.at("name"), - arguments.dump(), - // arguments.is_string() ? arguments.get() : arguments.dump(), - /* id= */ "", - }); - rit = {it, end, middle_pattern}; - if (rit != rend) { - it = rit->suffix().first; - } else { - rit = {it, end, end_pattern}; - if (rit == rend) { - throw std::runtime_error("Malformed input, missing "); + if (std::regex_search(it, end, match, open_regex)) { + // Add content before the match + msg.content += std::string(it, match[0].first); + + auto block_start = match[1].str(); + std::string block_end = block_start.empty() ? "" : "```"; + + auto open_tag = match[2].str(); + std::string close_tag; + + if (match[3].matched) { + close_tag = open_tag.empty() ? "" : "contains("name") && tool_call->contains("arguments")) { + + msg.tool_calls.emplace_back(process_tool_call(*tool_call)); + it = json_it; // Move iterator past parsed JSON + + // Handle close tags + consume_spaces(it, end); + if (!close_tag.empty() && !parse_literal(it, end, close_tag)) { + throw std::runtime_error("Failed to parse closing tag"); + } + consume_spaces(it, end); + if (!block_end.empty() && !parse_literal(it, end, block_end)) { + throw std::runtime_error("Failed to parse block end"); + } + } else { + // Not a valid tool call, treat as content + msg.content += std::string(match[0].first, match[0].second); + it = match[0].second; + } + } else { + auto function_name = match[4].str(); + if (function_name.empty()) { + function_name = match[5].str(); + } + GGML_ASSERT(!function_name.empty()); + + close_tag = ""; + // Start parsing from after the opening tags + auto json_it = match[6].first; + if (auto arguments = parse_json(json_it, end)) { + msg.tool_calls.emplace_back(process_tool_call({ + {"name", function_name}, + {"arguments", *arguments}, + })); + it = json_it; // Move iterator past parsed JSON + + // Handle close tags + consume_spaces(it, end); + if (!close_tag.empty() && !parse_literal(it, end, close_tag)) { + throw std::runtime_error("Failed to parse closing tag"); + } + consume_spaces(it, end); + if (!block_end.empty() && !parse_literal(it, end, block_end)) { + throw std::runtime_error("Failed to parse block end"); + } + } else { + // Not a valid tool call, treat as content + msg.content += std::string(match[0].first, match[0].second); + it = match[0].second; + } } + } else { + // Add remaining content + msg.content += std::string(it, end); break; } } diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py old mode 100644 new mode 100755 index a91a2f3333ca3..903867201af76 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -142,25 +142,29 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), - # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - # (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), - # (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), @@ -176,10 +180,10 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), - # (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), - # TODO: fix these - # (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - # (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), + + (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), ]) def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server @@ -281,6 +285,9 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), @@ -488,42 +495,45 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] @pytest.mark.slow -@pytest.mark.parametrize("expected_arguments_override,hf_repo,template_override", [ - (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - # (None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"), +@pytest.mark.parametrize("hf_repo,template_override", [ + ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), - (None, "bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), - ('{"code":"print("}', "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - (None, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), - (None, "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), - ('{"code":"print("}', "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), - (None, "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), - (None, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (None, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), - (None, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), - (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - (None, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), - # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. - (None, "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"), ]) -def test_hello_world(expected_arguments_override: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None): +def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None): global server + n_predict = 512 # High because of DeepSeek R1 server.n_slots = 1 server.jinja = True server.n_ctx = 8192 @@ -556,12 +566,8 @@ def test_hello_world(expected_arguments_override: str | None, hf_repo: str, temp tool_call = tool_calls[0] assert choice["message"].get("content") is None, f'Expected no content in {choice["message"]}' assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] - actual_arguments = tool_call["function"]["arguments"] - if expected_arguments_override is not None: - assert actual_arguments == expected_arguments_override - else: - actual_arguments = json.loads(actual_arguments) - assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" - code = actual_arguments["code"] - assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" - assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}' + actual_arguments = json.loads(tool_call["function"]["arguments"]) + assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" + code = actual_arguments["code"] + assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" + assert re.match(r'''((#.*)?\n)*print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}' diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 6435923054859..2b1226292177b 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -640,6 +640,86 @@ static void test_template_output_parsers() { inputs_tools) .format); + // Test parsing + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "{\"arg1\": 1}", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + "{\"arg1\": 1}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "```xml\n" + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "\n" + "```", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "```xml\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "```\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "```\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "```json\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "```", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\n" + " \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n" + " }\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); + test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); test_templates(tmpls.get(), end_tokens, message_assist_call, tools, "\n" From 44607981ad389c320012a96a8088fc38707bc0f0 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 17:24:20 +0000 Subject: [PATCH 11/43] migrate to trigger patterns, more lenient parser wrt/ spaces introduce trigger types llama 3.x: Allow leading spaces more lenient parsing --- common/chat.cpp | 133 +++++++++++++++++++++++-------------- common/common.cpp | 5 ++ common/common.h | 16 +++-- common/sampling.cpp | 61 +++++++++++------ examples/server/server.cpp | 80 +++++++++++++++------- examples/server/utils.hpp | 6 +- include/llama.h | 18 ++++- src/llama-grammar.cpp | 20 +++--- src/llama-grammar.h | 6 +- src/llama-sampling.cpp | 61 ++++++++++++----- tests/test-chat.cpp | 53 ++++++++++++--- 11 files changed, 315 insertions(+), 144 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index bf1269e1ac61f..fd3552c91ce57 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -804,7 +804,10 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat } builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); }, grammar_options); - data.grammar_triggers.push_back({"[TOOL_CALLS]", /* .at_start = */ true}); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"}); + data.preserved_tokens = { + "[TOOL_CALLS]", + }; data.prompt = apply(tmpl, inputs.messages, inputs.tools.empty() ? json() : inputs.tools, inputs.add_generation_prompt); data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; return data; @@ -847,13 +850,17 @@ static common_chat_params common_chat_params_init_command_r7b(const common_chat_ } builder.add_rule("root", "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\""); }, grammar_options); - data.grammar_triggers.push_back({"<|START_ACTION|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + "<|START_ACTION|>", + }); data.preserved_tokens = { + "<|START_ACTION|>", + "<|END_ACTION|>", "<|START_RESPONSE|>", "<|END_RESPONSE|>", "<|START_THINKING|>", "<|END_THINKING|>", - "<|END_ACTION|>", }; auto adjusted_messages = json::array(); for (const auto & msg : inputs.messages) { @@ -1054,21 +1061,20 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com builder.add_rule( name + "-call", "\"{\" space " - "( \"\\\"type\\\":\" space \"\\\"function\\\",\" space )? " - "\"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + - add_json_tool_args_rule(name, parameters, builder) + - " \"}\"")); - data.grammar_triggers.push_back({"{\"name\": \"" + name + "\"", /* .at_start = */ true}); + "( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? " + " \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space " + " \"\\\"parameters\\\"\" space \":\" space " + add_json_tool_args_rule(name, parameters, builder) + " " + "\"}\" space")); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, + "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"" + name + "\"[\\s\\S]*", + }); }); - data.grammar_triggers.push_back({"{\"name\":", /* .at_start = */ true}); - data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true}); - data.grammar_triggers.push_back({"{\n \"name\":", /* .at_start = */ true}); - data.grammar_triggers.push_back({"{\"type\": \"function\"", /* .at_start = */ true}); - data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true}); - data.grammar_triggers.push_back({"{\n \"type\": \"function\"", /* .at_start = */ true}); if (!builtin_tools.empty()) { - data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); + data.preserved_tokens.push_back("<|python_tag|>"); } + // Allow a few empty lines on top of the usual constrained json schema space rule. builder.add_rule("root", string_join(tool_rules, " | ")); }, grammar_options); data.additional_stops.push_back("<|eom_id|>"); @@ -1084,8 +1090,8 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com static common_chat_msg common_chat_parse_llama_3_1(const std::string & input, bool with_builtin_tools = false) { // TODO: tighten & simplify the parser, don't accept leading text context. static std::regex function_regex( - "\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*|\\s*)\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\": "); - static std::regex close_regex("\\}"); + "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); + static std::regex close_regex("\\}\\s*"); static std::regex builtin_call_regex("<\\|python_tag\\|>\\s*([^.(]+)\\s*\\.\\s*call\\s*\\(\\s*([\\w]+)\\s*=\\s*([\\s\\S]*?)\\)"); if (with_builtin_tools) { @@ -1138,16 +1144,18 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_ "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " "\"<|tool▁calls▁end|>\"" " space"); - data.grammar_triggers.push_back({"<|tool▁calls▁begin|>", /* .at_start = */ false}); - data.grammar_triggers.push_back({"<|tool_calls_begin|>", /* .at_start = */ false}); - data.grammar_triggers.push_back({"<|tool calls begin|>", /* .at_start = */ false}); - data.grammar_triggers.push_back({"<|tool\\_calls\\_begin|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool▁calls▁begin|>"}); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls_begin|>"}); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool calls begin|>"}); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool\\_calls\\_begin|>"}); data.preserved_tokens = { "", "", + "<|tool▁calls▁begin|>", + "<|tool▁call▁begin|>", "<|tool▁sep|>", - "<|tool▁calls▁end|", "<|tool▁call▁end|>", + "<|tool▁calls▁end|", }; }, grammar_options); } @@ -1239,7 +1247,10 @@ static common_chat_params common_chat_params_init_firefunction_v2(const common_c } builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); }, grammar_options); - data.grammar_triggers.push_back({" functools[", /* .at_start = */ false}); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["}); + data.preserved_tokens = { + " functools[", + }; data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2; } else { data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; @@ -1269,10 +1280,22 @@ static common_chat_params common_chat_params_init_functionary_v3_2(const common_ auto args_rule = builder.add_schema(name + "-args", parameters); first_tool_rules.push_back(builder.add_rule(name + "-call", "( \"assistant<|end_header_id|>\\n\" )? \"" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); - data.grammar_triggers.push_back({name, /* .at_start = */ true}); - data.grammar_triggers.push_back({"assistant<|end_header_id|>\n" + name, /* .at_start = */ true}); - data.grammar_triggers.push_back({">>>" + name, /* .at_start = */ false}); - data.grammar_triggers.push_back({">>>assistant<|end_header_id|>\n" + name, /* .at_start = */ false}); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, + regex_escape(name + "\n"), + }); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, + regex_escape("assistant<|end_header_id|>\n" + name + "\n"), + }); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + regex_escape(">>>" + name + "\n"), + }); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + ">>>assistant<|end_header_id|>\n" + name, + }); }); data.preserved_tokens = { "<|end_header_id|>", @@ -1366,11 +1389,12 @@ static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(con }); if (has_raw_python) { tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); - data.grammar_triggers.push_back({"<|python_tag|>", /* .at_start = */ false}); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); + data.preserved_tokens.push_back("<|python_tag|>"); } auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); - data.grammar_triggers.push_back({"\" space " + builder.add_schema(name + "-args", parameters) + " " "\"\" space")); + + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + "", + }); + auto escaped_name = regex_escape(name); + data.grammar_triggers.push_back({ + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, + "|||)?(?:```(?:json|xml)?\n)?\\s*\\{\\s*\"name\"\\s*:\\s*\"" + escaped_name + "\"", + }); }); auto any_tool_call = builder.add_rule("any_tool_call", "( " + string_join(tool_rules, " | ") + " ) space"); std::vector alt_tags { @@ -1452,29 +1491,21 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space "); auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | ")); builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); - data.grammar_triggers.push_back({"", /* .at_start = */ false}); - data.grammar_triggers.push_back({"", /* .at_start = */ true}); - data.grammar_triggers.push_back({"", /* .at_start = */ true}); - data.grammar_triggers.push_back({"", /* .at_start = */ true}); - data.grammar_triggers.push_back({"", /* .at_start = */ true}); - data.grammar_triggers.push_back({"```\n{\"name\":", /* .at_start = */ true}); - data.grammar_triggers.push_back({"```\n {\"name\":", /* .at_start = */ true}); - data.grammar_triggers.push_back({"```\n{\n \"name\":", /* .at_start = */ true}); - data.grammar_triggers.push_back({"```json\n{\"name\":", /* .at_start = */ true}); - data.grammar_triggers.push_back({"```json\n {\"name\":", /* .at_start = */ true}); - data.grammar_triggers.push_back({"```json\n{\n \"name\": \"", /* .at_start = */ true}); - data.grammar_triggers.push_back({"```xml\n{\"name\":", /* .at_start = */ true}); - data.grammar_triggers.push_back({"```xml\n {\"name\":", /* .at_start = */ true}); - data.grammar_triggers.push_back({"```xml\n{\n \"name\":", /* .at_start = */ true}); - data.grammar_triggers.push_back({"```xml\n\n {\"name\":", /* .at_start = */ true}); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, ""}); + data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "", "", + "", "", + "", "", + "", "", + "", "", + "", "", "```", "```json", @@ -1489,7 +1520,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input) { const static std::regex open_regex( "(?:" - "(```(?:xml|json)?\\n)?" // match 1 (block_start) + "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) "(" // match 2 (open_tag) "|" "|" @@ -1499,7 +1530,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input) "|" "|" ")?" - "(\\s*\\{\\s*\"name\":[\\s\\S]*)" // match 3 (named tool call + rest) + "(\\s*\\{\\s*\"name\"\\s*:[\\s\\S]*)" // match 3 (named tool call + rest) ")" "|" "(?:]+)>" // match 4 (function name) @@ -1545,6 +1576,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input) if (!block_end.empty() && !parse_literal(it, end, block_end)) { throw std::runtime_error("Failed to parse block end"); } + consume_spaces(it, end); } else { // Not a valid tool call, treat as content msg.content += std::string(match[0].first, match[0].second); @@ -1576,6 +1608,7 @@ static common_chat_msg common_chat_parse_hermes_2_pro(const std::string& input) if (!block_end.empty() && !parse_literal(it, end, block_end)) { throw std::runtime_error("Failed to parse block end"); } + consume_spaces(it, end); } else { // Not a valid tool call, treat as content msg.content += std::string(match[0].first, match[0].second); diff --git a/common/common.cpp b/common/common.cpp index d2b0d50e3ee39..7f0d18f2da67d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -483,6 +483,11 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } +std::string regex_escape(const std::string & s) { + static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); + return std::regex_replace(s, special_chars, "\\$0"); +} + std::string string_join(const std::vector & values, const std::string & separator) { std::ostringstream result; for (size_t i = 0; i < values.size(); ++i) { diff --git a/common/common.h b/common/common.h index 10bcc10d51bb5..cbd485d77fcf8 100644 --- a/common/common.h +++ b/common/common.h @@ -110,9 +110,16 @@ enum common_conversation_mode { COMMON_CONVERSATION_MODE_AUTO = 2, }; +enum common_grammar_trigger_type { + COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN, + COMMON_GRAMMAR_TRIGGER_TYPE_WORD, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START, +}; + struct common_grammar_trigger { - std::string word; - bool at_start; + common_grammar_trigger_type type; + std::variant value; }; // sampling parameters @@ -163,8 +170,7 @@ struct common_params_sampling { std::string grammar; // optional BNF-like grammar to constrain sampling bool grammar_lazy = false; - std::vector grammar_trigger_words; // optional trigger words to trigger lazy grammar - std::vector grammar_trigger_tokens; // optional trigger tokens to trigger lazy grammar and print trigger special tokens. + std::vector grammar_triggers; // optional trigger words to trigger lazy grammar std::set preserved_tokens; std::vector logit_bias; // logit biases to apply @@ -453,6 +459,8 @@ std::string string_repeat(const std::string & str, size_t n); void string_replace_all(std::string & s, const std::string & search, const std::string & replace); +std::string regex_escape(const std::string & s); + template static std::vector string_split(const std::string & str, char delim) { static_assert(!std::is_same::value, "Please use the specialized version for std::string"); diff --git a/common/sampling.cpp b/common/sampling.cpp index 9eea0f749f3be..20bf20bb2578a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -145,11 +145,6 @@ std::string common_params_sampling::print() const { return std::string(result); } -inline std::string regex_escape(const std::string & literal) { - static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); - return std::regex_replace(literal, special_chars, "\\$0"); -} - struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); @@ -165,31 +160,53 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); #endif // LLAMA_USE_LLGUIDANCE } else { - std::vector escaped_triggers_at_start; - std::vector escaped_triggers_anywhere; - for (const auto & trigger : params.grammar_trigger_words) { - (trigger.at_start ? escaped_triggers_at_start : escaped_triggers_anywhere) - .push_back(regex_escape(trigger.word)); + std::vector patterns_at_start; + std::vector patterns_anywhere; + std::vector trigger_tokens; + for (const auto & trigger : params.grammar_triggers) { + switch (trigger.type) { + case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: + { + const auto & word = std::get(trigger.value); + patterns_anywhere.push_back(regex_escape(word)); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START: + { + const auto & pattern = std::get(trigger.value); + (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: + { + const auto & token = std::get(trigger.value); + trigger_tokens.push_back(token); + break; + } + default: + GGML_ASSERT(false && "unknown trigger type"); + } } - std::vector trigger_regexes; - if (!escaped_triggers_at_start.empty()) { - trigger_regexes.push_back("^(" + string_join(escaped_triggers_at_start, "|") + ")[\\s\\S]*"); + std::vector trigger_patterns; + if (!patterns_at_start.empty()) { + trigger_patterns.push_back("^(" + string_join(patterns_at_start, "|") + ")[\\s\\S]*"); } - if (!escaped_triggers_anywhere.empty()) { - trigger_regexes.push_back("^[\\s\\S]*?(" + string_join(escaped_triggers_anywhere, "|") + ")[\\s\\S]*"); + if (!patterns_anywhere.empty()) { + trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); } - std::vector trigger_regexes_c; - trigger_regexes_c.reserve(trigger_regexes.size()); - for (const auto & regex : trigger_regexes) { - trigger_regexes_c.push_back(regex.c_str()); + std::vector trigger_patterns_c; + trigger_patterns_c.reserve(trigger_patterns.size()); + for (const auto & regex : trigger_patterns) { + trigger_patterns_c.push_back(regex.c_str()); } grmr = params.grammar_lazy - ? llama_sampler_init_grammar_lazy(vocab, params.grammar.c_str(), "root", - trigger_regexes_c.data(), trigger_regexes_c.size(), - params.grammar_trigger_tokens.data(), params.grammar_trigger_tokens.size()) + ? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", + trigger_patterns_c.data(), trigger_patterns_c.size(), + trigger_tokens.data(), trigger_tokens.size()) : llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index bdea828aeba3f..ffad6425fc229 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -131,9 +131,22 @@ struct slot_params { lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); } - std::vector grammar_trigger_words; - for (const auto & trigger : sampling.grammar_trigger_words) { - grammar_trigger_words.push_back(trigger.word); + auto grammar_triggers = json::array(); + for (const auto & trigger : sampling.grammar_triggers) { + switch (trigger.type) { + case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: + grammar_triggers.push_back({{"word", std::get(trigger.value)}}); + break; + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: + grammar_triggers.push_back({{"pattern", std::get(trigger.value)}}); + break; + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START: + grammar_triggers.push_back({{"pattern_start", std::get(trigger.value)}}); + break; + case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: + grammar_triggers.push_back({{"token", std::get(trigger.value)}}); + break; + } } return json { @@ -171,8 +184,7 @@ struct slot_params { {"min_keep", sampling.min_keep}, {"grammar", sampling.grammar}, {"grammar_lazy", sampling.grammar_lazy}, - {"grammar_trigger_words", grammar_trigger_words}, - {"grammar_trigger_tokens", sampling.grammar_trigger_tokens}, + {"grammar_triggers", grammar_triggers}, {"preserved_tokens", sampling.preserved_tokens}, {"chat_format", common_chat_format_name(oaicompat_chat_format)}, {"samplers", samplers}, @@ -357,24 +369,6 @@ struct server_task { } { - const auto grammar_triggers = data.find("grammar_triggers"); - if (grammar_triggers != data.end()) { - for (const auto & t : *grammar_triggers) { - common_grammar_trigger trigger; - trigger.word = t.at("word"); - trigger.at_start = t.at("at_start"); - - auto ids = common_tokenize(vocab, trigger.word, /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - SRV_DBG("Grammar trigger token: %d (`%s`)\n", ids[0], trigger.word.c_str()); - params.sampling.grammar_trigger_tokens.push_back(ids[0]); - params.sampling.preserved_tokens.insert(ids[0]); - continue; - } - SRV_DBG("Grammar trigger word: `%s`\n", trigger.word.c_str()); - params.sampling.grammar_trigger_words.push_back(trigger); - } - } const auto preserved_tokens = data.find("preserved_tokens"); if (preserved_tokens != data.end()) { for (const auto & t : *preserved_tokens) { @@ -384,12 +378,48 @@ struct server_task { params.sampling.preserved_tokens.insert(ids[0]); } else { // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. - SRV_WRN("Not preserved because more than 1 token (wrong chat template override?): %s\n", t.get().c_str()); + SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); + } + } + } + const auto grammar_triggers = data.find("grammar_triggers"); + if (grammar_triggers != data.end()) { + for (const auto & t : *grammar_triggers) { + auto type = static_cast(t.at("type")); + switch (type) { + case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: + { + const std::string & word = t.at("value"); + auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); + if (ids.size() == 1) { + auto token = ids[0]; + if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), token) == params.sampling.preserved_tokens.end()) { + throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); + } + SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); + params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN, token}); + } else { + SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); + params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); + } + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START: + { + const std::string & pattern = t.at("value"); + params.sampling.grammar_triggers.push_back({type, pattern}); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: + throw std::runtime_error("Unespected token trigger"); + default: + throw std::runtime_error("Unknown trigger type"); } } } if (params.sampling.grammar_lazy) { - GGML_ASSERT(params.sampling.grammar_trigger_tokens.size() > 0 || params.sampling.grammar_trigger_words.size() > 0); + GGML_ASSERT(params.sampling.grammar_triggers.size() > 0); } } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 6f8ab2b93aac7..12a67a54e32ae 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -614,8 +614,10 @@ static json oaicompat_completion_params_parse( auto grammar_triggers = json::array(); for (const auto & trigger : chat_params.grammar_triggers) { grammar_triggers.push_back({ - {"word", trigger.word}, - {"at_start", trigger.at_start}, + {"type", (int) trigger.type}, + {"value", trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN + ? json((int) std::get(trigger.value)) + : json(std::get(trigger.value))}, }); } llama_params["grammar_triggers"] = grammar_triggers; diff --git a/include/llama.h b/include/llama.h index b26f2b05e91f8..be538dd6297f4 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1210,11 +1210,25 @@ extern "C" { const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, - const char ** trigger_regexes, - size_t num_trigger_regexes, + const char ** trigger_words, + size_t num_trigger_words, const llama_token * trigger_tokens, size_t num_trigger_tokens); + + /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 + /// @param trigger_regexes A list of (full-string) regexes that will trigger the grammar sampler. Grammar sampler will be fed content starting from the first match group. + /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included. + LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_regexes, + size_t num_trigger_regexes, + const llama_token * trigger_tokens, + size_t num_trigger_tokens); + + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 06c3d188d4c2e..f20ec355ce4fa 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -969,7 +969,7 @@ struct llama_grammar * llama_grammar_init_impl( /* .awaiting_trigger = */ false, /* .trigger_buffer = */ "", /* .trigger_tokens = */ {}, - /* .trigger_regexes = */ {}, + /* .trigger_patterns = */ {}, }; } @@ -978,8 +978,8 @@ struct llama_grammar * llama_grammar_init_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_regexes, - size_t num_trigger_regexes, + const char ** trigger_patterns, + size_t num_trigger_patterns, const llama_token * trigger_tokens, size_t num_trigger_tokens) { llama_grammar_parser parser; @@ -1050,14 +1050,14 @@ struct llama_grammar * llama_grammar_init_impl( } while (true); std::vector vec_trigger_tokens; - std::vector> vec_trigger_regexes; + std::vector> vec_trigger_patterns; for (size_t i = 0; i < num_trigger_tokens; i++) { GGML_ASSERT(trigger_tokens != nullptr); vec_trigger_tokens.push_back(trigger_tokens[i]); } - for (size_t i = 0; i < num_trigger_regexes; i++) { - GGML_ASSERT(trigger_regexes != nullptr); - vec_trigger_regexes.emplace_back(trigger_regexes[i], trigger_regexes[i]); + for (size_t i = 0; i < num_trigger_patterns; i++) { + GGML_ASSERT(trigger_patterns != nullptr); + vec_trigger_patterns.emplace_back(trigger_patterns[i], trigger_patterns[i]); } // Important: vec_rules has to be moved here, not copied, because stacks contains @@ -1072,7 +1072,7 @@ struct llama_grammar * llama_grammar_init_impl( /* .awaiting_trigger = */ lazy, /* .trigger_buffer = */ "", std::move(vec_trigger_tokens), - std::move(vec_trigger_regexes), + std::move(vec_trigger_patterns), }; } @@ -1094,7 +1094,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.awaiting_trigger, grammar.trigger_buffer, grammar.trigger_tokens, - grammar.trigger_regexes, + grammar.trigger_patterns, }; // redirect elements in stacks to point to new rules @@ -1172,7 +1172,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token grammar.trigger_buffer += piece; std::smatch match; - for (const auto & [_, regex] : grammar.trigger_regexes) { + for (const auto & [_, regex] : grammar.trigger_patterns) { if (std::regex_match(grammar.trigger_buffer, match, regex)) { grammar.awaiting_trigger = false; // get from the first match to the end of the string diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 8d9b1a81dfd1c..a9b6f99ec34f5 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -124,7 +124,7 @@ struct llama_grammar { std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). std::vector> - trigger_regexes; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated + trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated // string, and the grammar will be given the string from the first match group onwards. }; @@ -145,8 +145,8 @@ struct llama_grammar * llama_grammar_init_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_regexes, - size_t num_trigger_regexes, + const char ** trigger_patterns, + size_t num_trigger_patterns, const llama_token * trigger_tokens, size_t num_trigger_tokens); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 7d5f9e86584c1..388998d949d29 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1446,10 +1446,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_regexes, - size_t num_trigger_regexes, + const char ** trigger_words, + size_t num_trigger_words, const llama_token * trigger_tokens, - size_t num_trigger_tokens); + size_t num_trigger_tokens, + const char ** trigger_patterns, + size_t num_trigger_patterns); static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_grammar *) smpl->ctx; @@ -1457,14 +1459,14 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { return; } - std::vector trigger_regexes_c; - trigger_regexes_c.reserve(ctx->grammar->trigger_regexes.size()); - for (auto & [pattern, _] : ctx->grammar->trigger_regexes) { - trigger_regexes_c.push_back(pattern.c_str()); + std::vector trigger_patterns_c; + trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size()); + for (auto & [pattern, _] : ctx->grammar->trigger_patterns) { + trigger_patterns_c.push_back(pattern.c_str()); } auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(), - ctx->grammar->lazy, trigger_regexes_c.data(), trigger_regexes_c.size(), + ctx->grammar->lazy, trigger_patterns_c.data(), trigger_patterns_c.size(), ctx->grammar->trigger_tokens.data(), ctx->grammar->trigger_tokens.size()); llama_grammar_free_impl(ctx->grammar); @@ -1474,7 +1476,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; - auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0); + auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0); // copy the state { @@ -1515,18 +1517,36 @@ static struct llama_sampler * llama_sampler_init_grammar_impl( const char * grammar_str, const char * grammar_root, bool lazy, - const char ** trigger_regexes, - size_t num_trigger_regexes, + const char ** trigger_words, + size_t num_trigger_words, const llama_token * trigger_tokens, - size_t num_trigger_tokens) { + size_t num_trigger_tokens, + const char ** trigger_patterns, + size_t num_trigger_patterns) { auto * ctx = new llama_sampler_grammar; if (grammar_str != nullptr && grammar_str[0] != '\0') { + // TODO: remove trigger_words support. + if (trigger_words != nullptr && num_trigger_words > 0) { + GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0); + std::string trigger_pattern("[\\s\\S]*?("); + for (size_t i = 0; i < num_trigger_words; ++i) { + static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); + if (i > 0) { + trigger_pattern += "|"; + } + trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0"); + } + trigger_pattern += ")[\\s\\S]*"; + auto trigger_pattern_c = trigger_pattern.c_str(); + trigger_patterns = &trigger_pattern_c; + num_trigger_patterns = 1; + } *ctx = { /* .vocab = */ vocab, /* .grammar_str = */ grammar_str, /* .grammar_root = */ grammar_root, - /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_regexes, num_trigger_regexes, trigger_tokens, num_trigger_tokens), + /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens), }; } else { *ctx = { @@ -1547,7 +1567,7 @@ struct llama_sampler * llama_sampler_init_grammar( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { - return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0); + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ false, nullptr, 0, nullptr, 0, nullptr, 0); } struct llama_sampler * llama_sampler_init_grammar_lazy( @@ -1558,7 +1578,18 @@ struct llama_sampler * llama_sampler_init_grammar_lazy( size_t num_trigger_words, const llama_token * trigger_tokens, size_t num_trigger_tokens) { - return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens); + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, trigger_words, num_trigger_words, trigger_tokens, num_trigger_tokens, nullptr, 0); +} + +struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( + const struct llama_vocab * vocab, + const char * grammar_str, + const char * grammar_root, + const char ** trigger_patterns, + size_t num_trigger_patterns, + const llama_token * trigger_tokens, + size_t num_trigger_tokens) { + return llama_sampler_init_grammar_impl(vocab, grammar_str, grammar_root, /* lazy= */ true, nullptr, 0, trigger_tokens, num_trigger_tokens, trigger_patterns, num_trigger_patterns); } // penalties diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 2b1226292177b..4fb8198374e71 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -237,12 +237,35 @@ static void test_templates(const struct common_chat_templates * tmpls, const std auto earliest_trigger_pos = std::string::npos; auto constrained = data.delta; for (const auto & trigger : data.params.grammar_triggers) { - auto pos = constrained.find(trigger.word); - if (pos == std::string::npos) { - continue; + size_t pos = std::string::npos; + std::smatch match; + switch (trigger.type) { + case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: + { + const auto & word = std::get(trigger.value); + pos = constrained.find(word); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: + { + const auto & pattern = std::get(trigger.value); + if (std::regex_search(constrained, match, std::regex(pattern))) { + pos = match.position(); + } + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START: + { + const auto & pattern = std::get(trigger.value); + if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) { + pos = 0; + } + break; + } + default: + throw std::runtime_error("Unknown trigger type"); } - if (pos > 0 && trigger.at_start) { - fprintf(stderr, "Trigger %s not at start of message, skipping:\n\n%s\n\n", trigger.word.c_str(), constrained.c_str()); + if (pos == std::string::npos) { continue; } if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) { @@ -260,7 +283,8 @@ static void test_templates(const struct common_chat_templates * tmpls, const std if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) { throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta + - "\n\nGrammar: " + data.params.grammar); + "\n\nConstrained: " + constrained + + "\n\nGrammar: " + data.params.grammar); } } } @@ -696,6 +720,13 @@ static void test_template_output_parsers() { " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" "```", COMMON_CHAT_FORMAT_HERMES_2_PRO)); + assert_msg_equals(message_assist_call, common_chat_parse( + "```json\n" + "\n" + " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n" + " \n" + "``` ", + COMMON_CHAT_FORMAT_HERMES_2_PRO)); assert_msg_equals(message_assist_call, common_chat_parse( "\n" " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" @@ -869,7 +900,7 @@ static void test_template_output_parsers() { } int main(int argc, char ** argv) { - try { + // try { #ifndef _WIN32 if (argc > 1) { common_chat_templates_inputs inputs; @@ -907,8 +938,8 @@ int main(int argc, char ** argv) { std::cout << "\n[chat] All tests passed!" << '\n'; } return 0; - } catch (const std::exception & e) { - std::cerr << "Error: " << e.what() << '\n'; - return 1; - } + // } catch (const std::exception & e) { + // std::cerr << "Error: " << e.what() << '\n'; + // return 1; + // } } From 2d882d2eff56bd58cb27363158ce6417d2f29fb9 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 21:38:49 +0000 Subject: [PATCH 12/43] deprecate llama_sampler_init_grammar_lazy --- include/llama.h | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/include/llama.h b/include/llama.h index be538dd6297f4..0b6522923ffeb 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1203,28 +1203,26 @@ extern "C" { const char * grammar_str, const char * grammar_root); - /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 - /// @param trigger_regexes A list of (full-string) regexes that will trigger the grammar sampler. Grammar sampler will be fed content starting from the first match group. - /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included. - LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( + DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, const char ** trigger_words, size_t num_trigger_words, const llama_token * trigger_tokens, - size_t num_trigger_tokens); + size_t num_trigger_tokens), + "use llama_sampler_init_grammar_lazy_patterns instead"); /// @details Lazy grammar sampler, introduced in https://github.com/ggml-org/llama.cpp/pull/9639 - /// @param trigger_regexes A list of (full-string) regexes that will trigger the grammar sampler. Grammar sampler will be fed content starting from the first match group. + /// @param trigger_patterns A list of patterns that will trigger the grammar sampler. Pattern will be matched from the start of the generation output, and grammar sampler will be fed content starting from its first match group. /// @param trigger_tokens A list of tokens that will trigger the grammar sampler. Grammar sampler will be fed content starting from the trigger token included. LLAMA_API struct llama_sampler * llama_sampler_init_grammar_lazy_patterns( const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root, - const char ** trigger_regexes, - size_t num_trigger_regexes, + const char ** trigger_patterns, + size_t num_trigger_patterns, const llama_token * trigger_tokens, size_t num_trigger_tokens); From 7e562f80cebc473ba5358d04388d06477408232a Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 21:39:02 +0000 Subject: [PATCH 13/43] fix tool_bench.py imports --- scripts/tool_bench.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index 826112ff54588..a199a41a28885 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -15,32 +15,36 @@ cmake --build build -j && ( \ export LLAMA_CACHE=$HOME/Library/Caches/llama.cpp ; export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server ; - export ARGS=( --n=10 --temps=0,0.5,0.75,1,1.5,2,5, --append=all.jsonl ) ; - ./scripts/tool_bench.py ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 1.5B Q4_K_M" --hf bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF --ollama deepseek-r1:1.5b ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M" --hf bartowski/Qwen2.5-7B-Instruct-GGUF ; + export ARGS=( --n=10 --temps=0,0.5,0.75,1,1.5,2,5, --output=../new.jsonl ) ; ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --hf bartowski/Llama-3.2-1B-Instruct-GGUF --ollama llama3.2:1b-instruct-q4_K_M ; ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --hf bartowski/Llama-3.2-3B-Instruct-GGUF --ollama llama3.1:3b ; ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M" --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF --ollama llama3.1:8b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ; + ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M" --hf bartowski/Qwen2.5-7B-Instruct-GGUF ; ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.3 Instruct 70B Q4_K_M" --hf bartowski/Llama-3.3-70B-Instruct-GGUF ; ./scripts/tool_bench.py ${ARGS[@]} --model "Mistral Nemo 2407 Q4_K_M" --hf bartowski/Mistral-Nemo-Instruct-2407-GGUF --ollama mistral-nemo:12b ; ./scripts/tool_bench.py ${ARGS[@]} --model "Functionary Small v3.2 Q4_K_M" --hf bartowski/functionary-small-v3.2-GGUF ; + ./scripts/tool_bench.py ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 1.5B Q4_K_M" --hf bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF --ollama deepseek-r1:1.5b ; ) ''' import argparse from contextlib import contextmanager from statistics import mean, median +import subprocess import pytest # ensure grandparent path is in sys.path from pathlib import Path import sys +import time + sys.path.insert(0, Path(__file__).parent.parent.as_posix()) print(sys.path) -from examples.server.tests.unit.test_tool_call import * +from examples.server.tests.utils import ServerProcess +from examples.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather @contextmanager From ea588ce5e8d71b76d342a80cd2061787a35bd4bf Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 21:57:43 +0000 Subject: [PATCH 14/43] Update tool_bench.py --- scripts/tool_bench.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index a199a41a28885..2edac6952d12e 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -13,6 +13,7 @@ # /// ''' cmake --build build -j && ( \ + export RETRIES=3 ; export LLAMA_CACHE=$HOME/Library/Caches/llama.cpp ; export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server ; export ARGS=( --n=10 --temps=0,0.5,0.75,1,1.5,2,5, --output=../new.jsonl ) ; @@ -33,7 +34,7 @@ from contextlib import contextmanager from statistics import mean, median import subprocess -import pytest +import json # ensure grandparent path is in sys.path from pathlib import Path @@ -73,7 +74,7 @@ def stop(): parser.add_argument('--hf', type=str, help='GGUF huggingface model repo id (+ optional quant) to test w/ llama-server') parser.add_argument('--chat-template', type=str, help='Chat template override for llama-server') parser.add_argument('--ollama', type=str, help='Ollama model tag to test') - parser.add_argument('--n', type=int, help='Number of times to run each test', default=30) + parser.add_argument('--n', type=int, help='Number of times to run each test', default=10) parser.add_argument('--temps', type=str, help='Comma-separated list of temperatures') parser.add_argument('--top-p', type=float, help='top_p') parser.add_argument('--top-k', type=int, help='top_k') @@ -172,7 +173,7 @@ def elapsed(): with scoped_server(server): server.start(timeout_seconds=TIMEOUT_SERVER_START) - for ignore_chat_grammar in [False, True]: + for ignore_chat_grammar in [False]: run( server, implementation="llama-server" + (" (no grammar)" if ignore_chat_grammar else ""), From bbb2af734a58dc42e10a7b2ce1c135ff019a2e07 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 22:39:23 +0000 Subject: [PATCH 15/43] update tool_Bench --- scripts/plot_tool_call_tests.py | 190 ----------------------- scripts/tool_bench.py | 258 ++++++++++++++++++++++++-------- 2 files changed, 193 insertions(+), 255 deletions(-) delete mode 100644 scripts/plot_tool_call_tests.py diff --git a/scripts/plot_tool_call_tests.py b/scripts/plot_tool_call_tests.py deleted file mode 100644 index b54aecce526dd..0000000000000 --- a/scripts/plot_tool_call_tests.py +++ /dev/null @@ -1,190 +0,0 @@ -#!/usr/bin/env python3 -""" -Model Performance Analysis and Visualization Tool - -This script analyzes JSON performance data for different model implementations and tests, -creating a heatmap visualization of success ratios. It handles multiple input files and -supports various model configurations. - -Usage: - python script.py input_file1.json [input_file2.json ...] -""" - -import json -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns -import sys -from typing import Dict, List, Tuple, Set, Any -from pathlib import Path -import logging - -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s' -) -logger = logging.getLogger(__name__) - -class ModelAnalyzer: - def __init__(self): - self.lines: List[Dict] = [] - self.data_dict: Dict[Tuple, float] = {} - self.models: List[str] = [] - self.temps: Set[float] = set() - self.tests: Set[str] = set() - self.impls: Set[str] = set() - - self.column_groups = [ - ("llama-server", []), # Tests will be populated dynamically - ("llama-server (no grammar)", []), - ("ollama", []) - ] - - def read_files(self, files: List[str]) -> None: - """Read and parse JSON data from input files.""" - for file in files: - path = Path(file) - if not path.exists(): - logger.error(f"File not found: {file}") - continue - - try: - with path.open() as f: - raw_data = f.read() - logger.info(f"Reading {file} ({len(raw_data)} bytes)") - - for line_num, line in enumerate(raw_data.split('\n'), 1): - line = line.strip() - if not line: - continue - try: - record = json.loads(line) - self.lines.append(record) - except json.JSONDecodeError as e: - logger.warning(f"Invalid JSON at {file}:{line_num} - {e}") - except Exception as e: - logger.error(f"Error processing {file}: {e}") - - def process_data(self) -> None: - """Process the loaded data and organize it for visualization.""" - for rec in self.lines: - try: - model = rec["model"] - temp = rec["temp"] - impl = rec["implementation"] - test = rec["test"] - success = rec["success_ratio"] - - self.data_dict[(model, temp, impl, test)] = success - - if model not in self.models: - self.models.append(model) - self.temps.add(temp) - self.tests.add(test) - self.impls.add(impl) - - except KeyError as e: - logger.warning(f"Missing required field in record: {e}") - - # Sort the collected values - self.temps = sorted(list(self.temps), key=lambda x: x if x is not None else -1) - self.tests = sorted(list(self.tests)) - - # Update column groups with actual tests - self.column_groups = [ - (impl, list(self.tests)) for impl, _ in self.column_groups - if impl in self.impls - ] - - def create_matrix(self) -> pd.DataFrame: - """Create a matrix for visualization.""" - all_cols = [ - (impl, test) - for impl, tests in self.column_groups - for test in tests - ] - - matrix = [] - index = [] - - for model in self.models: - for temp in self.temps: - index.append(f"{model} @ {temp}") - row_vals = [ - self.data_dict.get((model, temp, impl, test), np.nan) - for impl, test in all_cols - ] - matrix.append(row_vals) - - # Create column labels - col_labels = [f"{impl}\n({test})" for impl, test in all_cols] - - return pd.DataFrame(matrix, index=index, columns=col_labels) - - def plot_heatmap(self, df: pd.DataFrame, output_file: str = None) -> None: - """Create and display/save the heatmap visualization.""" - plt.figure(figsize=(12, 6)) - - sns.heatmap( - df, - annot=True, - cmap="RdYlGn", - vmin=0.0, - vmax=1.0, - cbar=True, - fmt=".2f", - center=0.5, - square=True, - linewidths=0.5, - cbar_kws={"label": "Success Ratio"} - ) - - plt.title("Model Performance Analysis\nSuccess Ratios by Implementation & Test", - pad=20) - plt.xlabel("Implementation and Test", labelpad=10) - plt.ylabel("Model @ Temperature", labelpad=10) - - plt.xticks(rotation=45, ha='right') - plt.yticks(rotation=0) - - plt.tight_layout() - - if output_file: - plt.savefig(output_file, dpi=300, bbox_inches='tight') - logger.info(f"Plot saved to {output_file}") - else: - plt.show() - -def main(): - if len(sys.argv) < 2: - logger.error("Please provide at least one input file") - sys.exit(1) - - analyzer = ModelAnalyzer() - - # Process input files - analyzer.read_files(sys.argv[1:]) - - if not analyzer.lines: - logger.error("No valid data was loaded") - sys.exit(1) - - # Process the data - analyzer.process_data() - - # Log summary statistics - logger.info(f"Processed {len(analyzer.lines)} lines") - logger.info(f"Found {len(analyzer.data_dict)} valid data points") - logger.info(f"Models: {analyzer.models}") - logger.info(f"Temperatures: {analyzer.temps}") - logger.info(f"Tests: {analyzer.tests}") - logger.info(f"Implementations: {analyzer.impls}") - - # Create and plot the visualization - df = analyzer.create_matrix() - # analyzer.plot_heatmap(df, "model_analysis.png") - analyzer.plot_heatmap(df) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index 2edac6952d12e..8dded389ef6dc 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -9,6 +9,7 @@ # "seaborn", # "requests", # "wget", +# "typer", # ] # /// ''' @@ -16,32 +17,39 @@ export RETRIES=3 ; export LLAMA_CACHE=$HOME/Library/Caches/llama.cpp ; export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server ; - export ARGS=( --n=10 --temps=0,0.5,0.75,1,1.5,2,5, --output=../new.jsonl ) ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --hf bartowski/Llama-3.2-1B-Instruct-GGUF --ollama llama3.2:1b-instruct-q4_K_M ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --hf bartowski/Llama-3.2-3B-Instruct-GGUF --ollama llama3.1:3b ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M" --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF --ollama llama3.1:8b ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M" --hf bartowski/Qwen2.5-7B-Instruct-GGUF ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Llama 3.3 Instruct 70B Q4_K_M" --hf bartowski/Llama-3.3-70B-Instruct-GGUF ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Mistral Nemo 2407 Q4_K_M" --hf bartowski/Mistral-Nemo-Instruct-2407-GGUF --ollama mistral-nemo:12b ; - ./scripts/tool_bench.py ${ARGS[@]} --model "Functionary Small v3.2 Q4_K_M" --hf bartowski/functionary-small-v3.2-GGUF ; - ./scripts/tool_bench.py ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 1.5B Q4_K_M" --hf bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF --ollama deepseek-r1:1.5b ; + export ARGS=( --n 10 --temp -1 --temp 0 --temp 0.5 --temp 0.75 --temp 1 --temp 1.5 --temp 2 --temp 5 --output ../qw.jsonl ) ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --hf bartowski/Llama-3.2-1B-Instruct-GGUF --ollama llama3.2:1b-instruct-q4_K_M ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --hf bartowski/Llama-3.2-3B-Instruct-GGUF --ollama llama3.1:3b ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M" --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF --ollama llama3.1:8b ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M" --hf bartowski/Qwen2.5-7B-Instruct-GGUF ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.3 Instruct 70B Q4_K_M" --hf bartowski/Llama-3.3-70B-Instruct-GGUF ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Mistral Nemo 2407 Q4_K_M" --hf bartowski/Mistral-Nemo-Instruct-2407-GGUF --ollama mistral-nemo:12b ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Functionary Small v3.2 Q4_K_M" --hf bartowski/functionary-small-v3.2-GGUF ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 1.5B Q4_K_M" --hf bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF --ollama deepseek-r1:1.5b ; ) + + ./scripts/tool_bench.py plot ../qw.jsonl --output ../qw.png ''' -import argparse + from contextlib import contextmanager +from pathlib import Path +from pathlib import Path from statistics import mean, median -import subprocess +from typing import Annotated, List, Optional +from typing import Dict, List, Tuple, Set, Any import json - -# ensure grandparent path is in sys.path -from pathlib import Path +import logging +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import subprocess import sys import time - sys.path.insert(0, Path(__file__).parent.parent.as_posix()) print(sys.path) from examples.server.tests.utils import ServerProcess @@ -68,44 +76,162 @@ def stop(): stop() -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Run tests for the chat server.') - parser.add_argument('--model', type=str, help='Name of the model to test (implementation agnostic)', required=True) - parser.add_argument('--hf', type=str, help='GGUF huggingface model repo id (+ optional quant) to test w/ llama-server') - parser.add_argument('--chat-template', type=str, help='Chat template override for llama-server') - parser.add_argument('--ollama', type=str, help='Ollama model tag to test') - parser.add_argument('--n', type=int, help='Number of times to run each test', default=10) - parser.add_argument('--temps', type=str, help='Comma-separated list of temperatures') - parser.add_argument('--top-p', type=float, help='top_p') - parser.add_argument('--top-k', type=int, help='top_k') - parser.add_argument('--seed', type=int, help='Random seed') - parser.add_argument('--port', type=int, help='llama-server port') - parser.add_argument('--output', type=str, help='Output JSON file') - parser.add_argument('--append', type=str, help='Output JSON file') - - - args = parser.parse_args() +import typer - # Check only one of output and append - assert (args.output is None) != (args.append is None), "Exactly one of --output and --append must be specified" +app = typer.Typer() + + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) - # chat_template = args.chat_template - n = args.n +@app.command() +def plot(files: List[Path], output: Optional[Path] = None): + + lines: List[Dict] = [] + for file in files: + if not file.exists(): + logger.error(f"File not found: {file}") + continue + + try: + with file.open() as f: + raw_data = f.read() + logger.info(f"Reading {file} ({len(raw_data)} bytes)") + + for line_num, line in enumerate(raw_data.split('\n'), 1): + line = line.strip() + if not line: + continue + try: + record = json.loads(line) + lines.append(record) + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON at {file}:{line_num} - {e}") + except Exception as e: + logger.error(f"Error processing {file}: {e}") + if not lines: + raise Exception("No valid data was loaded") + + data_dict: Dict[Tuple, float] = {} + models: List[str] = [] + temps = set() + tests = set() + impls = set() + for rec in lines: + try: + model = rec["model"] + temp = rec["temp"] + impl = rec["implementation"] + test = rec["test"] + success = rec["success_ratio"] + + + data_dict[(model, temp, impl, test)] = success + + if model not in models: + models.append(model) + temps.add(temp) + tests.add(test) + impls.add(impl) + + except KeyError as e: + logger.warning(f"Missing required field in record: {e}") + + # Sort the collected values + temps = list(sorted(temps, key=lambda x: x if x is not None else -1)) + tests = list(sorted(tests)) + impls = list(sorted(impls)) + + + logger.info(f"Processed {len(lines)} lines") + logger.info(f"Found {len(data_dict)} valid data points") + logger.info(f"Models: {models}") + logger.info(f"Temperatures: {temps}") + logger.info(f"Tests: {tests}") + logger.info(f"Implementations: {impls}") + + + matrix = [] + index = [] + + all_cols = [ + (impl, test) + for impl in impls + for test in tests + ] + for model in models: + for temp in temps: + index.append(f"{model} @ {temp}") + row_vals = [ + data_dict.get((model, temp, impl, test), np.nan) + for impl, test in all_cols + ] + matrix.append(row_vals) + + columns = [f"{impl}\n({test})" for impl, test in all_cols] + + df = pd.DataFrame(matrix, index=index, columns=columns) + + plt.figure(figsize=(12, 6)) + + sns.heatmap( + df, annot=True, cmap="RdYlGn", vmin=0.0, vmax=1.0, cbar=True, fmt=".2f", center=0.5, square=True, linewidths=0.5, + cbar_kws={"label": "Success Ratio"}, + ) + + plt.title("Tool Call Bench\nSuccess Ratios by Implementation & Test", pad=20) + plt.xlabel("Implementation and Test", labelpad=10) + plt.ylabel("Model @ Temperature", labelpad=10) + + plt.xticks(rotation=45, ha='right') + plt.yticks(rotation=0) + + plt.tight_layout() + + if output: + plt.savefig(output, dpi=300, bbox_inches='tight') + logger.info(f"Plot saved to {output}") + else: + plt.show() + +@app.command() +def run( + output: Annotated[Path, typer.Option(help="Output JSON file")], + model: Annotated[Optional[str], typer.Option(help="Name of the model to test (implementation agnostic)")] = None, + hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None, + chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None, + ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None, + n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10, + temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None, + top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None, + top_k: Annotated[Optional[int], typer.Option(help="top_k")] = None, + seed: Annotated[Optional[int], typer.Option(help="Random seed")] = None, + port: Annotated[int, typer.Option(help="llama-server port")] = 8084, + force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False, + append: Annotated[bool, typer.Option(help="Append to output file")] = False, +): + # Check only one of output and append + n_predict = 512 - with open(args.output or args.append, 'w' if args.output else 'a') as output_file: + assert force or not output.exists(), f"Output file already exists: {output}; use --force to overwrite" + + with output.open('a' if append else 'w') as output_file: def run(server: ServerProcess, *, implementation: str, model_id: str, temp: float | None = None, output_kwargs={}, request_kwargs={}): request_kwargs = {**request_kwargs} if temp is not None: request_kwargs['temperature'] = temp - if args.top_p is not None: - request_kwargs['top_p'] = args.top_p - if args.top_k is not None: - request_kwargs['top_k'] = args.top_k - if args.seed is not None: - request_kwargs['seed'] = args.seed + if top_p is not None: + request_kwargs['top_p'] = top_p + if top_k is not None: + request_kwargs['top_k'] = top_k + if seed is not None: + request_kwargs['seed'] = seed request_kwargs['cache_prompt'] = False @@ -120,7 +246,7 @@ def run(server: ServerProcess, *, implementation: str, model_id: str, temp: floa failures = [] success_times = [] failure_times = [] - print(f"Running {test_name} ({implementation}, {args.model}): ", file=sys.stderr, flush=True) + print(f"Running {test_name} ({implementation}, {model}): ", file=sys.stderr, flush=True) for i in range(n): start_time = time.time() def elapsed(): @@ -139,13 +265,13 @@ def elapsed(): failures.append(str(e)) print('\n', file=sys.stderr, flush=True) output_file.write(json.dumps({**output_kwargs, **dict( - model=args.model, + model=model, implementation=implementation, model_id=model_id, test=test_name, - temp=temp, - top_p=args.top_p, - top_k=args.top_k, + temp=t, + top_p=top_p, + top_k=top_k, success_ratio=float(success_count) / n, avg_time=mean(success_times + failure_times), median_time=median(success_times + failure_times), @@ -157,18 +283,17 @@ def elapsed(): )}) + '\n') output_file.flush() - temps = [float(temp) if temp != "" else None for temp in args.temps.split(',')] if args.temps is not None else [None] - for temp in temps: - if args.hf is not None: + for t in [None] if temp is None else [t if t >= 0 else None for t in temp]: + if hf is not None: server = ServerProcess() server.n_slots = 1 server.jinja = True server.n_predict = 512 # High because of DeepSeek R1 - server.model_hf_repo = args.hf + server.model_hf_repo = hf server.model_hf_file = None - server.chat_template = args.chat_template - if args.port is not None: - server.server_port = args.port + server.chat_template = chat_template + if port is not None: + server.server_port = port # server.debug = True with scoped_server(server): @@ -177,32 +302,35 @@ def elapsed(): run( server, implementation="llama-server" + (" (no grammar)" if ignore_chat_grammar else ""), - model_id=args.hf, - temp=temp, + model_id=hf, + temp=t, output_kwargs=dict( - chat_template=args.chat_template, + chat_template=chat_template, ), request_kwargs=dict( ignore_chat_grammar=ignore_chat_grammar, ), ) - if args.ollama is not None: + if ollama is not None: server = ServerProcess() server.server_port = 11434 server.server_host = "localhost" - subprocess.check_call(["ollama", "pull", args.ollama]) + subprocess.check_call(["ollama", "pull", ollama]) with scoped_server(server): run( server, implementation="ollama", - model_id=args.ollama, - temp=temp, + model_id=ollama, + temp=t, output_kwargs=dict( chat_template=None, ), request_kwargs=dict( - model=args.ollama, + model=ollama, ), ) + +if __name__ == "__main__": + app() \ No newline at end of file From 520b6237c9b833000e850dae397a4fe66ed1189e Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 22:46:43 +0000 Subject: [PATCH 16/43] Update tool_bench.py --- scripts/tool_bench.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index 8dded389ef6dc..3cad333a89323 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -17,17 +17,17 @@ export RETRIES=3 ; export LLAMA_CACHE=$HOME/Library/Caches/llama.cpp ; export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server ; - export ARGS=( --n 10 --temp -1 --temp 0 --temp 0.5 --temp 0.75 --temp 1 --temp 1.5 --temp 2 --temp 5 --output ../qw.jsonl ) ; - ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ; - ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --hf bartowski/Llama-3.2-1B-Instruct-GGUF --ollama llama3.2:1b-instruct-q4_K_M ; - ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --hf bartowski/Llama-3.2-3B-Instruct-GGUF --ollama llama3.1:3b ; - ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M" --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF --ollama llama3.1:8b ; - ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ; - ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M" --hf bartowski/Qwen2.5-7B-Instruct-GGUF ; - ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.3 Instruct 70B Q4_K_M" --hf bartowski/Llama-3.3-70B-Instruct-GGUF ; - ./scripts/tool_bench.py run ${ARGS[@]} --model "Mistral Nemo 2407 Q4_K_M" --hf bartowski/Mistral-Nemo-Instruct-2407-GGUF --ollama mistral-nemo:12b ; - ./scripts/tool_bench.py run ${ARGS[@]} --model "Functionary Small v3.2 Q4_K_M" --hf bartowski/functionary-small-v3.2-GGUF ; - ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 1.5B Q4_K_M" --hf bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF --ollama deepseek-r1:1.5b ; + export ARGS=( --n 10 --temp -1 --temp 0 --temp 0.5 --temp 0.75 --temp 1 --temp 1.5 --temp 2 --temp 5 ) ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --output llama1b.jsonl --hf bartowski/Llama-3.2-1B-Instruct-GGUF --ollama llama3.2:1b-instruct-q4_K_M ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --output llama3b.jsonl --hf bartowski/Llama-3.2-3B-Instruct-GGUF --ollama llama3.1:3b ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M" --output llama8b.jsonl --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF --ollama llama3.1:8b ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M" --output qwen7b.jsonl --hf bartowski/Qwen2.5-7B-Instruct-GGUF ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.3 Instruct 70B Q4_K_M" --output llama70b.jsonl --hf bartowski/Llama-3.3-70B-Instruct-GGUF ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Mistral Nemo 2407 Q4_K_M" --output nemo.jsonl --hf bartowski/Mistral-Nemo-Instruct-2407-GGUF --ollama mistral-nemo:12b ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Functionary Small v3.2 Q4_K_M" --output funcsmall.jsonl --hf bartowski/functionary-small-v3.2-GGUF ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 1.5B Q4_K_M" --output dsqw1.5b.jsonl --hf bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF --ollama deepseek-r1:1.5b ; ) ./scripts/tool_bench.py plot ../qw.jsonl --output ../qw.png From 9276a53835e5b09622b97b7a55079e19f99e3500 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 23:26:40 +0000 Subject: [PATCH 17/43] Update tool_bench.py --- scripts/tool_bench.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index 3cad333a89323..17bcdcb41b124 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -30,7 +30,12 @@ ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 1.5B Q4_K_M" --output dsqw1.5b.jsonl --hf bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF --ollama deepseek-r1:1.5b ; ) - ./scripts/tool_bench.py plot ../qw.jsonl --output ../qw.png + ./scripts/tool_bench.py plot qwen1.5b.jsonl + ./scripts/tool_bench.py plot *.jsonl --output all.png + + for f in *.jsonl; do + ./scripts/tool_bench.py plot $f --output ${f%.jsonl}.png + done ''' From e2641de5dc88aa5caf1e360e1de2cba3c07cceb8 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 23:27:08 +0000 Subject: [PATCH 18/43] disable python code constraints (causes hangs) --- common/chat.cpp | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index fd3552c91ce57..a71628fa3df7f 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1002,9 +1002,10 @@ static std::string add_python_code_arguments_rule(const std::string & name, cons } static std::string add_json_tool_args_rule(const std::string & name, const json & parameters, const common_grammar_builder & builder) { - if (name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { - return add_python_code_arguments_rule(name + "-code-args", builder); - } else { + // if (name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { + // return add_python_code_arguments_rule(name + "-code-args", builder); + // } else + { return builder.add_schema(name + "-args", parameters); } } @@ -1031,9 +1032,10 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com std::vector kvs; for (const auto & [key, value] : parameters.at("properties").items()) { - if (name == "python" && key == "code") { - kvs.push_back("\"" + key + "=\\\"\" " + add_escaped_python_code_soup_rule(builder) + " \"\\\"\""); // NOLINT - } else { + // if (name == "python" && key == "code") { + // kvs.push_back("\"" + key + "=\\\"\" " + add_escaped_python_code_soup_rule(builder) + " \"\\\"\""); // NOLINT + // } else + { kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT } } @@ -1436,13 +1438,14 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat std::string name = function.at("name"); auto parameters = function.at("parameters"); builder.resolve_refs(parameters); - if (name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { - tool_rules.push_back(builder.add_rule(name + "-call", - "\"{\" space " - "\"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space " - "\"\\\"arguments\\\"\" space \":\" space " + add_python_code_arguments_rule(name + "-code-arguments", builder) + " " - "\"}\" space ")); - } else { + // if (name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { + // tool_rules.push_back(builder.add_rule(name + "-call", + // "\"{\" space " + // "\"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space " + // "\"\\\"arguments\\\"\" space \":\" space " + add_python_code_arguments_rule(name + "-code-arguments", builder) + " " + // "\"}\" space ")); + // } else + { tool_rules.push_back(builder.add_schema(name + "-call", { {"type", "object"}, {"properties", json { From 03fe156a4ea5b39884250e06fbabb74b69413778 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 23:29:22 +0000 Subject: [PATCH 19/43] Update fetch_server_test_models.py --- scripts/fetch_server_test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fetch_server_test_models.py b/scripts/fetch_server_test_models.py index 05690b1385468..e6775bfc5867c 100755 --- a/scripts/fetch_server_test_models.py +++ b/scripts/fetch_server_test_models.py @@ -75,7 +75,7 @@ def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, N logging.info(f' - {m.hf_repo} / {m.hf_file}') cli_path = os.environ.get( - 'LLAMA_SERVER_BIN_PATH', + 'LLAMA_CLI_BIN_PATH', os.path.join( os.path.dirname(__file__), '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli')) From b9f61203f3ead6ab352329c4ae6bcad4b7b58901 Mon Sep 17 00:00:00 2001 From: ochafik Date: Wed, 19 Feb 2025 23:54:35 +0000 Subject: [PATCH 20/43] Update tool_bench.py --- scripts/tool_bench.py | 63 +++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index 17bcdcb41b124..42a9095ffcde8 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -17,7 +17,7 @@ export RETRIES=3 ; export LLAMA_CACHE=$HOME/Library/Caches/llama.cpp ; export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server ; - export ARGS=( --n 10 --temp -1 --temp 0 --temp 0.5 --temp 0.75 --temp 1 --temp 1.5 --temp 2 --temp 5 ) ; + export ARGS=( --n 30 --temp -1 --temp 0 --temp 0.5 --temp 0.75 --temp 1 --temp 1.5 --temp 2 --temp 5 ) ; ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ; ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --output llama1b.jsonl --hf bartowski/Llama-3.2-1B-Instruct-GGUF --ollama llama3.2:1b-instruct-q4_K_M ; ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --output llama3b.jsonl --hf bartowski/Llama-3.2-3B-Instruct-GGUF --ollama llama3.1:3b ; @@ -29,10 +29,10 @@ ./scripts/tool_bench.py run ${ARGS[@]} --model "Functionary Small v3.2 Q4_K_M" --output funcsmall.jsonl --hf bartowski/functionary-small-v3.2-GGUF ; ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 1.5B Q4_K_M" --output dsqw1.5b.jsonl --hf bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF --ollama deepseek-r1:1.5b ; ) - + ./scripts/tool_bench.py plot qwen1.5b.jsonl ./scripts/tool_bench.py plot *.jsonl --output all.png - + for f in *.jsonl; do ./scripts/tool_bench.py plot $f --output ${f%.jsonl}.png done @@ -45,6 +45,7 @@ from statistics import mean, median from typing import Annotated, List, Optional from typing import Dict, List, Tuple, Set, Any +import atexit import json import logging import matplotlib.pyplot as plt @@ -54,6 +55,7 @@ import subprocess import sys import time +import typer sys.path.insert(0, Path(__file__).parent.parent.as_posix()) print(sys.path) @@ -63,49 +65,38 @@ @contextmanager def scoped_server(sp: ServerProcess): - global server - server = sp - - import atexit def stop(): - global server nonlocal sp if sp is not None: sp.stop() sp = None # type: ignore - server = None # type: ignore atexit.register(stop) - yield sp - stop() -import typer - -app = typer.Typer() - - logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) +app = typer.Typer() + @app.command() def plot(files: List[Path], output: Optional[Path] = None): - + lines: List[Dict] = [] for file in files: if not file.exists(): logger.error(f"File not found: {file}") continue - + try: with file.open() as f: raw_data = f.read() logger.info(f"Reading {file} ({len(raw_data)} bytes)") - + for line_num, line in enumerate(raw_data.split('\n'), 1): line = line.strip() if not line: @@ -120,7 +111,7 @@ def plot(files: List[Path], output: Optional[Path] = None): if not lines: raise Exception("No valid data was loaded") - + data_dict: Dict[Tuple, float] = {} models: List[str] = [] temps = set() @@ -133,16 +124,16 @@ def plot(files: List[Path], output: Optional[Path] = None): impl = rec["implementation"] test = rec["test"] success = rec["success_ratio"] - - + + data_dict[(model, temp, impl, test)] = success - + if model not in models: models.append(model) temps.add(temp) tests.add(test) impls.add(impl) - + except KeyError as e: logger.warning(f"Missing required field in record: {e}") @@ -151,7 +142,7 @@ def plot(files: List[Path], output: Optional[Path] = None): tests = list(sorted(tests)) impls = list(sorted(impls)) - + logger.info(f"Processed {len(lines)} lines") logger.info(f"Found {len(data_dict)} valid data points") logger.info(f"Models: {models}") @@ -159,10 +150,10 @@ def plot(files: List[Path], output: Optional[Path] = None): logger.info(f"Tests: {tests}") logger.info(f"Implementations: {impls}") - + matrix = [] index = [] - + all_cols = [ (impl, test) for impl in impls @@ -176,33 +167,33 @@ def plot(files: List[Path], output: Optional[Path] = None): for impl, test in all_cols ] matrix.append(row_vals) - + columns = [f"{impl}\n({test})" for impl, test in all_cols] df = pd.DataFrame(matrix, index=index, columns=columns) plt.figure(figsize=(12, 6)) - + sns.heatmap( df, annot=True, cmap="RdYlGn", vmin=0.0, vmax=1.0, cbar=True, fmt=".2f", center=0.5, square=True, linewidths=0.5, cbar_kws={"label": "Success Ratio"}, ) - + plt.title("Tool Call Bench\nSuccess Ratios by Implementation & Test", pad=20) plt.xlabel("Implementation and Test", labelpad=10) plt.ylabel("Model @ Temperature", labelpad=10) - + plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) - + plt.tight_layout() - + if output: plt.savefig(output, dpi=300, bbox_inches='tight') logger.info(f"Plot saved to {output}") else: plt.show() - + @app.command() def run( output: Annotated[Path, typer.Option(help="Output JSON file")], @@ -220,11 +211,11 @@ def run( append: Annotated[bool, typer.Option(help="Append to output file")] = False, ): # Check only one of output and append - + n_predict = 512 assert force or not output.exists(), f"Output file already exists: {output}; use --force to overwrite" - + with output.open('a' if append else 'w') as output_file: def run(server: ServerProcess, *, implementation: str, model_id: str, temp: float | None = None, output_kwargs={}, request_kwargs={}): From 52fe049aa2837c3adafd59427f09529e8b3a638c Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 20 Feb 2025 00:12:04 +0000 Subject: [PATCH 21/43] allow f-strings in pseudo python soups --- common/chat.cpp | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index a71628fa3df7f..70c54288d6259 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -938,6 +938,8 @@ static void expect_tool_parameters(const std::string & name, const json & parame } } +const bool constrain_python_tool_code = !getenv("CONSTRAIN_PYTHON_TOOL_CODE") || std::string(getenv("CONSTRAIN_PYTHON_TOOL_CODE")) != "0"; + /* Adds a GBNF rule that matches a Python code string when escaped inside a JSON string (without surrounding double quotes) @@ -956,7 +958,6 @@ static void expect_tool_parameters(const std::string & name, const json & parame This should prevent an entire class of invalid Python programs to be generated by the model, but any bugs / omissions may also disallow some valid Python syntax. Current limitations: - - No f strings - No multiline strings Examples: @@ -974,10 +975,11 @@ static std::string add_escaped_python_code_soup_rule(const common_grammar_builde return builder.add_rule("json-escaped-code-soup", // Allow comments w/ (escaped) newline R"( ( [#] ( ( [^\\\t\r\n\uff00-\uffef] | [\\] [^n\n] )* [\\] [n] )? | )" - // Allow (escaped) double quoted strings and their nested (double) escapes - R"( [\\] ["] ( [^"\\\t\r\n\uff00-\uffef] | [\\] [\\] ["] | [\\] [trnu] )* [\\] ["] | )" - // Allow single quoted strings and their nested (double) escapes - R"( ['] ( [^"'\\\t\r\n\uff00-\uffef] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] )* ['] | )" + // Allow (escaped) double quoted strings, single quoted strings and their nested (double) escapes + f-string versions w/ nested expressions. + R"( [\\] ["] ( [^"\\\t\r\n\uff00-\uffef] | [\\] [\\] [\\] ["] | [\\] [trnu] )* [\\] ["] | )" + R"( [f][\\] ["] ( [^"\\\t\r\n\uff00-\uffef{}] | [\\] [\\] [\\] ["] | [\\] [trnu] | [{] json-escaped-code-soup [}] )* [\\] ["] | )" + R"( ['] ( [^"'\\\t\r\n\uff00-\uffef] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] )* ['] | )" + R"( [f]['] ( [^"'\\\t\r\n\uff00-\uffef{}] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] | [{] json-escaped-code-soup [}] )* ['] | )" // Soup wrapped in parentheses, curly braces or square brackets R"( [(] json-escaped-code-soup [)] | )" R"( [{] json-escaped-code-soup [}] | )" @@ -1002,9 +1004,9 @@ static std::string add_python_code_arguments_rule(const std::string & name, cons } static std::string add_json_tool_args_rule(const std::string & name, const json & parameters, const common_grammar_builder & builder) { - // if (name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { - // return add_python_code_arguments_rule(name + "-code-args", builder); - // } else + if (constrain_python_tool_code && name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { + return add_python_code_arguments_rule(name + "-code-args", builder); + } else { return builder.add_schema(name + "-args", parameters); } @@ -1032,10 +1034,9 @@ static common_chat_params common_chat_params_init_llama_3_1_tool_calls(const com std::vector kvs; for (const auto & [key, value] : parameters.at("properties").items()) { - // if (name == "python" && key == "code") { - // kvs.push_back("\"" + key + "=\\\"\" " + add_escaped_python_code_soup_rule(builder) + " \"\\\"\""); // NOLINT - // } else - { + if (constrain_python_tool_code && name == "python" && key == "code") { + kvs.push_back("\"" + key + "=\\\"\" " + add_escaped_python_code_soup_rule(builder) + " \"\\\"\""); // NOLINT + } else { kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT } } @@ -1438,14 +1439,13 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat std::string name = function.at("name"); auto parameters = function.at("parameters"); builder.resolve_refs(parameters); - // if (name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { - // tool_rules.push_back(builder.add_rule(name + "-call", - // "\"{\" space " - // "\"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space " - // "\"\\\"arguments\\\"\" space \":\" space " + add_python_code_arguments_rule(name + "-code-arguments", builder) + " " - // "\"}\" space ")); - // } else - { + if (constrain_python_tool_code && name == "python" && parameters.contains("properties") && parameters.at("properties").contains("code") && parameters.at("properties").size() == 1) { + tool_rules.push_back(builder.add_rule(name + "-call", + "\"{\" space " + "\"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space " + "\"\\\"arguments\\\"\" space \":\" space " + add_python_code_arguments_rule(name + "-code-arguments", builder) + " " + "\"}\" space ")); + } else { tool_rules.push_back(builder.add_schema(name + "-call", { {"type", "object"}, {"properties", json { From 1e78b680db922b8a1813725c20963897420d9a42 Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 20 Feb 2025 00:20:47 +0000 Subject: [PATCH 22/43] add multiline strings to python code constraints --- common/chat.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 70c54288d6259..0b8c4ad870705 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -976,10 +976,14 @@ static std::string add_escaped_python_code_soup_rule(const common_grammar_builde // Allow comments w/ (escaped) newline R"( ( [#] ( ( [^\\\t\r\n\uff00-\uffef] | [\\] [^n\n] )* [\\] [n] )? | )" // Allow (escaped) double quoted strings, single quoted strings and their nested (double) escapes + f-string versions w/ nested expressions. - R"( [\\] ["] ( [^"\\\t\r\n\uff00-\uffef] | [\\] [\\] [\\] ["] | [\\] [trnu] )* [\\] ["] | )" - R"( [f][\\] ["] ( [^"\\\t\r\n\uff00-\uffef{}] | [\\] [\\] [\\] ["] | [\\] [trnu] | [{] json-escaped-code-soup [}] )* [\\] ["] | )" - R"( ['] ( [^"'\\\t\r\n\uff00-\uffef] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] )* ['] | )" - R"( [f]['] ( [^"'\\\t\r\n\uff00-\uffef{}] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] | [{] json-escaped-code-soup [}] )* ['] | )" + R"( "\\\"" ( [^"\\\t\r\n\uff00-\uffef] | [\\] [\\] [\\] ["] | [\\] [trnu] )* [\\] ["] | )" + R"( "f\\\"" ( [^"\\\t\r\n\uff00-\uffef{}] | [\\] [\\] [\\] ["] | [\\] [trnu] | [{] json-escaped-code-soup [}] )* [\\] ["] | )" + R"( "'" ( [^"'\\\t\r\n\uff00-\uffef] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] )* ['] | )" + R"( "f'" ( [^"'\\\t\r\n\uff00-\uffef{}] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] | [{] json-escaped-code-soup [}] )* ['] | )" + R"( "\\\"\\\"\\\"" ( [^"\\\t\r\n\uff00-\uffef] | [\\] [\\] [\\] ["] | [\\] [trnu] )* "\\\"\\\"\\\"" | )" + R"( "f\\\"\\\"\\\"" ( [^"\\\t\r\n\uff00-\uffef{}] | [\\] [\\] [\\] ["] | [\\] [trnu] | [{] json-escaped-code-soup [}] )* "\\\"\\\"\\\"" | )" + R"( "'''" ( [^"'\\\t\r\n\uff00-\uffef] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] )* "'''" | )" + R"( "f'''" ( [^"'\\\t\r\n\uff00-\uffef{}] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] | [{] json-escaped-code-soup [}] )* "'''" | )" // Soup wrapped in parentheses, curly braces or square brackets R"( [(] json-escaped-code-soup [)] | )" R"( [{] json-escaped-code-soup [}] | )" From 6cc6c5e585591523550e670a98fb92001f11ded8 Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 20 Feb 2025 02:00:06 +0000 Subject: [PATCH 23/43] Update chat.cpp --- common/chat.cpp | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/common/chat.cpp b/common/chat.cpp index 0b8c4ad870705..1793245803cc3 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -974,16 +974,16 @@ const bool constrain_python_tool_code = !getenv("CONSTRAIN_PYTHON_TOOL_CODE") || static std::string add_escaped_python_code_soup_rule(const common_grammar_builder & builder) { return builder.add_rule("json-escaped-code-soup", // Allow comments w/ (escaped) newline - R"( ( [#] ( ( [^\\\t\r\n\uff00-\uffef] | [\\] [^n\n] )* [\\] [n] )? | )" + R"( ( [#] ( ( [^\\\t\r\n\uff00-\uffef\u0000-\u001F] | [\\] [^n\n] )* [\\] [n] )? | )" // Allow (escaped) double quoted strings, single quoted strings and their nested (double) escapes + f-string versions w/ nested expressions. - R"( "\\\"" ( [^"\\\t\r\n\uff00-\uffef] | [\\] [\\] [\\] ["] | [\\] [trnu] )* [\\] ["] | )" - R"( "f\\\"" ( [^"\\\t\r\n\uff00-\uffef{}] | [\\] [\\] [\\] ["] | [\\] [trnu] | [{] json-escaped-code-soup [}] )* [\\] ["] | )" - R"( "'" ( [^"'\\\t\r\n\uff00-\uffef] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] )* ['] | )" - R"( "f'" ( [^"'\\\t\r\n\uff00-\uffef{}] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] | [{] json-escaped-code-soup [}] )* ['] | )" - R"( "\\\"\\\"\\\"" ( [^"\\\t\r\n\uff00-\uffef] | [\\] [\\] [\\] ["] | [\\] [trnu] )* "\\\"\\\"\\\"" | )" - R"( "f\\\"\\\"\\\"" ( [^"\\\t\r\n\uff00-\uffef{}] | [\\] [\\] [\\] ["] | [\\] [trnu] | [{] json-escaped-code-soup [}] )* "\\\"\\\"\\\"" | )" - R"( "'''" ( [^"'\\\t\r\n\uff00-\uffef] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] )* "'''" | )" - R"( "f'''" ( [^"'\\\t\r\n\uff00-\uffef{}] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef] | [{] json-escaped-code-soup [}] )* "'''" | )" + R"( "\\\"" ( [^"\\\t\r\n\uff00-\uffef\u0000-\u001F] | [\\] [\\] [\\] ["] | [\\] [trnu] )* [\\] ["] | )" + R"( "f\\\"" ( [^"\\\t\r\n\uff00-\uffef\u0000-\u001F{}] | [\\] [\\] [\\] ["] | [\\] [trnu] | [{] json-escaped-code-soup [}] )* [\\] ["] | )" + R"( "'" ( [^"'\\\t\r\n\uff00-\uffef\u0000-\u001F] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef\u0000-\u001F] )* ['] | )" + R"( "f'" ( [^"'\\\t\r\n\uff00-\uffef\u0000-\u001F{}] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef\u0000-\u001F] | [{] json-escaped-code-soup [}] )* ['] | )" + R"( "\\\"\\\"\\\"" ( [^"\\\t\r\n\uff00-\uffef\u0000-\u001F] | [\\] [\\] [\\] ["] | [\\] [trnu] )* "\\\"\\\"\\\"" | )" + R"( "f\\\"\\\"\\\"" ( [^"\\\t\r\n\uff00-\uffef\u0000-\u001F{}] | [\\] [\\] [\\] ["] | [\\] [trnu] | [{] json-escaped-code-soup [}] )* "\\\"\\\"\\\"" | )" + R"( "'''" ( [^"'\\\t\r\n\uff00-\uffef\u0000-\u001F] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef\u0000-\u001F] )* "'''" | )" + R"( "f'''" ( [^"'\\\t\r\n\uff00-\uffef\u0000-\u001F{}] | [\\] [\\] ['] | [\\] [^'\t\r\n\uff00-\uffef\u0000-\u001F] | [{] json-escaped-code-soup [}] )* "'''" | )" // Soup wrapped in parentheses, curly braces or square brackets R"( [(] json-escaped-code-soup [)] | )" R"( [{] json-escaped-code-soup [}] | )" @@ -993,7 +993,8 @@ static std::string add_escaped_python_code_soup_rule(const common_grammar_builde // Allow other characters, minus code blocks for halfwidth & fullwidth forms (U+FF00 - U+FFEF) // (special tokens can use these to avoid prompt injections, as they will have to be unicode-escaped w/ \uXXXX // and won't be able to interfere w/ parsing) - R"( [^#{}"'\[\]\\()\t\r\n\uff00-\uffef]+ )" + R"( [^f#{}"'\[\]\\()\t\r\n\uff00-\uffef]+ | )" + R"( [f] [^'"#{}"'\[\]\\()\t\r\n\uff00-\uffef]+ )" // After any repetition of the previous, allow trailing comment w/o newline R"( )* ( [#] ( [^\\] | [\\] [^n] )* )? )" ); From deda56baa48a280fbc746401b6016626c440dda4 Mon Sep 17 00:00:00 2001 From: ochafik Date: Mon, 17 Feb 2025 00:48:20 +0000 Subject: [PATCH 24/43] add -jf / --json-schema-file flag --- common/arg.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/common/arg.cpp b/common/arg.cpp index eb8beccac2ee7..5ce3946567666 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1272,6 +1272,23 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.grammar = json_schema_to_grammar(json::parse(value)); } ).set_sparam()); + add_opt(common_arg( + {"-jf", "--json-schema-file"}, "FILE", + "File containing a JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object\nFor schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead", + [](common_params & params, const std::string & value) { + std::ifstream file(value); + if (!file) { + throw std::runtime_error(string_format("error: failed to open file '%s'\n", value.c_str())); + } + std::string schema; + std::copy( + std::istreambuf_iterator(file), + std::istreambuf_iterator(), + std::back_inserter(schema) + ); + params.sampling.grammar = json_schema_to_grammar(json::parse(schema)); + } + ).set_sparam()); add_opt(common_arg( {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", From 4a48898c9e1ce08e6b9302a0792646f4ba255dcd Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 20 Feb 2025 16:30:24 +0000 Subject: [PATCH 25/43] add token masking to grammar sampler rm dead code add LLAMA_MASK env add debug info --- src/llama-grammar.cpp | 242 ++++++++++++++++++++++++++++++++++++++++-- src/llama-grammar.h | 157 ++++++++++++++++++++++++++- 2 files changed, 390 insertions(+), 9 deletions(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index f20ec355ce4fa..711bdea7ba0bb 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -7,6 +7,7 @@ #include #include #include +#include // // helpers @@ -603,6 +604,150 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) } } +void token_ranges::fetch_pieces_for_debug(const std::vector & sorted_tokens) { + std::set pieces; + for (const auto & rng : allowed_token_ranges) { + for (size_t i = rng.from_sorted_index; i <= rng.to_sorted_index; i++) { + const auto & token = sorted_tokens[i]; + pieces.insert(token.piece); + } + } + allowed_pieces.clear(); + allowed_pieces.insert(allowed_pieces.end(), pieces.begin(), pieces.end()); +} + +static const token_ranges & llama_grammar_match_tokens( + struct llama_grammar & grammar, + const llama_grammar_element * pos) { + // const std::vector & sorted_tokens, + // std::unordered_map & allowed_tokens) { + + const auto & sorted_tokens = grammar.sorted_tokens; + + auto it = grammar.allowed_tokens.find(pos); + if (it != grammar.allowed_tokens.end()) { + return it->second; + } + + std::function explore_rng = [&]( + const token_range & rng, + const llama_grammar_element * pos, + size_t char_offset, + token_ranges & out) + { + auto sorted_begin = sorted_tokens.begin() + rng.from_sorted_index; + auto sorted_end = sorted_tokens.begin() + rng.to_sorted_index + 1; + auto find_lower = [&](uint32_t chr) { + // return std::lower_bound(sorted_begin, sorted_end, chr, [&](const std::vector & token_codepoints, uint32_t chr) { + // if (token_codepoints.size() <= char_offset) { + // return true; + // } + // return token_codepoints.codepoints.first[char_offset] < chr; + // }); + for (auto it = sorted_begin; it != sorted_end; it++) { + if ((*it).codepoints.first[char_offset] >= chr) { + return it; + } + } + return sorted_tokens.end(); + }; + auto find_upper = [&](uint32_t chr) { + for (auto it = sorted_begin; it != sorted_end; it++) { + if ((*it).codepoints.first[char_offset] > chr) { + return it; + } + } + return sorted_tokens.end(); + }; + + token_ranges res; + bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; + do { + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + auto low_chr = pos->value; + auto high_chr = pos[1].value; + auto low = find_lower(low_chr); + if (low != sorted_tokens.end()) { + auto high = find_upper(high_chr); + res += { + static_cast(low - sorted_tokens.begin()), + high == sorted_tokens.end() + ? sorted_tokens.size() - 1 + : static_cast(high - sorted_tokens.begin() - 1), + }; + } + pos += 2; + } else if (pos->type == LLAMA_GRETYPE_CHAR_ANY) { + // Any character matches "." + res += { 0, sorted_tokens.size() - 1 }; + pos += 1; + } else { + // exact char match, e.g. [a] or "a" + auto chr = pos->value; + auto low = find_lower(chr); + if (low != sorted_tokens.end()) { + auto high = find_upper(chr); + res += { + static_cast(low - sorted_tokens.begin()), + high == sorted_tokens.end() + ? sorted_tokens.size() - 1 + : static_cast(high - sorted_tokens.begin() - 1), + }; + } + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + if (!is_positive_char) { + res.invert(sorted_tokens.size()); + } + + res.fetch_pieces_for_debug(sorted_tokens); + + if (llama_grammar_is_end_of_sequence(pos) || pos->type == LLAMA_GRETYPE_RULE_REF) { + // Any matches are plausible guesses. We don't know if they're atually good without looking into the alternatives or the calling context we return to. + // TODO: get the rule's allowed ranges, avoiding infinite recursion. + out += res; + } else { + // For each range in res: if a token ends at size offset + 1, add it to the output. + // Otherwise, recurse on the next offset.s + auto next_offset = char_offset + 1; + + for (const auto & rng : res.allowed_token_ranges) { + if (sorted_tokens[rng.from_sorted_index].codepoints.first.size() == next_offset) { + out += rng.from_sorted_index; + if (rng.from_sorted_index != rng.to_sorted_index) { + explore_rng({rng.from_sorted_index + 1, rng.to_sorted_index}, pos, next_offset, out); + } + } else { + explore_rng(rng, pos, next_offset, out); + } + } + } + }; + + auto & rngs = grammar.allowed_tokens[pos]; + + explore_rng({0, sorted_tokens.size() - 1}, pos, 0, rngs); + + // Skip to end or alternative + while (!llama_grammar_is_end_of_sequence(pos)) { + pos++; + } + + // Merge allowed tokens from alternative(s) + if (pos->type == LLAMA_GRETYPE_ALT) { + auto & alt_matches = llama_grammar_match_tokens(grammar, pos + 1); + rngs += alt_matches; + } + + rngs.fetch_pieces_for_debug(sorted_tokens); + + return rngs; +} +//*/ + // returns true iff chr satisfies the char range at pos (regular or inverse range) // asserts that pos is pointing to a char range element static std::pair llama_grammar_match_char( @@ -957,13 +1102,40 @@ struct llama_grammar * llama_grammar_init_impl( } } while (true); + std::vector sorted_tokens; + std::vector sorted_tokens_indices; + + const bool mask = getenv("LLAMA_MASK") != nullptr && std::string(getenv("LLAMA_MASK")) == "1"; + if (mask && vocab) { + printf("Masking %d tokens\n", llama_vocab_n_tokens(vocab)); + for (size_t i = 0, n = llama_vocab_n_tokens(vocab); i < n; i++) { + auto & piece = vocab->token_to_piece(i); + sorted_tokens.push_back({ + (llama_token) i, + piece, + decode_utf8(piece, {}), + }); + } + + std::sort(sorted_tokens.begin(), sorted_tokens.end(), [](const llama_grammar_token & a, const llama_grammar_token & b) { + return a.codepoints.first < b.codepoints.first; + }); + sorted_tokens_indices.resize(sorted_tokens.size()); + for (size_t i = 0; i < sorted_tokens.size(); i++) { + sorted_tokens_indices[sorted_tokens[i].token] = i; + } + } + // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { + auto grammar = new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), + std::move(sorted_tokens), + std::move(sorted_tokens_indices), + /* .allowed_tokens = */ {}, /* .partial_utf8 = */ {}, /* .lazy =*/ false, /* .awaiting_trigger = */ false, @@ -971,6 +1143,13 @@ struct llama_grammar * llama_grammar_init_impl( /* .trigger_tokens = */ {}, /* .trigger_patterns = */ {}, }; + // Prime allowed_tokens for each rule in the grammar + for (const auto & rule : vec_rules) { + for (const auto & elem : rule) { + llama_grammar_match_tokens(*grammar, &elem); + } + } + return grammar; } struct llama_grammar * llama_grammar_init_impl( @@ -1028,6 +1207,7 @@ struct llama_grammar * llama_grammar_init_impl( } // loop over alternates of start rule to build initial stacks + // TODO: lazy, rely on llama_grammar_match_tokens for initial tokens, only build relevant stacks once we accept first token. llama_grammar_stacks stacks; pos = vec_rules[start_rule_index].data(); do { @@ -1060,13 +1240,40 @@ struct llama_grammar * llama_grammar_init_impl( vec_trigger_patterns.emplace_back(trigger_patterns[i], trigger_patterns[i]); } + std::vector sorted_tokens; + std::vector sorted_tokens_indices; + + const bool mask = getenv("LLAMA_MASK") != nullptr && std::string(getenv("LLAMA_MASK")) == "1"; + if (mask && vocab) { + printf("Masking %d tokens\n", llama_vocab_n_tokens(vocab)); + for (size_t i = 0, n = llama_vocab_n_tokens(vocab); i < n; i++) { + auto & piece = vocab->token_to_piece(i); + sorted_tokens.push_back({ + (llama_token) i, + piece, + decode_utf8(piece, {}), + }); + } + + std::sort(sorted_tokens.begin(), sorted_tokens.end(), [](const llama_grammar_token & a, const llama_grammar_token & b) { + return a.codepoints.first < b.codepoints.first; + }); + sorted_tokens_indices.resize(sorted_tokens.size()); + for (size_t i = 0; i < sorted_tokens.size(); i++) { + sorted_tokens_indices[sorted_tokens[i].token] = i; + } + } + // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar { + auto grammar = new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), + std::move(sorted_tokens), + std::move(sorted_tokens_indices), + /* .allowed_tokens = */ {}, /* .partial_utf8 = */ {}, /* .lazy = */ lazy, /* .awaiting_trigger = */ lazy, @@ -1074,6 +1281,13 @@ struct llama_grammar * llama_grammar_init_impl( std::move(vec_trigger_tokens), std::move(vec_trigger_patterns), }; + // Prime allowed_tokens for each rule in the grammar + for (const auto & rule : vec_rules) { + for (const auto & elem : rule) { + llama_grammar_match_tokens(*grammar, &elem); + } + } + return grammar; } void llama_grammar_free_impl(struct llama_grammar * grammar) { @@ -1089,6 +1303,9 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra grammar.vocab, grammar.rules, grammar.stacks, + grammar.sorted_tokens, + grammar.sorted_tokens_indices, + grammar.allowed_tokens, grammar.partial_utf8, grammar.lazy, grammar.awaiting_trigger, @@ -1113,7 +1330,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra return result; } -void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { +void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_array * cur_p) { GGML_ASSERT(grammar.vocab != nullptr); if (grammar.awaiting_trigger) { @@ -1121,10 +1338,16 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ } bool allow_eog = false; - for (const auto & stack : grammar.stacks) { - if (stack.empty()) { - allow_eog = true; - break; + token_ranges accepted_ranges; + + if (!grammar.sorted_tokens_indices.empty()) { + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + allow_eog = true; + // break; + } else { + accepted_ranges += llama_grammar_match_tokens(grammar, stack.back()); + } } } @@ -1136,14 +1359,17 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; + auto idx = grammar.sorted_tokens_indices.empty() ? std::string::npos : grammar.sorted_tokens_indices[id]; const std::string & piece = grammar.vocab->token_to_piece(id); - + if (grammar.vocab->is_eog(id)) { if (!allow_eog) { cur_p->data[i].logit = -INFINITY; } } else if (piece.empty() || piece[0] == 0) { cur_p->data[i].logit = -INFINITY; + } else if (idx != std::string::npos && grammar.partial_utf8.n_remain == 0 && !accepted_ranges.contains(idx)) { + cur_p->data[i].logit = -INFINITY; } else { candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); diff --git a/src/llama-grammar.h b/src/llama-grammar.h index a9b6f99ec34f5..99f16a8b073d4 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -6,6 +6,7 @@ #include #include #include +#include struct llama_vocab; @@ -54,6 +55,150 @@ struct llama_grammar_candidate { llama_partial_utf8 partial_utf8; }; +//* +struct token_range { + size_t from_sorted_index; + size_t to_sorted_index; +}; + +struct token_ranges { + std::vector allowed_token_ranges; + std::vector allowed_pieces; + + void fetch_pieces_for_debug(const std::vector & sorted_tokens); + + void invert(size_t size) { + // Go from positive matches to negative matches + // [[10, 20]] w/ size 30 -> [[0, 9], [21, 29]] + if (allowed_token_ranges.empty()) { + allowed_token_ranges.push_back({0, size - 1}); + return; + } + std::vector new_ranges; + if (allowed_token_ranges.front().from_sorted_index > 0) { + new_ranges.push_back({0, allowed_token_ranges.front().from_sorted_index - 1}); + } + for (size_t i = 1; i < allowed_token_ranges.size(); i++) { + new_ranges.push_back({allowed_token_ranges[i - 1].to_sorted_index + 1, allowed_token_ranges[i].from_sorted_index - 1}); + } + if (allowed_token_ranges.back().to_sorted_index < size - 1) { + new_ranges.push_back({allowed_token_ranges.back().to_sorted_index + 1, size - 1}); + } + allowed_token_ranges.swap(new_ranges); + } + + token_ranges & operator+=(const token_range & other) { + if (allowed_token_ranges.empty()) { + allowed_token_ranges.push_back(other); + return *this; + } + if (allowed_token_ranges.back().to_sorted_index + 1 == other.from_sorted_index) { + allowed_token_ranges.back().to_sorted_index = other.to_sorted_index; + return *this; + } + // find nearest + auto it = std::lower_bound(allowed_token_ranges.begin(), allowed_token_ranges.end(), other.from_sorted_index, + [](const token_range & range, size_t idx) { + return range.to_sorted_index < idx; + }); + if (it != allowed_token_ranges.end() && it->from_sorted_index <= other.from_sorted_index && other.to_sorted_index <= it->to_sorted_index) { + return *this; + } + // Insert a new range and fuse it with the previous one and/or followin if possible + auto new_range = other; + if (it != allowed_token_ranges.begin() && it[-1].to_sorted_index + 1 == other.from_sorted_index) { + it[-1].to_sorted_index = other.to_sorted_index; + new_range.from_sorted_index = it[-1].from_sorted_index; + it = allowed_token_ranges.erase(it); + } + if (it != allowed_token_ranges.end() && it->from_sorted_index == other.to_sorted_index + 1) { + new_range.to_sorted_index = it->to_sorted_index; + it = allowed_token_ranges.erase(it); + } + allowed_token_ranges.insert(it, new_range); + return *this; + } + token_ranges & operator+=(const token_ranges & other) { + if (allowed_token_ranges.empty()) { + allowed_token_ranges = other.allowed_token_ranges; + return *this; + } + else if (other.allowed_token_ranges.empty()) { + return *this; + } + auto it1 = allowed_token_ranges.begin(); + auto it2 = other.allowed_token_ranges.begin(); + + std::vector result; + // Merge the two ranges, fusing [from,to] pairs that overlap + while (it1 != allowed_token_ranges.end() && it2 != other.allowed_token_ranges.end()) { + if (it1->to_sorted_index < it2->from_sorted_index) { + result.push_back(*it1); + it1++; + } + else if (it2->to_sorted_index < it1->from_sorted_index) { + result.push_back(*it2); + it2++; + } + else { + result.push_back({std::min(it1->from_sorted_index, it2->from_sorted_index), std::max(it1->to_sorted_index, it2->to_sorted_index)}); + it1++; + it2++; + } + } + while (it1 != allowed_token_ranges.end()) { + result.push_back(*it1); + it1++; + } + while (it2 != other.allowed_token_ranges.end()) { + result.push_back(*it2); + it2++; + } + allowed_token_ranges = result; + return *this; + } + + token_ranges & operator+=(size_t idx) { + if (allowed_token_ranges.empty()) { + allowed_token_ranges.push_back({idx, idx}); + return *this; + } + if (allowed_token_ranges.back().to_sorted_index + 1 == idx) { + allowed_token_ranges.back().to_sorted_index = idx; + return *this; + } + // Find the range that contains the token + auto it = std::lower_bound(allowed_token_ranges.begin(), allowed_token_ranges.end(), idx, [](const token_range & range, size_t idx) { + return range.to_sorted_index < idx; + }); + if (it != allowed_token_ranges.end() && it->from_sorted_index <= idx && idx <= it->to_sorted_index) { + return *this; + } + // Insert a new range and fuse it with the previous one and/or followin if possible + token_range new_range { idx, idx }; + if (it != allowed_token_ranges.begin() && it[-1].to_sorted_index + 1 == idx) { + it[-1].to_sorted_index = idx; + new_range.from_sorted_index = it[-1].from_sorted_index; + it = allowed_token_ranges.erase(it); + } + if (it != allowed_token_ranges.end() && it->from_sorted_index == idx + 1) { + new_range.to_sorted_index = it->to_sorted_index; + it = allowed_token_ranges.erase(it); + } + allowed_token_ranges.insert(it, new_range); + return *this; + } + + bool contains(size_t idx) const { + // find (sorted) + auto it = std::lower_bound(allowed_token_ranges.begin(), allowed_token_ranges.end(), idx, [](const token_range & range, size_t idx) { + return range.to_sorted_index < idx; + }); + return it != allowed_token_ranges.end() && it->from_sorted_index <= idx && idx <= it->to_sorted_index; + } +}; +//*/ + using llama_grammar_rule = std::vector< llama_grammar_element>; using llama_grammar_stack = std::vector; @@ -106,6 +251,12 @@ struct llama_grammar_parser { void print(FILE * file); }; +struct llama_grammar_token { + llama_token token; + std::string piece; + std::pair, llama_partial_utf8> codepoints; +}; + struct llama_grammar { // note: allow null vocab for testing (not great) const llama_vocab * vocab; @@ -113,6 +264,10 @@ struct llama_grammar { const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; + std::vector sorted_tokens; + std::vector sorted_tokens_indices; // llama_token -> idx in sorted_token + std::unordered_map allowed_tokens; + // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; @@ -156,7 +311,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra // TODO: move the API below as member functions of llama_grammar void llama_grammar_apply_impl( - const struct llama_grammar & grammar, + struct llama_grammar & grammar, llama_token_data_array * cur_p); void llama_grammar_accept_impl( From e923222aa042b6db6c398ea3b87f7e01ed895699 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 20 Feb 2025 17:55:06 +0000 Subject: [PATCH 26/43] more logs --- src/llama-grammar.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 711bdea7ba0bb..02f614970ffbd 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1369,8 +1369,10 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a } else if (piece.empty() || piece[0] == 0) { cur_p->data[i].logit = -INFINITY; } else if (idx != std::string::npos && grammar.partial_utf8.n_remain == 0 && !accepted_ranges.contains(idx)) { + LLAMA_LOG_DEBUG("Rejecting masked token %u (`%s`)\n", id, piece.c_str()); cur_p->data[i].logit = -INFINITY; } else { + LLAMA_LOG_DEBUG("Considering unmasked token %u (`%s`)\n", id, piece.c_str()); candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } From 97a8ab2c2f29a145487653fbada0044f73f6e4fe Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 20 Feb 2025 18:03:48 +0000 Subject: [PATCH 27/43] Update llama-grammar.cpp --- src/llama-grammar.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 02f614970ffbd..cf40bfc311a3b 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1357,6 +1357,8 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a llama_grammar_candidates candidates_grammar; candidates_grammar.reserve(cur_p->size); + LLAMA_LOG_DEBUG("# Grammar sampling %zu tokens\n", cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; auto idx = grammar.sorted_tokens_indices.empty() ? std::string::npos : grammar.sorted_tokens_indices[id]; @@ -1369,10 +1371,10 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a } else if (piece.empty() || piece[0] == 0) { cur_p->data[i].logit = -INFINITY; } else if (idx != std::string::npos && grammar.partial_utf8.n_remain == 0 && !accepted_ranges.contains(idx)) { - LLAMA_LOG_DEBUG("Rejecting masked token %u (`%s`)\n", id, piece.c_str()); + LLAMA_LOG_DEBUG("- Rejecting masked token %u (`%s`)\n", id, piece.c_str()); cur_p->data[i].logit = -INFINITY; } else { - LLAMA_LOG_DEBUG("Considering unmasked token %u (`%s`)\n", id, piece.c_str()); + LLAMA_LOG_DEBUG("- Considering unmasked token %u (`%s`)\n", id, piece.c_str()); candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } @@ -1380,6 +1382,8 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { + auto & token = cur_p->data[reject.index]; + LLAMA_LOG_DEBUG("- Rejecting token %u (`%s`)\n", token.id, grammar.vocab->token_to_piece(token.id).c_str()); cur_p->data[reject.index].logit = -INFINITY; } } From a597293293ab03d1264c3885b9afa36322eb8756 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 20 Feb 2025 18:03:56 +0000 Subject: [PATCH 28/43] Update sampling.cpp --- common/sampling.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/common/sampling.cpp b/common/sampling.cpp index 20bf20bb2578a..e00d49f2c7ecb 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,6 +1,7 @@ #include "sampling.h" #include "common.h" +#include "llama-impl.h" #include #include @@ -352,6 +353,9 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co const llama_token id = cur_p.data[cur_p.selected].id; if (grammar_first) { + LLAMA_LOG_DEBUG("sampled token %u (`%s`)\n", id, common_token_to_piece(ctx, id).c_str()); + fflush(stdout); + fflush(stderr); return id; } From 1abeb0ce3fa2cd964d2c1dd547752de8896bd0e4 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 20 Feb 2025 18:04:03 +0000 Subject: [PATCH 29/43] Update sampling.cpp --- common/sampling.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/common/sampling.cpp b/common/sampling.cpp index e00d49f2c7ecb..ee9e9b35f442e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -342,6 +342,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co auto & chain = gsmpl->chain; auto & cur_p = gsmpl->cur_p; // initialized by set_logits + grammar_first = true; if (grammar_first) { llama_sampler_apply(grmr, &cur_p); } From 458d78c004b3898d83087875b9726eed7ba8e240 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Thu, 20 Feb 2025 18:09:58 +0000 Subject: [PATCH 30/43] logs --- common/sampling.cpp | 4 ++-- src/llama-grammar.cpp | 23 +++++++++++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index ee9e9b35f442e..83c0e7dbcfc9f 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -354,8 +354,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co const llama_token id = cur_p.data[cur_p.selected].id; if (grammar_first) { - LLAMA_LOG_DEBUG("sampled token %u (`%s`)\n", id, common_token_to_piece(ctx, id).c_str()); - fflush(stdout); + fprintf(stderr, "sampled token %u (`%s`)\n", id, common_token_to_piece(ctx, id).c_str()); + // fflush(stdout); fflush(stderr); return id; } diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index cf40bfc311a3b..bab5141c8a74d 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1357,7 +1357,8 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a llama_grammar_candidates candidates_grammar; candidates_grammar.reserve(cur_p->size); - LLAMA_LOG_DEBUG("# Grammar sampling %zu tokens\n", cur_p->size); + fprintf(stderr, "# Grammar sampling %zu tokens\n", cur_p->size); + fflush(stderr); for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; @@ -1371,10 +1372,12 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a } else if (piece.empty() || piece[0] == 0) { cur_p->data[i].logit = -INFINITY; } else if (idx != std::string::npos && grammar.partial_utf8.n_remain == 0 && !accepted_ranges.contains(idx)) { - LLAMA_LOG_DEBUG("- Rejecting masked token %u (`%s`)\n", id, piece.c_str()); + fprintf(stderr, "- Rejecting masked token %u (`%s`)\n", id, piece.c_str()); + fflush(stderr); cur_p->data[i].logit = -INFINITY; } else { - LLAMA_LOG_DEBUG("- Considering unmasked token %u (`%s`)\n", id, piece.c_str()); + fprintf(stderr, "- Considering unmasked token %u (`%s`)\n", id, piece.c_str()); + fflush(stderr); candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } @@ -1383,9 +1386,21 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { auto & token = cur_p->data[reject.index]; - LLAMA_LOG_DEBUG("- Rejecting token %u (`%s`)\n", token.id, grammar.vocab->token_to_piece(token.id).c_str()); + fprintf(stderr, "- Rejecting token %u (`%s`)\n", token.id, grammar.vocab->token_to_piece(token.id).c_str()); + fflush(stderr); cur_p->data[reject.index].logit = -INFINITY; } + + // Find non-rjects + + for (size_t i = 0; i < cur_p->size; ++i) { + auto & token = cur_p->data[i]; + if (token.logit == -INFINITY) { + continue; + } + fprintf(stderr, "- Accepted token %u (`%s`)\n", token.id, grammar.vocab->token_to_piece(token.id).c_str()); + fflush(stderr); + } } void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { From 178a36d70ffaaa7544f0f76ac3acc986d72cc9ca Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 20 Feb 2025 19:47:52 +0000 Subject: [PATCH 31/43] test token ranges --- src/llama-grammar.h | 4 +-- tests/test-grammar-integration.cpp | 48 ++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 99f16a8b073d4..5f992e22c16b5 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -106,12 +106,12 @@ struct token_ranges { } // Insert a new range and fuse it with the previous one and/or followin if possible auto new_range = other; - if (it != allowed_token_ranges.begin() && it[-1].to_sorted_index + 1 == other.from_sorted_index) { + if (it != allowed_token_ranges.begin() && it[-1].to_sorted_index + 1 >= other.from_sorted_index) { it[-1].to_sorted_index = other.to_sorted_index; new_range.from_sorted_index = it[-1].from_sorted_index; it = allowed_token_ranges.erase(it); } - if (it != allowed_token_ranges.end() && it->from_sorted_index == other.to_sorted_index + 1) { + if (it != allowed_token_ranges.end() && it->from_sorted_index <= other.to_sorted_index + 1) { new_range.to_sorted_index = it->to_sorted_index; it = allowed_token_ranges.erase(it); } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 89060864894a4..8c7aa07435380 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -9,6 +9,7 @@ #include #include #include +#include using json = nlohmann::ordered_json; @@ -1292,8 +1293,55 @@ static void test_json_schema() { ); } + +template static void assert_equals(const T & expected, const T & actual) { + if (expected != actual) { + std::cerr << "Expected: " << expected << std::endl; + std::cerr << "Actual: " << actual << std::endl; + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } +} + +static void test_token_ranges() { + { + token_ranges rngs; + rngs += 1; + assert_equals(true, rngs.contains(1)); + assert_equals(false, rngs.contains(2)); + + rngs += 10; + assert_equals(true, rngs.contains(1)); + assert_equals(false, rngs.contains(2)); + assert_equals(false, rngs.contains(9)); + assert_equals(true, rngs.contains(10)); + assert_equals(false, rngs.contains(11)); + } + { + token_ranges rngs; + rngs += {10, 20}; + assert_equals(1, rngs.allowed_token_ranges.size()); + assert_equals(false, rngs.contains(9)); + assert_equals(true, rngs.contains(10)); + assert_equals(true, rngs.contains(11)); + assert_equals(true, rngs.contains(20)); + assert_equals(false, rngs.contains(21)); + } + { + token_ranges rngs; + rngs += {10, 20}; + rngs += {15, 25}; + assert_equals(1, rngs.allowed_token_ranges.size()); + assert_equals(false, rngs.contains(9)); + assert_equals(true, rngs.contains(10)); + assert_equals(true, rngs.contains(25)); + assert_equals(false, rngs.contains(26)); + } +} + int main() { fprintf(stdout, "Running grammar integration tests...\n"); + test_token_ranges(); test_simple_grammar(); test_complex_grammar(); test_special_chars(); From 2f43139949eb266bf4d9cd4db39ffc503be8ce8c Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 20 Feb 2025 19:51:13 +0000 Subject: [PATCH 32/43] fix ranges --- src/llama-grammar.h | 140 ++++++++++------------------- tests/test-grammar-integration.cpp | 11 +++ 2 files changed, 58 insertions(+), 93 deletions(-) diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 5f992e22c16b5..d33400ca5e197 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -87,117 +87,71 @@ struct token_ranges { allowed_token_ranges.swap(new_ranges); } - token_ranges & operator+=(const token_range & other) { + // Helper function to merge overlapping or adjacent ranges + void merge_ranges() { if (allowed_token_ranges.empty()) { - allowed_token_ranges.push_back(other); - return *this; - } - if (allowed_token_ranges.back().to_sorted_index + 1 == other.from_sorted_index) { - allowed_token_ranges.back().to_sorted_index = other.to_sorted_index; - return *this; + return; } - // find nearest - auto it = std::lower_bound(allowed_token_ranges.begin(), allowed_token_ranges.end(), other.from_sorted_index, - [](const token_range & range, size_t idx) { - return range.to_sorted_index < idx; + + std::sort(allowed_token_ranges.begin(), allowed_token_ranges.end(), + [](const token_range& a, const token_range& b) { + return a.from_sorted_index < b.from_sorted_index; }); - if (it != allowed_token_ranges.end() && it->from_sorted_index <= other.from_sorted_index && other.to_sorted_index <= it->to_sorted_index) { - return *this; - } - // Insert a new range and fuse it with the previous one and/or followin if possible - auto new_range = other; - if (it != allowed_token_ranges.begin() && it[-1].to_sorted_index + 1 >= other.from_sorted_index) { - it[-1].to_sorted_index = other.to_sorted_index; - new_range.from_sorted_index = it[-1].from_sorted_index; - it = allowed_token_ranges.erase(it); - } - if (it != allowed_token_ranges.end() && it->from_sorted_index <= other.to_sorted_index + 1) { - new_range.to_sorted_index = it->to_sorted_index; - it = allowed_token_ranges.erase(it); - } - allowed_token_ranges.insert(it, new_range); - return *this; - } - token_ranges & operator+=(const token_ranges & other) { - if (allowed_token_ranges.empty()) { - allowed_token_ranges = other.allowed_token_ranges; - return *this; - } - else if (other.allowed_token_ranges.empty()) { - return *this; - } - auto it1 = allowed_token_ranges.begin(); - auto it2 = other.allowed_token_ranges.begin(); - - std::vector result; - // Merge the two ranges, fusing [from,to] pairs that overlap - while (it1 != allowed_token_ranges.end() && it2 != other.allowed_token_ranges.end()) { - if (it1->to_sorted_index < it2->from_sorted_index) { - result.push_back(*it1); - it1++; - } - else if (it2->to_sorted_index < it1->from_sorted_index) { - result.push_back(*it2); - it2++; - } - else { - result.push_back({std::min(it1->from_sorted_index, it2->from_sorted_index), std::max(it1->to_sorted_index, it2->to_sorted_index)}); - it1++; - it2++; + + std::vector merged; + merged.push_back(allowed_token_ranges[0]); + + for (size_t i = 1; i < allowed_token_ranges.size(); i++) { + auto& current = allowed_token_ranges[i]; + auto& last = merged.back(); + + // Check if ranges overlap or are adjacent + if (current.from_sorted_index <= last.to_sorted_index + 1) { + // Merge the ranges + last.to_sorted_index = std::max(last.to_sorted_index, current.to_sorted_index); + } else { + // Add new range + merged.push_back(current); } } - while (it1 != allowed_token_ranges.end()) { - result.push_back(*it1); - it1++; - } - while (it2 != other.allowed_token_ranges.end()) { - result.push_back(*it2); - it2++; - } - allowed_token_ranges = result; + + allowed_token_ranges.swap(merged); + } + + token_ranges& operator+=(const token_range& other) { + allowed_token_ranges.push_back(other); + merge_ranges(); return *this; } - token_ranges & operator+=(size_t idx) { - if (allowed_token_ranges.empty()) { - allowed_token_ranges.push_back({idx, idx}); - return *this; - } - if (allowed_token_ranges.back().to_sorted_index + 1 == idx) { - allowed_token_ranges.back().to_sorted_index = idx; - return *this; - } - // Find the range that contains the token - auto it = std::lower_bound(allowed_token_ranges.begin(), allowed_token_ranges.end(), idx, [](const token_range & range, size_t idx) { - return range.to_sorted_index < idx; - }); - if (it != allowed_token_ranges.end() && it->from_sorted_index <= idx && idx <= it->to_sorted_index) { + token_ranges& operator+=(const token_ranges& other) { + if (other.allowed_token_ranges.empty()) { return *this; } - // Insert a new range and fuse it with the previous one and/or followin if possible - token_range new_range { idx, idx }; - if (it != allowed_token_ranges.begin() && it[-1].to_sorted_index + 1 == idx) { - it[-1].to_sorted_index = idx; - new_range.from_sorted_index = it[-1].from_sorted_index; - it = allowed_token_ranges.erase(it); - } - if (it != allowed_token_ranges.end() && it->from_sorted_index == idx + 1) { - new_range.to_sorted_index = it->to_sorted_index; - it = allowed_token_ranges.erase(it); - } - allowed_token_ranges.insert(it, new_range); + + allowed_token_ranges.insert( + allowed_token_ranges.end(), + other.allowed_token_ranges.begin(), + other.allowed_token_ranges.end() + ); + + merge_ranges(); return *this; } + token_ranges& operator+=(size_t idx) { + return operator+=({idx, idx}); + } + bool contains(size_t idx) const { // find (sorted) - auto it = std::lower_bound(allowed_token_ranges.begin(), allowed_token_ranges.end(), idx, [](const token_range & range, size_t idx) { - return range.to_sorted_index < idx; - }); + auto it = std::lower_bound(allowed_token_ranges.begin(), allowed_token_ranges.end(), idx, + [](const token_range& range, size_t idx) { + return range.to_sorted_index < idx; + }); return it != allowed_token_ranges.end() && it->from_sorted_index <= idx && idx <= it->to_sorted_index; } }; -//*/ using llama_grammar_rule = std::vector< llama_grammar_element>; using llama_grammar_stack = std::vector; diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 8c7aa07435380..2d1dd920ace9e 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -1337,6 +1337,17 @@ static void test_token_ranges() { assert_equals(true, rngs.contains(25)); assert_equals(false, rngs.contains(26)); } + { + token_ranges rngs; + rngs += {10, 20}; + rngs += {30, 40}; + assert_equals(2, rngs.allowed_token_ranges.size()); + assert_equals(false, rngs.contains(9)); + assert_equals(true, rngs.contains(10)); + assert_equals(false, rngs.contains(29)); + assert_equals(true, rngs.contains(30)); + assert_equals(false, rngs.contains(41)); + } } int main() { From a9c0256cbca436b76d5e5a1d9b71b6e35b5107a7 Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 20 Feb 2025 19:56:36 +0000 Subject: [PATCH 33/43] mute logs --- common/sampling.cpp | 5 ++--- src/llama-grammar.cpp | 20 ++++++++++---------- tests/test-grammar-integration.cpp | 21 ++++++++++++++++++--- 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 83c0e7dbcfc9f..f1855104f167c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -354,9 +354,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co const llama_token id = cur_p.data[cur_p.selected].id; if (grammar_first) { - fprintf(stderr, "sampled token %u (`%s`)\n", id, common_token_to_piece(ctx, id).c_str()); - // fflush(stdout); - fflush(stderr); + // fprintf(stderr, "sampled token %u (`%s`)\n", id, common_token_to_piece(ctx, id).c_str()); + // fflush(stderr); return id; } diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index bab5141c8a74d..ec0f2f5bce838 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1357,8 +1357,8 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a llama_grammar_candidates candidates_grammar; candidates_grammar.reserve(cur_p->size); - fprintf(stderr, "# Grammar sampling %zu tokens\n", cur_p->size); - fflush(stderr); + // fprintf(stderr, "# Grammar sampling %zu tokens\n", cur_p->size); + // fflush(stderr); for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; @@ -1372,12 +1372,12 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a } else if (piece.empty() || piece[0] == 0) { cur_p->data[i].logit = -INFINITY; } else if (idx != std::string::npos && grammar.partial_utf8.n_remain == 0 && !accepted_ranges.contains(idx)) { - fprintf(stderr, "- Rejecting masked token %u (`%s`)\n", id, piece.c_str()); - fflush(stderr); + // fprintf(stderr, "- Rejecting masked token %u (`%s`)\n", id, piece.c_str()); + // fflush(stderr); cur_p->data[i].logit = -INFINITY; } else { - fprintf(stderr, "- Considering unmasked token %u (`%s`)\n", id, piece.c_str()); - fflush(stderr); + // fprintf(stderr, "- Considering unmasked token %u (`%s`)\n", id, piece.c_str()); + // fflush(stderr); candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } @@ -1386,8 +1386,8 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { auto & token = cur_p->data[reject.index]; - fprintf(stderr, "- Rejecting token %u (`%s`)\n", token.id, grammar.vocab->token_to_piece(token.id).c_str()); - fflush(stderr); + // fprintf(stderr, "- Rejecting token %u (`%s`)\n", token.id, grammar.vocab->token_to_piece(token.id).c_str()); + // fflush(stderr); cur_p->data[reject.index].logit = -INFINITY; } @@ -1398,8 +1398,8 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a if (token.logit == -INFINITY) { continue; } - fprintf(stderr, "- Accepted token %u (`%s`)\n", token.id, grammar.vocab->token_to_piece(token.id).c_str()); - fflush(stderr); + // fprintf(stderr, "- Accepted token %u (`%s`)\n", token.id, grammar.vocab->token_to_piece(token.id).c_str()); + // fflush(stderr); } } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 2d1dd920ace9e..6d5b62b753f74 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -1320,33 +1320,48 @@ static void test_token_ranges() { { token_ranges rngs; rngs += {10, 20}; - assert_equals(1, rngs.allowed_token_ranges.size()); assert_equals(false, rngs.contains(9)); assert_equals(true, rngs.contains(10)); assert_equals(true, rngs.contains(11)); assert_equals(true, rngs.contains(20)); assert_equals(false, rngs.contains(21)); + + assert_equals(1, rngs.allowed_token_ranges.size()); } { token_ranges rngs; rngs += {10, 20}; rngs += {15, 25}; - assert_equals(1, rngs.allowed_token_ranges.size()); assert_equals(false, rngs.contains(9)); assert_equals(true, rngs.contains(10)); assert_equals(true, rngs.contains(25)); assert_equals(false, rngs.contains(26)); + + assert_equals(1, rngs.allowed_token_ranges.size()); } { token_ranges rngs; rngs += {10, 20}; rngs += {30, 40}; - assert_equals(2, rngs.allowed_token_ranges.size()); assert_equals(false, rngs.contains(9)); assert_equals(true, rngs.contains(10)); assert_equals(false, rngs.contains(29)); assert_equals(true, rngs.contains(30)); assert_equals(false, rngs.contains(41)); + + assert_equals(2, rngs.allowed_token_ranges.size()); + } + { + token_ranges rngs; + rngs += 10; + rngs.invert(100); + assert_equals(true, rngs.contains(9)); + assert_equals(false, rngs.contains(10)); + assert_equals(true, rngs.contains(11)); + + assert_equals(2, rngs.allowed_token_ranges.size()); + assert_equals(0, rngs.allowed_token_ranges[0].from_sorted_index); + assert_equals(99, rngs.allowed_token_ranges[1].to_sorted_index); } } From fbefc2c437343d36dd4290bd7a6dd3528ba8e699 Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 20 Feb 2025 20:21:09 +0000 Subject: [PATCH 34/43] skip special tokens --- common/sampling.cpp | 2 +- src/llama-grammar.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index f1855104f167c..ce978166cc042 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -342,7 +342,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co auto & chain = gsmpl->chain; auto & cur_p = gsmpl->cur_p; // initialized by set_logits - grammar_first = true; + // grammar_first = true; if (grammar_first) { llama_sampler_apply(grmr, &cur_p); } diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index ec0f2f5bce838..510a29801b078 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1371,7 +1371,7 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a } } else if (piece.empty() || piece[0] == 0) { cur_p->data[i].logit = -INFINITY; - } else if (idx != std::string::npos && grammar.partial_utf8.n_remain == 0 && !accepted_ranges.contains(idx)) { + } else if (idx != std::string::npos && grammar.partial_utf8.n_remain == 0 && !accepted_ranges.contains(idx) && grammar.vocab->is_normal(id)) { // fprintf(stderr, "- Rejecting masked token %u (`%s`)\n", id, piece.c_str()); // fflush(stderr); cur_p->data[i].logit = -INFINITY; From b0a614b41b0092f558f5d0c6c598a64da374f3f4 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 21 Feb 2025 14:43:33 +0000 Subject: [PATCH 35/43] some fixes --- src/llama-grammar.cpp | 49 ++++++++++++++++++++++++------ src/llama-grammar.h | 70 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 109 insertions(+), 10 deletions(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 510a29801b078..3c3cb2ed74fac 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -622,6 +622,7 @@ static const token_ranges & llama_grammar_match_tokens( // const std::vector & sorted_tokens, // std::unordered_map & allowed_tokens) { + const auto orig_pos = pos; const auto & sorted_tokens = grammar.sorted_tokens; auto it = grammar.allowed_tokens.find(pos); @@ -635,6 +636,8 @@ static const token_ranges & llama_grammar_match_tokens( size_t char_offset, token_ranges & out) { + GGML_ASSERT(rng.from_sorted_index <= rng.to_sorted_index); + auto sorted_begin = sorted_tokens.begin() + rng.from_sorted_index; auto sorted_end = sorted_tokens.begin() + rng.to_sorted_index + 1; auto find_lower = [&](uint32_t chr) { @@ -645,14 +648,21 @@ static const token_ranges & llama_grammar_match_tokens( // return token_codepoints.codepoints.first[char_offset] < chr; // }); for (auto it = sorted_begin; it != sorted_end; it++) { - if ((*it).codepoints.first[char_offset] >= chr) { + if ((*it).codepoints.first.size() <= char_offset || !(*it).codepoints.first[char_offset]) { + // return it; + continue; + } + if ((*it).codepoints.first[char_offset] <= chr) { return it; } } return sorted_tokens.end(); }; - auto find_upper = [&](uint32_t chr) { - for (auto it = sorted_begin; it != sorted_end; it++) { + auto find_upper = [&](uint32_t chr, const std::vector::const_iterator & from) { + for (auto it = from; it != sorted_end; it++) { + if ((*it).codepoints.first.size() <= char_offset || !(*it).codepoints.first[char_offset]) { + continue; + } if ((*it).codepoints.first[char_offset] > chr) { return it; } @@ -669,7 +679,7 @@ static const token_ranges & llama_grammar_match_tokens( auto high_chr = pos[1].value; auto low = find_lower(low_chr); if (low != sorted_tokens.end()) { - auto high = find_upper(high_chr); + auto high = find_upper(high_chr, low); res += { static_cast(low - sorted_tokens.begin()), high == sorted_tokens.end() @@ -687,7 +697,7 @@ static const token_ranges & llama_grammar_match_tokens( auto chr = pos->value; auto low = find_lower(chr); if (low != sorted_tokens.end()) { - auto high = find_upper(chr); + auto high = find_upper(chr, low); res += { static_cast(low - sorted_tokens.begin()), high == sorted_tokens.end() @@ -715,7 +725,8 @@ static const token_ranges & llama_grammar_match_tokens( auto next_offset = char_offset + 1; for (const auto & rng : res.allowed_token_ranges) { - if (sorted_tokens[rng.from_sorted_index].codepoints.first.size() == next_offset) { + auto & first = sorted_tokens[rng.from_sorted_index]; + if (first.codepoints.first.size() == next_offset + 1) { // extra 1 for \0 out += rng.from_sorted_index; if (rng.from_sorted_index != rng.to_sorted_index) { explore_rng({rng.from_sorted_index + 1, rng.to_sorted_index}, pos, next_offset, out); @@ -727,7 +738,7 @@ static const token_ranges & llama_grammar_match_tokens( } }; - auto & rngs = grammar.allowed_tokens[pos]; + auto & rngs = grammar.allowed_tokens[orig_pos]; explore_rng({0, sorted_tokens.size() - 1}, pos, 0, rngs); @@ -1119,6 +1130,7 @@ struct llama_grammar * llama_grammar_init_impl( std::sort(sorted_tokens.begin(), sorted_tokens.end(), [](const llama_grammar_token & a, const llama_grammar_token & b) { return a.codepoints.first < b.codepoints.first; + // return a.piece < b.piece; }); sorted_tokens_indices.resize(sorted_tokens.size()); for (size_t i = 0; i < sorted_tokens.size(); i++) { @@ -1341,14 +1353,32 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a token_ranges accepted_ranges; if (!grammar.sorted_tokens_indices.empty()) { + std::vector ranges; for (const auto & stack : grammar.stacks) { if (stack.empty()) { allow_eog = true; // break; } else { - accepted_ranges += llama_grammar_match_tokens(grammar, stack.back()); + ranges.push_back(&llama_grammar_match_tokens(grammar, stack.back())); + } + } + accepted_ranges.union_all(ranges); + if (accepted_ranges.empty()) { + LLAMA_LOG_ERROR("No accepted ranges\n"); + for (const auto & stack : grammar.stacks) { + if (stack.empty()) { + LLAMA_LOG_ERROR(" EOG\n"); + } else { + LLAMA_LOG_ERROR(" %d\n", stack.back()->type); + auto & res = llama_grammar_match_tokens(grammar, stack.back()); + LLAMA_LOG_ERROR("res: %d\n", res.allowed_token_ranges.size()); + } + // for (const auto & elem : stack) { + // LLAMA_LOG_ERROR(" %d", elem->type); + // } } } + // GGML_ASSERT(!accepted_ranges.empty()); } std::vector, llama_partial_utf8>> candidates_decoded; @@ -1360,6 +1390,7 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a // fprintf(stderr, "# Grammar sampling %zu tokens\n", cur_p->size); // fflush(stderr); + // TODO: iterate on accepted_ranges & cur_p in lockstep sorted order for (size_t i = 0; i < cur_p->size; ++i) { const llama_token id = cur_p->data[i].id; auto idx = grammar.sorted_tokens_indices.empty() ? std::string::npos : grammar.sorted_tokens_indices[id]; @@ -1371,7 +1402,7 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a } } else if (piece.empty() || piece[0] == 0) { cur_p->data[i].logit = -INFINITY; - } else if (idx != std::string::npos && grammar.partial_utf8.n_remain == 0 && !accepted_ranges.contains(idx) && grammar.vocab->is_normal(id)) { + } else if (idx != std::string::npos && grammar.partial_utf8.n_remain == 0 && !accepted_ranges.empty() && !accepted_ranges.contains(idx) && grammar.vocab->is_normal(id)) { // fprintf(stderr, "- Rejecting masked token %u (`%s`)\n", id, piece.c_str()); // fflush(stderr); cur_p->data[i].logit = -INFINITY; diff --git a/src/llama-grammar.h b/src/llama-grammar.h index d33400ca5e197..85f819d39cded 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include struct llama_vocab; @@ -55,16 +57,19 @@ struct llama_grammar_candidate { llama_partial_utf8 partial_utf8; }; -//* struct token_range { size_t from_sorted_index; size_t to_sorted_index; }; + struct token_ranges { std::vector allowed_token_ranges; std::vector allowed_pieces; + bool empty() const { + return allowed_token_ranges.empty(); + } void fetch_pieces_for_debug(const std::vector & sorted_tokens); void invert(size_t size) { @@ -87,6 +92,68 @@ struct token_ranges { allowed_token_ranges.swap(new_ranges); } + void union_all(const std::vector & ranges) { + if (!ranges.empty()) { + if (ranges.size() == 1) { + *this += *ranges.front(); + } else { + // Priority queue to merge ranges + std::vector merged_ranges; + struct queue_item { + token_range range; + std::vector::const_iterator next; + std::vector::const_iterator end; + + bool operator>(const queue_item & other) const { + return range.from_sorted_index > other.range.from_sorted_index; + } + }; + std::priority_queue, std::greater> pq; + + // Initialize priority queue with first range from each input + for (const auto* r : ranges) { + if (!r->allowed_token_ranges.empty()) { + pq.push({r->allowed_token_ranges.front(), + r->allowed_token_ranges.begin() + 1, + r->allowed_token_ranges.end()}); + } + } + + // Merge ranges on the fly + while (!pq.empty()) { + auto top = pq.top(); + pq.pop(); + + // Merge with previous range if possible + if (!merged_ranges.empty() && + merged_ranges.back().to_sorted_index + 1 >= top.range.from_sorted_index) { + merged_ranges.back().to_sorted_index = + std::max(merged_ranges.back().to_sorted_index, top.range.to_sorted_index); + } else { + merged_ranges.push_back(top.range); + } + + // Add next range from the same input if available + if (top.next != top.end) { + pq.push({*top.next, top.next + 1, top.end}); + } + } + + allowed_token_ranges = std::move(merged_ranges); + } + + // union all debug pieces + std::set pieces(allowed_pieces.begin(), allowed_pieces.end()); + for (const auto & rng : ranges) { + for (const auto & piece : rng->allowed_pieces) { + pieces.insert(piece); + } + } + allowed_pieces.clear(); + allowed_pieces.insert(allowed_pieces.end(), pieces.begin(), pieces.end()); + } + } + // Helper function to merge overlapping or adjacent ranges void merge_ranges() { if (allowed_token_ranges.empty()) { @@ -119,6 +186,7 @@ struct token_ranges { } token_ranges& operator+=(const token_range& other) { + GGML_ASSERT(other.from_sorted_index <= other.to_sorted_index); allowed_token_ranges.push_back(other); merge_ranges(); return *this; From a13bbecc79ef62099e9610299490da0bca455371 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 21 Feb 2025 14:44:06 +0000 Subject: [PATCH 36/43] logs --- common/sampling.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index ce978166cc042..b22a53dffd71f 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -354,8 +354,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co const llama_token id = cur_p.data[cur_p.selected].id; if (grammar_first) { - // fprintf(stderr, "sampled token %u (`%s`)\n", id, common_token_to_piece(ctx, id).c_str()); - // fflush(stderr); + // fprintf(stderr, "Sampled token %u (`%s`)\n", id, common_token_to_piece(ctx, id).c_str()); return id; } @@ -381,6 +380,11 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration"); + { + auto id = cur_p.data[cur_p.selected].id; + fprintf(stderr, "Sampled token %u (`%s`)\n", id, common_token_to_piece(ctx, id).c_str()); + } + return cur_p.data[cur_p.selected].id; } From edbc561a502093e9792d6dfae4a7406d2f746035 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Fri, 21 Feb 2025 15:34:49 +0000 Subject: [PATCH 37/43] more lenient test_weather (space instead of comma) --- examples/server/tests/unit/test_tool_call.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/tests/unit/test_tool_call.py b/examples/server/tests/unit/test_tool_call.py index e02cb83876890..307cf6f0e33ac 100755 --- a/examples/server/tests/unit/test_tool_call.py +++ b/examples/server/tests/unit/test_tool_call.py @@ -370,7 +370,7 @@ def do_test_weather(server: ServerProcess, **kwargs): assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" location = actual_arguments["location"] assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}" - assert re.match('^Istanbul(, ?(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' + assert re.match('^Istanbul(( |, ?)(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' @pytest.mark.slow From 230b0097d632e00d25e347c2d0ef18cfc39191c2 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Fri, 21 Feb 2025 15:35:13 +0000 Subject: [PATCH 38/43] tool_bench: script + baseline support --- examples/server/tests/utils.py | 5 ++- scripts/tool_bench.py | 74 ++++++++++++++++++++-------------- scripts/tool_bench.sh | 28 +++++++++++++ 3 files changed, 75 insertions(+), 32 deletions(-) create mode 100644 scripts/tool_bench.sh diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index bb97178dcb4cb..d7eec81e9aa8b 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -81,6 +81,7 @@ class ServerProcess: reasoning_format: Literal['deepseek', 'none'] | None = None chat_template: str | None = None chat_template_file: str | None = None + server_path: str | None = None # session variables process: subprocess.Popen | None = None @@ -94,7 +95,9 @@ def __init__(self): self.server_port = int(os.environ["PORT"]) def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: - if "LLAMA_SERVER_BIN_PATH" in os.environ: + if self.server_path is not None: + server_path = self.server_path + elif "LLAMA_SERVER_BIN_PATH" in os.environ: server_path = os.environ["LLAMA_SERVER_BIN_PATH"] elif os.name == "nt": server_path = "../../../build/bin/Release/llama-server.exe" diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index 42a9095ffcde8..bc8e90fbe34ce 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -3,7 +3,6 @@ # requires-python = ">=3.10" # dependencies = [ # "pytest", -# "numpy", # "pandas", # "matplotlib", # "seaborn", @@ -13,16 +12,19 @@ # ] # /// ''' + See ./scripts/tool_bench.sh for example usage. + cmake --build build -j && ( \ export RETRIES=3 ; + export CONSTRAIN_PYTHON_TOOL_CODE=0 ; export LLAMA_CACHE=$HOME/Library/Caches/llama.cpp ; export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server ; - export ARGS=( --n 30 --temp -1 --temp 0 --temp 0.5 --temp 0.75 --temp 1 --temp 1.5 --temp 2 --temp 5 ) ; + export ARGS=( --llama-baseline=/opt/homebrew/bin/llama-server --n 30 --temp -1 --temp 0 --temp 0.5 --temp 0.75 --temp 1 --temp 1.5 --temp 2 --temp 5 ) ; ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ; + ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ; ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --output llama1b.jsonl --hf bartowski/Llama-3.2-1B-Instruct-GGUF --ollama llama3.2:1b-instruct-q4_K_M ; ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --output llama3b.jsonl --hf bartowski/Llama-3.2-3B-Instruct-GGUF --ollama llama3.1:3b ; ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M" --output llama8b.jsonl --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF --ollama llama3.1:8b ; - ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ; ./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 7B Q4_K_M" --output qwen7b.jsonl --hf bartowski/Qwen2.5-7B-Instruct-GGUF ; ./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.3 Instruct 70B Q4_K_M" --output llama70b.jsonl --hf bartowski/Llama-3.3-70B-Instruct-GGUF ; ./scripts/tool_bench.py run ${ARGS[@]} --model "Mistral Nemo 2407 Q4_K_M" --output nemo.jsonl --hf bartowski/Mistral-Nemo-Instruct-2407-GGUF --ollama mistral-nemo:12b ; @@ -117,6 +119,7 @@ def plot(files: List[Path], output: Optional[Path] = None): temps = set() tests = set() impls = set() + total_counts = set() for rec in lines: try: model = rec["model"] @@ -124,7 +127,10 @@ def plot(files: List[Path], output: Optional[Path] = None): impl = rec["implementation"] test = rec["test"] success = rec["success_ratio"] - + success_count = rec["success_count"] + failure_count = rec["failure_count"] + total_count = success_count + failure_count + total_counts.add(total_count) data_dict[(model, temp, impl, test)] = success @@ -137,6 +143,9 @@ def plot(files: List[Path], output: Optional[Path] = None): except KeyError as e: logger.warning(f"Missing required field in record: {e}") + if len(total_counts) > 1: + logger.warning(f"Total counts are not consistent: {total_counts}") + # Sort the collected values temps = list(sorted(temps, key=lambda x: x if x is not None else -1)) tests = list(sorted(tests)) @@ -179,7 +188,7 @@ def plot(files: List[Path], output: Optional[Path] = None): cbar_kws={"label": "Success Ratio"}, ) - plt.title("Tool Call Bench\nSuccess Ratios by Implementation & Test", pad=20) + plt.title(f"Tool Call Bench (n = {str(min(total_counts)) if len(total_counts) == 1 else f'{min(total_counts)}-{max(total_counts)}'})\nSuccess Ratios by Implementation & Test", pad=20) plt.xlabel("Implementation and Test", labelpad=10) plt.ylabel("Model @ Temperature", labelpad=10) @@ -201,6 +210,7 @@ def run( hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None, chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None, ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None, + llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None, n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10, temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None, top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None, @@ -281,32 +291,34 @@ def elapsed(): for t in [None] if temp is None else [t if t >= 0 else None for t in temp]: if hf is not None: - server = ServerProcess() - server.n_slots = 1 - server.jinja = True - server.n_predict = 512 # High because of DeepSeek R1 - server.model_hf_repo = hf - server.model_hf_file = None - server.chat_template = chat_template - if port is not None: - server.server_port = port - # server.debug = True - - with scoped_server(server): - server.start(timeout_seconds=TIMEOUT_SERVER_START) - for ignore_chat_grammar in [False]: - run( - server, - implementation="llama-server" + (" (no grammar)" if ignore_chat_grammar else ""), - model_id=hf, - temp=t, - output_kwargs=dict( - chat_template=chat_template, - ), - request_kwargs=dict( - ignore_chat_grammar=ignore_chat_grammar, - ), - ) + for implementation, server_path in [('llama-server', None)] if llama_baseline is None else [('llama-server (baseline)', llama_baseline), ('llama-server', None)]: + server = ServerProcess() + server.n_slots = 1 + server.jinja = True + server.n_predict = 512 # High because of DeepSeek R1 + server.model_hf_repo = hf + server.model_hf_file = None + server.chat_template = chat_template + server.server_path = server_path + if port is not None: + server.server_port = port + # server.debug = True + + with scoped_server(server): + server.start(timeout_seconds=TIMEOUT_SERVER_START) + for ignore_chat_grammar in [False]: + run( + server, + implementation=implementation, + model_id=hf, + temp=t, + output_kwargs=dict( + chat_template=chat_template, + ), + request_kwargs=dict( + ignore_chat_grammar=ignore_chat_grammar, + ), + ) if ollama is not None: server = ServerProcess() diff --git a/scripts/tool_bench.sh b/scripts/tool_bench.sh new file mode 100644 index 0000000000000..b833e16cad211 --- /dev/null +++ b/scripts/tool_bench.sh @@ -0,0 +1,28 @@ +#!/bin/bash +set -euo pipefail + +cmake --build build -j + +export RETRIES=3 +export CONSTRAIN_PYTHON_TOOL_CODE=0 +export LLAMA_MASK=1 +export LLAMA_CACHE=$HOME/Library/Caches/llama.cpp +export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server +export ARGS=( + --llama-baseline=/opt/homebrew/bin/llama-server + --n 30 + --temp -1 + --temp 0 + --temp 0.5 + --temp 0.75 + --temp 1 + --temp 1.5 + --temp 2 + --temp 5 +) + +./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 1B Q4_K_M" --output llama1b.jsonl --hf bartowski/Llama-3.2-1B-Instruct-GGUF --ollama llama3.2:1b-instruct-q4_K_M ; +./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.2 Instruct 3B Q4_K_M" --output llama3b.jsonl --hf bartowski/Llama-3.2-3B-Instruct-GGUF --ollama llama3.1:3b ; +./scripts/tool_bench.py run ${ARGS[@]} --model "Llama 3.1 Instruct 8B Q4_K_M" --output llama8b.jsonl --hf bartowski/Meta-Llama-3.1-8B-Instruct-GGUF --ollama llama3.1:8b ; +./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b ; +./scripts/tool_bench.py run ${ARGS[@]} --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M ; From 6ace7dd42bc14e544d1aa9e524bbcccd8fb4f4fc Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Fri, 21 Feb 2025 17:00:42 +0000 Subject: [PATCH 39/43] Update tool_bench.py --- scripts/tool_bench.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index bc8e90fbe34ce..88de6c5bed58f 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -224,7 +224,7 @@ def run( n_predict = 512 - assert force or not output.exists(), f"Output file already exists: {output}; use --force to overwrite" + assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite" with output.open('a' if append else 'w') as output_file: From f499e637555822c6a86a1ef90e5362c04b4963dd Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Fri, 21 Feb 2025 17:40:17 +0000 Subject: [PATCH 40/43] skip calc result test by default (not exactly tool call) --- scripts/tool_bench.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index 88de6c5bed58f..e952e21ac637a 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -32,7 +32,7 @@ ./scripts/tool_bench.py run ${ARGS[@]} --model "DeepSeek R1 Distill Qwen 1.5B Q4_K_M" --output dsqw1.5b.jsonl --hf bartowski/DeepSeek-R1-Distill-Qwen-1.5B-GGUF --ollama deepseek-r1:1.5b ; ) - ./scripts/tool_bench.py plot qwen1.5b.jsonl + ./scripts/tool_bench.py plot qwen1.5b.jsonl --test-regex 'hello|weather' ./scripts/tool_bench.py plot *.jsonl --output all.png for f in *.jsonl; do @@ -44,6 +44,7 @@ from contextlib import contextmanager from pathlib import Path from pathlib import Path +import re from statistics import mean, median from typing import Annotated, List, Optional from typing import Dict, List, Tuple, Set, Any @@ -86,7 +87,7 @@ def stop(): app = typer.Typer() @app.command() -def plot(files: List[Path], output: Optional[Path] = None): +def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None): lines: List[Dict] = [] for file in files: @@ -132,6 +133,9 @@ def plot(files: List[Path], output: Optional[Path] = None): total_count = success_count + failure_count total_counts.add(total_count) + if test_regex and not re.match(test_regex, test): + continue + data_dict[(model, temp, impl, test)] = success if model not in models: @@ -219,6 +223,10 @@ def run( port: Annotated[int, typer.Option(help="llama-server port")] = 8084, force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False, append: Annotated[bool, typer.Option(help="Append to output file")] = False, + + test_hello_world: Annotated[bool, typer.Option(help="Whether to run the hello world test")] = True, + test_weather: Annotated[bool, typer.Option(help="Whether to run the weather test")] = True, + test_calc_result: Annotated[bool, typer.Option(help="Whether to run the calc result test")] = False, ): # Check only one of output and append @@ -241,11 +249,14 @@ def run(server: ServerProcess, *, implementation: str, model_id: str, temp: floa request_kwargs['cache_prompt'] = False - tests = { - "hello world": lambda server: do_test_hello_world(server, **request_kwargs), - "weather": lambda server: do_test_weather(server, **request_kwargs), - "calc result": lambda server: do_test_calc_result(server, None, 512, **request_kwargs), - } + tests = {} + if test_hello_world: + tests["hello world"] = lambda server: do_test_hello_world(server, **request_kwargs) + if test_weather: + tests["weather"] = lambda server: do_test_weather(server, **request_kwargs) + if test_calc_result: + tests["calc result"] = lambda server: do_test_calc_result(server, None, 512, **request_kwargs) + for test_name, test in tests.items(): success_count = 0 failure_count = 0 From 6f870b69dedf300911778ec050f80550bfa30dac Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Fri, 21 Feb 2025 18:17:04 +0000 Subject: [PATCH 41/43] add --impl-regex filter flag to plot --- scripts/tool_bench.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index e952e21ac637a..88233dd3537fd 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -33,6 +33,7 @@ ) ./scripts/tool_bench.py plot qwen1.5b.jsonl --test-regex 'hello|weather' + ./scripts/tool_bench.py plot llama*.jsonl --test-regex 'hello|weather' --impl-regex '^(llama-server|ollama)$' ./scripts/tool_bench.py plot *.jsonl --output all.png for f in *.jsonl; do @@ -87,7 +88,7 @@ def stop(): app = typer.Typer() @app.command() -def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None): +def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None, impl_regex: Optional[str] = None): lines: List[Dict] = [] for file in files: @@ -133,7 +134,10 @@ def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[ total_count = success_count + failure_count total_counts.add(total_count) - if test_regex and not re.match(test_regex, test): + if test_regex and not re.search(test_regex, test): + continue + + if impl_regex and not re.search(impl_regex, impl): continue data_dict[(model, temp, impl, test)] = success From e82728204df3a479c63ad8365371ccd71da7ea1d Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Tue, 4 Mar 2025 14:13:21 +0000 Subject: [PATCH 42/43] common_grammar_trigger: always use string value (+ optional token) instead of. variant --- common/common.h | 3 ++- common/sampling.cpp | 6 +++--- examples/server/server.cpp | 8 ++++---- examples/server/utils.hpp | 4 +--- tests/test-chat.cpp | 6 +++--- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/common/common.h b/common/common.h index f1fbd51e77702..5d5b81716acd7 100644 --- a/common/common.h +++ b/common/common.h @@ -119,7 +119,8 @@ enum common_grammar_trigger_type { struct common_grammar_trigger { common_grammar_trigger_type type; - std::variant value; + std::string value; + llama_token token = LLAMA_TOKEN_NULL; }; // sampling parameters diff --git a/common/sampling.cpp b/common/sampling.cpp index b22a53dffd71f..ee51d7b8d0780 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -168,20 +168,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co switch (trigger.type) { case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: { - const auto & word = std::get(trigger.value); + const auto & word = trigger.value; patterns_anywhere.push_back(regex_escape(word)); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START: { - const auto & pattern = std::get(trigger.value); + const auto & pattern = trigger.value; (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START ? patterns_at_start : patterns_anywhere).push_back(pattern); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: { - const auto & token = std::get(trigger.value); + const auto token = trigger.token; trigger_tokens.push_back(token); break; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6d62959c00823..ef00975d0cd6c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -135,16 +135,16 @@ struct slot_params { for (const auto & trigger : sampling.grammar_triggers) { switch (trigger.type) { case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: - grammar_triggers.push_back({{"word", std::get(trigger.value)}}); + grammar_triggers.push_back({{"word", trigger.value}}); break; case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: - grammar_triggers.push_back({{"pattern", std::get(trigger.value)}}); + grammar_triggers.push_back({{"pattern", trigger.value}}); break; case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START: - grammar_triggers.push_back({{"pattern_start", std::get(trigger.value)}}); + grammar_triggers.push_back({{"pattern_start", trigger.value}}); break; case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: - grammar_triggers.push_back({{"token", std::get(trigger.value)}}); + grammar_triggers.push_back({{"token", trigger.token}}); break; } } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e6a61c17ab0c2..0c46ee6bc966a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -622,9 +622,7 @@ static json oaicompat_completion_params_parse( for (const auto & trigger : chat_params.grammar_triggers) { grammar_triggers.push_back({ {"type", (int) trigger.type}, - {"value", trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN - ? json((int) std::get(trigger.value)) - : json(std::get(trigger.value))}, + {"value", trigger.token}, }); } llama_params["grammar_triggers"] = grammar_triggers; diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 4fb8198374e71..35a307c632e68 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -242,13 +242,13 @@ static void test_templates(const struct common_chat_templates * tmpls, const std switch (trigger.type) { case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: { - const auto & word = std::get(trigger.value); + const auto & word = trigger.value; pos = constrained.find(word); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: { - const auto & pattern = std::get(trigger.value); + const auto & pattern = trigger.value; if (std::regex_search(constrained, match, std::regex(pattern))) { pos = match.position(); } @@ -256,7 +256,7 @@ static void test_templates(const struct common_chat_templates * tmpls, const std } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_START: { - const auto & pattern = std::get(trigger.value); + const auto & pattern = trigger.value; if (std::regex_search(constrained, match, std::regex(pattern)) && match.position() == 0) { pos = 0; } From ba7b185a35204b00f43b72a68b13860cb7e84592 Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Tue, 4 Mar 2025 19:00:01 +0000 Subject: [PATCH 43/43] add llama_grammar_trigger_pattern --- examples/server/server.cpp | 2 +- src/llama-grammar.cpp | 25 ++++++++++++------------- src/llama-grammar.h | 7 ++++++- src/llama-sampling.cpp | 4 ++-- 4 files changed, 21 insertions(+), 17 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ef00975d0cd6c..fb041b0749924 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -397,7 +397,7 @@ struct server_task { throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); } SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); - params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN, token}); + params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN, word, token}); } else { SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 1f56dae2aec91..a66c4f1002d21 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1141,7 +1141,7 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - auto grammar = new llama_grammar { + auto * grammar = new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), @@ -1242,14 +1242,16 @@ struct llama_grammar * llama_grammar_init_impl( } while (true); std::vector vec_trigger_tokens; - std::vector> vec_trigger_patterns; + std::vector vec_trigger_patterns; for (size_t i = 0; i < num_trigger_tokens; i++) { GGML_ASSERT(trigger_tokens != nullptr); vec_trigger_tokens.push_back(trigger_tokens[i]); } for (size_t i = 0; i < num_trigger_patterns; i++) { GGML_ASSERT(trigger_patterns != nullptr); - vec_trigger_patterns.emplace_back(trigger_patterns[i], trigger_patterns[i]); + auto & trigger = vec_trigger_patterns.back(); + trigger.pattern = trigger_patterns[i]; + trigger.regex = std::regex(trigger.pattern); } std::vector sorted_tokens; @@ -1259,7 +1261,7 @@ struct llama_grammar * llama_grammar_init_impl( if (mask && vocab) { printf("Masking %d tokens\n", llama_vocab_n_tokens(vocab)); for (size_t i = 0, n = llama_vocab_n_tokens(vocab); i < n; i++) { - auto & piece = vocab->token_to_piece(i); + const auto & piece = vocab->token_to_piece(i); sorted_tokens.push_back({ (llama_token) i, piece, @@ -1279,7 +1281,7 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - auto grammar = new llama_grammar { + auto * grammar = new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), @@ -1311,7 +1313,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { } struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { - llama_grammar * result = new llama_grammar { + auto * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, @@ -1370,8 +1372,8 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a LLAMA_LOG_ERROR(" EOG\n"); } else { LLAMA_LOG_ERROR(" %d\n", stack.back()->type); - auto & res = llama_grammar_match_tokens(grammar, stack.back()); - LLAMA_LOG_ERROR("res: %d\n", res.allowed_token_ranges.size()); + const auto & res = llama_grammar_match_tokens(grammar, stack.back()); + LLAMA_LOG_ERROR("res: %zu\n", res.allowed_token_ranges.size()); } // for (const auto & elem : stack) { // LLAMA_LOG_ERROR(" %d", elem->type); @@ -1416,9 +1418,6 @@ void llama_grammar_apply_impl(struct llama_grammar & grammar, llama_token_data_a const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { - auto & token = cur_p->data[reject.index]; - // fprintf(stderr, "- Rejecting token %u (`%s`)\n", token.id, grammar.vocab->token_to_piece(token.id).c_str()); - // fflush(stderr); cur_p->data[reject.index].logit = -INFINITY; } @@ -1450,8 +1449,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token grammar.trigger_buffer += piece; std::smatch match; - for (const auto & [_, regex] : grammar.trigger_patterns) { - if (std::regex_match(grammar.trigger_buffer, match, regex)) { + for (const auto & trigger_pattern : grammar.trigger_patterns) { + if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { grammar.awaiting_trigger = false; // get from the first match to the end of the string auto constrained_str = grammar.trigger_buffer.substr(match.position(1)); diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 85f819d39cded..297bcef42e615 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -279,6 +279,11 @@ struct llama_grammar_token { std::pair, llama_partial_utf8> codepoints; }; +struct llama_grammar_trigger_pattern { + std::string pattern; + std::regex regex; +}; + struct llama_grammar { // note: allow null vocab for testing (not great) const llama_vocab * vocab; @@ -300,7 +305,7 @@ struct llama_grammar { bool awaiting_trigger = false; // Initialized to true for lazy grammars only std::string trigger_buffer; // Output buffered by lazy grammar. Will be cleared once trigger is found. std::vector trigger_tokens; // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special). - std::vector> + std::vector trigger_patterns; // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated // string, and the grammar will be given the string from the first match group onwards. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 388998d949d29..c25977ca3a35c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1461,8 +1461,8 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { std::vector trigger_patterns_c; trigger_patterns_c.reserve(ctx->grammar->trigger_patterns.size()); - for (auto & [pattern, _] : ctx->grammar->trigger_patterns) { - trigger_patterns_c.push_back(pattern.c_str()); + for (auto & trigger_pattern : ctx->grammar->trigger_patterns) { + trigger_patterns_c.push_back(trigger_pattern.pattern.c_str()); } auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str(),