-
Notifications
You must be signed in to change notification settings - Fork 213
/
Copy pathchat.py
89 lines (73 loc) · 2.51 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json
import os
from elasticsearch_client import (
elasticsearch_client,
get_elasticsearch_chat_message_history,
)
from flask import current_app, render_template, stream_with_context
from functools import cache
from langchain_elasticsearch import (
ElasticsearchStore,
SparseVectorStrategy,
)
from llm_integrations import get_llm
INDEX = os.getenv("ES_INDEX", "workplace-app-docs")
INDEX_CHAT_HISTORY = os.getenv(
"ES_INDEX_CHAT_HISTORY", "workplace-app-docs-chat-history"
)
ELSER_MODEL = os.getenv("ELSER_MODEL", ".elser_model_2")
SESSION_ID_TAG = "[SESSION_ID]"
SOURCE_TAG = "[SOURCE]"
DONE_TAG = "[DONE]"
store = ElasticsearchStore(
es_connection=elasticsearch_client,
index_name=INDEX,
strategy=SparseVectorStrategy(model_id=ELSER_MODEL),
)
@cache
def get_lazy_llm():
return get_llm()
@stream_with_context
def ask_question(question, session_id):
llm = get_lazy_llm()
yield f"data: {SESSION_ID_TAG} {session_id}\n\n"
current_app.logger.debug("Chat session ID: %s", session_id)
chat_history = get_elasticsearch_chat_message_history(
INDEX_CHAT_HISTORY, session_id
)
if len(chat_history.messages) > 0:
# create a condensed question
condense_question_prompt = render_template(
"condense_question_prompt.txt",
question=question,
chat_history=chat_history.messages,
)
condensed_question = llm.invoke(condense_question_prompt).content
else:
condensed_question = question
current_app.logger.debug("Condensed question: %s", condensed_question)
current_app.logger.debug("Question: %s", question)
docs = store.as_retriever().invoke(condensed_question)
for doc in docs:
doc_source = {**doc.metadata, "page_content": doc.page_content}
current_app.logger.debug(
"Retrieved document passage from: %s", doc.metadata["name"]
)
yield f"data: {SOURCE_TAG} {json.dumps(doc_source)}\n\n"
qa_prompt = render_template(
"rag_prompt.txt",
question=question,
docs=docs,
chat_history=chat_history.messages,
)
answer = ""
for chunk in llm.stream(qa_prompt):
content = chunk.content.replace(
"\n", " "
) # the stream can get messed up with newlines
yield f"data: {content}\n\n"
answer += chunk.content
yield f"data: {DONE_TAG}\n\n"
current_app.logger.debug("Answer: %s", answer)
chat_history.add_user_message(question)
chat_history.add_ai_message(answer)