-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
Copy pathagent_output.py
194 lines (159 loc) · 6.96 KB
/
agent_output.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import abc
from dataclasses import dataclass
from typing import Any
from pydantic import BaseModel, TypeAdapter
from typing_extensions import TypedDict, get_args, get_origin
from .exceptions import ModelBehaviorError, UserError
from .strict_schema import ensure_strict_json_schema
from .tracing import SpanError
from .util import _error_tracing, _json
_WRAPPER_DICT_KEY = "response"
class AgentOutputSchemaBase(abc.ABC):
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
produced by the LLM into the output type.
"""
@abc.abstractmethod
def is_plain_text(self) -> bool:
"""Whether the output type is plain text (versus a JSON object)."""
pass
@abc.abstractmethod
def name(self) -> str:
"""The name of the output type."""
pass
@abc.abstractmethod
def json_schema(self) -> dict[str, Any]:
"""Returns the JSON schema of the output. Will only be called if the output type is not
plain text.
"""
pass
@abc.abstractmethod
def is_strict_json_schema(self) -> bool:
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
features, but guarantees valis JSON. See here for details:
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
"""
pass
@abc.abstractmethod
def validate_json(self, json_str: str) -> Any:
"""Validate a JSON string against the output type. You must return the validated object,
or raise a `ModelBehaviorError` if the JSON is invalid.
"""
pass
@dataclass(init=False)
class AgentOutputSchema(AgentOutputSchemaBase):
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
produced by the LLM into the output type.
"""
output_type: type[Any]
"""The type of the output."""
_type_adapter: TypeAdapter[Any]
"""A type adapter that wraps the output type, so that we can validate JSON."""
_is_wrapped: bool
"""Whether the output type is wrapped in a dictionary. This is generally done if the base
output type cannot be represented as a JSON Schema object.
"""
_output_schema: dict[str, Any]
"""The JSON schema of the output."""
_strict_json_schema: bool
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input.
"""
def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
"""
Args:
output_type: The type of the output.
strict_json_schema: Whether the JSON schema is in strict mode. We **strongly** recommend
setting this to True, as it increases the likelihood of correct JSON input.
"""
self.output_type = output_type
self._strict_json_schema = strict_json_schema
if output_type is None or output_type is str:
self._is_wrapped = False
self._type_adapter = TypeAdapter(output_type)
self._output_schema = self._type_adapter.json_schema()
return
# We should wrap for things that are not plain text, and for things that would definitely
# not be a JSON Schema object.
self._is_wrapped = not _is_subclass_of_base_model_or_dict(output_type)
if self._is_wrapped:
OutputType = TypedDict(
"OutputType",
{
_WRAPPER_DICT_KEY: output_type, # type: ignore
},
)
self._type_adapter = TypeAdapter(OutputType)
self._output_schema = self._type_adapter.json_schema()
else:
self._type_adapter = TypeAdapter(output_type)
self._output_schema = self._type_adapter.json_schema()
if self._strict_json_schema:
try:
self._output_schema = ensure_strict_json_schema(self._output_schema)
except UserError as e:
raise UserError(
"Strict JSON schema is enabled, but the output type is not valid. "
"Either make the output type strict, or pass output_schema_strict=False to "
"your Agent()"
) from e
def is_plain_text(self) -> bool:
"""Whether the output type is plain text (versus a JSON object)."""
return self.output_type is None or self.output_type is str
def is_strict_json_schema(self) -> bool:
"""Whether the JSON schema is in strict mode."""
return self._strict_json_schema
def json_schema(self) -> dict[str, Any]:
"""The JSON schema of the output type."""
if self.is_plain_text():
raise UserError("Output type is plain text, so no JSON schema is available")
return self._output_schema
def validate_json(self, json_str: str) -> Any:
"""Validate a JSON string against the output type. Returns the validated object, or raises
a `ModelBehaviorError` if the JSON is invalid.
"""
validated = _json.validate_json(json_str, self._type_adapter, partial=False)
if self._is_wrapped:
if not isinstance(validated, dict):
_error_tracing.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Expected a dict, got {type(validated)}"},
)
)
raise ModelBehaviorError(
f"Expected a dict, got {type(validated)} for JSON: {json_str}"
)
if _WRAPPER_DICT_KEY not in validated:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Invalid JSON",
data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"},
)
)
raise ModelBehaviorError(
f"Could not find key {_WRAPPER_DICT_KEY} in JSON: {json_str}"
)
return validated[_WRAPPER_DICT_KEY]
return validated
def name(self) -> str:
"""The name of the output type."""
return _type_to_str(self.output_type)
def _is_subclass_of_base_model_or_dict(t: Any) -> bool:
if not isinstance(t, type):
return False
# If it's a generic alias, 'origin' will be the actual type, e.g. 'list'
origin = get_origin(t)
allowed_types = (BaseModel, dict)
# If it's a generic alias e.g. list[str], then we should check the origin type i.e. list
return issubclass(origin or t, allowed_types)
def _type_to_str(t: type[Any]) -> str:
origin = get_origin(t)
args = get_args(t)
if origin is None:
# It's a simple type like `str`, `int`, etc.
return t.__name__
elif args:
args_str = ", ".join(_type_to_str(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"
else:
return str(t)