-
-
Save ngxson/f3e18888e88d87184f785bf0d4458bda to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include "server-context.h" | |
| #include "arg.h" | |
| #include "common.h" | |
| #include "llama.h" | |
| #include "log.h" | |
| #include <iostream> | |
| #include <atomic> | |
| #include <signal.h> | |
| #include <thread> // for std::thread::hardware_concurrency | |
| #if defined(_WIN32) | |
| #include <windows.h> | |
| #endif | |
| using json = nlohmann::ordered_json; | |
| int main(int argc, char ** argv) { | |
| // own arguments required by this example | |
| common_params params; | |
| if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { | |
| return 1; | |
| } | |
| common_init(); | |
| // struct that contains llama context and inference | |
| server_context ctx_server; | |
| llama_backend_init(); | |
| llama_numa_init(params.numa); | |
| if (!ctx_server.load_model(params)) { | |
| LOG_ERR("%s: exiting due to model loading error\n", __func__); | |
| return 1; | |
| } | |
| ctx_server.init(); | |
| std::thread server_thread([&]() { | |
| ctx_server.start_loop(); | |
| }); | |
| server_routes routes(params, ctx_server); | |
| std::string user_input; | |
| json messages = json::array(); | |
| auto should_stop = []() { return false; }; // TODO: implement stop condition via Ctrl+C | |
| std::cout << "> " << std::flush; | |
| while (std::getline(std::cin, user_input)) { | |
| if (user_input == "exit" || user_input == "quit") { | |
| break; | |
| } | |
| messages.push_back({ | |
| {"role", "user"}, | |
| {"content", user_input} | |
| }); | |
| server_http_req req{ | |
| {}, {}, "", | |
| safe_json_to_str(json{ | |
| {"messages", messages}, | |
| {"stream", true} | |
| }), | |
| should_stop | |
| }; | |
| auto res = routes.post_chat_completions(req); | |
| std::string curr_text; | |
| if (res->is_stream()) { | |
| std::string chunk; | |
| while (res->next(chunk)) { | |
| std::vector<std::string> lines = string_split<std::string>(chunk, '\n'); | |
| for (auto & line : lines) { | |
| if (line.empty()) { | |
| continue; | |
| } | |
| if (line == "[DONE]") { | |
| break; | |
| } | |
| std::string & data = line; | |
| if (string_starts_with(line, "data: ")) { | |
| data = line.substr(6); | |
| } | |
| // std::cout << "parsing: " << data << std::endl; | |
| auto data_json = json::parse(data); | |
| if (data_json.contains("choices") && !data_json["choices"].empty() && | |
| data_json["choices"][0].contains("delta") && | |
| data_json["choices"][0]["delta"].contains("content") && | |
| !data_json["choices"][0]["delta"]["content"].is_null()) { | |
| std::string new_text = data_json["choices"][0]["delta"]["content"].get<std::string>(); | |
| curr_text += new_text; | |
| std::cout << new_text << std::flush; | |
| } | |
| } | |
| } | |
| std::cout << std::endl; | |
| messages.push_back({ | |
| {"role", "assistant"}, | |
| {"content", curr_text} | |
| }); | |
| } else { | |
| std::cout << res->data << std::endl; | |
| } | |
| std::cout << "> " << std::flush; | |
| } | |
| ctx_server.terminate(); | |
| server_thread.join(); | |
| return 0; | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment