From 40e538bac1b3ddadae5dc1b62a58f302c4035215 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Sun, 18 Jan 2026 21:23:39 -0800 Subject: [PATCH 01/20] v0 Signed-off-by: Kourosh Hakhamaneshi --- .../skyrl_train/inference_servers/__init__.py | 25 ++ .../skyrl_train/inference_servers/common.py | 39 ++ .../skyrl_train/inference_servers/router.py | 348 ++++++++++++++++ .../inference_servers/server_group.py | 178 +++++++++ .../inference_servers/server_pool.py | 55 +++ .../inference_servers/vllm_server_actor.py | 371 ++++++++++++++++++ .../tests/inference_servers/__init__.py | 1 + .../tests/inference_servers/test_common.py | 75 ++++ .../tests/inference_servers/test_router.py | 173 ++++++++ 9 files changed, 1265 insertions(+) create mode 100644 skyrl-train/skyrl_train/inference_servers/__init__.py create mode 100644 skyrl-train/skyrl_train/inference_servers/common.py create mode 100644 skyrl-train/skyrl_train/inference_servers/router.py create mode 100644 skyrl-train/skyrl_train/inference_servers/server_group.py create mode 100644 skyrl-train/skyrl_train/inference_servers/server_pool.py create mode 100644 skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py create mode 100644 skyrl-train/tests/inference_servers/__init__.py create mode 100644 skyrl-train/tests/inference_servers/test_common.py create mode 100644 skyrl-train/tests/inference_servers/test_router.py diff --git a/skyrl-train/skyrl_train/inference_servers/__init__.py b/skyrl-train/skyrl_train/inference_servers/__init__.py new file mode 100644 index 000000000..06ee5c84d --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/__init__.py @@ -0,0 +1,25 @@ +""" +SkyRL Inference Servers Module. + +This module provides HTTP-based inference server infrastructure: +- VLLMServerActor: Ray actor running vLLM OpenAI-compatible server +- ServerActorPool: Generic pool managing server actors +- VLLMServerGroup: vLLM-specific server group with placement group support +- InferenceRouter: HTTP proxy router with session-aware routing +""" + +from skyrl_train.inference_servers.common import ServerInfo, get_node_ip, get_free_port +from skyrl_train.inference_servers.server_pool import ServerActorPool +from skyrl_train.inference_servers.vllm_server_actor import VLLMServerActor +from skyrl_train.inference_servers.server_group import VLLMServerGroup +from skyrl_train.inference_servers.router import InferenceRouter + +__all__ = [ + "ServerInfo", + "get_node_ip", + "get_free_port", + "ServerActorPool", + "VLLMServerActor", + "VLLMServerGroup", + "InferenceRouter", +] diff --git a/skyrl-train/skyrl_train/inference_servers/common.py b/skyrl-train/skyrl_train/inference_servers/common.py new file mode 100644 index 000000000..0a34265ad --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/common.py @@ -0,0 +1,39 @@ +""" +Common utilities for inference servers. +""" + +import socket +from dataclasses import dataclass + +import ray + + +@dataclass +class ServerInfo: + """Information about a running inference server.""" + + ip: str + port: int + + @property + def url(self) -> str: + return f"http://{self.ip}:{self.port}" + + +def get_node_ip() -> str: + """Get the IP address of the current node.""" + return ray._private.services.get_node_ip_address().strip("[]") + + +def get_free_port(start_port: int = 8000) -> int: + """Find an available port starting from start_port.""" + port = start_port + while True: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + return port + except OSError: + port += 1 diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py new file mode 100644 index 000000000..8dcf5d81a --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -0,0 +1,348 @@ +""" +Inference Router - HTTP proxy with session-aware routing and control plane fan-out. +""" + +import asyncio +import hashlib +import itertools +import logging +from typing import List, Optional + +import httpx +import uvicorn +from fastapi import FastAPI, Request, Response +from fastapi.responses import StreamingResponse + +from skyrl_train.inference_servers.common import ServerInfo, get_node_ip + +logger = logging.getLogger(__name__) + + +# Routes that go to ONE backend (data plane) +DATA_PLANE_ROUTES = [ + "/v1/completions", + "/v1/chat/completions", + "/tokenize", + "/detokenize", + "/health", + "/models", + "/version", +] + +# Routes that go to ALL backends (control plane) +CONTROL_PLANE_ROUTES = [ + "/pause", + "/resume", + "/sleep", + "/wake_up", + "/wakeup", + "/reset_prefix_cache", + "/collective_rpc", + "/init_weight_update_communicator", + "/update_weights", + "/finalize_weight_update", + "/destroy_weights_update_group", +] + + +class InferenceRouter: + """ + HTTP proxy router for multiple vLLM backends. + + Routing behavior: + - Data plane (generation requests): Routes to ONE backend + - If X-Session-ID header present: consistent hash to same backend + - Otherwise: round-robin + - Control plane (sleep, pause, weight sync): Fans out to ALL backends + + Usage: + router = InferenceRouter(server_urls, host="0.0.0.0", port=8080) + router_url = router.start() + # ... use router_url for inference ... + router.shutdown() + """ + + def __init__( + self, + server_urls: List[str], + host: str = "0.0.0.0", + port: int = 8080, + ): + """ + Initialize the router. + + Args: + server_urls: List of backend vLLM server URLs + host: Host to bind router to + port: Port to bind router to + """ + self._server_urls = server_urls + self._host = host + self._port = port + self._server_cycle = itertools.cycle(server_urls) + self._client: Optional[httpx.AsyncClient] = None + self._app: Optional[FastAPI] = None + self._server_task: Optional[asyncio.Task] = None + self._shutdown_event: Optional[asyncio.Event] = None + + logger.info(f"InferenceRouter: {len(server_urls)} backends, port={port}") + + def _hash_session_id(self, session_id: str) -> int: + """Hash session ID to get consistent backend index.""" + hash_bytes = hashlib.sha256(session_id.encode()).digest() + return int.from_bytes(hash_bytes[:8], "big") + + def _get_backend_for_session(self, session_id: str) -> str: + """Get consistent backend URL for a session ID.""" + idx = self._hash_session_id(session_id) % len(self._server_urls) + return self._server_urls[idx] + + def _get_backend_round_robin(self) -> str: + """Get next backend URL in round-robin order.""" + return next(self._server_cycle) + + def _get_backend_for_request(self, request: Request) -> str: + """ + Determine backend for a request. + + If X-Session-ID header is present, use consistent hashing. + Otherwise, use round-robin. + """ + session_id = request.headers.get("X-Session-ID") + if session_id: + return self._get_backend_for_session(session_id) + return self._get_backend_round_robin() + + def _is_control_plane_route(self, path: str) -> bool: + """Check if path is a control plane route (fan-out to all).""" + return any(path.startswith(route) for route in CONTROL_PLANE_ROUTES) + + def _build_app(self) -> FastAPI: + """Build the FastAPI app with proxy routes.""" + app = FastAPI( + title="SkyRL Inference Router", + docs_url=None, + redoc_url=None, + openapi_url=None, + ) + + @app.get("/servers") + async def list_servers(): + """Return list of backend URLs.""" + return {"servers": self._server_urls} + + @app.get("/get_server_info") + async def get_server_info(): + """Fetch server info from any backend (all should return same).""" + backend = self._server_urls[0] + try: + resp = await self._client.get(f"{backend}/get_server_info", timeout=10.0) + return resp.json() + except Exception as e: + return {"error": str(e)} + + # Catch-all: proxy everything else to backends + @app.api_route( + "/{path:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], + ) + async def proxy(request: Request, path: str): + return await self._proxy_request(request, f"/{path}") + + return app + + async def _proxy_request(self, request: Request, path: str) -> Response: + """ + Proxy a request to backend(s). + + Control plane routes go to ALL backends. + Data plane routes go to ONE backend (session-aware or round-robin). + """ + if self._is_control_plane_route(path): + return await self._proxy_to_all(request, path) + else: + return await self._proxy_to_one(request, path) + + async def _proxy_to_one(self, request: Request, path: str) -> Response: + """Proxy request to one backend (data plane).""" + backend_url = self._get_backend_for_request(request) + url = f"{backend_url}{path}" + method = request.method + + body = await request.body() + + # Check if streaming is requested + is_streaming = False + if body: + try: + import json + + data = json.loads(body) + is_streaming = data.get("stream", False) + except (json.JSONDecodeError, UnicodeDecodeError): + pass + + # Forward headers (filter out hop-by-hop headers) + headers = { + k: v + for k, v in request.headers.items() + if k.lower() not in ("host", "content-length", "transfer-encoding") + } + + if is_streaming: + return await self._proxy_streaming(url, method, headers, body) + else: + return await self._proxy_regular(url, method, headers, body) + + async def _proxy_regular( + self, url: str, method: str, headers: dict, body: bytes + ) -> Response: + """Proxy a regular (non-streaming) request.""" + response = await self._client.request( + method=method, + url=url, + headers=headers, + content=body, + ) + + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers), + ) + + async def _proxy_streaming( + self, url: str, method: str, headers: dict, body: bytes + ) -> StreamingResponse: + """Proxy a streaming request.""" + + async def stream_generator(): + async with self._client.stream( + method=method, + url=url, + headers=headers, + content=body, + ) as response: + async for chunk in response.aiter_bytes(): + yield chunk + + return StreamingResponse( + stream_generator(), + media_type="text/event-stream", + ) + + async def _proxy_to_all(self, request: Request, path: str) -> Response: + """Proxy request to all backends (control plane), aggregate responses.""" + method = request.method + body = await request.body() + + # Forward headers + headers = { + k: v + for k, v in request.headers.items() + if k.lower() not in ("host", "content-length", "transfer-encoding") + } + + # Send to all backends concurrently + async def call_backend(backend_url: str): + url = f"{backend_url}{path}" + try: + response = await self._client.request( + method=method, + url=url, + headers=headers, + content=body, + timeout=300.0, # Long timeout for weight sync + ) + return { + "url": backend_url, + "status": response.status_code, + "body": response.json() if response.content else None, + } + except Exception as e: + return { + "url": backend_url, + "status": 500, + "error": str(e), + } + + results = await asyncio.gather( + *[call_backend(url) for url in self._server_urls] + ) + + # Check if all succeeded + all_ok = all(r.get("status") == 200 for r in results) + + if all_ok: + return Response( + content='{"status": "ok"}', + status_code=200, + media_type="application/json", + ) + else: + import json + + return Response( + content=json.dumps({"status": "partial_failure", "results": results}), + status_code=207, # Multi-Status + media_type="application/json", + ) + + def start(self) -> str: + """ + Start the router server in background. + + Returns: + Router URL (e.g., "http://192.168.1.1:8080") + """ + if not self._server_urls: + raise ValueError("No backend servers available") + + # Create HTTP client for proxying + self._client = httpx.AsyncClient(timeout=httpx.Timeout(None)) + + # Build FastAPI app + self._app = self._build_app() + + # Create shutdown event + self._shutdown_event = asyncio.Event() + + # Start server in background thread (since we're not in async context) + import threading + + def run_server(): + asyncio.run(self._run_server()) + + self._server_thread = threading.Thread(target=run_server, daemon=True) + self._server_thread.start() + + # Wait a bit for server to start + import time + + time.sleep(1) + + ip = get_node_ip() + router_url = f"http://{ip}:{self._port}" + logger.info(f"Router started at {router_url}") + logger.info(f" GET /servers - list backend servers") + logger.info(f" GET /get_server_info - get parallelism info") + return router_url + + async def _run_server(self) -> None: + """Run the uvicorn server.""" + config = uvicorn.Config( + app=self._app, + host=self._host, + port=self._port, + log_level="warning", + access_log=False, + ) + server = uvicorn.Server(config) + await server.serve() + + def shutdown(self) -> None: + """Shutdown the router.""" + logger.info("Shutting down router...") + if self._shutdown_event: + self._shutdown_event.set() + # Note: Thread will exit when uvicorn server stops diff --git a/skyrl-train/skyrl_train/inference_servers/server_group.py b/skyrl-train/skyrl_train/inference_servers/server_group.py new file mode 100644 index 000000000..cdbfc3676 --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/server_group.py @@ -0,0 +1,178 @@ +""" +vLLM Server Group - manages vLLM server actors with placement groups. +""" + +import logging +from argparse import Namespace +from typing import Any, List, Optional + +import ray +from ray.util.placement_group import PlacementGroup, placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from skyrl_train.inference_servers.common import ServerInfo +from skyrl_train.inference_servers.server_pool import ServerActorPool +from skyrl_train.inference_servers.vllm_server_actor import VLLMServerActor + +logger = logging.getLogger(__name__) + + +class VLLMServerGroup: + """ + Creates and manages a group of vLLM server actors. + + This layer handles vLLM-specific actor creation with placement group support, + then delegates pool management to ServerActorPool. + + Supports: + - Basic mode: Creates its own placement group + - Colocation mode: Uses external placement group (shared with training) + - Data Parallel: Multiple DP-enabled servers + - PD Disaggregation: Prefill-decode disaggregation with NIXL + """ + + def __init__( + self, + engine_args: Namespace, + num_servers: int, + start_port: int = 8000, + placement_group: Optional[PlacementGroup] = None, + placement_group_bundle_offset: int = 0, + enable_dp: bool = False, + enable_pd: bool = False, + nixl_side_channel_base: int = 5600, + ): + """ + Initialize the vLLM server group. + + Args: + engine_args: vLLM engine configuration (Namespace from make_arg_parser). + Required attributes: tensor_parallel_size, pipeline_parallel_size, model. + num_servers: Number of vLLM server instances to create + start_port: Base port for server ports + placement_group: External placement group for colocation mode. + If None, creates internal placement group. + placement_group_bundle_offset: Offset for bundle indices when using + external placement group (e.g., if training uses first N bundles) + enable_dp: Enable data parallelism across servers + enable_pd: Enable prefill-decode disaggregation + nixl_side_channel_base: Base port for NIXL side channels + """ + self._engine_args = engine_args + self._num_servers = num_servers + self._start_port = start_port + self._external_pg = placement_group + self._bundle_offset = placement_group_bundle_offset + self._enable_dp = enable_dp + self._enable_pd = enable_pd + self._nixl_side_channel_base = nixl_side_channel_base + self._pool: Optional[ServerActorPool] = None + self._internal_pg: Optional[PlacementGroup] = None + + # Query the actor class for GPU requirements (single source of truth) + self._num_gpus_per_server = VLLMServerActor.compute_num_gpus_per_server(engine_args) + + logger.info( + f"VLLMServerGroup: num_servers={num_servers}, " + f"gpus_per_server={self._num_gpus_per_server}, " + f"enable_dp={enable_dp}, enable_pd={enable_pd}, " + f"external_pg={'yes' if placement_group else 'no'}" + ) + + def _create_placement_group(self) -> PlacementGroup: + """Create internal placement group with bundles for all servers.""" + total_bundles = self._num_servers * self._num_gpus_per_server + logger.info(f"Creating placement group with {total_bundles} bundles...") + pg = placement_group([{"CPU": 1, "GPU": 1} for _ in range(total_bundles)]) + ray.get(pg.ready()) + logger.info("Placement group ready") + return pg + + def _get_placement_group(self) -> PlacementGroup: + """Get the placement group (external or internal).""" + if self._external_pg is not None: + return self._external_pg + if self._internal_pg is None: + self._internal_pg = self._create_placement_group() + return self._internal_pg + + def _create_actor_class(self, pg: PlacementGroup, start_bundle_idx: int) -> Any: + """Create actor class with scheduling constraints for a specific bundle.""" + return ray.remote(VLLMServerActor).options( + num_gpus=0, # GPU allocation managed by placement group + num_cpus=1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=start_bundle_idx, + ), + ) + + def _create_actors(self) -> List[Any]: + """Create vLLM server actors with GPU resources.""" + pg = self._get_placement_group() + + actors = [] + dp_address, dp_rpc_port = None, None + + for server_idx in range(self._num_servers): + # Calculate bundle index accounting for offset (colocation mode) + start_bundle_idx = ( + self._bundle_offset + server_idx * self._num_gpus_per_server + ) + + ServerActorClass = self._create_actor_class(pg, start_bundle_idx) + + actor = ServerActorClass.remote( + self._engine_args, + self._start_port + server_idx, + server_idx=server_idx, + dp_size=self._num_servers if self._enable_dp else -1, + dp_master_address=dp_address, + dp_rpc_port=dp_rpc_port, + enable_pd=self._enable_pd, + nixl_side_channel_base=self._nixl_side_channel_base, + ) + + # Get DP info from server 0 which is where DP0 will be + if self._enable_dp and server_idx == 0: + dp_address, dp_rpc_port = ray.get(actor.get_dp_info.remote()) + logger.info(f"DP0 info: address={dp_address}, rpc_port={dp_rpc_port}") + + actors.append(actor) + + return actors + + def start(self) -> List[ServerInfo]: + """Create actors, start the pool, and return endpoints.""" + logger.info(f"Starting {self._num_servers} vLLM server(s)...") + actors = self._create_actors() + self._pool = ServerActorPool(actors) + server_infos = self._pool.start() + + for i, info in enumerate(server_infos): + logger.info(f"Server {i}: {info.url}") + + return server_infos + + def get_pool(self) -> Optional[ServerActorPool]: + """Get the underlying actor pool.""" + return self._pool + + def get_server_infos(self) -> List[ServerInfo]: + """Get the list of server endpoints.""" + return self._pool.get_server_infos() if self._pool else [] + + def get_server_urls(self) -> List[str]: + """Get the list of server URLs.""" + return self._pool.get_server_urls() if self._pool else [] + + def get_actors(self) -> List[Any]: + """Get the list of actor handles.""" + return self._pool.get_actors() if self._pool else [] + + def shutdown(self) -> None: + """Shutdown all servers.""" + if self._pool: + logger.info("Shutting down vLLM servers...") + self._pool.shutdown() diff --git a/skyrl-train/skyrl_train/inference_servers/server_pool.py b/skyrl-train/skyrl_train/inference_servers/server_pool.py new file mode 100644 index 000000000..ffbd72ea5 --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/server_pool.py @@ -0,0 +1,55 @@ +""" +Generic server actor pool. +""" + +from typing import Any, List + +import ray + +from skyrl_train.inference_servers.common import ServerInfo + + +class ServerActorPool: + """ + Generic pool that manages a list of server actors. + + Actors must implement: + - start() -> ServerInfo + - shutdown() -> None + + This layer is agnostic to the type of server (vLLM, SGLang, etc). + """ + + def __init__(self, actors: List[Any]): + """ + Initialize the pool with pre-constructed actor handles. + + Args: + actors: List of Ray actor handles + """ + self._actors = actors + self._server_infos: List[ServerInfo] = [] + + def start(self) -> List[ServerInfo]: + """Start all actors and collect their server infos.""" + # Start all actors in parallel, wait for all to be ready + start_refs = [actor.start.remote() for actor in self._actors] + self._server_infos = ray.get(start_refs) + return self._server_infos + + def get_server_infos(self) -> List[ServerInfo]: + """Get the list of server endpoints.""" + return self._server_infos + + def get_server_urls(self) -> List[str]: + """Get the list of server URLs.""" + return [info.url for info in self._server_infos] + + def get_actors(self) -> List[Any]: + """Get the list of actor handles.""" + return self._actors + + def shutdown(self) -> None: + """Shutdown all actors.""" + shutdown_refs = [actor.shutdown.remote() for actor in self._actors] + ray.get(shutdown_refs) diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py new file mode 100644 index 000000000..5538e1ca3 --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py @@ -0,0 +1,371 @@ +""" +vLLM Server Actor - Ray actor running a vLLM OpenAI-compatible API server. +""" + +import asyncio +import logging +import os +import pickle +import time +from argparse import Namespace +from typing import Any, Dict, Optional + +import httpx +import uvicorn +from fastapi import Request + +from skyrl_train.inference_servers.common import ServerInfo, get_node_ip, get_free_port + +logger = logging.getLogger(__name__) + + +class VLLMServerActor: + """ + Ray actor that runs a vLLM OpenAI-compatible API server. + + The server runs in the actor and exposes an HTTP endpoint that can be + called from anywhere (other actors, driver, external processes). + + Custom endpoints added for SkyRL: + - /init_weight_update_communicator: Initialize weight sync process group + - /update_weights: Update model weights via NCCL broadcast + - /finalize_weight_update: Post-processing after weight sync + - /destroy_weights_update_group: Teardown weight sync + - /sleep: Offload weights to CPU + - /wake_up: Load weights back to GPU + - /reset_prefix_cache: Clear KV cache + - /get_server_info: Return parallelism info + """ + + @staticmethod + def compute_num_gpus_per_server(engine_args: Namespace) -> int: + """Compute the number of GPUs needed per server based on TP * PP.""" + return engine_args.tensor_parallel_size * engine_args.pipeline_parallel_size + + def __init__( + self, + engine_args: Namespace, + start_port: int = 8000, + server_idx: int = 0, + dp_size: int = -1, + dp_master_address: Optional[str] = None, + dp_rpc_port: Optional[int] = None, + # PD disaggregation settings + enable_pd: bool = False, + nixl_side_channel_base: int = 5600, + ): + """ + Initialize the vLLM server actor. + + Args: + engine_args: vLLM engine configuration (Namespace from make_arg_parser). + Required attributes: tensor_parallel_size, pipeline_parallel_size, model. + Optional: uvicorn_log_level, ssl_*, disable_uvicorn_access_log, kv_transfer_config. + start_port: Base port to start searching for free port + server_idx: Index of this server in the group + dp_size: Data parallel size (-1 to disable) + dp_master_address: DP master address (for non-rank-0 servers) + dp_rpc_port: DP RPC port (for non-rank-0 servers) + enable_pd: Enable prefill-decode disaggregation + nixl_side_channel_base: Base port for NIXL side channel + """ + self._engine_args = engine_args + self._ip = get_node_ip() + self._port = get_free_port(start_port) + self._server_idx = server_idx + self._num_gpus_per_server = self.compute_num_gpus_per_server(engine_args) + + # Update args with our assigned host/port + self._engine_args.host = "0.0.0.0" + self._engine_args.port = self._port + + # PD disaggregation: setup NIXL side channel for KV transfer + if enable_pd: + self._setup_nixl_side_channel(nixl_side_channel_base) + + # Each engine needs to know its dp_rank and dp_size so DP process groups are formed + if dp_size > 0: + self._engine_args.data_parallel_size = dp_size + self._engine_args.data_parallel_rank = server_idx + # All DP ranks need to know the master's address and RPC port for handshake + if server_idx == 0: + dp_master_address, dp_rpc_port = self.get_dp_info() + + if dp_master_address is None or dp_rpc_port is None: + raise ValueError("DP address and RPC port must be set for non-server 0") + + self._engine_args.data_parallel_address = dp_master_address + self._engine_args.data_parallel_rpc_port = dp_rpc_port + logger.info( + f"Server {server_idx}: DP enabled - dp_size={dp_size}, dp_rank={server_idx}, " + f"dp_master_address={dp_master_address}, dp_rpc_port={dp_rpc_port}" + ) + + # Compute bundle indices for this server's TP/PP workers + # Each server uses a contiguous slice of bundles in the placement group + start_bundle = server_idx * self._num_gpus_per_server + bundle_indices = list(range(start_bundle, start_bundle + self._num_gpus_per_server)) + os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) + logger.info(f"Server {server_idx}: using bundle indices {bundle_indices}") + + self._engine = None + self._server_task = None + + def _setup_nixl_side_channel(self, base_port: int) -> None: + """ + Setup NIXL side channel for PD disaggregation. + + Each server instance needs a unique side channel port for KV transfer handshake. + """ + import json + + side_channel_port = base_port + self._server_idx + os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(side_channel_port) + os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = self._ip + + engine_id = f"server-{self._server_idx}-{self._ip}-{side_channel_port}" + + if hasattr(self._engine_args, "kv_transfer_config") and self._engine_args.kv_transfer_config: + try: + kv_config = json.loads(self._engine_args.kv_transfer_config) + kv_config["engine_id"] = engine_id + self._engine_args.kv_transfer_config = json.dumps(kv_config) + except (json.JSONDecodeError, TypeError): + pass + + logger.info( + f"Server {self._server_idx}: NIXL side channel configured - " + f"host={self._ip}, port={side_channel_port}, engine_id={engine_id}" + ) + + def get_server_info(self) -> ServerInfo: + """Get the server's IP and port info.""" + return ServerInfo(ip=self._ip, port=self._port) + + def get_extended_server_info(self) -> Dict[str, Any]: + """Return extended server info including parallelism settings.""" + return { + "ip": self._ip, + "port": self._port, + "url": f"http://{self._ip}:{self._port}", + "server_idx": self._server_idx, + "tp_size": self._engine_args.tensor_parallel_size, + "pp_size": self._engine_args.pipeline_parallel_size, + "dp_size": getattr(self._engine_args, "data_parallel_size", 1), + "world_size": self._num_gpus_per_server, + } + + def get_dp_info(self) -> tuple: + """Get the DP master address and RPC port (for server 0 to share with others).""" + dp_rpc_port = self._port + 500 + return (self._ip, dp_rpc_port) + + async def start(self) -> ServerInfo: + """Start the vLLM server. Blocks until server is healthy.""" + from vllm.utils.system_utils import set_ulimit + + set_ulimit() + logger.info(f"Starting server on {self._ip}:{self._port}...") + + # Start HTTP server as background asyncio task + self._server_task = asyncio.create_task(self._run_server()) + + # Wait until the server is actually healthy + await self._wait_until_healthy() + + return self.get_server_info() + + async def _wait_until_healthy(self, timeout: float = 600, interval: float = 1.0) -> None: + """Poll the /health endpoint until it responds OK.""" + url = f"http://{self._ip}:{self._port}/health" + start_time = time.time() + + async with httpx.AsyncClient() as client: + while True: + # Check if server task failed + if self._server_task.done(): + exc = self._server_task.exception() + if exc: + raise exc + raise RuntimeError("Server task exited unexpectedly") + + try: + resp = await client.get(url, timeout=5.0) + if resp.status_code == 200: + logger.info(f"Server {self._ip}:{self._port} is healthy") + return + except httpx.RequestError: + pass + + if time.time() - start_time > timeout: + raise TimeoutError(f"Server failed to become healthy within {timeout}s") + + await asyncio.sleep(interval) + + async def _run_server(self) -> None: + """Internal method to run the HTTP server.""" + from vllm import AsyncLLMEngine + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.entrypoints.openai.api_server import build_app, create_server_socket, init_app_state + from vllm.usage.usage_lib import UsageContext + import vllm.envs as envs + + sock_addr = (self._engine_args.host, self._engine_args.port) + sock = create_server_socket(sock_addr) + app = build_app(self._engine_args) + + # Initialize the engine (this loads the model - takes time) + engine_args = AsyncEngineArgs.from_cli_args(self._engine_args) + self._engine = AsyncLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.OPENAI_API_SERVER, + ) + logger.info(f"Engine initialized on {self._ip}:{self._port}, adding custom endpoints...") + + # Add custom SkyRL endpoints + self._add_custom_endpoints(app) + + await init_app_state(self._engine, app.state, self._engine_args) + + # Use uvicorn directly (serve_http tries to add signal handlers which fails in Ray actors) + config = uvicorn.Config( + app, + host=self._engine_args.host, + port=self._engine_args.port, + log_level=self._engine_args.uvicorn_log_level, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=self._engine_args.ssl_keyfile, + ssl_certfile=self._engine_args.ssl_certfile, + ssl_ca_certs=self._engine_args.ssl_ca_certs, + ssl_cert_reqs=self._engine_args.ssl_cert_reqs, + access_log=not getattr(self._engine_args, "disable_uvicorn_access_log", False), + ) + server = uvicorn.Server(config) + await server.serve(sockets=[sock]) + + def _add_custom_endpoints(self, app) -> None: + """Add custom SkyRL endpoints to the FastAPI app.""" + engine = self._engine + + @app.get("/get_server_info") + async def _get_server_info(): + """Return server parallelism info.""" + return self.get_extended_server_info() + + @app.post("/init_weight_update_communicator") + async def _init_weight_update_communicator(request: Request): + """Initialize weight sync process group.""" + from skyrl_train.weight_sync import BroadcastInitInfo + + data = await request.json() + init_info = BroadcastInitInfo(**data) + pickled_init_info = pickle.dumps(init_info) + + await engine.collective_rpc( + "init_weight_update_communicator", + args=(pickled_init_info,), + ) + return {"status": "ok"} + + @app.post("/update_weights") + async def _update_weights(request: Request): + """Update model weights via NCCL broadcast.""" + from skyrl_train.weight_sync import BroadcastWeightUpdateRequest + + data = await request.json() + weight_request = BroadcastWeightUpdateRequest(**data) + pickled_request = pickle.dumps(weight_request) + + await engine.collective_rpc( + "load_weights", + args=(pickled_request,), + ) + return {"status": "ok"} + + @app.post("/finalize_weight_update") + async def _finalize_weight_update(request: Request): + """ + Finalize weight update - post-processing hook. + + Currently a no-op, reserved for future use (cache invalidation, etc). + """ + # No-op for now - placeholder for future post-processing + return {"status": "ok"} + + @app.post("/destroy_weights_update_group") + async def _destroy_weights_update_group(request: Request): + """Teardown weight sync process group.""" + await engine.collective_rpc( + "teardown_weight_receiver", + args=(), + ) + return {"status": "ok"} + + @app.post("/sleep") + async def _sleep(request: Request): + """Offload weights to CPU.""" + data = await request.json() + level = data.get("level", 1) + + # Reset prefix cache before sleep to avoid gibberish on wake + # See: https://github.com/vllm-project/vllm/issues/17103 + await engine.reset_prefix_cache() + await engine.sleep(level) + return {"status": "ok"} + + @app.post("/wake_up") + async def _wake_up(request: Request): + """Load weights back to GPU.""" + data = await request.json() + tags = data.get("tags") + await engine.wake_up(tags) + return {"status": "ok"} + + @app.post("/wakeup") + async def _wakeup(request: Request): + """Alias for /wake_up.""" + data = await request.json() + tags = data.get("tags") + await engine.wake_up(tags) + return {"status": "ok"} + + @app.post("/reset_prefix_cache") + async def _reset_prefix_cache(request: Request): + """Clear KV cache.""" + data = await request.json() + reset_running = data.get("reset_running_requests", False) + if reset_running: + # If reset_running_requests is True, we need to abort first + await engine.abort_all_requests() + await engine.reset_prefix_cache() + return {"status": "ok"} + + @app.post("/pause") + async def _pause(request: Request): + """Pause generation.""" + data = await request.json() + wait_for_inflight = data.get("wait_for_inflight_request", False) + # vLLM's pause API - implementation depends on vLLM version + if hasattr(engine, "pause"): + await engine.pause(wait_for_inflight_request=wait_for_inflight) + else: + # Fallback: abort all if pause not available + if not wait_for_inflight: + await engine.abort_all_requests() + return {"status": "ok"} + + @app.post("/resume") + async def _resume(request: Request): + """Resume generation.""" + if hasattr(engine, "resume"): + await engine.resume() + return {"status": "ok"} + + async def shutdown(self) -> None: + """Gracefully shutdown the server.""" + if self._server_task: + self._server_task.cancel() + try: + await self._server_task + except asyncio.CancelledError: + pass diff --git a/skyrl-train/tests/inference_servers/__init__.py b/skyrl-train/tests/inference_servers/__init__.py new file mode 100644 index 000000000..55f4e2b47 --- /dev/null +++ b/skyrl-train/tests/inference_servers/__init__.py @@ -0,0 +1 @@ +# Tests for inference_servers module diff --git a/skyrl-train/tests/inference_servers/test_common.py b/skyrl-train/tests/inference_servers/test_common.py new file mode 100644 index 000000000..6a3b8acf9 --- /dev/null +++ b/skyrl-train/tests/inference_servers/test_common.py @@ -0,0 +1,75 @@ +"""Tests for inference_servers.common module.""" + +import pytest +import socket + +from skyrl_train.inference_servers.common import ServerInfo, get_free_port + + +class TestServerInfo: + """Tests for ServerInfo dataclass.""" + + def test_server_info_url(self): + """Test URL property.""" + info = ServerInfo(ip="192.168.1.1", port=8000) + assert info.url == "http://192.168.1.1:8000" + + def test_server_info_url_localhost(self): + """Test URL with localhost.""" + info = ServerInfo(ip="127.0.0.1", port=30000) + assert info.url == "http://127.0.0.1:30000" + + def test_server_info_fields(self): + """Test dataclass fields.""" + info = ServerInfo(ip="10.0.0.1", port=9000) + assert info.ip == "10.0.0.1" + assert info.port == 9000 + + +class TestGetFreePort: + """Tests for get_free_port function.""" + + def test_get_free_port_returns_available(self): + """Test that get_free_port returns an available port.""" + port = get_free_port(start_port=50000) + assert port >= 50000 + + # Verify the port is actually free by binding to it + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + + def test_get_free_port_skips_occupied(self): + """Test that get_free_port skips occupied ports.""" + # Occupy a port + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", 51000)) + s.listen(1) + + # Should return a different port + port = get_free_port(start_port=51000) + assert port >= 51000 + + def test_get_free_port_sequential_calls(self): + """Test that sequential calls return different ports when ports are occupied.""" + ports = [] + sockets = [] + + try: + # Get multiple ports and keep them occupied + for i in range(3): + port = get_free_port(start_port=52000) + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) + ports.append(port) + sockets.append(s) + + # All ports should be unique + assert len(set(ports)) == 3 + finally: + for s in sockets: + s.close() diff --git a/skyrl-train/tests/inference_servers/test_router.py b/skyrl-train/tests/inference_servers/test_router.py new file mode 100644 index 000000000..eb9a809e8 --- /dev/null +++ b/skyrl-train/tests/inference_servers/test_router.py @@ -0,0 +1,173 @@ +"""Tests for inference_servers.router module.""" + +import pytest +from unittest.mock import MagicMock, AsyncMock + +from skyrl_train.inference_servers.router import ( + InferenceRouter, + DATA_PLANE_ROUTES, + CONTROL_PLANE_ROUTES, +) + + +class TestRouterRoutingLogic: + """Tests for router routing logic (no actual HTTP calls).""" + + @pytest.fixture + def router(self): + """Create a router with mock backends.""" + server_urls = [ + "http://backend1:8000", + "http://backend2:8000", + "http://backend3:8000", + ] + return InferenceRouter(server_urls, host="0.0.0.0", port=9999) + + def test_session_hash_consistency(self, router): + """Test that same session ID always maps to same backend.""" + session_id = "user-123-session-456" + + # Multiple calls should return the same backend + backend1 = router._get_backend_for_session(session_id) + backend2 = router._get_backend_for_session(session_id) + backend3 = router._get_backend_for_session(session_id) + + assert backend1 == backend2 == backend3 + + def test_different_sessions_distribute(self, router): + """Test that different session IDs distribute across backends.""" + # With enough session IDs, we should hit multiple backends + backends = set() + for i in range(100): + session_id = f"session-{i}" + backend = router._get_backend_for_session(session_id) + backends.add(backend) + + # Should hit multiple backends (not all requests to one) + assert len(backends) >= 2 + + def test_round_robin_cycles(self, router): + """Test that round-robin cycles through all backends.""" + backends = [] + for _ in range(6): # 2 full cycles + backend = router._get_backend_round_robin() + backends.append(backend) + + # First 3 should be unique + assert len(set(backends[:3])) == 3 + + # Should repeat the pattern + assert backends[0] == backends[3] + assert backends[1] == backends[4] + assert backends[2] == backends[5] + + def test_control_plane_route_detection(self, router): + """Test control plane route detection.""" + # Control plane routes + assert router._is_control_plane_route("/pause") is True + assert router._is_control_plane_route("/resume") is True + assert router._is_control_plane_route("/sleep") is True + assert router._is_control_plane_route("/wake_up") is True + assert router._is_control_plane_route("/wakeup") is True + assert router._is_control_plane_route("/reset_prefix_cache") is True + assert router._is_control_plane_route("/init_weight_update_communicator") is True + assert router._is_control_plane_route("/update_weights") is True + assert router._is_control_plane_route("/finalize_weight_update") is True + + # Data plane routes should NOT be control plane + assert router._is_control_plane_route("/v1/completions") is False + assert router._is_control_plane_route("/v1/chat/completions") is False + assert router._is_control_plane_route("/health") is False + assert router._is_control_plane_route("/models") is False + assert router._is_control_plane_route("/tokenize") is False + + def test_data_plane_routes_list(self): + """Test that data plane routes list is correct.""" + expected = [ + "/v1/completions", + "/v1/chat/completions", + "/tokenize", + "/detokenize", + "/health", + "/models", + "/version", + ] + assert DATA_PLANE_ROUTES == expected + + def test_control_plane_routes_list(self): + """Test that control plane routes list is correct.""" + expected = [ + "/pause", + "/resume", + "/sleep", + "/wake_up", + "/wakeup", + "/reset_prefix_cache", + "/collective_rpc", + "/init_weight_update_communicator", + "/update_weights", + "/finalize_weight_update", + "/destroy_weights_update_group", + ] + assert CONTROL_PLANE_ROUTES == expected + + +class TestRouterRequestRouting: + """Tests for request routing based on headers.""" + + @pytest.fixture + def router(self): + """Create a router with mock backends.""" + server_urls = [ + "http://backend1:8000", + "http://backend2:8000", + ] + return InferenceRouter(server_urls, host="0.0.0.0", port=9999) + + def test_request_with_session_id_header(self, router): + """Test that X-Session-ID header triggers session-aware routing.""" + # Create mock request with session header + request = MagicMock() + request.headers = {"X-Session-ID": "test-session-123"} + + backend1 = router._get_backend_for_request(request) + backend2 = router._get_backend_for_request(request) + + # Same session should get same backend + assert backend1 == backend2 + + def test_request_without_session_id_header(self, router): + """Test that missing X-Session-ID header triggers round-robin.""" + # Create mock request without session header + request = MagicMock() + request.headers = {} + + backends = [] + for _ in range(4): + backend = router._get_backend_for_request(request) + backends.append(backend) + + # Should alternate between backends (round-robin) + assert backends[0] == backends[2] + assert backends[1] == backends[3] + assert backends[0] != backends[1] + + +class TestRouterInitialization: + """Tests for router initialization.""" + + def test_router_init_with_servers(self): + """Test router initialization with server list.""" + urls = ["http://a:8000", "http://b:8000"] + router = InferenceRouter(urls, host="127.0.0.1", port=8080) + + assert router._server_urls == urls + assert router._host == "127.0.0.1" + assert router._port == 8080 + + def test_router_start_fails_without_servers(self): + """Test that start fails with empty server list.""" + router = InferenceRouter([], host="0.0.0.0", port=8080) + + with pytest.raises(ValueError, match="No backend servers"): + router.start() From a52b0dc7c3acf8f9e23d8544132675680a3154ee Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 12:23:11 -0800 Subject: [PATCH 02/20] common Signed-off-by: Kourosh Hakhamaneshi --- .../skyrl_train/inference_servers/__init__.py | 25 ------ .../skyrl_train/inference_servers/common.py | 58 ++++++++++--- .../tests/inference_servers/test_common.py | 85 ++++++------------- 3 files changed, 71 insertions(+), 97 deletions(-) diff --git a/skyrl-train/skyrl_train/inference_servers/__init__.py b/skyrl-train/skyrl_train/inference_servers/__init__.py index 06ee5c84d..e69de29bb 100644 --- a/skyrl-train/skyrl_train/inference_servers/__init__.py +++ b/skyrl-train/skyrl_train/inference_servers/__init__.py @@ -1,25 +0,0 @@ -""" -SkyRL Inference Servers Module. - -This module provides HTTP-based inference server infrastructure: -- VLLMServerActor: Ray actor running vLLM OpenAI-compatible server -- ServerActorPool: Generic pool managing server actors -- VLLMServerGroup: vLLM-specific server group with placement group support -- InferenceRouter: HTTP proxy router with session-aware routing -""" - -from skyrl_train.inference_servers.common import ServerInfo, get_node_ip, get_free_port -from skyrl_train.inference_servers.server_pool import ServerActorPool -from skyrl_train.inference_servers.vllm_server_actor import VLLMServerActor -from skyrl_train.inference_servers.server_group import VLLMServerGroup -from skyrl_train.inference_servers.router import InferenceRouter - -__all__ = [ - "ServerInfo", - "get_node_ip", - "get_free_port", - "ServerActorPool", - "VLLMServerActor", - "VLLMServerGroup", - "InferenceRouter", -] diff --git a/skyrl-train/skyrl_train/inference_servers/common.py b/skyrl-train/skyrl_train/inference_servers/common.py index 0a34265ad..cc1205911 100644 --- a/skyrl-train/skyrl_train/inference_servers/common.py +++ b/skyrl-train/skyrl_train/inference_servers/common.py @@ -1,12 +1,17 @@ """ Common utilities for inference servers. + +Uses Ray's public network utilities for consistency with Ray's cluster management. """ +import logging import socket from dataclasses import dataclass import ray +logger = logging.getLogger(__name__) + @dataclass class ServerInfo: @@ -21,19 +26,50 @@ def url(self) -> str: def get_node_ip() -> str: - """Get the IP address of the current node.""" - return ray._private.services.get_node_ip_address().strip("[]") + """ + Get the IP address of the current node. + Returns the node IP from Ray's global worker if Ray is initialized + """ + return ray.util.get_node_ip_address() -def get_free_port(start_port: int = 8000) -> int: - """Find an available port starting from start_port.""" - port = start_port - while True: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + + +def get_open_port(start_port: int | None = None) -> int: + """ + Get an available port. + + Args: + start_port: If provided, search for an available port starting from this value. + If None, let the OS assign a free port. + + Returns: + An available port number. + """ + if start_port is not None: + # Search for available port starting from start_port + port = start_port + while True: try: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(("", port)) - s.listen(1) - return port + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + return port except OSError: port += 1 + if port > 65535: + raise RuntimeError(f"No available port found starting from {start_port}") + + # Let OS assign a free port + # Try IPv4 first + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + pass + + # Try IPv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] \ No newline at end of file diff --git a/skyrl-train/tests/inference_servers/test_common.py b/skyrl-train/tests/inference_servers/test_common.py index 6a3b8acf9..e41fac370 100644 --- a/skyrl-train/tests/inference_servers/test_common.py +++ b/skyrl-train/tests/inference_servers/test_common.py @@ -1,75 +1,38 @@ """Tests for inference_servers.common module.""" -import pytest import socket -from skyrl_train.inference_servers.common import ServerInfo, get_free_port +from skyrl_train.inference_servers.common import ( + get_node_ip, + get_open_port, +) -class TestServerInfo: - """Tests for ServerInfo dataclass.""" +class TestGetIp: + """Tests for get_ip function.""" - def test_server_info_url(self): - """Test URL property.""" - info = ServerInfo(ip="192.168.1.1", port=8000) - assert info.url == "http://192.168.1.1:8000" + def test_get_ip_returns_string(self): + """Test that get_ip returns a string.""" + ip = get_node_ip() + assert isinstance(ip, str) + assert len(ip) > 0 + assert ip != "" + assert "." in ip or ":" in ip - def test_server_info_url_localhost(self): - """Test URL with localhost.""" - info = ServerInfo(ip="127.0.0.1", port=30000) - assert info.url == "http://127.0.0.1:30000" - def test_server_info_fields(self): - """Test dataclass fields.""" - info = ServerInfo(ip="10.0.0.1", port=9000) - assert info.ip == "10.0.0.1" - assert info.port == 9000 +class TestGetOpenPort: + """Tests for get_open_port function.""" - -class TestGetFreePort: - """Tests for get_free_port function.""" - - def test_get_free_port_returns_available(self): - """Test that get_free_port returns an available port.""" - port = get_free_port(start_port=50000) - assert port >= 50000 - - # Verify the port is actually free by binding to it + def test_get_open_port_os_assigned(self): + """Test that get_open_port returns an available port (OS assigned).""" + port = get_open_port() + assert isinstance(port, int) + assert 1 <= port <= 65535 + self._verify_port_is_free(port) + + + def _verify_port_is_free(self, port: int): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("", port)) s.listen(1) - - def test_get_free_port_skips_occupied(self): - """Test that get_free_port skips occupied ports.""" - # Occupy a port - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(("", 51000)) - s.listen(1) - - # Should return a different port - port = get_free_port(start_port=51000) - assert port >= 51000 - - def test_get_free_port_sequential_calls(self): - """Test that sequential calls return different ports when ports are occupied.""" - ports = [] - sockets = [] - - try: - # Get multiple ports and keep them occupied - for i in range(3): - port = get_free_port(start_port=52000) - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.bind(("", port)) - s.listen(1) - ports.append(port) - sockets.append(s) - - # All ports should be unique - assert len(set(ports)) == 3 - finally: - for s in sockets: - s.close() From 6d68e2f93249fd9a24c75a6287ccb356ba89c03a Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 13:22:28 -0800 Subject: [PATCH 03/20] vllm_server_actor Signed-off-by: Kourosh Hakhamaneshi --- skyrl-train/skyrl_train/env_vars.py | 8 + .../skyrl_train/inference_servers/__init__.py | 29 +++ .../skyrl_train/inference_servers/router.py | 3 +- .../inference_servers/vllm_server_actor.py | 219 +++++++----------- .../tests/inference_servers/test_router.py | 5 +- 5 files changed, 125 insertions(+), 139 deletions(-) create mode 100644 skyrl-train/skyrl_train/env_vars.py diff --git a/skyrl-train/skyrl_train/env_vars.py b/skyrl-train/skyrl_train/env_vars.py new file mode 100644 index 000000000..a677ea98f --- /dev/null +++ b/skyrl-train/skyrl_train/env_vars.py @@ -0,0 +1,8 @@ + + + +import os + + +SKYRL_VLLM_DP_PORT_OFFSET = int(os.environ.get("SKYRL_VLLM_DP_PORT_OFFSET", 500)) +SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S = int(os.environ.get("SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S", 600)) \ No newline at end of file diff --git a/skyrl-train/skyrl_train/inference_servers/__init__.py b/skyrl-train/skyrl_train/inference_servers/__init__.py index e69de29bb..96a716591 100644 --- a/skyrl-train/skyrl_train/inference_servers/__init__.py +++ b/skyrl-train/skyrl_train/inference_servers/__init__.py @@ -0,0 +1,29 @@ +""" +SkyRL Inference Servers Module. + +This module provides HTTP-based inference server infrastructure: +- VLLMServerActor: Ray actor running vLLM OpenAI-compatible server +- ServerActorPool: Generic pool managing server actors +- VLLMServerGroup: vLLM-specific server group with placement group support +- InferenceRouter: HTTP proxy router with session-aware routing +""" + +from skyrl_train.inference_servers.common import ( + ServerInfo, + get_node_ip, + get_open_port, +) +from skyrl_train.inference_servers.server_pool import ServerActorPool +from skyrl_train.inference_servers.vllm_server_actor import VLLMServerActor +from skyrl_train.inference_servers.server_group import VLLMServerGroup +from skyrl_train.inference_servers.router import InferenceRouter + +__all__ = [ + "ServerInfo", + "get_node_ip", + "get_open_port", + "ServerActorPool", + "VLLMServerActor", + "VLLMServerGroup", + "InferenceRouter", +] diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py index 8dcf5d81a..862858159 100644 --- a/skyrl-train/skyrl_train/inference_servers/router.py +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -38,10 +38,9 @@ "/wakeup", "/reset_prefix_cache", "/collective_rpc", - "/init_weight_update_communicator", + "/init_weight_transfer", "/update_weights", "/finalize_weight_update", - "/destroy_weights_update_group", ] diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py index 5538e1ca3..220d72d02 100644 --- a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py +++ b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py @@ -7,14 +7,26 @@ import os import pickle import time -from argparse import Namespace from typing import Any, Dict, Optional +from argparse import Namespace import httpx import uvicorn from fastapi import Request -from skyrl_train.inference_servers.common import ServerInfo, get_node_ip, get_free_port + +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.openai.api_server import build_app, create_server_socket, init_app_state +from vllm.usage.usage_lib import UsageContext +import vllm.envs as envs +from vllm.utils.system_utils import set_ulimit + + +from skyrl_train.inference_servers.common import ServerInfo, get_node_ip, get_open_port +from skyrl_train.env_vars import ( + SKYRL_VLLM_DP_PORT_OFFSET, SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S, +) logger = logging.getLogger(__name__) @@ -27,24 +39,27 @@ class VLLMServerActor: called from anywhere (other actors, driver, external processes). Custom endpoints added for SkyRL: - - /init_weight_update_communicator: Initialize weight sync process group + - /get_server_info: Return parallelism info + + - (vLLM RFC: https://github.com/vllm-project/vllm/issues/31848) + - /init_weight_transfer: Initialize weight sync process group - /update_weights: Update model weights via NCCL broadcast - /finalize_weight_update: Post-processing after weight sync - - /destroy_weights_update_group: Teardown weight sync - - /sleep: Offload weights to CPU - - /wake_up: Load weights back to GPU - - /reset_prefix_cache: Clear KV cache - - /get_server_info: Return parallelism info """ @staticmethod - def compute_num_gpus_per_server(engine_args: Namespace) -> int: - """Compute the number of GPUs needed per server based on TP * PP.""" - return engine_args.tensor_parallel_size * engine_args.pipeline_parallel_size + def compute_num_gpus_per_server(vllm_cli_args: Namespace) -> int: + """Compute the number of GPUs needed per server based on TP * PP. + + This logic might need adjustment if we want to support other + parallelism schemes. If we get to this point, we should add a + vllm-specific utility for it and keep the logic inside the engine. + """ + return vllm_cli_args.tensor_parallel_size * vllm_cli_args.pipeline_parallel_size def __init__( self, - engine_args: Namespace, + vllm_cli_args: Namespace, start_port: int = 8000, server_idx: int = 0, dp_size: int = -1, @@ -58,8 +73,8 @@ def __init__( Initialize the vLLM server actor. Args: - engine_args: vLLM engine configuration (Namespace from make_arg_parser). - Required attributes: tensor_parallel_size, pipeline_parallel_size, model. + vllm_cli_args: vLLM CLI arguments. + Required attributes: tensor_parallel_size, pipeline_parallel_size. Optional: uvicorn_log_level, ssl_*, disable_uvicorn_access_log, kv_transfer_config. start_port: Base port to start searching for free port server_idx: Index of this server in the group @@ -69,15 +84,15 @@ def __init__( enable_pd: Enable prefill-decode disaggregation nixl_side_channel_base: Base port for NIXL side channel """ - self._engine_args = engine_args + self._cli_args = vllm_cli_args self._ip = get_node_ip() - self._port = get_free_port(start_port) + self._port = get_open_port(start_port) self._server_idx = server_idx - self._num_gpus_per_server = self.compute_num_gpus_per_server(engine_args) + self._num_gpus_per_server = self.compute_num_gpus_per_server(vllm_cli_args) # Update args with our assigned host/port - self._engine_args.host = "0.0.0.0" - self._engine_args.port = self._port + self._cli_args.host = "0.0.0.0" + self._cli_args.port = self._port # PD disaggregation: setup NIXL side channel for KV transfer if enable_pd: @@ -85,17 +100,20 @@ def __init__( # Each engine needs to know its dp_rank and dp_size so DP process groups are formed if dp_size > 0: - self._engine_args.data_parallel_size = dp_size - self._engine_args.data_parallel_rank = server_idx - # All DP ranks need to know the master's address and RPC port for handshake + self._cli_args.data_parallel_size = dp_size + self._cli_args.data_parallel_rank = server_idx + + # DP0 will be the master sharing its ip and port with others. + # So if we are not DP0, we need to pass master_ip and port from + # outside. otherwise, we can use the local ip and port. if server_idx == 0: dp_master_address, dp_rpc_port = self.get_dp_info() if dp_master_address is None or dp_rpc_port is None: raise ValueError("DP address and RPC port must be set for non-server 0") - self._engine_args.data_parallel_address = dp_master_address - self._engine_args.data_parallel_rpc_port = dp_rpc_port + self._cli_args.data_parallel_address = dp_master_address + self._cli_args.data_parallel_rpc_port = dp_rpc_port logger.info( f"Server {server_idx}: DP enabled - dp_size={dp_size}, dp_rank={server_idx}, " f"dp_master_address={dp_master_address}, dp_rpc_port={dp_rpc_port}" @@ -108,8 +126,9 @@ def __init__( os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) logger.info(f"Server {server_idx}: using bundle indices {bundle_indices}") - self._engine = None - self._server_task = None + # Initialized lazily to not block the actor initialization. + self._engine: Optional[AsyncLLMEngine] = None + self._server_task: Optional[asyncio.Task] = None def _setup_nixl_side_channel(self, base_port: int) -> None: """ @@ -125,13 +144,16 @@ def _setup_nixl_side_channel(self, base_port: int) -> None: engine_id = f"server-{self._server_idx}-{self._ip}-{side_channel_port}" - if hasattr(self._engine_args, "kv_transfer_config") and self._engine_args.kv_transfer_config: + if hasattr(self._cli_args, "kv_transfer_config") and self._cli_args.kv_transfer_config: try: - kv_config = json.loads(self._engine_args.kv_transfer_config) - kv_config["engine_id"] = engine_id - self._engine_args.kv_transfer_config = json.dumps(kv_config) - except (json.JSONDecodeError, TypeError): - pass + kv_config = json.loads(self._cli_args.kv_transfer_config) + except (json.JSONDecodeError, TypeError) as e: + raise ValueError( + f"Invalid kv_transfer_config: expected valid JSON string, " + f"got {type(self._cli_args.kv_transfer_config).__name__}: {e}" + ) from e + kv_config["engine_id"] = engine_id + self._cli_args.kv_transfer_config = json.dumps(kv_config) logger.info( f"Server {self._server_idx}: NIXL side channel configured - " @@ -142,27 +164,23 @@ def get_server_info(self) -> ServerInfo: """Get the server's IP and port info.""" return ServerInfo(ip=self._ip, port=self._port) - def get_extended_server_info(self) -> Dict[str, Any]: + def _get_extended_server_info(self) -> Dict[str, Any]: """Return extended server info including parallelism settings.""" return { "ip": self._ip, "port": self._port, "url": f"http://{self._ip}:{self._port}", "server_idx": self._server_idx, - "tp_size": self._engine_args.tensor_parallel_size, - "pp_size": self._engine_args.pipeline_parallel_size, - "dp_size": getattr(self._engine_args, "data_parallel_size", 1), "world_size": self._num_gpus_per_server, } def get_dp_info(self) -> tuple: """Get the DP master address and RPC port (for server 0 to share with others).""" - dp_rpc_port = self._port + 500 + dp_rpc_port = self._port + SKYRL_VLLM_DP_PORT_OFFSET return (self._ip, dp_rpc_port) async def start(self) -> ServerInfo: """Start the vLLM server. Blocks until server is healthy.""" - from vllm.utils.system_utils import set_ulimit set_ulimit() logger.info(f"Starting server on {self._ip}:{self._port}...") @@ -175,7 +193,7 @@ async def start(self) -> ServerInfo: return self.get_server_info() - async def _wait_until_healthy(self, timeout: float = 600, interval: float = 1.0) -> None: + async def _wait_until_healthy(self, timeout: float = SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S) -> None: """Poll the /health endpoint until it responds OK.""" url = f"http://{self._ip}:{self._port}/health" start_time = time.time() @@ -200,23 +218,17 @@ async def _wait_until_healthy(self, timeout: float = 600, interval: float = 1.0) if time.time() - start_time > timeout: raise TimeoutError(f"Server failed to become healthy within {timeout}s") - await asyncio.sleep(interval) + await asyncio.sleep(1.0) async def _run_server(self) -> None: """Internal method to run the HTTP server.""" - from vllm import AsyncLLMEngine - from vllm.engine.arg_utils import AsyncEngineArgs - from vllm.entrypoints.openai.api_server import build_app, create_server_socket, init_app_state - from vllm.usage.usage_lib import UsageContext - import vllm.envs as envs - - sock_addr = (self._engine_args.host, self._engine_args.port) + sock_addr = (self._cli_args.host, self._cli_args.port) sock = create_server_socket(sock_addr) - app = build_app(self._engine_args) + app = build_app(self._cli_args) # Initialize the engine (this loads the model - takes time) - engine_args = AsyncEngineArgs.from_cli_args(self._engine_args) - self._engine = AsyncLLMEngine.from_engine_args( + engine_args = AsyncEngineArgs.from_cli_args(self._cli_args) + self._engine = AsyncLLMEngine.from_cli_args( engine_args=engine_args, usage_context=UsageContext.OPENAI_API_SERVER, ) @@ -225,20 +237,20 @@ async def _run_server(self) -> None: # Add custom SkyRL endpoints self._add_custom_endpoints(app) - await init_app_state(self._engine, app.state, self._engine_args) + await init_app_state(self._engine, app.state, self._cli_args) # Use uvicorn directly (serve_http tries to add signal handlers which fails in Ray actors) config = uvicorn.Config( app, - host=self._engine_args.host, - port=self._engine_args.port, - log_level=self._engine_args.uvicorn_log_level, + host=self._cli_args.host, + port=self._cli_args.port, + log_level=self._cli_args.uvicorn_log_level, timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, - ssl_keyfile=self._engine_args.ssl_keyfile, - ssl_certfile=self._engine_args.ssl_certfile, - ssl_ca_certs=self._engine_args.ssl_ca_certs, - ssl_cert_reqs=self._engine_args.ssl_cert_reqs, - access_log=not getattr(self._engine_args, "disable_uvicorn_access_log", False), + ssl_keyfile=self._cli_args.ssl_keyfile, + ssl_certfile=self._cli_args.ssl_certfile, + ssl_ca_certs=self._cli_args.ssl_ca_certs, + ssl_cert_reqs=self._cli_args.ssl_cert_reqs, + access_log=not getattr(self._cli_args, "disable_uvicorn_access_log", False), ) server = uvicorn.Server(config) await server.serve(sockets=[sock]) @@ -250,19 +262,25 @@ def _add_custom_endpoints(self, app) -> None: @app.get("/get_server_info") async def _get_server_info(): """Return server parallelism info.""" - return self.get_extended_server_info() + return self._get_extended_server_info() - @app.post("/init_weight_update_communicator") - async def _init_weight_update_communicator(request: Request): + # TODO (Kourosh): After https://github.com/vllm-project/vllm/pull/ + # 31943/ is merged, use the native API. + @app.post("/init_weight_transfer") + async def _init_weight_transfer(request: Request): """Initialize weight sync process group.""" from skyrl_train.weight_sync import BroadcastInitInfo data = await request.json() - init_info = BroadcastInitInfo(**data) + init_info = BroadcastInitInfo(**data).for_engine( + engine_index=self._server_idx, + tp_size=self._cli_args.tensor_parallel_size, + pp_size=self._cli_args.pipeline_parallel_size + ) pickled_init_info = pickle.dumps(init_info) await engine.collective_rpc( - "init_weight_update_communicator", + "init_weight_transfer", args=(pickled_init_info,), ) return {"status": "ok"} @@ -287,80 +305,13 @@ async def _finalize_weight_update(request: Request): """ Finalize weight update - post-processing hook. - Currently a no-op, reserved for future use (cache invalidation, etc). + Currently a no-op, reserved for future use e.g. Quantization + See https://github.com/vllm-project/vllm/issues/31848 for more + details. """ # No-op for now - placeholder for future post-processing return {"status": "ok"} - @app.post("/destroy_weights_update_group") - async def _destroy_weights_update_group(request: Request): - """Teardown weight sync process group.""" - await engine.collective_rpc( - "teardown_weight_receiver", - args=(), - ) - return {"status": "ok"} - - @app.post("/sleep") - async def _sleep(request: Request): - """Offload weights to CPU.""" - data = await request.json() - level = data.get("level", 1) - - # Reset prefix cache before sleep to avoid gibberish on wake - # See: https://github.com/vllm-project/vllm/issues/17103 - await engine.reset_prefix_cache() - await engine.sleep(level) - return {"status": "ok"} - - @app.post("/wake_up") - async def _wake_up(request: Request): - """Load weights back to GPU.""" - data = await request.json() - tags = data.get("tags") - await engine.wake_up(tags) - return {"status": "ok"} - - @app.post("/wakeup") - async def _wakeup(request: Request): - """Alias for /wake_up.""" - data = await request.json() - tags = data.get("tags") - await engine.wake_up(tags) - return {"status": "ok"} - - @app.post("/reset_prefix_cache") - async def _reset_prefix_cache(request: Request): - """Clear KV cache.""" - data = await request.json() - reset_running = data.get("reset_running_requests", False) - if reset_running: - # If reset_running_requests is True, we need to abort first - await engine.abort_all_requests() - await engine.reset_prefix_cache() - return {"status": "ok"} - - @app.post("/pause") - async def _pause(request: Request): - """Pause generation.""" - data = await request.json() - wait_for_inflight = data.get("wait_for_inflight_request", False) - # vLLM's pause API - implementation depends on vLLM version - if hasattr(engine, "pause"): - await engine.pause(wait_for_inflight_request=wait_for_inflight) - else: - # Fallback: abort all if pause not available - if not wait_for_inflight: - await engine.abort_all_requests() - return {"status": "ok"} - - @app.post("/resume") - async def _resume(request: Request): - """Resume generation.""" - if hasattr(engine, "resume"): - await engine.resume() - return {"status": "ok"} - async def shutdown(self) -> None: """Gracefully shutdown the server.""" if self._server_task: diff --git a/skyrl-train/tests/inference_servers/test_router.py b/skyrl-train/tests/inference_servers/test_router.py index eb9a809e8..6d220efbc 100644 --- a/skyrl-train/tests/inference_servers/test_router.py +++ b/skyrl-train/tests/inference_servers/test_router.py @@ -70,7 +70,7 @@ def test_control_plane_route_detection(self, router): assert router._is_control_plane_route("/wake_up") is True assert router._is_control_plane_route("/wakeup") is True assert router._is_control_plane_route("/reset_prefix_cache") is True - assert router._is_control_plane_route("/init_weight_update_communicator") is True + assert router._is_control_plane_route("/init_weight_transfer") is True assert router._is_control_plane_route("/update_weights") is True assert router._is_control_plane_route("/finalize_weight_update") is True @@ -104,10 +104,9 @@ def test_control_plane_routes_list(self): "/wakeup", "/reset_prefix_cache", "/collective_rpc", - "/init_weight_update_communicator", + "/init_weight_transfer", "/update_weights", "/finalize_weight_update", - "/destroy_weights_update_group", ] assert CONTROL_PLANE_ROUTES == expected From d0d2990eaf2cd4b1ba28e90cee80c4d0d9dc65a2 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 13:25:18 -0800 Subject: [PATCH 04/20] pool Signed-off-by: Kourosh Hakhamaneshi --- skyrl-train/skyrl_train/inference_servers/server_pool.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/skyrl-train/skyrl_train/inference_servers/server_pool.py b/skyrl-train/skyrl_train/inference_servers/server_pool.py index ffbd72ea5..e9964cf03 100644 --- a/skyrl-train/skyrl_train/inference_servers/server_pool.py +++ b/skyrl-train/skyrl_train/inference_servers/server_pool.py @@ -10,8 +10,7 @@ class ServerActorPool: - """ - Generic pool that manages a list of server actors. + """Generic pool that manages a list of server actors. Actors must implement: - start() -> ServerInfo From d20b4bd62f8ab08d53086f28430062d06858837f Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 13:34:20 -0800 Subject: [PATCH 05/20] wip Signed-off-by: Kourosh Hakhamaneshi --- .../inference_servers/server_group.py | 22 ++++++++++--------- .../inference_servers/server_pool.py | 3 +++ 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/skyrl-train/skyrl_train/inference_servers/server_group.py b/skyrl-train/skyrl_train/inference_servers/server_group.py index cdbfc3676..b9bdc4a52 100644 --- a/skyrl-train/skyrl_train/inference_servers/server_group.py +++ b/skyrl-train/skyrl_train/inference_servers/server_group.py @@ -33,7 +33,7 @@ class VLLMServerGroup: def __init__( self, - engine_args: Namespace, + vllm_cli_args: Namespace, num_servers: int, start_port: int = 8000, placement_group: Optional[PlacementGroup] = None, @@ -46,19 +46,21 @@ def __init__( Initialize the vLLM server group. Args: - engine_args: vLLM engine configuration (Namespace from make_arg_parser). - Required attributes: tensor_parallel_size, pipeline_parallel_size, model. + vllm_cli_args: vLLM CLI arguments. num_servers: Number of vLLM server instances to create start_port: Base port for server ports placement_group: External placement group for colocation mode. - If None, creates internal placement group. + If None, creates internal placement group. placement_group_bundle_offset: Offset for bundle indices when using - external placement group (e.g., if training uses first N bundles) + external placement group (e.g., if training uses first N + bundles). enable_dp: Enable data parallelism across servers enable_pd: Enable prefill-decode disaggregation - nixl_side_channel_base: Base port for NIXL side channels + nixl_side_channel_base: Base port for NIXL side channels. Each + server will be assigned a port of nixl_side_channel_base + + server_idx. """ - self._engine_args = engine_args + self._vllm_cli_args = vllm_cli_args self._num_servers = num_servers self._start_port = start_port self._external_pg = placement_group @@ -69,8 +71,8 @@ def __init__( self._pool: Optional[ServerActorPool] = None self._internal_pg: Optional[PlacementGroup] = None - # Query the actor class for GPU requirements (single source of truth) - self._num_gpus_per_server = VLLMServerActor.compute_num_gpus_per_server(engine_args) + # Query the actor class for GPU requirements + self._num_gpus_per_server = VLLMServerActor.compute_num_gpus_per_server(vllm_cli_args) logger.info( f"VLLMServerGroup: num_servers={num_servers}, " @@ -124,7 +126,7 @@ def _create_actors(self) -> List[Any]: ServerActorClass = self._create_actor_class(pg, start_bundle_idx) actor = ServerActorClass.remote( - self._engine_args, + self._vllm_cli_args, self._start_port + server_idx, server_idx=server_idx, dp_size=self._num_servers if self._enable_dp else -1, diff --git a/skyrl-train/skyrl_train/inference_servers/server_pool.py b/skyrl-train/skyrl_train/inference_servers/server_pool.py index e9964cf03..1a6a467b2 100644 --- a/skyrl-train/skyrl_train/inference_servers/server_pool.py +++ b/skyrl-train/skyrl_train/inference_servers/server_pool.py @@ -11,6 +11,9 @@ class ServerActorPool: """Generic pool that manages a list of server actors. + + This layer provides a generic pool interface which can be extended to + support fault-tolerance, monitoring, etc. for now it's just a simple wrapper around a list of actor handles. Actors must implement: - start() -> ServerInfo From 07f3d9fdf1f4ac71a8b9f427eaa5899993a342f9 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 13:42:22 -0800 Subject: [PATCH 06/20] group Signed-off-by: Kourosh Hakhamaneshi --- .../inference_servers/protocols.py | 103 ++++++++++++++++++ .../inference_servers/server_group.py | 46 ++++---- .../inference_servers/vllm_server_actor.py | 9 +- 3 files changed, 135 insertions(+), 23 deletions(-) create mode 100644 skyrl-train/skyrl_train/inference_servers/protocols.py diff --git a/skyrl-train/skyrl_train/inference_servers/protocols.py b/skyrl-train/skyrl_train/inference_servers/protocols.py new file mode 100644 index 000000000..324cc5f00 --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/protocols.py @@ -0,0 +1,103 @@ +""" +Protocols for inference server components. + +These define the interfaces that server implementations must follow. +""" + +from argparse import Namespace +from typing import Optional, Protocol, Tuple, runtime_checkable + +from skyrl_train.inference_servers.common import ServerInfo + + +@runtime_checkable +class ServerActorProtocol(Protocol): + """ + Protocol defining the interface for server actor classes. + + Any server actor class (vLLM, SGLang, etc.) must implement this interface + to be usable with ServerGroup. + + Example: + class MyServerActor(ServerActorProtocol): + @staticmethod + def compute_num_gpus_per_server(cli_args: Namespace) -> int: + return cli_args.tensor_parallel_size + + def __init__(self, cli_args, start_port, server_idx, ...): + ... + + async def start(self) -> ServerInfo: + ... + """ + + @staticmethod + def compute_num_gpus_per_server(cli_args: Namespace) -> int: + """ + Compute the number of GPUs needed per server instance. + + This is called before actor creation to determine placement group size. + + Args: + cli_args: Engine-specific CLI arguments. + + Returns: + Number of GPUs required per server (e.g., TP * PP for vLLM). + """ + ... + + def __init__( + self, + cli_args: Namespace, + start_port: int, + server_idx: int, + dp_size: int, + dp_master_address: Optional[str], + dp_rpc_port: Optional[int], + enable_pd: bool, + nixl_side_channel_base: int, + ) -> None: + """ + Initialize the server actor. + + Args: + cli_args: Engine-specific CLI arguments. + start_port: Base port to search for available port. + server_idx: Index of this server in the group (0-indexed). + dp_size: Data parallel size (-1 to disable DP). + dp_master_address: DP master address (for non-rank-0 servers). + dp_rpc_port: DP RPC port (for non-rank-0 servers). + enable_pd: Enable prefill-decode disaggregation. + nixl_side_channel_base: Base port for NIXL side channels. + """ + ... + + def get_server_info(self) -> ServerInfo: + """Get the server's IP and port info.""" + ... + + def get_dp_info(self) -> Tuple[str, int]: + """ + Get the DP master address and RPC port. + + Only called on server_idx=0 when DP is enabled. + + Returns: + Tuple of (master_address, rpc_port). + """ + ... + + async def start(self) -> ServerInfo: + """ + Start the server. + + This should block until the server is healthy and ready to serve requests. + + Returns: + ServerInfo with the server's IP and port. + """ + ... + + async def shutdown(self) -> None: + """Gracefully shutdown the server.""" + ... diff --git a/skyrl-train/skyrl_train/inference_servers/server_group.py b/skyrl-train/skyrl_train/inference_servers/server_group.py index b9bdc4a52..33e700580 100644 --- a/skyrl-train/skyrl_train/inference_servers/server_group.py +++ b/skyrl-train/skyrl_train/inference_servers/server_group.py @@ -1,27 +1,28 @@ """ -vLLM Server Group - manages vLLM server actors with placement groups. +Server Group - manages server actors with placement groups. """ import logging from argparse import Namespace -from typing import Any, List, Optional +from typing import Any, List, Optional, Type import ray from ray.util.placement_group import PlacementGroup, placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from skyrl_train.inference_servers.common import ServerInfo +from skyrl_train.inference_servers.protocols import ServerActorProtocol from skyrl_train.inference_servers.server_pool import ServerActorPool from skyrl_train.inference_servers.vllm_server_actor import VLLMServerActor logger = logging.getLogger(__name__) -class VLLMServerGroup: +class ServerGroup: """ - Creates and manages a group of vLLM server actors. + Creates and manages a group of server actors. - This layer handles vLLM-specific actor creation with placement group support, + This layer handles actor creation with placement group support, then delegates pool management to ServerActorPool. Supports: @@ -33,7 +34,7 @@ class VLLMServerGroup: def __init__( self, - vllm_cli_args: Namespace, + cli_args: Namespace, num_servers: int, start_port: int = 8000, placement_group: Optional[PlacementGroup] = None, @@ -41,26 +42,30 @@ def __init__( enable_dp: bool = False, enable_pd: bool = False, nixl_side_channel_base: int = 5600, + server_actor_cls: Type[ServerActorProtocol] = VLLMServerActor, ): """ - Initialize the vLLM server group. + Initialize the server group. Args: - vllm_cli_args: vLLM CLI arguments. - num_servers: Number of vLLM server instances to create - start_port: Base port for server ports + cli_args: CLI arguments for the server (engine-specific). + num_servers: Number of server instances to create. + start_port: Base port for server ports. placement_group: External placement group for colocation mode. If None, creates internal placement group. placement_group_bundle_offset: Offset for bundle indices when using external placement group (e.g., if training uses first N bundles). - enable_dp: Enable data parallelism across servers - enable_pd: Enable prefill-decode disaggregation + enable_dp: Enable data parallelism across servers. + enable_pd: Enable prefill-decode disaggregation. nixl_side_channel_base: Base port for NIXL side channels. Each server will be assigned a port of nixl_side_channel_base + server_idx. + server_actor_cls: Server actor class implementing ServerActorProtocol. + Defaults to VLLMServerActor. """ - self._vllm_cli_args = vllm_cli_args + self._server_actor_cls = server_actor_cls + self._cli_args = cli_args self._num_servers = num_servers self._start_port = start_port self._external_pg = placement_group @@ -72,10 +77,11 @@ def __init__( self._internal_pg: Optional[PlacementGroup] = None # Query the actor class for GPU requirements - self._num_gpus_per_server = VLLMServerActor.compute_num_gpus_per_server(vllm_cli_args) + self._num_gpus_per_server = server_actor_cls.compute_num_gpus_per_server(cli_args) logger.info( - f"VLLMServerGroup: num_servers={num_servers}, " + f"ServerGroup: actor_cls={server_actor_cls.__name__}, " + f"num_servers={num_servers}, " f"gpus_per_server={self._num_gpus_per_server}, " f"enable_dp={enable_dp}, enable_pd={enable_pd}, " f"external_pg={'yes' if placement_group else 'no'}" @@ -100,7 +106,7 @@ def _get_placement_group(self) -> PlacementGroup: def _create_actor_class(self, pg: PlacementGroup, start_bundle_idx: int) -> Any: """Create actor class with scheduling constraints for a specific bundle.""" - return ray.remote(VLLMServerActor).options( + return ray.remote(self._server_actor_cls).options( num_gpus=0, # GPU allocation managed by placement group num_cpus=1, scheduling_strategy=PlacementGroupSchedulingStrategy( @@ -111,7 +117,7 @@ def _create_actor_class(self, pg: PlacementGroup, start_bundle_idx: int) -> Any: ) def _create_actors(self) -> List[Any]: - """Create vLLM server actors with GPU resources.""" + """Create server actors with GPU resources.""" pg = self._get_placement_group() actors = [] @@ -126,7 +132,7 @@ def _create_actors(self) -> List[Any]: ServerActorClass = self._create_actor_class(pg, start_bundle_idx) actor = ServerActorClass.remote( - self._vllm_cli_args, + self._cli_args, self._start_port + server_idx, server_idx=server_idx, dp_size=self._num_servers if self._enable_dp else -1, @@ -147,7 +153,7 @@ def _create_actors(self) -> List[Any]: def start(self) -> List[ServerInfo]: """Create actors, start the pool, and return endpoints.""" - logger.info(f"Starting {self._num_servers} vLLM server(s)...") + logger.info(f"Starting {self._num_servers} server(s)...") actors = self._create_actors() self._pool = ServerActorPool(actors) server_infos = self._pool.start() @@ -176,5 +182,5 @@ def get_actors(self) -> List[Any]: def shutdown(self) -> None: """Shutdown all servers.""" if self._pool: - logger.info("Shutting down vLLM servers...") + logger.info("Shutting down servers...") self._pool.shutdown() diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py index 220d72d02..68d9c1283 100644 --- a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py +++ b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py @@ -7,7 +7,7 @@ import os import pickle import time -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from argparse import Namespace import httpx @@ -27,13 +27,16 @@ from skyrl_train.env_vars import ( SKYRL_VLLM_DP_PORT_OFFSET, SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S, ) +from skyrl_train.inference_servers.protocols import ServerActorProtocol logger = logging.getLogger(__name__) -class VLLMServerActor: +class VLLMServerActor(ServerActorProtocol): """ Ray actor that runs a vLLM OpenAI-compatible API server. + + Implements ServerActorProtocol for use with ServerGroup. The server runs in the actor and exposes an HTTP endpoint that can be called from anywhere (other actors, driver, external processes). @@ -174,7 +177,7 @@ def _get_extended_server_info(self) -> Dict[str, Any]: "world_size": self._num_gpus_per_server, } - def get_dp_info(self) -> tuple: + def get_dp_info(self) -> Tuple[str, int]: """Get the DP master address and RPC port (for server 0 to share with others).""" dp_rpc_port = self._port + SKYRL_VLLM_DP_PORT_OFFSET return (self._ip, dp_rpc_port) From 1a48e613beddd8545316caffdcce673da51c80d4 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 14:00:56 -0800 Subject: [PATCH 07/20] wip Signed-off-by: Kourosh Hakhamaneshi --- .../inference_engines/vllm/vllm_engine.py | 68 +----------- .../inference_servers/vllm_server_actor.py | 21 +++- .../inference_servers/vllm_worker.py | 104 ++++++++++++++++++ 3 files changed, 127 insertions(+), 66 deletions(-) create mode 100644 skyrl-train/skyrl_train/inference_servers/vllm_worker.py diff --git a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py index f519a07c1..e80630e72 100644 --- a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -66,70 +66,10 @@ def setup_envvars_for_vllm(kwargs, bundle_indices): logger.info(f"creating LLM with bundle_indices={bundle_indices}") -class WorkerWrap: - def test_rpc(self, *args, **kwargs): - """Test RPC call to worker""" - return args, kwargs - - def init_weight_update_communicator(self, init_info: bytes): - """Init weight update communicator from init info. - - Args: - init_info: Pickled bytes of WeightSyncInitInfo from the sender. - """ - import pickle - - assert torch.distributed.is_initialized(), "default torch process group must be initialized" - - # Unpickle init_info to restore the original object type - assert isinstance(init_info, bytes), f"Expected bytes, got {type(init_info).__name__}" - init_info = pickle.loads(init_info) - - strategy_cls = init_info.strategy_type() - - if hasattr(self, "_weight_receiver") and self._weight_receiver is not None: - # TODO(haochen): we should get rid of this flag and override existing receiver. - if init_info.override_existing_receiver: - self._weight_receiver.teardown() - self._weight_receiver = None - else: - warnings.warn( - "Detected an existing weight receiver. " - "For overriding, use `generator.override_existing_update_group=enable`" - ) - return - - self._weight_receiver = strategy_cls.create_receiver(init_info) - - def load_weights(self, request: bytes) -> None: - """Load weights using the receiver. - - This method is called via collective_rpc from VLLMWeightLoader. - - Args: - request: Pickled bytes of WeightUpdateRequest. - """ - import pickle - - # Unpickle request to restore the original object type - assert isinstance(request, bytes), f"Expected bytes, got {type(request).__name__}" - request = pickle.loads(request) - - weight_list = [] - for name, tensor in self._weight_receiver.receive_weights(request): - weight_list.append((name, tensor)) - - self.model_runner.model.load_weights(weights=weight_list) - - for weight in weight_list: - del weight - - # TODO (sumanthrh): Add destroy process group RPC as a atexit handler to Trainer code. - def teardown_weight_receiver(self): - if not hasattr(self, "_weight_receiver") or self._weight_receiver is None: - warnings.warn("No weight receiver to teardown") - return - self._weight_receiver.teardown() +# Backward compatibility: WorkerWrap has moved to inference_servers.vllm_worker +# This alias preserves the old import path for existing scripts/configs. +# TODO: Remove this alias once all references are updated. +from skyrl_train.inference_servers.vllm_worker import WorkerWrap # noqa: F401 class BaseVLLMInferenceEngine(InferenceEngineInterface): diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py index 68d9c1283..e6ab3e445 100644 --- a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py +++ b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py @@ -24,10 +24,11 @@ from skyrl_train.inference_servers.common import ServerInfo, get_node_ip, get_open_port +from skyrl_train.inference_servers.protocols import ServerActorProtocol +from skyrl_train.inference_servers.vllm_worker import VLLM_WORKER_EXTENSION_CLS from skyrl_train.env_vars import ( SKYRL_VLLM_DP_PORT_OFFSET, SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S, ) -from skyrl_train.inference_servers.protocols import ServerActorProtocol logger = logging.getLogger(__name__) @@ -92,6 +93,9 @@ def __init__( self._port = get_open_port(start_port) self._server_idx = server_idx self._num_gpus_per_server = self.compute_num_gpus_per_server(vllm_cli_args) + + # Ensure SkyRL's custom worker extension is used for weight sync + self._ensure_worker_extension() # Update args with our assigned host/port self._cli_args.host = "0.0.0.0" @@ -133,6 +137,19 @@ def __init__( self._engine: Optional[AsyncLLMEngine] = None self._server_task: Optional[asyncio.Task] = None + def _ensure_worker_extension(self) -> None: + """ + Ensure the SkyRL worker extension is configured. + + The worker extension (WorkerWrap) provides the RPC methods needed for + weight synchronization (init_weight_update_communicator, load_weights). + """ + if not hasattr(self._cli_args, "worker_extension_cls") or not self._cli_args.worker_extension_cls: + self._cli_args.worker_extension_cls = VLLM_WORKER_EXTENSION_CLS + logger.info(f"Using default worker extension: {VLLM_WORKER_EXTENSION_CLS}") + else: + logger.info(f"Using provided worker extension: {self._cli_args.worker_extension_cls}") + def _setup_nixl_side_channel(self, base_port: int) -> None: """ Setup NIXL side channel for PD disaggregation. @@ -283,7 +300,7 @@ async def _init_weight_transfer(request: Request): pickled_init_info = pickle.dumps(init_info) await engine.collective_rpc( - "init_weight_transfer", + "init_weight_update_communicator", args=(pickled_init_info,), ) return {"status": "ok"} diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_worker.py b/skyrl-train/skyrl_train/inference_servers/vllm_worker.py new file mode 100644 index 000000000..00edf43f0 --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/vllm_worker.py @@ -0,0 +1,104 @@ +""" +vLLM Worker Extension for SkyRL weight synchronization. + +This module provides WorkerWrap, a vLLM worker extension class that enables +efficient NCCL-based and CUDA IPC-based weight updates from the training +process to inference workers. + +TODO: This will be removed once vLLM natively supports weight sync APIs. +See: https://github.com/vllm-project/vllm/issues/31848 + +Usage: + Pass as --worker-extension-cls to vLLM: + + vllm serve ... --worker-extension-cls skyrl_train.inference_servers.vllm_worker.WorkerWrap +""" + +import warnings + +import torch + + +# Path to this worker extension class for use in CLI args (derived from module path) +VLLM_WORKER_EXTENSION_CLS = f"{__name__}.WorkerWrap" + + +class WorkerWrap: + """ + vLLM worker extension for SkyRL weight synchronization. + + This class is injected into vLLM workers via --worker-extension-cls and + provides methods that can be called via engine.collective_rpc() to + coordinate weight updates across all TP/PP workers. + + Methods: + init_weight_update_communicator: Initialize the weight receiver + load_weights: Receive and load weights from trainer + teardown_weight_receiver: Clean up weight receiver resources + """ + + def test_rpc(self, *args, **kwargs): + """Test RPC call to worker.""" + return args, kwargs + + def init_weight_update_communicator(self, init_info: bytes): + """ + Initialize weight update communicator from init info. + + Args: + init_info: Pickled bytes of WeightSyncInitInfo from the sender. + """ + import pickle + + assert torch.distributed.is_initialized(), "default torch process group must be initialized" + + # Unpickle init_info to restore the original object type + assert isinstance(init_info, bytes), f"Expected bytes, got {type(init_info).__name__}" + init_info = pickle.loads(init_info) + + strategy_cls = init_info.strategy_type() + + if hasattr(self, "_weight_receiver") and self._weight_receiver is not None: + # TODO(haochen): we should get rid of this flag and override existing receiver. + if init_info.override_existing_receiver: + self._weight_receiver.teardown() + self._weight_receiver = None + else: + warnings.warn( + "Detected an existing weight receiver. " + "For overriding, use `generator.override_existing_update_group=enable`" + ) + return + + self._weight_receiver = strategy_cls.create_receiver(init_info) + + def load_weights(self, request: bytes) -> None: + """ + Load weights using the receiver. + + This method is called via collective_rpc from the weight loader. + + Args: + request: Pickled bytes of WeightUpdateRequest. + """ + import pickle + + # Unpickle request to restore the original object type + assert isinstance(request, bytes), f"Expected bytes, got {type(request).__name__}" + request = pickle.loads(request) + + weight_list = [] + for name, tensor in self._weight_receiver.receive_weights(request): + weight_list.append((name, tensor)) + + self.model_runner.model.load_weights(weights=weight_list) + + for weight in weight_list: + del weight + + def teardown_weight_receiver(self): + """Clean up weight receiver resources.""" + if not hasattr(self, "_weight_receiver") or self._weight_receiver is None: + warnings.warn("No weight receiver to teardown") + return + self._weight_receiver.teardown() From e290f4bc1a0561ac827a228224627b4cc69421f6 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 14:24:16 -0800 Subject: [PATCH 08/20] wip Signed-off-by: Kourosh Hakhamaneshi --- .../skyrl_train/inference_servers/router.py | 112 ++++++------------ .../tests/inference_servers/test_router.py | 66 +++++------ 2 files changed, 67 insertions(+), 111 deletions(-) diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py index 862858159..c78b47ad6 100644 --- a/skyrl-train/skyrl_train/inference_servers/router.py +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -11,14 +11,13 @@ import httpx import uvicorn from fastapi import FastAPI, Request, Response -from fastapi.responses import StreamingResponse from skyrl_train.inference_servers.common import ServerInfo, get_node_ip logger = logging.getLogger(__name__) -# Routes that go to ONE backend (data plane) +# Routes that are loaded balanced (data plane) DATA_PLANE_ROUTES = [ "/v1/completions", "/v1/chat/completions", @@ -29,15 +28,16 @@ "/version", ] -# Routes that go to ALL backends (control plane) +# Routes that go to ALL backends via a broadcast (control plane) CONTROL_PLANE_ROUTES = [ + # BUILT-IN ROUTES "/pause", "/resume", "/sleep", "/wake_up", - "/wakeup", "/reset_prefix_cache", "/collective_rpc", + # SKYRL-SPECIFIC ROUTES "/init_weight_transfer", "/update_weights", "/finalize_weight_update", @@ -46,10 +46,10 @@ class InferenceRouter: """ - HTTP proxy router for multiple vLLM backends. + HTTP proxy router for multiple vLLM servers. Routing behavior: - - Data plane (generation requests): Routes to ONE backend + - Data plane (generation requests): Routes to ONE server. - If X-Session-ID header present: consistent hash to same backend - Otherwise: round-robin - Control plane (sleep, pause, weight sync): Fans out to ALL backends @@ -84,33 +84,33 @@ def __init__( self._server_task: Optional[asyncio.Task] = None self._shutdown_event: Optional[asyncio.Event] = None - logger.info(f"InferenceRouter: {len(server_urls)} backends, port={port}") + logger.info(f"InferenceRouter: {len(server_urls)} servers, port={port}") def _hash_session_id(self, session_id: str) -> int: - """Hash session ID to get consistent backend index.""" + """Hash session ID to get consistent server index.""" hash_bytes = hashlib.sha256(session_id.encode()).digest() return int.from_bytes(hash_bytes[:8], "big") - def _get_backend_for_session(self, session_id: str) -> str: - """Get consistent backend URL for a session ID.""" + def _get_server_for_session(self, session_id: str) -> str: + """Get consistent server URL for a session ID.""" idx = self._hash_session_id(session_id) % len(self._server_urls) return self._server_urls[idx] - def _get_backend_round_robin(self) -> str: - """Get next backend URL in round-robin order.""" + def _get_server_round_robin(self) -> str: + """Get next server URL in round-robin order.""" return next(self._server_cycle) - def _get_backend_for_request(self, request: Request) -> str: + def _get_server_for_request(self, request: Request) -> str: """ - Determine backend for a request. + Determine server for a request. If X-Session-ID header is present, use consistent hashing. Otherwise, use round-robin. """ session_id = request.headers.get("X-Session-ID") if session_id: - return self._get_backend_for_session(session_id) - return self._get_backend_round_robin() + return self._get_server_for_session(session_id) + return self._get_server_round_robin() def _is_control_plane_route(self, path: str) -> bool: """Check if path is a control plane route (fan-out to all).""" @@ -127,15 +127,15 @@ def _build_app(self) -> FastAPI: @app.get("/servers") async def list_servers(): - """Return list of backend URLs.""" + """Return list of server URLs.""" return {"servers": self._server_urls} @app.get("/get_server_info") async def get_server_info(): - """Fetch server info from any backend (all should return same).""" - backend = self._server_urls[0] + """Fetch server info from first server (all should return same).""" + server_url = self._server_urls[0] try: - resp = await self._client.get(f"{backend}/get_server_info", timeout=10.0) + resp = await self._client.get(f"{server_url}/get_server_info", timeout=10.0) return resp.json() except Exception as e: return {"error": str(e)} @@ -163,23 +163,9 @@ async def _proxy_request(self, request: Request, path: str) -> Response: return await self._proxy_to_one(request, path) async def _proxy_to_one(self, request: Request, path: str) -> Response: - """Proxy request to one backend (data plane).""" - backend_url = self._get_backend_for_request(request) - url = f"{backend_url}{path}" - method = request.method - - body = await request.body() - - # Check if streaming is requested - is_streaming = False - if body: - try: - import json - - data = json.loads(body) - is_streaming = data.get("stream", False) - except (json.JSONDecodeError, UnicodeDecodeError): - pass + """Proxy request to one server (data plane).""" + server_url = self._get_server_for_request(request) + url = f"{server_url}{path}" # Forward headers (filter out hop-by-hop headers) headers = { @@ -188,20 +174,11 @@ async def _proxy_to_one(self, request: Request, path: str) -> Response: if k.lower() not in ("host", "content-length", "transfer-encoding") } - if is_streaming: - return await self._proxy_streaming(url, method, headers, body) - else: - return await self._proxy_regular(url, method, headers, body) - - async def _proxy_regular( - self, url: str, method: str, headers: dict, body: bytes - ) -> Response: - """Proxy a regular (non-streaming) request.""" response = await self._client.request( - method=method, + method=request.method, url=url, headers=headers, - content=body, + content=await request.body(), ) return Response( @@ -210,28 +187,8 @@ async def _proxy_regular( headers=dict(response.headers), ) - async def _proxy_streaming( - self, url: str, method: str, headers: dict, body: bytes - ) -> StreamingResponse: - """Proxy a streaming request.""" - - async def stream_generator(): - async with self._client.stream( - method=method, - url=url, - headers=headers, - content=body, - ) as response: - async for chunk in response.aiter_bytes(): - yield chunk - - return StreamingResponse( - stream_generator(), - media_type="text/event-stream", - ) - async def _proxy_to_all(self, request: Request, path: str) -> Response: - """Proxy request to all backends (control plane), aggregate responses.""" + """Proxy request to all servers (control plane), aggregate responses.""" method = request.method body = await request.body() @@ -242,31 +199,30 @@ async def _proxy_to_all(self, request: Request, path: str) -> Response: if k.lower() not in ("host", "content-length", "transfer-encoding") } - # Send to all backends concurrently - async def call_backend(backend_url: str): - url = f"{backend_url}{path}" + # Send to all servers concurrently + async def call_server(server_url: str): + url = f"{server_url}{path}" try: response = await self._client.request( method=method, url=url, headers=headers, content=body, - timeout=300.0, # Long timeout for weight sync ) return { - "url": backend_url, + "url": server_url, "status": response.status_code, "body": response.json() if response.content else None, } except Exception as e: return { - "url": backend_url, + "url": server_url, "status": 500, "error": str(e), } results = await asyncio.gather( - *[call_backend(url) for url in self._server_urls] + *[call_server(url) for url in self._server_urls] ) # Check if all succeeded @@ -295,7 +251,7 @@ def start(self) -> str: Router URL (e.g., "http://192.168.1.1:8080") """ if not self._server_urls: - raise ValueError("No backend servers available") + raise ValueError("No servers available") # Create HTTP client for proxying self._client = httpx.AsyncClient(timeout=httpx.Timeout(None)) @@ -323,7 +279,7 @@ def run_server(): ip = get_node_ip() router_url = f"http://{ip}:{self._port}" logger.info(f"Router started at {router_url}") - logger.info(f" GET /servers - list backend servers") + logger.info(f" GET /servers - list servers") logger.info(f" GET /get_server_info - get parallelism info") return router_url diff --git a/skyrl-train/tests/inference_servers/test_router.py b/skyrl-train/tests/inference_servers/test_router.py index 6d220efbc..5ea1ced9a 100644 --- a/skyrl-train/tests/inference_servers/test_router.py +++ b/skyrl-train/tests/inference_servers/test_router.py @@ -24,42 +24,42 @@ def router(self): return InferenceRouter(server_urls, host="0.0.0.0", port=9999) def test_session_hash_consistency(self, router): - """Test that same session ID always maps to same backend.""" + """Test that same session ID always maps to same server.""" session_id = "user-123-session-456" - # Multiple calls should return the same backend - backend1 = router._get_backend_for_session(session_id) - backend2 = router._get_backend_for_session(session_id) - backend3 = router._get_backend_for_session(session_id) + # Multiple calls should return the same server + server1 = router._get_server_for_session(session_id) + server2 = router._get_server_for_session(session_id) + server3 = router._get_server_for_session(session_id) - assert backend1 == backend2 == backend3 + assert server1 == server2 == server3 def test_different_sessions_distribute(self, router): - """Test that different session IDs distribute across backends.""" - # With enough session IDs, we should hit multiple backends - backends = set() + """Test that different session IDs distribute across servers.""" + # With enough session IDs, we should hit multiple servers + servers = set() for i in range(100): session_id = f"session-{i}" - backend = router._get_backend_for_session(session_id) - backends.add(backend) + server = router._get_server_for_session(session_id) + servers.add(server) - # Should hit multiple backends (not all requests to one) - assert len(backends) >= 2 + # Should hit multiple servers (not all requests to one) + assert len(servers) >= 2 def test_round_robin_cycles(self, router): - """Test that round-robin cycles through all backends.""" - backends = [] + """Test that round-robin cycles through all servers.""" + servers = [] for _ in range(6): # 2 full cycles - backend = router._get_backend_round_robin() - backends.append(backend) + server = router._get_server_round_robin() + servers.append(server) # First 3 should be unique - assert len(set(backends[:3])) == 3 + assert len(set(servers[:3])) == 3 # Should repeat the pattern - assert backends[0] == backends[3] - assert backends[1] == backends[4] - assert backends[2] == backends[5] + assert servers[0] == servers[3] + assert servers[1] == servers[4] + assert servers[2] == servers[5] def test_control_plane_route_detection(self, router): """Test control plane route detection.""" @@ -129,11 +129,11 @@ def test_request_with_session_id_header(self, router): request = MagicMock() request.headers = {"X-Session-ID": "test-session-123"} - backend1 = router._get_backend_for_request(request) - backend2 = router._get_backend_for_request(request) + server1 = router._get_server_for_request(request) + server2 = router._get_server_for_request(request) - # Same session should get same backend - assert backend1 == backend2 + # Same session should get same server + assert server1 == server2 def test_request_without_session_id_header(self, router): """Test that missing X-Session-ID header triggers round-robin.""" @@ -141,15 +141,15 @@ def test_request_without_session_id_header(self, router): request = MagicMock() request.headers = {} - backends = [] + servers = [] for _ in range(4): - backend = router._get_backend_for_request(request) - backends.append(backend) + server = router._get_server_for_request(request) + servers.append(server) - # Should alternate between backends (round-robin) - assert backends[0] == backends[2] - assert backends[1] == backends[3] - assert backends[0] != backends[1] + # Should alternate between servers (round-robin) + assert servers[0] == servers[2] + assert servers[1] == servers[3] + assert servers[0] != servers[1] class TestRouterInitialization: @@ -168,5 +168,5 @@ def test_router_start_fails_without_servers(self): """Test that start fails with empty server list.""" router = InferenceRouter([], host="0.0.0.0", port=8080) - with pytest.raises(ValueError, match="No backend servers"): + with pytest.raises(ValueError, match="No servers"): router.start() From 509538f1753aa532ff9454bb8944f99940acdbca Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 14:47:09 -0800 Subject: [PATCH 09/20] tests Signed-off-by: Kourosh Hakhamaneshi --- .../inference_engines/vllm/vllm_engine.py | 2 +- .../skyrl_train/inference_servers/__init__.py | 29 -- .../skyrl_train/inference_servers/router.py | 20 +- .../tests/inference_servers/test_router.py | 265 +++++++----------- 4 files changed, 106 insertions(+), 210 deletions(-) diff --git a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py index e80630e72..80259d5a3 100644 --- a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -68,7 +68,7 @@ def setup_envvars_for_vllm(kwargs, bundle_indices): # Backward compatibility: WorkerWrap has moved to inference_servers.vllm_worker # This alias preserves the old import path for existing scripts/configs. -# TODO: Remove this alias once all references are updated. +# TODO (Kourosh): Remove this alias once all references are updated. from skyrl_train.inference_servers.vllm_worker import WorkerWrap # noqa: F401 diff --git a/skyrl-train/skyrl_train/inference_servers/__init__.py b/skyrl-train/skyrl_train/inference_servers/__init__.py index 96a716591..e69de29bb 100644 --- a/skyrl-train/skyrl_train/inference_servers/__init__.py +++ b/skyrl-train/skyrl_train/inference_servers/__init__.py @@ -1,29 +0,0 @@ -""" -SkyRL Inference Servers Module. - -This module provides HTTP-based inference server infrastructure: -- VLLMServerActor: Ray actor running vLLM OpenAI-compatible server -- ServerActorPool: Generic pool managing server actors -- VLLMServerGroup: vLLM-specific server group with placement group support -- InferenceRouter: HTTP proxy router with session-aware routing -""" - -from skyrl_train.inference_servers.common import ( - ServerInfo, - get_node_ip, - get_open_port, -) -from skyrl_train.inference_servers.server_pool import ServerActorPool -from skyrl_train.inference_servers.vllm_server_actor import VLLMServerActor -from skyrl_train.inference_servers.server_group import VLLMServerGroup -from skyrl_train.inference_servers.router import InferenceRouter - -__all__ = [ - "ServerInfo", - "get_node_ip", - "get_open_port", - "ServerActorPool", - "VLLMServerActor", - "VLLMServerGroup", - "InferenceRouter", -] diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py index c78b47ad6..ba4f4a66c 100644 --- a/skyrl-train/skyrl_train/inference_servers/router.py +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -162,17 +162,21 @@ async def _proxy_request(self, request: Request, path: str) -> Response: else: return await self._proxy_to_one(request, path) + def _forward_headers(self, request: Request) -> dict: + """Forward headers (filter out hop-by-hop headers).""" + return { + k: v + for k, v in request.headers.items() + if k.lower() not in ("host", "content-length", "transfer-encoding") + } + async def _proxy_to_one(self, request: Request, path: str) -> Response: """Proxy request to one server (data plane).""" server_url = self._get_server_for_request(request) url = f"{server_url}{path}" # Forward headers (filter out hop-by-hop headers) - headers = { - k: v - for k, v in request.headers.items() - if k.lower() not in ("host", "content-length", "transfer-encoding") - } + headers = self._forward_headers(request) response = await self._client.request( method=request.method, @@ -193,11 +197,7 @@ async def _proxy_to_all(self, request: Request, path: str) -> Response: body = await request.body() # Forward headers - headers = { - k: v - for k, v in request.headers.items() - if k.lower() not in ("host", "content-length", "transfer-encoding") - } + headers = self._forward_headers(request) # Send to all servers concurrently async def call_server(server_url: str): diff --git a/skyrl-train/tests/inference_servers/test_router.py b/skyrl-train/tests/inference_servers/test_router.py index 5ea1ced9a..7bccb568b 100644 --- a/skyrl-train/tests/inference_servers/test_router.py +++ b/skyrl-train/tests/inference_servers/test_router.py @@ -1,172 +1,97 @@ -"""Tests for inference_servers.router module.""" +"""Tests for InferenceRouter.""" +import asyncio +import socket +import threading +import time + +import httpx import pytest -from unittest.mock import MagicMock, AsyncMock - -from skyrl_train.inference_servers.router import ( - InferenceRouter, - DATA_PLANE_ROUTES, - CONTROL_PLANE_ROUTES, -) - - -class TestRouterRoutingLogic: - """Tests for router routing logic (no actual HTTP calls).""" - - @pytest.fixture - def router(self): - """Create a router with mock backends.""" - server_urls = [ - "http://backend1:8000", - "http://backend2:8000", - "http://backend3:8000", - ] - return InferenceRouter(server_urls, host="0.0.0.0", port=9999) - - def test_session_hash_consistency(self, router): - """Test that same session ID always maps to same server.""" - session_id = "user-123-session-456" - - # Multiple calls should return the same server - server1 = router._get_server_for_session(session_id) - server2 = router._get_server_for_session(session_id) - server3 = router._get_server_for_session(session_id) - - assert server1 == server2 == server3 - - def test_different_sessions_distribute(self, router): - """Test that different session IDs distribute across servers.""" - # With enough session IDs, we should hit multiple servers - servers = set() - for i in range(100): - session_id = f"session-{i}" - server = router._get_server_for_session(session_id) - servers.add(server) - - # Should hit multiple servers (not all requests to one) - assert len(servers) >= 2 - - def test_round_robin_cycles(self, router): - """Test that round-robin cycles through all servers.""" - servers = [] - for _ in range(6): # 2 full cycles - server = router._get_server_round_robin() - servers.append(server) - - # First 3 should be unique - assert len(set(servers[:3])) == 3 - - # Should repeat the pattern - assert servers[0] == servers[3] - assert servers[1] == servers[4] - assert servers[2] == servers[5] - - def test_control_plane_route_detection(self, router): - """Test control plane route detection.""" - # Control plane routes - assert router._is_control_plane_route("/pause") is True - assert router._is_control_plane_route("/resume") is True - assert router._is_control_plane_route("/sleep") is True - assert router._is_control_plane_route("/wake_up") is True - assert router._is_control_plane_route("/wakeup") is True - assert router._is_control_plane_route("/reset_prefix_cache") is True - assert router._is_control_plane_route("/init_weight_transfer") is True - assert router._is_control_plane_route("/update_weights") is True - assert router._is_control_plane_route("/finalize_weight_update") is True - - # Data plane routes should NOT be control plane - assert router._is_control_plane_route("/v1/completions") is False - assert router._is_control_plane_route("/v1/chat/completions") is False - assert router._is_control_plane_route("/health") is False - assert router._is_control_plane_route("/models") is False - assert router._is_control_plane_route("/tokenize") is False - - def test_data_plane_routes_list(self): - """Test that data plane routes list is correct.""" - expected = [ - "/v1/completions", - "/v1/chat/completions", - "/tokenize", - "/detokenize", - "/health", - "/models", - "/version", - ] - assert DATA_PLANE_ROUTES == expected - - def test_control_plane_routes_list(self): - """Test that control plane routes list is correct.""" - expected = [ - "/pause", - "/resume", - "/sleep", - "/wake_up", - "/wakeup", - "/reset_prefix_cache", - "/collective_rpc", - "/init_weight_transfer", - "/update_weights", - "/finalize_weight_update", - ] - assert CONTROL_PLANE_ROUTES == expected - - -class TestRouterRequestRouting: - """Tests for request routing based on headers.""" - - @pytest.fixture - def router(self): - """Create a router with mock backends.""" - server_urls = [ - "http://backend1:8000", - "http://backend2:8000", - ] - return InferenceRouter(server_urls, host="0.0.0.0", port=9999) - - def test_request_with_session_id_header(self, router): - """Test that X-Session-ID header triggers session-aware routing.""" - # Create mock request with session header - request = MagicMock() - request.headers = {"X-Session-ID": "test-session-123"} - - server1 = router._get_server_for_request(request) - server2 = router._get_server_for_request(request) - - # Same session should get same server - assert server1 == server2 - - def test_request_without_session_id_header(self, router): - """Test that missing X-Session-ID header triggers round-robin.""" - # Create mock request without session header - request = MagicMock() - request.headers = {} - - servers = [] - for _ in range(4): - server = router._get_server_for_request(request) - servers.append(server) - - # Should alternate between servers (round-robin) - assert servers[0] == servers[2] - assert servers[1] == servers[3] - assert servers[0] != servers[1] - - -class TestRouterInitialization: - """Tests for router initialization.""" - - def test_router_init_with_servers(self): - """Test router initialization with server list.""" - urls = ["http://a:8000", "http://b:8000"] - router = InferenceRouter(urls, host="127.0.0.1", port=8080) - - assert router._server_urls == urls - assert router._host == "127.0.0.1" - assert router._port == 8080 - - def test_router_start_fails_without_servers(self): - """Test that start fails with empty server list.""" - router = InferenceRouter([], host="0.0.0.0", port=8080) - - with pytest.raises(ValueError, match="No servers"): - router.start() +import uvicorn +from fastapi import FastAPI + +from skyrl_train.inference_servers.router import InferenceRouter + + +def get_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + s.listen(1) + return s.getsockname()[1] + + +def create_mock_server(server_id: int) -> FastAPI: + app = FastAPI() + + @app.api_route("/{path:path}", methods=["GET", "POST"]) + async def catch_all(path: str): + return {"server_id": server_id, "path": f"/{path}"} + + return app + + +def start_server(port: int, server_id: int) -> None: + app = create_mock_server(server_id) + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") + server = uvicorn.Server(config) + threading.Thread(target=lambda: asyncio.run(server.serve()), daemon=True).start() + + +def wait_ready(url: str, timeout: float = 5.0) -> bool: + start = time.time() + while time.time() - start < timeout: + try: + if httpx.get(f"{url}/health", timeout=1.0).status_code == 200: + return True + except httpx.RequestError: + time.sleep(0.1) + return False + + +@pytest.fixture(scope="module") +def env(): + """Start mock servers and router once for all tests.""" + ports = [get_free_port(), get_free_port()] + router_port = get_free_port() + urls = [f"http://127.0.0.1:{p}" for p in ports] + + for i, port in enumerate(ports): + start_server(port, server_id=i) + for url in urls: + assert wait_ready(url) + + router = InferenceRouter(urls, host="127.0.0.1", port=router_port) + router._client = httpx.AsyncClient(timeout=httpx.Timeout(None)) + router._app = router._build_app() + threading.Thread( + target=lambda: asyncio.run(router._run_server()), daemon=True + ).start() + + router_url = f"http://127.0.0.1:{router_port}" + assert wait_ready(router_url) + yield router_url + + +def test_round_robin(env): + """Requests without session distribute across servers.""" + server_ids = {httpx.get(f"{env}/health").json()["server_id"] for _ in range(4)} + assert len(server_ids) == 2 + + +def test_session_affinity(env): + """Same X-Session-ID routes to same server.""" + headers = {"X-Session-ID": "sticky"} + ids = [httpx.get(f"{env}/health", headers=headers).json()["server_id"] for _ in range(3)] + assert len(set(ids)) == 1 + + +def test_control_plane_fanout(env): + """Control plane routes fan out to all servers.""" + resp = httpx.post(f"{env}/sleep", json={}) + assert resp.status_code == 200 and resp.json()["status"] == "ok" + + +def test_list_servers(env): + """/servers returns all server URLs.""" + resp = httpx.get(f"{env}/servers") + assert resp.status_code == 200 and len(resp.json()["servers"]) == 2 From 555082bf57f1f2aaa29ba94eafac7d612d6d2ed8 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 14:48:49 -0800 Subject: [PATCH 10/20] Wip Signed-off-by: Kourosh Hakhamaneshi --- .../tests/inference_servers/test_router.py | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/skyrl-train/tests/inference_servers/test_router.py b/skyrl-train/tests/inference_servers/test_router.py index 7bccb568b..617c5ec26 100644 --- a/skyrl-train/tests/inference_servers/test_router.py +++ b/skyrl-train/tests/inference_servers/test_router.py @@ -1,25 +1,19 @@ """Tests for InferenceRouter.""" import asyncio -import socket import threading import time +from typing import List import httpx import pytest import uvicorn from fastapi import FastAPI +from skyrl_train.inference_servers.common import get_open_port from skyrl_train.inference_servers.router import InferenceRouter -def get_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) - s.listen(1) - return s.getsockname()[1] - - def create_mock_server(server_id: int) -> FastAPI: app = FastAPI() @@ -30,11 +24,17 @@ async def catch_all(path: str): return app -def start_server(port: int, server_id: int) -> None: +def start_server(port: int, server_id: int) -> uvicorn.Server: + """Start a mock server, return the server instance for cleanup.""" app = create_mock_server(server_id) config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") server = uvicorn.Server(config) - threading.Thread(target=lambda: asyncio.run(server.serve()), daemon=True).start() + + def run(): + asyncio.run(server.serve()) + + threading.Thread(target=run, daemon=True).start() + return server def wait_ready(url: str, timeout: float = 5.0) -> bool: @@ -50,27 +50,45 @@ def wait_ready(url: str, timeout: float = 5.0) -> bool: @pytest.fixture(scope="module") def env(): - """Start mock servers and router once for all tests.""" - ports = [get_free_port(), get_free_port()] - router_port = get_free_port() + """Start mock servers and router, clean up after tests.""" + servers: List[uvicorn.Server] = [] + + # Start mock servers + ports = [get_open_port(), get_open_port()] + router_port = get_open_port() urls = [f"http://127.0.0.1:{p}" for p in ports] for i, port in enumerate(ports): - start_server(port, server_id=i) + servers.append(start_server(port, server_id=i)) for url in urls: assert wait_ready(url) + # Start router router = InferenceRouter(urls, host="127.0.0.1", port=router_port) router._client = httpx.AsyncClient(timeout=httpx.Timeout(None)) router._app = router._build_app() - threading.Thread( - target=lambda: asyncio.run(router._run_server()), daemon=True - ).start() + + router_config = uvicorn.Config( + router._app, host="127.0.0.1", port=router_port, log_level="error" + ) + router_server = uvicorn.Server(router_config) + servers.append(router_server) + + def run_router(): + asyncio.run(router_server.serve()) + + threading.Thread(target=run_router, daemon=True).start() router_url = f"http://127.0.0.1:{router_port}" assert wait_ready(router_url) + yield router_url + # Cleanup: signal all servers to shutdown + for server in servers: + server.should_exit = True + time.sleep(0.5) # Give servers time to shutdown + def test_round_robin(env): """Requests without session distribute across servers.""" From afcc8decb6e4fe7a9f1dd07eb37d91e5850926a1 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 15:26:43 -0800 Subject: [PATCH 11/20] wip Signed-off-by: Kourosh Hakhamaneshi --- skyrl-train/tests/{ => cpu}/inference_servers/test_common.py | 0 skyrl-train/tests/{ => cpu}/inference_servers/test_router.py | 0 skyrl-train/tests/inference_servers/__init__.py | 1 - 3 files changed, 1 deletion(-) rename skyrl-train/tests/{ => cpu}/inference_servers/test_common.py (100%) rename skyrl-train/tests/{ => cpu}/inference_servers/test_router.py (100%) delete mode 100644 skyrl-train/tests/inference_servers/__init__.py diff --git a/skyrl-train/tests/inference_servers/test_common.py b/skyrl-train/tests/cpu/inference_servers/test_common.py similarity index 100% rename from skyrl-train/tests/inference_servers/test_common.py rename to skyrl-train/tests/cpu/inference_servers/test_common.py diff --git a/skyrl-train/tests/inference_servers/test_router.py b/skyrl-train/tests/cpu/inference_servers/test_router.py similarity index 100% rename from skyrl-train/tests/inference_servers/test_router.py rename to skyrl-train/tests/cpu/inference_servers/test_router.py diff --git a/skyrl-train/tests/inference_servers/__init__.py b/skyrl-train/tests/inference_servers/__init__.py deleted file mode 100644 index 55f4e2b47..000000000 --- a/skyrl-train/tests/inference_servers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Tests for inference_servers module From 058cb95a6aef950ed52b3839c80e857d9c713440 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 19:02:35 -0800 Subject: [PATCH 12/20] Wip Signed-off-by: Kourosh Hakhamaneshi --- skyrl-train/skyrl_train/env_vars.py | 19 +- .../inference_servers/protocols.py | 2 + .../skyrl_train/inference_servers/router.py | 70 ++- .../inference_servers/server_group.py | 1 + .../inference_servers/vllm_server_actor.py | 24 +- skyrl-train/tests/gpu/gpu_ci/conftest.py | 7 +- .../gpu/gpu_ci/test_inference_server_group.py | 452 ++++++++++++++++++ 7 files changed, 529 insertions(+), 46 deletions(-) create mode 100644 skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py diff --git a/skyrl-train/skyrl_train/env_vars.py b/skyrl-train/skyrl_train/env_vars.py index a677ea98f..09ae5356f 100644 --- a/skyrl-train/skyrl_train/env_vars.py +++ b/skyrl-train/skyrl_train/env_vars.py @@ -5,4 +5,21 @@ SKYRL_VLLM_DP_PORT_OFFSET = int(os.environ.get("SKYRL_VLLM_DP_PORT_OFFSET", 500)) -SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S = int(os.environ.get("SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S", 600)) \ No newline at end of file +""" +Offset for the data parallel port of the vLLM server. +""" +SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S = int(os.environ.get("SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S", 600)) +""" +Timeout for waiting until the inference server is healthy. +""" +SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV = str(os.environ.get("SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV", "False")).lower() in ( + "true", + "1", + "yes", +) +""" +Whether to include the PYTHONPATH environment variable in the runtime +environment. In case of using ray nightly, this will be needed to avoid +dependencies issues by setting it to the local path where ray nightly is +installed. +""" \ No newline at end of file diff --git a/skyrl-train/skyrl_train/inference_servers/protocols.py b/skyrl-train/skyrl_train/inference_servers/protocols.py index 324cc5f00..ced4ff481 100644 --- a/skyrl-train/skyrl_train/inference_servers/protocols.py +++ b/skyrl-train/skyrl_train/inference_servers/protocols.py @@ -51,6 +51,7 @@ def __init__( cli_args: Namespace, start_port: int, server_idx: int, + start_bundle_idx: int, dp_size: int, dp_master_address: Optional[str], dp_rpc_port: Optional[int], @@ -64,6 +65,7 @@ def __init__( cli_args: Engine-specific CLI arguments. start_port: Base port to search for available port. server_idx: Index of this server in the group (0-indexed). + start_bundle_idx: Starting bundle index in placement group for this server's workers. dp_size: Data parallel size (-1 to disable DP). dp_master_address: DP master address (for non-rank-0 servers). dp_rpc_port: DP RPC port (for non-rank-0 servers). diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py index ba4f4a66c..db90ed5f4 100644 --- a/skyrl-train/skyrl_train/inference_servers/router.py +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -17,18 +17,8 @@ logger = logging.getLogger(__name__) -# Routes that are loaded balanced (data plane) -DATA_PLANE_ROUTES = [ - "/v1/completions", - "/v1/chat/completions", - "/tokenize", - "/detokenize", - "/health", - "/models", - "/version", -] - -# Routes that go to ALL backends via a broadcast (control plane) +# Control plane routes are broadcast to ALL servers. +# Everything else (data plane) is load-balanced to ONE server. CONTROL_PLANE_ROUTES = [ # BUILT-IN ROUTES "/pause", @@ -132,13 +122,8 @@ async def list_servers(): @app.get("/get_server_info") async def get_server_info(): - """Fetch server info from first server (all should return same).""" - server_url = self._server_urls[0] - try: - resp = await self._client.get(f"{server_url}/get_server_info", timeout=10.0) - return resp.json() - except Exception as e: - return {"error": str(e)} + """Fetch server info from all servers, return mapping.""" + return await self._fan_out_get("/get_server_info") # Catch-all: proxy everything else to backends @app.api_route( @@ -170,6 +155,18 @@ def _forward_headers(self, request: Request) -> dict: if k.lower() not in ("host", "content-length", "transfer-encoding") } + async def _fan_out_get(self, path: str) -> dict: + """Fan out a GET request to all servers, return mapping of server_url -> response.""" + async def call_server(server_url: str): + try: + resp = await self._client.get(f"{server_url}{path}", timeout=30.0) + return server_url, resp.json() if resp.content else None + except Exception as e: + return server_url, {"error": str(e)} + + results = await asyncio.gather(*[call_server(url) for url in self._server_urls]) + return {url: response for url, response in results} + async def _proxy_to_one(self, request: Request, path: str) -> Response: """Proxy request to one server (data plane).""" server_url = self._get_server_for_request(request) @@ -192,7 +189,9 @@ async def _proxy_to_one(self, request: Request, path: str) -> Response: ) async def _proxy_to_all(self, request: Request, path: str) -> Response: - """Proxy request to all servers (control plane), aggregate responses.""" + """Proxy request to all servers (control plane), return mapping of responses.""" + import json + method = request.method body = await request.body() @@ -209,14 +208,12 @@ async def call_server(server_url: str): headers=headers, content=body, ) - return { - "url": server_url, + return server_url, { "status": response.status_code, "body": response.json() if response.content else None, } except Exception as e: - return { - "url": server_url, + return server_url, { "status": 500, "error": str(e), } @@ -225,23 +222,18 @@ async def call_server(server_url: str): *[call_server(url) for url in self._server_urls] ) + # Build mapping from server_url to response + response_map = {url: resp for url, resp in results} + # Check if all succeeded - all_ok = all(r.get("status") == 200 for r in results) - - if all_ok: - return Response( - content='{"status": "ok"}', - status_code=200, - media_type="application/json", - ) - else: - import json + all_ok = all(r.get("status") == 200 for r in response_map.values()) + status_code = 200 if all_ok else 207 # Multi-Status on partial failure - return Response( - content=json.dumps({"status": "partial_failure", "results": results}), - status_code=207, # Multi-Status - media_type="application/json", - ) + return Response( + content=json.dumps(response_map), + status_code=status_code, + media_type="application/json", + ) def start(self) -> str: """ diff --git a/skyrl-train/skyrl_train/inference_servers/server_group.py b/skyrl-train/skyrl_train/inference_servers/server_group.py index 33e700580..cd25bc6af 100644 --- a/skyrl-train/skyrl_train/inference_servers/server_group.py +++ b/skyrl-train/skyrl_train/inference_servers/server_group.py @@ -135,6 +135,7 @@ def _create_actors(self) -> List[Any]: self._cli_args, self._start_port + server_idx, server_idx=server_idx, + start_bundle_idx=start_bundle_idx, dp_size=self._num_servers if self._enable_dp else -1, dp_master_address=dp_address, dp_rpc_port=dp_rpc_port, diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py index e6ab3e445..a2e53a18b 100644 --- a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py +++ b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py @@ -66,6 +66,7 @@ def __init__( vllm_cli_args: Namespace, start_port: int = 8000, server_idx: int = 0, + start_bundle_idx: int = 0, dp_size: int = -1, dp_master_address: Optional[str] = None, dp_rpc_port: Optional[int] = None, @@ -82,6 +83,7 @@ def __init__( Optional: uvicorn_log_level, ssl_*, disable_uvicorn_access_log, kv_transfer_config. start_port: Base port to start searching for free port server_idx: Index of this server in the group + start_bundle_idx: Starting bundle index in the placement group for this server's workers dp_size: Data parallel size (-1 to disable) dp_master_address: DP master address (for non-rank-0 servers) dp_rpc_port: DP RPC port (for non-rank-0 servers) @@ -97,6 +99,9 @@ def __init__( # Ensure SkyRL's custom worker extension is used for weight sync self._ensure_worker_extension() + # Ensure Ray executor is used (required for GPU inheritance in placement groups) + self._ensure_ray_executor() + # Update args with our assigned host/port self._cli_args.host = "0.0.0.0" self._cli_args.port = self._port @@ -126,10 +131,8 @@ def __init__( f"dp_master_address={dp_master_address}, dp_rpc_port={dp_rpc_port}" ) - # Compute bundle indices for this server's TP/PP workers - # Each server uses a contiguous slice of bundles in the placement group - start_bundle = server_idx * self._num_gpus_per_server - bundle_indices = list(range(start_bundle, start_bundle + self._num_gpus_per_server)) + # Set bundle indices for this server's TP/PP workers in the placement group + bundle_indices = list(range(start_bundle_idx, start_bundle_idx + self._num_gpus_per_server)) os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) logger.info(f"Server {server_idx}: using bundle indices {bundle_indices}") @@ -150,6 +153,17 @@ def _ensure_worker_extension(self) -> None: else: logger.info(f"Using provided worker extension: {self._cli_args.worker_extension_cls}") + def _ensure_ray_executor(self) -> None: + """ + Ensure Ray is used as the distributed executor backend. + + When running inside a Ray actor, we must use the Ray executor so that + workers are spawned and properly inherit GPU allocation from the + placement group. + """ + if not hasattr(self._cli_args, "distributed_executor_backend") or self._cli_args.distributed_executor_backend != "ray": + self._cli_args.distributed_executor_backend = "ray" + def _setup_nixl_side_channel(self, base_port: int) -> None: """ Setup NIXL side channel for PD disaggregation. @@ -248,7 +262,7 @@ async def _run_server(self) -> None: # Initialize the engine (this loads the model - takes time) engine_args = AsyncEngineArgs.from_cli_args(self._cli_args) - self._engine = AsyncLLMEngine.from_cli_args( + self._engine = AsyncLLMEngine.from_engine_args( engine_args=engine_args, usage_context=UsageContext.OPENAI_API_SERVER, ) diff --git a/skyrl-train/tests/gpu/gpu_ci/conftest.py b/skyrl-train/tests/gpu/gpu_ci/conftest.py index e6f35b11c..6e60ab6a2 100644 --- a/skyrl-train/tests/gpu/gpu_ci/conftest.py +++ b/skyrl-train/tests/gpu/gpu_ci/conftest.py @@ -1,8 +1,10 @@ +import os import pytest import ray from loguru import logger from functools import lru_cache from skyrl_train.utils.utils import peer_access_supported +from skyrl_train.env_vars import SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV @lru_cache(5) @@ -11,7 +13,7 @@ def log_once(msg): return None -@pytest.fixture +@pytest.fixture(scope="class") def ray_init_fixture(): if ray.is_initialized(): ray.shutdown() @@ -31,6 +33,9 @@ def ray_init_fixture(): # needed for megatron tests env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env_vars["NVTE_FUSED_ATTN"] = "0" + + if SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV: + env_vars["PYTHONPATH"] = os.environ.get("PYTHONPATH") logger.info(f"Initializing Ray with environment variables: {env_vars}") ray.init(runtime_env={"env_vars": env_vars}) diff --git a/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py b/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py new file mode 100644 index 000000000..bb87fce40 --- /dev/null +++ b/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py @@ -0,0 +1,452 @@ +""" +GPU CI tests for inference server infrastructure. + +Test Suite 1: ServerGroup + InferenceRouter + - 2 vLLM servers with TP=2 (4 GPUs total) + - Router with load balancing and control plane fan-out + - Tests: health, completions, get_server_info, session affinity + +Test Suite 2: Weight Update Flow + - 1 vLLM server with TP=2 + dummy weights + - Trainer emulation with real weights + - Tests: pause/resume, init_weight_transfer, weight sync effectiveness + +Run: + uv run pytest tests/gpu/gpu_ci/test_inference_server_group.py -v -s +""" + +import asyncio +import time + +import httpx +import pytest +import ray +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +import torch +from transformers import AutoModelForCausalLM +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from skyrl_train.inference_servers.router import InferenceRouter +from skyrl_train.inference_servers.server_group import ServerGroup +from skyrl_train.inference_servers.common import get_open_port, get_node_ip + +MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + +# ============================================================================= +# Utility: Skip tests if not enough GPUs +# ============================================================================= + +# Skip entire module if not enough GPUs +_gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 +if _gpu_count < 4: + pytest.skip(f"Need 4 GPUs for full test suite, found {_gpu_count}", allow_module_level=True) + + + +def make_vllm_cli_args( + model: str, + tp_size: int = 2, + load_format: str = "auto", +) -> FlexibleArgumentParser: + """Create CLI args for vLLM server using official parser.""" + parser = FlexibleArgumentParser(description="vLLM server") + parser = make_arg_parser(parser) + return parser.parse_args([ + "--model", model, + "--tensor-parallel-size", str(tp_size), + "--enforce-eager", + "--gpu-memory-utilization", "0.5", + "--max-model-len", "2048", + "--load-format", load_format, + ]) + + +def wait_for_url(url: str, timeout: float = 180.0) -> bool: + """Wait for a URL to become available.""" + start = time.time() + while time.time() - start < timeout: + try: + resp = httpx.get(f"{url}/health", timeout=5.0) + if resp.status_code == 200: + return True + except httpx.RequestError: + time.sleep(2.0) + return False + + +# ============================================================================= +# Test Suite 1: ServerGroup + Router (2 servers, TP=2 each, 4 GPUs) +# ============================================================================= + +@pytest.fixture(scope="class") +def server_group_and_router(ray_init_fixture): + """Create 2 vLLM servers (TP=2 each) + router.""" + cli_args = make_vllm_cli_args(MODEL, tp_size=2) + start_port = get_open_port() + + # Create server group with 2 servers + group = ServerGroup( + cli_args=cli_args, + num_servers=2, + start_port=start_port, + ) + server_infos = group.start() + server_urls = [info.url for info in server_infos] + + # Wait for servers + for url in server_urls: + assert wait_for_url(url), f"Server {url} failed to start" + + # Create router + router_port = get_open_port() + router = InferenceRouter(server_urls, host="0.0.0.0", port=router_port) + router_url = router.start() + assert wait_for_url(router_url), "Router failed to start" + + yield { + "group": group, + "server_urls": server_urls, + "router": router, + "router_url": router_url, + } + + router.shutdown() + group.shutdown() + del group + del router + + +class TestServerGroupAndRouter: + """Tests for ServerGroup + InferenceRouter with 2 TP=2 servers.""" + + def test_health_check(self, server_group_and_router): + """Health endpoint works through router.""" + router_url = server_group_and_router["router_url"] + resp = httpx.get(f"{router_url}/health", timeout=10.0) + assert resp.status_code == 200 + + def test_list_servers(self, server_group_and_router): + """/servers returns all backends.""" + router_url = server_group_and_router["router_url"] + resp = httpx.get(f"{router_url}/servers", timeout=10.0) + assert resp.status_code == 200 + assert len(resp.json()["servers"]) == 2 + + def test_get_server_info(self, server_group_and_router): + """/get_server_info returns mapping of server_url -> info for all servers.""" + router_url = server_group_and_router["router_url"] + server_urls = server_group_and_router["server_urls"] + + resp = httpx.get(f"{router_url}/get_server_info", timeout=10.0) + assert resp.status_code == 200 + info_map = resp.json() + print(f"Server info map: {info_map}") + + # Should have info for each server + assert len(info_map) == 2 + for url in server_urls: + assert url in info_map + server_info = info_map[url] + # Each server has TP=2, so per-server world_size=2 + assert server_info["world_size"] == 2 + + def test_completion_request(self, server_group_and_router): + """Completion requests work through router.""" + router_url = server_group_and_router["router_url"] + + payload = { + "model": MODEL, + "prompt": "What is 2 + 2? Answer:", + "max_tokens": 16, + "temperature": 0.0, + } + + resp = httpx.post(f"{router_url}/v1/completions", json=payload, timeout=60.0) + assert resp.status_code == 200 + data = resp.json() + assert "choices" in data + assert len(data["choices"]) > 0 + assert "text" in data["choices"][0] + print(f"Completion: {data['choices'][0]['text']}") + + @pytest.mark.asyncio + async def test_pause_resume(self, server_group_and_router): + """Pause/resume control plane routes work.""" + router_url = server_group_and_router["router_url"] + + async with httpx.AsyncClient() as client: + # Pause + resp = await client.post(f"{router_url}/pause", json={"wait_for_inflight_request": False}, timeout=30.0) + assert resp.status_code == 200 + + # Check is paused + resp = await client.get(f"{router_url}/is_paused", timeout=30.0) + assert resp.status_code == 200 + assert resp.json()["is_paused"] == True + + # Send a request while paused (should block) + async def send_request(): + r = await client.post(f"{router_url}/v1/completions", json={"model": MODEL, "prompt": "Test", "max_tokens": 4}, timeout=60.0) + assert r.status_code == 200 + return r.json() + + task = asyncio.create_task(send_request()) + await asyncio.sleep(1) + + # Task should not be done here (request blocked by pause) + assert not task.done() + + # Resume + resp = await client.post(f"{router_url}/resume", json={}, timeout=30.0) + assert resp.status_code == 200 + + # Verify that after resume, the request is completed + result = await task + assert result["choices"][0]["text"] is not None + + + +# ============================================================================= +# Test Suite 2: Weight Update Flow (1 server TP=2 + trainer emulation) +# ============================================================================= + +class Trainer: + """ + Simple trainer emulator that holds the real model weights. + + This is a simplified version of the trainer side for testing weight sync. + Non-colocated: runs on a separate GPU from the inference server. + """ + + def __init__(self, model_name: str, device: str = "cuda"): + self.device = torch.device(device) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + ).to(self.device) + self.pg = None + self.model_name = model_name + + def ready(self): + """Check if the trainer is ready.""" + return True + + def init_weight_sync(self, master_address: str, master_port: int, world_size: int, group_name: str): + """Initialize the weight sync process group as rank 0 (trainer).""" + from skyrl_train.distributed.utils import init_custom_process_group + from skyrl_train.utils import get_tcp_url + + self.pg = init_custom_process_group( + backend="nccl", + init_method=get_tcp_url(master_address, master_port), + world_size=world_size, + rank=0, # Trainer is always rank 0 + group_name=group_name, + ) + return True + + def get_weight_info(self) -> dict: + """ + Get weight metadata (names, dtypes, shapes) without doing NCCL. + + Returns: + dict with names, dtypes, shapes for the weight update request. + """ + names = [] + dtypes = [] + shapes = [] + + for name, param in self.model.named_parameters(): + names.append(name) + dtypes.append(str(param.dtype).split(".")[-1]) # e.g. "bfloat16" + shapes.append(list(param.shape)) + + return {"names": names, "dtypes": dtypes, "shapes": shapes} + + def broadcast_weights(self): + """ + Broadcast all model weights to inference workers via NCCL. + + This is a blocking operation - server must call receive concurrently. + """ + for name, param in self.model.named_parameters(): + torch.distributed.broadcast(param.data, src=0, group=self.pg) + torch.cuda.synchronize() + + def teardown(self): + """Clean up the process group.""" + if self.pg is not None: + torch.distributed.destroy_process_group(self.pg) + self.pg = None + + +@pytest.fixture(scope="class") +def weight_update_env(ray_init_fixture): + """ + Create environment for weight update testing: + - Trainer with real weights on GPU 0 + - 1 vLLM server with TP=2 and DUMMY weights (uses GPU 1,2) + - Router to proxy requests + """ + # Create server with dummy weights + cli_args = make_vllm_cli_args(MODEL, tp_size=2, load_format="dummy") + start_port = get_open_port() + + pg = placement_group([{"CPU": 1, "GPU": 1} for _ in range(3)]) + ray.get(pg.ready()) + + trainer = ray.remote(Trainer).options( + num_gpus=1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=0, + ), + ).remote(MODEL) + + ray.get(trainer.ready.remote()) + + group = ServerGroup( + cli_args=cli_args, + num_servers=1, + start_port=start_port, + placement_group=pg, + placement_group_bundle_offset=1, + ) + server_infos = group.start() + server_urls = [info.url for info in server_infos] + + for url in server_urls: + assert wait_for_url(url), f"Server {url} failed to start" + + # Create router + router_port = get_open_port() + router = InferenceRouter(server_urls, host="0.0.0.0", port=router_port) + router_url = router.start() + assert wait_for_url(router_url), "Router failed to start" + + yield { + "group": group, + "server_urls": server_urls, + "router": router, + "router_url": router_url, + "trainer": trainer, + } + + router.shutdown() + group.shutdown() + ray.get(trainer.teardown.remote()) + del router + del group + del trainer + del pg + +class TestWeightUpdateFlow: + """Tests for weight synchronization from trainer to inference server.""" + + @pytest.mark.asyncio + async def test_update_weights_flow(self, weight_update_env): + """ + Full E2E weight sync test via router: + 1. Query with dummy weights → gibberish + 2. Init weight transfer (both sides concurrently via router) + 3. Broadcast weights from trainer (concurrent with server receive) + 4. Finalize weight update + 5. Query again → correct output + """ + router_url = weight_update_env["router_url"] + trainer = weight_update_env["trainer"] + + async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as client: + # ===== Step 1: Verify dummy weights produce gibberish ===== + payload = { + "model": MODEL, + "prompt": "What is the capital of France?", + "max_tokens": 32, + "temperature": 0.0, + } + + resp = await client.post(f"{router_url}/v1/completions", json=payload) + assert resp.status_code == 200 + + text_before = resp.json()["choices"][0]["text"] + print(f"[Step 1] Dummy weights output: {text_before!r}") + + # Dummy weights should NOT produce coherent output about Paris + assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" + + # ===== Step 2: Init weight transfer (both sides concurrently) ===== + master_address = get_node_ip() + master_port = get_open_port() + + # Query all servers for world_size (TP * PP) via router + resp = await client.get(f"{router_url}/get_server_info") + assert resp.status_code == 200 + server_info_map = resp.json() + # Sum world_size across all servers + inference_world_size = sum(info["world_size"] for info in server_info_map.values()) + world_size = 1 + inference_world_size # 1 trainer + all inference workers + group_name = f"weight_sync_test_{master_port}" + + print(f"[Step 2] Init weight transfer: master={master_address}:{master_port}, world_size={world_size}") + + init_info = { + "master_addr": master_address, + "master_port": master_port, + "rank_offset": 1, + "world_size": world_size, + "group_name": group_name, + "backend": "nccl", + "model_dtype_str": "bfloat16", + "override_existing_receiver": True, + } + + # Both sides must init concurrently (NCCL blocks until all ranks join) + # Start trainer init (returns immediately, runs in Ray actor) + trainer_init_ref = trainer.init_weight_sync.remote(master_address, master_port, world_size, group_name) + + # Await server init (triggers NCCL join on server side) + server_resp = await client.post(f"{router_url}/init_weight_transfer", json=init_info) + assert server_resp.status_code == 200, f"Server init failed: {server_resp.text}" + + # Trainer should be done now (NCCL group formed) + ray.get(trainer_init_ref) + print("[Step 2] Both sides init complete") + + # ===== Step 3: Broadcast weights (concurrent send/receive) ===== + print("[Step 3] Broadcasting weights from trainer to server...") + + # Get weight metadata first (no NCCL yet) + weight_info = ray.get(trainer.get_weight_info.remote()) + print(f"[Step 3] Weight info: {len(weight_info['names'])} parameters") + + # Start trainer broadcast (returns immediately, runs in Ray actor) + trainer_broadcast_ref = trainer.broadcast_weights.remote() + + # Await server receive (triggers NCCL receive on server side) + server_resp = await client.post(f"{router_url}/update_weights", json=weight_info) + assert server_resp.status_code == 200, f"Update weights failed: {server_resp.text}" + + # Trainer should be done now (NCCL broadcast complete) + ray.get(trainer_broadcast_ref) + print("[Step 3] Weight sync complete") + + # ===== Step 4: Finalize weight update ===== + resp = await client.post(f"{router_url}/finalize_weight_update", json={}) + assert resp.status_code == 200 + print("[Step 4] Weight update finalized") + + # ===== Step 5: Query again - should produce correct output ===== + resp = await client.post(f"{router_url}/v1/completions", json=payload) + assert resp.status_code == 200 + + text_after = resp.json()["choices"][0]["text"] + print(f"[Step 5] Real weights output: {text_after!r}") + + assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" + + print("[SUCCESS] Weight sync test passed!") + From 7c8fc0bcfabfa2946e3b050909ecc6fbbe296ee8 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 19:14:17 -0800 Subject: [PATCH 13/20] wip Signed-off-by: Kourosh Hakhamaneshi --- .../gpu/gpu_ci/test_inference_server_group.py | 272 +--------------- .../tests/gpu/gpu_ci/test_weight_sync.py | 301 ++++++++++++++++++ 2 files changed, 305 insertions(+), 268 deletions(-) create mode 100644 skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py diff --git a/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py b/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py index bb87fce40..e4a7bd186 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py @@ -1,15 +1,10 @@ """ -GPU CI tests for inference server infrastructure. +GPU CI tests for ServerGroup + InferenceRouter. -Test Suite 1: ServerGroup + InferenceRouter +Tests: - 2 vLLM servers with TP=2 (4 GPUs total) - Router with load balancing and control plane fan-out - - Tests: health, completions, get_server_info, session affinity - -Test Suite 2: Weight Update Flow - - 1 vLLM server with TP=2 + dummy weights - - Trainer emulation with real weights - - Tests: pause/resume, init_weight_transfer, weight sync effectiveness + - Health, completions, get_server_info, session affinity, pause/resume Run: uv run pytest tests/gpu/gpu_ci/test_inference_server_group.py -v -s @@ -20,33 +15,23 @@ import httpx import pytest -import ray -from ray.util.placement_group import placement_group -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - import torch -from transformers import AutoModelForCausalLM from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.utils.argparse_utils import FlexibleArgumentParser from skyrl_train.inference_servers.router import InferenceRouter from skyrl_train.inference_servers.server_group import ServerGroup -from skyrl_train.inference_servers.common import get_open_port, get_node_ip +from skyrl_train.inference_servers.common import get_open_port MODEL = "Qwen/Qwen2.5-0.5B-Instruct" -# ============================================================================= -# Utility: Skip tests if not enough GPUs -# ============================================================================= - # Skip entire module if not enough GPUs _gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 if _gpu_count < 4: pytest.skip(f"Need 4 GPUs for full test suite, found {_gpu_count}", allow_module_level=True) - def make_vllm_cli_args( model: str, tp_size: int = 2, @@ -78,10 +63,6 @@ def wait_for_url(url: str, timeout: float = 180.0) -> bool: return False -# ============================================================================= -# Test Suite 1: ServerGroup + Router (2 servers, TP=2 each, 4 GPUs) -# ============================================================================= - @pytest.fixture(scope="class") def server_group_and_router(ray_init_fixture): """Create 2 vLLM servers (TP=2 each) + router.""" @@ -116,8 +97,6 @@ def server_group_and_router(ray_init_fixture): router.shutdown() group.shutdown() - del group - del router class TestServerGroupAndRouter: @@ -207,246 +186,3 @@ async def send_request(): # Verify that after resume, the request is completed result = await task assert result["choices"][0]["text"] is not None - - - -# ============================================================================= -# Test Suite 2: Weight Update Flow (1 server TP=2 + trainer emulation) -# ============================================================================= - -class Trainer: - """ - Simple trainer emulator that holds the real model weights. - - This is a simplified version of the trainer side for testing weight sync. - Non-colocated: runs on a separate GPU from the inference server. - """ - - def __init__(self, model_name: str, device: str = "cuda"): - self.device = torch.device(device) - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.bfloat16, - ).to(self.device) - self.pg = None - self.model_name = model_name - - def ready(self): - """Check if the trainer is ready.""" - return True - - def init_weight_sync(self, master_address: str, master_port: int, world_size: int, group_name: str): - """Initialize the weight sync process group as rank 0 (trainer).""" - from skyrl_train.distributed.utils import init_custom_process_group - from skyrl_train.utils import get_tcp_url - - self.pg = init_custom_process_group( - backend="nccl", - init_method=get_tcp_url(master_address, master_port), - world_size=world_size, - rank=0, # Trainer is always rank 0 - group_name=group_name, - ) - return True - - def get_weight_info(self) -> dict: - """ - Get weight metadata (names, dtypes, shapes) without doing NCCL. - - Returns: - dict with names, dtypes, shapes for the weight update request. - """ - names = [] - dtypes = [] - shapes = [] - - for name, param in self.model.named_parameters(): - names.append(name) - dtypes.append(str(param.dtype).split(".")[-1]) # e.g. "bfloat16" - shapes.append(list(param.shape)) - - return {"names": names, "dtypes": dtypes, "shapes": shapes} - - def broadcast_weights(self): - """ - Broadcast all model weights to inference workers via NCCL. - - This is a blocking operation - server must call receive concurrently. - """ - for name, param in self.model.named_parameters(): - torch.distributed.broadcast(param.data, src=0, group=self.pg) - torch.cuda.synchronize() - - def teardown(self): - """Clean up the process group.""" - if self.pg is not None: - torch.distributed.destroy_process_group(self.pg) - self.pg = None - - -@pytest.fixture(scope="class") -def weight_update_env(ray_init_fixture): - """ - Create environment for weight update testing: - - Trainer with real weights on GPU 0 - - 1 vLLM server with TP=2 and DUMMY weights (uses GPU 1,2) - - Router to proxy requests - """ - # Create server with dummy weights - cli_args = make_vllm_cli_args(MODEL, tp_size=2, load_format="dummy") - start_port = get_open_port() - - pg = placement_group([{"CPU": 1, "GPU": 1} for _ in range(3)]) - ray.get(pg.ready()) - - trainer = ray.remote(Trainer).options( - num_gpus=1, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=pg, - placement_group_bundle_index=0, - ), - ).remote(MODEL) - - ray.get(trainer.ready.remote()) - - group = ServerGroup( - cli_args=cli_args, - num_servers=1, - start_port=start_port, - placement_group=pg, - placement_group_bundle_offset=1, - ) - server_infos = group.start() - server_urls = [info.url for info in server_infos] - - for url in server_urls: - assert wait_for_url(url), f"Server {url} failed to start" - - # Create router - router_port = get_open_port() - router = InferenceRouter(server_urls, host="0.0.0.0", port=router_port) - router_url = router.start() - assert wait_for_url(router_url), "Router failed to start" - - yield { - "group": group, - "server_urls": server_urls, - "router": router, - "router_url": router_url, - "trainer": trainer, - } - - router.shutdown() - group.shutdown() - ray.get(trainer.teardown.remote()) - del router - del group - del trainer - del pg - -class TestWeightUpdateFlow: - """Tests for weight synchronization from trainer to inference server.""" - - @pytest.mark.asyncio - async def test_update_weights_flow(self, weight_update_env): - """ - Full E2E weight sync test via router: - 1. Query with dummy weights → gibberish - 2. Init weight transfer (both sides concurrently via router) - 3. Broadcast weights from trainer (concurrent with server receive) - 4. Finalize weight update - 5. Query again → correct output - """ - router_url = weight_update_env["router_url"] - trainer = weight_update_env["trainer"] - - async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as client: - # ===== Step 1: Verify dummy weights produce gibberish ===== - payload = { - "model": MODEL, - "prompt": "What is the capital of France?", - "max_tokens": 32, - "temperature": 0.0, - } - - resp = await client.post(f"{router_url}/v1/completions", json=payload) - assert resp.status_code == 200 - - text_before = resp.json()["choices"][0]["text"] - print(f"[Step 1] Dummy weights output: {text_before!r}") - - # Dummy weights should NOT produce coherent output about Paris - assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" - - # ===== Step 2: Init weight transfer (both sides concurrently) ===== - master_address = get_node_ip() - master_port = get_open_port() - - # Query all servers for world_size (TP * PP) via router - resp = await client.get(f"{router_url}/get_server_info") - assert resp.status_code == 200 - server_info_map = resp.json() - # Sum world_size across all servers - inference_world_size = sum(info["world_size"] for info in server_info_map.values()) - world_size = 1 + inference_world_size # 1 trainer + all inference workers - group_name = f"weight_sync_test_{master_port}" - - print(f"[Step 2] Init weight transfer: master={master_address}:{master_port}, world_size={world_size}") - - init_info = { - "master_addr": master_address, - "master_port": master_port, - "rank_offset": 1, - "world_size": world_size, - "group_name": group_name, - "backend": "nccl", - "model_dtype_str": "bfloat16", - "override_existing_receiver": True, - } - - # Both sides must init concurrently (NCCL blocks until all ranks join) - # Start trainer init (returns immediately, runs in Ray actor) - trainer_init_ref = trainer.init_weight_sync.remote(master_address, master_port, world_size, group_name) - - # Await server init (triggers NCCL join on server side) - server_resp = await client.post(f"{router_url}/init_weight_transfer", json=init_info) - assert server_resp.status_code == 200, f"Server init failed: {server_resp.text}" - - # Trainer should be done now (NCCL group formed) - ray.get(trainer_init_ref) - print("[Step 2] Both sides init complete") - - # ===== Step 3: Broadcast weights (concurrent send/receive) ===== - print("[Step 3] Broadcasting weights from trainer to server...") - - # Get weight metadata first (no NCCL yet) - weight_info = ray.get(trainer.get_weight_info.remote()) - print(f"[Step 3] Weight info: {len(weight_info['names'])} parameters") - - # Start trainer broadcast (returns immediately, runs in Ray actor) - trainer_broadcast_ref = trainer.broadcast_weights.remote() - - # Await server receive (triggers NCCL receive on server side) - server_resp = await client.post(f"{router_url}/update_weights", json=weight_info) - assert server_resp.status_code == 200, f"Update weights failed: {server_resp.text}" - - # Trainer should be done now (NCCL broadcast complete) - ray.get(trainer_broadcast_ref) - print("[Step 3] Weight sync complete") - - # ===== Step 4: Finalize weight update ===== - resp = await client.post(f"{router_url}/finalize_weight_update", json={}) - assert resp.status_code == 200 - print("[Step 4] Weight update finalized") - - # ===== Step 5: Query again - should produce correct output ===== - resp = await client.post(f"{router_url}/v1/completions", json=payload) - assert resp.status_code == 200 - - text_after = resp.json()["choices"][0]["text"] - print(f"[Step 5] Real weights output: {text_after!r}") - - assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" - - print("[SUCCESS] Weight sync test passed!") - diff --git a/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py b/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py new file mode 100644 index 000000000..0c5774c93 --- /dev/null +++ b/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py @@ -0,0 +1,301 @@ +""" +GPU CI tests for weight synchronization from trainer to inference server. + +Tests: + - 1 vLLM server with TP=2 + dummy weights + - Trainer emulation with real weights on separate GPU + - Weight sync via NCCL broadcast through router + +Run: + uv run pytest tests/gpu/gpu_ci/test_weight_sync.py -v -s +""" + +import time + +import httpx +import pytest +import ray +import torch +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from transformers import AutoModelForCausalLM +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from skyrl_train.inference_servers.router import InferenceRouter +from skyrl_train.inference_servers.server_group import ServerGroup +from skyrl_train.inference_servers.common import get_open_port, get_node_ip + +MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + +# Skip entire module if not enough GPUs (need 3: 1 trainer + 2 for TP=2 server) +_gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 +if _gpu_count < 3: + pytest.skip(f"Need 3 GPUs for weight sync test, found {_gpu_count}", allow_module_level=True) + + +def make_vllm_cli_args( + model: str, + tp_size: int = 2, + load_format: str = "auto", +) -> FlexibleArgumentParser: + """Create CLI args for vLLM server using official parser.""" + parser = FlexibleArgumentParser(description="vLLM server") + parser = make_arg_parser(parser) + return parser.parse_args([ + "--model", model, + "--tensor-parallel-size", str(tp_size), + "--enforce-eager", + "--gpu-memory-utilization", "0.5", + "--max-model-len", "2048", + "--load-format", load_format, + ]) + + +def wait_for_url(url: str, timeout: float = 180.0) -> bool: + """Wait for a URL to become available.""" + start = time.time() + while time.time() - start < timeout: + try: + resp = httpx.get(f"{url}/health", timeout=5.0) + if resp.status_code == 200: + return True + except httpx.RequestError: + time.sleep(2.0) + return False + + +@ray.remote +class Trainer: + """ + Simple trainer emulator that holds the real model weights. + + This is a simplified version of the trainer side for testing weight sync. + Non-colocated: runs on a separate GPU from the inference server. + """ + + def __init__(self, model_name: str, device: str = "cuda"): + self.device = torch.device(device) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + ).to(self.device) + self.pg = None + self.model_name = model_name + + def ready(self): + """Check if the trainer is ready.""" + return True + + def init_weight_sync(self, master_address: str, master_port: int, world_size: int, group_name: str): + """Initialize the weight sync process group as rank 0 (trainer).""" + from skyrl_train.distributed.utils import init_custom_process_group + from skyrl_train.utils import get_tcp_url + + self.pg = init_custom_process_group( + backend="nccl", + init_method=get_tcp_url(master_address, master_port), + world_size=world_size, + rank=0, # Trainer is always rank 0 + group_name=group_name, + ) + return True + + def get_weight_info(self) -> dict: + """ + Get weight metadata (names, dtypes, shapes) without doing NCCL. + + Returns: + dict with names, dtypes, shapes for the weight update request. + """ + names = [] + dtypes = [] + shapes = [] + + for name, param in self.model.named_parameters(): + names.append(name) + dtypes.append(str(param.dtype).split(".")[-1]) # e.g. "bfloat16" + shapes.append(list(param.shape)) + + return {"names": names, "dtypes": dtypes, "shapes": shapes} + + def broadcast_weights(self): + """ + Broadcast all model weights to inference workers via NCCL. + + This is a blocking operation - server must call receive concurrently. + """ + for name, param in self.model.named_parameters(): + torch.distributed.broadcast(param.data, src=0, group=self.pg) + torch.cuda.synchronize() + + def teardown(self): + """Clean up the process group.""" + if self.pg is not None: + torch.distributed.destroy_process_group(self.pg) + self.pg = None + + +@pytest.fixture(scope="class") +def weight_update_env(ray_init_fixture): + """ + Create environment for weight update testing: + - Trainer with real weights on GPU 0 + - 1 vLLM server with TP=2 and DUMMY weights (uses GPU 1,2) + - Router to proxy requests + """ + # Create server with dummy weights + cli_args = make_vllm_cli_args(MODEL, tp_size=2, load_format="dummy") + start_port = get_open_port() + + pg = placement_group([{"CPU": 1, "GPU": 1} for _ in range(3)]) + ray.get(pg.ready()) + + trainer = Trainer.options( + num_gpus=1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=0, + ), + ).remote(MODEL) + + ray.get(trainer.ready.remote()) + + group = ServerGroup( + cli_args=cli_args, + num_servers=1, + start_port=start_port, + placement_group=pg, + placement_group_bundle_offset=1, + ) + server_infos = group.start() + server_urls = [info.url for info in server_infos] + + for url in server_urls: + assert wait_for_url(url), f"Server {url} failed to start" + + # Create router + router_port = get_open_port() + router = InferenceRouter(server_urls, host="0.0.0.0", port=router_port) + router_url = router.start() + assert wait_for_url(router_url), "Router failed to start" + + yield { + "group": group, + "server_urls": server_urls, + "router": router, + "router_url": router_url, + "trainer": trainer, + } + + router.shutdown() + group.shutdown() + ray.get(trainer.teardown.remote()) + + +class TestWeightUpdateFlow: + """Tests for weight synchronization from trainer to inference server.""" + + @pytest.mark.asyncio + async def test_update_weights_flow(self, weight_update_env): + """ + Full E2E weight sync test via router: + 1. Query with dummy weights → gibberish + 2. Init weight transfer (both sides concurrently via router) + 3. Broadcast weights from trainer (concurrent with server receive) + 4. Finalize weight update + 5. Query again → correct output + """ + router_url = weight_update_env["router_url"] + trainer = weight_update_env["trainer"] + + async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as client: + # ===== Step 1: Verify dummy weights produce gibberish ===== + payload = { + "model": MODEL, + "prompt": "What is the capital of France?", + "max_tokens": 32, + "temperature": 0.0, + } + + resp = await client.post(f"{router_url}/v1/completions", json=payload) + assert resp.status_code == 200 + + text_before = resp.json()["choices"][0]["text"] + print(f"[Step 1] Dummy weights output: {text_before!r}") + + # Dummy weights should NOT produce coherent output about Paris + assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" + + # ===== Step 2: Init weight transfer (both sides concurrently) ===== + master_address = get_node_ip() + master_port = get_open_port() + + # Query all servers for world_size (TP * PP) via router + resp = await client.get(f"{router_url}/get_server_info") + assert resp.status_code == 200 + server_info_map = resp.json() + # Sum world_size across all servers + inference_world_size = sum(info["world_size"] for info in server_info_map.values()) + world_size = 1 + inference_world_size # 1 trainer + all inference workers + group_name = f"weight_sync_test_{master_port}" + + print(f"[Step 2] Init weight transfer: master={master_address}:{master_port}, world_size={world_size}") + + init_info = { + "master_addr": master_address, + "master_port": master_port, + "rank_offset": 1, + "world_size": world_size, + "group_name": group_name, + "backend": "nccl", + "model_dtype_str": "bfloat16", + "override_existing_receiver": True, + } + + # Both sides must init concurrently (NCCL blocks until all ranks join) + # Start trainer init (returns immediately, runs in Ray actor) + trainer_init_ref = trainer.init_weight_sync.remote(master_address, master_port, world_size, group_name) + + # Await server init (triggers NCCL join on server side) + server_resp = await client.post(f"{router_url}/init_weight_transfer", json=init_info) + assert server_resp.status_code == 200, f"Server init failed: {server_resp.text}" + + # Trainer should be done now (NCCL group formed) + ray.get(trainer_init_ref) + print("[Step 2] Both sides init complete") + + # ===== Step 3: Broadcast weights (concurrent send/receive) ===== + print("[Step 3] Broadcasting weights from trainer to server...") + + # Get weight metadata first (no NCCL yet) + weight_info = ray.get(trainer.get_weight_info.remote()) + print(f"[Step 3] Weight info: {len(weight_info['names'])} parameters") + + # Start trainer broadcast (returns immediately, runs in Ray actor) + trainer_broadcast_ref = trainer.broadcast_weights.remote() + + # Await server receive (triggers NCCL receive on server side) + server_resp = await client.post(f"{router_url}/update_weights", json=weight_info) + assert server_resp.status_code == 200, f"Update weights failed: {server_resp.text}" + + # Trainer should be done now (NCCL broadcast complete) + ray.get(trainer_broadcast_ref) + print("[Step 3] Weight sync complete") + + # ===== Step 4: Finalize weight update ===== + resp = await client.post(f"{router_url}/finalize_weight_update", json={}) + assert resp.status_code == 200 + print("[Step 4] Weight update finalized") + + # ===== Step 5: Query again - should produce correct output ===== + resp = await client.post(f"{router_url}/v1/completions", json=payload) + assert resp.status_code == 200 + + text_after = resp.json()["choices"][0]["text"] + print(f"[Step 5] Real weights output: {text_after!r}") + + assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" + + print("[SUCCESS] Weight sync test passed!") From 68dc4ed42e891cc92106e6b81cb0d8e53d768524 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 19:16:08 -0800 Subject: [PATCH 14/20] lint Signed-off-by: Kourosh Hakhamaneshi --- .../skyrl_train/inference_servers/common.py | 7 +- .../inference_servers/protocols.py | 36 +++--- .../skyrl_train/inference_servers/router.py | 11 +- .../inference_servers/server_group.py | 10 +- .../inference_servers/server_pool.py | 4 +- .../inference_servers/vllm_server_actor.py | 72 +++++++---- .../inference_servers/vllm_worker.py | 22 ++-- .../gpu/gpu_ci/test_inference_server_group.py | 57 +++++---- .../tests/gpu/gpu_ci/test_weight_sync.py | 113 +++++++++++------- 9 files changed, 205 insertions(+), 127 deletions(-) diff --git a/skyrl-train/skyrl_train/inference_servers/common.py b/skyrl-train/skyrl_train/inference_servers/common.py index cc1205911..6599ecf0b 100644 --- a/skyrl-train/skyrl_train/inference_servers/common.py +++ b/skyrl-train/skyrl_train/inference_servers/common.py @@ -34,7 +34,6 @@ def get_node_ip() -> str: return ray.util.get_node_ip_address() - def get_open_port(start_port: int | None = None) -> int: """ Get an available port. @@ -58,7 +57,9 @@ def get_open_port(start_port: int | None = None) -> int: except OSError: port += 1 if port > 65535: - raise RuntimeError(f"No available port found starting from {start_port}") + raise RuntimeError( + f"No available port found starting from {start_port}" + ) # Let OS assign a free port # Try IPv4 first @@ -72,4 +73,4 @@ def get_open_port(start_port: int | None = None) -> int: # Try IPv6 with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: s.bind(("", 0)) - return s.getsockname()[1] \ No newline at end of file + return s.getsockname()[1] diff --git a/skyrl-train/skyrl_train/inference_servers/protocols.py b/skyrl-train/skyrl_train/inference_servers/protocols.py index ced4ff481..828d83609 100644 --- a/skyrl-train/skyrl_train/inference_servers/protocols.py +++ b/skyrl-train/skyrl_train/inference_servers/protocols.py @@ -14,38 +14,38 @@ class ServerActorProtocol(Protocol): """ Protocol defining the interface for server actor classes. - + Any server actor class (vLLM, SGLang, etc.) must implement this interface to be usable with ServerGroup. - + Example: class MyServerActor(ServerActorProtocol): @staticmethod def compute_num_gpus_per_server(cli_args: Namespace) -> int: return cli_args.tensor_parallel_size - + def __init__(self, cli_args, start_port, server_idx, ...): ... - + async def start(self) -> ServerInfo: ... """ - + @staticmethod def compute_num_gpus_per_server(cli_args: Namespace) -> int: """ Compute the number of GPUs needed per server instance. - + This is called before actor creation to determine placement group size. - + Args: cli_args: Engine-specific CLI arguments. - + Returns: Number of GPUs required per server (e.g., TP * PP for vLLM). """ ... - + def __init__( self, cli_args: Namespace, @@ -60,7 +60,7 @@ def __init__( ) -> None: """ Initialize the server actor. - + Args: cli_args: Engine-specific CLI arguments. start_port: Base port to search for available port. @@ -73,33 +73,33 @@ def __init__( nixl_side_channel_base: Base port for NIXL side channels. """ ... - + def get_server_info(self) -> ServerInfo: """Get the server's IP and port info.""" ... - + def get_dp_info(self) -> Tuple[str, int]: """ Get the DP master address and RPC port. - + Only called on server_idx=0 when DP is enabled. - + Returns: Tuple of (master_address, rpc_port). """ ... - + async def start(self) -> ServerInfo: """ Start the server. - + This should block until the server is healthy and ready to serve requests. - + Returns: ServerInfo with the server's IP and port. """ ... - + async def shutdown(self) -> None: """Gracefully shutdown the server.""" ... diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py index db90ed5f4..a4074febb 100644 --- a/skyrl-train/skyrl_train/inference_servers/router.py +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -12,7 +12,7 @@ import uvicorn from fastapi import FastAPI, Request, Response -from skyrl_train.inference_servers.common import ServerInfo, get_node_ip +from skyrl_train.inference_servers.common import get_node_ip logger = logging.getLogger(__name__) @@ -157,6 +157,7 @@ def _forward_headers(self, request: Request) -> dict: async def _fan_out_get(self, path: str) -> dict: """Fan out a GET request to all servers, return mapping of server_url -> response.""" + async def call_server(server_url: str): try: resp = await self._client.get(f"{server_url}{path}", timeout=30.0) @@ -218,9 +219,7 @@ async def call_server(server_url: str): "error": str(e), } - results = await asyncio.gather( - *[call_server(url) for url in self._server_urls] - ) + results = await asyncio.gather(*[call_server(url) for url in self._server_urls]) # Build mapping from server_url to response response_map = {url: resp for url, resp in results} @@ -271,8 +270,8 @@ def run_server(): ip = get_node_ip() router_url = f"http://{ip}:{self._port}" logger.info(f"Router started at {router_url}") - logger.info(f" GET /servers - list servers") - logger.info(f" GET /get_server_info - get parallelism info") + logger.info(" GET /servers - list servers") + logger.info(" GET /get_server_info - get parallelism info") return router_url async def _run_server(self) -> None: diff --git a/skyrl-train/skyrl_train/inference_servers/server_group.py b/skyrl-train/skyrl_train/inference_servers/server_group.py index cd25bc6af..a4a857504 100644 --- a/skyrl-train/skyrl_train/inference_servers/server_group.py +++ b/skyrl-train/skyrl_train/inference_servers/server_group.py @@ -54,12 +54,12 @@ def __init__( placement_group: External placement group for colocation mode. If None, creates internal placement group. placement_group_bundle_offset: Offset for bundle indices when using - external placement group (e.g., if training uses first N + external placement group (e.g., if training uses first N bundles). enable_dp: Enable data parallelism across servers. enable_pd: Enable prefill-decode disaggregation. - nixl_side_channel_base: Base port for NIXL side channels. Each - server will be assigned a port of nixl_side_channel_base + + nixl_side_channel_base: Base port for NIXL side channels. Each + server will be assigned a port of nixl_side_channel_base + server_idx. server_actor_cls: Server actor class implementing ServerActorProtocol. Defaults to VLLMServerActor. @@ -77,7 +77,9 @@ def __init__( self._internal_pg: Optional[PlacementGroup] = None # Query the actor class for GPU requirements - self._num_gpus_per_server = server_actor_cls.compute_num_gpus_per_server(cli_args) + self._num_gpus_per_server = server_actor_cls.compute_num_gpus_per_server( + cli_args + ) logger.info( f"ServerGroup: actor_cls={server_actor_cls.__name__}, " diff --git a/skyrl-train/skyrl_train/inference_servers/server_pool.py b/skyrl-train/skyrl_train/inference_servers/server_pool.py index 1a6a467b2..620d5db04 100644 --- a/skyrl-train/skyrl_train/inference_servers/server_pool.py +++ b/skyrl-train/skyrl_train/inference_servers/server_pool.py @@ -11,8 +11,8 @@ class ServerActorPool: """Generic pool that manages a list of server actors. - - This layer provides a generic pool interface which can be extended to + + This layer provides a generic pool interface which can be extended to support fault-tolerance, monitoring, etc. for now it's just a simple wrapper around a list of actor handles. Actors must implement: diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py index a2e53a18b..ce4ec5d13 100644 --- a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py +++ b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py @@ -17,7 +17,11 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.openai.api_server import build_app, create_server_socket, init_app_state +from vllm.entrypoints.openai.api_server import ( + build_app, + create_server_socket, + init_app_state, +) from vllm.usage.usage_lib import UsageContext import vllm.envs as envs from vllm.utils.system_utils import set_ulimit @@ -27,7 +31,8 @@ from skyrl_train.inference_servers.protocols import ServerActorProtocol from skyrl_train.inference_servers.vllm_worker import VLLM_WORKER_EXTENSION_CLS from skyrl_train.env_vars import ( - SKYRL_VLLM_DP_PORT_OFFSET, SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S, + SKYRL_VLLM_DP_PORT_OFFSET, + SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S, ) logger = logging.getLogger(__name__) @@ -36,7 +41,7 @@ class VLLMServerActor(ServerActorProtocol): """ Ray actor that runs a vLLM OpenAI-compatible API server. - + Implements ServerActorProtocol for use with ServerGroup. The server runs in the actor and exposes an HTTP endpoint that can be @@ -44,7 +49,7 @@ class VLLMServerActor(ServerActorProtocol): Custom endpoints added for SkyRL: - /get_server_info: Return parallelism info - + - (vLLM RFC: https://github.com/vllm-project/vllm/issues/31848) - /init_weight_transfer: Initialize weight sync process group - /update_weights: Update model weights via NCCL broadcast @@ -54,9 +59,9 @@ class VLLMServerActor(ServerActorProtocol): @staticmethod def compute_num_gpus_per_server(vllm_cli_args: Namespace) -> int: """Compute the number of GPUs needed per server based on TP * PP. - - This logic might need adjustment if we want to support other - parallelism schemes. If we get to this point, we should add a + + This logic might need adjustment if we want to support other + parallelism schemes. If we get to this point, we should add a vllm-specific utility for it and keep the logic inside the engine. """ return vllm_cli_args.tensor_parallel_size * vllm_cli_args.pipeline_parallel_size @@ -95,7 +100,7 @@ def __init__( self._port = get_open_port(start_port) self._server_idx = server_idx self._num_gpus_per_server = self.compute_num_gpus_per_server(vllm_cli_args) - + # Ensure SkyRL's custom worker extension is used for weight sync self._ensure_worker_extension() @@ -116,7 +121,7 @@ def __init__( self._cli_args.data_parallel_rank = server_idx # DP0 will be the master sharing its ip and port with others. - # So if we are not DP0, we need to pass master_ip and port from + # So if we are not DP0, we need to pass master_ip and port from # outside. otherwise, we can use the local ip and port. if server_idx == 0: dp_master_address, dp_rpc_port = self.get_dp_info() @@ -132,7 +137,9 @@ def __init__( ) # Set bundle indices for this server's TP/PP workers in the placement group - bundle_indices = list(range(start_bundle_idx, start_bundle_idx + self._num_gpus_per_server)) + bundle_indices = list( + range(start_bundle_idx, start_bundle_idx + self._num_gpus_per_server) + ) os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) logger.info(f"Server {server_idx}: using bundle indices {bundle_indices}") @@ -143,25 +150,33 @@ def __init__( def _ensure_worker_extension(self) -> None: """ Ensure the SkyRL worker extension is configured. - + The worker extension (WorkerWrap) provides the RPC methods needed for weight synchronization (init_weight_update_communicator, load_weights). """ - if not hasattr(self._cli_args, "worker_extension_cls") or not self._cli_args.worker_extension_cls: + if ( + not hasattr(self._cli_args, "worker_extension_cls") + or not self._cli_args.worker_extension_cls + ): self._cli_args.worker_extension_cls = VLLM_WORKER_EXTENSION_CLS logger.info(f"Using default worker extension: {VLLM_WORKER_EXTENSION_CLS}") else: - logger.info(f"Using provided worker extension: {self._cli_args.worker_extension_cls}") + logger.info( + f"Using provided worker extension: {self._cli_args.worker_extension_cls}" + ) def _ensure_ray_executor(self) -> None: """ Ensure Ray is used as the distributed executor backend. - + When running inside a Ray actor, we must use the Ray executor so that - workers are spawned and properly inherit GPU allocation from the + workers are spawned and properly inherit GPU allocation from the placement group. """ - if not hasattr(self._cli_args, "distributed_executor_backend") or self._cli_args.distributed_executor_backend != "ray": + if ( + not hasattr(self._cli_args, "distributed_executor_backend") + or self._cli_args.distributed_executor_backend != "ray" + ): self._cli_args.distributed_executor_backend = "ray" def _setup_nixl_side_channel(self, base_port: int) -> None: @@ -178,7 +193,10 @@ def _setup_nixl_side_channel(self, base_port: int) -> None: engine_id = f"server-{self._server_idx}-{self._ip}-{side_channel_port}" - if hasattr(self._cli_args, "kv_transfer_config") and self._cli_args.kv_transfer_config: + if ( + hasattr(self._cli_args, "kv_transfer_config") + and self._cli_args.kv_transfer_config + ): try: kv_config = json.loads(self._cli_args.kv_transfer_config) except (json.JSONDecodeError, TypeError) as e: @@ -227,7 +245,9 @@ async def start(self) -> ServerInfo: return self.get_server_info() - async def _wait_until_healthy(self, timeout: float = SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S) -> None: + async def _wait_until_healthy( + self, timeout: float = SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S + ) -> None: """Poll the /health endpoint until it responds OK.""" url = f"http://{self._ip}:{self._port}/health" start_time = time.time() @@ -250,7 +270,9 @@ async def _wait_until_healthy(self, timeout: float = SKYRL_WAIT_UNTIL_INFERENCE_ pass if time.time() - start_time > timeout: - raise TimeoutError(f"Server failed to become healthy within {timeout}s") + raise TimeoutError( + f"Server failed to become healthy within {timeout}s" + ) await asyncio.sleep(1.0) @@ -266,7 +288,9 @@ async def _run_server(self) -> None: engine_args=engine_args, usage_context=UsageContext.OPENAI_API_SERVER, ) - logger.info(f"Engine initialized on {self._ip}:{self._port}, adding custom endpoints...") + logger.info( + f"Engine initialized on {self._ip}:{self._port}, adding custom endpoints..." + ) # Add custom SkyRL endpoints self._add_custom_endpoints(app) @@ -307,9 +331,9 @@ async def _init_weight_transfer(request: Request): data = await request.json() init_info = BroadcastInitInfo(**data).for_engine( - engine_index=self._server_idx, - tp_size=self._cli_args.tensor_parallel_size, - pp_size=self._cli_args.pipeline_parallel_size + engine_index=self._server_idx, + tp_size=self._cli_args.tensor_parallel_size, + pp_size=self._cli_args.pipeline_parallel_size, ) pickled_init_info = pickle.dumps(init_info) @@ -340,7 +364,7 @@ async def _finalize_weight_update(request: Request): Finalize weight update - post-processing hook. Currently a no-op, reserved for future use e.g. Quantization - See https://github.com/vllm-project/vllm/issues/31848 for more + See https://github.com/vllm-project/vllm/issues/31848 for more details. """ # No-op for now - placeholder for future post-processing diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_worker.py b/skyrl-train/skyrl_train/inference_servers/vllm_worker.py index 00edf43f0..98df697ce 100644 --- a/skyrl-train/skyrl_train/inference_servers/vllm_worker.py +++ b/skyrl-train/skyrl_train/inference_servers/vllm_worker.py @@ -2,7 +2,7 @@ vLLM Worker Extension for SkyRL weight synchronization. This module provides WorkerWrap, a vLLM worker extension class that enables -efficient NCCL-based and CUDA IPC-based weight updates from the training +efficient NCCL-based and CUDA IPC-based weight updates from the training process to inference workers. TODO: This will be removed once vLLM natively supports weight sync APIs. @@ -10,7 +10,7 @@ Usage: Pass as --worker-extension-cls to vLLM: - + vllm serve ... --worker-extension-cls skyrl_train.inference_servers.vllm_worker.WorkerWrap """ @@ -26,17 +26,17 @@ class WorkerWrap: """ vLLM worker extension for SkyRL weight synchronization. - + This class is injected into vLLM workers via --worker-extension-cls and provides methods that can be called via engine.collective_rpc() to coordinate weight updates across all TP/PP workers. - + Methods: init_weight_update_communicator: Initialize the weight receiver load_weights: Receive and load weights from trainer teardown_weight_receiver: Clean up weight receiver resources """ - + def test_rpc(self, *args, **kwargs): """Test RPC call to worker.""" return args, kwargs @@ -50,10 +50,14 @@ def init_weight_update_communicator(self, init_info: bytes): """ import pickle - assert torch.distributed.is_initialized(), "default torch process group must be initialized" + assert torch.distributed.is_initialized(), ( + "default torch process group must be initialized" + ) # Unpickle init_info to restore the original object type - assert isinstance(init_info, bytes), f"Expected bytes, got {type(init_info).__name__}" + assert isinstance(init_info, bytes), ( + f"Expected bytes, got {type(init_info).__name__}" + ) init_info = pickle.loads(init_info) strategy_cls = init_info.strategy_type() @@ -84,7 +88,9 @@ def load_weights(self, request: bytes) -> None: import pickle # Unpickle request to restore the original object type - assert isinstance(request, bytes), f"Expected bytes, got {type(request).__name__}" + assert isinstance(request, bytes), ( + f"Expected bytes, got {type(request).__name__}" + ) request = pickle.loads(request) weight_list = [] diff --git a/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py b/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py index e4a7bd186..62907b4ec 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py @@ -29,7 +29,9 @@ # Skip entire module if not enough GPUs _gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 if _gpu_count < 4: - pytest.skip(f"Need 4 GPUs for full test suite, found {_gpu_count}", allow_module_level=True) + pytest.skip( + f"Need 4 GPUs for full test suite, found {_gpu_count}", allow_module_level=True + ) def make_vllm_cli_args( @@ -40,14 +42,21 @@ def make_vllm_cli_args( """Create CLI args for vLLM server using official parser.""" parser = FlexibleArgumentParser(description="vLLM server") parser = make_arg_parser(parser) - return parser.parse_args([ - "--model", model, - "--tensor-parallel-size", str(tp_size), - "--enforce-eager", - "--gpu-memory-utilization", "0.5", - "--max-model-len", "2048", - "--load-format", load_format, - ]) + return parser.parse_args( + [ + "--model", + model, + "--tensor-parallel-size", + str(tp_size), + "--enforce-eager", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "2048", + "--load-format", + load_format, + ] + ) def wait_for_url(url: str, timeout: float = 180.0) -> bool: @@ -119,12 +128,12 @@ def test_get_server_info(self, server_group_and_router): """/get_server_info returns mapping of server_url -> info for all servers.""" router_url = server_group_and_router["router_url"] server_urls = server_group_and_router["server_urls"] - + resp = httpx.get(f"{router_url}/get_server_info", timeout=10.0) assert resp.status_code == 200 info_map = resp.json() print(f"Server info map: {info_map}") - + # Should have info for each server assert len(info_map) == 2 for url in server_urls: @@ -159,30 +168,38 @@ async def test_pause_resume(self, server_group_and_router): async with httpx.AsyncClient() as client: # Pause - resp = await client.post(f"{router_url}/pause", json={"wait_for_inflight_request": False}, timeout=30.0) + resp = await client.post( + f"{router_url}/pause", + json={"wait_for_inflight_request": False}, + timeout=30.0, + ) assert resp.status_code == 200 - + # Check is paused resp = await client.get(f"{router_url}/is_paused", timeout=30.0) assert resp.status_code == 200 - assert resp.json()["is_paused"] == True - + assert resp.json()["is_paused"] is True + # Send a request while paused (should block) async def send_request(): - r = await client.post(f"{router_url}/v1/completions", json={"model": MODEL, "prompt": "Test", "max_tokens": 4}, timeout=60.0) + r = await client.post( + f"{router_url}/v1/completions", + json={"model": MODEL, "prompt": "Test", "max_tokens": 4}, + timeout=60.0, + ) assert r.status_code == 200 return r.json() - + task = asyncio.create_task(send_request()) await asyncio.sleep(1) - + # Task should not be done here (request blocked by pause) assert not task.done() - + # Resume resp = await client.post(f"{router_url}/resume", json={}, timeout=30.0) assert resp.status_code == 200 - + # Verify that after resume, the request is completed result = await task assert result["choices"][0]["text"] is not None diff --git a/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py b/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py index 0c5774c93..79506908f 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py @@ -32,7 +32,9 @@ # Skip entire module if not enough GPUs (need 3: 1 trainer + 2 for TP=2 server) _gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 if _gpu_count < 3: - pytest.skip(f"Need 3 GPUs for weight sync test, found {_gpu_count}", allow_module_level=True) + pytest.skip( + f"Need 3 GPUs for weight sync test, found {_gpu_count}", allow_module_level=True + ) def make_vllm_cli_args( @@ -43,14 +45,21 @@ def make_vllm_cli_args( """Create CLI args for vLLM server using official parser.""" parser = FlexibleArgumentParser(description="vLLM server") parser = make_arg_parser(parser) - return parser.parse_args([ - "--model", model, - "--tensor-parallel-size", str(tp_size), - "--enforce-eager", - "--gpu-memory-utilization", "0.5", - "--max-model-len", "2048", - "--load-format", load_format, - ]) + return parser.parse_args( + [ + "--model", + model, + "--tensor-parallel-size", + str(tp_size), + "--enforce-eager", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "2048", + "--load-format", + load_format, + ] + ) def wait_for_url(url: str, timeout: float = 180.0) -> bool: @@ -70,29 +79,31 @@ def wait_for_url(url: str, timeout: float = 180.0) -> bool: class Trainer: """ Simple trainer emulator that holds the real model weights. - + This is a simplified version of the trainer side for testing weight sync. Non-colocated: runs on a separate GPU from the inference server. """ - + def __init__(self, model_name: str, device: str = "cuda"): self.device = torch.device(device) self.model = AutoModelForCausalLM.from_pretrained( - model_name, + model_name, torch_dtype=torch.bfloat16, ).to(self.device) self.pg = None self.model_name = model_name - + def ready(self): """Check if the trainer is ready.""" return True - - def init_weight_sync(self, master_address: str, master_port: int, world_size: int, group_name: str): + + def init_weight_sync( + self, master_address: str, master_port: int, world_size: int, group_name: str + ): """Initialize the weight sync process group as rank 0 (trainer).""" from skyrl_train.distributed.utils import init_custom_process_group from skyrl_train.utils import get_tcp_url - + self.pg = init_custom_process_group( backend="nccl", init_method=get_tcp_url(master_address, master_port), @@ -101,35 +112,35 @@ def init_weight_sync(self, master_address: str, master_port: int, world_size: in group_name=group_name, ) return True - + def get_weight_info(self) -> dict: """ Get weight metadata (names, dtypes, shapes) without doing NCCL. - + Returns: dict with names, dtypes, shapes for the weight update request. """ names = [] dtypes = [] shapes = [] - + for name, param in self.model.named_parameters(): names.append(name) dtypes.append(str(param.dtype).split(".")[-1]) # e.g. "bfloat16" shapes.append(list(param.shape)) - + return {"names": names, "dtypes": dtypes, "shapes": shapes} - + def broadcast_weights(self): """ Broadcast all model weights to inference workers via NCCL. - + This is a blocking operation - server must call receive concurrently. """ for name, param in self.model.named_parameters(): torch.distributed.broadcast(param.data, src=0, group=self.pg) torch.cuda.synchronize() - + def teardown(self): """Clean up the process group.""" if self.pg is not None: @@ -148,10 +159,10 @@ def weight_update_env(ray_init_fixture): # Create server with dummy weights cli_args = make_vllm_cli_args(MODEL, tp_size=2, load_format="dummy") start_port = get_open_port() - + pg = placement_group([{"CPU": 1, "GPU": 1} for _ in range(3)]) ray.get(pg.ready()) - + trainer = Trainer.options( num_gpus=1, scheduling_strategy=PlacementGroupSchedulingStrategy( @@ -159,7 +170,7 @@ def weight_update_env(ray_init_fixture): placement_group_bundle_index=0, ), ).remote(MODEL) - + ray.get(trainer.ready.remote()) group = ServerGroup( @@ -226,22 +237,28 @@ async def test_update_weights_flow(self, weight_update_env): print(f"[Step 1] Dummy weights output: {text_before!r}") # Dummy weights should NOT produce coherent output about Paris - assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" + assert "Paris" not in text_before, ( + "Dummy weights unexpectedly produced correct answer" + ) # ===== Step 2: Init weight transfer (both sides concurrently) ===== master_address = get_node_ip() master_port = get_open_port() - + # Query all servers for world_size (TP * PP) via router resp = await client.get(f"{router_url}/get_server_info") assert resp.status_code == 200 server_info_map = resp.json() # Sum world_size across all servers - inference_world_size = sum(info["world_size"] for info in server_info_map.values()) + inference_world_size = sum( + info["world_size"] for info in server_info_map.values() + ) world_size = 1 + inference_world_size # 1 trainer + all inference workers group_name = f"weight_sync_test_{master_port}" - print(f"[Step 2] Init weight transfer: master={master_address}:{master_port}, world_size={world_size}") + print( + f"[Step 2] Init weight transfer: master={master_address}:{master_port}, world_size={world_size}" + ) init_info = { "master_addr": master_address, @@ -256,30 +273,40 @@ async def test_update_weights_flow(self, weight_update_env): # Both sides must init concurrently (NCCL blocks until all ranks join) # Start trainer init (returns immediately, runs in Ray actor) - trainer_init_ref = trainer.init_weight_sync.remote(master_address, master_port, world_size, group_name) - + trainer_init_ref = trainer.init_weight_sync.remote( + master_address, master_port, world_size, group_name + ) + # Await server init (triggers NCCL join on server side) - server_resp = await client.post(f"{router_url}/init_weight_transfer", json=init_info) - assert server_resp.status_code == 200, f"Server init failed: {server_resp.text}" - + server_resp = await client.post( + f"{router_url}/init_weight_transfer", json=init_info + ) + assert server_resp.status_code == 200, ( + f"Server init failed: {server_resp.text}" + ) + # Trainer should be done now (NCCL group formed) ray.get(trainer_init_ref) print("[Step 2] Both sides init complete") # ===== Step 3: Broadcast weights (concurrent send/receive) ===== print("[Step 3] Broadcasting weights from trainer to server...") - + # Get weight metadata first (no NCCL yet) weight_info = ray.get(trainer.get_weight_info.remote()) print(f"[Step 3] Weight info: {len(weight_info['names'])} parameters") # Start trainer broadcast (returns immediately, runs in Ray actor) trainer_broadcast_ref = trainer.broadcast_weights.remote() - + # Await server receive (triggers NCCL receive on server side) - server_resp = await client.post(f"{router_url}/update_weights", json=weight_info) - assert server_resp.status_code == 200, f"Update weights failed: {server_resp.text}" - + server_resp = await client.post( + f"{router_url}/update_weights", json=weight_info + ) + assert server_resp.status_code == 200, ( + f"Update weights failed: {server_resp.text}" + ) + # Trainer should be done now (NCCL broadcast complete) ray.get(trainer_broadcast_ref) print("[Step 3] Weight sync complete") @@ -296,6 +323,8 @@ async def test_update_weights_flow(self, weight_update_env): text_after = resp.json()["choices"][0]["text"] print(f"[Step 5] Real weights output: {text_after!r}") - assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" - + assert "Paris" in text_after, ( + f"Weight sync failed - expected 'Paris' but got: {text_after!r}" + ) + print("[SUCCESS] Weight sync test passed!") From dce17d2ec25c6762238bbfb670c56bbd75660c9e Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 19:17:28 -0800 Subject: [PATCH 15/20] lint Signed-off-by: Kourosh Hakhamaneshi --- .../skyrl_train/inference_servers/common.py | 4 +- .../skyrl_train/inference_servers/router.py | 4 +- .../inference_servers/server_group.py | 8 +--- .../inference_servers/vllm_server_actor.py | 45 ++++++------------ .../inference_servers/vllm_worker.py | 13 ++---- skyrl-train/tests/gpu/gpu_ci/conftest.py | 2 +- .../gpu/gpu_ci/test_inference_server_group.py | 6 +-- .../tests/gpu/gpu_ci/test_weight_sync.py | 46 +++++-------------- 8 files changed, 35 insertions(+), 93 deletions(-) diff --git a/skyrl-train/skyrl_train/inference_servers/common.py b/skyrl-train/skyrl_train/inference_servers/common.py index 6599ecf0b..17ae4bb36 100644 --- a/skyrl-train/skyrl_train/inference_servers/common.py +++ b/skyrl-train/skyrl_train/inference_servers/common.py @@ -57,9 +57,7 @@ def get_open_port(start_port: int | None = None) -> int: except OSError: port += 1 if port > 65535: - raise RuntimeError( - f"No available port found starting from {start_port}" - ) + raise RuntimeError(f"No available port found starting from {start_port}") # Let OS assign a free port # Try IPv4 first diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py index a4074febb..a62d586ef 100644 --- a/skyrl-train/skyrl_train/inference_servers/router.py +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -150,9 +150,7 @@ async def _proxy_request(self, request: Request, path: str) -> Response: def _forward_headers(self, request: Request) -> dict: """Forward headers (filter out hop-by-hop headers).""" return { - k: v - for k, v in request.headers.items() - if k.lower() not in ("host", "content-length", "transfer-encoding") + k: v for k, v in request.headers.items() if k.lower() not in ("host", "content-length", "transfer-encoding") } async def _fan_out_get(self, path: str) -> dict: diff --git a/skyrl-train/skyrl_train/inference_servers/server_group.py b/skyrl-train/skyrl_train/inference_servers/server_group.py index a4a857504..4d0e53e6e 100644 --- a/skyrl-train/skyrl_train/inference_servers/server_group.py +++ b/skyrl-train/skyrl_train/inference_servers/server_group.py @@ -77,9 +77,7 @@ def __init__( self._internal_pg: Optional[PlacementGroup] = None # Query the actor class for GPU requirements - self._num_gpus_per_server = server_actor_cls.compute_num_gpus_per_server( - cli_args - ) + self._num_gpus_per_server = server_actor_cls.compute_num_gpus_per_server(cli_args) logger.info( f"ServerGroup: actor_cls={server_actor_cls.__name__}, " @@ -127,9 +125,7 @@ def _create_actors(self) -> List[Any]: for server_idx in range(self._num_servers): # Calculate bundle index accounting for offset (colocation mode) - start_bundle_idx = ( - self._bundle_offset + server_idx * self._num_gpus_per_server - ) + start_bundle_idx = self._bundle_offset + server_idx * self._num_gpus_per_server ServerActorClass = self._create_actor_class(pg, start_bundle_idx) diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py index ce4ec5d13..019018705 100644 --- a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py +++ b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py @@ -7,33 +7,30 @@ import os import pickle import time +from argparse import Namespace from typing import Any, Dict, Optional, Tuple -from argparse import Namespace import httpx import uvicorn +import vllm.envs as envs from fastapi import Request - - -from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.api_server import ( build_app, create_server_socket, init_app_state, ) from vllm.usage.usage_lib import UsageContext -import vllm.envs as envs from vllm.utils.system_utils import set_ulimit - -from skyrl_train.inference_servers.common import ServerInfo, get_node_ip, get_open_port -from skyrl_train.inference_servers.protocols import ServerActorProtocol -from skyrl_train.inference_servers.vllm_worker import VLLM_WORKER_EXTENSION_CLS from skyrl_train.env_vars import ( SKYRL_VLLM_DP_PORT_OFFSET, SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S, ) +from skyrl_train.inference_servers.common import ServerInfo, get_node_ip, get_open_port +from skyrl_train.inference_servers.protocols import ServerActorProtocol +from skyrl_train.inference_servers.vllm_worker import VLLM_WORKER_EXTENSION_CLS logger = logging.getLogger(__name__) @@ -137,9 +134,7 @@ def __init__( ) # Set bundle indices for this server's TP/PP workers in the placement group - bundle_indices = list( - range(start_bundle_idx, start_bundle_idx + self._num_gpus_per_server) - ) + bundle_indices = list(range(start_bundle_idx, start_bundle_idx + self._num_gpus_per_server)) os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) logger.info(f"Server {server_idx}: using bundle indices {bundle_indices}") @@ -154,16 +149,11 @@ def _ensure_worker_extension(self) -> None: The worker extension (WorkerWrap) provides the RPC methods needed for weight synchronization (init_weight_update_communicator, load_weights). """ - if ( - not hasattr(self._cli_args, "worker_extension_cls") - or not self._cli_args.worker_extension_cls - ): + if not hasattr(self._cli_args, "worker_extension_cls") or not self._cli_args.worker_extension_cls: self._cli_args.worker_extension_cls = VLLM_WORKER_EXTENSION_CLS logger.info(f"Using default worker extension: {VLLM_WORKER_EXTENSION_CLS}") else: - logger.info( - f"Using provided worker extension: {self._cli_args.worker_extension_cls}" - ) + logger.info(f"Using provided worker extension: {self._cli_args.worker_extension_cls}") def _ensure_ray_executor(self) -> None: """ @@ -193,10 +183,7 @@ def _setup_nixl_side_channel(self, base_port: int) -> None: engine_id = f"server-{self._server_idx}-{self._ip}-{side_channel_port}" - if ( - hasattr(self._cli_args, "kv_transfer_config") - and self._cli_args.kv_transfer_config - ): + if hasattr(self._cli_args, "kv_transfer_config") and self._cli_args.kv_transfer_config: try: kv_config = json.loads(self._cli_args.kv_transfer_config) except (json.JSONDecodeError, TypeError) as e: @@ -245,9 +232,7 @@ async def start(self) -> ServerInfo: return self.get_server_info() - async def _wait_until_healthy( - self, timeout: float = SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S - ) -> None: + async def _wait_until_healthy(self, timeout: float = SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S) -> None: """Poll the /health endpoint until it responds OK.""" url = f"http://{self._ip}:{self._port}/health" start_time = time.time() @@ -270,9 +255,7 @@ async def _wait_until_healthy( pass if time.time() - start_time > timeout: - raise TimeoutError( - f"Server failed to become healthy within {timeout}s" - ) + raise TimeoutError(f"Server failed to become healthy within {timeout}s") await asyncio.sleep(1.0) @@ -288,9 +271,7 @@ async def _run_server(self) -> None: engine_args=engine_args, usage_context=UsageContext.OPENAI_API_SERVER, ) - logger.info( - f"Engine initialized on {self._ip}:{self._port}, adding custom endpoints..." - ) + logger.info(f"Engine initialized on {self._ip}:{self._port}, adding custom endpoints...") # Add custom SkyRL endpoints self._add_custom_endpoints(app) diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_worker.py b/skyrl-train/skyrl_train/inference_servers/vllm_worker.py index 98df697ce..8249b30a7 100644 --- a/skyrl-train/skyrl_train/inference_servers/vllm_worker.py +++ b/skyrl-train/skyrl_train/inference_servers/vllm_worker.py @@ -18,7 +18,6 @@ import torch - # Path to this worker extension class for use in CLI args (derived from module path) VLLM_WORKER_EXTENSION_CLS = f"{__name__}.WorkerWrap" @@ -50,14 +49,10 @@ def init_weight_update_communicator(self, init_info: bytes): """ import pickle - assert torch.distributed.is_initialized(), ( - "default torch process group must be initialized" - ) + assert torch.distributed.is_initialized(), "default torch process group must be initialized" # Unpickle init_info to restore the original object type - assert isinstance(init_info, bytes), ( - f"Expected bytes, got {type(init_info).__name__}" - ) + assert isinstance(init_info, bytes), f"Expected bytes, got {type(init_info).__name__}" init_info = pickle.loads(init_info) strategy_cls = init_info.strategy_type() @@ -88,9 +83,7 @@ def load_weights(self, request: bytes) -> None: import pickle # Unpickle request to restore the original object type - assert isinstance(request, bytes), ( - f"Expected bytes, got {type(request).__name__}" - ) + assert isinstance(request, bytes), f"Expected bytes, got {type(request).__name__}" request = pickle.loads(request) weight_list = [] diff --git a/skyrl-train/tests/gpu/gpu_ci/conftest.py b/skyrl-train/tests/gpu/gpu_ci/conftest.py index 6e60ab6a2..ed33b89b5 100644 --- a/skyrl-train/tests/gpu/gpu_ci/conftest.py +++ b/skyrl-train/tests/gpu/gpu_ci/conftest.py @@ -33,7 +33,7 @@ def ray_init_fixture(): # needed for megatron tests env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env_vars["NVTE_FUSED_ATTN"] = "0" - + if SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV: env_vars["PYTHONPATH"] = os.environ.get("PYTHONPATH") diff --git a/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py b/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py index 62907b4ec..53e6f65f1 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py @@ -19,9 +19,9 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.utils.argparse_utils import FlexibleArgumentParser +from skyrl_train.inference_servers.common import get_open_port from skyrl_train.inference_servers.router import InferenceRouter from skyrl_train.inference_servers.server_group import ServerGroup -from skyrl_train.inference_servers.common import get_open_port MODEL = "Qwen/Qwen2.5-0.5B-Instruct" @@ -29,9 +29,7 @@ # Skip entire module if not enough GPUs _gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 if _gpu_count < 4: - pytest.skip( - f"Need 4 GPUs for full test suite, found {_gpu_count}", allow_module_level=True - ) + pytest.skip(f"Need 4 GPUs for full test suite, found {_gpu_count}", allow_module_level=True) def make_vllm_cli_args( diff --git a/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py b/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py index 79506908f..a26649fd5 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py @@ -22,9 +22,9 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.utils.argparse_utils import FlexibleArgumentParser +from skyrl_train.inference_servers.common import get_node_ip, get_open_port from skyrl_train.inference_servers.router import InferenceRouter from skyrl_train.inference_servers.server_group import ServerGroup -from skyrl_train.inference_servers.common import get_open_port, get_node_ip MODEL = "Qwen/Qwen2.5-0.5B-Instruct" @@ -32,9 +32,7 @@ # Skip entire module if not enough GPUs (need 3: 1 trainer + 2 for TP=2 server) _gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 if _gpu_count < 3: - pytest.skip( - f"Need 3 GPUs for weight sync test, found {_gpu_count}", allow_module_level=True - ) + pytest.skip(f"Need 3 GPUs for weight sync test, found {_gpu_count}", allow_module_level=True) def make_vllm_cli_args( @@ -97,9 +95,7 @@ def ready(self): """Check if the trainer is ready.""" return True - def init_weight_sync( - self, master_address: str, master_port: int, world_size: int, group_name: str - ): + def init_weight_sync(self, master_address: str, master_port: int, world_size: int, group_name: str): """Initialize the weight sync process group as rank 0 (trainer).""" from skyrl_train.distributed.utils import init_custom_process_group from skyrl_train.utils import get_tcp_url @@ -237,9 +233,7 @@ async def test_update_weights_flow(self, weight_update_env): print(f"[Step 1] Dummy weights output: {text_before!r}") # Dummy weights should NOT produce coherent output about Paris - assert "Paris" not in text_before, ( - "Dummy weights unexpectedly produced correct answer" - ) + assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" # ===== Step 2: Init weight transfer (both sides concurrently) ===== master_address = get_node_ip() @@ -250,15 +244,11 @@ async def test_update_weights_flow(self, weight_update_env): assert resp.status_code == 200 server_info_map = resp.json() # Sum world_size across all servers - inference_world_size = sum( - info["world_size"] for info in server_info_map.values() - ) + inference_world_size = sum(info["world_size"] for info in server_info_map.values()) world_size = 1 + inference_world_size # 1 trainer + all inference workers group_name = f"weight_sync_test_{master_port}" - print( - f"[Step 2] Init weight transfer: master={master_address}:{master_port}, world_size={world_size}" - ) + print(f"[Step 2] Init weight transfer: master={master_address}:{master_port}, world_size={world_size}") init_info = { "master_addr": master_address, @@ -273,17 +263,11 @@ async def test_update_weights_flow(self, weight_update_env): # Both sides must init concurrently (NCCL blocks until all ranks join) # Start trainer init (returns immediately, runs in Ray actor) - trainer_init_ref = trainer.init_weight_sync.remote( - master_address, master_port, world_size, group_name - ) + trainer_init_ref = trainer.init_weight_sync.remote(master_address, master_port, world_size, group_name) # Await server init (triggers NCCL join on server side) - server_resp = await client.post( - f"{router_url}/init_weight_transfer", json=init_info - ) - assert server_resp.status_code == 200, ( - f"Server init failed: {server_resp.text}" - ) + server_resp = await client.post(f"{router_url}/init_weight_transfer", json=init_info) + assert server_resp.status_code == 200, f"Server init failed: {server_resp.text}" # Trainer should be done now (NCCL group formed) ray.get(trainer_init_ref) @@ -300,12 +284,8 @@ async def test_update_weights_flow(self, weight_update_env): trainer_broadcast_ref = trainer.broadcast_weights.remote() # Await server receive (triggers NCCL receive on server side) - server_resp = await client.post( - f"{router_url}/update_weights", json=weight_info - ) - assert server_resp.status_code == 200, ( - f"Update weights failed: {server_resp.text}" - ) + server_resp = await client.post(f"{router_url}/update_weights", json=weight_info) + assert server_resp.status_code == 200, f"Update weights failed: {server_resp.text}" # Trainer should be done now (NCCL broadcast complete) ray.get(trainer_broadcast_ref) @@ -323,8 +303,6 @@ async def test_update_weights_flow(self, weight_update_env): text_after = resp.json()["choices"][0]["text"] print(f"[Step 5] Real weights output: {text_after!r}") - assert "Paris" in text_after, ( - f"Weight sync failed - expected 'Paris' but got: {text_after!r}" - ) + assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" print("[SUCCESS] Weight sync test passed!") From 22c12ade022f1629c14f678b70538220d1863567 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 22:19:22 -0800 Subject: [PATCH 16/20] gemini fback Signed-off-by: Kourosh Hakhamaneshi --- skyrl-train/skyrl_train/env_vars.py | 13 ++-- .../inference_engines/vllm/vllm_engine.py | 4 +- .../skyrl_train/inference_servers/router.py | 71 ++++++++++--------- .../cpu/inference_servers/test_common.py | 3 +- .../cpu/inference_servers/test_router.py | 11 +-- 5 files changed, 53 insertions(+), 49 deletions(-) diff --git a/skyrl-train/skyrl_train/env_vars.py b/skyrl-train/skyrl_train/env_vars.py index 09ae5356f..9d7eb6b2f 100644 --- a/skyrl-train/skyrl_train/env_vars.py +++ b/skyrl-train/skyrl_train/env_vars.py @@ -1,6 +1,3 @@ - - - import os @@ -8,11 +5,15 @@ """ Offset for the data parallel port of the vLLM server. """ -SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S = int(os.environ.get("SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S", 600)) +SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S = int( + os.environ.get("SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S", 600) +) """ Timeout for waiting until the inference server is healthy. """ -SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV = str(os.environ.get("SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV", "False")).lower() in ( +SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV = str( + os.environ.get("SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV", "False") +).lower() in ( "true", "1", "yes", @@ -22,4 +23,4 @@ environment. In case of using ray nightly, this will be needed to avoid dependencies issues by setting it to the local path where ray nightly is installed. -""" \ No newline at end of file +""" diff --git a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py index 80259d5a3..7c6933a3a 100644 --- a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from http import HTTPStatus import ray -import torch import asyncio import vllm from types import SimpleNamespace @@ -24,7 +23,6 @@ ) from vllm.lora.request import LoRARequest from uuid import uuid4 -import warnings from skyrl_train.inference_engines.base import ( InferenceEngineInterface, InferenceEngineInput, @@ -69,7 +67,7 @@ def setup_envvars_for_vllm(kwargs, bundle_indices): # Backward compatibility: WorkerWrap has moved to inference_servers.vllm_worker # This alias preserves the old import path for existing scripts/configs. # TODO (Kourosh): Remove this alias once all references are updated. -from skyrl_train.inference_servers.vllm_worker import WorkerWrap # noqa: F401 +from skyrl_train.inference_servers.vllm_worker import WorkerWrap # noqa: F401, E402 class BaseVLLMInferenceEngine(InferenceEngineInterface): diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py index a62d586ef..f1826ecb4 100644 --- a/skyrl-train/skyrl_train/inference_servers/router.py +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -6,13 +6,15 @@ import hashlib import itertools import logging +import threading +import time from typing import List, Optional import httpx import uvicorn from fastapi import FastAPI, Request, Response -from skyrl_train.inference_servers.common import get_node_ip +from skyrl_train.inference_servers.common import get_node_ip, SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S logger = logging.getLogger(__name__) @@ -71,8 +73,8 @@ def __init__( self._server_cycle = itertools.cycle(server_urls) self._client: Optional[httpx.AsyncClient] = None self._app: Optional[FastAPI] = None - self._server_task: Optional[asyncio.Task] = None - self._shutdown_event: Optional[asyncio.Event] = None + self._server: Optional[uvicorn.Server] = None + self._server_thread: Optional[threading.Thread] = None logger.info(f"InferenceRouter: {len(server_urls)} servers, port={port}") @@ -245,48 +247,49 @@ def start(self) -> str: # Create HTTP client for proxying self._client = httpx.AsyncClient(timeout=httpx.Timeout(None)) - # Build FastAPI app + # Build FastAPI app and uvicorn server self._app = self._build_app() + config = uvicorn.Config( + app=self._app, + host=self._host, + port=self._port, + log_level="warning", + access_log=False, + ) + self._server = uvicorn.Server(config) - # Create shutdown event - self._shutdown_event = asyncio.Event() - - # Start server in background thread (since we're not in async context) - import threading - - def run_server(): - asyncio.run(self._run_server()) - - self._server_thread = threading.Thread(target=run_server, daemon=True) + # Start server in background thread + self._server_thread = threading.Thread(target=asyncio.run, args=(self._server.serve(),), daemon=True) self._server_thread.start() - # Wait a bit for server to start - import time - - time.sleep(1) - ip = get_node_ip() router_url = f"http://{ip}:{self._port}" + self._wait_until_healthy(router_url) + logger.info(f"Router started at {router_url}") logger.info(" GET /servers - list servers") logger.info(" GET /get_server_info - get parallelism info") return router_url - async def _run_server(self) -> None: - """Run the uvicorn server.""" - config = uvicorn.Config( - app=self._app, - host=self._host, - port=self._port, - log_level="warning", - access_log=False, - ) - server = uvicorn.Server(config) - await server.serve() + def _wait_until_healthy( + self, router_url: str, timeout: float = SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S + ) -> None: + """Poll health endpoint until server is ready.""" + health_url = f"{router_url}/health" + start_time = time.time() + while time.time() - start_time < timeout: + try: + with httpx.Client() as client: + if client.get(health_url, timeout=1).status_code == 200: + return + except httpx.RequestError: + time.sleep(0.1) + raise RuntimeError(f"Router failed to start within {timeout}s") def shutdown(self) -> None: - """Shutdown the router.""" + """Shutdown the router gracefully.""" logger.info("Shutting down router...") - if self._shutdown_event: - self._shutdown_event.set() - # Note: Thread will exit when uvicorn server stops + if self._server: + self._server.should_exit = True + if self._server_thread: + self._server_thread.join(timeout=5) diff --git a/skyrl-train/tests/cpu/inference_servers/test_common.py b/skyrl-train/tests/cpu/inference_servers/test_common.py index e41fac370..76136368c 100644 --- a/skyrl-train/tests/cpu/inference_servers/test_common.py +++ b/skyrl-train/tests/cpu/inference_servers/test_common.py @@ -29,8 +29,7 @@ def test_get_open_port_os_assigned(self): assert isinstance(port, int) assert 1 <= port <= 65535 self._verify_port_is_free(port) - - + def _verify_port_is_free(self, port: int): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) diff --git a/skyrl-train/tests/cpu/inference_servers/test_router.py b/skyrl-train/tests/cpu/inference_servers/test_router.py index 617c5ec26..386e72fac 100644 --- a/skyrl-train/tests/cpu/inference_servers/test_router.py +++ b/skyrl-train/tests/cpu/inference_servers/test_router.py @@ -68,9 +68,7 @@ def env(): router._client = httpx.AsyncClient(timeout=httpx.Timeout(None)) router._app = router._build_app() - router_config = uvicorn.Config( - router._app, host="127.0.0.1", port=router_port, log_level="error" - ) + router_config = uvicorn.Config(router._app, host="127.0.0.1", port=router_port, log_level="error") router_server = uvicorn.Server(router_config) servers.append(router_server) @@ -106,7 +104,12 @@ def test_session_affinity(env): def test_control_plane_fanout(env): """Control plane routes fan out to all servers.""" resp = httpx.post(f"{env}/sleep", json={}) - assert resp.status_code == 200 and resp.json()["status"] == "ok" + assert resp.status_code == 200 + # Response is a mapping of server_url -> {status, body} + response_map = resp.json() + assert len(response_map) == 2 # Both servers received the request + for server_url, result in response_map.items(): + assert result["status"] == 200 def test_list_servers(env): From 05bfc923d68a5b4a51f8e68716a8e167bea0dd82 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 22:21:44 -0800 Subject: [PATCH 17/20] wip Signed-off-by: Kourosh Hakhamaneshi --- skyrl-train/skyrl_train/inference_servers/router.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py index f1826ecb4..4c66f5af2 100644 --- a/skyrl-train/skyrl_train/inference_servers/router.py +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -14,7 +14,8 @@ import uvicorn from fastapi import FastAPI, Request, Response -from skyrl_train.inference_servers.common import get_node_ip, SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S +from skyrl_train.inference_servers.common import get_node_ip +from skyrl_train.env_vars import SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S logger = logging.getLogger(__name__) From eca0e3dc3e1e4376c8aff06a9f77f25f22f31bf0 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Mon, 19 Jan 2026 23:50:25 -0800 Subject: [PATCH 18/20] wip Signed-off-by: Kourosh Hakhamaneshi --- .../remote_inference_client.py | 719 ++++++++++++++++++ .../skyrl_train/inference_servers/router.py | 1 + .../test_remote_inference_client.py | 633 +++++++++++++++ 3 files changed, 1353 insertions(+) create mode 100644 skyrl-train/skyrl_train/inference_servers/remote_inference_client.py create mode 100644 skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py diff --git a/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py b/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py new file mode 100644 index 000000000..652d6b94c --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py @@ -0,0 +1,719 @@ +""" +RemoteInferenceClient - Serializable HTTP client for inference. + +This is a lightweight, fully serializable HTTP client that wraps the inference +server HTTP API. It replaces the old InferenceEngineInterface for HTTP-based +inference servers. + +Key features: +- Serializable: Can be pickled and passed between processes +- Two URL types: + - proxy_url: Single URL for data plane operations (routed requests) + - server_urls: List of backend URLs for control plane operations (fan-out) +- Lazy world_size fetching from /get_server_info +- Pure HTTP: Uses /tokenize endpoint if token IDs are needed +- Built-in retry on abort for in-flight weight updates + +Usage: + # Full proxy mode (router handles both data and control plane) + client = RemoteInferenceClient( + proxy_url="http://router:8080", + server_urls=["http://router:8080"], + ) + +Comparison with existing code: +- Replaces: InferenceEngineClient + RemoteInferenceEngine (for remote-only usage) +- Key difference: Talks directly to router via HTTP, no Ray actor wrapping +- The router handles session-aware routing; this client is simpler + +TODO: Data Plane Operations - Future Deprecation +------------------------------------------------ +All data plane operations (generate, chat_completion, completion, tokenize, detokenize) +and the retry-on-abort logic will eventually be removed from this client. + +When vLLM RFC #32103 lands with PauseMode.KEEP: +- The retry logic in generate() will be deleted +- pause() will use mode="keep" which preserves KV cache and scheduler state +- Requests resume seamlessly after unpause with zero client changes + +The generator code will transition to: +1. OpenAI-compatible endpoints (/v1/chat/completions) for text-based interaction +2. Tinker sample API for token-in-token-out workflows: + - Input: ModelInput.from_ints(tokens=input_ids) + - Output: sequences[0].tokens, sequences[0].logprobs + - Internally maps to /v1/completions with token-in-token-out + - May become a native vLLM API in the future + +This client will then primarily handle control plane operations only. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +import aiohttp + +if TYPE_CHECKING: + from skyrl_train.inference_engines.base import InferenceEngineInput, InferenceEngineOutput + from skyrl_train.weight_sync import BroadcastInitInfo, BroadcastWeightUpdateRequest + +logger = logging.getLogger(__name__) + + +class PauseMode(Enum): + """ + Pause mode for inference servers. + + This enum mirrors the pause modes that will be available in vLLM RFC #32103. + For now, we map these to the existing `wait_for_inflight_request` parameter. + + Modes: + ABORT: Abort in-flight requests immediately. Clients receive partial + tokens and must retry with accumulated context. + Maps to: wait_for_inflight_request=False + + FINISH: Wait for in-flight requests to complete before pausing. + New requests are blocked. No retry needed. + Maps to: wait_for_inflight_request=True + + KEEP: (Future - vLLM RFC #32103) Preserve KV cache and scheduler state. + Requests resume seamlessly after unpause. Zero client changes needed. + NOT YET SUPPORTED - raises NotImplementedError. + """ + + ABORT = "abort" + FINISH = "finish" + KEEP = "keep" + + +@dataclass +class RemoteInferenceClient: + """ + Serializable HTTP client for inference. Replaces InferenceEngineInterface. + + This class maintains two URL types: + - proxy_url: Single URL for data plane operations (routed requests) + - server_urls: List of backend URLs for control plane operations (fan-out) + + This separation allows using external routers (vllm-router, sglang-router) + that only handle data plane, while still being able to call control plane + endpoints directly on backends. + + For a "full proxy" setup where the router handles both data and control plane, + set proxy_url and server_urls to the same value: + proxy_url = "http://router:8080" + server_urls = ["http://router:8080"] + + For external routers that only handle data plane: + proxy_url = "http://vllm-router:8080" + server_urls = ["http://backend1:8000", "http://backend2:8000"] + """ + + proxy_url: str + """Data plane URL (single endpoint - router or direct server).""" + + server_urls: List[str] + """Control plane URLs (list of backend servers for fan-out).""" + + model_name: str = "default" + """Model name for OpenAI-compatible API calls.""" + + # Private fields excluded from repr for cleaner output + _session: Optional[aiohttp.ClientSession] = field(default=None, repr=False) + _world_size: Optional[int] = field(default=None, repr=False) + + # --------------------------- + # Session Management + # --------------------------- + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create the aiohttp session.""" + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) + return self._session + + # --------------------------- + # Data Plane + # --------------------------- + + async def generate( + self, + input_batch: "InferenceEngineInput", + ) -> "InferenceEngineOutput": + """ + Generate completions via /v1/completions. + + This is the interface for token-in-token-out workflows. Input will have + token ids, and the output is token ids as well. + + Each prompt is sent as a separate request to allow the router to route + based on session_id. All requests are made in parallel. + + Args: + input_batch: Contains prompt_token_ids, sampling_params, and optional session_ids. + + Returns: + InferenceEngineOutput with responses, response_ids, and stop_reasons. + + Note: + If retry_on_abort=True and a request is aborted (due to pause), + the client will retry with accumulated tokens until completion. + """ + from skyrl_train.inference_engines.base import InferenceEngineOutput + + prompt_token_ids = input_batch.get("prompt_token_ids") + if prompt_token_ids is None: + raise ValueError("RemoteInferenceClient only accepts `prompt_token_ids`, not `prompts`.") + + sampling_params = input_batch.get("sampling_params") or {} + if sampling_params.get("n", 1) > 1: + raise ValueError( + "n > 1 is not supported. Use `config.generator.n_samples_per_prompt` instead." + ) + + session_ids = input_batch.get("session_ids") + + # Create parallel tasks for all prompts + # Each task handles its own retry on abort + tasks = [ + self._generate_single( + prompt_token_ids=prompt_token_ids[idx], + sampling_params=sampling_params, + session_id=session_ids[idx] if session_ids and idx < len(session_ids) else None, + ) + for idx in range(len(prompt_token_ids)) + ] + + # Run all in parallel - retries happen within each task + results = await asyncio.gather(*tasks) + + return InferenceEngineOutput( + responses=[r["response"] for r in results], + stop_reasons=[r["stop_reason"] for r in results], + response_ids=[r["response_ids"] for r in results], + response_logprobs=None, + ) + + # TODO: Delete retry logic when vLLM RFC #32103 lands with PauseMode.KEEP + async def _generate_single( + self, + prompt_token_ids: List[int], + sampling_params: Dict[str, Any], + session_id: Optional[Any], + ) -> Dict[str, Any]: + """ + Generate completion for a single prompt with built-in retry on abort. + + When pause(mode=ABORT) is called, running requests return partial tokens + with stop_reason="abort". This method retries with accumulated tokens + until generation completes with a non-abort stop reason. + + TODO: Retry logic will be deleted when vLLM RFC #32103 lands. + With PauseMode.KEEP, requests resume seamlessly after unpause. + + Returns: + Dict with keys: response, stop_reason, response_ids + """ + session = await self._get_session() + url = f"{self.proxy_url}/v1/completions" + + # Determine max_tokens key and original value + max_key = None + if "max_tokens" in sampling_params: + max_key = "max_tokens" + elif "max_completion_tokens" in sampling_params: + max_key = "max_completion_tokens" + original_max_tokens = sampling_params.get(max_key) if max_key else None + + # Accumulate across retries + accum_text = "" + accum_token_ids: List[int] = [] + stop_reason = "abort" + + while stop_reason == "abort": + # Wait if generation is paused + await self._wait_for_resume() + + # Build payload with accumulated context + cur_params = sampling_params.copy() + if original_max_tokens is not None and max_key: + remaining = original_max_tokens - len(accum_token_ids) + if remaining <= 0: + break + cur_params[max_key] = remaining + + # New prompt = original + accumulated tokens + new_prompt = prompt_token_ids + accum_token_ids + + payload = cur_params.copy() + payload["model"] = self.model_name + payload["prompt"] = new_prompt + + headers = {"Content-Type": "application/json"} + if session_id: + headers["X-Session-ID"] = str(session_id) + + async with session.post(url, json=payload, headers=headers) as resp: + resp.raise_for_status() + response = await resp.json() + + choice = response["choices"][0] + new_text = choice["text"] + stop_reason = choice["finish_reason"] + + # Accumulate text + accum_text += new_text + # Tokenize the new text to get token IDs for next iteration + if stop_reason == "abort" and new_text: + new_token_ids = (await self.tokenize([new_text], add_special_tokens=False))[0] + accum_token_ids.extend(new_token_ids) + + # Final response + # Tokenize full accumulated text for response_ids + final_token_ids = (await self.tokenize([accum_text], add_special_tokens=False))[0] if accum_text else [] + + return { + "response": accum_text, + "stop_reason": stop_reason, + "response_ids": final_token_ids, + } + + async def _wait_for_resume(self, poll_interval: float = 0.1) -> None: + """ + Wait until ALL servers return is_paused=False. + + Sampling can only continue when every server has resumed. + + Args: + poll_interval: Seconds between polling attempts. + """ + session = await self._get_session() + + while True: + # ALL servers must return is_paused=False for sampling to continue + all_resumed = True + for url in self.server_urls: + try: + async with session.get(f"{url}/is_paused") as resp: + if resp.status == 200: + result = await resp.json() + # Default to paused=True for safety if key missing + if result.get("is_paused", True): + all_resumed = False + break + else: + # Non-200 response, assume still paused + all_resumed = False + break + except Exception: + # If we can't reach the server, assume it's still paused + all_resumed = False + break + + if all_resumed: + return + + await asyncio.sleep(poll_interval) + + async def chat_completion( + self, + request_payload: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Chat completion via /v1/chat/completions. + + Args: + request_payload: Dict with {"json": , "headers": }. + The request body should be OpenAI-compatible chat completion request. + session_id can be included in json for consistent routing. + + Returns: + OpenAI-compatible chat completion response. + """ + body = request_payload.get("json", {}) + + # Extract session_id for routing (same as InferenceEngineClient) + session_id = body.pop("session_id", None) + + headers = {"Content-Type": "application/json"} + if session_id: + headers["X-Session-ID"] = str(session_id) + + session = await self._get_session() + url = f"{self.proxy_url}/v1/chat/completions" + + async with session.post(url, json=body, headers=headers) as resp: + resp.raise_for_status() + return await resp.json() + + async def completion( + self, + request_payload: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Completion via /v1/completions. + + Args: + request_payload: Dict with {"json": , "headers": }. + The request body should be OpenAI-compatible completion request. + session_id can be included in json for consistent routing. + + Returns: + OpenAI-compatible completion response. + """ + body = request_payload.get("json", {}) + + # Extract session_id for routing (same as InferenceEngineClient) + session_id = body.pop("session_id", None) + + headers = {"Content-Type": "application/json"} + if session_id: + headers["X-Session-ID"] = str(session_id) + + session = await self._get_session() + url = f"{self.proxy_url}/v1/completions" + + async with session.post(url, json=body, headers=headers) as resp: + resp.raise_for_status() + return await resp.json() + + async def tokenize( + self, + texts: List[str], + add_special_tokens: bool = True, + ) -> List[List[int]]: + """ + Tokenize texts via /tokenize. + + Args: + texts: List of texts to tokenize. + add_special_tokens: Whether to add special tokens. + + Returns: + List of token ID lists. + """ + session = await self._get_session() + url = f"{self.proxy_url}/tokenize" + + # vLLM /tokenize expects individual requests, batch them + results = [] + for text in texts: + payload = { + "model": self.model_name, + "prompt": text, + "add_special_tokens": add_special_tokens, + } + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + result = await resp.json() + results.append(result.get("tokens", [])) + + return results + + async def detokenize( + self, + token_ids: List[List[int]], + ) -> List[str]: + """ + Detokenize token IDs via /detokenize. + + Args: + token_ids: List of token ID lists. + + Returns: + List of decoded texts. + """ + session = await self._get_session() + url = f"{self.proxy_url}/detokenize" + + # vLLM /detokenize expects individual requests, batch them + results = [] + for ids in token_ids: + payload = { + "model": self.model_name, + "tokens": ids, + } + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + result = await resp.json() + results.append(result.get("prompt", "")) + + return results + + # --------------------------- + # Control Plane (fan-out to all server_urls) + # --------------------------- + + async def _call_all_servers( + self, + endpoint: str, + json: Dict[str, Any], + method: str = "POST", + ) -> Dict[str, Any]: + """ + Call endpoint on all server_urls concurrently. + + Args: + endpoint: Endpoint path (e.g., "/pause"). + json: JSON payload to send. + method: HTTP method (default: POST). + + Returns: + Dict mapping server_url to response. + """ + session = await self._get_session() + + async def call_server(server_url: str) -> tuple: + url = f"{server_url}{endpoint}" + try: + async with session.request(method, url, json=json) as resp: + body = await resp.json() if resp.content_length else None + return server_url, {"status": resp.status, "body": body} + except Exception as e: + return server_url, {"status": 500, "error": str(e)} + + results = await asyncio.gather(*[call_server(url) for url in self.server_urls]) + return {url: resp for url, resp in results} + + async def pause(self, mode: PauseMode = PauseMode.ABORT) -> Dict[str, Any]: + """ + Pause generation on all backends. + + Args: + mode: Pause mode determining how in-flight requests are handled. + - ABORT: Abort in-flight requests immediately. Clients receive + partial tokens and must retry with accumulated context. + - FINISH: Wait for in-flight requests to complete before pausing. + New requests are blocked. No retry needed. + - KEEP: (Future) Preserve KV cache and scheduler state. + NOT YET SUPPORTED. + + Returns: + Dict mapping server_url to response. + + Note: + When vLLM RFC #32103 lands, we'll use the native mode parameter. + For now, we map modes to wait_for_inflight_request: + - ABORT → wait_for_inflight_request=False + - FINISH → wait_for_inflight_request=True + """ + if mode == PauseMode.KEEP: + raise NotImplementedError( + "PauseMode.KEEP is not yet supported. " + "Waiting for vLLM RFC #32103 to land." + ) + + wait_for_inflight_request = mode == PauseMode.FINISH + + return await self._call_all_servers("/pause", { + "wait_for_inflight_request": wait_for_inflight_request + }) + + async def resume(self) -> Dict[str, Any]: + """Resume generation on all backends.""" + return await self._call_all_servers("/resume", {}) + + async def sleep(self, level: int = 2) -> Dict[str, Any]: + """ + Put all backends to sleep (offload weights to CPU). + + Args: + level: Sleep level (1 or 2). Level 2 offloads more aggressively. + + Returns: + Dict mapping server_url to response. + """ + return await self._call_all_servers("/sleep", {"level": level}) + + async def wake_up(self) -> Dict[str, Any]: + """Wake up all backends (load weights back to GPU).""" + return await self._call_all_servers("/wake_up", {}) + + async def reset_prefix_cache( + self, + reset_running_requests: bool = False, + ) -> Dict[str, Any]: + """ + Reset KV cache on all backends. + + Args: + reset_running_requests: Whether to reset running requests. + + Returns: + Dict mapping server_url to response. + """ + return await self._call_all_servers("/reset_prefix_cache", { + "reset_running_requests": reset_running_requests + }) + + # --------------------------- + # Weight Sync (control plane - fan-out) + # --------------------------- + + async def init_weight_transfer( + self, + init_info: "BroadcastInitInfo", + ) -> Dict[str, Any]: + """ + Initialize weight sync process group on all backends. + + Args: + init_info: BroadcastInitInfo containing all args for weight sync setup. + + Returns: + Dict mapping server_url to response. + """ + from dataclasses import asdict + return await self._call_all_servers("/init_weight_transfer", asdict(init_info)) + + async def update_weights( + self, + request: "BroadcastWeightUpdateRequest", + ) -> Dict[str, Any]: + """ + Update weights on all backends. + + Args: + request: BroadcastWeightUpdateRequest containing weight metadata. + + Returns: + Dict mapping server_url to response. + """ + from dataclasses import asdict + return await self._call_all_servers("/update_weights", asdict(request)) + + async def finalize_weight_update(self) -> Dict[str, Any]: + """ + Finalize weight update on all backends. + + Called after all update_weights() calls are complete. + Reserved for any post-processing steps that may be needed: + - Cache invalidation + - State synchronization + - Future vLLM requirements + + Returns: + Dict mapping server_url to response. + """ + return await self._call_all_servers("/finalize_weight_update", {}) + + # --------------------------- + # Info + # --------------------------- + + async def get_world_size(self) -> int: + """ + Get total world size across all inference workers. + + Fetches from /get_server_info on each server and sums the world_size values. + Result is cached after first call. + """ + if self._world_size is not None: + return self._world_size + + session = await self._get_session() + total_world_size = 0 + + for server_url in self.server_urls: + try: + url = f"{server_url}/get_server_info" + async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp: + resp.raise_for_status() + info = await resp.json() + total_world_size += info.get("world_size", 1) + except Exception as e: + logger.warning(f"Failed to fetch server info from {server_url}: {e}") + raise + + self._world_size = total_world_size + return self._world_size + + # --------------------------- + # Lifecycle + # --------------------------- + + async def teardown(self) -> None: + """Close HTTP session.""" + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + async def __aenter__(self) -> "RemoteInferenceClient": + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Async context manager exit.""" + await self.teardown() + + # --------------------------- + # Serialization + # --------------------------- + + def __getstate__(self) -> dict: + """Exclude non-serializable fields from pickle.""" + state = self.__dict__.copy() + state["_session"] = None + return state + + def __setstate__(self, state: dict) -> None: + """Restore state after unpickling.""" + self.__dict__.update(state) + self._session = None + + # --------------------------- + # Compatibility helpers + # --------------------------- + + @classmethod + def from_server_group( + cls, + server_urls: List[str], + router_url: str, + model_name: str = "default", + ) -> "RemoteInferenceClient": + """ + Create client from server group URLs. + + Args: + server_urls: List of backend server URLs. + router_url: Router URL for data plane. + model_name: Model name for API calls. + + Returns: + Configured RemoteInferenceClient. + """ + return cls( + proxy_url=router_url, + server_urls=server_urls, + model_name=model_name, + ) + + @classmethod + def from_router( + cls, + router_url: str, + model_name: str = "default", + ) -> "RemoteInferenceClient": + """ + Create client for a full-feature router (handles both data and control plane). + + This is for routers like InferenceRouter that support all endpoints. + Control plane calls go to the router which fans out to backends. + + Args: + router_url: Router URL. + model_name: Model name for API calls. + + Returns: + Configured RemoteInferenceClient. + """ + return cls( + proxy_url=router_url, + server_urls=[router_url], + model_name=model_name, + ) diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py index 4c66f5af2..466b166e9 100644 --- a/skyrl-train/skyrl_train/inference_servers/router.py +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -26,6 +26,7 @@ # BUILT-IN ROUTES "/pause", "/resume", + "/is_paused", "/sleep", "/wake_up", "/reset_prefix_cache", diff --git a/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py b/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py new file mode 100644 index 000000000..6539f6d2f --- /dev/null +++ b/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py @@ -0,0 +1,633 @@ +"""Tests for RemoteInferenceClient.""" + +import asyncio +import pickle +import threading +import time +from typing import Any, Dict, List + +import httpx +import pytest +import uvicorn +from fastapi import FastAPI, Request + +from skyrl_train.inference_servers.common import get_open_port +from skyrl_train.inference_servers.remote_inference_client import RemoteInferenceClient, PauseMode + + +def create_mock_vllm_server(server_id: int) -> FastAPI: + """Create a mock vLLM server with standard endpoints.""" + app = FastAPI() + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.get("/get_server_info") + async def get_server_info(): + return { + "ip": "127.0.0.1", + "port": 8000 + server_id, + "url": f"http://127.0.0.1:{8000 + server_id}", + "server_idx": server_id, + "world_size": 2, # Simulate TP=2 + } + + @app.post("/v1/completions") + async def completions(request: Request): + body = await request.json() + prompts = body.get("prompt", []) + n_prompts = len(prompts) if isinstance(prompts, list) else 1 + return { + "choices": [ + {"index": i, "text": f"Response {i} from server {server_id}", "finish_reason": "stop"} + for i in range(n_prompts) + ] + } + + @app.post("/v1/chat/completions") + async def chat_completions(request: Request): + return {"choices": [{"message": {"content": f"Chat from server {server_id}"}}]} + + @app.post("/tokenize") + async def tokenize(request: Request): + return {"tokens": [1, 2, 3]} + + @app.post("/detokenize") + async def detokenize(request: Request): + return {"prompt": "hello world"} + + # Control plane endpoints + @app.post("/pause") + async def pause(request: Request): + return {"status": "paused", "server_id": server_id} + + @app.post("/resume") + async def resume(): + return {"status": "resumed", "server_id": server_id} + + @app.get("/is_paused") + async def is_paused(): + # Mock always returns not paused for basic tests + return {"is_paused": False} + + @app.post("/sleep") + async def sleep(request: Request): + return {"status": "sleeping", "server_id": server_id} + + @app.post("/wake_up") + async def wake_up(): + return {"status": "awake", "server_id": server_id} + + @app.post("/reset_prefix_cache") + async def reset_prefix_cache(request: Request): + return {"status": "cache_reset", "server_id": server_id} + + @app.post("/init_weight_transfer") + async def init_weight_transfer(request: Request): + return {"status": "ok", "server_id": server_id} + + @app.post("/update_weights") + async def update_weights(request: Request): + return {"status": "ok", "server_id": server_id} + + @app.post("/finalize_weight_update") + async def finalize_weight_update(request: Request): + return {"status": "ok", "server_id": server_id} + + return app + + +def start_server(port: int, server_id: int) -> uvicorn.Server: + """Start a mock server, return the server instance.""" + app = create_mock_vllm_server(server_id) + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") + server = uvicorn.Server(config) + + def run(): + asyncio.run(server.serve()) + + threading.Thread(target=run, daemon=True).start() + return server + + +def wait_ready(url: str, timeout: float = 5.0) -> bool: + """Wait for server to become healthy.""" + start = time.time() + while time.time() - start < timeout: + try: + if httpx.get(f"{url}/health", timeout=1.0).status_code == 200: + return True + except httpx.RequestError: + time.sleep(0.1) + return False + + +@pytest.fixture(scope="module") +def mock_servers(): + """Start mock vLLM servers.""" + servers: List[uvicorn.Server] = [] + ports = [get_open_port(), get_open_port()] + urls = [f"http://127.0.0.1:{p}" for p in ports] + + for i, port in enumerate(ports): + servers.append(start_server(port, server_id=i)) + + for url in urls: + assert wait_ready(url), f"Server {url} failed to start" + + yield urls + + # Cleanup + for server in servers: + server.should_exit = True + time.sleep(0.3) + + +class TestRemoteInferenceClientInit: + """Test client initialization and serialization.""" + + def test_init(self, mock_servers): + """Client initializes with proxy_url and server_urls.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + assert client.proxy_url == mock_servers[0] + assert client.server_urls == mock_servers + + def test_serialization(self, mock_servers): + """Client can be pickled and unpickled.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + model_name="test-model", + ) + + # Pickle and unpickle + pickled = pickle.dumps(client) + restored = pickle.loads(pickled) + + assert restored.proxy_url == client.proxy_url + assert restored.server_urls == client.server_urls + assert restored.model_name == client.model_name + # Session should be None after unpickling + assert restored._session is None + + def test_from_server_group(self, mock_servers): + """Test factory method from_server_group.""" + client = RemoteInferenceClient.from_server_group( + server_urls=mock_servers, + model_name="test-model", + ) + assert client.proxy_url == mock_servers[0] # Defaults to first server + assert client.server_urls == mock_servers + + def test_from_server_group_with_router(self, mock_servers): + """Test factory method from_server_group with router URL.""" + router_url = "http://router:8080" + client = RemoteInferenceClient.from_server_group( + server_urls=mock_servers, + router_url=router_url, + model_name="test-model", + ) + assert client.proxy_url == router_url + assert client.server_urls == mock_servers + + def test_from_router(self, mock_servers): + """Test factory method from_router.""" + router_url = mock_servers[0] + client = RemoteInferenceClient.from_router( + router_url=router_url, + model_name="test-model", + ) + assert client.proxy_url == router_url + assert client.server_urls == [router_url] + + +class TestDataPlane: + """Test data plane methods.""" + + @pytest.mark.asyncio + async def test_generate(self, mock_servers): + """Test generate method.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + input_batch = { + "prompt_token_ids": [[1, 2, 3], [4, 5, 6]], + "sampling_params": {"max_tokens": 100}, + } + result = await client.generate(input_batch) + + assert "responses" in result + assert "stop_reasons" in result + assert len(result["responses"]) == 2 + assert all(r == "stop" for r in result["stop_reasons"]) + # response_ids are empty (use tokenize() if needed) + assert all(len(ids) == 0 for ids in result["response_ids"]) + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_generate_with_session_id(self, mock_servers): + """Test generate with session ID for consistent routing.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + input_batch = { + "prompt_token_ids": [[1, 2, 3]], + "session_ids": ["test-session"], + } + result = await client.generate(input_batch) + assert len(result["responses"]) == 1 + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_chat_completion(self, mock_servers): + """Test chat completion method.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + request_payload = { + "json": { + "model": "test", + "messages": [{"role": "user", "content": "Hello"}], + }, + "headers": {}, + } + result = await client.chat_completion(request_payload) + assert "choices" in result + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_completion(self, mock_servers): + """Test completion method.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + request_payload = { + "json": {"model": "test", "prompt": "Hello"}, + "headers": {}, + } + result = await client.completion(request_payload) + assert "choices" in result + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_tokenize(self, mock_servers): + """Test tokenize method.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + result = await client.tokenize(["hello", "world"]) + assert len(result) == 2 + assert result[0] == [1, 2, 3] # Mock response + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_detokenize(self, mock_servers): + """Test detokenize method.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + result = await client.detokenize([[1, 2, 3], [4, 5, 6]]) + assert len(result) == 2 + assert result[0] == "hello world" # Mock response + finally: + await client.teardown() + + +class TestControlPlane: + """Test control plane methods (fan-out to all servers).""" + + @pytest.mark.asyncio + async def test_pause_abort_mode(self, mock_servers): + """Test pause with ABORT mode (default) fans out to all servers.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + result = await client.pause(mode=PauseMode.ABORT) + assert len(result) == 2 + for url, response in result.items(): + assert response["status"] == 200 + assert response["body"]["status"] == "paused" + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_pause_finish_mode(self, mock_servers): + """Test pause with FINISH mode fans out to all servers.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + result = await client.pause(mode=PauseMode.FINISH) + assert len(result) == 2 + for url, response in result.items(): + assert response["status"] == 200 + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_pause_keep_mode_not_supported(self, mock_servers): + """Test pause with KEEP mode raises NotImplementedError.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + with pytest.raises(NotImplementedError, match="KEEP is not yet supported"): + await client.pause(mode=PauseMode.KEEP) + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_resume(self, mock_servers): + """Test resume fans out to all servers.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + # Pause first + await client.pause() + + # Resume + result = await client.resume() + assert len(result) == 2 + for url, response in result.items(): + assert response["status"] == 200 + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_sleep(self, mock_servers): + """Test sleep fans out to all servers.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + result = await client.sleep(level=2) + assert len(result) == 2 + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_wake_up(self, mock_servers): + """Test wake_up fans out to all servers.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + result = await client.wake_up() + assert len(result) == 2 + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_reset_prefix_cache(self, mock_servers): + """Test reset_prefix_cache fans out to all servers.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + result = await client.reset_prefix_cache() + assert len(result) == 2 + finally: + await client.teardown() + + +class TestWeightSync: + """Test weight sync methods.""" + + @pytest.mark.asyncio + async def test_init_weight_transfer(self, mock_servers): + """Test init_weight_transfer fans out to all servers.""" + from skyrl_train.weight_sync import BroadcastInitInfo + + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + init_info = BroadcastInitInfo( + master_addr="127.0.0.1", + master_port=29500, + rank_offset=1, + world_size=5, + group_name="test", + backend="nccl", + model_dtype_str="torch.bfloat16", + ) + result = await client.init_weight_transfer(init_info) + assert len(result) == 2 + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_update_weights(self, mock_servers): + """Test update_weights fans out to all servers.""" + from skyrl_train.weight_sync import BroadcastWeightUpdateRequest + + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + request = BroadcastWeightUpdateRequest( + names=["layer.weight"], + dtypes=["torch.bfloat16"], + shapes=[[1024, 1024]], + ) + result = await client.update_weights(request) + assert len(result) == 2 + finally: + await client.teardown() + + @pytest.mark.asyncio + async def test_finalize_weight_update(self, mock_servers): + """Test finalize_weight_update fans out to all servers.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + result = await client.finalize_weight_update() + assert len(result) == 2 + finally: + await client.teardown() + + +class TestServerInfo: + """Test server info and world_size.""" + + @pytest.mark.asyncio + async def test_get_world_size(self, mock_servers): + """Test world_size fetching and caching.""" + client = RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) + + try: + # First call fetches from all servers and sums + world_size = await client.get_world_size() + # Each mock server reports world_size=2, we have 2 servers = 4 + assert world_size == 4 + + # Second call returns cached value + world_size2 = await client.get_world_size() + assert world_size2 == 4 + finally: + await client.teardown() + + +class TestContextManager: + """Test async context manager.""" + + @pytest.mark.asyncio + async def test_async_context_manager(self, mock_servers): + """Test using client as async context manager.""" + async with RemoteInferenceClient( + proxy_url=mock_servers[0], + server_urls=mock_servers, + ) as client: + result = await client.resume() + assert len(result) == 2 + + # Session should be closed after exiting context + assert client._session is None or client._session.closed + + +class TestRetryOnAbort: + """Test retry on abort functionality.""" + + @pytest.fixture + def abort_mock_server(self): + """Create a mock server that returns abort on first call, then stop.""" + app = FastAPI() + call_count = {"completions": 0} + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.post("/v1/completions") + async def completions(request: Request): + call_count["completions"] += 1 + body = await request.json() + + # First call returns abort with partial response + if call_count["completions"] == 1: + return { + "choices": [ + {"index": 0, "text": "Partial ", "finish_reason": "abort"} + ] + } + # Second call returns complete response + else: + return { + "choices": [ + {"index": 0, "text": "response complete", "finish_reason": "stop"} + ] + } + + @app.post("/tokenize") + async def tokenize(request: Request): + body = await request.json() + prompt = body.get("prompt", "") + # Simple tokenization: one token per word + tokens = [hash(word) % 10000 for word in prompt.split()] + return {"tokens": tokens} + + @app.get("/get_server_info") + async def get_server_info(): + return {"world_size": 1} + + @app.get("/is_paused") + async def is_paused(): + # Not paused - allows retry to proceed immediately + return {"is_paused": False} + + # Start server in background thread + port = get_open_port() + config = uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="warning") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + # Wait for server to be ready + for _ in range(100): + try: + httpx.get(f"http://127.0.0.1:{port}/health", timeout=0.1) + break + except Exception: + time.sleep(0.05) + + yield f"http://127.0.0.1:{port}", call_count + + server.should_exit = True + thread.join(timeout=1) + + @pytest.mark.asyncio + async def test_retry_on_abort(self, abort_mock_server): + """Test that retry on abort is always active (built-in behavior).""" + url, call_count = abort_mock_server + client = RemoteInferenceClient( + proxy_url=url, + server_urls=[url], + ) + + try: + result = await client.generate({ + "prompt_token_ids": [[1, 2, 3]], + "sampling_params": {"max_tokens": 100}, + }) + + # Should get complete response after retry + assert result["stop_reasons"][0] == "stop" + assert result["responses"][0] == "Partial response complete" + assert call_count["completions"] == 2 + # Should have response_ids from tokenization + assert len(result["response_ids"][0]) > 0 + finally: + await client.teardown() From bdd1d8ab31520ef9857b133cc0a1406d785fbc74 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Tue, 20 Jan 2026 21:25:15 +0000 Subject: [PATCH 19/20] wip Signed-off-by: Kourosh Hakhamaneshi --- .../skyrl_train/inference_engines/base.py | 32 ++++- .../remote_inference_client.py | 115 +++++------------ .../skyrl_train/inference_servers/router.py | 5 + .../test_remote_inference_client.py | 118 ++++++++---------- 4 files changed, 119 insertions(+), 151 deletions(-) diff --git a/skyrl-train/skyrl_train/inference_engines/base.py b/skyrl-train/skyrl_train/inference_engines/base.py index 392e2100e..ddb893182 100644 --- a/skyrl-train/skyrl_train/inference_engines/base.py +++ b/skyrl-train/skyrl_train/inference_engines/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Dict, TypedDict, Any, Optional, Hashable, TYPE_CHECKING +from typing import List, Dict, TypedDict, Any, Optional, Hashable, Protocol, runtime_checkable, TYPE_CHECKING if TYPE_CHECKING: from skyrl_train.weight_sync.transfer_strategy import WeightSyncInitInfo @@ -31,6 +31,36 @@ class InferenceEngineOutput(TypedDict): response_logprobs: Optional[List[List[float]]] +@runtime_checkable +class InferenceClientProtocol(Protocol): + """ + Structural protocol for inference clients. + + Both InferenceEngineInterface and RemoteInferenceClient satisfy this protocol, + enabling code to work with either implementation via duck typing. + + This is the minimal common interface for: + - Data plane: generate() + - Lifecycle: sleep(), wake_up(), teardown() + """ + + async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: + """Generate completions for a batch of inputs.""" + ... + + async def sleep(self, *args: Any, **kwargs: Any) -> Any: + """Put the engine into sleep/low-power mode.""" + ... + + async def wake_up(self, *args: Any, **kwargs: Any) -> Any: + """Wake the engine from sleep mode.""" + ... + + async def teardown(self) -> None: + """Clean up resources.""" + ... + + class InferenceEngineInterface(ABC): @abstractmethod diff --git a/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py b/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py index 652d6b94c..3868adfc8 100644 --- a/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py @@ -53,12 +53,15 @@ import logging from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING +from dataclasses import asdict + import aiohttp +from skyrl_train.inference_engines.base import InferenceEngineInput, InferenceEngineOutput + if TYPE_CHECKING: - from skyrl_train.inference_engines.base import InferenceEngineInput, InferenceEngineOutput from skyrl_train.weight_sync import BroadcastInitInfo, BroadcastWeightUpdateRequest logger = logging.getLogger(__name__) @@ -79,15 +82,10 @@ class PauseMode(Enum): FINISH: Wait for in-flight requests to complete before pausing. New requests are blocked. No retry needed. Maps to: wait_for_inflight_request=True - - KEEP: (Future - vLLM RFC #32103) Preserve KV cache and scheduler state. - Requests resume seamlessly after unpause. Zero client changes needed. - NOT YET SUPPORTED - raises NotImplementedError. """ ABORT = "abort" FINISH = "finish" - KEEP = "keep" @dataclass @@ -142,8 +140,8 @@ async def _get_session(self) -> aiohttp.ClientSession: async def generate( self, - input_batch: "InferenceEngineInput", - ) -> "InferenceEngineOutput": + input_batch: InferenceEngineInput, + ) -> InferenceEngineOutput: """ Generate completions via /v1/completions. @@ -158,12 +156,7 @@ async def generate( Returns: InferenceEngineOutput with responses, response_ids, and stop_reasons. - - Note: - If retry_on_abort=True and a request is aborted (due to pause), - the client will retry with accumulated tokens until completion. """ - from skyrl_train.inference_engines.base import InferenceEngineOutput prompt_token_ids = input_batch.get("prompt_token_ids") if prompt_token_ids is None: @@ -235,9 +228,6 @@ async def _generate_single( stop_reason = "abort" while stop_reason == "abort": - # Wait if generation is paused - await self._wait_for_resume() - # Build payload with accumulated context cur_params = sampling_params.copy() if original_max_tokens is not None and max_key: @@ -269,7 +259,10 @@ async def _generate_single( accum_text += new_text # Tokenize the new text to get token IDs for next iteration if stop_reason == "abort" and new_text: - new_token_ids = (await self.tokenize([new_text], add_special_tokens=False))[0] + new_token_ids = ( + await self.tokenize( + [new_text], add_special_tokens=False) + )[0] accum_token_ids.extend(new_token_ids) # Final response @@ -282,43 +275,6 @@ async def _generate_single( "response_ids": final_token_ids, } - async def _wait_for_resume(self, poll_interval: float = 0.1) -> None: - """ - Wait until ALL servers return is_paused=False. - - Sampling can only continue when every server has resumed. - - Args: - poll_interval: Seconds between polling attempts. - """ - session = await self._get_session() - - while True: - # ALL servers must return is_paused=False for sampling to continue - all_resumed = True - for url in self.server_urls: - try: - async with session.get(f"{url}/is_paused") as resp: - if resp.status == 200: - result = await resp.json() - # Default to paused=True for safety if key missing - if result.get("is_paused", True): - all_resumed = False - break - else: - # Non-200 response, assume still paused - all_resumed = False - break - except Exception: - # If we can't reach the server, assume it's still paused - all_resumed = False - break - - if all_resumed: - return - - await asyncio.sleep(poll_interval) - async def chat_completion( self, request_payload: Dict[str, Any], @@ -479,33 +435,31 @@ async def call_server(server_url: str) -> tuple: results = await asyncio.gather(*[call_server(url) for url in self.server_urls]) return {url: resp for url, resp in results} - async def pause(self, mode: PauseMode = PauseMode.ABORT) -> Dict[str, Any]: + async def pause(self, mode: Union[PauseMode, str] = PauseMode.ABORT) -> Dict[str, Any]: """ Pause generation on all backends. Args: mode: Pause mode determining how in-flight requests are handled. - - ABORT: Abort in-flight requests immediately. Clients receive - partial tokens and must retry with accumulated context. - - FINISH: Wait for in-flight requests to complete before pausing. - New requests are blocked. No retry needed. - - KEEP: (Future) Preserve KV cache and scheduler state. - NOT YET SUPPORTED. + Can be a PauseMode enum or string ("abort", "finish"). + - ABORT / "abort": Abort in-flight requests immediately. Clients + receive partial tokens and must retry with accumulated context. + New requests are blocked. + - FINISH / "finish": Wait for in-flight requests to complete before + pausing. New requests are blocked. No retry needed. Returns: Dict mapping server_url to response. - Note: + TODO: When vLLM RFC #32103 lands, we'll use the native mode parameter. For now, we map modes to wait_for_inflight_request: - ABORT → wait_for_inflight_request=False - FINISH → wait_for_inflight_request=True """ - if mode == PauseMode.KEEP: - raise NotImplementedError( - "PauseMode.KEEP is not yet supported. " - "Waiting for vLLM RFC #32103 to land." - ) + # Convert string to PauseMode if needed + if isinstance(mode, str): + mode = PauseMode(mode.lower()) wait_for_inflight_request = mode == PauseMode.FINISH @@ -567,7 +521,6 @@ async def init_weight_transfer( Returns: Dict mapping server_url to response. """ - from dataclasses import asdict return await self._call_all_servers("/init_weight_transfer", asdict(init_info)) async def update_weights( @@ -583,7 +536,6 @@ async def update_weights( Returns: Dict mapping server_url to response. """ - from dataclasses import asdict return await self._call_all_servers("/update_weights", asdict(request)) async def finalize_weight_update(self) -> Dict[str, Any]: @@ -615,19 +567,18 @@ async def get_world_size(self) -> int: if self._world_size is not None: return self._world_size - session = await self._get_session() - total_world_size = 0 + results = await self._call_all_servers("/get_server_info", {}, method="GET") - for server_url in self.server_urls: - try: - url = f"{server_url}/get_server_info" - async with session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp: - resp.raise_for_status() - info = await resp.json() - total_world_size += info.get("world_size", 1) - except Exception as e: - logger.warning(f"Failed to fetch server info from {server_url}: {e}") - raise + total_world_size = 0 + for server_url, resp in results.items(): + if resp.get("status") != 200: + error = resp.get("error", resp.get("body")) + raise RuntimeError(f"Failed to fetch server info from {server_url}: {error}") + body = resp.get("body", {}) + world_size = body.get("world_size") + if world_size is None: + raise RuntimeError(f"Failed to fetch server info from {server_url}: world_size is missing") + total_world_size += world_size self._world_size = total_world_size return self._world_size diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py index 466b166e9..d8863d21d 100644 --- a/skyrl-train/skyrl_train/inference_servers/router.py +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -119,6 +119,11 @@ def _build_app(self) -> FastAPI: openapi_url=None, ) + @app.get("/health") + async def health(): + """Router health check (doesn't proxy to backends).""" + return {"status": "healthy"} + @app.get("/servers") async def list_servers(): """Return list of server URLs.""" diff --git a/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py b/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py index 6539f6d2f..7f68a2177 100644 --- a/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py +++ b/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py @@ -101,7 +101,7 @@ async def finalize_weight_update(request: Request): def start_server(port: int, server_id: int) -> uvicorn.Server: """Start a mock server, return the server instance.""" app = create_mock_vllm_server(server_id) - config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") + config = uvicorn.Config(app, host="0.0.0.0", port=port, log_level="error") server = uvicorn.Server(config) def run(): @@ -125,18 +125,19 @@ def wait_ready(url: str, timeout: float = 5.0) -> bool: @pytest.fixture(scope="module") def mock_servers(): - """Start mock vLLM servers.""" + """Start mock vLLM servers, return proxy_url and server_urls.""" servers: List[uvicorn.Server] = [] ports = [get_open_port(), get_open_port()] - urls = [f"http://127.0.0.1:{p}" for p in ports] + server_urls = [f"http://127.0.0.1:{p}" for p in ports] for i, port in enumerate(ports): servers.append(start_server(port, server_id=i)) - for url in urls: + for url in server_urls: assert wait_ready(url), f"Server {url} failed to start" - yield urls + # proxy_url defaults to first server; can be replaced with router URL later + yield {"proxy_url": server_urls[0], "server_urls": server_urls} # Cleanup for server in servers: @@ -147,20 +148,11 @@ def mock_servers(): class TestRemoteInferenceClientInit: """Test client initialization and serialization.""" - def test_init(self, mock_servers): - """Client initializes with proxy_url and server_urls.""" - client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, - ) - assert client.proxy_url == mock_servers[0] - assert client.server_urls == mock_servers - def test_serialization(self, mock_servers): """Client can be pickled and unpickled.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], model_name="test-model", ) @@ -177,26 +169,16 @@ def test_serialization(self, mock_servers): def test_from_server_group(self, mock_servers): """Test factory method from_server_group.""" client = RemoteInferenceClient.from_server_group( - server_urls=mock_servers, + server_urls=mock_servers["server_urls"], + router_url=mock_servers["proxy_url"], model_name="test-model", ) - assert client.proxy_url == mock_servers[0] # Defaults to first server - assert client.server_urls == mock_servers - - def test_from_server_group_with_router(self, mock_servers): - """Test factory method from_server_group with router URL.""" - router_url = "http://router:8080" - client = RemoteInferenceClient.from_server_group( - server_urls=mock_servers, - router_url=router_url, - model_name="test-model", - ) - assert client.proxy_url == router_url - assert client.server_urls == mock_servers + assert client.proxy_url == mock_servers["proxy_url"] + assert client.server_urls == mock_servers["server_urls"] def test_from_router(self, mock_servers): """Test factory method from_router.""" - router_url = mock_servers[0] + router_url = mock_servers["proxy_url"] client = RemoteInferenceClient.from_router( router_url=router_url, model_name="test-model", @@ -212,8 +194,8 @@ class TestDataPlane: async def test_generate(self, mock_servers): """Test generate method.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -236,8 +218,8 @@ async def test_generate(self, mock_servers): async def test_generate_with_session_id(self, mock_servers): """Test generate with session ID for consistent routing.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -254,8 +236,8 @@ async def test_generate_with_session_id(self, mock_servers): async def test_chat_completion(self, mock_servers): """Test chat completion method.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -275,8 +257,8 @@ async def test_chat_completion(self, mock_servers): async def test_completion(self, mock_servers): """Test completion method.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -293,8 +275,8 @@ async def test_completion(self, mock_servers): async def test_tokenize(self, mock_servers): """Test tokenize method.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -308,8 +290,8 @@ async def test_tokenize(self, mock_servers): async def test_detokenize(self, mock_servers): """Test detokenize method.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -327,8 +309,8 @@ class TestControlPlane: async def test_pause_abort_mode(self, mock_servers): """Test pause with ABORT mode (default) fans out to all servers.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -344,8 +326,8 @@ async def test_pause_abort_mode(self, mock_servers): async def test_pause_finish_mode(self, mock_servers): """Test pause with FINISH mode fans out to all servers.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -360,8 +342,8 @@ async def test_pause_finish_mode(self, mock_servers): async def test_pause_keep_mode_not_supported(self, mock_servers): """Test pause with KEEP mode raises NotImplementedError.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -374,8 +356,8 @@ async def test_pause_keep_mode_not_supported(self, mock_servers): async def test_resume(self, mock_servers): """Test resume fans out to all servers.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -394,8 +376,8 @@ async def test_resume(self, mock_servers): async def test_sleep(self, mock_servers): """Test sleep fans out to all servers.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -408,8 +390,8 @@ async def test_sleep(self, mock_servers): async def test_wake_up(self, mock_servers): """Test wake_up fans out to all servers.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -422,8 +404,8 @@ async def test_wake_up(self, mock_servers): async def test_reset_prefix_cache(self, mock_servers): """Test reset_prefix_cache fans out to all servers.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -442,8 +424,8 @@ async def test_init_weight_transfer(self, mock_servers): from skyrl_train.weight_sync import BroadcastInitInfo client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -467,8 +449,8 @@ async def test_update_weights(self, mock_servers): from skyrl_train.weight_sync import BroadcastWeightUpdateRequest client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -486,8 +468,8 @@ async def test_update_weights(self, mock_servers): async def test_finalize_weight_update(self, mock_servers): """Test finalize_weight_update fans out to all servers.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -504,8 +486,8 @@ class TestServerInfo: async def test_get_world_size(self, mock_servers): """Test world_size fetching and caching.""" client = RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) try: @@ -528,8 +510,8 @@ class TestContextManager: async def test_async_context_manager(self, mock_servers): """Test using client as async context manager.""" async with RemoteInferenceClient( - proxy_url=mock_servers[0], - server_urls=mock_servers, + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], ) as client: result = await client.resume() assert len(result) == 2 @@ -590,7 +572,7 @@ async def is_paused(): # Start server in background thread port = get_open_port() - config = uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="warning") + config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level="warning") server = uvicorn.Server(config) thread = threading.Thread(target=server.run, daemon=True) thread.start() From 9bf417350735e4d462fed6258cdb38938aa2b585 Mon Sep 17 00:00:00 2001 From: Kourosh Hakhamaneshi Date: Tue, 20 Jan 2026 21:50:41 +0000 Subject: [PATCH 20/20] wip Signed-off-by: Kourosh Hakhamaneshi --- .../remote_inference_client.py | 17 +- .../test_remote_inference_client.py | 390 ++++++------------ 2 files changed, 133 insertions(+), 274 deletions(-) diff --git a/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py b/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py index 3868adfc8..ee06c72df 100644 --- a/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py +++ b/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py @@ -164,9 +164,7 @@ async def generate( sampling_params = input_batch.get("sampling_params") or {} if sampling_params.get("n", 1) > 1: - raise ValueError( - "n > 1 is not supported. Use `config.generator.n_samples_per_prompt` instead." - ) + raise ValueError("n > 1 is not supported. Use `config.generator.n_samples_per_prompt` instead.") session_ids = input_batch.get("session_ids") @@ -259,10 +257,7 @@ async def _generate_single( accum_text += new_text # Tokenize the new text to get token IDs for next iteration if stop_reason == "abort" and new_text: - new_token_ids = ( - await self.tokenize( - [new_text], add_special_tokens=False) - )[0] + new_token_ids = (await self.tokenize([new_text], add_special_tokens=False))[0] accum_token_ids.extend(new_token_ids) # Final response @@ -463,9 +458,7 @@ async def pause(self, mode: Union[PauseMode, str] = PauseMode.ABORT) -> Dict[str wait_for_inflight_request = mode == PauseMode.FINISH - return await self._call_all_servers("/pause", { - "wait_for_inflight_request": wait_for_inflight_request - }) + return await self._call_all_servers("/pause", {"wait_for_inflight_request": wait_for_inflight_request}) async def resume(self) -> Dict[str, Any]: """Resume generation on all backends.""" @@ -500,9 +493,7 @@ async def reset_prefix_cache( Returns: Dict mapping server_url to response. """ - return await self._call_all_servers("/reset_prefix_cache", { - "reset_running_requests": reset_running_requests - }) + return await self._call_all_servers("/reset_prefix_cache", {"reset_running_requests": reset_running_requests}) # --------------------------- # Weight Sync (control plane - fan-out) diff --git a/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py b/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py index 7f68a2177..aff303560 100644 --- a/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py +++ b/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py @@ -4,10 +4,11 @@ import pickle import threading import time -from typing import Any, Dict, List +from typing import List import httpx import pytest +import pytest_asyncio import uvicorn from fastapi import FastAPI, Request @@ -145,6 +146,17 @@ def mock_servers(): time.sleep(0.3) +@pytest_asyncio.fixture +async def client(mock_servers): + """Create a RemoteInferenceClient for data/control plane tests.""" + client = RemoteInferenceClient( + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], + ) + yield client + await client.teardown() + + class TestRemoteInferenceClientInit: """Test client initialization and serialization.""" @@ -191,316 +203,175 @@ class TestDataPlane: """Test data plane methods.""" @pytest.mark.asyncio - async def test_generate(self, mock_servers): + async def test_generate(self, client): """Test generate method.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) + input_batch = { + "prompt_token_ids": [[1, 2, 3], [4, 5, 6]], + "sampling_params": {"max_tokens": 100}, + } + result = await client.generate(input_batch) - try: - input_batch = { - "prompt_token_ids": [[1, 2, 3], [4, 5, 6]], - "sampling_params": {"max_tokens": 100}, - } - result = await client.generate(input_batch) - - assert "responses" in result - assert "stop_reasons" in result - assert len(result["responses"]) == 2 - assert all(r == "stop" for r in result["stop_reasons"]) - # response_ids are empty (use tokenize() if needed) - assert all(len(ids) == 0 for ids in result["response_ids"]) - finally: - await client.teardown() + assert "responses" in result + assert "stop_reasons" in result + assert len(result["responses"]) == 2 + assert all(r == "stop" for r in result["stop_reasons"]) + # response_ids are tokenized from the response + assert len(result["response_ids"]) == 2 @pytest.mark.asyncio - async def test_generate_with_session_id(self, mock_servers): + async def test_generate_with_session_id(self, client): """Test generate with session ID for consistent routing.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - input_batch = { - "prompt_token_ids": [[1, 2, 3]], - "session_ids": ["test-session"], - } - result = await client.generate(input_batch) - assert len(result["responses"]) == 1 - finally: - await client.teardown() + input_batch = { + "prompt_token_ids": [[1, 2, 3]], + "session_ids": ["test-session"], + } + result = await client.generate(input_batch) + assert len(result["responses"]) == 1 @pytest.mark.asyncio - async def test_chat_completion(self, mock_servers): + async def test_chat_completion(self, client): """Test chat completion method.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - request_payload = { - "json": { - "model": "test", - "messages": [{"role": "user", "content": "Hello"}], - }, - "headers": {}, - } - result = await client.chat_completion(request_payload) - assert "choices" in result - finally: - await client.teardown() + request_payload = { + "json": { + "model": "test", + "messages": [{"role": "user", "content": "Hello"}], + }, + "headers": {}, + } + result = await client.chat_completion(request_payload) + assert "choices" in result @pytest.mark.asyncio - async def test_completion(self, mock_servers): + async def test_completion(self, client): """Test completion method.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - request_payload = { - "json": {"model": "test", "prompt": "Hello"}, - "headers": {}, - } - result = await client.completion(request_payload) - assert "choices" in result - finally: - await client.teardown() + request_payload = { + "json": {"model": "test", "prompt": "Hello"}, + "headers": {}, + } + result = await client.completion(request_payload) + assert "choices" in result @pytest.mark.asyncio - async def test_tokenize(self, mock_servers): + async def test_tokenize(self, client): """Test tokenize method.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - result = await client.tokenize(["hello", "world"]) - assert len(result) == 2 - assert result[0] == [1, 2, 3] # Mock response - finally: - await client.teardown() + result = await client.tokenize(["hello", "world"]) + assert len(result) == 2 + assert result[0] == [1, 2, 3] # Mock response @pytest.mark.asyncio - async def test_detokenize(self, mock_servers): + async def test_detokenize(self, client): """Test detokenize method.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - result = await client.detokenize([[1, 2, 3], [4, 5, 6]]) - assert len(result) == 2 - assert result[0] == "hello world" # Mock response - finally: - await client.teardown() + result = await client.detokenize([[1, 2, 3], [4, 5, 6]]) + assert len(result) == 2 + assert result[0] == "hello world" # Mock response class TestControlPlane: """Test control plane methods (fan-out to all servers).""" @pytest.mark.asyncio - async def test_pause_abort_mode(self, mock_servers): + async def test_pause_abort_mode(self, client): """Test pause with ABORT mode (default) fans out to all servers.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - result = await client.pause(mode=PauseMode.ABORT) - assert len(result) == 2 - for url, response in result.items(): - assert response["status"] == 200 - assert response["body"]["status"] == "paused" - finally: - await client.teardown() + result = await client.pause(mode=PauseMode.ABORT) + assert len(result) == 2 + for url, response in result.items(): + assert response["status"] == 200 + assert response["body"]["status"] == "paused" @pytest.mark.asyncio - async def test_pause_finish_mode(self, mock_servers): + async def test_pause_finish_mode(self, client): """Test pause with FINISH mode fans out to all servers.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - result = await client.pause(mode=PauseMode.FINISH) - assert len(result) == 2 - for url, response in result.items(): - assert response["status"] == 200 - finally: - await client.teardown() + result = await client.pause(mode=PauseMode.FINISH) + assert len(result) == 2 + for url, response in result.items(): + assert response["status"] == 200 @pytest.mark.asyncio - async def test_pause_keep_mode_not_supported(self, mock_servers): - """Test pause with KEEP mode raises NotImplementedError.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - with pytest.raises(NotImplementedError, match="KEEP is not yet supported"): - await client.pause(mode=PauseMode.KEEP) - finally: - await client.teardown() - - @pytest.mark.asyncio - async def test_resume(self, mock_servers): + async def test_resume(self, client): """Test resume fans out to all servers.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) + # Pause first + await client.pause() - try: - # Pause first - await client.pause() - - # Resume - result = await client.resume() - assert len(result) == 2 - for url, response in result.items(): - assert response["status"] == 200 - finally: - await client.teardown() + # Resume + result = await client.resume() + assert len(result) == 2 + for url, response in result.items(): + assert response["status"] == 200 @pytest.mark.asyncio - async def test_sleep(self, mock_servers): + async def test_sleep(self, client): """Test sleep fans out to all servers.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - result = await client.sleep(level=2) - assert len(result) == 2 - finally: - await client.teardown() + result = await client.sleep(level=2) + assert len(result) == 2 @pytest.mark.asyncio - async def test_wake_up(self, mock_servers): + async def test_wake_up(self, client): """Test wake_up fans out to all servers.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - result = await client.wake_up() - assert len(result) == 2 - finally: - await client.teardown() + result = await client.wake_up() + assert len(result) == 2 @pytest.mark.asyncio - async def test_reset_prefix_cache(self, mock_servers): + async def test_reset_prefix_cache(self, client): """Test reset_prefix_cache fans out to all servers.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - result = await client.reset_prefix_cache() - assert len(result) == 2 - finally: - await client.teardown() + result = await client.reset_prefix_cache() + assert len(result) == 2 class TestWeightSync: """Test weight sync methods.""" @pytest.mark.asyncio - async def test_init_weight_transfer(self, mock_servers): + async def test_init_weight_transfer(self, client): """Test init_weight_transfer fans out to all servers.""" from skyrl_train.weight_sync import BroadcastInitInfo - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], + init_info = BroadcastInitInfo( + master_addr="127.0.0.1", + master_port=29500, + rank_offset=1, + world_size=5, + group_name="test", + backend="nccl", + model_dtype_str="torch.bfloat16", + override_existing_receiver=True, ) - - try: - init_info = BroadcastInitInfo( - master_addr="127.0.0.1", - master_port=29500, - rank_offset=1, - world_size=5, - group_name="test", - backend="nccl", - model_dtype_str="torch.bfloat16", - ) - result = await client.init_weight_transfer(init_info) - assert len(result) == 2 - finally: - await client.teardown() + result = await client.init_weight_transfer(init_info) + assert len(result) == 2 @pytest.mark.asyncio - async def test_update_weights(self, mock_servers): + async def test_update_weights(self, client): """Test update_weights fans out to all servers.""" from skyrl_train.weight_sync import BroadcastWeightUpdateRequest - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], + request = BroadcastWeightUpdateRequest( + names=["layer.weight"], + dtypes=["torch.bfloat16"], + shapes=[[1024, 1024]], ) - - try: - request = BroadcastWeightUpdateRequest( - names=["layer.weight"], - dtypes=["torch.bfloat16"], - shapes=[[1024, 1024]], - ) - result = await client.update_weights(request) - assert len(result) == 2 - finally: - await client.teardown() + result = await client.update_weights(request) + assert len(result) == 2 @pytest.mark.asyncio - async def test_finalize_weight_update(self, mock_servers): + async def test_finalize_weight_update(self, client): """Test finalize_weight_update fans out to all servers.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) - - try: - result = await client.finalize_weight_update() - assert len(result) == 2 - finally: - await client.teardown() + result = await client.finalize_weight_update() + assert len(result) == 2 class TestServerInfo: """Test server info and world_size.""" @pytest.mark.asyncio - async def test_get_world_size(self, mock_servers): + async def test_get_world_size(self, client): """Test world_size fetching and caching.""" - client = RemoteInferenceClient( - proxy_url=mock_servers["proxy_url"], - server_urls=mock_servers["server_urls"], - ) + # First call fetches from all servers and sums + world_size = await client.get_world_size() + # Each mock server reports world_size=2, we have 2 servers = 4 + assert world_size == 4 - try: - # First call fetches from all servers and sums - world_size = await client.get_world_size() - # Each mock server reports world_size=2, we have 2 servers = 4 - assert world_size == 4 - - # Second call returns cached value - world_size2 = await client.get_world_size() - assert world_size2 == 4 - finally: - await client.teardown() + # Second call returns cached value + world_size2 = await client.get_world_size() + assert world_size2 == 4 class TestContextManager: @@ -509,10 +380,13 @@ class TestContextManager: @pytest.mark.asyncio async def test_async_context_manager(self, mock_servers): """Test using client as async context manager.""" - async with RemoteInferenceClient( + + client = RemoteInferenceClient( proxy_url=mock_servers["proxy_url"], server_urls=mock_servers["server_urls"], - ) as client: + ) + + async with client: result = await client.resume() assert len(result) == 2 @@ -536,22 +410,14 @@ async def health(): @app.post("/v1/completions") async def completions(request: Request): call_count["completions"] += 1 - body = await request.json() + await request.json() # Consume body # First call returns abort with partial response if call_count["completions"] == 1: - return { - "choices": [ - {"index": 0, "text": "Partial ", "finish_reason": "abort"} - ] - } + return {"choices": [{"index": 0, "text": "Partial ", "finish_reason": "abort"}]} # Second call returns complete response else: - return { - "choices": [ - {"index": 0, "text": "response complete", "finish_reason": "stop"} - ] - } + return {"choices": [{"index": 0, "text": "response complete", "finish_reason": "stop"}]} @app.post("/tokenize") async def tokenize(request: Request): @@ -600,10 +466,12 @@ async def test_retry_on_abort(self, abort_mock_server): ) try: - result = await client.generate({ - "prompt_token_ids": [[1, 2, 3]], - "sampling_params": {"max_tokens": 100}, - }) + result = await client.generate( + { + "prompt_token_ids": [[1, 2, 3]], + "sampling_params": {"max_tokens": 100}, + } + ) # Should get complete response after retry assert result["stop_reasons"][0] == "stop"