From 573b690e16bf4c2b50512294b4b659687c9146bc Mon Sep 17 00:00:00 2001 From: Danny Daemonic Date: Sat, 27 May 2023 03:20:03 -0700 Subject: [PATCH 1/2] Work around for recalculating logits in cached prompts --- examples/main/main.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c7c591537419c..e73c5356e5440 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -360,6 +360,12 @@ int main(int argc, char ** argv) { } } if (i > 0) { + // check if we've used up all the prompt but not all cached tokens + if (embd.size() == i && n_session_consumed < session_tokens.size()) { + // force revaluation of the last token to recalculate logits + i--; + n_past--; + } embd.erase(embd.begin(), embd.begin() + i); } } From 6d47258e4106b4ed01b1da29d529dae622ffd8d4 Mon Sep 17 00:00:00 2001 From: Danny Daemonic Date: Sat, 27 May 2023 03:36:44 -0700 Subject: [PATCH 2/2] n_session_consumed should just be size_t, but the cache code casts to (int) --- examples/main/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e73c5356e5440..6131f5b467304 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -361,7 +361,7 @@ int main(int argc, char ** argv) { } if (i > 0) { // check if we've used up all the prompt but not all cached tokens - if (embd.size() == i && n_session_consumed < session_tokens.size()) { + if (embd.size() == i && n_session_consumed < (int) session_tokens.size()) { // force revaluation of the last token to recalculate logits i--; n_past--;