-
Notifications
You must be signed in to change notification settings - Fork 6.6k
/
Copy pathapp_team_user_proxy.py
140 lines (123 loc) · 5.46 KB
/
app_team_user_proxy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from typing import List, cast
import chainlit as cl
import yaml
from autogen_agentchat.agents import AssistantAgent, UserProxyAgent
from autogen_agentchat.base import TaskResult
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.messages import ModelClientStreamingChunkEvent, TextMessage
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_core import CancellationToken
from autogen_core.models import ChatCompletionClient
async def user_input_func(prompt: str, cancellation_token: CancellationToken | None = None) -> str:
"""Get user input from the UI for the user proxy agent."""
try:
response = await cl.AskUserMessage(content=prompt).send()
except TimeoutError:
return "User did not provide any input within the time limit."
if response:
return response["output"] # type: ignore
else:
return "User did not provide any input."
async def user_action_func(prompt: str, cancellation_token: CancellationToken | None = None) -> str:
"""Get user action from the UI for the user proxy agent."""
try:
response = await cl.AskActionMessage(
content="Pick an action",
actions=[
cl.Action(name="approve", label="Approve", payload={"value": "approve"}),
cl.Action(name="reject", label="Reject", payload={"value": "reject"}),
],
).send()
except TimeoutError:
return "User did not provide any input within the time limit."
if response and response.get("payload"): # type: ignore
if response.get("payload").get("value") == "approve": # type: ignore
return "APPROVE." # This is the termination condition.
else:
return "REJECT."
else:
return "User did not provide any input."
@cl.on_chat_start # type: ignore
async def start_chat() -> None:
# Load model configuration and create the model client.
with open("model_config.yaml", "r") as f:
model_config = yaml.safe_load(f)
model_client = ChatCompletionClient.load_component(model_config)
# Create the assistant agent.
assistant = AssistantAgent(
name="assistant",
model_client=model_client,
system_message="You are a helpful assistant.",
model_client_stream=True, # Enable model client streaming.
)
# Create the critic agent.
critic = AssistantAgent(
name="critic",
model_client=model_client,
system_message="You are a critic. Provide constructive feedback. "
"Respond with 'APPROVE' if your feedback has been addressed.",
model_client_stream=True, # Enable model client streaming.
)
# Create the user proxy agent.
user = UserProxyAgent(
name="user",
# input_func=user_input_func, # Uncomment this line to use user input as text.
input_func=user_action_func, # Uncomment this line to use user input as action.
)
# Termination condition.
termination = TextMentionTermination("APPROVE", sources=["user"])
# Chain the assistant, critic and user agents using RoundRobinGroupChat.
group_chat = RoundRobinGroupChat([assistant, critic, user], termination_condition=termination)
# Set the assistant agent in the user session.
cl.user_session.set("prompt_history", "") # type: ignore
cl.user_session.set("team", group_chat) # type: ignore
@cl.set_starters # type: ignore
async def set_starts() -> List[cl.Starter]:
return [
cl.Starter(
label="Poem Writing",
message="Write a poem about the ocean.",
),
cl.Starter(
label="Story Writing",
message="Write a story about a detective solving a mystery.",
),
cl.Starter(
label="Write Code",
message="Write a function that merge two list of numbers into single sorted list.",
),
]
@cl.on_message # type: ignore
async def chat(message: cl.Message) -> None:
# Get the team from the user session.
team = cast(RoundRobinGroupChat, cl.user_session.get("team")) # type: ignore
# Streaming response message.
streaming_response: cl.Message | None = None
# Stream the messages from the team.
async for msg in team.run_stream(
task=[TextMessage(content=message.content, source="user")],
cancellation_token=CancellationToken(),
):
if isinstance(msg, ModelClientStreamingChunkEvent):
# Stream the model client response to the user.
if streaming_response is None:
# Start a new streaming response.
streaming_response = cl.Message(content="", author=msg.source)
await streaming_response.stream_token(msg.content)
elif streaming_response is not None:
# Done streaming the model client response.
# We can skip the current message as it is just the complete message
# of the streaming response.
await streaming_response.send()
# Reset the streaming response so we won't enter this block again
# until the next streaming response is complete.
streaming_response = None
elif isinstance(msg, TaskResult):
# Send the task termination message.
final_message = "Task terminated. "
if msg.stop_reason:
final_message += msg.stop_reason
await cl.Message(content=final_message).send()
else:
# Skip all other message types.
pass