Skip to content
2 changes: 1 addition & 1 deletion pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2825,7 +2825,7 @@ async def run(self) -> T:
if self._last_error is None:
self._last_error = exc

if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
if self._server is not None:
self._deprioritized_servers.append(self._server)

def _is_not_eligible_for_retry(self) -> bool:
Expand Down
37 changes: 20 additions & 17 deletions pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ async def select_servers(
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
operation_id: Optional[int] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> list[Server]:
"""Return a list of Servers matching selector, or time out.

Expand Down Expand Up @@ -292,7 +293,12 @@ async def select_servers(

async with self._lock:
server_descriptions = await self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
selector,
server_timeout,
operation,
operation_id,
address,
deprioritized_servers=deprioritized_servers,
)

return [
Expand All @@ -306,6 +312,7 @@ async def _select_servers_loop(
operation: str,
operation_id: Optional[int],
address: Optional[_Address],
deprioritized_servers: Optional[list[Server]] = None,
) -> list[ServerDescription]:
"""select_servers() guts. Hold the lock when calling this."""
now = time.monotonic()
Expand All @@ -324,7 +331,12 @@ async def _select_servers_loop(
)

server_descriptions = self._description.apply_selector(
selector, address, custom_selector=self._settings.server_selector
selector,
address,
custom_selector=self._settings.server_selector,
deprioritized_servers=[server.description for server in deprioritized_servers]
if deprioritized_servers
else None,
)

while not server_descriptions:
Expand Down Expand Up @@ -385,9 +397,13 @@ async def _select_server(
operation_id: Optional[int] = None,
) -> Server:
servers = await self.select_servers(
selector, operation, server_selection_timeout, address, operation_id
selector,
operation,
server_selection_timeout,
address,
operation_id,
deprioritized_servers,
)
servers = _filter_servers(servers, deprioritized_servers)
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
Expand Down Expand Up @@ -1112,16 +1128,3 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe
if current_tv["processId"] != new_tv["processId"]:
return False
return current_tv["counter"] > new_tv["counter"]


def _filter_servers(
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
) -> list[Server]:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
return candidates

filtered = [server for server in candidates if server not in deprioritized_servers]

# If not possible to pick a prioritized server, return the original list
return filtered or candidates
6 changes: 3 additions & 3 deletions pymongo/server_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ class Selection:

@classmethod
def from_topology_description(cls, topology_description: TopologyDescription) -> Selection:
known_servers = topology_description.known_servers
candidate_servers = topology_description.candidate_servers
primary = None
for sd in known_servers:
for sd in candidate_servers:
if sd.server_type == SERVER_TYPE.RSPrimary:
primary = sd
break

return Selection(
topology_description,
topology_description.known_servers,
topology_description.candidate_servers,
topology_description.common_wire_version,
primary,
)
Expand Down
2 changes: 1 addition & 1 deletion pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2815,7 +2815,7 @@ def run(self) -> T:
if self._last_error is None:
self._last_error = exc

if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
if self._server is not None:
self._deprioritized_servers.append(self._server)

def _is_not_eligible_for_retry(self) -> bool:
Expand Down
37 changes: 20 additions & 17 deletions pymongo/synchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def select_servers(
server_selection_timeout: Optional[float] = None,
address: Optional[_Address] = None,
operation_id: Optional[int] = None,
deprioritized_servers: Optional[list[Server]] = None,
) -> list[Server]:
"""Return a list of Servers matching selector, or time out.

Expand Down Expand Up @@ -292,7 +293,12 @@ def select_servers(

with self._lock:
server_descriptions = self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
selector,
server_timeout,
operation,
operation_id,
address,
deprioritized_servers=deprioritized_servers,
)

return [
Expand All @@ -306,6 +312,7 @@ def _select_servers_loop(
operation: str,
operation_id: Optional[int],
address: Optional[_Address],
deprioritized_servers: Optional[list[Server]] = None,
) -> list[ServerDescription]:
"""select_servers() guts. Hold the lock when calling this."""
now = time.monotonic()
Expand All @@ -324,7 +331,12 @@ def _select_servers_loop(
)

server_descriptions = self._description.apply_selector(
selector, address, custom_selector=self._settings.server_selector
selector,
address,
custom_selector=self._settings.server_selector,
deprioritized_servers=[server.description for server in deprioritized_servers]
if deprioritized_servers
else None,
)

while not server_descriptions:
Expand Down Expand Up @@ -385,9 +397,13 @@ def _select_server(
operation_id: Optional[int] = None,
) -> Server:
servers = self.select_servers(
selector, operation, server_selection_timeout, address, operation_id
selector,
operation,
server_selection_timeout,
address,
operation_id,
deprioritized_servers,
)
servers = _filter_servers(servers, deprioritized_servers)
if len(servers) == 1:
return servers[0]
server1, server2 = random.sample(servers, 2)
Expand Down Expand Up @@ -1110,16 +1126,3 @@ def _is_stale_server_description(current_sd: ServerDescription, new_sd: ServerDe
if current_tv["processId"] != new_tv["processId"]:
return False
return current_tv["counter"] > new_tv["counter"]


def _filter_servers(
candidates: list[Server], deprioritized_servers: Optional[list[Server]] = None
) -> list[Server]:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
return candidates

filtered = [server for server in candidates if server not in deprioritized_servers]

# If not possible to pick a prioritized server, return the original list
return filtered or candidates
38 changes: 37 additions & 1 deletion pymongo/topology_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
self._server_descriptions = server_descriptions
self._max_set_version = max_set_version
self._max_election_id = max_election_id
self._candidate_servers = list(self._server_descriptions.values())

# The heartbeat_frequency is used in staleness estimates.
self._topology_settings = topology_settings
Expand Down Expand Up @@ -248,6 +249,11 @@ def readable_servers(self) -> list[ServerDescription]:
"""List of readable Servers."""
return [s for s in self._server_descriptions.values() if s.is_readable]

@property
def candidate_servers(self) -> list[ServerDescription]:
"""List of Servers excluding deprioritized servers."""
return self._candidate_servers

@property
def common_wire_version(self) -> Optional[int]:
"""Minimum of all servers' max wire versions, or None."""
Expand Down Expand Up @@ -283,11 +289,27 @@ def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerD
if (cast(float, s.round_trip_time) - fastest) <= threshold
]

def _filter_servers(
self, deprioritized_servers: Optional[list[ServerDescription]] = None
) -> None:
"""Filter out deprioritized servers from a list of server candidates."""
if not deprioritized_servers:
self._candidate_servers = self.known_servers
else:
deprioritized_addresses = {sd.address for sd in deprioritized_servers}
filtered = [
server
for server in self.known_servers
if server.address not in deprioritized_addresses
]
self._candidate_servers = filtered or self.known_servers

def apply_selector(
self,
selector: Any,
address: Optional[_Address] = None,
custom_selector: Optional[_ServerSelector] = None,
deprioritized_servers: Optional[list[ServerDescription]] = None,
) -> list[ServerDescription]:
"""List of servers matching the provided selector(s).

Expand Down Expand Up @@ -324,21 +346,35 @@ def apply_selector(
description = self.server_descriptions().get(address)
return [description] if description and description.is_server_type_known else []

self._filter_servers(deprioritized_servers)
# Primary selection fast path.
if self.topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary and type(selector) is Primary:
for sd in self._server_descriptions.values():
for sd in self._candidate_servers:
if sd.server_type == SERVER_TYPE.RSPrimary:
sds = [sd]
if custom_selector:
sds = custom_selector(sds)
return sds
# All primaries are deprioritized
if deprioritized_servers:
for sd in deprioritized_servers:
if sd.server_type == SERVER_TYPE.RSPrimary:
sds = [sd]
if custom_selector:
sds = custom_selector(sds)
return sds
# No primary found, return an empty list.
return []

selection = Selection.from_topology_description(self)
# Ignore read preference for sharded clusters.
if self.topology_type != TOPOLOGY_TYPE.Sharded:
selection = selector(selection)
# No suitable servers found, apply preference again but include deprioritized servers.
if not selection and deprioritized_servers:
self._filter_servers(None)
selection = Selection.from_topology_description(self)
selection = selector(selection)

# Apply custom selector followed by localThresholdMS.
if custom_selector is not None and selection:
Expand Down
39 changes: 32 additions & 7 deletions test/asynchronous/utils_selection_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from bson import json_util
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology
from pymongo.common import HEARTBEAT_FREQUENCY
from pymongo.common import HEARTBEAT_FREQUENCY, clean_node
from pymongo.errors import AutoReconnect, ConfigurationError
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
Expand Down Expand Up @@ -95,12 +95,21 @@ async def run_scenario(self):
# "Eligible servers" is defined in the server selection spec as
# the set of servers matching both the ReadPreference's mode
# and tag sets.
top_latency = await create_topology(scenario_def)
top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000)

# "In latency window" is defined in the server selection
# spec as the subset of suitable_servers that falls within the
# allowable latency window.
top_suitable = await create_topology(scenario_def, local_threshold_ms=1000000)
top_latency = await create_topology(scenario_def)

top_suitable_deprioritized_servers = [
top_suitable.get_server_by_address(clean_node(server["address"]))
for server in scenario_def.get("deprioritized_servers", [])
]
top_latency_deprioritized_servers = [
top_latency.get_server_by_address(clean_node(server["address"]))
for server in scenario_def.get("deprioritized_servers", [])
]

# Create server selector.
if scenario_def.get("operation") == "write":
Expand All @@ -120,21 +129,37 @@ async def run_scenario(self):
# Select servers.
if not scenario_def.get("suitable_servers"):
with self.assertRaises(AutoReconnect):
await top_suitable.select_server(pref, _Op.TEST, server_selection_timeout=0)
await top_suitable.select_server(
pref,
_Op.TEST,
server_selection_timeout=0,
deprioritized_servers=top_suitable_deprioritized_servers,
)

return

if not scenario_def["in_latency_window"]:
with self.assertRaises(AutoReconnect):
await top_latency.select_server(pref, _Op.TEST, server_selection_timeout=0)
await top_latency.select_server(
pref,
_Op.TEST,
server_selection_timeout=0,
deprioritized_servers=top_latency_deprioritized_servers,
)

return

actual_suitable_s = await top_suitable.select_servers(
pref, _Op.TEST, server_selection_timeout=0
pref,
_Op.TEST,
server_selection_timeout=0,
deprioritized_servers=top_suitable_deprioritized_servers,
)
actual_latency_s = await top_latency.select_servers(
pref, _Op.TEST, server_selection_timeout=0
pref,
_Op.TEST,
server_selection_timeout=0,
deprioritized_servers=top_latency_deprioritized_servers,
)

expected_suitable_servers = {}
Expand Down
Loading
Loading