Skip to content

Commit c95ca68

Browse files
committed
addressing broken tests
1 parent ae5dc6a commit c95ca68

File tree

2 files changed

+35
-41
lines changed

2 files changed

+35
-41
lines changed

google/cloud/bigtable/data/_async/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,8 @@ async def _manage_channel(
361361
start_timestamp = time.monotonic()
362362
# prepare new channel for use
363363
old_channel = self.transport.grpc_channel
364-
new_channel = self.transport.grpc_channel._create_channel()
365-
await self._ping_and_warm_instances(new_channel)
364+
new_channel = self.transport.create_channel()
365+
await self._ping_and_warm_instances(channel=new_channel)
366366
# cycle channel out of use, with long grace window before closure
367367
self.transport._grpc_channel = new_channel
368368
await old_channel.close(grace_period)

tests/unit/data/_async/test_client.py

Lines changed: 33 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ async def test_ctor_super_inits(self):
106106
client_options = {"api_endpoint": "foo.bar:1234"}
107107
options_parsed = client_options_lib.from_dict(client_options)
108108
asyncio_portion = "-async" if CrossSync.is_async else ""
109-
transport_str = f"bt-{bigtable_version}-data{asyncio_portion}-{pool_size}"
110109
with mock.patch.object(
111110
CrossSync.GapicClient, "__init__"
112111
) as bigtable_client_init:
@@ -503,7 +502,7 @@ async def test__manage_channel_refresh(self, num_cycles):
503502
grpc_lib = grpc.aio if CrossSync.is_async else grpc
504503
new_channel = grpc_lib.insecure_channel("localhost:8080")
505504

506-
with mock.patch.object(asyncio, "sleep") as sleep:
505+
with mock.patch.object(CrossSync, "event_wait") as sleep:
507506
sleep.side_effect = [None for i in range(num_cycles)] + [
508507
asyncio.CancelledError
509508
]
@@ -602,7 +601,7 @@ async def test__register_instance_duplicate(self):
602601
instance_owners = {}
603602
client_mock._active_instances = active_instances
604603
client_mock._instance_owners = instance_owners
605-
client_mock._channel_refresh_tasks = [object()]
604+
client_mock._channel_refresh_task = object()
606605
mock_channels = [mock.Mock()]
607606
client_mock.transport.channels = mock_channels
608607
client_mock._ping_and_warm_instances = CrossSync.Mock()
@@ -659,12 +658,7 @@ async def test__register_instance_state(
659658
instance_owners = {}
660659
client_mock._active_instances = active_instances
661660
client_mock._instance_owners = instance_owners
662-
client_mock._channel_refresh_tasks = []
663-
client_mock._start_background_channel_refresh.side_effect = (
664-
lambda: client_mock._channel_refresh_tasks.append(mock.Mock)
665-
)
666-
mock_channels = [mock.Mock() for i in range(5)]
667-
client_mock.transport.channels = mock_channels
661+
client_mock._channel_refresh_task = None
668662
client_mock._ping_and_warm_instances = CrossSync.Mock()
669663
table_mock = mock.Mock()
670664
# register instances
@@ -951,7 +945,6 @@ async def test_close(self):
951945
async def test_close_with_timeout(self):
952946
expected_timeout = 19
953947
client = self._make_client(project="project-id", use_emulator=False)
954-
tasks = list(client._channel_refresh_tasks)
955948
with mock.patch.object(CrossSync, "wait", CrossSync.Mock()) as wait_for_mock:
956949
await client.close(timeout=expected_timeout)
957950
wait_for_mock.assert_called_once()
@@ -1275,36 +1268,37 @@ async def test_customizable_retryable_errors(
12751268
@CrossSync.pytest
12761269
@CrossSync.convert
12771270
async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn):
1278-
"""check that all requests attach proper metadata headers"""
1279-
profile = "profile" if include_app_profile else None
1280-
with mock.patch.object(
1281-
CrossSync.GapicClient, gapic_fn, CrossSync.Mock()
1282-
) as gapic_mock:
1283-
gapic_mock.side_effect = RuntimeError("stop early")
1284-
async with self._make_client() as client:
1285-
table = self._get_target_class()(
1286-
client, "instance-id", "table-id", profile
1287-
)
1288-
try:
1289-
test_fn = table.__getattribute__(fn_name)
1290-
maybe_stream = await test_fn(*fn_args)
1291-
[i async for i in maybe_stream]
1292-
except Exception:
1293-
# we expect an exception from attempting to call the mock
1294-
pass
1295-
kwargs = gapic_mock.call_args_list[0][1]
1296-
metadata = kwargs["metadata"]
1297-
goog_metadata = None
1298-
for key, value in metadata:
1299-
if key == "x-goog-request-params":
1300-
goog_metadata = value
1301-
assert goog_metadata is not None, "x-goog-request-params not found"
1302-
assert "table_name=" + table.table_name in goog_metadata
1303-
if include_app_profile:
1304-
assert "app_profile_id=profile" in goog_metadata
1305-
else:
1306-
assert "app_profile_id=" not in goog_metadata
1271+
from google.cloud.bigtable.data import TableAsync
13071272

1273+
profile = "profile" if include_app_profile else None
1274+
client = self._make_client()
1275+
# create mock for rpc stub
1276+
transport_mock = mock.MagicMock()
1277+
rpc_mock = mock.AsyncMock()
1278+
transport_mock._wrapped_methods.__getitem__.return_value = rpc_mock
1279+
client._gapic_client._client._transport = transport_mock
1280+
client._gapic_client._client._is_universe_domain_valid = True
1281+
table = self._get_target_class()(client, "instance-id", "table-id", profile)
1282+
try:
1283+
test_fn = table.__getattribute__(fn_name)
1284+
maybe_stream = await test_fn(*fn_args)
1285+
[i async for i in maybe_stream]
1286+
except Exception:
1287+
# we expect an exception from attempting to call the mock
1288+
pass
1289+
assert rpc_mock.call_count == 1
1290+
kwargs = rpc_mock.call_args_list[0].kwargs
1291+
metadata = kwargs["metadata"]
1292+
# expect single metadata entry
1293+
assert len(metadata) == 1
1294+
# expect x-goog-request-params tag
1295+
assert metadata[0][0] == "x-goog-request-params"
1296+
routing_str = metadata[0][1]
1297+
assert "table_name=" + table.table_name in routing_str
1298+
if include_app_profile:
1299+
assert "app_profile_id=profile" in routing_str
1300+
else:
1301+
assert "app_profile_id=" not in routing_str
13081302

13091303
@CrossSync.convert_class(
13101304
"TestReadRows",

0 commit comments

Comments
 (0)