diff --git a/skyrl-train/skyrl_train/env_vars.py b/skyrl-train/skyrl_train/env_vars.py new file mode 100644 index 000000000..9d7eb6b2f --- /dev/null +++ b/skyrl-train/skyrl_train/env_vars.py @@ -0,0 +1,26 @@ +import os + + +SKYRL_VLLM_DP_PORT_OFFSET = int(os.environ.get("SKYRL_VLLM_DP_PORT_OFFSET", 500)) +""" +Offset for the data parallel port of the vLLM server. +""" +SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S = int( + os.environ.get("SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S", 600) +) +""" +Timeout for waiting until the inference server is healthy. +""" +SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV = str( + os.environ.get("SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV", "False") +).lower() in ( + "true", + "1", + "yes", +) +""" +Whether to include the PYTHONPATH environment variable in the runtime +environment. In case of using ray nightly, this will be needed to avoid +dependencies issues by setting it to the local path where ray nightly is +installed. +""" diff --git a/skyrl-train/skyrl_train/inference_engines/base.py b/skyrl-train/skyrl_train/inference_engines/base.py index 392e2100e..ddb893182 100644 --- a/skyrl-train/skyrl_train/inference_engines/base.py +++ b/skyrl-train/skyrl_train/inference_engines/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Dict, TypedDict, Any, Optional, Hashable, TYPE_CHECKING +from typing import List, Dict, TypedDict, Any, Optional, Hashable, Protocol, runtime_checkable, TYPE_CHECKING if TYPE_CHECKING: from skyrl_train.weight_sync.transfer_strategy import WeightSyncInitInfo @@ -31,6 +31,36 @@ class InferenceEngineOutput(TypedDict): response_logprobs: Optional[List[List[float]]] +@runtime_checkable +class InferenceClientProtocol(Protocol): + """ + Structural protocol for inference clients. + + Both InferenceEngineInterface and RemoteInferenceClient satisfy this protocol, + enabling code to work with either implementation via duck typing. + + This is the minimal common interface for: + - Data plane: generate() + - Lifecycle: sleep(), wake_up(), teardown() + """ + + async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOutput: + """Generate completions for a batch of inputs.""" + ... + + async def sleep(self, *args: Any, **kwargs: Any) -> Any: + """Put the engine into sleep/low-power mode.""" + ... + + async def wake_up(self, *args: Any, **kwargs: Any) -> Any: + """Wake the engine from sleep mode.""" + ... + + async def teardown(self) -> None: + """Clean up resources.""" + ... + + class InferenceEngineInterface(ABC): @abstractmethod diff --git a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py index f519a07c1..7c6933a3a 100644 --- a/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py +++ b/skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from http import HTTPStatus import ray -import torch import asyncio import vllm from types import SimpleNamespace @@ -24,7 +23,6 @@ ) from vllm.lora.request import LoRARequest from uuid import uuid4 -import warnings from skyrl_train.inference_engines.base import ( InferenceEngineInterface, InferenceEngineInput, @@ -66,70 +64,10 @@ def setup_envvars_for_vllm(kwargs, bundle_indices): logger.info(f"creating LLM with bundle_indices={bundle_indices}") -class WorkerWrap: - def test_rpc(self, *args, **kwargs): - """Test RPC call to worker""" - return args, kwargs - - def init_weight_update_communicator(self, init_info: bytes): - """Init weight update communicator from init info. - - Args: - init_info: Pickled bytes of WeightSyncInitInfo from the sender. - """ - import pickle - - assert torch.distributed.is_initialized(), "default torch process group must be initialized" - - # Unpickle init_info to restore the original object type - assert isinstance(init_info, bytes), f"Expected bytes, got {type(init_info).__name__}" - init_info = pickle.loads(init_info) - - strategy_cls = init_info.strategy_type() - - if hasattr(self, "_weight_receiver") and self._weight_receiver is not None: - # TODO(haochen): we should get rid of this flag and override existing receiver. - if init_info.override_existing_receiver: - self._weight_receiver.teardown() - self._weight_receiver = None - else: - warnings.warn( - "Detected an existing weight receiver. " - "For overriding, use `generator.override_existing_update_group=enable`" - ) - return - - self._weight_receiver = strategy_cls.create_receiver(init_info) - - def load_weights(self, request: bytes) -> None: - """Load weights using the receiver. - - This method is called via collective_rpc from VLLMWeightLoader. - - Args: - request: Pickled bytes of WeightUpdateRequest. - """ - import pickle - - # Unpickle request to restore the original object type - assert isinstance(request, bytes), f"Expected bytes, got {type(request).__name__}" - request = pickle.loads(request) - - weight_list = [] - for name, tensor in self._weight_receiver.receive_weights(request): - weight_list.append((name, tensor)) - - self.model_runner.model.load_weights(weights=weight_list) - - for weight in weight_list: - del weight - - # TODO (sumanthrh): Add destroy process group RPC as a atexit handler to Trainer code. - def teardown_weight_receiver(self): - if not hasattr(self, "_weight_receiver") or self._weight_receiver is None: - warnings.warn("No weight receiver to teardown") - return - self._weight_receiver.teardown() +# Backward compatibility: WorkerWrap has moved to inference_servers.vllm_worker +# This alias preserves the old import path for existing scripts/configs. +# TODO (Kourosh): Remove this alias once all references are updated. +from skyrl_train.inference_servers.vllm_worker import WorkerWrap # noqa: F401, E402 class BaseVLLMInferenceEngine(InferenceEngineInterface): diff --git a/skyrl-train/skyrl_train/inference_servers/__init__.py b/skyrl-train/skyrl_train/inference_servers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/skyrl-train/skyrl_train/inference_servers/common.py b/skyrl-train/skyrl_train/inference_servers/common.py new file mode 100644 index 000000000..17ae4bb36 --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/common.py @@ -0,0 +1,74 @@ +""" +Common utilities for inference servers. + +Uses Ray's public network utilities for consistency with Ray's cluster management. +""" + +import logging +import socket +from dataclasses import dataclass + +import ray + +logger = logging.getLogger(__name__) + + +@dataclass +class ServerInfo: + """Information about a running inference server.""" + + ip: str + port: int + + @property + def url(self) -> str: + return f"http://{self.ip}:{self.port}" + + +def get_node_ip() -> str: + """ + Get the IP address of the current node. + + Returns the node IP from Ray's global worker if Ray is initialized + """ + return ray.util.get_node_ip_address() + + +def get_open_port(start_port: int | None = None) -> int: + """ + Get an available port. + + Args: + start_port: If provided, search for an available port starting from this value. + If None, let the OS assign a free port. + + Returns: + An available port number. + """ + if start_port is not None: + # Search for available port starting from start_port + port = start_port + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + return port + except OSError: + port += 1 + if port > 65535: + raise RuntimeError(f"No available port found starting from {start_port}") + + # Let OS assign a free port + # Try IPv4 first + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + pass + + # Try IPv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] diff --git a/skyrl-train/skyrl_train/inference_servers/protocols.py b/skyrl-train/skyrl_train/inference_servers/protocols.py new file mode 100644 index 000000000..828d83609 --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/protocols.py @@ -0,0 +1,105 @@ +""" +Protocols for inference server components. + +These define the interfaces that server implementations must follow. +""" + +from argparse import Namespace +from typing import Optional, Protocol, Tuple, runtime_checkable + +from skyrl_train.inference_servers.common import ServerInfo + + +@runtime_checkable +class ServerActorProtocol(Protocol): + """ + Protocol defining the interface for server actor classes. + + Any server actor class (vLLM, SGLang, etc.) must implement this interface + to be usable with ServerGroup. + + Example: + class MyServerActor(ServerActorProtocol): + @staticmethod + def compute_num_gpus_per_server(cli_args: Namespace) -> int: + return cli_args.tensor_parallel_size + + def __init__(self, cli_args, start_port, server_idx, ...): + ... + + async def start(self) -> ServerInfo: + ... + """ + + @staticmethod + def compute_num_gpus_per_server(cli_args: Namespace) -> int: + """ + Compute the number of GPUs needed per server instance. + + This is called before actor creation to determine placement group size. + + Args: + cli_args: Engine-specific CLI arguments. + + Returns: + Number of GPUs required per server (e.g., TP * PP for vLLM). + """ + ... + + def __init__( + self, + cli_args: Namespace, + start_port: int, + server_idx: int, + start_bundle_idx: int, + dp_size: int, + dp_master_address: Optional[str], + dp_rpc_port: Optional[int], + enable_pd: bool, + nixl_side_channel_base: int, + ) -> None: + """ + Initialize the server actor. + + Args: + cli_args: Engine-specific CLI arguments. + start_port: Base port to search for available port. + server_idx: Index of this server in the group (0-indexed). + start_bundle_idx: Starting bundle index in placement group for this server's workers. + dp_size: Data parallel size (-1 to disable DP). + dp_master_address: DP master address (for non-rank-0 servers). + dp_rpc_port: DP RPC port (for non-rank-0 servers). + enable_pd: Enable prefill-decode disaggregation. + nixl_side_channel_base: Base port for NIXL side channels. + """ + ... + + def get_server_info(self) -> ServerInfo: + """Get the server's IP and port info.""" + ... + + def get_dp_info(self) -> Tuple[str, int]: + """ + Get the DP master address and RPC port. + + Only called on server_idx=0 when DP is enabled. + + Returns: + Tuple of (master_address, rpc_port). + """ + ... + + async def start(self) -> ServerInfo: + """ + Start the server. + + This should block until the server is healthy and ready to serve requests. + + Returns: + ServerInfo with the server's IP and port. + """ + ... + + async def shutdown(self) -> None: + """Gracefully shutdown the server.""" + ... diff --git a/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py b/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py new file mode 100644 index 000000000..ee06c72df --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/remote_inference_client.py @@ -0,0 +1,661 @@ +""" +RemoteInferenceClient - Serializable HTTP client for inference. + +This is a lightweight, fully serializable HTTP client that wraps the inference +server HTTP API. It replaces the old InferenceEngineInterface for HTTP-based +inference servers. + +Key features: +- Serializable: Can be pickled and passed between processes +- Two URL types: + - proxy_url: Single URL for data plane operations (routed requests) + - server_urls: List of backend URLs for control plane operations (fan-out) +- Lazy world_size fetching from /get_server_info +- Pure HTTP: Uses /tokenize endpoint if token IDs are needed +- Built-in retry on abort for in-flight weight updates + +Usage: + # Full proxy mode (router handles both data and control plane) + client = RemoteInferenceClient( + proxy_url="http://router:8080", + server_urls=["http://router:8080"], + ) + +Comparison with existing code: +- Replaces: InferenceEngineClient + RemoteInferenceEngine (for remote-only usage) +- Key difference: Talks directly to router via HTTP, no Ray actor wrapping +- The router handles session-aware routing; this client is simpler + +TODO: Data Plane Operations - Future Deprecation +------------------------------------------------ +All data plane operations (generate, chat_completion, completion, tokenize, detokenize) +and the retry-on-abort logic will eventually be removed from this client. + +When vLLM RFC #32103 lands with PauseMode.KEEP: +- The retry logic in generate() will be deleted +- pause() will use mode="keep" which preserves KV cache and scheduler state +- Requests resume seamlessly after unpause with zero client changes + +The generator code will transition to: +1. OpenAI-compatible endpoints (/v1/chat/completions) for text-based interaction +2. Tinker sample API for token-in-token-out workflows: + - Input: ModelInput.from_ints(tokens=input_ids) + - Output: sequences[0].tokens, sequences[0].logprobs + - Internally maps to /v1/completions with token-in-token-out + - May become a native vLLM API in the future + +This client will then primarily handle control plane operations only. +""" + +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING +from dataclasses import asdict + + +import aiohttp + +from skyrl_train.inference_engines.base import InferenceEngineInput, InferenceEngineOutput + +if TYPE_CHECKING: + from skyrl_train.weight_sync import BroadcastInitInfo, BroadcastWeightUpdateRequest + +logger = logging.getLogger(__name__) + + +class PauseMode(Enum): + """ + Pause mode for inference servers. + + This enum mirrors the pause modes that will be available in vLLM RFC #32103. + For now, we map these to the existing `wait_for_inflight_request` parameter. + + Modes: + ABORT: Abort in-flight requests immediately. Clients receive partial + tokens and must retry with accumulated context. + Maps to: wait_for_inflight_request=False + + FINISH: Wait for in-flight requests to complete before pausing. + New requests are blocked. No retry needed. + Maps to: wait_for_inflight_request=True + """ + + ABORT = "abort" + FINISH = "finish" + + +@dataclass +class RemoteInferenceClient: + """ + Serializable HTTP client for inference. Replaces InferenceEngineInterface. + + This class maintains two URL types: + - proxy_url: Single URL for data plane operations (routed requests) + - server_urls: List of backend URLs for control plane operations (fan-out) + + This separation allows using external routers (vllm-router, sglang-router) + that only handle data plane, while still being able to call control plane + endpoints directly on backends. + + For a "full proxy" setup where the router handles both data and control plane, + set proxy_url and server_urls to the same value: + proxy_url = "http://router:8080" + server_urls = ["http://router:8080"] + + For external routers that only handle data plane: + proxy_url = "http://vllm-router:8080" + server_urls = ["http://backend1:8000", "http://backend2:8000"] + """ + + proxy_url: str + """Data plane URL (single endpoint - router or direct server).""" + + server_urls: List[str] + """Control plane URLs (list of backend servers for fan-out).""" + + model_name: str = "default" + """Model name for OpenAI-compatible API calls.""" + + # Private fields excluded from repr for cleaner output + _session: Optional[aiohttp.ClientSession] = field(default=None, repr=False) + _world_size: Optional[int] = field(default=None, repr=False) + + # --------------------------- + # Session Management + # --------------------------- + + async def _get_session(self) -> aiohttp.ClientSession: + """Get or create the aiohttp session.""" + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=None)) + return self._session + + # --------------------------- + # Data Plane + # --------------------------- + + async def generate( + self, + input_batch: InferenceEngineInput, + ) -> InferenceEngineOutput: + """ + Generate completions via /v1/completions. + + This is the interface for token-in-token-out workflows. Input will have + token ids, and the output is token ids as well. + + Each prompt is sent as a separate request to allow the router to route + based on session_id. All requests are made in parallel. + + Args: + input_batch: Contains prompt_token_ids, sampling_params, and optional session_ids. + + Returns: + InferenceEngineOutput with responses, response_ids, and stop_reasons. + """ + + prompt_token_ids = input_batch.get("prompt_token_ids") + if prompt_token_ids is None: + raise ValueError("RemoteInferenceClient only accepts `prompt_token_ids`, not `prompts`.") + + sampling_params = input_batch.get("sampling_params") or {} + if sampling_params.get("n", 1) > 1: + raise ValueError("n > 1 is not supported. Use `config.generator.n_samples_per_prompt` instead.") + + session_ids = input_batch.get("session_ids") + + # Create parallel tasks for all prompts + # Each task handles its own retry on abort + tasks = [ + self._generate_single( + prompt_token_ids=prompt_token_ids[idx], + sampling_params=sampling_params, + session_id=session_ids[idx] if session_ids and idx < len(session_ids) else None, + ) + for idx in range(len(prompt_token_ids)) + ] + + # Run all in parallel - retries happen within each task + results = await asyncio.gather(*tasks) + + return InferenceEngineOutput( + responses=[r["response"] for r in results], + stop_reasons=[r["stop_reason"] for r in results], + response_ids=[r["response_ids"] for r in results], + response_logprobs=None, + ) + + # TODO: Delete retry logic when vLLM RFC #32103 lands with PauseMode.KEEP + async def _generate_single( + self, + prompt_token_ids: List[int], + sampling_params: Dict[str, Any], + session_id: Optional[Any], + ) -> Dict[str, Any]: + """ + Generate completion for a single prompt with built-in retry on abort. + + When pause(mode=ABORT) is called, running requests return partial tokens + with stop_reason="abort". This method retries with accumulated tokens + until generation completes with a non-abort stop reason. + + TODO: Retry logic will be deleted when vLLM RFC #32103 lands. + With PauseMode.KEEP, requests resume seamlessly after unpause. + + Returns: + Dict with keys: response, stop_reason, response_ids + """ + session = await self._get_session() + url = f"{self.proxy_url}/v1/completions" + + # Determine max_tokens key and original value + max_key = None + if "max_tokens" in sampling_params: + max_key = "max_tokens" + elif "max_completion_tokens" in sampling_params: + max_key = "max_completion_tokens" + original_max_tokens = sampling_params.get(max_key) if max_key else None + + # Accumulate across retries + accum_text = "" + accum_token_ids: List[int] = [] + stop_reason = "abort" + + while stop_reason == "abort": + # Build payload with accumulated context + cur_params = sampling_params.copy() + if original_max_tokens is not None and max_key: + remaining = original_max_tokens - len(accum_token_ids) + if remaining <= 0: + break + cur_params[max_key] = remaining + + # New prompt = original + accumulated tokens + new_prompt = prompt_token_ids + accum_token_ids + + payload = cur_params.copy() + payload["model"] = self.model_name + payload["prompt"] = new_prompt + + headers = {"Content-Type": "application/json"} + if session_id: + headers["X-Session-ID"] = str(session_id) + + async with session.post(url, json=payload, headers=headers) as resp: + resp.raise_for_status() + response = await resp.json() + + choice = response["choices"][0] + new_text = choice["text"] + stop_reason = choice["finish_reason"] + + # Accumulate text + accum_text += new_text + # Tokenize the new text to get token IDs for next iteration + if stop_reason == "abort" and new_text: + new_token_ids = (await self.tokenize([new_text], add_special_tokens=False))[0] + accum_token_ids.extend(new_token_ids) + + # Final response + # Tokenize full accumulated text for response_ids + final_token_ids = (await self.tokenize([accum_text], add_special_tokens=False))[0] if accum_text else [] + + return { + "response": accum_text, + "stop_reason": stop_reason, + "response_ids": final_token_ids, + } + + async def chat_completion( + self, + request_payload: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Chat completion via /v1/chat/completions. + + Args: + request_payload: Dict with {"json": , "headers": }. + The request body should be OpenAI-compatible chat completion request. + session_id can be included in json for consistent routing. + + Returns: + OpenAI-compatible chat completion response. + """ + body = request_payload.get("json", {}) + + # Extract session_id for routing (same as InferenceEngineClient) + session_id = body.pop("session_id", None) + + headers = {"Content-Type": "application/json"} + if session_id: + headers["X-Session-ID"] = str(session_id) + + session = await self._get_session() + url = f"{self.proxy_url}/v1/chat/completions" + + async with session.post(url, json=body, headers=headers) as resp: + resp.raise_for_status() + return await resp.json() + + async def completion( + self, + request_payload: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Completion via /v1/completions. + + Args: + request_payload: Dict with {"json": , "headers": }. + The request body should be OpenAI-compatible completion request. + session_id can be included in json for consistent routing. + + Returns: + OpenAI-compatible completion response. + """ + body = request_payload.get("json", {}) + + # Extract session_id for routing (same as InferenceEngineClient) + session_id = body.pop("session_id", None) + + headers = {"Content-Type": "application/json"} + if session_id: + headers["X-Session-ID"] = str(session_id) + + session = await self._get_session() + url = f"{self.proxy_url}/v1/completions" + + async with session.post(url, json=body, headers=headers) as resp: + resp.raise_for_status() + return await resp.json() + + async def tokenize( + self, + texts: List[str], + add_special_tokens: bool = True, + ) -> List[List[int]]: + """ + Tokenize texts via /tokenize. + + Args: + texts: List of texts to tokenize. + add_special_tokens: Whether to add special tokens. + + Returns: + List of token ID lists. + """ + session = await self._get_session() + url = f"{self.proxy_url}/tokenize" + + # vLLM /tokenize expects individual requests, batch them + results = [] + for text in texts: + payload = { + "model": self.model_name, + "prompt": text, + "add_special_tokens": add_special_tokens, + } + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + result = await resp.json() + results.append(result.get("tokens", [])) + + return results + + async def detokenize( + self, + token_ids: List[List[int]], + ) -> List[str]: + """ + Detokenize token IDs via /detokenize. + + Args: + token_ids: List of token ID lists. + + Returns: + List of decoded texts. + """ + session = await self._get_session() + url = f"{self.proxy_url}/detokenize" + + # vLLM /detokenize expects individual requests, batch them + results = [] + for ids in token_ids: + payload = { + "model": self.model_name, + "tokens": ids, + } + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + result = await resp.json() + results.append(result.get("prompt", "")) + + return results + + # --------------------------- + # Control Plane (fan-out to all server_urls) + # --------------------------- + + async def _call_all_servers( + self, + endpoint: str, + json: Dict[str, Any], + method: str = "POST", + ) -> Dict[str, Any]: + """ + Call endpoint on all server_urls concurrently. + + Args: + endpoint: Endpoint path (e.g., "/pause"). + json: JSON payload to send. + method: HTTP method (default: POST). + + Returns: + Dict mapping server_url to response. + """ + session = await self._get_session() + + async def call_server(server_url: str) -> tuple: + url = f"{server_url}{endpoint}" + try: + async with session.request(method, url, json=json) as resp: + body = await resp.json() if resp.content_length else None + return server_url, {"status": resp.status, "body": body} + except Exception as e: + return server_url, {"status": 500, "error": str(e)} + + results = await asyncio.gather(*[call_server(url) for url in self.server_urls]) + return {url: resp for url, resp in results} + + async def pause(self, mode: Union[PauseMode, str] = PauseMode.ABORT) -> Dict[str, Any]: + """ + Pause generation on all backends. + + Args: + mode: Pause mode determining how in-flight requests are handled. + Can be a PauseMode enum or string ("abort", "finish"). + - ABORT / "abort": Abort in-flight requests immediately. Clients + receive partial tokens and must retry with accumulated context. + New requests are blocked. + - FINISH / "finish": Wait for in-flight requests to complete before + pausing. New requests are blocked. No retry needed. + + Returns: + Dict mapping server_url to response. + + TODO: + When vLLM RFC #32103 lands, we'll use the native mode parameter. + For now, we map modes to wait_for_inflight_request: + - ABORT → wait_for_inflight_request=False + - FINISH → wait_for_inflight_request=True + """ + # Convert string to PauseMode if needed + if isinstance(mode, str): + mode = PauseMode(mode.lower()) + + wait_for_inflight_request = mode == PauseMode.FINISH + + return await self._call_all_servers("/pause", {"wait_for_inflight_request": wait_for_inflight_request}) + + async def resume(self) -> Dict[str, Any]: + """Resume generation on all backends.""" + return await self._call_all_servers("/resume", {}) + + async def sleep(self, level: int = 2) -> Dict[str, Any]: + """ + Put all backends to sleep (offload weights to CPU). + + Args: + level: Sleep level (1 or 2). Level 2 offloads more aggressively. + + Returns: + Dict mapping server_url to response. + """ + return await self._call_all_servers("/sleep", {"level": level}) + + async def wake_up(self) -> Dict[str, Any]: + """Wake up all backends (load weights back to GPU).""" + return await self._call_all_servers("/wake_up", {}) + + async def reset_prefix_cache( + self, + reset_running_requests: bool = False, + ) -> Dict[str, Any]: + """ + Reset KV cache on all backends. + + Args: + reset_running_requests: Whether to reset running requests. + + Returns: + Dict mapping server_url to response. + """ + return await self._call_all_servers("/reset_prefix_cache", {"reset_running_requests": reset_running_requests}) + + # --------------------------- + # Weight Sync (control plane - fan-out) + # --------------------------- + + async def init_weight_transfer( + self, + init_info: "BroadcastInitInfo", + ) -> Dict[str, Any]: + """ + Initialize weight sync process group on all backends. + + Args: + init_info: BroadcastInitInfo containing all args for weight sync setup. + + Returns: + Dict mapping server_url to response. + """ + return await self._call_all_servers("/init_weight_transfer", asdict(init_info)) + + async def update_weights( + self, + request: "BroadcastWeightUpdateRequest", + ) -> Dict[str, Any]: + """ + Update weights on all backends. + + Args: + request: BroadcastWeightUpdateRequest containing weight metadata. + + Returns: + Dict mapping server_url to response. + """ + return await self._call_all_servers("/update_weights", asdict(request)) + + async def finalize_weight_update(self) -> Dict[str, Any]: + """ + Finalize weight update on all backends. + + Called after all update_weights() calls are complete. + Reserved for any post-processing steps that may be needed: + - Cache invalidation + - State synchronization + - Future vLLM requirements + + Returns: + Dict mapping server_url to response. + """ + return await self._call_all_servers("/finalize_weight_update", {}) + + # --------------------------- + # Info + # --------------------------- + + async def get_world_size(self) -> int: + """ + Get total world size across all inference workers. + + Fetches from /get_server_info on each server and sums the world_size values. + Result is cached after first call. + """ + if self._world_size is not None: + return self._world_size + + results = await self._call_all_servers("/get_server_info", {}, method="GET") + + total_world_size = 0 + for server_url, resp in results.items(): + if resp.get("status") != 200: + error = resp.get("error", resp.get("body")) + raise RuntimeError(f"Failed to fetch server info from {server_url}: {error}") + body = resp.get("body", {}) + world_size = body.get("world_size") + if world_size is None: + raise RuntimeError(f"Failed to fetch server info from {server_url}: world_size is missing") + total_world_size += world_size + + self._world_size = total_world_size + return self._world_size + + # --------------------------- + # Lifecycle + # --------------------------- + + async def teardown(self) -> None: + """Close HTTP session.""" + if self._session and not self._session.closed: + await self._session.close() + self._session = None + + async def __aenter__(self) -> "RemoteInferenceClient": + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Async context manager exit.""" + await self.teardown() + + # --------------------------- + # Serialization + # --------------------------- + + def __getstate__(self) -> dict: + """Exclude non-serializable fields from pickle.""" + state = self.__dict__.copy() + state["_session"] = None + return state + + def __setstate__(self, state: dict) -> None: + """Restore state after unpickling.""" + self.__dict__.update(state) + self._session = None + + # --------------------------- + # Compatibility helpers + # --------------------------- + + @classmethod + def from_server_group( + cls, + server_urls: List[str], + router_url: str, + model_name: str = "default", + ) -> "RemoteInferenceClient": + """ + Create client from server group URLs. + + Args: + server_urls: List of backend server URLs. + router_url: Router URL for data plane. + model_name: Model name for API calls. + + Returns: + Configured RemoteInferenceClient. + """ + return cls( + proxy_url=router_url, + server_urls=server_urls, + model_name=model_name, + ) + + @classmethod + def from_router( + cls, + router_url: str, + model_name: str = "default", + ) -> "RemoteInferenceClient": + """ + Create client for a full-feature router (handles both data and control plane). + + This is for routers like InferenceRouter that support all endpoints. + Control plane calls go to the router which fans out to backends. + + Args: + router_url: Router URL. + model_name: Model name for API calls. + + Returns: + Configured RemoteInferenceClient. + """ + return cls( + proxy_url=router_url, + server_urls=[router_url], + model_name=model_name, + ) diff --git a/skyrl-train/skyrl_train/inference_servers/router.py b/skyrl-train/skyrl_train/inference_servers/router.py new file mode 100644 index 000000000..d8863d21d --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/router.py @@ -0,0 +1,302 @@ +""" +Inference Router - HTTP proxy with session-aware routing and control plane fan-out. +""" + +import asyncio +import hashlib +import itertools +import logging +import threading +import time +from typing import List, Optional + +import httpx +import uvicorn +from fastapi import FastAPI, Request, Response + +from skyrl_train.inference_servers.common import get_node_ip +from skyrl_train.env_vars import SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S + +logger = logging.getLogger(__name__) + + +# Control plane routes are broadcast to ALL servers. +# Everything else (data plane) is load-balanced to ONE server. +CONTROL_PLANE_ROUTES = [ + # BUILT-IN ROUTES + "/pause", + "/resume", + "/is_paused", + "/sleep", + "/wake_up", + "/reset_prefix_cache", + "/collective_rpc", + # SKYRL-SPECIFIC ROUTES + "/init_weight_transfer", + "/update_weights", + "/finalize_weight_update", +] + + +class InferenceRouter: + """ + HTTP proxy router for multiple vLLM servers. + + Routing behavior: + - Data plane (generation requests): Routes to ONE server. + - If X-Session-ID header present: consistent hash to same backend + - Otherwise: round-robin + - Control plane (sleep, pause, weight sync): Fans out to ALL backends + + Usage: + router = InferenceRouter(server_urls, host="0.0.0.0", port=8080) + router_url = router.start() + # ... use router_url for inference ... + router.shutdown() + """ + + def __init__( + self, + server_urls: List[str], + host: str = "0.0.0.0", + port: int = 8080, + ): + """ + Initialize the router. + + Args: + server_urls: List of backend vLLM server URLs + host: Host to bind router to + port: Port to bind router to + """ + self._server_urls = server_urls + self._host = host + self._port = port + self._server_cycle = itertools.cycle(server_urls) + self._client: Optional[httpx.AsyncClient] = None + self._app: Optional[FastAPI] = None + self._server: Optional[uvicorn.Server] = None + self._server_thread: Optional[threading.Thread] = None + + logger.info(f"InferenceRouter: {len(server_urls)} servers, port={port}") + + def _hash_session_id(self, session_id: str) -> int: + """Hash session ID to get consistent server index.""" + hash_bytes = hashlib.sha256(session_id.encode()).digest() + return int.from_bytes(hash_bytes[:8], "big") + + def _get_server_for_session(self, session_id: str) -> str: + """Get consistent server URL for a session ID.""" + idx = self._hash_session_id(session_id) % len(self._server_urls) + return self._server_urls[idx] + + def _get_server_round_robin(self) -> str: + """Get next server URL in round-robin order.""" + return next(self._server_cycle) + + def _get_server_for_request(self, request: Request) -> str: + """ + Determine server for a request. + + If X-Session-ID header is present, use consistent hashing. + Otherwise, use round-robin. + """ + session_id = request.headers.get("X-Session-ID") + if session_id: + return self._get_server_for_session(session_id) + return self._get_server_round_robin() + + def _is_control_plane_route(self, path: str) -> bool: + """Check if path is a control plane route (fan-out to all).""" + return any(path.startswith(route) for route in CONTROL_PLANE_ROUTES) + + def _build_app(self) -> FastAPI: + """Build the FastAPI app with proxy routes.""" + app = FastAPI( + title="SkyRL Inference Router", + docs_url=None, + redoc_url=None, + openapi_url=None, + ) + + @app.get("/health") + async def health(): + """Router health check (doesn't proxy to backends).""" + return {"status": "healthy"} + + @app.get("/servers") + async def list_servers(): + """Return list of server URLs.""" + return {"servers": self._server_urls} + + @app.get("/get_server_info") + async def get_server_info(): + """Fetch server info from all servers, return mapping.""" + return await self._fan_out_get("/get_server_info") + + # Catch-all: proxy everything else to backends + @app.api_route( + "/{path:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"], + ) + async def proxy(request: Request, path: str): + return await self._proxy_request(request, f"/{path}") + + return app + + async def _proxy_request(self, request: Request, path: str) -> Response: + """ + Proxy a request to backend(s). + + Control plane routes go to ALL backends. + Data plane routes go to ONE backend (session-aware or round-robin). + """ + if self._is_control_plane_route(path): + return await self._proxy_to_all(request, path) + else: + return await self._proxy_to_one(request, path) + + def _forward_headers(self, request: Request) -> dict: + """Forward headers (filter out hop-by-hop headers).""" + return { + k: v for k, v in request.headers.items() if k.lower() not in ("host", "content-length", "transfer-encoding") + } + + async def _fan_out_get(self, path: str) -> dict: + """Fan out a GET request to all servers, return mapping of server_url -> response.""" + + async def call_server(server_url: str): + try: + resp = await self._client.get(f"{server_url}{path}", timeout=30.0) + return server_url, resp.json() if resp.content else None + except Exception as e: + return server_url, {"error": str(e)} + + results = await asyncio.gather(*[call_server(url) for url in self._server_urls]) + return {url: response for url, response in results} + + async def _proxy_to_one(self, request: Request, path: str) -> Response: + """Proxy request to one server (data plane).""" + server_url = self._get_server_for_request(request) + url = f"{server_url}{path}" + + # Forward headers (filter out hop-by-hop headers) + headers = self._forward_headers(request) + + response = await self._client.request( + method=request.method, + url=url, + headers=headers, + content=await request.body(), + ) + + return Response( + content=response.content, + status_code=response.status_code, + headers=dict(response.headers), + ) + + async def _proxy_to_all(self, request: Request, path: str) -> Response: + """Proxy request to all servers (control plane), return mapping of responses.""" + import json + + method = request.method + body = await request.body() + + # Forward headers + headers = self._forward_headers(request) + + # Send to all servers concurrently + async def call_server(server_url: str): + url = f"{server_url}{path}" + try: + response = await self._client.request( + method=method, + url=url, + headers=headers, + content=body, + ) + return server_url, { + "status": response.status_code, + "body": response.json() if response.content else None, + } + except Exception as e: + return server_url, { + "status": 500, + "error": str(e), + } + + results = await asyncio.gather(*[call_server(url) for url in self._server_urls]) + + # Build mapping from server_url to response + response_map = {url: resp for url, resp in results} + + # Check if all succeeded + all_ok = all(r.get("status") == 200 for r in response_map.values()) + status_code = 200 if all_ok else 207 # Multi-Status on partial failure + + return Response( + content=json.dumps(response_map), + status_code=status_code, + media_type="application/json", + ) + + def start(self) -> str: + """ + Start the router server in background. + + Returns: + Router URL (e.g., "http://192.168.1.1:8080") + """ + if not self._server_urls: + raise ValueError("No servers available") + + # Create HTTP client for proxying + self._client = httpx.AsyncClient(timeout=httpx.Timeout(None)) + + # Build FastAPI app and uvicorn server + self._app = self._build_app() + config = uvicorn.Config( + app=self._app, + host=self._host, + port=self._port, + log_level="warning", + access_log=False, + ) + self._server = uvicorn.Server(config) + + # Start server in background thread + self._server_thread = threading.Thread(target=asyncio.run, args=(self._server.serve(),), daemon=True) + self._server_thread.start() + + ip = get_node_ip() + router_url = f"http://{ip}:{self._port}" + self._wait_until_healthy(router_url) + + logger.info(f"Router started at {router_url}") + logger.info(" GET /servers - list servers") + logger.info(" GET /get_server_info - get parallelism info") + return router_url + + def _wait_until_healthy( + self, router_url: str, timeout: float = SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S + ) -> None: + """Poll health endpoint until server is ready.""" + health_url = f"{router_url}/health" + start_time = time.time() + while time.time() - start_time < timeout: + try: + with httpx.Client() as client: + if client.get(health_url, timeout=1).status_code == 200: + return + except httpx.RequestError: + time.sleep(0.1) + raise RuntimeError(f"Router failed to start within {timeout}s") + + def shutdown(self) -> None: + """Shutdown the router gracefully.""" + logger.info("Shutting down router...") + if self._server: + self._server.should_exit = True + if self._server_thread: + self._server_thread.join(timeout=5) diff --git a/skyrl-train/skyrl_train/inference_servers/server_group.py b/skyrl-train/skyrl_train/inference_servers/server_group.py new file mode 100644 index 000000000..4d0e53e6e --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/server_group.py @@ -0,0 +1,185 @@ +""" +Server Group - manages server actors with placement groups. +""" + +import logging +from argparse import Namespace +from typing import Any, List, Optional, Type + +import ray +from ray.util.placement_group import PlacementGroup, placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from skyrl_train.inference_servers.common import ServerInfo +from skyrl_train.inference_servers.protocols import ServerActorProtocol +from skyrl_train.inference_servers.server_pool import ServerActorPool +from skyrl_train.inference_servers.vllm_server_actor import VLLMServerActor + +logger = logging.getLogger(__name__) + + +class ServerGroup: + """ + Creates and manages a group of server actors. + + This layer handles actor creation with placement group support, + then delegates pool management to ServerActorPool. + + Supports: + - Basic mode: Creates its own placement group + - Colocation mode: Uses external placement group (shared with training) + - Data Parallel: Multiple DP-enabled servers + - PD Disaggregation: Prefill-decode disaggregation with NIXL + """ + + def __init__( + self, + cli_args: Namespace, + num_servers: int, + start_port: int = 8000, + placement_group: Optional[PlacementGroup] = None, + placement_group_bundle_offset: int = 0, + enable_dp: bool = False, + enable_pd: bool = False, + nixl_side_channel_base: int = 5600, + server_actor_cls: Type[ServerActorProtocol] = VLLMServerActor, + ): + """ + Initialize the server group. + + Args: + cli_args: CLI arguments for the server (engine-specific). + num_servers: Number of server instances to create. + start_port: Base port for server ports. + placement_group: External placement group for colocation mode. + If None, creates internal placement group. + placement_group_bundle_offset: Offset for bundle indices when using + external placement group (e.g., if training uses first N + bundles). + enable_dp: Enable data parallelism across servers. + enable_pd: Enable prefill-decode disaggregation. + nixl_side_channel_base: Base port for NIXL side channels. Each + server will be assigned a port of nixl_side_channel_base + + server_idx. + server_actor_cls: Server actor class implementing ServerActorProtocol. + Defaults to VLLMServerActor. + """ + self._server_actor_cls = server_actor_cls + self._cli_args = cli_args + self._num_servers = num_servers + self._start_port = start_port + self._external_pg = placement_group + self._bundle_offset = placement_group_bundle_offset + self._enable_dp = enable_dp + self._enable_pd = enable_pd + self._nixl_side_channel_base = nixl_side_channel_base + self._pool: Optional[ServerActorPool] = None + self._internal_pg: Optional[PlacementGroup] = None + + # Query the actor class for GPU requirements + self._num_gpus_per_server = server_actor_cls.compute_num_gpus_per_server(cli_args) + + logger.info( + f"ServerGroup: actor_cls={server_actor_cls.__name__}, " + f"num_servers={num_servers}, " + f"gpus_per_server={self._num_gpus_per_server}, " + f"enable_dp={enable_dp}, enable_pd={enable_pd}, " + f"external_pg={'yes' if placement_group else 'no'}" + ) + + def _create_placement_group(self) -> PlacementGroup: + """Create internal placement group with bundles for all servers.""" + total_bundles = self._num_servers * self._num_gpus_per_server + logger.info(f"Creating placement group with {total_bundles} bundles...") + pg = placement_group([{"CPU": 1, "GPU": 1} for _ in range(total_bundles)]) + ray.get(pg.ready()) + logger.info("Placement group ready") + return pg + + def _get_placement_group(self) -> PlacementGroup: + """Get the placement group (external or internal).""" + if self._external_pg is not None: + return self._external_pg + if self._internal_pg is None: + self._internal_pg = self._create_placement_group() + return self._internal_pg + + def _create_actor_class(self, pg: PlacementGroup, start_bundle_idx: int) -> Any: + """Create actor class with scheduling constraints for a specific bundle.""" + return ray.remote(self._server_actor_cls).options( + num_gpus=0, # GPU allocation managed by placement group + num_cpus=1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=start_bundle_idx, + ), + ) + + def _create_actors(self) -> List[Any]: + """Create server actors with GPU resources.""" + pg = self._get_placement_group() + + actors = [] + dp_address, dp_rpc_port = None, None + + for server_idx in range(self._num_servers): + # Calculate bundle index accounting for offset (colocation mode) + start_bundle_idx = self._bundle_offset + server_idx * self._num_gpus_per_server + + ServerActorClass = self._create_actor_class(pg, start_bundle_idx) + + actor = ServerActorClass.remote( + self._cli_args, + self._start_port + server_idx, + server_idx=server_idx, + start_bundle_idx=start_bundle_idx, + dp_size=self._num_servers if self._enable_dp else -1, + dp_master_address=dp_address, + dp_rpc_port=dp_rpc_port, + enable_pd=self._enable_pd, + nixl_side_channel_base=self._nixl_side_channel_base, + ) + + # Get DP info from server 0 which is where DP0 will be + if self._enable_dp and server_idx == 0: + dp_address, dp_rpc_port = ray.get(actor.get_dp_info.remote()) + logger.info(f"DP0 info: address={dp_address}, rpc_port={dp_rpc_port}") + + actors.append(actor) + + return actors + + def start(self) -> List[ServerInfo]: + """Create actors, start the pool, and return endpoints.""" + logger.info(f"Starting {self._num_servers} server(s)...") + actors = self._create_actors() + self._pool = ServerActorPool(actors) + server_infos = self._pool.start() + + for i, info in enumerate(server_infos): + logger.info(f"Server {i}: {info.url}") + + return server_infos + + def get_pool(self) -> Optional[ServerActorPool]: + """Get the underlying actor pool.""" + return self._pool + + def get_server_infos(self) -> List[ServerInfo]: + """Get the list of server endpoints.""" + return self._pool.get_server_infos() if self._pool else [] + + def get_server_urls(self) -> List[str]: + """Get the list of server URLs.""" + return self._pool.get_server_urls() if self._pool else [] + + def get_actors(self) -> List[Any]: + """Get the list of actor handles.""" + return self._pool.get_actors() if self._pool else [] + + def shutdown(self) -> None: + """Shutdown all servers.""" + if self._pool: + logger.info("Shutting down servers...") + self._pool.shutdown() diff --git a/skyrl-train/skyrl_train/inference_servers/server_pool.py b/skyrl-train/skyrl_train/inference_servers/server_pool.py new file mode 100644 index 000000000..620d5db04 --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/server_pool.py @@ -0,0 +1,57 @@ +""" +Generic server actor pool. +""" + +from typing import Any, List + +import ray + +from skyrl_train.inference_servers.common import ServerInfo + + +class ServerActorPool: + """Generic pool that manages a list of server actors. + + This layer provides a generic pool interface which can be extended to + support fault-tolerance, monitoring, etc. for now it's just a simple wrapper around a list of actor handles. + + Actors must implement: + - start() -> ServerInfo + - shutdown() -> None + + This layer is agnostic to the type of server (vLLM, SGLang, etc). + """ + + def __init__(self, actors: List[Any]): + """ + Initialize the pool with pre-constructed actor handles. + + Args: + actors: List of Ray actor handles + """ + self._actors = actors + self._server_infos: List[ServerInfo] = [] + + def start(self) -> List[ServerInfo]: + """Start all actors and collect their server infos.""" + # Start all actors in parallel, wait for all to be ready + start_refs = [actor.start.remote() for actor in self._actors] + self._server_infos = ray.get(start_refs) + return self._server_infos + + def get_server_infos(self) -> List[ServerInfo]: + """Get the list of server endpoints.""" + return self._server_infos + + def get_server_urls(self) -> List[str]: + """Get the list of server URLs.""" + return [info.url for info in self._server_infos] + + def get_actors(self) -> List[Any]: + """Get the list of actor handles.""" + return self._actors + + def shutdown(self) -> None: + """Shutdown all actors.""" + shutdown_refs = [actor.shutdown.remote() for actor in self._actors] + ray.get(shutdown_refs) diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py new file mode 100644 index 000000000..019018705 --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/vllm_server_actor.py @@ -0,0 +1,361 @@ +""" +vLLM Server Actor - Ray actor running a vLLM OpenAI-compatible API server. +""" + +import asyncio +import logging +import os +import pickle +import time +from argparse import Namespace +from typing import Any, Dict, Optional, Tuple + +import httpx +import uvicorn +import vllm.envs as envs +from fastapi import Request +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.api_server import ( + build_app, + create_server_socket, + init_app_state, +) +from vllm.usage.usage_lib import UsageContext +from vllm.utils.system_utils import set_ulimit + +from skyrl_train.env_vars import ( + SKYRL_VLLM_DP_PORT_OFFSET, + SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S, +) +from skyrl_train.inference_servers.common import ServerInfo, get_node_ip, get_open_port +from skyrl_train.inference_servers.protocols import ServerActorProtocol +from skyrl_train.inference_servers.vllm_worker import VLLM_WORKER_EXTENSION_CLS + +logger = logging.getLogger(__name__) + + +class VLLMServerActor(ServerActorProtocol): + """ + Ray actor that runs a vLLM OpenAI-compatible API server. + + Implements ServerActorProtocol for use with ServerGroup. + + The server runs in the actor and exposes an HTTP endpoint that can be + called from anywhere (other actors, driver, external processes). + + Custom endpoints added for SkyRL: + - /get_server_info: Return parallelism info + + - (vLLM RFC: https://github.com/vllm-project/vllm/issues/31848) + - /init_weight_transfer: Initialize weight sync process group + - /update_weights: Update model weights via NCCL broadcast + - /finalize_weight_update: Post-processing after weight sync + """ + + @staticmethod + def compute_num_gpus_per_server(vllm_cli_args: Namespace) -> int: + """Compute the number of GPUs needed per server based on TP * PP. + + This logic might need adjustment if we want to support other + parallelism schemes. If we get to this point, we should add a + vllm-specific utility for it and keep the logic inside the engine. + """ + return vllm_cli_args.tensor_parallel_size * vllm_cli_args.pipeline_parallel_size + + def __init__( + self, + vllm_cli_args: Namespace, + start_port: int = 8000, + server_idx: int = 0, + start_bundle_idx: int = 0, + dp_size: int = -1, + dp_master_address: Optional[str] = None, + dp_rpc_port: Optional[int] = None, + # PD disaggregation settings + enable_pd: bool = False, + nixl_side_channel_base: int = 5600, + ): + """ + Initialize the vLLM server actor. + + Args: + vllm_cli_args: vLLM CLI arguments. + Required attributes: tensor_parallel_size, pipeline_parallel_size. + Optional: uvicorn_log_level, ssl_*, disable_uvicorn_access_log, kv_transfer_config. + start_port: Base port to start searching for free port + server_idx: Index of this server in the group + start_bundle_idx: Starting bundle index in the placement group for this server's workers + dp_size: Data parallel size (-1 to disable) + dp_master_address: DP master address (for non-rank-0 servers) + dp_rpc_port: DP RPC port (for non-rank-0 servers) + enable_pd: Enable prefill-decode disaggregation + nixl_side_channel_base: Base port for NIXL side channel + """ + self._cli_args = vllm_cli_args + self._ip = get_node_ip() + self._port = get_open_port(start_port) + self._server_idx = server_idx + self._num_gpus_per_server = self.compute_num_gpus_per_server(vllm_cli_args) + + # Ensure SkyRL's custom worker extension is used for weight sync + self._ensure_worker_extension() + + # Ensure Ray executor is used (required for GPU inheritance in placement groups) + self._ensure_ray_executor() + + # Update args with our assigned host/port + self._cli_args.host = "0.0.0.0" + self._cli_args.port = self._port + + # PD disaggregation: setup NIXL side channel for KV transfer + if enable_pd: + self._setup_nixl_side_channel(nixl_side_channel_base) + + # Each engine needs to know its dp_rank and dp_size so DP process groups are formed + if dp_size > 0: + self._cli_args.data_parallel_size = dp_size + self._cli_args.data_parallel_rank = server_idx + + # DP0 will be the master sharing its ip and port with others. + # So if we are not DP0, we need to pass master_ip and port from + # outside. otherwise, we can use the local ip and port. + if server_idx == 0: + dp_master_address, dp_rpc_port = self.get_dp_info() + + if dp_master_address is None or dp_rpc_port is None: + raise ValueError("DP address and RPC port must be set for non-server 0") + + self._cli_args.data_parallel_address = dp_master_address + self._cli_args.data_parallel_rpc_port = dp_rpc_port + logger.info( + f"Server {server_idx}: DP enabled - dp_size={dp_size}, dp_rank={server_idx}, " + f"dp_master_address={dp_master_address}, dp_rpc_port={dp_rpc_port}" + ) + + # Set bundle indices for this server's TP/PP workers in the placement group + bundle_indices = list(range(start_bundle_idx, start_bundle_idx + self._num_gpus_per_server)) + os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) + logger.info(f"Server {server_idx}: using bundle indices {bundle_indices}") + + # Initialized lazily to not block the actor initialization. + self._engine: Optional[AsyncLLMEngine] = None + self._server_task: Optional[asyncio.Task] = None + + def _ensure_worker_extension(self) -> None: + """ + Ensure the SkyRL worker extension is configured. + + The worker extension (WorkerWrap) provides the RPC methods needed for + weight synchronization (init_weight_update_communicator, load_weights). + """ + if not hasattr(self._cli_args, "worker_extension_cls") or not self._cli_args.worker_extension_cls: + self._cli_args.worker_extension_cls = VLLM_WORKER_EXTENSION_CLS + logger.info(f"Using default worker extension: {VLLM_WORKER_EXTENSION_CLS}") + else: + logger.info(f"Using provided worker extension: {self._cli_args.worker_extension_cls}") + + def _ensure_ray_executor(self) -> None: + """ + Ensure Ray is used as the distributed executor backend. + + When running inside a Ray actor, we must use the Ray executor so that + workers are spawned and properly inherit GPU allocation from the + placement group. + """ + if ( + not hasattr(self._cli_args, "distributed_executor_backend") + or self._cli_args.distributed_executor_backend != "ray" + ): + self._cli_args.distributed_executor_backend = "ray" + + def _setup_nixl_side_channel(self, base_port: int) -> None: + """ + Setup NIXL side channel for PD disaggregation. + + Each server instance needs a unique side channel port for KV transfer handshake. + """ + import json + + side_channel_port = base_port + self._server_idx + os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(side_channel_port) + os.environ["VLLM_NIXL_SIDE_CHANNEL_HOST"] = self._ip + + engine_id = f"server-{self._server_idx}-{self._ip}-{side_channel_port}" + + if hasattr(self._cli_args, "kv_transfer_config") and self._cli_args.kv_transfer_config: + try: + kv_config = json.loads(self._cli_args.kv_transfer_config) + except (json.JSONDecodeError, TypeError) as e: + raise ValueError( + f"Invalid kv_transfer_config: expected valid JSON string, " + f"got {type(self._cli_args.kv_transfer_config).__name__}: {e}" + ) from e + kv_config["engine_id"] = engine_id + self._cli_args.kv_transfer_config = json.dumps(kv_config) + + logger.info( + f"Server {self._server_idx}: NIXL side channel configured - " + f"host={self._ip}, port={side_channel_port}, engine_id={engine_id}" + ) + + def get_server_info(self) -> ServerInfo: + """Get the server's IP and port info.""" + return ServerInfo(ip=self._ip, port=self._port) + + def _get_extended_server_info(self) -> Dict[str, Any]: + """Return extended server info including parallelism settings.""" + return { + "ip": self._ip, + "port": self._port, + "url": f"http://{self._ip}:{self._port}", + "server_idx": self._server_idx, + "world_size": self._num_gpus_per_server, + } + + def get_dp_info(self) -> Tuple[str, int]: + """Get the DP master address and RPC port (for server 0 to share with others).""" + dp_rpc_port = self._port + SKYRL_VLLM_DP_PORT_OFFSET + return (self._ip, dp_rpc_port) + + async def start(self) -> ServerInfo: + """Start the vLLM server. Blocks until server is healthy.""" + + set_ulimit() + logger.info(f"Starting server on {self._ip}:{self._port}...") + + # Start HTTP server as background asyncio task + self._server_task = asyncio.create_task(self._run_server()) + + # Wait until the server is actually healthy + await self._wait_until_healthy() + + return self.get_server_info() + + async def _wait_until_healthy(self, timeout: float = SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S) -> None: + """Poll the /health endpoint until it responds OK.""" + url = f"http://{self._ip}:{self._port}/health" + start_time = time.time() + + async with httpx.AsyncClient() as client: + while True: + # Check if server task failed + if self._server_task.done(): + exc = self._server_task.exception() + if exc: + raise exc + raise RuntimeError("Server task exited unexpectedly") + + try: + resp = await client.get(url, timeout=5.0) + if resp.status_code == 200: + logger.info(f"Server {self._ip}:{self._port} is healthy") + return + except httpx.RequestError: + pass + + if time.time() - start_time > timeout: + raise TimeoutError(f"Server failed to become healthy within {timeout}s") + + await asyncio.sleep(1.0) + + async def _run_server(self) -> None: + """Internal method to run the HTTP server.""" + sock_addr = (self._cli_args.host, self._cli_args.port) + sock = create_server_socket(sock_addr) + app = build_app(self._cli_args) + + # Initialize the engine (this loads the model - takes time) + engine_args = AsyncEngineArgs.from_cli_args(self._cli_args) + self._engine = AsyncLLMEngine.from_engine_args( + engine_args=engine_args, + usage_context=UsageContext.OPENAI_API_SERVER, + ) + logger.info(f"Engine initialized on {self._ip}:{self._port}, adding custom endpoints...") + + # Add custom SkyRL endpoints + self._add_custom_endpoints(app) + + await init_app_state(self._engine, app.state, self._cli_args) + + # Use uvicorn directly (serve_http tries to add signal handlers which fails in Ray actors) + config = uvicorn.Config( + app, + host=self._cli_args.host, + port=self._cli_args.port, + log_level=self._cli_args.uvicorn_log_level, + timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, + ssl_keyfile=self._cli_args.ssl_keyfile, + ssl_certfile=self._cli_args.ssl_certfile, + ssl_ca_certs=self._cli_args.ssl_ca_certs, + ssl_cert_reqs=self._cli_args.ssl_cert_reqs, + access_log=not getattr(self._cli_args, "disable_uvicorn_access_log", False), + ) + server = uvicorn.Server(config) + await server.serve(sockets=[sock]) + + def _add_custom_endpoints(self, app) -> None: + """Add custom SkyRL endpoints to the FastAPI app.""" + engine = self._engine + + @app.get("/get_server_info") + async def _get_server_info(): + """Return server parallelism info.""" + return self._get_extended_server_info() + + # TODO (Kourosh): After https://github.com/vllm-project/vllm/pull/ + # 31943/ is merged, use the native API. + @app.post("/init_weight_transfer") + async def _init_weight_transfer(request: Request): + """Initialize weight sync process group.""" + from skyrl_train.weight_sync import BroadcastInitInfo + + data = await request.json() + init_info = BroadcastInitInfo(**data).for_engine( + engine_index=self._server_idx, + tp_size=self._cli_args.tensor_parallel_size, + pp_size=self._cli_args.pipeline_parallel_size, + ) + pickled_init_info = pickle.dumps(init_info) + + await engine.collective_rpc( + "init_weight_update_communicator", + args=(pickled_init_info,), + ) + return {"status": "ok"} + + @app.post("/update_weights") + async def _update_weights(request: Request): + """Update model weights via NCCL broadcast.""" + from skyrl_train.weight_sync import BroadcastWeightUpdateRequest + + data = await request.json() + weight_request = BroadcastWeightUpdateRequest(**data) + pickled_request = pickle.dumps(weight_request) + + await engine.collective_rpc( + "load_weights", + args=(pickled_request,), + ) + return {"status": "ok"} + + @app.post("/finalize_weight_update") + async def _finalize_weight_update(request: Request): + """ + Finalize weight update - post-processing hook. + + Currently a no-op, reserved for future use e.g. Quantization + See https://github.com/vllm-project/vllm/issues/31848 for more + details. + """ + # No-op for now - placeholder for future post-processing + return {"status": "ok"} + + async def shutdown(self) -> None: + """Gracefully shutdown the server.""" + if self._server_task: + self._server_task.cancel() + try: + await self._server_task + except asyncio.CancelledError: + pass diff --git a/skyrl-train/skyrl_train/inference_servers/vllm_worker.py b/skyrl-train/skyrl_train/inference_servers/vllm_worker.py new file mode 100644 index 000000000..8249b30a7 --- /dev/null +++ b/skyrl-train/skyrl_train/inference_servers/vllm_worker.py @@ -0,0 +1,103 @@ +""" +vLLM Worker Extension for SkyRL weight synchronization. + +This module provides WorkerWrap, a vLLM worker extension class that enables +efficient NCCL-based and CUDA IPC-based weight updates from the training +process to inference workers. + +TODO: This will be removed once vLLM natively supports weight sync APIs. +See: https://github.com/vllm-project/vllm/issues/31848 + +Usage: + Pass as --worker-extension-cls to vLLM: + + vllm serve ... --worker-extension-cls skyrl_train.inference_servers.vllm_worker.WorkerWrap +""" + +import warnings + +import torch + +# Path to this worker extension class for use in CLI args (derived from module path) +VLLM_WORKER_EXTENSION_CLS = f"{__name__}.WorkerWrap" + + +class WorkerWrap: + """ + vLLM worker extension for SkyRL weight synchronization. + + This class is injected into vLLM workers via --worker-extension-cls and + provides methods that can be called via engine.collective_rpc() to + coordinate weight updates across all TP/PP workers. + + Methods: + init_weight_update_communicator: Initialize the weight receiver + load_weights: Receive and load weights from trainer + teardown_weight_receiver: Clean up weight receiver resources + """ + + def test_rpc(self, *args, **kwargs): + """Test RPC call to worker.""" + return args, kwargs + + def init_weight_update_communicator(self, init_info: bytes): + """ + Initialize weight update communicator from init info. + + Args: + init_info: Pickled bytes of WeightSyncInitInfo from the sender. + """ + import pickle + + assert torch.distributed.is_initialized(), "default torch process group must be initialized" + + # Unpickle init_info to restore the original object type + assert isinstance(init_info, bytes), f"Expected bytes, got {type(init_info).__name__}" + init_info = pickle.loads(init_info) + + strategy_cls = init_info.strategy_type() + + if hasattr(self, "_weight_receiver") and self._weight_receiver is not None: + # TODO(haochen): we should get rid of this flag and override existing receiver. + if init_info.override_existing_receiver: + self._weight_receiver.teardown() + self._weight_receiver = None + else: + warnings.warn( + "Detected an existing weight receiver. " + "For overriding, use `generator.override_existing_update_group=enable`" + ) + return + + self._weight_receiver = strategy_cls.create_receiver(init_info) + + def load_weights(self, request: bytes) -> None: + """ + Load weights using the receiver. + + This method is called via collective_rpc from the weight loader. + + Args: + request: Pickled bytes of WeightUpdateRequest. + """ + import pickle + + # Unpickle request to restore the original object type + assert isinstance(request, bytes), f"Expected bytes, got {type(request).__name__}" + request = pickle.loads(request) + + weight_list = [] + for name, tensor in self._weight_receiver.receive_weights(request): + weight_list.append((name, tensor)) + + self.model_runner.model.load_weights(weights=weight_list) + + for weight in weight_list: + del weight + + def teardown_weight_receiver(self): + """Clean up weight receiver resources.""" + if not hasattr(self, "_weight_receiver") or self._weight_receiver is None: + warnings.warn("No weight receiver to teardown") + return + self._weight_receiver.teardown() diff --git a/skyrl-train/tests/cpu/inference_servers/test_common.py b/skyrl-train/tests/cpu/inference_servers/test_common.py new file mode 100644 index 000000000..76136368c --- /dev/null +++ b/skyrl-train/tests/cpu/inference_servers/test_common.py @@ -0,0 +1,37 @@ +"""Tests for inference_servers.common module.""" + +import socket + +from skyrl_train.inference_servers.common import ( + get_node_ip, + get_open_port, +) + + +class TestGetIp: + """Tests for get_ip function.""" + + def test_get_ip_returns_string(self): + """Test that get_ip returns a string.""" + ip = get_node_ip() + assert isinstance(ip, str) + assert len(ip) > 0 + assert ip != "" + assert "." in ip or ":" in ip + + +class TestGetOpenPort: + """Tests for get_open_port function.""" + + def test_get_open_port_os_assigned(self): + """Test that get_open_port returns an available port (OS assigned).""" + port = get_open_port() + assert isinstance(port, int) + assert 1 <= port <= 65535 + self._verify_port_is_free(port) + + def _verify_port_is_free(self, port: int): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("", port)) + s.listen(1) diff --git a/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py b/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py new file mode 100644 index 000000000..aff303560 --- /dev/null +++ b/skyrl-train/tests/cpu/inference_servers/test_remote_inference_client.py @@ -0,0 +1,483 @@ +"""Tests for RemoteInferenceClient.""" + +import asyncio +import pickle +import threading +import time +from typing import List + +import httpx +import pytest +import pytest_asyncio +import uvicorn +from fastapi import FastAPI, Request + +from skyrl_train.inference_servers.common import get_open_port +from skyrl_train.inference_servers.remote_inference_client import RemoteInferenceClient, PauseMode + + +def create_mock_vllm_server(server_id: int) -> FastAPI: + """Create a mock vLLM server with standard endpoints.""" + app = FastAPI() + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.get("/get_server_info") + async def get_server_info(): + return { + "ip": "127.0.0.1", + "port": 8000 + server_id, + "url": f"http://127.0.0.1:{8000 + server_id}", + "server_idx": server_id, + "world_size": 2, # Simulate TP=2 + } + + @app.post("/v1/completions") + async def completions(request: Request): + body = await request.json() + prompts = body.get("prompt", []) + n_prompts = len(prompts) if isinstance(prompts, list) else 1 + return { + "choices": [ + {"index": i, "text": f"Response {i} from server {server_id}", "finish_reason": "stop"} + for i in range(n_prompts) + ] + } + + @app.post("/v1/chat/completions") + async def chat_completions(request: Request): + return {"choices": [{"message": {"content": f"Chat from server {server_id}"}}]} + + @app.post("/tokenize") + async def tokenize(request: Request): + return {"tokens": [1, 2, 3]} + + @app.post("/detokenize") + async def detokenize(request: Request): + return {"prompt": "hello world"} + + # Control plane endpoints + @app.post("/pause") + async def pause(request: Request): + return {"status": "paused", "server_id": server_id} + + @app.post("/resume") + async def resume(): + return {"status": "resumed", "server_id": server_id} + + @app.get("/is_paused") + async def is_paused(): + # Mock always returns not paused for basic tests + return {"is_paused": False} + + @app.post("/sleep") + async def sleep(request: Request): + return {"status": "sleeping", "server_id": server_id} + + @app.post("/wake_up") + async def wake_up(): + return {"status": "awake", "server_id": server_id} + + @app.post("/reset_prefix_cache") + async def reset_prefix_cache(request: Request): + return {"status": "cache_reset", "server_id": server_id} + + @app.post("/init_weight_transfer") + async def init_weight_transfer(request: Request): + return {"status": "ok", "server_id": server_id} + + @app.post("/update_weights") + async def update_weights(request: Request): + return {"status": "ok", "server_id": server_id} + + @app.post("/finalize_weight_update") + async def finalize_weight_update(request: Request): + return {"status": "ok", "server_id": server_id} + + return app + + +def start_server(port: int, server_id: int) -> uvicorn.Server: + """Start a mock server, return the server instance.""" + app = create_mock_vllm_server(server_id) + config = uvicorn.Config(app, host="0.0.0.0", port=port, log_level="error") + server = uvicorn.Server(config) + + def run(): + asyncio.run(server.serve()) + + threading.Thread(target=run, daemon=True).start() + return server + + +def wait_ready(url: str, timeout: float = 5.0) -> bool: + """Wait for server to become healthy.""" + start = time.time() + while time.time() - start < timeout: + try: + if httpx.get(f"{url}/health", timeout=1.0).status_code == 200: + return True + except httpx.RequestError: + time.sleep(0.1) + return False + + +@pytest.fixture(scope="module") +def mock_servers(): + """Start mock vLLM servers, return proxy_url and server_urls.""" + servers: List[uvicorn.Server] = [] + ports = [get_open_port(), get_open_port()] + server_urls = [f"http://127.0.0.1:{p}" for p in ports] + + for i, port in enumerate(ports): + servers.append(start_server(port, server_id=i)) + + for url in server_urls: + assert wait_ready(url), f"Server {url} failed to start" + + # proxy_url defaults to first server; can be replaced with router URL later + yield {"proxy_url": server_urls[0], "server_urls": server_urls} + + # Cleanup + for server in servers: + server.should_exit = True + time.sleep(0.3) + + +@pytest_asyncio.fixture +async def client(mock_servers): + """Create a RemoteInferenceClient for data/control plane tests.""" + client = RemoteInferenceClient( + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], + ) + yield client + await client.teardown() + + +class TestRemoteInferenceClientInit: + """Test client initialization and serialization.""" + + def test_serialization(self, mock_servers): + """Client can be pickled and unpickled.""" + client = RemoteInferenceClient( + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], + model_name="test-model", + ) + + # Pickle and unpickle + pickled = pickle.dumps(client) + restored = pickle.loads(pickled) + + assert restored.proxy_url == client.proxy_url + assert restored.server_urls == client.server_urls + assert restored.model_name == client.model_name + # Session should be None after unpickling + assert restored._session is None + + def test_from_server_group(self, mock_servers): + """Test factory method from_server_group.""" + client = RemoteInferenceClient.from_server_group( + server_urls=mock_servers["server_urls"], + router_url=mock_servers["proxy_url"], + model_name="test-model", + ) + assert client.proxy_url == mock_servers["proxy_url"] + assert client.server_urls == mock_servers["server_urls"] + + def test_from_router(self, mock_servers): + """Test factory method from_router.""" + router_url = mock_servers["proxy_url"] + client = RemoteInferenceClient.from_router( + router_url=router_url, + model_name="test-model", + ) + assert client.proxy_url == router_url + assert client.server_urls == [router_url] + + +class TestDataPlane: + """Test data plane methods.""" + + @pytest.mark.asyncio + async def test_generate(self, client): + """Test generate method.""" + input_batch = { + "prompt_token_ids": [[1, 2, 3], [4, 5, 6]], + "sampling_params": {"max_tokens": 100}, + } + result = await client.generate(input_batch) + + assert "responses" in result + assert "stop_reasons" in result + assert len(result["responses"]) == 2 + assert all(r == "stop" for r in result["stop_reasons"]) + # response_ids are tokenized from the response + assert len(result["response_ids"]) == 2 + + @pytest.mark.asyncio + async def test_generate_with_session_id(self, client): + """Test generate with session ID for consistent routing.""" + input_batch = { + "prompt_token_ids": [[1, 2, 3]], + "session_ids": ["test-session"], + } + result = await client.generate(input_batch) + assert len(result["responses"]) == 1 + + @pytest.mark.asyncio + async def test_chat_completion(self, client): + """Test chat completion method.""" + request_payload = { + "json": { + "model": "test", + "messages": [{"role": "user", "content": "Hello"}], + }, + "headers": {}, + } + result = await client.chat_completion(request_payload) + assert "choices" in result + + @pytest.mark.asyncio + async def test_completion(self, client): + """Test completion method.""" + request_payload = { + "json": {"model": "test", "prompt": "Hello"}, + "headers": {}, + } + result = await client.completion(request_payload) + assert "choices" in result + + @pytest.mark.asyncio + async def test_tokenize(self, client): + """Test tokenize method.""" + result = await client.tokenize(["hello", "world"]) + assert len(result) == 2 + assert result[0] == [1, 2, 3] # Mock response + + @pytest.mark.asyncio + async def test_detokenize(self, client): + """Test detokenize method.""" + result = await client.detokenize([[1, 2, 3], [4, 5, 6]]) + assert len(result) == 2 + assert result[0] == "hello world" # Mock response + + +class TestControlPlane: + """Test control plane methods (fan-out to all servers).""" + + @pytest.mark.asyncio + async def test_pause_abort_mode(self, client): + """Test pause with ABORT mode (default) fans out to all servers.""" + result = await client.pause(mode=PauseMode.ABORT) + assert len(result) == 2 + for url, response in result.items(): + assert response["status"] == 200 + assert response["body"]["status"] == "paused" + + @pytest.mark.asyncio + async def test_pause_finish_mode(self, client): + """Test pause with FINISH mode fans out to all servers.""" + result = await client.pause(mode=PauseMode.FINISH) + assert len(result) == 2 + for url, response in result.items(): + assert response["status"] == 200 + + @pytest.mark.asyncio + async def test_resume(self, client): + """Test resume fans out to all servers.""" + # Pause first + await client.pause() + + # Resume + result = await client.resume() + assert len(result) == 2 + for url, response in result.items(): + assert response["status"] == 200 + + @pytest.mark.asyncio + async def test_sleep(self, client): + """Test sleep fans out to all servers.""" + result = await client.sleep(level=2) + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_wake_up(self, client): + """Test wake_up fans out to all servers.""" + result = await client.wake_up() + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_reset_prefix_cache(self, client): + """Test reset_prefix_cache fans out to all servers.""" + result = await client.reset_prefix_cache() + assert len(result) == 2 + + +class TestWeightSync: + """Test weight sync methods.""" + + @pytest.mark.asyncio + async def test_init_weight_transfer(self, client): + """Test init_weight_transfer fans out to all servers.""" + from skyrl_train.weight_sync import BroadcastInitInfo + + init_info = BroadcastInitInfo( + master_addr="127.0.0.1", + master_port=29500, + rank_offset=1, + world_size=5, + group_name="test", + backend="nccl", + model_dtype_str="torch.bfloat16", + override_existing_receiver=True, + ) + result = await client.init_weight_transfer(init_info) + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_update_weights(self, client): + """Test update_weights fans out to all servers.""" + from skyrl_train.weight_sync import BroadcastWeightUpdateRequest + + request = BroadcastWeightUpdateRequest( + names=["layer.weight"], + dtypes=["torch.bfloat16"], + shapes=[[1024, 1024]], + ) + result = await client.update_weights(request) + assert len(result) == 2 + + @pytest.mark.asyncio + async def test_finalize_weight_update(self, client): + """Test finalize_weight_update fans out to all servers.""" + result = await client.finalize_weight_update() + assert len(result) == 2 + + +class TestServerInfo: + """Test server info and world_size.""" + + @pytest.mark.asyncio + async def test_get_world_size(self, client): + """Test world_size fetching and caching.""" + # First call fetches from all servers and sums + world_size = await client.get_world_size() + # Each mock server reports world_size=2, we have 2 servers = 4 + assert world_size == 4 + + # Second call returns cached value + world_size2 = await client.get_world_size() + assert world_size2 == 4 + + +class TestContextManager: + """Test async context manager.""" + + @pytest.mark.asyncio + async def test_async_context_manager(self, mock_servers): + """Test using client as async context manager.""" + + client = RemoteInferenceClient( + proxy_url=mock_servers["proxy_url"], + server_urls=mock_servers["server_urls"], + ) + + async with client: + result = await client.resume() + assert len(result) == 2 + + # Session should be closed after exiting context + assert client._session is None or client._session.closed + + +class TestRetryOnAbort: + """Test retry on abort functionality.""" + + @pytest.fixture + def abort_mock_server(self): + """Create a mock server that returns abort on first call, then stop.""" + app = FastAPI() + call_count = {"completions": 0} + + @app.get("/health") + async def health(): + return {"status": "ok"} + + @app.post("/v1/completions") + async def completions(request: Request): + call_count["completions"] += 1 + await request.json() # Consume body + + # First call returns abort with partial response + if call_count["completions"] == 1: + return {"choices": [{"index": 0, "text": "Partial ", "finish_reason": "abort"}]} + # Second call returns complete response + else: + return {"choices": [{"index": 0, "text": "response complete", "finish_reason": "stop"}]} + + @app.post("/tokenize") + async def tokenize(request: Request): + body = await request.json() + prompt = body.get("prompt", "") + # Simple tokenization: one token per word + tokens = [hash(word) % 10000 for word in prompt.split()] + return {"tokens": tokens} + + @app.get("/get_server_info") + async def get_server_info(): + return {"world_size": 1} + + @app.get("/is_paused") + async def is_paused(): + # Not paused - allows retry to proceed immediately + return {"is_paused": False} + + # Start server in background thread + port = get_open_port() + config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level="warning") + server = uvicorn.Server(config) + thread = threading.Thread(target=server.run, daemon=True) + thread.start() + + # Wait for server to be ready + for _ in range(100): + try: + httpx.get(f"http://127.0.0.1:{port}/health", timeout=0.1) + break + except Exception: + time.sleep(0.05) + + yield f"http://127.0.0.1:{port}", call_count + + server.should_exit = True + thread.join(timeout=1) + + @pytest.mark.asyncio + async def test_retry_on_abort(self, abort_mock_server): + """Test that retry on abort is always active (built-in behavior).""" + url, call_count = abort_mock_server + client = RemoteInferenceClient( + proxy_url=url, + server_urls=[url], + ) + + try: + result = await client.generate( + { + "prompt_token_ids": [[1, 2, 3]], + "sampling_params": {"max_tokens": 100}, + } + ) + + # Should get complete response after retry + assert result["stop_reasons"][0] == "stop" + assert result["responses"][0] == "Partial response complete" + assert call_count["completions"] == 2 + # Should have response_ids from tokenization + assert len(result["response_ids"][0]) > 0 + finally: + await client.teardown() diff --git a/skyrl-train/tests/cpu/inference_servers/test_router.py b/skyrl-train/tests/cpu/inference_servers/test_router.py new file mode 100644 index 000000000..386e72fac --- /dev/null +++ b/skyrl-train/tests/cpu/inference_servers/test_router.py @@ -0,0 +1,118 @@ +"""Tests for InferenceRouter.""" + +import asyncio +import threading +import time +from typing import List + +import httpx +import pytest +import uvicorn +from fastapi import FastAPI + +from skyrl_train.inference_servers.common import get_open_port +from skyrl_train.inference_servers.router import InferenceRouter + + +def create_mock_server(server_id: int) -> FastAPI: + app = FastAPI() + + @app.api_route("/{path:path}", methods=["GET", "POST"]) + async def catch_all(path: str): + return {"server_id": server_id, "path": f"/{path}"} + + return app + + +def start_server(port: int, server_id: int) -> uvicorn.Server: + """Start a mock server, return the server instance for cleanup.""" + app = create_mock_server(server_id) + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="error") + server = uvicorn.Server(config) + + def run(): + asyncio.run(server.serve()) + + threading.Thread(target=run, daemon=True).start() + return server + + +def wait_ready(url: str, timeout: float = 5.0) -> bool: + start = time.time() + while time.time() - start < timeout: + try: + if httpx.get(f"{url}/health", timeout=1.0).status_code == 200: + return True + except httpx.RequestError: + time.sleep(0.1) + return False + + +@pytest.fixture(scope="module") +def env(): + """Start mock servers and router, clean up after tests.""" + servers: List[uvicorn.Server] = [] + + # Start mock servers + ports = [get_open_port(), get_open_port()] + router_port = get_open_port() + urls = [f"http://127.0.0.1:{p}" for p in ports] + + for i, port in enumerate(ports): + servers.append(start_server(port, server_id=i)) + for url in urls: + assert wait_ready(url) + + # Start router + router = InferenceRouter(urls, host="127.0.0.1", port=router_port) + router._client = httpx.AsyncClient(timeout=httpx.Timeout(None)) + router._app = router._build_app() + + router_config = uvicorn.Config(router._app, host="127.0.0.1", port=router_port, log_level="error") + router_server = uvicorn.Server(router_config) + servers.append(router_server) + + def run_router(): + asyncio.run(router_server.serve()) + + threading.Thread(target=run_router, daemon=True).start() + + router_url = f"http://127.0.0.1:{router_port}" + assert wait_ready(router_url) + + yield router_url + + # Cleanup: signal all servers to shutdown + for server in servers: + server.should_exit = True + time.sleep(0.5) # Give servers time to shutdown + + +def test_round_robin(env): + """Requests without session distribute across servers.""" + server_ids = {httpx.get(f"{env}/health").json()["server_id"] for _ in range(4)} + assert len(server_ids) == 2 + + +def test_session_affinity(env): + """Same X-Session-ID routes to same server.""" + headers = {"X-Session-ID": "sticky"} + ids = [httpx.get(f"{env}/health", headers=headers).json()["server_id"] for _ in range(3)] + assert len(set(ids)) == 1 + + +def test_control_plane_fanout(env): + """Control plane routes fan out to all servers.""" + resp = httpx.post(f"{env}/sleep", json={}) + assert resp.status_code == 200 + # Response is a mapping of server_url -> {status, body} + response_map = resp.json() + assert len(response_map) == 2 # Both servers received the request + for server_url, result in response_map.items(): + assert result["status"] == 200 + + +def test_list_servers(env): + """/servers returns all server URLs.""" + resp = httpx.get(f"{env}/servers") + assert resp.status_code == 200 and len(resp.json()["servers"]) == 2 diff --git a/skyrl-train/tests/gpu/gpu_ci/conftest.py b/skyrl-train/tests/gpu/gpu_ci/conftest.py index e6f35b11c..ed33b89b5 100644 --- a/skyrl-train/tests/gpu/gpu_ci/conftest.py +++ b/skyrl-train/tests/gpu/gpu_ci/conftest.py @@ -1,8 +1,10 @@ +import os import pytest import ray from loguru import logger from functools import lru_cache from skyrl_train.utils.utils import peer_access_supported +from skyrl_train.env_vars import SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV @lru_cache(5) @@ -11,7 +13,7 @@ def log_once(msg): return None -@pytest.fixture +@pytest.fixture(scope="class") def ray_init_fixture(): if ray.is_initialized(): ray.shutdown() @@ -32,6 +34,9 @@ def ray_init_fixture(): env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" env_vars["NVTE_FUSED_ATTN"] = "0" + if SKYRL_INCLUDE_PYTHONPATH_IN_RUNTIME_ENV: + env_vars["PYTHONPATH"] = os.environ.get("PYTHONPATH") + logger.info(f"Initializing Ray with environment variables: {env_vars}") ray.init(runtime_env={"env_vars": env_vars}) diff --git a/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py b/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py new file mode 100644 index 000000000..53e6f65f1 --- /dev/null +++ b/skyrl-train/tests/gpu/gpu_ci/test_inference_server_group.py @@ -0,0 +1,203 @@ +""" +GPU CI tests for ServerGroup + InferenceRouter. + +Tests: + - 2 vLLM servers with TP=2 (4 GPUs total) + - Router with load balancing and control plane fan-out + - Health, completions, get_server_info, session affinity, pause/resume + +Run: + uv run pytest tests/gpu/gpu_ci/test_inference_server_group.py -v -s +""" + +import asyncio +import time + +import httpx +import pytest +import torch +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from skyrl_train.inference_servers.common import get_open_port +from skyrl_train.inference_servers.router import InferenceRouter +from skyrl_train.inference_servers.server_group import ServerGroup + +MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + +# Skip entire module if not enough GPUs +_gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 +if _gpu_count < 4: + pytest.skip(f"Need 4 GPUs for full test suite, found {_gpu_count}", allow_module_level=True) + + +def make_vllm_cli_args( + model: str, + tp_size: int = 2, + load_format: str = "auto", +) -> FlexibleArgumentParser: + """Create CLI args for vLLM server using official parser.""" + parser = FlexibleArgumentParser(description="vLLM server") + parser = make_arg_parser(parser) + return parser.parse_args( + [ + "--model", + model, + "--tensor-parallel-size", + str(tp_size), + "--enforce-eager", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "2048", + "--load-format", + load_format, + ] + ) + + +def wait_for_url(url: str, timeout: float = 180.0) -> bool: + """Wait for a URL to become available.""" + start = time.time() + while time.time() - start < timeout: + try: + resp = httpx.get(f"{url}/health", timeout=5.0) + if resp.status_code == 200: + return True + except httpx.RequestError: + time.sleep(2.0) + return False + + +@pytest.fixture(scope="class") +def server_group_and_router(ray_init_fixture): + """Create 2 vLLM servers (TP=2 each) + router.""" + cli_args = make_vllm_cli_args(MODEL, tp_size=2) + start_port = get_open_port() + + # Create server group with 2 servers + group = ServerGroup( + cli_args=cli_args, + num_servers=2, + start_port=start_port, + ) + server_infos = group.start() + server_urls = [info.url for info in server_infos] + + # Wait for servers + for url in server_urls: + assert wait_for_url(url), f"Server {url} failed to start" + + # Create router + router_port = get_open_port() + router = InferenceRouter(server_urls, host="0.0.0.0", port=router_port) + router_url = router.start() + assert wait_for_url(router_url), "Router failed to start" + + yield { + "group": group, + "server_urls": server_urls, + "router": router, + "router_url": router_url, + } + + router.shutdown() + group.shutdown() + + +class TestServerGroupAndRouter: + """Tests for ServerGroup + InferenceRouter with 2 TP=2 servers.""" + + def test_health_check(self, server_group_and_router): + """Health endpoint works through router.""" + router_url = server_group_and_router["router_url"] + resp = httpx.get(f"{router_url}/health", timeout=10.0) + assert resp.status_code == 200 + + def test_list_servers(self, server_group_and_router): + """/servers returns all backends.""" + router_url = server_group_and_router["router_url"] + resp = httpx.get(f"{router_url}/servers", timeout=10.0) + assert resp.status_code == 200 + assert len(resp.json()["servers"]) == 2 + + def test_get_server_info(self, server_group_and_router): + """/get_server_info returns mapping of server_url -> info for all servers.""" + router_url = server_group_and_router["router_url"] + server_urls = server_group_and_router["server_urls"] + + resp = httpx.get(f"{router_url}/get_server_info", timeout=10.0) + assert resp.status_code == 200 + info_map = resp.json() + print(f"Server info map: {info_map}") + + # Should have info for each server + assert len(info_map) == 2 + for url in server_urls: + assert url in info_map + server_info = info_map[url] + # Each server has TP=2, so per-server world_size=2 + assert server_info["world_size"] == 2 + + def test_completion_request(self, server_group_and_router): + """Completion requests work through router.""" + router_url = server_group_and_router["router_url"] + + payload = { + "model": MODEL, + "prompt": "What is 2 + 2? Answer:", + "max_tokens": 16, + "temperature": 0.0, + } + + resp = httpx.post(f"{router_url}/v1/completions", json=payload, timeout=60.0) + assert resp.status_code == 200 + data = resp.json() + assert "choices" in data + assert len(data["choices"]) > 0 + assert "text" in data["choices"][0] + print(f"Completion: {data['choices'][0]['text']}") + + @pytest.mark.asyncio + async def test_pause_resume(self, server_group_and_router): + """Pause/resume control plane routes work.""" + router_url = server_group_and_router["router_url"] + + async with httpx.AsyncClient() as client: + # Pause + resp = await client.post( + f"{router_url}/pause", + json={"wait_for_inflight_request": False}, + timeout=30.0, + ) + assert resp.status_code == 200 + + # Check is paused + resp = await client.get(f"{router_url}/is_paused", timeout=30.0) + assert resp.status_code == 200 + assert resp.json()["is_paused"] is True + + # Send a request while paused (should block) + async def send_request(): + r = await client.post( + f"{router_url}/v1/completions", + json={"model": MODEL, "prompt": "Test", "max_tokens": 4}, + timeout=60.0, + ) + assert r.status_code == 200 + return r.json() + + task = asyncio.create_task(send_request()) + await asyncio.sleep(1) + + # Task should not be done here (request blocked by pause) + assert not task.done() + + # Resume + resp = await client.post(f"{router_url}/resume", json={}, timeout=30.0) + assert resp.status_code == 200 + + # Verify that after resume, the request is completed + result = await task + assert result["choices"][0]["text"] is not None diff --git a/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py b/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py new file mode 100644 index 000000000..a26649fd5 --- /dev/null +++ b/skyrl-train/tests/gpu/gpu_ci/test_weight_sync.py @@ -0,0 +1,308 @@ +""" +GPU CI tests for weight synchronization from trainer to inference server. + +Tests: + - 1 vLLM server with TP=2 + dummy weights + - Trainer emulation with real weights on separate GPU + - Weight sync via NCCL broadcast through router + +Run: + uv run pytest tests/gpu/gpu_ci/test_weight_sync.py -v -s +""" + +import time + +import httpx +import pytest +import ray +import torch +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from transformers import AutoModelForCausalLM +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.utils.argparse_utils import FlexibleArgumentParser + +from skyrl_train.inference_servers.common import get_node_ip, get_open_port +from skyrl_train.inference_servers.router import InferenceRouter +from skyrl_train.inference_servers.server_group import ServerGroup + +MODEL = "Qwen/Qwen2.5-0.5B-Instruct" + + +# Skip entire module if not enough GPUs (need 3: 1 trainer + 2 for TP=2 server) +_gpu_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 +if _gpu_count < 3: + pytest.skip(f"Need 3 GPUs for weight sync test, found {_gpu_count}", allow_module_level=True) + + +def make_vllm_cli_args( + model: str, + tp_size: int = 2, + load_format: str = "auto", +) -> FlexibleArgumentParser: + """Create CLI args for vLLM server using official parser.""" + parser = FlexibleArgumentParser(description="vLLM server") + parser = make_arg_parser(parser) + return parser.parse_args( + [ + "--model", + model, + "--tensor-parallel-size", + str(tp_size), + "--enforce-eager", + "--gpu-memory-utilization", + "0.5", + "--max-model-len", + "2048", + "--load-format", + load_format, + ] + ) + + +def wait_for_url(url: str, timeout: float = 180.0) -> bool: + """Wait for a URL to become available.""" + start = time.time() + while time.time() - start < timeout: + try: + resp = httpx.get(f"{url}/health", timeout=5.0) + if resp.status_code == 200: + return True + except httpx.RequestError: + time.sleep(2.0) + return False + + +@ray.remote +class Trainer: + """ + Simple trainer emulator that holds the real model weights. + + This is a simplified version of the trainer side for testing weight sync. + Non-colocated: runs on a separate GPU from the inference server. + """ + + def __init__(self, model_name: str, device: str = "cuda"): + self.device = torch.device(device) + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + ).to(self.device) + self.pg = None + self.model_name = model_name + + def ready(self): + """Check if the trainer is ready.""" + return True + + def init_weight_sync(self, master_address: str, master_port: int, world_size: int, group_name: str): + """Initialize the weight sync process group as rank 0 (trainer).""" + from skyrl_train.distributed.utils import init_custom_process_group + from skyrl_train.utils import get_tcp_url + + self.pg = init_custom_process_group( + backend="nccl", + init_method=get_tcp_url(master_address, master_port), + world_size=world_size, + rank=0, # Trainer is always rank 0 + group_name=group_name, + ) + return True + + def get_weight_info(self) -> dict: + """ + Get weight metadata (names, dtypes, shapes) without doing NCCL. + + Returns: + dict with names, dtypes, shapes for the weight update request. + """ + names = [] + dtypes = [] + shapes = [] + + for name, param in self.model.named_parameters(): + names.append(name) + dtypes.append(str(param.dtype).split(".")[-1]) # e.g. "bfloat16" + shapes.append(list(param.shape)) + + return {"names": names, "dtypes": dtypes, "shapes": shapes} + + def broadcast_weights(self): + """ + Broadcast all model weights to inference workers via NCCL. + + This is a blocking operation - server must call receive concurrently. + """ + for name, param in self.model.named_parameters(): + torch.distributed.broadcast(param.data, src=0, group=self.pg) + torch.cuda.synchronize() + + def teardown(self): + """Clean up the process group.""" + if self.pg is not None: + torch.distributed.destroy_process_group(self.pg) + self.pg = None + + +@pytest.fixture(scope="class") +def weight_update_env(ray_init_fixture): + """ + Create environment for weight update testing: + - Trainer with real weights on GPU 0 + - 1 vLLM server with TP=2 and DUMMY weights (uses GPU 1,2) + - Router to proxy requests + """ + # Create server with dummy weights + cli_args = make_vllm_cli_args(MODEL, tp_size=2, load_format="dummy") + start_port = get_open_port() + + pg = placement_group([{"CPU": 1, "GPU": 1} for _ in range(3)]) + ray.get(pg.ready()) + + trainer = Trainer.options( + num_gpus=1, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=pg, + placement_group_bundle_index=0, + ), + ).remote(MODEL) + + ray.get(trainer.ready.remote()) + + group = ServerGroup( + cli_args=cli_args, + num_servers=1, + start_port=start_port, + placement_group=pg, + placement_group_bundle_offset=1, + ) + server_infos = group.start() + server_urls = [info.url for info in server_infos] + + for url in server_urls: + assert wait_for_url(url), f"Server {url} failed to start" + + # Create router + router_port = get_open_port() + router = InferenceRouter(server_urls, host="0.0.0.0", port=router_port) + router_url = router.start() + assert wait_for_url(router_url), "Router failed to start" + + yield { + "group": group, + "server_urls": server_urls, + "router": router, + "router_url": router_url, + "trainer": trainer, + } + + router.shutdown() + group.shutdown() + ray.get(trainer.teardown.remote()) + + +class TestWeightUpdateFlow: + """Tests for weight synchronization from trainer to inference server.""" + + @pytest.mark.asyncio + async def test_update_weights_flow(self, weight_update_env): + """ + Full E2E weight sync test via router: + 1. Query with dummy weights → gibberish + 2. Init weight transfer (both sides concurrently via router) + 3. Broadcast weights from trainer (concurrent with server receive) + 4. Finalize weight update + 5. Query again → correct output + """ + router_url = weight_update_env["router_url"] + trainer = weight_update_env["trainer"] + + async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as client: + # ===== Step 1: Verify dummy weights produce gibberish ===== + payload = { + "model": MODEL, + "prompt": "What is the capital of France?", + "max_tokens": 32, + "temperature": 0.0, + } + + resp = await client.post(f"{router_url}/v1/completions", json=payload) + assert resp.status_code == 200 + + text_before = resp.json()["choices"][0]["text"] + print(f"[Step 1] Dummy weights output: {text_before!r}") + + # Dummy weights should NOT produce coherent output about Paris + assert "Paris" not in text_before, "Dummy weights unexpectedly produced correct answer" + + # ===== Step 2: Init weight transfer (both sides concurrently) ===== + master_address = get_node_ip() + master_port = get_open_port() + + # Query all servers for world_size (TP * PP) via router + resp = await client.get(f"{router_url}/get_server_info") + assert resp.status_code == 200 + server_info_map = resp.json() + # Sum world_size across all servers + inference_world_size = sum(info["world_size"] for info in server_info_map.values()) + world_size = 1 + inference_world_size # 1 trainer + all inference workers + group_name = f"weight_sync_test_{master_port}" + + print(f"[Step 2] Init weight transfer: master={master_address}:{master_port}, world_size={world_size}") + + init_info = { + "master_addr": master_address, + "master_port": master_port, + "rank_offset": 1, + "world_size": world_size, + "group_name": group_name, + "backend": "nccl", + "model_dtype_str": "bfloat16", + "override_existing_receiver": True, + } + + # Both sides must init concurrently (NCCL blocks until all ranks join) + # Start trainer init (returns immediately, runs in Ray actor) + trainer_init_ref = trainer.init_weight_sync.remote(master_address, master_port, world_size, group_name) + + # Await server init (triggers NCCL join on server side) + server_resp = await client.post(f"{router_url}/init_weight_transfer", json=init_info) + assert server_resp.status_code == 200, f"Server init failed: {server_resp.text}" + + # Trainer should be done now (NCCL group formed) + ray.get(trainer_init_ref) + print("[Step 2] Both sides init complete") + + # ===== Step 3: Broadcast weights (concurrent send/receive) ===== + print("[Step 3] Broadcasting weights from trainer to server...") + + # Get weight metadata first (no NCCL yet) + weight_info = ray.get(trainer.get_weight_info.remote()) + print(f"[Step 3] Weight info: {len(weight_info['names'])} parameters") + + # Start trainer broadcast (returns immediately, runs in Ray actor) + trainer_broadcast_ref = trainer.broadcast_weights.remote() + + # Await server receive (triggers NCCL receive on server side) + server_resp = await client.post(f"{router_url}/update_weights", json=weight_info) + assert server_resp.status_code == 200, f"Update weights failed: {server_resp.text}" + + # Trainer should be done now (NCCL broadcast complete) + ray.get(trainer_broadcast_ref) + print("[Step 3] Weight sync complete") + + # ===== Step 4: Finalize weight update ===== + resp = await client.post(f"{router_url}/finalize_weight_update", json={}) + assert resp.status_code == 200 + print("[Step 4] Weight update finalized") + + # ===== Step 5: Query again - should produce correct output ===== + resp = await client.post(f"{router_url}/v1/completions", json=payload) + assert resp.status_code == 200 + + text_after = resp.json()["choices"][0]["text"] + print(f"[Step 5] Real weights output: {text_after!r}") + + assert "Paris" in text_after, f"Weight sync failed - expected 'Paris' but got: {text_after!r}" + + print("[SUCCESS] Weight sync test passed!")