Skip to content

Stop keywords #365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 67 additions & 43 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,13 @@ int main(int argc, char ** argv) {
}
}
}

if(params.stop_keyword.size()) {
for (auto stop_keyword : params.stop_keyword) {
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
}
}

fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
fprintf(stderr, "\n\n");

Expand Down Expand Up @@ -1165,61 +1172,78 @@ int main(int argc, char ** argv) {
set_console_state(CONSOLE_STATE_DEFAULT);
}

// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= input_consumed) {
// check for reverse prompt
// If we are not processing queued inputs, check for reverse prompt and stop keywords
if((int) embd_inp.size() <= input_consumed) {
// Build the output string
// TODO - Recomputing this whole string every iteration is not efficient
std::string last_output;
for (auto id : last_n_tokens) {
last_output += vocab.id_to_token[id].tok;
}

// Check if each of the reverse prompts appears at the end of the output.
for (std::string antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
// Check for stop keywords
bool stop = false;
for (std::string stop_keyword : params.stop_keyword) {
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
stop = true;
break;
}
}
if (is_interacting) {
// potentially set color to indicate we are taking user input
set_console_state(CONSOLE_STATE_USER_INPUT);

if (params.instruct) {
input_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());

printf("\n> ");
if(stop) {
break;
}

// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive) {

// Check if each of the reverse prompts appears at the end of the output.
for (std::string antiprompt : params.antiprompt) {
if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
is_interacting = true;
break;
}
}

std::string buffer;
std::string line;
bool another_line = true;
do {
std::getline(std::cin, line);
if (line.empty() || line.back() != '\\') {
another_line = false;
} else {
line.pop_back(); // Remove the continue character
if (is_interacting) {
// potentially set color to indicate we are taking user input
set_console_state(CONSOLE_STATE_USER_INPUT);

if (params.instruct) {
input_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());

printf("\n> ");
}
buffer += line + '\n'; // Append the line to the result
} while (another_line);

// done taking input, reset color
set_console_state(CONSOLE_STATE_DEFAULT);

std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());

if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());

std::string buffer;
std::string line;
bool another_line = true;
do {
std::getline(std::cin, line);
if (line.empty() || line.back() != '\\') {
another_line = false;
} else {
line.pop_back(); // Remove the continue character
}
buffer += line + '\n'; // Append the line to the result
} while (another_line);

// done taking input, reset color
set_console_state(CONSOLE_STATE_DEFAULT);

std::vector<llama_vocab::id> line_inp = ::llama_tokenize(vocab, buffer, false);
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());

if (params.instruct) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}

remaining_tokens -= line_inp.size();

input_noecho = true; // do not echo this again
}

remaining_tokens -= line_inp.size();

input_noecho = true; // do not echo this again
is_interacting = false;
}
is_interacting = false;
}

// end of text token
Expand Down
4 changes: 4 additions & 0 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_color = true;
} else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt.push_back(argv[++i]);
} else if (arg == "--stop") {
params.stop_keyword.push_back(argv[++i]);
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--ignore-eos") {
Expand Down Expand Up @@ -103,6 +105,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " in interactive mode, poll user input upon seeing PROMPT (can be\n");
fprintf(stderr, " specified more than once for multiple prompts).\n");
fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n");
fprintf(stderr, " (can be specified more than once for multiple keywords).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
Expand Down
1 change: 1 addition & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct gpt_params {
std::string prompt = "";

std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::vector<std::string> stop_keyword; // string upon seeing which the model will stop

bool memory_f16 = false; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
Expand Down