Skip to content

Commit 0fcfc23

Browse files
authored
fix: correctly set resume token when restarting streams (#314)
* fix: correctly set resume token for restarting streams * style: fix lint * docs: update docstring * test: fix assertion Co-authored-by: larkee <[email protected]>
1 parent 772aa3c commit 0fcfc23

File tree

3 files changed

+76
-48
lines changed

3 files changed

+76
-48
lines changed

google/cloud/spanner_v1/database.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -518,11 +518,11 @@ def execute_pdml():
518518
param_types=param_types,
519519
query_options=query_options,
520520
)
521-
restart = functools.partial(
522-
api.execute_streaming_sql, request=request, metadata=metadata,
521+
method = functools.partial(
522+
api.execute_streaming_sql, metadata=metadata,
523523
)
524524

525-
iterator = _restart_on_unavailable(restart)
525+
iterator = _restart_on_unavailable(method, request)
526526

527527
result_set = StreamedResultSet(iterator)
528528
list(result_set) # consume all partials

google/cloud/spanner_v1/snapshot.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,21 @@
4141
)
4242

4343

44-
def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=None):
44+
def _restart_on_unavailable(
45+
method, request, trace_name=None, session=None, attributes=None
46+
):
4547
"""Restart iteration after :exc:`.ServiceUnavailable`.
4648
47-
:type restart: callable
48-
:param restart: curried function returning iterator
49+
:type method: callable
50+
:param method: function returning iterator
51+
52+
:type request: proto
53+
:param request: request proto to call the method with
4954
"""
5055
resume_token = b""
5156
item_buffer = []
5257
with trace_call(trace_name, session, attributes):
53-
iterator = restart()
58+
iterator = method(request=request)
5459
while True:
5560
try:
5661
for item in iterator:
@@ -61,7 +66,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N
6166
except ServiceUnavailable:
6267
del item_buffer[:]
6368
with trace_call(trace_name, session, attributes):
64-
iterator = restart(resume_token=resume_token)
69+
request.resume_token = resume_token
70+
iterator = method(request=request)
6571
continue
6672
except InternalServerError as exc:
6773
resumable_error = any(
@@ -72,7 +78,8 @@ def _restart_on_unavailable(restart, trace_name=None, session=None, attributes=N
7278
raise
7379
del item_buffer[:]
7480
with trace_call(trace_name, session, attributes):
75-
iterator = restart(resume_token=resume_token)
81+
request.resume_token = resume_token
82+
iterator = method(request=request)
7683
continue
7784

7885
if len(item_buffer) == 0:
@@ -189,7 +196,11 @@ def read(
189196

190197
trace_attributes = {"table_id": table, "columns": columns}
191198
iterator = _restart_on_unavailable(
192-
restart, "CloudSpanner.ReadOnlyTransaction", self._session, trace_attributes
199+
restart,
200+
request,
201+
"CloudSpanner.ReadOnlyTransaction",
202+
self._session,
203+
trace_attributes,
193204
)
194205

195206
self._read_request_count += 1
@@ -302,6 +313,7 @@ def execute_sql(
302313
trace_attributes = {"db.statement": sql}
303314
iterator = _restart_on_unavailable(
304315
restart,
316+
request,
305317
"CloudSpanner.ReadWriteTransaction",
306318
self._session,
307319
trace_attributes,

tests/unit/test_snapshot.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@
4747

4848

4949
class Test_restart_on_unavailable(OpenTelemetryBase):
50-
def _call_fut(self, restart, span_name=None, session=None, attributes=None):
50+
def _call_fut(
51+
self, restart, request, span_name=None, session=None, attributes=None
52+
):
5153
from google.cloud.spanner_v1.snapshot import _restart_on_unavailable
5254

53-
return _restart_on_unavailable(restart, span_name, session, attributes)
55+
return _restart_on_unavailable(restart, request, span_name, session, attributes)
5456

5557
def _make_item(self, value, resume_token=b""):
5658
return mock.Mock(
@@ -59,18 +61,21 @@ def _make_item(self, value, resume_token=b""):
5961

6062
def test_iteration_w_empty_raw(self):
6163
raw = _MockIterator()
64+
request = mock.Mock(test="test", spec=["test", "resume_token"])
6265
restart = mock.Mock(spec=[], return_value=raw)
63-
resumable = self._call_fut(restart)
66+
resumable = self._call_fut(restart, request)
6467
self.assertEqual(list(resumable), [])
68+
restart.assert_called_once_with(request=request)
6569
self.assertNoSpans()
6670

6771
def test_iteration_w_non_empty_raw(self):
6872
ITEMS = (self._make_item(0), self._make_item(1))
6973
raw = _MockIterator(*ITEMS)
74+
request = mock.Mock(test="test", spec=["test", "resume_token"])
7075
restart = mock.Mock(spec=[], return_value=raw)
71-
resumable = self._call_fut(restart)
76+
resumable = self._call_fut(restart, request)
7277
self.assertEqual(list(resumable), list(ITEMS))
73-
restart.assert_called_once_with()
78+
restart.assert_called_once_with(request=request)
7479
self.assertNoSpans()
7580

7681
def test_iteration_w_raw_w_resume_tken(self):
@@ -81,10 +86,11 @@ def test_iteration_w_raw_w_resume_tken(self):
8186
self._make_item(3),
8287
)
8388
raw = _MockIterator(*ITEMS)
89+
request = mock.Mock(test="test", spec=["test", "resume_token"])
8490
restart = mock.Mock(spec=[], return_value=raw)
85-
resumable = self._call_fut(restart)
91+
resumable = self._call_fut(restart, request)
8692
self.assertEqual(list(resumable), list(ITEMS))
87-
restart.assert_called_once_with()
93+
restart.assert_called_once_with(request=request)
8894
self.assertNoSpans()
8995

9096
def test_iteration_w_raw_raising_unavailable_no_token(self):
@@ -97,10 +103,12 @@ def test_iteration_w_raw_raising_unavailable_no_token(self):
97103
)
98104
before = _MockIterator(fail_after=True, error=ServiceUnavailable("testing"))
99105
after = _MockIterator(*ITEMS)
106+
request = mock.Mock(test="test", spec=["test", "resume_token"])
100107
restart = mock.Mock(spec=[], side_effect=[before, after])
101-
resumable = self._call_fut(restart)
108+
resumable = self._call_fut(restart, request)
102109
self.assertEqual(list(resumable), list(ITEMS))
103-
self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")])
110+
self.assertEqual(len(restart.mock_calls), 2)
111+
self.assertEqual(request.resume_token, b"")
104112
self.assertNoSpans()
105113

106114
def test_iteration_w_raw_raising_retryable_internal_error_no_token(self):
@@ -118,10 +126,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_no_token(self):
118126
),
119127
)
120128
after = _MockIterator(*ITEMS)
129+
request = mock.Mock(test="test", spec=["test", "resume_token"])
121130
restart = mock.Mock(spec=[], side_effect=[before, after])
122-
resumable = self._call_fut(restart)
131+
resumable = self._call_fut(restart, request)
123132
self.assertEqual(list(resumable), list(ITEMS))
124-
self.assertEqual(restart.mock_calls, [mock.call(), mock.call(resume_token=b"")])
133+
self.assertEqual(len(restart.mock_calls), 2)
134+
self.assertEqual(request.resume_token, b"")
125135
self.assertNoSpans()
126136

127137
def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self):
@@ -134,11 +144,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_no_token(self):
134144
)
135145
before = _MockIterator(fail_after=True, error=InternalServerError("testing"))
136146
after = _MockIterator(*ITEMS)
147+
request = mock.Mock(spec=["resume_token"])
137148
restart = mock.Mock(spec=[], side_effect=[before, after])
138-
resumable = self._call_fut(restart)
149+
resumable = self._call_fut(restart, request)
139150
with self.assertRaises(InternalServerError):
140151
list(resumable)
141-
self.assertEqual(restart.mock_calls, [mock.call()])
152+
restart.assert_called_once_with(request=request)
142153
self.assertNoSpans()
143154

144155
def test_iteration_w_raw_raising_unavailable(self):
@@ -151,12 +162,12 @@ def test_iteration_w_raw_raising_unavailable(self):
151162
*(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing")
152163
)
153164
after = _MockIterator(*LAST)
165+
request = mock.Mock(test="test", spec=["test", "resume_token"])
154166
restart = mock.Mock(spec=[], side_effect=[before, after])
155-
resumable = self._call_fut(restart)
167+
resumable = self._call_fut(restart, request)
156168
self.assertEqual(list(resumable), list(FIRST + LAST))
157-
self.assertEqual(
158-
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
159-
)
169+
self.assertEqual(len(restart.mock_calls), 2)
170+
self.assertEqual(request.resume_token, RESUME_TOKEN)
160171
self.assertNoSpans()
161172

162173
def test_iteration_w_raw_raising_retryable_internal_error(self):
@@ -173,12 +184,12 @@ def test_iteration_w_raw_raising_retryable_internal_error(self):
173184
)
174185
)
175186
after = _MockIterator(*LAST)
187+
request = mock.Mock(test="test", spec=["test", "resume_token"])
176188
restart = mock.Mock(spec=[], side_effect=[before, after])
177-
resumable = self._call_fut(restart)
189+
resumable = self._call_fut(restart, request)
178190
self.assertEqual(list(resumable), list(FIRST + LAST))
179-
self.assertEqual(
180-
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
181-
)
191+
self.assertEqual(len(restart.mock_calls), 2)
192+
self.assertEqual(request.resume_token, RESUME_TOKEN)
182193
self.assertNoSpans()
183194

184195
def test_iteration_w_raw_raising_non_retryable_internal_error(self):
@@ -191,11 +202,12 @@ def test_iteration_w_raw_raising_non_retryable_internal_error(self):
191202
*(FIRST + SECOND), fail_after=True, error=InternalServerError("testing")
192203
)
193204
after = _MockIterator(*LAST)
205+
request = mock.Mock(test="test", spec=["test", "resume_token"])
194206
restart = mock.Mock(spec=[], side_effect=[before, after])
195-
resumable = self._call_fut(restart)
207+
resumable = self._call_fut(restart, request)
196208
with self.assertRaises(InternalServerError):
197209
list(resumable)
198-
self.assertEqual(restart.mock_calls, [mock.call()])
210+
restart.assert_called_once_with(request=request)
199211
self.assertNoSpans()
200212

201213
def test_iteration_w_raw_raising_unavailable_after_token(self):
@@ -207,12 +219,12 @@ def test_iteration_w_raw_raising_unavailable_after_token(self):
207219
*FIRST, fail_after=True, error=ServiceUnavailable("testing")
208220
)
209221
after = _MockIterator(*SECOND)
222+
request = mock.Mock(test="test", spec=["test", "resume_token"])
210223
restart = mock.Mock(spec=[], side_effect=[before, after])
211-
resumable = self._call_fut(restart)
224+
resumable = self._call_fut(restart, request)
212225
self.assertEqual(list(resumable), list(FIRST + SECOND))
213-
self.assertEqual(
214-
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
215-
)
226+
self.assertEqual(len(restart.mock_calls), 2)
227+
self.assertEqual(request.resume_token, RESUME_TOKEN)
216228
self.assertNoSpans()
217229

218230
def test_iteration_w_raw_raising_retryable_internal_error_after_token(self):
@@ -228,12 +240,12 @@ def test_iteration_w_raw_raising_retryable_internal_error_after_token(self):
228240
)
229241
)
230242
after = _MockIterator(*SECOND)
243+
request = mock.Mock(test="test", spec=["test", "resume_token"])
231244
restart = mock.Mock(spec=[], side_effect=[before, after])
232-
resumable = self._call_fut(restart)
245+
resumable = self._call_fut(restart, request)
233246
self.assertEqual(list(resumable), list(FIRST + SECOND))
234-
self.assertEqual(
235-
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
236-
)
247+
self.assertEqual(len(restart.mock_calls), 2)
248+
self.assertEqual(request.resume_token, RESUME_TOKEN)
237249
self.assertNoSpans()
238250

239251
def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self):
@@ -245,19 +257,23 @@ def test_iteration_w_raw_raising_non_retryable_internal_error_after_token(self):
245257
*FIRST, fail_after=True, error=InternalServerError("testing")
246258
)
247259
after = _MockIterator(*SECOND)
260+
request = mock.Mock(test="test", spec=["test", "resume_token"])
248261
restart = mock.Mock(spec=[], side_effect=[before, after])
249-
resumable = self._call_fut(restart)
262+
resumable = self._call_fut(restart, request)
250263
with self.assertRaises(InternalServerError):
251264
list(resumable)
252-
self.assertEqual(restart.mock_calls, [mock.call()])
265+
restart.assert_called_once_with(request=request)
253266
self.assertNoSpans()
254267

255268
def test_iteration_w_span_creation(self):
256269
name = "TestSpan"
257270
extra_atts = {"test_att": 1}
258271
raw = _MockIterator()
272+
request = mock.Mock(test="test", spec=["test", "resume_token"])
259273
restart = mock.Mock(spec=[], return_value=raw)
260-
resumable = self._call_fut(restart, name, _Session(_Database()), extra_atts)
274+
resumable = self._call_fut(
275+
restart, request, name, _Session(_Database()), extra_atts
276+
)
261277
self.assertEqual(list(resumable), [])
262278
self.assertSpanAttributes(name, attributes=dict(BASE_ATTRIBUTES, test_att=1))
263279

@@ -272,13 +288,13 @@ def test_iteration_w_multiple_span_creation(self):
272288
*(FIRST + SECOND), fail_after=True, error=ServiceUnavailable("testing")
273289
)
274290
after = _MockIterator(*LAST)
291+
request = mock.Mock(test="test", spec=["test", "resume_token"])
275292
restart = mock.Mock(spec=[], side_effect=[before, after])
276293
name = "TestSpan"
277-
resumable = self._call_fut(restart, name, _Session(_Database()))
294+
resumable = self._call_fut(restart, request, name, _Session(_Database()))
278295
self.assertEqual(list(resumable), list(FIRST + LAST))
279-
self.assertEqual(
280-
restart.mock_calls, [mock.call(), mock.call(resume_token=RESUME_TOKEN)]
281-
)
296+
self.assertEqual(len(restart.mock_calls), 2)
297+
self.assertEqual(request.resume_token, RESUME_TOKEN)
282298

283299
span_list = self.memory_exporter.get_finished_spans()
284300
self.assertEqual(len(span_list), 2)

0 commit comments

Comments
 (0)