@@ -549,6 +549,74 @@ def test_get_table_column_schema(self):
549
549
)
550
550
self .assertEqual (result , expected )
551
551
552
+ def test_peek_iterator_aborted (self ):
553
+ """
554
+ Checking that an Aborted exception is retried in case it happened
555
+ while streaming the first element with a PeekIterator.
556
+ """
557
+ from google .api_core .exceptions import Aborted
558
+ from google .cloud .spanner_dbapi .connection import connect
559
+
560
+ with mock .patch (
561
+ "google.cloud.spanner_v1.instance.Instance.exists" , return_value = True ,
562
+ ):
563
+ with mock .patch (
564
+ "google.cloud.spanner_v1.database.Database.exists" , return_value = True ,
565
+ ):
566
+ connection = connect ("test-instance" , "test-database" )
567
+
568
+ cursor = connection .cursor ()
569
+ with mock .patch (
570
+ "google.cloud.spanner_dbapi.utils.PeekIterator.__init__" ,
571
+ side_effect = (Aborted ("Aborted" ), None ),
572
+ ):
573
+ with mock .patch (
574
+ "google.cloud.spanner_dbapi.connection.Connection.retry_transaction"
575
+ ) as retry_mock :
576
+ with mock .patch (
577
+ "google.cloud.spanner_dbapi.connection.Connection.run_statement" ,
578
+ return_value = ((1 , 2 , 3 ), None ),
579
+ ):
580
+ cursor .execute ("SELECT * FROM table_name" )
581
+
582
+ retry_mock .assert_called_with ()
583
+
584
+ def test_peek_iterator_aborted_autocommit (self ):
585
+ """
586
+ Checking that an Aborted exception is retried in case it happened while
587
+ streaming the first element with a PeekIterator in autocommit mode.
588
+ """
589
+ from google .api_core .exceptions import Aborted
590
+ from google .cloud .spanner_dbapi .connection import connect
591
+
592
+ with mock .patch (
593
+ "google.cloud.spanner_v1.instance.Instance.exists" , return_value = True ,
594
+ ):
595
+ with mock .patch (
596
+ "google.cloud.spanner_v1.database.Database.exists" , return_value = True ,
597
+ ):
598
+ connection = connect ("test-instance" , "test-database" )
599
+
600
+ connection .autocommit = True
601
+ cursor = connection .cursor ()
602
+ with mock .patch (
603
+ "google.cloud.spanner_dbapi.utils.PeekIterator.__init__" ,
604
+ side_effect = (Aborted ("Aborted" ), None ),
605
+ ):
606
+ with mock .patch (
607
+ "google.cloud.spanner_dbapi.connection.Connection.retry_transaction"
608
+ ) as retry_mock :
609
+ with mock .patch (
610
+ "google.cloud.spanner_dbapi.connection.Connection.run_statement" ,
611
+ return_value = ((1 , 2 , 3 ), None ),
612
+ ):
613
+ with mock .patch (
614
+ "google.cloud.spanner_v1.database.Database.snapshot"
615
+ ):
616
+ cursor .execute ("SELECT * FROM table_name" )
617
+
618
+ retry_mock .assert_called_with ()
619
+
552
620
def test_fetchone_retry_aborted (self ):
553
621
"""Check that aborted fetch re-executing transaction."""
554
622
from google .api_core .exceptions import Aborted
0 commit comments