Skip to content

Commit b54791c

Browse files
fix: increase recursion limit, cache compilation tree hashes (#184)
* fix: increase recursion limit, cache compilation tree hashes * don't decrease recursion limit * add comment explaining _node_hash method
1 parent 034f71f commit b54791c

File tree

3 files changed

+83
-3
lines changed

3 files changed

+83
-3
lines changed

bigframes/core/nodes.py

+72-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from dataclasses import dataclass, field
17+
from dataclasses import dataclass, field, fields
1818
import functools
1919
import typing
2020
from typing import Optional, Tuple
@@ -66,6 +66,13 @@ def session(self):
6666
return sessions[0]
6767
return None
6868

69+
# BigFrameNode trees can be very deep so its important avoid recalculating the hash from scratch
70+
# Each subclass of BigFrameNode should use this property to implement __hash__
71+
# The default dataclass-generated __hash__ method is not cached
72+
@functools.cached_property
73+
def _node_hash(self):
74+
return hash(tuple(hash(getattr(self, field.name)) for field in fields(self)))
75+
6976

7077
@dataclass(frozen=True)
7178
class UnaryNode(BigFrameNode):
@@ -95,6 +102,9 @@ class JoinNode(BigFrameNode):
95102
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
96103
return (self.left_child, self.right_child)
97104

105+
def __hash__(self):
106+
return self._node_hash
107+
98108

99109
@dataclass(frozen=True)
100110
class ConcatNode(BigFrameNode):
@@ -104,13 +114,19 @@ class ConcatNode(BigFrameNode):
104114
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
105115
return self.children
106116

117+
def __hash__(self):
118+
return self._node_hash
119+
107120

108121
# Input Nodex
109122
@dataclass(frozen=True)
110123
class ReadLocalNode(BigFrameNode):
111124
feather_bytes: bytes
112125
column_ids: typing.Tuple[str, ...]
113126

127+
def __hash__(self):
128+
return self._node_hash
129+
114130

115131
# TODO: Refactor to take raw gbq object reference
116132
@dataclass(frozen=True)
@@ -125,45 +141,70 @@ class ReadGbqNode(BigFrameNode):
125141
def session(self):
126142
return (self.table_session,)
127143

144+
def __hash__(self):
145+
return self._node_hash
146+
128147

129148
# Unary nodes
130149
@dataclass(frozen=True)
131150
class DropColumnsNode(UnaryNode):
132151
columns: Tuple[str, ...]
133152

153+
def __hash__(self):
154+
return self._node_hash
155+
134156

135157
@dataclass(frozen=True)
136158
class PromoteOffsetsNode(UnaryNode):
137159
col_id: str
138160

161+
def __hash__(self):
162+
return self._node_hash
163+
139164

140165
@dataclass(frozen=True)
141166
class FilterNode(UnaryNode):
142167
predicate_id: str
143168
keep_null: bool = False
144169

170+
def __hash__(self):
171+
return self._node_hash
172+
145173

146174
@dataclass(frozen=True)
147175
class OrderByNode(UnaryNode):
148176
by: Tuple[OrderingColumnReference, ...]
149177

178+
def __hash__(self):
179+
return self._node_hash
180+
150181

151182
@dataclass(frozen=True)
152183
class ReversedNode(UnaryNode):
153-
pass
184+
# useless field to make sure has distinct hash
185+
reversed: bool = True
186+
187+
def __hash__(self):
188+
return self._node_hash
154189

155190

156191
@dataclass(frozen=True)
157192
class SelectNode(UnaryNode):
158193
column_ids: typing.Tuple[str, ...]
159194

195+
def __hash__(self):
196+
return self._node_hash
197+
160198

161199
@dataclass(frozen=True)
162200
class ProjectUnaryOpNode(UnaryNode):
163201
input_id: str
164202
op: ops.UnaryOp
165203
output_id: Optional[str] = None
166204

205+
def __hash__(self):
206+
return self._node_hash
207+
167208

168209
@dataclass(frozen=True)
169210
class ProjectBinaryOpNode(UnaryNode):
@@ -172,6 +213,9 @@ class ProjectBinaryOpNode(UnaryNode):
172213
op: ops.BinaryOp
173214
output_id: str
174215

216+
def __hash__(self):
217+
return self._node_hash
218+
175219

176220
@dataclass(frozen=True)
177221
class ProjectTernaryOpNode(UnaryNode):
@@ -181,19 +225,28 @@ class ProjectTernaryOpNode(UnaryNode):
181225
op: ops.TernaryOp
182226
output_id: str
183227

228+
def __hash__(self):
229+
return self._node_hash
230+
184231

185232
@dataclass(frozen=True)
186233
class AggregateNode(UnaryNode):
187234
aggregations: typing.Tuple[typing.Tuple[str, agg_ops.AggregateOp, str], ...]
188235
by_column_ids: typing.Tuple[str, ...] = tuple([])
189236
dropna: bool = True
190237

238+
def __hash__(self):
239+
return self._node_hash
240+
191241

192242
# TODO: Unify into aggregate
193243
@dataclass(frozen=True)
194244
class CorrNode(UnaryNode):
195245
corr_aggregations: typing.Tuple[typing.Tuple[str, str, str], ...]
196246

247+
def __hash__(self):
248+
return self._node_hash
249+
197250

198251
@dataclass(frozen=True)
199252
class WindowOpNode(UnaryNode):
@@ -204,10 +257,14 @@ class WindowOpNode(UnaryNode):
204257
never_skip_nulls: bool = False
205258
skip_reproject_unsafe: bool = False
206259

260+
def __hash__(self):
261+
return self._node_hash
262+
207263

208264
@dataclass(frozen=True)
209265
class ReprojectOpNode(UnaryNode):
210-
pass
266+
def __hash__(self):
267+
return self._node_hash
211268

212269

213270
@dataclass(frozen=True)
@@ -223,19 +280,28 @@ class UnpivotNode(UnaryNode):
223280
] = (pandas.Float64Dtype(),)
224281
how: typing.Literal["left", "right"] = "left"
225282

283+
def __hash__(self):
284+
return self._node_hash
285+
226286

227287
@dataclass(frozen=True)
228288
class AssignNode(UnaryNode):
229289
source_id: str
230290
destination_id: str
231291

292+
def __hash__(self):
293+
return self._node_hash
294+
232295

233296
@dataclass(frozen=True)
234297
class AssignConstantNode(UnaryNode):
235298
destination_id: str
236299
value: typing.Hashable
237300
dtype: typing.Optional[bigframes.dtypes.Dtype]
238301

302+
def __hash__(self):
303+
return self._node_hash
304+
239305

240306
@dataclass(frozen=True)
241307
class RandomSampleNode(UnaryNode):
@@ -244,3 +310,6 @@ class RandomSampleNode(UnaryNode):
244310
@property
245311
def deterministic(self) -> bool:
246312
return False
313+
314+
def __hash__(self):
315+
return self._node_hash

bigframes/pandas/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from collections import namedtuple
2020
import inspect
21+
import sys
2122
import typing
2223
from typing import (
2324
Any,
@@ -657,6 +658,9 @@ def read_gbq_function(function_name: str):
657658
close_session = global_session.close_session
658659
reset_session = global_session.close_session
659660

661+
# SQL Compilation uses recursive algorithms on deep trees
662+
# 10M tree depth should be sufficient to generate any sql that is under bigquery limit
663+
sys.setrecursionlimit(max(10000000, sys.getrecursionlimit()))
660664

661665
# Use __all__ to let type checkers know what is part of the public API.
662666
__all___ = [

tests/system/small/test_dataframe.py

+7
Original file line numberDiff line numberDiff line change
@@ -3667,6 +3667,13 @@ def test_df_dot_operator_series(
36673667
)
36683668

36693669

3670+
def test_recursion_limit(scalars_df_index):
3671+
scalars_df_index = scalars_df_index[["int64_too", "int64_col", "float64_col"]]
3672+
for i in range(400):
3673+
scalars_df_index = scalars_df_index + 4
3674+
scalars_df_index.to_pandas()
3675+
3676+
36703677
def test_to_pandas_downsampling_option_override(session):
36713678
df = session.read_gbq("bigframes-dev.bigframes_tests_sys.batting")
36723679
download_size = 1

0 commit comments

Comments
 (0)