diff --git a/CMakeLists.txt b/CMakeLists.txt index 847465e6..ec2b84a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,7 @@ FetchContent_MakeAvailable(json) FetchContent_Declare( llama.cpp GIT_REPOSITORY https://github.com/ggerganov/llama.cpp.git - GIT_TAG b3534 + GIT_TAG b3751 ) FetchContent_MakeAvailable(llama.cpp) diff --git a/src/main/cpp/jllama.cpp b/src/main/cpp/jllama.cpp index d59f3b77..07eef014 100644 --- a/src/main/cpp/jllama.cpp +++ b/src/main/cpp/jllama.cpp @@ -386,8 +386,8 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo LOG_INFO("build info", {{"build", LLAMA_BUILD_NUMBER}, {"commit", LLAMA_COMMIT}}); LOG_INFO("system info", { - {"n_threads", params.n_threads}, - {"n_threads_batch", params.n_threads_batch}, + {"n_threads", params.cpuparams.n_threads}, + {"n_threads_batch", params.cpuparams_batch.n_threads}, {"total_threads", std::thread::hardware_concurrency()}, {"system_info", llama_print_system_info()}, }); @@ -445,14 +445,10 @@ JNIEXPORT void JNICALL Java_de_kherud_llama_LlamaModel_loadModel(JNIEnv *env, jo }); } - ctx_server->queue_tasks.on_new_task( - std::bind(&server_context::process_single_task, ctx_server, std::placeholders::_1)); - ctx_server->queue_tasks.on_finish_multitask( - std::bind(&server_context::on_finish_multitask, ctx_server, std::placeholders::_1)); - ctx_server->queue_tasks.on_update_slots(std::bind(&server_context::update_slots, ctx_server)); - ctx_server->queue_results.on_multitask_update(std::bind(&server_queue::update_multitask, &ctx_server->queue_tasks, - std::placeholders::_1, std::placeholders::_2, - std::placeholders::_3)); + ctx_server->queue_tasks.on_new_task(std::bind( + &server_context::process_single_task, ctx_server, std::placeholders::_1)); + ctx_server->queue_tasks.on_update_slots(std::bind( + &server_context::update_slots, ctx_server)); std::thread t([ctx_server]() { JNIEnv *env; @@ -479,7 +475,11 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv std::string c_params = parse_jstring(env, jparams); json json_params = json::parse(c_params); - const bool infill = json_params.contains("input_prefix") || json_params.contains("input_suffix"); + + server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + if (json_params.contains("input_prefix") || json_params.contains("input_suffix")) { + cmpl_type = SERVER_TASK_CMPL_TYPE_INFILL; + } if (json_params.value("use_chat_template", false)) { @@ -489,11 +489,18 @@ JNIEXPORT jint JNICALL Java_de_kherud_llama_LlamaModel_requestCompletion(JNIEnv json_params["prompt"] = format_chat(ctx_server->model, ctx_server->params.chat_template, chat); } - const int id_task = ctx_server->queue_tasks.get_new_id(); - ctx_server->queue_results.add_waiting_task_id(id_task); - ctx_server->request_completion(id_task, -1, json_params, infill, false); + std::vector tasks = ctx_server->create_tasks_cmpl(json_params, cmpl_type); + ctx_server->queue_results.add_waiting_tasks(tasks); + ctx_server->queue_tasks.post(tasks); + + const auto task_ids = server_task::get_list_id(tasks); + + if (task_ids.size() != 1) { + env->ThrowNew(c_llama_error, "multitasking currently not supported"); + return 0; + } - return id_task; + return *task_ids.begin(); } JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_receiveCompletion(JNIEnv *env, jobject obj, jint id_task) @@ -555,20 +562,36 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env, const std::string prompt = parse_jstring(env, jprompt); - const int id_task = ctx_server->queue_tasks.get_new_id(); - ctx_server->queue_results.add_waiting_task_id(id_task); - ctx_server->request_completion(id_task, -1, {{"prompt", prompt}}, false, true); + std::vector tasks = ctx_server->create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); + ctx_server->queue_results.add_waiting_tasks(tasks); + ctx_server->queue_tasks.post(tasks); - server_task_result result = ctx_server->queue_results.recv(id_task); - ctx_server->queue_results.remove_waiting_task_id(id_task); - if (result.error) + std::unordered_set task_ids = server_task::get_list_id(tasks); + + json responses = json::array(); + + json error = nullptr; + ctx_server->receive_cmpl_results(task_ids, [&](std::vector & results) { + for (const auto & res : results) { + responses.push_back(res.data); + } + }, [&](const json& error_data) { + error = error_data; + }); + + if (error != nullptr) { - std::string response = result.data["message"].get(); + std::string response = error["message"].get(); env->ThrowNew(c_llama_error, response.c_str()); return nullptr; } - std::vector embedding = result.data["embedding"].get>(); + if (responses.size() != 1) { + env->ThrowNew(c_llama_error, "could not compute embedding"); + return nullptr; + } + + std::vector embedding = responses[0]["embedding"].get>(); jsize embedding_size = embedding.size(); // NOLINT(*-narrowing-conversions) jfloatArray j_embedding = env->NewFloatArray(embedding_size); diff --git a/src/main/cpp/server.hpp b/src/main/cpp/server.hpp index 0601dac4..029721c1 100644 --- a/src/main/cpp/server.hpp +++ b/src/main/cpp/server.hpp @@ -1,6 +1,7 @@ #include "utils.hpp" #include "common.h" +#include "sampling.h" #include "grammar-parser.h" #include "llama.h" @@ -10,11 +11,13 @@ #include #include #include -#include #include -#include -#include #include +#include +#include +#include +#include +#include using json = nlohmann::ordered_json; @@ -24,24 +27,18 @@ enum stop_type STOP_TYPE_PARTIAL, }; -enum slot_state -{ +// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283 +enum slot_state { SLOT_STATE_IDLE, - SLOT_STATE_PROCESSING, -}; - -enum slot_command -{ - SLOT_COMMAND_NONE, - SLOT_COMMAND_LOAD_PROMPT, - SLOT_COMMAND_RELEASE, + SLOT_STATE_PROCESSING_PROMPT, + SLOT_STATE_DONE_PROMPT, + SLOT_STATE_GENERATING, }; enum server_state { SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet SERVER_STATE_READY, // Server is ready and model is loaded - SERVER_STATE_ERROR // An error occurred, load_model failed }; enum server_task_type @@ -53,25 +50,41 @@ enum server_task_type SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, +}; + +enum server_task_cmpl_type { + SERVER_TASK_CMPL_TYPE_NORMAL, + SERVER_TASK_CMPL_TYPE_EMBEDDING, + SERVER_TASK_CMPL_TYPE_INFILL, }; struct server_task { int id = -1; // to be filled by server_queue - int id_multi = -1; - int id_target = -1; + int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL server_task_type type; json data; bool infill = false; bool embedding = false; + + server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + + // utility function + static std::unordered_set get_list_id(const std::vector & tasks) { + std::unordered_set ids(tasks.size()); + for (size_t i = 0; i < tasks.size(); i++) { + ids.insert(tasks[i].id); + } + return ids; + } }; struct server_task_result { int id = -1; - int id_multi = -1; json data; @@ -79,14 +92,6 @@ struct server_task_result bool error; }; -struct server_task_multi -{ - int id = -1; - - std::set subtasks_remaining; - std::vector results; -}; - struct slot_params { bool stream = true; @@ -107,12 +112,13 @@ struct server_slot { int id; int id_task = -1; - int id_multi = -1; + + // the index relative to completion multi-task request + size_t index = 0; struct slot_params params; slot_state state = SLOT_STATE_IDLE; - slot_command command = SLOT_COMMAND_NONE; // used to determine the slot that has been used the longest int64_t t_last_used = -1; @@ -137,6 +143,7 @@ struct server_slot std::vector cache_tokens; std::vector generated_token_probs; + server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; bool infill = false; bool embedding = false; bool has_next_token = true; @@ -151,11 +158,13 @@ struct server_slot std::string stopping_word; // sampling - llama_token sampled; - struct llama_sampling_params sparams; - llama_sampling_context *ctx_sampling = nullptr; json json_schema; + struct gpt_sampler_params sparams; + struct gpt_sampler * smpl = nullptr; + + llama_token sampled; + int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width @@ -172,6 +181,8 @@ struct server_slot double t_prompt_processing; // ms double t_token_generation; // ms + std::function callback_on_release; + void reset() { n_prompt_tokens = 0; @@ -184,7 +195,7 @@ struct server_slot n_past = 0; n_sent_text = 0; n_sent_token_probs = 0; - infill = false; + cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; ga_i = 0; n_past_se = 0; @@ -212,20 +223,13 @@ struct server_slot return n_remaining > 0; // no budget } - bool available() const - { - return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE; - } - - bool is_processing() const - { - return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING; + bool is_processing() const { + return state != SLOT_STATE_IDLE; } void add_token_string(const completion_token_output &token) { - if (command == SLOT_COMMAND_RELEASE) - { + if (!is_processing()) { return; } generated_token_probs.push_back(token); @@ -233,10 +237,16 @@ struct server_slot void release() { - if (state == SLOT_STATE_PROCESSING) - { + if (is_processing()) { t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; - command = SLOT_COMMAND_RELEASE; + state = SLOT_STATE_IDLE; + LOG_INFO("slot released", { + {"id_slot", id}, + {"id_task", id_task}, + {"n_past", n_past}, + {"truncated", truncated}, + }); + callback_on_release(id); } } @@ -353,6 +363,9 @@ struct server_metrics uint64_t n_tokens_predicted = 0; uint64_t t_tokens_generation = 0; + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + void init() { t_start = ggml_time_us(); @@ -374,8 +387,16 @@ struct server_metrics t_tokens_generation_total += slot.t_token_generation; } - void reset_bucket() - { + void on_decoded(const std::vector & slots) { + n_decode_total++; + for (const auto & slot : slots) { + if (slot.is_processing()) { + n_busy_slots_total++; + } + } + } + + void reset_bucket() { n_prompt_tokens_processed = 0; t_prompt_processing = 0; n_tokens_predicted = 0; @@ -389,38 +410,57 @@ struct server_queue bool running; // queues - std::vector queue_tasks; - std::vector queue_tasks_deferred; - - std::vector queue_multitasks; + std::deque queue_tasks; + std::deque queue_tasks_deferred; std::mutex mutex_tasks; std::condition_variable condition_tasks; // callback functions std::function callback_new_task; - std::function callback_finish_multitask; std::function callback_update_slots; // Add a new task to the end of the queue - int post(server_task task) - { + int post(server_task task, bool front = false) { std::unique_lock lock(mutex_tasks); if (task.id == -1) { task.id = id++; LOG_VERBOSE("new task id", {{"new_id", task.id}}); } + if (front) { + queue_tasks.push_front(std::move(task)); + } else { queue_tasks.push_back(std::move(task)); + } condition_tasks.notify_one(); return task.id; } + // multi-task version of post() + int post(std::vector & tasks, bool front = false) { + std::unique_lock lock(mutex_tasks); + for (auto & task : tasks) { + if (task.id == -1) { + task.id = id++; + LOG_VERBOSE("new task id", {{"new_id", task.id}}); + } + if (front) { + queue_tasks.push_front(std::move(task)); + } else { + queue_tasks.push_back(std::move(task)); + } + } + condition_tasks.notify_one(); + return 0; + } + // Add a new task, but defer until one slot is available void defer(server_task task) { std::unique_lock lock(mutex_tasks); queue_tasks_deferred.push_back(std::move(task)); + condition_tasks.notify_one(); } // Get the next id for creating anew task @@ -438,28 +478,20 @@ struct server_queue callback_new_task = std::move(callback); } - // Register function to process a multitask when it is finished - void on_finish_multitask(std::function callback) - { - callback_finish_multitask = std::move(callback); - } - // Register the function to be called when all slots data is ready to be processed void on_update_slots(std::function callback) { callback_update_slots = std::move(callback); } - // Call when the state of one slot is changed - void notify_slot_changed() - { - // move deferred tasks back to main loop + // Call when the state of one slot is changed, it will move one task from deferred to main queue + void pop_deferred_task() { std::unique_lock lock(mutex_tasks); - for (auto &task : queue_tasks_deferred) - { - queue_tasks.push_back(std::move(task)); + if (!queue_tasks_deferred.empty()) { + queue_tasks.emplace_back(std::move(queue_tasks_deferred.front())); + queue_tasks_deferred.pop_front(); } - queue_tasks_deferred.clear(); + condition_tasks.notify_one(); } // end the start_loop routine @@ -494,32 +526,12 @@ struct server_queue break; } server_task task = queue_tasks.front(); - queue_tasks.erase(queue_tasks.begin()); + queue_tasks.pop_front(); lock.unlock(); LOG_VERBOSE("callback_new_task", {{"id_task", task.id}}); callback_new_task(task); } - LOG_VERBOSE("update_multitasks", {}); - - // check if we have any finished multitasks - auto queue_iterator = queue_multitasks.begin(); - while (queue_iterator != queue_multitasks.end()) - { - if (queue_iterator->subtasks_remaining.empty()) - { - // all subtasks done == multitask is done - server_task_multi current_multitask = *queue_iterator; - callback_finish_multitask(current_multitask); - // remove this multitask - queue_iterator = queue_multitasks.erase(queue_iterator); - } - else - { - ++queue_iterator; - } - } - // all tasks in the current loop is processed, slots data is now ready LOG_VERBOSE("callback_update_slots", {}); @@ -535,49 +547,18 @@ struct server_queue LOG_VERBOSE("ending start_loop", {}); return; } - condition_tasks.wait(lock, [&] { return (!queue_tasks.empty() || !running); }); - } - } - } + condition_tasks.wait(lock, [&]{ + return (!queue_tasks.empty() || !running); + }); } - - // - // functions to manage multitasks - // - - // add a multitask by specifying the id of all subtask (subtask is a server_task) - void add_multitask(int id_multi, std::vector &sub_ids) - { - std::lock_guard lock(mutex_tasks); - server_task_multi multi; - multi.id = id_multi; - std::copy(sub_ids.begin(), sub_ids.end(), - std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); - queue_multitasks.push_back(multi); - } - - // updatethe remaining subtasks, while appending results to multitask - void update_multitask(int id_multi, int id_sub, server_task_result &result) - { - std::lock_guard lock(mutex_tasks); - for (auto &multitask : queue_multitasks) - { - if (multitask.id == id_multi) - { - multitask.subtasks_remaining.erase(id_sub); - multitask.results.push_back(result); } } } }; -struct server_response -{ - typedef std::function callback_multitask_t; - callback_multitask_t callback_update_multitask; - +struct server_response { // for keeping track of all tasks waiting for the result - std::set waiting_task_ids; + std::unordered_set waiting_task_ids; // the main result queue std::vector queue_results; @@ -594,6 +575,12 @@ struct server_response waiting_task_ids.insert(id_task); } + void add_waiting_tasks(const std::vector & tasks) { + for (const auto & t : tasks) { + add_waiting_task_id(t.id); + } + } + // when the request is finished, we can remove task associated with it void remove_waiting_task_id(int id_task) { @@ -603,9 +590,8 @@ struct server_response waiting_task_ids.erase(id_task); } - // This function blocks the thread until there is a response for this id_task - server_task_result recv(int id_task) - { + // This function blocks the thread until there is a response for one of the id_tasks + server_task_result recv(const std::unordered_set & id_tasks) { while (true) { std::unique_lock lock(mutex_results); @@ -613,9 +599,7 @@ struct server_response for (int i = 0; i < (int)queue_results.size(); i++) { - if (queue_results[i].id == id_task) - { - assert(queue_results[i].id_multi == -1); + if (id_tasks.find(queue_results[i].id) != id_tasks.end()) { server_task_result res = queue_results[i]; queue_results.erase(queue_results.begin() + i); return res; @@ -626,33 +610,22 @@ struct server_response // should never reach here } - // Register the function to update multitask - void on_multitask_update(callback_multitask_t callback) - { - callback_update_multitask = std::move(callback); + // single-task version of recv() + server_task_result recv(int id_task) { + std::unordered_set id_tasks = {id_task}; + return recv(id_tasks); } // Send a new result to a waiting id_task - void send(server_task_result result) - { + void send(server_task_result & result) { LOG_VERBOSE("send new result", {{"id_task", result.id}}); std::unique_lock lock(mutex_results); - for (const auto &id_task : waiting_task_ids) - { - // LOG_TEE("waiting task id %i \n", id_task); - // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result - if (result.id_multi == id_task) - { - LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}}); - callback_update_multitask(id_task, result.id, result); - continue; - } - + for (const auto & id_task : waiting_task_ids) { if (result.id == id_task) { LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}}); - queue_results.push_back(result); + queue_results.push_back(std::move(result)); condition_results.notify_all(); return; } @@ -664,13 +637,15 @@ struct server_context { llama_model *model = nullptr; llama_context *ctx = nullptr; + std::vector lora_adapters; gpt_params params; - llama_batch batch; + llama_batch batch = {}; bool clean_kv_cache = true; bool add_bos_token = true; + bool has_eos_token = false; int32_t n_ctx; // total context for all clients / slots @@ -709,9 +684,8 @@ struct server_context // Clear any sampling context for (server_slot &slot : slots) { - if (slot.ctx_sampling != nullptr) - { - llama_sampling_free(slot.ctx_sampling); + if (slot.smpl != nullptr) { + gpt_sampler_free(slot.smpl); } } @@ -729,6 +703,7 @@ struct server_context model = llama_init.model; ctx = llama_init.context; + lora_adapters = llama_init.lora_adapters; params.n_parallel -= 1; // but be sneaky about it if (model == nullptr) { @@ -738,8 +713,8 @@ struct server_context n_ctx = llama_n_ctx(ctx); - add_bos_token = llama_should_add_bos_token(model); - GGML_ASSERT(llama_add_eos_token(model) != 1); + add_bos_token = llama_add_bos_token(model); + has_eos_token = !llama_add_eos_token(model); return true; } @@ -788,6 +763,10 @@ struct server_context slot.sparams = params.sparams; + slot.callback_on_release = [this](int) { + queue_tasks.pop_deferred_task(); + }; + slot.reset(); slots.push_back(slot); @@ -796,14 +775,13 @@ struct server_context default_generation_settings_for_props = get_formated_generation(slots.front()); default_generation_settings_for_props["seed"] = -1; - // the update_slots() logic will always submit a maximum of n_batch tokens - // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not - // used) + // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens + // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch = llama_batch_init(n_batch, 0, 1); + batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1); } metrics.init(); @@ -812,9 +790,8 @@ struct server_context std::vector tokenize(const json &json_prompt, bool add_special) const { // TODO: currently, we tokenize using special tokens by default - // this is not always correct (see - // https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) but it's better compared to - // completely ignoring ChatML and other chat templates + // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216) + // but it's better compared to completely ignoring ChatML and other chat templates const bool TMP_FORCE_SPECIAL = true; // If `add_bos` is true, we only add BOS, when json_prompt is a string, @@ -889,8 +866,7 @@ struct server_context for (server_slot &slot : slots) { // skip the slot if it is not available - if (!slot.available()) - { + if (slot.is_processing()) { continue; } @@ -937,8 +913,7 @@ struct server_context for (server_slot &slot : slots) { // skip the slot if it is not available - if (!slot.available()) - { + if (slot.is_processing()) { continue; } @@ -965,10 +940,9 @@ struct server_context bool launch_slot_with_task(server_slot &slot, const server_task &task) { slot_params default_params; - // Sampling parameter defaults are loaded from the global server context (but individual requests can still - // override them) - llama_sampling_params default_sparams = params.sparams; - auto &data = task.data; + // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) + auto default_sparams = params.sparams; + const auto & data = task.data; slot.oaicompat = false; slot.oaicompat_model = ""; @@ -980,7 +954,7 @@ struct server_context slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); @@ -1021,8 +995,7 @@ struct server_context slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix); // get prompt - if (!task.infill) - { + if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) { const auto &prompt = data.find("prompt"); if (prompt == data.end()) { @@ -1034,9 +1007,9 @@ struct server_context (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) { slot.prompt = *prompt; - } - else - { + } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) { + slot.prompt = prompt->at(0); + } else { send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST); return false; } @@ -1098,9 +1071,8 @@ struct server_context { slot.sparams.logit_bias.clear(); - if (json_value(data, "ignore_eos", false)) - { - slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + if (json_value(data, "ignore_eos", false) && has_eos_token) { + slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY}); } const auto &logit_bias = data.find("logit_bias"); @@ -1131,7 +1103,7 @@ struct server_context llama_token tok = el[0].get(); if (tok >= 0 && tok < n_vocab) { - slot.sparams.logit_bias[tok] = bias; + slot.sparams.logit_bias.push_back({tok, bias}); } } else if (el[0].is_string()) @@ -1139,7 +1111,7 @@ struct server_context auto toks = llama_tokenize(model, el[0].get(), false); for (auto tok : toks) { - slot.sparams.logit_bias[tok] = bias; + slot.sparams.logit_bias.push_back({tok, bias}); } } } @@ -1164,40 +1136,34 @@ struct server_context } { - const auto &samplers_sequence = data.find("samplers"); - if (samplers_sequence != data.end() && samplers_sequence->is_array()) - { + const auto & samplers = data.find("samplers"); + if (samplers != data.end() && samplers->is_array()) { std::vector sampler_names; - for (const auto &sampler_name : *samplers_sequence) - { - if (sampler_name.is_string()) - { - sampler_names.emplace_back(sampler_name); - } + for (const auto & name : *samplers) { + if (name.is_string()) { + sampler_names.emplace_back(name); } - slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false); } - else - { - slot.sparams.samplers_sequence = default_sparams.samplers_sequence; + slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false); + } else { + slot.sparams.samplers = default_sparams.samplers; } } { - if (slot.ctx_sampling != nullptr) - { - llama_sampling_free(slot.ctx_sampling); + if (slot.smpl != nullptr) { + gpt_sampler_free(slot.smpl); } - slot.ctx_sampling = llama_sampling_init(slot.sparams); - if (slot.ctx_sampling == nullptr) - { + + slot.smpl = gpt_sampler_init(model, slot.sparams); + if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; } } - slot.command = SLOT_COMMAND_LOAD_PROMPT; + slot.state = SLOT_STATE_PROCESSING_PROMPT; slot.prompt_tokens.clear(); LOG_INFO("slot is processing task", { @@ -1230,33 +1196,19 @@ struct server_context { system_tokens = ::llama_tokenize(ctx, system_prompt, true); - llama_batch_clear(batch); + const int32_t n_batch = llama_n_batch(ctx); + const int32_t n_tokens_prompt = system_tokens.size(); - for (int i = 0; i < (int)system_tokens.size(); ++i) - { - llama_batch_add(batch, system_tokens[i], i, {0}, false); - } + for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i); - const int32_t n_batch = llama_n_batch(ctx); + llama_batch_clear(batch); - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) - { - const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, - 0, - 0, // unused - }; + for (int32_t j = 0; j < n_tokens; ++j) { + llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false); + } - if (llama_decode(ctx, batch_view) != 0) - { + if (llama_decode(ctx, batch) != 0) { LOG_ERROR("llama_decode() failed", {}); return; } @@ -1300,12 +1252,6 @@ struct server_context slot.generated_text += token_str; slot.has_next_token = true; - if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) - { - // we can change penalty_prompt_tokens because it is always created from scratch each request - slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); - } - // check if there is incomplete UTF-8 character at the end bool incomplete = false; for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) @@ -1399,9 +1345,8 @@ struct server_context } auto n_ctx_train = llama_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && - slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) - { + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 + && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { LOG_WARNING("n_predict is not set and self-context extend is disabled." " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", { @@ -1438,23 +1383,18 @@ struct server_context return slot.has_next_token; // continue } - json get_formated_generation(const server_slot &slot) const - { - const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); - const bool ignore_eos = - eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); - - std::vector samplers_sequence; - samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); - for (const auto &sampler_type : slot.sparams.samplers_sequence) - { - samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type)); + json get_formated_generation(const server_slot & slot) const { + std::vector samplers; + samplers.reserve(slot.sparams.samplers.size()); + for (const auto & sampler : slot.sparams.samplers) { + samplers.emplace_back(gpt_sampler_type_to_str(sampler)); } return json{{"n_ctx", slot.n_ctx}, - {"n_predict", slot.n_predict}, + {"n_predict", slot.n_predict}, // Server configured n_predict {"model", params.model_alias}, {"seed", slot.sparams.seed}, + {"seed_cur", slot.smpl ? gpt_sampler_get_seed(slot.smpl) : 0}, {"temperature", slot.sparams.temp}, {"dynatemp_range", slot.sparams.dynatemp_range}, {"dynatemp_exponent", slot.sparams.dynatemp_exponent}, @@ -1462,52 +1402,47 @@ struct server_context {"top_p", slot.sparams.top_p}, {"min_p", slot.sparams.min_p}, {"tfs_z", slot.sparams.tfs_z}, - {"typical_p", slot.sparams.typical_p}, + {"typical_p", slot.sparams.typ_p}, {"repeat_last_n", slot.sparams.penalty_last_n}, {"repeat_penalty", slot.sparams.penalty_repeat}, {"presence_penalty", slot.sparams.penalty_present}, {"frequency_penalty", slot.sparams.penalty_freq}, - {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, - {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, {"penalize_nl", slot.sparams.penalize_nl}, {"stop", slot.params.antiprompt}, - {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict + {"max_tokens", slot.params.n_predict}, // User configured n_predict {"n_keep", slot.params.n_keep}, {"n_discard", slot.params.n_discard}, - {"ignore_eos", ignore_eos}, + {"ignore_eos", slot.sparams.ignore_eos}, {"stream", slot.params.stream}, - {"logit_bias", slot.sparams.logit_bias}, + //{"logit_bias", slot.sparams.logit_bias}, {"n_probs", slot.sparams.n_probs}, {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, - {"samplers", samplers_sequence}}; + {"samplers", samplers}, + }; } void send_error(const server_task &task, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(task.id, task.id_multi, error, type); + send_error(task.id, error, type); } void send_error(const server_slot &slot, const std::string &error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.id_task, slot.id_multi, error, type); + send_error(slot.id_task, error, type); } - void send_error(const int id_task, const int id_multi, const std::string &error, - const enum error_type type = ERROR_TYPE_SERVER) - { + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { LOG_ERROR("task error", { - {"id_multi", id_multi}, {"id_task", id_task}, {"error", error}, }); server_task_result res; res.id = id_task; - res.id_multi = id_multi; res.stop = false; res.error = true; res.data = format_error_response(error, type); @@ -1519,10 +1454,15 @@ struct server_context { server_task_result res; res.id = slot.id_task; - res.id_multi = slot.id_multi; res.error = false; res.stop = false; - res.data = json{{"content", tkn.text_to_send}, {"stop", false}, {"id_slot", slot.id}, {"multimodal", false}}; + res.data = json { + {"content", tkn.text_to_send}, + {"stop", false}, + {"id_slot", slot.id}, + {"multimodal", false}, + {"index", slot.index}, + }; if (slot.sparams.n_probs > 0) { @@ -1556,7 +1496,6 @@ struct server_context { server_task_result res; res.id = slot.id_task; - res.id_multi = slot.id_multi; res.error = false; res.stop = true; res.data = json{{"content", !slot.params.stream ? slot.generated_text : ""}, @@ -1573,7 +1512,9 @@ struct server_context {"stopped_limit", slot.stopped_limit}, {"stopping_word", slot.stopping_word}, {"tokens_cached", slot.n_past}, - {"timings", slot.get_formated_timings()}}; + {"timings", slot.get_formated_timings()}, + {"index", slot.index}, + }; if (slot.sparams.n_probs > 0) { @@ -1608,7 +1549,6 @@ struct server_context { server_task_result res; res.id = slot.id_task; - res.id_multi = slot.id_multi; res.error = false; res.stop = true; @@ -1644,103 +1584,132 @@ struct server_context res.data = json{ {"embedding", embd_res}, + {"index", slot.index}, }; } queue_results.send(res); } - void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding) - { + // + // Functions to create new task(s) and receive result(s) + // + + std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { + std::vector tasks; + auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { server_task task; - task.id = id_task; - task.id_multi = id_multi; - task.id_target = 0; - task.data = std::move(data); - task.infill = infill; - task.embedding = embedding; + task.id = queue_tasks.get_new_id(); + task.cmpl_type = cmpl_type; task.type = SERVER_TASK_TYPE_COMPLETION; + if (replace_prompt) { + task.data = task_data; + task.data["prompt"] = prompt; + } else { + task.data = std::move(task_data); + } + tasks.push_back(std::move(task)); + }; - // when a completion task's prompt array is not a singleton, we split it into multiple requests - // otherwise, it's a single-prompt task, we actually queue it - // if there's numbers in the prompt array it will be treated as an array of tokens - if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) - { - bool numbers = false; - for (const auto &e : task.data.at("prompt")) - { - if (e.is_number()) - { - numbers = true; - break; + static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts"; + if (!data.contains("prompt")) { + throw std::runtime_error(error_msg); } - } - // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers, - // it will completely stall the server. I don't know where the bug for this is. - // - // if there are numbers, it needs to be treated like a single prompt, - // queue_tasks handles a mix of strings and numbers just fine. - if (numbers) - { - queue_tasks.post(task); - } - else - { - split_multiprompt_task(id_task, task); + json prompt = data.at("prompt"); + + // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task + if (prompt.is_string() || json_is_array_of_numbers(prompt)) { + data["index"] = 0; + create_task(data, false, nullptr); + } + // otherwise, it's a multiple-prompt task, we break it into smaller tasks + else if (prompt.is_array()) { + std::vector prompts = prompt; + for (size_t i = 0; i < prompts.size(); i++) { + const auto & e = prompts[i]; + if (e.is_string() || json_is_array_of_numbers(e)) { + data["index"] = i; + create_task(data, true, e); + } else { + throw std::runtime_error(error_msg); + } } } - else - { - queue_tasks.post(task); - } + // invalid case + else { + throw std::runtime_error(error_msg); + } + + return tasks; } - void request_cancel(int id_task) - { + void cancel_tasks(const std::unordered_set & id_tasks) { + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto & id_task : id_tasks) { + LOG_VERBOSE("cancel task", {{"id_task", id_task}}); server_task task; task.type = SERVER_TASK_TYPE_CANCEL; task.id_target = id_task; - - queue_tasks.post(task); + cancel_tasks.push_back(task); + queue_results.remove_waiting_task_id(id_task); + } + // push to beginning of the queue, so it has highest priority + queue_tasks.post(cancel_tasks, true); } - void split_multiprompt_task(int id_multi, const server_task &multiprompt_task) - { - const int prompt_count = multiprompt_task.data.at("prompt").size(); - if (prompt_count <= 1) - { - send_error(multiprompt_task, "error while handling multiple prompts"); - return; - } + // receive the results from task(s) created by create_tasks_cmpl + void receive_cmpl_results(const std::unordered_set & id_tasks, std::function&)> result_handler, std::function error_handler) { + // TODO: currently, there is no way to detect the client has cancelled the request + std::vector results(id_tasks.size()); + for (size_t i = 0; i < id_tasks.size(); i++) { + server_task_result result = queue_results.recv(id_tasks); - // generate all the ID for subtask - std::vector subtask_ids(prompt_count); - for (int i = 0; i < prompt_count; i++) - { - subtask_ids[i] = queue_tasks.get_new_id(); + if (result.error) { + error_handler(result.data); + cancel_tasks(id_tasks); + break; + } + + size_t idx = result.data["index"]; + results[idx] = result; } + result_handler(results); + } - // queue up the multitask so we can track its subtask progression - queue_tasks.add_multitask(id_multi, subtask_ids); + // receive the results from task(s) created by create_tasks_cmpl, in stream mode + void receive_cmpl_results_stream(const std::unordered_set & id_tasks, std::function result_handler, std::function error_handler) { + size_t n_finished = 0; + while (true) { + server_task_result result = queue_results.recv(id_tasks); + if (!result_handler(result)) { + cancel_tasks(id_tasks); + break; + } - // add subtasks - for (int i = 0; i < prompt_count; i++) - { - json subtask_data = multiprompt_task.data; - subtask_data["prompt"] = subtask_data.at("prompt")[i]; + if (result.error) { + error_handler(result.data); + cancel_tasks(id_tasks); + break; + } - // subtasks inherit everything else (infill mode, embedding mode, etc.) - request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, - multiprompt_task.embedding); + if (result.stop) { + if (++n_finished == id_tasks.size()) { + break; + } + } } } - void process_single_task(const server_task &task) + // + // Functions to process the task + // + + void process_single_task(const server_task & task) { + switch (task.type) { + case SERVER_TASK_TYPE_COMPLETION: { - switch (task.type) - { - case SERVER_TASK_TYPE_COMPLETION: { const int id_slot = json_value(task.data, "id_slot", -1); server_slot *slot; @@ -1767,8 +1736,7 @@ struct server_context queue_tasks.defer(task); break; } - if (!slot->available()) - { + if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); queue_tasks.defer(task); @@ -1790,9 +1758,8 @@ struct server_context slot->reset(); slot->id_task = task.id; - slot->id_multi = task.id_multi; - slot->infill = task.infill; - slot->embedding = task.embedding; + slot->cmpl_type = task.cmpl_type; + slot->index = json_value(task.data, "index", 0); if (!launch_slot_with_task(*slot, task)) { @@ -1848,9 +1815,11 @@ struct server_context slots_data.push_back(slot_data); } - LOG_INFO( - "slot data", - {{"id_task", task.id}, {"n_idle_slots", n_idle_slots}, {"n_processing_slots", n_processing_slots}}); + LOG_INFO("slot data", { + {"id_task", task.id}, + {"n_idle_slots", n_idle_slots}, + {"n_processing_slots", n_processing_slots} + }); LOG_VERBOSE("slot data", {{"id_task", task.id}, {"n_idle_slots", n_idle_slots}, @@ -1859,7 +1828,6 @@ struct server_context server_task_result res; res.id = task.id; - res.id_multi = task.id_multi; res.stop = true; res.error = false; res.data = { @@ -1878,6 +1846,9 @@ struct server_context {"n_tokens_predicted", metrics.n_tokens_predicted}, {"t_tokens_generation", metrics.t_tokens_generation}, + { "n_decode_total", metrics.n_decode_total}, + { "n_busy_slots_total", metrics.n_busy_slots_total}, + {"kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)}, {"kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)}, @@ -1899,8 +1870,7 @@ struct server_context send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } - if (!slot->available()) - { + if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); queue_tasks.defer(task); @@ -1939,8 +1909,7 @@ struct server_context send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } - if (!slot->available()) - { + if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); queue_tasks.defer(task); @@ -1988,8 +1957,7 @@ struct server_context send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); break; } - if (!slot->available()) - { + if (slot->is_processing()) { // if requested slot is unavailable, we defer this task for processing later LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}}); queue_tasks.defer(task); @@ -2009,27 +1977,17 @@ struct server_context queue_results.send(result); } break; - } - } - - void on_finish_multitask(const server_task_multi &multitask) + case SERVER_TASK_TYPE_SET_LORA: { - // all subtasks done == multitask is done + llama_lora_adapters_apply(ctx, lora_adapters); server_task_result result; - result.id = multitask.id; + result.id = task.id; result.stop = true; result.error = false; - - // collect json results into one json result - std::vector result_jsons; - for (const auto &subres : multitask.results) - { - result_jsons.push_back(subres.data); - result.error = result.error && subres.error; - } - result.data = json{{"results", result_jsons}}; - + result.data = json{{ "success", true }}; queue_results.send(result); + } break; + } } void update_slots() @@ -2039,35 +1997,13 @@ struct server_context system_prompt_update(); } - // release slots - for (auto &slot : slots) - { - if (slot.command == SLOT_COMMAND_RELEASE) - { - slot.state = SLOT_STATE_IDLE; - slot.command = SLOT_COMMAND_NONE; - slot.t_last_used = ggml_time_us(); - - LOG_INFO("slot released", {{"id_slot", slot.id}, - {"id_task", slot.id_task}, - {"n_ctx", n_ctx}, - {"n_past", slot.n_past}, - {"n_system_tokens", system_tokens.size()}, - {"n_cache_tokens", slot.cache_tokens.size()}, - {"truncated", slot.truncated}}); - - queue_tasks.notify_slot_changed(); - } - } - // check if all slots are idle { bool all_idle = true; for (auto &slot : slots) { - if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) - { + if (slot.is_processing()) { all_idle = false; break; } @@ -2145,8 +2081,7 @@ struct server_context // frist, add sampled tokens from any ongoing sequences for (auto &slot : slots) { - if (slot.state == SLOT_STATE_IDLE) - { + if (slot.state != SLOT_STATE_GENERATING) { continue; } @@ -2189,8 +2124,7 @@ struct server_context for (auto &slot : slots) { // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) - { + if (slot.state == SLOT_STATE_PROCESSING_PROMPT) { auto &prompt_tokens = slot.prompt_tokens; // we haven't tokenized the prompt yet - do it now: @@ -2201,9 +2135,8 @@ struct server_context slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - if (slot.infill) - { - const bool add_bos = llama_should_add_bos_token(model); + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) { + const bool add_bos = llama_add_bos_token(model); bool suff_rm_leading_spc = true; if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { @@ -2261,24 +2194,20 @@ struct server_context // empty prompt passed -> release the slot and send empty response if (prompt_tokens.empty()) { - LOG_INFO("empty prompt - releasing slot", - {{"id_slot", slot.id}, {"id_task", slot.id_task}}); + LOG_INFO("empty prompt - releasing slot", { + {"id_slot", slot.id}, + {"id_task", slot.id_task} + }); - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; slot.release(); slot.print_timings(); send_final_response(slot); continue; } - if (slot.embedding) - { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { // this prompt is too large to process - discard it - if (slot.n_prompt_tokens > n_ubatch) - { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; + if (slot.n_prompt_tokens > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); @@ -2330,7 +2259,7 @@ struct server_context GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } - llama_sampling_reset(slot.ctx_sampling); + gpt_sampler_reset(slot.smpl); if (!slot.params.cache_prompt) { @@ -2345,9 +2274,8 @@ struct server_context slot.n_past = common_part(slot.cache_tokens, prompt_tokens); // push the prompt into the sampling context (do not apply grammar) - for (int i = 0; i < slot.n_past; ++i) - { - llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + for (int i = 0; i < slot.n_past; ++i) { + gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false); } } } @@ -2355,8 +2283,10 @@ struct server_context if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { // we have to evaluate at least 1 token to generate logits. - LOG_INFO("we have to evaluate at least 1 token to generate logits", - {{"id_slot", slot.id}, {"id_task", slot.id_task}}); + LOG_INFO("we have to evaluate at least 1 token to generate logits", { + { "id_slot", slot.id }, + { "id_task", slot.id_task } + }); slot.n_past--; if (slot.ga_i > 0) @@ -2368,8 +2298,7 @@ struct server_context slot.n_prompt_tokens_processed = 0; } - if (slot.embedding) - { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { // cannot fit the prompt in the current batch - will try next iter if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { @@ -2378,7 +2307,7 @@ struct server_context } // check that we are in the right batch_type, if not defer the slot - bool slot_type = slot.embedding ? 1 : 0; + bool slot_type = slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ? 1 : 0; if (batch_type == -1) { batch_type = slot_type; @@ -2407,7 +2336,7 @@ struct server_context slot.n_past_se = 0; slot.ga_i = 0; // TODO: is the system prompt ever in the sampling context? - llama_sampling_reset(slot.ctx_sampling); + gpt_sampler_reset(slot.smpl); } // remove the non-common part from the cache @@ -2456,11 +2385,9 @@ struct server_context {"progress", (float)slot.n_prompt_tokens_processed / slot.n_prompt_tokens}, }); - // entire prompt has been processed - start decoding new tokens - if (slot.n_past == slot.n_prompt_tokens) - { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; + // entire prompt has been processed + if (slot.n_past == slot.n_prompt_tokens) { + slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); @@ -2558,6 +2485,7 @@ struct server_context }; const int ret = llama_decode(ctx, batch_view); + metrics.on_decoded(slots); if (ret != 0) { @@ -2567,13 +2495,10 @@ struct server_context LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", { {"i", i}, - {"n_batch", ret}, + {"n_batch", n_batch}, {"ret", ret}, }); - for (auto &slot : slots) - { - slot.state = SLOT_STATE_PROCESSING; - slot.command = SLOT_COMMAND_NONE; + for (auto & slot : slots) { slot.release(); send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size."); } @@ -2584,9 +2509,7 @@ struct server_context n_batch /= 2; i -= n_batch; - LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try " - "increasing it via the context size or enable defragmentation", - { + LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", { {"i", i}, {"n_batch", n_batch}, {"ret", ret}, @@ -2597,24 +2520,29 @@ struct server_context for (auto &slot : slots) { - if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int)i || slot.i_batch >= (int)(i + n_tokens)) - { + if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { continue; // continue loop of slots } + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) { // prompt evaluated for embedding - if (slot.embedding) - { send_embedding(slot, batch_view); slot.release(); slot.i_batch = -1; continue; // continue loop of slots } + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + } else if (slot.state != SLOT_STATE_GENERATING) { + continue; // continue loop of slots + } + completion_token_output result; - const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); + const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); - llama_sampling_accept(slot.ctx_sampling, ctx, id, true); + gpt_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; if (slot.n_decoded == 1) @@ -2624,44 +2552,19 @@ struct server_context metrics.on_prompt_eval(slot); } - llama_token_data_array cur_p = {slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false}; result.tok = id; - const size_t n_probs = std::min(cur_p.size, (size_t)slot.sparams.n_probs); - if (n_probs > 0) - { - const size_t n_valid = slot.ctx_sampling->n_valid; - - // Make sure at least n_probs top tokens are at the front of the vector: - if (slot.sparams.temp == 0.0f && n_probs > n_valid) - { - llama_sample_top_k(ctx, &cur_p, n_probs, 0); - } + const auto * cur_p = gpt_sampler_get_candidates(slot.smpl); - if (slot.sparams.temp == 0.0f) - { - // With greedy sampling the probabilities have possibly not been calculated. - for (size_t i = 0; i < n_probs; ++i) - { - result.probs.push_back({cur_p.data[i].id, i == 0 ? 1.0f : 0.0f}); - } - } - else - { - for (size_t i = 0; i < n_probs; ++i) - { + for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { result.probs.push_back({ - cur_p.data[i].id, - i >= n_valid - ? 0.0f - : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. + cur_p->data[i].id, + i >= cur_p->size ? 0.0f : cur_p->data[i].p, }); } - } - } - if (!process_token(result, slot)) - { + if (!process_token(result, slot)) { + // release slot because of stop condition slot.release(); slot.print_timings(); send_final_response(slot);