-
Notifications
You must be signed in to change notification settings - Fork 240
/
Copy pathserver.py
executable file
Β·182 lines (147 loc) Β· 6.08 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#!/usr/bin/env python
"""Example of a chat server with persistence handled on the backend.
For simplicity, we're using file storage here -- to avoid the need to set up
a database. This is obviously not a good idea for a production environment,
but will help us to demonstrate the RunnableWithMessageHistory interface.
We'll use cookies to identify the user. This will help illustrate how to
fetch configuration from the request.
"""
import re
from pathlib import Path
from typing import Any, Callable, Dict, Union
from fastapi import FastAPI, HTTPException, Request
from langchain_community.chat_message_histories import FileChatMessageHistory
from langchain_core import __version__
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import ConfigurableFieldSpec
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_openai import ChatOpenAI
from typing_extensions import TypedDict
from langserve import add_routes
# Define the minimum required version as (0, 1, 0)
# Earlier versions did not allow specifying custom config fields in
# RunnableWithMessageHistory.
MIN_VERSION_LANGCHAIN_CORE = (0, 1, 0)
# Split the version string by "." and convert to integers
LANGCHAIN_CORE_VERSION = tuple(map(int, __version__.split(".")))
if LANGCHAIN_CORE_VERSION < MIN_VERSION_LANGCHAIN_CORE:
raise RuntimeError(
f"Minimum required version of langchain-core is {MIN_VERSION_LANGCHAIN_CORE}, "
f"but found {LANGCHAIN_CORE_VERSION}"
)
def _is_valid_identifier(value: str) -> bool:
"""Check if the value is a valid identifier."""
# Use a regular expression to match the allowed characters
valid_characters = re.compile(r"^[a-zA-Z0-9-_]+$")
return bool(valid_characters.match(value))
def create_session_factory(
base_dir: Union[str, Path],
) -> Callable[[str], BaseChatMessageHistory]:
"""Create a factory that can retrieve chat histories.
The chat histories are keyed by user ID and conversation ID.
Args:
base_dir: Base directory to use for storing the chat histories.
Returns:
A factory that can retrieve chat histories keyed by user ID and conversation ID.
"""
base_dir_ = Path(base_dir) if isinstance(base_dir, str) else base_dir
if not base_dir_.exists():
base_dir_.mkdir(parents=True)
def get_chat_history(user_id: str, conversation_id: str) -> FileChatMessageHistory:
"""Get a chat history from a user id and conversation id."""
if not _is_valid_identifier(user_id):
raise ValueError(
f"User ID {user_id} is not in a valid format. "
"User ID must only contain alphanumeric characters, "
"hyphens, and underscores."
"Please include a valid cookie in the request headers called 'user-id'."
)
if not _is_valid_identifier(conversation_id):
raise ValueError(
f"Conversation ID {conversation_id} is not in a valid format. "
"Conversation ID must only contain alphanumeric characters, "
"hyphens, and underscores. Please provide a valid conversation id "
"via config. For example, "
"chain.invoke(.., {'configurable': {'conversation_id': '123'}})"
)
user_dir = base_dir_ / user_id
if not user_dir.exists():
user_dir.mkdir(parents=True)
file_path = user_dir / f"{conversation_id}.json"
return FileChatMessageHistory(str(file_path))
return get_chat_history
app = FastAPI(
title="LangChain Server",
version="1.0",
description="Spin up a simple api server using Langchain's Runnable interfaces",
)
def _per_request_config_modifier(
config: Dict[str, Any], request: Request
) -> Dict[str, Any]:
"""Update the config"""
config = config.copy()
configurable = config.get("configurable", {})
# Look for a cookie named "user_id"
user_id = request.cookies.get("user_id", None)
if user_id is None:
raise HTTPException(
status_code=400,
detail="No user id found. Please set a cookie named 'user_id'.",
)
configurable["user_id"] = user_id
config["configurable"] = configurable
return config
# Declare a chain
prompt = ChatPromptTemplate.from_messages(
[
("system", "You're an assistant by the name of Bob."),
MessagesPlaceholder(variable_name="history"),
("human", "{human_input}"),
]
)
chain = prompt | ChatOpenAI()
class InputChat(TypedDict):
"""Input for the chat endpoint."""
human_input: str
"""Human input"""
chain_with_history = RunnableWithMessageHistory(
chain,
create_session_factory("chat_histories"),
input_messages_key="human_input",
history_messages_key="history",
history_factory_config=[
ConfigurableFieldSpec(
id="user_id",
annotation=str,
name="User ID",
description="Unique identifier for the user.",
default="",
is_shared=True,
),
ConfigurableFieldSpec(
id="conversation_id",
annotation=str,
name="Conversation ID",
description="Unique identifier for the conversation.",
default="",
is_shared=True,
),
],
).with_types(input_type=InputChat)
add_routes(
app,
chain_with_history,
per_req_config_modifier=_per_request_config_modifier,
# Disable playground and batch
# 1) Playground we're passing information via headers, which is not supported via
# the playground right now.
# 2) Disable batch to avoid users being confused. Batch will work fine
# as long as users invoke it with multiple configs appropriately, but
# without validation users are likely going to forget to do that.
# In addition, there's likely little sense in support batch for a chatbot.
disabled_endpoints=["playground", "batch"],
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="localhost", port=8000)