diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index aadd067159..0853a5b837 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -396,78 +396,119 @@ struct server_slot { } }; -void server_slots_t::clear() { - for (auto & slot_ptr : data) { - delete slot_ptr; - } - data.clear(); -} - -server_slot & server_slots_t::create() { - auto instance = new server_slot(); - data.push_back(instance); - return *instance; -} - -server_slots_t::~server_slots_t() { - clear(); -} - // // server_metrics // -void server_metrics::init() { - t_start = ggml_time_us(); -} +struct server_metrics { + int64_t t_start = 0; -void server_metrics::on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - t_prompt_processing_total += slot.t_prompt_processing; + uint64_t n_prompt_tokens_processed_total = 0; + uint64_t t_prompt_processing_total = 0; + uint64_t n_tokens_predicted_total = 0; + uint64_t t_tokens_generation_total = 0; - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); -} + uint64_t n_tokens_max = 0; -void server_metrics::on_prediction(const server_slot & slot) { - n_tokens_predicted_total += slot.n_decoded; - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - t_tokens_generation_total += slot.t_token_generation; -} + uint64_t n_prompt_tokens_processed = 0; + uint64_t t_prompt_processing = 0; + + uint64_t n_tokens_predicted = 0; + uint64_t t_tokens_generation = 0; + + uint64_t n_decode_total = 0; + uint64_t n_busy_slots_total = 0; + + void init() { + t_start = ggml_time_us(); + } + + void on_prompt_eval(const server_slot & slot) { + n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + t_prompt_processing += slot.t_prompt_processing; + t_prompt_processing_total += slot.t_prompt_processing; -void server_metrics::on_decoded(const server_slots_t & slots) { - n_decode_total++; - for (size_t i = 0; i < slots.size(); i++) { - const auto & slot = slots[i]; - if (slot.is_processing()) { - n_busy_slots_total++; - } n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); } -} -void server_metrics::reset_bucket() { - n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; -} + void on_prediction(const server_slot & slot) { + n_tokens_predicted_total += slot.n_decoded; + n_tokens_predicted += slot.n_decoded; + t_tokens_generation += slot.t_token_generation; + t_tokens_generation_total += slot.t_token_generation; + } + + void on_decoded(const std::vector & slots) { + n_decode_total++; + for (const auto & slot : slots) { + if (slot.is_processing()) { + n_busy_slots_total++; + } + n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); + } + } + + void reset_bucket() { + n_prompt_tokens_processed = 0; + t_prompt_processing = 0; + n_tokens_predicted = 0; + t_tokens_generation = 0; + } +}; // -// server_context +// server_context_impl (private implementation) // -// TODO @ngxson : the only purpose of this extern "C" is to keep the first indentation level -// this was done to avoid massive changes in while doing the recent refactoring, avoiding merge conflicts -// we can remove this once things are more stable +struct server_context_impl { + common_params params_base; + + // note: keep these alive - they determine the lifetime of the model, context, etc. + common_init_result llama_init; + common_init_result llama_init_dft; + + llama_model * model = nullptr; + llama_context * ctx = nullptr; + + // multimodal + mtmd_context * mctx = nullptr; + + const llama_vocab * vocab = nullptr; + bool vocab_dft_compatible = true; + + llama_model * model_dft = nullptr; + + llama_context_params cparams_dft; -extern "C" { - server_context::~server_context() { + llama_batch batch {}; + + bool add_bos_token = true; + + int32_t n_ctx; // total context for all clients / slots + + // slots / clients + std::vector slots; + + int slots_debug = 0; + + server_queue queue_tasks; + server_response queue_results; + + std::unique_ptr prompt_cache; + + server_metrics metrics; + + // Necessary similarity of prompt for slot selection + float slot_prompt_similarity = 0.0f; + + common_chat_templates_ptr chat_templates; + oaicompat_parser_options oai_parser_opt; + + ~server_context_impl() { mtmd_free(mctx); // Clear any sampling context @@ -488,7 +529,7 @@ extern "C" { } // load the model and initialize llama_context - bool server_context::load_model(const common_params & params) { + bool load_model(const common_params & params) { SRV_INF("loading model '%s'\n", params.model.path.c_str()); params_base = params; @@ -608,7 +649,19 @@ extern "C" { } // initialize slots and server-related data - void server_context::init() { + void init() { + // wiring up server queues + queue_tasks.on_new_task([this](server_task && task) { + process_single_task(std::move(task)); + }); + queue_tasks.on_update_slots([this]() { + update_slots(); + }); + + // Necessary similarity of prompt for slot selection + slot_prompt_similarity = params_base.slot_prompt_similarity; + + // setup slots SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); const int n_ctx_train = llama_model_n_ctx_train(model); @@ -620,7 +673,7 @@ extern "C" { } for (int i = 0; i < params_base.n_parallel; i++) { - server_slot & slot = slots.create(); + server_slot slot; slot.id = i; slot.ctx = ctx; @@ -655,6 +708,8 @@ extern "C" { }; slot.reset(); + + slots.push_back(std::move(slot)); } { @@ -712,7 +767,7 @@ extern "C" { common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); } - server_slot * server_context::get_slot_by_id(int id) { + server_slot * get_slot_by_id(int id) { for (server_slot & slot : slots) { if (slot.id == id) { return &slot; @@ -722,7 +777,7 @@ extern "C" { return nullptr; } - server_slot * server_context::get_available_slot(const server_task & task) { + server_slot * get_available_slot(const server_task & task) { server_slot * ret = nullptr; bool update_cache = false; @@ -826,7 +881,7 @@ extern "C" { return ret; } - void server_context::clear_slot(server_slot & slot) const { + void clear_slot(server_slot & slot) const { GGML_ASSERT(!slot.is_processing()); SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); @@ -840,7 +895,7 @@ extern "C" { // - smarter decision which slot to clear (LRU or longest prompt?) // - move slot to level 2 cache instead of removing? // - instead of purging, try to store and resume later? - bool server_context::try_clear_idle_slots() { + bool try_clear_idle_slots() { bool res = false; if (!params_base.kv_unified) { @@ -867,7 +922,7 @@ extern "C" { return res; } - bool server_context::launch_slot_with_task(server_slot & slot, server_task && task) { + bool launch_slot_with_task(server_slot & slot, server_task && task) { slot.reset(); if (!are_lora_equal(task.params.lora, slot.lora)) { @@ -969,7 +1024,7 @@ extern "C" { return true; } - bool server_context::process_token(completion_token_output & result, server_slot & slot) { + bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = result.text_to_send; slot.sampled = result.tok; @@ -1100,7 +1155,7 @@ extern "C" { return slot.has_next_token; // continue } - void server_context::populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { size_t n_probs = slot.task->params.sampling.n_probs; size_t n_vocab = llama_vocab_n_tokens(vocab); @@ -1150,11 +1205,15 @@ extern "C" { } } - void server_context::send_error(const server_slot & slot, const std::string & error, const enum error_type type) { + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { + send_error(task.id, error, type); + } + + void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx); } - void server_context::send_error(const int id_task, const std::string & error, const enum error_type type, const int32_t n_prompt_tokens, const int32_t n_ctx) { + void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { @@ -1172,7 +1231,7 @@ extern "C" { } // if multimodal is enabled, send an error and return false - bool server_context::check_no_mtmd(const int id_task) { + bool check_no_mtmd(const int id_task) { if (mctx) { send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); return false; @@ -1180,7 +1239,7 @@ extern "C" { return true; } - void server_context::send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { + void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { auto res = std::make_unique(); res->id = slot.task->id; @@ -1221,7 +1280,7 @@ extern "C" { queue_results.send(std::move(res)); } - void server_context::send_final_response(server_slot & slot) { + void send_final_response(server_slot & slot) { auto res = std::make_unique(); res->id = slot.task->id; @@ -1272,7 +1331,7 @@ extern "C" { queue_results.send(std::move(res)); } - void server_context::send_embedding(const server_slot & slot, const llama_batch & batch) { + void send_embedding(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); res->id = slot.task->id; res->index = slot.task->index; @@ -1317,7 +1376,7 @@ extern "C" { queue_results.send(std::move(res)); } - void server_context::send_rerank(const server_slot & slot, const llama_batch & batch) { + void send_rerank(const server_slot & slot, const llama_batch & batch) { auto res = std::make_unique(); res->id = slot.task->id; res->index = slot.task->index; @@ -1352,7 +1411,7 @@ extern "C" { // Functions to process the task // - void server_context::process_single_task(server_task && task) { + void process_single_task(server_task && task) { switch (task.type) { case SERVER_TASK_TYPE_COMPLETION: case SERVER_TASK_TYPE_INFILL: @@ -1572,7 +1631,7 @@ extern "C" { } } - void server_context::update_slots() { + void update_slots() { // check if all slots are idle { bool all_idle = true; @@ -2421,24 +2480,65 @@ extern "C" { SRV_DBG("%s", "run slots completed\n"); } - // - // Utility functions - // + json model_meta() const { + return json { + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_tokens (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, + }; + } - int server_context::get_slot_n_ctx() const { - return slots[0].n_ctx; + int get_slot_n_ctx() { + return slots.back().n_ctx; + } +}; + +void server_context_impl_deleter::operator()(server_context_impl * ptr) const { + if (ptr) { + delete ptr; } } +// +// server_context (public API) +// + +server_context::server_context() : impl(new server_context_impl()) {} +server_context::~server_context() = default; + +void server_context::init() { + impl->init(); +} + +bool server_context::load_model(const common_params & params) { + return impl->load_model(params); +} + +void server_context::start_loop() { + impl->queue_tasks.start_loop(); +} + +void server_context::terminate() { + impl->queue_tasks.terminate(); +} + +llama_context * server_context::get_llama_context() const { + return impl->ctx; +} + + // generator-like API for server responses, support pooling connection state and aggregating results struct server_response_reader { std::unordered_set id_tasks; - server_context & ctx_server; + server_context_impl & ctx_server; size_t received_count = 0; bool cancelled = false; - server_response_reader(server_context & ctx_server) : ctx_server(ctx_server) {} + server_response_reader(server_context_impl & ctx_server) : ctx_server(ctx_server) {} ~server_response_reader() { stop(); } @@ -2532,7 +2632,7 @@ struct server_response_reader { // generator-like API for HTTP response generation struct server_res_generator : server_http_res { server_response_reader rd; - server_res_generator(server_context & ctx_server_) : rd(ctx_server_) {} + server_res_generator(server_context_impl & ctx_server_) : rd(ctx_server_) {} void ok(const json & response_data) { status = 200; data = safe_json_to_str(response_data); diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 5c48f8e808..b799754145 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -47,172 +47,32 @@ enum server_state { SERVER_STATE_READY, // Server is ready and model is loaded }; -// forward declarations -struct server_slot; - -// proxy for std::vector to allow forward declaration of server_slot -struct server_slots_t { - ~server_slots_t(); - std::vector data; - size_t size() const { return data.size(); } - server_slot & operator[](size_t idx) { return *(data[idx]); } - server_slot & operator[](size_t idx) const { return *(data[idx]); } - void clear(); - server_slot & create(); - struct iterator { - typename std::vector::iterator it; - iterator(typename std::vector::iterator i) : it(i) {} - server_slot & operator*() { return **it; } - iterator & operator++() { ++it; return *this; } - bool operator!=(const iterator& other) const { return it != other.it; } - }; - iterator begin() { return iterator(data.begin()); } - iterator end() { return iterator(data.end()); } -}; - -struct server_metrics { - int64_t t_start = 0; - - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_tokens_max = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - void init(); - void on_prompt_eval(const server_slot & slot); - void on_prediction(const server_slot & slot); - void on_decoded(const server_slots_t & slots); - void reset_bucket(); +struct server_context_impl; // private implementation +struct server_context_impl_deleter { + void operator()(server_context_impl * p) const; }; struct server_context { -public: - common_params params_base; - - // note: keep these alive - they determine the lifetime of the model, context, etc. - common_init_result llama_init; - common_init_result llama_init_dft; - - llama_model * model = nullptr; - llama_context * ctx = nullptr; - - const llama_vocab * vocab = nullptr; - bool vocab_dft_compatible = true; - - // multimodal - mtmd_context * mctx = nullptr; - - server_queue queue_tasks; - server_response queue_results; - - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; - - // Necessary similarity of prompt for slot selection - float slot_prompt_similarity = 0.0f; - -private: - llama_model * model_dft = nullptr; - - llama_context_params cparams_dft; - - llama_batch batch {}; - - bool add_bos_token = true; - - int32_t n_ctx; // total context for all clients / slots - - // slots / clients - server_slots_t slots; - - int slots_debug = 0; - - std::unique_ptr prompt_cache; + std::unique_ptr impl; - server_metrics metrics; - -public: + server_context(); ~server_context(); - // load the model and initialize llama_context - bool load_model(const common_params & params); - // initialize slots and server-related data void init(); - server_slot * get_slot_by_id(int id); - - server_slot * get_available_slot(const server_task & task); - - void clear_slot(server_slot & slot) const; - - // return true if at least one slot has been cleared - // TODO: improve logic - // - smarter decision which slot to clear (LRU or longest prompt?) - // - move slot to level 2 cache instead of removing? - // - instead of purging, try to store and resume later? - bool try_clear_idle_slots(); - - bool launch_slot_with_task(server_slot & slot, server_task && task); - - bool process_token(completion_token_output & result, server_slot & slot); - - void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const; - - void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(task.id, error, type); - } - - void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER); - - void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0); - - // if multimodal is enabled, send an error and return false - bool check_no_mtmd(const int id_task); - - void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress); - - void send_final_response(server_slot & slot); - - void send_embedding(const server_slot & slot, const llama_batch & batch); - - void send_rerank(const server_slot & slot, const llama_batch & batch); - - // - // Functions to process the task - // - - void process_single_task(server_task && task); - - void update_slots(); + // load the model and initialize llama_context + // returns true on success + bool load_model(const common_params & params); - // - // Utility functions - // + // this function will block main thread until termination + void start_loop(); - int get_slot_n_ctx() const; + // terminate main loop (will unblock start_loop) + void terminate(); - json model_meta() const { - return json { - {"vocab_type", llama_vocab_type (vocab)}, - {"n_vocab", llama_vocab_n_tokens (vocab)}, - {"n_ctx_train", llama_model_n_ctx_train(model)}, - {"n_embd", llama_model_n_embd (model)}, - {"n_params", llama_model_n_params (model)}, - {"size", llama_model_size (model)}, - }; - } + // get the underlaying llama_context + llama_context * get_llama_context() const; }; @@ -220,10 +80,10 @@ struct server_res_generator; struct server_routes { const common_params & params; - server_context & ctx_server; + server_context_impl & ctx_server; server_http_context & ctx_http; // for reading is_ready server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) - : params(params), ctx_server(ctx_server), ctx_http(ctx_http) { + : params(params), ctx_server(*ctx_server.impl.get()), ctx_http(ctx_http) { init_routes(); } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index b9a9335ff2..d6603a5299 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -94,9 +94,6 @@ int main(int argc, char ** argv) { // struct that contains llama context and inference server_context ctx_server; - // Necessary similarity of prompt for slot selection - ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - llama_backend_init(); llama_numa_init(params.numa); @@ -161,7 +158,7 @@ int main(int argc, char ** argv) { auto clean_up = [&ctx_http, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); ctx_http.stop(); - ctx_server.queue_results.terminate(); + ctx_server.terminate(); llama_backend_free(); }; @@ -189,17 +186,9 @@ int main(int argc, char ** argv) { LOG_INF("%s: model loaded\n", __func__); - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { - ctx_server.process_single_task(std::move(task)); - }); - - ctx_server.queue_tasks.on_update_slots([&ctx_server]() { - ctx_server.update_slots(); - }); - shutdown_handler = [&](int) { // this will unblock start_loop() - ctx_server.queue_tasks.terminate(); + ctx_server.terminate(); }; // TODO: refactor in common/console @@ -219,14 +208,14 @@ int main(int argc, char ** argv) { LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); LOG_INF("%s: starting the main loop...\n", __func__); - // this call blocks the main thread until queue_tasks.terminate() is called - ctx_server.queue_tasks.start_loop(); + // this call blocks the main thread until ctx_server.terminate() is called + ctx_server.start_loop(); clean_up(); if (ctx_http.thread.joinable()) { ctx_http.thread.join(); } - llama_memory_breakdown_print(ctx_server.ctx); + llama_memory_breakdown_print(ctx_server.get_llama_context()); return 0; }