Skip to content
This repository was archived by the owner on Jul 6, 2023. It is now read-only.

Commit 33eb545

Browse files
feat: add context manager support in client (#41)
- [ ] Regenerate this pull request now. chore: fix docstring for first attribute of protos committer: @busunkim96 PiperOrigin-RevId: 401271153 Source-Link: googleapis/googleapis@787f8c9 Source-Link: https://github.com/googleapis/googleapis-gen/commit/81decffe9fc72396a8153e756d1d67a6eecfd620 Copy-Tag: eyJwIjoiLmdpdGh1Yi8uT3dsQm90LnlhbWwiLCJoIjoiODFkZWNmZmU5ZmM3MjM5NmE4MTUzZTc1NmQxZDY3YTZlZWNmZDYyMCJ9
1 parent 10f8927 commit 33eb545

File tree

7 files changed

+99
-11
lines changed

7 files changed

+99
-11
lines changed

google/cloud/shell_v1/services/cloud_shell_service/async_client.py

+6
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,12 @@ async def remove_public_key(
518518
# Done; return the response.
519519
return response
520520

521+
async def __aenter__(self):
522+
return self
523+
524+
async def __aexit__(self, exc_type, exc, tb):
525+
await self.transport.close()
526+
521527

522528
try:
523529
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(

google/cloud/shell_v1/services/cloud_shell_service/client.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,7 @@ def __init__(
352352
client_cert_source_for_mtls=client_cert_source_func,
353353
quota_project_id=client_options.quota_project_id,
354354
client_info=client_info,
355-
always_use_jwt_access=(
356-
Transport == type(self).get_transport_class("grpc")
357-
or Transport == type(self).get_transport_class("grpc_asyncio")
358-
),
355+
always_use_jwt_access=True,
359356
)
360357

361358
def get_environment(
@@ -699,6 +696,19 @@ def remove_public_key(
699696
# Done; return the response.
700697
return response
701698

699+
def __enter__(self):
700+
return self
701+
702+
def __exit__(self, type, value, traceback):
703+
"""Releases underlying transport's resources.
704+
705+
.. warning::
706+
ONLY use as a context manager if the transport is NOT shared
707+
with other clients! Exiting the with block will CLOSE the transport
708+
and may cause errors in other clients!
709+
"""
710+
self.transport.close()
711+
702712

703713
try:
704714
DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo(

google/cloud/shell_v1/services/cloud_shell_service/transports/base.py

+9
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,15 @@ def _prep_wrapped_messages(self, client_info):
184184
),
185185
}
186186

187+
def close(self):
188+
"""Closes resources associated with the transport.
189+
190+
.. warning::
191+
Only call this method if the transport is NOT shared
192+
with other clients - this may cause errors in other clients!
193+
"""
194+
raise NotImplementedError()
195+
187196
@property
188197
def operations_client(self) -> operations_v1.OperationsClient:
189198
"""Return the client designed to process long-running operations."""

google/cloud/shell_v1/services/cloud_shell_service/transports/grpc.py

+3
Original file line numberDiff line numberDiff line change
@@ -397,5 +397,8 @@ def remove_public_key(
397397
)
398398
return self._stubs["remove_public_key"]
399399

400+
def close(self):
401+
self.grpc_channel.close()
402+
400403

401404
__all__ = ("CloudShellServiceGrpcTransport",)

google/cloud/shell_v1/services/cloud_shell_service/transports/grpc_asyncio.py

+3
Original file line numberDiff line numberDiff line change
@@ -412,5 +412,8 @@ def remove_public_key(
412412
)
413413
return self._stubs["remove_public_key"]
414414

415+
def close(self):
416+
return self.grpc_channel.close()
417+
415418

416419
__all__ = ("CloudShellServiceGrpcAsyncIOTransport",)

google/cloud/shell_v1/types/cloudshell.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,15 @@ class GetEnvironmentRequest(proto.Message):
129129
class CreateEnvironmentMetadata(proto.Message):
130130
r"""Message included in the metadata field of operations returned from
131131
[CreateEnvironment][google.cloud.shell.v1.CloudShellService.CreateEnvironment].
132-
"""
132+
133+
"""
133134

134135

135136
class DeleteEnvironmentMetadata(proto.Message):
136137
r"""Message included in the metadata field of operations returned from
137138
[DeleteEnvironment][google.cloud.shell.v1.CloudShellService.DeleteEnvironment].
138-
"""
139+
140+
"""
139141

140142

141143
class StartEnvironmentRequest(proto.Message):
@@ -195,13 +197,15 @@ class AuthorizeEnvironmentRequest(proto.Message):
195197
class AuthorizeEnvironmentResponse(proto.Message):
196198
r"""Response message for
197199
[AuthorizeEnvironment][google.cloud.shell.v1.CloudShellService.AuthorizeEnvironment].
198-
"""
200+
201+
"""
199202

200203

201204
class AuthorizeEnvironmentMetadata(proto.Message):
202205
r"""Message included in the metadata field of operations returned from
203206
[AuthorizeEnvironment][google.cloud.shell.v1.CloudShellService.AuthorizeEnvironment].
204-
"""
207+
208+
"""
205209

206210

207211
class StartEnvironmentMetadata(proto.Message):
@@ -280,7 +284,8 @@ class AddPublicKeyResponse(proto.Message):
280284
class AddPublicKeyMetadata(proto.Message):
281285
r"""Message included in the metadata field of operations returned from
282286
[AddPublicKey][google.cloud.shell.v1.CloudShellService.AddPublicKey].
283-
"""
287+
288+
"""
284289

285290

286291
class RemovePublicKeyRequest(proto.Message):
@@ -303,13 +308,15 @@ class RemovePublicKeyRequest(proto.Message):
303308
class RemovePublicKeyResponse(proto.Message):
304309
r"""Response message for
305310
[RemovePublicKey][google.cloud.shell.v1.CloudShellService.RemovePublicKey].
306-
"""
311+
312+
"""
307313

308314

309315
class RemovePublicKeyMetadata(proto.Message):
310316
r"""Message included in the metadata field of operations returned from
311317
[RemovePublicKey][google.cloud.shell.v1.CloudShellService.RemovePublicKey].
312-
"""
318+
319+
"""
313320

314321

315322
class CloudShellErrorDetails(proto.Message):

tests/unit/gapic/shell_v1/test_cloud_shell_service.py

+50
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from google.api_core import grpc_helpers_async
3333
from google.api_core import operation_async # type: ignore
3434
from google.api_core import operations_v1
35+
from google.api_core import path_template
3536
from google.auth import credentials as ga_credentials
3637
from google.auth.exceptions import MutualTLSChannelError
3738
from google.cloud.shell_v1.services.cloud_shell_service import (
@@ -1414,6 +1415,9 @@ def test_cloud_shell_service_base_transport():
14141415
with pytest.raises(NotImplementedError):
14151416
getattr(transport, method)(request=object())
14161417

1418+
with pytest.raises(NotImplementedError):
1419+
transport.close()
1420+
14171421
# Additionally, the LRO client (a property) should
14181422
# also raise NotImplementedError
14191423
with pytest.raises(NotImplementedError):
@@ -1927,3 +1931,49 @@ def test_client_withDEFAULT_CLIENT_INFO():
19271931
credentials=ga_credentials.AnonymousCredentials(), client_info=client_info,
19281932
)
19291933
prep.assert_called_once_with(client_info)
1934+
1935+
1936+
@pytest.mark.asyncio
1937+
async def test_transport_close_async():
1938+
client = CloudShellServiceAsyncClient(
1939+
credentials=ga_credentials.AnonymousCredentials(), transport="grpc_asyncio",
1940+
)
1941+
with mock.patch.object(
1942+
type(getattr(client.transport, "grpc_channel")), "close"
1943+
) as close:
1944+
async with client:
1945+
close.assert_not_called()
1946+
close.assert_called_once()
1947+
1948+
1949+
def test_transport_close():
1950+
transports = {
1951+
"grpc": "_grpc_channel",
1952+
}
1953+
1954+
for transport, close_name in transports.items():
1955+
client = CloudShellServiceClient(
1956+
credentials=ga_credentials.AnonymousCredentials(), transport=transport
1957+
)
1958+
with mock.patch.object(
1959+
type(getattr(client.transport, close_name)), "close"
1960+
) as close:
1961+
with client:
1962+
close.assert_not_called()
1963+
close.assert_called_once()
1964+
1965+
1966+
def test_client_ctx():
1967+
transports = [
1968+
"grpc",
1969+
]
1970+
for transport in transports:
1971+
client = CloudShellServiceClient(
1972+
credentials=ga_credentials.AnonymousCredentials(), transport=transport
1973+
)
1974+
# Test client calls underlying transport.
1975+
with mock.patch.object(type(client.transport), "close") as close:
1976+
close.assert_not_called()
1977+
with client:
1978+
pass
1979+
close.assert_called()

0 commit comments

Comments
 (0)